Tracks request and response ID and adds enhanced error messages for common status codes
This commit is contained in:
Родитель
34c16bc772
Коммит
6e14e6e78b
|
@ -11,6 +11,7 @@ import (
|
|||
|
||||
"github.com/Azure/azure-sdk-for-go/storage"
|
||||
"github.com/Azure/run-command-handler-linux/pkg/blobutil"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
@ -30,7 +31,11 @@ func (b blobDownload) GetRequest() (*http.Request, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return http.NewRequest("GET", url, nil)
|
||||
req, error := http.NewRequest("GET", url, nil)
|
||||
if req != nil {
|
||||
req.Header.Set(xMsClientRequestIdHeaderName, uuid.New().String())
|
||||
}
|
||||
return req, error
|
||||
}
|
||||
|
||||
// getURL returns publicly downloadable URL of the Azure Blob
|
||||
|
|
|
@ -11,9 +11,14 @@ import (
|
|||
|
||||
"github.com/Azure/azure-sdk-for-go/storage"
|
||||
"github.com/Azure/run-command-handler-linux/pkg/blobutil"
|
||||
"github.com/go-kit/kit/log"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
testctx = log.NewContext(log.NewNopLogger())
|
||||
)
|
||||
|
||||
func Test_blobDownload_validateInputs(t *testing.T) {
|
||||
type sas interface {
|
||||
getURL() (string, error)
|
||||
|
@ -72,9 +77,10 @@ func Test_blobDownload_fails_badCreds(t *testing.T) {
|
|||
Container: "foocontainer",
|
||||
})
|
||||
|
||||
status, _, err := Download(d)
|
||||
status, _, err := Download(testctx, d)
|
||||
require.NotNil(t, err)
|
||||
require.Contains(t, err.Error(), "Status code 403 while downloading blob")
|
||||
require.Contains(t, err.Error(), "Please verify the machine has network connectivity")
|
||||
require.Contains(t, err.Error(), "403")
|
||||
require.Equal(t, status, http.StatusForbidden)
|
||||
}
|
||||
|
||||
|
@ -85,7 +91,7 @@ func Test_blobDownload_fails_urlNotFound(t *testing.T) {
|
|||
Container: "foocontainer",
|
||||
})
|
||||
|
||||
_, _, err := Download(d)
|
||||
_, _, err := Download(testctx, d)
|
||||
require.NotNil(t, err)
|
||||
require.Contains(t, err.Error(), "http request failed:")
|
||||
}
|
||||
|
@ -145,7 +151,7 @@ func Test_blobDownload_actualBlob(t *testing.T) {
|
|||
Blob: name,
|
||||
StorageBase: base,
|
||||
})
|
||||
_, body, err := Download(d)
|
||||
_, body, err := Download(testctx, d)
|
||||
require.Nil(t, err)
|
||||
defer body.Close()
|
||||
b, err := ioutil.ReadAll(body)
|
||||
|
@ -153,6 +159,57 @@ func Test_blobDownload_actualBlob(t *testing.T) {
|
|||
require.EqualValues(t, chunk, b, "retrieved body is different body=%d chunk=%d", len(b), len(chunk))
|
||||
}
|
||||
|
||||
func Test_blobDownload_fails_actualBlob404(t *testing.T) {
|
||||
acct := os.Getenv("AZURE_STORAGE_ACCOUNT")
|
||||
key := os.Getenv("AZURE_STORAGE_ACCESS_KEY")
|
||||
if acct == "" || key == "" {
|
||||
t.Skipf("Skipping: AZURE_STORAGE_ACCOUNT or AZURE_STORAGE_ACCESS_KEY not specified to run this test")
|
||||
}
|
||||
base := storage.DefaultBaseURL
|
||||
|
||||
blobName := "<BLOB THAT DOESN'T EXIST>"
|
||||
containerName := "<CONTAINER NAME>"
|
||||
|
||||
// Get the blob via downloader
|
||||
d := NewBlobDownload(acct, key, blobutil.AzureBlobRef{
|
||||
Container: containerName,
|
||||
Blob: blobName,
|
||||
StorageBase: base,
|
||||
})
|
||||
code, _, err := Download(testctx, d)
|
||||
require.NotNil(t, err)
|
||||
require.Equal(t, code, http.StatusNotFound)
|
||||
require.Contains(t, err.Error(), "because it does not exist")
|
||||
require.Contains(t, err.Error(), "Not Found")
|
||||
require.Contains(t, err.Error(), "Service request ID")
|
||||
}
|
||||
|
||||
func Test_blobDownload_fails_actualBlob409(t *testing.T) {
|
||||
// before running this test, go to your storage account on portal > Configuration and disable Blob public access
|
||||
acct := os.Getenv("AZURE_STORAGE_ACCOUNT")
|
||||
key := os.Getenv("AZURE_STORAGE_ACCESS_KEY")
|
||||
if acct == "" || key == "" {
|
||||
t.Skipf("Skipping: AZURE_STORAGE_ACCOUNT or AZURE_STORAGE_ACCESS_KEY not specified to run this test")
|
||||
}
|
||||
base := storage.DefaultBaseURL
|
||||
|
||||
blobName := "<BLOB NAME>"
|
||||
containerName := "<CONTAINER NAME>"
|
||||
|
||||
// Get the blob via downloader
|
||||
d := NewBlobDownload(acct, key, blobutil.AzureBlobRef{
|
||||
Container: containerName,
|
||||
Blob: blobName,
|
||||
StorageBase: base,
|
||||
})
|
||||
code, _, err := Download(testctx, d)
|
||||
require.NotNil(t, err)
|
||||
require.Equal(t, code, http.StatusConflict)
|
||||
require.Contains(t, err.Error(), "Please verify the machine has network connectivity")
|
||||
require.Contains(t, err.Error(), "Public access is not permitted on this storage account")
|
||||
require.Contains(t, err.Error(), "Service request ID")
|
||||
}
|
||||
|
||||
func Test_blobAppend_actualBlob(t *testing.T) {
|
||||
// Before running the test locally prepare storage account and set the following env variables:
|
||||
// export AZURE_STORAGE_BLOB="https://atanas.blob.core.windows.net/con1/output5.txt"
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
|
||||
"github.com/Azure/azure-extension-foundation/httputil"
|
||||
"github.com/Azure/azure-extension-foundation/msi"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
@ -55,6 +56,7 @@ func (self *blobWithMsiToken) GetRequest() (*http.Request, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
request.Header.Set(xMsClientRequestIdHeaderName, uuid.New().String())
|
||||
if IsAzureStorageBlobUri(self.url) {
|
||||
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", msi.AccessToken))
|
||||
request.Header.Set(xMsVersionHeaderName, xMsVersionValue)
|
||||
|
|
|
@ -3,6 +3,7 @@ package download
|
|||
import (
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/Azure/azure-extension-foundation/msi"
|
||||
|
@ -45,7 +46,7 @@ func Test_realDownloadBlobWithMsiToken(t *testing.T) {
|
|||
err := json.Unmarshal([]byte(msiJson), &msi)
|
||||
return msi, err
|
||||
}}
|
||||
_, stream, err := Download(&downloader)
|
||||
_, stream, err := Download(testctx, &downloader)
|
||||
require.NoError(t, err, "File download failed")
|
||||
defer stream.Close()
|
||||
|
||||
|
@ -54,6 +55,23 @@ func Test_realDownloadBlobWithMsiToken(t *testing.T) {
|
|||
require.Contains(t, string(bytes), stringToLookFor)
|
||||
}
|
||||
|
||||
func Test_realDownloadBlobWithMsiToken404(t *testing.T) {
|
||||
if msiJson == "" || blobUri == "" || stringToLookFor == "" {
|
||||
t.Skip()
|
||||
}
|
||||
var badBlobUri = blobUri[0 : len(blobUri)-1]
|
||||
downloader := blobWithMsiToken{badBlobUri, func() (msi.Msi, error) {
|
||||
msi := msi.Msi{}
|
||||
err := json.Unmarshal([]byte(msiJson), &msi)
|
||||
return msi, err
|
||||
}}
|
||||
code, _, err := Download(testctx, &downloader)
|
||||
require.NotNil(t, err, "File download succeeded but was not supposed to")
|
||||
require.Equal(t, http.StatusNotFound, code)
|
||||
require.Contains(t, err.Error(), MsiDownload404ErrorString)
|
||||
require.Contains(t, err.Error(), "Service request ID:") // should have a service request ID since downloading from Azure Storage
|
||||
}
|
||||
|
||||
func Test_isAzureStorageBlobUri(t *testing.T) {
|
||||
require.True(t, IsAzureStorageBlobUri("https://a.blob.core.windows.net/container/blobname"))
|
||||
require.True(t, IsAzureStorageBlobUri("http://mystorageaccountcn.blob.core.chinacloudapi.cn"))
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/Azure/run-command-handler-linux/pkg/urlutil"
|
||||
"github.com/go-kit/kit/log"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
@ -45,12 +46,17 @@ var (
|
|||
// Download retrieves a response body and checks the response status code to see
|
||||
// if it is 200 OK and then returns the response body. It issues a new request
|
||||
// every time called. It is caller's responsibility to close the response body.
|
||||
func Download(downloader Downloader) (int, io.ReadCloser, error) {
|
||||
func Download(ctx *log.Context, downloader Downloader) (int, io.ReadCloser, error) {
|
||||
request, err := downloader.GetRequest()
|
||||
if err != nil {
|
||||
return -1, nil, errors.Wrapf(err, "failed to create http request")
|
||||
}
|
||||
|
||||
requestID := request.Header.Get(xMsClientRequestIdHeaderName)
|
||||
if len(requestID) > 0 {
|
||||
ctx.Log("info", fmt.Sprintf("starting download with client request ID %s", requestID))
|
||||
}
|
||||
|
||||
response, err := httpClient.Do(request)
|
||||
if err != nil {
|
||||
err = urlutil.RemoveUrlFromErr(err)
|
||||
|
@ -61,7 +67,8 @@ func Download(downloader Downloader) (int, io.ReadCloser, error) {
|
|||
return response.StatusCode, response.Body, nil
|
||||
}
|
||||
|
||||
err = fmt.Errorf("Status code %d while downloading blob '%s'. Use either a public script URI that points to .sh file, Azure storage blob SAS URI or storage blob accessible by a managed identity and retry. For more info, refer https://aka.ms/RunCommandManagedLinux", response.StatusCode, request.URL.Opaque)
|
||||
errString := ""
|
||||
requestId := response.Header.Get(xMsServiceRequestIdHeaderName)
|
||||
switch downloader.(type) {
|
||||
case *blobWithMsiToken:
|
||||
switch response.StatusCode {
|
||||
|
@ -75,6 +82,35 @@ func Download(downloader Downloader) (int, io.ReadCloser, error) {
|
|||
forbiddenError := fmt.Errorf("Make sure managed identity has been given access to container of storage blob '%s' with 'Storage Blob Data Reader' role assignment. In case of user assigned identity, make sure you add it under VM's identity. For more info, refer https://aka.ms/RunCommandManagedLinux", request.URL.Opaque)
|
||||
return response.StatusCode, nil, errors.Wrapf(forbiddenError, MsiDownload403ErrorString)
|
||||
}
|
||||
default:
|
||||
hostname := request.URL.Host
|
||||
switch response.StatusCode {
|
||||
case http.StatusUnauthorized:
|
||||
errString = fmt.Sprintf("RunCommand failed to download the file from %s because access was denied. Please fix the blob permissions and try again, the response code and message returned were: %q",
|
||||
hostname,
|
||||
response.Status)
|
||||
case http.StatusNotFound:
|
||||
errString = fmt.Sprintf("RunCommand failed to download the file from %s because it does not exist. Please create the blob and try again, the response code and message returned were: %q",
|
||||
hostname,
|
||||
response.Status)
|
||||
|
||||
case http.StatusBadRequest:
|
||||
errString = fmt.Sprintf("RunCommand failed to download the file from %s because parts of the request were incorrectly formatted, missing, and/or invalid. The response code and message returned were: %q",
|
||||
hostname,
|
||||
response.Status)
|
||||
|
||||
case http.StatusInternalServerError:
|
||||
errString = fmt.Sprintf("RunCommand failed to download the file from %s due to an issue with storage. The response code and message returned were: %q",
|
||||
hostname,
|
||||
response.Status)
|
||||
default:
|
||||
errString = fmt.Sprintf("RunCommand failed to download the file from %s because the server returned a response code and message of %q Please verify the machine has network connectivity.",
|
||||
hostname,
|
||||
response.Status)
|
||||
}
|
||||
}
|
||||
return response.StatusCode, nil, err
|
||||
if len(requestId) > 0 {
|
||||
errString += fmt.Sprintf(" (Service request ID: %s)", requestId)
|
||||
}
|
||||
return response.StatusCode, nil, fmt.Errorf(errString)
|
||||
}
|
||||
|
|
|
@ -10,32 +10,36 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/Azure/azure-extension-foundation/msi"
|
||||
|
||||
"github.com/Azure/run-command-handler-linux/pkg/download"
|
||||
"github.com/ahmetalpbalkan/go-httpbin"
|
||||
"github.com/go-kit/kit/log"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type badDownloader struct{ calls int }
|
||||
|
||||
var (
|
||||
testctx = log.NewContext(log.NewNopLogger())
|
||||
)
|
||||
|
||||
func (b *badDownloader) GetRequest() (*http.Request, error) {
|
||||
b.calls++
|
||||
return nil, errors.New("expected error")
|
||||
}
|
||||
|
||||
func TestDownload_wrapsGetRequestError(t *testing.T) {
|
||||
_, _, err := download.Download(new(badDownloader))
|
||||
_, _, err := download.Download(testctx, new(badDownloader))
|
||||
require.NotNil(t, err)
|
||||
require.EqualError(t, err, "failed to create http request: expected error")
|
||||
}
|
||||
|
||||
func TestDownload_wrapsHTTPError(t *testing.T) {
|
||||
_, _, err := download.Download(download.NewURLDownload("bad url"))
|
||||
_, _, err := download.Download(testctx, download.NewURLDownload("bad url"))
|
||||
require.NotNil(t, err)
|
||||
require.Contains(t, err.Error(), "http request failed:")
|
||||
}
|
||||
|
||||
func TestDownload_badStatusCodeFails(t *testing.T) {
|
||||
func TestDownload_wrapsCommonErrorCodes(t *testing.T) {
|
||||
srv := httptest.NewServer(httpbin.GetMux())
|
||||
defer srv.Close()
|
||||
|
||||
|
@ -47,9 +51,21 @@ func TestDownload_badStatusCodeFails(t *testing.T) {
|
|||
http.StatusBadRequest,
|
||||
http.StatusUnauthorized,
|
||||
} {
|
||||
_, _, err := download.Download(download.NewURLDownload(fmt.Sprintf("%s/status/%d", srv.URL, code)))
|
||||
respCode, _, err := download.Download(testctx, download.NewURLDownload(fmt.Sprintf("%s/status/%d", srv.URL, code)))
|
||||
require.NotNil(t, err, "not failed for code:%d", code)
|
||||
require.Contains(t, err.Error(), fmt.Sprintf("Status code %d while downloading blob", code))
|
||||
require.Equal(t, code, respCode)
|
||||
switch respCode {
|
||||
case http.StatusNotFound:
|
||||
require.Contains(t, err.Error(), "because it does not exist")
|
||||
case http.StatusForbidden:
|
||||
require.Contains(t, err.Error(), "Please verify the machine has network connectivity")
|
||||
case http.StatusInternalServerError:
|
||||
require.Contains(t, err.Error(), "due to an issue with storage")
|
||||
case http.StatusBadRequest:
|
||||
require.Contains(t, err.Error(), "because parts of the request were incorrectly formatted, missing, and/or invalid")
|
||||
case http.StatusUnauthorized:
|
||||
require.Contains(t, err.Error(), "because access was denied")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -57,7 +73,7 @@ func TestDownload_statusOKSucceeds(t *testing.T) {
|
|||
srv := httptest.NewServer(httpbin.GetMux())
|
||||
defer srv.Close()
|
||||
|
||||
_, body, err := download.Download(download.NewURLDownload(srv.URL + "/status/200"))
|
||||
_, body, err := download.Download(testctx, download.NewURLDownload(srv.URL + "/status/200"))
|
||||
require.Nil(t, err)
|
||||
defer body.Close()
|
||||
require.NotNil(t, body)
|
||||
|
@ -72,13 +88,13 @@ func TestDowload_msiDownloaderErrorMessage(t *testing.T) {
|
|||
|
||||
msiDownloader404 := download.NewBlobWithMsiDownload(srv.URL+"/status/404", mockMsiProvider)
|
||||
|
||||
returnCode, body, err := download.Download(msiDownloader404)
|
||||
returnCode, body, err := download.Download(testctx, msiDownloader404)
|
||||
require.True(t, strings.Contains(err.Error(), download.MsiDownload404ErrorString), "error string doesn't contain the correct message")
|
||||
require.Nil(t, body, "body is not nil for failed download")
|
||||
require.Equal(t, 404, returnCode, "return code was not 404")
|
||||
|
||||
msiDownloader403 := download.NewBlobWithMsiDownload(srv.URL+"/status/403", mockMsiProvider)
|
||||
returnCode, body, err = download.Download(msiDownloader403)
|
||||
returnCode, body, err = download.Download(testctx, msiDownloader403)
|
||||
require.True(t, strings.Contains(err.Error(), download.MsiDownload403ErrorString), "error string doesn't contain the correct message")
|
||||
require.Nil(t, body, "body is not nil for failed download")
|
||||
require.Equal(t, 403, returnCode, "return code was not 403")
|
||||
|
@ -89,7 +105,7 @@ func TestDownload_retrievesBody(t *testing.T) {
|
|||
srv := httptest.NewServer(httpbin.GetMux())
|
||||
defer srv.Close()
|
||||
|
||||
_, body, err := download.Download(download.NewURLDownload(srv.URL + "/bytes/65536"))
|
||||
_, body, err := download.Download(testctx, download.NewURLDownload(srv.URL+"/bytes/65536"))
|
||||
require.Nil(t, err)
|
||||
defer body.Close()
|
||||
b, err := ioutil.ReadAll(body)
|
||||
|
@ -101,7 +117,7 @@ func TestDownload_bodyClosesWithoutError(t *testing.T) {
|
|||
srv := httptest.NewServer(httpbin.GetMux())
|
||||
defer srv.Close()
|
||||
|
||||
_, body, err := download.Download(download.NewURLDownload(srv.URL + "/get"))
|
||||
_, body, err := download.Download(testctx, download.NewURLDownload(srv.URL+"/get"))
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, body.Close(), "body should close fine")
|
||||
}
|
||||
|
|
|
@ -37,7 +37,7 @@ func WithRetries(ctx *log.Context, downloaders []Downloader, sf SleepFunc) (io.R
|
|||
for _, d := range downloaders {
|
||||
for n := 0; n < expRetryN; n++ {
|
||||
ctx := ctx.With("retry", n)
|
||||
status, out, err := Download(d)
|
||||
status, out, err := Download(ctx, d)
|
||||
if err == nil {
|
||||
return out, nil
|
||||
}
|
||||
|
|
|
@ -60,7 +60,8 @@ func TestWithRetries_failingBadStatusCode_validateSleeps(t *testing.T) {
|
|||
|
||||
sr := new(sleepRecorder)
|
||||
_, err := download.WithRetries(nopLog(), []download.Downloader{d}, sr.Sleep)
|
||||
require.Contains(t, err.Error(), "Status code 429 while downloading blob")
|
||||
require.Contains(t, err.Error(), "429 Too Many Requests")
|
||||
require.Contains(t, err.Error(), "Please verify the machine has network connectivity")
|
||||
|
||||
require.Equal(t, sleepSchedule, []time.Duration(*sr))
|
||||
}
|
||||
|
@ -104,7 +105,7 @@ func TestRetriesWith_SwitchDownloaderThenFailWithCorrectErrorMessage(t *testing.
|
|||
require.NotNil(t, err, "download with retries should fail")
|
||||
require.Nil(t, resp, "response body should be nil for failed download with retries")
|
||||
require.Equal(t, d404.timesCalled, 1)
|
||||
require.Contains(t, err.Error(), "Status code 403 while downloading blob")
|
||||
require.Contains(t, err.Error(), "403 Forbidden")
|
||||
|
||||
d404 = mockDownloader{0, svr.URL + "/status/404"}
|
||||
msiDownloader404 := download.NewBlobWithMsiDownload(svr.URL+"/status/404", mockMsiProvider)
|
||||
|
|
|
@ -3,6 +3,13 @@ package download
|
|||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
xMsClientRequestIdHeaderName = "x-ms-client-request-id"
|
||||
xMsServiceRequestIdHeaderName = "x-ms-request-id"
|
||||
)
|
||||
|
||||
// urlDownload describes a URL to download.
|
||||
|
@ -17,7 +24,11 @@ func NewURLDownload(url string) Downloader {
|
|||
|
||||
// GetRequest returns a new request to download the URL
|
||||
func (u urlDownload) GetRequest() (*http.Request, error) {
|
||||
return http.NewRequest("GET", u.url, nil)
|
||||
req, err := http.NewRequest("GET", u.url, nil)
|
||||
if req != nil {
|
||||
req.Header.Add(xMsClientRequestIdHeaderName, uuid.New().String())
|
||||
}
|
||||
return req, err
|
||||
}
|
||||
|
||||
// Scrub query. Used to remove the query parts like SAS token.
|
||||
|
|
|
@ -25,6 +25,7 @@ func Test_urlDownload_GetRequest_goodURL(t *testing.T) {
|
|||
r, err := d.GetRequest()
|
||||
require.Nil(t, err, u)
|
||||
require.NotNil(t, r, u)
|
||||
require.NotNil(t, r.Header.Get(xMsClientRequestIdHeaderName))
|
||||
}
|
||||
|
||||
func Test_GetUriForLogging_ScrubsQuery(t *testing.T) {
|
||||
|
|
Загрузка…
Ссылка в новой задаче