package s3gof3r import ( "crypto/md5" "fmt" "hash" "io" "io/ioutil" "math" "net/http" "net/url" "sync" "syscall" "time" ) const qWaitMax = 2 type getter struct { url url.URL b *Bucket bufsz int64 err error chunkID int rChunk *chunk contentLen int64 bytesRead int64 chunkTotal int readCh chan *chunk getCh chan *chunk quit chan struct{} workerAborted chan struct{} abortOnce sync.Once qWait map[int]*chunk qWaitLen uint cond sync.Cond sp *bp closed bool c *Config md5 hash.Hash cIdx int64 } type chunk struct { id int header http.Header start int64 size int64 b []byte } func newGetter(getURL url.URL, c *Config, b *Bucket) (io.ReadCloser, http.Header, error) { g := new(getter) g.url = getURL g.c, g.b = new(Config), new(Bucket) *g.c, *g.b = *c, *b g.bufsz = max64(c.PartSize, 1) g.c.NTry = max(c.NTry, 1) g.c.Concurrency = max(c.Concurrency, 1) g.getCh = make(chan *chunk) g.readCh = make(chan *chunk) g.quit = make(chan struct{}) g.workerAborted = make(chan struct{}) g.qWait = make(map[int]*chunk) g.b = b g.md5 = md5.New() g.cond = sync.Cond{L: &sync.Mutex{}} // use get instead of head for error messaging resp, err := g.retryRequest("GET", g.url.String(), nil) if err != nil { return nil, nil, err } if resp.StatusCode != 200 { return nil, nil, newRespError(resp) } defer checkClose(resp.Body, err) // Golang changes content-length to -1 when chunked transfer encoding / EOF close response detected if resp.ContentLength == -1 { return nil, nil, fmt.Errorf("Retrieving objects with undefined content-length " + " responses (chunked transfer encoding / EOF close) is not supported") } g.contentLen = resp.ContentLength g.chunkTotal = int((g.contentLen + g.bufsz - 1) / g.bufsz) // round up, integer division logger.debugPrintf("object size: %3.2g MB", float64(g.contentLen)/float64((1*mb))) g.sp = bufferPool(g.bufsz) for i := 0; i < g.c.Concurrency; i++ { go g.worker() } go g.initChunks() return g, resp.Header, nil } func (g *getter) retryRequest(method, urlStr string, body io.ReadSeeker) (resp *http.Response, err error) { for i := 0; i < g.c.NTry; i++ { var req *http.Request req, err = http.NewRequest(method, urlStr, body) if err != nil { return } if body != nil { req.Header.Set(sha256Header, shaReader(body)) } g.b.Sign(req) resp, err = g.c.Client.Do(req) if err == nil && resp.StatusCode == 500 { time.Sleep(time.Duration(math.Exp2(float64(i))) * 100 * time.Millisecond) // exponential back-off continue } if err == nil { return } logger.debugPrintln(err) if body != nil { if _, err = body.Seek(0, 0); err != nil { return } } } return } func (g *getter) initChunks() { id := 0 for i := int64(0); i < g.contentLen; { size := min64(g.bufsz, g.contentLen-i) c := &chunk{ id: id, header: http.Header{ "Range": {fmt.Sprintf("bytes=%d-%d", i, i+size-1)}, }, start: i, size: size, b: nil, } i += size id++ g.getCh <- c } close(g.getCh) } func (g *getter) worker() { for c := range g.getCh { g.retryGetChunk(c) if g.err != nil { // tell Read() caller that 1 or more chunks can't be read; abort g.abortOnce.Do(func() { close(g.workerAborted) }) break } } } func (g *getter) retryGetChunk(c *chunk) { var err error c.b = <-g.sp.get for i := 0; i < g.c.NTry; i++ { err = g.getChunk(c) if err == nil { return } logger.debugPrintf("error on attempt %d: retrying chunk: %v, error: %s", i, c.id, err) time.Sleep(time.Duration(math.Exp2(float64(i))) * 100 * time.Millisecond) // exponential back-off } select { case <-g.quit: // check for closed quit channel before setting error return default: g.err = err } } func (g *getter) getChunk(c *chunk) error { // ensure buffer is empty r, err := http.NewRequest("GET", g.url.String(), nil) if err != nil { return err } r.Header = c.header g.b.Sign(r) resp, err := g.c.Client.Do(r) if err != nil { return err } if resp.StatusCode != 206 && resp.StatusCode != 200 { return newRespError(resp) } defer checkClose(resp.Body, err) n, err := io.ReadAtLeast(resp.Body, c.b, int(c.size)) if err != nil { return err } if err := resp.Body.Close(); err != nil { return err } if int64(n) != c.size { return fmt.Errorf("chunk %d: Expected %d bytes, received %d", c.id, c.size, n) } g.readCh <- c // wait for qWait to drain before starting next chunk g.cond.L.Lock() for g.qWaitLen >= qWaitMax { if g.closed { return nil } g.cond.Wait() } g.cond.L.Unlock() return nil } func (g *getter) Read(p []byte) (int, error) { var err error if g.closed { return 0, syscall.EINVAL } if g.err != nil { return 0, g.err } nw := 0 for nw < len(p) { if g.bytesRead == g.contentLen { return nw, io.EOF } else if g.bytesRead > g.contentLen { // Here for robustness / completeness // Should not occur as golang uses LimitedReader up to content-length return nw, fmt.Errorf("Expected %d bytes, received %d (too many bytes)", g.contentLen, g.bytesRead) } // If for some reason no more chunks to be read and bytes are off, error, incomplete result if g.chunkID >= g.chunkTotal { return nw, fmt.Errorf("Expected %d bytes, received %d and chunkID %d >= chunkTotal %d (no more chunks remaining)", g.contentLen, g.bytesRead, g.chunkID, g.chunkTotal) } if g.rChunk == nil { g.rChunk, err = g.nextChunk() if err != nil { return 0, err } g.cIdx = 0 } n := copy(p[nw:], g.rChunk.b[g.cIdx:g.rChunk.size]) g.cIdx += int64(n) nw += n g.bytesRead += int64(n) if g.cIdx >= g.rChunk.size { // chunk complete g.sp.give <- g.rChunk.b g.chunkID++ g.rChunk = nil } } return nw, nil } func (g *getter) nextChunk() (*chunk, error) { for { // first check qWait c := g.qWait[g.chunkID] if c != nil { delete(g.qWait, g.chunkID) g.cond.L.Lock() g.qWaitLen-- g.cond.L.Unlock() g.cond.Signal() // wake up waiting worker goroutine if g.c.Md5Check { if _, err := g.md5.Write(c.b[:c.size]); err != nil { return nil, err } } return c, nil } // if next chunk not in qWait, read from channel select { case c := <-g.readCh: g.qWait[c.id] = c g.cond.L.Lock() g.qWaitLen++ g.cond.L.Unlock() case <-g.workerAborted: return nil, g.err // worker aborted, quit case <-g.quit: return nil, g.err // fatal error, quit. } } } func (g *getter) Close() error { if g.closed { return syscall.EINVAL } g.closed = true close(g.sp.quit) close(g.quit) g.cond.Broadcast() if g.err != nil { return g.err } if g.bytesRead != g.contentLen { return fmt.Errorf("read error: %d bytes read. expected: %d", g.bytesRead, g.contentLen) } if g.c.Md5Check { if err := g.checkMd5(); err != nil { return err } } return nil } func (g *getter) checkMd5() (err error) { calcMd5 := fmt.Sprintf("%x", g.md5.Sum(nil)) md5Path := fmt.Sprint(".md5", g.url.Path, ".md5") md5Url, err := g.b.url(md5Path, g.c) if err != nil { return err } logger.debugPrintln("md5: ", calcMd5) logger.debugPrintln("md5Path: ", md5Path) resp, err := g.retryRequest("GET", md5Url.String(), nil) if err != nil { return } if resp.StatusCode != 200 { return fmt.Errorf("MD5 check failed: %s not found: %s", md5Url.String(), newRespError(resp)) } defer checkClose(resp.Body, err) givenMd5, err := ioutil.ReadAll(resp.Body) if err != nil { return } if calcMd5 != string(givenMd5) { return fmt.Errorf("MD5 mismatch. given:%s calculated:%s", givenMd5, calcMd5) } return }