diff --git a/engine/env.go b/engine/env.go index a16dc35cd9..f370e95ed0 100644 --- a/engine/env.go +++ b/engine/env.go @@ -7,6 +7,8 @@ import ( "io" "strconv" "strings" + + "github.com/docker/docker/utils" ) type Env []string @@ -242,9 +244,10 @@ func (env *Env) Encode(dst io.Writer) error { return nil } -func (env *Env) WriteTo(dst io.Writer) (n int64, err error) { - // FIXME: return the number of bytes written to respect io.WriterTo - return 0, env.Encode(dst) +func (env *Env) WriteTo(dst io.Writer) (int64, error) { + wc := utils.NewWriteCounter(dst) + err := env.Encode(wc) + return wc.Count, err } func (env *Env) Import(src interface{}) (err error) { diff --git a/utils/utils.go b/utils/utils.go index a3e17b886d..6392298214 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -516,3 +516,24 @@ func ReadDockerIgnore(path string) ([]string, error) { } return excludes, nil } + +// Wrap a concrete io.Writer and hold a count of the number +// of bytes written to the writer during a "session". +// This can be convenient when write return is masked +// (e.g., json.Encoder.Encode()) +type WriteCounter struct { + Count int64 + Writer io.Writer +} + +func NewWriteCounter(w io.Writer) *WriteCounter { + return &WriteCounter{ + Writer: w, + } +} + +func (wc *WriteCounter) Write(p []byte) (count int, err error) { + count, err = wc.Writer.Write(p) + wc.Count += int64(count) + return +} diff --git a/utils/utils_test.go b/utils/utils_test.go index ce304482b8..ef1f7af03b 100644 --- a/utils/utils_test.go +++ b/utils/utils_test.go @@ -1,7 +1,9 @@ package utils import ( + "bytes" "os" + "strings" "testing" ) @@ -97,3 +99,26 @@ func TestReadSymlinkedDirectoryToFile(t *testing.T) { t.Errorf("failed to remove symlink: %s", err) } } + +func TestWriteCounter(t *testing.T) { + dummy1 := "This is a dummy string." + dummy2 := "This is another dummy string." + totalLength := int64(len(dummy1) + len(dummy2)) + + reader1 := strings.NewReader(dummy1) + reader2 := strings.NewReader(dummy2) + + var buffer bytes.Buffer + wc := NewWriteCounter(&buffer) + + reader1.WriteTo(wc) + reader2.WriteTo(wc) + + if wc.Count != totalLength { + t.Errorf("Wrong count: %d vs. %d", wc.Count, totalLength) + } + + if buffer.String() != dummy1+dummy2 { + t.Error("Wrong message written") + } +}