Merge pull request #151 from Azure/bhbrahma/msi
RBAC support for downloading scripts with managed identities
This commit is contained in:
Коммит
fc5210a6d8
|
@ -13,6 +13,7 @@ install:
|
|||
- sudo npm install -g azure-cli
|
||||
- go get -u golang.org/x/lint/golint
|
||||
- go get -u github.com/ahmetalpbalkan/govvv
|
||||
- go get -u -d github.com/Azure/azure-extension-foundation/...
|
||||
before_script:
|
||||
- docker version
|
||||
- docker info
|
||||
|
|
1
Makefile
1
Makefile
|
@ -15,6 +15,7 @@ binary: clean
|
|||
echo "GOPATH is not set"; \
|
||||
exit 1; \
|
||||
fi
|
||||
go get -d -u -f github.com/Azure/azure-extension-foundation/...
|
||||
GOOS=linux GOARCH=amd64 govvv build -v \
|
||||
-ldflags "-X main.Version=`grep -E -m 1 -o '<Version>(.*)</Version>' misc/manifest.xml | awk -F">" '{print $$2}' | awk -F"<" '{print $$1}'`" \
|
||||
-o $(BINDIR)/$(BIN) ./main
|
||||
|
|
|
@ -230,7 +230,7 @@ func downloadFiles(ctx *log.Context, dir string, cfg handlerSettings) error {
|
|||
for i, f := range cfg.fileUrls() {
|
||||
ctx := ctx.With("file", i)
|
||||
ctx.Log("event", "download start")
|
||||
if err := downloadAndProcessURL(ctx, f, dir, cfg.StorageAccountName, cfg.StorageAccountKey, cfg.publicSettings.SkipDos2Unix); err != nil {
|
||||
if err := downloadAndProcessURL(ctx, f, dir, &cfg); err != nil {
|
||||
ctx.Log("event", "download failed", "error", err)
|
||||
return errors.Wrapf(err, "failed to download file[%d]", i)
|
||||
}
|
||||
|
|
|
@ -19,7 +19,7 @@ import (
|
|||
// downloadAndProcessURL downloads using the specified downloader and saves it to the
|
||||
// specified existing directory, which must be the path to the saved file. Then
|
||||
// it post-processes file based on heuristics.
|
||||
func downloadAndProcessURL(ctx *log.Context, url, downloadDir, storageAccountName, storageAccountKey string, skipDos2Unix bool) error {
|
||||
func downloadAndProcessURL(ctx *log.Context, url, downloadDir string, cfg *handlerSettings) error {
|
||||
fn, err := urlToFileName(url)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -29,7 +29,7 @@ func downloadAndProcessURL(ctx *log.Context, url, downloadDir, storageAccountNam
|
|||
return fmt.Errorf("[REDACTED] is not a valid url")
|
||||
}
|
||||
|
||||
dl, err := getDownloader(url, storageAccountName, storageAccountKey)
|
||||
dl, err := getDownloaders(url, cfg.StorageAccountName, cfg.StorageAccountKey, cfg.ManagedIdentity)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -40,7 +40,7 @@ func downloadAndProcessURL(ctx *log.Context, url, downloadDir, storageAccountNam
|
|||
return err
|
||||
}
|
||||
|
||||
if skipDos2Unix == false {
|
||||
if cfg.SkipDos2Unix == false {
|
||||
err = postProcessFile(fp)
|
||||
}
|
||||
return errors.Wrapf(err, "failed to post-process '%s'", fn)
|
||||
|
@ -48,19 +48,46 @@ func downloadAndProcessURL(ctx *log.Context, url, downloadDir, storageAccountNam
|
|||
|
||||
// getDownloader returns a downloader for the given URL based on whether the
|
||||
// storage credentials are empty or not.
|
||||
func getDownloader(fileURL string, storageAccountName, storageAccountKey string) (
|
||||
download.Downloader, error) {
|
||||
func getDownloaders(fileURL string, storageAccountName, storageAccountKey string, managedIdentity *clientOrObjectId) (
|
||||
[]download.Downloader, error) {
|
||||
if storageAccountName == "" || storageAccountKey == "" {
|
||||
return download.NewURLDownload(fileURL), nil
|
||||
// storage account name and key cannot be specified with managed identity, handler settings validation won't allow that
|
||||
// handler settings validation will also not allow storageAccountName XOR storageAccountKey == 1
|
||||
// in this case, we can be sure that storage account name and key was not specified
|
||||
if download.IsAzureStorageBlobUri(fileURL) && managedIdentity != nil {
|
||||
// if managed identity was specified in the configuration, try to use it to download the files
|
||||
var msiProvider download.MsiProvider
|
||||
switch {
|
||||
case managedIdentity.ClientId == "" && managedIdentity.ObjectId == "":
|
||||
// get msi using clientId or objectId or implicitly
|
||||
msiProvider = download.GetMsiProviderForStorageAccountsImplicitly(fileURL)
|
||||
case managedIdentity.ClientId != "" && managedIdentity.ObjectId == "":
|
||||
msiProvider = download.GetMsiProviderForStorageAccountsWithClientId(fileURL, managedIdentity.ClientId)
|
||||
case managedIdentity.ClientId == "" && managedIdentity.ObjectId != "":
|
||||
msiProvider = download.GetMsiProviderForStorageAccountsWithObjectId(fileURL, managedIdentity.ObjectId)
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected combination of ClientId and ObjectId found")
|
||||
}
|
||||
return []download.Downloader{
|
||||
// try downloading without MSI token first, but attempt with MSI if the download fails
|
||||
download.NewURLDownload(fileURL),
|
||||
download.NewBlobWithMsiDownload(fileURL, msiProvider),
|
||||
}, nil
|
||||
} else {
|
||||
// do not use MSI downloader if the uri is not azure storage blob, or managedIdentity isn't specified
|
||||
return []download.Downloader{download.NewURLDownload(fileURL)}, nil
|
||||
}
|
||||
} else {
|
||||
// if storage name account and key are specified, use that for all files
|
||||
// this preserves old behavior
|
||||
blob, err := blobutil.ParseBlobURL(fileURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []download.Downloader{download.NewBlobDownload(
|
||||
storageAccountName, storageAccountKey,
|
||||
blob)}, nil
|
||||
}
|
||||
|
||||
blob, err := blobutil.ParseBlobURL(fileURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return download.NewBlobDownload(
|
||||
storageAccountName, storageAccountKey,
|
||||
blob), nil
|
||||
}
|
||||
|
||||
// urlToFileName parses given URL and returns the section after the last slash
|
||||
|
|
|
@ -15,31 +15,44 @@ import (
|
|||
|
||||
func Test_getDownloader_azureBlob(t *testing.T) {
|
||||
// error condition
|
||||
_, err := getDownloader("http://acct.blob.core.windows.net/", "acct", "key")
|
||||
_, err := getDownloaders("http://acct.blob.core.windows.net/", "acct", "key", nil)
|
||||
require.NotNil(t, err)
|
||||
|
||||
// valid input
|
||||
d, err := getDownloader("http://acct.blob.core.windows.net/container/blob", "acct", "key")
|
||||
d, err := getDownloaders("http://acct.blob.core.windows.net/container/blob", "acct", "key", nil)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, d)
|
||||
require.Equal(t, "download.blobDownload", fmt.Sprintf("%T", d), "got wrong type")
|
||||
require.Equal(t, 1, len(d))
|
||||
require.Equal(t, "download.blobDownload", fmt.Sprintf("%T", d[0]), "got wrong type")
|
||||
}
|
||||
|
||||
func Test_getDownloader_externalUrl(t *testing.T) {
|
||||
d, err := getDownloader("http://acct.blob.core.windows.net/", "", "")
|
||||
d, err := getDownloaders("http://acct.blob.core.windows.net/", "", "", nil)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, d)
|
||||
require.Equal(t, "download.urlDownload", fmt.Sprintf("%T", d), "got wrong type")
|
||||
require.NotEmpty(t, d)
|
||||
require.Equal(t, 1, len(d))
|
||||
require.Equal(t, "download.urlDownload", fmt.Sprintf("%T", d[0]), "got wrong type")
|
||||
|
||||
d, err = getDownloader("http://acct.blob.core.windows.net/", "foo", "")
|
||||
d, err = getDownloaders("http://acct.blob.core.windows.net/", "", "", &clientOrObjectId{"", "dummyclientid"})
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, d)
|
||||
require.Equal(t, "download.urlDownload", fmt.Sprintf("%T", d), "got wrong type")
|
||||
require.NotEmpty(t, d)
|
||||
require.Equal(t, 2, len(d))
|
||||
require.Equal(t, "download.urlDownload", fmt.Sprintf("%T", d[0]), "got wrong type")
|
||||
require.Equal(t, "*download.blobWithMsiToken", fmt.Sprintf("%T", d[1]), "got wrong type")
|
||||
|
||||
d, err = getDownloader("http://acct.blob.core.windows.net/", "", "bar")
|
||||
d, err = getDownloaders("http://acct.blob.core.windows.net/", "foo", "", nil)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, d)
|
||||
require.Equal(t, "download.urlDownload", fmt.Sprintf("%T", d), "got wrong type")
|
||||
require.Equal(t, 1, len(d))
|
||||
require.Equal(t, "download.urlDownload", fmt.Sprintf("%T", d[0]), "got wrong type")
|
||||
|
||||
d, err = getDownloaders("http://acct.blob.core.windows.net/", "", "bar", nil)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, d)
|
||||
require.Equal(t, 1, len(d))
|
||||
require.Equal(t, "download.urlDownload", fmt.Sprintf("%T", d[0]), "got wrong type")
|
||||
}
|
||||
|
||||
func Test_urlToFileName_badURL(t *testing.T) {
|
||||
|
@ -110,7 +123,8 @@ func Test_downloadAndProcessURL(t *testing.T) {
|
|||
require.Nil(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
err = downloadAndProcessURL(log.NewContext(log.NewNopLogger()), srv.URL+"/bytes/256", tmpDir, "", "", false)
|
||||
cfg := handlerSettings{publicSettings{}, protectedSettings{StorageAccountName: "", StorageAccountKey: ""}}
|
||||
err = downloadAndProcessURL(log.NewContext(log.NewNopLogger()), srv.URL+"/bytes/256", tmpDir, &cfg)
|
||||
require.Nil(t, err)
|
||||
|
||||
fp := filepath.Join(tmpDir, "256")
|
||||
|
|
|
@ -9,11 +9,13 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
errStoragePartialCredentials = errors.New("both 'storageAccountName' and 'storageAccountKey' must be specified")
|
||||
errCmdTooMany = errors.New("'commandToExecute' was specified both in public and protected settings; it must be specified only once")
|
||||
errScriptTooMany = errors.New("'script' was specified both in public and protected settings; it must be specified only once")
|
||||
errCmdAndScript = errors.New("'commandToExecute' and 'script' were both specified, but only one is validate at a time")
|
||||
errCmdMissing = errors.New("'commandToExecute' is not specified")
|
||||
errStoragePartialCredentials = errors.New("both 'storageAccountName' and 'storageAccountKey' must be specified")
|
||||
errCmdTooMany = errors.New("'commandToExecute' was specified both in public and protected settings; it must be specified only once")
|
||||
errScriptTooMany = errors.New("'script' was specified both in public and protected settings; it must be specified only once")
|
||||
errCmdAndScript = errors.New("'commandToExecute' and 'script' were both specified, but only one is validate at a time")
|
||||
errCmdMissing = errors.New("'commandToExecute' is not specified")
|
||||
errUsingBothKeyAndMsi = errors.New("'storageAccountName' or 'storageAccountKey' must not be specified with 'managedServiceIdentity'")
|
||||
errUsingBothClientIdAndObjectId = errors.New("only one of 'clientId' or 'objectId' must be specified with 'managedServiceIdentity'")
|
||||
)
|
||||
|
||||
// handlerSettings holds the configuration of the extension handler.
|
||||
|
@ -66,6 +68,16 @@ func (h handlerSettings) validate() error {
|
|||
return errStoragePartialCredentials
|
||||
}
|
||||
|
||||
if (h.protectedSettings.StorageAccountKey != "" || h.protectedSettings.StorageAccountName != "") && h.protectedSettings.ManagedIdentity != nil {
|
||||
return errUsingBothKeyAndMsi
|
||||
}
|
||||
|
||||
if h.protectedSettings.ManagedIdentity != nil {
|
||||
if h.protectedSettings.ManagedIdentity.ClientId != "" && h.protectedSettings.ManagedIdentity.ObjectId != "" {
|
||||
return errUsingBothClientIdAndObjectId
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -81,11 +93,21 @@ type publicSettings struct {
|
|||
// protectedSettings is the type decoded and deserialized from protected
|
||||
// configuration section. This should be in sync with protectedSettingsSchema.
|
||||
type protectedSettings struct {
|
||||
CommandToExecute string `json:"commandToExecute"`
|
||||
Script string `json:"script"`
|
||||
FileURLs []string `json:"fileUris"`
|
||||
StorageAccountName string `json:"storageAccountName"`
|
||||
StorageAccountKey string `json:"storageAccountKey"`
|
||||
CommandToExecute string `json:"commandToExecute"`
|
||||
Script string `json:"script"`
|
||||
FileURLs []string `json:"fileUris"`
|
||||
StorageAccountName string `json:"storageAccountName"`
|
||||
StorageAccountKey string `json:"storageAccountKey"`
|
||||
ManagedIdentity *clientOrObjectId `json:"managedIdentity"`
|
||||
}
|
||||
|
||||
type clientOrObjectId struct {
|
||||
ObjectId string `json:"objectId"`
|
||||
ClientId string `json:"clientId"`
|
||||
}
|
||||
|
||||
func (self *clientOrObjectId) isEmpty() bool {
|
||||
return self.ClientId == "" && self.ObjectId == ""
|
||||
}
|
||||
|
||||
// parseAndValidateSettings reads configuration from configFolder, decrypts it,
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
package main
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
import "github.com/stretchr/testify/require"
|
||||
|
||||
func Test_handlerSettingsValidate(t *testing.T) {
|
||||
|
@ -86,6 +89,44 @@ func Test_skipDos2UnixDefaultsToFalse(t *testing.T) {
|
|||
require.Equal(t, false, testSubject.SkipDos2Unix)
|
||||
}
|
||||
|
||||
func Test_managedIdentityVerification(t *testing.T) {
|
||||
require.NoError(t, handlerSettings{publicSettings{}, protectedSettings{
|
||||
CommandToExecute: "echo hi",
|
||||
FileURLs: []string{"file1", "file2"},
|
||||
ManagedIdentity: &clientOrObjectId{
|
||||
ClientId: "31b403aa-c364-4240-a7ff-d85fb6cd7232",
|
||||
},
|
||||
}}.validate(), "validation failed for settings with MSI")
|
||||
|
||||
require.NoError(t, handlerSettings{publicSettings{}, protectedSettings{
|
||||
CommandToExecute: "echo hi",
|
||||
ManagedIdentity: &clientOrObjectId{
|
||||
ObjectId: "31b403aa-c364-4240-a7ff-d85fb6cd7232",
|
||||
},
|
||||
}}.validate(), "validation failed for settings with MSI")
|
||||
|
||||
require.Equal(t, errUsingBothKeyAndMsi,
|
||||
handlerSettings{publicSettings{},
|
||||
protectedSettings{
|
||||
CommandToExecute: "echo hi",
|
||||
StorageAccountName: "name",
|
||||
StorageAccountKey: "key",
|
||||
ManagedIdentity: &clientOrObjectId{
|
||||
ObjectId: "31b403aa-c364-4240-a7ff-d85fb6cd7232",
|
||||
},
|
||||
}}.validate(), "validation didn't fail for settings with both MSI and storage account")
|
||||
|
||||
require.Equal(t, errUsingBothClientIdAndObjectId,
|
||||
handlerSettings{publicSettings{},
|
||||
protectedSettings{
|
||||
CommandToExecute: "echo hi",
|
||||
ManagedIdentity: &clientOrObjectId{
|
||||
ObjectId: "31b403aa-c364-4240-a7ff-d85fb6cd7232",
|
||||
ClientId: "31b403aa-c364-4240-a7ff-d85fb6cd7232",
|
||||
},
|
||||
}}.validate(), "validation didn't fail for settings with both MSI and storage account")
|
||||
}
|
||||
|
||||
func Test_toJSON_empty(t *testing.T) {
|
||||
s, err := toJSON(nil)
|
||||
require.Nil(t, err)
|
||||
|
@ -98,3 +139,58 @@ func Test_toJSON(t *testing.T) {
|
|||
require.Nil(t, err)
|
||||
require.Equal(t, `{"a":3}`, s)
|
||||
}
|
||||
|
||||
func Test_toJSONUmarshallForManagedIdentity(t *testing.T) {
|
||||
testString := `{"commandToExecute" : "echo hello", "fileUris":["https://a.com/file.txt", "https://b.com/file2.txt"]}`
|
||||
require.NoError(t, validateProtectedSettings(testString), "protected settings should be valid")
|
||||
protSettings := new(protectedSettings)
|
||||
err := json.Unmarshal([]byte(testString), protSettings)
|
||||
require.NoError(t, err, "error while deserializing json")
|
||||
require.Nil(t, protSettings.ManagedIdentity, "ProtectedSettings.ManagedIdentity was expected to be nil")
|
||||
h := handlerSettings{publicSettings{}, *protSettings}
|
||||
require.NoError(t, h.validate(), "settings should be valid")
|
||||
|
||||
testString = `{"commandToExecute" : "echo hello", "fileUris":["https://a.com/file.txt"], "managedIdentity": { }}`
|
||||
require.NoError(t, validateProtectedSettings(testString), "protected settings should be valid")
|
||||
protSettings = new(protectedSettings)
|
||||
err = json.Unmarshal([]byte(testString), protSettings)
|
||||
require.NoError(t, err, "error while deserializing json")
|
||||
require.NotNil(t, protSettings.ManagedIdentity, "ProtectedSettings.ManagedIdentity was expected to not be nil")
|
||||
require.Equal(t, protSettings.ManagedIdentity.ClientId, "")
|
||||
require.Equal(t, protSettings.ManagedIdentity.ObjectId, "")
|
||||
h = handlerSettings{publicSettings{}, *protSettings}
|
||||
require.NoError(t, h.validate(), "settings should be valid")
|
||||
|
||||
testString = `{"commandToExecute" : "echo hello", "fileUris":["https://a.com/file.txt", "https://b.com/file2.txt"], "managedIdentity": { "clientId": "31b403aa-c364-4240-a7ff-d85fb6cd7232"}}`
|
||||
require.NoError(t, validateProtectedSettings(testString), "protected settings should be valid")
|
||||
protSettings = new(protectedSettings)
|
||||
err = json.Unmarshal([]byte(testString), protSettings)
|
||||
require.NoError(t, err, "error while deserializing json")
|
||||
require.NotNil(t, protSettings.ManagedIdentity, "ProtectedSettings.ManagedIdentity was expected to not be nil")
|
||||
require.Equal(t, protSettings.ManagedIdentity.ClientId, "31b403aa-c364-4240-a7ff-d85fb6cd7232")
|
||||
require.Equal(t, protSettings.ManagedIdentity.ObjectId, "")
|
||||
h = handlerSettings{publicSettings{}, *protSettings}
|
||||
require.NoError(t, h.validate(), "settings should be valid")
|
||||
|
||||
testString = `{"commandToExecute" : "echo hello", "fileUris":["https://a.com/file.txt"], "managedIdentity": { "objectId": "31b403aa-c364-4240-a7ff-d85fb6cd7232"}}`
|
||||
require.NoError(t, validateProtectedSettings(testString), "protected settings should be valid")
|
||||
protSettings = new(protectedSettings)
|
||||
err = json.Unmarshal([]byte(testString), protSettings)
|
||||
require.NoError(t, err, "error while deserializing json")
|
||||
require.NotNil(t, protSettings.ManagedIdentity, "ProtectedSettings.ManagedIdentity was expected to not be nil")
|
||||
require.Equal(t, protSettings.ManagedIdentity.ObjectId, "31b403aa-c364-4240-a7ff-d85fb6cd7232")
|
||||
require.Equal(t, protSettings.ManagedIdentity.ClientId, "")
|
||||
h = handlerSettings{publicSettings{}, *protSettings}
|
||||
require.NoError(t, h.validate(), "settings should be valid")
|
||||
|
||||
testString = `{"commandToExecute" : "echo hello", "fileUris":["https://a.com/file.txt", "https://b.com/file2.txt"], "managedIdentity": { "clientId": "31b403aa-c364-4240-a7ff-d85fb6cd7232", "objectId": "41b403aa-c364-4240-a7ff-d85fb6cd7232"}}`
|
||||
require.NoError(t, validateProtectedSettings(testString), "protected settings should be valid")
|
||||
protSettings = new(protectedSettings)
|
||||
err = json.Unmarshal([]byte(testString), protSettings)
|
||||
require.NoError(t, err, "error while deserializing json")
|
||||
require.NotNil(t, protSettings.ManagedIdentity, "ProtectedSettings.ManagedIdentity was expected to not be nil")
|
||||
require.Equal(t, protSettings.ManagedIdentity.ClientId, "31b403aa-c364-4240-a7ff-d85fb6cd7232")
|
||||
require.Equal(t, protSettings.ManagedIdentity.ObjectId, "41b403aa-c364-4240-a7ff-d85fb6cd7232")
|
||||
h = handlerSettings{publicSettings{}, *protSettings}
|
||||
require.Error(t, h.validate(), "settings should be invalid")
|
||||
}
|
||||
|
|
|
@ -19,14 +19,14 @@ const (
|
|||
"description": "Command to be executed",
|
||||
"type": "string"
|
||||
},
|
||||
"script": {
|
||||
"description": "Script to be executed",
|
||||
"type": "string"
|
||||
},
|
||||
"skipDos2Unix": {
|
||||
"description": "Skip DOS2UNIX and BOM removal for download files and script",
|
||||
"type": "boolean"
|
||||
},
|
||||
"script": {
|
||||
"description": "Script to be executed",
|
||||
"type": "string"
|
||||
},
|
||||
"skipDos2Unix": {
|
||||
"description": "Skip DOS2UNIX and BOM removal for download files and script",
|
||||
"type": "boolean"
|
||||
},
|
||||
"fileUris": {
|
||||
"description": "List of files to be downloaded",
|
||||
"type": "array",
|
||||
|
@ -52,7 +52,7 @@ const (
|
|||
"description": "Command to be executed",
|
||||
"type": "string"
|
||||
},
|
||||
"fileUris": {
|
||||
"fileUris": {
|
||||
"description": "List of files to be downloaded",
|
||||
"type": "array",
|
||||
"items": {
|
||||
|
@ -60,10 +60,10 @@ const (
|
|||
"format": "uri"
|
||||
}
|
||||
},
|
||||
"script": {
|
||||
"description": "Script to be executed",
|
||||
"type": "string"
|
||||
},
|
||||
"script": {
|
||||
"description": "Script to be executed",
|
||||
"type": "string"
|
||||
},
|
||||
"storageAccountName": {
|
||||
"description": "Name of the Azure Storage Account (3-24 characters of lowercase letters or digits)",
|
||||
"type": "string",
|
||||
|
@ -73,6 +73,22 @@ const (
|
|||
"description": "Key for the Azure Storage Account (a base64 encoded string)",
|
||||
"type": "string",
|
||||
"pattern": "^(?:[A-Za-z0-9+/]{4})*(?:[A-Za-z0-9+/]{2}==|[A-Za-z0-9+/]{3}=|[A-Za-z0-9+/]{4})$"
|
||||
},
|
||||
"managedIdentity": {
|
||||
"description": "Setting to use Managed Service Identity to try to download fileUri from azure blob",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"objectId": {
|
||||
"description": "Object id that identifies the user created managed identity",
|
||||
"type": "string",
|
||||
"pattern": "^(?:[0-9A-Fa-f]{8}[-][0-9A-Fa-f]{4}[-][0-9A-Fa-f]{4}[-][0-9A-Fa-f]{4}[-][0-9A-Fa-f]{12})$"
|
||||
},
|
||||
"clientId": {
|
||||
"description": "Client id that identifies the user created managed identity",
|
||||
"type": "string",
|
||||
"pattern": "^(?:[0-9A-Fa-f]{8}[-][0-9A-Fa-f]{4}[-][0-9A-Fa-f]{4}[-][0-9A-Fa-f]{4}[-][0-9A-Fa-f]{12})$"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
|
|
|
@ -144,3 +144,15 @@ func TestValidateProtectedSettings_storageAccountKey(t *testing.T) {
|
|||
require.Nil(t, validateProtectedSettings(`{"storageAccountKey": "A+hMRrsZQ6COPXTYX/EiKiF2HVtfhCfLDo3Dkc3ekKoX3jA58zXVG2QRe/C1+zdEFSrVX6FZsKyivsSlnwmWOw=="}`), "ok")
|
||||
require.Nil(t, validateProtectedSettings(`{"storageAccountKey": "/yGnx6KyxQ8Pjzk0QXeY+66Du0BeTWaCt83la59w72hu/81e6TzskXXvL/IlO3q6g0k0kJrR9MYQNi+cNR3SXA=="}`), "ok")
|
||||
}
|
||||
|
||||
func TestValidateProtectedSettings_managedServiceIdentity(t *testing.T) {
|
||||
require.NoError(t, validateProtectedSettings(`{"managedIdentity": { "clientId": "31b403aa-c364-4240-a7ff-d85fb6cd7232"}}`),
|
||||
"couldn't parse msi proprety with lowercase guid")
|
||||
require.NoError(t, validateProtectedSettings(`{"managedIdentity": { "objectId": "31B403AA-C364-4240-A7FF-D85FB6CD7232"}}`),
|
||||
"couldn't parse msi property with uppercase guid")
|
||||
require.NoError(t, validateProtectedSettings(`{"managedIdentity": { }}`),
|
||||
"couldn't parse msi property without clientId or objectId")
|
||||
|
||||
require.Error(t, validateProtectedSettings(`{"managedIdentity": { "clientId": "notaguid"}}`),
|
||||
"guid validation succeded when expected to fail")
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
<ExtensionImage xmlns="http://schemas.microsoft.com/windowsazure">
|
||||
<ProviderNameSpace>Microsoft.Azure.Extensions</ProviderNameSpace>
|
||||
<Type>CustomScript</Type>
|
||||
<Version>2.0.7</Version>
|
||||
<Version>2.1.0</Version>
|
||||
<Label>Microsoft Azure Custom Script Extension for Linux Virtual Machines</Label>
|
||||
<HostingResources>VmRole</HostingResources>
|
||||
<MediaLink></MediaLink>
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"fmt"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -70,9 +71,10 @@ func Test_blobDownload_fails_badCreds(t *testing.T) {
|
|||
Container: "foocontainer",
|
||||
})
|
||||
|
||||
_, err := Download(d)
|
||||
status, _, err := Download(d)
|
||||
require.NotNil(t, err)
|
||||
require.Contains(t, err.Error(), "unexpected status code: got=403")
|
||||
require.Contains(t, err.Error(), "unexpected status code: actual=403")
|
||||
require.Equal(t, status, http.StatusForbidden)
|
||||
}
|
||||
|
||||
func Test_blobDownload_fails_urlNotFound(t *testing.T) {
|
||||
|
@ -82,7 +84,7 @@ func Test_blobDownload_fails_urlNotFound(t *testing.T) {
|
|||
Container: "foocontainer",
|
||||
})
|
||||
|
||||
_, err := Download(d)
|
||||
_, _, err := Download(d)
|
||||
require.NotNil(t, err)
|
||||
require.Contains(t, err.Error(), "http request failed:")
|
||||
}
|
||||
|
@ -121,7 +123,7 @@ func Test_blobDownload_actualBlob(t *testing.T) {
|
|||
Blob: name,
|
||||
StorageBase: base,
|
||||
})
|
||||
body, err := Download(d)
|
||||
_, body, err := Download(d)
|
||||
require.Nil(t, err)
|
||||
defer body.Close()
|
||||
b, err := ioutil.ReadAll(body)
|
||||
|
|
|
@ -0,0 +1,116 @@
|
|||
package download
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/Azure/azure-extension-foundation/httputil"
|
||||
"github.com/Azure/azure-extension-foundation/msi"
|
||||
"github.com/pkg/errors"
|
||||
"net/http"
|
||||
url2 "net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
xMsVersionHeaderName = "x-ms-version"
|
||||
xMsVersionValue = "2018-03-28"
|
||||
storageResourceName = "https://storage.azure.com/"
|
||||
)
|
||||
|
||||
var azureBlobDomains = map[string]interface{}{ // golang doesn't have builtin hash sets, so this is a workaround for that
|
||||
"blob.core.windows.net": nil,
|
||||
"blob.core.chinacloudapi.cn": nil,
|
||||
"blob.core.usgovcloudapi.net": nil,
|
||||
"blob.core.couldapi.de": nil,
|
||||
}
|
||||
|
||||
type blobWithMsiToken struct {
|
||||
url string
|
||||
msiProvider MsiProvider
|
||||
}
|
||||
|
||||
type MsiProvider func() (msi.Msi, error)
|
||||
|
||||
func (self *blobWithMsiToken) GetRequest() (*http.Request, error) {
|
||||
msi, err := self.msiProvider()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if msi.AccessToken == "" {
|
||||
return nil, errors.New("MSI token was empty")
|
||||
}
|
||||
|
||||
request, err := http.NewRequest(http.MethodGet, self.url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if IsAzureStorageBlobUri(self.url) {
|
||||
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", msi.AccessToken))
|
||||
request.Header.Set(xMsVersionHeaderName, xMsVersionValue)
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func NewBlobWithMsiDownload(url string, msiProvider MsiProvider) Downloader {
|
||||
return &blobWithMsiToken{url, msiProvider}
|
||||
}
|
||||
|
||||
func GetMsiProviderForStorageAccountsImplicitly(blobUri string) MsiProvider {
|
||||
msiProvider := msi.NewMsiProvider(httputil.NewSecureHttpClient(httputil.DefaultRetryBehavior))
|
||||
return func() (msi.Msi, error) {
|
||||
msi, err := msiProvider.GetMsiForResource(GetResourceNameFromBlobUri(blobUri))
|
||||
if err != nil {
|
||||
return msi, errors.Wrapf(err, "Unable to get managed identity. "+
|
||||
"Please make sure that system assigned managed identity is enabled on the VM "+
|
||||
"or user assigned identity is added to the system.")
|
||||
}
|
||||
return msi, nil
|
||||
}
|
||||
}
|
||||
|
||||
func GetMsiProviderForStorageAccountsWithClientId(blobUri, clientId string) MsiProvider {
|
||||
msiProvider := msi.NewMsiProvider(httputil.NewSecureHttpClient(httputil.DefaultRetryBehavior))
|
||||
return func() (msi.Msi, error) {
|
||||
msi, err := msiProvider.GetMsiUsingClientId(clientId, GetResourceNameFromBlobUri(blobUri))
|
||||
if err != nil {
|
||||
return msi, errors.Wrapf(err, "Unable to get managed identity with client id %s. "+
|
||||
"Please make sure that the user assigned managed identity is added to the VM ", clientId)
|
||||
}
|
||||
return msi, nil
|
||||
}
|
||||
}
|
||||
|
||||
func GetMsiProviderForStorageAccountsWithObjectId(blobUri, objectId string) MsiProvider {
|
||||
msiProvider := msi.NewMsiProvider(httputil.NewSecureHttpClient(httputil.DefaultRetryBehavior))
|
||||
return func() (msi.Msi, error) {
|
||||
msi, err := msiProvider.GetMsiUsingObjectId(objectId, GetResourceNameFromBlobUri(blobUri))
|
||||
if err != nil {
|
||||
return msi, errors.Wrapf(err, "Unable to get managed identity with object id %s. "+
|
||||
"Please make sure that the user assigned managed identity is added to the VM ", objectId)
|
||||
}
|
||||
return msi, nil
|
||||
}
|
||||
}
|
||||
|
||||
func GetResourceNameFromBlobUri(uri string) string {
|
||||
// TODO: update this function as sovereign cloud blob resource strings become available
|
||||
// resource string for getting MSI for azure storage is still https://storage.azure.com/ for sovereign regions but it is expected to change
|
||||
return storageResourceName
|
||||
}
|
||||
|
||||
func IsAzureStorageBlobUri(url string) bool {
|
||||
// TODO update this function for sovereign regions
|
||||
parsedUrl, err := url2.Parse(url)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
s := strings.Split(parsedUrl.Hostname(), ".")
|
||||
if len(s) < 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
domainName := strings.Join(s[1:], ".")
|
||||
_, foundDomain := azureBlobDomains[domainName]
|
||||
return foundDomain
|
||||
|
||||
}
|
|
@ -0,0 +1,61 @@
|
|||
package download
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/Azure/azure-extension-foundation/msi"
|
||||
"github.com/stretchr/testify/require"
|
||||
"io/ioutil"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// README
|
||||
// to run this test, assign/create an azure VM with system assigned or user assigned identity
|
||||
// this is the machine that you'll get the msiJson from
|
||||
// assign "Storage Blob Data Reader" permissions to managed identity on a blob
|
||||
|
||||
var msiJson = `` // place the msi json here e.g.
|
||||
// {"access_token":<access token>","client_id":"31b403aa-c364-4240-a7ff-d85fb6cd7232","expires_in":"28799",
|
||||
// "expires_on":"1563607134","ext_expires_in":"28799","not_before":"1563578034","resource":"https://storage.azure.com/",
|
||||
// "token_type":"Bearer"}
|
||||
|
||||
// Linux command to get msi
|
||||
// curl 'http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fstorage.azure.com%2F' -H Metadata:true
|
||||
// curl 'http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fstorage.azure.com%2F&client_id=<client_id>' -H Metadata:true
|
||||
// curl 'http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fstorage.azure.com%2F&object_id=<object_id>' -H Metadata:true
|
||||
|
||||
// Powershell command to get msi
|
||||
// Invoke-RestMethod -Method "GET" -Headers @{"Metadata"=$true} "http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fstorage.azure.com%2F" | ConvertTo-Json
|
||||
// Invoke-RestMethod -Method "GET" -Headers @{"Metadata"=$true} "http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fstorage.azure.com%2F&client_id=<client_id>" | ConvertTo-Json
|
||||
// Invoke-RestMethod -Method "GET" -Headers @{"Metadata"=$true} "http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fstorage.azure.com%2F&object_id=<object_id>" | ConvertTo-Json
|
||||
|
||||
// the first command gets the system managed identity, or the user assigned identity if the VM has only one user assigned identity and no system assigned identity
|
||||
// the second command gets user assigned identity with its client id
|
||||
// the third command gets user assigned identity with its object id
|
||||
|
||||
var blobUri = "" // set the blob to download here e.g. https://storageaccount.blob.core.windows.net/container/blobname
|
||||
var stringToLookFor = "" // the string to look for in you blob
|
||||
|
||||
func Test_realDownloadBlobWithMsiToken(t *testing.T) {
|
||||
if msiJson == "" || blobUri == "" || stringToLookFor == "" {
|
||||
t.Skip()
|
||||
}
|
||||
downloader := blobWithMsiToken{blobUri, func() (msi.Msi, error) {
|
||||
msi := msi.Msi{}
|
||||
err := json.Unmarshal([]byte(msiJson), &msi)
|
||||
return msi, err
|
||||
}}
|
||||
_, stream, err := Download(&downloader)
|
||||
require.NoError(t, err, "File download failed")
|
||||
defer stream.Close()
|
||||
|
||||
bytes, err := ioutil.ReadAll(stream)
|
||||
require.NoError(t, err, "saving file stream to memory failed")
|
||||
require.Contains(t, string(bytes), stringToLookFor)
|
||||
}
|
||||
|
||||
func Test_isAzureStorageBlobUri(t *testing.T) {
|
||||
require.True(t, IsAzureStorageBlobUri("https://a.blob.core.windows.net/container/blobname"))
|
||||
require.True(t, IsAzureStorageBlobUri("http://mystorageaccountcn.blob.core.chinacloudapi.cn"))
|
||||
require.True(t, IsAzureStorageBlobUri("https://blackforestsa.blob.core.couldapi.de/c/b/x"))
|
||||
require.False(t, IsAzureStorageBlobUri("https://github.com/Azure-Samples/storage-blobs-go-quickstart/blob/master/README.md"))
|
||||
}
|
|
@ -37,20 +37,28 @@ var (
|
|||
// Download retrieves a response body and checks the response status code to see
|
||||
// if it is 200 OK and then returns the response body. It issues a new request
|
||||
// every time called. It is caller's responsibility to close the response body.
|
||||
func Download(d Downloader) (io.ReadCloser, error) {
|
||||
func Download(d Downloader) (int, io.ReadCloser, error) {
|
||||
req, err := d.GetRequest()
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "failed to create the request")
|
||||
return -1, nil, errors.Wrapf(err, "failed to create the request")
|
||||
}
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
err = urlutil.RemoveUrlFromErr(err)
|
||||
return nil, errors.Wrapf(err, "http request failed")
|
||||
return -1, nil, errors.Wrapf(err, "http request failed")
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("unexpected status code: got=%d expected=%d", resp.StatusCode, http.StatusOK)
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
return resp.StatusCode, resp.Body, nil
|
||||
}
|
||||
return resp.Body, nil
|
||||
|
||||
err = fmt.Errorf("unexpected status code: actual=%d expected=%d", resp.StatusCode, http.StatusOK)
|
||||
switch d.(type) {
|
||||
case *blobWithMsiToken:
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return resp.StatusCode, nil, errors.Wrapf(err, "please ensure that the blob location in the fileUri setting exists and the specified Managed Identity has read permissions to the storage blob")
|
||||
}
|
||||
}
|
||||
return resp.StatusCode, nil, err
|
||||
}
|
||||
|
|
|
@ -21,13 +21,13 @@ func (b *badDownloader) GetRequest() (*http.Request, error) {
|
|||
}
|
||||
|
||||
func TestDownload_wrapsGetRequestError(t *testing.T) {
|
||||
_, err := download.Download(new(badDownloader))
|
||||
_, _, err := download.Download(new(badDownloader))
|
||||
require.NotNil(t, err)
|
||||
require.EqualError(t, err, "failed to create the request: expected error")
|
||||
}
|
||||
|
||||
func TestDownload_wrapsHTTPError(t *testing.T) {
|
||||
_, err := download.Download(download.NewURLDownload("bad url"))
|
||||
_, _, err := download.Download(download.NewURLDownload("bad url"))
|
||||
require.NotNil(t, err)
|
||||
require.Contains(t, err.Error(), "http request failed:")
|
||||
}
|
||||
|
@ -44,7 +44,7 @@ func TestDownload_badStatusCodeFails(t *testing.T) {
|
|||
http.StatusBadRequest,
|
||||
http.StatusUnauthorized,
|
||||
} {
|
||||
_, err := download.Download(download.NewURLDownload(fmt.Sprintf("%s/status/%d", srv.URL, code)))
|
||||
_, _, err := download.Download(download.NewURLDownload(fmt.Sprintf("%s/status/%d", srv.URL, code)))
|
||||
require.NotNil(t, err, "not failed for code:%d", code)
|
||||
require.Contains(t, err.Error(), "unexpected status code", "wrong message for code %d", code)
|
||||
}
|
||||
|
@ -54,7 +54,7 @@ func TestDownload_statusOKSucceeds(t *testing.T) {
|
|||
srv := httptest.NewServer(httpbin.GetMux())
|
||||
defer srv.Close()
|
||||
|
||||
body, err := download.Download(download.NewURLDownload(srv.URL + "/status/200"))
|
||||
_, body, err := download.Download(download.NewURLDownload(srv.URL + "/status/200"))
|
||||
require.Nil(t, err)
|
||||
defer body.Close()
|
||||
require.NotNil(t, body)
|
||||
|
@ -64,7 +64,7 @@ func TestDownload_retrievesBody(t *testing.T) {
|
|||
srv := httptest.NewServer(httpbin.GetMux())
|
||||
defer srv.Close()
|
||||
|
||||
body, err := download.Download(download.NewURLDownload(srv.URL + "/bytes/65536"))
|
||||
_, body, err := download.Download(download.NewURLDownload(srv.URL + "/bytes/65536"))
|
||||
require.Nil(t, err)
|
||||
defer body.Close()
|
||||
b, err := ioutil.ReadAll(body)
|
||||
|
@ -76,7 +76,7 @@ func TestDownload_bodyClosesWithoutError(t *testing.T) {
|
|||
srv := httptest.NewServer(httpbin.GetMux())
|
||||
defer srv.Close()
|
||||
|
||||
body, err := download.Download(download.NewURLDownload(srv.URL + "/get"))
|
||||
_, body, err := download.Download(download.NewURLDownload(srv.URL + "/get"))
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, body.Close(), "body should close fine")
|
||||
}
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
package download
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-kit/kit/log"
|
||||
|
@ -29,27 +31,51 @@ const (
|
|||
// closed on failures). If the retries do not succeed, the last error is returned.
|
||||
//
|
||||
// It sleeps in exponentially increasing durations between retries.
|
||||
func WithRetries(ctx *log.Context, d Downloader, sf SleepFunc) (io.ReadCloser, error) {
|
||||
func WithRetries(ctx *log.Context, downloaders []Downloader, sf SleepFunc) (io.ReadCloser, error) {
|
||||
var lastErr error
|
||||
for n := 0; n < expRetryN; n++ {
|
||||
ctx := ctx.With("retry", n)
|
||||
out, err := Download(d)
|
||||
if err == nil {
|
||||
return out, nil
|
||||
}
|
||||
lastErr = err
|
||||
ctx.Log("error", err)
|
||||
for _, d := range downloaders {
|
||||
for n := 0; n < expRetryN; n++ {
|
||||
ctx := ctx.With("retry", n)
|
||||
status, out, err := Download(d)
|
||||
if err == nil {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
if out != nil { // we are not going to read this response body
|
||||
out.Close()
|
||||
}
|
||||
lastErr = err
|
||||
ctx.Log("error", err)
|
||||
|
||||
if n != expRetryN-1 {
|
||||
// have more retries to go, sleep before retrying
|
||||
slp := expRetryK * time.Duration(int(math.Pow(float64(expRetryM), float64(n))))
|
||||
ctx.Log("sleep", slp)
|
||||
sf(slp)
|
||||
if out != nil { // we are not going to read this response body
|
||||
out.Close()
|
||||
}
|
||||
|
||||
// status == -1 the value when there was no http request
|
||||
if status != -1 && !isTransientHttpStatusCode(status) {
|
||||
ctx.Log("info", fmt.Sprintf("downloader %T returned %v, skipping retries", d, status))
|
||||
break
|
||||
}
|
||||
|
||||
if n != expRetryN-1 {
|
||||
// have more retries to go, sleep before retrying
|
||||
slp := expRetryK * time.Duration(int(math.Pow(float64(expRetryM), float64(n))))
|
||||
ctx.Log("sleep", slp)
|
||||
sf(slp)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
func isTransientHttpStatusCode(statusCode int) bool {
|
||||
switch statusCode {
|
||||
case
|
||||
http.StatusRequestTimeout, // 408
|
||||
http.StatusTooManyRequests, // 429
|
||||
http.StatusInternalServerError, // 500
|
||||
http.StatusBadGateway, // 502
|
||||
http.StatusServiceUnavailable, // 503
|
||||
http.StatusGatewayTimeout: // 504
|
||||
return true // timeout and too many requests
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
|
|
@ -37,7 +37,7 @@ func TestWithRetries_noRetries(t *testing.T) {
|
|||
d := download.NewURLDownload(srv.URL + "/status/200")
|
||||
|
||||
sr := new(sleepRecorder)
|
||||
resp, err := download.WithRetries(nopLog(), d, sr.Sleep)
|
||||
resp, err := download.WithRetries(nopLog(), []download.Downloader{d}, sr.Sleep)
|
||||
require.Nil(t, err, "should not fail")
|
||||
require.NotNil(t, resp, "response body exists")
|
||||
require.Equal(t, []time.Duration(nil), []time.Duration(*sr), "sleep should not be called")
|
||||
|
@ -48,7 +48,7 @@ func TestWithRetries_failing_validateNumberOfCalls(t *testing.T) {
|
|||
defer srv.Close()
|
||||
|
||||
bd := new(badDownloader)
|
||||
_, err := download.WithRetries(nopLog(), bd, new(sleepRecorder).Sleep)
|
||||
_, err := download.WithRetries(nopLog(), []download.Downloader{bd}, new(sleepRecorder).Sleep)
|
||||
require.Contains(t, err.Error(), "expected error", "error is preserved")
|
||||
require.EqualValues(t, 7, bd.calls, "calls exactly expRetryN times")
|
||||
}
|
||||
|
@ -57,11 +57,11 @@ func TestWithRetries_failingBadStatusCode_validateSleeps(t *testing.T) {
|
|||
srv := httptest.NewServer(httpbin.GetMux())
|
||||
defer srv.Close()
|
||||
|
||||
d := download.NewURLDownload(srv.URL + "/status/404")
|
||||
d := download.NewURLDownload(srv.URL + "/status/429")
|
||||
|
||||
sr := new(sleepRecorder)
|
||||
_, err := download.WithRetries(nopLog(), d, sr.Sleep)
|
||||
require.EqualError(t, err, "unexpected status code: got=404 expected=200")
|
||||
_, err := download.WithRetries(nopLog(), []download.Downloader{d}, sr.Sleep)
|
||||
require.EqualError(t, err, "unexpected status code: actual=429 expected=200")
|
||||
|
||||
require.Equal(t, sleepSchedule, []time.Duration(*sr))
|
||||
}
|
||||
|
@ -72,15 +72,38 @@ func TestWithRetries_healingServer(t *testing.T) {
|
|||
|
||||
d := download.NewURLDownload(srv.URL)
|
||||
sr := new(sleepRecorder)
|
||||
resp, err := download.WithRetries(nopLog(), d, sr.Sleep)
|
||||
resp, err := download.WithRetries(nopLog(), []download.Downloader{d}, sr.Sleep)
|
||||
require.Nil(t, err, "should eventually succeed")
|
||||
require.NotNil(t, resp, "response body exists")
|
||||
|
||||
require.Equal(t, sleepSchedule[:3], []time.Duration(*sr))
|
||||
}
|
||||
|
||||
func TestRetriesWith_SwitchDownloaderOn404(t *testing.T) {
|
||||
svr := httptest.NewServer(httpbin.GetMux())
|
||||
hSvr := httptest.NewServer(new(healingServer))
|
||||
defer svr.Close()
|
||||
d404 := mockDownloader{0, svr.URL + "/status/404"}
|
||||
d200 := mockDownloader{0, hSvr.URL}
|
||||
resp, err := download.WithRetries(nopLog(), []download.Downloader{&d404, &d200}, func(d time.Duration) { return })
|
||||
require.Nil(t, err, "should eventually succeed")
|
||||
require.NotNil(t, resp, "response body exists")
|
||||
require.Equal(t, d404.timesCalled, 1)
|
||||
require.Equal(t, d200.timesCalled, 4)
|
||||
}
|
||||
|
||||
// Test Utilities:
|
||||
|
||||
type mockDownloader struct {
|
||||
timesCalled int
|
||||
url string
|
||||
}
|
||||
|
||||
func (self *mockDownloader) GetRequest() (*http.Request, error) {
|
||||
self.timesCalled++
|
||||
return http.NewRequest("GET", self.url, nil)
|
||||
}
|
||||
|
||||
// sleepRecorder keeps track of the durations of Sleep calls
|
||||
type sleepRecorder []time.Duration
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ const (
|
|||
// given file. Directory of dst is not created by this function. If a file at
|
||||
// dst exists, it will be truncated. If a new file is created, mode is used to
|
||||
// set the permission bits. Written number of bytes are returned on success.
|
||||
func SaveTo(ctx *log.Context, d Downloader, dst string, mode os.FileMode) (int64, error) {
|
||||
func SaveTo(ctx *log.Context, d []Downloader, dst string, mode os.FileMode) (int64, error) {
|
||||
f, err := os.OpenFile(dst, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, mode)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "failed to open file for writing")
|
||||
|
|
|
@ -19,7 +19,7 @@ func TestSaveTo_invalidDir(t *testing.T) {
|
|||
|
||||
d := download.NewURLDownload(srv.URL + "/bytes/65536")
|
||||
|
||||
_, err := download.SaveTo(nopLog(), d, "/nonexistent-dir/dst", 0600)
|
||||
_, err := download.SaveTo(nopLog(), []download.Downloader{d}, "/nonexistent-dir/dst", 0600)
|
||||
require.Contains(t, err.Error(), "failed to open file for writing")
|
||||
}
|
||||
|
||||
|
@ -33,7 +33,7 @@ func TestSave(t *testing.T) {
|
|||
|
||||
d := download.NewURLDownload(srv.URL + "/bytes/65536")
|
||||
path := filepath.Join(dir, "test-file")
|
||||
n, err := download.SaveTo(nopLog(), d, path, 0600)
|
||||
n, err := download.SaveTo(nopLog(), []download.Downloader{d}, path, 0600)
|
||||
require.Nil(t, err)
|
||||
require.EqualValues(t, 65536, n)
|
||||
|
||||
|
@ -52,9 +52,9 @@ func TestSave_truncates(t *testing.T) {
|
|||
defer os.RemoveAll(dir)
|
||||
|
||||
path := filepath.Join(dir, "test-file")
|
||||
_, err = download.SaveTo(nopLog(), download.NewURLDownload(srv.URL+"/bytes/65536"), path, 0600)
|
||||
_, err = download.SaveTo(nopLog(), []download.Downloader{download.NewURLDownload(srv.URL + "/bytes/65536")}, path, 0600)
|
||||
require.Nil(t, err)
|
||||
_, err = download.SaveTo(nopLog(), download.NewURLDownload(srv.URL+"/bytes/128"), path, 0777)
|
||||
_, err = download.SaveTo(nopLog(), []download.Downloader{download.NewURLDownload(srv.URL + "/bytes/128")}, path, 0777)
|
||||
require.Nil(t, err)
|
||||
|
||||
fi, err := os.Stat(path)
|
||||
|
@ -74,7 +74,7 @@ func TestSave_largeFile(t *testing.T) {
|
|||
size := 1024 * 1024 * 128 // 128 mb
|
||||
|
||||
path := filepath.Join(dir, "large-file")
|
||||
n, err := download.SaveTo(nopLog(), download.NewURLDownload(srv.URL+"/bytes/"+fmt.Sprintf("%d", size)), path, 0600)
|
||||
n, err := download.SaveTo(nopLog(), []download.Downloader{download.NewURLDownload(srv.URL + "/bytes/" + fmt.Sprintf("%d", size))}, path, 0600)
|
||||
require.Nil(t, err)
|
||||
require.EqualValues(t, size, n)
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче