Tracks request and response ID and adds enhanced error messages for common status codes

This commit is contained in:
Deepti Vaidyanathan 2023-02-02 18:03:30 +00:00
Родитель 34c16bc772
Коммит 6e14e6e78b
10 изменённых файлов: 171 добавлений и 24 удалений

Просмотреть файл

@ -11,6 +11,7 @@ import (
"github.com/Azure/azure-sdk-for-go/storage"
"github.com/Azure/run-command-handler-linux/pkg/blobutil"
"github.com/google/uuid"
"github.com/pkg/errors"
)
@ -30,7 +31,11 @@ func (b blobDownload) GetRequest() (*http.Request, error) {
if err != nil {
return nil, err
}
return http.NewRequest("GET", url, nil)
req, error := http.NewRequest("GET", url, nil)
if req != nil {
req.Header.Set(xMsClientRequestIdHeaderName, uuid.New().String())
}
return req, error
}
// getURL returns publicly downloadable URL of the Azure Blob

Просмотреть файл

@ -11,9 +11,14 @@ import (
"github.com/Azure/azure-sdk-for-go/storage"
"github.com/Azure/run-command-handler-linux/pkg/blobutil"
"github.com/go-kit/kit/log"
"github.com/stretchr/testify/require"
)
var (
testctx = log.NewContext(log.NewNopLogger())
)
func Test_blobDownload_validateInputs(t *testing.T) {
type sas interface {
getURL() (string, error)
@ -72,9 +77,10 @@ func Test_blobDownload_fails_badCreds(t *testing.T) {
Container: "foocontainer",
})
status, _, err := Download(d)
status, _, err := Download(testctx, d)
require.NotNil(t, err)
require.Contains(t, err.Error(), "Status code 403 while downloading blob")
require.Contains(t, err.Error(), "Please verify the machine has network connectivity")
require.Contains(t, err.Error(), "403")
require.Equal(t, status, http.StatusForbidden)
}
@ -85,7 +91,7 @@ func Test_blobDownload_fails_urlNotFound(t *testing.T) {
Container: "foocontainer",
})
_, _, err := Download(d)
_, _, err := Download(testctx, d)
require.NotNil(t, err)
require.Contains(t, err.Error(), "http request failed:")
}
@ -145,7 +151,7 @@ func Test_blobDownload_actualBlob(t *testing.T) {
Blob: name,
StorageBase: base,
})
_, body, err := Download(d)
_, body, err := Download(testctx, d)
require.Nil(t, err)
defer body.Close()
b, err := ioutil.ReadAll(body)
@ -153,6 +159,57 @@ func Test_blobDownload_actualBlob(t *testing.T) {
require.EqualValues(t, chunk, b, "retrieved body is different body=%d chunk=%d", len(b), len(chunk))
}
func Test_blobDownload_fails_actualBlob404(t *testing.T) {
acct := os.Getenv("AZURE_STORAGE_ACCOUNT")
key := os.Getenv("AZURE_STORAGE_ACCESS_KEY")
if acct == "" || key == "" {
t.Skipf("Skipping: AZURE_STORAGE_ACCOUNT or AZURE_STORAGE_ACCESS_KEY not specified to run this test")
}
base := storage.DefaultBaseURL
blobName := "<BLOB THAT DOESN'T EXIST>"
containerName := "<CONTAINER NAME>"
// Get the blob via downloader
d := NewBlobDownload(acct, key, blobutil.AzureBlobRef{
Container: containerName,
Blob: blobName,
StorageBase: base,
})
code, _, err := Download(testctx, d)
require.NotNil(t, err)
require.Equal(t, code, http.StatusNotFound)
require.Contains(t, err.Error(), "because it does not exist")
require.Contains(t, err.Error(), "Not Found")
require.Contains(t, err.Error(), "Service request ID")
}
func Test_blobDownload_fails_actualBlob409(t *testing.T) {
// before running this test, go to your storage account on portal > Configuration and disable Blob public access
acct := os.Getenv("AZURE_STORAGE_ACCOUNT")
key := os.Getenv("AZURE_STORAGE_ACCESS_KEY")
if acct == "" || key == "" {
t.Skipf("Skipping: AZURE_STORAGE_ACCOUNT or AZURE_STORAGE_ACCESS_KEY not specified to run this test")
}
base := storage.DefaultBaseURL
blobName := "<BLOB NAME>"
containerName := "<CONTAINER NAME>"
// Get the blob via downloader
d := NewBlobDownload(acct, key, blobutil.AzureBlobRef{
Container: containerName,
Blob: blobName,
StorageBase: base,
})
code, _, err := Download(testctx, d)
require.NotNil(t, err)
require.Equal(t, code, http.StatusConflict)
require.Contains(t, err.Error(), "Please verify the machine has network connectivity")
require.Contains(t, err.Error(), "Public access is not permitted on this storage account")
require.Contains(t, err.Error(), "Service request ID")
}
func Test_blobAppend_actualBlob(t *testing.T) {
// Before running the test locally prepare storage account and set the following env variables:
// export AZURE_STORAGE_BLOB="https://atanas.blob.core.windows.net/con1/output5.txt"

Просмотреть файл

@ -8,6 +8,7 @@ import (
"github.com/Azure/azure-extension-foundation/httputil"
"github.com/Azure/azure-extension-foundation/msi"
"github.com/google/uuid"
"github.com/pkg/errors"
)
@ -55,6 +56,7 @@ func (self *blobWithMsiToken) GetRequest() (*http.Request, error) {
return nil, err
}
request.Header.Set(xMsClientRequestIdHeaderName, uuid.New().String())
if IsAzureStorageBlobUri(self.url) {
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", msi.AccessToken))
request.Header.Set(xMsVersionHeaderName, xMsVersionValue)

Просмотреть файл

@ -3,6 +3,7 @@ package download
import (
"encoding/json"
"io/ioutil"
"net/http"
"testing"
"github.com/Azure/azure-extension-foundation/msi"
@ -45,7 +46,7 @@ func Test_realDownloadBlobWithMsiToken(t *testing.T) {
err := json.Unmarshal([]byte(msiJson), &msi)
return msi, err
}}
_, stream, err := Download(&downloader)
_, stream, err := Download(testctx, &downloader)
require.NoError(t, err, "File download failed")
defer stream.Close()
@ -54,6 +55,23 @@ func Test_realDownloadBlobWithMsiToken(t *testing.T) {
require.Contains(t, string(bytes), stringToLookFor)
}
func Test_realDownloadBlobWithMsiToken404(t *testing.T) {
if msiJson == "" || blobUri == "" || stringToLookFor == "" {
t.Skip()
}
var badBlobUri = blobUri[0 : len(blobUri)-1]
downloader := blobWithMsiToken{badBlobUri, func() (msi.Msi, error) {
msi := msi.Msi{}
err := json.Unmarshal([]byte(msiJson), &msi)
return msi, err
}}
code, _, err := Download(testctx, &downloader)
require.NotNil(t, err, "File download succeeded but was not supposed to")
require.Equal(t, http.StatusNotFound, code)
require.Contains(t, err.Error(), MsiDownload404ErrorString)
require.Contains(t, err.Error(), "Service request ID:") // should have a service request ID since downloading from Azure Storage
}
func Test_isAzureStorageBlobUri(t *testing.T) {
require.True(t, IsAzureStorageBlobUri("https://a.blob.core.windows.net/container/blobname"))
require.True(t, IsAzureStorageBlobUri("http://mystorageaccountcn.blob.core.chinacloudapi.cn"))

Просмотреть файл

@ -8,6 +8,7 @@ import (
"time"
"github.com/Azure/run-command-handler-linux/pkg/urlutil"
"github.com/go-kit/kit/log"
"github.com/pkg/errors"
)
@ -45,12 +46,17 @@ var (
// Download retrieves a response body and checks the response status code to see
// if it is 200 OK and then returns the response body. It issues a new request
// every time called. It is caller's responsibility to close the response body.
func Download(downloader Downloader) (int, io.ReadCloser, error) {
func Download(ctx *log.Context, downloader Downloader) (int, io.ReadCloser, error) {
request, err := downloader.GetRequest()
if err != nil {
return -1, nil, errors.Wrapf(err, "failed to create http request")
}
requestID := request.Header.Get(xMsClientRequestIdHeaderName)
if len(requestID) > 0 {
ctx.Log("info", fmt.Sprintf("starting download with client request ID %s", requestID))
}
response, err := httpClient.Do(request)
if err != nil {
err = urlutil.RemoveUrlFromErr(err)
@ -61,7 +67,8 @@ func Download(downloader Downloader) (int, io.ReadCloser, error) {
return response.StatusCode, response.Body, nil
}
err = fmt.Errorf("Status code %d while downloading blob '%s'. Use either a public script URI that points to .sh file, Azure storage blob SAS URI or storage blob accessible by a managed identity and retry. For more info, refer https://aka.ms/RunCommandManagedLinux", response.StatusCode, request.URL.Opaque)
errString := ""
requestId := response.Header.Get(xMsServiceRequestIdHeaderName)
switch downloader.(type) {
case *blobWithMsiToken:
switch response.StatusCode {
@ -75,6 +82,35 @@ func Download(downloader Downloader) (int, io.ReadCloser, error) {
forbiddenError := fmt.Errorf("Make sure managed identity has been given access to container of storage blob '%s' with 'Storage Blob Data Reader' role assignment. In case of user assigned identity, make sure you add it under VM's identity. For more info, refer https://aka.ms/RunCommandManagedLinux", request.URL.Opaque)
return response.StatusCode, nil, errors.Wrapf(forbiddenError, MsiDownload403ErrorString)
}
default:
hostname := request.URL.Host
switch response.StatusCode {
case http.StatusUnauthorized:
errString = fmt.Sprintf("RunCommand failed to download the file from %s because access was denied. Please fix the blob permissions and try again, the response code and message returned were: %q",
hostname,
response.Status)
case http.StatusNotFound:
errString = fmt.Sprintf("RunCommand failed to download the file from %s because it does not exist. Please create the blob and try again, the response code and message returned were: %q",
hostname,
response.Status)
case http.StatusBadRequest:
errString = fmt.Sprintf("RunCommand failed to download the file from %s because parts of the request were incorrectly formatted, missing, and/or invalid. The response code and message returned were: %q",
hostname,
response.Status)
case http.StatusInternalServerError:
errString = fmt.Sprintf("RunCommand failed to download the file from %s due to an issue with storage. The response code and message returned were: %q",
hostname,
response.Status)
default:
errString = fmt.Sprintf("RunCommand failed to download the file from %s because the server returned a response code and message of %q Please verify the machine has network connectivity.",
hostname,
response.Status)
}
}
return response.StatusCode, nil, err
if len(requestId) > 0 {
errString += fmt.Sprintf(" (Service request ID: %s)", requestId)
}
return response.StatusCode, nil, fmt.Errorf(errString)
}

Просмотреть файл

@ -10,32 +10,36 @@ import (
"testing"
"github.com/Azure/azure-extension-foundation/msi"
"github.com/Azure/run-command-handler-linux/pkg/download"
"github.com/ahmetalpbalkan/go-httpbin"
"github.com/go-kit/kit/log"
"github.com/stretchr/testify/require"
)
type badDownloader struct{ calls int }
var (
testctx = log.NewContext(log.NewNopLogger())
)
func (b *badDownloader) GetRequest() (*http.Request, error) {
b.calls++
return nil, errors.New("expected error")
}
func TestDownload_wrapsGetRequestError(t *testing.T) {
_, _, err := download.Download(new(badDownloader))
_, _, err := download.Download(testctx, new(badDownloader))
require.NotNil(t, err)
require.EqualError(t, err, "failed to create http request: expected error")
}
func TestDownload_wrapsHTTPError(t *testing.T) {
_, _, err := download.Download(download.NewURLDownload("bad url"))
_, _, err := download.Download(testctx, download.NewURLDownload("bad url"))
require.NotNil(t, err)
require.Contains(t, err.Error(), "http request failed:")
}
func TestDownload_badStatusCodeFails(t *testing.T) {
func TestDownload_wrapsCommonErrorCodes(t *testing.T) {
srv := httptest.NewServer(httpbin.GetMux())
defer srv.Close()
@ -47,9 +51,21 @@ func TestDownload_badStatusCodeFails(t *testing.T) {
http.StatusBadRequest,
http.StatusUnauthorized,
} {
_, _, err := download.Download(download.NewURLDownload(fmt.Sprintf("%s/status/%d", srv.URL, code)))
respCode, _, err := download.Download(testctx, download.NewURLDownload(fmt.Sprintf("%s/status/%d", srv.URL, code)))
require.NotNil(t, err, "not failed for code:%d", code)
require.Contains(t, err.Error(), fmt.Sprintf("Status code %d while downloading blob", code))
require.Equal(t, code, respCode)
switch respCode {
case http.StatusNotFound:
require.Contains(t, err.Error(), "because it does not exist")
case http.StatusForbidden:
require.Contains(t, err.Error(), "Please verify the machine has network connectivity")
case http.StatusInternalServerError:
require.Contains(t, err.Error(), "due to an issue with storage")
case http.StatusBadRequest:
require.Contains(t, err.Error(), "because parts of the request were incorrectly formatted, missing, and/or invalid")
case http.StatusUnauthorized:
require.Contains(t, err.Error(), "because access was denied")
}
}
}
@ -57,7 +73,7 @@ func TestDownload_statusOKSucceeds(t *testing.T) {
srv := httptest.NewServer(httpbin.GetMux())
defer srv.Close()
_, body, err := download.Download(download.NewURLDownload(srv.URL + "/status/200"))
_, body, err := download.Download(testctx, download.NewURLDownload(srv.URL + "/status/200"))
require.Nil(t, err)
defer body.Close()
require.NotNil(t, body)
@ -72,13 +88,13 @@ func TestDowload_msiDownloaderErrorMessage(t *testing.T) {
msiDownloader404 := download.NewBlobWithMsiDownload(srv.URL+"/status/404", mockMsiProvider)
returnCode, body, err := download.Download(msiDownloader404)
returnCode, body, err := download.Download(testctx, msiDownloader404)
require.True(t, strings.Contains(err.Error(), download.MsiDownload404ErrorString), "error string doesn't contain the correct message")
require.Nil(t, body, "body is not nil for failed download")
require.Equal(t, 404, returnCode, "return code was not 404")
msiDownloader403 := download.NewBlobWithMsiDownload(srv.URL+"/status/403", mockMsiProvider)
returnCode, body, err = download.Download(msiDownloader403)
returnCode, body, err = download.Download(testctx, msiDownloader403)
require.True(t, strings.Contains(err.Error(), download.MsiDownload403ErrorString), "error string doesn't contain the correct message")
require.Nil(t, body, "body is not nil for failed download")
require.Equal(t, 403, returnCode, "return code was not 403")
@ -89,7 +105,7 @@ func TestDownload_retrievesBody(t *testing.T) {
srv := httptest.NewServer(httpbin.GetMux())
defer srv.Close()
_, body, err := download.Download(download.NewURLDownload(srv.URL + "/bytes/65536"))
_, body, err := download.Download(testctx, download.NewURLDownload(srv.URL+"/bytes/65536"))
require.Nil(t, err)
defer body.Close()
b, err := ioutil.ReadAll(body)
@ -101,7 +117,7 @@ func TestDownload_bodyClosesWithoutError(t *testing.T) {
srv := httptest.NewServer(httpbin.GetMux())
defer srv.Close()
_, body, err := download.Download(download.NewURLDownload(srv.URL + "/get"))
_, body, err := download.Download(testctx, download.NewURLDownload(srv.URL+"/get"))
require.Nil(t, err)
require.Nil(t, body.Close(), "body should close fine")
}

Просмотреть файл

@ -37,7 +37,7 @@ func WithRetries(ctx *log.Context, downloaders []Downloader, sf SleepFunc) (io.R
for _, d := range downloaders {
for n := 0; n < expRetryN; n++ {
ctx := ctx.With("retry", n)
status, out, err := Download(d)
status, out, err := Download(ctx, d)
if err == nil {
return out, nil
}

Просмотреть файл

@ -60,7 +60,8 @@ func TestWithRetries_failingBadStatusCode_validateSleeps(t *testing.T) {
sr := new(sleepRecorder)
_, err := download.WithRetries(nopLog(), []download.Downloader{d}, sr.Sleep)
require.Contains(t, err.Error(), "Status code 429 while downloading blob")
require.Contains(t, err.Error(), "429 Too Many Requests")
require.Contains(t, err.Error(), "Please verify the machine has network connectivity")
require.Equal(t, sleepSchedule, []time.Duration(*sr))
}
@ -104,7 +105,7 @@ func TestRetriesWith_SwitchDownloaderThenFailWithCorrectErrorMessage(t *testing.
require.NotNil(t, err, "download with retries should fail")
require.Nil(t, resp, "response body should be nil for failed download with retries")
require.Equal(t, d404.timesCalled, 1)
require.Contains(t, err.Error(), "Status code 403 while downloading blob")
require.Contains(t, err.Error(), "403 Forbidden")
d404 = mockDownloader{0, svr.URL + "/status/404"}
msiDownloader404 := download.NewBlobWithMsiDownload(svr.URL+"/status/404", mockMsiProvider)

Просмотреть файл

@ -3,6 +3,13 @@ package download
import (
"net/http"
"net/url"
"github.com/google/uuid"
)
const (
xMsClientRequestIdHeaderName = "x-ms-client-request-id"
xMsServiceRequestIdHeaderName = "x-ms-request-id"
)
// urlDownload describes a URL to download.
@ -17,7 +24,11 @@ func NewURLDownload(url string) Downloader {
// GetRequest returns a new request to download the URL
func (u urlDownload) GetRequest() (*http.Request, error) {
return http.NewRequest("GET", u.url, nil)
req, err := http.NewRequest("GET", u.url, nil)
if req != nil {
req.Header.Add(xMsClientRequestIdHeaderName, uuid.New().String())
}
return req, err
}
// Scrub query. Used to remove the query parts like SAS token.

Просмотреть файл

@ -25,6 +25,7 @@ func Test_urlDownload_GetRequest_goodURL(t *testing.T) {
r, err := d.GetRequest()
require.Nil(t, err, u)
require.NotNil(t, r, u)
require.NotNil(t, r.Header.Get(xMsClientRequestIdHeaderName))
}
func Test_GetUriForLogging_ScrubsQuery(t *testing.T) {