diff --git a/oauth2.go b/oauth2.go index 62f1a6a9..31cae9ad 100644 --- a/oauth2.go +++ b/oauth2.go @@ -33,7 +33,7 @@ type tokenRespBody struct { type TokenFetcher interface { // FetchToken retrieves a new access token for the provider. // If the implementation doesn't know how to retrieve a new token, - // it returns an error. + // it returns an error. Existing token could be nil. FetchToken(existing *Token) (*Token, error) } @@ -173,12 +173,11 @@ func (c *Config) FetchToken(existing *Token) (*Token, error) { if existing == nil || existing.RefreshToken == "" { return nil, errors.New("cannot fetch access token without refresh token.") } - err := c.updateToken(existing, url.Values{ + return c.retrieveToken(url.Values{ "grant_type": {"refresh_token"}, "client_secret": {c.opts.ClientSecret}, "refresh_token": {existing.RefreshToken}, }) - return existing, err } // Checks if all required configuration fields have non-zero values. @@ -195,7 +194,6 @@ func (c *Config) validate() error { // Exchange exchanges the exchange code with the OAuth 2.0 provider // to retrieve a new access token. func (c *Config) Exchange(exchangeCode string) (*Token, error) { - token := &Token{} vals := url.Values{ "grant_type": {"authorization_code"}, "client_secret": {c.opts.ClientSecret}, @@ -207,15 +205,10 @@ func (c *Config) Exchange(exchangeCode string) (*Token, error) { if c.opts.RedirectURL != "" { vals.Set("redirect_uri", c.opts.RedirectURL) } - err := c.updateToken(token, vals) - if err != nil { - return nil, err - } - return token, nil + return c.retrieveToken(vals) } -// updateToken mutates both tok and v. -func (c *Config) updateToken(tok *Token, v url.Values) error { +func (c *Config) retrieveToken(v url.Values) (*Token, error) { v.Set("client_id", c.opts.ClientID) // Note that we're not setting v's client_secret to t.ClientSecret, due // to https://code.google.com/p/goauth2/issues/detail?id=31 @@ -225,12 +218,12 @@ func (c *Config) updateToken(tok *Token, v url.Values) error { // so that's all we use. r, err := http.DefaultClient.PostForm(c.tokenURL.String(), v) if err != nil { - return err + return nil, err } defer r.Body.Close() if r.StatusCode != 200 { // TODO(jbd): Add status code or error message - return errors.New("Error during updating token.") + return nil, errors.New("error during updating token") } resp := &tokenRespBody{} @@ -239,11 +232,11 @@ func (c *Config) updateToken(tok *Token, v url.Values) error { case "application/x-www-form-urlencoded", "text/plain": body, err := ioutil.ReadAll(r.Body) if err != nil { - return err + return nil, err } vals, err := url.ParseQuery(string(body)) if err != nil { - return err + return nil, err } resp.AccessToken = vals.Get("access_token") resp.TokenType = vals.Get("token_type") @@ -252,25 +245,29 @@ func (c *Config) updateToken(tok *Token, v url.Values) error { resp.IdToken = vals.Get("id_token") default: if err = json.NewDecoder(r.Body).Decode(&resp); err != nil { - return err + return nil, err } } - tok.AccessToken = resp.AccessToken - tok.TokenType = resp.TokenType + token := &Token{ + AccessToken: resp.AccessToken, + TokenType: resp.TokenType, + RefreshToken: resp.RefreshToken, + } // Don't overwrite `RefreshToken` with an empty value - if resp.RefreshToken != "" { - tok.RefreshToken = resp.RefreshToken + // if this was a token refreshing request. + if resp.RefreshToken == "" { + token.RefreshToken = v.Get("refresh_token") } if resp.ExpiresIn == 0 { - tok.Expiry = time.Time{} + token.Expiry = time.Time{} } else { - tok.Expiry = time.Now().Add(time.Duration(resp.ExpiresIn) * time.Second) + token.Expiry = time.Now().Add(time.Duration(resp.ExpiresIn) * time.Second) } if resp.IdToken != "" { - if tok.Extra == nil { - tok.Extra = make(map[string]string) + if token.Extra == nil { + token.Extra = make(map[string]string) } - tok.Extra["id_token"] = resp.IdToken + token.Extra["id_token"] = resp.IdToken } - return nil + return token, nil }