Moved cache to a separate package.

This commit is contained in:
Nathan Osman 2016-04-30 21:00:43 -07:00
parent 397626194f
commit b233dd8e92
7 changed files with 192 additions and 164 deletions

View File

@ -1,4 +1,4 @@
package main package cache
import ( import (
"crypto/md5" "crypto/md5"
@ -11,11 +11,18 @@ import (
"sync" "sync"
) )
// Reader is a generic interface for reading cache entries either from disk or
// directly attached to a downloader.
type Reader interface {
io.ReadCloser
GetEntry() (*Entry, error)
}
// Cache provides access to entries in the cache. // Cache provides access to entries in the cache.
type Cache struct { type Cache struct {
mutex sync.Mutex mutex sync.Mutex
directory string directory string
downloaders map[string]*Downloader downloaders map[string]*downloader
waitGroup sync.WaitGroup waitGroup sync.WaitGroup
} }
@ -26,15 +33,15 @@ func NewCache(directory string) (*Cache, error) {
} }
return &Cache{ return &Cache{
directory: directory, directory: directory,
downloaders: make(map[string]*Downloader), downloaders: make(map[string]*downloader),
}, nil }, nil
} }
// GetReader obtains an io.Reader for the specified rawurl. If a downloader // GetReader obtains a Reader for the specified rawurl. If a downloader
// currently exists for the URL, a live reader is created and connected to it. // currently exists for the URL, a live reader is created and connected to it.
// If the URL exists in the cache, it is read using the standard file API. If // If the URL exists in the cache, it is read using the standard file API. If
// not, a downloader and live reader are created. // not, a downloader and live reader are created.
func (c *Cache) GetReader(rawurl string) (io.ReadCloser, chan *Entry, error) { func (c *Cache) GetReader(rawurl string) (Reader, error) {
var ( var (
b = md5.Sum([]byte(rawurl)) b = md5.Sum([]byte(rawurl))
hash = hex.EncodeToString(b[:]) hash = hex.EncodeToString(b[:])
@ -48,30 +55,21 @@ func (c *Cache) GetReader(rawurl string) (io.ReadCloser, chan *Entry, error) {
_, err := os.Stat(jsonFilename) _, err := os.Stat(jsonFilename)
if err != nil { if err != nil {
if !os.IsNotExist(err) { if !os.IsNotExist(err) {
return nil, nil, err return nil, err
} }
} else { } else {
e := &Entry{} r, err := newDiskReader(jsonFilename, dataFilename)
if err = e.Load(jsonFilename); err != nil {
return nil, nil, err
}
if e.Complete {
f, err := os.Open(dataFilename)
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
eChan := make(chan *Entry) if e, _ := r.GetEntry(); e.Complete {
go func() {
eChan <- e
close(eChan)
}()
log.Println("[HIT]", rawurl) log.Println("[HIT]", rawurl)
return f, eChan, nil return r, nil
} }
} }
d = NewDownloader(rawurl, jsonFilename, dataFilename) d = newDownloader(rawurl, jsonFilename, dataFilename)
go func() { go func() {
d.Wait() d.WaitForDone()
c.mutex.Lock() c.mutex.Lock()
defer c.mutex.Unlock() defer c.mutex.Unlock()
delete(c.downloaders, hash) delete(c.downloaders, hash)
@ -80,13 +78,8 @@ func (c *Cache) GetReader(rawurl string) (io.ReadCloser, chan *Entry, error) {
c.downloaders[hash] = d c.downloaders[hash] = d
c.waitGroup.Add(1) c.waitGroup.Add(1)
} }
eChan := make(chan *Entry)
go func() {
eChan <- d.GetEntry()
close(eChan)
}()
log.Println("[MISS]", rawurl) log.Println("[MISS]", rawurl)
return NewLiveReader(d, dataFilename), eChan, nil return newLiveReader(d, dataFilename)
} }
// TODO: implement some form of "safe abort" for downloads so that the entire // TODO: implement some form of "safe abort" for downloads so that the entire

43
cache/diskreader.go vendored Normal file
View File

@ -0,0 +1,43 @@
package cache
import (
"os"
)
// diskReader reads a file from the cache on disk.
type diskReader struct {
entry *Entry
file *os.File
}
// newDiskReader creates a reader from the provided JSON and data filenames.
// Failure to open either of these results in an immediate error.
func newDiskReader(jsonFilename, dataFilename string) (*diskReader, error) {
e := &Entry{}
if err := e.Load(jsonFilename); err != nil {
return nil, err
}
f, err := os.Open(dataFilename)
if err != nil {
return nil, err
}
return &diskReader{
entry: e,
file: f,
}, nil
}
// Read attempts to read as much data as possible into the provided buffer.
func (d *diskReader) Read(p []byte) (int, error) {
return d.file.Read(p)
}
// Close attempts to close the data file.
func (d *diskReader) Close() error {
return d.file.Close()
}
// GetEntry returns the Entry associated with the file.
func (d *diskReader) GetEntry() (*Entry, error) {
return d.entry, nil
}

View File

@ -1,4 +1,4 @@
package main package cache
import ( import (
"errors" "errors"
@ -9,17 +9,17 @@ import (
"sync" "sync"
) )
// Downloader attempts to download a file from a remote URL. // downloader attempts to download a file from a remote URL.
type Downloader struct { type downloader struct {
doneMutex sync.Mutex doneMutex sync.Mutex
err error err error
entry *Entry entry *Entry
entryMutex sync.Mutex entryMutex sync.Mutex
} }
// NewDownloader creates a new downloader. // newDownloader creates a new downloader.
func NewDownloader(rawurl, jsonFilename, dataFilename string) *Downloader { func newDownloader(rawurl, jsonFilename, dataFilename string) *downloader {
d := &Downloader{} d := &downloader{}
d.doneMutex.Lock() d.doneMutex.Lock()
d.entryMutex.Lock() d.entryMutex.Lock()
go func() { go func() {
@ -72,16 +72,15 @@ func NewDownloader(rawurl, jsonFilename, dataFilename string) *Downloader {
return d return d
} }
// GetEntry waits until the Entry associated with the download is available. // GetEntry retrieves the entry associated with the download.
// This call will block until the entry is available or an error occurs. func (d *downloader) GetEntry() (*Entry, error) {
func (d *Downloader) GetEntry() *Entry {
d.entryMutex.Lock() d.entryMutex.Lock()
defer d.entryMutex.Unlock() defer d.entryMutex.Unlock()
return d.entry return d.entry, d.err
} }
// Wait will block until the download completes. // WaitForDone will block until the download completes.
func (d *Downloader) Wait() error { func (d *downloader) WaitForDone() error {
d.doneMutex.Lock() d.doneMutex.Lock()
defer d.doneMutex.Unlock() defer d.doneMutex.Unlock()
return d.err return d.err

View File

@ -1,4 +1,4 @@
package main package cache
import ( import (
"encoding/json" "encoding/json"

93
cache/livereader.go vendored Normal file
View File

@ -0,0 +1,93 @@
package cache
import (
"github.com/fsnotify/fsnotify"
"io"
"os"
)
// liveReader reads a file from disk, synchronizing reads with a downloader.
type liveReader struct {
downloader *downloader
file *os.File
watcher *fsnotify.Watcher
entry *Entry
done chan error
err error
eof bool
}
// newLiveReader creates a reader from the provided downloader and data
// file. fsnotify is used to watch for writes to the file to avoid using a
// spinloop. Invoking this function assumes the existence of the data file.
func newLiveReader(d *downloader, dataFilename string) (*liveReader, error) {
f, err := os.Open(dataFilename)
if err != nil {
return nil, err
}
w, err := fsnotify.NewWatcher()
if err != nil {
return nil, err
}
if err = w.Add(dataFilename); err != nil {
return nil, err
}
l := &liveReader{
downloader: d,
file: f,
watcher: w,
done: make(chan error),
}
go func() {
defer close(l.done)
l.done <- d.WaitForDone()
}()
return l, err
}
// Read attempts to read as much data as possible into the provided buffer.
// Since data is being downloaded as data is being read, fsnotify is used to
// monitor writes to the file. This function blocks until the requested amount
// of data is read, an error occurs, or EOF is encountered.
func (l *liveReader) Read(p []byte) (int, error) {
if l.err != nil {
return 0, l.err
}
bytesRead := 0
loop:
for bytesRead < len(p) {
n, err := l.file.Read(p[bytesRead:])
bytesRead += n
if err != nil {
if err != io.EOF || l.eof {
l.err = err
break loop
}
for {
select {
case e := <-l.watcher.Events:
if e.Op&fsnotify.Write != fsnotify.Write {
continue
}
case err = <-l.done:
l.err = err
l.eof = true
}
continue loop
}
}
}
return bytesRead, l.err
}
// Close attempts to close the data file (if opened).
func (l *liveReader) Close() error {
return l.file.Close()
}
// GetEntry returns the Entry associated with the file, blocking until either
// the data is available or an error occurs.
func (l *liveReader) GetEntry() (*Entry, error) {
return l.downloader.GetEntry()
}

View File

@ -1,99 +0,0 @@
package main
import (
"github.com/fsnotify/fsnotify"
"io"
"os"
)
// LiveReader synchronizes with a downloader to read from a file.
type LiveReader struct {
dataFilename string
open chan bool
done chan error
file *os.File
err error
eof bool
}
// NewLiveReader creates a new live reader.
func NewLiveReader(d *Downloader, dataFilename string) *LiveReader {
l := &LiveReader{
dataFilename: dataFilename,
open: make(chan bool),
done: make(chan error),
}
go func() {
d.GetEntry()
close(l.open)
l.done <- d.Wait()
close(l.done)
}()
return l
}
// Read attempts to read data as it is being downloaded. If EOF is reached,
// fsnotify is used to watch for new data being written. The download is not
// complete until the "done" channel receives a value.
func (l *LiveReader) Read(p []byte) (int, error) {
if l.err != nil {
return 0, l.err
}
<-l.open
if l.file == nil {
f, err := os.Open(l.dataFilename)
if err != nil {
return 0, err
}
l.file = f
}
var (
bytesRead int
watcher *fsnotify.Watcher
)
loop:
for bytesRead < len(p) {
n, err := l.file.Read(p[bytesRead:])
bytesRead += n
if err != nil {
if err != io.EOF || l.eof {
l.err = err
break loop
}
if watcher == nil {
watcher, err = fsnotify.NewWatcher()
if err != nil {
l.err = err
break loop
}
defer watcher.Close()
if err = watcher.Add(l.dataFilename); err != nil {
l.err = err
break loop
}
}
for {
select {
case e := <-watcher.Events:
if e.Op&fsnotify.Write != fsnotify.Write {
continue
}
case err = <-l.done:
l.err = err
l.eof = true
}
continue loop
}
}
}
return bytesRead, l.err
}
// Close frees resources associated with the reader.
func (l *LiveReader) Close() error {
if l.file != nil {
l.file.Close()
}
return nil
}

View File

@ -2,6 +2,7 @@ package main
import ( import (
"github.com/hectane/go-asyncserver" "github.com/hectane/go-asyncserver"
"github.com/nathan-osman/go-aptproxy/cache"
"io" "io"
"log" "log"
@ -16,7 +17,7 @@ import (
// needless duplication. // needless duplication.
type Server struct { type Server struct {
server *server.AsyncServer server *server.AsyncServer
cache *Cache cache *cache.Cache
} }
func rewrite(rawurl string) string { func rewrite(rawurl string) string {
@ -31,7 +32,7 @@ func rewrite(rawurl string) string {
return rawurl return rawurl
} }
func (s *Server) writeHeaders(w http.ResponseWriter, e *Entry) { func (s *Server) writeHeaders(w http.ResponseWriter, e *cache.Entry) {
if e.ContentType != "" { if e.ContentType != "" {
w.Header().Set("Content-Type", e.ContentType) w.Header().Set("Content-Type", e.ContentType)
} else { } else {
@ -48,24 +49,25 @@ func (s *Server) writeHeaders(w http.ResponseWriter, e *Entry) {
} }
// TODO: support for HEAD requests // TODO: support for HEAD requests
// TODO: find a reasonable way for getting errors from eChan
// ServeHTTP processes an incoming request to the proxy. GET requests are // ServeHTTP processes an incoming request to the proxy. GET requests are
// served with the storage backend and every other request is (out of // served with the storage backend and every other request is (out of
// necessity) rejected since it can't be cached. // necessity) rejected since it can't be cached.
func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if req.Method == "GET" { if req.Method != "GET" {
r, eChan, err := s.cache.GetReader(rewrite(req.RequestURI)) http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
}
r, err := s.cache.GetReader(rewrite(req.RequestURI))
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
log.Println("[ERR]", err) log.Println("[ERR]", err)
return return
} }
defer r.Close() defer r.Close()
e := <-eChan e, err := r.GetEntry()
if e == nil { if err != nil {
http.Error(w, "header retrieval error", http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
log.Println("[ERR] header retrieval") log.Println("[ERR]", err)
return return
} }
s.writeHeaders(w, e) s.writeHeaders(w, e)
@ -73,14 +75,11 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if err != nil { if err != nil {
log.Println("[ERR]", err) log.Println("[ERR]", err)
} }
} else {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
}
} }
// NewServer creates a new server. // NewServer creates a new server.
func NewServer(addr, directory string) (*Server, error) { func NewServer(addr, directory string) (*Server, error) {
c, err := NewCache(directory) c, err := cache.NewCache(directory)
if err != nil { if err != nil {
return nil, err return nil, err
} }