From 09f842cdcba63e66ca99cc397b823ceb209d5ecf Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Fri, 6 Sep 2024 13:39:34 -0700 Subject: [PATCH] Remove redundant azidentity error content (#23407) --- sdk/azidentity/CHANGELOG.md | 1 + sdk/azidentity/azidentity_test.go | 71 +++++++++++++++++++ sdk/azidentity/azure_cli_credential_test.go | 2 +- .../azure_developer_cli_credential_test.go | 2 +- sdk/azidentity/azure_pipelines_credential.go | 12 ++-- sdk/azidentity/chained_token_credential.go | 16 +++-- .../chained_token_credential_test.go | 6 +- sdk/azidentity/confidential_client.go | 12 ++-- sdk/azidentity/errors.go | 26 +++++-- sdk/azidentity/errors_test.go | 6 +- sdk/azidentity/managed_identity_client.go | 30 ++++---- sdk/azidentity/public_client.go | 3 +- 12 files changed, 139 insertions(+), 48 deletions(-) diff --git a/sdk/azidentity/CHANGELOG.md b/sdk/azidentity/CHANGELOG.md index 51b874e3ad..3cf2dca290 100644 --- a/sdk/azidentity/CHANGELOG.md +++ b/sdk/azidentity/CHANGELOG.md @@ -10,6 +10,7 @@ ### Bugs Fixed ### Other Changes +* Removed redundant content from error messages ## 1.8.0-beta.2 (2024-08-06) diff --git a/sdk/azidentity/azidentity_test.go b/sdk/azidentity/azidentity_test.go index 27847576b3..0eadd5ae29 100644 --- a/sdk/azidentity/azidentity_test.go +++ b/sdk/azidentity/azidentity_test.go @@ -101,6 +101,77 @@ func (t *tokenRequestCountingPolicy) Do(req *policy.Request) (*http.Response, er return req.Next() } +func TestResponseErrors(t *testing.T) { + // compact removes whitespace from errors to simplify validation + compact := func(s string) string { + return strings.Map(func(r rune) rune { + if r == ' ' || r == '\n' || r == '\t' { + return -1 + } + return r + }, s) + } + content := "no tokens here" + statusCode := http.StatusTeapot + validate := func(t *testing.T, err error) { + require.Error(t, err) + flatErr := compact(err.Error()) + actual := strings.Count(flatErr, compact(http.StatusText(statusCode))) + require.Equal(t, 1, actual, "error message should include response exactly once:\n%s", err.Error()) + actual = strings.Count(flatErr, compact(content)) + require.Equal(t, 1, actual, "error message should include body exactly once:\n%s", err.Error()) + } + + for _, client := range []struct { + name string + ctor func(co policy.ClientOptions) (azcore.TokenCredential, error) + }{ + { + name: "confidential", + ctor: func(co policy.ClientOptions) (azcore.TokenCredential, error) { + return NewClientSecretCredential(fakeTenantID, fakeClientID, fakeSecret, &ClientSecretCredentialOptions{ClientOptions: co}) + }, + }, + { + name: "managed identity", + ctor: func(co policy.ClientOptions) (azcore.TokenCredential, error) { + return NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ClientOptions: co}) + }, + }, + { + name: "public", + ctor: func(co policy.ClientOptions) (azcore.TokenCredential, error) { + return NewUsernamePasswordCredential(fakeTenantID, fakeClientID, "username", "password", &UsernamePasswordCredentialOptions{ClientOptions: co}) + }, + }, + } { + t.Run(client.name, func(t *testing.T) { + cred, err := client.ctor(policy.ClientOptions{ + Retry: policy.RetryOptions{MaxRetries: -1}, + Transport: &mockSTS{ + tokenRequestCallback: func(*http.Request) *http.Response { + return &http.Response{ + Body: io.NopCloser(bytes.NewBufferString(content)), + Status: http.StatusText(statusCode), + StatusCode: statusCode, + } + }, + }, + }) + require.NoError(t, err) + _, err = cred.GetToken(ctx, testTRO) + validate(t, err) + + t.Run("ChainedTokenCredential", func(t *testing.T) { + chain, err := NewChainedTokenCredential([]azcore.TokenCredential{cred}, nil) + require.NoError(t, err) + _, err = chain.GetToken(ctx, testTRO) + validate(t, err) + }) + }) + } +} + func TestTenantID(t *testing.T) { type tc struct { name string diff --git a/sdk/azidentity/azure_cli_credential_test.go b/sdk/azidentity/azure_cli_credential_test.go index d3841c2ff6..083156ece5 100644 --- a/sdk/azidentity/azure_cli_credential_test.go +++ b/sdk/azidentity/azure_cli_credential_test.go @@ -37,7 +37,7 @@ func azTokenOutput(expiresOn string, expires_on int64) []byte { } func mockAzTokenProviderFailure(context.Context, []string, string, string) ([]byte, error) { - return nil, newAuthenticationFailedError(credNameAzureCLI, "mock provider error", nil, nil) + return nil, newAuthenticationFailedError(credNameAzureCLI, "mock provider error", nil) } func mockAzTokenProviderSuccess(context.Context, []string, string, string) ([]byte, error) { diff --git a/sdk/azidentity/azure_developer_cli_credential_test.go b/sdk/azidentity/azure_developer_cli_credential_test.go index f452ccfed5..c83602066b 100644 --- a/sdk/azidentity/azure_developer_cli_credential_test.go +++ b/sdk/azidentity/azure_developer_cli_credential_test.go @@ -22,7 +22,7 @@ var ( `), nil } mockAzdTokenProviderFailure = func(context.Context, []string, string) ([]byte, error) { - return nil, newAuthenticationFailedError(credNameAzureCLI, "mock provider error", nil, nil) + return nil, newAuthenticationFailedError(credNameAzureCLI, "mock provider error", nil) } ) diff --git a/sdk/azidentity/azure_pipelines_credential.go b/sdk/azidentity/azure_pipelines_credential.go index 320551ffb7..6a0fb4a3c0 100644 --- a/sdk/azidentity/azure_pipelines_credential.go +++ b/sdk/azidentity/azure_pipelines_credential.go @@ -114,33 +114,33 @@ func (a *AzurePipelinesCredential) getAssertion(ctx context.Context) (string, er url := a.oidcURI + "?api-version=" + oidcAPIVersion + "&serviceConnectionId=" + a.connectionID url, err := runtime.EncodeQueryParams(url) if err != nil { - return "", newAuthenticationFailedError(credNameAzurePipelines, "couldn't encode OIDC URL: "+err.Error(), nil, nil) + return "", newAuthenticationFailedError(credNameAzurePipelines, "couldn't encode OIDC URL: "+err.Error(), nil) } req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil) if err != nil { - return "", newAuthenticationFailedError(credNameAzurePipelines, "couldn't create OIDC token request: "+err.Error(), nil, nil) + return "", newAuthenticationFailedError(credNameAzurePipelines, "couldn't create OIDC token request: "+err.Error(), nil) } req.Header.Set("Authorization", "Bearer "+a.systemAccessToken) res, err := doForClient(a.cred.client.azClient, req) if err != nil { - return "", newAuthenticationFailedError(credNameAzurePipelines, "couldn't send OIDC token request: "+err.Error(), nil, nil) + return "", newAuthenticationFailedError(credNameAzurePipelines, "couldn't send OIDC token request: "+err.Error(), nil) } if res.StatusCode != http.StatusOK { msg := res.Status + " response from the OIDC endpoint. Check service connection ID and Pipeline configuration" // include the response because its body, if any, probably contains an error message. // OK responses aren't included with errors because they probably contain secrets - return "", newAuthenticationFailedError(credNameAzurePipelines, msg, res, nil) + return "", newAuthenticationFailedError(credNameAzurePipelines, msg, res) } b, err := runtime.Payload(res) if err != nil { - return "", newAuthenticationFailedError(credNameAzurePipelines, "couldn't read OIDC response content: "+err.Error(), nil, nil) + return "", newAuthenticationFailedError(credNameAzurePipelines, "couldn't read OIDC response content: "+err.Error(), nil) } var r struct { OIDCToken string `json:"oidcToken"` } err = json.Unmarshal(b, &r) if err != nil { - return "", newAuthenticationFailedError(credNameAzurePipelines, "unexpected response from OIDC endpoint", nil, nil) + return "", newAuthenticationFailedError(credNameAzurePipelines, "unexpected response from OIDC endpoint", nil) } return r.OIDCToken, nil } diff --git a/sdk/azidentity/chained_token_credential.go b/sdk/azidentity/chained_token_credential.go index 6c35a941b9..2460f66ec1 100644 --- a/sdk/azidentity/chained_token_credential.go +++ b/sdk/azidentity/chained_token_credential.go @@ -113,11 +113,19 @@ func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.Token if err != nil { // return credentialUnavailableError iff all sources did so; return AuthenticationFailedError otherwise msg := createChainedErrorMessage(errs) - if errors.As(err, &unavailableErr) { + var authFailedErr *AuthenticationFailedError + switch { + case errors.As(err, &authFailedErr): + err = newAuthenticationFailedError(c.name, msg, authFailedErr.RawResponse) + if af, ok := err.(*AuthenticationFailedError); ok { + // stop Error() printing the response again; it's already in msg + af.omitResponse = true + } + case errors.As(err, &unavailableErr): err = newCredentialUnavailableError(c.name, msg) - } else { + default: res := getResponseFromError(err) - err = newAuthenticationFailedError(c.name, msg, res, err) + err = newAuthenticationFailedError(c.name, msg, res) } } return token, err @@ -126,7 +134,7 @@ func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.Token func createChainedErrorMessage(errs []error) string { msg := "failed to acquire a token.\nAttempted credentials:" for _, err := range errs { - msg += fmt.Sprintf("\n\t%s", err.Error()) + msg += fmt.Sprintf("\n\t%s", strings.ReplaceAll(err.Error(), "\n", "\n\t\t")) } return msg } diff --git a/sdk/azidentity/chained_token_credential_test.go b/sdk/azidentity/chained_token_credential_test.go index 04cb1eb74a..be26c1f149 100644 --- a/sdk/azidentity/chained_token_credential_test.go +++ b/sdk/azidentity/chained_token_credential_test.go @@ -113,7 +113,7 @@ func TestChainedTokenCredential_GetTokenSuccess(t *testing.T) { func TestChainedTokenCredential_GetTokenFail(t *testing.T) { c := NewFakeCredential() - c.SetResponse(azcore.AccessToken{}, newAuthenticationFailedError("test", "something went wrong", nil, nil)) + c.SetResponse(azcore.AccessToken{}, newAuthenticationFailedError("test", "something went wrong", nil)) cred, err := NewChainedTokenCredential([]azcore.TokenCredential{c}, nil) if err != nil { t.Fatal(err) @@ -158,7 +158,7 @@ func TestChainedTokenCredential_MultipleCredentialsGetTokenAuthenticationFailed( c2 := NewFakeCredential() c2.SetResponse(azcore.AccessToken{}, newCredentialUnavailableError("unavailableCredential2", "Unavailable expected error")) c3 := NewFakeCredential() - c3.SetResponse(azcore.AccessToken{}, newAuthenticationFailedError("authenticationFailedCredential3", "Authentication failed expected error", nil, nil)) + c3.SetResponse(azcore.AccessToken{}, newAuthenticationFailedError("authenticationFailedCredential3", "Authentication failed expected error", nil)) cred, err := NewChainedTokenCredential([]azcore.TokenCredential{c1, c2, c3}, nil) if err != nil { t.Fatal(err) @@ -259,7 +259,7 @@ func TestChainedTokenCredential_Race(t *testing.T) { successFake := NewFakeCredential() successFake.SetResponse(azcore.AccessToken{Token: "*", ExpiresOn: time.Now().Add(time.Hour)}, nil) authFailFake := NewFakeCredential() - authFailFake.SetResponse(azcore.AccessToken{}, newAuthenticationFailedError("", "", nil, nil)) + authFailFake.SetResponse(azcore.AccessToken{}, newAuthenticationFailedError("", "", nil)) unavailableFake := NewFakeCredential() unavailableFake.SetResponse(azcore.AccessToken{}, newCredentialUnavailableError("", "")) diff --git a/sdk/azidentity/confidential_client.go b/sdk/azidentity/confidential_client.go index 4e3e5da4bc..7059a510c2 100644 --- a/sdk/azidentity/confidential_client.go +++ b/sdk/azidentity/confidential_client.go @@ -107,12 +107,12 @@ func (c *confidentialClient) GetToken(ctx context.Context, tro policy.TokenReque } } if err != nil { - // We could get a credentialUnavailableError from managed identity authentication because in that case the error comes from our code. - // We return it directly because it affects the behavior of credential chains. Otherwise, we return AuthenticationFailedError. - var unavailableErr credentialUnavailable - if !errors.As(err, &unavailableErr) { - res := getResponseFromError(err) - err = newAuthenticationFailedError(c.name, err.Error(), res, err) + var ( + authFailedErr *AuthenticationFailedError + unavailableErr credentialUnavailable + ) + if !(errors.As(err, &unavailableErr) || errors.As(err, &authFailedErr)) { + err = newAuthenticationFailedErrorFromMSAL(c.name, err) } } else { msg := fmt.Sprintf("%s.GetToken() acquired a token for scope %q", c.name, strings.Join(ar.GrantedScopes, ", ")) diff --git a/sdk/azidentity/errors.go b/sdk/azidentity/errors.go index f2b0f2e24a..b05cb035a8 100644 --- a/sdk/azidentity/errors.go +++ b/sdk/azidentity/errors.go @@ -38,18 +38,30 @@ type AuthenticationFailedError struct { // RawResponse is the HTTP response motivating the error, if available. RawResponse *http.Response - credType string - message string - err error + credType, message string + omitResponse bool } -func newAuthenticationFailedError(credType string, message string, resp *http.Response, err error) error { - return &AuthenticationFailedError{credType: credType, message: message, RawResponse: resp, err: err} +func newAuthenticationFailedError(credType, message string, resp *http.Response) error { + return &AuthenticationFailedError{credType: credType, message: message, RawResponse: resp} +} + +// newAuthenticationFailedErrorFromMSAL creates an AuthenticationFailedError from an MSAL error. +// If the error is an MSAL CallErr, the new error includes an HTTP response and not the MSAL error +// message, because that message is redundant given the response. If the original error isn't a +// CallErr, the returned error incorporates its message. +func newAuthenticationFailedErrorFromMSAL(credType string, err error) error { + msg := "" + res := getResponseFromError(err) + if res == nil { + msg = err.Error() + } + return newAuthenticationFailedError(credType, msg, res) } // Error implements the error interface. Note that the message contents are not contractual and can change over time. func (e *AuthenticationFailedError) Error() string { - if e.RawResponse == nil { + if e.RawResponse == nil || e.omitResponse { return e.credType + ": " + e.message } msg := &bytes.Buffer{} @@ -62,7 +74,7 @@ func (e *AuthenticationFailedError) Error() string { fmt.Fprintln(msg, "Request information not available") } fmt.Fprintln(msg, "--------------------------------------------------------------------------------") - fmt.Fprintf(msg, "RESPONSE %s\n", e.RawResponse.Status) + fmt.Fprintf(msg, "RESPONSE %d: %s\n", e.RawResponse.StatusCode, e.RawResponse.Status) fmt.Fprintln(msg, "--------------------------------------------------------------------------------") body, err := runtime.Payload(e.RawResponse) switch { diff --git a/sdk/azidentity/errors_test.go b/sdk/azidentity/errors_test.go index 98791b7505..b06d26a183 100644 --- a/sdk/azidentity/errors_test.go +++ b/sdk/azidentity/errors_test.go @@ -31,7 +31,7 @@ func TestAuthenticationFailedErrorInterface(t *testing.T) { Body: io.NopCloser(bytes.NewBufferString(resBodyString)), Request: req, } - err = newAuthenticationFailedError(credNameAzureCLI, "error message", res, nil) + err = newAuthenticationFailedError(credNameAzureCLI, "error message", res) if e, ok := err.(*AuthenticationFailedError); ok { if e.RawResponse == nil { t.Fatal("expected a non-nil RawResponse") @@ -61,7 +61,7 @@ func TestAuthenticationFailedErrorInterface(t *testing.T) { } func TestAuthenticationFailedErrorWithoutResponse(t *testing.T) { - err := newAuthenticationFailedError(credNameAzureCLI, "error message", nil, nil) + err := newAuthenticationFailedError(credNameAzureCLI, "error message", nil) if _, ok := err.(*AuthenticationFailedError); !ok { t.Fatalf("expected AuthenticationFailedError, received %T", err) } @@ -79,7 +79,7 @@ func TestAuthenticationFailedErrorWithoutRequest(t *testing.T) { Body: io.NopCloser(bytes.NewBufferString(resBodyString)), Request: nil, } - err := newAuthenticationFailedError(credNameAzureCLI, "error message", res, nil) + err := newAuthenticationFailedError(credNameAzureCLI, "error message", res) if e, ok := err.(*AuthenticationFailedError); ok { if e.RawResponse == nil { t.Fatal("expected a non-nil RawResponse") diff --git a/sdk/azidentity/managed_identity_client.go b/sdk/azidentity/managed_identity_client.go index 7109c7a22b..4c657a92ec 100644 --- a/sdk/azidentity/managed_identity_client.go +++ b/sdk/azidentity/managed_identity_client.go @@ -250,7 +250,7 @@ func (c *managedIdentityClient) authenticate(ctx context.Context, id ManagedIDKi resp, err := c.azClient.Pipeline().Do(msg) if err != nil { - return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, err.Error(), nil, err) + return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, err.Error(), nil) } if azruntime.HasStatusCode(resp, http.StatusOK, http.StatusCreated) { @@ -261,7 +261,7 @@ func (c *managedIdentityClient) authenticate(ctx context.Context, id ManagedIDKi switch resp.StatusCode { case http.StatusBadRequest: if id != nil { - return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, "the requested identity isn't assigned to this resource", resp, nil) + return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, "the requested identity isn't assigned to this resource", resp) } msg := "failed to authenticate a system assigned identity" if body, err := azruntime.Payload(resp); err == nil && len(body) > 0 { @@ -278,7 +278,7 @@ func (c *managedIdentityClient) authenticate(ctx context.Context, id ManagedIDKi } } - return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, "authentication failed", resp, nil) + return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, "", resp) } func (c *managedIdentityClient) createAccessToken(res *http.Response) (azcore.AccessToken, error) { @@ -306,10 +306,10 @@ func (c *managedIdentityClient) createAccessToken(res *http.Response) (azcore.Ac if expiresOn, err := strconv.Atoi(v); err == nil { return azcore.AccessToken{Token: value.Token, ExpiresOn: time.Unix(int64(expiresOn), 0).UTC()}, nil } - return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, "unexpected expires_on value: "+v, res, nil) + return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, "unexpected expires_on value: "+v, res) default: msg := fmt.Sprintf("unsupported type received in expires_on: %T, %v", v, v) - return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, msg, res, nil) + return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, msg, res) } } @@ -324,7 +324,7 @@ func (c *managedIdentityClient) createAuthRequest(ctx context.Context, id Manage key, err := c.getAzureArcSecretKey(ctx, scopes) if err != nil { msg := fmt.Sprintf("failed to retreive secret key from the identity endpoint: %v", err) - return nil, newAuthenticationFailedError(credNameManagedIdentity, msg, nil, err) + return nil, newAuthenticationFailedError(credNameManagedIdentity, msg, nil) } return c.createAzureArcAuthRequest(ctx, scopes, key) case msiTypeAzureML: @@ -399,9 +399,9 @@ func (c *managedIdentityClient) createAzureMLAuthRequest(ctx context.Context, id case miClientID: q.Set("clientid", id.String()) case miObjectID: - return nil, newAuthenticationFailedError(credNameManagedIdentity, "Azure ML doesn't support specifying a managed identity by object ID", nil, nil) + return nil, newAuthenticationFailedError(credNameManagedIdentity, "Azure ML doesn't support specifying a managed identity by object ID", nil) case miResourceID: - return nil, newAuthenticationFailedError(credNameManagedIdentity, "Azure ML doesn't support specifying a managed identity by resource ID", nil, nil) + return nil, newAuthenticationFailedError(credNameManagedIdentity, "Azure ML doesn't support specifying a managed identity by resource ID", nil) } } request.Raw().URL.RawQuery = q.Encode() @@ -442,34 +442,34 @@ func (c *managedIdentityClient) getAzureArcSecretKey(ctx context.Context, resour // of the secret key file. Any other status code indicates an error in the request. if response.StatusCode != 401 { msg := fmt.Sprintf("expected a 401 response, received %d", response.StatusCode) - return "", newAuthenticationFailedError(credNameManagedIdentity, msg, response, nil) + return "", newAuthenticationFailedError(credNameManagedIdentity, msg, response) } header := response.Header.Get("WWW-Authenticate") if len(header) == 0 { - return "", newAuthenticationFailedError(credNameManagedIdentity, "HIMDS response has no WWW-Authenticate header", nil, nil) + return "", newAuthenticationFailedError(credNameManagedIdentity, "HIMDS response has no WWW-Authenticate header", nil) } // the WWW-Authenticate header is expected in the following format: Basic realm=/some/file/path.key _, p, found := strings.Cut(header, "=") if !found { - return "", newAuthenticationFailedError(credNameManagedIdentity, "unexpected WWW-Authenticate header from HIMDS: "+header, nil, nil) + return "", newAuthenticationFailedError(credNameManagedIdentity, "unexpected WWW-Authenticate header from HIMDS: "+header, nil) } expected, err := arcKeyDirectory() if err != nil { return "", err } if filepath.Dir(p) != expected || !strings.HasSuffix(p, ".key") { - return "", newAuthenticationFailedError(credNameManagedIdentity, "unexpected file path from HIMDS service: "+p, nil, nil) + return "", newAuthenticationFailedError(credNameManagedIdentity, "unexpected file path from HIMDS service: "+p, nil) } f, err := os.Stat(p) if err != nil { - return "", newAuthenticationFailedError(credNameManagedIdentity, fmt.Sprintf("could not stat %q: %v", p, err), nil, nil) + return "", newAuthenticationFailedError(credNameManagedIdentity, fmt.Sprintf("could not stat %q: %v", p, err), nil) } if s := f.Size(); s > 4096 { - return "", newAuthenticationFailedError(credNameManagedIdentity, fmt.Sprintf("key is too large (%d bytes)", s), nil, nil) + return "", newAuthenticationFailedError(credNameManagedIdentity, fmt.Sprintf("key is too large (%d bytes)", s), nil) } key, err := os.ReadFile(p) if err != nil { - return "", newAuthenticationFailedError(credNameManagedIdentity, fmt.Sprintf("could not read %q: %v", p, err), nil, nil) + return "", newAuthenticationFailedError(credNameManagedIdentity, fmt.Sprintf("could not read %q: %v", p, err), nil) } return string(key), nil } diff --git a/sdk/azidentity/public_client.go b/sdk/azidentity/public_client.go index 5669ee9b1e..73363e1c9e 100644 --- a/sdk/azidentity/public_client.go +++ b/sdk/azidentity/public_client.go @@ -244,8 +244,7 @@ func (p *publicClient) token(ar public.AuthResult, err error) (azcore.AccessToke if err == nil { p.record, err = newAuthenticationRecord(ar) } else { - res := getResponseFromError(err) - err = newAuthenticationFailedError(p.name, err.Error(), res, err) + err = newAuthenticationFailedErrorFromMSAL(p.name, err) } return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err }