Test for system identity vs otherwise

This commit is contained in:
vivlingaiah 2023-01-24 14:48:25 -08:00
Родитель 86e1e1da55
Коммит 0a95bbb7fd
4 изменённых файлов: 140 добавлений и 48 удалений

Просмотреть файл

@ -17,8 +17,6 @@ import (
)
var UseMockSASDownloadFailure bool = false
var UseMockGetDownloaders = false
var ReturnErrorForMockGetDownloaders = false
// downloadAndProcessURL downloads using the specified downloader and saves it to the
// specified existing directory, which must be the path to the saved file. Then
@ -52,7 +50,7 @@ func downloadAndProcessURL(ctx *log.Context, url, downloadDir string, cfg *handl
//If there was an error downloading using SAS URI or SAS was not provided, download using managedIdentity or publicly.
if scriptSASDownloadErr != nil || scriptSAS == "" {
downloaders, getDownloadersError := getDownloaders(url, cfg.SourceManagedIdentity)
downloaders, getDownloadersError := getDownloaders(url, cfg.SourceManagedIdentity, download.ProdMsiDownloader{})
if getDownloadersError == nil {
const mode = 0500 // we assume users download scripts to execute
_, err = download.SaveTo(ctx, downloaders, targetFilePath, mode)
@ -76,7 +74,7 @@ func downloadAndProcessURL(ctx *log.Context, url, downloadDir string, cfg *handl
// getDownloaders returns one or two downloaders (two if it is an Azure storage blob):
// 1. Downloader for script using public URI.
// 2. Downloader for script using managed identity.
func getDownloaders(fileURL string, managedIdentity *RunCommandManagedIdentity) ([]download.Downloader, error) {
func getDownloaders(fileURL string, managedIdentity *RunCommandManagedIdentity, msiDownloader download.MsiDownloader) ([]download.Downloader, error) {
if fileURL == "" {
return nil, fmt.Errorf("fileURL is empty.")
@ -86,23 +84,19 @@ func getDownloaders(fileURL string, managedIdentity *RunCommandManagedIdentity)
// if managed identity was specified in the configuration, try to use it to download the files
var msiProvider download.MsiProvider
if UseMockGetDownloaders {
msiProvider = download.GetMockMsiProvider(managedIdentity.ClientId, ReturnErrorForMockGetDownloaders)
} else {
switch {
case managedIdentity == nil || (managedIdentity.ClientId == "" && managedIdentity.ObjectId == ""):
// get msi Provider for blob url implicitly (uses system managed identity)
msiProvider = download.GetMsiProviderForStorageAccountsImplicitly(fileURL)
switch {
case managedIdentity == nil || (managedIdentity.ClientId == "" && managedIdentity.ObjectId == ""):
// get msi Provider for blob url implicitly (uses system managed identity)
msiProvider = msiDownloader.GetMsiProvider(fileURL)
case managedIdentity.ClientId != "" && managedIdentity.ObjectId == "":
// uses user-managed identity
msiProvider = download.GetMsiProviderForStorageAccountsWithClientId(fileURL, managedIdentity.ClientId)
case managedIdentity.ClientId == "" && managedIdentity.ObjectId != "":
// uses user-managed identity
msiProvider = download.GetMsiProviderForStorageAccountsWithObjectId(fileURL, managedIdentity.ObjectId)
default:
return nil, fmt.Errorf("Use either ClientId or ObjectId for managed identity. Not both.")
}
case managedIdentity.ClientId != "" && managedIdentity.ObjectId == "":
// uses user-managed identity
msiProvider = msiDownloader.GetMsiProviderByClientId(fileURL, managedIdentity.ClientId)
case managedIdentity.ClientId == "" && managedIdentity.ObjectId != "":
// uses user-managed identity
msiProvider = msiDownloader.GetMsiProviderByObjectId(fileURL, managedIdentity.ObjectId)
default:
return nil, fmt.Errorf("Use either ClientId or ObjectId for managed identity. Not both.")
}
_, msiError := msiProvider()

Просмотреть файл

@ -8,22 +8,33 @@ import (
"path/filepath"
"testing"
"github.com/Azure/run-command-handler-linux/pkg/download"
"github.com/ahmetalpbalkan/go-httpbin"
"github.com/go-kit/kit/log"
"github.com/stretchr/testify/require"
)
var mockManagedIdentity = RunCommandManagedIdentity{
var mockManagedIdentityBoth = RunCommandManagedIdentity{
ClientId: "5d784f90-d7d9-4b04-bdf1-4ae4824d55b0",
ObjectId: "bed99fe3-1ad3-4a25-867d-7d48d68def6a",
}
func Test_getDownloader_externalUrl(t *testing.T) {
UseMockGetDownloaders = true
ReturnErrorForMockGetDownloaders = true
var mockManagedIdentityClientId = RunCommandManagedIdentity{
ClientId: "5d784f90-d7d9-4b04-bdf1-4ae4824d55b0",
}
var mockManagedIdentityObjectId = RunCommandManagedIdentity{
ObjectId: "bed99fe3-1ad3-4a25-867d-7d48d68def6a",
}
var mockManagedSystemIdentity = RunCommandManagedIdentity{}
func Test_getDownloaders_externalUrl(t *testing.T) {
download.MockReturnErrorForMockMsiDownloader = true
var mockMsiDownloder = download.MockMsiDownloader{}
// Case 0: Error getting Msi. It returns public URL downloader
d, err := getDownloaders("http://acct.blob.core.windows.net/", &mockManagedIdentity)
d, err := getDownloaders("http://acct.blob.core.windows.net/", &mockManagedIdentityObjectId, mockMsiDownloder)
require.Nil(t, err)
require.NotNil(t, d)
require.NotEmpty(t, d)
@ -31,14 +42,54 @@ func Test_getDownloader_externalUrl(t *testing.T) {
require.Equal(t, "download.urlDownload", fmt.Sprintf("%T", d[0]), "got wrong type")
// Case 1: Valid Msi returned. It returns both MSI downloader and public URL downloader. First downloader is MSI downloader
ReturnErrorForMockGetDownloaders = false
d, err = getDownloaders("http://acct.blob.core.windows.net/", &mockManagedIdentity)
download.MockReturnErrorForMockMsiDownloader = false
d, err = getDownloaders("http://acct.blob.core.windows.net/", &mockManagedIdentityClientId, mockMsiDownloder)
require.Nil(t, err)
require.NotNil(t, d)
require.Equal(t, 2, len(d))
require.Equal(t, "*download.blobWithMsiToken", fmt.Sprintf("%T", d[0]), "got wrong type")
UseMockGetDownloaders = false
download.MockReturnErrorForMockMsiDownloader = false
}
func Test_getDownloaders_SystemIdentityVersusByClientIdOrObjectId(t *testing.T) {
download.MockReturnErrorForMockMsiDownloader = true
var mockMsiDownloder = download.MockMsiDownloader{}
// Case 0: Provide both clientId and ObjectId getting Msi.
d, err := getDownloaders("http://acct.blob.core.windows.net/", &mockManagedIdentityBoth, mockMsiDownloder)
require.NotNil(t, err)
require.Equal(t, err.Error(), "Use either ClientId or ObjectId for managed identity. Not both.")
download.MockReturnErrorForMockMsiDownloader = false
// Case 1: Valid Msi returned by system identity. It returns both MSI downloader and public URL downloader. First downloader is MSI downloader
d, err = getDownloaders("http://acct.blob.core.windows.net/", &mockManagedSystemIdentity, mockMsiDownloder)
require.Nil(t, err)
require.NotNil(t, d)
require.Equal(t, 2, len(d))
require.Equal(t, "*download.blobWithMsiToken", fmt.Sprintf("%T", d[0]), "got wrong type")
// Case 2: Valid Msi returned by system identity - nil identity passed. It returns both MSI downloader and public URL downloader. First downloader is MSI downloader
d, err = getDownloaders("http://acct.blob.core.windows.net/", nil, mockMsiDownloder)
require.Nil(t, err)
require.NotNil(t, d)
require.Equal(t, 2, len(d))
require.Equal(t, "*download.blobWithMsiToken", fmt.Sprintf("%T", d[0]), "got wrong type")
// Case 3: Valid Msi returned by clientId. It returns both MSI downloader and public URL downloader. First downloader is MSI downloader
d, err = getDownloaders("http://acct.blob.core.windows.net/", &mockManagedIdentityClientId, mockMsiDownloder)
require.Nil(t, err)
require.NotNil(t, d)
require.Equal(t, 2, len(d))
require.Equal(t, "*download.blobWithMsiToken", fmt.Sprintf("%T", d[0]), "got wrong type")
// Case 4: Valid Msi returned by clientId. It returns both MSI downloader and public URL downloader. First downloader is MSI downloader
d, err = getDownloaders("http://acct.blob.core.windows.net/", &mockManagedIdentityObjectId, mockMsiDownloder)
require.Nil(t, err)
require.NotNil(t, d)
require.Equal(t, 2, len(d))
require.Equal(t, "*download.blobWithMsiToken", fmt.Sprintf("%T", d[0]), "got wrong type")
}
func Test_urlToFileName_badURL(t *testing.T) {

Просмотреть файл

@ -29,6 +29,18 @@ type blobWithMsiToken struct {
type MsiProvider func() (msi.Msi, error)
type MsiDownloader interface {
GetMsiProvider(blobUri string) MsiProvider
GetMsiProviderByClientId(blobUri, clientId string) MsiProvider
GetMsiProviderByObjectId(blobUri, objectId string) MsiProvider
}
type ProdMsiDownloader struct{}
type MockMsiDownloader struct{} // Used only for test
var MockReturnErrorForMockMsiDownloader = false // Used only for test
func (self *blobWithMsiToken) GetRequest() (*http.Request, error) {
msi, err := self.msiProvider()
if err != nil {
@ -54,7 +66,8 @@ func NewBlobWithMsiDownload(url string, msiProvider MsiProvider) Downloader {
return &blobWithMsiToken{url, msiProvider}
}
func GetMsiProviderForStorageAccountsImplicitly(blobUri string) MsiProvider {
// Uses system identity to get Msi token
func (prodMsiDownloader ProdMsiDownloader) GetMsiProvider(blobUri string) MsiProvider {
msiProvider := msi.NewMsiProvider(httputil.NewSecureHttpClient(httputil.DefaultRetryBehavior))
return func() (msi.Msi, error) {
msi, err := msiProvider.GetMsiForResource(GetResourceNameFromBlobUri(blobUri))
@ -67,7 +80,24 @@ func GetMsiProviderForStorageAccountsImplicitly(blobUri string) MsiProvider {
}
}
func GetMsiProviderForStorageAccountsWithClientId(blobUri, clientId string) MsiProvider {
// Mock implementation of GetMsiProvider
func (mockMsiDownloader MockMsiDownloader) GetMsiProvider(blobUri string) MsiProvider {
return func() (msi.Msi, error) {
mockMsi := msi.Msi{
AccessToken: "uwsihdiuhiuasdfui*(*(&90790asofhdioas",
Resource: "Msi by System Identity for blob " + blobUri,
}
if MockReturnErrorForMockMsiDownloader {
return mockMsi, errors.New("Error getting msi")
} else {
return mockMsi, nil
}
}
}
// Get Msi token by clientId
func (prodMsiDownloader ProdMsiDownloader) GetMsiProviderByClientId(blobUri, clientId string) MsiProvider {
msiProvider := msi.NewMsiProvider(httputil.NewSecureHttpClient(httputil.DefaultRetryBehavior))
return func() (msi.Msi, error) {
msi, err := msiProvider.GetMsiUsingClientId(clientId, GetResourceNameFromBlobUri(blobUri))
@ -79,7 +109,23 @@ func GetMsiProviderForStorageAccountsWithClientId(blobUri, clientId string) MsiP
}
}
func GetMsiProviderForStorageAccountsWithObjectId(blobUri, objectId string) MsiProvider {
// Mock implementation of GetMsiProviderByClientId
func (mockMsiDownloader MockMsiDownloader) GetMsiProviderByClientId(blobUri string, clientId string) MsiProvider {
return func() (msi.Msi, error) {
mockMsi := msi.Msi{
AccessToken: "uwsihdiuhiuasdfui*(*(&90790asofhdioas",
Resource: "Msi by clientId for blob " + blobUri,
}
if MockReturnErrorForMockMsiDownloader {
return mockMsi, errors.New("Error getting msi")
} else {
return mockMsi, nil
}
}
}
// Get Msi token by objectId
func (prodMsiDownloader ProdMsiDownloader) GetMsiProviderByObjectId(blobUri, objectId string) MsiProvider {
msiProvider := msi.NewMsiProvider(httputil.NewSecureHttpClient(httputil.DefaultRetryBehavior))
return func() (msi.Msi, error) {
msi, err := msiProvider.GetMsiUsingObjectId(objectId, GetResourceNameFromBlobUri(blobUri))
@ -91,6 +137,21 @@ func GetMsiProviderForStorageAccountsWithObjectId(blobUri, objectId string) MsiP
}
}
// Mock implementation of GetMsiProviderByObjectId
func (mockMsiDownloader MockMsiDownloader) GetMsiProviderByObjectId(blobUri, objectId string) MsiProvider {
return func() (msi.Msi, error) {
mockMsi := msi.Msi{
AccessToken: "uwsihdiuhiuasdfui*(*(&90790asofhdioas",
Resource: "Msi by objectId for blob " + blobUri,
}
if MockReturnErrorForMockMsiDownloader {
return mockMsi, errors.New("Error getting msi")
} else {
return mockMsi, nil
}
}
}
func GetResourceNameFromBlobUri(uri string) string {
// TODO: update this function as sovereign cloud blob resource strings become available
// resource string for getting MSI for azure storage is still https://storage.azure.com/ for sovereign regions but it is expected to change
@ -113,18 +174,3 @@ func IsAzureStorageBlobUri(url string) bool {
return false
}
func GetMockMsiProvider(clientId string, returnError bool) MsiProvider {
return func() (msi.Msi, error) {
mockMsi := msi.Msi{
AccessToken: "uwsihdiuhiuasdfui*(*(&90790asofhdioas",
ClientID: clientId,
}
if returnError {
return mockMsi, errors.New("Error getting msi")
} else {
return mockMsi, nil
}
}
}

Просмотреть файл

@ -2,10 +2,11 @@ package download
import (
"encoding/json"
"github.com/Azure/azure-extension-foundation/msi"
"github.com/stretchr/testify/require"
"io/ioutil"
"testing"
"github.com/Azure/azure-extension-foundation/msi"
"github.com/stretchr/testify/require"
)
// README