Added domain name validation for sovereign regions and additional tests
This commit is contained in:
Родитель
2ea9cc88e6
Коммит
10710fb27d
|
@ -60,11 +60,11 @@ func getDownloaders(fileURL string, storageAccountName, storageAccountKey string
|
|||
switch {
|
||||
case managedIdentity.ClientId == "" && managedIdentity.ObjectId == "":
|
||||
// get msi using clientId or objectId or implicitly
|
||||
msiProvider = download.GetMsiProviderForStorageAccountsImplicitly()
|
||||
msiProvider = download.GetMsiProviderForStorageAccountsImplicitly(fileURL)
|
||||
case managedIdentity.ClientId != "" && managedIdentity.ObjectId == "":
|
||||
msiProvider = download.GetMsiProviderForStorageAccountsWithClientId(managedIdentity.ClientId)
|
||||
msiProvider = download.GetMsiProviderForStorageAccountsWithClientId(fileURL, managedIdentity.ClientId)
|
||||
case managedIdentity.ClientId == "" && managedIdentity.ObjectId != "":
|
||||
msiProvider = download.GetMsiProviderForStorageAccountsWithObjectId(managedIdentity.ObjectId)
|
||||
msiProvider = download.GetMsiProviderForStorageAccountsWithObjectId(fileURL, managedIdentity.ObjectId)
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected combination of ClientId and ObjectId found")
|
||||
}
|
||||
|
|
|
@ -13,10 +13,16 @@ import (
|
|||
const (
|
||||
xMsVersionHeaderName = "x-ms-version"
|
||||
xMsVersionValue = "2018-03-28"
|
||||
azureBlobDomainName = ".blob.core.windows.net"
|
||||
storageResourceName = "https://storage.azure.com/"
|
||||
)
|
||||
|
||||
var azureBlobDomains = map[string]interface{}{ // golang doesn't have builtin hash sets, so this is a workaround for that
|
||||
"blob.core.windows.net": nil,
|
||||
"blob.core.chinacloudapi.cn": nil,
|
||||
"blob.core.usgovcloudapi.net": nil,
|
||||
"blob.core.couldapi.de": nil,
|
||||
}
|
||||
|
||||
type blobWithMsiToken struct {
|
||||
url string
|
||||
msiProvider MsiProvider
|
||||
|
@ -49,19 +55,29 @@ func NewBlobWithMsiDownload(url string, msiProvider MsiProvider) Downloader {
|
|||
return &blobWithMsiToken{url, msiProvider}
|
||||
}
|
||||
|
||||
func GetMsiProviderForStorageAccountsImplicitly() MsiProvider {
|
||||
func GetMsiProviderForStorageAccountsImplicitly(blobUri string) MsiProvider {
|
||||
msiProvider := msi.NewMsiProvider(httputil.NewSecureHttpClient(httputil.DefaultRetryBehavior))
|
||||
return func() (msi.Msi, error) { return msiProvider.GetMsiForResource(storageResourceName) }
|
||||
return func() (msi.Msi, error) { return msiProvider.GetMsiForResource(GetResourceNameFromBlobUri(blobUri)) }
|
||||
}
|
||||
|
||||
func GetMsiProviderForStorageAccountsWithClientId(clientId string) MsiProvider {
|
||||
func GetMsiProviderForStorageAccountsWithClientId(blobUri, clientId string) MsiProvider {
|
||||
msiProvider := msi.NewMsiProvider(httputil.NewSecureHttpClient(httputil.DefaultRetryBehavior))
|
||||
return func() (msi.Msi, error) { return msiProvider.GetMsiUsingClientId(clientId, storageResourceName) }
|
||||
return func() (msi.Msi, error) {
|
||||
return msiProvider.GetMsiUsingClientId(clientId, GetResourceNameFromBlobUri(blobUri))
|
||||
}
|
||||
}
|
||||
|
||||
func GetMsiProviderForStorageAccountsWithObjectId(objectId string) MsiProvider {
|
||||
func GetMsiProviderForStorageAccountsWithObjectId(blobUri, objectId string) MsiProvider {
|
||||
msiProvider := msi.NewMsiProvider(httputil.NewSecureHttpClient(httputil.DefaultRetryBehavior))
|
||||
return func() (msi.Msi, error) { return msiProvider.GetMsiUsingObjectId(objectId, storageResourceName) }
|
||||
return func() (msi.Msi, error) {
|
||||
return msiProvider.GetMsiUsingObjectId(objectId, GetResourceNameFromBlobUri(blobUri))
|
||||
}
|
||||
}
|
||||
|
||||
func GetResourceNameFromBlobUri(uri string) string {
|
||||
// TODO: update this function as sovereign cloud blob resource strings become available
|
||||
// resource string for getting MSI for azure storage is still https://storage.azure.com/ for sovereign regions but it is expected to change
|
||||
return storageResourceName
|
||||
}
|
||||
|
||||
func IsAzureStorageBlobUri(url string) bool {
|
||||
|
@ -70,5 +86,13 @@ func IsAzureStorageBlobUri(url string) bool {
|
|||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return strings.HasSuffix(parsedUrl.Host, azureBlobDomainName)
|
||||
s := strings.Split(parsedUrl.Hostname(), ".")
|
||||
if len(s) < 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
domainName := strings.Join(s[1:], ".")
|
||||
_, foundDomain := azureBlobDomains[domainName]
|
||||
return foundDomain
|
||||
|
||||
}
|
||||
|
|
|
@ -54,6 +54,8 @@ func Test_realDownloadBlobWithMsiToken(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_isAzureStorageBlobUri(t *testing.T) {
|
||||
require.True(t, IsAzureStorageBlobUri("https://a.blob.core.windows.net"))
|
||||
require.True(t, IsAzureStorageBlobUri("https://a.blob.core.windows.net/container/blobname"))
|
||||
require.True(t, IsAzureStorageBlobUri("http://mystorageaccountcn.blob.core.chinacloudapi.cn"))
|
||||
require.True(t, IsAzureStorageBlobUri("https://blackforestsa.blob.core.couldapi.de/c/b/x"))
|
||||
require.False(t, IsAzureStorageBlobUri("https://github.com/Azure-Samples/storage-blobs-go-quickstart/blob/master/README.md"))
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче