Merge pull request #151 from Azure/bhbrahma/msi

RBAC support for downloading scripts with managed identities
This commit is contained in:
Bhaskar Brahma 2019-08-14 11:39:19 -07:00 коммит произвёл GitHub
Родитель 1f9c51c15e 2b26ef7bfb
Коммит fc5210a6d8
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
19 изменённых файлов: 520 добавлений и 95 удалений

Просмотреть файл

@ -13,6 +13,7 @@ install:
- sudo npm install -g azure-cli
- go get -u golang.org/x/lint/golint
- go get -u github.com/ahmetalpbalkan/govvv
- go get -u -d github.com/Azure/azure-extension-foundation/...
before_script:
- docker version
- docker info

Просмотреть файл

@ -15,6 +15,7 @@ binary: clean
echo "GOPATH is not set"; \
exit 1; \
fi
go get -d -u -f github.com/Azure/azure-extension-foundation/...
GOOS=linux GOARCH=amd64 govvv build -v \
-ldflags "-X main.Version=`grep -E -m 1 -o '<Version>(.*)</Version>' misc/manifest.xml | awk -F">" '{print $$2}' | awk -F"<" '{print $$1}'`" \
-o $(BINDIR)/$(BIN) ./main

Просмотреть файл

@ -230,7 +230,7 @@ func downloadFiles(ctx *log.Context, dir string, cfg handlerSettings) error {
for i, f := range cfg.fileUrls() {
ctx := ctx.With("file", i)
ctx.Log("event", "download start")
if err := downloadAndProcessURL(ctx, f, dir, cfg.StorageAccountName, cfg.StorageAccountKey, cfg.publicSettings.SkipDos2Unix); err != nil {
if err := downloadAndProcessURL(ctx, f, dir, &cfg); err != nil {
ctx.Log("event", "download failed", "error", err)
return errors.Wrapf(err, "failed to download file[%d]", i)
}

Просмотреть файл

@ -19,7 +19,7 @@ import (
// downloadAndProcessURL downloads using the specified downloader and saves it to the
// specified existing directory, which must be the path to the saved file. Then
// it post-processes file based on heuristics.
func downloadAndProcessURL(ctx *log.Context, url, downloadDir, storageAccountName, storageAccountKey string, skipDos2Unix bool) error {
func downloadAndProcessURL(ctx *log.Context, url, downloadDir string, cfg *handlerSettings) error {
fn, err := urlToFileName(url)
if err != nil {
return err
@ -29,7 +29,7 @@ func downloadAndProcessURL(ctx *log.Context, url, downloadDir, storageAccountNam
return fmt.Errorf("[REDACTED] is not a valid url")
}
dl, err := getDownloader(url, storageAccountName, storageAccountKey)
dl, err := getDownloaders(url, cfg.StorageAccountName, cfg.StorageAccountKey, cfg.ManagedIdentity)
if err != nil {
return err
}
@ -40,7 +40,7 @@ func downloadAndProcessURL(ctx *log.Context, url, downloadDir, storageAccountNam
return err
}
if skipDos2Unix == false {
if cfg.SkipDos2Unix == false {
err = postProcessFile(fp)
}
return errors.Wrapf(err, "failed to post-process '%s'", fn)
@ -48,19 +48,46 @@ func downloadAndProcessURL(ctx *log.Context, url, downloadDir, storageAccountNam
// getDownloader returns a downloader for the given URL based on whether the
// storage credentials are empty or not.
func getDownloader(fileURL string, storageAccountName, storageAccountKey string) (
download.Downloader, error) {
func getDownloaders(fileURL string, storageAccountName, storageAccountKey string, managedIdentity *clientOrObjectId) (
[]download.Downloader, error) {
if storageAccountName == "" || storageAccountKey == "" {
return download.NewURLDownload(fileURL), nil
// storage account name and key cannot be specified with managed identity, handler settings validation won't allow that
// handler settings validation will also not allow storageAccountName XOR storageAccountKey == 1
// in this case, we can be sure that storage account name and key was not specified
if download.IsAzureStorageBlobUri(fileURL) && managedIdentity != nil {
// if managed identity was specified in the configuration, try to use it to download the files
var msiProvider download.MsiProvider
switch {
case managedIdentity.ClientId == "" && managedIdentity.ObjectId == "":
// get msi using clientId or objectId or implicitly
msiProvider = download.GetMsiProviderForStorageAccountsImplicitly(fileURL)
case managedIdentity.ClientId != "" && managedIdentity.ObjectId == "":
msiProvider = download.GetMsiProviderForStorageAccountsWithClientId(fileURL, managedIdentity.ClientId)
case managedIdentity.ClientId == "" && managedIdentity.ObjectId != "":
msiProvider = download.GetMsiProviderForStorageAccountsWithObjectId(fileURL, managedIdentity.ObjectId)
default:
return nil, fmt.Errorf("unexpected combination of ClientId and ObjectId found")
}
return []download.Downloader{
// try downloading without MSI token first, but attempt with MSI if the download fails
download.NewURLDownload(fileURL),
download.NewBlobWithMsiDownload(fileURL, msiProvider),
}, nil
} else {
// do not use MSI downloader if the uri is not azure storage blob, or managedIdentity isn't specified
return []download.Downloader{download.NewURLDownload(fileURL)}, nil
}
} else {
// if storage name account and key are specified, use that for all files
// this preserves old behavior
blob, err := blobutil.ParseBlobURL(fileURL)
if err != nil {
return nil, err
}
return []download.Downloader{download.NewBlobDownload(
storageAccountName, storageAccountKey,
blob)}, nil
}
blob, err := blobutil.ParseBlobURL(fileURL)
if err != nil {
return nil, err
}
return download.NewBlobDownload(
storageAccountName, storageAccountKey,
blob), nil
}
// urlToFileName parses given URL and returns the section after the last slash

Просмотреть файл

@ -15,31 +15,44 @@ import (
func Test_getDownloader_azureBlob(t *testing.T) {
// error condition
_, err := getDownloader("http://acct.blob.core.windows.net/", "acct", "key")
_, err := getDownloaders("http://acct.blob.core.windows.net/", "acct", "key", nil)
require.NotNil(t, err)
// valid input
d, err := getDownloader("http://acct.blob.core.windows.net/container/blob", "acct", "key")
d, err := getDownloaders("http://acct.blob.core.windows.net/container/blob", "acct", "key", nil)
require.Nil(t, err)
require.NotNil(t, d)
require.Equal(t, "download.blobDownload", fmt.Sprintf("%T", d), "got wrong type")
require.Equal(t, 1, len(d))
require.Equal(t, "download.blobDownload", fmt.Sprintf("%T", d[0]), "got wrong type")
}
func Test_getDownloader_externalUrl(t *testing.T) {
d, err := getDownloader("http://acct.blob.core.windows.net/", "", "")
d, err := getDownloaders("http://acct.blob.core.windows.net/", "", "", nil)
require.Nil(t, err)
require.NotNil(t, d)
require.Equal(t, "download.urlDownload", fmt.Sprintf("%T", d), "got wrong type")
require.NotEmpty(t, d)
require.Equal(t, 1, len(d))
require.Equal(t, "download.urlDownload", fmt.Sprintf("%T", d[0]), "got wrong type")
d, err = getDownloader("http://acct.blob.core.windows.net/", "foo", "")
d, err = getDownloaders("http://acct.blob.core.windows.net/", "", "", &clientOrObjectId{"", "dummyclientid"})
require.Nil(t, err)
require.NotNil(t, d)
require.Equal(t, "download.urlDownload", fmt.Sprintf("%T", d), "got wrong type")
require.NotEmpty(t, d)
require.Equal(t, 2, len(d))
require.Equal(t, "download.urlDownload", fmt.Sprintf("%T", d[0]), "got wrong type")
require.Equal(t, "*download.blobWithMsiToken", fmt.Sprintf("%T", d[1]), "got wrong type")
d, err = getDownloader("http://acct.blob.core.windows.net/", "", "bar")
d, err = getDownloaders("http://acct.blob.core.windows.net/", "foo", "", nil)
require.Nil(t, err)
require.NotNil(t, d)
require.Equal(t, "download.urlDownload", fmt.Sprintf("%T", d), "got wrong type")
require.Equal(t, 1, len(d))
require.Equal(t, "download.urlDownload", fmt.Sprintf("%T", d[0]), "got wrong type")
d, err = getDownloaders("http://acct.blob.core.windows.net/", "", "bar", nil)
require.Nil(t, err)
require.NotNil(t, d)
require.Equal(t, 1, len(d))
require.Equal(t, "download.urlDownload", fmt.Sprintf("%T", d[0]), "got wrong type")
}
func Test_urlToFileName_badURL(t *testing.T) {
@ -110,7 +123,8 @@ func Test_downloadAndProcessURL(t *testing.T) {
require.Nil(t, err)
defer os.RemoveAll(tmpDir)
err = downloadAndProcessURL(log.NewContext(log.NewNopLogger()), srv.URL+"/bytes/256", tmpDir, "", "", false)
cfg := handlerSettings{publicSettings{}, protectedSettings{StorageAccountName: "", StorageAccountKey: ""}}
err = downloadAndProcessURL(log.NewContext(log.NewNopLogger()), srv.URL+"/bytes/256", tmpDir, &cfg)
require.Nil(t, err)
fp := filepath.Join(tmpDir, "256")

Просмотреть файл

@ -9,11 +9,13 @@ import (
)
var (
errStoragePartialCredentials = errors.New("both 'storageAccountName' and 'storageAccountKey' must be specified")
errCmdTooMany = errors.New("'commandToExecute' was specified both in public and protected settings; it must be specified only once")
errScriptTooMany = errors.New("'script' was specified both in public and protected settings; it must be specified only once")
errCmdAndScript = errors.New("'commandToExecute' and 'script' were both specified, but only one is validate at a time")
errCmdMissing = errors.New("'commandToExecute' is not specified")
errStoragePartialCredentials = errors.New("both 'storageAccountName' and 'storageAccountKey' must be specified")
errCmdTooMany = errors.New("'commandToExecute' was specified both in public and protected settings; it must be specified only once")
errScriptTooMany = errors.New("'script' was specified both in public and protected settings; it must be specified only once")
errCmdAndScript = errors.New("'commandToExecute' and 'script' were both specified, but only one is validate at a time")
errCmdMissing = errors.New("'commandToExecute' is not specified")
errUsingBothKeyAndMsi = errors.New("'storageAccountName' or 'storageAccountKey' must not be specified with 'managedServiceIdentity'")
errUsingBothClientIdAndObjectId = errors.New("only one of 'clientId' or 'objectId' must be specified with 'managedServiceIdentity'")
)
// handlerSettings holds the configuration of the extension handler.
@ -66,6 +68,16 @@ func (h handlerSettings) validate() error {
return errStoragePartialCredentials
}
if (h.protectedSettings.StorageAccountKey != "" || h.protectedSettings.StorageAccountName != "") && h.protectedSettings.ManagedIdentity != nil {
return errUsingBothKeyAndMsi
}
if h.protectedSettings.ManagedIdentity != nil {
if h.protectedSettings.ManagedIdentity.ClientId != "" && h.protectedSettings.ManagedIdentity.ObjectId != "" {
return errUsingBothClientIdAndObjectId
}
}
return nil
}
@ -81,11 +93,21 @@ type publicSettings struct {
// protectedSettings is the type decoded and deserialized from protected
// configuration section. This should be in sync with protectedSettingsSchema.
type protectedSettings struct {
CommandToExecute string `json:"commandToExecute"`
Script string `json:"script"`
FileURLs []string `json:"fileUris"`
StorageAccountName string `json:"storageAccountName"`
StorageAccountKey string `json:"storageAccountKey"`
CommandToExecute string `json:"commandToExecute"`
Script string `json:"script"`
FileURLs []string `json:"fileUris"`
StorageAccountName string `json:"storageAccountName"`
StorageAccountKey string `json:"storageAccountKey"`
ManagedIdentity *clientOrObjectId `json:"managedIdentity"`
}
type clientOrObjectId struct {
ObjectId string `json:"objectId"`
ClientId string `json:"clientId"`
}
func (self *clientOrObjectId) isEmpty() bool {
return self.ClientId == "" && self.ObjectId == ""
}
// parseAndValidateSettings reads configuration from configFolder, decrypts it,

Просмотреть файл

@ -1,6 +1,9 @@
package main
import "testing"
import (
"encoding/json"
"testing"
)
import "github.com/stretchr/testify/require"
func Test_handlerSettingsValidate(t *testing.T) {
@ -86,6 +89,44 @@ func Test_skipDos2UnixDefaultsToFalse(t *testing.T) {
require.Equal(t, false, testSubject.SkipDos2Unix)
}
func Test_managedIdentityVerification(t *testing.T) {
require.NoError(t, handlerSettings{publicSettings{}, protectedSettings{
CommandToExecute: "echo hi",
FileURLs: []string{"file1", "file2"},
ManagedIdentity: &clientOrObjectId{
ClientId: "31b403aa-c364-4240-a7ff-d85fb6cd7232",
},
}}.validate(), "validation failed for settings with MSI")
require.NoError(t, handlerSettings{publicSettings{}, protectedSettings{
CommandToExecute: "echo hi",
ManagedIdentity: &clientOrObjectId{
ObjectId: "31b403aa-c364-4240-a7ff-d85fb6cd7232",
},
}}.validate(), "validation failed for settings with MSI")
require.Equal(t, errUsingBothKeyAndMsi,
handlerSettings{publicSettings{},
protectedSettings{
CommandToExecute: "echo hi",
StorageAccountName: "name",
StorageAccountKey: "key",
ManagedIdentity: &clientOrObjectId{
ObjectId: "31b403aa-c364-4240-a7ff-d85fb6cd7232",
},
}}.validate(), "validation didn't fail for settings with both MSI and storage account")
require.Equal(t, errUsingBothClientIdAndObjectId,
handlerSettings{publicSettings{},
protectedSettings{
CommandToExecute: "echo hi",
ManagedIdentity: &clientOrObjectId{
ObjectId: "31b403aa-c364-4240-a7ff-d85fb6cd7232",
ClientId: "31b403aa-c364-4240-a7ff-d85fb6cd7232",
},
}}.validate(), "validation didn't fail for settings with both MSI and storage account")
}
func Test_toJSON_empty(t *testing.T) {
s, err := toJSON(nil)
require.Nil(t, err)
@ -98,3 +139,58 @@ func Test_toJSON(t *testing.T) {
require.Nil(t, err)
require.Equal(t, `{"a":3}`, s)
}
func Test_toJSONUmarshallForManagedIdentity(t *testing.T) {
testString := `{"commandToExecute" : "echo hello", "fileUris":["https://a.com/file.txt", "https://b.com/file2.txt"]}`
require.NoError(t, validateProtectedSettings(testString), "protected settings should be valid")
protSettings := new(protectedSettings)
err := json.Unmarshal([]byte(testString), protSettings)
require.NoError(t, err, "error while deserializing json")
require.Nil(t, protSettings.ManagedIdentity, "ProtectedSettings.ManagedIdentity was expected to be nil")
h := handlerSettings{publicSettings{}, *protSettings}
require.NoError(t, h.validate(), "settings should be valid")
testString = `{"commandToExecute" : "echo hello", "fileUris":["https://a.com/file.txt"], "managedIdentity": { }}`
require.NoError(t, validateProtectedSettings(testString), "protected settings should be valid")
protSettings = new(protectedSettings)
err = json.Unmarshal([]byte(testString), protSettings)
require.NoError(t, err, "error while deserializing json")
require.NotNil(t, protSettings.ManagedIdentity, "ProtectedSettings.ManagedIdentity was expected to not be nil")
require.Equal(t, protSettings.ManagedIdentity.ClientId, "")
require.Equal(t, protSettings.ManagedIdentity.ObjectId, "")
h = handlerSettings{publicSettings{}, *protSettings}
require.NoError(t, h.validate(), "settings should be valid")
testString = `{"commandToExecute" : "echo hello", "fileUris":["https://a.com/file.txt", "https://b.com/file2.txt"], "managedIdentity": { "clientId": "31b403aa-c364-4240-a7ff-d85fb6cd7232"}}`
require.NoError(t, validateProtectedSettings(testString), "protected settings should be valid")
protSettings = new(protectedSettings)
err = json.Unmarshal([]byte(testString), protSettings)
require.NoError(t, err, "error while deserializing json")
require.NotNil(t, protSettings.ManagedIdentity, "ProtectedSettings.ManagedIdentity was expected to not be nil")
require.Equal(t, protSettings.ManagedIdentity.ClientId, "31b403aa-c364-4240-a7ff-d85fb6cd7232")
require.Equal(t, protSettings.ManagedIdentity.ObjectId, "")
h = handlerSettings{publicSettings{}, *protSettings}
require.NoError(t, h.validate(), "settings should be valid")
testString = `{"commandToExecute" : "echo hello", "fileUris":["https://a.com/file.txt"], "managedIdentity": { "objectId": "31b403aa-c364-4240-a7ff-d85fb6cd7232"}}`
require.NoError(t, validateProtectedSettings(testString), "protected settings should be valid")
protSettings = new(protectedSettings)
err = json.Unmarshal([]byte(testString), protSettings)
require.NoError(t, err, "error while deserializing json")
require.NotNil(t, protSettings.ManagedIdentity, "ProtectedSettings.ManagedIdentity was expected to not be nil")
require.Equal(t, protSettings.ManagedIdentity.ObjectId, "31b403aa-c364-4240-a7ff-d85fb6cd7232")
require.Equal(t, protSettings.ManagedIdentity.ClientId, "")
h = handlerSettings{publicSettings{}, *protSettings}
require.NoError(t, h.validate(), "settings should be valid")
testString = `{"commandToExecute" : "echo hello", "fileUris":["https://a.com/file.txt", "https://b.com/file2.txt"], "managedIdentity": { "clientId": "31b403aa-c364-4240-a7ff-d85fb6cd7232", "objectId": "41b403aa-c364-4240-a7ff-d85fb6cd7232"}}`
require.NoError(t, validateProtectedSettings(testString), "protected settings should be valid")
protSettings = new(protectedSettings)
err = json.Unmarshal([]byte(testString), protSettings)
require.NoError(t, err, "error while deserializing json")
require.NotNil(t, protSettings.ManagedIdentity, "ProtectedSettings.ManagedIdentity was expected to not be nil")
require.Equal(t, protSettings.ManagedIdentity.ClientId, "31b403aa-c364-4240-a7ff-d85fb6cd7232")
require.Equal(t, protSettings.ManagedIdentity.ObjectId, "41b403aa-c364-4240-a7ff-d85fb6cd7232")
h = handlerSettings{publicSettings{}, *protSettings}
require.Error(t, h.validate(), "settings should be invalid")
}

Просмотреть файл

@ -19,14 +19,14 @@ const (
"description": "Command to be executed",
"type": "string"
},
"script": {
"description": "Script to be executed",
"type": "string"
},
"skipDos2Unix": {
"description": "Skip DOS2UNIX and BOM removal for download files and script",
"type": "boolean"
},
"script": {
"description": "Script to be executed",
"type": "string"
},
"skipDos2Unix": {
"description": "Skip DOS2UNIX and BOM removal for download files and script",
"type": "boolean"
},
"fileUris": {
"description": "List of files to be downloaded",
"type": "array",
@ -52,7 +52,7 @@ const (
"description": "Command to be executed",
"type": "string"
},
"fileUris": {
"fileUris": {
"description": "List of files to be downloaded",
"type": "array",
"items": {
@ -60,10 +60,10 @@ const (
"format": "uri"
}
},
"script": {
"description": "Script to be executed",
"type": "string"
},
"script": {
"description": "Script to be executed",
"type": "string"
},
"storageAccountName": {
"description": "Name of the Azure Storage Account (3-24 characters of lowercase letters or digits)",
"type": "string",
@ -73,6 +73,22 @@ const (
"description": "Key for the Azure Storage Account (a base64 encoded string)",
"type": "string",
"pattern": "^(?:[A-Za-z0-9+/]{4})*(?:[A-Za-z0-9+/]{2}==|[A-Za-z0-9+/]{3}=|[A-Za-z0-9+/]{4})$"
},
"managedIdentity": {
"description": "Setting to use Managed Service Identity to try to download fileUri from azure blob",
"type": "object",
"properties": {
"objectId": {
"description": "Object id that identifies the user created managed identity",
"type": "string",
"pattern": "^(?:[0-9A-Fa-f]{8}[-][0-9A-Fa-f]{4}[-][0-9A-Fa-f]{4}[-][0-9A-Fa-f]{4}[-][0-9A-Fa-f]{12})$"
},
"clientId": {
"description": "Client id that identifies the user created managed identity",
"type": "string",
"pattern": "^(?:[0-9A-Fa-f]{8}[-][0-9A-Fa-f]{4}[-][0-9A-Fa-f]{4}[-][0-9A-Fa-f]{4}[-][0-9A-Fa-f]{12})$"
}
}
}
},
"additionalProperties": false

Просмотреть файл

@ -144,3 +144,15 @@ func TestValidateProtectedSettings_storageAccountKey(t *testing.T) {
require.Nil(t, validateProtectedSettings(`{"storageAccountKey": "A+hMRrsZQ6COPXTYX/EiKiF2HVtfhCfLDo3Dkc3ekKoX3jA58zXVG2QRe/C1+zdEFSrVX6FZsKyivsSlnwmWOw=="}`), "ok")
require.Nil(t, validateProtectedSettings(`{"storageAccountKey": "/yGnx6KyxQ8Pjzk0QXeY+66Du0BeTWaCt83la59w72hu/81e6TzskXXvL/IlO3q6g0k0kJrR9MYQNi+cNR3SXA=="}`), "ok")
}
func TestValidateProtectedSettings_managedServiceIdentity(t *testing.T) {
require.NoError(t, validateProtectedSettings(`{"managedIdentity": { "clientId": "31b403aa-c364-4240-a7ff-d85fb6cd7232"}}`),
"couldn't parse msi proprety with lowercase guid")
require.NoError(t, validateProtectedSettings(`{"managedIdentity": { "objectId": "31B403AA-C364-4240-A7FF-D85FB6CD7232"}}`),
"couldn't parse msi property with uppercase guid")
require.NoError(t, validateProtectedSettings(`{"managedIdentity": { }}`),
"couldn't parse msi property without clientId or objectId")
require.Error(t, validateProtectedSettings(`{"managedIdentity": { "clientId": "notaguid"}}`),
"guid validation succeded when expected to fail")
}

Просмотреть файл

@ -2,7 +2,7 @@
<ExtensionImage xmlns="http://schemas.microsoft.com/windowsazure">
<ProviderNameSpace>Microsoft.Azure.Extensions</ProviderNameSpace>
<Type>CustomScript</Type>
<Version>2.0.7</Version>
<Version>2.1.0</Version>
<Label>Microsoft Azure Custom Script Extension for Linux Virtual Machines</Label>
<HostingResources>VmRole</HostingResources>
<MediaLink></MediaLink>

Просмотреть файл

@ -4,6 +4,7 @@ import (
"fmt"
"io/ioutil"
"math/rand"
"net/http"
"os"
"testing"
"time"
@ -70,9 +71,10 @@ func Test_blobDownload_fails_badCreds(t *testing.T) {
Container: "foocontainer",
})
_, err := Download(d)
status, _, err := Download(d)
require.NotNil(t, err)
require.Contains(t, err.Error(), "unexpected status code: got=403")
require.Contains(t, err.Error(), "unexpected status code: actual=403")
require.Equal(t, status, http.StatusForbidden)
}
func Test_blobDownload_fails_urlNotFound(t *testing.T) {
@ -82,7 +84,7 @@ func Test_blobDownload_fails_urlNotFound(t *testing.T) {
Container: "foocontainer",
})
_, err := Download(d)
_, _, err := Download(d)
require.NotNil(t, err)
require.Contains(t, err.Error(), "http request failed:")
}
@ -121,7 +123,7 @@ func Test_blobDownload_actualBlob(t *testing.T) {
Blob: name,
StorageBase: base,
})
body, err := Download(d)
_, body, err := Download(d)
require.Nil(t, err)
defer body.Close()
b, err := ioutil.ReadAll(body)

Просмотреть файл

@ -0,0 +1,116 @@
package download
import (
"fmt"
"github.com/Azure/azure-extension-foundation/httputil"
"github.com/Azure/azure-extension-foundation/msi"
"github.com/pkg/errors"
"net/http"
url2 "net/url"
"strings"
)
const (
xMsVersionHeaderName = "x-ms-version"
xMsVersionValue = "2018-03-28"
storageResourceName = "https://storage.azure.com/"
)
var azureBlobDomains = map[string]interface{}{ // golang doesn't have builtin hash sets, so this is a workaround for that
"blob.core.windows.net": nil,
"blob.core.chinacloudapi.cn": nil,
"blob.core.usgovcloudapi.net": nil,
"blob.core.couldapi.de": nil,
}
type blobWithMsiToken struct {
url string
msiProvider MsiProvider
}
type MsiProvider func() (msi.Msi, error)
func (self *blobWithMsiToken) GetRequest() (*http.Request, error) {
msi, err := self.msiProvider()
if err != nil {
return nil, err
}
if msi.AccessToken == "" {
return nil, errors.New("MSI token was empty")
}
request, err := http.NewRequest(http.MethodGet, self.url, nil)
if err != nil {
return nil, err
}
if IsAzureStorageBlobUri(self.url) {
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", msi.AccessToken))
request.Header.Set(xMsVersionHeaderName, xMsVersionValue)
}
return request, nil
}
func NewBlobWithMsiDownload(url string, msiProvider MsiProvider) Downloader {
return &blobWithMsiToken{url, msiProvider}
}
func GetMsiProviderForStorageAccountsImplicitly(blobUri string) MsiProvider {
msiProvider := msi.NewMsiProvider(httputil.NewSecureHttpClient(httputil.DefaultRetryBehavior))
return func() (msi.Msi, error) {
msi, err := msiProvider.GetMsiForResource(GetResourceNameFromBlobUri(blobUri))
if err != nil {
return msi, errors.Wrapf(err, "Unable to get managed identity. "+
"Please make sure that system assigned managed identity is enabled on the VM "+
"or user assigned identity is added to the system.")
}
return msi, nil
}
}
func GetMsiProviderForStorageAccountsWithClientId(blobUri, clientId string) MsiProvider {
msiProvider := msi.NewMsiProvider(httputil.NewSecureHttpClient(httputil.DefaultRetryBehavior))
return func() (msi.Msi, error) {
msi, err := msiProvider.GetMsiUsingClientId(clientId, GetResourceNameFromBlobUri(blobUri))
if err != nil {
return msi, errors.Wrapf(err, "Unable to get managed identity with client id %s. "+
"Please make sure that the user assigned managed identity is added to the VM ", clientId)
}
return msi, nil
}
}
func GetMsiProviderForStorageAccountsWithObjectId(blobUri, objectId string) MsiProvider {
msiProvider := msi.NewMsiProvider(httputil.NewSecureHttpClient(httputil.DefaultRetryBehavior))
return func() (msi.Msi, error) {
msi, err := msiProvider.GetMsiUsingObjectId(objectId, GetResourceNameFromBlobUri(blobUri))
if err != nil {
return msi, errors.Wrapf(err, "Unable to get managed identity with object id %s. "+
"Please make sure that the user assigned managed identity is added to the VM ", objectId)
}
return msi, nil
}
}
func GetResourceNameFromBlobUri(uri string) string {
// TODO: update this function as sovereign cloud blob resource strings become available
// resource string for getting MSI for azure storage is still https://storage.azure.com/ for sovereign regions but it is expected to change
return storageResourceName
}
func IsAzureStorageBlobUri(url string) bool {
// TODO update this function for sovereign regions
parsedUrl, err := url2.Parse(url)
if err != nil {
return false
}
s := strings.Split(parsedUrl.Hostname(), ".")
if len(s) < 2 {
return false
}
domainName := strings.Join(s[1:], ".")
_, foundDomain := azureBlobDomains[domainName]
return foundDomain
}

Просмотреть файл

@ -0,0 +1,61 @@
package download
import (
"encoding/json"
"github.com/Azure/azure-extension-foundation/msi"
"github.com/stretchr/testify/require"
"io/ioutil"
"testing"
)
// README
// to run this test, assign/create an azure VM with system assigned or user assigned identity
// this is the machine that you'll get the msiJson from
// assign "Storage Blob Data Reader" permissions to managed identity on a blob
var msiJson = `` // place the msi json here e.g.
// {"access_token":<access token>","client_id":"31b403aa-c364-4240-a7ff-d85fb6cd7232","expires_in":"28799",
// "expires_on":"1563607134","ext_expires_in":"28799","not_before":"1563578034","resource":"https://storage.azure.com/",
// "token_type":"Bearer"}
// Linux command to get msi
// curl 'http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fstorage.azure.com%2F' -H Metadata:true
// curl 'http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fstorage.azure.com%2F&client_id=<client_id>' -H Metadata:true
// curl 'http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fstorage.azure.com%2F&object_id=<object_id>' -H Metadata:true
// Powershell command to get msi
// Invoke-RestMethod -Method "GET" -Headers @{"Metadata"=$true} "http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fstorage.azure.com%2F" | ConvertTo-Json
// Invoke-RestMethod -Method "GET" -Headers @{"Metadata"=$true} "http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fstorage.azure.com%2F&client_id=<client_id>" | ConvertTo-Json
// Invoke-RestMethod -Method "GET" -Headers @{"Metadata"=$true} "http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fstorage.azure.com%2F&object_id=<object_id>" | ConvertTo-Json
// the first command gets the system managed identity, or the user assigned identity if the VM has only one user assigned identity and no system assigned identity
// the second command gets user assigned identity with its client id
// the third command gets user assigned identity with its object id
var blobUri = "" // set the blob to download here e.g. https://storageaccount.blob.core.windows.net/container/blobname
var stringToLookFor = "" // the string to look for in you blob
func Test_realDownloadBlobWithMsiToken(t *testing.T) {
if msiJson == "" || blobUri == "" || stringToLookFor == "" {
t.Skip()
}
downloader := blobWithMsiToken{blobUri, func() (msi.Msi, error) {
msi := msi.Msi{}
err := json.Unmarshal([]byte(msiJson), &msi)
return msi, err
}}
_, stream, err := Download(&downloader)
require.NoError(t, err, "File download failed")
defer stream.Close()
bytes, err := ioutil.ReadAll(stream)
require.NoError(t, err, "saving file stream to memory failed")
require.Contains(t, string(bytes), stringToLookFor)
}
func Test_isAzureStorageBlobUri(t *testing.T) {
require.True(t, IsAzureStorageBlobUri("https://a.blob.core.windows.net/container/blobname"))
require.True(t, IsAzureStorageBlobUri("http://mystorageaccountcn.blob.core.chinacloudapi.cn"))
require.True(t, IsAzureStorageBlobUri("https://blackforestsa.blob.core.couldapi.de/c/b/x"))
require.False(t, IsAzureStorageBlobUri("https://github.com/Azure-Samples/storage-blobs-go-quickstart/blob/master/README.md"))
}

Просмотреть файл

@ -37,20 +37,28 @@ var (
// Download retrieves a response body and checks the response status code to see
// if it is 200 OK and then returns the response body. It issues a new request
// every time called. It is caller's responsibility to close the response body.
func Download(d Downloader) (io.ReadCloser, error) {
func Download(d Downloader) (int, io.ReadCloser, error) {
req, err := d.GetRequest()
if err != nil {
return nil, errors.Wrapf(err, "failed to create the request")
return -1, nil, errors.Wrapf(err, "failed to create the request")
}
resp, err := httpClient.Do(req)
if err != nil {
err = urlutil.RemoveUrlFromErr(err)
return nil, errors.Wrapf(err, "http request failed")
return -1, nil, errors.Wrapf(err, "http request failed")
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code: got=%d expected=%d", resp.StatusCode, http.StatusOK)
if resp.StatusCode == http.StatusOK {
return resp.StatusCode, resp.Body, nil
}
return resp.Body, nil
err = fmt.Errorf("unexpected status code: actual=%d expected=%d", resp.StatusCode, http.StatusOK)
switch d.(type) {
case *blobWithMsiToken:
if resp.StatusCode == http.StatusNotFound {
return resp.StatusCode, nil, errors.Wrapf(err, "please ensure that the blob location in the fileUri setting exists and the specified Managed Identity has read permissions to the storage blob")
}
}
return resp.StatusCode, nil, err
}

Просмотреть файл

@ -21,13 +21,13 @@ func (b *badDownloader) GetRequest() (*http.Request, error) {
}
func TestDownload_wrapsGetRequestError(t *testing.T) {
_, err := download.Download(new(badDownloader))
_, _, err := download.Download(new(badDownloader))
require.NotNil(t, err)
require.EqualError(t, err, "failed to create the request: expected error")
}
func TestDownload_wrapsHTTPError(t *testing.T) {
_, err := download.Download(download.NewURLDownload("bad url"))
_, _, err := download.Download(download.NewURLDownload("bad url"))
require.NotNil(t, err)
require.Contains(t, err.Error(), "http request failed:")
}
@ -44,7 +44,7 @@ func TestDownload_badStatusCodeFails(t *testing.T) {
http.StatusBadRequest,
http.StatusUnauthorized,
} {
_, err := download.Download(download.NewURLDownload(fmt.Sprintf("%s/status/%d", srv.URL, code)))
_, _, err := download.Download(download.NewURLDownload(fmt.Sprintf("%s/status/%d", srv.URL, code)))
require.NotNil(t, err, "not failed for code:%d", code)
require.Contains(t, err.Error(), "unexpected status code", "wrong message for code %d", code)
}
@ -54,7 +54,7 @@ func TestDownload_statusOKSucceeds(t *testing.T) {
srv := httptest.NewServer(httpbin.GetMux())
defer srv.Close()
body, err := download.Download(download.NewURLDownload(srv.URL + "/status/200"))
_, body, err := download.Download(download.NewURLDownload(srv.URL + "/status/200"))
require.Nil(t, err)
defer body.Close()
require.NotNil(t, body)
@ -64,7 +64,7 @@ func TestDownload_retrievesBody(t *testing.T) {
srv := httptest.NewServer(httpbin.GetMux())
defer srv.Close()
body, err := download.Download(download.NewURLDownload(srv.URL + "/bytes/65536"))
_, body, err := download.Download(download.NewURLDownload(srv.URL + "/bytes/65536"))
require.Nil(t, err)
defer body.Close()
b, err := ioutil.ReadAll(body)
@ -76,7 +76,7 @@ func TestDownload_bodyClosesWithoutError(t *testing.T) {
srv := httptest.NewServer(httpbin.GetMux())
defer srv.Close()
body, err := download.Download(download.NewURLDownload(srv.URL + "/get"))
_, body, err := download.Download(download.NewURLDownload(srv.URL + "/get"))
require.Nil(t, err)
require.Nil(t, body.Close(), "body should close fine")
}

Просмотреть файл

@ -1,8 +1,10 @@
package download
import (
"fmt"
"io"
"math"
"net/http"
"time"
"github.com/go-kit/kit/log"
@ -29,27 +31,51 @@ const (
// closed on failures). If the retries do not succeed, the last error is returned.
//
// It sleeps in exponentially increasing durations between retries.
func WithRetries(ctx *log.Context, d Downloader, sf SleepFunc) (io.ReadCloser, error) {
func WithRetries(ctx *log.Context, downloaders []Downloader, sf SleepFunc) (io.ReadCloser, error) {
var lastErr error
for n := 0; n < expRetryN; n++ {
ctx := ctx.With("retry", n)
out, err := Download(d)
if err == nil {
return out, nil
}
lastErr = err
ctx.Log("error", err)
for _, d := range downloaders {
for n := 0; n < expRetryN; n++ {
ctx := ctx.With("retry", n)
status, out, err := Download(d)
if err == nil {
return out, nil
}
if out != nil { // we are not going to read this response body
out.Close()
}
lastErr = err
ctx.Log("error", err)
if n != expRetryN-1 {
// have more retries to go, sleep before retrying
slp := expRetryK * time.Duration(int(math.Pow(float64(expRetryM), float64(n))))
ctx.Log("sleep", slp)
sf(slp)
if out != nil { // we are not going to read this response body
out.Close()
}
// status == -1 the value when there was no http request
if status != -1 && !isTransientHttpStatusCode(status) {
ctx.Log("info", fmt.Sprintf("downloader %T returned %v, skipping retries", d, status))
break
}
if n != expRetryN-1 {
// have more retries to go, sleep before retrying
slp := expRetryK * time.Duration(int(math.Pow(float64(expRetryM), float64(n))))
ctx.Log("sleep", slp)
sf(slp)
}
}
}
return nil, lastErr
}
func isTransientHttpStatusCode(statusCode int) bool {
switch statusCode {
case
http.StatusRequestTimeout, // 408
http.StatusTooManyRequests, // 429
http.StatusInternalServerError, // 500
http.StatusBadGateway, // 502
http.StatusServiceUnavailable, // 503
http.StatusGatewayTimeout: // 504
return true // timeout and too many requests
default:
return false
}
}

Просмотреть файл

@ -37,7 +37,7 @@ func TestWithRetries_noRetries(t *testing.T) {
d := download.NewURLDownload(srv.URL + "/status/200")
sr := new(sleepRecorder)
resp, err := download.WithRetries(nopLog(), d, sr.Sleep)
resp, err := download.WithRetries(nopLog(), []download.Downloader{d}, sr.Sleep)
require.Nil(t, err, "should not fail")
require.NotNil(t, resp, "response body exists")
require.Equal(t, []time.Duration(nil), []time.Duration(*sr), "sleep should not be called")
@ -48,7 +48,7 @@ func TestWithRetries_failing_validateNumberOfCalls(t *testing.T) {
defer srv.Close()
bd := new(badDownloader)
_, err := download.WithRetries(nopLog(), bd, new(sleepRecorder).Sleep)
_, err := download.WithRetries(nopLog(), []download.Downloader{bd}, new(sleepRecorder).Sleep)
require.Contains(t, err.Error(), "expected error", "error is preserved")
require.EqualValues(t, 7, bd.calls, "calls exactly expRetryN times")
}
@ -57,11 +57,11 @@ func TestWithRetries_failingBadStatusCode_validateSleeps(t *testing.T) {
srv := httptest.NewServer(httpbin.GetMux())
defer srv.Close()
d := download.NewURLDownload(srv.URL + "/status/404")
d := download.NewURLDownload(srv.URL + "/status/429")
sr := new(sleepRecorder)
_, err := download.WithRetries(nopLog(), d, sr.Sleep)
require.EqualError(t, err, "unexpected status code: got=404 expected=200")
_, err := download.WithRetries(nopLog(), []download.Downloader{d}, sr.Sleep)
require.EqualError(t, err, "unexpected status code: actual=429 expected=200")
require.Equal(t, sleepSchedule, []time.Duration(*sr))
}
@ -72,15 +72,38 @@ func TestWithRetries_healingServer(t *testing.T) {
d := download.NewURLDownload(srv.URL)
sr := new(sleepRecorder)
resp, err := download.WithRetries(nopLog(), d, sr.Sleep)
resp, err := download.WithRetries(nopLog(), []download.Downloader{d}, sr.Sleep)
require.Nil(t, err, "should eventually succeed")
require.NotNil(t, resp, "response body exists")
require.Equal(t, sleepSchedule[:3], []time.Duration(*sr))
}
func TestRetriesWith_SwitchDownloaderOn404(t *testing.T) {
svr := httptest.NewServer(httpbin.GetMux())
hSvr := httptest.NewServer(new(healingServer))
defer svr.Close()
d404 := mockDownloader{0, svr.URL + "/status/404"}
d200 := mockDownloader{0, hSvr.URL}
resp, err := download.WithRetries(nopLog(), []download.Downloader{&d404, &d200}, func(d time.Duration) { return })
require.Nil(t, err, "should eventually succeed")
require.NotNil(t, resp, "response body exists")
require.Equal(t, d404.timesCalled, 1)
require.Equal(t, d200.timesCalled, 4)
}
// Test Utilities:
type mockDownloader struct {
timesCalled int
url string
}
func (self *mockDownloader) GetRequest() (*http.Request, error) {
self.timesCalled++
return http.NewRequest("GET", self.url, nil)
}
// sleepRecorder keeps track of the durations of Sleep calls
type sleepRecorder []time.Duration

Просмотреть файл

@ -16,7 +16,7 @@ const (
// given file. Directory of dst is not created by this function. If a file at
// dst exists, it will be truncated. If a new file is created, mode is used to
// set the permission bits. Written number of bytes are returned on success.
func SaveTo(ctx *log.Context, d Downloader, dst string, mode os.FileMode) (int64, error) {
func SaveTo(ctx *log.Context, d []Downloader, dst string, mode os.FileMode) (int64, error) {
f, err := os.OpenFile(dst, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, mode)
if err != nil {
return 0, errors.Wrap(err, "failed to open file for writing")

Просмотреть файл

@ -19,7 +19,7 @@ func TestSaveTo_invalidDir(t *testing.T) {
d := download.NewURLDownload(srv.URL + "/bytes/65536")
_, err := download.SaveTo(nopLog(), d, "/nonexistent-dir/dst", 0600)
_, err := download.SaveTo(nopLog(), []download.Downloader{d}, "/nonexistent-dir/dst", 0600)
require.Contains(t, err.Error(), "failed to open file for writing")
}
@ -33,7 +33,7 @@ func TestSave(t *testing.T) {
d := download.NewURLDownload(srv.URL + "/bytes/65536")
path := filepath.Join(dir, "test-file")
n, err := download.SaveTo(nopLog(), d, path, 0600)
n, err := download.SaveTo(nopLog(), []download.Downloader{d}, path, 0600)
require.Nil(t, err)
require.EqualValues(t, 65536, n)
@ -52,9 +52,9 @@ func TestSave_truncates(t *testing.T) {
defer os.RemoveAll(dir)
path := filepath.Join(dir, "test-file")
_, err = download.SaveTo(nopLog(), download.NewURLDownload(srv.URL+"/bytes/65536"), path, 0600)
_, err = download.SaveTo(nopLog(), []download.Downloader{download.NewURLDownload(srv.URL + "/bytes/65536")}, path, 0600)
require.Nil(t, err)
_, err = download.SaveTo(nopLog(), download.NewURLDownload(srv.URL+"/bytes/128"), path, 0777)
_, err = download.SaveTo(nopLog(), []download.Downloader{download.NewURLDownload(srv.URL + "/bytes/128")}, path, 0777)
require.Nil(t, err)
fi, err := os.Stat(path)
@ -74,7 +74,7 @@ func TestSave_largeFile(t *testing.T) {
size := 1024 * 1024 * 128 // 128 mb
path := filepath.Join(dir, "large-file")
n, err := download.SaveTo(nopLog(), download.NewURLDownload(srv.URL+"/bytes/"+fmt.Sprintf("%d", size)), path, 0600)
n, err := download.SaveTo(nopLog(), []download.Downloader{download.NewURLDownload(srv.URL + "/bytes/" + fmt.Sprintf("%d", size))}, path, 0600)
require.Nil(t, err)
require.EqualValues(t, size, n)