зеркало из https://github.com/microsoft/hi-ml.git
BUG: Fix for duplicate authentication (#878)
Resolve issues with AML SDK v1/v2: Even in the v2 code path we are creating v1 workspaces, leading to duplicate authentication requests. To achieve that, a couple of other changes were necessary: * Deprecate the use of the default datastore, which was read out of an SDK v1 Workspace object even if SDK v2 was chosen * No longer allowing SDK v2 when mounting datasets for local runs (v2 Datasets can't be mounted at all). Also added more detailed logging for dataset creation, and a commandline flag to control logging level.
This commit is contained in:
Родитель
683def950a
Коммит
f46f60e7fa
|
@ -20,7 +20,9 @@ from azureml.data.dataset_consumption_config import DatasetConsumptionConfig
|
||||||
from azureml.dataprep.fuse.daemon import MountContext
|
from azureml.dataprep.fuse.daemon import MountContext
|
||||||
from azureml.exceptions._azureml_exception import UserErrorException
|
from azureml.exceptions._azureml_exception import UserErrorException
|
||||||
|
|
||||||
from health_azure.utils import PathOrString, get_workspace, get_ml_client
|
from health_azure.utils import PathOrString, get_ml_client
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
V1OrV2DataType = Union[FileDataset, Data]
|
V1OrV2DataType = Union[FileDataset, Data]
|
||||||
|
@ -128,11 +130,14 @@ def _get_or_create_v1_dataset(datastore_name: str, dataset_name: str, workspace:
|
||||||
try:
|
try:
|
||||||
azureml_dataset = _retrieve_v1_dataset(dataset_name, workspace)
|
azureml_dataset = _retrieve_v1_dataset(dataset_name, workspace)
|
||||||
except UserErrorException:
|
except UserErrorException:
|
||||||
|
logger.warning(f"Dataset '{dataset_name}' was not found, or is not an AzureML SDK v1 dataset.")
|
||||||
|
logger.info(f"Trying to create a new dataset '{dataset_name}' from files in folder '{dataset_name}'")
|
||||||
if datastore_name == "":
|
if datastore_name == "":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"When creating a new dataset, a datastore name must be provided. Please specify a datastore name using "
|
"When creating a new dataset, a datastore name must be provided. Please specify a datastore name using "
|
||||||
"the --datastore flag"
|
"the --datastore flag"
|
||||||
)
|
)
|
||||||
|
logger.info(f"Trying to create a new dataset '{dataset_name}' in datastore '{datastore_name}'")
|
||||||
azureml_dataset = _create_v1_dataset(datastore_name, dataset_name, workspace)
|
azureml_dataset = _create_v1_dataset(datastore_name, dataset_name, workspace)
|
||||||
return azureml_dataset
|
return azureml_dataset
|
||||||
|
|
||||||
|
@ -352,10 +357,8 @@ class DatasetConfig:
|
||||||
|
|
||||||
def to_input_dataset_local(
|
def to_input_dataset_local(
|
||||||
self,
|
self,
|
||||||
strictly_aml_v1: bool,
|
workspace: Workspace,
|
||||||
workspace: Workspace = None,
|
) -> Tuple[Path, Optional[MountContext]]:
|
||||||
ml_client: Optional[MLClient] = None,
|
|
||||||
) -> Tuple[Optional[Path], Optional[MountContext]]:
|
|
||||||
"""
|
"""
|
||||||
Return a local path to the dataset when outside of an AzureML run.
|
Return a local path to the dataset when outside of an AzureML run.
|
||||||
If local_folder is supplied, then this is assumed to be a local dataset, and this is returned.
|
If local_folder is supplied, then this is assumed to be a local dataset, and this is returned.
|
||||||
|
@ -364,9 +367,6 @@ class DatasetConfig:
|
||||||
therefore a tuple of Nones will be returned.
|
therefore a tuple of Nones will be returned.
|
||||||
|
|
||||||
:param workspace: The AzureML workspace to read from.
|
:param workspace: The AzureML workspace to read from.
|
||||||
:param strictly_aml_v1: If True, use Azure ML SDK v1 to attempt to find or create and reigster the dataset.
|
|
||||||
Otherwise, attempt to use Azure ML SDK v2.
|
|
||||||
:param ml_client: An Azure MLClient object for interacting with Azure resources.
|
|
||||||
:return: Tuple of (path to dataset, optional mountcontext)
|
:return: Tuple of (path to dataset, optional mountcontext)
|
||||||
"""
|
"""
|
||||||
status = f"Dataset '{self.name}' will be "
|
status = f"Dataset '{self.name}' will be "
|
||||||
|
@ -381,12 +381,10 @@ class DatasetConfig:
|
||||||
f"Unable to make dataset '{self.name} available for a local run because no AzureML "
|
f"Unable to make dataset '{self.name} available for a local run because no AzureML "
|
||||||
"workspace has been provided. Provide a workspace, or set a folder for local execution."
|
"workspace has been provided. Provide a workspace, or set a folder for local execution."
|
||||||
)
|
)
|
||||||
azureml_dataset = get_or_create_dataset(
|
azureml_dataset = _get_or_create_v1_dataset(
|
||||||
datastore_name=self.datastore,
|
datastore_name=self.datastore,
|
||||||
dataset_name=self.name,
|
dataset_name=self.name,
|
||||||
workspace=workspace,
|
workspace=workspace,
|
||||||
strictly_aml_v1=strictly_aml_v1,
|
|
||||||
ml_client=ml_client,
|
|
||||||
)
|
)
|
||||||
if isinstance(azureml_dataset, FileDataset):
|
if isinstance(azureml_dataset, FileDataset):
|
||||||
target_path = self.target_folder or Path(tempfile.mkdtemp())
|
target_path = self.target_folder or Path(tempfile.mkdtemp())
|
||||||
|
@ -404,7 +402,7 @@ class DatasetConfig:
|
||||||
print(status)
|
print(status)
|
||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
return None, None
|
raise ValueError(f"Don't know how to handle dataset '{self.name}' of type {type(azureml_dataset)}")
|
||||||
|
|
||||||
def to_input_dataset(
|
def to_input_dataset(
|
||||||
self,
|
self,
|
||||||
|
@ -556,38 +554,10 @@ def create_dataset_configs(
|
||||||
return datasets
|
return datasets
|
||||||
|
|
||||||
|
|
||||||
def find_workspace_for_local_datasets(
|
|
||||||
aml_workspace: Optional[Workspace], workspace_config_path: Optional[Path], dataset_configs: List[DatasetConfig]
|
|
||||||
) -> Optional[Workspace]:
|
|
||||||
"""
|
|
||||||
If any of the dataset_configs require an AzureML workspace then try to get one, otherwise return None.
|
|
||||||
|
|
||||||
:param aml_workspace: There are two optional parameters used to glean an existing AzureML Workspace. The simplest is
|
|
||||||
to pass it in as a parameter.
|
|
||||||
:param workspace_config_path: The 2nd option is to specify the path to the config.json file downloaded from the
|
|
||||||
Azure portal from which we can retrieve the existing Workspace.
|
|
||||||
:param dataset_configs: List of DatasetConfig describing the input datasets.
|
|
||||||
:return: Workspace if required, None otherwise.
|
|
||||||
"""
|
|
||||||
workspace: Workspace = None
|
|
||||||
# Check whether an attempt will be made to mount or download a dataset when running locally.
|
|
||||||
# If so, try to get the AzureML workspace.
|
|
||||||
if any(dc.local_folder is None for dc in dataset_configs):
|
|
||||||
try:
|
|
||||||
workspace = get_workspace(aml_workspace, workspace_config_path)
|
|
||||||
logging.info(f"Found workspace for datasets: {workspace.name}")
|
|
||||||
except Exception as ex:
|
|
||||||
logging.info(f"Could not find workspace for datasets. Exception: {ex}")
|
|
||||||
return workspace
|
|
||||||
|
|
||||||
|
|
||||||
def setup_local_datasets(
|
def setup_local_datasets(
|
||||||
dataset_configs: List[DatasetConfig],
|
dataset_configs: List[DatasetConfig],
|
||||||
strictly_aml_v1: bool,
|
workspace: Optional[Workspace],
|
||||||
aml_workspace: Optional[Workspace] = None,
|
) -> Tuple[List[Path], List[MountContext]]:
|
||||||
ml_client: Optional[MLClient] = None,
|
|
||||||
workspace_config_path: Optional[Path] = None,
|
|
||||||
) -> Tuple[List[Optional[Path]], List[MountContext]]:
|
|
||||||
"""
|
"""
|
||||||
When running outside of AzureML, setup datasets to be used locally.
|
When running outside of AzureML, setup datasets to be used locally.
|
||||||
|
|
||||||
|
@ -595,21 +565,20 @@ def setup_local_datasets(
|
||||||
used. Otherwise the dataset is mounted or downloaded to either the target folder or a temporary folder and that is
|
used. Otherwise the dataset is mounted or downloaded to either the target folder or a temporary folder and that is
|
||||||
used.
|
used.
|
||||||
|
|
||||||
:param aml_workspace: There are two optional parameters used to glean an existing AzureML Workspace. The simplest is
|
If a dataset does not exist, an AzureML SDK v1 dataset will be created, assuming that the dataset is given
|
||||||
to pass it in as a parameter.
|
in a folder of the same name (for example, if a dataset is given as "mydataset", then it is created from the files
|
||||||
:param workspace_config_path: The 2nd option is to specify the path to the config.json file downloaded from the
|
in folder "mydataset" in the datastore).
|
||||||
Azure portal from which we can retrieve the existing Workspace.
|
|
||||||
|
:param workspace: The AzureML workspace to work with. Can be None if the list of datasets is empty, or if
|
||||||
|
the datasets are available local.
|
||||||
:param dataset_configs: List of DatasetConfig describing the input data assets.
|
:param dataset_configs: List of DatasetConfig describing the input data assets.
|
||||||
:param strictly_aml_v1: If True, use Azure ML SDK v1. Otherwise, attempt to use Azure ML SDK v2.
|
:return: Pair of: list of paths to the input datasets, list of mountcontexts, one for each mounted dataset.
|
||||||
:param ml_client: An MLClient object for interacting with AML v2 datastores.
|
|
||||||
:return: Pair of: list of optional paths to the input datasets, list of mountcontexts, one for each mounted dataset.
|
|
||||||
"""
|
"""
|
||||||
workspace = find_workspace_for_local_datasets(aml_workspace, workspace_config_path, dataset_configs)
|
mounted_input_datasets: List[Path] = []
|
||||||
mounted_input_datasets: List[Optional[Path]] = []
|
|
||||||
mount_contexts: List[MountContext] = []
|
mount_contexts: List[MountContext] = []
|
||||||
|
|
||||||
for data_config in dataset_configs:
|
for data_config in dataset_configs:
|
||||||
target_path, mount_context = data_config.to_input_dataset_local(strictly_aml_v1, workspace, ml_client)
|
target_path, mount_context = data_config.to_input_dataset_local(workspace)
|
||||||
|
|
||||||
mounted_input_datasets.append(target_path)
|
mounted_input_datasets.append(target_path)
|
||||||
|
|
||||||
|
|
|
@ -442,22 +442,20 @@ def effective_experiment_name(experiment_name: Optional[str], entry_script: Opti
|
||||||
|
|
||||||
|
|
||||||
def submit_run_v2(
|
def submit_run_v2(
|
||||||
workspace: Optional[Workspace],
|
ml_client: MLClient,
|
||||||
environment: EnvironmentV2,
|
environment: EnvironmentV2,
|
||||||
|
entry_script: PathOrString,
|
||||||
|
script_params: List[str],
|
||||||
|
compute_target: str,
|
||||||
environment_variables: Optional[Dict[str, str]] = None,
|
environment_variables: Optional[Dict[str, str]] = None,
|
||||||
experiment_name: Optional[str] = None,
|
experiment_name: Optional[str] = None,
|
||||||
input_datasets_v2: Optional[Dict[str, Input]] = None,
|
input_datasets_v2: Optional[Dict[str, Input]] = None,
|
||||||
output_datasets_v2: Optional[Dict[str, Output]] = None,
|
output_datasets_v2: Optional[Dict[str, Output]] = None,
|
||||||
snapshot_root_directory: Optional[Path] = None,
|
snapshot_root_directory: Optional[Path] = None,
|
||||||
entry_script: Optional[PathOrString] = None,
|
|
||||||
script_params: Optional[List[str]] = None,
|
|
||||||
compute_target: Optional[str] = None,
|
|
||||||
tags: Optional[Dict[str, str]] = None,
|
tags: Optional[Dict[str, str]] = None,
|
||||||
docker_shm_size: str = "",
|
docker_shm_size: str = "",
|
||||||
wait_for_completion: bool = False,
|
wait_for_completion: bool = False,
|
||||||
identity_based_auth: bool = False,
|
identity_based_auth: bool = False,
|
||||||
workspace_config_path: Optional[PathOrString] = None,
|
|
||||||
ml_client: Optional[MLClient] = None,
|
|
||||||
hyperparam_args: Optional[Dict[str, Any]] = None,
|
hyperparam_args: Optional[Dict[str, Any]] = None,
|
||||||
num_nodes: int = 1,
|
num_nodes: int = 1,
|
||||||
pytorch_processes_per_node: Optional[int] = None,
|
pytorch_processes_per_node: Optional[int] = None,
|
||||||
|
@ -466,8 +464,11 @@ def submit_run_v2(
|
||||||
"""
|
"""
|
||||||
Starts a v2 AML Job on a given workspace by submitting a command
|
Starts a v2 AML Job on a given workspace by submitting a command
|
||||||
|
|
||||||
:param workspace: The AzureML workspace to use.
|
:param ml_client: An Azure MLClient object for interacting with Azure resources.
|
||||||
:param environment: An AML v2 Environment object.
|
:param environment: An AML v2 Environment object.
|
||||||
|
:param entry_script: The script that should be run in AzureML.
|
||||||
|
:param script_params: A list of parameter to pass on to the script as it runs in AzureML.
|
||||||
|
:param compute_target: The name of a compute target in Azure ML to submit the job to.
|
||||||
:param environment_variables: The environment variables that should be set when running in AzureML.
|
:param environment_variables: The environment variables that should be set when running in AzureML.
|
||||||
:param experiment_name: The name of the experiment that will be used or created. If the experiment name contains
|
:param experiment_name: The name of the experiment that will be used or created. If the experiment name contains
|
||||||
characters that are not valid in Azure, those will be removed.
|
characters that are not valid in Azure, those will be removed.
|
||||||
|
@ -475,18 +476,11 @@ def submit_run_v2(
|
||||||
:param output_datasets_v2: An optional dictionary of Outputs to pass in to the command.
|
:param output_datasets_v2: An optional dictionary of Outputs to pass in to the command.
|
||||||
:param snapshot_root_directory: The directory that contains all code that should be packaged and sent to AzureML.
|
:param snapshot_root_directory: The directory that contains all code that should be packaged and sent to AzureML.
|
||||||
All Python code that the script uses must be copied over.
|
All Python code that the script uses must be copied over.
|
||||||
:param entry_script: The script that should be run in AzureML.
|
|
||||||
:param script_params: A list of parameter to pass on to the script as it runs in AzureML.
|
|
||||||
:param compute_target: Optional name of a compute target in Azure ML to submit the job to. If None, will run
|
|
||||||
locally.
|
|
||||||
:param tags: A dictionary of string key/value pairs, that will be added as metadata to the run. If set to None,
|
:param tags: A dictionary of string key/value pairs, that will be added as metadata to the run. If set to None,
|
||||||
a default metadata field will be added that only contains the commandline arguments that started the run.
|
a default metadata field will be added that only contains the commandline arguments that started the run.
|
||||||
:param docker_shm_size: The Docker shared memory size that should be used when creating a new Docker image.
|
:param docker_shm_size: The Docker shared memory size that should be used when creating a new Docker image.
|
||||||
:param wait_for_completion: If False (the default) return after the run is submitted to AzureML, otherwise wait for
|
:param wait_for_completion: If False (the default) return after the run is submitted to AzureML, otherwise wait for
|
||||||
the completion of this run (if True).
|
the completion of this run (if True).
|
||||||
:param workspace_config_path: If not provided with an AzureML Workspace, then load one given the information in this
|
|
||||||
config
|
|
||||||
:param ml_client: An Azure MLClient object for interacting with Azure resources.
|
|
||||||
:param hyperparam_args: A dictionary of hyperparameter search args to pass into a sweep job.
|
:param hyperparam_args: A dictionary of hyperparameter search args to pass into a sweep job.
|
||||||
:param num_nodes: The number of nodes to use for the job in AzureML. The value must be 1 or greater.
|
:param num_nodes: The number of nodes to use for the job in AzureML. The value must be 1 or greater.
|
||||||
:param pytorch_processes_per_node: For plain PyTorch multi-GPU processing: The number of processes per node.
|
:param pytorch_processes_per_node: For plain PyTorch multi-GPU processing: The number of processes per node.
|
||||||
|
@ -496,20 +490,6 @@ def submit_run_v2(
|
||||||
display name will be generated by AzureML.
|
display name will be generated by AzureML.
|
||||||
:return: An AzureML Run object.
|
:return: An AzureML Run object.
|
||||||
"""
|
"""
|
||||||
if ml_client is None:
|
|
||||||
if workspace is not None:
|
|
||||||
ml_client = get_ml_client(
|
|
||||||
subscription_id=workspace.subscription_id,
|
|
||||||
resource_group=workspace.resource_group,
|
|
||||||
workspace_name=workspace.name,
|
|
||||||
)
|
|
||||||
elif workspace_config_path is not None:
|
|
||||||
ml_client = get_ml_client(workspace_config_path=workspace_config_path)
|
|
||||||
else:
|
|
||||||
raise ValueError("Either workspace or workspace_config_path must be specified to connect to the Workspace")
|
|
||||||
|
|
||||||
assert compute_target is not None, "No compute_target has been provided"
|
|
||||||
assert entry_script is not None, "No entry_script has been provided"
|
|
||||||
snapshot_root_directory = snapshot_root_directory or Path.cwd()
|
snapshot_root_directory = snapshot_root_directory or Path.cwd()
|
||||||
root_dir = Path(snapshot_root_directory)
|
root_dir = Path(snapshot_root_directory)
|
||||||
|
|
||||||
|
@ -592,7 +572,11 @@ def submit_run_v2(
|
||||||
job_to_submit = create_command_job(cmd)
|
job_to_submit = create_command_job(cmd)
|
||||||
|
|
||||||
returned_job = ml_client.jobs.create_or_update(job_to_submit)
|
returned_job = ml_client.jobs.create_or_update(job_to_submit)
|
||||||
print(f"URL to job: {returned_job.services['Studio'].endpoint}") # type: ignore
|
print("\n==============================================================================")
|
||||||
|
# The ID field looks like /subscriptions/<sub>/resourceGroups/<rg?/providers/Microsoft.MachineLearningServices/..
|
||||||
|
print(f"Successfully queued run {(returned_job.id or '').split('/')[-1]}")
|
||||||
|
print(f"Run URL: {returned_job.services['Studio'].endpoint}") # type: ignore
|
||||||
|
print("==============================================================================\n")
|
||||||
if wait_for_completion:
|
if wait_for_completion:
|
||||||
print("Waiting for the completion of the AzureML job.")
|
print("Waiting for the completion of the AzureML job.")
|
||||||
wait_for_job_completion(ml_client, job_name=returned_job.name)
|
wait_for_job_completion(ml_client, job_name=returned_job.name)
|
||||||
|
@ -671,7 +655,7 @@ def submit_run(
|
||||||
|
|
||||||
# These need to be 'print' not 'logging.info' so that the calling script sees them outside AzureML
|
# These need to be 'print' not 'logging.info' so that the calling script sees them outside AzureML
|
||||||
print("\n==============================================================================")
|
print("\n==============================================================================")
|
||||||
print(f"Successfully queued run number {run.number} (ID {run.id}) in experiment {run.experiment.name}")
|
print(f"Successfully queued run {run.id} in experiment {run.experiment.name}")
|
||||||
print(f"Experiment name and run ID are available in file {RUN_RECOVERY_FILE}")
|
print(f"Experiment name and run ID are available in file {RUN_RECOVERY_FILE}")
|
||||||
print(f"Experiment URL: {run.experiment.get_portal_url()}")
|
print(f"Experiment URL: {run.experiment.get_portal_url()}")
|
||||||
print(f"Run URL: {run.get_portal_url()}")
|
print(f"Run URL: {run.get_portal_url()}")
|
||||||
|
@ -885,6 +869,18 @@ def submit_to_azure_if_needed( # type: ignore
|
||||||
# is necessary. If not, return to the caller for local execution.
|
# is necessary. If not, return to the caller for local execution.
|
||||||
if submit_to_azureml is None:
|
if submit_to_azureml is None:
|
||||||
submit_to_azureml = AZUREML_FLAG in sys.argv[1:]
|
submit_to_azureml = AZUREML_FLAG in sys.argv[1:]
|
||||||
|
|
||||||
|
has_input_datasets = len(cleaned_input_datasets) > 0
|
||||||
|
if submit_to_azureml or has_input_datasets:
|
||||||
|
if strictly_aml_v1:
|
||||||
|
aml_workspace = get_workspace(aml_workspace, workspace_config_path)
|
||||||
|
assert aml_workspace is not None
|
||||||
|
print(f"Loaded AzureML workspace {aml_workspace.name}")
|
||||||
|
else:
|
||||||
|
ml_client = get_ml_client(ml_client=ml_client, workspace_config_path=workspace_config_path)
|
||||||
|
assert ml_client is not None
|
||||||
|
print(f"Created MLClient for AzureML workspace {ml_client.workspace_name}")
|
||||||
|
|
||||||
if not submit_to_azureml:
|
if not submit_to_azureml:
|
||||||
# Set the environment variables for local execution.
|
# Set the environment variables for local execution.
|
||||||
environment_variables = {**DEFAULT_ENVIRONMENT_VARIABLES, **(environment_variables or {})}
|
environment_variables = {**DEFAULT_ENVIRONMENT_VARIABLES, **(environment_variables or {})}
|
||||||
|
@ -898,16 +894,24 @@ def submit_to_azure_if_needed( # type: ignore
|
||||||
logs_folder = Path.cwd() / LOGS_FOLDER
|
logs_folder = Path.cwd() / LOGS_FOLDER
|
||||||
logs_folder.mkdir(exist_ok=True)
|
logs_folder.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
any_local_folders_missing = any(dataset.local_folder is None for dataset in cleaned_input_datasets)
|
||||||
|
|
||||||
|
if has_input_datasets and any_local_folders_missing and not strictly_aml_v1:
|
||||||
|
raise ValueError(
|
||||||
|
"AzureML SDK v2 does not support downloading datasets from AzureML for local execution. "
|
||||||
|
"Please switch to AzureML SDK v1 by setting strictly_aml_v1=True, or use "
|
||||||
|
"--strictly_aml_v1 on the commandline, or provide a local folder for each input dataset. "
|
||||||
|
"Note that you will not be able use AzureML datasets for runs outside AzureML if the datasets were "
|
||||||
|
"created via SDK v2."
|
||||||
|
)
|
||||||
|
|
||||||
mounted_input_datasets, mount_contexts = setup_local_datasets(
|
mounted_input_datasets, mount_contexts = setup_local_datasets(
|
||||||
cleaned_input_datasets,
|
cleaned_input_datasets,
|
||||||
strictly_aml_v1,
|
workspace=aml_workspace,
|
||||||
aml_workspace=aml_workspace,
|
|
||||||
ml_client=ml_client,
|
|
||||||
workspace_config_path=workspace_config_path,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return AzureRunInfo(
|
return AzureRunInfo(
|
||||||
input_datasets=mounted_input_datasets,
|
input_datasets=mounted_input_datasets, # type: ignore
|
||||||
output_datasets=[d.local_folder for d in cleaned_output_datasets],
|
output_datasets=[d.local_folder for d in cleaned_output_datasets],
|
||||||
mount_contexts=mount_contexts,
|
mount_contexts=mount_contexts,
|
||||||
run=None,
|
run=None,
|
||||||
|
@ -920,9 +924,6 @@ def submit_to_azure_if_needed( # type: ignore
|
||||||
print(f"No snapshot root directory given. Uploading all files in the current directory {Path.cwd()}")
|
print(f"No snapshot root directory given. Uploading all files in the current directory {Path.cwd()}")
|
||||||
snapshot_root_directory = Path.cwd()
|
snapshot_root_directory = Path.cwd()
|
||||||
|
|
||||||
workspace = get_workspace(aml_workspace, workspace_config_path)
|
|
||||||
print(f"Loaded AzureML workspace {workspace.name}")
|
|
||||||
|
|
||||||
if conda_environment_file is None:
|
if conda_environment_file is None:
|
||||||
conda_environment_file = find_file_in_parent_to_pythonpath(CONDA_ENVIRONMENT_FILE)
|
conda_environment_file = find_file_in_parent_to_pythonpath(CONDA_ENVIRONMENT_FILE)
|
||||||
if conda_environment_file is None:
|
if conda_environment_file is None:
|
||||||
|
@ -938,8 +939,9 @@ def submit_to_azure_if_needed( # type: ignore
|
||||||
|
|
||||||
with append_to_amlignore(amlignore=amlignore_path, lines_to_append=lines_to_append):
|
with append_to_amlignore(amlignore=amlignore_path, lines_to_append=lines_to_append):
|
||||||
if strictly_aml_v1:
|
if strictly_aml_v1:
|
||||||
|
assert aml_workspace is not None, "An AzureML workspace should have been created already."
|
||||||
run_config = create_run_configuration(
|
run_config = create_run_configuration(
|
||||||
workspace=workspace,
|
workspace=aml_workspace,
|
||||||
compute_cluster_name=compute_cluster_name,
|
compute_cluster_name=compute_cluster_name,
|
||||||
aml_environment_name=aml_environment_name,
|
aml_environment_name=aml_environment_name,
|
||||||
conda_environment_file=conda_environment_file,
|
conda_environment_file=conda_environment_file,
|
||||||
|
@ -968,7 +970,7 @@ def submit_to_azure_if_needed( # type: ignore
|
||||||
config_to_submit = script_run_config
|
config_to_submit = script_run_config
|
||||||
|
|
||||||
run = submit_run(
|
run = submit_run(
|
||||||
workspace=workspace,
|
workspace=aml_workspace,
|
||||||
experiment_name=effective_experiment_name(experiment_name, script_run_config.script),
|
experiment_name=effective_experiment_name(experiment_name, script_run_config.script),
|
||||||
script_run_config=config_to_submit,
|
script_run_config=config_to_submit,
|
||||||
tags=tags,
|
tags=tags,
|
||||||
|
@ -979,6 +981,7 @@ def submit_to_azure_if_needed( # type: ignore
|
||||||
if after_submission is not None:
|
if after_submission is not None:
|
||||||
after_submission(run) # type: ignore
|
after_submission(run) # type: ignore
|
||||||
else:
|
else:
|
||||||
|
assert ml_client is not None, "An AzureML MLClient should have been created already."
|
||||||
if conda_environment_file is None:
|
if conda_environment_file is None:
|
||||||
raise ValueError("Argument 'conda_environment_file' must be specified when using AzureML v2")
|
raise ValueError("Argument 'conda_environment_file' must be specified when using AzureML v2")
|
||||||
environment = create_python_environment_v2(
|
environment = create_python_environment_v2(
|
||||||
|
@ -987,13 +990,12 @@ def submit_to_azure_if_needed( # type: ignore
|
||||||
if entry_script is None:
|
if entry_script is None:
|
||||||
entry_script = Path(sys.argv[0])
|
entry_script = Path(sys.argv[0])
|
||||||
|
|
||||||
ml_client = get_ml_client(ml_client=ml_client, aml_workspace=workspace)
|
|
||||||
registered_env = register_environment_v2(environment, ml_client)
|
registered_env = register_environment_v2(environment, ml_client)
|
||||||
input_datasets_v2 = create_v2_inputs(ml_client, cleaned_input_datasets)
|
input_datasets_v2 = create_v2_inputs(ml_client, cleaned_input_datasets)
|
||||||
output_datasets_v2 = create_v2_outputs(ml_client, cleaned_output_datasets)
|
output_datasets_v2 = create_v2_outputs(ml_client, cleaned_output_datasets)
|
||||||
|
|
||||||
job = submit_run_v2(
|
job = submit_run_v2(
|
||||||
workspace=workspace,
|
ml_client=ml_client,
|
||||||
input_datasets_v2=input_datasets_v2,
|
input_datasets_v2=input_datasets_v2,
|
||||||
output_datasets_v2=output_datasets_v2,
|
output_datasets_v2=output_datasets_v2,
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
|
|
|
@ -39,12 +39,7 @@ def main() -> None: # pragma: no cover
|
||||||
|
|
||||||
files_to_download = download_config.files_to_download
|
files_to_download = download_config.files_to_download
|
||||||
|
|
||||||
workspace = get_workspace()
|
ml_client = get_ml_client()
|
||||||
ml_client = get_ml_client(
|
|
||||||
subscription_id=workspace.subscription_id,
|
|
||||||
resource_group=workspace.resource_group,
|
|
||||||
workspace_name=workspace.name,
|
|
||||||
)
|
|
||||||
for run_id in download_config.run:
|
for run_id in download_config.run:
|
||||||
download_job_outputs_logs(ml_client, run_id, file_to_download_path=files_to_download, download_dir=output_dir)
|
download_job_outputs_logs(ml_client, run_id, file_to_download_path=files_to_download, download_dir=output_dir)
|
||||||
print("Successfully downloaded output and log files")
|
print("Successfully downloaded output and log files")
|
||||||
|
|
|
@ -13,7 +13,6 @@ from typing import Generator, Optional, Union
|
||||||
from health_azure.utils import ENV_LOCAL_RANK, check_is_any_of, is_global_rank_zero
|
from health_azure.utils import ENV_LOCAL_RANK, check_is_any_of, is_global_rank_zero
|
||||||
|
|
||||||
logging_stdout_handler: Optional[logging.StreamHandler] = None
|
logging_stdout_handler: Optional[logging.StreamHandler] = None
|
||||||
logging_to_file_handler: Optional[logging.StreamHandler] = None
|
|
||||||
|
|
||||||
|
|
||||||
def logging_to_stdout(log_level: Union[int, str] = logging.INFO) -> None:
|
def logging_to_stdout(log_level: Union[int, str] = logging.INFO) -> None:
|
||||||
|
|
|
@ -288,6 +288,30 @@ def find_file_in_parent_to_pythonpath(file_name: str) -> Optional[Path]:
|
||||||
return find_file_in_parent_folders(file_name=file_name, stop_at_path=pythonpaths)
|
return find_file_in_parent_folders(file_name=file_name, stop_at_path=pythonpaths)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_workspace_config_path(workspace_config_path: Optional[Path] = None) -> Optional[Path]:
|
||||||
|
"""Retrieve the path to the workspace config file, either from the argument, or from the current working directory.
|
||||||
|
|
||||||
|
:param workspace_config_path: A path to a workspace config file that was provided on the commandline, defaults to
|
||||||
|
None
|
||||||
|
:return: The path to the workspace config file, or None if it cannot be found.
|
||||||
|
:raises FileNotFoundError: If the workspace config file that was provided as an argument does not exist.
|
||||||
|
"""
|
||||||
|
if workspace_config_path is None:
|
||||||
|
logger.info(
|
||||||
|
f"Trying to locate the workspace config file '{WORKSPACE_CONFIG_JSON}' in the current folder "
|
||||||
|
"and its parent folders"
|
||||||
|
)
|
||||||
|
result = find_file_in_parent_to_pythonpath(WORKSPACE_CONFIG_JSON)
|
||||||
|
if result:
|
||||||
|
logger.info(f"Using the workspace config file {str(result.absolute())}")
|
||||||
|
else:
|
||||||
|
logger.debug("No workspace config file found")
|
||||||
|
return result
|
||||||
|
if not workspace_config_path.is_file():
|
||||||
|
raise FileNotFoundError(f"Workspace config file does not exist: {workspace_config_path}")
|
||||||
|
return workspace_config_path
|
||||||
|
|
||||||
|
|
||||||
def get_workspace(aml_workspace: Optional[Workspace] = None, workspace_config_path: Optional[Path] = None) -> Workspace:
|
def get_workspace(aml_workspace: Optional[Workspace] = None, workspace_config_path: Optional[Path] = None) -> Workspace:
|
||||||
"""
|
"""
|
||||||
Retrieve an Azure ML Workspace by going through the following steps:
|
Retrieve an Azure ML Workspace by going through the following steps:
|
||||||
|
@ -320,26 +344,16 @@ def get_workspace(aml_workspace: Optional[Workspace] = None, workspace_config_pa
|
||||||
if aml_workspace:
|
if aml_workspace:
|
||||||
return aml_workspace
|
return aml_workspace
|
||||||
|
|
||||||
if workspace_config_path is None:
|
workspace_config_path = resolve_workspace_config_path(workspace_config_path)
|
||||||
logging.info(
|
|
||||||
f"Trying to locate the workspace config file '{WORKSPACE_CONFIG_JSON}' in the current folder "
|
|
||||||
"and its parent folders"
|
|
||||||
)
|
|
||||||
workspace_config_path = find_file_in_parent_to_pythonpath(WORKSPACE_CONFIG_JSON)
|
|
||||||
if workspace_config_path:
|
|
||||||
logging.info(f"Using the workspace config file {str(workspace_config_path.absolute())}")
|
|
||||||
|
|
||||||
auth = get_authentication()
|
auth = get_authentication()
|
||||||
if workspace_config_path is not None:
|
if workspace_config_path is not None:
|
||||||
if not workspace_config_path.is_file():
|
|
||||||
raise FileNotFoundError(f"Workspace config file does not exist: {workspace_config_path}")
|
|
||||||
workspace = Workspace.from_config(path=str(workspace_config_path), auth=auth)
|
workspace = Workspace.from_config(path=str(workspace_config_path), auth=auth)
|
||||||
logging.info(
|
logger.info(
|
||||||
f"Logged into AzureML workspace {workspace.name} as specified in config file " f"{workspace_config_path}"
|
f"Logged into AzureML workspace {workspace.name} as specified in config file " f"{workspace_config_path}"
|
||||||
)
|
)
|
||||||
return workspace
|
return workspace
|
||||||
|
|
||||||
logging.info("Trying to load the environment variables that define the workspace.")
|
logger.info("Trying to load the environment variables that define the workspace.")
|
||||||
workspace_name = get_secret_from_environment(ENV_WORKSPACE_NAME, allow_missing=True)
|
workspace_name = get_secret_from_environment(ENV_WORKSPACE_NAME, allow_missing=True)
|
||||||
subscription_id = get_secret_from_environment(ENV_SUBSCRIPTION_ID, allow_missing=True)
|
subscription_id = get_secret_from_environment(ENV_SUBSCRIPTION_ID, allow_missing=True)
|
||||||
resource_group = get_secret_from_environment(ENV_RESOURCE_GROUP, allow_missing=True)
|
resource_group = get_secret_from_environment(ENV_RESOURCE_GROUP, allow_missing=True)
|
||||||
|
@ -347,7 +361,7 @@ def get_workspace(aml_workspace: Optional[Workspace] = None, workspace_config_pa
|
||||||
workspace = Workspace.get(
|
workspace = Workspace.get(
|
||||||
name=workspace_name, auth=auth, subscription_id=subscription_id, resource_group=resource_group
|
name=workspace_name, auth=auth, subscription_id=subscription_id, resource_group=resource_group
|
||||||
)
|
)
|
||||||
logging.info(f"Logged into AzureML workspace {workspace.name} as specified by environment variables")
|
logger.info(f"Logged into AzureML workspace {workspace.name} as specified by environment variables")
|
||||||
return workspace
|
return workspace
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -1747,7 +1761,8 @@ class UnitTestWorkspaceWrapper:
|
||||||
"""
|
"""
|
||||||
Init.
|
Init.
|
||||||
"""
|
"""
|
||||||
self._workspace: Workspace = None
|
self._workspace: Optional[Workspace] = None
|
||||||
|
self._ml_client: Optional[MLClient] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def workspace(self) -> Workspace:
|
def workspace(self) -> Workspace:
|
||||||
|
@ -1758,6 +1773,15 @@ class UnitTestWorkspaceWrapper:
|
||||||
self._workspace = get_workspace()
|
self._workspace = get_workspace()
|
||||||
return self._workspace
|
return self._workspace
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ml_client(self) -> MLClient:
|
||||||
|
"""
|
||||||
|
Lazily load the ML Client.
|
||||||
|
"""
|
||||||
|
if self._ml_client is None:
|
||||||
|
self._ml_client = get_ml_client()
|
||||||
|
return self._ml_client
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def check_config_json(script_folder: Path, shared_config_json: Path) -> Generator:
|
def check_config_json(script_folder: Path, shared_config_json: Path) -> Generator:
|
||||||
|
@ -1895,7 +1919,7 @@ def _get_legitimate_interactive_browser_credential() -> Optional[TokenCredential
|
||||||
|
|
||||||
def get_credential() -> Optional[TokenCredential]:
|
def get_credential() -> Optional[TokenCredential]:
|
||||||
"""
|
"""
|
||||||
Get a credential for authenticating with Azure.There are multiple ways to retrieve a credential.
|
Get a credential for authenticating with Azure. There are multiple ways to retrieve a credential.
|
||||||
If environment variables pertaining to details of a Service Principal are available, those will be used
|
If environment variables pertaining to details of a Service Principal are available, those will be used
|
||||||
to authenticate. If no environment variables exist, and the script is not currently
|
to authenticate. If no environment variables exist, and the script is not currently
|
||||||
running inside of Azure ML or another Azure agent, will attempt to retrieve a credential via a
|
running inside of Azure ML or another Azure agent, will attempt to retrieve a credential via a
|
||||||
|
@ -1910,6 +1934,7 @@ def get_credential() -> Optional[TokenCredential]:
|
||||||
tenant_id = get_secret_from_environment(ENV_TENANT_ID, allow_missing=True)
|
tenant_id = get_secret_from_environment(ENV_TENANT_ID, allow_missing=True)
|
||||||
service_principal_password = get_secret_from_environment(ENV_SERVICE_PRINCIPAL_PASSWORD, allow_missing=True)
|
service_principal_password = get_secret_from_environment(ENV_SERVICE_PRINCIPAL_PASSWORD, allow_missing=True)
|
||||||
if service_principal_id and tenant_id and service_principal_password:
|
if service_principal_id and tenant_id and service_principal_password:
|
||||||
|
logger.debug("Found environment variables for Service Principal authentication")
|
||||||
return _get_legitimate_service_principal_credential(tenant_id, service_principal_id, service_principal_password)
|
return _get_legitimate_service_principal_credential(tenant_id, service_principal_id, service_principal_password)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -1927,66 +1952,76 @@ def get_credential() -> Optional[TokenCredential]:
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unable to generate and validate a credential. Please see Azure ML documentation"
|
"Unable to generate and validate a credential. Please see Azure ML documentation"
|
||||||
"for instructions on diffrent options to get a credential"
|
"for instructions on different options to get a credential"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_ml_client(
|
def get_ml_client(
|
||||||
ml_client: Optional[MLClient] = None,
|
ml_client: Optional[MLClient] = None,
|
||||||
aml_workspace: Optional[Workspace] = None,
|
workspace_config_path: Optional[Path] = None,
|
||||||
workspace_config_path: Optional[PathOrString] = None,
|
|
||||||
subscription_id: Optional[str] = None,
|
|
||||||
resource_group: Optional[str] = None,
|
|
||||||
workspace_name: str = "",
|
|
||||||
) -> MLClient:
|
) -> MLClient:
|
||||||
"""
|
"""
|
||||||
Instantiate an MLClient for interacting with Azure resources via v2 of the Azure ML SDK.
|
Instantiate an MLClient for interacting with Azure resources via v2 of the Azure ML SDK. The following ways of
|
||||||
If a ml_client is provided, return that. Otherwise, create one using workspace details
|
creating the client are tried out:
|
||||||
coming from either an existing Workspace object, a config.json file or passed in as an argument.
|
|
||||||
|
1. If an MLClient object has been provided in the `ml_client` argument, return that.
|
||||||
|
|
||||||
|
2. If a path to a workspace config file has been provided, load the MLClient according to that config file.
|
||||||
|
|
||||||
|
3. If a workspace config file is present in the current working directory or one of its parents, load the
|
||||||
|
MLClient according to that config file.
|
||||||
|
|
||||||
|
4. If 3 environment variables are found, use them to identify the workspace (`HIML_RESOURCE_GROUP`,
|
||||||
|
`HIML_SUBSCRIPTION_ID`, `HIML_WORKSPACE_NAME`)
|
||||||
|
|
||||||
|
If none of the above succeeds, an exception is raised.
|
||||||
|
|
||||||
:param ml_client: An optional existing MLClient object to be returned.
|
:param ml_client: An optional existing MLClient object to be returned.
|
||||||
:param aml_workspace: An optional Workspace object to take connection details from.
|
|
||||||
:param workspace_config_path: An optional path toa config.json file containing details of the Workspace.
|
:param workspace_config_path: An optional path toa config.json file containing details of the Workspace.
|
||||||
:param subscription_id: An optional subscription ID.
|
:param subscription_id: An optional subscription ID.
|
||||||
:param resource_group: An optional resource group name.
|
:param resource_group: An optional resource group name.
|
||||||
:param workspace_name: An optional workspace name.
|
:param workspace_name: An optional workspace name.
|
||||||
:return: An instance of MLClient to interact with Azure resources.
|
:return: An instance of MLClient to interact with Azure resources.
|
||||||
"""
|
"""
|
||||||
if ml_client:
|
if ml_client is not None:
|
||||||
return ml_client
|
return ml_client
|
||||||
|
logger.debug("Getting credentials")
|
||||||
credential = get_credential()
|
credential = get_credential()
|
||||||
if credential is None:
|
if credential is None:
|
||||||
raise ValueError("Can't connect to MLClient without a valid credential")
|
raise ValueError("Can't connect to MLClient without a valid credential")
|
||||||
if aml_workspace is not None:
|
workspace_config_path = resolve_workspace_config_path(workspace_config_path)
|
||||||
ml_client = MLClient(
|
if workspace_config_path is not None:
|
||||||
subscription_id=aml_workspace.subscription_id,
|
logger.debug(f"Retrieving MLClient from workspace config {workspace_config_path}")
|
||||||
resource_group_name=aml_workspace.resource_group,
|
|
||||||
workspace_name=aml_workspace.name,
|
|
||||||
credential=credential,
|
|
||||||
) # type: ignore
|
|
||||||
elif workspace_config_path:
|
|
||||||
ml_client = MLClient.from_config(credential=credential, path=str(workspace_config_path)) # type: ignore
|
ml_client = MLClient.from_config(credential=credential, path=str(workspace_config_path)) # type: ignore
|
||||||
elif subscription_id and resource_group and workspace_name:
|
logger.info(
|
||||||
|
f"Using MLClient for AzureML workspace {ml_client.workspace_name} as specified in config file"
|
||||||
|
f"{workspace_config_path}"
|
||||||
|
)
|
||||||
|
return ml_client
|
||||||
|
|
||||||
|
logger.info("Trying to load the environment variables that define the workspace.")
|
||||||
|
workspace_name = get_secret_from_environment(ENV_WORKSPACE_NAME, allow_missing=True)
|
||||||
|
subscription_id = get_secret_from_environment(ENV_SUBSCRIPTION_ID, allow_missing=True)
|
||||||
|
resource_group = get_secret_from_environment(ENV_RESOURCE_GROUP, allow_missing=True)
|
||||||
|
if workspace_name and subscription_id and resource_group:
|
||||||
|
logger.debug(
|
||||||
|
"Retrieving MLClient via subscription ID, resource group and workspace name retrieved from "
|
||||||
|
"environment variables."
|
||||||
|
)
|
||||||
ml_client = MLClient(
|
ml_client = MLClient(
|
||||||
subscription_id=subscription_id,
|
subscription_id=subscription_id,
|
||||||
resource_group_name=resource_group,
|
resource_group_name=resource_group,
|
||||||
workspace_name=workspace_name,
|
workspace_name=workspace_name,
|
||||||
credential=credential,
|
credential=credential,
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
else:
|
logger.info(f"Using MLClient for AzureML workspace {workspace_name} as specified by environment variables")
|
||||||
try:
|
return ml_client
|
||||||
workspace = get_workspace()
|
|
||||||
ml_client = MLClient(
|
raise ValueError(
|
||||||
subscription_id=workspace.subscription_id,
|
"Tried all ways of identifying the MLClient, but failed. Please provide a workspace config "
|
||||||
resource_group_name=workspace.resource_group,
|
f"file {WORKSPACE_CONFIG_JSON} or set the environment variables {ENV_RESOURCE_GROUP}, "
|
||||||
workspace_name=workspace.name,
|
f"{ENV_SUBSCRIPTION_ID}, and {ENV_WORKSPACE_NAME}."
|
||||||
credential=credential,
|
)
|
||||||
) # type: ignore
|
|
||||||
except ValueError as e:
|
|
||||||
raise ValueError(f"Couldn't connect to MLClient: {e}")
|
|
||||||
logging.info(f"Logged into AzureML workspace {ml_client.workspace_name}")
|
|
||||||
return ml_client
|
|
||||||
|
|
||||||
|
|
||||||
def retrieve_workspace_from_client(ml_client: MLClient, workspace_name: Optional[str] = None) -> WorkspaceV2:
|
def retrieve_workspace_from_client(ml_client: MLClient, workspace_name: Optional[str] = None) -> WorkspaceV2:
|
||||||
|
|
|
@ -13,7 +13,7 @@ import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Generator, List, Optional, Union
|
from typing import Any, Dict, Generator, List, Optional, Union
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
from unittest.mock import DEFAULT, MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from xmlrpc.client import Boolean
|
from xmlrpc.client import Boolean
|
||||||
|
|
||||||
|
@ -23,13 +23,12 @@ import pandas as pd
|
||||||
import param
|
import param
|
||||||
import pytest
|
import pytest
|
||||||
from _pytest.logging import LogCaptureFixture
|
from _pytest.logging import LogCaptureFixture
|
||||||
from azure.identity import ClientSecretCredential, DeviceCodeCredential, DefaultAzureCredential
|
|
||||||
from azure.storage.blob import ContainerClient
|
from azure.storage.blob import ContainerClient
|
||||||
from azureml._restclient.constants import RunStatus
|
from azureml._restclient.constants import RunStatus
|
||||||
from azureml.core import Experiment, Run, ScriptRunConfig, Workspace
|
from azureml.core import Experiment, Run, ScriptRunConfig, Workspace
|
||||||
from azureml.core.run import _OfflineRun
|
from azureml.core.run import _OfflineRun
|
||||||
from azureml.core.environment import CondaDependencies
|
from azureml.core.environment import CondaDependencies
|
||||||
from azure.core.exceptions import ClientAuthenticationError, ResourceNotFoundError
|
from azure.core.exceptions import ResourceNotFoundError
|
||||||
from azureml.data.azure_storage_datastore import AzureBlobDatastore
|
from azureml.data.azure_storage_datastore import AzureBlobDatastore
|
||||||
|
|
||||||
import health_azure.utils as util
|
import health_azure.utils as util
|
||||||
|
@ -41,8 +40,8 @@ from health_azure.utils import (
|
||||||
MASTER_PORT_DEFAULT,
|
MASTER_PORT_DEFAULT,
|
||||||
PackageDependency,
|
PackageDependency,
|
||||||
download_files_by_suffix,
|
download_files_by_suffix,
|
||||||
get_credential,
|
|
||||||
download_file_if_necessary,
|
download_file_if_necessary,
|
||||||
|
resolve_workspace_config_path,
|
||||||
)
|
)
|
||||||
from testazure.test_himl import RunTarget, render_and_run_test_script
|
from testazure.test_himl import RunTarget, render_and_run_test_script
|
||||||
from testazure.utils_testazure import (
|
from testazure.utils_testazure import (
|
||||||
|
@ -1996,136 +1995,6 @@ def test_create_run() -> None:
|
||||||
run.complete()
|
run.complete()
|
||||||
|
|
||||||
|
|
||||||
def test_get_credential() -> None:
|
|
||||||
def _mock_validation_error() -> None:
|
|
||||||
raise ClientAuthenticationError("")
|
|
||||||
|
|
||||||
# test the case where service principal credentials are set as environment variables
|
|
||||||
mock_env_vars = {
|
|
||||||
util.ENV_SERVICE_PRINCIPAL_ID: "foo",
|
|
||||||
util.ENV_TENANT_ID: "bar",
|
|
||||||
util.ENV_SERVICE_PRINCIPAL_PASSWORD: "baz",
|
|
||||||
}
|
|
||||||
|
|
||||||
with patch.object(os.environ, "get", return_value=mock_env_vars):
|
|
||||||
with patch.multiple(
|
|
||||||
"health_azure.utils",
|
|
||||||
is_running_in_azure_ml=DEFAULT,
|
|
||||||
is_running_on_azure_agent=DEFAULT,
|
|
||||||
_get_legitimate_service_principal_credential=DEFAULT,
|
|
||||||
_get_legitimate_device_code_credential=DEFAULT,
|
|
||||||
_get_legitimate_default_credential=DEFAULT,
|
|
||||||
_get_legitimate_interactive_browser_credential=DEFAULT,
|
|
||||||
) as mocks:
|
|
||||||
mocks["is_running_in_azure_ml"].return_value = False
|
|
||||||
mocks["is_running_on_azure_agent"].return_value = False
|
|
||||||
_ = get_credential()
|
|
||||||
mocks["_get_legitimate_service_principal_credential"].assert_called_once()
|
|
||||||
mocks["_get_legitimate_device_code_credential"].assert_not_called()
|
|
||||||
mocks["_get_legitimate_default_credential"].assert_not_called()
|
|
||||||
mocks["_get_legitimate_interactive_browser_credential"].assert_not_called()
|
|
||||||
|
|
||||||
# if the environment variables are not set and we are running on a local machine, a
|
|
||||||
# DefaultAzureCredential should be attempted first
|
|
||||||
with patch.object(os.environ, "get", return_value={}):
|
|
||||||
with patch.multiple(
|
|
||||||
"health_azure.utils",
|
|
||||||
is_running_in_azure_ml=DEFAULT,
|
|
||||||
is_running_on_azure_agent=DEFAULT,
|
|
||||||
_get_legitimate_service_principal_credential=DEFAULT,
|
|
||||||
_get_legitimate_device_code_credential=DEFAULT,
|
|
||||||
_get_legitimate_default_credential=DEFAULT,
|
|
||||||
_get_legitimate_interactive_browser_credential=DEFAULT,
|
|
||||||
) as mocks:
|
|
||||||
mock_get_sp_cred = mocks["_get_legitimate_service_principal_credential"]
|
|
||||||
mock_get_device_cred = mocks["_get_legitimate_device_code_credential"]
|
|
||||||
mock_get_default_cred = mocks["_get_legitimate_default_credential"]
|
|
||||||
mock_get_browser_cred = mocks["_get_legitimate_interactive_browser_credential"]
|
|
||||||
|
|
||||||
mocks["is_running_in_azure_ml"].return_value = False
|
|
||||||
mocks["is_running_on_azure_agent"].return_value = False
|
|
||||||
_ = get_credential()
|
|
||||||
mock_get_sp_cred.assert_not_called()
|
|
||||||
mock_get_device_cred.assert_not_called()
|
|
||||||
mock_get_default_cred.assert_called_once()
|
|
||||||
mock_get_browser_cred.assert_not_called()
|
|
||||||
|
|
||||||
# if that fails, a DeviceCode credential should be attempted
|
|
||||||
mock_get_default_cred.side_effect = _mock_validation_error
|
|
||||||
_ = get_credential()
|
|
||||||
mock_get_sp_cred.assert_not_called()
|
|
||||||
mock_get_device_cred.assert_called_once()
|
|
||||||
assert mock_get_default_cred.call_count == 2
|
|
||||||
mock_get_browser_cred.assert_not_called()
|
|
||||||
|
|
||||||
# if None of the previous credentials work, an InteractiveBrowser credential should be tried
|
|
||||||
mock_get_device_cred.return_value = None
|
|
||||||
_ = get_credential()
|
|
||||||
mock_get_sp_cred.assert_not_called()
|
|
||||||
assert mock_get_device_cred.call_count == 2
|
|
||||||
assert mock_get_default_cred.call_count == 3
|
|
||||||
mock_get_browser_cred.assert_called_once()
|
|
||||||
|
|
||||||
# finally, if none of the methods work, an Exception should be raised
|
|
||||||
mock_get_browser_cred.return_value = None
|
|
||||||
with pytest.raises(Exception) as e:
|
|
||||||
get_credential()
|
|
||||||
assert (
|
|
||||||
"Unable to generate and validate a credential. Please see Azure ML documentation"
|
|
||||||
"for instructions on different options to get a credential" in str(e)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_legitimate_service_principal_credential() -> None:
|
|
||||||
# first attempt to create and valiadate a credential with non-existant service principal credentials
|
|
||||||
# and check it fails
|
|
||||||
mock_service_principal_id = "foo"
|
|
||||||
mock_service_principal_password = "bar"
|
|
||||||
mock_tenant_id = "baz"
|
|
||||||
expected_error_msg = f"Found environment variables for {util.ENV_SERVICE_PRINCIPAL_ID}, "
|
|
||||||
f"{util.ENV_SERVICE_PRINCIPAL_PASSWORD}, and {util.ENV_TENANT_ID} but was not able to authenticate"
|
|
||||||
with pytest.raises(Exception) as e:
|
|
||||||
util._get_legitimate_service_principal_credential(
|
|
||||||
mock_tenant_id, mock_service_principal_id, mock_service_principal_password
|
|
||||||
)
|
|
||||||
assert expected_error_msg in str(e)
|
|
||||||
|
|
||||||
# now mock the case where validating the credential succeeds and check the value of that
|
|
||||||
with patch("health_azure.utils._validate_credential"):
|
|
||||||
cred = util._get_legitimate_service_principal_credential(
|
|
||||||
mock_tenant_id, mock_service_principal_id, mock_service_principal_password
|
|
||||||
)
|
|
||||||
assert isinstance(cred, ClientSecretCredential)
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_legitimate_device_code_credential() -> None:
|
|
||||||
def _mock_credential_fast_timeout(timeout: int) -> DeviceCodeCredential:
|
|
||||||
return DeviceCodeCredential(timeout=1)
|
|
||||||
|
|
||||||
with patch("health_azure.utils.DeviceCodeCredential", new=_mock_credential_fast_timeout):
|
|
||||||
cred = util._get_legitimate_device_code_credential()
|
|
||||||
assert cred is None
|
|
||||||
|
|
||||||
# now mock the case where validating the credential succeeds
|
|
||||||
with patch("health_azure.utils._validate_credential"):
|
|
||||||
cred = util._get_legitimate_device_code_credential()
|
|
||||||
assert isinstance(cred, DeviceCodeCredential)
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_legitimate_default_credential() -> None:
|
|
||||||
def _mock_credential_fast_timeout(timeout: int) -> DefaultAzureCredential:
|
|
||||||
return DefaultAzureCredential(timeout=1)
|
|
||||||
|
|
||||||
with patch("health_azure.utils.DefaultAzureCredential", new=_mock_credential_fast_timeout):
|
|
||||||
exception_message = r"DefaultAzureCredential failed to retrieve a token from the included credentials."
|
|
||||||
with pytest.raises(ClientAuthenticationError, match=exception_message):
|
|
||||||
cred = util._get_legitimate_default_credential()
|
|
||||||
|
|
||||||
with patch("health_azure.utils._validate_credential"):
|
|
||||||
cred = util._get_legitimate_default_credential()
|
|
||||||
assert isinstance(cred, DefaultAzureCredential)
|
|
||||||
|
|
||||||
|
|
||||||
def test_filter_v2_input_output_args() -> None:
|
def test_filter_v2_input_output_args() -> None:
|
||||||
def _compare_args(expected: List[str], actual: List[str]) -> None:
|
def _compare_args(expected: List[str], actual: List[str]) -> None:
|
||||||
assert len(actual) == len(expected)
|
assert len(actual) == len(expected)
|
||||||
|
@ -2244,3 +2113,33 @@ def test_download_files_by_suffix(tmp_path: Path, files: List[str], expected_dow
|
||||||
assert f.is_file()
|
assert f.is_file()
|
||||||
downloaded_filenames = [f.name for f in downloaded_list]
|
downloaded_filenames = [f.name for f in downloaded_list]
|
||||||
assert downloaded_filenames == expected_downloaded
|
assert downloaded_filenames == expected_downloaded
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_workspace_config_path_no_argument(tmp_path: Path) -> None:
|
||||||
|
"""Test for resolve_workspace_config_path without argument: It should try to find a config file in the folders.
|
||||||
|
If the file exists, it should return the path"""
|
||||||
|
mocked_file = tmp_path / "foo.json"
|
||||||
|
with patch("health_azure.utils.find_file_in_parent_to_pythonpath", return_value=mocked_file):
|
||||||
|
result = resolve_workspace_config_path()
|
||||||
|
assert result == mocked_file
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_workspace_config_path_no_argument_no_file() -> None:
|
||||||
|
"""Test for resolve_workspace_config_path without argument: It should try to find a config file in the folders.
|
||||||
|
If the file does not exist, return None"""
|
||||||
|
with patch("health_azure.utils.find_file_in_parent_to_pythonpath", return_value=None):
|
||||||
|
result = resolve_workspace_config_path()
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_workspace_config_path_file_exists(tmp_path: Path) -> None:
|
||||||
|
mocked_file = tmp_path / "foo.json"
|
||||||
|
mocked_file.touch()
|
||||||
|
result = resolve_workspace_config_path(mocked_file)
|
||||||
|
assert result == mocked_file
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_workspace_config_path_missing(tmp_path: Path) -> None:
|
||||||
|
mocked_file = tmp_path / "foo.json"
|
||||||
|
with pytest.raises(FileNotFoundError, match="Workspace config file does not exist"):
|
||||||
|
resolve_workspace_config_path(mocked_file)
|
||||||
|
|
|
@ -20,6 +20,7 @@ from azureml.data import FileDataset, OutputFileDatasetConfig
|
||||||
from azureml.data.azure_storage_datastore import AzureBlobDatastore
|
from azureml.data.azure_storage_datastore import AzureBlobDatastore
|
||||||
from azureml.data.dataset_consumption_config import DatasetConsumptionConfig
|
from azureml.data.dataset_consumption_config import DatasetConsumptionConfig
|
||||||
from azureml.exceptions._azureml_exception import UserErrorException
|
from azureml.exceptions._azureml_exception import UserErrorException
|
||||||
|
from health_azure.himl import submit_to_azure_if_needed
|
||||||
from testazure.utils_testazure import (
|
from testazure.utils_testazure import (
|
||||||
DEFAULT_DATASTORE,
|
DEFAULT_DATASTORE,
|
||||||
DEFAULT_WORKSPACE,
|
DEFAULT_WORKSPACE,
|
||||||
|
@ -27,7 +28,6 @@ from testazure.utils_testazure import (
|
||||||
TEST_DATA_ASSET_NAME,
|
TEST_DATA_ASSET_NAME,
|
||||||
TEST_INVALID_DATA_ASSET_NAME,
|
TEST_INVALID_DATA_ASSET_NAME,
|
||||||
TEST_DATASTORE_NAME,
|
TEST_DATASTORE_NAME,
|
||||||
get_test_ml_client,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from health_azure.datasets import (
|
from health_azure.datasets import (
|
||||||
|
@ -46,10 +46,7 @@ from health_azure.datasets import (
|
||||||
get_or_create_dataset,
|
get_or_create_dataset,
|
||||||
_get_latest_v2_asset_version,
|
_get_latest_v2_asset_version,
|
||||||
)
|
)
|
||||||
from health_azure.utils import PathOrString, get_ml_client
|
from health_azure.utils import PathOrString
|
||||||
|
|
||||||
|
|
||||||
TEST_ML_CLIENT = get_test_ml_client()
|
|
||||||
|
|
||||||
|
|
||||||
def test_datasetconfig_init() -> None:
|
def test_datasetconfig_init() -> None:
|
||||||
|
@ -234,12 +231,11 @@ def test_get_or_create_dataset() -> None:
|
||||||
|
|
||||||
data_asset_name = "himl_tiny_data_asset"
|
data_asset_name = "himl_tiny_data_asset"
|
||||||
workspace = DEFAULT_WORKSPACE.workspace
|
workspace = DEFAULT_WORKSPACE.workspace
|
||||||
ml_client = get_ml_client(aml_workspace=workspace)
|
|
||||||
# When creating a dataset, we need a non-empty name
|
# When creating a dataset, we need a non-empty name
|
||||||
with pytest.raises(ValueError) as ex:
|
with pytest.raises(ValueError) as ex:
|
||||||
get_or_create_dataset(
|
get_or_create_dataset(
|
||||||
workspace=workspace,
|
workspace=workspace,
|
||||||
ml_client=ml_client,
|
ml_client=DEFAULT_WORKSPACE.ml_client,
|
||||||
datastore_name="himldatasetsv2",
|
datastore_name="himldatasetsv2",
|
||||||
dataset_name="",
|
dataset_name="",
|
||||||
strictly_aml_v1=True,
|
strictly_aml_v1=True,
|
||||||
|
@ -254,7 +250,7 @@ def test_get_or_create_dataset() -> None:
|
||||||
mocks["_get_or_create_v1_dataset"].return_value = mock_v1_dataset
|
mocks["_get_or_create_v1_dataset"].return_value = mock_v1_dataset
|
||||||
dataset = get_or_create_dataset(
|
dataset = get_or_create_dataset(
|
||||||
workspace=workspace,
|
workspace=workspace,
|
||||||
ml_client=ml_client,
|
ml_client=DEFAULT_WORKSPACE.ml_client,
|
||||||
datastore_name="himldatasetsv2",
|
datastore_name="himldatasetsv2",
|
||||||
dataset_name=data_asset_name,
|
dataset_name=data_asset_name,
|
||||||
strictly_aml_v1=True,
|
strictly_aml_v1=True,
|
||||||
|
@ -268,7 +264,7 @@ def test_get_or_create_dataset() -> None:
|
||||||
mocks["_get_or_create_v2_data_asset"].return_value = mock_v2_dataset
|
mocks["_get_or_create_v2_data_asset"].return_value = mock_v2_dataset
|
||||||
dataset = get_or_create_dataset(
|
dataset = get_or_create_dataset(
|
||||||
workspace=workspace,
|
workspace=workspace,
|
||||||
ml_client=ml_client,
|
ml_client=DEFAULT_WORKSPACE.ml_client,
|
||||||
datastore_name="himldatasetsv2",
|
datastore_name="himldatasetsv2",
|
||||||
dataset_name=data_asset_name,
|
dataset_name=data_asset_name,
|
||||||
strictly_aml_v1=False,
|
strictly_aml_v1=False,
|
||||||
|
@ -281,7 +277,7 @@ def test_get_or_create_dataset() -> None:
|
||||||
mocks["_get_or_create_v2_data_asset"].side_effect = _mock_retrieve_or_create_v2_dataset_fails
|
mocks["_get_or_create_v2_data_asset"].side_effect = _mock_retrieve_or_create_v2_dataset_fails
|
||||||
dataset = get_or_create_dataset(
|
dataset = get_or_create_dataset(
|
||||||
workspace=workspace,
|
workspace=workspace,
|
||||||
ml_client=ml_client,
|
ml_client=DEFAULT_WORKSPACE.ml_client,
|
||||||
datastore_name="himldatasetsv2",
|
datastore_name="himldatasetsv2",
|
||||||
dataset_name=data_asset_name,
|
dataset_name=data_asset_name,
|
||||||
strictly_aml_v1=False,
|
strictly_aml_v1=False,
|
||||||
|
@ -417,7 +413,7 @@ def test_retrieve_v2_data_asset(asset_name: str, asset_version: Optional[str]) -
|
||||||
mock_get_v2_asset_version.side_effect = _get_latest_v2_asset_version
|
mock_get_v2_asset_version.side_effect = _get_latest_v2_asset_version
|
||||||
try:
|
try:
|
||||||
data_asset = _retrieve_v2_data_asset(
|
data_asset = _retrieve_v2_data_asset(
|
||||||
ml_client=TEST_ML_CLIENT,
|
ml_client=DEFAULT_WORKSPACE.ml_client,
|
||||||
data_asset_name=asset_name,
|
data_asset_name=asset_name,
|
||||||
version=asset_version,
|
version=asset_version,
|
||||||
)
|
)
|
||||||
|
@ -445,10 +441,12 @@ def test_retrieve_v2_data_asset(asset_name: str, asset_version: Optional[str]) -
|
||||||
|
|
||||||
|
|
||||||
def test_retrieve_v2_data_asset_invalid_version() -> None:
|
def test_retrieve_v2_data_asset_invalid_version() -> None:
|
||||||
invalid_asset_version = str(int(_get_latest_v2_asset_version(TEST_ML_CLIENT, TEST_DATA_ASSET_NAME)) + 1)
|
invalid_asset_version = str(
|
||||||
|
int(_get_latest_v2_asset_version(DEFAULT_WORKSPACE.ml_client, TEST_DATA_ASSET_NAME)) + 1
|
||||||
|
)
|
||||||
with pytest.raises(ResourceNotFoundError) as ex:
|
with pytest.raises(ResourceNotFoundError) as ex:
|
||||||
_retrieve_v2_data_asset(
|
_retrieve_v2_data_asset(
|
||||||
ml_client=TEST_ML_CLIENT,
|
ml_client=DEFAULT_WORKSPACE.ml_client,
|
||||||
data_asset_name=TEST_DATA_ASSET_NAME,
|
data_asset_name=TEST_DATA_ASSET_NAME,
|
||||||
version=invalid_asset_version,
|
version=invalid_asset_version,
|
||||||
)
|
)
|
||||||
|
@ -459,15 +457,19 @@ def test_retrieving_v2_data_asset_does_not_increment() -> None:
|
||||||
"""Test if calling the get_or_create_data_asset on an existing asset does not increment the version number."""
|
"""Test if calling the get_or_create_data_asset on an existing asset does not increment the version number."""
|
||||||
|
|
||||||
with patch("health_azure.datasets._create_v2_data_asset") as mock_create_v2_data_asset:
|
with patch("health_azure.datasets._create_v2_data_asset") as mock_create_v2_data_asset:
|
||||||
asset_version_before_get_or_create = _get_latest_v2_asset_version(TEST_ML_CLIENT, TEST_DATA_ASSET_NAME)
|
asset_version_before_get_or_create = _get_latest_v2_asset_version(
|
||||||
|
DEFAULT_WORKSPACE.ml_client, TEST_DATA_ASSET_NAME
|
||||||
|
)
|
||||||
get_or_create_dataset(
|
get_or_create_dataset(
|
||||||
TEST_DATASTORE_NAME,
|
TEST_DATASTORE_NAME,
|
||||||
TEST_DATA_ASSET_NAME,
|
TEST_DATA_ASSET_NAME,
|
||||||
DEFAULT_WORKSPACE,
|
DEFAULT_WORKSPACE,
|
||||||
strictly_aml_v1=False,
|
strictly_aml_v1=False,
|
||||||
ml_client=TEST_ML_CLIENT,
|
ml_client=DEFAULT_WORKSPACE.ml_client,
|
||||||
|
)
|
||||||
|
asset_version_after_get_or_create = _get_latest_v2_asset_version(
|
||||||
|
DEFAULT_WORKSPACE.ml_client, TEST_DATA_ASSET_NAME
|
||||||
)
|
)
|
||||||
asset_version_after_get_or_create = _get_latest_v2_asset_version(TEST_ML_CLIENT, TEST_DATA_ASSET_NAME)
|
|
||||||
|
|
||||||
mock_create_v2_data_asset.assert_not_called()
|
mock_create_v2_data_asset.assert_not_called()
|
||||||
assert asset_version_before_get_or_create == asset_version_after_get_or_create
|
assert asset_version_before_get_or_create == asset_version_after_get_or_create
|
||||||
|
@ -485,7 +487,7 @@ def test_retrieving_v2_data_asset_does_not_increment() -> None:
|
||||||
def test_create_v2_data_asset(asset_name: str, datastore_name: str, version: Optional[str]) -> None:
|
def test_create_v2_data_asset(asset_name: str, datastore_name: str, version: Optional[str]) -> None:
|
||||||
try:
|
try:
|
||||||
data_asset = _create_v2_data_asset(
|
data_asset = _create_v2_data_asset(
|
||||||
ml_client=TEST_ML_CLIENT,
|
ml_client=DEFAULT_WORKSPACE.ml_client,
|
||||||
datastore_name=TEST_DATASTORE_NAME,
|
datastore_name=TEST_DATASTORE_NAME,
|
||||||
data_asset_name=asset_name,
|
data_asset_name=asset_name,
|
||||||
version=version,
|
version=version,
|
||||||
|
@ -558,3 +560,36 @@ def test_create_dataset_configs() -> None:
|
||||||
with pytest.raises(Exception) as e:
|
with pytest.raises(Exception) as e:
|
||||||
create_dataset_configs(azure_datasets, dataset_mountpoints, local_datasets, datastore, use_mounting)
|
create_dataset_configs(azure_datasets, dataset_mountpoints, local_datasets, datastore, use_mounting)
|
||||||
assert "Invalid dataset setup" in str(e)
|
assert "Invalid dataset setup" in str(e)
|
||||||
|
|
||||||
|
|
||||||
|
def test_local_datasets() -> None:
|
||||||
|
"""Test if Azure datasets can be mounted for local runs"""
|
||||||
|
# Dataset hello_world must exist in the test AzureML workspace
|
||||||
|
dataset = DatasetConfig(name="hello_world")
|
||||||
|
run_info = submit_to_azure_if_needed(
|
||||||
|
input_datasets=[dataset],
|
||||||
|
strictly_aml_v1=True,
|
||||||
|
)
|
||||||
|
assert len(run_info.input_datasets) == 1
|
||||||
|
assert isinstance(run_info.input_datasets[0], Path)
|
||||||
|
assert run_info.input_datasets[0].is_dir()
|
||||||
|
assert len(list(run_info.input_datasets[0].glob("*"))) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_local_datasets_fails_with_v2() -> None:
|
||||||
|
"""Azure datasets can't be used when using SDK v2"""
|
||||||
|
dataset = DatasetConfig(name="himl-tiny_dataset")
|
||||||
|
with pytest.raises(ValueError, match="AzureML SDK v2 does not support downloading datasets from AzureML"):
|
||||||
|
submit_to_azure_if_needed(
|
||||||
|
input_datasets=[dataset],
|
||||||
|
strictly_aml_v1=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_local_datasets_fail_with_v2() -> None:
|
||||||
|
"""If no datasets are specified, we can still run with SDK v2"""
|
||||||
|
run_info = submit_to_azure_if_needed(
|
||||||
|
input_datasets=[],
|
||||||
|
strictly_aml_v1=False,
|
||||||
|
)
|
||||||
|
assert len(run_info.input_datasets) == 0
|
||||||
|
|
|
@ -0,0 +1,249 @@
|
||||||
|
# ------------------------------------------------------------------------------------------
|
||||||
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||||
|
# ------------------------------------------------------------------------------------------
|
||||||
|
"""
|
||||||
|
Tests for health_azure.azure_get_workspace and related functions.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import DEFAULT, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from azure.core.exceptions import ClientAuthenticationError
|
||||||
|
from azure.identity import ClientSecretCredential, DefaultAzureCredential, DeviceCodeCredential
|
||||||
|
|
||||||
|
from health_azure.utils import (
|
||||||
|
ENV_RESOURCE_GROUP,
|
||||||
|
ENV_SERVICE_PRINCIPAL_ID,
|
||||||
|
ENV_SERVICE_PRINCIPAL_PASSWORD,
|
||||||
|
ENV_SUBSCRIPTION_ID,
|
||||||
|
ENV_TENANT_ID,
|
||||||
|
ENV_WORKSPACE_NAME,
|
||||||
|
_get_legitimate_default_credential,
|
||||||
|
_get_legitimate_device_code_credential,
|
||||||
|
_get_legitimate_service_principal_credential,
|
||||||
|
get_credential,
|
||||||
|
get_ml_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.fast
|
||||||
|
def test_get_credential() -> None:
|
||||||
|
def _mock_validation_error() -> None:
|
||||||
|
raise ClientAuthenticationError("")
|
||||||
|
|
||||||
|
# test the case where service principal credentials are set as environment variables
|
||||||
|
mock_env_vars = {
|
||||||
|
ENV_SERVICE_PRINCIPAL_ID: "foo",
|
||||||
|
ENV_TENANT_ID: "bar",
|
||||||
|
ENV_SERVICE_PRINCIPAL_PASSWORD: "baz",
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(os.environ, "get", return_value=mock_env_vars):
|
||||||
|
with patch.multiple(
|
||||||
|
"health_azure.utils",
|
||||||
|
is_running_in_azure_ml=DEFAULT,
|
||||||
|
is_running_on_azure_agent=DEFAULT,
|
||||||
|
_get_legitimate_service_principal_credential=DEFAULT,
|
||||||
|
_get_legitimate_device_code_credential=DEFAULT,
|
||||||
|
_get_legitimate_default_credential=DEFAULT,
|
||||||
|
_get_legitimate_interactive_browser_credential=DEFAULT,
|
||||||
|
) as mocks:
|
||||||
|
mocks["is_running_in_azure_ml"].return_value = False
|
||||||
|
mocks["is_running_on_azure_agent"].return_value = False
|
||||||
|
_ = get_credential()
|
||||||
|
mocks["_get_legitimate_service_principal_credential"].assert_called_once()
|
||||||
|
mocks["_get_legitimate_device_code_credential"].assert_not_called()
|
||||||
|
mocks["_get_legitimate_default_credential"].assert_not_called()
|
||||||
|
mocks["_get_legitimate_interactive_browser_credential"].assert_not_called()
|
||||||
|
|
||||||
|
# if the environment variables are not set and we are running on a local machine, a
|
||||||
|
# DefaultAzureCredential should be attempted first
|
||||||
|
with patch.object(os.environ, "get", return_value={}):
|
||||||
|
with patch.multiple(
|
||||||
|
"health_azure.utils",
|
||||||
|
is_running_in_azure_ml=DEFAULT,
|
||||||
|
is_running_on_azure_agent=DEFAULT,
|
||||||
|
_get_legitimate_service_principal_credential=DEFAULT,
|
||||||
|
_get_legitimate_device_code_credential=DEFAULT,
|
||||||
|
_get_legitimate_default_credential=DEFAULT,
|
||||||
|
_get_legitimate_interactive_browser_credential=DEFAULT,
|
||||||
|
) as mocks:
|
||||||
|
mock_get_sp_cred = mocks["_get_legitimate_service_principal_credential"]
|
||||||
|
mock_get_device_cred = mocks["_get_legitimate_device_code_credential"]
|
||||||
|
mock_get_default_cred = mocks["_get_legitimate_default_credential"]
|
||||||
|
mock_get_browser_cred = mocks["_get_legitimate_interactive_browser_credential"]
|
||||||
|
|
||||||
|
mocks["is_running_in_azure_ml"].return_value = False
|
||||||
|
mocks["is_running_on_azure_agent"].return_value = False
|
||||||
|
_ = get_credential()
|
||||||
|
mock_get_sp_cred.assert_not_called()
|
||||||
|
mock_get_device_cred.assert_not_called()
|
||||||
|
mock_get_default_cred.assert_called_once()
|
||||||
|
mock_get_browser_cred.assert_not_called()
|
||||||
|
|
||||||
|
# if that fails, a DeviceCode credential should be attempted
|
||||||
|
mock_get_default_cred.side_effect = _mock_validation_error
|
||||||
|
_ = get_credential()
|
||||||
|
mock_get_sp_cred.assert_not_called()
|
||||||
|
mock_get_device_cred.assert_called_once()
|
||||||
|
assert mock_get_default_cred.call_count == 2
|
||||||
|
mock_get_browser_cred.assert_not_called()
|
||||||
|
|
||||||
|
# if None of the previous credentials work, an InteractiveBrowser credential should be tried
|
||||||
|
mock_get_device_cred.return_value = None
|
||||||
|
_ = get_credential()
|
||||||
|
mock_get_sp_cred.assert_not_called()
|
||||||
|
assert mock_get_device_cred.call_count == 2
|
||||||
|
assert mock_get_default_cred.call_count == 3
|
||||||
|
mock_get_browser_cred.assert_called_once()
|
||||||
|
|
||||||
|
# finally, if none of the methods work, an Exception should be raised
|
||||||
|
mock_get_browser_cred.return_value = None
|
||||||
|
with pytest.raises(Exception) as e:
|
||||||
|
get_credential()
|
||||||
|
assert (
|
||||||
|
"Unable to generate and validate a credential. Please see Azure ML documentation"
|
||||||
|
"for instructions on different options to get a credential" in str(e)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.fast
|
||||||
|
def test_get_legitimate_service_principal_credential() -> None:
|
||||||
|
# first attempt to create and valiadate a credential with non-existant service principal credentials
|
||||||
|
# and check it fails
|
||||||
|
mock_service_principal_id = "foo"
|
||||||
|
mock_service_principal_password = "bar"
|
||||||
|
mock_tenant_id = "baz"
|
||||||
|
expected_error_msg = f"Found environment variables for {ENV_SERVICE_PRINCIPAL_ID}, "
|
||||||
|
f"{ENV_SERVICE_PRINCIPAL_PASSWORD}, and {ENV_TENANT_ID} but was not able to authenticate"
|
||||||
|
with pytest.raises(Exception) as e:
|
||||||
|
_get_legitimate_service_principal_credential(
|
||||||
|
mock_tenant_id, mock_service_principal_id, mock_service_principal_password
|
||||||
|
)
|
||||||
|
assert expected_error_msg in str(e)
|
||||||
|
|
||||||
|
# now mock the case where validating the credential succeeds and check the value of that
|
||||||
|
with patch("health_azure.utils._validate_credential"):
|
||||||
|
cred = _get_legitimate_service_principal_credential(
|
||||||
|
mock_tenant_id, mock_service_principal_id, mock_service_principal_password
|
||||||
|
)
|
||||||
|
assert isinstance(cred, ClientSecretCredential)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.fast
|
||||||
|
def test_get_legitimate_device_code_credential() -> None:
|
||||||
|
def _mock_credential_fast_timeout(timeout: int) -> DeviceCodeCredential:
|
||||||
|
return DeviceCodeCredential(timeout=1)
|
||||||
|
|
||||||
|
with patch("health_azure.utils.DeviceCodeCredential", new=_mock_credential_fast_timeout):
|
||||||
|
cred = _get_legitimate_device_code_credential()
|
||||||
|
assert cred is None
|
||||||
|
|
||||||
|
# now mock the case where validating the credential succeeds
|
||||||
|
with patch("health_azure.utils._validate_credential"):
|
||||||
|
cred = _get_legitimate_device_code_credential()
|
||||||
|
assert isinstance(cred, DeviceCodeCredential)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.fast
|
||||||
|
def test_get_legitimate_default_credential() -> None:
|
||||||
|
def _mock_credential_fast_timeout(timeout: int) -> DefaultAzureCredential:
|
||||||
|
return DefaultAzureCredential(timeout=1)
|
||||||
|
|
||||||
|
with patch("health_azure.utils.DefaultAzureCredential", new=_mock_credential_fast_timeout):
|
||||||
|
exception_message = r"DefaultAzureCredential failed to retrieve a token from the included credentials."
|
||||||
|
with pytest.raises(ClientAuthenticationError, match=exception_message):
|
||||||
|
cred = _get_legitimate_default_credential()
|
||||||
|
|
||||||
|
with patch("health_azure.utils._validate_credential"):
|
||||||
|
cred = _get_legitimate_default_credential()
|
||||||
|
assert isinstance(cred, DefaultAzureCredential)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.fast
|
||||||
|
def test_get_ml_client_with_existing_client() -> None:
|
||||||
|
"""When passing an existing ml_client, it should be returned"""
|
||||||
|
ml_client = "mock_ml_client"
|
||||||
|
result = get_ml_client(ml_client=ml_client) # type: ignore
|
||||||
|
assert result == ml_client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.fast
|
||||||
|
def test_get_ml_client_without_credentials() -> None:
|
||||||
|
"""When no credentials are available, an exception should be raised"""
|
||||||
|
with patch("health_azure.utils.get_credential", return_value=None):
|
||||||
|
with pytest.raises(ValueError, match="Can't connect to MLClient without a valid credential"):
|
||||||
|
get_ml_client()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.fast
|
||||||
|
def test_get_ml_client_from_config_file() -> None:
|
||||||
|
"""If a workspace config file is found, it should be used to create the MLClient"""
|
||||||
|
mock_credentials = "mock_credentials"
|
||||||
|
mock_config_path = Path("foo")
|
||||||
|
mock_ml_client = MagicMock(workspace_name="workspace")
|
||||||
|
mock_from_config = MagicMock(return_value=mock_ml_client)
|
||||||
|
mock_resolve_config_path = MagicMock(return_value=mock_config_path)
|
||||||
|
with patch.multiple(
|
||||||
|
"health_azure.utils",
|
||||||
|
get_credential=MagicMock(return_value=mock_credentials),
|
||||||
|
resolve_workspace_config_path=mock_resolve_config_path,
|
||||||
|
MLClient=MagicMock(from_config=mock_from_config),
|
||||||
|
):
|
||||||
|
config_file = Path("foo")
|
||||||
|
result = get_ml_client(workspace_config_path=config_file)
|
||||||
|
assert result == mock_ml_client
|
||||||
|
mock_resolve_config_path.assert_called_once_with(config_file)
|
||||||
|
mock_from_config.assert_called_once_with(
|
||||||
|
credential=mock_credentials,
|
||||||
|
path=str(mock_config_path),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.fast
|
||||||
|
def test_get_ml_client_from_environment_variables() -> None:
|
||||||
|
"""When no workspace config file is found, the MLClient should be created from environment variables"""
|
||||||
|
mock_credentials = "mock_credentials"
|
||||||
|
the_client = "the_client"
|
||||||
|
mock_ml_client = MagicMock(return_value=the_client)
|
||||||
|
workspace = "workspace"
|
||||||
|
subscription = "subscription"
|
||||||
|
resource_group = "resource_group"
|
||||||
|
with patch.multiple(
|
||||||
|
"health_azure.utils",
|
||||||
|
get_credential=MagicMock(return_value=mock_credentials),
|
||||||
|
resolve_workspace_config_path=MagicMock(return_value=None),
|
||||||
|
MLClient=mock_ml_client,
|
||||||
|
):
|
||||||
|
with patch.dict(
|
||||||
|
os.environ,
|
||||||
|
{ENV_WORKSPACE_NAME: workspace, ENV_SUBSCRIPTION_ID: subscription, ENV_RESOURCE_GROUP: resource_group},
|
||||||
|
):
|
||||||
|
result = get_ml_client()
|
||||||
|
assert result == the_client
|
||||||
|
mock_ml_client.assert_called_once_with(
|
||||||
|
subscription_id=subscription,
|
||||||
|
resource_group_name=resource_group,
|
||||||
|
workspace_name=workspace,
|
||||||
|
credential=mock_credentials,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.fast
|
||||||
|
def test_get_ml_client_fails() -> None:
|
||||||
|
"""If neither a workspace config file nor environment variables are found, an exception should be raised"""
|
||||||
|
mock_credentials = "mock_credentials"
|
||||||
|
the_client = "the_client"
|
||||||
|
mock_ml_client = MagicMock(return_value=the_client)
|
||||||
|
with patch.multiple(
|
||||||
|
"health_azure.utils",
|
||||||
|
get_credential=MagicMock(return_value=mock_credentials),
|
||||||
|
resolve_workspace_config_path=MagicMock(return_value=None),
|
||||||
|
MLClient=mock_ml_client,
|
||||||
|
):
|
||||||
|
# In the GitHub runner, the environment variables are set. We need to unset them to test the exception
|
||||||
|
with patch.dict(os.environ, {ENV_WORKSPACE_NAME: ""}):
|
||||||
|
with pytest.raises(ValueError, match="Tried all ways of identifying the MLClient, but failed"):
|
||||||
|
get_ml_client()
|
|
@ -10,6 +10,7 @@ from pathlib import Path
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from azureml.core.authentication import ServicePrincipalAuthentication
|
from azureml.core.authentication import ServicePrincipalAuthentication
|
||||||
|
from azureml.exceptions._azureml_exception import UserErrorException
|
||||||
from _pytest.logging import LogCaptureFixture
|
from _pytest.logging import LogCaptureFixture
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
@ -22,7 +23,6 @@ from health_azure.utils import (
|
||||||
get_workspace,
|
get_workspace,
|
||||||
)
|
)
|
||||||
from health_azure.utils import (
|
from health_azure.utils import (
|
||||||
WORKSPACE_CONFIG_JSON,
|
|
||||||
ENV_SERVICE_PRINCIPAL_ID,
|
ENV_SERVICE_PRINCIPAL_ID,
|
||||||
ENV_SERVICE_PRINCIPAL_PASSWORD,
|
ENV_SERVICE_PRINCIPAL_PASSWORD,
|
||||||
ENV_TENANT_ID,
|
ENV_TENANT_ID,
|
||||||
|
@ -141,13 +141,16 @@ def test_get_workspace_with_given_workspace() -> None:
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.fast
|
@pytest.mark.fast
|
||||||
def test_get_workspace_searches_for_file() -> None:
|
def test_get_workspace_searches_for_file(tmp_path: Path) -> None:
|
||||||
"""get_workspace should try to load a config.json file if not provided with one"""
|
"""get_workspace should try to load a config.json file if not provided with one"""
|
||||||
found_file = Path("does_not_exist")
|
with change_working_directory(tmp_path):
|
||||||
with patch("health_azure.utils.find_file_in_parent_to_pythonpath", return_value=found_file) as mock_find:
|
found_file = Path("does_not_exist")
|
||||||
with pytest.raises(FileNotFoundError, match="Workspace config file does not exist"):
|
with patch("health_azure.utils.resolve_workspace_config_path", return_value=found_file) as mock_find:
|
||||||
get_workspace(None, None)
|
with pytest.raises(
|
||||||
mock_find.assert_called_once_with(WORKSPACE_CONFIG_JSON)
|
UserErrorException, match="workspace configuration file config.json, could not be found"
|
||||||
|
):
|
||||||
|
get_workspace(None, None)
|
||||||
|
mock_find.assert_called_once_with(None)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.fast
|
@pytest.mark.fast
|
||||||
|
|
|
@ -773,7 +773,7 @@ def test_submit_run_v2(tmp_path: Path) -> None:
|
||||||
with patch("azure.ai.ml.MLClient") as mock_ml_client:
|
with patch("azure.ai.ml.MLClient") as mock_ml_client:
|
||||||
with patch("health_azure.himl.command") as mock_command:
|
with patch("health_azure.himl.command") as mock_command:
|
||||||
himl.submit_run_v2(
|
himl.submit_run_v2(
|
||||||
workspace=None,
|
ml_client=mock_ml_client,
|
||||||
experiment_name=dummy_experiment_name,
|
experiment_name=dummy_experiment_name,
|
||||||
environment=dummy_environment,
|
environment=dummy_environment,
|
||||||
input_datasets_v2=dummy_inputs,
|
input_datasets_v2=dummy_inputs,
|
||||||
|
@ -784,8 +784,6 @@ def test_submit_run_v2(tmp_path: Path) -> None:
|
||||||
compute_target=dummy_compute_target,
|
compute_target=dummy_compute_target,
|
||||||
tags=dummy_tags,
|
tags=dummy_tags,
|
||||||
docker_shm_size=dummy_docker_shm_size,
|
docker_shm_size=dummy_docker_shm_size,
|
||||||
workspace_config_path=None,
|
|
||||||
ml_client=mock_ml_client,
|
|
||||||
hyperparam_args=None,
|
hyperparam_args=None,
|
||||||
display_name=dummy_display_name,
|
display_name=dummy_display_name,
|
||||||
)
|
)
|
||||||
|
@ -835,7 +833,6 @@ def test_submit_run_v2(tmp_path: Path) -> None:
|
||||||
expected_command += " --learning_rate=${{inputs.learning_rate}}"
|
expected_command += " --learning_rate=${{inputs.learning_rate}}"
|
||||||
|
|
||||||
himl.submit_run_v2(
|
himl.submit_run_v2(
|
||||||
workspace=None,
|
|
||||||
experiment_name=dummy_experiment_name,
|
experiment_name=dummy_experiment_name,
|
||||||
environment=dummy_environment,
|
environment=dummy_environment,
|
||||||
input_datasets_v2=dummy_inputs,
|
input_datasets_v2=dummy_inputs,
|
||||||
|
@ -846,7 +843,6 @@ def test_submit_run_v2(tmp_path: Path) -> None:
|
||||||
compute_target=dummy_compute_target,
|
compute_target=dummy_compute_target,
|
||||||
tags=dummy_tags,
|
tags=dummy_tags,
|
||||||
docker_shm_size=dummy_docker_shm_size,
|
docker_shm_size=dummy_docker_shm_size,
|
||||||
workspace_config_path=None,
|
|
||||||
ml_client=mock_ml_client,
|
ml_client=mock_ml_client,
|
||||||
hyperparam_args=dummy_hyperparam_args,
|
hyperparam_args=dummy_hyperparam_args,
|
||||||
display_name=dummy_display_name,
|
display_name=dummy_display_name,
|
||||||
|
@ -882,7 +878,7 @@ def test_submit_run_v2(tmp_path: Path) -> None:
|
||||||
expected_command = f"python {dummy_entry_script_for_module} {expected_arg_str}"
|
expected_command = f"python {dummy_entry_script_for_module} {expected_arg_str}"
|
||||||
|
|
||||||
himl.submit_run_v2(
|
himl.submit_run_v2(
|
||||||
workspace=None,
|
ml_client=mock_ml_client,
|
||||||
experiment_name=dummy_experiment_name,
|
experiment_name=dummy_experiment_name,
|
||||||
environment=dummy_environment,
|
environment=dummy_environment,
|
||||||
input_datasets_v2=dummy_inputs,
|
input_datasets_v2=dummy_inputs,
|
||||||
|
@ -893,8 +889,6 @@ def test_submit_run_v2(tmp_path: Path) -> None:
|
||||||
compute_target=dummy_compute_target,
|
compute_target=dummy_compute_target,
|
||||||
tags=dummy_tags,
|
tags=dummy_tags,
|
||||||
docker_shm_size=dummy_docker_shm_size,
|
docker_shm_size=dummy_docker_shm_size,
|
||||||
workspace_config_path=None,
|
|
||||||
ml_client=mock_ml_client,
|
|
||||||
hyperparam_args=None,
|
hyperparam_args=None,
|
||||||
display_name=dummy_display_name,
|
display_name=dummy_display_name,
|
||||||
)
|
)
|
||||||
|
@ -1308,9 +1302,7 @@ def test_mounting_and_downloading_dataset(tmp_path: Path) -> None:
|
||||||
target_path = tmp_path / action
|
target_path = tmp_path / action
|
||||||
dataset_config = DatasetConfig(name="hello_world", use_mounting=use_mounting, target_folder=target_path)
|
dataset_config = DatasetConfig(name="hello_world", use_mounting=use_mounting, target_folder=target_path)
|
||||||
logging.info(f"ready to {action}")
|
logging.info(f"ready to {action}")
|
||||||
paths, mount_contexts = setup_local_datasets(
|
paths, mount_contexts = setup_local_datasets(dataset_configs=[dataset_config], workspace=workspace)
|
||||||
dataset_configs=[dataset_config], strictly_aml_v1=True, aml_workspace=workspace
|
|
||||||
)
|
|
||||||
logging.info(f"{action} done")
|
logging.info(f"{action} done")
|
||||||
path = paths[0]
|
path = paths[0]
|
||||||
assert path is not None
|
assert path is not None
|
||||||
|
@ -1372,7 +1364,7 @@ class TestOutputDataset:
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
["run_target", "local_folder", "strictly_aml_v1"],
|
["run_target", "local_folder", "strictly_aml_v1"],
|
||||||
[
|
[
|
||||||
(RunTarget.LOCAL, True, False),
|
(RunTarget.LOCAL, True, True),
|
||||||
(RunTarget.AZUREML, False, True),
|
(RunTarget.AZUREML, False, True),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -10,7 +10,6 @@ from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Generator, Optional
|
from typing import Dict, Generator, Optional
|
||||||
|
|
||||||
from azure.ai.ml import MLClient
|
|
||||||
from azureml.core import Run
|
from azureml.core import Run
|
||||||
from health_azure.utils import (
|
from health_azure.utils import (
|
||||||
ENV_EXPERIMENT_NAME,
|
ENV_EXPERIMENT_NAME,
|
||||||
|
@ -20,7 +19,6 @@ from health_azure.utils import (
|
||||||
)
|
)
|
||||||
from health_azure import create_aml_run_object
|
from health_azure import create_aml_run_object
|
||||||
from health_azure.himl import effective_experiment_name
|
from health_azure.himl import effective_experiment_name
|
||||||
from health_azure.utils import get_ml_client, get_workspace
|
|
||||||
|
|
||||||
DEFAULT_DATASTORE = "himldatasets"
|
DEFAULT_DATASTORE = "himldatasets"
|
||||||
|
|
||||||
|
@ -111,16 +109,6 @@ def create_unittest_run_object(snapshot_directory: Optional[Path] = None) -> Run
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_test_ml_client() -> MLClient:
|
|
||||||
"""Generates an MLClient object for use in tests.
|
|
||||||
|
|
||||||
:return: MLClient object
|
|
||||||
"""
|
|
||||||
|
|
||||||
workspace = get_workspace()
|
|
||||||
return get_ml_client(aml_workspace=workspace)
|
|
||||||
|
|
||||||
|
|
||||||
def current_test_name() -> str:
|
def current_test_name() -> str:
|
||||||
"""Get the name of the currently executed test. This is read off an environment variable. If that
|
"""Get the name of the currently executed test. This is read off an environment variable. If that
|
||||||
is not found, the function returns an empty string."""
|
is not found, the function returns an empty string."""
|
||||||
|
|
|
@ -21,7 +21,7 @@ def mount_dataset(dataset_id: str, tmp_root: str = "/tmp/datasets", aml_workspac
|
||||||
ws = get_workspace(aml_workspace)
|
ws = get_workspace(aml_workspace)
|
||||||
target_folder = "/".join([tmp_root, dataset_id])
|
target_folder = "/".join([tmp_root, dataset_id])
|
||||||
dataset = DatasetConfig(name=dataset_id, target_folder=target_folder, use_mounting=True)
|
dataset = DatasetConfig(name=dataset_id, target_folder=target_folder, use_mounting=True)
|
||||||
_, mount_ctx = dataset.to_input_dataset_local(strictly_aml_v1=True, workspace=ws)
|
_, mount_ctx = dataset.to_input_dataset_local(workspace=ws)
|
||||||
assert mount_ctx is not None # for mypy
|
assert mount_ctx is not None # for mypy
|
||||||
mount_ctx.start()
|
mount_ctx.start()
|
||||||
return mount_ctx
|
return mount_ctx
|
||||||
|
|
|
@ -21,5 +21,5 @@ def download_azure_dataset(tmp_path: Path, dataset_id: str) -> None:
|
||||||
with check_config_json(script_folder=tmp_path, shared_config_json=get_shared_config_json()):
|
with check_config_json(script_folder=tmp_path, shared_config_json=get_shared_config_json()):
|
||||||
ws = get_workspace(workspace_config_path=tmp_path / WORKSPACE_CONFIG_JSON)
|
ws = get_workspace(workspace_config_path=tmp_path / WORKSPACE_CONFIG_JSON)
|
||||||
dataset = DatasetConfig(name=dataset_id, target_folder=tmp_path, use_mounting=False)
|
dataset = DatasetConfig(name=dataset_id, target_folder=tmp_path, use_mounting=False)
|
||||||
dataset_dl_folder = dataset.to_input_dataset_local(strictly_aml_v1=True, workspace=ws)
|
dataset_dl_folder = dataset.to_input_dataset_local(workspace=ws)
|
||||||
logging.info(f"Dataset saved in {dataset_dl_folder}")
|
logging.info(f"Dataset saved in {dataset_dl_folder}")
|
||||||
|
|
|
@ -15,6 +15,14 @@ class RunnerMode(Enum):
|
||||||
EVAL_FULL = "eval_full"
|
EVAL_FULL = "eval_full"
|
||||||
|
|
||||||
|
|
||||||
|
class LogLevel(Enum):
|
||||||
|
ERROR = "ERROR"
|
||||||
|
WARNING = "WARNING"
|
||||||
|
WARN = "WARN"
|
||||||
|
INFO = "INFO"
|
||||||
|
DEBUG = "DEBUG"
|
||||||
|
|
||||||
|
|
||||||
DEBUG_DDP_ENV_VAR = "TORCH_DISTRIBUTED_DEBUG"
|
DEBUG_DDP_ENV_VAR = "TORCH_DISTRIBUTED_DEBUG"
|
||||||
|
|
||||||
|
|
||||||
|
@ -87,10 +95,22 @@ class ExperimentConfig(param.Parameterized):
|
||||||
doc="The maximum runtime that is allowed for this job in AzureML. This is given as a floating"
|
doc="The maximum runtime that is allowed for this job in AzureML. This is given as a floating"
|
||||||
"point number with a string suffix s, m, h, d for seconds, minutes, hours, day. Examples: '3.5h', '2d'",
|
"point number with a string suffix s, m, h, d for seconds, minutes, hours, day. Examples: '3.5h', '2d'",
|
||||||
)
|
)
|
||||||
mode: str = param.ClassSelector(
|
mode: RunnerMode = param.ClassSelector(
|
||||||
class_=RunnerMode,
|
class_=RunnerMode,
|
||||||
default=RunnerMode.TRAIN,
|
default=RunnerMode.TRAIN,
|
||||||
doc=f"The mode to run the experiment in. Can be one of '{RunnerMode.TRAIN}' (training and evaluation on the "
|
doc=f"The mode to run the experiment in. Can be one of '{RunnerMode.TRAIN}' (training and evaluation on the "
|
||||||
f"test set), or '{RunnerMode.EVAL_FULL}' for evaluation on the full dataset specified by the "
|
f"test set), or '{RunnerMode.EVAL_FULL}' for evaluation on the full dataset specified by the "
|
||||||
"'get_eval_data_module' method of the container.",
|
"'get_eval_data_module' method of the container.",
|
||||||
)
|
)
|
||||||
|
log_level: Optional[RunnerMode] = param.ClassSelector(
|
||||||
|
class_=LogLevel,
|
||||||
|
default=None,
|
||||||
|
doc=f"The log level to use. Can be one of {list(map(str, LogLevel))}",
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def submit_to_azure_ml(self) -> bool:
|
||||||
|
"""Returns True if the experiment should be submitted to AzureML, False if it should be run locally.
|
||||||
|
|
||||||
|
:return: True if the experiment should be submitted to AzureML, False if it should be run locally."""
|
||||||
|
return self.cluster != ""
|
||||||
|
|
|
@ -15,8 +15,6 @@ import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from azureml.core import Workspace
|
|
||||||
|
|
||||||
# Add hi-ml packages to sys.path so that AML can find them if we are using the runner directly from the git repo
|
# Add hi-ml packages to sys.path so that AML can find them if we are using the runner directly from the git repo
|
||||||
himl_root = Path(__file__).resolve().parent.parent.parent.parent
|
himl_root = Path(__file__).resolve().parent.parent.parent.parent
|
||||||
folders_to_add = [himl_root / "hi-ml" / "src", himl_root / "hi-ml-azure" / "src", himl_root / "hi-ml-cpath" / "src"]
|
folders_to_add = [himl_root / "hi-ml" / "src", himl_root / "hi-ml-azure" / "src", himl_root / "hi-ml-cpath" / "src"]
|
||||||
|
@ -34,8 +32,6 @@ from health_azure.paths import is_himl_used_from_git_repo # noqa: E402
|
||||||
from health_azure.utils import ( # noqa: E402
|
from health_azure.utils import ( # noqa: E402
|
||||||
ENV_LOCAL_RANK,
|
ENV_LOCAL_RANK,
|
||||||
ENV_NODE_RANK,
|
ENV_NODE_RANK,
|
||||||
get_workspace,
|
|
||||||
get_ml_client,
|
|
||||||
is_local_rank_zero,
|
is_local_rank_zero,
|
||||||
is_running_in_azure_ml,
|
is_running_in_azure_ml,
|
||||||
set_environment_variables_for_multi_node,
|
set_environment_variables_for_multi_node,
|
||||||
|
@ -122,6 +118,11 @@ class Runner:
|
||||||
parser1_result = parse_arguments(parser1, args=filtered_args)
|
parser1_result = parse_arguments(parser1, args=filtered_args)
|
||||||
experiment_config = ExperimentConfig(**parser1_result.args)
|
experiment_config = ExperimentConfig(**parser1_result.args)
|
||||||
|
|
||||||
|
from health_azure.logging import logging_stdout_handler # noqa: E402
|
||||||
|
|
||||||
|
if logging_stdout_handler is not None and experiment_config.log_level is not None:
|
||||||
|
print(f"Setting custom logging level to {experiment_config.log_level}")
|
||||||
|
logging_stdout_handler.setLevel(experiment_config.log_level.value)
|
||||||
self.experiment_config = experiment_config
|
self.experiment_config = experiment_config
|
||||||
if not experiment_config.model:
|
if not experiment_config.model:
|
||||||
raise ValueError("Parameter 'model' needs to be set to specify which model to run.")
|
raise ValueError("Parameter 'model' needs to be set to specify which model to run.")
|
||||||
|
@ -150,7 +151,7 @@ class Runner:
|
||||||
"""
|
"""
|
||||||
Runs sanity checks on the whole experiment.
|
Runs sanity checks on the whole experiment.
|
||||||
"""
|
"""
|
||||||
if not self.experiment_config.cluster:
|
if not self.experiment_config.submit_to_azure_ml:
|
||||||
if self.lightning_container.hyperdrive:
|
if self.lightning_container.hyperdrive:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"HyperDrive for hyperparameters tuning is only supported when submitting the job to "
|
"HyperDrive for hyperparameters tuning is only supported when submitting the job to "
|
||||||
|
@ -214,47 +215,26 @@ class Runner:
|
||||||
script_params = sys.argv[1:]
|
script_params = sys.argv[1:]
|
||||||
|
|
||||||
environment_variables = self.additional_environment_variables()
|
environment_variables = self.additional_environment_variables()
|
||||||
|
|
||||||
# Get default datastore from the provided workspace. Authentication can take a few seconds, hence only do
|
|
||||||
# that if we are really submitting to AzureML.
|
|
||||||
workspace: Optional[Workspace] = None
|
|
||||||
if self.experiment_config.cluster:
|
|
||||||
try:
|
|
||||||
workspace = get_workspace(workspace_config_path=self.experiment_config.workspace_config_path)
|
|
||||||
except ValueError:
|
|
||||||
raise ValueError(
|
|
||||||
"Unable to submit the script to AzureML because no workspace configuration file "
|
|
||||||
"(config.json) was found."
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.lightning_container.datastore:
|
|
||||||
datastore = self.lightning_container.datastore
|
|
||||||
elif workspace:
|
|
||||||
datastore = workspace.get_default_datastore().name
|
|
||||||
else:
|
|
||||||
datastore = ""
|
|
||||||
|
|
||||||
local_datasets = self.lightning_container.local_datasets
|
local_datasets = self.lightning_container.local_datasets
|
||||||
all_local_datasets = [Path(p) for p in local_datasets] if len(local_datasets) > 0 else []
|
all_local_datasets = [Path(p) for p in local_datasets] if len(local_datasets) > 0 else []
|
||||||
# When running in AzureML, respect the commandline flag for mounting. Outside of AML, we always mount
|
# When running in AzureML, respect the commandline flag for mounting. Outside of AML, we always mount
|
||||||
# datasets to be quicker.
|
# datasets to be quicker.
|
||||||
use_mounting = self.experiment_config.mount_in_azureml if self.experiment_config.cluster else True
|
use_mounting = self.experiment_config.mount_in_azureml if self.experiment_config.submit_to_azure_ml else True
|
||||||
input_datasets = create_dataset_configs(
|
input_datasets = create_dataset_configs(
|
||||||
all_azure_dataset_ids=self.lightning_container.azure_datasets,
|
all_azure_dataset_ids=self.lightning_container.azure_datasets,
|
||||||
all_dataset_mountpoints=self.lightning_container.dataset_mountpoints,
|
all_dataset_mountpoints=self.lightning_container.dataset_mountpoints,
|
||||||
all_local_datasets=all_local_datasets, # type: ignore
|
all_local_datasets=all_local_datasets, # type: ignore
|
||||||
datastore=datastore,
|
datastore=self.lightning_container.datastore,
|
||||||
use_mounting=use_mounting,
|
use_mounting=use_mounting,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.experiment_config.cluster and not is_running_in_azure_ml():
|
if self.experiment_config.submit_to_azure_ml and not is_running_in_azure_ml():
|
||||||
if self.experiment_config.strictly_aml_v1:
|
if self.experiment_config.strictly_aml_v1:
|
||||||
hyperdrive_config = self.lightning_container.get_hyperdrive_config()
|
hyperdrive_config = self.lightning_container.get_hyperdrive_config()
|
||||||
hyperparam_args = None
|
hyperparam_args = None
|
||||||
else:
|
else:
|
||||||
hyperparam_args = self.lightning_container.get_hyperparam_args()
|
hyperparam_args = self.lightning_container.get_hyperparam_args()
|
||||||
hyperdrive_config = None
|
hyperdrive_config = None
|
||||||
ml_client = get_ml_client(aml_workspace=workspace) if not self.experiment_config.strictly_aml_v1 else None
|
|
||||||
|
|
||||||
env_file = choose_conda_env_file(env_file=self.experiment_config.conda_env)
|
env_file = choose_conda_env_file(env_file=self.experiment_config.conda_env)
|
||||||
logging.info(f"Using this Conda environment definition: {env_file}")
|
logging.info(f"Using this Conda environment definition: {env_file}")
|
||||||
|
@ -265,18 +245,15 @@ class Runner:
|
||||||
snapshot_root_directory=root_folder,
|
snapshot_root_directory=root_folder,
|
||||||
script_params=script_params,
|
script_params=script_params,
|
||||||
conda_environment_file=env_file,
|
conda_environment_file=env_file,
|
||||||
aml_workspace=workspace,
|
|
||||||
ml_client=ml_client,
|
|
||||||
compute_cluster_name=self.experiment_config.cluster,
|
compute_cluster_name=self.experiment_config.cluster,
|
||||||
environment_variables=environment_variables,
|
environment_variables=environment_variables,
|
||||||
default_datastore=datastore,
|
|
||||||
experiment_name=self.lightning_container.effective_experiment_name,
|
experiment_name=self.lightning_container.effective_experiment_name,
|
||||||
input_datasets=input_datasets, # type: ignore
|
input_datasets=input_datasets, # type: ignore
|
||||||
num_nodes=self.experiment_config.num_nodes,
|
num_nodes=self.experiment_config.num_nodes,
|
||||||
wait_for_completion=self.experiment_config.wait_for_completion,
|
wait_for_completion=self.experiment_config.wait_for_completion,
|
||||||
max_run_duration=self.experiment_config.max_run_duration,
|
max_run_duration=self.experiment_config.max_run_duration,
|
||||||
ignored_folders=[],
|
ignored_folders=[],
|
||||||
submit_to_azureml=bool(self.experiment_config.cluster),
|
submit_to_azureml=self.experiment_config.submit_to_azure_ml,
|
||||||
docker_base_image=DEFAULT_DOCKER_BASE_IMAGE,
|
docker_base_image=DEFAULT_DOCKER_BASE_IMAGE,
|
||||||
docker_shm_size=self.experiment_config.docker_shm_size,
|
docker_shm_size=self.experiment_config.docker_shm_size,
|
||||||
hyperdrive_config=hyperdrive_config,
|
hyperdrive_config=hyperdrive_config,
|
||||||
|
@ -292,7 +269,6 @@ class Runner:
|
||||||
submit_to_azureml=False,
|
submit_to_azureml=False,
|
||||||
environment_variables=environment_variables,
|
environment_variables=environment_variables,
|
||||||
strictly_aml_v1=self.experiment_config.strictly_aml_v1,
|
strictly_aml_v1=self.experiment_config.strictly_aml_v1,
|
||||||
default_datastore=datastore,
|
|
||||||
)
|
)
|
||||||
if azure_run_info.run:
|
if azure_run_info.run:
|
||||||
# This code is only reached inside Azure. Set display name again - this will now affect
|
# This code is only reached inside Azure. Set display name again - this will now affect
|
||||||
|
|
|
@ -26,6 +26,7 @@ from health_ml.lightning_container import LightningContainer
|
||||||
from health_ml.runner import Runner, create_logging_filename, run_with_logging
|
from health_ml.runner import Runner, create_logging_filename, run_with_logging
|
||||||
from health_ml.utils.common_utils import change_working_directory
|
from health_ml.utils.common_utils import change_working_directory
|
||||||
from health_ml.utils.fixed_paths import repository_root_directory
|
from health_ml.utils.fixed_paths import repository_root_directory
|
||||||
|
from testhiml.utils_testhiml import DEFAULT_WORKSPACE
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
@ -132,10 +133,9 @@ def test_ddp_debug_flag(debug_ddp: DebugDDPOptions, mock_runner: Runner) -> None
|
||||||
model_name = "HelloWorld"
|
model_name = "HelloWorld"
|
||||||
arguments = ["", f"--debug_ddp={debug_ddp}", f"--model={model_name}"]
|
arguments = ["", f"--debug_ddp={debug_ddp}", f"--model={model_name}"]
|
||||||
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_azure_if_needed:
|
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_azure_if_needed:
|
||||||
with patch("health_ml.runner.get_workspace"):
|
with patch("health_ml.runner.Runner.run_in_situ"):
|
||||||
with patch("health_ml.runner.Runner.run_in_situ"):
|
with patch.object(sys, "argv", arguments):
|
||||||
with patch.object(sys, "argv", arguments):
|
mock_runner.run()
|
||||||
mock_runner.run()
|
|
||||||
mock_submit_to_azure_if_needed.assert_called_once()
|
mock_submit_to_azure_if_needed.assert_called_once()
|
||||||
assert mock_submit_to_azure_if_needed.call_args[1]["environment_variables"][DEBUG_DDP_ENV_VAR] == debug_ddp
|
assert mock_submit_to_azure_if_needed.call_args[1]["environment_variables"][DEBUG_DDP_ENV_VAR] == debug_ddp
|
||||||
|
|
||||||
|
@ -144,12 +144,9 @@ def test_additional_aml_run_tags(mock_runner: Runner) -> None:
|
||||||
model_name = "HelloWorld"
|
model_name = "HelloWorld"
|
||||||
arguments = ["", f"--model={model_name}", "--cluster=foo"]
|
arguments = ["", f"--model={model_name}", "--cluster=foo"]
|
||||||
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_azure_if_needed:
|
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_azure_if_needed:
|
||||||
with patch("health_ml.runner.check_conda_environment"):
|
with patch("health_ml.runner.Runner.run_in_situ"):
|
||||||
with patch("health_ml.runner.get_workspace"):
|
with patch.object(sys, "argv", arguments):
|
||||||
with patch("health_ml.runner.get_ml_client"):
|
mock_runner.run()
|
||||||
with patch("health_ml.runner.Runner.run_in_situ"):
|
|
||||||
with patch.object(sys, "argv", arguments):
|
|
||||||
mock_runner.run()
|
|
||||||
mock_submit_to_azure_if_needed.assert_called_once()
|
mock_submit_to_azure_if_needed.assert_called_once()
|
||||||
assert "commandline_args" in mock_submit_to_azure_if_needed.call_args[1]["tags"]
|
assert "commandline_args" in mock_submit_to_azure_if_needed.call_args[1]["tags"]
|
||||||
assert "tag" in mock_submit_to_azure_if_needed.call_args[1]["tags"]
|
assert "tag" in mock_submit_to_azure_if_needed.call_args[1]["tags"]
|
||||||
|
@ -162,9 +159,6 @@ def test_additional_environment_variables(mock_runner: Runner) -> None:
|
||||||
with patch.multiple(
|
with patch.multiple(
|
||||||
"health_ml.runner",
|
"health_ml.runner",
|
||||||
submit_to_azure_if_needed=DEFAULT,
|
submit_to_azure_if_needed=DEFAULT,
|
||||||
check_conda_environment=DEFAULT,
|
|
||||||
get_workspace=DEFAULT,
|
|
||||||
get_ml_client=DEFAULT,
|
|
||||||
) as mocks:
|
) as mocks:
|
||||||
with patch("health_ml.runner.Runner.run_in_situ"):
|
with patch("health_ml.runner.Runner.run_in_situ"):
|
||||||
with patch("health_ml.runner.Runner.parse_and_load_model"):
|
with patch("health_ml.runner.Runner.parse_and_load_model"):
|
||||||
|
@ -185,9 +179,8 @@ def test_run(mock_runner: Runner) -> None:
|
||||||
model_name = "HelloWorld"
|
model_name = "HelloWorld"
|
||||||
arguments = ["", f"--model={model_name}"]
|
arguments = ["", f"--model={model_name}"]
|
||||||
with patch("health_ml.runner.Runner.run_in_situ") as mock_run_in_situ:
|
with patch("health_ml.runner.Runner.run_in_situ") as mock_run_in_situ:
|
||||||
with patch("health_ml.runner.get_workspace"):
|
with patch.object(sys, "argv", arguments):
|
||||||
with patch.object(sys, "argv", arguments):
|
model_config, azure_run_info = mock_runner.run()
|
||||||
model_config, azure_run_info = mock_runner.run()
|
|
||||||
mock_run_in_situ.assert_called_once()
|
mock_run_in_situ.assert_called_once()
|
||||||
|
|
||||||
assert model_config is not None # for pyright
|
assert model_config is not None # for pyright
|
||||||
|
@ -197,17 +190,13 @@ def test_run(mock_runner: Runner) -> None:
|
||||||
|
|
||||||
|
|
||||||
@patch("health_ml.runner.choose_conda_env_file")
|
@patch("health_ml.runner.choose_conda_env_file")
|
||||||
@patch("health_ml.runner.get_workspace")
|
|
||||||
@pytest.mark.fast
|
@pytest.mark.fast
|
||||||
def test_submit_to_azureml_if_needed(
|
def test_submit_to_azureml_if_needed(mock_get_env_files: MagicMock, mock_runner: Runner) -> None:
|
||||||
mock_get_workspace: MagicMock, mock_get_env_files: MagicMock, mock_runner: Runner
|
|
||||||
) -> None:
|
|
||||||
def _mock_dont_submit_to_aml(
|
def _mock_dont_submit_to_aml(
|
||||||
input_datasets: List[DatasetConfig],
|
input_datasets: List[DatasetConfig],
|
||||||
submit_to_azureml: bool,
|
submit_to_azureml: bool,
|
||||||
strictly_aml_v1: bool, # type: ignore
|
strictly_aml_v1: bool, # type: ignore
|
||||||
environment_variables: Dict[str, Any], # type: ignore
|
environment_variables: Dict[str, Any], # type: ignore
|
||||||
default_datastore: Optional[str], # type: ignore
|
|
||||||
) -> AzureRunInfo:
|
) -> AzureRunInfo:
|
||||||
datasets_input = [d.target_folder for d in input_datasets] if input_datasets else []
|
datasets_input = [d.target_folder for d in input_datasets] if input_datasets else []
|
||||||
return AzureRunInfo(
|
return AzureRunInfo(
|
||||||
|
@ -222,10 +211,6 @@ def test_submit_to_azureml_if_needed(
|
||||||
|
|
||||||
mock_get_env_files.return_value = Path("some_file.txt")
|
mock_get_env_files.return_value = Path("some_file.txt")
|
||||||
|
|
||||||
mock_default_datastore = MagicMock()
|
|
||||||
mock_default_datastore.name.return_value = "dummy_datastore"
|
|
||||||
mock_get_workspace.get_default_datastore.return_value = mock_default_datastore
|
|
||||||
|
|
||||||
with patch("health_ml.runner.create_dataset_configs") as mock_create_datasets:
|
with patch("health_ml.runner.create_dataset_configs") as mock_create_datasets:
|
||||||
mock_create_datasets.return_value = []
|
mock_create_datasets.return_value = []
|
||||||
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_aml:
|
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_aml:
|
||||||
|
@ -334,11 +319,9 @@ def _test_hyperdrive_submission(
|
||||||
# start in that temp folder.
|
# start in that temp folder.
|
||||||
with change_working_folder_and_add_environment(mock_runner.project_root):
|
with change_working_folder_and_add_environment(mock_runner.project_root):
|
||||||
with patch("health_ml.runner.Runner.run_in_situ") as mock_run_in_situ:
|
with patch("health_ml.runner.Runner.run_in_situ") as mock_run_in_situ:
|
||||||
with patch("health_ml.runner.get_workspace"):
|
with patch.object(sys, "argv", arguments):
|
||||||
with patch("health_ml.runner.get_ml_client"):
|
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_aml:
|
||||||
with patch.object(sys, "argv", arguments):
|
mock_runner.run()
|
||||||
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_aml:
|
|
||||||
mock_runner.run()
|
|
||||||
mock_run_in_situ.assert_called_once()
|
mock_run_in_situ.assert_called_once()
|
||||||
mock_submit_to_aml.assert_called_once()
|
mock_submit_to_aml.assert_called_once()
|
||||||
# call_args is a tuple of (args, kwargs)
|
# call_args is a tuple of (args, kwargs)
|
||||||
|
@ -364,11 +347,9 @@ def test_submit_to_azure_docker(mock_runner: Runner) -> None:
|
||||||
# start in that temp folder.
|
# start in that temp folder.
|
||||||
with change_working_folder_and_add_environment(mock_runner.project_root):
|
with change_working_folder_and_add_environment(mock_runner.project_root):
|
||||||
with patch("health_ml.runner.Runner.run_in_situ") as mock_run_in_situ:
|
with patch("health_ml.runner.Runner.run_in_situ") as mock_run_in_situ:
|
||||||
with patch("health_ml.runner.get_ml_client"):
|
with patch.object(sys, "argv", arguments):
|
||||||
with patch("health_ml.runner.get_workspace"):
|
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_aml:
|
||||||
with patch.object(sys, "argv", arguments):
|
mock_runner.run()
|
||||||
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_aml:
|
|
||||||
mock_runner.run()
|
|
||||||
mock_run_in_situ.assert_called_once()
|
mock_run_in_situ.assert_called_once()
|
||||||
mock_submit_to_aml.assert_called_once()
|
mock_submit_to_aml.assert_called_once()
|
||||||
# call_args is a tuple of (args, kwargs)
|
# call_args is a tuple of (args, kwargs)
|
||||||
|
@ -393,16 +374,12 @@ def test_run_hello_world(mock_runner: Runner) -> None:
|
||||||
"""Test running a model end-to-end via the commandline runner"""
|
"""Test running a model end-to-end via the commandline runner"""
|
||||||
model_name = "HelloWorld"
|
model_name = "HelloWorld"
|
||||||
arguments = ["", f"--model={model_name}"]
|
arguments = ["", f"--model={model_name}"]
|
||||||
with patch("health_ml.runner.get_workspace") as mock_get_workspace:
|
with patch.object(sys, "argv", arguments):
|
||||||
with patch.object(sys, "argv", arguments):
|
mock_runner.run()
|
||||||
mock_runner.run()
|
# Summary.txt is written at start, the other files during inference
|
||||||
# get_workspace should not be called when using the runner outside AzureML, to not go through the
|
expected_files = ["experiment_summary.txt", TEST_MSE_FILE, TEST_MAE_FILE]
|
||||||
# time-consuming auth
|
for file in expected_files:
|
||||||
mock_get_workspace.assert_not_called()
|
assert (mock_runner.lightning_container.outputs_folder / file).is_file(), f"Missing file: {file}"
|
||||||
# Summary.txt is written at start, the other files during inference
|
|
||||||
expected_files = ["experiment_summary.txt", TEST_MSE_FILE, TEST_MAE_FILE]
|
|
||||||
for file in expected_files:
|
|
||||||
assert (mock_runner.lightning_container.outputs_folder / file).is_file(), f"Missing file: {file}"
|
|
||||||
|
|
||||||
|
|
||||||
def test_invalid_args(mock_runner: Runner) -> None:
|
def test_invalid_args(mock_runner: Runner) -> None:
|
||||||
|
@ -425,17 +402,37 @@ def test_invalid_profiler(mock_runner: Runner) -> None:
|
||||||
mock_runner.run()
|
mock_runner.run()
|
||||||
|
|
||||||
|
|
||||||
def test_custom_datastore_outside_aml(mock_runner: Runner) -> None:
|
def test_datastore_argument(mock_runner: Runner) -> None:
|
||||||
|
"""The datastore argument should be respected"""
|
||||||
model_name = "HelloWorld"
|
model_name = "HelloWorld"
|
||||||
datastore = "foo"
|
datastore = "foo"
|
||||||
arguments = ["", f"--datastore={datastore}", f"--model={model_name}"]
|
dataset = "bar"
|
||||||
|
arguments = ["", f"--datastore={datastore}", f"--model={model_name}", f"--azure_datasets={dataset}"]
|
||||||
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_azure_if_needed:
|
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_azure_if_needed:
|
||||||
with patch("health_ml.runner.get_workspace"):
|
with patch("health_ml.runner.Runner.run_in_situ"):
|
||||||
with patch("health_ml.runner.Runner.run_in_situ"):
|
with patch.object(sys, "argv", arguments):
|
||||||
with patch.object(sys, "argv", arguments):
|
mock_runner.run()
|
||||||
mock_runner.run()
|
|
||||||
mock_submit_to_azure_if_needed.assert_called_once()
|
mock_submit_to_azure_if_needed.assert_called_once()
|
||||||
assert mock_submit_to_azure_if_needed.call_args[1]["default_datastore"] == datastore
|
input_datasets = mock_submit_to_azure_if_needed.call_args[1]["input_datasets"]
|
||||||
|
assert len(input_datasets) == 1
|
||||||
|
assert input_datasets[0].datastore == datastore
|
||||||
|
assert input_datasets[0].name == dataset
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_authentication_outside_azureml(mock_runner: Runner) -> None:
|
||||||
|
"""No authentication should happen for a model that runs locally and needs no datasets."""
|
||||||
|
model_name = "HelloWorld"
|
||||||
|
arguments = ["", f"--datastore=datastore", f"--model={model_name}"]
|
||||||
|
with patch("health_ml.runner.Runner.run_in_situ"):
|
||||||
|
with patch.object(sys, "argv", arguments):
|
||||||
|
mock_get_workspace = MagicMock()
|
||||||
|
mock_get_ml_client = MagicMock()
|
||||||
|
with patch.multiple(
|
||||||
|
"health_azure.himl", get_workspace=mock_get_workspace, get_ml_client=mock_get_ml_client
|
||||||
|
):
|
||||||
|
mock_runner.run()
|
||||||
|
mock_get_workspace.assert_not_called()
|
||||||
|
mock_get_ml_client.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.fast
|
@pytest.mark.fast
|
||||||
|
@ -512,3 +509,111 @@ def test_run_without_logging(tmp_path: Path) -> None:
|
||||||
run_with_logging(tmp_path)
|
run_with_logging(tmp_path)
|
||||||
mock_create_filename.assert_not_called()
|
mock_create_filename.assert_not_called()
|
||||||
mock_run.assert_called_once()
|
mock_run.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.fast
|
||||||
|
def test_runner_does_not_use_get_workspace() -> None:
|
||||||
|
"""Test that the runner does not itself import get_workspace or get_ml_client (otherwise we would need to check
|
||||||
|
them in the tests below that count calls to those methods)"""
|
||||||
|
with pytest.raises(ImportError):
|
||||||
|
from health_ml.runner import get_workspace # type: ignore
|
||||||
|
with pytest.raises(ImportError):
|
||||||
|
from health_ml.runner import get_ml_client # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def test_runner_authenticates_once_v1() -> None:
|
||||||
|
"""Test that the runner requires authentication only once when doing a job submission with the V1 SDK"""
|
||||||
|
runner = Runner(project_root=repository_root_directory())
|
||||||
|
mock_get_workspace = MagicMock()
|
||||||
|
mock_get_ml_client = MagicMock()
|
||||||
|
with patch.multiple(
|
||||||
|
"health_azure.himl",
|
||||||
|
get_workspace=mock_get_workspace,
|
||||||
|
get_ml_client=mock_get_ml_client,
|
||||||
|
Experiment=MagicMock(),
|
||||||
|
register_environment=MagicMock(return_value="env"),
|
||||||
|
validate_compute_cluster=MagicMock(),
|
||||||
|
):
|
||||||
|
with patch.object(
|
||||||
|
sys,
|
||||||
|
"argv",
|
||||||
|
["src/health_ml/runner.py", "--model=HelloWorld", "--cluster=pr-gpu", "--strictly_aml_v1"],
|
||||||
|
):
|
||||||
|
# Job submission should trigger a system exit
|
||||||
|
with pytest.raises(SystemExit):
|
||||||
|
runner.run()
|
||||||
|
mock_get_workspace.assert_called_once()
|
||||||
|
mock_get_ml_client.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_runner_authenticates_once_v2() -> None:
|
||||||
|
"""Test that the runner requires authentication only once when doing a job submission with the V2 SDK"""
|
||||||
|
runner = Runner(project_root=repository_root_directory())
|
||||||
|
mock_get_workspace = MagicMock()
|
||||||
|
mock_get_ml_client = MagicMock()
|
||||||
|
with patch.multiple(
|
||||||
|
"health_azure.himl",
|
||||||
|
get_workspace=mock_get_workspace,
|
||||||
|
get_ml_client=mock_get_ml_client,
|
||||||
|
command=MagicMock(),
|
||||||
|
):
|
||||||
|
with patch.object(sys, "argv", ["", "--model=HelloWorld", "--cluster=pr-gpu"]):
|
||||||
|
# Job submission should trigger a system exit
|
||||||
|
with pytest.raises(SystemExit):
|
||||||
|
runner.run()
|
||||||
|
mock_get_workspace.assert_not_called()
|
||||||
|
mock_get_ml_client.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_runner_with_local_dataset_v1() -> None:
|
||||||
|
"""Test that the runner requires authentication only once when doing a local run and a dataset has to be mounted"""
|
||||||
|
runner = Runner(project_root=repository_root_directory())
|
||||||
|
mock_get_workspace = MagicMock(return_value=DEFAULT_WORKSPACE.workspace)
|
||||||
|
mock_get_ml_client = MagicMock()
|
||||||
|
with patch.multiple(
|
||||||
|
"health_azure.himl",
|
||||||
|
get_workspace=mock_get_workspace,
|
||||||
|
get_ml_client=mock_get_ml_client,
|
||||||
|
):
|
||||||
|
with patch.object(
|
||||||
|
sys,
|
||||||
|
"argv",
|
||||||
|
[
|
||||||
|
"src/health_ml/runner.py",
|
||||||
|
"--model=HelloWorld",
|
||||||
|
"--strictly_aml_v1",
|
||||||
|
"--azure_datasets=hello_world",
|
||||||
|
],
|
||||||
|
):
|
||||||
|
runner.run()
|
||||||
|
mock_get_workspace.assert_called_once()
|
||||||
|
mock_get_ml_client.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("use_local_dataset", [True, False])
|
||||||
|
def test_runner_with_local_dataset_v2(use_local_dataset: bool, tmp_path: Path) -> None:
|
||||||
|
"""Test that the runner requires authentication only once when doing a local run with SDK v2"""
|
||||||
|
runner = Runner(project_root=repository_root_directory())
|
||||||
|
mock_get_workspace = MagicMock()
|
||||||
|
mock_get_ml_client = MagicMock(return_value=DEFAULT_WORKSPACE.ml_client)
|
||||||
|
with patch.multiple(
|
||||||
|
"health_azure.himl",
|
||||||
|
get_workspace=mock_get_workspace,
|
||||||
|
get_ml_client=mock_get_ml_client,
|
||||||
|
):
|
||||||
|
args = [
|
||||||
|
"src/health_ml/runner.py",
|
||||||
|
"--model=HelloWorld",
|
||||||
|
f"--strictly_aml_v1=False",
|
||||||
|
"--azure_datasets=hello_world",
|
||||||
|
]
|
||||||
|
if use_local_dataset:
|
||||||
|
args.append(f"--local_datasets={tmp_path}")
|
||||||
|
with patch.object(sys, "argv", args):
|
||||||
|
if use_local_dataset:
|
||||||
|
runner.run()
|
||||||
|
else:
|
||||||
|
with pytest.raises(ValueError, match="AzureML SDK v2 does not support downloading datasets from"):
|
||||||
|
runner.run()
|
||||||
|
mock_get_workspace.assert_not_called()
|
||||||
|
mock_get_ml_client.assert_called_once()
|
||||||
|
|
Загрузка…
Ссылка в новой задаче