Fix scenario of using SAS urls with default stg acct (#12)

•	Fixes the special case where a trigger file's Workflow URL uses the default storage account, but also includes a SAS token
•	Fixes an issue where additional storage accounts added via the containers-to-mount file are not mounted
This commit is contained in:
Matt McLoughlin 2019-12-19 18:44:45 -05:00 коммит произвёл GitHub
Родитель 4d752f83bb
Коммит e3b8077319
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
8 изменённых файлов: 159 добавлений и 42 удалений

Просмотреть файл

@ -24,6 +24,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TriggerService", "src\Trigg
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Common", "src\Common\Common.csproj", "{63C91E1C-640C-4281-8388-971908009CAE}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TriggerService.Tests", "TriggerService.Tests\TriggerService.Tests.csproj", "{76E9ECA1-AA42-4A59-9C8B-665A69A8FADF}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@ -58,6 +60,10 @@ Global
{63C91E1C-640C-4281-8388-971908009CAE}.Debug|Any CPU.Build.0 = Debug|Any CPU
{63C91E1C-640C-4281-8388-971908009CAE}.Release|Any CPU.ActiveCfg = Release|Any CPU
{63C91E1C-640C-4281-8388-971908009CAE}.Release|Any CPU.Build.0 = Release|Any CPU
{76E9ECA1-AA42-4A59-9C8B-665A69A8FADF}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{76E9ECA1-AA42-4A59-9C8B-665A69A8FADF}.Debug|Any CPU.Build.0 = Debug|Any CPU
{76E9ECA1-AA42-4A59-9C8B-665A69A8FADF}.Release|Any CPU.ActiveCfg = Release|Any CPU
{76E9ECA1-AA42-4A59-9C8B-665A69A8FADF}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE

Просмотреть файл

@ -0,0 +1,78 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
using System;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Moq;
namespace TriggerService.Tests
{
[TestClass]
public class CromwellOnAzureEnvironmentTests
{
private byte[] blobData = new byte[1] { 0 };
private byte[] httpClientData = new byte[1] { 1 };
[TestMethod]
public async Task GetBlobFileNameAndDataWithDefaultStorageAccount()
{
const string url = "https://fake.azure.storage.account/test/test.wdl";
var accountAuthority = new Uri(url).Authority;
(var name, var data) = await GetBlobFileNameAndDataUsingMocksAsync(url, accountAuthority);
Assert.IsNotNull(name);
Assert.IsNotNull(data);
Assert.IsTrue(data.Length > 0);
// Test if Azure credentials code path is used
Assert.AreEqual(data, blobData);
}
[TestMethod]
public async Task GetBlobFileNameAndDataWithDefaultStorageAccountWithSasToken()
{
const string url = "https://fake.azure.storage.account/test/test.wdl?sp=r&st=2019-12-18T18:55:41Z&se=2019-12-19T02:55:41Z&spr=https&sv=2019-02-02&sr=b&sig=EMJyBMOxdG2NvBqiwUsg71ZdYqwqMWda9242KU43%2F5Y%3D";
var accountAuthority = new Uri(url).Authority;
(var name, var data) = await GetBlobFileNameAndDataUsingMocksAsync(url, accountAuthority);
Assert.IsNotNull(name);
Assert.IsNotNull(data);
Assert.IsTrue(data.Length > 0);
// Test if HttpClient code path is used
Assert.AreEqual(data, httpClientData);
}
private async Task<(string, byte[])> GetBlobFileNameAndDataUsingMocksAsync(string url, string accountAuthority)
{
var serviceCollection = new ServiceCollection()
.AddLogging(loggingBuilder => loggingBuilder.AddConsole());
var serviceProvider = serviceCollection.BuildServiceProvider();
var azStorageMock = new Mock<IAzureStorage>();
azStorageMock.Setup(az => az
.DownloadBlockBlobAsync(It.IsAny<string>()))
.Returns(Task.FromResult(blobData));
azStorageMock.Setup(az => az
.DownloadFileUsingHttpClientAsync(It.IsAny<string>()))
.Returns(Task.FromResult(httpClientData));
azStorageMock.SetupGet(az => az.AccountAuthority).Returns(accountAuthority);
var environment = new CromwellOnAzureEnvironment(
serviceProvider.GetRequiredService<ILoggerFactory>(),
azStorageMock.Object,
new CromwellApiClient.CromwellApiClient("http://cromwell:8000"));
return await environment.GetBlobFileNameAndData(url);
}
}
}

Просмотреть файл

@ -0,0 +1,21 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>netcoreapp3.0</TargetFramework>
<IsPackable>false</IsPackable>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.2.0" />
<PackageReference Include="Moq" Version="4.13.1" />
<PackageReference Include="MSTest.TestAdapter" Version="2.0.0" />
<PackageReference Include="MSTest.TestFramework" Version="2.0.0" />
<PackageReference Include="coverlet.collector" Version="1.0.1" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\src\TriggerService\TriggerService.csproj" />
</ItemGroup>
</Project>

Просмотреть файл

@ -54,7 +54,7 @@ namespace TesApi.Web
defaultStorageAccountName = configuration["DefaultStorageAccountName"]; // This account contains the cromwell-executions container
usePreemptibleVmsOnly = bool.TryParse(configuration["UsePreemptibleVmsOnly"], out var temp) ? temp : false;
externalStorageContainers = configuration["ExternalStorageContainers"]?.Split(new[] { ',', ' ', '\r', '\n' }, StringSplitOptions.RemoveEmptyEntries)
externalStorageContainers = configuration["ExternalStorageContainers"]?.Split(new[] { ',', ';', '\r', '\n' }, StringSplitOptions.RemoveEmptyEntries)
.SelectMany(e => externalStorageContainerRegex.Matches(e).Cast<Match>()
.Select(m => new ExternalStorageContainerInfo { BlobEndpoint = m.Groups[1].Value, AccountName = m.Groups[2].Value, ContainerName = m.Groups[3].Value, SasToken = m.Groups[4].Value }))
.ToList();

Просмотреть файл

@ -6,6 +6,7 @@ using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Threading.Tasks;
using Microsoft.Azure.Management.ApplicationInsights.Management;
using Microsoft.Azure.Management.Fluent;
@ -25,17 +26,20 @@ namespace TriggerService
private readonly ILogger<AzureStorage> logger;
private readonly CloudStorageAccount account;
private readonly CloudBlobClient blobClient;
private readonly HttpClient httpClient;
public AzureStorage(ILogger<AzureStorage> logger, string accountName)
public AzureStorage(ILogger<AzureStorage> logger, CloudStorageAccount account, HttpClient httpClient)
{
this.logger = logger;
this.AccountName = accountName;
ServicePointManager.DefaultConnectionLimit = Environment.ProcessorCount * 8;
ServicePointManager.Expect100Continue = false;
account = GetCloudStorageAccountUsingMsiAsync(accountName).Result;
this.logger = logger;
this.account = account;
this.httpClient = httpClient;
blobClient = account.CreateCloudBlobClient();
var host = account.BlobStorageUri.PrimaryUri.Host;
AccountName = host.Substring(0, host.IndexOf("."));
}
public string AccountName { get; }
@ -168,12 +172,12 @@ namespace TriggerService
await containerReference.DeleteIfExistsAsync();
}
public async Task<byte[]> DownloadFileAsync(string blobUrl)
public async Task<byte[]> DownloadBlockBlobAsync(string blobUrl)
{
// Supporting "http://account.blob.core.windows.net/container/blob", "/account/container/blob" and "account/container/blob" URLs
if (!blobUrl.StartsWith("http", StringComparison.OrdinalIgnoreCase) && blobUrl.TrimStart('/').StartsWith(this.AccountName + "/", StringComparison.OrdinalIgnoreCase))
{
blobUrl = blobUrl.TrimStart('/').Replace(this.AccountName, $"http://{this.AccountAuthority}", StringComparison.OrdinalIgnoreCase);
blobUrl = blobUrl.TrimStart('/').Replace(this.AccountName, $"https://{this.AccountAuthority}", StringComparison.OrdinalIgnoreCase);
}
var blob = new CloudBlockBlob(new Uri(blobUrl), account.Credentials);
@ -192,6 +196,11 @@ namespace TriggerService
}
}
public async Task<byte[]> DownloadFileUsingHttpClientAsync(string url)
{
return await httpClient.GetByteArrayAsync(url);
}
public enum WorkflowState { New, InProgress, Succeeded, Failed, Abort };
public class StorageAccountInfo
@ -237,7 +246,7 @@ namespace TriggerService
return blobList;
}
private static async Task<CloudStorageAccount> GetCloudStorageAccountUsingMsiAsync(string accountName)
public static async Task<CloudStorageAccount> GetCloudStorageAccountUsingMsiAsync(string accountName)
{
var accounts = await GetAccessibleStorageAccountsAsync();
var account = accounts.FirstOrDefault(s => s.Name == accountName);

Просмотреть файл

@ -4,7 +4,6 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net.Http;
using System.Text.RegularExpressions;
using System.Threading.Tasks;
@ -17,17 +16,17 @@ namespace TriggerService
{
public class CromwellOnAzureEnvironment
{
private static readonly Regex blobNameRegex = new Regex("^(?:https?://|/)?[^/]+/[^/]+/(.+)"); // Supporting "http://account.blob.core.windows.net/container/blob", "/account/container/blob" and "account/container/blob" URLs in the trigger file.
private readonly HttpClient httpClient = new HttpClient();
private static readonly Regex blobNameRegex = new Regex("^(?:https?://|/)?[^/]+/[^/]+/([^?.]+)"); // Supporting "http://account.blob.core.windows.net/container/blob", "/account/container/blob" and "account/container/blob" URLs in the trigger file.
private IAzureStorage storage { get; set; }
private ICromwellApiClient cromwellApiClient { get; set; }
private readonly ILogger<AzureStorage> logger;
public CromwellOnAzureEnvironment(ILoggerFactory loggerFactory, IAzureStorage storage, ICromwellApiClient cromwellApiClient)
{
logger = loggerFactory.CreateLogger<AzureStorage>();
this.storage = storage;
this.cromwellApiClient = cromwellApiClient;
this.logger = loggerFactory.CreateLogger<AzureStorage>();
logger.LogInformation($"Cromwell URL: {cromwellApiClient.GetUrl()}");
}
@ -191,9 +190,30 @@ namespace TriggerService
}
}
private static string GetBlobName(string url)
public async Task<(string, byte[])> GetBlobFileNameAndData(string url)
{
return blobNameRegex.Match(url)?.Groups[1].Value.Replace("/", "_");
if (string.IsNullOrWhiteSpace(url))
{
return (null, null);
}
var blobName = GetBlobName(url);
byte[] data;
if (((Uri.TryCreate(url, UriKind.Absolute, out var uri) && uri.Authority.Equals(storage.AccountAuthority, StringComparison.OrdinalIgnoreCase))
|| url.TrimStart('/').StartsWith(storage.AccountName + "/", StringComparison.OrdinalIgnoreCase))
&& uri.ParseQueryString().Get("sig") == null)
{
// use known credentials, unless the URL specifies a shared-access signature
data = await storage.DownloadBlockBlobAsync(url);
}
else
{
data = await storage.DownloadFileUsingHttpClientAsync(url);
}
return (blobName, data);
}
private static Guid ExtractWorkflowId(Microsoft.WindowsAzure.Storage.Blob.CloudBlockBlob blobTrigger, AzureStorage.WorkflowState currentState)
@ -211,6 +231,11 @@ namespace TriggerService
return withoutExtension.Substring(0, withoutExtension.LastIndexOf('.'));
}
private static string GetBlobName(string url)
{
return blobNameRegex.Match(url)?.Groups[1].Value.Replace("/", "_");
}
private async Task UploadOutputsAsync(Microsoft.WindowsAzure.Storage.Blob.CloudBlockBlob blobTrigger, Guid id, string sampleName)
{
const string outputsContainer = "outputs";
@ -252,30 +277,5 @@ namespace TriggerService
logger.LogWarning(exc, $"Getting timing threw an exception for Id: {id}");
}
}
private async Task<(string, byte[])> GetBlobFileNameAndData(string url)
{
if (string.IsNullOrWhiteSpace(url))
{
return (null, null);
}
var blobName = GetBlobName(url);
byte[] data;
if ((Uri.TryCreate(url, UriKind.Absolute, out var uri) && uri.Authority.Equals(storage.AccountAuthority, StringComparison.OrdinalIgnoreCase))
|| url.TrimStart('/').StartsWith(storage.AccountName + "/", StringComparison.OrdinalIgnoreCase))
{
// use known credentials
data = await storage.DownloadFileAsync(url);
}
else
{
data = await httpClient.GetByteArrayAsync(url);
}
return (blobName, data);
}
}
}

Просмотреть файл

@ -13,7 +13,7 @@ namespace TriggerService
string AccountName { get; }
string AccountAuthority { get; }
string GetBlobSasUrl(string blobUrl, TimeSpan sasTokenDuration);
Task<byte[]> DownloadFileAsync(string blobUrl);
Task<byte[]> DownloadBlockBlobAsync(string blobUrl);
Task<string> UploadFileFromPathAsync(string path, string container, string blobName);
Task<string> UploadFileTextAsync(string content, string container, string blobName);
Task MutateStateAsync(string container, string blobName, AzureStorage.WorkflowState newState);
@ -24,5 +24,6 @@ namespace TriggerService
Task<IEnumerable<CloudBlockBlob>> GetWorkflowsByStateAsync(AzureStorage.WorkflowState state);
Task<bool> IsSingleBlobExistsFromPrefixAsync(string container, string blobPrefix);
Task<bool> IsAvailableAsync();
Task<byte[]> DownloadFileUsingHttpClientAsync(string url);
}
}

Просмотреть файл

@ -39,9 +39,11 @@ namespace TriggerService
var serviceProvider = serviceCollection.BuildServiceProvider();
var cloudStorageAccount = await AzureStorage.GetCloudStorageAccountUsingMsiAsync(defaultStorageAccountName);
var environment = new CromwellOnAzureEnvironment(
serviceProvider.GetRequiredService<ILoggerFactory>(),
new AzureStorage(serviceProvider.GetRequiredService<ILoggerFactory>().CreateLogger<AzureStorage>(), defaultStorageAccountName),
new AzureStorage(serviceProvider.GetRequiredService<ILoggerFactory>().CreateLogger<AzureStorage>(), cloudStorageAccount, new System.Net.Http.HttpClient()),
new CromwellApiClient.CromwellApiClient(cromwellUrl));
serviceCollection.AddSingleton(s => new TriggerEngine(s.GetRequiredService<ILoggerFactory>(), environment));