2016-07-07 02:07:55 +03:00
|
|
|
package download_test
|
|
|
|
|
|
|
|
import (
|
|
|
|
"errors"
|
|
|
|
"fmt"
|
|
|
|
"io/ioutil"
|
|
|
|
"net/http"
|
|
|
|
"net/http/httptest"
|
2020-03-24 20:35:11 +03:00
|
|
|
"strings"
|
2016-07-07 02:07:55 +03:00
|
|
|
"testing"
|
|
|
|
|
2023-01-25 01:48:48 +03:00
|
|
|
"github.com/Azure/azure-extension-foundation/msi"
|
|
|
|
"github.com/go-kit/kit/log"
|
|
|
|
|
2016-08-01 23:44:25 +03:00
|
|
|
"github.com/Azure/custom-script-extension-linux/pkg/download"
|
2016-07-07 02:07:55 +03:00
|
|
|
"github.com/ahmetalpbalkan/go-httpbin"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
|
|
)
|
|
|
|
|
2016-07-13 21:36:14 +03:00
|
|
|
type badDownloader struct{ calls int }
|
2016-07-07 02:07:55 +03:00
|
|
|
|
2023-01-25 01:48:48 +03:00
|
|
|
var (
|
|
|
|
testctx = log.NewContext(log.NewNopLogger())
|
|
|
|
)
|
|
|
|
|
2016-07-13 21:36:14 +03:00
|
|
|
func (b *badDownloader) GetRequest() (*http.Request, error) {
|
|
|
|
b.calls++
|
2016-07-07 02:07:55 +03:00
|
|
|
return nil, errors.New("expected error")
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestDownload_wrapsGetRequestError(t *testing.T) {
|
2023-01-25 01:48:48 +03:00
|
|
|
_, _, err := download.Download(testctx, new(badDownloader))
|
2016-07-07 02:07:55 +03:00
|
|
|
require.NotNil(t, err)
|
2020-01-08 00:46:39 +03:00
|
|
|
require.EqualError(t, err, "failed to create http request: expected error")
|
2016-07-07 02:07:55 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
func TestDownload_wrapsHTTPError(t *testing.T) {
|
2023-01-25 01:48:48 +03:00
|
|
|
_, _, err := download.Download(testctx, download.NewURLDownload("bad url"))
|
2016-07-07 02:07:55 +03:00
|
|
|
require.NotNil(t, err)
|
|
|
|
require.Contains(t, err.Error(), "http request failed:")
|
|
|
|
}
|
|
|
|
|
2023-01-25 01:48:48 +03:00
|
|
|
// This test is only to make sure that formatting of error messages for specific codes is correct
|
|
|
|
func TestDownload_wrapsCommonErrorCodes(t *testing.T) {
|
2016-07-07 02:07:55 +03:00
|
|
|
srv := httptest.NewServer(httpbin.GetMux())
|
|
|
|
defer srv.Close()
|
|
|
|
|
|
|
|
for _, code := range []int{
|
|
|
|
http.StatusNotFound,
|
|
|
|
http.StatusForbidden,
|
|
|
|
http.StatusInternalServerError,
|
|
|
|
http.StatusBadGateway,
|
|
|
|
http.StatusBadRequest,
|
|
|
|
http.StatusUnauthorized,
|
|
|
|
} {
|
2023-01-25 01:48:48 +03:00
|
|
|
respCode, _, err := download.Download(testctx, download.NewURLDownload(fmt.Sprintf("%s/status/%d", srv.URL, code)))
|
2016-07-07 02:07:55 +03:00
|
|
|
require.NotNil(t, err, "not failed for code:%d", code)
|
2023-01-25 01:48:48 +03:00
|
|
|
require.Equal(t, code, respCode)
|
|
|
|
switch respCode {
|
|
|
|
case http.StatusNotFound:
|
|
|
|
require.Contains(t, err.Error(), "because it does not exist")
|
|
|
|
case http.StatusForbidden:
|
|
|
|
require.Contains(t, err.Error(), "Please verify the machine has network connectivity")
|
|
|
|
case http.StatusInternalServerError:
|
|
|
|
require.Contains(t, err.Error(), "due to an issue with storage")
|
|
|
|
case http.StatusBadRequest:
|
|
|
|
require.Contains(t, err.Error(), "because parts of the request were incorrectly formatted, missing, and/or invalid")
|
|
|
|
case http.StatusUnauthorized:
|
|
|
|
require.Contains(t, err.Error(), "because access was denied")
|
|
|
|
}
|
2016-07-07 02:07:55 +03:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestDownload_statusOKSucceeds(t *testing.T) {
|
|
|
|
srv := httptest.NewServer(httpbin.GetMux())
|
|
|
|
defer srv.Close()
|
|
|
|
|
2023-01-25 01:48:48 +03:00
|
|
|
_, body, err := download.Download(testctx, download.NewURLDownload(srv.URL+"/status/200"))
|
2016-07-07 02:07:55 +03:00
|
|
|
require.Nil(t, err)
|
|
|
|
defer body.Close()
|
|
|
|
require.NotNil(t, body)
|
|
|
|
}
|
|
|
|
|
2020-03-24 20:35:11 +03:00
|
|
|
func TestDowload_msiDownloaderErrorMessage(t *testing.T) {
|
|
|
|
var mockMsiProvider download.MsiProvider = func() (msi.Msi, error) {
|
|
|
|
return msi.Msi{AccessToken: "fakeAccessToken"}, nil
|
|
|
|
}
|
|
|
|
srv := httptest.NewServer(httpbin.GetMux())
|
|
|
|
defer srv.Close()
|
|
|
|
|
|
|
|
msiDownloader404 := download.NewBlobWithMsiDownload(srv.URL+"/status/404", mockMsiProvider)
|
|
|
|
|
2023-01-25 01:48:48 +03:00
|
|
|
returnCode, body, err := download.Download(testctx, msiDownloader404)
|
2020-03-24 20:35:11 +03:00
|
|
|
require.True(t, strings.Contains(err.Error(), download.MsiDownload404ErrorString), "error string doesn't contain the correct message")
|
|
|
|
require.Nil(t, body, "body is not nil for failed download")
|
|
|
|
require.Equal(t, 404, returnCode, "return code was not 404")
|
|
|
|
|
|
|
|
msiDownloader403 := download.NewBlobWithMsiDownload(srv.URL+"/status/403", mockMsiProvider)
|
2023-01-25 01:48:48 +03:00
|
|
|
returnCode, body, err = download.Download(testctx, msiDownloader403)
|
2020-03-24 20:35:11 +03:00
|
|
|
require.True(t, strings.Contains(err.Error(), download.MsiDownload403ErrorString), "error string doesn't contain the correct message")
|
|
|
|
require.Nil(t, body, "body is not nil for failed download")
|
|
|
|
require.Equal(t, 403, returnCode, "return code was not 403")
|
|
|
|
|
|
|
|
}
|
|
|
|
|
2016-07-07 02:07:55 +03:00
|
|
|
func TestDownload_retrievesBody(t *testing.T) {
|
|
|
|
srv := httptest.NewServer(httpbin.GetMux())
|
|
|
|
defer srv.Close()
|
|
|
|
|
2023-01-25 01:48:48 +03:00
|
|
|
_, body, err := download.Download(testctx, download.NewURLDownload(srv.URL+"/bytes/65536"))
|
2016-07-07 02:07:55 +03:00
|
|
|
require.Nil(t, err)
|
|
|
|
defer body.Close()
|
|
|
|
b, err := ioutil.ReadAll(body)
|
|
|
|
require.Nil(t, err)
|
|
|
|
require.EqualValues(t, 65536, len(b))
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestDownload_bodyClosesWithoutError(t *testing.T) {
|
|
|
|
srv := httptest.NewServer(httpbin.GetMux())
|
|
|
|
defer srv.Close()
|
|
|
|
|
2023-01-25 01:48:48 +03:00
|
|
|
_, body, err := download.Download(testctx, download.NewURLDownload(srv.URL+"/get"))
|
2016-07-07 02:07:55 +03:00
|
|
|
require.Nil(t, err)
|
|
|
|
require.Nil(t, body.Close(), "body should close fine")
|
|
|
|
}
|