From 10710fb27d7fb5061d9193f3f6edbb8b82b1a5a8 Mon Sep 17 00:00:00 2001 From: Bhaskar Brahma Date: Mon, 29 Jul 2019 13:51:40 -0700 Subject: [PATCH] Added domain name validation for sovereign regions and additional tests --- main/files.go | 6 ++-- pkg/download/blobwithmsitoken.go | 40 +++++++++++++++++++++------ pkg/download/blobwithmsitoken_test.go | 4 ++- 3 files changed, 38 insertions(+), 12 deletions(-) diff --git a/main/files.go b/main/files.go index c52e0d6..71dbb64 100644 --- a/main/files.go +++ b/main/files.go @@ -60,11 +60,11 @@ func getDownloaders(fileURL string, storageAccountName, storageAccountKey string switch { case managedIdentity.ClientId == "" && managedIdentity.ObjectId == "": // get msi using clientId or objectId or implicitly - msiProvider = download.GetMsiProviderForStorageAccountsImplicitly() + msiProvider = download.GetMsiProviderForStorageAccountsImplicitly(fileURL) case managedIdentity.ClientId != "" && managedIdentity.ObjectId == "": - msiProvider = download.GetMsiProviderForStorageAccountsWithClientId(managedIdentity.ClientId) + msiProvider = download.GetMsiProviderForStorageAccountsWithClientId(fileURL, managedIdentity.ClientId) case managedIdentity.ClientId == "" && managedIdentity.ObjectId != "": - msiProvider = download.GetMsiProviderForStorageAccountsWithObjectId(managedIdentity.ObjectId) + msiProvider = download.GetMsiProviderForStorageAccountsWithObjectId(fileURL, managedIdentity.ObjectId) default: return nil, fmt.Errorf("unexpected combination of ClientId and ObjectId found") } diff --git a/pkg/download/blobwithmsitoken.go b/pkg/download/blobwithmsitoken.go index 0996714..970a2b7 100644 --- a/pkg/download/blobwithmsitoken.go +++ b/pkg/download/blobwithmsitoken.go @@ -13,10 +13,16 @@ import ( const ( xMsVersionHeaderName = "x-ms-version" xMsVersionValue = "2018-03-28" - azureBlobDomainName = ".blob.core.windows.net" storageResourceName = "https://storage.azure.com/" ) +var azureBlobDomains = map[string]interface{}{ // golang doesn't have builtin hash sets, so this is a workaround for that + "blob.core.windows.net": nil, + "blob.core.chinacloudapi.cn": nil, + "blob.core.usgovcloudapi.net": nil, + "blob.core.couldapi.de": nil, +} + type blobWithMsiToken struct { url string msiProvider MsiProvider @@ -49,19 +55,29 @@ func NewBlobWithMsiDownload(url string, msiProvider MsiProvider) Downloader { return &blobWithMsiToken{url, msiProvider} } -func GetMsiProviderForStorageAccountsImplicitly() MsiProvider { +func GetMsiProviderForStorageAccountsImplicitly(blobUri string) MsiProvider { msiProvider := msi.NewMsiProvider(httputil.NewSecureHttpClient(httputil.DefaultRetryBehavior)) - return func() (msi.Msi, error) { return msiProvider.GetMsiForResource(storageResourceName) } + return func() (msi.Msi, error) { return msiProvider.GetMsiForResource(GetResourceNameFromBlobUri(blobUri)) } } -func GetMsiProviderForStorageAccountsWithClientId(clientId string) MsiProvider { +func GetMsiProviderForStorageAccountsWithClientId(blobUri, clientId string) MsiProvider { msiProvider := msi.NewMsiProvider(httputil.NewSecureHttpClient(httputil.DefaultRetryBehavior)) - return func() (msi.Msi, error) { return msiProvider.GetMsiUsingClientId(clientId, storageResourceName) } + return func() (msi.Msi, error) { + return msiProvider.GetMsiUsingClientId(clientId, GetResourceNameFromBlobUri(blobUri)) + } } -func GetMsiProviderForStorageAccountsWithObjectId(objectId string) MsiProvider { +func GetMsiProviderForStorageAccountsWithObjectId(blobUri, objectId string) MsiProvider { msiProvider := msi.NewMsiProvider(httputil.NewSecureHttpClient(httputil.DefaultRetryBehavior)) - return func() (msi.Msi, error) { return msiProvider.GetMsiUsingObjectId(objectId, storageResourceName) } + return func() (msi.Msi, error) { + return msiProvider.GetMsiUsingObjectId(objectId, GetResourceNameFromBlobUri(blobUri)) + } +} + +func GetResourceNameFromBlobUri(uri string) string { + // TODO: update this function as sovereign cloud blob resource strings become available + // resource string for getting MSI for azure storage is still https://storage.azure.com/ for sovereign regions but it is expected to change + return storageResourceName } func IsAzureStorageBlobUri(url string) bool { @@ -70,5 +86,13 @@ func IsAzureStorageBlobUri(url string) bool { if err != nil { return false } - return strings.HasSuffix(parsedUrl.Host, azureBlobDomainName) + s := strings.Split(parsedUrl.Hostname(), ".") + if len(s) < 2 { + return false + } + + domainName := strings.Join(s[1:], ".") + _, foundDomain := azureBlobDomains[domainName] + return foundDomain + } diff --git a/pkg/download/blobwithmsitoken_test.go b/pkg/download/blobwithmsitoken_test.go index a2b2b27..e7d7b5c 100644 --- a/pkg/download/blobwithmsitoken_test.go +++ b/pkg/download/blobwithmsitoken_test.go @@ -54,6 +54,8 @@ func Test_realDownloadBlobWithMsiToken(t *testing.T) { } func Test_isAzureStorageBlobUri(t *testing.T) { - require.True(t, IsAzureStorageBlobUri("https://a.blob.core.windows.net")) + require.True(t, IsAzureStorageBlobUri("https://a.blob.core.windows.net/container/blobname")) + require.True(t, IsAzureStorageBlobUri("http://mystorageaccountcn.blob.core.chinacloudapi.cn")) + require.True(t, IsAzureStorageBlobUri("https://blackforestsa.blob.core.couldapi.de/c/b/x")) require.False(t, IsAzureStorageBlobUri("https://github.com/Azure-Samples/storage-blobs-go-quickstart/blob/master/README.md")) }