diff --git a/google/google.go b/google/google.go index 81de32b3..3d49d607 100644 --- a/google/google.go +++ b/google/google.go @@ -159,13 +159,49 @@ func ComputeTokenSource(account string, scope ...string) oauth2.TokenSource { return oauth2.ReuseTokenSource(nil, computeSource{account: account, scopes: scope}) } +// ComputeTokenSourceOnGCE forces the token fetcher to run as it should on a GCE +// VM, without performing the "is on GCE" check. +func ComputeTokenSourceOnGCE(account string, scope ...string) oauth2.TokenSource { + return oauth2.ReuseTokenSource(nil, computeSource{account: account, scopes: scope, gce: true}) +} + +// ComuputeTokenSourceWithClient allows the metadata.Client to be customized, +// for use cases where the default metadata.Client is not suitable. Note that +// this will still try to use the default metadata.Client to check whether we +// are running on GCE, unless GCE_METADATA_HOST is set in the environment. +func ComputeTokenSourceWithClient(client *metadata.Client, account string, scope ...string) oauth2.TokenSource { + return oauth2.ReuseTokenSource(nil, computeSource{account: account, scopes: scope, client: client}) +} + +// ComputeTokenSourceWithClientOnGCE allows the metadata.Client to be customized, +// and also forces the token fetcher to run as it should on a GCE VM. +func ComputeTokenSourceWithClientOnGCE(client *metadata.Client, account string, scope ...string) oauth2.TokenSource { + return oauth2.ReuseTokenSource(nil, computeSource{account: account, scopes: scope, client: client, gce: true}) +} + type computeSource struct { account string scopes []string + client *metadata.Client + gce bool +} + +func (cs computeSource) fetchToken(tokenURL string) (string, error) { + if cs.client != nil { + return cs.client.Get(tokenURI) + } + return metadata.Get(tokenURI) +} + +func (cs computeSource) onGCE() bool { + if cs.gce { + return true + } + return metadata.OnGCE() } func (cs computeSource) Token() (*oauth2.Token, error) { - if !metadata.OnGCE() { + if !cs.onGCE() { return nil, errors.New("oauth2/google: can't get a token from the metadata service; not running on GCE") } acct := cs.account @@ -178,7 +214,7 @@ func (cs computeSource) Token() (*oauth2.Token, error) { v.Set("scopes", strings.Join(cs.scopes, ",")) tokenURI = tokenURI + "?" + v.Encode() } - tokenJSON, err := metadata.Get(tokenURI) + tokenJSON, err := cs.fetchToken(tokenURI) if err != nil { return nil, err }