snappy: add Reset methods to Reader and Writer.

LGTM=bradfitz
R=bradfitz
CC=golang-codereviews
https://codereview.appspot.com/202990043
This commit is contained in:
Nigel Tao 2015-02-10 14:07:10 +11:00
Родитель 4c08685702
Коммит eaed4addcd
3 изменённых файлов: 115 добавлений и 18 удалений

Просмотреть файл

@ -131,15 +131,16 @@ func Decode(dst, src []byte) ([]byte, error) {
// NewReader returns a new Reader that decompresses from r, using the framing // NewReader returns a new Reader that decompresses from r, using the framing
// format described at // format described at
// https://code.google.com/p/snappy/source/browse/trunk/framing_format.txt // https://code.google.com/p/snappy/source/browse/trunk/framing_format.txt
func NewReader(r io.Reader) io.Reader { func NewReader(r io.Reader) *Reader {
return &reader{ return &Reader{
r: r, r: r,
decoded: make([]byte, maxUncompressedChunkLen), decoded: make([]byte, maxUncompressedChunkLen),
buf: make([]byte, MaxEncodedLen(maxUncompressedChunkLen)+checksumSize), buf: make([]byte, MaxEncodedLen(maxUncompressedChunkLen)+checksumSize),
} }
} }
type reader struct { // Reader is an io.Reader than can read Snappy-compressed bytes.
type Reader struct {
r io.Reader r io.Reader
err error err error
decoded []byte decoded []byte
@ -149,7 +150,18 @@ type reader struct {
readHeader bool readHeader bool
} }
func (r *reader) readFull(p []byte) (ok bool) { // Reset discards any buffered data, resets all state, and switches the Snappy
// reader to read from r. This permits reusing a Reader rather than allocating
// a new one.
func (r *Reader) Reset(reader io.Reader) {
r.r = reader
r.err = nil
r.i = 0
r.j = 0
r.readHeader = false
}
func (r *Reader) readFull(p []byte) (ok bool) {
if _, r.err = io.ReadFull(r.r, p); r.err != nil { if _, r.err = io.ReadFull(r.r, p); r.err != nil {
if r.err == io.ErrUnexpectedEOF { if r.err == io.ErrUnexpectedEOF {
r.err = ErrCorrupt r.err = ErrCorrupt
@ -159,7 +171,8 @@ func (r *reader) readFull(p []byte) (ok bool) {
return true return true
} }
func (r *reader) Read(p []byte) (int, error) { // Read satisfies the io.Reader interface.
func (r *Reader) Read(p []byte) (int, error) {
if r.err != nil { if r.err != nil {
return 0, r.err return 0, r.err
} }

Просмотреть файл

@ -177,14 +177,15 @@ func MaxEncodedLen(srcLen int) int {
// NewWriter returns a new Writer that compresses to w, using the framing // NewWriter returns a new Writer that compresses to w, using the framing
// format described at // format described at
// https://code.google.com/p/snappy/source/browse/trunk/framing_format.txt // https://code.google.com/p/snappy/source/browse/trunk/framing_format.txt
func NewWriter(w io.Writer) io.Writer { func NewWriter(w io.Writer) *Writer {
return &writer{ return &Writer{
w: w, w: w,
enc: make([]byte, MaxEncodedLen(maxUncompressedChunkLen)), enc: make([]byte, MaxEncodedLen(maxUncompressedChunkLen)),
} }
} }
type writer struct { // Writer is an io.Writer than can write Snappy-compressed bytes.
type Writer struct {
w io.Writer w io.Writer
err error err error
enc []byte enc []byte
@ -192,7 +193,16 @@ type writer struct {
wroteHeader bool wroteHeader bool
} }
func (w *writer) Write(p []byte) (n int, errRet error) { // Reset discards the writer's state and switches the Snappy writer to write to
// w. This permits reusing a Writer rather than allocating a new one.
func (w *Writer) Reset(writer io.Writer) {
w.w = writer
w.err = nil
w.wroteHeader = false
}
// Write satisfies the io.Writer interface.
func (w *Writer) Write(p []byte) (n int, errRet error) {
if w.err != nil { if w.err != nil {
return 0, w.err return 0, w.err
} }

Просмотреть файл

@ -79,8 +79,19 @@ func TestSmallRegular(t *testing.T) {
} }
} }
func cmp(a, b []byte) error {
if len(a) != len(b) {
return fmt.Errorf("got %d bytes, want %d", len(a), len(b))
}
for i := range a {
if a[i] != b[i] {
return fmt.Errorf("byte #%d: got 0x%02x, want 0x%02x", i, a[i], b[i])
}
}
return nil
}
func TestFramingFormat(t *testing.T) { func TestFramingFormat(t *testing.T) {
loop:
for _, tf := range testFiles { for _, tf := range testFiles {
if err := downloadTestdata(tf.filename); err != nil { if err := downloadTestdata(tf.filename); err != nil {
t.Fatalf("failed to download testdata: %s", err) t.Fatalf("failed to download testdata: %s", err)
@ -96,17 +107,80 @@ loop:
t.Errorf("%s: decoding: %v", tf.filename, err) t.Errorf("%s: decoding: %v", tf.filename, err)
continue continue
} }
if !bytes.Equal(dst, src) { if err := cmp(dst, src); err != nil {
if len(dst) != len(src) { t.Errorf("%s: %v", tf.filename, err)
t.Errorf("%s: got %d bytes, want %d", tf.filename, len(dst), len(src)) continue
}
}
}
func TestReaderReset(t *testing.T) {
gold := bytes.Repeat([]byte("All that is gold does not glitter,\n"), 10000)
buf := new(bytes.Buffer)
if _, err := NewWriter(buf).Write(gold); err != nil {
t.Fatalf("Write: %v", err)
}
encoded, invalid, partial := buf.String(), "invalid", "partial"
r := NewReader(nil)
for i, s := range []string{encoded, invalid, partial, encoded, partial, invalid, encoded, encoded} {
if s == partial {
r.Reset(strings.NewReader(encoded))
if _, err := r.Read(make([]byte, 101)); err != nil {
t.Errorf("#%d: %v", i, err)
continue continue
} }
for i := range dst { continue
if dst[i] != src[i] { }
t.Errorf("%s: byte #%d: got 0x%02x, want 0x%02x", tf.filename, i, dst[i], src[i]) r.Reset(strings.NewReader(s))
continue loop got, err := ioutil.ReadAll(r)
} switch s {
case encoded:
if err != nil {
t.Errorf("#%d: %v", i, err)
continue
} }
if err := cmp(got, gold); err != nil {
t.Errorf("%#d: %v", i, err)
continue
}
case invalid:
if err == nil {
t.Errorf("#%d: got nil error, want non-nil", i)
continue
}
}
}
}
func TestWriterReset(t *testing.T) {
gold := bytes.Repeat([]byte("Not all those who wander are lost;\n"), 10000)
var gots, wants [][]byte
const n = 20
w, failed := NewWriter(nil), false
for i := 0; i <= n; i++ {
buf := new(bytes.Buffer)
w.Reset(buf)
want := gold[:len(gold)*i/n]
if _, err := w.Write(want); err != nil {
t.Errorf("#%d: Write: %v", i, err)
failed = true
continue
}
got, err := ioutil.ReadAll(NewReader(buf))
if err != nil {
t.Errorf("#%d: ReadAll: %v", i, err)
failed = true
continue
}
gots = append(gots, got)
wants = append(wants, want)
}
if failed {
return
}
for i := range gots {
if err := cmp(gots[i], wants[i]); err != nil {
t.Errorf("#%d: %v", i, err)
} }
} }
} }