diff --git a/snappy/decode.go b/snappy/decode.go index 846b96c..552a17b 100644 --- a/snappy/decode.go +++ b/snappy/decode.go @@ -131,15 +131,16 @@ func Decode(dst, src []byte) ([]byte, error) { // NewReader returns a new Reader that decompresses from r, using the framing // format described at // https://code.google.com/p/snappy/source/browse/trunk/framing_format.txt -func NewReader(r io.Reader) io.Reader { - return &reader{ +func NewReader(r io.Reader) *Reader { + return &Reader{ r: r, decoded: make([]byte, maxUncompressedChunkLen), buf: make([]byte, MaxEncodedLen(maxUncompressedChunkLen)+checksumSize), } } -type reader struct { +// Reader is an io.Reader than can read Snappy-compressed bytes. +type Reader struct { r io.Reader err error decoded []byte @@ -149,7 +150,18 @@ type reader struct { readHeader bool } -func (r *reader) readFull(p []byte) (ok bool) { +// Reset discards any buffered data, resets all state, and switches the Snappy +// reader to read from r. This permits reusing a Reader rather than allocating +// a new one. +func (r *Reader) Reset(reader io.Reader) { + r.r = reader + r.err = nil + r.i = 0 + r.j = 0 + r.readHeader = false +} + +func (r *Reader) readFull(p []byte) (ok bool) { if _, r.err = io.ReadFull(r.r, p); r.err != nil { if r.err == io.ErrUnexpectedEOF { r.err = ErrCorrupt @@ -159,7 +171,8 @@ func (r *reader) readFull(p []byte) (ok bool) { return true } -func (r *reader) Read(p []byte) (int, error) { +// Read satisfies the io.Reader interface. +func (r *Reader) Read(p []byte) (int, error) { if r.err != nil { return 0, r.err } diff --git a/snappy/encode.go b/snappy/encode.go index d4713ad..dda3724 100644 --- a/snappy/encode.go +++ b/snappy/encode.go @@ -177,14 +177,15 @@ func MaxEncodedLen(srcLen int) int { // NewWriter returns a new Writer that compresses to w, using the framing // format described at // https://code.google.com/p/snappy/source/browse/trunk/framing_format.txt -func NewWriter(w io.Writer) io.Writer { - return &writer{ +func NewWriter(w io.Writer) *Writer { + return &Writer{ w: w, enc: make([]byte, MaxEncodedLen(maxUncompressedChunkLen)), } } -type writer struct { +// Writer is an io.Writer than can write Snappy-compressed bytes. +type Writer struct { w io.Writer err error enc []byte @@ -192,7 +193,16 @@ type writer struct { wroteHeader bool } -func (w *writer) Write(p []byte) (n int, errRet error) { +// Reset discards the writer's state and switches the Snappy writer to write to +// w. This permits reusing a Writer rather than allocating a new one. +func (w *Writer) Reset(writer io.Writer) { + w.w = writer + w.err = nil + w.wroteHeader = false +} + +// Write satisfies the io.Writer interface. +func (w *Writer) Write(p []byte) (n int, errRet error) { if w.err != nil { return 0, w.err } diff --git a/snappy/snappy_test.go b/snappy/snappy_test.go index c61198a..c76c475 100644 --- a/snappy/snappy_test.go +++ b/snappy/snappy_test.go @@ -79,8 +79,19 @@ func TestSmallRegular(t *testing.T) { } } +func cmp(a, b []byte) error { + if len(a) != len(b) { + return fmt.Errorf("got %d bytes, want %d", len(a), len(b)) + } + for i := range a { + if a[i] != b[i] { + return fmt.Errorf("byte #%d: got 0x%02x, want 0x%02x", i, a[i], b[i]) + } + } + return nil +} + func TestFramingFormat(t *testing.T) { -loop: for _, tf := range testFiles { if err := downloadTestdata(tf.filename); err != nil { t.Fatalf("failed to download testdata: %s", err) @@ -96,17 +107,80 @@ loop: t.Errorf("%s: decoding: %v", tf.filename, err) continue } - if !bytes.Equal(dst, src) { - if len(dst) != len(src) { - t.Errorf("%s: got %d bytes, want %d", tf.filename, len(dst), len(src)) + if err := cmp(dst, src); err != nil { + t.Errorf("%s: %v", tf.filename, err) + continue + } + } +} + +func TestReaderReset(t *testing.T) { + gold := bytes.Repeat([]byte("All that is gold does not glitter,\n"), 10000) + buf := new(bytes.Buffer) + if _, err := NewWriter(buf).Write(gold); err != nil { + t.Fatalf("Write: %v", err) + } + encoded, invalid, partial := buf.String(), "invalid", "partial" + r := NewReader(nil) + for i, s := range []string{encoded, invalid, partial, encoded, partial, invalid, encoded, encoded} { + if s == partial { + r.Reset(strings.NewReader(encoded)) + if _, err := r.Read(make([]byte, 101)); err != nil { + t.Errorf("#%d: %v", i, err) continue } - for i := range dst { - if dst[i] != src[i] { - t.Errorf("%s: byte #%d: got 0x%02x, want 0x%02x", tf.filename, i, dst[i], src[i]) - continue loop - } + continue + } + r.Reset(strings.NewReader(s)) + got, err := ioutil.ReadAll(r) + switch s { + case encoded: + if err != nil { + t.Errorf("#%d: %v", i, err) + continue } + if err := cmp(got, gold); err != nil { + t.Errorf("%#d: %v", i, err) + continue + } + case invalid: + if err == nil { + t.Errorf("#%d: got nil error, want non-nil", i) + continue + } + } + } +} + +func TestWriterReset(t *testing.T) { + gold := bytes.Repeat([]byte("Not all those who wander are lost;\n"), 10000) + var gots, wants [][]byte + const n = 20 + w, failed := NewWriter(nil), false + for i := 0; i <= n; i++ { + buf := new(bytes.Buffer) + w.Reset(buf) + want := gold[:len(gold)*i/n] + if _, err := w.Write(want); err != nil { + t.Errorf("#%d: Write: %v", i, err) + failed = true + continue + } + got, err := ioutil.ReadAll(NewReader(buf)) + if err != nil { + t.Errorf("#%d: ReadAll: %v", i, err) + failed = true + continue + } + gots = append(gots, got) + wants = append(wants, want) + } + if failed { + return + } + for i := range gots { + if err := cmp(gots[i], wants[i]); err != nil { + t.Errorf("#%d: %v", i, err) } } }