зеркало из https://github.com/microsoft/hi-ml.git
BUG: Fix for duplicate authentication (#878)
Resolve issues with AML SDK v1/v2: Even in the v2 code path we are creating v1 workspaces, leading to duplicate authentication requests. To achieve that, a couple of other changes were necessary: * Deprecate the use of the default datastore, which was read out of an SDK v1 Workspace object even if SDK v2 was chosen * No longer allowing SDK v2 when mounting datasets for local runs (v2 Datasets can't be mounted at all). Also added more detailed logging for dataset creation, and a commandline flag to control logging level.
This commit is contained in:
Родитель
683def950a
Коммит
f46f60e7fa
|
@ -20,7 +20,9 @@ from azureml.data.dataset_consumption_config import DatasetConsumptionConfig
|
|||
from azureml.dataprep.fuse.daemon import MountContext
|
||||
from azureml.exceptions._azureml_exception import UserErrorException
|
||||
|
||||
from health_azure.utils import PathOrString, get_workspace, get_ml_client
|
||||
from health_azure.utils import PathOrString, get_ml_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
V1OrV2DataType = Union[FileDataset, Data]
|
||||
|
@ -128,11 +130,14 @@ def _get_or_create_v1_dataset(datastore_name: str, dataset_name: str, workspace:
|
|||
try:
|
||||
azureml_dataset = _retrieve_v1_dataset(dataset_name, workspace)
|
||||
except UserErrorException:
|
||||
logger.warning(f"Dataset '{dataset_name}' was not found, or is not an AzureML SDK v1 dataset.")
|
||||
logger.info(f"Trying to create a new dataset '{dataset_name}' from files in folder '{dataset_name}'")
|
||||
if datastore_name == "":
|
||||
raise ValueError(
|
||||
"When creating a new dataset, a datastore name must be provided. Please specify a datastore name using "
|
||||
"the --datastore flag"
|
||||
)
|
||||
logger.info(f"Trying to create a new dataset '{dataset_name}' in datastore '{datastore_name}'")
|
||||
azureml_dataset = _create_v1_dataset(datastore_name, dataset_name, workspace)
|
||||
return azureml_dataset
|
||||
|
||||
|
@ -352,10 +357,8 @@ class DatasetConfig:
|
|||
|
||||
def to_input_dataset_local(
|
||||
self,
|
||||
strictly_aml_v1: bool,
|
||||
workspace: Workspace = None,
|
||||
ml_client: Optional[MLClient] = None,
|
||||
) -> Tuple[Optional[Path], Optional[MountContext]]:
|
||||
workspace: Workspace,
|
||||
) -> Tuple[Path, Optional[MountContext]]:
|
||||
"""
|
||||
Return a local path to the dataset when outside of an AzureML run.
|
||||
If local_folder is supplied, then this is assumed to be a local dataset, and this is returned.
|
||||
|
@ -364,9 +367,6 @@ class DatasetConfig:
|
|||
therefore a tuple of Nones will be returned.
|
||||
|
||||
:param workspace: The AzureML workspace to read from.
|
||||
:param strictly_aml_v1: If True, use Azure ML SDK v1 to attempt to find or create and reigster the dataset.
|
||||
Otherwise, attempt to use Azure ML SDK v2.
|
||||
:param ml_client: An Azure MLClient object for interacting with Azure resources.
|
||||
:return: Tuple of (path to dataset, optional mountcontext)
|
||||
"""
|
||||
status = f"Dataset '{self.name}' will be "
|
||||
|
@ -381,12 +381,10 @@ class DatasetConfig:
|
|||
f"Unable to make dataset '{self.name} available for a local run because no AzureML "
|
||||
"workspace has been provided. Provide a workspace, or set a folder for local execution."
|
||||
)
|
||||
azureml_dataset = get_or_create_dataset(
|
||||
azureml_dataset = _get_or_create_v1_dataset(
|
||||
datastore_name=self.datastore,
|
||||
dataset_name=self.name,
|
||||
workspace=workspace,
|
||||
strictly_aml_v1=strictly_aml_v1,
|
||||
ml_client=ml_client,
|
||||
)
|
||||
if isinstance(azureml_dataset, FileDataset):
|
||||
target_path = self.target_folder or Path(tempfile.mkdtemp())
|
||||
|
@ -404,7 +402,7 @@ class DatasetConfig:
|
|||
print(status)
|
||||
return result
|
||||
else:
|
||||
return None, None
|
||||
raise ValueError(f"Don't know how to handle dataset '{self.name}' of type {type(azureml_dataset)}")
|
||||
|
||||
def to_input_dataset(
|
||||
self,
|
||||
|
@ -556,38 +554,10 @@ def create_dataset_configs(
|
|||
return datasets
|
||||
|
||||
|
||||
def find_workspace_for_local_datasets(
|
||||
aml_workspace: Optional[Workspace], workspace_config_path: Optional[Path], dataset_configs: List[DatasetConfig]
|
||||
) -> Optional[Workspace]:
|
||||
"""
|
||||
If any of the dataset_configs require an AzureML workspace then try to get one, otherwise return None.
|
||||
|
||||
:param aml_workspace: There are two optional parameters used to glean an existing AzureML Workspace. The simplest is
|
||||
to pass it in as a parameter.
|
||||
:param workspace_config_path: The 2nd option is to specify the path to the config.json file downloaded from the
|
||||
Azure portal from which we can retrieve the existing Workspace.
|
||||
:param dataset_configs: List of DatasetConfig describing the input datasets.
|
||||
:return: Workspace if required, None otherwise.
|
||||
"""
|
||||
workspace: Workspace = None
|
||||
# Check whether an attempt will be made to mount or download a dataset when running locally.
|
||||
# If so, try to get the AzureML workspace.
|
||||
if any(dc.local_folder is None for dc in dataset_configs):
|
||||
try:
|
||||
workspace = get_workspace(aml_workspace, workspace_config_path)
|
||||
logging.info(f"Found workspace for datasets: {workspace.name}")
|
||||
except Exception as ex:
|
||||
logging.info(f"Could not find workspace for datasets. Exception: {ex}")
|
||||
return workspace
|
||||
|
||||
|
||||
def setup_local_datasets(
|
||||
dataset_configs: List[DatasetConfig],
|
||||
strictly_aml_v1: bool,
|
||||
aml_workspace: Optional[Workspace] = None,
|
||||
ml_client: Optional[MLClient] = None,
|
||||
workspace_config_path: Optional[Path] = None,
|
||||
) -> Tuple[List[Optional[Path]], List[MountContext]]:
|
||||
workspace: Optional[Workspace],
|
||||
) -> Tuple[List[Path], List[MountContext]]:
|
||||
"""
|
||||
When running outside of AzureML, setup datasets to be used locally.
|
||||
|
||||
|
@ -595,21 +565,20 @@ def setup_local_datasets(
|
|||
used. Otherwise the dataset is mounted or downloaded to either the target folder or a temporary folder and that is
|
||||
used.
|
||||
|
||||
:param aml_workspace: There are two optional parameters used to glean an existing AzureML Workspace. The simplest is
|
||||
to pass it in as a parameter.
|
||||
:param workspace_config_path: The 2nd option is to specify the path to the config.json file downloaded from the
|
||||
Azure portal from which we can retrieve the existing Workspace.
|
||||
If a dataset does not exist, an AzureML SDK v1 dataset will be created, assuming that the dataset is given
|
||||
in a folder of the same name (for example, if a dataset is given as "mydataset", then it is created from the files
|
||||
in folder "mydataset" in the datastore).
|
||||
|
||||
:param workspace: The AzureML workspace to work with. Can be None if the list of datasets is empty, or if
|
||||
the datasets are available local.
|
||||
:param dataset_configs: List of DatasetConfig describing the input data assets.
|
||||
:param strictly_aml_v1: If True, use Azure ML SDK v1. Otherwise, attempt to use Azure ML SDK v2.
|
||||
:param ml_client: An MLClient object for interacting with AML v2 datastores.
|
||||
:return: Pair of: list of optional paths to the input datasets, list of mountcontexts, one for each mounted dataset.
|
||||
:return: Pair of: list of paths to the input datasets, list of mountcontexts, one for each mounted dataset.
|
||||
"""
|
||||
workspace = find_workspace_for_local_datasets(aml_workspace, workspace_config_path, dataset_configs)
|
||||
mounted_input_datasets: List[Optional[Path]] = []
|
||||
mounted_input_datasets: List[Path] = []
|
||||
mount_contexts: List[MountContext] = []
|
||||
|
||||
for data_config in dataset_configs:
|
||||
target_path, mount_context = data_config.to_input_dataset_local(strictly_aml_v1, workspace, ml_client)
|
||||
target_path, mount_context = data_config.to_input_dataset_local(workspace)
|
||||
|
||||
mounted_input_datasets.append(target_path)
|
||||
|
||||
|
|
|
@ -442,22 +442,20 @@ def effective_experiment_name(experiment_name: Optional[str], entry_script: Opti
|
|||
|
||||
|
||||
def submit_run_v2(
|
||||
workspace: Optional[Workspace],
|
||||
ml_client: MLClient,
|
||||
environment: EnvironmentV2,
|
||||
entry_script: PathOrString,
|
||||
script_params: List[str],
|
||||
compute_target: str,
|
||||
environment_variables: Optional[Dict[str, str]] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
input_datasets_v2: Optional[Dict[str, Input]] = None,
|
||||
output_datasets_v2: Optional[Dict[str, Output]] = None,
|
||||
snapshot_root_directory: Optional[Path] = None,
|
||||
entry_script: Optional[PathOrString] = None,
|
||||
script_params: Optional[List[str]] = None,
|
||||
compute_target: Optional[str] = None,
|
||||
tags: Optional[Dict[str, str]] = None,
|
||||
docker_shm_size: str = "",
|
||||
wait_for_completion: bool = False,
|
||||
identity_based_auth: bool = False,
|
||||
workspace_config_path: Optional[PathOrString] = None,
|
||||
ml_client: Optional[MLClient] = None,
|
||||
hyperparam_args: Optional[Dict[str, Any]] = None,
|
||||
num_nodes: int = 1,
|
||||
pytorch_processes_per_node: Optional[int] = None,
|
||||
|
@ -466,8 +464,11 @@ def submit_run_v2(
|
|||
"""
|
||||
Starts a v2 AML Job on a given workspace by submitting a command
|
||||
|
||||
:param workspace: The AzureML workspace to use.
|
||||
:param ml_client: An Azure MLClient object for interacting with Azure resources.
|
||||
:param environment: An AML v2 Environment object.
|
||||
:param entry_script: The script that should be run in AzureML.
|
||||
:param script_params: A list of parameter to pass on to the script as it runs in AzureML.
|
||||
:param compute_target: The name of a compute target in Azure ML to submit the job to.
|
||||
:param environment_variables: The environment variables that should be set when running in AzureML.
|
||||
:param experiment_name: The name of the experiment that will be used or created. If the experiment name contains
|
||||
characters that are not valid in Azure, those will be removed.
|
||||
|
@ -475,18 +476,11 @@ def submit_run_v2(
|
|||
:param output_datasets_v2: An optional dictionary of Outputs to pass in to the command.
|
||||
:param snapshot_root_directory: The directory that contains all code that should be packaged and sent to AzureML.
|
||||
All Python code that the script uses must be copied over.
|
||||
:param entry_script: The script that should be run in AzureML.
|
||||
:param script_params: A list of parameter to pass on to the script as it runs in AzureML.
|
||||
:param compute_target: Optional name of a compute target in Azure ML to submit the job to. If None, will run
|
||||
locally.
|
||||
:param tags: A dictionary of string key/value pairs, that will be added as metadata to the run. If set to None,
|
||||
a default metadata field will be added that only contains the commandline arguments that started the run.
|
||||
:param docker_shm_size: The Docker shared memory size that should be used when creating a new Docker image.
|
||||
:param wait_for_completion: If False (the default) return after the run is submitted to AzureML, otherwise wait for
|
||||
the completion of this run (if True).
|
||||
:param workspace_config_path: If not provided with an AzureML Workspace, then load one given the information in this
|
||||
config
|
||||
:param ml_client: An Azure MLClient object for interacting with Azure resources.
|
||||
:param hyperparam_args: A dictionary of hyperparameter search args to pass into a sweep job.
|
||||
:param num_nodes: The number of nodes to use for the job in AzureML. The value must be 1 or greater.
|
||||
:param pytorch_processes_per_node: For plain PyTorch multi-GPU processing: The number of processes per node.
|
||||
|
@ -496,20 +490,6 @@ def submit_run_v2(
|
|||
display name will be generated by AzureML.
|
||||
:return: An AzureML Run object.
|
||||
"""
|
||||
if ml_client is None:
|
||||
if workspace is not None:
|
||||
ml_client = get_ml_client(
|
||||
subscription_id=workspace.subscription_id,
|
||||
resource_group=workspace.resource_group,
|
||||
workspace_name=workspace.name,
|
||||
)
|
||||
elif workspace_config_path is not None:
|
||||
ml_client = get_ml_client(workspace_config_path=workspace_config_path)
|
||||
else:
|
||||
raise ValueError("Either workspace or workspace_config_path must be specified to connect to the Workspace")
|
||||
|
||||
assert compute_target is not None, "No compute_target has been provided"
|
||||
assert entry_script is not None, "No entry_script has been provided"
|
||||
snapshot_root_directory = snapshot_root_directory or Path.cwd()
|
||||
root_dir = Path(snapshot_root_directory)
|
||||
|
||||
|
@ -592,7 +572,11 @@ def submit_run_v2(
|
|||
job_to_submit = create_command_job(cmd)
|
||||
|
||||
returned_job = ml_client.jobs.create_or_update(job_to_submit)
|
||||
print(f"URL to job: {returned_job.services['Studio'].endpoint}") # type: ignore
|
||||
print("\n==============================================================================")
|
||||
# The ID field looks like /subscriptions/<sub>/resourceGroups/<rg?/providers/Microsoft.MachineLearningServices/..
|
||||
print(f"Successfully queued run {(returned_job.id or '').split('/')[-1]}")
|
||||
print(f"Run URL: {returned_job.services['Studio'].endpoint}") # type: ignore
|
||||
print("==============================================================================\n")
|
||||
if wait_for_completion:
|
||||
print("Waiting for the completion of the AzureML job.")
|
||||
wait_for_job_completion(ml_client, job_name=returned_job.name)
|
||||
|
@ -671,7 +655,7 @@ def submit_run(
|
|||
|
||||
# These need to be 'print' not 'logging.info' so that the calling script sees them outside AzureML
|
||||
print("\n==============================================================================")
|
||||
print(f"Successfully queued run number {run.number} (ID {run.id}) in experiment {run.experiment.name}")
|
||||
print(f"Successfully queued run {run.id} in experiment {run.experiment.name}")
|
||||
print(f"Experiment name and run ID are available in file {RUN_RECOVERY_FILE}")
|
||||
print(f"Experiment URL: {run.experiment.get_portal_url()}")
|
||||
print(f"Run URL: {run.get_portal_url()}")
|
||||
|
@ -885,6 +869,18 @@ def submit_to_azure_if_needed( # type: ignore
|
|||
# is necessary. If not, return to the caller for local execution.
|
||||
if submit_to_azureml is None:
|
||||
submit_to_azureml = AZUREML_FLAG in sys.argv[1:]
|
||||
|
||||
has_input_datasets = len(cleaned_input_datasets) > 0
|
||||
if submit_to_azureml or has_input_datasets:
|
||||
if strictly_aml_v1:
|
||||
aml_workspace = get_workspace(aml_workspace, workspace_config_path)
|
||||
assert aml_workspace is not None
|
||||
print(f"Loaded AzureML workspace {aml_workspace.name}")
|
||||
else:
|
||||
ml_client = get_ml_client(ml_client=ml_client, workspace_config_path=workspace_config_path)
|
||||
assert ml_client is not None
|
||||
print(f"Created MLClient for AzureML workspace {ml_client.workspace_name}")
|
||||
|
||||
if not submit_to_azureml:
|
||||
# Set the environment variables for local execution.
|
||||
environment_variables = {**DEFAULT_ENVIRONMENT_VARIABLES, **(environment_variables or {})}
|
||||
|
@ -898,16 +894,24 @@ def submit_to_azure_if_needed( # type: ignore
|
|||
logs_folder = Path.cwd() / LOGS_FOLDER
|
||||
logs_folder.mkdir(exist_ok=True)
|
||||
|
||||
any_local_folders_missing = any(dataset.local_folder is None for dataset in cleaned_input_datasets)
|
||||
|
||||
if has_input_datasets and any_local_folders_missing and not strictly_aml_v1:
|
||||
raise ValueError(
|
||||
"AzureML SDK v2 does not support downloading datasets from AzureML for local execution. "
|
||||
"Please switch to AzureML SDK v1 by setting strictly_aml_v1=True, or use "
|
||||
"--strictly_aml_v1 on the commandline, or provide a local folder for each input dataset. "
|
||||
"Note that you will not be able use AzureML datasets for runs outside AzureML if the datasets were "
|
||||
"created via SDK v2."
|
||||
)
|
||||
|
||||
mounted_input_datasets, mount_contexts = setup_local_datasets(
|
||||
cleaned_input_datasets,
|
||||
strictly_aml_v1,
|
||||
aml_workspace=aml_workspace,
|
||||
ml_client=ml_client,
|
||||
workspace_config_path=workspace_config_path,
|
||||
workspace=aml_workspace,
|
||||
)
|
||||
|
||||
return AzureRunInfo(
|
||||
input_datasets=mounted_input_datasets,
|
||||
input_datasets=mounted_input_datasets, # type: ignore
|
||||
output_datasets=[d.local_folder for d in cleaned_output_datasets],
|
||||
mount_contexts=mount_contexts,
|
||||
run=None,
|
||||
|
@ -920,9 +924,6 @@ def submit_to_azure_if_needed( # type: ignore
|
|||
print(f"No snapshot root directory given. Uploading all files in the current directory {Path.cwd()}")
|
||||
snapshot_root_directory = Path.cwd()
|
||||
|
||||
workspace = get_workspace(aml_workspace, workspace_config_path)
|
||||
print(f"Loaded AzureML workspace {workspace.name}")
|
||||
|
||||
if conda_environment_file is None:
|
||||
conda_environment_file = find_file_in_parent_to_pythonpath(CONDA_ENVIRONMENT_FILE)
|
||||
if conda_environment_file is None:
|
||||
|
@ -938,8 +939,9 @@ def submit_to_azure_if_needed( # type: ignore
|
|||
|
||||
with append_to_amlignore(amlignore=amlignore_path, lines_to_append=lines_to_append):
|
||||
if strictly_aml_v1:
|
||||
assert aml_workspace is not None, "An AzureML workspace should have been created already."
|
||||
run_config = create_run_configuration(
|
||||
workspace=workspace,
|
||||
workspace=aml_workspace,
|
||||
compute_cluster_name=compute_cluster_name,
|
||||
aml_environment_name=aml_environment_name,
|
||||
conda_environment_file=conda_environment_file,
|
||||
|
@ -968,7 +970,7 @@ def submit_to_azure_if_needed( # type: ignore
|
|||
config_to_submit = script_run_config
|
||||
|
||||
run = submit_run(
|
||||
workspace=workspace,
|
||||
workspace=aml_workspace,
|
||||
experiment_name=effective_experiment_name(experiment_name, script_run_config.script),
|
||||
script_run_config=config_to_submit,
|
||||
tags=tags,
|
||||
|
@ -979,6 +981,7 @@ def submit_to_azure_if_needed( # type: ignore
|
|||
if after_submission is not None:
|
||||
after_submission(run) # type: ignore
|
||||
else:
|
||||
assert ml_client is not None, "An AzureML MLClient should have been created already."
|
||||
if conda_environment_file is None:
|
||||
raise ValueError("Argument 'conda_environment_file' must be specified when using AzureML v2")
|
||||
environment = create_python_environment_v2(
|
||||
|
@ -987,13 +990,12 @@ def submit_to_azure_if_needed( # type: ignore
|
|||
if entry_script is None:
|
||||
entry_script = Path(sys.argv[0])
|
||||
|
||||
ml_client = get_ml_client(ml_client=ml_client, aml_workspace=workspace)
|
||||
registered_env = register_environment_v2(environment, ml_client)
|
||||
input_datasets_v2 = create_v2_inputs(ml_client, cleaned_input_datasets)
|
||||
output_datasets_v2 = create_v2_outputs(ml_client, cleaned_output_datasets)
|
||||
|
||||
job = submit_run_v2(
|
||||
workspace=workspace,
|
||||
ml_client=ml_client,
|
||||
input_datasets_v2=input_datasets_v2,
|
||||
output_datasets_v2=output_datasets_v2,
|
||||
experiment_name=experiment_name,
|
||||
|
|
|
@ -39,12 +39,7 @@ def main() -> None: # pragma: no cover
|
|||
|
||||
files_to_download = download_config.files_to_download
|
||||
|
||||
workspace = get_workspace()
|
||||
ml_client = get_ml_client(
|
||||
subscription_id=workspace.subscription_id,
|
||||
resource_group=workspace.resource_group,
|
||||
workspace_name=workspace.name,
|
||||
)
|
||||
ml_client = get_ml_client()
|
||||
for run_id in download_config.run:
|
||||
download_job_outputs_logs(ml_client, run_id, file_to_download_path=files_to_download, download_dir=output_dir)
|
||||
print("Successfully downloaded output and log files")
|
||||
|
|
|
@ -13,7 +13,6 @@ from typing import Generator, Optional, Union
|
|||
from health_azure.utils import ENV_LOCAL_RANK, check_is_any_of, is_global_rank_zero
|
||||
|
||||
logging_stdout_handler: Optional[logging.StreamHandler] = None
|
||||
logging_to_file_handler: Optional[logging.StreamHandler] = None
|
||||
|
||||
|
||||
def logging_to_stdout(log_level: Union[int, str] = logging.INFO) -> None:
|
||||
|
|
|
@ -288,6 +288,30 @@ def find_file_in_parent_to_pythonpath(file_name: str) -> Optional[Path]:
|
|||
return find_file_in_parent_folders(file_name=file_name, stop_at_path=pythonpaths)
|
||||
|
||||
|
||||
def resolve_workspace_config_path(workspace_config_path: Optional[Path] = None) -> Optional[Path]:
|
||||
"""Retrieve the path to the workspace config file, either from the argument, or from the current working directory.
|
||||
|
||||
:param workspace_config_path: A path to a workspace config file that was provided on the commandline, defaults to
|
||||
None
|
||||
:return: The path to the workspace config file, or None if it cannot be found.
|
||||
:raises FileNotFoundError: If the workspace config file that was provided as an argument does not exist.
|
||||
"""
|
||||
if workspace_config_path is None:
|
||||
logger.info(
|
||||
f"Trying to locate the workspace config file '{WORKSPACE_CONFIG_JSON}' in the current folder "
|
||||
"and its parent folders"
|
||||
)
|
||||
result = find_file_in_parent_to_pythonpath(WORKSPACE_CONFIG_JSON)
|
||||
if result:
|
||||
logger.info(f"Using the workspace config file {str(result.absolute())}")
|
||||
else:
|
||||
logger.debug("No workspace config file found")
|
||||
return result
|
||||
if not workspace_config_path.is_file():
|
||||
raise FileNotFoundError(f"Workspace config file does not exist: {workspace_config_path}")
|
||||
return workspace_config_path
|
||||
|
||||
|
||||
def get_workspace(aml_workspace: Optional[Workspace] = None, workspace_config_path: Optional[Path] = None) -> Workspace:
|
||||
"""
|
||||
Retrieve an Azure ML Workspace by going through the following steps:
|
||||
|
@ -320,26 +344,16 @@ def get_workspace(aml_workspace: Optional[Workspace] = None, workspace_config_pa
|
|||
if aml_workspace:
|
||||
return aml_workspace
|
||||
|
||||
if workspace_config_path is None:
|
||||
logging.info(
|
||||
f"Trying to locate the workspace config file '{WORKSPACE_CONFIG_JSON}' in the current folder "
|
||||
"and its parent folders"
|
||||
)
|
||||
workspace_config_path = find_file_in_parent_to_pythonpath(WORKSPACE_CONFIG_JSON)
|
||||
if workspace_config_path:
|
||||
logging.info(f"Using the workspace config file {str(workspace_config_path.absolute())}")
|
||||
|
||||
workspace_config_path = resolve_workspace_config_path(workspace_config_path)
|
||||
auth = get_authentication()
|
||||
if workspace_config_path is not None:
|
||||
if not workspace_config_path.is_file():
|
||||
raise FileNotFoundError(f"Workspace config file does not exist: {workspace_config_path}")
|
||||
workspace = Workspace.from_config(path=str(workspace_config_path), auth=auth)
|
||||
logging.info(
|
||||
logger.info(
|
||||
f"Logged into AzureML workspace {workspace.name} as specified in config file " f"{workspace_config_path}"
|
||||
)
|
||||
return workspace
|
||||
|
||||
logging.info("Trying to load the environment variables that define the workspace.")
|
||||
logger.info("Trying to load the environment variables that define the workspace.")
|
||||
workspace_name = get_secret_from_environment(ENV_WORKSPACE_NAME, allow_missing=True)
|
||||
subscription_id = get_secret_from_environment(ENV_SUBSCRIPTION_ID, allow_missing=True)
|
||||
resource_group = get_secret_from_environment(ENV_RESOURCE_GROUP, allow_missing=True)
|
||||
|
@ -347,7 +361,7 @@ def get_workspace(aml_workspace: Optional[Workspace] = None, workspace_config_pa
|
|||
workspace = Workspace.get(
|
||||
name=workspace_name, auth=auth, subscription_id=subscription_id, resource_group=resource_group
|
||||
)
|
||||
logging.info(f"Logged into AzureML workspace {workspace.name} as specified by environment variables")
|
||||
logger.info(f"Logged into AzureML workspace {workspace.name} as specified by environment variables")
|
||||
return workspace
|
||||
|
||||
raise ValueError(
|
||||
|
@ -1747,7 +1761,8 @@ class UnitTestWorkspaceWrapper:
|
|||
"""
|
||||
Init.
|
||||
"""
|
||||
self._workspace: Workspace = None
|
||||
self._workspace: Optional[Workspace] = None
|
||||
self._ml_client: Optional[MLClient] = None
|
||||
|
||||
@property
|
||||
def workspace(self) -> Workspace:
|
||||
|
@ -1758,6 +1773,15 @@ class UnitTestWorkspaceWrapper:
|
|||
self._workspace = get_workspace()
|
||||
return self._workspace
|
||||
|
||||
@property
|
||||
def ml_client(self) -> MLClient:
|
||||
"""
|
||||
Lazily load the ML Client.
|
||||
"""
|
||||
if self._ml_client is None:
|
||||
self._ml_client = get_ml_client()
|
||||
return self._ml_client
|
||||
|
||||
|
||||
@contextmanager
|
||||
def check_config_json(script_folder: Path, shared_config_json: Path) -> Generator:
|
||||
|
@ -1895,7 +1919,7 @@ def _get_legitimate_interactive_browser_credential() -> Optional[TokenCredential
|
|||
|
||||
def get_credential() -> Optional[TokenCredential]:
|
||||
"""
|
||||
Get a credential for authenticating with Azure.There are multiple ways to retrieve a credential.
|
||||
Get a credential for authenticating with Azure. There are multiple ways to retrieve a credential.
|
||||
If environment variables pertaining to details of a Service Principal are available, those will be used
|
||||
to authenticate. If no environment variables exist, and the script is not currently
|
||||
running inside of Azure ML or another Azure agent, will attempt to retrieve a credential via a
|
||||
|
@ -1910,6 +1934,7 @@ def get_credential() -> Optional[TokenCredential]:
|
|||
tenant_id = get_secret_from_environment(ENV_TENANT_ID, allow_missing=True)
|
||||
service_principal_password = get_secret_from_environment(ENV_SERVICE_PRINCIPAL_PASSWORD, allow_missing=True)
|
||||
if service_principal_id and tenant_id and service_principal_password:
|
||||
logger.debug("Found environment variables for Service Principal authentication")
|
||||
return _get_legitimate_service_principal_credential(tenant_id, service_principal_id, service_principal_password)
|
||||
|
||||
try:
|
||||
|
@ -1927,66 +1952,76 @@ def get_credential() -> Optional[TokenCredential]:
|
|||
|
||||
raise ValueError(
|
||||
"Unable to generate and validate a credential. Please see Azure ML documentation"
|
||||
"for instructions on diffrent options to get a credential"
|
||||
"for instructions on different options to get a credential"
|
||||
)
|
||||
|
||||
|
||||
def get_ml_client(
|
||||
ml_client: Optional[MLClient] = None,
|
||||
aml_workspace: Optional[Workspace] = None,
|
||||
workspace_config_path: Optional[PathOrString] = None,
|
||||
subscription_id: Optional[str] = None,
|
||||
resource_group: Optional[str] = None,
|
||||
workspace_name: str = "",
|
||||
workspace_config_path: Optional[Path] = None,
|
||||
) -> MLClient:
|
||||
"""
|
||||
Instantiate an MLClient for interacting with Azure resources via v2 of the Azure ML SDK.
|
||||
If a ml_client is provided, return that. Otherwise, create one using workspace details
|
||||
coming from either an existing Workspace object, a config.json file or passed in as an argument.
|
||||
Instantiate an MLClient for interacting with Azure resources via v2 of the Azure ML SDK. The following ways of
|
||||
creating the client are tried out:
|
||||
|
||||
1. If an MLClient object has been provided in the `ml_client` argument, return that.
|
||||
|
||||
2. If a path to a workspace config file has been provided, load the MLClient according to that config file.
|
||||
|
||||
3. If a workspace config file is present in the current working directory or one of its parents, load the
|
||||
MLClient according to that config file.
|
||||
|
||||
4. If 3 environment variables are found, use them to identify the workspace (`HIML_RESOURCE_GROUP`,
|
||||
`HIML_SUBSCRIPTION_ID`, `HIML_WORKSPACE_NAME`)
|
||||
|
||||
If none of the above succeeds, an exception is raised.
|
||||
|
||||
:param ml_client: An optional existing MLClient object to be returned.
|
||||
:param aml_workspace: An optional Workspace object to take connection details from.
|
||||
:param workspace_config_path: An optional path toa config.json file containing details of the Workspace.
|
||||
:param subscription_id: An optional subscription ID.
|
||||
:param resource_group: An optional resource group name.
|
||||
:param workspace_name: An optional workspace name.
|
||||
:return: An instance of MLClient to interact with Azure resources.
|
||||
"""
|
||||
if ml_client:
|
||||
if ml_client is not None:
|
||||
return ml_client
|
||||
|
||||
logger.debug("Getting credentials")
|
||||
credential = get_credential()
|
||||
if credential is None:
|
||||
raise ValueError("Can't connect to MLClient without a valid credential")
|
||||
if aml_workspace is not None:
|
||||
ml_client = MLClient(
|
||||
subscription_id=aml_workspace.subscription_id,
|
||||
resource_group_name=aml_workspace.resource_group,
|
||||
workspace_name=aml_workspace.name,
|
||||
credential=credential,
|
||||
) # type: ignore
|
||||
elif workspace_config_path:
|
||||
workspace_config_path = resolve_workspace_config_path(workspace_config_path)
|
||||
if workspace_config_path is not None:
|
||||
logger.debug(f"Retrieving MLClient from workspace config {workspace_config_path}")
|
||||
ml_client = MLClient.from_config(credential=credential, path=str(workspace_config_path)) # type: ignore
|
||||
elif subscription_id and resource_group and workspace_name:
|
||||
logger.info(
|
||||
f"Using MLClient for AzureML workspace {ml_client.workspace_name} as specified in config file"
|
||||
f"{workspace_config_path}"
|
||||
)
|
||||
return ml_client
|
||||
|
||||
logger.info("Trying to load the environment variables that define the workspace.")
|
||||
workspace_name = get_secret_from_environment(ENV_WORKSPACE_NAME, allow_missing=True)
|
||||
subscription_id = get_secret_from_environment(ENV_SUBSCRIPTION_ID, allow_missing=True)
|
||||
resource_group = get_secret_from_environment(ENV_RESOURCE_GROUP, allow_missing=True)
|
||||
if workspace_name and subscription_id and resource_group:
|
||||
logger.debug(
|
||||
"Retrieving MLClient via subscription ID, resource group and workspace name retrieved from "
|
||||
"environment variables."
|
||||
)
|
||||
ml_client = MLClient(
|
||||
subscription_id=subscription_id,
|
||||
resource_group_name=resource_group,
|
||||
workspace_name=workspace_name,
|
||||
credential=credential,
|
||||
) # type: ignore
|
||||
else:
|
||||
try:
|
||||
workspace = get_workspace()
|
||||
ml_client = MLClient(
|
||||
subscription_id=workspace.subscription_id,
|
||||
resource_group_name=workspace.resource_group,
|
||||
workspace_name=workspace.name,
|
||||
credential=credential,
|
||||
) # type: ignore
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Couldn't connect to MLClient: {e}")
|
||||
logging.info(f"Logged into AzureML workspace {ml_client.workspace_name}")
|
||||
return ml_client
|
||||
logger.info(f"Using MLClient for AzureML workspace {workspace_name} as specified by environment variables")
|
||||
return ml_client
|
||||
|
||||
raise ValueError(
|
||||
"Tried all ways of identifying the MLClient, but failed. Please provide a workspace config "
|
||||
f"file {WORKSPACE_CONFIG_JSON} or set the environment variables {ENV_RESOURCE_GROUP}, "
|
||||
f"{ENV_SUBSCRIPTION_ID}, and {ENV_WORKSPACE_NAME}."
|
||||
)
|
||||
|
||||
|
||||
def retrieve_workspace_from_client(ml_client: MLClient, workspace_name: Optional[str] = None) -> WorkspaceV2:
|
||||
|
|
|
@ -13,7 +13,7 @@ import time
|
|||
from pathlib import Path
|
||||
from typing import Any, Dict, Generator, List, Optional, Union
|
||||
from unittest import mock
|
||||
from unittest.mock import DEFAULT, MagicMock, patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
from xmlrpc.client import Boolean
|
||||
|
||||
|
@ -23,13 +23,12 @@ import pandas as pd
|
|||
import param
|
||||
import pytest
|
||||
from _pytest.logging import LogCaptureFixture
|
||||
from azure.identity import ClientSecretCredential, DeviceCodeCredential, DefaultAzureCredential
|
||||
from azure.storage.blob import ContainerClient
|
||||
from azureml._restclient.constants import RunStatus
|
||||
from azureml.core import Experiment, Run, ScriptRunConfig, Workspace
|
||||
from azureml.core.run import _OfflineRun
|
||||
from azureml.core.environment import CondaDependencies
|
||||
from azure.core.exceptions import ClientAuthenticationError, ResourceNotFoundError
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
from azureml.data.azure_storage_datastore import AzureBlobDatastore
|
||||
|
||||
import health_azure.utils as util
|
||||
|
@ -41,8 +40,8 @@ from health_azure.utils import (
|
|||
MASTER_PORT_DEFAULT,
|
||||
PackageDependency,
|
||||
download_files_by_suffix,
|
||||
get_credential,
|
||||
download_file_if_necessary,
|
||||
resolve_workspace_config_path,
|
||||
)
|
||||
from testazure.test_himl import RunTarget, render_and_run_test_script
|
||||
from testazure.utils_testazure import (
|
||||
|
@ -1996,136 +1995,6 @@ def test_create_run() -> None:
|
|||
run.complete()
|
||||
|
||||
|
||||
def test_get_credential() -> None:
|
||||
def _mock_validation_error() -> None:
|
||||
raise ClientAuthenticationError("")
|
||||
|
||||
# test the case where service principal credentials are set as environment variables
|
||||
mock_env_vars = {
|
||||
util.ENV_SERVICE_PRINCIPAL_ID: "foo",
|
||||
util.ENV_TENANT_ID: "bar",
|
||||
util.ENV_SERVICE_PRINCIPAL_PASSWORD: "baz",
|
||||
}
|
||||
|
||||
with patch.object(os.environ, "get", return_value=mock_env_vars):
|
||||
with patch.multiple(
|
||||
"health_azure.utils",
|
||||
is_running_in_azure_ml=DEFAULT,
|
||||
is_running_on_azure_agent=DEFAULT,
|
||||
_get_legitimate_service_principal_credential=DEFAULT,
|
||||
_get_legitimate_device_code_credential=DEFAULT,
|
||||
_get_legitimate_default_credential=DEFAULT,
|
||||
_get_legitimate_interactive_browser_credential=DEFAULT,
|
||||
) as mocks:
|
||||
mocks["is_running_in_azure_ml"].return_value = False
|
||||
mocks["is_running_on_azure_agent"].return_value = False
|
||||
_ = get_credential()
|
||||
mocks["_get_legitimate_service_principal_credential"].assert_called_once()
|
||||
mocks["_get_legitimate_device_code_credential"].assert_not_called()
|
||||
mocks["_get_legitimate_default_credential"].assert_not_called()
|
||||
mocks["_get_legitimate_interactive_browser_credential"].assert_not_called()
|
||||
|
||||
# if the environment variables are not set and we are running on a local machine, a
|
||||
# DefaultAzureCredential should be attempted first
|
||||
with patch.object(os.environ, "get", return_value={}):
|
||||
with patch.multiple(
|
||||
"health_azure.utils",
|
||||
is_running_in_azure_ml=DEFAULT,
|
||||
is_running_on_azure_agent=DEFAULT,
|
||||
_get_legitimate_service_principal_credential=DEFAULT,
|
||||
_get_legitimate_device_code_credential=DEFAULT,
|
||||
_get_legitimate_default_credential=DEFAULT,
|
||||
_get_legitimate_interactive_browser_credential=DEFAULT,
|
||||
) as mocks:
|
||||
mock_get_sp_cred = mocks["_get_legitimate_service_principal_credential"]
|
||||
mock_get_device_cred = mocks["_get_legitimate_device_code_credential"]
|
||||
mock_get_default_cred = mocks["_get_legitimate_default_credential"]
|
||||
mock_get_browser_cred = mocks["_get_legitimate_interactive_browser_credential"]
|
||||
|
||||
mocks["is_running_in_azure_ml"].return_value = False
|
||||
mocks["is_running_on_azure_agent"].return_value = False
|
||||
_ = get_credential()
|
||||
mock_get_sp_cred.assert_not_called()
|
||||
mock_get_device_cred.assert_not_called()
|
||||
mock_get_default_cred.assert_called_once()
|
||||
mock_get_browser_cred.assert_not_called()
|
||||
|
||||
# if that fails, a DeviceCode credential should be attempted
|
||||
mock_get_default_cred.side_effect = _mock_validation_error
|
||||
_ = get_credential()
|
||||
mock_get_sp_cred.assert_not_called()
|
||||
mock_get_device_cred.assert_called_once()
|
||||
assert mock_get_default_cred.call_count == 2
|
||||
mock_get_browser_cred.assert_not_called()
|
||||
|
||||
# if None of the previous credentials work, an InteractiveBrowser credential should be tried
|
||||
mock_get_device_cred.return_value = None
|
||||
_ = get_credential()
|
||||
mock_get_sp_cred.assert_not_called()
|
||||
assert mock_get_device_cred.call_count == 2
|
||||
assert mock_get_default_cred.call_count == 3
|
||||
mock_get_browser_cred.assert_called_once()
|
||||
|
||||
# finally, if none of the methods work, an Exception should be raised
|
||||
mock_get_browser_cred.return_value = None
|
||||
with pytest.raises(Exception) as e:
|
||||
get_credential()
|
||||
assert (
|
||||
"Unable to generate and validate a credential. Please see Azure ML documentation"
|
||||
"for instructions on different options to get a credential" in str(e)
|
||||
)
|
||||
|
||||
|
||||
def test_get_legitimate_service_principal_credential() -> None:
|
||||
# first attempt to create and valiadate a credential with non-existant service principal credentials
|
||||
# and check it fails
|
||||
mock_service_principal_id = "foo"
|
||||
mock_service_principal_password = "bar"
|
||||
mock_tenant_id = "baz"
|
||||
expected_error_msg = f"Found environment variables for {util.ENV_SERVICE_PRINCIPAL_ID}, "
|
||||
f"{util.ENV_SERVICE_PRINCIPAL_PASSWORD}, and {util.ENV_TENANT_ID} but was not able to authenticate"
|
||||
with pytest.raises(Exception) as e:
|
||||
util._get_legitimate_service_principal_credential(
|
||||
mock_tenant_id, mock_service_principal_id, mock_service_principal_password
|
||||
)
|
||||
assert expected_error_msg in str(e)
|
||||
|
||||
# now mock the case where validating the credential succeeds and check the value of that
|
||||
with patch("health_azure.utils._validate_credential"):
|
||||
cred = util._get_legitimate_service_principal_credential(
|
||||
mock_tenant_id, mock_service_principal_id, mock_service_principal_password
|
||||
)
|
||||
assert isinstance(cred, ClientSecretCredential)
|
||||
|
||||
|
||||
def test_get_legitimate_device_code_credential() -> None:
|
||||
def _mock_credential_fast_timeout(timeout: int) -> DeviceCodeCredential:
|
||||
return DeviceCodeCredential(timeout=1)
|
||||
|
||||
with patch("health_azure.utils.DeviceCodeCredential", new=_mock_credential_fast_timeout):
|
||||
cred = util._get_legitimate_device_code_credential()
|
||||
assert cred is None
|
||||
|
||||
# now mock the case where validating the credential succeeds
|
||||
with patch("health_azure.utils._validate_credential"):
|
||||
cred = util._get_legitimate_device_code_credential()
|
||||
assert isinstance(cred, DeviceCodeCredential)
|
||||
|
||||
|
||||
def test_get_legitimate_default_credential() -> None:
|
||||
def _mock_credential_fast_timeout(timeout: int) -> DefaultAzureCredential:
|
||||
return DefaultAzureCredential(timeout=1)
|
||||
|
||||
with patch("health_azure.utils.DefaultAzureCredential", new=_mock_credential_fast_timeout):
|
||||
exception_message = r"DefaultAzureCredential failed to retrieve a token from the included credentials."
|
||||
with pytest.raises(ClientAuthenticationError, match=exception_message):
|
||||
cred = util._get_legitimate_default_credential()
|
||||
|
||||
with patch("health_azure.utils._validate_credential"):
|
||||
cred = util._get_legitimate_default_credential()
|
||||
assert isinstance(cred, DefaultAzureCredential)
|
||||
|
||||
|
||||
def test_filter_v2_input_output_args() -> None:
|
||||
def _compare_args(expected: List[str], actual: List[str]) -> None:
|
||||
assert len(actual) == len(expected)
|
||||
|
@ -2244,3 +2113,33 @@ def test_download_files_by_suffix(tmp_path: Path, files: List[str], expected_dow
|
|||
assert f.is_file()
|
||||
downloaded_filenames = [f.name for f in downloaded_list]
|
||||
assert downloaded_filenames == expected_downloaded
|
||||
|
||||
|
||||
def test_resolve_workspace_config_path_no_argument(tmp_path: Path) -> None:
|
||||
"""Test for resolve_workspace_config_path without argument: It should try to find a config file in the folders.
|
||||
If the file exists, it should return the path"""
|
||||
mocked_file = tmp_path / "foo.json"
|
||||
with patch("health_azure.utils.find_file_in_parent_to_pythonpath", return_value=mocked_file):
|
||||
result = resolve_workspace_config_path()
|
||||
assert result == mocked_file
|
||||
|
||||
|
||||
def test_resolve_workspace_config_path_no_argument_no_file() -> None:
|
||||
"""Test for resolve_workspace_config_path without argument: It should try to find a config file in the folders.
|
||||
If the file does not exist, return None"""
|
||||
with patch("health_azure.utils.find_file_in_parent_to_pythonpath", return_value=None):
|
||||
result = resolve_workspace_config_path()
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_resolve_workspace_config_path_file_exists(tmp_path: Path) -> None:
|
||||
mocked_file = tmp_path / "foo.json"
|
||||
mocked_file.touch()
|
||||
result = resolve_workspace_config_path(mocked_file)
|
||||
assert result == mocked_file
|
||||
|
||||
|
||||
def test_resolve_workspace_config_path_missing(tmp_path: Path) -> None:
|
||||
mocked_file = tmp_path / "foo.json"
|
||||
with pytest.raises(FileNotFoundError, match="Workspace config file does not exist"):
|
||||
resolve_workspace_config_path(mocked_file)
|
||||
|
|
|
@ -20,6 +20,7 @@ from azureml.data import FileDataset, OutputFileDatasetConfig
|
|||
from azureml.data.azure_storage_datastore import AzureBlobDatastore
|
||||
from azureml.data.dataset_consumption_config import DatasetConsumptionConfig
|
||||
from azureml.exceptions._azureml_exception import UserErrorException
|
||||
from health_azure.himl import submit_to_azure_if_needed
|
||||
from testazure.utils_testazure import (
|
||||
DEFAULT_DATASTORE,
|
||||
DEFAULT_WORKSPACE,
|
||||
|
@ -27,7 +28,6 @@ from testazure.utils_testazure import (
|
|||
TEST_DATA_ASSET_NAME,
|
||||
TEST_INVALID_DATA_ASSET_NAME,
|
||||
TEST_DATASTORE_NAME,
|
||||
get_test_ml_client,
|
||||
)
|
||||
|
||||
from health_azure.datasets import (
|
||||
|
@ -46,10 +46,7 @@ from health_azure.datasets import (
|
|||
get_or_create_dataset,
|
||||
_get_latest_v2_asset_version,
|
||||
)
|
||||
from health_azure.utils import PathOrString, get_ml_client
|
||||
|
||||
|
||||
TEST_ML_CLIENT = get_test_ml_client()
|
||||
from health_azure.utils import PathOrString
|
||||
|
||||
|
||||
def test_datasetconfig_init() -> None:
|
||||
|
@ -234,12 +231,11 @@ def test_get_or_create_dataset() -> None:
|
|||
|
||||
data_asset_name = "himl_tiny_data_asset"
|
||||
workspace = DEFAULT_WORKSPACE.workspace
|
||||
ml_client = get_ml_client(aml_workspace=workspace)
|
||||
# When creating a dataset, we need a non-empty name
|
||||
with pytest.raises(ValueError) as ex:
|
||||
get_or_create_dataset(
|
||||
workspace=workspace,
|
||||
ml_client=ml_client,
|
||||
ml_client=DEFAULT_WORKSPACE.ml_client,
|
||||
datastore_name="himldatasetsv2",
|
||||
dataset_name="",
|
||||
strictly_aml_v1=True,
|
||||
|
@ -254,7 +250,7 @@ def test_get_or_create_dataset() -> None:
|
|||
mocks["_get_or_create_v1_dataset"].return_value = mock_v1_dataset
|
||||
dataset = get_or_create_dataset(
|
||||
workspace=workspace,
|
||||
ml_client=ml_client,
|
||||
ml_client=DEFAULT_WORKSPACE.ml_client,
|
||||
datastore_name="himldatasetsv2",
|
||||
dataset_name=data_asset_name,
|
||||
strictly_aml_v1=True,
|
||||
|
@ -268,7 +264,7 @@ def test_get_or_create_dataset() -> None:
|
|||
mocks["_get_or_create_v2_data_asset"].return_value = mock_v2_dataset
|
||||
dataset = get_or_create_dataset(
|
||||
workspace=workspace,
|
||||
ml_client=ml_client,
|
||||
ml_client=DEFAULT_WORKSPACE.ml_client,
|
||||
datastore_name="himldatasetsv2",
|
||||
dataset_name=data_asset_name,
|
||||
strictly_aml_v1=False,
|
||||
|
@ -281,7 +277,7 @@ def test_get_or_create_dataset() -> None:
|
|||
mocks["_get_or_create_v2_data_asset"].side_effect = _mock_retrieve_or_create_v2_dataset_fails
|
||||
dataset = get_or_create_dataset(
|
||||
workspace=workspace,
|
||||
ml_client=ml_client,
|
||||
ml_client=DEFAULT_WORKSPACE.ml_client,
|
||||
datastore_name="himldatasetsv2",
|
||||
dataset_name=data_asset_name,
|
||||
strictly_aml_v1=False,
|
||||
|
@ -417,7 +413,7 @@ def test_retrieve_v2_data_asset(asset_name: str, asset_version: Optional[str]) -
|
|||
mock_get_v2_asset_version.side_effect = _get_latest_v2_asset_version
|
||||
try:
|
||||
data_asset = _retrieve_v2_data_asset(
|
||||
ml_client=TEST_ML_CLIENT,
|
||||
ml_client=DEFAULT_WORKSPACE.ml_client,
|
||||
data_asset_name=asset_name,
|
||||
version=asset_version,
|
||||
)
|
||||
|
@ -445,10 +441,12 @@ def test_retrieve_v2_data_asset(asset_name: str, asset_version: Optional[str]) -
|
|||
|
||||
|
||||
def test_retrieve_v2_data_asset_invalid_version() -> None:
|
||||
invalid_asset_version = str(int(_get_latest_v2_asset_version(TEST_ML_CLIENT, TEST_DATA_ASSET_NAME)) + 1)
|
||||
invalid_asset_version = str(
|
||||
int(_get_latest_v2_asset_version(DEFAULT_WORKSPACE.ml_client, TEST_DATA_ASSET_NAME)) + 1
|
||||
)
|
||||
with pytest.raises(ResourceNotFoundError) as ex:
|
||||
_retrieve_v2_data_asset(
|
||||
ml_client=TEST_ML_CLIENT,
|
||||
ml_client=DEFAULT_WORKSPACE.ml_client,
|
||||
data_asset_name=TEST_DATA_ASSET_NAME,
|
||||
version=invalid_asset_version,
|
||||
)
|
||||
|
@ -459,15 +457,19 @@ def test_retrieving_v2_data_asset_does_not_increment() -> None:
|
|||
"""Test if calling the get_or_create_data_asset on an existing asset does not increment the version number."""
|
||||
|
||||
with patch("health_azure.datasets._create_v2_data_asset") as mock_create_v2_data_asset:
|
||||
asset_version_before_get_or_create = _get_latest_v2_asset_version(TEST_ML_CLIENT, TEST_DATA_ASSET_NAME)
|
||||
asset_version_before_get_or_create = _get_latest_v2_asset_version(
|
||||
DEFAULT_WORKSPACE.ml_client, TEST_DATA_ASSET_NAME
|
||||
)
|
||||
get_or_create_dataset(
|
||||
TEST_DATASTORE_NAME,
|
||||
TEST_DATA_ASSET_NAME,
|
||||
DEFAULT_WORKSPACE,
|
||||
strictly_aml_v1=False,
|
||||
ml_client=TEST_ML_CLIENT,
|
||||
ml_client=DEFAULT_WORKSPACE.ml_client,
|
||||
)
|
||||
asset_version_after_get_or_create = _get_latest_v2_asset_version(
|
||||
DEFAULT_WORKSPACE.ml_client, TEST_DATA_ASSET_NAME
|
||||
)
|
||||
asset_version_after_get_or_create = _get_latest_v2_asset_version(TEST_ML_CLIENT, TEST_DATA_ASSET_NAME)
|
||||
|
||||
mock_create_v2_data_asset.assert_not_called()
|
||||
assert asset_version_before_get_or_create == asset_version_after_get_or_create
|
||||
|
@ -485,7 +487,7 @@ def test_retrieving_v2_data_asset_does_not_increment() -> None:
|
|||
def test_create_v2_data_asset(asset_name: str, datastore_name: str, version: Optional[str]) -> None:
|
||||
try:
|
||||
data_asset = _create_v2_data_asset(
|
||||
ml_client=TEST_ML_CLIENT,
|
||||
ml_client=DEFAULT_WORKSPACE.ml_client,
|
||||
datastore_name=TEST_DATASTORE_NAME,
|
||||
data_asset_name=asset_name,
|
||||
version=version,
|
||||
|
@ -558,3 +560,36 @@ def test_create_dataset_configs() -> None:
|
|||
with pytest.raises(Exception) as e:
|
||||
create_dataset_configs(azure_datasets, dataset_mountpoints, local_datasets, datastore, use_mounting)
|
||||
assert "Invalid dataset setup" in str(e)
|
||||
|
||||
|
||||
def test_local_datasets() -> None:
|
||||
"""Test if Azure datasets can be mounted for local runs"""
|
||||
# Dataset hello_world must exist in the test AzureML workspace
|
||||
dataset = DatasetConfig(name="hello_world")
|
||||
run_info = submit_to_azure_if_needed(
|
||||
input_datasets=[dataset],
|
||||
strictly_aml_v1=True,
|
||||
)
|
||||
assert len(run_info.input_datasets) == 1
|
||||
assert isinstance(run_info.input_datasets[0], Path)
|
||||
assert run_info.input_datasets[0].is_dir()
|
||||
assert len(list(run_info.input_datasets[0].glob("*"))) > 0
|
||||
|
||||
|
||||
def test_local_datasets_fails_with_v2() -> None:
|
||||
"""Azure datasets can't be used when using SDK v2"""
|
||||
dataset = DatasetConfig(name="himl-tiny_dataset")
|
||||
with pytest.raises(ValueError, match="AzureML SDK v2 does not support downloading datasets from AzureML"):
|
||||
submit_to_azure_if_needed(
|
||||
input_datasets=[dataset],
|
||||
strictly_aml_v1=False,
|
||||
)
|
||||
|
||||
|
||||
def test_local_datasets_fail_with_v2() -> None:
|
||||
"""If no datasets are specified, we can still run with SDK v2"""
|
||||
run_info = submit_to_azure_if_needed(
|
||||
input_datasets=[],
|
||||
strictly_aml_v1=False,
|
||||
)
|
||||
assert len(run_info.input_datasets) == 0
|
||||
|
|
|
@ -0,0 +1,249 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
"""
|
||||
Tests for health_azure.azure_get_workspace and related functions.
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import DEFAULT, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from azure.core.exceptions import ClientAuthenticationError
|
||||
from azure.identity import ClientSecretCredential, DefaultAzureCredential, DeviceCodeCredential
|
||||
|
||||
from health_azure.utils import (
|
||||
ENV_RESOURCE_GROUP,
|
||||
ENV_SERVICE_PRINCIPAL_ID,
|
||||
ENV_SERVICE_PRINCIPAL_PASSWORD,
|
||||
ENV_SUBSCRIPTION_ID,
|
||||
ENV_TENANT_ID,
|
||||
ENV_WORKSPACE_NAME,
|
||||
_get_legitimate_default_credential,
|
||||
_get_legitimate_device_code_credential,
|
||||
_get_legitimate_service_principal_credential,
|
||||
get_credential,
|
||||
get_ml_client,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.fast
|
||||
def test_get_credential() -> None:
|
||||
def _mock_validation_error() -> None:
|
||||
raise ClientAuthenticationError("")
|
||||
|
||||
# test the case where service principal credentials are set as environment variables
|
||||
mock_env_vars = {
|
||||
ENV_SERVICE_PRINCIPAL_ID: "foo",
|
||||
ENV_TENANT_ID: "bar",
|
||||
ENV_SERVICE_PRINCIPAL_PASSWORD: "baz",
|
||||
}
|
||||
|
||||
with patch.object(os.environ, "get", return_value=mock_env_vars):
|
||||
with patch.multiple(
|
||||
"health_azure.utils",
|
||||
is_running_in_azure_ml=DEFAULT,
|
||||
is_running_on_azure_agent=DEFAULT,
|
||||
_get_legitimate_service_principal_credential=DEFAULT,
|
||||
_get_legitimate_device_code_credential=DEFAULT,
|
||||
_get_legitimate_default_credential=DEFAULT,
|
||||
_get_legitimate_interactive_browser_credential=DEFAULT,
|
||||
) as mocks:
|
||||
mocks["is_running_in_azure_ml"].return_value = False
|
||||
mocks["is_running_on_azure_agent"].return_value = False
|
||||
_ = get_credential()
|
||||
mocks["_get_legitimate_service_principal_credential"].assert_called_once()
|
||||
mocks["_get_legitimate_device_code_credential"].assert_not_called()
|
||||
mocks["_get_legitimate_default_credential"].assert_not_called()
|
||||
mocks["_get_legitimate_interactive_browser_credential"].assert_not_called()
|
||||
|
||||
# if the environment variables are not set and we are running on a local machine, a
|
||||
# DefaultAzureCredential should be attempted first
|
||||
with patch.object(os.environ, "get", return_value={}):
|
||||
with patch.multiple(
|
||||
"health_azure.utils",
|
||||
is_running_in_azure_ml=DEFAULT,
|
||||
is_running_on_azure_agent=DEFAULT,
|
||||
_get_legitimate_service_principal_credential=DEFAULT,
|
||||
_get_legitimate_device_code_credential=DEFAULT,
|
||||
_get_legitimate_default_credential=DEFAULT,
|
||||
_get_legitimate_interactive_browser_credential=DEFAULT,
|
||||
) as mocks:
|
||||
mock_get_sp_cred = mocks["_get_legitimate_service_principal_credential"]
|
||||
mock_get_device_cred = mocks["_get_legitimate_device_code_credential"]
|
||||
mock_get_default_cred = mocks["_get_legitimate_default_credential"]
|
||||
mock_get_browser_cred = mocks["_get_legitimate_interactive_browser_credential"]
|
||||
|
||||
mocks["is_running_in_azure_ml"].return_value = False
|
||||
mocks["is_running_on_azure_agent"].return_value = False
|
||||
_ = get_credential()
|
||||
mock_get_sp_cred.assert_not_called()
|
||||
mock_get_device_cred.assert_not_called()
|
||||
mock_get_default_cred.assert_called_once()
|
||||
mock_get_browser_cred.assert_not_called()
|
||||
|
||||
# if that fails, a DeviceCode credential should be attempted
|
||||
mock_get_default_cred.side_effect = _mock_validation_error
|
||||
_ = get_credential()
|
||||
mock_get_sp_cred.assert_not_called()
|
||||
mock_get_device_cred.assert_called_once()
|
||||
assert mock_get_default_cred.call_count == 2
|
||||
mock_get_browser_cred.assert_not_called()
|
||||
|
||||
# if None of the previous credentials work, an InteractiveBrowser credential should be tried
|
||||
mock_get_device_cred.return_value = None
|
||||
_ = get_credential()
|
||||
mock_get_sp_cred.assert_not_called()
|
||||
assert mock_get_device_cred.call_count == 2
|
||||
assert mock_get_default_cred.call_count == 3
|
||||
mock_get_browser_cred.assert_called_once()
|
||||
|
||||
# finally, if none of the methods work, an Exception should be raised
|
||||
mock_get_browser_cred.return_value = None
|
||||
with pytest.raises(Exception) as e:
|
||||
get_credential()
|
||||
assert (
|
||||
"Unable to generate and validate a credential. Please see Azure ML documentation"
|
||||
"for instructions on different options to get a credential" in str(e)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.fast
|
||||
def test_get_legitimate_service_principal_credential() -> None:
|
||||
# first attempt to create and valiadate a credential with non-existant service principal credentials
|
||||
# and check it fails
|
||||
mock_service_principal_id = "foo"
|
||||
mock_service_principal_password = "bar"
|
||||
mock_tenant_id = "baz"
|
||||
expected_error_msg = f"Found environment variables for {ENV_SERVICE_PRINCIPAL_ID}, "
|
||||
f"{ENV_SERVICE_PRINCIPAL_PASSWORD}, and {ENV_TENANT_ID} but was not able to authenticate"
|
||||
with pytest.raises(Exception) as e:
|
||||
_get_legitimate_service_principal_credential(
|
||||
mock_tenant_id, mock_service_principal_id, mock_service_principal_password
|
||||
)
|
||||
assert expected_error_msg in str(e)
|
||||
|
||||
# now mock the case where validating the credential succeeds and check the value of that
|
||||
with patch("health_azure.utils._validate_credential"):
|
||||
cred = _get_legitimate_service_principal_credential(
|
||||
mock_tenant_id, mock_service_principal_id, mock_service_principal_password
|
||||
)
|
||||
assert isinstance(cred, ClientSecretCredential)
|
||||
|
||||
|
||||
@pytest.mark.fast
|
||||
def test_get_legitimate_device_code_credential() -> None:
|
||||
def _mock_credential_fast_timeout(timeout: int) -> DeviceCodeCredential:
|
||||
return DeviceCodeCredential(timeout=1)
|
||||
|
||||
with patch("health_azure.utils.DeviceCodeCredential", new=_mock_credential_fast_timeout):
|
||||
cred = _get_legitimate_device_code_credential()
|
||||
assert cred is None
|
||||
|
||||
# now mock the case where validating the credential succeeds
|
||||
with patch("health_azure.utils._validate_credential"):
|
||||
cred = _get_legitimate_device_code_credential()
|
||||
assert isinstance(cred, DeviceCodeCredential)
|
||||
|
||||
|
||||
@pytest.mark.fast
|
||||
def test_get_legitimate_default_credential() -> None:
|
||||
def _mock_credential_fast_timeout(timeout: int) -> DefaultAzureCredential:
|
||||
return DefaultAzureCredential(timeout=1)
|
||||
|
||||
with patch("health_azure.utils.DefaultAzureCredential", new=_mock_credential_fast_timeout):
|
||||
exception_message = r"DefaultAzureCredential failed to retrieve a token from the included credentials."
|
||||
with pytest.raises(ClientAuthenticationError, match=exception_message):
|
||||
cred = _get_legitimate_default_credential()
|
||||
|
||||
with patch("health_azure.utils._validate_credential"):
|
||||
cred = _get_legitimate_default_credential()
|
||||
assert isinstance(cred, DefaultAzureCredential)
|
||||
|
||||
|
||||
@pytest.mark.fast
|
||||
def test_get_ml_client_with_existing_client() -> None:
|
||||
"""When passing an existing ml_client, it should be returned"""
|
||||
ml_client = "mock_ml_client"
|
||||
result = get_ml_client(ml_client=ml_client) # type: ignore
|
||||
assert result == ml_client
|
||||
|
||||
|
||||
@pytest.mark.fast
|
||||
def test_get_ml_client_without_credentials() -> None:
|
||||
"""When no credentials are available, an exception should be raised"""
|
||||
with patch("health_azure.utils.get_credential", return_value=None):
|
||||
with pytest.raises(ValueError, match="Can't connect to MLClient without a valid credential"):
|
||||
get_ml_client()
|
||||
|
||||
|
||||
@pytest.mark.fast
|
||||
def test_get_ml_client_from_config_file() -> None:
|
||||
"""If a workspace config file is found, it should be used to create the MLClient"""
|
||||
mock_credentials = "mock_credentials"
|
||||
mock_config_path = Path("foo")
|
||||
mock_ml_client = MagicMock(workspace_name="workspace")
|
||||
mock_from_config = MagicMock(return_value=mock_ml_client)
|
||||
mock_resolve_config_path = MagicMock(return_value=mock_config_path)
|
||||
with patch.multiple(
|
||||
"health_azure.utils",
|
||||
get_credential=MagicMock(return_value=mock_credentials),
|
||||
resolve_workspace_config_path=mock_resolve_config_path,
|
||||
MLClient=MagicMock(from_config=mock_from_config),
|
||||
):
|
||||
config_file = Path("foo")
|
||||
result = get_ml_client(workspace_config_path=config_file)
|
||||
assert result == mock_ml_client
|
||||
mock_resolve_config_path.assert_called_once_with(config_file)
|
||||
mock_from_config.assert_called_once_with(
|
||||
credential=mock_credentials,
|
||||
path=str(mock_config_path),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.fast
|
||||
def test_get_ml_client_from_environment_variables() -> None:
|
||||
"""When no workspace config file is found, the MLClient should be created from environment variables"""
|
||||
mock_credentials = "mock_credentials"
|
||||
the_client = "the_client"
|
||||
mock_ml_client = MagicMock(return_value=the_client)
|
||||
workspace = "workspace"
|
||||
subscription = "subscription"
|
||||
resource_group = "resource_group"
|
||||
with patch.multiple(
|
||||
"health_azure.utils",
|
||||
get_credential=MagicMock(return_value=mock_credentials),
|
||||
resolve_workspace_config_path=MagicMock(return_value=None),
|
||||
MLClient=mock_ml_client,
|
||||
):
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{ENV_WORKSPACE_NAME: workspace, ENV_SUBSCRIPTION_ID: subscription, ENV_RESOURCE_GROUP: resource_group},
|
||||
):
|
||||
result = get_ml_client()
|
||||
assert result == the_client
|
||||
mock_ml_client.assert_called_once_with(
|
||||
subscription_id=subscription,
|
||||
resource_group_name=resource_group,
|
||||
workspace_name=workspace,
|
||||
credential=mock_credentials,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.fast
|
||||
def test_get_ml_client_fails() -> None:
|
||||
"""If neither a workspace config file nor environment variables are found, an exception should be raised"""
|
||||
mock_credentials = "mock_credentials"
|
||||
the_client = "the_client"
|
||||
mock_ml_client = MagicMock(return_value=the_client)
|
||||
with patch.multiple(
|
||||
"health_azure.utils",
|
||||
get_credential=MagicMock(return_value=mock_credentials),
|
||||
resolve_workspace_config_path=MagicMock(return_value=None),
|
||||
MLClient=mock_ml_client,
|
||||
):
|
||||
# In the GitHub runner, the environment variables are set. We need to unset them to test the exception
|
||||
with patch.dict(os.environ, {ENV_WORKSPACE_NAME: ""}):
|
||||
with pytest.raises(ValueError, match="Tried all ways of identifying the MLClient, but failed"):
|
||||
get_ml_client()
|
|
@ -10,6 +10,7 @@ from pathlib import Path
|
|||
from uuid import uuid4
|
||||
|
||||
from azureml.core.authentication import ServicePrincipalAuthentication
|
||||
from azureml.exceptions._azureml_exception import UserErrorException
|
||||
from _pytest.logging import LogCaptureFixture
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
@ -22,7 +23,6 @@ from health_azure.utils import (
|
|||
get_workspace,
|
||||
)
|
||||
from health_azure.utils import (
|
||||
WORKSPACE_CONFIG_JSON,
|
||||
ENV_SERVICE_PRINCIPAL_ID,
|
||||
ENV_SERVICE_PRINCIPAL_PASSWORD,
|
||||
ENV_TENANT_ID,
|
||||
|
@ -141,13 +141,16 @@ def test_get_workspace_with_given_workspace() -> None:
|
|||
|
||||
|
||||
@pytest.mark.fast
|
||||
def test_get_workspace_searches_for_file() -> None:
|
||||
def test_get_workspace_searches_for_file(tmp_path: Path) -> None:
|
||||
"""get_workspace should try to load a config.json file if not provided with one"""
|
||||
found_file = Path("does_not_exist")
|
||||
with patch("health_azure.utils.find_file_in_parent_to_pythonpath", return_value=found_file) as mock_find:
|
||||
with pytest.raises(FileNotFoundError, match="Workspace config file does not exist"):
|
||||
get_workspace(None, None)
|
||||
mock_find.assert_called_once_with(WORKSPACE_CONFIG_JSON)
|
||||
with change_working_directory(tmp_path):
|
||||
found_file = Path("does_not_exist")
|
||||
with patch("health_azure.utils.resolve_workspace_config_path", return_value=found_file) as mock_find:
|
||||
with pytest.raises(
|
||||
UserErrorException, match="workspace configuration file config.json, could not be found"
|
||||
):
|
||||
get_workspace(None, None)
|
||||
mock_find.assert_called_once_with(None)
|
||||
|
||||
|
||||
@pytest.mark.fast
|
||||
|
|
|
@ -773,7 +773,7 @@ def test_submit_run_v2(tmp_path: Path) -> None:
|
|||
with patch("azure.ai.ml.MLClient") as mock_ml_client:
|
||||
with patch("health_azure.himl.command") as mock_command:
|
||||
himl.submit_run_v2(
|
||||
workspace=None,
|
||||
ml_client=mock_ml_client,
|
||||
experiment_name=dummy_experiment_name,
|
||||
environment=dummy_environment,
|
||||
input_datasets_v2=dummy_inputs,
|
||||
|
@ -784,8 +784,6 @@ def test_submit_run_v2(tmp_path: Path) -> None:
|
|||
compute_target=dummy_compute_target,
|
||||
tags=dummy_tags,
|
||||
docker_shm_size=dummy_docker_shm_size,
|
||||
workspace_config_path=None,
|
||||
ml_client=mock_ml_client,
|
||||
hyperparam_args=None,
|
||||
display_name=dummy_display_name,
|
||||
)
|
||||
|
@ -835,7 +833,6 @@ def test_submit_run_v2(tmp_path: Path) -> None:
|
|||
expected_command += " --learning_rate=${{inputs.learning_rate}}"
|
||||
|
||||
himl.submit_run_v2(
|
||||
workspace=None,
|
||||
experiment_name=dummy_experiment_name,
|
||||
environment=dummy_environment,
|
||||
input_datasets_v2=dummy_inputs,
|
||||
|
@ -846,7 +843,6 @@ def test_submit_run_v2(tmp_path: Path) -> None:
|
|||
compute_target=dummy_compute_target,
|
||||
tags=dummy_tags,
|
||||
docker_shm_size=dummy_docker_shm_size,
|
||||
workspace_config_path=None,
|
||||
ml_client=mock_ml_client,
|
||||
hyperparam_args=dummy_hyperparam_args,
|
||||
display_name=dummy_display_name,
|
||||
|
@ -882,7 +878,7 @@ def test_submit_run_v2(tmp_path: Path) -> None:
|
|||
expected_command = f"python {dummy_entry_script_for_module} {expected_arg_str}"
|
||||
|
||||
himl.submit_run_v2(
|
||||
workspace=None,
|
||||
ml_client=mock_ml_client,
|
||||
experiment_name=dummy_experiment_name,
|
||||
environment=dummy_environment,
|
||||
input_datasets_v2=dummy_inputs,
|
||||
|
@ -893,8 +889,6 @@ def test_submit_run_v2(tmp_path: Path) -> None:
|
|||
compute_target=dummy_compute_target,
|
||||
tags=dummy_tags,
|
||||
docker_shm_size=dummy_docker_shm_size,
|
||||
workspace_config_path=None,
|
||||
ml_client=mock_ml_client,
|
||||
hyperparam_args=None,
|
||||
display_name=dummy_display_name,
|
||||
)
|
||||
|
@ -1308,9 +1302,7 @@ def test_mounting_and_downloading_dataset(tmp_path: Path) -> None:
|
|||
target_path = tmp_path / action
|
||||
dataset_config = DatasetConfig(name="hello_world", use_mounting=use_mounting, target_folder=target_path)
|
||||
logging.info(f"ready to {action}")
|
||||
paths, mount_contexts = setup_local_datasets(
|
||||
dataset_configs=[dataset_config], strictly_aml_v1=True, aml_workspace=workspace
|
||||
)
|
||||
paths, mount_contexts = setup_local_datasets(dataset_configs=[dataset_config], workspace=workspace)
|
||||
logging.info(f"{action} done")
|
||||
path = paths[0]
|
||||
assert path is not None
|
||||
|
@ -1372,7 +1364,7 @@ class TestOutputDataset:
|
|||
@pytest.mark.parametrize(
|
||||
["run_target", "local_folder", "strictly_aml_v1"],
|
||||
[
|
||||
(RunTarget.LOCAL, True, False),
|
||||
(RunTarget.LOCAL, True, True),
|
||||
(RunTarget.AZUREML, False, True),
|
||||
],
|
||||
)
|
||||
|
|
|
@ -10,7 +10,6 @@ from contextlib import contextmanager
|
|||
from pathlib import Path
|
||||
from typing import Dict, Generator, Optional
|
||||
|
||||
from azure.ai.ml import MLClient
|
||||
from azureml.core import Run
|
||||
from health_azure.utils import (
|
||||
ENV_EXPERIMENT_NAME,
|
||||
|
@ -20,7 +19,6 @@ from health_azure.utils import (
|
|||
)
|
||||
from health_azure import create_aml_run_object
|
||||
from health_azure.himl import effective_experiment_name
|
||||
from health_azure.utils import get_ml_client, get_workspace
|
||||
|
||||
DEFAULT_DATASTORE = "himldatasets"
|
||||
|
||||
|
@ -111,16 +109,6 @@ def create_unittest_run_object(snapshot_directory: Optional[Path] = None) -> Run
|
|||
)
|
||||
|
||||
|
||||
def get_test_ml_client() -> MLClient:
|
||||
"""Generates an MLClient object for use in tests.
|
||||
|
||||
:return: MLClient object
|
||||
"""
|
||||
|
||||
workspace = get_workspace()
|
||||
return get_ml_client(aml_workspace=workspace)
|
||||
|
||||
|
||||
def current_test_name() -> str:
|
||||
"""Get the name of the currently executed test. This is read off an environment variable. If that
|
||||
is not found, the function returns an empty string."""
|
||||
|
|
|
@ -21,7 +21,7 @@ def mount_dataset(dataset_id: str, tmp_root: str = "/tmp/datasets", aml_workspac
|
|||
ws = get_workspace(aml_workspace)
|
||||
target_folder = "/".join([tmp_root, dataset_id])
|
||||
dataset = DatasetConfig(name=dataset_id, target_folder=target_folder, use_mounting=True)
|
||||
_, mount_ctx = dataset.to_input_dataset_local(strictly_aml_v1=True, workspace=ws)
|
||||
_, mount_ctx = dataset.to_input_dataset_local(workspace=ws)
|
||||
assert mount_ctx is not None # for mypy
|
||||
mount_ctx.start()
|
||||
return mount_ctx
|
||||
|
|
|
@ -21,5 +21,5 @@ def download_azure_dataset(tmp_path: Path, dataset_id: str) -> None:
|
|||
with check_config_json(script_folder=tmp_path, shared_config_json=get_shared_config_json()):
|
||||
ws = get_workspace(workspace_config_path=tmp_path / WORKSPACE_CONFIG_JSON)
|
||||
dataset = DatasetConfig(name=dataset_id, target_folder=tmp_path, use_mounting=False)
|
||||
dataset_dl_folder = dataset.to_input_dataset_local(strictly_aml_v1=True, workspace=ws)
|
||||
dataset_dl_folder = dataset.to_input_dataset_local(workspace=ws)
|
||||
logging.info(f"Dataset saved in {dataset_dl_folder}")
|
||||
|
|
|
@ -15,6 +15,14 @@ class RunnerMode(Enum):
|
|||
EVAL_FULL = "eval_full"
|
||||
|
||||
|
||||
class LogLevel(Enum):
|
||||
ERROR = "ERROR"
|
||||
WARNING = "WARNING"
|
||||
WARN = "WARN"
|
||||
INFO = "INFO"
|
||||
DEBUG = "DEBUG"
|
||||
|
||||
|
||||
DEBUG_DDP_ENV_VAR = "TORCH_DISTRIBUTED_DEBUG"
|
||||
|
||||
|
||||
|
@ -87,10 +95,22 @@ class ExperimentConfig(param.Parameterized):
|
|||
doc="The maximum runtime that is allowed for this job in AzureML. This is given as a floating"
|
||||
"point number with a string suffix s, m, h, d for seconds, minutes, hours, day. Examples: '3.5h', '2d'",
|
||||
)
|
||||
mode: str = param.ClassSelector(
|
||||
mode: RunnerMode = param.ClassSelector(
|
||||
class_=RunnerMode,
|
||||
default=RunnerMode.TRAIN,
|
||||
doc=f"The mode to run the experiment in. Can be one of '{RunnerMode.TRAIN}' (training and evaluation on the "
|
||||
f"test set), or '{RunnerMode.EVAL_FULL}' for evaluation on the full dataset specified by the "
|
||||
"'get_eval_data_module' method of the container.",
|
||||
)
|
||||
log_level: Optional[RunnerMode] = param.ClassSelector(
|
||||
class_=LogLevel,
|
||||
default=None,
|
||||
doc=f"The log level to use. Can be one of {list(map(str, LogLevel))}",
|
||||
)
|
||||
|
||||
@property
|
||||
def submit_to_azure_ml(self) -> bool:
|
||||
"""Returns True if the experiment should be submitted to AzureML, False if it should be run locally.
|
||||
|
||||
:return: True if the experiment should be submitted to AzureML, False if it should be run locally."""
|
||||
return self.cluster != ""
|
||||
|
|
|
@ -15,8 +15,6 @@ import sys
|
|||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from azureml.core import Workspace
|
||||
|
||||
# Add hi-ml packages to sys.path so that AML can find them if we are using the runner directly from the git repo
|
||||
himl_root = Path(__file__).resolve().parent.parent.parent.parent
|
||||
folders_to_add = [himl_root / "hi-ml" / "src", himl_root / "hi-ml-azure" / "src", himl_root / "hi-ml-cpath" / "src"]
|
||||
|
@ -34,8 +32,6 @@ from health_azure.paths import is_himl_used_from_git_repo # noqa: E402
|
|||
from health_azure.utils import ( # noqa: E402
|
||||
ENV_LOCAL_RANK,
|
||||
ENV_NODE_RANK,
|
||||
get_workspace,
|
||||
get_ml_client,
|
||||
is_local_rank_zero,
|
||||
is_running_in_azure_ml,
|
||||
set_environment_variables_for_multi_node,
|
||||
|
@ -122,6 +118,11 @@ class Runner:
|
|||
parser1_result = parse_arguments(parser1, args=filtered_args)
|
||||
experiment_config = ExperimentConfig(**parser1_result.args)
|
||||
|
||||
from health_azure.logging import logging_stdout_handler # noqa: E402
|
||||
|
||||
if logging_stdout_handler is not None and experiment_config.log_level is not None:
|
||||
print(f"Setting custom logging level to {experiment_config.log_level}")
|
||||
logging_stdout_handler.setLevel(experiment_config.log_level.value)
|
||||
self.experiment_config = experiment_config
|
||||
if not experiment_config.model:
|
||||
raise ValueError("Parameter 'model' needs to be set to specify which model to run.")
|
||||
|
@ -150,7 +151,7 @@ class Runner:
|
|||
"""
|
||||
Runs sanity checks on the whole experiment.
|
||||
"""
|
||||
if not self.experiment_config.cluster:
|
||||
if not self.experiment_config.submit_to_azure_ml:
|
||||
if self.lightning_container.hyperdrive:
|
||||
raise ValueError(
|
||||
"HyperDrive for hyperparameters tuning is only supported when submitting the job to "
|
||||
|
@ -214,47 +215,26 @@ class Runner:
|
|||
script_params = sys.argv[1:]
|
||||
|
||||
environment_variables = self.additional_environment_variables()
|
||||
|
||||
# Get default datastore from the provided workspace. Authentication can take a few seconds, hence only do
|
||||
# that if we are really submitting to AzureML.
|
||||
workspace: Optional[Workspace] = None
|
||||
if self.experiment_config.cluster:
|
||||
try:
|
||||
workspace = get_workspace(workspace_config_path=self.experiment_config.workspace_config_path)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
"Unable to submit the script to AzureML because no workspace configuration file "
|
||||
"(config.json) was found."
|
||||
)
|
||||
|
||||
if self.lightning_container.datastore:
|
||||
datastore = self.lightning_container.datastore
|
||||
elif workspace:
|
||||
datastore = workspace.get_default_datastore().name
|
||||
else:
|
||||
datastore = ""
|
||||
|
||||
local_datasets = self.lightning_container.local_datasets
|
||||
all_local_datasets = [Path(p) for p in local_datasets] if len(local_datasets) > 0 else []
|
||||
# When running in AzureML, respect the commandline flag for mounting. Outside of AML, we always mount
|
||||
# datasets to be quicker.
|
||||
use_mounting = self.experiment_config.mount_in_azureml if self.experiment_config.cluster else True
|
||||
use_mounting = self.experiment_config.mount_in_azureml if self.experiment_config.submit_to_azure_ml else True
|
||||
input_datasets = create_dataset_configs(
|
||||
all_azure_dataset_ids=self.lightning_container.azure_datasets,
|
||||
all_dataset_mountpoints=self.lightning_container.dataset_mountpoints,
|
||||
all_local_datasets=all_local_datasets, # type: ignore
|
||||
datastore=datastore,
|
||||
datastore=self.lightning_container.datastore,
|
||||
use_mounting=use_mounting,
|
||||
)
|
||||
|
||||
if self.experiment_config.cluster and not is_running_in_azure_ml():
|
||||
if self.experiment_config.submit_to_azure_ml and not is_running_in_azure_ml():
|
||||
if self.experiment_config.strictly_aml_v1:
|
||||
hyperdrive_config = self.lightning_container.get_hyperdrive_config()
|
||||
hyperparam_args = None
|
||||
else:
|
||||
hyperparam_args = self.lightning_container.get_hyperparam_args()
|
||||
hyperdrive_config = None
|
||||
ml_client = get_ml_client(aml_workspace=workspace) if not self.experiment_config.strictly_aml_v1 else None
|
||||
|
||||
env_file = choose_conda_env_file(env_file=self.experiment_config.conda_env)
|
||||
logging.info(f"Using this Conda environment definition: {env_file}")
|
||||
|
@ -265,18 +245,15 @@ class Runner:
|
|||
snapshot_root_directory=root_folder,
|
||||
script_params=script_params,
|
||||
conda_environment_file=env_file,
|
||||
aml_workspace=workspace,
|
||||
ml_client=ml_client,
|
||||
compute_cluster_name=self.experiment_config.cluster,
|
||||
environment_variables=environment_variables,
|
||||
default_datastore=datastore,
|
||||
experiment_name=self.lightning_container.effective_experiment_name,
|
||||
input_datasets=input_datasets, # type: ignore
|
||||
num_nodes=self.experiment_config.num_nodes,
|
||||
wait_for_completion=self.experiment_config.wait_for_completion,
|
||||
max_run_duration=self.experiment_config.max_run_duration,
|
||||
ignored_folders=[],
|
||||
submit_to_azureml=bool(self.experiment_config.cluster),
|
||||
submit_to_azureml=self.experiment_config.submit_to_azure_ml,
|
||||
docker_base_image=DEFAULT_DOCKER_BASE_IMAGE,
|
||||
docker_shm_size=self.experiment_config.docker_shm_size,
|
||||
hyperdrive_config=hyperdrive_config,
|
||||
|
@ -292,7 +269,6 @@ class Runner:
|
|||
submit_to_azureml=False,
|
||||
environment_variables=environment_variables,
|
||||
strictly_aml_v1=self.experiment_config.strictly_aml_v1,
|
||||
default_datastore=datastore,
|
||||
)
|
||||
if azure_run_info.run:
|
||||
# This code is only reached inside Azure. Set display name again - this will now affect
|
||||
|
|
|
@ -26,6 +26,7 @@ from health_ml.lightning_container import LightningContainer
|
|||
from health_ml.runner import Runner, create_logging_filename, run_with_logging
|
||||
from health_ml.utils.common_utils import change_working_directory
|
||||
from health_ml.utils.fixed_paths import repository_root_directory
|
||||
from testhiml.utils_testhiml import DEFAULT_WORKSPACE
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
@ -132,10 +133,9 @@ def test_ddp_debug_flag(debug_ddp: DebugDDPOptions, mock_runner: Runner) -> None
|
|||
model_name = "HelloWorld"
|
||||
arguments = ["", f"--debug_ddp={debug_ddp}", f"--model={model_name}"]
|
||||
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_azure_if_needed:
|
||||
with patch("health_ml.runner.get_workspace"):
|
||||
with patch("health_ml.runner.Runner.run_in_situ"):
|
||||
with patch.object(sys, "argv", arguments):
|
||||
mock_runner.run()
|
||||
with patch("health_ml.runner.Runner.run_in_situ"):
|
||||
with patch.object(sys, "argv", arguments):
|
||||
mock_runner.run()
|
||||
mock_submit_to_azure_if_needed.assert_called_once()
|
||||
assert mock_submit_to_azure_if_needed.call_args[1]["environment_variables"][DEBUG_DDP_ENV_VAR] == debug_ddp
|
||||
|
||||
|
@ -144,12 +144,9 @@ def test_additional_aml_run_tags(mock_runner: Runner) -> None:
|
|||
model_name = "HelloWorld"
|
||||
arguments = ["", f"--model={model_name}", "--cluster=foo"]
|
||||
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_azure_if_needed:
|
||||
with patch("health_ml.runner.check_conda_environment"):
|
||||
with patch("health_ml.runner.get_workspace"):
|
||||
with patch("health_ml.runner.get_ml_client"):
|
||||
with patch("health_ml.runner.Runner.run_in_situ"):
|
||||
with patch.object(sys, "argv", arguments):
|
||||
mock_runner.run()
|
||||
with patch("health_ml.runner.Runner.run_in_situ"):
|
||||
with patch.object(sys, "argv", arguments):
|
||||
mock_runner.run()
|
||||
mock_submit_to_azure_if_needed.assert_called_once()
|
||||
assert "commandline_args" in mock_submit_to_azure_if_needed.call_args[1]["tags"]
|
||||
assert "tag" in mock_submit_to_azure_if_needed.call_args[1]["tags"]
|
||||
|
@ -162,9 +159,6 @@ def test_additional_environment_variables(mock_runner: Runner) -> None:
|
|||
with patch.multiple(
|
||||
"health_ml.runner",
|
||||
submit_to_azure_if_needed=DEFAULT,
|
||||
check_conda_environment=DEFAULT,
|
||||
get_workspace=DEFAULT,
|
||||
get_ml_client=DEFAULT,
|
||||
) as mocks:
|
||||
with patch("health_ml.runner.Runner.run_in_situ"):
|
||||
with patch("health_ml.runner.Runner.parse_and_load_model"):
|
||||
|
@ -185,9 +179,8 @@ def test_run(mock_runner: Runner) -> None:
|
|||
model_name = "HelloWorld"
|
||||
arguments = ["", f"--model={model_name}"]
|
||||
with patch("health_ml.runner.Runner.run_in_situ") as mock_run_in_situ:
|
||||
with patch("health_ml.runner.get_workspace"):
|
||||
with patch.object(sys, "argv", arguments):
|
||||
model_config, azure_run_info = mock_runner.run()
|
||||
with patch.object(sys, "argv", arguments):
|
||||
model_config, azure_run_info = mock_runner.run()
|
||||
mock_run_in_situ.assert_called_once()
|
||||
|
||||
assert model_config is not None # for pyright
|
||||
|
@ -197,17 +190,13 @@ def test_run(mock_runner: Runner) -> None:
|
|||
|
||||
|
||||
@patch("health_ml.runner.choose_conda_env_file")
|
||||
@patch("health_ml.runner.get_workspace")
|
||||
@pytest.mark.fast
|
||||
def test_submit_to_azureml_if_needed(
|
||||
mock_get_workspace: MagicMock, mock_get_env_files: MagicMock, mock_runner: Runner
|
||||
) -> None:
|
||||
def test_submit_to_azureml_if_needed(mock_get_env_files: MagicMock, mock_runner: Runner) -> None:
|
||||
def _mock_dont_submit_to_aml(
|
||||
input_datasets: List[DatasetConfig],
|
||||
submit_to_azureml: bool,
|
||||
strictly_aml_v1: bool, # type: ignore
|
||||
environment_variables: Dict[str, Any], # type: ignore
|
||||
default_datastore: Optional[str], # type: ignore
|
||||
) -> AzureRunInfo:
|
||||
datasets_input = [d.target_folder for d in input_datasets] if input_datasets else []
|
||||
return AzureRunInfo(
|
||||
|
@ -222,10 +211,6 @@ def test_submit_to_azureml_if_needed(
|
|||
|
||||
mock_get_env_files.return_value = Path("some_file.txt")
|
||||
|
||||
mock_default_datastore = MagicMock()
|
||||
mock_default_datastore.name.return_value = "dummy_datastore"
|
||||
mock_get_workspace.get_default_datastore.return_value = mock_default_datastore
|
||||
|
||||
with patch("health_ml.runner.create_dataset_configs") as mock_create_datasets:
|
||||
mock_create_datasets.return_value = []
|
||||
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_aml:
|
||||
|
@ -334,11 +319,9 @@ def _test_hyperdrive_submission(
|
|||
# start in that temp folder.
|
||||
with change_working_folder_and_add_environment(mock_runner.project_root):
|
||||
with patch("health_ml.runner.Runner.run_in_situ") as mock_run_in_situ:
|
||||
with patch("health_ml.runner.get_workspace"):
|
||||
with patch("health_ml.runner.get_ml_client"):
|
||||
with patch.object(sys, "argv", arguments):
|
||||
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_aml:
|
||||
mock_runner.run()
|
||||
with patch.object(sys, "argv", arguments):
|
||||
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_aml:
|
||||
mock_runner.run()
|
||||
mock_run_in_situ.assert_called_once()
|
||||
mock_submit_to_aml.assert_called_once()
|
||||
# call_args is a tuple of (args, kwargs)
|
||||
|
@ -364,11 +347,9 @@ def test_submit_to_azure_docker(mock_runner: Runner) -> None:
|
|||
# start in that temp folder.
|
||||
with change_working_folder_and_add_environment(mock_runner.project_root):
|
||||
with patch("health_ml.runner.Runner.run_in_situ") as mock_run_in_situ:
|
||||
with patch("health_ml.runner.get_ml_client"):
|
||||
with patch("health_ml.runner.get_workspace"):
|
||||
with patch.object(sys, "argv", arguments):
|
||||
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_aml:
|
||||
mock_runner.run()
|
||||
with patch.object(sys, "argv", arguments):
|
||||
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_aml:
|
||||
mock_runner.run()
|
||||
mock_run_in_situ.assert_called_once()
|
||||
mock_submit_to_aml.assert_called_once()
|
||||
# call_args is a tuple of (args, kwargs)
|
||||
|
@ -393,16 +374,12 @@ def test_run_hello_world(mock_runner: Runner) -> None:
|
|||
"""Test running a model end-to-end via the commandline runner"""
|
||||
model_name = "HelloWorld"
|
||||
arguments = ["", f"--model={model_name}"]
|
||||
with patch("health_ml.runner.get_workspace") as mock_get_workspace:
|
||||
with patch.object(sys, "argv", arguments):
|
||||
mock_runner.run()
|
||||
# get_workspace should not be called when using the runner outside AzureML, to not go through the
|
||||
# time-consuming auth
|
||||
mock_get_workspace.assert_not_called()
|
||||
# Summary.txt is written at start, the other files during inference
|
||||
expected_files = ["experiment_summary.txt", TEST_MSE_FILE, TEST_MAE_FILE]
|
||||
for file in expected_files:
|
||||
assert (mock_runner.lightning_container.outputs_folder / file).is_file(), f"Missing file: {file}"
|
||||
with patch.object(sys, "argv", arguments):
|
||||
mock_runner.run()
|
||||
# Summary.txt is written at start, the other files during inference
|
||||
expected_files = ["experiment_summary.txt", TEST_MSE_FILE, TEST_MAE_FILE]
|
||||
for file in expected_files:
|
||||
assert (mock_runner.lightning_container.outputs_folder / file).is_file(), f"Missing file: {file}"
|
||||
|
||||
|
||||
def test_invalid_args(mock_runner: Runner) -> None:
|
||||
|
@ -425,17 +402,37 @@ def test_invalid_profiler(mock_runner: Runner) -> None:
|
|||
mock_runner.run()
|
||||
|
||||
|
||||
def test_custom_datastore_outside_aml(mock_runner: Runner) -> None:
|
||||
def test_datastore_argument(mock_runner: Runner) -> None:
|
||||
"""The datastore argument should be respected"""
|
||||
model_name = "HelloWorld"
|
||||
datastore = "foo"
|
||||
arguments = ["", f"--datastore={datastore}", f"--model={model_name}"]
|
||||
dataset = "bar"
|
||||
arguments = ["", f"--datastore={datastore}", f"--model={model_name}", f"--azure_datasets={dataset}"]
|
||||
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_azure_if_needed:
|
||||
with patch("health_ml.runner.get_workspace"):
|
||||
with patch("health_ml.runner.Runner.run_in_situ"):
|
||||
with patch.object(sys, "argv", arguments):
|
||||
mock_runner.run()
|
||||
with patch("health_ml.runner.Runner.run_in_situ"):
|
||||
with patch.object(sys, "argv", arguments):
|
||||
mock_runner.run()
|
||||
mock_submit_to_azure_if_needed.assert_called_once()
|
||||
assert mock_submit_to_azure_if_needed.call_args[1]["default_datastore"] == datastore
|
||||
input_datasets = mock_submit_to_azure_if_needed.call_args[1]["input_datasets"]
|
||||
assert len(input_datasets) == 1
|
||||
assert input_datasets[0].datastore == datastore
|
||||
assert input_datasets[0].name == dataset
|
||||
|
||||
|
||||
def test_no_authentication_outside_azureml(mock_runner: Runner) -> None:
|
||||
"""No authentication should happen for a model that runs locally and needs no datasets."""
|
||||
model_name = "HelloWorld"
|
||||
arguments = ["", f"--datastore=datastore", f"--model={model_name}"]
|
||||
with patch("health_ml.runner.Runner.run_in_situ"):
|
||||
with patch.object(sys, "argv", arguments):
|
||||
mock_get_workspace = MagicMock()
|
||||
mock_get_ml_client = MagicMock()
|
||||
with patch.multiple(
|
||||
"health_azure.himl", get_workspace=mock_get_workspace, get_ml_client=mock_get_ml_client
|
||||
):
|
||||
mock_runner.run()
|
||||
mock_get_workspace.assert_not_called()
|
||||
mock_get_ml_client.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.fast
|
||||
|
@ -512,3 +509,111 @@ def test_run_without_logging(tmp_path: Path) -> None:
|
|||
run_with_logging(tmp_path)
|
||||
mock_create_filename.assert_not_called()
|
||||
mock_run.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.fast
|
||||
def test_runner_does_not_use_get_workspace() -> None:
|
||||
"""Test that the runner does not itself import get_workspace or get_ml_client (otherwise we would need to check
|
||||
them in the tests below that count calls to those methods)"""
|
||||
with pytest.raises(ImportError):
|
||||
from health_ml.runner import get_workspace # type: ignore
|
||||
with pytest.raises(ImportError):
|
||||
from health_ml.runner import get_ml_client # type: ignore
|
||||
|
||||
|
||||
def test_runner_authenticates_once_v1() -> None:
|
||||
"""Test that the runner requires authentication only once when doing a job submission with the V1 SDK"""
|
||||
runner = Runner(project_root=repository_root_directory())
|
||||
mock_get_workspace = MagicMock()
|
||||
mock_get_ml_client = MagicMock()
|
||||
with patch.multiple(
|
||||
"health_azure.himl",
|
||||
get_workspace=mock_get_workspace,
|
||||
get_ml_client=mock_get_ml_client,
|
||||
Experiment=MagicMock(),
|
||||
register_environment=MagicMock(return_value="env"),
|
||||
validate_compute_cluster=MagicMock(),
|
||||
):
|
||||
with patch.object(
|
||||
sys,
|
||||
"argv",
|
||||
["src/health_ml/runner.py", "--model=HelloWorld", "--cluster=pr-gpu", "--strictly_aml_v1"],
|
||||
):
|
||||
# Job submission should trigger a system exit
|
||||
with pytest.raises(SystemExit):
|
||||
runner.run()
|
||||
mock_get_workspace.assert_called_once()
|
||||
mock_get_ml_client.assert_not_called()
|
||||
|
||||
|
||||
def test_runner_authenticates_once_v2() -> None:
|
||||
"""Test that the runner requires authentication only once when doing a job submission with the V2 SDK"""
|
||||
runner = Runner(project_root=repository_root_directory())
|
||||
mock_get_workspace = MagicMock()
|
||||
mock_get_ml_client = MagicMock()
|
||||
with patch.multiple(
|
||||
"health_azure.himl",
|
||||
get_workspace=mock_get_workspace,
|
||||
get_ml_client=mock_get_ml_client,
|
||||
command=MagicMock(),
|
||||
):
|
||||
with patch.object(sys, "argv", ["", "--model=HelloWorld", "--cluster=pr-gpu"]):
|
||||
# Job submission should trigger a system exit
|
||||
with pytest.raises(SystemExit):
|
||||
runner.run()
|
||||
mock_get_workspace.assert_not_called()
|
||||
mock_get_ml_client.assert_called_once()
|
||||
|
||||
|
||||
def test_runner_with_local_dataset_v1() -> None:
|
||||
"""Test that the runner requires authentication only once when doing a local run and a dataset has to be mounted"""
|
||||
runner = Runner(project_root=repository_root_directory())
|
||||
mock_get_workspace = MagicMock(return_value=DEFAULT_WORKSPACE.workspace)
|
||||
mock_get_ml_client = MagicMock()
|
||||
with patch.multiple(
|
||||
"health_azure.himl",
|
||||
get_workspace=mock_get_workspace,
|
||||
get_ml_client=mock_get_ml_client,
|
||||
):
|
||||
with patch.object(
|
||||
sys,
|
||||
"argv",
|
||||
[
|
||||
"src/health_ml/runner.py",
|
||||
"--model=HelloWorld",
|
||||
"--strictly_aml_v1",
|
||||
"--azure_datasets=hello_world",
|
||||
],
|
||||
):
|
||||
runner.run()
|
||||
mock_get_workspace.assert_called_once()
|
||||
mock_get_ml_client.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_local_dataset", [True, False])
|
||||
def test_runner_with_local_dataset_v2(use_local_dataset: bool, tmp_path: Path) -> None:
|
||||
"""Test that the runner requires authentication only once when doing a local run with SDK v2"""
|
||||
runner = Runner(project_root=repository_root_directory())
|
||||
mock_get_workspace = MagicMock()
|
||||
mock_get_ml_client = MagicMock(return_value=DEFAULT_WORKSPACE.ml_client)
|
||||
with patch.multiple(
|
||||
"health_azure.himl",
|
||||
get_workspace=mock_get_workspace,
|
||||
get_ml_client=mock_get_ml_client,
|
||||
):
|
||||
args = [
|
||||
"src/health_ml/runner.py",
|
||||
"--model=HelloWorld",
|
||||
f"--strictly_aml_v1=False",
|
||||
"--azure_datasets=hello_world",
|
||||
]
|
||||
if use_local_dataset:
|
||||
args.append(f"--local_datasets={tmp_path}")
|
||||
with patch.object(sys, "argv", args):
|
||||
if use_local_dataset:
|
||||
runner.run()
|
||||
else:
|
||||
with pytest.raises(ValueError, match="AzureML SDK v2 does not support downloading datasets from"):
|
||||
runner.run()
|
||||
mock_get_workspace.assert_not_called()
|
||||
mock_get_ml_client.assert_called_once()
|
||||
|
|
Загрузка…
Ссылка в новой задаче