diff --git a/google/default.go b/google/default.go index 04ebdc05..02ccd08a 100644 --- a/google/default.go +++ b/google/default.go @@ -12,6 +12,7 @@ import ( "os" "path/filepath" "runtime" + "sync" "time" "cloud.google.com/go/compute/metadata" @@ -41,12 +42,20 @@ type Credentials struct { // running on Google Cloud Platform. JSON []byte + udMu sync.Mutex // guards universeDomain // universeDomain is the default service domain for a given Cloud universe. universeDomain string } // UniverseDomain returns the default service domain for a given Cloud universe. +// // The default value is "googleapis.com". +// +// Deprecated: Use instead (*Credentials).GetUniverseDomain(), which supports +// obtaining the universe domain when authenticating via the GCE metadata server. +// Unlike GetUniverseDomain, this method, UniverseDomain, will always return the +// default value when authenticating via the GCE metadata server. +// See also [The attached service account](https://cloud.google.com/docs/authentication/application-default-credentials#attached-sa). func (c *Credentials) UniverseDomain() string { if c.universeDomain == "" { return universeDomainDefault @@ -54,6 +63,55 @@ func (c *Credentials) UniverseDomain() string { return c.universeDomain } +// GetUniverseDomain returns the default service domain for a given Cloud +// universe. +// +// The default value is "googleapis.com". +// +// It obtains the universe domain from the attached service account on GCE when +// authenticating via the GCE metadata server. See also [The attached service +// account](https://cloud.google.com/docs/authentication/application-default-credentials#attached-sa). +// If the GCE metadata server returns a 404 error, the default value is +// returned. If the GCE metadata server returns an error other than 404, the +// error is returned. +func (c *Credentials) GetUniverseDomain() (string, error) { + c.udMu.Lock() + defer c.udMu.Unlock() + if c.universeDomain == "" && metadata.OnGCE() { + // If we're on Google Compute Engine, an App Engine standard second + // generation runtime, or App Engine flexible, use the metadata server. + err := c.computeUniverseDomain() + if err != nil { + return "", err + } + } + // If not on Google Compute Engine, or in case of any non-error path in + // computeUniverseDomain that did not set universeDomain, set the default + // universe domain. + if c.universeDomain == "" { + c.universeDomain = universeDomainDefault + } + return c.universeDomain, nil +} + +// computeUniverseDomain fetches the default service domain for a given Cloud +// universe from Google Compute Engine (GCE)'s metadata server. It's only valid +// to use this method if your program is running on a GCE instance. +func (c *Credentials) computeUniverseDomain() error { + var err error + c.universeDomain, err = metadata.Get("universe/universe_domain") + if err != nil { + if _, ok := err.(metadata.NotDefinedError); ok { + // http.StatusNotFound (404) + c.universeDomain = universeDomainDefault + return nil + } else { + return err + } + } + return nil +} + // DefaultCredentials is the old name of Credentials. // // Deprecated: use Credentials instead. diff --git a/google/default_test.go b/google/default_test.go index 439887ac..7352ffcc 100644 --- a/google/default_test.go +++ b/google/default_test.go @@ -6,6 +6,9 @@ package google import ( "context" + "net/http" + "net/http/httptest" + "strings" "testing" ) @@ -74,6 +77,9 @@ func TestCredentialsFromJSONWithParams_SA(t *testing.T) { if want := "googleapis.com"; creds.UniverseDomain() != want { t.Fatalf("got %q, want %q", creds.UniverseDomain(), want) } + if want := "googleapis.com"; creds.UniverseDomain() != want { + t.Fatalf("got %q, want %q", creds.UniverseDomain(), want) + } } func TestCredentialsFromJSONWithParams_SA_Params_UniverseDomain(t *testing.T) { @@ -94,6 +100,9 @@ func TestCredentialsFromJSONWithParams_SA_Params_UniverseDomain(t *testing.T) { if creds.UniverseDomain() != universeDomain2 { t.Fatalf("got %q, want %q", creds.UniverseDomain(), universeDomain2) } + if creds.UniverseDomain() != universeDomain2 { + t.Fatalf("got %q, want %q", creds.UniverseDomain(), universeDomain2) + } } func TestCredentialsFromJSONWithParams_SA_UniverseDomain(t *testing.T) { @@ -113,6 +122,13 @@ func TestCredentialsFromJSONWithParams_SA_UniverseDomain(t *testing.T) { if creds.UniverseDomain() != universeDomain { t.Fatalf("got %q, want %q", creds.UniverseDomain(), universeDomain) } + got, err := creds.GetUniverseDomain() + if err != nil { + t.Fatal(err) + } + if got != universeDomain { + t.Fatalf("got %q, want %q", got, universeDomain) + } } func TestCredentialsFromJSONWithParams_SA_UniverseDomain_Params_UniverseDomain(t *testing.T) { @@ -133,6 +149,13 @@ func TestCredentialsFromJSONWithParams_SA_UniverseDomain_Params_UniverseDomain(t if creds.UniverseDomain() != universeDomain2 { t.Fatalf("got %q, want %q", creds.UniverseDomain(), universeDomain2) } + got, err := creds.GetUniverseDomain() + if err != nil { + t.Fatal(err) + } + if got != universeDomain2 { + t.Fatalf("got %q, want %q", got, universeDomain2) + } } func TestCredentialsFromJSONWithParams_User(t *testing.T) { @@ -149,6 +172,13 @@ func TestCredentialsFromJSONWithParams_User(t *testing.T) { if want := "googleapis.com"; creds.UniverseDomain() != want { t.Fatalf("got %q, want %q", creds.UniverseDomain(), want) } + got, err := creds.GetUniverseDomain() + if err != nil { + t.Fatal(err) + } + if want := "googleapis.com"; got != want { + t.Fatalf("got %q, want %q", got, want) + } } func TestCredentialsFromJSONWithParams_User_Params_UniverseDomain(t *testing.T) { @@ -166,6 +196,13 @@ func TestCredentialsFromJSONWithParams_User_Params_UniverseDomain(t *testing.T) if want := "googleapis.com"; creds.UniverseDomain() != want { t.Fatalf("got %q, want %q", creds.UniverseDomain(), want) } + got, err := creds.GetUniverseDomain() + if err != nil { + t.Fatal(err) + } + if want := "googleapis.com"; got != want { + t.Fatalf("got %q, want %q", got, want) + } } func TestCredentialsFromJSONWithParams_User_UniverseDomain(t *testing.T) { @@ -182,6 +219,13 @@ func TestCredentialsFromJSONWithParams_User_UniverseDomain(t *testing.T) { if want := "googleapis.com"; creds.UniverseDomain() != want { t.Fatalf("got %q, want %q", creds.UniverseDomain(), want) } + got, err := creds.GetUniverseDomain() + if err != nil { + t.Fatal(err) + } + if want := "googleapis.com"; got != want { + t.Fatalf("got %q, want %q", got, want) + } } func TestCredentialsFromJSONWithParams_User_UniverseDomain_Params_UniverseDomain(t *testing.T) { @@ -199,4 +243,55 @@ func TestCredentialsFromJSONWithParams_User_UniverseDomain_Params_UniverseDomain if want := "googleapis.com"; creds.UniverseDomain() != want { t.Fatalf("got %q, want %q", creds.UniverseDomain(), want) } + got, err := creds.GetUniverseDomain() + if err != nil { + t.Fatal(err) + } + if want := "googleapis.com"; got != want { + t.Fatalf("got %q, want %q", got, want) + } +} + +func TestComputeUniverseDomain(t *testing.T) { + universeDomainPath := "/computeMetadata/v1/universe/universe_domain" + universeDomainResponseBody := "example.com" + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != universeDomainPath { + t.Errorf("got %s, want %s", r.URL.Path, universeDomainPath) + } + w.Write([]byte(universeDomainResponseBody)) + })) + defer s.Close() + t.Setenv("GCE_METADATA_HOST", strings.TrimPrefix(s.URL, "http://")) + + scope := "https://www.googleapis.com/auth/cloud-platform" + params := CredentialsParams{ + Scopes: []string{scope}, + } + // Copied from FindDefaultCredentialsWithParams, metadata.OnGCE() = true block + creds := &Credentials{ + ProjectID: "fake_project", + TokenSource: computeTokenSource("", params.EarlyTokenRefresh, params.Scopes...), + universeDomain: params.UniverseDomain, // empty + } + c := make(chan bool) + go func() { + got, err := creds.GetUniverseDomain() // First conflicting access. + if err != nil { + t.Error(err) + } + if want := universeDomainResponseBody; got != want { + t.Errorf("got %q, want %q", got, want) + } + c <- true + }() + got, err := creds.GetUniverseDomain() // Second conflicting access. + <-c + if err != nil { + t.Error(err) + } + if want := universeDomainResponseBody; got != want { + t.Errorf("got %q, want %q", got, want) + } + } diff --git a/google/google_test.go b/google/google_test.go index ea010494..7078d429 100644 --- a/google/google_test.go +++ b/google/google_test.go @@ -5,6 +5,8 @@ package google import ( + "net/http" + "net/http/httptest" "strings" "testing" ) @@ -137,3 +139,21 @@ func TestJWTConfigFromJSONNoAudience(t *testing.T) { t.Errorf("Audience = %q; want %q", got, want) } } + +func TestComputeTokenSource(t *testing.T) { + tokenPath := "/computeMetadata/v1/instance/service-accounts/default/token" + tokenResponseBody := `{"access_token":"Sample.Access.Token","token_type":"Bearer","expires_in":3600}` + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != tokenPath { + t.Errorf("got %s, want %s", r.URL.Path, tokenPath) + } + w.Write([]byte(tokenResponseBody)) + })) + defer s.Close() + t.Setenv("GCE_METADATA_HOST", strings.TrimPrefix(s.URL, "http://")) + ts := ComputeTokenSource("") + _, err := ts.Token() + if err != nil { + t.Errorf("ts.Token() = %v", err) + } +}