BUG: Fix for duplicate authentication (#878)

Resolve issues with AML SDK v1/v2: Even in the v2 code path we are
creating v1 workspaces, leading to duplicate authentication requests.
To achieve that, a couple of other changes were necessary:
* Deprecate the use of the default datastore, which was read out of an
SDK v1 Workspace object even if SDK v2 was chosen
* No longer allowing SDK v2 when mounting datasets for local runs (v2
Datasets can't be mounted at all).

Also added more detailed logging for dataset creation, and a commandline
flag to control logging level.
This commit is contained in:
Anton Schwaighofer 2023-05-16 20:26:48 +01:00 коммит произвёл GitHub
Родитель 683def950a
Коммит f46f60e7fa
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
16 изменённых файлов: 689 добавлений и 422 удалений

Просмотреть файл

@ -20,7 +20,9 @@ from azureml.data.dataset_consumption_config import DatasetConsumptionConfig
from azureml.dataprep.fuse.daemon import MountContext
from azureml.exceptions._azureml_exception import UserErrorException
from health_azure.utils import PathOrString, get_workspace, get_ml_client
from health_azure.utils import PathOrString, get_ml_client
logger = logging.getLogger(__name__)
V1OrV2DataType = Union[FileDataset, Data]
@ -128,11 +130,14 @@ def _get_or_create_v1_dataset(datastore_name: str, dataset_name: str, workspace:
try:
azureml_dataset = _retrieve_v1_dataset(dataset_name, workspace)
except UserErrorException:
logger.warning(f"Dataset '{dataset_name}' was not found, or is not an AzureML SDK v1 dataset.")
logger.info(f"Trying to create a new dataset '{dataset_name}' from files in folder '{dataset_name}'")
if datastore_name == "":
raise ValueError(
"When creating a new dataset, a datastore name must be provided. Please specify a datastore name using "
"the --datastore flag"
)
logger.info(f"Trying to create a new dataset '{dataset_name}' in datastore '{datastore_name}'")
azureml_dataset = _create_v1_dataset(datastore_name, dataset_name, workspace)
return azureml_dataset
@ -352,10 +357,8 @@ class DatasetConfig:
def to_input_dataset_local(
self,
strictly_aml_v1: bool,
workspace: Workspace = None,
ml_client: Optional[MLClient] = None,
) -> Tuple[Optional[Path], Optional[MountContext]]:
workspace: Workspace,
) -> Tuple[Path, Optional[MountContext]]:
"""
Return a local path to the dataset when outside of an AzureML run.
If local_folder is supplied, then this is assumed to be a local dataset, and this is returned.
@ -364,9 +367,6 @@ class DatasetConfig:
therefore a tuple of Nones will be returned.
:param workspace: The AzureML workspace to read from.
:param strictly_aml_v1: If True, use Azure ML SDK v1 to attempt to find or create and reigster the dataset.
Otherwise, attempt to use Azure ML SDK v2.
:param ml_client: An Azure MLClient object for interacting with Azure resources.
:return: Tuple of (path to dataset, optional mountcontext)
"""
status = f"Dataset '{self.name}' will be "
@ -381,12 +381,10 @@ class DatasetConfig:
f"Unable to make dataset '{self.name} available for a local run because no AzureML "
"workspace has been provided. Provide a workspace, or set a folder for local execution."
)
azureml_dataset = get_or_create_dataset(
azureml_dataset = _get_or_create_v1_dataset(
datastore_name=self.datastore,
dataset_name=self.name,
workspace=workspace,
strictly_aml_v1=strictly_aml_v1,
ml_client=ml_client,
)
if isinstance(azureml_dataset, FileDataset):
target_path = self.target_folder or Path(tempfile.mkdtemp())
@ -404,7 +402,7 @@ class DatasetConfig:
print(status)
return result
else:
return None, None
raise ValueError(f"Don't know how to handle dataset '{self.name}' of type {type(azureml_dataset)}")
def to_input_dataset(
self,
@ -556,38 +554,10 @@ def create_dataset_configs(
return datasets
def find_workspace_for_local_datasets(
aml_workspace: Optional[Workspace], workspace_config_path: Optional[Path], dataset_configs: List[DatasetConfig]
) -> Optional[Workspace]:
"""
If any of the dataset_configs require an AzureML workspace then try to get one, otherwise return None.
:param aml_workspace: There are two optional parameters used to glean an existing AzureML Workspace. The simplest is
to pass it in as a parameter.
:param workspace_config_path: The 2nd option is to specify the path to the config.json file downloaded from the
Azure portal from which we can retrieve the existing Workspace.
:param dataset_configs: List of DatasetConfig describing the input datasets.
:return: Workspace if required, None otherwise.
"""
workspace: Workspace = None
# Check whether an attempt will be made to mount or download a dataset when running locally.
# If so, try to get the AzureML workspace.
if any(dc.local_folder is None for dc in dataset_configs):
try:
workspace = get_workspace(aml_workspace, workspace_config_path)
logging.info(f"Found workspace for datasets: {workspace.name}")
except Exception as ex:
logging.info(f"Could not find workspace for datasets. Exception: {ex}")
return workspace
def setup_local_datasets(
dataset_configs: List[DatasetConfig],
strictly_aml_v1: bool,
aml_workspace: Optional[Workspace] = None,
ml_client: Optional[MLClient] = None,
workspace_config_path: Optional[Path] = None,
) -> Tuple[List[Optional[Path]], List[MountContext]]:
workspace: Optional[Workspace],
) -> Tuple[List[Path], List[MountContext]]:
"""
When running outside of AzureML, setup datasets to be used locally.
@ -595,21 +565,20 @@ def setup_local_datasets(
used. Otherwise the dataset is mounted or downloaded to either the target folder or a temporary folder and that is
used.
:param aml_workspace: There are two optional parameters used to glean an existing AzureML Workspace. The simplest is
to pass it in as a parameter.
:param workspace_config_path: The 2nd option is to specify the path to the config.json file downloaded from the
Azure portal from which we can retrieve the existing Workspace.
If a dataset does not exist, an AzureML SDK v1 dataset will be created, assuming that the dataset is given
in a folder of the same name (for example, if a dataset is given as "mydataset", then it is created from the files
in folder "mydataset" in the datastore).
:param workspace: The AzureML workspace to work with. Can be None if the list of datasets is empty, or if
the datasets are available local.
:param dataset_configs: List of DatasetConfig describing the input data assets.
:param strictly_aml_v1: If True, use Azure ML SDK v1. Otherwise, attempt to use Azure ML SDK v2.
:param ml_client: An MLClient object for interacting with AML v2 datastores.
:return: Pair of: list of optional paths to the input datasets, list of mountcontexts, one for each mounted dataset.
:return: Pair of: list of paths to the input datasets, list of mountcontexts, one for each mounted dataset.
"""
workspace = find_workspace_for_local_datasets(aml_workspace, workspace_config_path, dataset_configs)
mounted_input_datasets: List[Optional[Path]] = []
mounted_input_datasets: List[Path] = []
mount_contexts: List[MountContext] = []
for data_config in dataset_configs:
target_path, mount_context = data_config.to_input_dataset_local(strictly_aml_v1, workspace, ml_client)
target_path, mount_context = data_config.to_input_dataset_local(workspace)
mounted_input_datasets.append(target_path)

Просмотреть файл

@ -442,22 +442,20 @@ def effective_experiment_name(experiment_name: Optional[str], entry_script: Opti
def submit_run_v2(
workspace: Optional[Workspace],
ml_client: MLClient,
environment: EnvironmentV2,
entry_script: PathOrString,
script_params: List[str],
compute_target: str,
environment_variables: Optional[Dict[str, str]] = None,
experiment_name: Optional[str] = None,
input_datasets_v2: Optional[Dict[str, Input]] = None,
output_datasets_v2: Optional[Dict[str, Output]] = None,
snapshot_root_directory: Optional[Path] = None,
entry_script: Optional[PathOrString] = None,
script_params: Optional[List[str]] = None,
compute_target: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
docker_shm_size: str = "",
wait_for_completion: bool = False,
identity_based_auth: bool = False,
workspace_config_path: Optional[PathOrString] = None,
ml_client: Optional[MLClient] = None,
hyperparam_args: Optional[Dict[str, Any]] = None,
num_nodes: int = 1,
pytorch_processes_per_node: Optional[int] = None,
@ -466,8 +464,11 @@ def submit_run_v2(
"""
Starts a v2 AML Job on a given workspace by submitting a command
:param workspace: The AzureML workspace to use.
:param ml_client: An Azure MLClient object for interacting with Azure resources.
:param environment: An AML v2 Environment object.
:param entry_script: The script that should be run in AzureML.
:param script_params: A list of parameter to pass on to the script as it runs in AzureML.
:param compute_target: The name of a compute target in Azure ML to submit the job to.
:param environment_variables: The environment variables that should be set when running in AzureML.
:param experiment_name: The name of the experiment that will be used or created. If the experiment name contains
characters that are not valid in Azure, those will be removed.
@ -475,18 +476,11 @@ def submit_run_v2(
:param output_datasets_v2: An optional dictionary of Outputs to pass in to the command.
:param snapshot_root_directory: The directory that contains all code that should be packaged and sent to AzureML.
All Python code that the script uses must be copied over.
:param entry_script: The script that should be run in AzureML.
:param script_params: A list of parameter to pass on to the script as it runs in AzureML.
:param compute_target: Optional name of a compute target in Azure ML to submit the job to. If None, will run
locally.
:param tags: A dictionary of string key/value pairs, that will be added as metadata to the run. If set to None,
a default metadata field will be added that only contains the commandline arguments that started the run.
:param docker_shm_size: The Docker shared memory size that should be used when creating a new Docker image.
:param wait_for_completion: If False (the default) return after the run is submitted to AzureML, otherwise wait for
the completion of this run (if True).
:param workspace_config_path: If not provided with an AzureML Workspace, then load one given the information in this
config
:param ml_client: An Azure MLClient object for interacting with Azure resources.
:param hyperparam_args: A dictionary of hyperparameter search args to pass into a sweep job.
:param num_nodes: The number of nodes to use for the job in AzureML. The value must be 1 or greater.
:param pytorch_processes_per_node: For plain PyTorch multi-GPU processing: The number of processes per node.
@ -496,20 +490,6 @@ def submit_run_v2(
display name will be generated by AzureML.
:return: An AzureML Run object.
"""
if ml_client is None:
if workspace is not None:
ml_client = get_ml_client(
subscription_id=workspace.subscription_id,
resource_group=workspace.resource_group,
workspace_name=workspace.name,
)
elif workspace_config_path is not None:
ml_client = get_ml_client(workspace_config_path=workspace_config_path)
else:
raise ValueError("Either workspace or workspace_config_path must be specified to connect to the Workspace")
assert compute_target is not None, "No compute_target has been provided"
assert entry_script is not None, "No entry_script has been provided"
snapshot_root_directory = snapshot_root_directory or Path.cwd()
root_dir = Path(snapshot_root_directory)
@ -592,7 +572,11 @@ def submit_run_v2(
job_to_submit = create_command_job(cmd)
returned_job = ml_client.jobs.create_or_update(job_to_submit)
print(f"URL to job: {returned_job.services['Studio'].endpoint}") # type: ignore
print("\n==============================================================================")
# The ID field looks like /subscriptions/<sub>/resourceGroups/<rg?/providers/Microsoft.MachineLearningServices/..
print(f"Successfully queued run {(returned_job.id or '').split('/')[-1]}")
print(f"Run URL: {returned_job.services['Studio'].endpoint}") # type: ignore
print("==============================================================================\n")
if wait_for_completion:
print("Waiting for the completion of the AzureML job.")
wait_for_job_completion(ml_client, job_name=returned_job.name)
@ -671,7 +655,7 @@ def submit_run(
# These need to be 'print' not 'logging.info' so that the calling script sees them outside AzureML
print("\n==============================================================================")
print(f"Successfully queued run number {run.number} (ID {run.id}) in experiment {run.experiment.name}")
print(f"Successfully queued run {run.id} in experiment {run.experiment.name}")
print(f"Experiment name and run ID are available in file {RUN_RECOVERY_FILE}")
print(f"Experiment URL: {run.experiment.get_portal_url()}")
print(f"Run URL: {run.get_portal_url()}")
@ -885,6 +869,18 @@ def submit_to_azure_if_needed( # type: ignore
# is necessary. If not, return to the caller for local execution.
if submit_to_azureml is None:
submit_to_azureml = AZUREML_FLAG in sys.argv[1:]
has_input_datasets = len(cleaned_input_datasets) > 0
if submit_to_azureml or has_input_datasets:
if strictly_aml_v1:
aml_workspace = get_workspace(aml_workspace, workspace_config_path)
assert aml_workspace is not None
print(f"Loaded AzureML workspace {aml_workspace.name}")
else:
ml_client = get_ml_client(ml_client=ml_client, workspace_config_path=workspace_config_path)
assert ml_client is not None
print(f"Created MLClient for AzureML workspace {ml_client.workspace_name}")
if not submit_to_azureml:
# Set the environment variables for local execution.
environment_variables = {**DEFAULT_ENVIRONMENT_VARIABLES, **(environment_variables or {})}
@ -898,16 +894,24 @@ def submit_to_azure_if_needed( # type: ignore
logs_folder = Path.cwd() / LOGS_FOLDER
logs_folder.mkdir(exist_ok=True)
any_local_folders_missing = any(dataset.local_folder is None for dataset in cleaned_input_datasets)
if has_input_datasets and any_local_folders_missing and not strictly_aml_v1:
raise ValueError(
"AzureML SDK v2 does not support downloading datasets from AzureML for local execution. "
"Please switch to AzureML SDK v1 by setting strictly_aml_v1=True, or use "
"--strictly_aml_v1 on the commandline, or provide a local folder for each input dataset. "
"Note that you will not be able use AzureML datasets for runs outside AzureML if the datasets were "
"created via SDK v2."
)
mounted_input_datasets, mount_contexts = setup_local_datasets(
cleaned_input_datasets,
strictly_aml_v1,
aml_workspace=aml_workspace,
ml_client=ml_client,
workspace_config_path=workspace_config_path,
workspace=aml_workspace,
)
return AzureRunInfo(
input_datasets=mounted_input_datasets,
input_datasets=mounted_input_datasets, # type: ignore
output_datasets=[d.local_folder for d in cleaned_output_datasets],
mount_contexts=mount_contexts,
run=None,
@ -920,9 +924,6 @@ def submit_to_azure_if_needed( # type: ignore
print(f"No snapshot root directory given. Uploading all files in the current directory {Path.cwd()}")
snapshot_root_directory = Path.cwd()
workspace = get_workspace(aml_workspace, workspace_config_path)
print(f"Loaded AzureML workspace {workspace.name}")
if conda_environment_file is None:
conda_environment_file = find_file_in_parent_to_pythonpath(CONDA_ENVIRONMENT_FILE)
if conda_environment_file is None:
@ -938,8 +939,9 @@ def submit_to_azure_if_needed( # type: ignore
with append_to_amlignore(amlignore=amlignore_path, lines_to_append=lines_to_append):
if strictly_aml_v1:
assert aml_workspace is not None, "An AzureML workspace should have been created already."
run_config = create_run_configuration(
workspace=workspace,
workspace=aml_workspace,
compute_cluster_name=compute_cluster_name,
aml_environment_name=aml_environment_name,
conda_environment_file=conda_environment_file,
@ -968,7 +970,7 @@ def submit_to_azure_if_needed( # type: ignore
config_to_submit = script_run_config
run = submit_run(
workspace=workspace,
workspace=aml_workspace,
experiment_name=effective_experiment_name(experiment_name, script_run_config.script),
script_run_config=config_to_submit,
tags=tags,
@ -979,6 +981,7 @@ def submit_to_azure_if_needed( # type: ignore
if after_submission is not None:
after_submission(run) # type: ignore
else:
assert ml_client is not None, "An AzureML MLClient should have been created already."
if conda_environment_file is None:
raise ValueError("Argument 'conda_environment_file' must be specified when using AzureML v2")
environment = create_python_environment_v2(
@ -987,13 +990,12 @@ def submit_to_azure_if_needed( # type: ignore
if entry_script is None:
entry_script = Path(sys.argv[0])
ml_client = get_ml_client(ml_client=ml_client, aml_workspace=workspace)
registered_env = register_environment_v2(environment, ml_client)
input_datasets_v2 = create_v2_inputs(ml_client, cleaned_input_datasets)
output_datasets_v2 = create_v2_outputs(ml_client, cleaned_output_datasets)
job = submit_run_v2(
workspace=workspace,
ml_client=ml_client,
input_datasets_v2=input_datasets_v2,
output_datasets_v2=output_datasets_v2,
experiment_name=experiment_name,

Просмотреть файл

@ -39,12 +39,7 @@ def main() -> None: # pragma: no cover
files_to_download = download_config.files_to_download
workspace = get_workspace()
ml_client = get_ml_client(
subscription_id=workspace.subscription_id,
resource_group=workspace.resource_group,
workspace_name=workspace.name,
)
ml_client = get_ml_client()
for run_id in download_config.run:
download_job_outputs_logs(ml_client, run_id, file_to_download_path=files_to_download, download_dir=output_dir)
print("Successfully downloaded output and log files")

Просмотреть файл

@ -13,7 +13,6 @@ from typing import Generator, Optional, Union
from health_azure.utils import ENV_LOCAL_RANK, check_is_any_of, is_global_rank_zero
logging_stdout_handler: Optional[logging.StreamHandler] = None
logging_to_file_handler: Optional[logging.StreamHandler] = None
def logging_to_stdout(log_level: Union[int, str] = logging.INFO) -> None:

Просмотреть файл

@ -288,6 +288,30 @@ def find_file_in_parent_to_pythonpath(file_name: str) -> Optional[Path]:
return find_file_in_parent_folders(file_name=file_name, stop_at_path=pythonpaths)
def resolve_workspace_config_path(workspace_config_path: Optional[Path] = None) -> Optional[Path]:
"""Retrieve the path to the workspace config file, either from the argument, or from the current working directory.
:param workspace_config_path: A path to a workspace config file that was provided on the commandline, defaults to
None
:return: The path to the workspace config file, or None if it cannot be found.
:raises FileNotFoundError: If the workspace config file that was provided as an argument does not exist.
"""
if workspace_config_path is None:
logger.info(
f"Trying to locate the workspace config file '{WORKSPACE_CONFIG_JSON}' in the current folder "
"and its parent folders"
)
result = find_file_in_parent_to_pythonpath(WORKSPACE_CONFIG_JSON)
if result:
logger.info(f"Using the workspace config file {str(result.absolute())}")
else:
logger.debug("No workspace config file found")
return result
if not workspace_config_path.is_file():
raise FileNotFoundError(f"Workspace config file does not exist: {workspace_config_path}")
return workspace_config_path
def get_workspace(aml_workspace: Optional[Workspace] = None, workspace_config_path: Optional[Path] = None) -> Workspace:
"""
Retrieve an Azure ML Workspace by going through the following steps:
@ -320,26 +344,16 @@ def get_workspace(aml_workspace: Optional[Workspace] = None, workspace_config_pa
if aml_workspace:
return aml_workspace
if workspace_config_path is None:
logging.info(
f"Trying to locate the workspace config file '{WORKSPACE_CONFIG_JSON}' in the current folder "
"and its parent folders"
)
workspace_config_path = find_file_in_parent_to_pythonpath(WORKSPACE_CONFIG_JSON)
if workspace_config_path:
logging.info(f"Using the workspace config file {str(workspace_config_path.absolute())}")
workspace_config_path = resolve_workspace_config_path(workspace_config_path)
auth = get_authentication()
if workspace_config_path is not None:
if not workspace_config_path.is_file():
raise FileNotFoundError(f"Workspace config file does not exist: {workspace_config_path}")
workspace = Workspace.from_config(path=str(workspace_config_path), auth=auth)
logging.info(
logger.info(
f"Logged into AzureML workspace {workspace.name} as specified in config file " f"{workspace_config_path}"
)
return workspace
logging.info("Trying to load the environment variables that define the workspace.")
logger.info("Trying to load the environment variables that define the workspace.")
workspace_name = get_secret_from_environment(ENV_WORKSPACE_NAME, allow_missing=True)
subscription_id = get_secret_from_environment(ENV_SUBSCRIPTION_ID, allow_missing=True)
resource_group = get_secret_from_environment(ENV_RESOURCE_GROUP, allow_missing=True)
@ -347,7 +361,7 @@ def get_workspace(aml_workspace: Optional[Workspace] = None, workspace_config_pa
workspace = Workspace.get(
name=workspace_name, auth=auth, subscription_id=subscription_id, resource_group=resource_group
)
logging.info(f"Logged into AzureML workspace {workspace.name} as specified by environment variables")
logger.info(f"Logged into AzureML workspace {workspace.name} as specified by environment variables")
return workspace
raise ValueError(
@ -1747,7 +1761,8 @@ class UnitTestWorkspaceWrapper:
"""
Init.
"""
self._workspace: Workspace = None
self._workspace: Optional[Workspace] = None
self._ml_client: Optional[MLClient] = None
@property
def workspace(self) -> Workspace:
@ -1758,6 +1773,15 @@ class UnitTestWorkspaceWrapper:
self._workspace = get_workspace()
return self._workspace
@property
def ml_client(self) -> MLClient:
"""
Lazily load the ML Client.
"""
if self._ml_client is None:
self._ml_client = get_ml_client()
return self._ml_client
@contextmanager
def check_config_json(script_folder: Path, shared_config_json: Path) -> Generator:
@ -1895,7 +1919,7 @@ def _get_legitimate_interactive_browser_credential() -> Optional[TokenCredential
def get_credential() -> Optional[TokenCredential]:
"""
Get a credential for authenticating with Azure.There are multiple ways to retrieve a credential.
Get a credential for authenticating with Azure. There are multiple ways to retrieve a credential.
If environment variables pertaining to details of a Service Principal are available, those will be used
to authenticate. If no environment variables exist, and the script is not currently
running inside of Azure ML or another Azure agent, will attempt to retrieve a credential via a
@ -1910,6 +1934,7 @@ def get_credential() -> Optional[TokenCredential]:
tenant_id = get_secret_from_environment(ENV_TENANT_ID, allow_missing=True)
service_principal_password = get_secret_from_environment(ENV_SERVICE_PRINCIPAL_PASSWORD, allow_missing=True)
if service_principal_id and tenant_id and service_principal_password:
logger.debug("Found environment variables for Service Principal authentication")
return _get_legitimate_service_principal_credential(tenant_id, service_principal_id, service_principal_password)
try:
@ -1927,66 +1952,76 @@ def get_credential() -> Optional[TokenCredential]:
raise ValueError(
"Unable to generate and validate a credential. Please see Azure ML documentation"
"for instructions on diffrent options to get a credential"
"for instructions on different options to get a credential"
)
def get_ml_client(
ml_client: Optional[MLClient] = None,
aml_workspace: Optional[Workspace] = None,
workspace_config_path: Optional[PathOrString] = None,
subscription_id: Optional[str] = None,
resource_group: Optional[str] = None,
workspace_name: str = "",
workspace_config_path: Optional[Path] = None,
) -> MLClient:
"""
Instantiate an MLClient for interacting with Azure resources via v2 of the Azure ML SDK.
If a ml_client is provided, return that. Otherwise, create one using workspace details
coming from either an existing Workspace object, a config.json file or passed in as an argument.
Instantiate an MLClient for interacting with Azure resources via v2 of the Azure ML SDK. The following ways of
creating the client are tried out:
1. If an MLClient object has been provided in the `ml_client` argument, return that.
2. If a path to a workspace config file has been provided, load the MLClient according to that config file.
3. If a workspace config file is present in the current working directory or one of its parents, load the
MLClient according to that config file.
4. If 3 environment variables are found, use them to identify the workspace (`HIML_RESOURCE_GROUP`,
`HIML_SUBSCRIPTION_ID`, `HIML_WORKSPACE_NAME`)
If none of the above succeeds, an exception is raised.
:param ml_client: An optional existing MLClient object to be returned.
:param aml_workspace: An optional Workspace object to take connection details from.
:param workspace_config_path: An optional path toa config.json file containing details of the Workspace.
:param subscription_id: An optional subscription ID.
:param resource_group: An optional resource group name.
:param workspace_name: An optional workspace name.
:return: An instance of MLClient to interact with Azure resources.
"""
if ml_client:
if ml_client is not None:
return ml_client
logger.debug("Getting credentials")
credential = get_credential()
if credential is None:
raise ValueError("Can't connect to MLClient without a valid credential")
if aml_workspace is not None:
ml_client = MLClient(
subscription_id=aml_workspace.subscription_id,
resource_group_name=aml_workspace.resource_group,
workspace_name=aml_workspace.name,
credential=credential,
) # type: ignore
elif workspace_config_path:
workspace_config_path = resolve_workspace_config_path(workspace_config_path)
if workspace_config_path is not None:
logger.debug(f"Retrieving MLClient from workspace config {workspace_config_path}")
ml_client = MLClient.from_config(credential=credential, path=str(workspace_config_path)) # type: ignore
elif subscription_id and resource_group and workspace_name:
logger.info(
f"Using MLClient for AzureML workspace {ml_client.workspace_name} as specified in config file"
f"{workspace_config_path}"
)
return ml_client
logger.info("Trying to load the environment variables that define the workspace.")
workspace_name = get_secret_from_environment(ENV_WORKSPACE_NAME, allow_missing=True)
subscription_id = get_secret_from_environment(ENV_SUBSCRIPTION_ID, allow_missing=True)
resource_group = get_secret_from_environment(ENV_RESOURCE_GROUP, allow_missing=True)
if workspace_name and subscription_id and resource_group:
logger.debug(
"Retrieving MLClient via subscription ID, resource group and workspace name retrieved from "
"environment variables."
)
ml_client = MLClient(
subscription_id=subscription_id,
resource_group_name=resource_group,
workspace_name=workspace_name,
credential=credential,
) # type: ignore
else:
try:
workspace = get_workspace()
ml_client = MLClient(
subscription_id=workspace.subscription_id,
resource_group_name=workspace.resource_group,
workspace_name=workspace.name,
credential=credential,
) # type: ignore
except ValueError as e:
raise ValueError(f"Couldn't connect to MLClient: {e}")
logging.info(f"Logged into AzureML workspace {ml_client.workspace_name}")
return ml_client
logger.info(f"Using MLClient for AzureML workspace {workspace_name} as specified by environment variables")
return ml_client
raise ValueError(
"Tried all ways of identifying the MLClient, but failed. Please provide a workspace config "
f"file {WORKSPACE_CONFIG_JSON} or set the environment variables {ENV_RESOURCE_GROUP}, "
f"{ENV_SUBSCRIPTION_ID}, and {ENV_WORKSPACE_NAME}."
)
def retrieve_workspace_from_client(ml_client: MLClient, workspace_name: Optional[str] = None) -> WorkspaceV2:

Просмотреть файл

@ -13,7 +13,7 @@ import time
from pathlib import Path
from typing import Any, Dict, Generator, List, Optional, Union
from unittest import mock
from unittest.mock import DEFAULT, MagicMock, patch
from unittest.mock import MagicMock, patch
from uuid import uuid4
from xmlrpc.client import Boolean
@ -23,13 +23,12 @@ import pandas as pd
import param
import pytest
from _pytest.logging import LogCaptureFixture
from azure.identity import ClientSecretCredential, DeviceCodeCredential, DefaultAzureCredential
from azure.storage.blob import ContainerClient
from azureml._restclient.constants import RunStatus
from azureml.core import Experiment, Run, ScriptRunConfig, Workspace
from azureml.core.run import _OfflineRun
from azureml.core.environment import CondaDependencies
from azure.core.exceptions import ClientAuthenticationError, ResourceNotFoundError
from azure.core.exceptions import ResourceNotFoundError
from azureml.data.azure_storage_datastore import AzureBlobDatastore
import health_azure.utils as util
@ -41,8 +40,8 @@ from health_azure.utils import (
MASTER_PORT_DEFAULT,
PackageDependency,
download_files_by_suffix,
get_credential,
download_file_if_necessary,
resolve_workspace_config_path,
)
from testazure.test_himl import RunTarget, render_and_run_test_script
from testazure.utils_testazure import (
@ -1996,136 +1995,6 @@ def test_create_run() -> None:
run.complete()
def test_get_credential() -> None:
def _mock_validation_error() -> None:
raise ClientAuthenticationError("")
# test the case where service principal credentials are set as environment variables
mock_env_vars = {
util.ENV_SERVICE_PRINCIPAL_ID: "foo",
util.ENV_TENANT_ID: "bar",
util.ENV_SERVICE_PRINCIPAL_PASSWORD: "baz",
}
with patch.object(os.environ, "get", return_value=mock_env_vars):
with patch.multiple(
"health_azure.utils",
is_running_in_azure_ml=DEFAULT,
is_running_on_azure_agent=DEFAULT,
_get_legitimate_service_principal_credential=DEFAULT,
_get_legitimate_device_code_credential=DEFAULT,
_get_legitimate_default_credential=DEFAULT,
_get_legitimate_interactive_browser_credential=DEFAULT,
) as mocks:
mocks["is_running_in_azure_ml"].return_value = False
mocks["is_running_on_azure_agent"].return_value = False
_ = get_credential()
mocks["_get_legitimate_service_principal_credential"].assert_called_once()
mocks["_get_legitimate_device_code_credential"].assert_not_called()
mocks["_get_legitimate_default_credential"].assert_not_called()
mocks["_get_legitimate_interactive_browser_credential"].assert_not_called()
# if the environment variables are not set and we are running on a local machine, a
# DefaultAzureCredential should be attempted first
with patch.object(os.environ, "get", return_value={}):
with patch.multiple(
"health_azure.utils",
is_running_in_azure_ml=DEFAULT,
is_running_on_azure_agent=DEFAULT,
_get_legitimate_service_principal_credential=DEFAULT,
_get_legitimate_device_code_credential=DEFAULT,
_get_legitimate_default_credential=DEFAULT,
_get_legitimate_interactive_browser_credential=DEFAULT,
) as mocks:
mock_get_sp_cred = mocks["_get_legitimate_service_principal_credential"]
mock_get_device_cred = mocks["_get_legitimate_device_code_credential"]
mock_get_default_cred = mocks["_get_legitimate_default_credential"]
mock_get_browser_cred = mocks["_get_legitimate_interactive_browser_credential"]
mocks["is_running_in_azure_ml"].return_value = False
mocks["is_running_on_azure_agent"].return_value = False
_ = get_credential()
mock_get_sp_cred.assert_not_called()
mock_get_device_cred.assert_not_called()
mock_get_default_cred.assert_called_once()
mock_get_browser_cred.assert_not_called()
# if that fails, a DeviceCode credential should be attempted
mock_get_default_cred.side_effect = _mock_validation_error
_ = get_credential()
mock_get_sp_cred.assert_not_called()
mock_get_device_cred.assert_called_once()
assert mock_get_default_cred.call_count == 2
mock_get_browser_cred.assert_not_called()
# if None of the previous credentials work, an InteractiveBrowser credential should be tried
mock_get_device_cred.return_value = None
_ = get_credential()
mock_get_sp_cred.assert_not_called()
assert mock_get_device_cred.call_count == 2
assert mock_get_default_cred.call_count == 3
mock_get_browser_cred.assert_called_once()
# finally, if none of the methods work, an Exception should be raised
mock_get_browser_cred.return_value = None
with pytest.raises(Exception) as e:
get_credential()
assert (
"Unable to generate and validate a credential. Please see Azure ML documentation"
"for instructions on different options to get a credential" in str(e)
)
def test_get_legitimate_service_principal_credential() -> None:
# first attempt to create and valiadate a credential with non-existant service principal credentials
# and check it fails
mock_service_principal_id = "foo"
mock_service_principal_password = "bar"
mock_tenant_id = "baz"
expected_error_msg = f"Found environment variables for {util.ENV_SERVICE_PRINCIPAL_ID}, "
f"{util.ENV_SERVICE_PRINCIPAL_PASSWORD}, and {util.ENV_TENANT_ID} but was not able to authenticate"
with pytest.raises(Exception) as e:
util._get_legitimate_service_principal_credential(
mock_tenant_id, mock_service_principal_id, mock_service_principal_password
)
assert expected_error_msg in str(e)
# now mock the case where validating the credential succeeds and check the value of that
with patch("health_azure.utils._validate_credential"):
cred = util._get_legitimate_service_principal_credential(
mock_tenant_id, mock_service_principal_id, mock_service_principal_password
)
assert isinstance(cred, ClientSecretCredential)
def test_get_legitimate_device_code_credential() -> None:
def _mock_credential_fast_timeout(timeout: int) -> DeviceCodeCredential:
return DeviceCodeCredential(timeout=1)
with patch("health_azure.utils.DeviceCodeCredential", new=_mock_credential_fast_timeout):
cred = util._get_legitimate_device_code_credential()
assert cred is None
# now mock the case where validating the credential succeeds
with patch("health_azure.utils._validate_credential"):
cred = util._get_legitimate_device_code_credential()
assert isinstance(cred, DeviceCodeCredential)
def test_get_legitimate_default_credential() -> None:
def _mock_credential_fast_timeout(timeout: int) -> DefaultAzureCredential:
return DefaultAzureCredential(timeout=1)
with patch("health_azure.utils.DefaultAzureCredential", new=_mock_credential_fast_timeout):
exception_message = r"DefaultAzureCredential failed to retrieve a token from the included credentials."
with pytest.raises(ClientAuthenticationError, match=exception_message):
cred = util._get_legitimate_default_credential()
with patch("health_azure.utils._validate_credential"):
cred = util._get_legitimate_default_credential()
assert isinstance(cred, DefaultAzureCredential)
def test_filter_v2_input_output_args() -> None:
def _compare_args(expected: List[str], actual: List[str]) -> None:
assert len(actual) == len(expected)
@ -2244,3 +2113,33 @@ def test_download_files_by_suffix(tmp_path: Path, files: List[str], expected_dow
assert f.is_file()
downloaded_filenames = [f.name for f in downloaded_list]
assert downloaded_filenames == expected_downloaded
def test_resolve_workspace_config_path_no_argument(tmp_path: Path) -> None:
"""Test for resolve_workspace_config_path without argument: It should try to find a config file in the folders.
If the file exists, it should return the path"""
mocked_file = tmp_path / "foo.json"
with patch("health_azure.utils.find_file_in_parent_to_pythonpath", return_value=mocked_file):
result = resolve_workspace_config_path()
assert result == mocked_file
def test_resolve_workspace_config_path_no_argument_no_file() -> None:
"""Test for resolve_workspace_config_path without argument: It should try to find a config file in the folders.
If the file does not exist, return None"""
with patch("health_azure.utils.find_file_in_parent_to_pythonpath", return_value=None):
result = resolve_workspace_config_path()
assert result is None
def test_resolve_workspace_config_path_file_exists(tmp_path: Path) -> None:
mocked_file = tmp_path / "foo.json"
mocked_file.touch()
result = resolve_workspace_config_path(mocked_file)
assert result == mocked_file
def test_resolve_workspace_config_path_missing(tmp_path: Path) -> None:
mocked_file = tmp_path / "foo.json"
with pytest.raises(FileNotFoundError, match="Workspace config file does not exist"):
resolve_workspace_config_path(mocked_file)

Просмотреть файл

@ -20,6 +20,7 @@ from azureml.data import FileDataset, OutputFileDatasetConfig
from azureml.data.azure_storage_datastore import AzureBlobDatastore
from azureml.data.dataset_consumption_config import DatasetConsumptionConfig
from azureml.exceptions._azureml_exception import UserErrorException
from health_azure.himl import submit_to_azure_if_needed
from testazure.utils_testazure import (
DEFAULT_DATASTORE,
DEFAULT_WORKSPACE,
@ -27,7 +28,6 @@ from testazure.utils_testazure import (
TEST_DATA_ASSET_NAME,
TEST_INVALID_DATA_ASSET_NAME,
TEST_DATASTORE_NAME,
get_test_ml_client,
)
from health_azure.datasets import (
@ -46,10 +46,7 @@ from health_azure.datasets import (
get_or_create_dataset,
_get_latest_v2_asset_version,
)
from health_azure.utils import PathOrString, get_ml_client
TEST_ML_CLIENT = get_test_ml_client()
from health_azure.utils import PathOrString
def test_datasetconfig_init() -> None:
@ -234,12 +231,11 @@ def test_get_or_create_dataset() -> None:
data_asset_name = "himl_tiny_data_asset"
workspace = DEFAULT_WORKSPACE.workspace
ml_client = get_ml_client(aml_workspace=workspace)
# When creating a dataset, we need a non-empty name
with pytest.raises(ValueError) as ex:
get_or_create_dataset(
workspace=workspace,
ml_client=ml_client,
ml_client=DEFAULT_WORKSPACE.ml_client,
datastore_name="himldatasetsv2",
dataset_name="",
strictly_aml_v1=True,
@ -254,7 +250,7 @@ def test_get_or_create_dataset() -> None:
mocks["_get_or_create_v1_dataset"].return_value = mock_v1_dataset
dataset = get_or_create_dataset(
workspace=workspace,
ml_client=ml_client,
ml_client=DEFAULT_WORKSPACE.ml_client,
datastore_name="himldatasetsv2",
dataset_name=data_asset_name,
strictly_aml_v1=True,
@ -268,7 +264,7 @@ def test_get_or_create_dataset() -> None:
mocks["_get_or_create_v2_data_asset"].return_value = mock_v2_dataset
dataset = get_or_create_dataset(
workspace=workspace,
ml_client=ml_client,
ml_client=DEFAULT_WORKSPACE.ml_client,
datastore_name="himldatasetsv2",
dataset_name=data_asset_name,
strictly_aml_v1=False,
@ -281,7 +277,7 @@ def test_get_or_create_dataset() -> None:
mocks["_get_or_create_v2_data_asset"].side_effect = _mock_retrieve_or_create_v2_dataset_fails
dataset = get_or_create_dataset(
workspace=workspace,
ml_client=ml_client,
ml_client=DEFAULT_WORKSPACE.ml_client,
datastore_name="himldatasetsv2",
dataset_name=data_asset_name,
strictly_aml_v1=False,
@ -417,7 +413,7 @@ def test_retrieve_v2_data_asset(asset_name: str, asset_version: Optional[str]) -
mock_get_v2_asset_version.side_effect = _get_latest_v2_asset_version
try:
data_asset = _retrieve_v2_data_asset(
ml_client=TEST_ML_CLIENT,
ml_client=DEFAULT_WORKSPACE.ml_client,
data_asset_name=asset_name,
version=asset_version,
)
@ -445,10 +441,12 @@ def test_retrieve_v2_data_asset(asset_name: str, asset_version: Optional[str]) -
def test_retrieve_v2_data_asset_invalid_version() -> None:
invalid_asset_version = str(int(_get_latest_v2_asset_version(TEST_ML_CLIENT, TEST_DATA_ASSET_NAME)) + 1)
invalid_asset_version = str(
int(_get_latest_v2_asset_version(DEFAULT_WORKSPACE.ml_client, TEST_DATA_ASSET_NAME)) + 1
)
with pytest.raises(ResourceNotFoundError) as ex:
_retrieve_v2_data_asset(
ml_client=TEST_ML_CLIENT,
ml_client=DEFAULT_WORKSPACE.ml_client,
data_asset_name=TEST_DATA_ASSET_NAME,
version=invalid_asset_version,
)
@ -459,15 +457,19 @@ def test_retrieving_v2_data_asset_does_not_increment() -> None:
"""Test if calling the get_or_create_data_asset on an existing asset does not increment the version number."""
with patch("health_azure.datasets._create_v2_data_asset") as mock_create_v2_data_asset:
asset_version_before_get_or_create = _get_latest_v2_asset_version(TEST_ML_CLIENT, TEST_DATA_ASSET_NAME)
asset_version_before_get_or_create = _get_latest_v2_asset_version(
DEFAULT_WORKSPACE.ml_client, TEST_DATA_ASSET_NAME
)
get_or_create_dataset(
TEST_DATASTORE_NAME,
TEST_DATA_ASSET_NAME,
DEFAULT_WORKSPACE,
strictly_aml_v1=False,
ml_client=TEST_ML_CLIENT,
ml_client=DEFAULT_WORKSPACE.ml_client,
)
asset_version_after_get_or_create = _get_latest_v2_asset_version(
DEFAULT_WORKSPACE.ml_client, TEST_DATA_ASSET_NAME
)
asset_version_after_get_or_create = _get_latest_v2_asset_version(TEST_ML_CLIENT, TEST_DATA_ASSET_NAME)
mock_create_v2_data_asset.assert_not_called()
assert asset_version_before_get_or_create == asset_version_after_get_or_create
@ -485,7 +487,7 @@ def test_retrieving_v2_data_asset_does_not_increment() -> None:
def test_create_v2_data_asset(asset_name: str, datastore_name: str, version: Optional[str]) -> None:
try:
data_asset = _create_v2_data_asset(
ml_client=TEST_ML_CLIENT,
ml_client=DEFAULT_WORKSPACE.ml_client,
datastore_name=TEST_DATASTORE_NAME,
data_asset_name=asset_name,
version=version,
@ -558,3 +560,36 @@ def test_create_dataset_configs() -> None:
with pytest.raises(Exception) as e:
create_dataset_configs(azure_datasets, dataset_mountpoints, local_datasets, datastore, use_mounting)
assert "Invalid dataset setup" in str(e)
def test_local_datasets() -> None:
"""Test if Azure datasets can be mounted for local runs"""
# Dataset hello_world must exist in the test AzureML workspace
dataset = DatasetConfig(name="hello_world")
run_info = submit_to_azure_if_needed(
input_datasets=[dataset],
strictly_aml_v1=True,
)
assert len(run_info.input_datasets) == 1
assert isinstance(run_info.input_datasets[0], Path)
assert run_info.input_datasets[0].is_dir()
assert len(list(run_info.input_datasets[0].glob("*"))) > 0
def test_local_datasets_fails_with_v2() -> None:
"""Azure datasets can't be used when using SDK v2"""
dataset = DatasetConfig(name="himl-tiny_dataset")
with pytest.raises(ValueError, match="AzureML SDK v2 does not support downloading datasets from AzureML"):
submit_to_azure_if_needed(
input_datasets=[dataset],
strictly_aml_v1=False,
)
def test_local_datasets_fail_with_v2() -> None:
"""If no datasets are specified, we can still run with SDK v2"""
run_info = submit_to_azure_if_needed(
input_datasets=[],
strictly_aml_v1=False,
)
assert len(run_info.input_datasets) == 0

Просмотреть файл

@ -0,0 +1,249 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
"""
Tests for health_azure.azure_get_workspace and related functions.
"""
import os
from pathlib import Path
from unittest.mock import DEFAULT, MagicMock, patch
import pytest
from azure.core.exceptions import ClientAuthenticationError
from azure.identity import ClientSecretCredential, DefaultAzureCredential, DeviceCodeCredential
from health_azure.utils import (
ENV_RESOURCE_GROUP,
ENV_SERVICE_PRINCIPAL_ID,
ENV_SERVICE_PRINCIPAL_PASSWORD,
ENV_SUBSCRIPTION_ID,
ENV_TENANT_ID,
ENV_WORKSPACE_NAME,
_get_legitimate_default_credential,
_get_legitimate_device_code_credential,
_get_legitimate_service_principal_credential,
get_credential,
get_ml_client,
)
@pytest.mark.fast
def test_get_credential() -> None:
def _mock_validation_error() -> None:
raise ClientAuthenticationError("")
# test the case where service principal credentials are set as environment variables
mock_env_vars = {
ENV_SERVICE_PRINCIPAL_ID: "foo",
ENV_TENANT_ID: "bar",
ENV_SERVICE_PRINCIPAL_PASSWORD: "baz",
}
with patch.object(os.environ, "get", return_value=mock_env_vars):
with patch.multiple(
"health_azure.utils",
is_running_in_azure_ml=DEFAULT,
is_running_on_azure_agent=DEFAULT,
_get_legitimate_service_principal_credential=DEFAULT,
_get_legitimate_device_code_credential=DEFAULT,
_get_legitimate_default_credential=DEFAULT,
_get_legitimate_interactive_browser_credential=DEFAULT,
) as mocks:
mocks["is_running_in_azure_ml"].return_value = False
mocks["is_running_on_azure_agent"].return_value = False
_ = get_credential()
mocks["_get_legitimate_service_principal_credential"].assert_called_once()
mocks["_get_legitimate_device_code_credential"].assert_not_called()
mocks["_get_legitimate_default_credential"].assert_not_called()
mocks["_get_legitimate_interactive_browser_credential"].assert_not_called()
# if the environment variables are not set and we are running on a local machine, a
# DefaultAzureCredential should be attempted first
with patch.object(os.environ, "get", return_value={}):
with patch.multiple(
"health_azure.utils",
is_running_in_azure_ml=DEFAULT,
is_running_on_azure_agent=DEFAULT,
_get_legitimate_service_principal_credential=DEFAULT,
_get_legitimate_device_code_credential=DEFAULT,
_get_legitimate_default_credential=DEFAULT,
_get_legitimate_interactive_browser_credential=DEFAULT,
) as mocks:
mock_get_sp_cred = mocks["_get_legitimate_service_principal_credential"]
mock_get_device_cred = mocks["_get_legitimate_device_code_credential"]
mock_get_default_cred = mocks["_get_legitimate_default_credential"]
mock_get_browser_cred = mocks["_get_legitimate_interactive_browser_credential"]
mocks["is_running_in_azure_ml"].return_value = False
mocks["is_running_on_azure_agent"].return_value = False
_ = get_credential()
mock_get_sp_cred.assert_not_called()
mock_get_device_cred.assert_not_called()
mock_get_default_cred.assert_called_once()
mock_get_browser_cred.assert_not_called()
# if that fails, a DeviceCode credential should be attempted
mock_get_default_cred.side_effect = _mock_validation_error
_ = get_credential()
mock_get_sp_cred.assert_not_called()
mock_get_device_cred.assert_called_once()
assert mock_get_default_cred.call_count == 2
mock_get_browser_cred.assert_not_called()
# if None of the previous credentials work, an InteractiveBrowser credential should be tried
mock_get_device_cred.return_value = None
_ = get_credential()
mock_get_sp_cred.assert_not_called()
assert mock_get_device_cred.call_count == 2
assert mock_get_default_cred.call_count == 3
mock_get_browser_cred.assert_called_once()
# finally, if none of the methods work, an Exception should be raised
mock_get_browser_cred.return_value = None
with pytest.raises(Exception) as e:
get_credential()
assert (
"Unable to generate and validate a credential. Please see Azure ML documentation"
"for instructions on different options to get a credential" in str(e)
)
@pytest.mark.fast
def test_get_legitimate_service_principal_credential() -> None:
# first attempt to create and valiadate a credential with non-existant service principal credentials
# and check it fails
mock_service_principal_id = "foo"
mock_service_principal_password = "bar"
mock_tenant_id = "baz"
expected_error_msg = f"Found environment variables for {ENV_SERVICE_PRINCIPAL_ID}, "
f"{ENV_SERVICE_PRINCIPAL_PASSWORD}, and {ENV_TENANT_ID} but was not able to authenticate"
with pytest.raises(Exception) as e:
_get_legitimate_service_principal_credential(
mock_tenant_id, mock_service_principal_id, mock_service_principal_password
)
assert expected_error_msg in str(e)
# now mock the case where validating the credential succeeds and check the value of that
with patch("health_azure.utils._validate_credential"):
cred = _get_legitimate_service_principal_credential(
mock_tenant_id, mock_service_principal_id, mock_service_principal_password
)
assert isinstance(cred, ClientSecretCredential)
@pytest.mark.fast
def test_get_legitimate_device_code_credential() -> None:
def _mock_credential_fast_timeout(timeout: int) -> DeviceCodeCredential:
return DeviceCodeCredential(timeout=1)
with patch("health_azure.utils.DeviceCodeCredential", new=_mock_credential_fast_timeout):
cred = _get_legitimate_device_code_credential()
assert cred is None
# now mock the case where validating the credential succeeds
with patch("health_azure.utils._validate_credential"):
cred = _get_legitimate_device_code_credential()
assert isinstance(cred, DeviceCodeCredential)
@pytest.mark.fast
def test_get_legitimate_default_credential() -> None:
def _mock_credential_fast_timeout(timeout: int) -> DefaultAzureCredential:
return DefaultAzureCredential(timeout=1)
with patch("health_azure.utils.DefaultAzureCredential", new=_mock_credential_fast_timeout):
exception_message = r"DefaultAzureCredential failed to retrieve a token from the included credentials."
with pytest.raises(ClientAuthenticationError, match=exception_message):
cred = _get_legitimate_default_credential()
with patch("health_azure.utils._validate_credential"):
cred = _get_legitimate_default_credential()
assert isinstance(cred, DefaultAzureCredential)
@pytest.mark.fast
def test_get_ml_client_with_existing_client() -> None:
"""When passing an existing ml_client, it should be returned"""
ml_client = "mock_ml_client"
result = get_ml_client(ml_client=ml_client) # type: ignore
assert result == ml_client
@pytest.mark.fast
def test_get_ml_client_without_credentials() -> None:
"""When no credentials are available, an exception should be raised"""
with patch("health_azure.utils.get_credential", return_value=None):
with pytest.raises(ValueError, match="Can't connect to MLClient without a valid credential"):
get_ml_client()
@pytest.mark.fast
def test_get_ml_client_from_config_file() -> None:
"""If a workspace config file is found, it should be used to create the MLClient"""
mock_credentials = "mock_credentials"
mock_config_path = Path("foo")
mock_ml_client = MagicMock(workspace_name="workspace")
mock_from_config = MagicMock(return_value=mock_ml_client)
mock_resolve_config_path = MagicMock(return_value=mock_config_path)
with patch.multiple(
"health_azure.utils",
get_credential=MagicMock(return_value=mock_credentials),
resolve_workspace_config_path=mock_resolve_config_path,
MLClient=MagicMock(from_config=mock_from_config),
):
config_file = Path("foo")
result = get_ml_client(workspace_config_path=config_file)
assert result == mock_ml_client
mock_resolve_config_path.assert_called_once_with(config_file)
mock_from_config.assert_called_once_with(
credential=mock_credentials,
path=str(mock_config_path),
)
@pytest.mark.fast
def test_get_ml_client_from_environment_variables() -> None:
"""When no workspace config file is found, the MLClient should be created from environment variables"""
mock_credentials = "mock_credentials"
the_client = "the_client"
mock_ml_client = MagicMock(return_value=the_client)
workspace = "workspace"
subscription = "subscription"
resource_group = "resource_group"
with patch.multiple(
"health_azure.utils",
get_credential=MagicMock(return_value=mock_credentials),
resolve_workspace_config_path=MagicMock(return_value=None),
MLClient=mock_ml_client,
):
with patch.dict(
os.environ,
{ENV_WORKSPACE_NAME: workspace, ENV_SUBSCRIPTION_ID: subscription, ENV_RESOURCE_GROUP: resource_group},
):
result = get_ml_client()
assert result == the_client
mock_ml_client.assert_called_once_with(
subscription_id=subscription,
resource_group_name=resource_group,
workspace_name=workspace,
credential=mock_credentials,
)
@pytest.mark.fast
def test_get_ml_client_fails() -> None:
"""If neither a workspace config file nor environment variables are found, an exception should be raised"""
mock_credentials = "mock_credentials"
the_client = "the_client"
mock_ml_client = MagicMock(return_value=the_client)
with patch.multiple(
"health_azure.utils",
get_credential=MagicMock(return_value=mock_credentials),
resolve_workspace_config_path=MagicMock(return_value=None),
MLClient=mock_ml_client,
):
# In the GitHub runner, the environment variables are set. We need to unset them to test the exception
with patch.dict(os.environ, {ENV_WORKSPACE_NAME: ""}):
with pytest.raises(ValueError, match="Tried all ways of identifying the MLClient, but failed"):
get_ml_client()

Просмотреть файл

@ -10,6 +10,7 @@ from pathlib import Path
from uuid import uuid4
from azureml.core.authentication import ServicePrincipalAuthentication
from azureml.exceptions._azureml_exception import UserErrorException
from _pytest.logging import LogCaptureFixture
import pytest
from unittest.mock import MagicMock, patch
@ -22,7 +23,6 @@ from health_azure.utils import (
get_workspace,
)
from health_azure.utils import (
WORKSPACE_CONFIG_JSON,
ENV_SERVICE_PRINCIPAL_ID,
ENV_SERVICE_PRINCIPAL_PASSWORD,
ENV_TENANT_ID,
@ -141,13 +141,16 @@ def test_get_workspace_with_given_workspace() -> None:
@pytest.mark.fast
def test_get_workspace_searches_for_file() -> None:
def test_get_workspace_searches_for_file(tmp_path: Path) -> None:
"""get_workspace should try to load a config.json file if not provided with one"""
found_file = Path("does_not_exist")
with patch("health_azure.utils.find_file_in_parent_to_pythonpath", return_value=found_file) as mock_find:
with pytest.raises(FileNotFoundError, match="Workspace config file does not exist"):
get_workspace(None, None)
mock_find.assert_called_once_with(WORKSPACE_CONFIG_JSON)
with change_working_directory(tmp_path):
found_file = Path("does_not_exist")
with patch("health_azure.utils.resolve_workspace_config_path", return_value=found_file) as mock_find:
with pytest.raises(
UserErrorException, match="workspace configuration file config.json, could not be found"
):
get_workspace(None, None)
mock_find.assert_called_once_with(None)
@pytest.mark.fast

Просмотреть файл

@ -773,7 +773,7 @@ def test_submit_run_v2(tmp_path: Path) -> None:
with patch("azure.ai.ml.MLClient") as mock_ml_client:
with patch("health_azure.himl.command") as mock_command:
himl.submit_run_v2(
workspace=None,
ml_client=mock_ml_client,
experiment_name=dummy_experiment_name,
environment=dummy_environment,
input_datasets_v2=dummy_inputs,
@ -784,8 +784,6 @@ def test_submit_run_v2(tmp_path: Path) -> None:
compute_target=dummy_compute_target,
tags=dummy_tags,
docker_shm_size=dummy_docker_shm_size,
workspace_config_path=None,
ml_client=mock_ml_client,
hyperparam_args=None,
display_name=dummy_display_name,
)
@ -835,7 +833,6 @@ def test_submit_run_v2(tmp_path: Path) -> None:
expected_command += " --learning_rate=${{inputs.learning_rate}}"
himl.submit_run_v2(
workspace=None,
experiment_name=dummy_experiment_name,
environment=dummy_environment,
input_datasets_v2=dummy_inputs,
@ -846,7 +843,6 @@ def test_submit_run_v2(tmp_path: Path) -> None:
compute_target=dummy_compute_target,
tags=dummy_tags,
docker_shm_size=dummy_docker_shm_size,
workspace_config_path=None,
ml_client=mock_ml_client,
hyperparam_args=dummy_hyperparam_args,
display_name=dummy_display_name,
@ -882,7 +878,7 @@ def test_submit_run_v2(tmp_path: Path) -> None:
expected_command = f"python {dummy_entry_script_for_module} {expected_arg_str}"
himl.submit_run_v2(
workspace=None,
ml_client=mock_ml_client,
experiment_name=dummy_experiment_name,
environment=dummy_environment,
input_datasets_v2=dummy_inputs,
@ -893,8 +889,6 @@ def test_submit_run_v2(tmp_path: Path) -> None:
compute_target=dummy_compute_target,
tags=dummy_tags,
docker_shm_size=dummy_docker_shm_size,
workspace_config_path=None,
ml_client=mock_ml_client,
hyperparam_args=None,
display_name=dummy_display_name,
)
@ -1308,9 +1302,7 @@ def test_mounting_and_downloading_dataset(tmp_path: Path) -> None:
target_path = tmp_path / action
dataset_config = DatasetConfig(name="hello_world", use_mounting=use_mounting, target_folder=target_path)
logging.info(f"ready to {action}")
paths, mount_contexts = setup_local_datasets(
dataset_configs=[dataset_config], strictly_aml_v1=True, aml_workspace=workspace
)
paths, mount_contexts = setup_local_datasets(dataset_configs=[dataset_config], workspace=workspace)
logging.info(f"{action} done")
path = paths[0]
assert path is not None
@ -1372,7 +1364,7 @@ class TestOutputDataset:
@pytest.mark.parametrize(
["run_target", "local_folder", "strictly_aml_v1"],
[
(RunTarget.LOCAL, True, False),
(RunTarget.LOCAL, True, True),
(RunTarget.AZUREML, False, True),
],
)

Просмотреть файл

@ -10,7 +10,6 @@ from contextlib import contextmanager
from pathlib import Path
from typing import Dict, Generator, Optional
from azure.ai.ml import MLClient
from azureml.core import Run
from health_azure.utils import (
ENV_EXPERIMENT_NAME,
@ -20,7 +19,6 @@ from health_azure.utils import (
)
from health_azure import create_aml_run_object
from health_azure.himl import effective_experiment_name
from health_azure.utils import get_ml_client, get_workspace
DEFAULT_DATASTORE = "himldatasets"
@ -111,16 +109,6 @@ def create_unittest_run_object(snapshot_directory: Optional[Path] = None) -> Run
)
def get_test_ml_client() -> MLClient:
"""Generates an MLClient object for use in tests.
:return: MLClient object
"""
workspace = get_workspace()
return get_ml_client(aml_workspace=workspace)
def current_test_name() -> str:
"""Get the name of the currently executed test. This is read off an environment variable. If that
is not found, the function returns an empty string."""

Просмотреть файл

@ -21,7 +21,7 @@ def mount_dataset(dataset_id: str, tmp_root: str = "/tmp/datasets", aml_workspac
ws = get_workspace(aml_workspace)
target_folder = "/".join([tmp_root, dataset_id])
dataset = DatasetConfig(name=dataset_id, target_folder=target_folder, use_mounting=True)
_, mount_ctx = dataset.to_input_dataset_local(strictly_aml_v1=True, workspace=ws)
_, mount_ctx = dataset.to_input_dataset_local(workspace=ws)
assert mount_ctx is not None # for mypy
mount_ctx.start()
return mount_ctx

Просмотреть файл

@ -21,5 +21,5 @@ def download_azure_dataset(tmp_path: Path, dataset_id: str) -> None:
with check_config_json(script_folder=tmp_path, shared_config_json=get_shared_config_json()):
ws = get_workspace(workspace_config_path=tmp_path / WORKSPACE_CONFIG_JSON)
dataset = DatasetConfig(name=dataset_id, target_folder=tmp_path, use_mounting=False)
dataset_dl_folder = dataset.to_input_dataset_local(strictly_aml_v1=True, workspace=ws)
dataset_dl_folder = dataset.to_input_dataset_local(workspace=ws)
logging.info(f"Dataset saved in {dataset_dl_folder}")

Просмотреть файл

@ -15,6 +15,14 @@ class RunnerMode(Enum):
EVAL_FULL = "eval_full"
class LogLevel(Enum):
ERROR = "ERROR"
WARNING = "WARNING"
WARN = "WARN"
INFO = "INFO"
DEBUG = "DEBUG"
DEBUG_DDP_ENV_VAR = "TORCH_DISTRIBUTED_DEBUG"
@ -87,10 +95,22 @@ class ExperimentConfig(param.Parameterized):
doc="The maximum runtime that is allowed for this job in AzureML. This is given as a floating"
"point number with a string suffix s, m, h, d for seconds, minutes, hours, day. Examples: '3.5h', '2d'",
)
mode: str = param.ClassSelector(
mode: RunnerMode = param.ClassSelector(
class_=RunnerMode,
default=RunnerMode.TRAIN,
doc=f"The mode to run the experiment in. Can be one of '{RunnerMode.TRAIN}' (training and evaluation on the "
f"test set), or '{RunnerMode.EVAL_FULL}' for evaluation on the full dataset specified by the "
"'get_eval_data_module' method of the container.",
)
log_level: Optional[RunnerMode] = param.ClassSelector(
class_=LogLevel,
default=None,
doc=f"The log level to use. Can be one of {list(map(str, LogLevel))}",
)
@property
def submit_to_azure_ml(self) -> bool:
"""Returns True if the experiment should be submitted to AzureML, False if it should be run locally.
:return: True if the experiment should be submitted to AzureML, False if it should be run locally."""
return self.cluster != ""

Просмотреть файл

@ -15,8 +15,6 @@ import sys
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from azureml.core import Workspace
# Add hi-ml packages to sys.path so that AML can find them if we are using the runner directly from the git repo
himl_root = Path(__file__).resolve().parent.parent.parent.parent
folders_to_add = [himl_root / "hi-ml" / "src", himl_root / "hi-ml-azure" / "src", himl_root / "hi-ml-cpath" / "src"]
@ -34,8 +32,6 @@ from health_azure.paths import is_himl_used_from_git_repo # noqa: E402
from health_azure.utils import ( # noqa: E402
ENV_LOCAL_RANK,
ENV_NODE_RANK,
get_workspace,
get_ml_client,
is_local_rank_zero,
is_running_in_azure_ml,
set_environment_variables_for_multi_node,
@ -122,6 +118,11 @@ class Runner:
parser1_result = parse_arguments(parser1, args=filtered_args)
experiment_config = ExperimentConfig(**parser1_result.args)
from health_azure.logging import logging_stdout_handler # noqa: E402
if logging_stdout_handler is not None and experiment_config.log_level is not None:
print(f"Setting custom logging level to {experiment_config.log_level}")
logging_stdout_handler.setLevel(experiment_config.log_level.value)
self.experiment_config = experiment_config
if not experiment_config.model:
raise ValueError("Parameter 'model' needs to be set to specify which model to run.")
@ -150,7 +151,7 @@ class Runner:
"""
Runs sanity checks on the whole experiment.
"""
if not self.experiment_config.cluster:
if not self.experiment_config.submit_to_azure_ml:
if self.lightning_container.hyperdrive:
raise ValueError(
"HyperDrive for hyperparameters tuning is only supported when submitting the job to "
@ -214,47 +215,26 @@ class Runner:
script_params = sys.argv[1:]
environment_variables = self.additional_environment_variables()
# Get default datastore from the provided workspace. Authentication can take a few seconds, hence only do
# that if we are really submitting to AzureML.
workspace: Optional[Workspace] = None
if self.experiment_config.cluster:
try:
workspace = get_workspace(workspace_config_path=self.experiment_config.workspace_config_path)
except ValueError:
raise ValueError(
"Unable to submit the script to AzureML because no workspace configuration file "
"(config.json) was found."
)
if self.lightning_container.datastore:
datastore = self.lightning_container.datastore
elif workspace:
datastore = workspace.get_default_datastore().name
else:
datastore = ""
local_datasets = self.lightning_container.local_datasets
all_local_datasets = [Path(p) for p in local_datasets] if len(local_datasets) > 0 else []
# When running in AzureML, respect the commandline flag for mounting. Outside of AML, we always mount
# datasets to be quicker.
use_mounting = self.experiment_config.mount_in_azureml if self.experiment_config.cluster else True
use_mounting = self.experiment_config.mount_in_azureml if self.experiment_config.submit_to_azure_ml else True
input_datasets = create_dataset_configs(
all_azure_dataset_ids=self.lightning_container.azure_datasets,
all_dataset_mountpoints=self.lightning_container.dataset_mountpoints,
all_local_datasets=all_local_datasets, # type: ignore
datastore=datastore,
datastore=self.lightning_container.datastore,
use_mounting=use_mounting,
)
if self.experiment_config.cluster and not is_running_in_azure_ml():
if self.experiment_config.submit_to_azure_ml and not is_running_in_azure_ml():
if self.experiment_config.strictly_aml_v1:
hyperdrive_config = self.lightning_container.get_hyperdrive_config()
hyperparam_args = None
else:
hyperparam_args = self.lightning_container.get_hyperparam_args()
hyperdrive_config = None
ml_client = get_ml_client(aml_workspace=workspace) if not self.experiment_config.strictly_aml_v1 else None
env_file = choose_conda_env_file(env_file=self.experiment_config.conda_env)
logging.info(f"Using this Conda environment definition: {env_file}")
@ -265,18 +245,15 @@ class Runner:
snapshot_root_directory=root_folder,
script_params=script_params,
conda_environment_file=env_file,
aml_workspace=workspace,
ml_client=ml_client,
compute_cluster_name=self.experiment_config.cluster,
environment_variables=environment_variables,
default_datastore=datastore,
experiment_name=self.lightning_container.effective_experiment_name,
input_datasets=input_datasets, # type: ignore
num_nodes=self.experiment_config.num_nodes,
wait_for_completion=self.experiment_config.wait_for_completion,
max_run_duration=self.experiment_config.max_run_duration,
ignored_folders=[],
submit_to_azureml=bool(self.experiment_config.cluster),
submit_to_azureml=self.experiment_config.submit_to_azure_ml,
docker_base_image=DEFAULT_DOCKER_BASE_IMAGE,
docker_shm_size=self.experiment_config.docker_shm_size,
hyperdrive_config=hyperdrive_config,
@ -292,7 +269,6 @@ class Runner:
submit_to_azureml=False,
environment_variables=environment_variables,
strictly_aml_v1=self.experiment_config.strictly_aml_v1,
default_datastore=datastore,
)
if azure_run_info.run:
# This code is only reached inside Azure. Set display name again - this will now affect

Просмотреть файл

@ -26,6 +26,7 @@ from health_ml.lightning_container import LightningContainer
from health_ml.runner import Runner, create_logging_filename, run_with_logging
from health_ml.utils.common_utils import change_working_directory
from health_ml.utils.fixed_paths import repository_root_directory
from testhiml.utils_testhiml import DEFAULT_WORKSPACE
@contextmanager
@ -132,10 +133,9 @@ def test_ddp_debug_flag(debug_ddp: DebugDDPOptions, mock_runner: Runner) -> None
model_name = "HelloWorld"
arguments = ["", f"--debug_ddp={debug_ddp}", f"--model={model_name}"]
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_azure_if_needed:
with patch("health_ml.runner.get_workspace"):
with patch("health_ml.runner.Runner.run_in_situ"):
with patch.object(sys, "argv", arguments):
mock_runner.run()
with patch("health_ml.runner.Runner.run_in_situ"):
with patch.object(sys, "argv", arguments):
mock_runner.run()
mock_submit_to_azure_if_needed.assert_called_once()
assert mock_submit_to_azure_if_needed.call_args[1]["environment_variables"][DEBUG_DDP_ENV_VAR] == debug_ddp
@ -144,12 +144,9 @@ def test_additional_aml_run_tags(mock_runner: Runner) -> None:
model_name = "HelloWorld"
arguments = ["", f"--model={model_name}", "--cluster=foo"]
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_azure_if_needed:
with patch("health_ml.runner.check_conda_environment"):
with patch("health_ml.runner.get_workspace"):
with patch("health_ml.runner.get_ml_client"):
with patch("health_ml.runner.Runner.run_in_situ"):
with patch.object(sys, "argv", arguments):
mock_runner.run()
with patch("health_ml.runner.Runner.run_in_situ"):
with patch.object(sys, "argv", arguments):
mock_runner.run()
mock_submit_to_azure_if_needed.assert_called_once()
assert "commandline_args" in mock_submit_to_azure_if_needed.call_args[1]["tags"]
assert "tag" in mock_submit_to_azure_if_needed.call_args[1]["tags"]
@ -162,9 +159,6 @@ def test_additional_environment_variables(mock_runner: Runner) -> None:
with patch.multiple(
"health_ml.runner",
submit_to_azure_if_needed=DEFAULT,
check_conda_environment=DEFAULT,
get_workspace=DEFAULT,
get_ml_client=DEFAULT,
) as mocks:
with patch("health_ml.runner.Runner.run_in_situ"):
with patch("health_ml.runner.Runner.parse_and_load_model"):
@ -185,9 +179,8 @@ def test_run(mock_runner: Runner) -> None:
model_name = "HelloWorld"
arguments = ["", f"--model={model_name}"]
with patch("health_ml.runner.Runner.run_in_situ") as mock_run_in_situ:
with patch("health_ml.runner.get_workspace"):
with patch.object(sys, "argv", arguments):
model_config, azure_run_info = mock_runner.run()
with patch.object(sys, "argv", arguments):
model_config, azure_run_info = mock_runner.run()
mock_run_in_situ.assert_called_once()
assert model_config is not None # for pyright
@ -197,17 +190,13 @@ def test_run(mock_runner: Runner) -> None:
@patch("health_ml.runner.choose_conda_env_file")
@patch("health_ml.runner.get_workspace")
@pytest.mark.fast
def test_submit_to_azureml_if_needed(
mock_get_workspace: MagicMock, mock_get_env_files: MagicMock, mock_runner: Runner
) -> None:
def test_submit_to_azureml_if_needed(mock_get_env_files: MagicMock, mock_runner: Runner) -> None:
def _mock_dont_submit_to_aml(
input_datasets: List[DatasetConfig],
submit_to_azureml: bool,
strictly_aml_v1: bool, # type: ignore
environment_variables: Dict[str, Any], # type: ignore
default_datastore: Optional[str], # type: ignore
) -> AzureRunInfo:
datasets_input = [d.target_folder for d in input_datasets] if input_datasets else []
return AzureRunInfo(
@ -222,10 +211,6 @@ def test_submit_to_azureml_if_needed(
mock_get_env_files.return_value = Path("some_file.txt")
mock_default_datastore = MagicMock()
mock_default_datastore.name.return_value = "dummy_datastore"
mock_get_workspace.get_default_datastore.return_value = mock_default_datastore
with patch("health_ml.runner.create_dataset_configs") as mock_create_datasets:
mock_create_datasets.return_value = []
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_aml:
@ -334,11 +319,9 @@ def _test_hyperdrive_submission(
# start in that temp folder.
with change_working_folder_and_add_environment(mock_runner.project_root):
with patch("health_ml.runner.Runner.run_in_situ") as mock_run_in_situ:
with patch("health_ml.runner.get_workspace"):
with patch("health_ml.runner.get_ml_client"):
with patch.object(sys, "argv", arguments):
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_aml:
mock_runner.run()
with patch.object(sys, "argv", arguments):
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_aml:
mock_runner.run()
mock_run_in_situ.assert_called_once()
mock_submit_to_aml.assert_called_once()
# call_args is a tuple of (args, kwargs)
@ -364,11 +347,9 @@ def test_submit_to_azure_docker(mock_runner: Runner) -> None:
# start in that temp folder.
with change_working_folder_and_add_environment(mock_runner.project_root):
with patch("health_ml.runner.Runner.run_in_situ") as mock_run_in_situ:
with patch("health_ml.runner.get_ml_client"):
with patch("health_ml.runner.get_workspace"):
with patch.object(sys, "argv", arguments):
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_aml:
mock_runner.run()
with patch.object(sys, "argv", arguments):
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_aml:
mock_runner.run()
mock_run_in_situ.assert_called_once()
mock_submit_to_aml.assert_called_once()
# call_args is a tuple of (args, kwargs)
@ -393,16 +374,12 @@ def test_run_hello_world(mock_runner: Runner) -> None:
"""Test running a model end-to-end via the commandline runner"""
model_name = "HelloWorld"
arguments = ["", f"--model={model_name}"]
with patch("health_ml.runner.get_workspace") as mock_get_workspace:
with patch.object(sys, "argv", arguments):
mock_runner.run()
# get_workspace should not be called when using the runner outside AzureML, to not go through the
# time-consuming auth
mock_get_workspace.assert_not_called()
# Summary.txt is written at start, the other files during inference
expected_files = ["experiment_summary.txt", TEST_MSE_FILE, TEST_MAE_FILE]
for file in expected_files:
assert (mock_runner.lightning_container.outputs_folder / file).is_file(), f"Missing file: {file}"
with patch.object(sys, "argv", arguments):
mock_runner.run()
# Summary.txt is written at start, the other files during inference
expected_files = ["experiment_summary.txt", TEST_MSE_FILE, TEST_MAE_FILE]
for file in expected_files:
assert (mock_runner.lightning_container.outputs_folder / file).is_file(), f"Missing file: {file}"
def test_invalid_args(mock_runner: Runner) -> None:
@ -425,17 +402,37 @@ def test_invalid_profiler(mock_runner: Runner) -> None:
mock_runner.run()
def test_custom_datastore_outside_aml(mock_runner: Runner) -> None:
def test_datastore_argument(mock_runner: Runner) -> None:
"""The datastore argument should be respected"""
model_name = "HelloWorld"
datastore = "foo"
arguments = ["", f"--datastore={datastore}", f"--model={model_name}"]
dataset = "bar"
arguments = ["", f"--datastore={datastore}", f"--model={model_name}", f"--azure_datasets={dataset}"]
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_azure_if_needed:
with patch("health_ml.runner.get_workspace"):
with patch("health_ml.runner.Runner.run_in_situ"):
with patch.object(sys, "argv", arguments):
mock_runner.run()
with patch("health_ml.runner.Runner.run_in_situ"):
with patch.object(sys, "argv", arguments):
mock_runner.run()
mock_submit_to_azure_if_needed.assert_called_once()
assert mock_submit_to_azure_if_needed.call_args[1]["default_datastore"] == datastore
input_datasets = mock_submit_to_azure_if_needed.call_args[1]["input_datasets"]
assert len(input_datasets) == 1
assert input_datasets[0].datastore == datastore
assert input_datasets[0].name == dataset
def test_no_authentication_outside_azureml(mock_runner: Runner) -> None:
"""No authentication should happen for a model that runs locally and needs no datasets."""
model_name = "HelloWorld"
arguments = ["", f"--datastore=datastore", f"--model={model_name}"]
with patch("health_ml.runner.Runner.run_in_situ"):
with patch.object(sys, "argv", arguments):
mock_get_workspace = MagicMock()
mock_get_ml_client = MagicMock()
with patch.multiple(
"health_azure.himl", get_workspace=mock_get_workspace, get_ml_client=mock_get_ml_client
):
mock_runner.run()
mock_get_workspace.assert_not_called()
mock_get_ml_client.assert_not_called()
@pytest.mark.fast
@ -512,3 +509,111 @@ def test_run_without_logging(tmp_path: Path) -> None:
run_with_logging(tmp_path)
mock_create_filename.assert_not_called()
mock_run.assert_called_once()
@pytest.mark.fast
def test_runner_does_not_use_get_workspace() -> None:
"""Test that the runner does not itself import get_workspace or get_ml_client (otherwise we would need to check
them in the tests below that count calls to those methods)"""
with pytest.raises(ImportError):
from health_ml.runner import get_workspace # type: ignore
with pytest.raises(ImportError):
from health_ml.runner import get_ml_client # type: ignore
def test_runner_authenticates_once_v1() -> None:
"""Test that the runner requires authentication only once when doing a job submission with the V1 SDK"""
runner = Runner(project_root=repository_root_directory())
mock_get_workspace = MagicMock()
mock_get_ml_client = MagicMock()
with patch.multiple(
"health_azure.himl",
get_workspace=mock_get_workspace,
get_ml_client=mock_get_ml_client,
Experiment=MagicMock(),
register_environment=MagicMock(return_value="env"),
validate_compute_cluster=MagicMock(),
):
with patch.object(
sys,
"argv",
["src/health_ml/runner.py", "--model=HelloWorld", "--cluster=pr-gpu", "--strictly_aml_v1"],
):
# Job submission should trigger a system exit
with pytest.raises(SystemExit):
runner.run()
mock_get_workspace.assert_called_once()
mock_get_ml_client.assert_not_called()
def test_runner_authenticates_once_v2() -> None:
"""Test that the runner requires authentication only once when doing a job submission with the V2 SDK"""
runner = Runner(project_root=repository_root_directory())
mock_get_workspace = MagicMock()
mock_get_ml_client = MagicMock()
with patch.multiple(
"health_azure.himl",
get_workspace=mock_get_workspace,
get_ml_client=mock_get_ml_client,
command=MagicMock(),
):
with patch.object(sys, "argv", ["", "--model=HelloWorld", "--cluster=pr-gpu"]):
# Job submission should trigger a system exit
with pytest.raises(SystemExit):
runner.run()
mock_get_workspace.assert_not_called()
mock_get_ml_client.assert_called_once()
def test_runner_with_local_dataset_v1() -> None:
"""Test that the runner requires authentication only once when doing a local run and a dataset has to be mounted"""
runner = Runner(project_root=repository_root_directory())
mock_get_workspace = MagicMock(return_value=DEFAULT_WORKSPACE.workspace)
mock_get_ml_client = MagicMock()
with patch.multiple(
"health_azure.himl",
get_workspace=mock_get_workspace,
get_ml_client=mock_get_ml_client,
):
with patch.object(
sys,
"argv",
[
"src/health_ml/runner.py",
"--model=HelloWorld",
"--strictly_aml_v1",
"--azure_datasets=hello_world",
],
):
runner.run()
mock_get_workspace.assert_called_once()
mock_get_ml_client.assert_not_called()
@pytest.mark.parametrize("use_local_dataset", [True, False])
def test_runner_with_local_dataset_v2(use_local_dataset: bool, tmp_path: Path) -> None:
"""Test that the runner requires authentication only once when doing a local run with SDK v2"""
runner = Runner(project_root=repository_root_directory())
mock_get_workspace = MagicMock()
mock_get_ml_client = MagicMock(return_value=DEFAULT_WORKSPACE.ml_client)
with patch.multiple(
"health_azure.himl",
get_workspace=mock_get_workspace,
get_ml_client=mock_get_ml_client,
):
args = [
"src/health_ml/runner.py",
"--model=HelloWorld",
f"--strictly_aml_v1=False",
"--azure_datasets=hello_world",
]
if use_local_dataset:
args.append(f"--local_datasets={tmp_path}")
with patch.object(sys, "argv", args):
if use_local_dataset:
runner.run()
else:
with pytest.raises(ValueError, match="AzureML SDK v2 does not support downloading datasets from"):
runner.run()
mock_get_workspace.assert_not_called()
mock_get_ml_client.assert_called_once()