diff --git a/jwt/jwt.go b/jwt/jwt.go index 99f3e0a3..b2bf1829 100644 --- a/jwt/jwt.go +++ b/jwt/jwt.go @@ -66,6 +66,14 @@ type Config struct { // request. If empty, the value of TokenURL is used as the // intended audience. Audience string + + // PrivateClaims optionally specifies custom private claims in the JWT. + // See http://tools.ietf.org/html/draft-jones-json-web-token-10#section-4.3 + PrivateClaims map[string]interface{} + + // UseIDToken optionally specifies whether ID token should be used instead + // of access token when the server returns both. + UseIDToken bool } // TokenSource returns a JWT TokenSource using the configuration @@ -97,9 +105,10 @@ func (js jwtSource) Token() (*oauth2.Token, error) { } hc := oauth2.NewClient(js.ctx, nil) claimSet := &jws.ClaimSet{ - Iss: js.conf.Email, - Scope: strings.Join(js.conf.Scopes, " "), - Aud: js.conf.TokenURL, + Iss: js.conf.Email, + Scope: strings.Join(js.conf.Scopes, " "), + Aud: js.conf.TokenURL, + PrivateClaims: js.conf.PrivateClaims, } if subject := js.conf.Subject; subject != "" { claimSet.Sub = subject @@ -166,5 +175,11 @@ func (js jwtSource) Token() (*oauth2.Token, error) { } token.Expiry = time.Unix(claimSet.Exp, 0) } + if js.conf.UseIDToken { + if tokenRes.IDToken == "" { + return nil, fmt.Errorf("oauth2: response doesn't have JWT token") + } + token.AccessToken = tokenRes.IDToken + } return token, nil } diff --git a/jwt/jwt_test.go b/jwt/jwt_test.go index 9dfa3b35..9772dc52 100644 --- a/jwt/jwt_test.go +++ b/jwt/jwt_test.go @@ -11,6 +11,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "reflect" "strings" "testing" @@ -221,6 +222,16 @@ func TestJWTFetch_AssertionPayload(t *testing.T) { TokenURL: ts.URL, Audience: "https://example.com", }, + { + Email: "aaa2@xxx.com", + PrivateKey: dummyPrivateKey, + PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ", + TokenURL: ts.URL, + PrivateClaims: map[string]interface{}{ + "private0": "claim0", + "private1": "claim1", + }, + }, } { t.Run(conf.Email, func(t *testing.T) { _, err := conf.TokenSource(context.Background()).Token() @@ -261,6 +272,18 @@ func TestJWTFetch_AssertionPayload(t *testing.T) { if got, want := claimSet.Prn, conf.Subject; got != want { t.Errorf("payload prn = %q; want %q", got, want) } + if len(conf.PrivateClaims) > 0 { + var got interface{} + if err := json.Unmarshal(gotjson, &got); err != nil { + t.Errorf("failed to parse payload; err = %q", err) + } + m := got.(map[string]interface{}) + for v, k := range conf.PrivateClaims { + if !reflect.DeepEqual(m[v], k) { + t.Errorf("payload private claims key = %q: got %#v; want %#v", v, m[v], k) + } + } + } }) } }