s3gof3r/getter.go

359 строки
7.5 KiB
Go

package s3gof3r
import (
"crypto/md5"
"fmt"
"hash"
"io"
"io/ioutil"
"math"
"net/http"
"net/url"
"sync"
"syscall"
"time"
)
const qWaitMax = 2
type getter struct {
url url.URL
b *Bucket
bufsz int64
err error
chunkID int
rChunk *chunk
contentLen int64
bytesRead int64
chunkTotal int
readCh chan *chunk
getCh chan *chunk
quit chan struct{}
workerAborted chan struct{}
abortOnce sync.Once
qWait map[int]*chunk
qWaitLen uint
cond sync.Cond
sp *bp
closed bool
c *Config
md5 hash.Hash
cIdx int64
}
type chunk struct {
id int
header http.Header
start int64
size int64
b []byte
}
func newGetter(getURL url.URL, c *Config, b *Bucket) (io.ReadCloser, http.Header, error) {
g := new(getter)
g.url = getURL
g.c, g.b = new(Config), new(Bucket)
*g.c, *g.b = *c, *b
g.bufsz = max64(c.PartSize, 1)
g.c.NTry = max(c.NTry, 1)
g.c.Concurrency = max(c.Concurrency, 1)
g.getCh = make(chan *chunk)
g.readCh = make(chan *chunk)
g.quit = make(chan struct{})
g.workerAborted = make(chan struct{})
g.qWait = make(map[int]*chunk)
g.b = b
g.md5 = md5.New()
g.cond = sync.Cond{L: &sync.Mutex{}}
// use get instead of head for error messaging
resp, err := g.retryRequest("GET", g.url.String(), nil)
if err != nil {
return nil, nil, err
}
if resp.StatusCode != 200 {
return nil, nil, newRespError(resp)
}
defer checkClose(resp.Body, err)
// Golang changes content-length to -1 when chunked transfer encoding / EOF close response detected
if resp.ContentLength == -1 {
return nil, nil, fmt.Errorf("Retrieving objects with undefined content-length " +
" responses (chunked transfer encoding / EOF close) is not supported")
}
g.contentLen = resp.ContentLength
g.chunkTotal = int((g.contentLen + g.bufsz - 1) / g.bufsz) // round up, integer division
logger.debugPrintf("object size: %3.2g MB", float64(g.contentLen)/float64((1*mb)))
g.sp = bufferPool(g.bufsz)
for i := 0; i < g.c.Concurrency; i++ {
go g.worker()
}
go g.initChunks()
return g, resp.Header, nil
}
func (g *getter) retryRequest(method, urlStr string, body io.ReadSeeker) (resp *http.Response, err error) {
for i := 0; i < g.c.NTry; i++ {
var req *http.Request
req, err = http.NewRequest(method, urlStr, body)
if err != nil {
return
}
if body != nil {
req.Header.Set(sha256Header, shaReader(body))
}
g.b.Sign(req)
resp, err = g.c.Client.Do(req)
if err == nil && resp.StatusCode == 500 {
time.Sleep(time.Duration(math.Exp2(float64(i))) * 100 * time.Millisecond) // exponential back-off
continue
}
if err == nil {
return
}
logger.debugPrintln(err)
if body != nil {
if _, err = body.Seek(0, 0); err != nil {
return
}
}
}
return
}
func (g *getter) initChunks() {
id := 0
for i := int64(0); i < g.contentLen; {
size := min64(g.bufsz, g.contentLen-i)
c := &chunk{
id: id,
header: http.Header{
"Range": {fmt.Sprintf("bytes=%d-%d",
i, i+size-1)},
},
start: i,
size: size,
b: nil,
}
i += size
id++
g.getCh <- c
}
close(g.getCh)
}
func (g *getter) worker() {
for c := range g.getCh {
g.retryGetChunk(c)
if g.err != nil {
// tell Read() caller that 1 or more chunks can't be read; abort
g.abortOnce.Do(func() { close(g.workerAborted) })
break
}
}
}
func (g *getter) retryGetChunk(c *chunk) {
var err error
c.b = <-g.sp.get
for i := 0; i < g.c.NTry; i++ {
err = g.getChunk(c)
if err == nil {
return
}
logger.debugPrintf("error on attempt %d: retrying chunk: %v, error: %s", i, c.id, err)
time.Sleep(time.Duration(math.Exp2(float64(i))) * 100 * time.Millisecond) // exponential back-off
}
select {
case <-g.quit: // check for closed quit channel before setting error
return
default:
g.err = err
}
}
func (g *getter) getChunk(c *chunk) error {
// ensure buffer is empty
r, err := http.NewRequest("GET", g.url.String(), nil)
if err != nil {
return err
}
r.Header = c.header
g.b.Sign(r)
resp, err := g.c.Client.Do(r)
if err != nil {
return err
}
if resp.StatusCode != 206 && resp.StatusCode != 200 {
return newRespError(resp)
}
defer checkClose(resp.Body, err)
n, err := io.ReadAtLeast(resp.Body, c.b, int(c.size))
if err != nil {
return err
}
if err := resp.Body.Close(); err != nil {
return err
}
if int64(n) != c.size {
return fmt.Errorf("chunk %d: Expected %d bytes, received %d",
c.id, c.size, n)
}
g.readCh <- c
// wait for qWait to drain before starting next chunk
g.cond.L.Lock()
for g.qWaitLen >= qWaitMax {
if g.closed {
return nil
}
g.cond.Wait()
}
g.cond.L.Unlock()
return nil
}
func (g *getter) Read(p []byte) (int, error) {
var err error
if g.closed {
return 0, syscall.EINVAL
}
if g.err != nil {
return 0, g.err
}
nw := 0
for nw < len(p) {
if g.bytesRead == g.contentLen {
return nw, io.EOF
} else if g.bytesRead > g.contentLen {
// Here for robustness / completeness
// Should not occur as golang uses LimitedReader up to content-length
return nw, fmt.Errorf("Expected %d bytes, received %d (too many bytes)",
g.contentLen, g.bytesRead)
}
// If for some reason no more chunks to be read and bytes are off, error, incomplete result
if g.chunkID >= g.chunkTotal {
return nw, fmt.Errorf("Expected %d bytes, received %d and chunkID %d >= chunkTotal %d (no more chunks remaining)",
g.contentLen, g.bytesRead, g.chunkID, g.chunkTotal)
}
if g.rChunk == nil {
g.rChunk, err = g.nextChunk()
if err != nil {
return 0, err
}
g.cIdx = 0
}
n := copy(p[nw:], g.rChunk.b[g.cIdx:g.rChunk.size])
g.cIdx += int64(n)
nw += n
g.bytesRead += int64(n)
if g.cIdx >= g.rChunk.size { // chunk complete
g.sp.give <- g.rChunk.b
g.chunkID++
g.rChunk = nil
}
}
return nw, nil
}
func (g *getter) nextChunk() (*chunk, error) {
for {
// first check qWait
c := g.qWait[g.chunkID]
if c != nil {
delete(g.qWait, g.chunkID)
g.cond.L.Lock()
g.qWaitLen--
g.cond.L.Unlock()
g.cond.Signal() // wake up waiting worker goroutine
if g.c.Md5Check {
if _, err := g.md5.Write(c.b[:c.size]); err != nil {
return nil, err
}
}
return c, nil
}
// if next chunk not in qWait, read from channel
select {
case c := <-g.readCh:
g.qWait[c.id] = c
g.cond.L.Lock()
g.qWaitLen++
g.cond.L.Unlock()
case <-g.workerAborted:
return nil, g.err // worker aborted, quit
case <-g.quit:
return nil, g.err // fatal error, quit.
}
}
}
func (g *getter) Close() error {
if g.closed {
return syscall.EINVAL
}
g.closed = true
close(g.sp.quit)
close(g.quit)
g.cond.Broadcast()
if g.err != nil {
return g.err
}
if g.bytesRead != g.contentLen {
return fmt.Errorf("read error: %d bytes read. expected: %d", g.bytesRead, g.contentLen)
}
if g.c.Md5Check {
if err := g.checkMd5(); err != nil {
return err
}
}
return nil
}
func (g *getter) checkMd5() (err error) {
calcMd5 := fmt.Sprintf("%x", g.md5.Sum(nil))
md5Path := fmt.Sprint(".md5", g.url.Path, ".md5")
md5Url, err := g.b.url(md5Path, g.c)
if err != nil {
return err
}
logger.debugPrintln("md5: ", calcMd5)
logger.debugPrintln("md5Path: ", md5Path)
resp, err := g.retryRequest("GET", md5Url.String(), nil)
if err != nil {
return
}
if resp.StatusCode != 200 {
return fmt.Errorf("MD5 check failed: %s not found: %s", md5Url.String(), newRespError(resp))
}
defer checkClose(resp.Body, err)
givenMd5, err := ioutil.ReadAll(resp.Body)
if err != nil {
return
}
if calcMd5 != string(givenMd5) {
return fmt.Errorf("MD5 mismatch. given:%s calculated:%s", givenMd5, calcMd5)
}
return
}