From 6e14e6e78b723435a9dbfcff797c11d6fc0d8111 Mon Sep 17 00:00:00 2001 From: Deepti Vaidyanathan Date: Thu, 2 Feb 2023 18:03:30 +0000 Subject: [PATCH] Tracks request and response ID and adds enhanced error messages for common status codes --- pkg/download/blob.go | 7 ++- pkg/download/blob_test.go | 65 +++++++++++++++++++++++++-- pkg/download/blobwithmsitoken.go | 2 + pkg/download/blobwithmsitoken_test.go | 20 ++++++++- pkg/download/downloader.go | 42 +++++++++++++++-- pkg/download/downloader_test.go | 38 +++++++++++----- pkg/download/retry.go | 2 +- pkg/download/retry_test.go | 5 ++- pkg/download/url.go | 13 +++++- pkg/download/url_test.go | 1 + 10 files changed, 171 insertions(+), 24 deletions(-) diff --git a/pkg/download/blob.go b/pkg/download/blob.go index a53c3de..a0baeac 100644 --- a/pkg/download/blob.go +++ b/pkg/download/blob.go @@ -11,6 +11,7 @@ import ( "github.com/Azure/azure-sdk-for-go/storage" "github.com/Azure/run-command-handler-linux/pkg/blobutil" + "github.com/google/uuid" "github.com/pkg/errors" ) @@ -30,7 +31,11 @@ func (b blobDownload) GetRequest() (*http.Request, error) { if err != nil { return nil, err } - return http.NewRequest("GET", url, nil) + req, error := http.NewRequest("GET", url, nil) + if req != nil { + req.Header.Set(xMsClientRequestIdHeaderName, uuid.New().String()) + } + return req, error } // getURL returns publicly downloadable URL of the Azure Blob diff --git a/pkg/download/blob_test.go b/pkg/download/blob_test.go index 8ca9f50..1ccc154 100644 --- a/pkg/download/blob_test.go +++ b/pkg/download/blob_test.go @@ -11,9 +11,14 @@ import ( "github.com/Azure/azure-sdk-for-go/storage" "github.com/Azure/run-command-handler-linux/pkg/blobutil" + "github.com/go-kit/kit/log" "github.com/stretchr/testify/require" ) +var ( + testctx = log.NewContext(log.NewNopLogger()) +) + func Test_blobDownload_validateInputs(t *testing.T) { type sas interface { getURL() (string, error) @@ -72,9 +77,10 @@ func Test_blobDownload_fails_badCreds(t *testing.T) { Container: "foocontainer", }) - status, _, err := Download(d) + status, _, err := Download(testctx, d) require.NotNil(t, err) - require.Contains(t, err.Error(), "Status code 403 while downloading blob") + require.Contains(t, err.Error(), "Please verify the machine has network connectivity") + require.Contains(t, err.Error(), "403") require.Equal(t, status, http.StatusForbidden) } @@ -85,7 +91,7 @@ func Test_blobDownload_fails_urlNotFound(t *testing.T) { Container: "foocontainer", }) - _, _, err := Download(d) + _, _, err := Download(testctx, d) require.NotNil(t, err) require.Contains(t, err.Error(), "http request failed:") } @@ -145,7 +151,7 @@ func Test_blobDownload_actualBlob(t *testing.T) { Blob: name, StorageBase: base, }) - _, body, err := Download(d) + _, body, err := Download(testctx, d) require.Nil(t, err) defer body.Close() b, err := ioutil.ReadAll(body) @@ -153,6 +159,57 @@ func Test_blobDownload_actualBlob(t *testing.T) { require.EqualValues(t, chunk, b, "retrieved body is different body=%d chunk=%d", len(b), len(chunk)) } +func Test_blobDownload_fails_actualBlob404(t *testing.T) { + acct := os.Getenv("AZURE_STORAGE_ACCOUNT") + key := os.Getenv("AZURE_STORAGE_ACCESS_KEY") + if acct == "" || key == "" { + t.Skipf("Skipping: AZURE_STORAGE_ACCOUNT or AZURE_STORAGE_ACCESS_KEY not specified to run this test") + } + base := storage.DefaultBaseURL + + blobName := "" + containerName := "" + + // Get the blob via downloader + d := NewBlobDownload(acct, key, blobutil.AzureBlobRef{ + Container: containerName, + Blob: blobName, + StorageBase: base, + }) + code, _, err := Download(testctx, d) + require.NotNil(t, err) + require.Equal(t, code, http.StatusNotFound) + require.Contains(t, err.Error(), "because it does not exist") + require.Contains(t, err.Error(), "Not Found") + require.Contains(t, err.Error(), "Service request ID") +} + +func Test_blobDownload_fails_actualBlob409(t *testing.T) { + // before running this test, go to your storage account on portal > Configuration and disable Blob public access + acct := os.Getenv("AZURE_STORAGE_ACCOUNT") + key := os.Getenv("AZURE_STORAGE_ACCESS_KEY") + if acct == "" || key == "" { + t.Skipf("Skipping: AZURE_STORAGE_ACCOUNT or AZURE_STORAGE_ACCESS_KEY not specified to run this test") + } + base := storage.DefaultBaseURL + + blobName := "" + containerName := "" + + // Get the blob via downloader + d := NewBlobDownload(acct, key, blobutil.AzureBlobRef{ + Container: containerName, + Blob: blobName, + StorageBase: base, + }) + code, _, err := Download(testctx, d) + require.NotNil(t, err) + require.Equal(t, code, http.StatusConflict) + require.Contains(t, err.Error(), "Please verify the machine has network connectivity") + require.Contains(t, err.Error(), "Public access is not permitted on this storage account") + require.Contains(t, err.Error(), "Service request ID") +} + func Test_blobAppend_actualBlob(t *testing.T) { // Before running the test locally prepare storage account and set the following env variables: // export AZURE_STORAGE_BLOB="https://atanas.blob.core.windows.net/con1/output5.txt" diff --git a/pkg/download/blobwithmsitoken.go b/pkg/download/blobwithmsitoken.go index 23d09df..ede13af 100644 --- a/pkg/download/blobwithmsitoken.go +++ b/pkg/download/blobwithmsitoken.go @@ -8,6 +8,7 @@ import ( "github.com/Azure/azure-extension-foundation/httputil" "github.com/Azure/azure-extension-foundation/msi" + "github.com/google/uuid" "github.com/pkg/errors" ) @@ -55,6 +56,7 @@ func (self *blobWithMsiToken) GetRequest() (*http.Request, error) { return nil, err } + request.Header.Set(xMsClientRequestIdHeaderName, uuid.New().String()) if IsAzureStorageBlobUri(self.url) { request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", msi.AccessToken)) request.Header.Set(xMsVersionHeaderName, xMsVersionValue) diff --git a/pkg/download/blobwithmsitoken_test.go b/pkg/download/blobwithmsitoken_test.go index 0eac802..ec33db0 100644 --- a/pkg/download/blobwithmsitoken_test.go +++ b/pkg/download/blobwithmsitoken_test.go @@ -3,6 +3,7 @@ package download import ( "encoding/json" "io/ioutil" + "net/http" "testing" "github.com/Azure/azure-extension-foundation/msi" @@ -45,7 +46,7 @@ func Test_realDownloadBlobWithMsiToken(t *testing.T) { err := json.Unmarshal([]byte(msiJson), &msi) return msi, err }} - _, stream, err := Download(&downloader) + _, stream, err := Download(testctx, &downloader) require.NoError(t, err, "File download failed") defer stream.Close() @@ -54,6 +55,23 @@ func Test_realDownloadBlobWithMsiToken(t *testing.T) { require.Contains(t, string(bytes), stringToLookFor) } +func Test_realDownloadBlobWithMsiToken404(t *testing.T) { + if msiJson == "" || blobUri == "" || stringToLookFor == "" { + t.Skip() + } + var badBlobUri = blobUri[0 : len(blobUri)-1] + downloader := blobWithMsiToken{badBlobUri, func() (msi.Msi, error) { + msi := msi.Msi{} + err := json.Unmarshal([]byte(msiJson), &msi) + return msi, err + }} + code, _, err := Download(testctx, &downloader) + require.NotNil(t, err, "File download succeeded but was not supposed to") + require.Equal(t, http.StatusNotFound, code) + require.Contains(t, err.Error(), MsiDownload404ErrorString) + require.Contains(t, err.Error(), "Service request ID:") // should have a service request ID since downloading from Azure Storage +} + func Test_isAzureStorageBlobUri(t *testing.T) { require.True(t, IsAzureStorageBlobUri("https://a.blob.core.windows.net/container/blobname")) require.True(t, IsAzureStorageBlobUri("http://mystorageaccountcn.blob.core.chinacloudapi.cn")) diff --git a/pkg/download/downloader.go b/pkg/download/downloader.go index 1dd0ba6..9a78910 100644 --- a/pkg/download/downloader.go +++ b/pkg/download/downloader.go @@ -8,6 +8,7 @@ import ( "time" "github.com/Azure/run-command-handler-linux/pkg/urlutil" + "github.com/go-kit/kit/log" "github.com/pkg/errors" ) @@ -45,12 +46,17 @@ var ( // Download retrieves a response body and checks the response status code to see // if it is 200 OK and then returns the response body. It issues a new request // every time called. It is caller's responsibility to close the response body. -func Download(downloader Downloader) (int, io.ReadCloser, error) { +func Download(ctx *log.Context, downloader Downloader) (int, io.ReadCloser, error) { request, err := downloader.GetRequest() if err != nil { return -1, nil, errors.Wrapf(err, "failed to create http request") } + requestID := request.Header.Get(xMsClientRequestIdHeaderName) + if len(requestID) > 0 { + ctx.Log("info", fmt.Sprintf("starting download with client request ID %s", requestID)) + } + response, err := httpClient.Do(request) if err != nil { err = urlutil.RemoveUrlFromErr(err) @@ -61,7 +67,8 @@ func Download(downloader Downloader) (int, io.ReadCloser, error) { return response.StatusCode, response.Body, nil } - err = fmt.Errorf("Status code %d while downloading blob '%s'. Use either a public script URI that points to .sh file, Azure storage blob SAS URI or storage blob accessible by a managed identity and retry. For more info, refer https://aka.ms/RunCommandManagedLinux", response.StatusCode, request.URL.Opaque) + errString := "" + requestId := response.Header.Get(xMsServiceRequestIdHeaderName) switch downloader.(type) { case *blobWithMsiToken: switch response.StatusCode { @@ -75,6 +82,35 @@ func Download(downloader Downloader) (int, io.ReadCloser, error) { forbiddenError := fmt.Errorf("Make sure managed identity has been given access to container of storage blob '%s' with 'Storage Blob Data Reader' role assignment. In case of user assigned identity, make sure you add it under VM's identity. For more info, refer https://aka.ms/RunCommandManagedLinux", request.URL.Opaque) return response.StatusCode, nil, errors.Wrapf(forbiddenError, MsiDownload403ErrorString) } + default: + hostname := request.URL.Host + switch response.StatusCode { + case http.StatusUnauthorized: + errString = fmt.Sprintf("RunCommand failed to download the file from %s because access was denied. Please fix the blob permissions and try again, the response code and message returned were: %q", + hostname, + response.Status) + case http.StatusNotFound: + errString = fmt.Sprintf("RunCommand failed to download the file from %s because it does not exist. Please create the blob and try again, the response code and message returned were: %q", + hostname, + response.Status) + + case http.StatusBadRequest: + errString = fmt.Sprintf("RunCommand failed to download the file from %s because parts of the request were incorrectly formatted, missing, and/or invalid. The response code and message returned were: %q", + hostname, + response.Status) + + case http.StatusInternalServerError: + errString = fmt.Sprintf("RunCommand failed to download the file from %s due to an issue with storage. The response code and message returned were: %q", + hostname, + response.Status) + default: + errString = fmt.Sprintf("RunCommand failed to download the file from %s because the server returned a response code and message of %q Please verify the machine has network connectivity.", + hostname, + response.Status) + } } - return response.StatusCode, nil, err + if len(requestId) > 0 { + errString += fmt.Sprintf(" (Service request ID: %s)", requestId) + } + return response.StatusCode, nil, fmt.Errorf(errString) } diff --git a/pkg/download/downloader_test.go b/pkg/download/downloader_test.go index 7d1c618..e27825e 100644 --- a/pkg/download/downloader_test.go +++ b/pkg/download/downloader_test.go @@ -10,32 +10,36 @@ import ( "testing" "github.com/Azure/azure-extension-foundation/msi" - "github.com/Azure/run-command-handler-linux/pkg/download" "github.com/ahmetalpbalkan/go-httpbin" + "github.com/go-kit/kit/log" "github.com/stretchr/testify/require" ) type badDownloader struct{ calls int } +var ( + testctx = log.NewContext(log.NewNopLogger()) +) + func (b *badDownloader) GetRequest() (*http.Request, error) { b.calls++ return nil, errors.New("expected error") } func TestDownload_wrapsGetRequestError(t *testing.T) { - _, _, err := download.Download(new(badDownloader)) + _, _, err := download.Download(testctx, new(badDownloader)) require.NotNil(t, err) require.EqualError(t, err, "failed to create http request: expected error") } func TestDownload_wrapsHTTPError(t *testing.T) { - _, _, err := download.Download(download.NewURLDownload("bad url")) + _, _, err := download.Download(testctx, download.NewURLDownload("bad url")) require.NotNil(t, err) require.Contains(t, err.Error(), "http request failed:") } -func TestDownload_badStatusCodeFails(t *testing.T) { +func TestDownload_wrapsCommonErrorCodes(t *testing.T) { srv := httptest.NewServer(httpbin.GetMux()) defer srv.Close() @@ -47,9 +51,21 @@ func TestDownload_badStatusCodeFails(t *testing.T) { http.StatusBadRequest, http.StatusUnauthorized, } { - _, _, err := download.Download(download.NewURLDownload(fmt.Sprintf("%s/status/%d", srv.URL, code))) + respCode, _, err := download.Download(testctx, download.NewURLDownload(fmt.Sprintf("%s/status/%d", srv.URL, code))) require.NotNil(t, err, "not failed for code:%d", code) - require.Contains(t, err.Error(), fmt.Sprintf("Status code %d while downloading blob", code)) + require.Equal(t, code, respCode) + switch respCode { + case http.StatusNotFound: + require.Contains(t, err.Error(), "because it does not exist") + case http.StatusForbidden: + require.Contains(t, err.Error(), "Please verify the machine has network connectivity") + case http.StatusInternalServerError: + require.Contains(t, err.Error(), "due to an issue with storage") + case http.StatusBadRequest: + require.Contains(t, err.Error(), "because parts of the request were incorrectly formatted, missing, and/or invalid") + case http.StatusUnauthorized: + require.Contains(t, err.Error(), "because access was denied") + } } } @@ -57,7 +73,7 @@ func TestDownload_statusOKSucceeds(t *testing.T) { srv := httptest.NewServer(httpbin.GetMux()) defer srv.Close() - _, body, err := download.Download(download.NewURLDownload(srv.URL + "/status/200")) + _, body, err := download.Download(testctx, download.NewURLDownload(srv.URL + "/status/200")) require.Nil(t, err) defer body.Close() require.NotNil(t, body) @@ -72,13 +88,13 @@ func TestDowload_msiDownloaderErrorMessage(t *testing.T) { msiDownloader404 := download.NewBlobWithMsiDownload(srv.URL+"/status/404", mockMsiProvider) - returnCode, body, err := download.Download(msiDownloader404) + returnCode, body, err := download.Download(testctx, msiDownloader404) require.True(t, strings.Contains(err.Error(), download.MsiDownload404ErrorString), "error string doesn't contain the correct message") require.Nil(t, body, "body is not nil for failed download") require.Equal(t, 404, returnCode, "return code was not 404") msiDownloader403 := download.NewBlobWithMsiDownload(srv.URL+"/status/403", mockMsiProvider) - returnCode, body, err = download.Download(msiDownloader403) + returnCode, body, err = download.Download(testctx, msiDownloader403) require.True(t, strings.Contains(err.Error(), download.MsiDownload403ErrorString), "error string doesn't contain the correct message") require.Nil(t, body, "body is not nil for failed download") require.Equal(t, 403, returnCode, "return code was not 403") @@ -89,7 +105,7 @@ func TestDownload_retrievesBody(t *testing.T) { srv := httptest.NewServer(httpbin.GetMux()) defer srv.Close() - _, body, err := download.Download(download.NewURLDownload(srv.URL + "/bytes/65536")) + _, body, err := download.Download(testctx, download.NewURLDownload(srv.URL+"/bytes/65536")) require.Nil(t, err) defer body.Close() b, err := ioutil.ReadAll(body) @@ -101,7 +117,7 @@ func TestDownload_bodyClosesWithoutError(t *testing.T) { srv := httptest.NewServer(httpbin.GetMux()) defer srv.Close() - _, body, err := download.Download(download.NewURLDownload(srv.URL + "/get")) + _, body, err := download.Download(testctx, download.NewURLDownload(srv.URL+"/get")) require.Nil(t, err) require.Nil(t, body.Close(), "body should close fine") } diff --git a/pkg/download/retry.go b/pkg/download/retry.go index fe9d3ec..97fb314 100644 --- a/pkg/download/retry.go +++ b/pkg/download/retry.go @@ -37,7 +37,7 @@ func WithRetries(ctx *log.Context, downloaders []Downloader, sf SleepFunc) (io.R for _, d := range downloaders { for n := 0; n < expRetryN; n++ { ctx := ctx.With("retry", n) - status, out, err := Download(d) + status, out, err := Download(ctx, d) if err == nil { return out, nil } diff --git a/pkg/download/retry_test.go b/pkg/download/retry_test.go index e21e1b0..c39a900 100644 --- a/pkg/download/retry_test.go +++ b/pkg/download/retry_test.go @@ -60,7 +60,8 @@ func TestWithRetries_failingBadStatusCode_validateSleeps(t *testing.T) { sr := new(sleepRecorder) _, err := download.WithRetries(nopLog(), []download.Downloader{d}, sr.Sleep) - require.Contains(t, err.Error(), "Status code 429 while downloading blob") + require.Contains(t, err.Error(), "429 Too Many Requests") + require.Contains(t, err.Error(), "Please verify the machine has network connectivity") require.Equal(t, sleepSchedule, []time.Duration(*sr)) } @@ -104,7 +105,7 @@ func TestRetriesWith_SwitchDownloaderThenFailWithCorrectErrorMessage(t *testing. require.NotNil(t, err, "download with retries should fail") require.Nil(t, resp, "response body should be nil for failed download with retries") require.Equal(t, d404.timesCalled, 1) - require.Contains(t, err.Error(), "Status code 403 while downloading blob") + require.Contains(t, err.Error(), "403 Forbidden") d404 = mockDownloader{0, svr.URL + "/status/404"} msiDownloader404 := download.NewBlobWithMsiDownload(svr.URL+"/status/404", mockMsiProvider) diff --git a/pkg/download/url.go b/pkg/download/url.go index dc8658c..3dec7ea 100644 --- a/pkg/download/url.go +++ b/pkg/download/url.go @@ -3,6 +3,13 @@ package download import ( "net/http" "net/url" + + "github.com/google/uuid" +) + +const ( + xMsClientRequestIdHeaderName = "x-ms-client-request-id" + xMsServiceRequestIdHeaderName = "x-ms-request-id" ) // urlDownload describes a URL to download. @@ -17,7 +24,11 @@ func NewURLDownload(url string) Downloader { // GetRequest returns a new request to download the URL func (u urlDownload) GetRequest() (*http.Request, error) { - return http.NewRequest("GET", u.url, nil) + req, err := http.NewRequest("GET", u.url, nil) + if req != nil { + req.Header.Add(xMsClientRequestIdHeaderName, uuid.New().String()) + } + return req, err } // Scrub query. Used to remove the query parts like SAS token. diff --git a/pkg/download/url_test.go b/pkg/download/url_test.go index aca7f11..caf6670 100644 --- a/pkg/download/url_test.go +++ b/pkg/download/url_test.go @@ -25,6 +25,7 @@ func Test_urlDownload_GetRequest_goodURL(t *testing.T) { r, err := d.GetRequest() require.Nil(t, err, u) require.NotNil(t, r, u) + require.NotNil(t, r.Header.Get(xMsClientRequestIdHeaderName)) } func Test_GetUriForLogging_ScrubsQuery(t *testing.T) {