Increase test coverage, add datasets to submission code (#22)

This commit is contained in:
Anton Schwaighofer 2021-07-24 09:26:50 +01:00 коммит произвёл GitHub
Родитель d52270a4d8
Коммит be41c9be18
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
9 изменённых файлов: 195 добавлений и 95 удалений

Просмотреть файл

@ -0,0 +1,7 @@
<component name="CopyrightManager">
<copyright>
<option name="keyword" value="Copyright .* Microsoft Corporation" />
<option name="notice" value="------------------------------------------------------------------------------------------&#10;Copyright (c) Microsoft Corporation. All rights reserved.&#10;Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.&#10;------------------------------------------------------------------------------------------" />
<option name="myName" value="MIT license" />
</copyright>
</component>

Просмотреть файл

@ -0,0 +1,7 @@
<component name="CopyrightManager">
<settings>
<module2copyright>
<element module="All" copyright="MIT license" />
</module2copyright>
</settings>
</component>

Просмотреть файл

@ -1,3 +1,6 @@
[pytest]
log_cli = True
log_cli_level = DEBUG
log_cli_level = DEBUG
addopts=--strict-markers
markers=
fast: Tests that should run very fast, and can act as smoke tests to see if something goes terrribly wrong.

Просмотреть файл

@ -1,13 +1,9 @@
import logging
from typing import List
from typing import Optional
from typing import Union
from pathlib import Path
from typing import List, Optional, Union
from azureml.core import Dataset
from azureml.core import Datastore
from azureml.core import Workspace
from azureml.data import FileDataset
from azureml.data import OutputFileDatasetConfig
from azureml.core import Dataset, Datastore, Workspace
from azureml.data import FileDataset, OutputFileDatasetConfig
from azureml.data.dataset_consumption_config import DatasetConsumptionConfig
@ -76,7 +72,7 @@ class DatasetConfig:
version: Optional[int] = None,
use_mounting: Optional[bool] = None,
target_folder: str = "",
local_folder: str = ""):
local_folder: Optional[Path] = None):
"""
Creates a new configuration for using an AzureML dataset.
:param name: The name of the dataset, as it was registered in the AzureML workspace. For output datasets,

Просмотреть файл

@ -15,21 +15,12 @@ import sys
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Dict
from typing import Generator
from typing import List
from typing import Optional
from typing import Dict, Generator, List, Optional
from azureml.core import Experiment
from azureml.core import Run
from azureml.core import RunConfiguration
from azureml.core import ScriptRunConfig
from azureml.core import Workspace
from azureml.core import Experiment, Run, RunConfiguration, ScriptRunConfig, Workspace
from health.azure.datasets import StrOrDatasetConfig
from src.health.azure.himl_configs import SourceConfig
from src.health.azure.himl_configs import WorkspaceConfig
from src.health.azure.himl_configs import get_authentication
from health.azure.datasets import StrOrDatasetConfig, _input_dataset_key, _output_dataset_key, _replace_string_datasets
from src.health.azure.himl_configs import SourceConfig, WorkspaceConfig, get_authentication
logger = logging.getLogger('health.azure')
logger.setLevel(logging.DEBUG)
@ -38,12 +29,13 @@ logger.setLevel(logging.DEBUG)
RUN_CONTEXT = Run.get_context()
WORKSPACE_CONFIG_JSON = "config.json"
AZUREML_COMMANDLINE_FLAG = "--azureml"
@dataclass
class AzureRunInformation:
input_datasets: List[Path]
output_datasets: List[Path]
input_datasets: List[Optional[Path]]
output_datasets: List[Optional[Path]]
run: Run
is_running_in_azure: bool
# In Azure, this would be the "outputs" folder. In local runs: "." or create a timestamped folder.
@ -61,21 +53,19 @@ def is_running_in_azure(run: Run = RUN_CONTEXT) -> bool:
return hasattr(run, 'experiment')
def submit_to_azure_if_needed(
workspace_config: Optional[WorkspaceConfig],
workspace_config_path: Optional[Path],
compute_cluster_name: str,
snapshot_root_directory: Path,
entry_script: Path,
conda_environment_file: Path,
script_params: List[str] = [],
environment_variables: Dict[str, str] = {},
ignored_folders: List[Path] = [],
default_datastore: str = "",
input_datasets: Optional[List[StrOrDatasetConfig]] = None,
output_datasets: Optional[List[StrOrDatasetConfig]] = None,
num_nodes: int = 1,
) -> Run:
def submit_to_azure_if_needed(entry_script: Path, # type: ignore
compute_cluster_name: str,
conda_environment_file: Path,
workspace_config: Optional[WorkspaceConfig] = None,
workspace_config_path: Optional[Path] = None,
snapshot_root_directory: Optional[Path] = None,
environment_variables: Optional[Dict[str, str]] = None,
ignored_folders: Optional[List[Path]] = None,
default_datastore: str = "",
input_datasets: Optional[List[StrOrDatasetConfig]] = None,
output_datasets: Optional[List[StrOrDatasetConfig]] = None,
num_nodes: int = 1,
) -> AzureRunInformation:
"""
Submit a folder to Azure, if needed and run it.
@ -85,9 +75,33 @@ def submit_to_azure_if_needed(
:param workspace_config_file: Optional path to workspace config file.
:return: Run object for the submitted AzureML run.
"""
if all(["azureml" not in arg for arg in sys.argv]):
logging.info("The flag azureml is not set, and so not submitting to AzureML")
return
cleaned_input_datasets = _replace_string_datasets(input_datasets or [],
default_datastore_name=default_datastore)
cleaned_output_datasets = _replace_string_datasets(output_datasets or [],
default_datastore_name=default_datastore)
in_azure = is_running_in_azure()
if in_azure:
returned_input_datasets = [RUN_CONTEXT.input_datasets[_input_dataset_key(index)]
for index in range(len(cleaned_input_datasets))]
returned_output_datasets = [RUN_CONTEXT.output_datasets[_output_dataset_key(index)]
for index in range(len(cleaned_output_datasets))]
return AzureRunInformation(
input_datasets=returned_input_datasets,
output_datasets=returned_output_datasets,
run=RUN_CONTEXT,
is_running_in_azure=True,
output_folder=Path.cwd() / "outputs",
log_folder=Path.cwd() / "logs"
)
if AZUREML_COMMANDLINE_FLAG not in sys.argv[1:]:
return AzureRunInformation(
input_datasets=[d.local_folder for d in cleaned_input_datasets],
output_datasets=[d.local_folder for d in cleaned_output_datasets],
run=RUN_CONTEXT,
is_running_in_azure=False,
output_folder=Path.cwd() / "outputs",
log_folder=Path.cwd() / "logs"
)
if workspace_config_path and workspace_config_path.is_file():
auth = get_authentication()
workspace = Workspace.from_config(path=workspace_config_path, auth=auth)
@ -101,12 +115,12 @@ def submit_to_azure_if_needed(
snapshot_root_directory=snapshot_root_directory,
conda_environment_file=conda_environment_file,
entry_script=entry_script,
script_params=script_params,
script_params=[p for p in sys.argv[1:] if p != AZUREML_COMMANDLINE_FLAG],
environment_variables=environment_variables)
with append_to_amlignore(
dirs_to_append=ignored_folders,
snapshot_root_directory=snapshot_root_directory):
dirs_to_append=ignored_folders or [],
snapshot_root_directory=snapshot_root_directory or Path.cwd()):
# TODO: InnerEye.azure.azure_runner.submit_to_azureml does work here with interupt handlers to kill interupted
# jobs. We'll do that later if still required.
@ -115,6 +129,16 @@ def submit_to_azure_if_needed(
run_config = RunConfiguration(
script=entry_script_relative_path,
arguments=source_config.script_params)
inputs = {}
for index, d in enumerate(cleaned_input_datasets):
consumption = d.to_input_dataset(workspace=workspace, dataset_index=index)
inputs[consumption.name] = consumption
outputs = {}
for index, d in enumerate(cleaned_output_datasets):
out = d.to_output_dataset(workspace=workspace, dataset_index=index)
outputs[out.name] = out
run_config.data = inputs
run_config.output_data = outputs
script_run_config = ScriptRunConfig(
source_directory=str(source_config.snapshot_root_directory),
run_config=run_config,
@ -134,7 +158,7 @@ def submit_to_azure_if_needed(
logging.info("Experiment URL: {}".format(experiment.get_portal_url()))
logging.info("Run URL: {}".format(run.get_portal_url()))
logging.info("==============================================================================\n")
return run
exit(0)
@contextmanager

Просмотреть файл

@ -2,7 +2,8 @@ from pathlib import Path
from torchvision.datasets import MNIST
from health.azure.aml import submit_to_azure_if_needed, DatasetConfig
from health.azure.datasets import DatasetConfig
from health.azure.himl import submit_to_azure_if_needed
def main() -> None:

Просмотреть файл

@ -3,20 +3,28 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import logging
from pathlib import Path
from unittest import mock
import pytest
from health.azure.himl import AzureRunInformation
from health.azure.himl import submit_to_azure_if_needed
logger = logging.getLogger('test.health.azure')
logger.setLevel(logging.DEBUG)
@pytest.mark.fast
def test_submit_to_azure_if_needed() -> None:
"""
Test that submit_to_azure_if_needed can be called.
Test that submit_to_azure_if_needed can be called, and returns immediately.
"""
with pytest.raises(Exception) as ex:
submit_to_azure_if_needed(
workspace_config=None,
workspace_config_path=None)
with mock.patch("sys.argv", [""]):
result = submit_to_azure_if_needed(entry_script=Path(__file__),
compute_cluster_name="foo",
conda_environment_file=Path("env.yml"),
)
assert isinstance(result, AzureRunInformation)
assert not result.is_running_in_azure

Просмотреть файл

@ -1,14 +1,14 @@
from unittest import mock
import pytest
from azureml.data import OutputFileDatasetConfig
from azureml.core import Dataset
from azureml.data import FileDataset, OutputFileDatasetConfig
from azureml.data.azure_storage_datastore import AzureBlobDatastore
from azureml.data.dataset_consumption_config import DatasetConsumptionConfig
from health.azure.datasets import DatasetConfig
from health.azure.datasets import get_datastore
from testhiml.health.azure.utils import DEFAULT_DATASTORE
from testhiml.health.azure.utils import DEFAULT_WORKSPACE
from health.azure.datasets import (DatasetConfig, _input_dataset_key, _output_dataset_key, _replace_string_datasets,
get_datastore, get_or_create_dataset)
from testhiml.health.azure.utils import DEFAULT_DATASTORE, default_aml_workspace
def test_datasetconfig_init() -> None:
@ -17,40 +17,34 @@ def test_datasetconfig_init() -> None:
assert "name of the dataset must be a non-empty string" in str(ex)
def test_get_datastore_fails() -> None:
"""
Retrieving a datastore that does not exist should fail
"""
does_not_exist = "does_not_exist"
with pytest.raises(ValueError) as ex:
get_datastore(workspace=DEFAULT_WORKSPACE, datastore_name=does_not_exist)
assert f"Datastore {does_not_exist} was not found" in str(ex)
def test_get_datastore_without_name() -> None:
"""
Trying to get a datastore without name should only work if there is a single datastore
"""
assert len(DEFAULT_WORKSPACE.datastores) > 1
with pytest.raises(ValueError) as ex:
get_datastore(workspace=DEFAULT_WORKSPACE, datastore_name="")
assert "No datastore name provided" in str(ex)
def test_get_datastore() -> None:
"""
Tests getting a datastore by name.
Test retrieving a datastore from the AML workspace.
"""
# Retrieving a datastore that does not exist should fail
does_not_exist = "does_not_exist"
workspace = default_aml_workspace()
with pytest.raises(ValueError) as ex:
get_datastore(workspace=workspace, datastore_name=does_not_exist)
assert f"Datastore {does_not_exist} was not found" in str(ex)
# Trying to get a datastore without name should only work if there is a single datastore
assert len(workspace.datastores) > 1
with pytest.raises(ValueError) as ex:
get_datastore(workspace=workspace, datastore_name="")
assert "No datastore name provided" in str(ex)
# Retrieve a datastore by name
name = DEFAULT_DATASTORE
datastore = get_datastore(workspace=DEFAULT_WORKSPACE, datastore_name=name)
datastore = get_datastore(workspace=workspace, datastore_name=name)
assert isinstance(datastore, AzureBlobDatastore)
assert datastore.name == name
assert len(DEFAULT_WORKSPACE.datastores) > 1
assert len(workspace.datastores) > 1
# Now mock the datastores property of the workspace, to pretend there is only a single datastore.
# With that in place, we can get the datastore without the name
faked_stores = {name: datastore}
with mock.patch("azureml.core.Workspace.datastores", faked_stores):
single_store = get_datastore(workspace=DEFAULT_WORKSPACE, datastore_name="")
single_store = get_datastore(workspace=workspace, datastore_name="")
assert isinstance(single_store, AzureBlobDatastore)
assert single_store.name == name
@ -59,20 +53,22 @@ def test_dataset_input() -> None:
"""
Test turning a dataset setup object to an actual AML input dataset.
"""
workspace = default_aml_workspace()
# This dataset must exist in the workspace already, or at least in blob storage.
dataset_config = DatasetConfig(name="hello_world", datastore=DEFAULT_DATASTORE)
aml_dataset = dataset_config.to_input_dataset(workspace=DEFAULT_WORKSPACE, dataset_index=1)
aml_dataset = dataset_config.to_input_dataset(workspace=workspace, dataset_index=1)
assert isinstance(aml_dataset, DatasetConsumptionConfig)
assert aml_dataset.path_on_compute is None
assert aml_dataset.mode == "download"
# Downloading or mounting to a given path
target_folder = "/tmp/foo"
dataset_config = DatasetConfig(name="hello_world", datastore=DEFAULT_DATASTORE, target_folder=target_folder)
aml_dataset = dataset_config.to_input_dataset(workspace=DEFAULT_WORKSPACE, dataset_index=1)
aml_dataset = dataset_config.to_input_dataset(workspace=workspace, dataset_index=1)
assert isinstance(aml_dataset, DatasetConsumptionConfig)
assert aml_dataset.path_on_compute == target_folder
# Use mounting instead of downloading
dataset_config = DatasetConfig(name="hello_world", datastore=DEFAULT_DATASTORE, use_mounting=True)
aml_dataset = dataset_config.to_input_dataset(workspace=DEFAULT_WORKSPACE, dataset_index=1)
aml_dataset = dataset_config.to_input_dataset(workspace=workspace, dataset_index=1)
assert isinstance(aml_dataset, DatasetConsumptionConfig)
assert aml_dataset.mode == "mount"
@ -82,8 +78,9 @@ def test_dataset_output() -> None:
Test turning a dataset setup object to an actual AML output dataset.
"""
name = "new_dataset"
workspace = default_aml_workspace()
dataset_config = DatasetConfig(name=name, datastore=DEFAULT_DATASTORE)
aml_dataset = dataset_config.to_output_dataset(workspace=DEFAULT_WORKSPACE, dataset_index=1)
aml_dataset = dataset_config.to_output_dataset(workspace=workspace, dataset_index=1)
assert isinstance(aml_dataset, OutputFileDatasetConfig)
assert isinstance(aml_dataset.destination, tuple)
assert aml_dataset.destination[0].name == DEFAULT_DATASTORE
@ -91,10 +88,70 @@ def test_dataset_output() -> None:
assert aml_dataset.mode == "mount"
# Use downloading instead of mounting
dataset_config = DatasetConfig(name="hello_world", datastore=DEFAULT_DATASTORE, use_mounting=False)
aml_dataset = dataset_config.to_input_dataset(workspace=DEFAULT_WORKSPACE, dataset_index=1)
assert aml_dataset.mode == "download"
aml_dataset = dataset_config.to_output_dataset(workspace=workspace, dataset_index=1)
assert isinstance(aml_dataset, OutputFileDatasetConfig)
assert aml_dataset.mode == "upload"
# Mounting at a fixed folder is not possible
with pytest.raises(ValueError) as ex:
dataset_config = DatasetConfig(name=name, datastore=DEFAULT_DATASTORE, target_folder="something")
dataset_config.to_output_dataset(workspace=DEFAULT_WORKSPACE, dataset_index=1)
dataset_config.to_output_dataset(workspace=workspace, dataset_index=1)
assert "Output datasets can't have a target_folder set" in str(ex)
def test_datasets_from_string() -> None:
"""
Test the conversion of datasets that are only specified as strings.
"""
dataset1 = "foo"
dataset2 = "bar"
store = "store"
default_store = "default"
original = [dataset1, DatasetConfig(name=dataset2, datastore=store)]
replaced = _replace_string_datasets(original, default_datastore_name=default_store)
assert len(replaced) == len(original)
for d in replaced:
assert isinstance(d, DatasetConfig)
assert replaced[0].name == dataset1
assert replaced[0].datastore == default_store
assert replaced[1] == original[1]
def test_get_dataset() -> None:
"""
Test if a dataset that does not yet exist can be created from a folder in blob storage
"""
# A folder with a single tiny file
tiny_dataset = "himl-tiny_dataset"
workspace = default_aml_workspace()
# When creating a dataset, we need a non-empty name
with pytest.raises(ValueError) as ex:
get_or_create_dataset(workspace=workspace,
datastore_name=DEFAULT_DATASTORE,
dataset_name="")
assert "No dataset name" in str(ex)
# Check first that there is no dataset yet of that name. If there is, delete that dataset (it would come
# from previous runs of this test)
try:
existing_dataset = Dataset.get_by_name(workspace, name=tiny_dataset)
existing_dataset.unregister_all_versions()
except Exception as ex:
assert "Cannot find dataset registered" in str(ex)
dataset = get_or_create_dataset(workspace=workspace,
datastore_name=DEFAULT_DATASTORE,
dataset_name=tiny_dataset)
assert isinstance(dataset, FileDataset)
# We should now be able to get that dataset without special means
dataset2 = Dataset.get_by_name(workspace, name=tiny_dataset)
# Delete the dataset again
dataset2.unregister_all_versions()
def test_dataset_keys() -> None:
"""
Check that dataset keys are non-empty strings, and that inputs and outputs have different keys.
"""
in1 = _input_dataset_key(1)
out1 = _output_dataset_key(1)
assert in1
assert out1
assert in1 != out1

Просмотреть файл

@ -3,9 +3,7 @@ from pathlib import Path
from azureml.core import Workspace
from health.azure.himl import WORKSPACE_CONFIG_JSON
from health.azure.himl_configs import SUBSCRIPTION_ID
from health.azure.himl_configs import get_secret_from_environment
from health.azure.himl_configs import get_authentication
from health.azure.himl_configs import SUBSCRIPTION_ID, get_authentication, get_secret_from_environment
def repository_root() -> Path:
@ -15,7 +13,7 @@ def repository_root() -> Path:
return Path(__file__).parent.parent.parent.parent
def aml_workspace() -> Workspace:
def default_aml_workspace() -> Workspace:
"""
Gets the default AzureML workspace that is used for testing.
"""
@ -31,5 +29,4 @@ def aml_workspace() -> Workspace:
resource_group="InnerEye-DeepLearning")
DEFAULT_WORKSPACE = aml_workspace()
DEFAULT_DATASTORE = "innereyedatasets"
DEFAULT_DATASTORE = "innereyedatasets"