Test for system identity vs otherwise
This commit is contained in:
Родитель
86e1e1da55
Коммит
0a95bbb7fd
|
@ -17,8 +17,6 @@ import (
|
|||
)
|
||||
|
||||
var UseMockSASDownloadFailure bool = false
|
||||
var UseMockGetDownloaders = false
|
||||
var ReturnErrorForMockGetDownloaders = false
|
||||
|
||||
// downloadAndProcessURL downloads using the specified downloader and saves it to the
|
||||
// specified existing directory, which must be the path to the saved file. Then
|
||||
|
@ -52,7 +50,7 @@ func downloadAndProcessURL(ctx *log.Context, url, downloadDir string, cfg *handl
|
|||
|
||||
//If there was an error downloading using SAS URI or SAS was not provided, download using managedIdentity or publicly.
|
||||
if scriptSASDownloadErr != nil || scriptSAS == "" {
|
||||
downloaders, getDownloadersError := getDownloaders(url, cfg.SourceManagedIdentity)
|
||||
downloaders, getDownloadersError := getDownloaders(url, cfg.SourceManagedIdentity, download.ProdMsiDownloader{})
|
||||
if getDownloadersError == nil {
|
||||
const mode = 0500 // we assume users download scripts to execute
|
||||
_, err = download.SaveTo(ctx, downloaders, targetFilePath, mode)
|
||||
|
@ -76,7 +74,7 @@ func downloadAndProcessURL(ctx *log.Context, url, downloadDir string, cfg *handl
|
|||
// getDownloaders returns one or two downloaders (two if it is an Azure storage blob):
|
||||
// 1. Downloader for script using public URI.
|
||||
// 2. Downloader for script using managed identity.
|
||||
func getDownloaders(fileURL string, managedIdentity *RunCommandManagedIdentity) ([]download.Downloader, error) {
|
||||
func getDownloaders(fileURL string, managedIdentity *RunCommandManagedIdentity, msiDownloader download.MsiDownloader) ([]download.Downloader, error) {
|
||||
|
||||
if fileURL == "" {
|
||||
return nil, fmt.Errorf("fileURL is empty.")
|
||||
|
@ -86,23 +84,19 @@ func getDownloaders(fileURL string, managedIdentity *RunCommandManagedIdentity)
|
|||
// if managed identity was specified in the configuration, try to use it to download the files
|
||||
var msiProvider download.MsiProvider
|
||||
|
||||
if UseMockGetDownloaders {
|
||||
msiProvider = download.GetMockMsiProvider(managedIdentity.ClientId, ReturnErrorForMockGetDownloaders)
|
||||
} else {
|
||||
switch {
|
||||
case managedIdentity == nil || (managedIdentity.ClientId == "" && managedIdentity.ObjectId == ""):
|
||||
// get msi Provider for blob url implicitly (uses system managed identity)
|
||||
msiProvider = download.GetMsiProviderForStorageAccountsImplicitly(fileURL)
|
||||
switch {
|
||||
case managedIdentity == nil || (managedIdentity.ClientId == "" && managedIdentity.ObjectId == ""):
|
||||
// get msi Provider for blob url implicitly (uses system managed identity)
|
||||
msiProvider = msiDownloader.GetMsiProvider(fileURL)
|
||||
|
||||
case managedIdentity.ClientId != "" && managedIdentity.ObjectId == "":
|
||||
// uses user-managed identity
|
||||
msiProvider = download.GetMsiProviderForStorageAccountsWithClientId(fileURL, managedIdentity.ClientId)
|
||||
case managedIdentity.ClientId == "" && managedIdentity.ObjectId != "":
|
||||
// uses user-managed identity
|
||||
msiProvider = download.GetMsiProviderForStorageAccountsWithObjectId(fileURL, managedIdentity.ObjectId)
|
||||
default:
|
||||
return nil, fmt.Errorf("Use either ClientId or ObjectId for managed identity. Not both.")
|
||||
}
|
||||
case managedIdentity.ClientId != "" && managedIdentity.ObjectId == "":
|
||||
// uses user-managed identity
|
||||
msiProvider = msiDownloader.GetMsiProviderByClientId(fileURL, managedIdentity.ClientId)
|
||||
case managedIdentity.ClientId == "" && managedIdentity.ObjectId != "":
|
||||
// uses user-managed identity
|
||||
msiProvider = msiDownloader.GetMsiProviderByObjectId(fileURL, managedIdentity.ObjectId)
|
||||
default:
|
||||
return nil, fmt.Errorf("Use either ClientId or ObjectId for managed identity. Not both.")
|
||||
}
|
||||
|
||||
_, msiError := msiProvider()
|
||||
|
|
|
@ -8,22 +8,33 @@ import (
|
|||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/Azure/run-command-handler-linux/pkg/download"
|
||||
"github.com/ahmetalpbalkan/go-httpbin"
|
||||
"github.com/go-kit/kit/log"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var mockManagedIdentity = RunCommandManagedIdentity{
|
||||
var mockManagedIdentityBoth = RunCommandManagedIdentity{
|
||||
ClientId: "5d784f90-d7d9-4b04-bdf1-4ae4824d55b0",
|
||||
ObjectId: "bed99fe3-1ad3-4a25-867d-7d48d68def6a",
|
||||
}
|
||||
|
||||
func Test_getDownloader_externalUrl(t *testing.T) {
|
||||
UseMockGetDownloaders = true
|
||||
ReturnErrorForMockGetDownloaders = true
|
||||
var mockManagedIdentityClientId = RunCommandManagedIdentity{
|
||||
ClientId: "5d784f90-d7d9-4b04-bdf1-4ae4824d55b0",
|
||||
}
|
||||
|
||||
var mockManagedIdentityObjectId = RunCommandManagedIdentity{
|
||||
ObjectId: "bed99fe3-1ad3-4a25-867d-7d48d68def6a",
|
||||
}
|
||||
|
||||
var mockManagedSystemIdentity = RunCommandManagedIdentity{}
|
||||
|
||||
func Test_getDownloaders_externalUrl(t *testing.T) {
|
||||
download.MockReturnErrorForMockMsiDownloader = true
|
||||
var mockMsiDownloder = download.MockMsiDownloader{}
|
||||
|
||||
// Case 0: Error getting Msi. It returns public URL downloader
|
||||
d, err := getDownloaders("http://acct.blob.core.windows.net/", &mockManagedIdentity)
|
||||
d, err := getDownloaders("http://acct.blob.core.windows.net/", &mockManagedIdentityObjectId, mockMsiDownloder)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, d)
|
||||
require.NotEmpty(t, d)
|
||||
|
@ -31,14 +42,54 @@ func Test_getDownloader_externalUrl(t *testing.T) {
|
|||
require.Equal(t, "download.urlDownload", fmt.Sprintf("%T", d[0]), "got wrong type")
|
||||
|
||||
// Case 1: Valid Msi returned. It returns both MSI downloader and public URL downloader. First downloader is MSI downloader
|
||||
ReturnErrorForMockGetDownloaders = false
|
||||
d, err = getDownloaders("http://acct.blob.core.windows.net/", &mockManagedIdentity)
|
||||
download.MockReturnErrorForMockMsiDownloader = false
|
||||
d, err = getDownloaders("http://acct.blob.core.windows.net/", &mockManagedIdentityClientId, mockMsiDownloder)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, d)
|
||||
require.Equal(t, 2, len(d))
|
||||
require.Equal(t, "*download.blobWithMsiToken", fmt.Sprintf("%T", d[0]), "got wrong type")
|
||||
|
||||
UseMockGetDownloaders = false
|
||||
download.MockReturnErrorForMockMsiDownloader = false
|
||||
}
|
||||
|
||||
func Test_getDownloaders_SystemIdentityVersusByClientIdOrObjectId(t *testing.T) {
|
||||
download.MockReturnErrorForMockMsiDownloader = true
|
||||
var mockMsiDownloder = download.MockMsiDownloader{}
|
||||
|
||||
// Case 0: Provide both clientId and ObjectId getting Msi.
|
||||
d, err := getDownloaders("http://acct.blob.core.windows.net/", &mockManagedIdentityBoth, mockMsiDownloder)
|
||||
require.NotNil(t, err)
|
||||
require.Equal(t, err.Error(), "Use either ClientId or ObjectId for managed identity. Not both.")
|
||||
|
||||
download.MockReturnErrorForMockMsiDownloader = false
|
||||
|
||||
// Case 1: Valid Msi returned by system identity. It returns both MSI downloader and public URL downloader. First downloader is MSI downloader
|
||||
d, err = getDownloaders("http://acct.blob.core.windows.net/", &mockManagedSystemIdentity, mockMsiDownloder)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, d)
|
||||
require.Equal(t, 2, len(d))
|
||||
require.Equal(t, "*download.blobWithMsiToken", fmt.Sprintf("%T", d[0]), "got wrong type")
|
||||
|
||||
// Case 2: Valid Msi returned by system identity - nil identity passed. It returns both MSI downloader and public URL downloader. First downloader is MSI downloader
|
||||
d, err = getDownloaders("http://acct.blob.core.windows.net/", nil, mockMsiDownloder)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, d)
|
||||
require.Equal(t, 2, len(d))
|
||||
require.Equal(t, "*download.blobWithMsiToken", fmt.Sprintf("%T", d[0]), "got wrong type")
|
||||
|
||||
// Case 3: Valid Msi returned by clientId. It returns both MSI downloader and public URL downloader. First downloader is MSI downloader
|
||||
d, err = getDownloaders("http://acct.blob.core.windows.net/", &mockManagedIdentityClientId, mockMsiDownloder)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, d)
|
||||
require.Equal(t, 2, len(d))
|
||||
require.Equal(t, "*download.blobWithMsiToken", fmt.Sprintf("%T", d[0]), "got wrong type")
|
||||
|
||||
// Case 4: Valid Msi returned by clientId. It returns both MSI downloader and public URL downloader. First downloader is MSI downloader
|
||||
d, err = getDownloaders("http://acct.blob.core.windows.net/", &mockManagedIdentityObjectId, mockMsiDownloder)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, d)
|
||||
require.Equal(t, 2, len(d))
|
||||
require.Equal(t, "*download.blobWithMsiToken", fmt.Sprintf("%T", d[0]), "got wrong type")
|
||||
}
|
||||
|
||||
func Test_urlToFileName_badURL(t *testing.T) {
|
||||
|
|
|
@ -29,6 +29,18 @@ type blobWithMsiToken struct {
|
|||
|
||||
type MsiProvider func() (msi.Msi, error)
|
||||
|
||||
type MsiDownloader interface {
|
||||
GetMsiProvider(blobUri string) MsiProvider
|
||||
GetMsiProviderByClientId(blobUri, clientId string) MsiProvider
|
||||
GetMsiProviderByObjectId(blobUri, objectId string) MsiProvider
|
||||
}
|
||||
|
||||
type ProdMsiDownloader struct{}
|
||||
|
||||
type MockMsiDownloader struct{} // Used only for test
|
||||
|
||||
var MockReturnErrorForMockMsiDownloader = false // Used only for test
|
||||
|
||||
func (self *blobWithMsiToken) GetRequest() (*http.Request, error) {
|
||||
msi, err := self.msiProvider()
|
||||
if err != nil {
|
||||
|
@ -54,7 +66,8 @@ func NewBlobWithMsiDownload(url string, msiProvider MsiProvider) Downloader {
|
|||
return &blobWithMsiToken{url, msiProvider}
|
||||
}
|
||||
|
||||
func GetMsiProviderForStorageAccountsImplicitly(blobUri string) MsiProvider {
|
||||
// Uses system identity to get Msi token
|
||||
func (prodMsiDownloader ProdMsiDownloader) GetMsiProvider(blobUri string) MsiProvider {
|
||||
msiProvider := msi.NewMsiProvider(httputil.NewSecureHttpClient(httputil.DefaultRetryBehavior))
|
||||
return func() (msi.Msi, error) {
|
||||
msi, err := msiProvider.GetMsiForResource(GetResourceNameFromBlobUri(blobUri))
|
||||
|
@ -67,7 +80,24 @@ func GetMsiProviderForStorageAccountsImplicitly(blobUri string) MsiProvider {
|
|||
}
|
||||
}
|
||||
|
||||
func GetMsiProviderForStorageAccountsWithClientId(blobUri, clientId string) MsiProvider {
|
||||
// Mock implementation of GetMsiProvider
|
||||
func (mockMsiDownloader MockMsiDownloader) GetMsiProvider(blobUri string) MsiProvider {
|
||||
return func() (msi.Msi, error) {
|
||||
mockMsi := msi.Msi{
|
||||
AccessToken: "uwsihdiuhiuasdfui*(*(&90790asofhdioas",
|
||||
Resource: "Msi by System Identity for blob " + blobUri,
|
||||
}
|
||||
if MockReturnErrorForMockMsiDownloader {
|
||||
return mockMsi, errors.New("Error getting msi")
|
||||
} else {
|
||||
return mockMsi, nil
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// Get Msi token by clientId
|
||||
func (prodMsiDownloader ProdMsiDownloader) GetMsiProviderByClientId(blobUri, clientId string) MsiProvider {
|
||||
msiProvider := msi.NewMsiProvider(httputil.NewSecureHttpClient(httputil.DefaultRetryBehavior))
|
||||
return func() (msi.Msi, error) {
|
||||
msi, err := msiProvider.GetMsiUsingClientId(clientId, GetResourceNameFromBlobUri(blobUri))
|
||||
|
@ -79,7 +109,23 @@ func GetMsiProviderForStorageAccountsWithClientId(blobUri, clientId string) MsiP
|
|||
}
|
||||
}
|
||||
|
||||
func GetMsiProviderForStorageAccountsWithObjectId(blobUri, objectId string) MsiProvider {
|
||||
// Mock implementation of GetMsiProviderByClientId
|
||||
func (mockMsiDownloader MockMsiDownloader) GetMsiProviderByClientId(blobUri string, clientId string) MsiProvider {
|
||||
return func() (msi.Msi, error) {
|
||||
mockMsi := msi.Msi{
|
||||
AccessToken: "uwsihdiuhiuasdfui*(*(&90790asofhdioas",
|
||||
Resource: "Msi by clientId for blob " + blobUri,
|
||||
}
|
||||
if MockReturnErrorForMockMsiDownloader {
|
||||
return mockMsi, errors.New("Error getting msi")
|
||||
} else {
|
||||
return mockMsi, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get Msi token by objectId
|
||||
func (prodMsiDownloader ProdMsiDownloader) GetMsiProviderByObjectId(blobUri, objectId string) MsiProvider {
|
||||
msiProvider := msi.NewMsiProvider(httputil.NewSecureHttpClient(httputil.DefaultRetryBehavior))
|
||||
return func() (msi.Msi, error) {
|
||||
msi, err := msiProvider.GetMsiUsingObjectId(objectId, GetResourceNameFromBlobUri(blobUri))
|
||||
|
@ -91,6 +137,21 @@ func GetMsiProviderForStorageAccountsWithObjectId(blobUri, objectId string) MsiP
|
|||
}
|
||||
}
|
||||
|
||||
// Mock implementation of GetMsiProviderByObjectId
|
||||
func (mockMsiDownloader MockMsiDownloader) GetMsiProviderByObjectId(blobUri, objectId string) MsiProvider {
|
||||
return func() (msi.Msi, error) {
|
||||
mockMsi := msi.Msi{
|
||||
AccessToken: "uwsihdiuhiuasdfui*(*(&90790asofhdioas",
|
||||
Resource: "Msi by objectId for blob " + blobUri,
|
||||
}
|
||||
if MockReturnErrorForMockMsiDownloader {
|
||||
return mockMsi, errors.New("Error getting msi")
|
||||
} else {
|
||||
return mockMsi, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func GetResourceNameFromBlobUri(uri string) string {
|
||||
// TODO: update this function as sovereign cloud blob resource strings become available
|
||||
// resource string for getting MSI for azure storage is still https://storage.azure.com/ for sovereign regions but it is expected to change
|
||||
|
@ -113,18 +174,3 @@ func IsAzureStorageBlobUri(url string) bool {
|
|||
|
||||
return false
|
||||
}
|
||||
|
||||
func GetMockMsiProvider(clientId string, returnError bool) MsiProvider {
|
||||
return func() (msi.Msi, error) {
|
||||
mockMsi := msi.Msi{
|
||||
AccessToken: "uwsihdiuhiuasdfui*(*(&90790asofhdioas",
|
||||
ClientID: clientId,
|
||||
}
|
||||
if returnError {
|
||||
return mockMsi, errors.New("Error getting msi")
|
||||
} else {
|
||||
return mockMsi, nil
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,10 +2,11 @@ package download
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/Azure/azure-extension-foundation/msi"
|
||||
"github.com/stretchr/testify/require"
|
||||
"io/ioutil"
|
||||
"testing"
|
||||
|
||||
"github.com/Azure/azure-extension-foundation/msi"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// README
|
||||
|
|
Загрузка…
Ссылка в новой задаче