diff --git a/transport.go b/transport.go index 92ac7e25..c55bfa01 100644 --- a/transport.go +++ b/transport.go @@ -34,6 +34,15 @@ type Transport struct { // access token. If no token exists or token is expired, // tries to refresh/fetch a new token. func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { + reqBodyClosed := false + if req.Body != nil { + defer func() { + if !reqBodyClosed { + req.Body.Close() + } + }() + } + if t.Source == nil { return nil, errors.New("oauth2: Transport's Source is nil") } @@ -46,6 +55,10 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { token.SetAuthHeader(req2) t.setModReq(req, req2) res, err := t.base().RoundTrip(req2) + + // req.Body is assumed to have been closed by the base RoundTripper. + reqBodyClosed = true + if err != nil { t.setModReq(req, nil) return nil, err diff --git a/transport_test.go b/transport_test.go index d6e8087d..faa87d51 100644 --- a/transport_test.go +++ b/transport_test.go @@ -1,6 +1,8 @@ package oauth2 import ( + "errors" + "io" "net/http" "net/http/httptest" "testing" @@ -27,6 +29,64 @@ func TestTransportNilTokenSource(t *testing.T) { } } +type readCloseCounter struct { + CloseCount int + ReadErr error +} + +func (r *readCloseCounter) Read(b []byte) (int, error) { + return 0, r.ReadErr +} + +func (r *readCloseCounter) Close() error { + r.CloseCount++ + return nil +} + +func TestTransportCloseRequestBody(t *testing.T) { + tr := &Transport{} + server := newMockServer(func(w http.ResponseWriter, r *http.Request) {}) + defer server.Close() + client := &http.Client{Transport: tr} + body := &readCloseCounter{ + ReadErr: errors.New("readCloseCounter.Read not implemented"), + } + resp, err := client.Post(server.URL, "application/json", body) + if err == nil { + t.Errorf("got no errors, want an error with nil token source") + } + if resp != nil { + t.Errorf("Response = %v; want nil", resp) + } + if expected := 1; body.CloseCount != expected { + t.Errorf("Body was closed %d times, expected %d", body.CloseCount, expected) + } +} + +func TestTransportCloseRequestBodySuccess(t *testing.T) { + tr := &Transport{ + Source: StaticTokenSource(&Token{ + AccessToken: "abc", + }), + } + server := newMockServer(func(w http.ResponseWriter, r *http.Request) {}) + defer server.Close() + client := &http.Client{Transport: tr} + body := &readCloseCounter{ + ReadErr: io.EOF, + } + resp, err := client.Post(server.URL, "application/json", body) + if err != nil { + t.Errorf("got error %v; expected none", err) + } + if resp == nil { + t.Errorf("Response is nil; expected non-nil") + } + if expected := 1; body.CloseCount != expected { + t.Errorf("Body was closed %d times, expected %d", body.CloseCount, expected) + } +} + func TestTransportTokenSource(t *testing.T) { ts := &tokenSource{ token: &Token{