ENH: Merge Main into Transfer Main (#653)

Sync 32 commits from main

Co-authored-by: Fernando Pérez-García <fepegar@gmail.com>
Co-authored-by: Shruthi42 <13177030+Shruthi42@users.noreply.github.com>
Co-authored-by: vale-salvatelli <vale-salvatelli@users.noreply.github.com>
Co-authored-by: Melissa Bristow <66642528+mebristo@users.noreply.github.com>
Co-authored-by: Anton Schwaighofer <antonsc@microsoft.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Fernando Pérez-García <fperezgarcia@microsoft.com>
Co-authored-by: Ozan Oktay <ozan.oktay@microsoft.com>
This commit is contained in:
Kenza Bouzid 2022-11-10 18:36:40 +00:00 коммит произвёл GitHub
Родитель 6ba3cfe685
Коммит e30a0d1f6d
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
93 изменённых файлов: 3086 добавлений и 802 удалений

54
.github/actions/prepare_cpath_environment/action.yml поставляемый Normal file
Просмотреть файл

@ -0,0 +1,54 @@
name: 'Hi-ml-cpath environment setup'
description: 'Set up environment hi-ml-cpath workflows'
runs:
using: "composite"
steps:
- name: Create AzureML config.json file
shell: bash
run: ./create_config.sh
# Use a cache action to save the full conda environment, so that we don't have to reinstall it every time.
# Paths are tied to the location of the miniconda installation, and may need adjustment on a different OS.
- name: Retrieve cached Conda environment
id: cache-conda
uses: actions/cache@v3
with:
path: /usr/share/miniconda/envs/HimlHisto
key: hi-ml-cpath-conda-${{ hashFiles('hi-ml-cpath/environment.yml') }}
# If the cache action didn't find a cache, then install the conda environment afresh.
- name: Build Conda environment from scratch
uses: conda-incubator/setup-miniconda@v2
if: steps.cache-conda.outputs.cache-hit != 'true'
with:
environment-file: hi-ml-cpath/environment.yml
activate-environment: HimlHisto
# Modify the path to point to the new or cached Conda environment.
# This is effectively also what `conda activate` does.
- name: Activate environment
shell: bash
run: |
echo "Adding Conda bin folder to path"
echo "/usr/share/miniconda/envs/HimlHisto/bin" >> $GITHUB_PATH
- name: Conda info
shell: bash
run: conda info
- name: Show active Python path
shell: bash
run: which python
- name: Install hi-ml locally
shell: bash
run: |
cd hi-ml
make pip_local
- name: Install hi-ml-azure locally
shell: bash
run: |
cd hi-ml-azure
make pip_local

Просмотреть файл

@ -1,31 +0,0 @@
name: 'Smoke test environment setup'
description: 'Set up environment for running smoke tests'
runs:
using: "composite"
steps:
- name: create config file
shell: bash
run: ./create_config.sh
- name: Set up Python ${{ env.pythonVersion }}
uses: actions/setup-python@v4
with:
python-version: ${{ env.pythonVersion }}
- name: Install required packages
shell: bash
run: |
cd ${{ env.folder }}
make pip_from_conda
- name: Install hi-ml locally
shell: bash
run: |
cd hi-ml
make pip_local
- name: Install hi-ml-azure locally
shell: bash
run: |
cd hi-ml-azure
make pip_local

68
.github/workflows/cpath-pr.yml поставляемый
Просмотреть файл

@ -14,7 +14,7 @@ on:
- "hi-ml/**"
env:
pythonVersion: 3.7
pythonVersion: 3.9
folder: hi-ml-cpath
module_for_coverage_reporting: health_cpath
HIML_TENANT_ID: ${{ secrets.HIML_TENANT_ID }}
@ -69,30 +69,8 @@ jobs:
with:
lfs: true
- name: Set up Python ${{ env.pythonVersion }}
uses: actions/setup-python@v4
with:
python-version: ${{ env.pythonVersion }}
- name: PIP upgrade
run: |
cd ${{ env.folder }}
make pip_upgrade
- name: Install required packages
run: |
cd ${{ env.folder }}
make pip_from_conda
- name: Install hi-ml locally
run: |
cd hi-ml
make pip_local
- name: Install hi-ml-azure locally
run: |
cd hi-ml-azure
make pip_local
- name: Prepare Conda environment
uses: ./.github/actions/prepare_cpath_environment
- name: Test with pytest
run: |
@ -126,9 +104,8 @@ jobs:
with:
lfs: true
- name: Set up smoke test environment
id: setup-slides-smoke-test-environment
uses: ./.github/actions/prepare_smoke_test_environment
- name: Prepare Conda environment
uses: ./.github/actions/prepare_cpath_environment
- name: smoke test
run: |
@ -142,9 +119,8 @@ jobs:
with:
lfs: true
- name: Set up smoke test environment
id: setup-tiles-smoke-test-environment
uses: ./.github/actions/prepare_smoke_test_environment
- name: Prepare Conda environment
uses: ./.github/actions/prepare_cpath_environment
- name: smoke test
run: |
@ -158,9 +134,8 @@ jobs:
with:
lfs: true
- name: Set up smoke test environment
id: setup-tiles-smoke-test-environment
uses: ./.github/actions/prepare_smoke_test_environment
- name: Prepare Conda environment
uses: ./.github/actions/prepare_cpath_environment
- name: smoke test
run: |
@ -174,9 +149,8 @@ jobs:
with:
lfs: true
- name: Set up smoke test environment
id: setup-sslmil-smoke-test-environment
uses: ./.github/actions/prepare_smoke_test_environment
- name: Prepare Conda environment
uses: ./.github/actions/prepare_cpath_environment
- name: smoke test
run: |
@ -190,9 +164,8 @@ jobs:
with:
lfs: true
- name: Set up smoke test environment
id: setup-simclr-smoke-test-environment
uses: ./.github/actions/prepare_smoke_test_environment
- name: Prepare Conda environment
uses: ./.github/actions/prepare_cpath_environment
- name: smoke test
run: |
@ -206,9 +179,8 @@ jobs:
with:
lfs: true
- name: Set up smoke test environment
id: setup-finetuning-smoke-test-environment
uses: ./.github/actions/prepare_smoke_test_environment
- name: Prepare Conda environment
uses: ./.github/actions/prepare_cpath_environment
- name: smoke test
run: |
@ -222,9 +194,8 @@ jobs:
with:
lfs: true
- name: Set up smoke test environment
id: setup-finetuning-smoke-test-environment
uses: ./.github/actions/prepare_smoke_test_environment
- name: Prepare Conda environment
uses: ./.github/actions/prepare_cpath_environment
- name: smoke test
run: |
@ -238,9 +209,8 @@ jobs:
with:
lfs: true
- name: Set up smoke test environment
id: setup-finetuning-smoke-test-environment
uses: ./.github/actions/prepare_smoke_test_environment
- name: Prepare Conda environment
uses: ./.github/actions/prepare_cpath_environment
- name: smoke test
run: |

8
.github/workflows/multimodal-pr.yml поставляемый
Просмотреть файл

@ -105,5 +105,9 @@ jobs:
run: |
cd ${{ env.folder }}
ipython kernel install --name "python3" --user
echo "Current branch: ${BRANCH_NAME}"
papermill --parameters repo_branch ${BRANCH_NAME} notebooks/phrase_grounding.ipynb /tmp/phrase_grounding_output.ipynb
PIP_SOURCE="git+https://github.com/microsoft/hi-ml.git@${BRANCH_NAME}#subdirectory=hi-ml-multimodal"
echo "The package will be installed from: $PIP_SOURCE"
papermill \
--parameters pip_source $PIP_SOURCE \
notebooks/phrase_grounding.ipynb \
/tmp/phrase_grounding_output.ipynb

Просмотреть файл

@ -27,7 +27,7 @@ repos:
# The structure of this was suggested by the author of pre-commit and maintainer of flake8
# See https://stackoverflow.com/a/66485642/3956024
- repo: https://github.com/PyCQA/flake8
rev: 4.0.1
rev: 5.0.2
hooks:
- id: flake8
name: flake8 ./hi-ml/

Просмотреть файл

@ -5,6 +5,7 @@
from pathlib import Path
import numpy as np
from azureml.core import Datastore
from sklearn import datasets
from sklearn.model_selection import KFold
@ -40,8 +41,8 @@ def main() -> None:
np.savetxt(str(target_splits_file), np.vstack(indices_test_splits), delimiter=",")
ws = get_workspace()
datastore = get_datastore(workspace=ws,
datastore_name="himldatasets")
datastore: Datastore = get_datastore(workspace=ws,
datastore_name="himldatasets")
dataset_name = 'himl_kfold_split_iris'
datastore.upload_files(

Просмотреть файл

@ -4,6 +4,8 @@
# ------------------------------------------------------------------------------------------
from pathlib import Path
from azureml.core import Datastore
from health_azure.datasets import get_datastore
from health_azure import get_workspace
@ -13,8 +15,8 @@ def main() -> None:
workspace = get_workspace()
datastore = get_datastore(workspace=workspace,
datastore_name="himldatasets")
datastore: Datastore = get_datastore(workspace=workspace,
datastore_name="himldatasets")
# Either download all outputs:
# run.download_files(prefix="outputs", output_directory=str(path))

Просмотреть файл

@ -5,6 +5,7 @@
from pathlib import Path
import numpy as np
from azureml.core import Datastore
from sklearn import datasets
from health_azure.datasets import get_datastore
@ -25,8 +26,8 @@ def main() -> None:
workspace = get_workspace()
datastore = get_datastore(workspace=workspace,
datastore_name="himldatasets")
datastore: Datastore = get_datastore(workspace=workspace,
datastore_name="himldatasets")
datastore.upload_files(
[str(X_csv), str(y_csv)],

Просмотреть файл

@ -5,6 +5,7 @@
from pathlib import Path
import numpy as np
from azureml.core import Datastore
from sklearn import datasets
from health_azure import get_workspace
@ -25,8 +26,8 @@ def main() -> None:
workspace = get_workspace()
datastore = get_datastore(workspace=workspace,
datastore_name="himldatasets")
datastore: Datastore = get_datastore(workspace=workspace,
datastore_name="himldatasets")
datastore.upload_files(
[str(X_csv), str(y_csv)],

Просмотреть файл

@ -1,4 +1,4 @@
# Hyperparameter Search via Hyperdrive
# Hyperparameter Search via Hyperdrive (AML SDK v1)
[HyperDrive runs](https://docs.microsoft.com/en-us/azure/machine-learning/how-to-tune-hyperparameters)
can start multiple AzureML jobs in parallel. This can be used for tuning hyperparameters, or executing multiple
@ -27,3 +27,37 @@ submit_to_azure_if_needed(..., hyperdrive_config=hyperdrive_config)
For further examples, please check the [example scripts here](examples.md), and the
[HyperDrive documentation](https://docs.microsoft.com/en-us/azure/machine-learning/how-to-tune-hyperparameters).
# Hyperparameter Search in AML SDK v2
There is no concept of a HyperDriveConfig in AML SDK v2. Instead, hyperparameter search arguments are passed into a
command, and then the 'sweep' method is called [AML
docs](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-tune-hyperparameters). To specify a hyperparameter
search job you must specify the method `get_parameter_tuning_args` in your Container. This should return a dictionary of
the arguments to be passed in to the command. For example:
```python
def get_parameter_tuning_args(self) -> Dict[str, Any]:
from azure.ai.ml.entities import Choice
from health_azure.himl import (MAX_TOTAL_TRIALS_ARG, PARAM_SAMPLING_ARG, SAMPLING_ALGORITHM_ARG,
PRIMARY_METRIC_ARG, GOAL_ARG)
values = [0.1, 0.5, 0.9]
argument_name = "learning_rate"
param_sampling = {argument_name: Choice(values)}
metric_name = "val/loss"
hparam_args = {
MAX_TOTAL_TRIALS_ARG: len(values),
PARAM_SAMPLING_ARG: param_sampling,
SAMPLING_ALGORITHM_ARG: "grid",
PRIMARY_METRIC_ARG: metric_name,
GOAL_ARG: "Minimize"
}
return hparam_args
```
Additional parameters, sampling strategies, limits etc. are described in the link above. Note that each job that is
created will receive an additional command line argument `<argument_name>` and it is your job to update the script to be
able to parse and use this argument.

Просмотреть файл

@ -22,7 +22,7 @@ The `hi-ml` toolbox provides
azure_setup.md
authentication.md
datasets.md
hyperdrive.md
hyperparameter_search.md
lowpriority.md
commandline_tools.md
downloading.md

Просмотреть файл

@ -1 +0,0 @@
../../hi-ml-multimodal/README.md

Просмотреть файл

@ -291,3 +291,12 @@ experiment will get run 3 times with seeds 0, 1 and 2. This is equivalent to sta
These runs will be started in parallel in AzureML via the HyperDrive framework. It is not possible to run with different
seeds on a local machine, other than by manually starting runs with `--random_seed=0` etc.
## Common problems with running in AML
1. `"Your total snapshot size exceeds the limit <SNAPSHOT_LIMIT>"`. Cause: The size of your source directory is larger than
the limit that AML sets for snapshots. Solution: check for cache files, log files or other files that are not
necessary for running your experiment and add them to a `.amlignore` file in the root directory. Alternatively, you
can see Azure ML documentation for instructions on increasing this limit, although it will make your jobs slower.
2. `"FileNotFoundError"`. Possible cause: Symlinked files. Azure ML SDK v2 will resolve the symlink and attempt to upload
the resolved file. Solution: Remove symlinks from any files that should be uploaded to Azure ML.

Просмотреть файл

@ -133,7 +133,7 @@ To use this code with your own data, you will need to:
and `DatasetWithReturnIndex`. See for example how we constructed `RSNAKaggleCXR`
class. WARNING: the first positional argument of your dataset class constructor MUST be the data directory ("root"),
as VisionDataModule expects this in the prepare_data step.
3. In your own container update the `_SSLDataClassMappings` member of the class so that the code knows which data class
3. In your own container update the `DatasetToClassMapping` member of the class so that the code knows which data class
to associate to your new dataset name.
4. Create a yaml configuration file that contains the augmentations specific to your dataset. The yaml file will be
consumed by the `create_transforms_from_config` function defined in the

Просмотреть файл

@ -39,9 +39,9 @@ etc. Here, we need to define some important parameters:
1. The dataset we want to use for training the image encoder and the linear model we only use for evaluation of the
image encoder. In theory, they could be two different datasets.
+ [```ssl_training_dataset_name=SSLDatasetNameHiml.TCGA_CRCK```](https://github.com/microsoft/hi-ml/blob/7f4baadaa8bc0d08a4895ca896ebc3f68ea6a4f8/hi-ml-histopathology/src/histopathology/configs/SSL/CRCK_SimCLRContainer.py#L40)
+ [```ssl_training_dataset_name=SSL_Dataset_TCGA_CRCK```](https://github.com/microsoft/hi-ml/blob/7f4baadaa8bc0d08a4895ca896ebc3f68ea6a4f8/hi-ml-histopathology/src/histopathology/configs/SSL/CRCK_SimCLRContainer.py#L40)
+ [```linear_head_dataset_name=SSLDatasetNameHiml.TCGA_CRCK```](https://github.com/microsoft/hi-ml/blob/7f4baadaa8bc0d08a4895ca896ebc3f68ea6a4f8/hi-ml-histopathology/src/histopathology/configs/SSL/CRCK_SimCLRContainer.py#L41)
+ [```linear_head_dataset_name=SSL_Dataset_TCGA_CRCK```](https://github.com/microsoft/hi-ml/blob/7f4baadaa8bc0d08a4895ca896ebc3f68ea6a4f8/hi-ml-histopathology/src/histopathology/configs/SSL/CRCK_SimCLRContainer.py#L41)
1. Model checkpointing: We use [PyTorch lightning
checkpointing](https://pytorch-lightning.readthedocs.io/en/stable/common/checkpointing.html). Among others, we define

16
hi-ml-azure/run_pytest.py Normal file → Executable file
Просмотреть файл

@ -1,6 +1,14 @@
#! /usr/bin/env python
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import logging
import sys
from pathlib import Path
import time
import pytest
import param
@ -8,7 +16,6 @@ from _pytest.main import ExitCode
from azureml._restclient.constants import RunStatus
from azureml.core import Run
# Add hi-ml packages to sys.path so that AML can find them if we are using the runner directly from the git repo
himl_root = Path(__file__).resolve().parent.parent
@ -34,7 +41,7 @@ from health_azure.utils import ( # noqa: E402
is_running_in_azure_ml,
parse_arguments,
)
from health_ml.utils.common_utils import DEFAULT_AML_UPLOAD_DIR # noqa: E402
from health_ml.utils.common_utils import DEFAULT_AML_UPLOAD_DIR, DEFAULT_DOCKER_BASE_IMAGE # noqa: E402
PYTEST_RESULTS_FILE = "pytest_results.xml"
PYTEST_GPU_COVERAGE_FILE = "pytest_gpu_coverage.xml"
@ -68,6 +75,8 @@ class RunPytestConfig(param.Parameterized):
default="",
doc="A folder name that should be added to sys.path. The folder name should be relative to repository root."
)
strictly_aml_v1: bool = param.Boolean(default=True, doc="If True, use AzureML v1 SDK. If False (default), use "
"the v2 of the SDK")
def run_pytest(folder_to_test: str, pytest_mark: str, coverage_module: str) -> None:
@ -178,5 +187,8 @@ if __name__ == "__main__":
experiment_name=config.experiment,
max_run_duration=config.max_run_duration,
after_submission=pytest_after_submission_hook,
docker_base_image=DEFAULT_DOCKER_BASE_IMAGE,
strictly_aml_v1=config.strictly_aml_v1,
)
run_pytest(folder_to_test=config.folder, pytest_mark=config.mark, coverage_module=config.coverage_module)
time.sleep(10) # Give the AzureML job time to finish uploading the pytest result file.

Просмотреть файл

@ -1,12 +1,16 @@
azureml-core==1.43.0
azureml-dataset-runtime[fuse]
azureml-tensorboard==1.43.0
azureml-train-core==1.43.0
azure-ai-ml>=0.1.0b6
azureml-core>=1.42.0
azureml-dataset-runtime[fuse]>=1.42.0
azureml-mlflow>=1.42.0
azure-storage-blob==12.10.0
azureml-tensorboard>=1.42.0
azureml-train-core>=1.42.0
conda-merge>=0.1.5
mlflow>=1.29.0
pandas>=1.3.4
param>=1.12
protobuf<=3.20.1
protobuf<4.0
pysocks>=1.5.8
ruamel.yaml>=0.16.12
tensorboard>=2.6.0
typing-extensions==4.3.0
typing-extensions>=4.3.0

Просмотреть файл

@ -19,76 +19,76 @@ from setuptools import setup, find_namespace_packages # type: ignore
here = pathlib.Path(__file__).parent.resolve()
# Get the long description from the README file
long_description = (here / 'package_description.md').read_text(encoding='utf-8')
long_description = (here / "package_description.md").read_text(encoding="utf-8")
version = ''
version = ""
# If running from a GitHub Action then a standard set of environment variables will be
# populated (https://docs.github.com/en/actions/reference/environment-variables#default-environment-variables).
# In particular, GITHUB_REF is the branch or tag ref that triggered the workflow.
# If this was triggered by a tagged commit then GITHUB_REF will be: 'ref/tags/new_tag'.
# If this was triggered by a tagged commit then GITHUB_REF will be: "ref/tags/new_tag".
# Extract this tag and use it as a version string
# See also:
# https://packaging.python.org/guides/publishing-package-distribution-releases-using-github-actions-ci-cd-workflows/
# https://github.com/pypa/gh-action-pypi-publish
GITHUB_REF_TAG_COMMIT = 'refs/tags/v'
GITHUB_REF_TAG_COMMIT = "refs/tags/v"
github_ref = os.getenv('GITHUB_REF')
github_ref = os.getenv("GITHUB_REF")
if github_ref and github_ref.startswith(GITHUB_REF_TAG_COMMIT):
version = github_ref[len(GITHUB_REF_TAG_COMMIT):]
# Otherwise, if running from a GitHub Action, but not a tagged commit then GITHUB_RUN_NUMBER will be populated.
# Use this as a post release number. For example if GITHUB_RUN_NUMBER = 124 then the version string will be
# '99.99.post124'. Although this is discouraged, see:
# "99.99.post124". Although this is discouraged, see:
# https://www.python.org/dev/peps/pep-0440/#post-releases
# it is necessary here to avoid duplicate packages in Test.PyPI.
if not version:
build_number = os.getenv('GITHUB_RUN_NUMBER')
build_number = os.getenv("GITHUB_RUN_NUMBER")
if build_number:
# In github workflows, tests for hi-ml pull in hi-ml-azure as a dependency. Usually, we have a condition like
# hi-ml-azure>=0.1.5. This means that a package version from PyPi would trump the local wheels. For this reason,
# use an extremely large version number to give the local wheel priority.
version = '99.991.post' + build_number
version = "99.991.post" + build_number
else:
default_random_version_number = floor(random() * 10_000_000_000)
version = f'99.991.post{str(default_random_version_number)}'
version = f"99.991.post{str(default_random_version_number)}"
(here / 'package_name.txt').write_text('hi-ml-azure')
(here / 'latest_version.txt').write_text(version)
(here / "package_name.txt").write_text("hi-ml-azure")
(here / "latest_version.txt").write_text(version)
# Read run_requirements.txt to get install_requires
install_requires = (here / 'run_requirements.txt').read_text().split("\n")
install_requires = (here / "run_requirements.txt").read_text().split("\n")
# Remove any whitespace and blank lines
install_requires = [line.strip() for line in install_requires if line.strip()]
description = 'Microsoft Health Intelligence package to elevate and monitor scripts to an AzureML workspace'
description = "Microsoft Health Futures package to elevate and monitor scripts to an AzureML workspace"
setup(
name='hi-ml-azure',
name="hi-ml-azure",
version=version,
description=description,
long_description=long_description,
long_description_content_type='text/markdown',
url='https://github.com/microsoft/hi-ml',
author="Microsoft Research Cambridge InnerEye Team ",
long_description_content_type="text/markdown",
url="https://github.com/microsoft/hi-ml",
author="Biomedical Imaging Team @ Microsoft Health Futures",
author_email="innereyedev@microsoft.com",
classifiers=[
'Development Status :: 3 - Alpha',
'Intended Audience :: Science/Research',
"Development Status :: 3 - Alpha",
"Intended Audience :: Science/Research",
"Topic :: Scientific/Engineering :: Medical Science Apps.",
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.7'
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.7"
],
keywords='InnerEye, HealthIntelligence, AzureML',
license='MIT License',
keywords="Health Futures, Health Intelligence, AzureML",
license="MIT License",
packages=find_namespace_packages(where="src"),
package_dir={"": "src"},
include_package_data=True,
install_requires=install_requires,
entry_points={
'console_scripts': [
'himl-tb = health_azure.himl_tensorboard:main',
'himl-download = health_azure.himl_download:main'
"console_scripts": [
"himl-tb = health_azure.himl_tensorboard:main",
"himl-download = health_azure.himl_download:main"
]
}
)

Просмотреть файл

@ -5,17 +5,28 @@
import logging
import tempfile
from pathlib import Path
from typing import List, Optional, Sequence, Tuple, Union
from typing import Dict, List, Optional, Sequence, Tuple, Union
from azureml.core import Dataset, Datastore, Workspace
from azure.ai.ml import MLClient
from azure.ai.ml.entities import Data
from azure.ai.ml.entities import Datastore as V2Datastore
from azure.ai.ml.constants import AssetTypes
from azure.ai.ml.operations import DatastoreOperations
from azure.core.exceptions import HttpResponseError, ResourceNotFoundError
from azureml.core import Dataset, Workspace, Datastore
from azureml.data import FileDataset, OutputFileDatasetConfig
from azureml.data.azure_storage_datastore import AzureBlobDatastore
from azureml.data.dataset_consumption_config import DatasetConsumptionConfig
from azureml.dataprep.fuse.daemon import MountContext
from azureml.exceptions._azureml_exception import UserErrorException
from health_azure.utils import PathOrString, get_workspace
from health_azure.utils import PathOrString, get_workspace, get_ml_client
def get_datastore(workspace: Workspace, datastore_name: str) -> Datastore:
V1OrV2DataType = Union[FileDataset, Data]
def get_datastore(workspace: Workspace, datastore_name: str) -> Union[AzureBlobDatastore, V2Datastore]:
"""
Retrieves a datastore of a given name from an AzureML workspace. The datastore_name argument can be omitted if
the workspace only contains a single datastore. Raises a ValueError if there is no datastore of the given name.
@ -24,45 +35,199 @@ def get_datastore(workspace: Workspace, datastore_name: str) -> Datastore:
:param datastore_name: The name of the datastore to retrieve.
:return: An AzureML datastore.
"""
datastores = workspace.datastores
existing_stores = list(datastores.keys())
if not datastore_name:
def _retrieve_v1_datastore(datastores: Dict[str, Datastore], datastore_name: str) -> Datastore:
# First check if there is only one datastore, which is then obviously unique.
# Only then try to use the default datastore, because there may not be a default set.
if len(existing_stores) == 1:
return datastores[existing_stores[0]]
datastore = workspace.get_default_datastore()
logging.info(f"Using the workspace default datastore {datastore.name} to access datasets.")
existing_stores = list(datastores.keys())
if not datastore_name:
if len(existing_stores) == 1:
return datastores[existing_stores[0]]
datastore = workspace.get_default_datastore()
logging.info(f"Using the workspace default datastore {datastore.name} to access datasets.")
return datastore
if datastore_name in datastores:
return datastores[datastore_name]
raise ValueError(f"Datastore \"{datastore_name}\" was not found in the \"{workspace.name}\" workspace. "
f"Existing datastores: {existing_stores}")
def _retrieve_v2_datastore(datastores: DatastoreOperations, datastore_name: str) -> V2Datastore:
existing_stores = list(datastores.list())
if not datastore_name:
if len(existing_stores) == 1:
return existing_stores[0]
datastore = datastores.get_default()
logging.info(f"Using the workspace default datastore {datastore.name} to access datasets.")
return datastore
try:
datastore = datastores.get(datastore_name)
except ResourceNotFoundError:
raise ValueError(f"Datastore \"{datastore_name}\" was not found in the workspace")
return datastore
if datastore_name in datastores:
return datastores[datastore_name]
raise ValueError(f"Datastore \"{datastore_name}\" was not found in the \"{workspace.name}\" workspace. "
f"Existing datastores: {existing_stores}")
datastores = workspace.datastores
if isinstance(datastores, DatastoreOperations):
return _retrieve_v2_datastore(datastores, datastore_name)
elif isinstance(datastores, dict):
return _retrieve_v1_datastore(datastores, datastore_name)
else:
raise ValueError(f"Unrecognised type for datastores: {type(datastores)}")
def get_or_create_dataset(workspace: Workspace, datastore_name: str, dataset_name: str) -> FileDataset:
def _retrieve_v1_dataset(dataset_name: str, workspace: Workspace) -> Optional[FileDataset]:
"""
Retrieve an Azure ML v1 Dataset if it exists, otherwise return None
:param dataset_name: The name of the Dataset to look for.
:param workspace: An Azure ML Workspace object for retrieving the Dataset.
:return: A Dataset object if it is found, else None.
"""
logging.info(f"Trying to retrieve AzureML Dataset '{dataset_name}'")
azureml_dataset = Dataset.get_by_name(workspace, name=dataset_name)
return azureml_dataset
def _create_v1_dataset(datastore_name: str, dataset_name: str, workspace: Workspace
) -> FileDataset:
"""
Create a v1 Dataset in the specified Datastore
:param datastore_name: The AML Datastore to create the Dataset in.
:param dataset_name: The name of the Dataset to create.
:param workspace: An AML Workspace object.
:return: An Azure ML (v1) FileDataset object.
"""
if not dataset_name:
raise ValueError(f"Cannot create dataset without a valid dataset name (received '{dataset_name}')")
datastore = get_datastore(workspace, datastore_name)
assert isinstance(datastore, AzureBlobDatastore)
logging.info(f"Creating a new dataset from data in folder '{dataset_name}' in the datastore")
# Ensure that there is a / at the end of the file path, otherwise folder that share a prefix could create
# trouble (for example, folders foo and foo_bar exist, and I'm trying to create a dataset from "foo")
azureml_dataset = Dataset.File.from_files(path=(datastore, dataset_name + "/"))
logging.info("Registering the dataset for future use.")
azureml_dataset.register(workspace, name=dataset_name)
return azureml_dataset
def _get_or_create_v1_dataset(datastore_name: str, dataset_name: str, workspace: Workspace) -> Dataset:
"""
Attempt to retrieve a v1 Dataset object and return that, otherwise attempt to create and register
a v1 Dataset and return that.
:param datastore_name: The name of the Datastore to either retrieve or create and register the Dataset in.
:param dataset_name: The name of the Dataset to be retrieved or registered.
:param workspace: An Azure ML Workspace object.
:return: An Azure ML Dataset object with the provided dataset name, in the provided datastore.
"""
try:
azureml_dataset = _retrieve_v1_dataset(dataset_name, workspace)
except UserErrorException:
azureml_dataset = _create_v1_dataset(datastore_name, dataset_name, workspace)
return azureml_dataset
def _retrieve_v2_dataset(dataset_name: str, ml_client: MLClient) -> Data:
"""
Attempt to retrieve a v2 Data Asset using a provided Azure ML Workspace connection. If
no Data asset can be found with a matching name, the underlying code will raise an Exception
:param dataset_name: The name of the dataset to look for.
:param ml_client: An Azure MLClient object for interacting with Azure resources.
:return: An Azure Data asset representing the dataset if found, otherwise an Exception will be raised.
"""
aml_data = ml_client.data.get(name=dataset_name)
return aml_data
def _create_v2_dataset(datastore_name: str, dataset_name: str, ml_client: MLClient) -> Data:
"""
Create or update a v2 Data asset in the specified Datastore
:param datastore_name: The name of the datastore in which to create or update the Data asset.
:param dataset_name: The name of the dataset to be created.
:param ml_client: An Azure MLClient object for interacting with Azure resources.
:raises ValueError: If no datastore name is provided to define where to create the data
:return: The created or updated Data asset.
"""
if not dataset_name:
raise ValueError(f"Cannot create data asset without a valid dataset name (received {dataset_name})")
if not datastore_name:
default_datastore = ml_client.datastores.get_default()
datastore_name = default_datastore.name
logging.info(f"Creating a new Data asset from data in folder '{dataset_name}' in the datastore '{datastore_name}'")
azureml_data_asset = Data(
path=f"azureml://datastores/{datastore_name}/paths/{dataset_name}/",
type=AssetTypes.URI_FOLDER,
description="<description>",
name=dataset_name,
version=None
)
ml_client.data.create_or_update(azureml_data_asset)
return azureml_data_asset
def _get_or_create_v2_dataset(datastore_name: str, dataset_name: str, ml_client: MLClient) -> Data:
"""
Attempt to retrieve a v2 Dataset object and return that, otherwise attempt to create and register
a v2 Dataset and return that.
:param datastore_name: The name of the Datastore to either retrieve or create and register the Data asset in.
:param dataset_name: The name of the Data asset to be retrieved or registered.
:param ml_client: An Azure MLClient object for interacting with Azure resources.
:return: An Azure Data asset object with the provided dataset name, in the provided datastore
"""
try:
azureml_dataset = _retrieve_v2_dataset(dataset_name, ml_client)
except Exception:
azureml_dataset = _create_v2_dataset(datastore_name, dataset_name, ml_client)
return azureml_dataset
def get_or_create_dataset(datastore_name: str,
dataset_name: str,
workspace: Workspace,
strictly_aml_v1: bool,
ml_client: Optional[MLClient] = None,
) -> V1OrV2DataType:
"""
Looks in the AzureML datastore for a dataset of the given name. If there is no such dataset, a dataset is
created and registered, assuming that the files are in a folder that has the same name as the dataset.
For example, if dataset_name is 'foo', then the 'foo' dataset should be pointing to the folder
<container_root>/datasets/dataset_name/
<container_root>/datasets/dataset_name/.
If the command line arg to strictly use AML SDK v1 is set to True, will attempt to retrieve a dataset using
v1 of the SDK. Otherwise, will attempt to use v2 of the SDK. If no data of this name is found in the v2 datastore,
will attempt to create it, but if the data container provided is v1 version, will fall back to using the
v1 SDK to create and register this dataset.
:param datastore_name: The name of the datastore in which to look for, or create and register, the dataset.
:param dataset_name: The name of the dataset to find or create.
:param workspace: An AML Workspace object for interacting with AML v1 datastores.
:param strictly_aml_v1: If True, use Azure ML SDK v1 to attempt to find or create and reigster the dataset.
Otherwise, attempt to use Azure ML SDK v2.
:param ml_client: An optional MLClient object for interacting with AML v2 datastores.
"""
if not dataset_name:
raise ValueError("No dataset name provided.")
try:
logging.info(f"Trying to retrieve AzureML Dataset '{dataset_name}'")
azureml_dataset = Dataset.get_by_name(workspace, name=dataset_name)
logging.info("Dataset found.")
except Exception:
logging.info(f"Retrieving datastore '{datastore_name}' from AzureML workspace")
datastore = get_datastore(workspace, datastore_name)
logging.info(f"Creating a new dataset from data in folder '{dataset_name}' in the datastore")
# Ensure that there is a / at the end of the file path, otherwise folder that share a prefix could create
# trouble (for example, folders foo and foo_bar exist, and I'm trying to create a dataset from "foo")
azureml_dataset = Dataset.File.from_files(path=(datastore, dataset_name + "/"))
logging.info("Registering the dataset for future use.")
azureml_dataset.register(workspace, name=dataset_name)
return azureml_dataset
if strictly_aml_v1:
aml_dataset = _get_or_create_v1_dataset(datastore_name, dataset_name, workspace)
return aml_dataset
else:
try:
ml_client = get_ml_client(ml_client=ml_client)
aml_dataset = _get_or_create_v2_dataset(datastore_name, dataset_name, ml_client)
except HttpResponseError as e:
if "Cannot create v2 Data Version in v1 Data Container" in e.message:
logging.info("This appears to be a v1 Data Container. Reverting to API v1 to create this Dataset")
aml_dataset = _get_or_create_v1_dataset(datastore_name, dataset_name, workspace)
return aml_dataset
def _input_dataset_key(index: int) -> str:
@ -118,14 +283,22 @@ class DatasetConfig:
raise ValueError("Can't mount or download a dataset to the current working directory.")
self.local_folder = Path(local_folder) if local_folder else None
def to_input_dataset_local(self, workspace: Optional[Workspace]) -> Tuple[Path, Optional[MountContext]]:
def to_input_dataset_local(self,
strictly_aml_v1: bool,
workspace: Workspace = None,
ml_client: Optional[MLClient] = None,
) -> Tuple[Optional[Path], Optional[MountContext]]:
"""
Return a local path to the dataset when outside of an AzureML run.
If local_folder is supplied, then this is assumed to be a local dataset, and this is returned.
Otherwise the dataset is mounted or downloaded to either the target folder or a temporary folder and that is
returned.
returned. If self.name refers to a v2 dataset, it is not possible to mount the data here,
therefore a tuple of Nones will be returned.
:param workspace: The AzureML workspace to read from.
:param strictly_aml_v1: If True, use Azure ML SDK v1 to attempt to find or create and reigster the dataset.
Otherwise, attempt to use Azure ML SDK v2.
:param ml_client: An Azure MLClient object for interacting with Azure resources.
:return: Tuple of (path to dataset, optional mountcontext)
"""
status = f"Dataset '{self.name}' will be "
@ -138,55 +311,70 @@ class DatasetConfig:
if workspace is None:
raise ValueError(f"Unable to make dataset '{self.name} available for a local run because no AzureML "
"workspace has been provided. Provide a workspace, or set a folder for local execution.")
azureml_dataset = get_or_create_dataset(workspace=workspace,
ml_client=ml_client,
dataset_name=self.name,
datastore_name=self.datastore)
target_path = self.target_folder or Path(tempfile.mkdtemp())
use_mounting = self.use_mounting if self.use_mounting is not None else False
if use_mounting:
status += f"mounted at {target_path}"
datastore_name=self.datastore,
strictly_aml_v1=strictly_aml_v1)
if isinstance(azureml_dataset, FileDataset):
target_path = self.target_folder or Path(tempfile.mkdtemp())
use_mounting = self.use_mounting if self.use_mounting is not None else False
if use_mounting:
status += f"mounted at {target_path}"
mount_context = azureml_dataset.mount(mount_point=str(target_path)) # type: ignore
result = target_path, mount_context
else:
status += f"downloaded to {target_path}"
azureml_dataset.download(target_path=str(target_path), overwrite=False) # type: ignore
result = target_path, None
print(status)
mount_context = azureml_dataset.mount(mount_point=str(target_path))
result = target_path, mount_context
return result
else:
status += f"downloaded to {target_path}"
print(status)
azureml_dataset.download(target_path=str(target_path), overwrite=False)
result = target_path, None
return result
return None, None
def to_input_dataset(self,
dataset_index: int,
workspace: Workspace,
dataset_index: int) -> DatasetConsumptionConfig:
strictly_aml_v1: bool,
ml_client: Optional[MLClient] = None,
) -> Optional[DatasetConsumptionConfig]:
"""
Creates a configuration for using an AzureML dataset inside of an AzureML run. This will make the AzureML
dataset with given name available as a named input, using INPUT_0 as the key for dataset index 0.
:param workspace: The AzureML workspace to read from.
:param dataset_index: Suffix for using datasets as named inputs, the dataset will be marked INPUT_{index}
:param strictly_aml_v1: If True, use Azure ML SDK v1. Otherwise, attempt to use Azure ML SDK v2.
:param ml_client: An Azure MLClient object for interacting with Azure resources.
"""
status = f"In AzureML, dataset {self.name} (index {dataset_index}) will be "
azureml_dataset = get_or_create_dataset(workspace=workspace,
ml_client=ml_client,
dataset_name=self.name,
datastore_name=self.datastore)
named_input = azureml_dataset.as_named_input(_input_dataset_key(index=dataset_index))
datastore_name=self.datastore,
strictly_aml_v1=strictly_aml_v1)
# If running on windows then self.target_folder may be a WindowsPath, make sure it is
# in posix format for Azure.
path_on_compute = self.target_folder.as_posix() if self.target_folder is not None else None
use_mounting = False if self.use_mounting is None else self.use_mounting
if use_mounting:
status += "mounted at "
result = named_input.as_mount(path_on_compute)
if isinstance(azureml_dataset, FileDataset):
named_input = azureml_dataset.as_named_input(_input_dataset_key(index=dataset_index)) # type: ignore
path_on_compute = self.target_folder.as_posix() if self.target_folder is not None else None
if use_mounting:
status += "mounted at "
result = named_input.as_mount(path_on_compute)
else:
status += "downloaded to "
result = named_input.as_download(path_on_compute)
if path_on_compute:
status += f"{path_on_compute}."
else:
status += "a randomly chosen folder."
print(status)
return result
else:
status += "downloaded to "
result = named_input.as_download(path_on_compute)
if path_on_compute:
status += f"{path_on_compute}."
else:
status += "a randomly chosen folder."
print(status)
return result
return None
def to_output_dataset(self,
workspace: Workspace,
@ -316,8 +504,10 @@ def find_workspace_for_local_datasets(aml_workspace: Optional[Workspace],
def setup_local_datasets(dataset_configs: List[DatasetConfig],
strictly_aml_v1: bool,
aml_workspace: Optional[Workspace] = None,
workspace_config_path: Optional[Path] = None
ml_client: Optional[MLClient] = None,
workspace_config_path: Optional[Path] = None,
) -> Tuple[List[Optional[Path]], List[MountContext]]:
"""
When running outside of AzureML, setup datasets to be used locally.
@ -330,16 +520,17 @@ def setup_local_datasets(dataset_configs: List[DatasetConfig],
to pass it in as a parameter.
:param workspace_config_path: The 2nd option is to specify the path to the config.json file downloaded from the
Azure portal from which we can retrieve the existing Workspace.
:param dataset_configs: List of DatasetConfig describing the input datasets.
:param dataset_configs: List of DatasetConfig describing the input data assets.
:param strictly_aml_v1: If True, use Azure ML SDK v1. Otherwise, attempt to use Azure ML SDK v2.
:param ml_client: An MLClient object for interacting with AML v2 datastores.
:return: Pair of: list of optional paths to the input datasets, list of mountcontexts, one for each mounted dataset.
"""
workspace = find_workspace_for_local_datasets(aml_workspace, workspace_config_path, dataset_configs)
mounted_input_datasets: List[Optional[Path]] = []
mount_contexts: List[MountContext] = []
for d in dataset_configs:
target_path, mount_context = d.to_input_dataset_local(workspace)
for data_config in dataset_configs:
target_path, mount_context = data_config.to_input_dataset_local(strictly_aml_v1, workspace, ml_client)
mounted_input_datasets.append(target_path)

Просмотреть файл

@ -11,14 +11,21 @@ See examples/elevate_this.py for a very simple 'hello world' example of use.
import logging
import os
import re
import sys
import warnings
from argparse import ArgumentParser
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Dict, Generator, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
from azure.ai.ml import MLClient, Input, Output, command
from azure.ai.ml.constants import AssetTypes, InputOutputModes
from azure.ai.ml.entities import Data, Job, Command, Sweep
from azure.ai.ml.entities import Environment as EnvironmentV2
from azure.ai.ml.sweep import Choice
from azureml._base_sdk_common import user_agent
from azureml.core import ComputeTarget, Environment, Experiment, Run, RunConfiguration, ScriptRunConfig, Workspace
from azureml.core.runconfig import DockerConfiguration, MpiConfiguration
@ -31,9 +38,12 @@ from health_azure.amulet import (ENV_AMLT_DATAREFERENCE_DATA, ENV_AMLT_DATAREFER
from health_azure.utils import (create_python_environment, create_run_recovery_id, find_file_in_parent_to_pythonpath,
is_run_and_child_runs_completed, is_running_in_azure_ml, register_environment,
run_duration_string_to_seconds, to_azure_friendly_string, RUN_CONTEXT, get_workspace,
PathOrString, DEFAULT_ENVIRONMENT_VARIABLES)
from health_azure.datasets import (DatasetConfig, StrOrDatasetConfig, _input_dataset_key, _output_dataset_key,
_replace_string_datasets, setup_local_datasets)
PathOrString, DEFAULT_ENVIRONMENT_VARIABLES, get_ml_client,
create_python_environment_v2, register_environment_v2, V2_INPUT_DATASET_PATTERN,
V2_OUTPUT_DATASET_PATTERN)
from health_azure.datasets import (DatasetConfig, StrOrDatasetConfig, setup_local_datasets,
_input_dataset_key, _output_dataset_key, _replace_string_datasets)
logger = logging.getLogger('health_azure')
logger.setLevel(logging.DEBUG)
@ -47,6 +57,13 @@ RUN_RECOVERY_FILE = "most_recent_run.txt"
SDK_NAME = "innereye"
SDK_VERSION = "2.0"
# hyperparameter search args
PARAM_SAMPLING_ARG = "parameter_sampling"
MAX_TOTAL_TRIALS_ARG = "max_total_trials"
PRIMARY_METRIC_ARG = "primary_metric"
SAMPLING_ALGORITHM_ARG = "sampling_algorithm"
GOAL_ARG = "goal"
@dataclass
class AzureRunInfo:
@ -215,7 +232,9 @@ def create_run_configuration(workspace: Workspace,
if input_datasets or output_datasets:
inputs, outputs = convert_himl_to_azureml_datasets(cleaned_input_datasets=input_datasets or [],
cleaned_output_datasets=output_datasets or [],
workspace=workspace)
workspace=workspace,
strictly_aml_v1=True
)
run_config.data = inputs
run_config.output_data = outputs
@ -256,6 +275,30 @@ def create_grid_hyperdrive_config(values: List[str],
)
def create_grid_hyperparam_args_v2(values: List[Any],
argument_name: str,
metric_name: str) -> Dict[str, Any]:
"""
Create a dictionary of arguments to create an Azure ML v2 SDK Sweep job.
:param values: The list of values to try for the commandline argument given by `argument_name`.
:param argument_name: The name of the commandline argument that each of the child runs gets, to
indicate which value they should work on.
:param metric_name: The name of the metric that the sweep job will compare runs by. Please note that it is
your responsibility to make sure a metric with this name is logged to the Run in your training script
:return: A dictionary of arguments and values to pass in to the command job.
"""
param_sampling = {argument_name: Choice(values)}
hyperparam_args = {
MAX_TOTAL_TRIALS_ARG: len(values),
PARAM_SAMPLING_ARG: param_sampling,
SAMPLING_ALGORITHM_ARG: "grid",
PRIMARY_METRIC_ARG: metric_name,
GOAL_ARG: "Minimize"
}
return hyperparam_args
def create_crossval_hyperdrive_config(num_splits: int,
cross_val_index_arg_name: str = "crossval_index",
metric_name: str = "val/loss") -> HyperDriveConfig:
@ -276,6 +319,24 @@ def create_crossval_hyperdrive_config(num_splits: int,
metric_name=metric_name)
def create_crossval_hyperparam_args_v2(num_splits: int,
cross_val_index_arg_name: str = "crossval_index",
metric_name: str = "val/loss") -> Dict[str, Any]:
"""
Create a dictionary of arguments to create an Azure ML v2 SDK Sweep job.
:param num_splits: The number of splits for k-fold cross validation
:param cross_val_index_arg_name: The name of the commandline argument that each of the child runs gets, to
indicate which split they should work on.
:param metric_name: The name of the metric that the HyperDriveConfig will compare runs by. Please note that it is
your responsibility to make sure a metric with this name is logged to the Run in your training script
:return: A dictionary of arguments and values to pass in to the command job.
"""
return create_grid_hyperparam_args_v2(values=list(map(str, range(num_splits))),
argument_name=cross_val_index_arg_name,
metric_name=metric_name)
def create_script_run(snapshot_root_directory: Optional[Path] = None,
entry_script: Optional[PathOrString] = None,
script_params: Optional[List[str]] = None) -> ScriptRunConfig:
@ -317,12 +378,196 @@ def create_script_run(snapshot_root_directory: Optional[Path] = None,
arguments=script_params)
def _generate_input_dataset_command(input_datasets_v2: Dict[str, Input]) -> str:
"""
Generate command line arguments to pass AML v2 data assets into a script
:param input_datasets_v2: A dictionary of Input objects that have been passed into the AML command
:return: A string representing the input datasets that the script should expect
"""
input_cmd = ""
for i, (input_data_name, input_dataset_v2) in enumerate(input_datasets_v2.items()):
input_name = f"INPUT_{i}"
input_str = "${{inputs." + f"{input_name}" + "}}"
input_cmd += f" --{input_name}={input_str}"
return input_cmd
def _generate_output_dataset_command(output_datasets_v2: Dict[str, Output]) -> str:
"""
Generate command line arguments to pass AML v2 outputs into a script
:param output_datasets_v2: A dictionary of Output objects that have been passed into the AML command
:return: A string representing the output values that the script should expect
"""
output_cmd = ""
for i, (output_data_name, output_dataset_v2) in enumerate(output_datasets_v2.items()):
output_name = f"OUTPUT_{i}"
output_str = "${{outputs." + f"{output_name}" + "}}"
output_cmd += f" --{output_name}={output_str}"
return output_cmd
def submit_run_v2(workspace: Optional[Workspace],
experiment_name: str,
environment: EnvironmentV2,
input_datasets_v2: Optional[Dict[str, Input]] = None,
output_datasets_v2: Optional[Dict[str, Output]] = None,
snapshot_root_directory: Optional[Path] = None,
entry_script: Optional[PathOrString] = None,
script_params: Optional[List[str]] = None,
compute_target: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
wait_for_completion: bool = False,
wait_for_completion_show_output: bool = False,
workspace_config_path: Optional[PathOrString] = None,
ml_client: Optional[MLClient] = None,
hyperparam_args: Optional[Dict[str, Any]] = None) -> Job:
"""
Starts a v2 AML Job on a given workspace by submitting a command
:param workspace: The AzureML workspace to use.
:param experiment_name: The name of the experiment that will be used or created. If the experiment name contains
characters that are not valid in Azure, those will be removed.
:param environment: An AML v2 Environment object.
:param input_datasets_v2: An optional dictionary of Inputs to pass in to the command.
:param output_datasets_v2: An optional dictionary of Outputs to pass in to the command.
:param snapshot_root_directory: The directory that contains all code that should be packaged and sent to AzureML.
All Python code that the script uses must be copied over.
:param entry_script: The script that should be run in AzureML.
:param script_params: A list of parameter to pass on to the script as it runs in AzureML.
:param compute_target: Optional name of a compute target in Azure ML to submit the job to. If None, will run
locally.
:param tags: A dictionary of string key/value pairs, that will be added as metadata to the run. If set to None,
a default metadata field will be added that only contains the commandline arguments that started the run.
:param wait_for_completion: If False (the default) return after the run is submitted to AzureML, otherwise wait for
the completion of this run (if True).
:param wait_for_completion_show_output: If wait_for_completion is True this parameter indicates whether to show the
run output on sys.stdout.
:param workspace_config_path: If not provided with an AzureML Workspace, then load one given the information in this
config
:param ml_client: An Azure MLClient object for interacting with Azure resources.
:param hyperparam_args: A dictionary of hyperparameter search args to pass into a sweep job.
:return: An AzureML Run object.
"""
if ml_client is None:
if workspace is not None:
ml_client = get_ml_client(
subscription_id=workspace.subscription_id,
resource_group=workspace.resource_group,
workspace_name=workspace.name
)
elif workspace_config_path is not None:
ml_client = get_ml_client(workspace_config_path=workspace_config_path)
else:
raise ValueError("Either workspace or workspace_config_path must be specified to connect to the Workspace")
assert compute_target is not None, "No compute_target has been provided"
assert entry_script is not None, "No entry_script has been provided"
snapshot_root_directory = snapshot_root_directory or Path.cwd()
root_dir = Path(snapshot_root_directory)
entry_script = Path(entry_script).relative_to(root_dir).as_posix()
script_params = script_params or []
args = [p for p in script_params if "conda_env" not in p]
arg_str = " ".join(args)
cmd = "python " + str(entry_script) + " " + arg_str
if input_datasets_v2:
cmd += _generate_input_dataset_command(input_datasets_v2)
else:
input_datasets_v2 = {}
if output_datasets_v2:
cmd += _generate_output_dataset_command(output_datasets_v2)
else:
output_datasets_v2 = {}
job_to_submit: Union[Command, Sweep]
if hyperparam_args:
param_sampling = hyperparam_args[PARAM_SAMPLING_ARG]
for sample_param, choices in param_sampling.items():
input_datasets_v2[sample_param] = choices.values[0]
cmd += f" --{sample_param}=" + "${{inputs." + sample_param + "}}"
command_job = command(
code=str(snapshot_root_directory),
command=cmd,
inputs=input_datasets_v2,
outputs=output_datasets_v2,
environment=environment.name + "@latest",
compute=compute_target,
experiment_name=experiment_name,
environment_variables={
"JOB_EXECUTION_MODE": "Basic",
}
)
del hyperparam_args[PARAM_SAMPLING_ARG]
# override command with parameter expressions
command_job = command_job(
**param_sampling,
)
job_to_submit = command_job.sweep(
compute=compute_target, # AML docs suggest setting this here although already passed to command
**hyperparam_args
)
# AML docs state to reset certain properties here which aren't picked up from the
# underlying command such as experiment name and max_total_trials
job_to_submit.experiment_name = experiment_name
job_to_submit.set_limits(max_total_trials=hyperparam_args.get(MAX_TOTAL_TRIALS_ARG, None))
else:
job_to_submit = command(
code=str(snapshot_root_directory),
command=cmd,
inputs=input_datasets_v2,
outputs=output_datasets_v2,
environment=environment.name + "@latest",
compute=compute_target,
experiment_name=experiment_name,
environment_variables={
"JOB_EXECUTION_MODE": "Basic",
}
)
returned_job = ml_client.jobs.create_or_update(job_to_submit)
logging.info(f"URL to job: {returned_job.services['Studio'].endpoint}") # type: ignore
return returned_job
def download_job_outputs_logs(ml_client: MLClient,
job_name: str,
file_to_download_path: str = "",
download_dir: Optional[PathOrString] = None) -> None:
"""
Download output files from an mlflow job. Outputs will be downloaded to a folder named
`<download_dir>/<job_name>` where download_dir is either provided to this function,
or is "outputs". If a single file is required, the path to this file within the job can
be specified with 'file_to_download_path'
:param ml_client: An MLClient object.
:param job_name: The name (id) of the job to download output files from.
:param file_to_download_path: An optional path to a single file/folder to download.
:param download_dir: An optional folder into which to download the run files.
"""
download_dir = Path(download_dir) if download_dir else Path("outputs")
download_dir = download_dir / job_name
ml_client.jobs.download(job_name, output_name=file_to_download_path, download_path=download_dir)
def submit_run(workspace: Workspace,
experiment_name: str,
script_run_config: Union[ScriptRunConfig, HyperDriveConfig],
tags: Optional[Dict[str, str]] = None,
wait_for_completion: bool = False,
wait_for_completion_show_output: bool = False, ) -> Run:
wait_for_completion_show_output: bool = False,
) -> Run:
"""
Starts an AzureML run on a given workspace, via the script_run_config.
@ -381,11 +626,59 @@ def _str_to_path(s: Optional[PathOrString]) -> Optional[Path]:
return s
def create_v2_inputs(ml_client: MLClient, input_datasets: List[DatasetConfig]) -> Dict[str, Input]:
"""
Create a dictionary of Azure ML v2 Input objects, required for passing input data in to an AML job
:param ml_client: An MLClient object.
:param input_datasets: A list of DatasetConfigs to convert to Inputs.
:return: A dictionary in the format "input_name": Input.
"""
inputs: Dict[str, Input] = {}
for i, input_dataset in enumerate(input_datasets):
input_name = f"INPUT_{i}"
version = input_dataset.version or 1
data_asset: Data = ml_client.data.get(input_dataset.name, version=str(version))
data_path = data_asset.id or ""
# Note that there are alternative formats that the input path can take, such as:
# v1_datastore_path = f"azureml://datastores/{input_dataset.datastore}/paths/<path_to_dataset>"
# v2_dataset_path = f"azureml:{input_dataset.name}:1"
inputs[input_name] = Input( # type: ignore
type=AssetTypes.URI_FOLDER,
path=data_path,
mode=InputOutputModes.MOUNT,
)
return inputs
def create_v2_outputs(output_datasets: List[DatasetConfig]) -> Dict[str, Output]:
"""
Create a dictionary of Azure ML v2 Output objects, required for passing output data in to an AML job
:param output_datasets: A list of DatasetConfigs to convert to Outputs.
:return: A dictionary in the format "output_name": Output.
"""
outputs = {}
for i, output_dataset in enumerate(output_datasets):
output_name = f"OUTPUT_{i}"
v1_datastore_path = f"azureml://datastores/{output_dataset.datastore}/paths/{output_dataset.name}"
# Note that there are alternative formats that the output path can take, such as:
# v2_data_asset_path = f"azureml:{output_dataset.name}@latest"
outputs[output_name] = Output( # type: ignore
type=AssetTypes.URI_FOLDER,
path=v1_datastore_path,
mode=InputOutputModes.DIRECT,
)
return outputs
def submit_to_azure_if_needed( # type: ignore
compute_cluster_name: str = "",
entry_script: Optional[PathOrString] = None,
aml_workspace: Optional[Workspace] = None,
workspace_config_file: Optional[PathOrString] = None,
ml_client: Optional[MLClient] = None,
snapshot_root_directory: Optional[PathOrString] = None,
script_params: Optional[List[str]] = None,
conda_environment_file: Optional[PathOrString] = None,
@ -408,7 +701,9 @@ def submit_to_azure_if_needed( # type: ignore
tags: Optional[Dict[str, str]] = None,
after_submission: Optional[Callable[[Run], None]] = None,
hyperdrive_config: Optional[HyperDriveConfig] = None,
hyperparam_args: Optional[Dict[str, Any]] = None,
create_output_folders: bool = True,
strictly_aml_v1: bool = False,
) -> AzureRunInfo: # pragma: no cover
"""
Submit a folder to Azure, if needed and run it.
@ -434,6 +729,7 @@ def submit_to_azure_if_needed( # type: ignore
to pass it in as a parameter.
:param workspace_config_file: The 2nd option is to specify the path to the config.json file downloaded from the
Azure portal from which we can retrieve the existing Workspace.
:param ml_client: An Azure MLClient object for interacting with Azure resources.
:param snapshot_root_directory: The directory that contains all code that should be packaged and sent to AzureML.
All Python code that the script uses must be copied over.
:param ignored_folders: A list of folders to exclude from the snapshot when copying it to AzureML.
@ -462,6 +758,7 @@ def submit_to_azure_if_needed( # type: ignore
will be triggered if the commandline flag '--azureml' is present in sys.argv
:param hyperdrive_config: A configuration object for Hyperdrive (hyperparameter search).
:param create_output_folders: If True (default), create folders "outputs" and "logs" in the current working folder.
:param strictly_aml_v1: If True, use Azure ML SDK v1. Otherwise, attempt to use Azure ML SDK v2.
:return: If the script is submitted to AzureML then we terminate python as the script should be executed in AzureML,
otherwise we return a AzureRunInfo object.
"""
@ -472,17 +769,18 @@ def submit_to_azure_if_needed( # type: ignore
default_datastore_name=default_datastore)
cleaned_output_datasets = _replace_string_datasets(output_datasets or [],
default_datastore_name=default_datastore)
# The present function will most likely be called from the script once it is running in AzureML.
# The '--azureml' flag will not be present anymore, but we don't want to rely on that. From Run.get_context we
# can infer if the present code is running in AzureML.
in_azure = is_running_in_azure_ml(RUN_CONTEXT)
if in_azure:
return _generate_azure_datasets(cleaned_input_datasets, cleaned_output_datasets)
if strictly_aml_v1:
return _generate_azure_datasets(cleaned_input_datasets, cleaned_output_datasets)
else:
return _generate_v2_azure_datasets(cleaned_input_datasets, cleaned_output_datasets)
# This codepath is reached when executing outside AzureML. Here we first check if a script submission to AzureML
# is necessary. If not, return to the caller for local execution.
if submit_to_azureml is None:
submit_to_azureml = AZUREML_COMMANDLINE_FLAG in sys.argv[1:]
if not submit_to_azureml:
# Set the environment variables for local execution.
environment_variables = {
@ -500,8 +798,10 @@ def submit_to_azure_if_needed( # type: ignore
logs_folder.mkdir(exist_ok=True)
mounted_input_datasets, mount_contexts = setup_local_datasets(cleaned_input_datasets,
aml_workspace,
workspace_config_path)
strictly_aml_v1,
aml_workspace=aml_workspace,
ml_client=ml_client,
workspace_config_path=workspace_config_path)
return AzureRunInfo(
input_datasets=mounted_input_datasets,
@ -518,6 +818,7 @@ def submit_to_azure_if_needed( # type: ignore
snapshot_root_directory = Path.cwd()
workspace = get_workspace(aml_workspace, workspace_config_path)
ml_client = get_ml_client(ml_client=ml_client, aml_workspace=workspace)
print(f"Loaded AzureML workspace {workspace.name}")
if conda_environment_file is None:
@ -525,47 +826,76 @@ def submit_to_azure_if_needed( # type: ignore
print(f"Using the Conda environment from this file: {conda_environment_file}")
conda_environment_file = _str_to_path(conda_environment_file)
run_config = create_run_configuration(
workspace=workspace,
compute_cluster_name=compute_cluster_name,
aml_environment_name=aml_environment_name,
conda_environment_file=conda_environment_file,
environment_variables=environment_variables,
pip_extra_index_url=pip_extra_index_url,
private_pip_wheel_path=_str_to_path(private_pip_wheel_path),
docker_base_image=docker_base_image,
docker_shm_size=docker_shm_size,
num_nodes=num_nodes,
max_run_duration=max_run_duration,
input_datasets=cleaned_input_datasets,
output_datasets=cleaned_output_datasets
)
script_run_config = create_script_run(snapshot_root_directory=snapshot_root_directory,
entry_script=entry_script,
script_params=script_params)
script_run_config.run_config = run_config
if hyperdrive_config:
config_to_submit: Union[ScriptRunConfig, HyperDriveConfig] = hyperdrive_config
config_to_submit._run_config = script_run_config
else:
config_to_submit = script_run_config
effective_experiment_name = experiment_name or Path(script_run_config.script).stem
amlignore_path = snapshot_root_directory / AML_IGNORE_FILE
lines_to_append = [str(path) for path in (ignored_folders or [])]
with append_to_amlignore(
amlignore=amlignore_path,
lines_to_append=lines_to_append):
run = submit_run(workspace=workspace,
experiment_name=effective_experiment_name,
script_run_config=config_to_submit,
tags=tags,
wait_for_completion=wait_for_completion,
wait_for_completion_show_output=wait_for_completion_show_output)
with append_to_amlignore(amlignore=amlignore_path, lines_to_append=lines_to_append):
if strictly_aml_v1:
run_config = create_run_configuration(
workspace=workspace,
compute_cluster_name=compute_cluster_name,
aml_environment_name=aml_environment_name,
conda_environment_file=conda_environment_file,
environment_variables=environment_variables,
pip_extra_index_url=pip_extra_index_url,
private_pip_wheel_path=_str_to_path(private_pip_wheel_path),
docker_base_image=docker_base_image,
docker_shm_size=docker_shm_size,
num_nodes=num_nodes,
max_run_duration=max_run_duration,
input_datasets=cleaned_input_datasets,
output_datasets=cleaned_output_datasets,
)
script_run_config = create_script_run(snapshot_root_directory=snapshot_root_directory,
entry_script=entry_script,
script_params=script_params)
script_run_config.run_config = run_config
if after_submission is not None:
if hyperdrive_config:
config_to_submit: Union[ScriptRunConfig, HyperDriveConfig] = hyperdrive_config
config_to_submit._run_config = script_run_config
else:
config_to_submit = script_run_config
effective_experiment_name = experiment_name or Path(script_run_config.script).stem
run = submit_run(workspace=workspace,
experiment_name=effective_experiment_name,
script_run_config=config_to_submit,
tags=tags,
wait_for_completion=wait_for_completion,
wait_for_completion_show_output=wait_for_completion_show_output)
else:
assert conda_environment_file is not None
environment = create_python_environment_v2(
conda_environment_file=conda_environment_file,
docker_base_image=docker_base_image
)
if entry_script is None:
entry_script = Path(sys.argv[0])
script_params = script_params or sys.argv[1:]
effective_experiment_name = experiment_name or Path(entry_script).stem
registered_env = register_environment_v2(environment, ml_client)
input_datasets_v2 = create_v2_inputs(ml_client, cleaned_input_datasets)
output_datasets_v2 = create_v2_outputs(cleaned_output_datasets)
run = submit_run_v2(workspace=workspace,
input_datasets_v2=input_datasets_v2,
output_datasets_v2=output_datasets_v2,
experiment_name=effective_experiment_name,
environment=registered_env,
snapshot_root_directory=snapshot_root_directory,
entry_script=entry_script,
script_params=script_params,
compute_target=compute_cluster_name,
tags=tags,
wait_for_completion=wait_for_completion,
wait_for_completion_show_output=wait_for_completion_show_output,
hyperparam_args=hyperparam_args
)
if after_submission is not None and strictly_aml_v1:
after_submission(run)
exit(0)
@ -584,26 +914,36 @@ def _write_run_recovery_file(run: Run) -> None:
def convert_himl_to_azureml_datasets(
cleaned_input_datasets: List[DatasetConfig],
cleaned_output_datasets: List[DatasetConfig],
workspace: Workspace) -> Tuple[Dict[str, DatasetConsumptionConfig], Dict[str, OutputFileDatasetConfig]]:
cleaned_input_datasets: List[DatasetConfig],
cleaned_output_datasets: List[DatasetConfig],
workspace: Workspace,
strictly_aml_v1: bool
) -> Tuple[Dict[str, DatasetConsumptionConfig], Dict[str, OutputFileDatasetConfig]]:
"""
Convert the cleaned input and output datasets into dictionaries of DatasetConsumptionConfigs for use in AzureML.
:param cleaned_input_datasets: The list of input DatasetConfigs
:param cleaned_output_datasets: The list of output DatasetConfigs
:param workspace: The AzureML workspace
:param strictly_aml_v1: If True, use Azure ML SDK v1 to attempt to find or create and reigster the dataset.
Otherwise, attempt to use Azure ML SDK v2.
:return: The input and output dictionaries of DatasetConsumptionConfigs.
"""
inputs = {}
for index, d in enumerate(cleaned_input_datasets):
consumption = d.to_input_dataset(workspace=workspace, dataset_index=index)
if consumption.name in inputs:
raise ValueError(f"There is already an input dataset with name '{consumption.name}' set up?")
inputs[consumption.name] = consumption
for index, input_dataset in enumerate(cleaned_input_datasets):
consumption = input_dataset.to_input_dataset(index, workspace, strictly_aml_v1=strictly_aml_v1)
if isinstance(consumption, DatasetConsumptionConfig):
data_name = consumption.name # type: ignore
if data_name in inputs:
raise ValueError(f"There is already an input dataset with name '{data_name}' set up?")
inputs[data_name] = consumption
elif isinstance(consumption, Input):
inputs[input_dataset.name] = consumption
else:
raise ValueError(f"Unrecognised input data type: {type(consumption)}")
outputs = {}
for index, d in enumerate(cleaned_output_datasets):
out = d.to_output_dataset(workspace=workspace, dataset_index=index)
for index, output_dataset in enumerate(cleaned_output_datasets):
out = output_dataset.to_output_dataset(workspace=workspace, dataset_index=index)
if out.name in outputs:
raise ValueError(f"There is already an output dataset with name '{out.name}' set up?")
outputs[out.name] = out
@ -659,6 +999,52 @@ def _generate_azure_datasets(
logs_folder=Path.cwd() / LOGS_FOLDER)
def _get_dataset_names_from_string(sys_arg: str, pattern: str) -> Path:
dataset_string = re.split(pattern, sys_arg)[-1]
dataset_path = Path(dataset_string)
return dataset_path
def _extract_v2_inputs_outputs_from_args() -> Tuple[List[Path], List[Path]]:
"""
Extract all command line arguments of the format INPUT_i=path_to_input or OUTPUT_i=path_to_output (where i is any
integer) and return a list of the Paths for each.
:return: A list of Input paths and a list of Output paths
"""
returned_input_datasets: List[Path] = []
returned_output_datasets: List[Path] = []
for sys_arg in sys.argv:
if re.match(V2_INPUT_DATASET_PATTERN, sys_arg):
returned_input_datasets += [_get_dataset_names_from_string(sys_arg, V2_INPUT_DATASET_PATTERN)]
if re.match(V2_OUTPUT_DATASET_PATTERN, sys_arg):
returned_output_datasets += [_get_dataset_names_from_string(sys_arg, V2_OUTPUT_DATASET_PATTERN)]
return returned_input_datasets, returned_output_datasets
def _generate_v2_azure_datasets(cleaned_input_datasets: List[DatasetConfig],
cleaned_output_datasets: List[DatasetConfig]) -> AzureRunInfo:
"""
Generate returned datasets when running in AzureML. Assumes this is v2 Job, so we need to get
the input datasets from the command line args
:param cleaned_input_datasets: The list of input dataset configs
:param cleaned_output_datasets: The list of output dataset configs
:return: The AzureRunInfo containing the AzureML input and output dataset lists etc.
"""
returned_input_datasets, returned_output_datasets = _extract_v2_inputs_outputs_from_args()
return AzureRunInfo(
input_datasets=returned_input_datasets, # type: ignore
output_datasets=returned_output_datasets, # type: ignore
mount_contexts=[],
run=RUN_CONTEXT,
is_running_in_azure_ml=True,
output_folder=Path.cwd() / OUTPUT_FOLDER,
logs_folder=Path.cwd() / LOGS_FOLDER)
@contextmanager
def append_to_amlignore(lines_to_append: List[str], amlignore: Optional[Path] = None) -> Generator:
"""

Просмотреть файл

@ -6,56 +6,19 @@
import param
import sys
from pathlib import Path
from typing import List
from azureml.core import Run
import health_azure.utils as azure_util
from health_azure.himl import RUN_RECOVERY_FILE
from health_azure.himl import download_job_outputs_logs
class HimlDownloadConfig(azure_util.AmlRunScriptConfig):
output_dir: Path = param.ClassSelector(class_=Path, default=Path(), instantiate=False,
output_dir: Path = param.ClassSelector(class_=Path, default=Path("outputs"), instantiate=False,
doc="Path to directory to store files downloaded from the AML Run")
config_file: Path = param.ClassSelector(class_=Path, default=None, instantiate=False,
doc="Path to config.json where Workspace name is defined. If not provided, "
"the code will try to locate a config.json file in any of the parent "
"folders of the current working directory")
prefix: str = param.String(default=None, allow_None=True, doc="Optional prefix to filter Run files by")
def retrieve_runs(download_config: HimlDownloadConfig) -> List[Run]:
"""
Retrieve a list of AML Run objects, given a HimlDownloadConfig object which contains values for either run
(one or more run ids), experiment (experiment name) or latest_run_file. If none of these are provided,
the parent directories of this script will be searched for a "most_recent_run.txt" file, and the run id will
be extracted from there, to retrieve the run object(s). If no Runs are found, a ValueError will be raised.
:param download_config: A HimlDownloadConfig object containing run information (e.g. run ids or experiment name)
:return: List of AML Run objects
"""
if download_config.run is not None:
run_ids: List[str] = download_config.run
runs = [azure_util.get_aml_run_from_run_id(r_id) for r_id in run_ids]
if len(runs) == 0:
raise ValueError(f"Did not find any runs with the given run id(s): {download_config.run}")
elif download_config.experiment is not None:
runs = azure_util.get_latest_aml_runs_from_experiment(download_config.experiment,
download_config.num_runs,
download_config.tags,
workspace_config_path=download_config.config_file)
if len(runs) == 0:
raise ValueError(f"Did not find any runs under the given experiment name: {download_config.experiment}")
else:
most_recent_run_path = download_config.latest_run_file or Path(RUN_RECOVERY_FILE)
run_or_recovery_id = azure_util.get_most_recent_run_id(most_recent_run_path)
runs = [azure_util.get_aml_run_from_run_id(run_or_recovery_id,
workspace_config_path=download_config.config_file)]
if len(runs) == 0:
raise ValueError(f"Did not find any runs with run id {run_or_recovery_id} as found in"
f" {download_config.latest_run_file}")
return runs
files_to_download: str = param.String(default=None, allow_None=True, doc="Path to the file to download")
def main() -> None: # pragma: no cover
@ -66,17 +29,17 @@ def main() -> None: # pragma: no cover
output_dir = download_config.output_dir
output_dir.mkdir(exist_ok=True)
runs = retrieve_runs(download_config)
files_to_download = download_config.files_to_download
for run in runs:
output_folder = output_dir / run.id
try: # pragma: no cover
azure_util.download_files_from_run_id(run.id, output_folder=output_folder, prefix=download_config.prefix,
workspace_config_path=download_config.config_file)
print(f"Downloaded file(s) to '{output_folder}'")
except Exception as e: # pragma: no cover
raise ValueError(f"Failed to download files from run {run.id}: {e}")
workspace = azure_util.get_workspace()
ml_client = azure_util.get_ml_client(
subscription_id=workspace.subscription_id,
resource_group=workspace.resource_group,
workspace_name=workspace.name
)
for run_id in download_config.run:
download_job_outputs_logs(ml_client, run_id, file_to_download_path=files_to_download, download_dir=output_dir)
print("Successfully downloaded output and log files")
if __name__ == "__main__": # pragma: no cover

Просмотреть файл

@ -0,0 +1,48 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
"""
Utility functions for interacting with mlflow runs
"""
from typing import Any, Dict, List
from mlflow.client import MlflowClient
from mlflow.entities import Run, Metric
def get_mlflow_run(mlflow_client: MlflowClient, mlflow_run_id: str) -> Run:
"""
Retrieve a Run from an MLFlow client
:param mlflow_client: An MLflowClient object.
:param mlflow_run_id: The id of an mlflow run to retrieve.
:return: An mlflow Run object
"""
mlflow_run = mlflow_client.get_run(mlflow_run_id)
return mlflow_run
def get_last_metrics_from_mlflow_run(mlflow_run: Run) -> Dict[str, Any]:
"""
Retrieve the last logged metrics from an mlflow Run
:param mlflow_run: the mlflow Run to retrieve metrics from
:return: A dictionary of metric_name to value
"""
metrics = mlflow_run.data.metrics
return metrics
def get_metric_from_mlflow_run(mlflow_client: MlflowClient, run_id: str, metric_name: str
) -> List[Metric]:
"""
For a given metric name, get the entire history of logged values from an mlflow Run
:param mlflow_client: An MLFlowClient object.
:param run_id: The id of the run to retrieve the metrics from.
:param metric_name: The name of the metric to retrieve values for.
:return: A list of mlflow Metric objects representing the all of the values of the given
metric throughout the run
"""
metric_history = mlflow_client.get_metric_history(run_id, metric_name)
return metric_history

Просмотреть файл

@ -27,6 +27,7 @@ from typing import (Any, Callable, DefaultDict, Dict, Generator, Iterable, List,
import conda_merge
import pandas as pd
import param
from azureml._restclient.constants import RunStatus
from azureml.core import Environment, Experiment, Run, Workspace, get_run
from azureml.core.authentication import InteractiveLoginAuthentication, ServicePrincipalAuthentication
@ -35,6 +36,15 @@ from azureml.core.run import _OfflineRun
from azureml.data.azure_storage_datastore import AzureBlobDatastore
from azureml.train.hyperdrive import HyperDriveRun
from azure.ai.ml import MLClient
from azure.ai.ml.entities import Job
from azure.ai.ml.entities import Workspace as WorkspaceV2
from azure.ai.ml.entities import Environment as EnvironmentV2
from azure.core.credentials import TokenCredential
from azure.core.exceptions import ClientAuthenticationError, ResourceNotFoundError
from azure.identity import (ClientSecretCredential, DeviceCodeCredential,
DefaultAzureCredential, InteractiveBrowserCredential)
T = TypeVar("T")
@ -96,6 +106,10 @@ DEFAULT_ENVIRONMENT_VARIABLES = {
"AZUREML_COMPUTE_USE_COMMON_RUNTIME": "false",
}
V2_INPUT_DATASET_PATTERN = r"--INPUT_\d[=| ]"
V2_OUTPUT_DATASET_PATTERN = r"--OUTPUT_\d[=| ]"
PathOrString = Union[Path, str]
@ -702,6 +716,7 @@ def get_workspace(aml_workspace: Optional[Workspace] = None, workspace_config_pa
if is_running_in_azure_ml(RUN_CONTEXT):
return RUN_CONTEXT.experiment.workspace
# If aml_workspace has been provided, use that
if aml_workspace:
return aml_workspace
@ -1040,6 +1055,22 @@ def is_conda_file_with_pip_include(conda_file: Path) -> Tuple[bool, Dict]:
return False, conda_yaml
def generate_unique_environment_name(environment_description_string: str) -> str:
"""
Generates a unique environment name beginning with "HealthML" and ending with a hash string generated
from the environment description.
:param environment_description_string: String to be hashed that should include everything that can
reasonably change between environments.
:return: A string representing the unique environment name.
"""
sha1 = hashlib.sha1(environment_description_string.encode("utf8"))
overall_hash = sha1.hexdigest()[:32]
unique_env_name = f"HealthML-{overall_hash}"
return unique_env_name
def create_python_environment(
conda_environment_file: Path,
pip_extra_index_url: str = "",
@ -1080,8 +1111,7 @@ def create_python_environment(
logging.info(f"Added add_private_pip_wheel {private_pip_wheel_path} to AzureML environment.")
# Create a name for the environment that will likely uniquely identify it. AzureML does hashing on top of that,
# and will re-use existing environments even if they don't have the same name.
# Hashing should include everything that can reasonably change. Rely on hashlib here, because the built-in
hash_string = "\n".join(
env_description_string = "\n".join(
[
yaml_contents,
docker_base_image,
@ -1096,9 +1126,7 @@ def create_python_environment(
)
# Python's hash function gives different results for the same string in different python instances,
# hence need to use hashlib
sha1 = hashlib.sha1(hash_string.encode("utf8"))
overall_hash = sha1.hexdigest()[:32]
unique_env_name = f"HealthML-{overall_hash}"
unique_env_name = generate_unique_environment_name(env_description_string)
env = Environment(name=unique_env_name)
env.python.conda_dependencies = conda_dependencies
if docker_base_image:
@ -1132,6 +1160,69 @@ def register_environment(workspace: Workspace, environment: Environment) -> Envi
return environment.register(workspace)
def create_python_environment_v2(
conda_environment_file: Path,
pip_extra_index_url: str = "",
private_pip_wheel_path: Optional[Path] = None,
docker_base_image: str = ""
) -> EnvironmentV2:
"""
Creates a description for the V2 Python execution environment in AzureML, based on the arguments.
The environment will have a name that uniquely identifies it (it is based on hashing the contents of the
Conda file, the docker base image, environment variables and private wheels.
:param docker_base_image: The Docker base image that should be used when creating a new Docker image.
:param pip_extra_index_url: If provided, use this PIP package index to find additional packages when building
the Docker image.
:param private_pip_wheel_path: If provided, add this wheel as a private package to the AzureML environment.
:param conda_environment_file: The file that contains the Conda environment definition.
:return: A v2 Azure ML Environment object
"""
yaml_contents = conda_environment_file.read_text()
environment_description_string = "\n".join(
[
yaml_contents,
docker_base_image,
# Changing the index URL can lead to differences in package version resolution
pip_extra_index_url,
# Use the path of the private wheel as a proxy. This could lead to problems if
# a new environment uses the same private wheel file name, but the wheel has different
# contents. In hi-ml PR builds, the wheel file name is unique to the build, so it
# should not occur there.
str(private_pip_wheel_path),
]
)
unique_env_name = generate_unique_environment_name(environment_description_string)
environment = EnvironmentV2(
image=docker_base_image,
name=unique_env_name + "-v2",
conda_file=conda_environment_file,
)
return environment
def register_environment_v2(environment: EnvironmentV2, ml_client: MLClient) -> EnvironmentV2:
"""
Try to get the v2 AzureML environment by name and version from the AzureML workspace. If it succeeds, return that
environment object. If that fails, register the environment with the MLClient.
:param ml_client: An AzureML MLClient object.
:param environment: An AzureML execution environment.
:return: A v2 AzureML Environment object. If the environment did already exist on the workspace, returns that,
otherwise returns the newly registered environment.
"""
try:
if environment.version:
env = ml_client.environments.get(environment.name, environment.version)
else:
env = ml_client.environments.get(environment.name, label="latest")
logging.info(f"Found a registered environment with name {environment.name}, returning that.")
except ResourceNotFoundError:
logging.info("Didn't find existing environment. Registering a new one.")
env = ml_client.environments.create_or_update(environment)
return env
def run_duration_string_to_seconds(s: str) -> Optional[int]:
"""
Parse a string that represents a timespan, and returns it converted into seconds. The string is expected to be
@ -1309,7 +1400,7 @@ def get_run_file_names(run: Run, prefix: str = "") -> List[str]:
:return: A list of paths within the Run's container
"""
all_files = run.get_file_names()
print(f"Selecting files with prefix {prefix}")
logging.info(f"Selecting files with prefix {prefix}")
return [f for f in all_files if f.startswith(prefix)] if prefix else all_files
@ -1429,12 +1520,12 @@ def download_file_if_necessary(run: Run, filename: str, output_file: Path, overw
:return: Local path to the downloaded file.
"""
if not overwrite and output_file.exists():
print("File already exists at", output_file)
logging.info(f"File already exists at {output_file}")
else:
output_file.parent.mkdir(exist_ok=True, parents=True)
_download_file_from_run(run, filename, output_file, validate_checksum=True)
assert output_file.exists()
print("File is downloaded at", output_file)
logging.info(f"File is downloaded at {output_file}")
return output_file
@ -2022,3 +2113,219 @@ def check_is_any_of(message: str, actual: Optional[str], valid: Iterable[Optiona
if actual not in valid:
all_valid = ", ".join(["<None>" if v is None else v for v in valid])
raise ValueError("{} must be one of [{}], but got: {}".format(message, all_valid, actual))
def _validate_credential(credential: TokenCredential) -> None:
"""
Validate credential by attempting to get token. If authentication has been successful, get_token
will succeed. Otherwise an exception will be raised
:param credential: The credential object to validate.
"""
credential.get_token("https://management.azure.com/.default")
def _get_legitimate_service_principal_credential(tenant_id: str, service_principal_id: str,
service_principal_password: str) -> TokenCredential:
"""
Create a ClientSecretCredential given a tenant id, service principal id and password
:param tenant_id: The Azure tenant id.
:param service_principal_id: The id of an existing Service Principal.
:param service_principal_password: The password of an existing Service Principal.
:raises ValueError: If the credential cannot be validated (i.e. authentication was unsucessful).
:return: The validated credential.
"""
cred = ClientSecretCredential(tenant_id=tenant_id,
client_id=service_principal_id,
client_secret=service_principal_password)
try:
_validate_credential(cred)
return cred
except ClientAuthenticationError as e:
raise ValueError(f"Found environment variables for {ENV_SERVICE_PRINCIPAL_ID}, "
f"{ENV_SERVICE_PRINCIPAL_PASSWORD}, and {ENV_TENANT_ID} but was "
f"not able to authenticate: {e}")
def _get_legitimate_device_code_credential() -> Optional[TokenCredential]:
"""
Create a DeviceCodeCredential for interacting with Azure resources. If the credential can't be
validated, return None.
:return: A valid Azure credential.
"""
cred = DeviceCodeCredential(timeout=60)
try:
_validate_credential(cred)
return cred
except ClientAuthenticationError:
return None
def _get_legitimate_default_credential() -> Optional[TokenCredential]:
"""
Create a DefaultAzure credential for interacting with Azure resources. If the credential can't be
validated, return None.
:return: A valid Azure credential.
"""
cred = DefaultAzureCredential(timeout=60)
try:
_validate_credential(cred)
return cred
except ClientAuthenticationError:
return None
def _get_legitimate_interactive_browser_credential() -> Optional[TokenCredential]:
"""
Create an InteractiveBrowser credential for interacting with Azure resources. If the credential can't be
validated, return None.
:return: A valid Azure credential.
"""
cred = InteractiveBrowserCredential(timeout=60)
try:
_validate_credential(cred)
return cred
except ClientAuthenticationError:
return None
def get_credential() -> Optional[TokenCredential]:
"""
Get a credential for authenticating with Azure.There are multiple ways to retrieve a credential.
If environment variables pertaining to details of a Service Principal are available, those will be used
to authenticate. If no environment variables exist, and the script is not currently
running inside of Azure ML or another Azure agent, will attempt to retrieve a credential via a
device code (which requires the user to visit a link and enter a provided code). If this fails, or if running in
Azure, DefaultAzureCredential will be used which iterates through a number of possible authentication methods
including identifying an Azure managed identity, cached credentials from VS code, Azure CLI, Powershell etc.
Otherwise returns None.
:return: Any of the aforementioned credentials if available, else None.
"""
service_principal_id = get_secret_from_environment(ENV_SERVICE_PRINCIPAL_ID, allow_missing=True)
tenant_id = get_secret_from_environment(ENV_TENANT_ID, allow_missing=True)
service_principal_password = get_secret_from_environment(ENV_SERVICE_PRINCIPAL_PASSWORD, allow_missing=True)
if service_principal_id and tenant_id and service_principal_password:
return _get_legitimate_service_principal_credential(tenant_id, service_principal_id, service_principal_password)
try:
cred = _get_legitimate_default_credential()
if cred is not None:
return cred
except ClientAuthenticationError:
cred = _get_legitimate_device_code_credential()
if cred is not None:
return cred
cred = _get_legitimate_interactive_browser_credential()
if cred is not None:
return cred
raise ValueError("Unable to generate and validate a credential. Please see Azure ML documentation"
"for instructions on diffrent options to get a credential")
def get_ml_client(ml_client: Optional[MLClient] = None,
aml_workspace: Optional[Workspace] = None,
workspace_config_path: Optional[PathOrString] = None,
subscription_id: Optional[str] = None,
resource_group: Optional[str] = None,
workspace_name: str = "",
) -> MLClient:
"""
Instantiate an MLClient for interacting with Azure resources via v2 of the Azure ML SDK.
If a ml_client is provided, return that. Otherwise, create one using workspace details
coming from either an existing Workspace object, a config.json file or passed in as an argument.
:param ml_client: An optional existing MLClient object to be returned.
:param aml_workspace: An optional Workspace object to take connection details from.
:param workspace_config_path: An optional path toa config.json file containing details of the Workspace.
:param subscription_id: An optional subscription ID.
:param resource_group: An optional resource group name.
:param workspace_name: An optional workspace name.
:return: An instance of MLClient to interact with Azure resources.
"""
if ml_client:
return ml_client
credential = get_credential()
if credential is None:
raise ValueError("Can't connect to MLClient without a valid credential")
if aml_workspace is not None:
ml_client = MLClient(
subscription_id=aml_workspace.subscription_id,
resource_group_name=aml_workspace.resource_group,
workspace_name=aml_workspace.name,
credential=credential) # type: ignore
elif workspace_config_path:
ml_client = MLClient.from_config(
credential=credential, # type: ignore
path=str(workspace_config_path))
elif subscription_id and resource_group and workspace_name:
ml_client = MLClient(
subscription_id=subscription_id,
resource_group_name=resource_group,
workspace_name=workspace_name,
credential=credential) # type: ignore
else:
try:
workspace = get_workspace()
ml_client = MLClient(
subscription_id=workspace.subscription_id,
resource_group_name=workspace.resource_group,
workspace_name=workspace.name,
credential=credential) # type: ignore
except ValueError as e:
raise ValueError(f"Couldn't connect to MLClient: {e}")
logging.info(f"Logged into AzureML workspace {ml_client.workspace_name}")
return ml_client
def retrieve_workspace_from_client(ml_client: MLClient, workspace_name: Optional[str] = None
) -> WorkspaceV2:
"""
Get a v2 Workspace object from an MLClient object. If a workspace_name is passed, will attempt
to retrieve a workspace with that name. Otherweise will use the MLClient's default workspace_name
:param ml_client: An MLClient object to retrieve the Workspace from
:param workspace_name: An optional name of the workspace to retrieve.
:return: A v2 Workspace object.
"""
if workspace_name is not None:
workspace_name = workspace_name
elif ml_client.workspace_name is not None:
workspace_name = ml_client.workspace_name
else:
workspace_name = ""
workspace = ml_client.workspaces.get(workspace_name)
return workspace
def fetch_job(ml_client: MLClient, run_id: str) -> Job:
"""
Retrieve a job with a given run_id from an MLClient
:param ml_client: An MLClient object.
:param run_id: The id of the run to retrieve.
:return: An Azure ML (v2) Job object.
"""
job = ml_client.jobs.get(run_id)
return job
def filter_v2_input_output_args(args: List[str]) -> List[str]:
"""
Filter out AML v2 Input and Output entries from a list of args. Under AML SDK v2 it is necessary to
pass input and output arguments to a script via the command line, of which there can be an unknown number.
Therefore we need to remove these from the list of args passed to the argument parsers.
:param args: A list of arguments from which to remove input and output args
:return: A filtered list of arguments, without entries in the format of INPUT_i or OUTPUT_i where i is
any integer.
"""
return [a for a in args if
not re.match(V2_INPUT_DATASET_PATTERN, a) and not re.match(V2_OUTPUT_DATASET_PATTERN, a)]

Просмотреть файл

@ -15,7 +15,7 @@ from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from unittest import mock
from unittest.mock import MagicMock, patch
from unittest.mock import DEFAULT, MagicMock, patch
from uuid import uuid4
from xmlrpc.client import Boolean
@ -26,15 +26,18 @@ import param
import pytest
from _pytest.capture import CaptureFixture
from _pytest.logging import LogCaptureFixture
from azure.identity import (ClientSecretCredential, DeviceCodeCredential, DefaultAzureCredential)
from azure.storage.blob import ContainerClient
from azureml.core import Experiment, Run, ScriptRunConfig, Workspace
from azureml.core.authentication import ServicePrincipalAuthentication
from azureml.core.environment import CondaDependencies
from azure.core.exceptions import ClientAuthenticationError, ResourceNotFoundError
from azureml.data.azure_storage_datastore import AzureBlobDatastore
import health_azure.utils as util
from health_azure.himl import AML_IGNORE_FILE, append_to_amlignore
from health_azure.utils import (ENV_MASTER_ADDR, ENV_MASTER_PORT, MASTER_PORT_DEFAULT,
PackageDependency, create_argparser)
PackageDependency, create_argparser, get_credential)
from testazure.test_himl import RunTarget, render_and_run_test_script
from testazure.utils_testazure import (DEFAULT_IGNORE_FOLDERS, DEFAULT_WORKSPACE, MockRun, change_working_directory,
himl_azure_root, repository_root)
@ -67,7 +70,7 @@ def test_find_file_in_parent_folders(caplog: LogCaptureFixture) -> None:
)
last_caplog_msg = caplog.messages[-1]
assert found_file_path == current_file_path
assert(f"Searching for file {current_file_path.name} in {himl_azure_test_root}" in last_caplog_msg)
assert f"Searching for file {current_file_path.name} in {himl_azure_test_root}" in last_caplog_msg
# Now try to search for a nonexistent path in the same folder. This should return None
nonexistent_path = himl_az_root / "idontexist.py"
@ -652,6 +655,19 @@ def test_nonexisting_amlignore(random_folder: Path) -> None:
os.chdir(cwd)
def test_generate_unique_environment_name() -> None:
dummy_env_description_string_1 = "A pretend environment description\ncontaining information about pip "
"packages\netc etc"
env_name_1 = util.generate_unique_environment_name(dummy_env_description_string_1)
assert env_name_1.startswith("HealthML-")
dummy_env_description_string_2 = "A slightly differetpretend environment description\ncontaining "
"information about pip packages\netc etc"
env_name_2 = util.generate_unique_environment_name(dummy_env_description_string_2)
assert env_name_2.startswith("HealthML-")
assert env_name_1 != env_name_2
@patch("health_azure.utils.Workspace")
def test_create_python_environment(
mock_workspace: mock.MagicMock,
@ -700,6 +716,42 @@ dependencies:
assert private_pip_wheel_url in envs_pip_packages
@patch("health_azure.utils.Workspace")
def test_create_python_environment_v2(
mock_workspace: mock.MagicMock,
random_folder: Path,
) -> None:
conda_str = """name: simple-env
dependencies:
- pip=20.1.1
- python=3.7.3
- pip:
- azureml-sdk==1.23.0
- something-else==0.1.5
- pip:
- --index-url https://test.pypi.org/simple/
- --extra-index-url https://pypi.org/simple
- hi-ml-azure
"""
conda_environment_file = random_folder / "environment.yml"
conda_environment_file.write_text(conda_str)
env = util.create_python_environment_v2(conda_environment_file=conda_environment_file)
# Check that the environment has a reasonable name. Detailed checks for uniqueness of the name follow below.
assert env.name.startswith("HealthML")
assert env.name.endswith("-v2")
assert env._conda_file_path == conda_environment_file
pip_extra_index_url = "https://where.great.packages.live/"
docker_base_image = "viennaglobal.azurecr.io/azureml/azureml_a187a87cc7c31ac4d9f67496bc9c8239"
env = util.create_python_environment_v2(
conda_environment_file=conda_environment_file,
pip_extra_index_url=pip_extra_index_url,
docker_base_image=docker_base_image)
assert env.image == docker_base_image
def test_create_environment_unique_name(random_folder: Path) -> None:
"""
Test if the name of the conda environment changes with each of the components
@ -815,6 +867,33 @@ def test_register_environment(
assert env.version == util.ENVIRONMENT_VERSION
@patch("azure.ai.ml.entities.Environment")
@patch("azure.ai.ml.MLClient")
def test_register_environment_v2(
mock_ml_client: MagicMock,
mock_environment_v2: MagicMock,
caplog: LogCaptureFixture,
) -> None:
def _mock_cant_find_env(env_name: str, label_or_version: str) -> None:
raise ResourceNotFoundError("Does not exist")
env_name = "an environment"
env_version = "environment version"
mock_ml_client.environments.get.return_value = mock_environment_v2
mock_environment_v2.name = env_name
mock_environment_v2.version = env_version
with caplog.at_level(logging.INFO): # type: ignore
_ = util.register_environment_v2(mock_environment_v2, mock_ml_client)
caplog_text = caplog.text
assert f"Found a registered environment with name {env_name}, returning that." in caplog_text
# test that log is correct when exception is triggered
mock_ml_client.environments.get.side_effect = _mock_cant_find_env
_ = util.register_environment_v2(mock_environment_v2, mock_ml_client)
caplog_text = caplog.text
assert "Didn't find existing environment. Registering a new one." in caplog_text
def test_set_environment_variables_for_multi_node(caplog: LogCaptureFixture) -> None:
# If none of AZ_BATCHAI_MPI_MASTER_NODE, AZ_BATCH_MASTER_NODE or ENV_MASTER_IP are set, should assume
# single node training job
@ -1267,16 +1346,19 @@ import sys
from pathlib import Path
from azureml.core import Run
from health_azure.utils import download_files_from_run_id""",
"body": script_body
"body": script_body,
}
# Run the script locally first, then in the cloud. In local runs, the workspace should be picked up from the
# config.json file, in AzureML runs it should be read off the run context.
render_and_run_test_script(tmp_path, RunTarget.LOCAL, extra_options, extra_args=[], expected_pass=True)
render_and_run_test_script(tmp_path, RunTarget.LOCAL, extra_options,
extra_args=[], expected_pass=True)
print("Local run finished")
render_and_run_test_script(tmp_path / "foo", RunTarget.AZUREML, extra_options, extra_args=[], expected_pass=True)
render_and_run_test_script(tmp_path / "foo", RunTarget.AZUREML, extra_options,
extra_args=[], expected_pass=True)
def test_replace_directory(tmp_path: Path) -> None:
extra_options = {
"imports": """
import sys
@ -1297,13 +1379,15 @@ from health_azure.utils import replace_directory
assert not output_dir.exists()
assert (new_output_dir / file_name).exists()
"""
""",
}
render_and_run_test_script(tmp_path, RunTarget.LOCAL, extra_options, extra_args=[], expected_pass=True)
render_and_run_test_script(tmp_path, RunTarget.LOCAL, extra_options,
extra_args=[], expected_pass=True)
print("Local run finished")
render_and_run_test_script(tmp_path / "foo", RunTarget.AZUREML, extra_options, extra_args=[], expected_pass=True)
render_and_run_test_script(tmp_path / "foo", RunTarget.AZUREML, extra_options,
extra_args=[], expected_pass=True)
def test_is_global_rank_zero() -> None:
@ -1346,17 +1430,32 @@ def test_get_run_source(dummy_recovery_id: str,
assert isinstance(script_config.run, str)
def delete_existing_blobs(datastore: AzureBlobDatastore, prefix: str) -> None:
def get_container_client(datastore: AzureBlobDatastore) -> ContainerClient:
"""Gets a ContainerClient to interact with the blobs in the given datastore.
param datastore: The datastore from which the files should be read.
"""
return datastore.blob_service.get_container_client(datastore.container_name)
def get_blobs_in_datastore(datastore: AzureBlobDatastore, prefix: str) -> List[Any]:
"""Gets all blobs in the datastore where the name starts with the given prefix.
param datastore: The datastore from which the files should be read.
param prefix: The prefix string for the files that should be returned.
"""
return list(get_container_client(datastore).list_blobs(name_starts_with=prefix))
def delete_blobs_in_datastore(datastore: AzureBlobDatastore, prefix: str) -> None:
"""Deletes all existing files in blob storage at the location that the test uses.
param datastore: The datastore from which the files should be deleted.
param prefix: The prefix string for the files that should be deleted.
"""
container = datastore.container_name
existing_blobs = list(datastore.blob_service.list_blobs(prefix=prefix,
container_name=container))
for existing_blob in existing_blobs:
datastore.blob_service.delete_blob(container_name=container, blob_name=existing_blob.name)
container_client = get_container_client(datastore)
for existing_blob in get_blobs_in_datastore(datastore, prefix):
container_client.delete_blob(existing_blob.name)
@pytest.mark.parametrize("overwrite", [True, False])
@ -1374,7 +1473,7 @@ def test_download_from_datastore(tmp_path: Path, overwrite: bool) -> None:
local_data_path.mkdir()
test_data_path_remote = "test_data/abc"
delete_existing_blobs(datastore=default_datastore, prefix=test_data_path_remote)
delete_blobs_in_datastore(datastore=default_datastore, prefix=test_data_path_remote)
try:
# Create dummy data files and upload to datastore (checking they are uploaded)
dummy_filenames = []
@ -1387,8 +1486,7 @@ def test_download_from_datastore(tmp_path: Path, overwrite: bool) -> None:
default_datastore.upload(str(local_data_path), test_data_path_remote, overwrite=False)
# Wait a bit because there seem to be spurious errors with files not yet existing at this point
time.sleep(0.1)
existing = list(default_datastore.blob_service.list_blobs(prefix=test_data_path_remote,
container_name=default_datastore.container_name))
existing = get_blobs_in_datastore(default_datastore, prefix=test_data_path_remote)
assert len(existing) == num_dummy_files
# Check that the file doesn't currently exist at download location
@ -1403,7 +1501,7 @@ def test_download_from_datastore(tmp_path: Path, overwrite: bool) -> None:
expected_download_paths = [expected_local_download_dir / dummy_filename for dummy_filename in dummy_filenames]
assert all([p.exists() for p in expected_download_paths])
finally:
delete_existing_blobs(datastore=default_datastore, prefix=test_data_path_remote)
delete_blobs_in_datastore(datastore=default_datastore, prefix=test_data_path_remote)
@pytest.mark.parametrize("overwrite", [True, False])
@ -1416,14 +1514,13 @@ def test_upload_to_datastore(tmp_path: Path, overwrite: bool) -> None:
"""
ws = DEFAULT_WORKSPACE.workspace
default_datastore: AzureBlobDatastore = ws.get_default_datastore()
container = default_datastore.container_name
dummy_file_content = "Hello world"
remote_data_dir = "test_data"
dummy_file_name = Path("abc/uploaded_file.txt")
expected_remote_path = Path(remote_data_dir) / dummy_file_name.name
delete_existing_blobs(datastore=default_datastore, prefix=str(expected_remote_path.as_posix()))
delete_blobs_in_datastore(datastore=default_datastore, prefix=str(expected_remote_path.as_posix()))
try:
# Create a dummy data file and upload to datastore
@ -1436,11 +1533,10 @@ def test_upload_to_datastore(tmp_path: Path, overwrite: bool) -> None:
# Wait a bit because there seem to be spurious errors with files not yet existing at this point
time.sleep(0.1)
existing_blobs = list(default_datastore.blob_service.list_blobs(prefix=str(expected_remote_path.as_posix()),
container_name=container))
existing_blobs = get_blobs_in_datastore(default_datastore, prefix=str(expected_remote_path.as_posix()))
assert len(existing_blobs) == 1
finally:
delete_existing_blobs(datastore=default_datastore, prefix=str(expected_remote_path.as_posix()))
delete_blobs_in_datastore(datastore=default_datastore, prefix=str(expected_remote_path.as_posix()))
@pytest.mark.parametrize("arguments, run_id", [
@ -2263,3 +2359,150 @@ def test_create_run() -> None:
finally:
if run is not None:
run.complete()
def test_get_credential() -> None:
def _mock_validation_error() -> None:
raise ClientAuthenticationError("")
# test the case where service principal credentials are set as environment variables
mock_env_vars = {
util.ENV_SERVICE_PRINCIPAL_ID: "foo",
util.ENV_TENANT_ID: "bar",
util.ENV_SERVICE_PRINCIPAL_PASSWORD: "baz"
}
with patch.object(os.environ, "get", return_value=mock_env_vars):
with patch.multiple(
"health_azure.utils",
is_running_in_azure_ml=DEFAULT,
is_running_on_azure_agent=DEFAULT,
_get_legitimate_service_principal_credential=DEFAULT,
_get_legitimate_device_code_credential=DEFAULT,
_get_legitimate_default_credential=DEFAULT,
_get_legitimate_interactive_browser_credential=DEFAULT
) as mocks:
mocks["is_running_in_azure_ml"].return_value = False
mocks["is_running_on_azure_agent"].return_value = False
_ = get_credential()
mocks["_get_legitimate_service_principal_credential"].assert_called_once()
mocks["_get_legitimate_device_code_credential"].assert_not_called()
mocks["_get_legitimate_default_credential"].assert_not_called()
mocks["_get_legitimate_interactive_browser_credential"].assert_not_called()
# if the environment variables are not set and we are running on a local machine, a
# DefaultAzureCredential should be attempted first
with patch.object(os.environ, "get", return_value={}):
with patch.multiple(
"health_azure.utils",
is_running_in_azure_ml=DEFAULT,
is_running_on_azure_agent=DEFAULT,
_get_legitimate_service_principal_credential=DEFAULT,
_get_legitimate_device_code_credential=DEFAULT,
_get_legitimate_default_credential=DEFAULT,
_get_legitimate_interactive_browser_credential=DEFAULT
) as mocks:
mock_get_sp_cred = mocks["_get_legitimate_service_principal_credential"]
mock_get_device_cred = mocks["_get_legitimate_device_code_credential"]
mock_get_default_cred = mocks["_get_legitimate_default_credential"]
mock_get_browser_cred = mocks["_get_legitimate_interactive_browser_credential"]
mocks["is_running_in_azure_ml"].return_value = False
mocks["is_running_on_azure_agent"].return_value = False
_ = get_credential()
mock_get_sp_cred.assert_not_called()
mock_get_device_cred.assert_not_called()
mock_get_default_cred.assert_called_once()
mock_get_browser_cred.assert_not_called()
# if that fails, a DeviceCode credential should be attempted
mock_get_default_cred.side_effect = _mock_validation_error
_ = get_credential()
mock_get_sp_cred.assert_not_called()
mock_get_device_cred.assert_called_once()
assert mock_get_default_cred.call_count == 2
mock_get_browser_cred.assert_not_called()
# if None of the previous credentials work, an InteractiveBrowser credential should be tried
mock_get_device_cred.return_value = None
_ = get_credential()
mock_get_sp_cred.assert_not_called()
assert mock_get_device_cred.call_count == 2
assert mock_get_default_cred.call_count == 3
mock_get_browser_cred.assert_called_once()
# finally, if none of the methods work, an Exception should be raised
mock_get_browser_cred.return_value = None
with pytest.raises(Exception) as e:
get_credential()
assert "Unable to generate and validate a credential. Please see Azure ML documentation"\
"for instructions on different options to get a credential" in str(e)
def test_get_legitimate_service_principal_credential() -> None:
# first attempt to create and valiadate a credential with non-existant service principal credentials
# and check it fails
mock_service_principal_id = "foo"
mock_service_principal_password = "bar"
mock_tenant_id = "baz"
expected_error_msg = f"Found environment variables for {util.ENV_SERVICE_PRINCIPAL_ID}, "
f"{util.ENV_SERVICE_PRINCIPAL_PASSWORD}, and {util.ENV_TENANT_ID} but was not able to authenticate"
with pytest.raises(Exception) as e:
util._get_legitimate_service_principal_credential(mock_tenant_id, mock_service_principal_id,
mock_service_principal_password)
assert expected_error_msg in str(e)
# now mock the case where validating the credential succeeds and check the value of that
with patch("health_azure.utils._validate_credential"):
cred = util._get_legitimate_service_principal_credential(mock_tenant_id, mock_service_principal_id,
mock_service_principal_password)
assert isinstance(cred, ClientSecretCredential)
def test_get_legitimate_device_code_credential() -> None:
def _mock_credential_fast_timeout(timeout: int) -> DeviceCodeCredential:
return DeviceCodeCredential(timeout=1)
with patch("health_azure.utils.DeviceCodeCredential", new=_mock_credential_fast_timeout):
cred = util._get_legitimate_device_code_credential()
assert cred is None
# now mock the case where validating the credential succeeds
with patch("health_azure.utils._validate_credential"):
cred = util._get_legitimate_device_code_credential()
assert isinstance(cred, DeviceCodeCredential)
def test_get_legitimate_default_credential() -> None:
def _mock_credential_fast_timeout(timeout: int) -> DefaultAzureCredential:
return DefaultAzureCredential(timeout=1)
with patch("health_azure.utils.DefaultAzureCredential", new=_mock_credential_fast_timeout):
cred = util._get_legitimate_default_credential()
assert cred is None
with patch("health_azure.utils._validate_credential"):
cred = util._get_legitimate_default_credential()
assert isinstance(cred, DefaultAzureCredential)
def test_filter_v2_input_output_args() -> None:
def _compare_args(expected: List[str], actual: List[str]) -> None:
assert len(actual) == len(expected)
for actual_entry in actual:
assert actual_entry in expected
args_to_filter = ["--a=foo", "--INPUT_0=input0", "--b=bar", "--INPUT_1=input1"]
expected_filtered = ["--a=foo", "--b=bar"]
actual_filtered = util.filter_v2_input_output_args(args_to_filter)
_compare_args(expected_filtered, actual_filtered)
# try passing empty list
empty_list: List[str] = []
actual_filtered = util.filter_v2_input_output_args(empty_list)
assert actual_filtered == empty_list
# pass args with similar but different input and output args
args_to_filter = ["--input_0=input0", "--a=foo"]
expected_filtered = ["--input_0=input0", "--a=foo"]
actual_filtered = util.filter_v2_input_output_args(args_to_filter)
_compare_args(expected_filtered, actual_filtered)

Просмотреть файл

@ -110,6 +110,8 @@ def render_test_script(entry_script_path: Path, extra_options: Dict[str, str],
default_options['args'] = ''
default_options['body'] = ''
default_options["tags"] = '{}'
default_options["strictly_aml_v1"] = 'True'
default_options["submit_to_azureml"] = 'False'
all_options = dict(default_options, **extra_options)

Просмотреть файл

@ -1,3 +1,5 @@
#! /usr/bin/env python
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
@ -49,7 +51,9 @@ def main() -> None:
output_datasets={{ output_datasets }},
wait_for_completion={{ wait_for_completion }},
wait_for_completion_show_output={{ wait_for_completion_show_output }},
tags={{ tags }})
tags={{ tags }},
submit_to_azureml={{ submit_to_azureml}},
strictly_aml_v1={{ strictly_aml_v1 }})
parser = ArgumentParser()
{{ args }}

Просмотреть файл

@ -6,13 +6,17 @@
Test the data input and output functionality
"""
from pathlib import Path
from unittest import mock
from health_azure.utils import PathOrString
from unittest.mock import create_autospec, DEFAULT, MagicMock, patch
from health_azure.utils import PathOrString, get_ml_client
from typing import List, Union, Optional
import pytest
from azure.ai.ml import MLClient
from azure.ai.ml.entities import Data
from azure.ai.ml.operations import DatastoreOperations
from azure.core.exceptions import HttpResponseError
from azureml._restclient.exceptions import ServiceException
from azureml.core import Dataset
from azureml.core import Dataset, Workspace
from azureml.data import FileDataset, OutputFileDatasetConfig
from azureml.data.azure_storage_datastore import AzureBlobDatastore
from azureml.data.dataset_consumption_config import DatasetConsumptionConfig
@ -20,7 +24,8 @@ from azureml.exceptions._azureml_exception import UserErrorException
from health_azure.datasets import (DatasetConfig, _input_dataset_key, _output_dataset_key,
_replace_string_datasets, get_datastore, get_or_create_dataset,
create_dataset_configs)
create_dataset_configs, _get_or_create_v1_dataset, _get_or_create_v2_dataset,
_retrieve_v1_dataset, _create_v1_dataset, _retrieve_v2_dataset, _create_v2_dataset)
from testazure.utils_testazure import DEFAULT_DATASTORE, DEFAULT_WORKSPACE
@ -56,11 +61,39 @@ def test_get_datastore() -> None:
# Now mock the datastores property of the workspace, to pretend there is only a single datastore.
# With that in place, we can get the datastore without the name
faked_stores = {name: datastore}
with mock.patch("azureml.core.Workspace.datastores", faked_stores):
with patch("azureml.core.Workspace.datastores", faked_stores):
single_store = get_datastore(workspace=workspace, datastore_name="")
assert isinstance(single_store, AzureBlobDatastore)
assert single_store.name == name
# Now test retrieving a v2 datastore by name
mock_v2_dataset_name = "dummy_v2_datastore"
mock_returned_datastore = MagicMock()
mock_returned_datastore.name = mock_v2_dataset_name
mock_workspace = MagicMock()
mock_workspace.datastores = create_autospec(DatastoreOperations)
mock_workspace.datastores.get.return_value = mock_returned_datastore
v2_datastore = get_datastore(mock_workspace, datastore_name=mock_v2_dataset_name)
assert v2_datastore.name == mock_v2_dataset_name
# Test retrieving a default v2 datastore
mock_workspace.datastores.list.return_value = [mock_returned_datastore]
v2_default_datastore = get_datastore(mock_workspace, datastore_name="")
assert v2_default_datastore.name == mock_v2_dataset_name
# Mock case where list is empty but get_default returns a value
mock_workspace.datastores.list.return_value = []
mock_workspace.datastores.get_default.return_value = mock_returned_datastore
v2_default_datastore = get_datastore(mock_workspace, datastore_name="")
assert v2_default_datastore.name == mock_v2_dataset_name
# If datastores has an unknown format, an exception should be raised
mock_workspace = MagicMock()
mock_workspace.datastores.return_value = ["dummy_datastore_name"]
with pytest.raises(Exception) as e:
get_datastore(workspace=mock_workspace, datastore_name="")
assert "Unrecognised type for datastores" in str(e)
def test_dataset_input() -> None:
"""
@ -69,21 +102,23 @@ def test_dataset_input() -> None:
workspace = DEFAULT_WORKSPACE.workspace
# This dataset must exist in the workspace already, or at least in blob storage.
dataset_config = DatasetConfig(name="hello_world", datastore=DEFAULT_DATASTORE)
aml_dataset = dataset_config.to_input_dataset(workspace=workspace, dataset_index=1)
aml_dataset = dataset_config.to_input_dataset(dataset_index=1, workspace=workspace, strictly_aml_v1=True)
assert isinstance(aml_dataset, DatasetConsumptionConfig)
assert aml_dataset.path_on_compute is None
assert aml_dataset.mode == "download"
assert aml_dataset.path_on_compute is None # type: ignore
assert aml_dataset.mode == "download" # type: ignore
# Downloading or mounting to a given path
target_folder = "/tmp/foo"
dataset_config = DatasetConfig(name="hello_world", datastore=DEFAULT_DATASTORE, target_folder=target_folder)
aml_dataset = dataset_config.to_input_dataset(workspace=workspace, dataset_index=1)
aml_dataset = dataset_config.to_input_dataset(
dataset_index=1, workspace=workspace, strictly_aml_v1=True)
assert isinstance(aml_dataset, DatasetConsumptionConfig)
assert aml_dataset.path_on_compute == target_folder
assert aml_dataset.path_on_compute == target_folder # type: ignore
# Use mounting instead of downloading
dataset_config = DatasetConfig(name="hello_world", datastore=DEFAULT_DATASTORE, use_mounting=True)
aml_dataset = dataset_config.to_input_dataset(workspace=workspace, dataset_index=1)
aml_dataset = dataset_config.to_input_dataset(
dataset_index=1, workspace=workspace, strictly_aml_v1=True)
assert isinstance(aml_dataset, DatasetConsumptionConfig)
assert aml_dataset.mode == "mount"
assert aml_dataset.mode == "mount" # type: ignore
@pytest.mark.parametrize("target_folder", [
@ -97,7 +132,8 @@ def test_dataset_input_target_empty(target_folder: PathOrString) -> None:
workspace = DEFAULT_WORKSPACE.workspace
# This dataset must exist in the workspace already, or at least in blob storage.
dataset_config = DatasetConfig(name="hello_world", datastore=DEFAULT_DATASTORE, target_folder=target_folder)
aml_dataset = dataset_config.to_input_dataset(workspace=workspace, dataset_index=1)
aml_dataset: DatasetConsumptionConfig = dataset_config.to_input_dataset(
workspace=workspace, dataset_index=1, strictly_aml_v1=True)
assert isinstance(aml_dataset, DatasetConsumptionConfig)
assert aml_dataset.path_on_compute is None
@ -159,21 +195,138 @@ def test_datasets_from_string() -> None:
assert replaced[1] == original[1]
def test_get_dataset() -> None:
"""
Test if a dataset that does not yet exist can be created from a folder in blob storage
"""
# A folder with a single tiny file
tiny_dataset = "himl_tiny_dataset"
def test_get_or_create_dataset() -> None:
def _mock_retrieve_or_create_v2_dataset_fails(datastore_name: str, dataset_name: str, ml_client: MLClient) -> None:
raise HttpResponseError("Cannot create v2 Data Version in v1 Data Container")
data_asset_name = "himl_tiny_data_asset"
workspace = DEFAULT_WORKSPACE.workspace
ml_client = get_ml_client(aml_workspace=workspace)
# When creating a dataset, we need a non-empty name
with pytest.raises(ValueError) as ex:
get_or_create_dataset(workspace=workspace,
datastore_name=DEFAULT_DATASTORE,
dataset_name="")
ml_client=ml_client,
datastore_name="himldatasetsv2",
dataset_name="",
strictly_aml_v1=True)
assert "No dataset name" in str(ex)
# Check first that there is no dataset yet of that name. If there is, delete that dataset (it would come
# from previous runs of this test)
# pass strictly_aml_v1 = True and check the expected function is called
mock_v1_dataset = "v1_dataset"
with patch.multiple("health_azure.datasets",
_get_or_create_v1_dataset=DEFAULT,
_get_or_create_v2_dataset=DEFAULT) as mocks:
mocks["_get_or_create_v1_dataset"].return_value = mock_v1_dataset
dataset = get_or_create_dataset(workspace=workspace,
ml_client=ml_client,
datastore_name="himldatasetsv2",
dataset_name=data_asset_name,
strictly_aml_v1=True)
mocks["_get_or_create_v1_dataset"].assert_called_once()
mocks["_get_or_create_v2_dataset"].assert_not_called()
assert dataset == mock_v1_dataset
# Now pass strictly_aml_v1 as False
mock_v2_dataset = "v2_dataset"
mocks["_get_or_create_v2_dataset"].return_value = mock_v2_dataset
dataset = get_or_create_dataset(workspace=workspace,
ml_client=ml_client,
datastore_name="himldatasetsv2",
dataset_name=data_asset_name,
strictly_aml_v1=False)
mocks["_get_or_create_v1_dataset"].assert_called_once()
mocks["_get_or_create_v2_dataset"].assert_called_once()
assert dataset == mock_v2_dataset
# if trying to get or create a v2 dataset fails, should revert back to _get_or_create_v1_dataset
mocks["_get_or_create_v2_dataset"].side_effect = _mock_retrieve_or_create_v2_dataset_fails
dataset = get_or_create_dataset(workspace=workspace,
ml_client=ml_client,
datastore_name="himldatasetsv2",
dataset_name=data_asset_name,
strictly_aml_v1=False)
assert mocks["_get_or_create_v1_dataset"].call_count == 2
assert mocks["_get_or_create_v2_dataset"].call_count == 2
assert dataset == mock_v1_dataset
def test_get_or_create_v1_dataset() -> None:
def _mock_error_from_retrieve_v1_dataset(dataset_name: str, workspace: Workspace) -> None:
raise UserErrorException("Error Message")
workspace = DEFAULT_WORKSPACE.workspace
datastore = workspace.get_default_datastore()
dataset_name = "foo"
with patch.multiple("health_azure.datasets",
_retrieve_v1_dataset=DEFAULT,
_create_v1_dataset=DEFAULT) as mocks:
_get_or_create_v1_dataset(datastore, dataset_name, workspace)
mocks["_retrieve_v1_dataset"].assert_called_once()
mocks["_create_v1_dataset"].assert_not_called()
mocks["_retrieve_v1_dataset"].side_effect = _mock_error_from_retrieve_v1_dataset
_get_or_create_v1_dataset(datastore, dataset_name, workspace)
assert mocks["_retrieve_v1_dataset"].call_count == 2
mocks["_create_v1_dataset"].assert_called_once()
def test_get_or_create_v2_dataset() -> None:
def _mock_error_from_retrieve_v2_dataset(dataset_name: str, workspace: Workspace) -> None:
raise Exception("Error Message")
ml_client = MagicMock()
datastore = "dummy_datastore"
dataset_name = "foo"
with patch.multiple("health_azure.datasets",
_retrieve_v2_dataset=DEFAULT,
_create_v2_dataset=DEFAULT) as mocks:
_get_or_create_v2_dataset(datastore, dataset_name, ml_client)
mocks["_retrieve_v2_dataset"].assert_called_once()
mocks["_create_v2_dataset"].assert_not_called()
mocks["_retrieve_v2_dataset"].side_effect = _mock_error_from_retrieve_v2_dataset
_get_or_create_v2_dataset(datastore, dataset_name, ml_client)
assert mocks["_retrieve_v2_dataset"].call_count == 2
mocks["_create_v2_dataset"].assert_called_once()
def test_retrieve_v1_dataset() -> None:
nonexistent_dataset = "idontexist"
workspace = DEFAULT_WORKSPACE.workspace
# patch get_by_name to ensure it is called
with patch("azureml.core.Dataset.get_by_name") as mock_get_dataset:
_retrieve_v1_dataset(nonexistent_dataset, workspace)
mock_get_dataset.assert_called_once()
# Expect a ValueError to be raised if the dataset doesnt exist
with pytest.raises(Exception) as e:
_retrieve_v1_dataset(nonexistent_dataset, workspace)
assert "Cannot find dataset registered with name \"idontexist\"" in str(e)
def test_create_v1_dataset() -> None:
# If dataset_name or datastore_name are empty strings expect an Exception
empty_dataset_name = ""
empty_datastore_name = ""
nonempty_dataset_name = "foo"
nonempty_datastore_name = "bar"
workspace = DEFAULT_WORKSPACE.workspace
tiny_dataset = "himl_tiny_dataset"
with pytest.raises(Exception) as e:
_create_v1_dataset(empty_datastore_name, nonempty_dataset_name, workspace)
expected_str = "Cannot create dataset without a valid datastore name (received '') and a valid dataset name"
f" (received '{nonempty_dataset_name}')"
assert expected_str in str(e)
_create_v1_dataset(nonempty_datastore_name, empty_dataset_name, workspace)
expected_str = f"Cannot create dataset without a valid datastore name (received '{empty_dataset_name}') and "
"a valid dataset name (received '')"
assert expected_str in str(e)
try:
existing_dataset = Dataset.get_by_name(workspace, name=tiny_dataset)
try:
@ -183,10 +336,10 @@ def test_get_dataset() -> None:
pass
except Exception as ex:
assert "Cannot find dataset registered" in str(ex)
dataset = get_or_create_dataset(workspace=workspace,
datastore_name=DEFAULT_DATASTORE,
dataset_name=tiny_dataset)
dataset = _create_v1_dataset(DEFAULT_DATASTORE, tiny_dataset, workspace)
assert isinstance(dataset, FileDataset)
# We should now be able to get that dataset without special means
dataset2 = Dataset.get_by_name(workspace, name=tiny_dataset)
try:
@ -197,9 +350,31 @@ def test_get_dataset() -> None:
pass
def test_retrieve_v2_dataset() -> None:
dataset_name = "dummydataset"
mock_ml_client = MagicMock()
mock_retrieved_dataset = "dummy_data"
mock_ml_client.data.get.return_value = mock_retrieved_dataset
data_asset = _retrieve_v2_dataset(dataset_name, mock_ml_client)
assert data_asset == mock_retrieved_dataset
def test_create_v2_dataset() -> None:
dataset_name = "dummydataset"
datastore = "dummydataset"
mock_ml_client = MagicMock()
# mock_ml_client.data.create_or_update.return_value = None
data_asset = _create_v2_dataset(datastore, dataset_name, mock_ml_client)
assert isinstance(data_asset, Data)
assert data_asset.path == "azureml://datastores/dummydataset/paths/dummydataset/"
assert data_asset.type == "uri_folder"
assert data_asset.name == dataset_name
def test_dataset_keys() -> None:
"""
Check that dataset keys are non-empty strings, and that inputs and outputs have different keys.
Check that dataset keys are non-e
mpty strings, and that inputs and outputs have different keys.
"""
in1 = _input_dataset_key(1)
out1 = _output_dataset_key(1)

Просмотреть файл

@ -19,15 +19,20 @@ from ruamel import yaml
from ruamel.yaml.comments import CommentedMap as OrderedDict, CommentedSeq as OrderedList
from typing import Any, Dict, List, Optional, Tuple
from unittest import mock
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, create_autospec, patch, DEFAULT
from uuid import uuid4
import pytest
from _pytest.capture import CaptureFixture
from azure.ai.ml import Input, Output, MLClient
from azure.ai.ml.constants import AssetTypes
from azure.ai.ml.entities import Data
from azure.ai.ml.sweep import Choice
from azureml._restclient.constants import RunStatus
from azureml.core import ComputeTarget, Environment, RunConfiguration, ScriptRunConfig, Workspace
from azureml.data.azure_storage_datastore import AzureBlobDatastore
from azureml.data.dataset_consumption_config import DatasetConsumptionConfig
from azureml.dataprep.fuse.daemon import MountContext
from azureml.train.hyperdrive import HyperDriveConfig
import health_azure.himl as himl
@ -69,30 +74,34 @@ logger.setLevel(logging.DEBUG)
# region Small fast local unit tests
@pytest.mark.fast
def test_submit_to_azure_if_needed_returns_immediately() -> None:
def test_submit_to_azure_if_needed_returns_immediately(tmp_path: Path) -> None:
"""
Test that himl.submit_to_azure_if_needed can be called, and returns immediately.
"""
with mock.patch("sys.argv", ["", "--azureml"]):
shared_config_json = get_shared_config_json()
with check_config_json(tmp_path, shared_config_json=shared_config_json):
with pytest.raises(Exception) as ex:
himl.submit_to_azure_if_needed(
aml_workspace=None,
workspace_config_file=None,
entry_script=Path(__file__),
compute_cluster_name="foo",
snapshot_root_directory=Path(__file__).parent)
snapshot_root_directory=Path(__file__).parent,
submit_to_azureml=True)
# N.B. This assert may fail when run locally since we may find a workspace_config_file through the call to
# _find_file(CONDA_ENVIRONMENT_FILE) in submit_to_azure_if_needed
if _is_running_in_github_pipeline():
assert "No workspace config file given, nor can we find one" in str(ex)
with mock.patch("sys.argv", [""]):
result = himl.submit_to_azure_if_needed(
entry_script=Path(__file__),
compute_cluster_name="foo",
conda_environment_file=Path("env.yml"))
assert isinstance(result, himl.AzureRunInfo)
assert not result.is_running_in_azure_ml
assert result.run is None
with mock.patch("sys.argv", [""]):
result = himl.submit_to_azure_if_needed(
entry_script=Path(__file__),
compute_cluster_name="foo",
conda_environment_file=shared_config_json)
assert isinstance(result, himl.AzureRunInfo)
assert not result.is_running_in_azure_ml
assert result.run is None
def _is_running_in_github_pipeline() -> bool:
@ -254,20 +263,21 @@ def test_validate_compute_real(tmp_path: Path) -> None:
@pytest.mark.fast
@patch("azureml.data.OutputFileDatasetConfig")
@patch("health_azure.himl.DatasetConsumptionConfig")
@patch("health_azure.himl.Workspace")
@patch("health_azure.himl.DatasetConfig")
def test_to_datasets(
mock_dataset_config: mock.MagicMock,
mock_workspace: mock.MagicMock,
mock_dataset_consumption_config: mock.MagicMock,
mock_output_file_dataset_config: mock.MagicMock) -> None:
def to_input_dataset(workspace: Workspace, dataset_index: int, ) -> DatasetConsumptionConfig:
def to_input_dataset(workspace: Workspace, dataset_index: int, strictly_aml_v1: bool,
ml_client: Optional[MLClient] = None) -> DatasetConsumptionConfig:
return mock_dataset_consumption_config
def to_output_dataset(workspace: Workspace, dataset_index: int, ) -> DatasetConsumptionConfig:
def to_output_dataset(workspace: Workspace, dataset_index: int) -> DatasetConsumptionConfig:
return mock_output_file_dataset_config
mock_dataset_consumption_config = mock.create_autospec(DatasetConsumptionConfig)
mock_dataset_consumption_config.__class__.return_value = DatasetConsumptionConfig
mock_dataset_consumption_config.name = "A Consumption Config"
mock_output_file_dataset_config.name = "An Output File Dataset Config"
mock_dataset_config.to_input_dataset = to_input_dataset
@ -276,12 +286,14 @@ def test_to_datasets(
himl.convert_himl_to_azureml_datasets(
cleaned_input_datasets=[mock_dataset_config, mock_dataset_config],
cleaned_output_datasets=[],
strictly_aml_v1=True,
workspace=mock_workspace)
assert "already an input dataset with name" in str(ex1)
with pytest.raises(ValueError) as ex2:
himl.convert_himl_to_azureml_datasets(
cleaned_input_datasets=[mock_dataset_config, mock_dataset_config],
cleaned_output_datasets=[],
strictly_aml_v1=True,
workspace=mock_workspace)
assert "already an output dataset with name" in str(ex2)
@ -290,6 +302,7 @@ def test_to_datasets(
inputs, outputs = himl.convert_himl_to_azureml_datasets(
cleaned_input_datasets=cleaned_input_datasets,
cleaned_output_datasets=cleaned_output_datasets,
strictly_aml_v1=True,
workspace=mock_workspace)
assert len(inputs) == 1
assert len(outputs) == 1
@ -341,9 +354,9 @@ def test_create_run_configuration(
mock_env_name = "Mock Env"
mock_environment_get.return_value = mock_env_name
mock_workspace.compute_targets = {existing_compute_target: mock_compute_cluster}
aml_input_dataset = MagicMock()
aml_input_dataset = create_autospec(DatasetConsumptionConfig)
aml_input_dataset.name = "dataset_in"
aml_output_dataset = MagicMock()
aml_output_dataset = create_autospec(DatasetConsumptionConfig)
aml_output_dataset.name = "dataset_out"
mock_to_input_dataset.return_value = aml_input_dataset
mock_to_output_dataset.return_value = aml_output_dataset
@ -460,7 +473,8 @@ def test_create_run_configuration_correct_env(mock_create_environment: MagicMock
with pytest.raises(Exception) as e:
himl.create_run_configuration(mock_workspace,
"dummy_compute_cluster",
conda_environment_file=conda_env_path)
conda_environment_file=conda_env_path,
)
assert "you must specify the python version" in str(e)
# check that when create_run_configuration is called, whatever is returned from register_environment
@ -538,13 +552,11 @@ def test_get_workspace_no_config(
mock_is_running_in_azure.return_value = False
with change_working_directory(tmp_path):
with pytest.raises(ValueError) as ex:
with mock.patch("sys.argv", ["", "--azureml"]):
himl.submit_to_azure_if_needed(compute_cluster_name="foo")
himl.submit_to_azure_if_needed(compute_cluster_name="foo", submit_to_azureml=True)
assert "No workspace config file given" in str(ex)
@pytest.mark.fast
@patch("health_azure.himl.Run")
@patch("health_azure.himl.Workspace")
@patch("health_azure.himl._generate_azure_datasets")
@patch("health_azure.himl.RUN_CONTEXT")
@ -552,7 +564,7 @@ def test_submit_to_azure_if_needed_azure_return(
mock_run_context: mock.MagicMock,
mock_generate_azure_datasets: mock.MagicMock,
mock_workspace: mock.MagicMock,
mock_run: mock.MagicMock) -> None:
) -> None:
"""
When running in AzureML, the call to submit_to_azure_if_needed should return immediately, without trying to
submit a new job.
@ -561,13 +573,13 @@ def test_submit_to_azure_if_needed_azure_return(
mock_run_context.experiment = mock.MagicMock(workspace=mock_workspace)
assert is_running_in_azure_ml(himl.RUN_CONTEXT)
expected_run_info = himl.AzureRunInfo(
run=mock_run,
run=mock_run_context,
input_datasets=[],
output_datasets=[],
mount_contexts=[],
is_running_in_azure_ml=True,
output_folder=Path.cwd(),
logs_folder=Path.cwd())
output_folder=Path.cwd() / himl.OUTPUT_FOLDER,
logs_folder=Path.cwd() / himl.LOGS_FOLDER)
mock_generate_azure_datasets.return_value = expected_run_info
with mock.patch("sys.argv", ["", "--azureml"]):
run_info = himl.submit_to_azure_if_needed(
@ -835,6 +847,9 @@ def render_and_run_test_script(path: Path,
run_requirements = True
print("Copied 'src' folder.")
if run_target == RunTarget.AZUREML:
extra_options["submit_to_azureml"] = 'True'
environment_yaml_path = path / "environment.yml"
render_environment_yaml(environment_yaml_path, version, run_requirements, extra_options=extra_options)
@ -844,8 +859,6 @@ def render_and_run_test_script(path: Path,
workspace_config_file_arg=workspace_config_file_arg)
score_args = [str(entry_script_path)]
if run_target == RunTarget.AZUREML:
score_args.append("--azureml")
score_args.extend(extra_args)
env = dict(os.environ.items())
@ -907,10 +920,11 @@ def test_invoking_hello_world_no_config(run_target: RunTarget, tmp_path: Path) -
:param run_target: Where to run the script.
:param tmp_path: PyTest test fixture for temporary path.
"""
parser_args = "parser.add_argument('-m', '--message', type=str, required=True, help='The message to print out')"
message_guid = uuid4().hex
extra_options = {
'workspace_config_file': 'None',
'args': 'parser.add_argument("-m", "--message", type=str, required=True, help="The message to print out")',
'args': parser_args,
'body': 'print(f"The message was: {args.message}")'
}
extra_args = [f"--message={message_guid}"]
@ -946,8 +960,9 @@ def test_invoking_hello_world_config(run_target: RunTarget, use_package: bool, t
return
message_guid = uuid4().hex
parser_args = "parser.add_argument('-m', '--message', type=str, required=True, help='The message to print out')"
extra_options = {
'args': 'parser.add_argument("-m", "--message", type=str, required=True, help="The message to print out")',
'args': parser_args,
'body': 'print(f"The message was: {args.message}")'
}
extra_args = [f"--message={message_guid}"]
@ -1006,8 +1021,9 @@ def test_invoking_hello_world_env_var(run_target: RunTarget, tmp_path: Path) ->
import os
import sys""",
'environment_variables': f"{{'message_guid': '{message_guid}'}}",
'body': 'print(f"The message_guid env var was: {os.getenv(\'message_guid\')}")'
'body': 'print(f"The message_guid env var was: {os.getenv(\'message_guid\')}")',
}
extra_args: List[str] = []
output = render_and_run_test_script(tmp_path, run_target, extra_options, extra_args, True)
expected_output = f"The message_guid env var was: {message_guid}"
@ -1046,7 +1062,11 @@ def test_mounting_and_downloading_dataset(tmp_path: Path) -> None:
use_mounting=use_mounting,
target_folder=target_path)
logging.info(f"ready to {action}")
paths, mount_contexts = setup_local_datasets(dataset_configs=[dataset_config], aml_workspace=workspace)
paths, mount_contexts = setup_local_datasets(
dataset_configs=[dataset_config],
strictly_aml_v1=True,
aml_workspace=workspace
)
logging.info(f"{action} done")
path = paths[0]
assert path is not None
@ -1235,7 +1255,7 @@ import sys
file = input_folder / filename
shutil.copy(file, output_folder)
print(f"Copied file: {{file.name}} from {{input_blob_name}} to {{output_blob_name}}")
"""
""",
}
extra_args: List[str] = []
output = render_and_run_test_script(tmp_path, run_target, extra_options, extra_args, True)
@ -1288,33 +1308,204 @@ def test_create_crossval_hyperdrive_config(_: MagicMock, num_crossval_splits: in
assert crossval_config._max_total_runs == num_crossval_splits
def test_create_crossval_hyperparam_args_v2() -> None:
num_splits = 3
crossval_args = himl.create_crossval_hyperparam_args_v2(num_splits)
assert isinstance(crossval_args, Dict)
assert crossval_args[himl.MAX_TOTAL_TRIALS_ARG] == num_splits
assert isinstance(crossval_args[himl.PARAM_SAMPLING_ARG], Dict)
assert isinstance(crossval_args[himl.PARAM_SAMPLING_ARG]["crossval_index"], Choice)
assert crossval_args[himl.PRIMARY_METRIC_ARG] == "val/loss"
assert crossval_args[himl.SAMPLING_ALGORITHM_ARG] == "grid"
assert crossval_args[himl.GOAL_ARG] == "Minimize"
def test_create_grid_hyperparam_args_v2() -> None:
mock_values_float = [0.1, 0.2, 0.5]
mock_arg_name_float = "float_number"
mock_metric_name_float = mock_arg_name_float
hparams_args_float = himl.create_grid_hyperparam_args_v2(mock_values_float, mock_arg_name_float,
mock_metric_name_float)
assert isinstance(hparams_args_float, Dict)
assert hparams_args_float[himl.MAX_TOTAL_TRIALS_ARG] == len(mock_values_float)
assert isinstance(hparams_args_float[himl.PARAM_SAMPLING_ARG], Dict)
assert isinstance(hparams_args_float[himl.PARAM_SAMPLING_ARG][mock_arg_name_float], Choice)
assert hparams_args_float[himl.PRIMARY_METRIC_ARG] == mock_metric_name_float
assert hparams_args_float[himl.SAMPLING_ALGORITHM_ARG] == "grid"
assert hparams_args_float[himl.GOAL_ARG] == "Minimize"
mock_values_str = ["a", "b", "c"]
mock_arg_name_str = "letter"
mock_metric_name_str = mock_arg_name_str
hparam_args_str = himl.create_grid_hyperparam_args_v2(mock_values_str, mock_arg_name_str,
mock_metric_name_str)
assert isinstance(hparam_args_str, Dict)
assert hparam_args_str[himl.MAX_TOTAL_TRIALS_ARG] == len(mock_values_str)
assert isinstance(hparam_args_str[himl.PARAM_SAMPLING_ARG], Dict)
assert isinstance(hparam_args_str[himl.PARAM_SAMPLING_ARG][mock_arg_name_str], Choice)
assert hparam_args_str[himl.PRIMARY_METRIC_ARG] == mock_metric_name_str
assert hparam_args_str[himl.SAMPLING_ALGORITHM_ARG] == "grid"
assert hparam_args_str[himl.GOAL_ARG] == "Minimize"
@pytest.mark.fast
@pytest.mark.parametrize("cross_validation_metric_name", [None, "accuracy"])
@patch("sys.argv")
@patch("health_azure.himl.exit")
def test_submit_to_azure_if_needed_with_hyperdrive(mock_sys_args: MagicMock, mock_exit: MagicMock,
def test_submit_to_azure_if_needed_with_hyperdrive(mock_sys_args: MagicMock,
mock_exit: MagicMock,
mock_compute_cluster: MagicMock,
cross_validation_metric_name: Optional[str]) -> None:
cross_validation_metric_name: Optional[str],
) -> None:
"""
Test that himl.submit_to_azure_if_needed can be called, and returns immediately.
"""
cross_validation_metric_name = cross_validation_metric_name or ""
mock_sys_args.return_value = ["", "--azureml"]
with patch.object(Environment, "get", return_value="dummy_env"):
with patch("azureml.core.Workspace") as mock_workspace:
with patch("health_azure.himl.get_ml_client") as mock_get_ml_client:
mock_ml_client = MagicMock()
mock_get_ml_client.return_value = mock_ml_client
with patch.object(Environment, "get", return_value="dummy_env"):
mock_workspace = MagicMock()
mock_workspace.compute_targets = {"foo": mock_compute_cluster}
with patch("health_azure.datasets.setup_local_datasets") as mock_setup_local_datasets:
mock_setup_local_datasets.return_value = [], []
with patch("health_azure.himl.submit_run") as mock_submit_run:
with patch("health_azure.himl.HyperDriveConfig") as mock_hyperdrive_config:
crossval_config = himl.create_crossval_hyperdrive_config(
num_splits=2,
cross_val_index_arg_name="cross_val_split_index",
metric_name=cross_validation_metric_name)
himl.submit_to_azure_if_needed(
aml_workspace=mock_workspace,
ml_client=mock_ml_client,
entry_script=Path(__file__),
snapshot_root_directory=Path(__file__).parent,
compute_cluster_name="foo",
aml_environment_name="dummy_env",
submit_to_azureml=True,
hyperdrive_config=crossval_config,
strictly_aml_v1=True)
mock_submit_run.assert_called_once()
mock_hyperdrive_config.assert_called_once()
@pytest.mark.fast
def test_create_v2_inputs() -> None:
mock_ml_client = MagicMock()
mock_data_name = "mock_data"
mock_data_version = "1"
mock_data_path = "path/to/mock/data"
mock_ml_client.data.get.return_value = Data(
name=mock_data_name,
version=mock_data_version,
id=mock_data_path
)
mock_input_dataconfigs = [DatasetConfig(name="dummy_dataset")]
inputs = himl.create_v2_inputs(mock_ml_client, mock_input_dataconfigs)
assert isinstance(inputs, Dict)
assert len(inputs) == len(mock_input_dataconfigs)
input_entry = inputs["INPUT_0"]
assert isinstance(input_entry, Input)
assert input_entry.type == AssetTypes.URI_FOLDER
actual_path: str = input_entry.path # type: ignore
assert actual_path == mock_data_path
@pytest.mark.fast
def test_create_v2_outputs() -> None:
mock_datastore_name = "dummy_datastore"
mock_data_name = "dummy_dataset"
mock_output_dataconfigs = [DatasetConfig(name=mock_data_name, datastore=mock_datastore_name)]
outputs = himl.create_v2_outputs(mock_output_dataconfigs)
assert isinstance(outputs, Dict)
assert len(outputs) == len(mock_output_dataconfigs)
output_entry = outputs["OUTPUT_0"]
assert isinstance(output_entry, Output)
assert output_entry.type == AssetTypes.URI_FOLDER
expected_path = f"azureml://datastores/{mock_datastore_name}/paths/{mock_data_name}"
assert expected_path in output_entry['path']
def test_submit_to_azure_if_needed_v2() -> None:
"""
Check that submit_run_v2 is called when submit_to_azure_if_needed is called, unless strictly_aml_v1 is
set to True, in which case submit_run should be called instead
"""
dummy_input_datasets: List[Optional[Path]] = []
dummy_mount_contexts: List[MountContext] = []
with patch.multiple(
"health_azure.himl",
_package_setup=DEFAULT,
get_workspace=DEFAULT,
get_ml_client=DEFAULT,
create_run_configuration=DEFAULT,
create_script_run=DEFAULT,
append_to_amlignore=DEFAULT,
exit=DEFAULT
) as mocks:
mock_script_run = mocks["create_script_run"].return_value
mock_script_run.script = "dummy_script"
mock_script_run.source_directory = "dummy_dir"
with patch("health_azure.himl.setup_local_datasets") as mock_setup_datasets:
mock_setup_datasets.return_value = dummy_input_datasets, dummy_mount_contexts
with patch("health_azure.himl.submit_run_v2") as mock_submit_run_v2:
return_value = himl.submit_to_azure_if_needed(
workspace_config_file="mockconfig.json",
snapshot_root_directory="dummy",
submit_to_azureml=True,
strictly_aml_v1=False
)
mock_submit_run_v2.assert_called_once()
assert return_value is None
# Now supply strictly_aml_v1=True, and check that submit_run is called
with patch("health_azure.himl.submit_run") as mock_submit_run:
with patch("health_azure.himl.HyperDriveConfig") as mock_hyperdrive_config:
crossval_config = himl.create_crossval_hyperdrive_config(
num_splits=2,
cross_val_index_arg_name="cross_val_split_index",
metric_name=cross_validation_metric_name)
himl.submit_to_azure_if_needed(
aml_workspace=mock_workspace,
entry_script=Path(__file__),
compute_cluster_name="foo",
aml_environment_name="dummy_env",
submit_to_azureml=True,
hyperdrive_config=crossval_config)
mock_submit_run.assert_called_once()
mock_hyperdrive_config.assert_called_once()
return_value = himl.submit_to_azure_if_needed(
workspace_config_file="mockconfig.json",
snapshot_root_directory="dummy",
submit_to_azureml=True,
strictly_aml_v1=True,
)
mock_submit_run.assert_called_once()
assert return_value is None
@pytest.mark.fast
def test_generate_input_dataset_command() -> None:
input_datasets = {"INPUT_0": Input(), "INPUT_1": Input()}
input_data_cmd = himl._generate_input_dataset_command(input_datasets)
assert input_data_cmd == " --INPUT_0=${{inputs.INPUT_0}} --INPUT_1=${{inputs.INPUT_1}}"
@pytest.mark.fast
def test_generate_output_dataset_command() -> None:
output_datasets = {"OUTPUT_0": Output(), "OUTPUT_1": Output()}
output_data_cmd = himl._generate_output_dataset_command(output_datasets)
assert output_data_cmd == " --OUTPUT_0=${{outputs.OUTPUT_0}} --OUTPUT_1=${{outputs.OUTPUT_1}}"
@pytest.mark.fast
def test_extract_v2_inputs_outputs_from_args() -> None:
path_to_input_0 = "path_to_input_0"
path_to_output_0 = "path_to_output_0"
mock_args = [f"--INPUT_0={path_to_input_0}", "--INPUT_1=path_to_input_1", f"--OUTPUT_0={path_to_output_0}",
"--a=foo", "--b=bar"]
with patch.object(sys, "argv", new=mock_args):
input_datasets, output_datasets = himl._extract_v2_inputs_outputs_from_args()
assert len(input_datasets) == 2
assert input_datasets[0] == Path(path_to_input_0)
assert len(output_datasets) == 1
assert output_datasets[0] == Path(path_to_output_0)
# similar args should be ignored
mock_args_similar = [f"--input_0={path_to_input_0}", "--input_1=path_to_input_1", f"--output_0={path_to_output_0}",
"--a=foo", "--b=bar"]
with patch.object(sys, "argv", new=mock_args_similar):
input_datasets, output_datasets = himl._extract_v2_inputs_outputs_from_args()
assert len(input_datasets) == 0
assert len(output_datasets) == 0

Просмотреть файл

@ -3,13 +3,10 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from pathlib import Path
from unittest.mock import patch
import pytest
import subprocess
from health_azure import himl_download
from testazure.utils_testazure import MockRun
DOWNLOAD_SCRIPT_PATH = himl_download.__file__
@ -37,12 +34,3 @@ def test_download_aml_run_no_runs(tmp_path: Path) -> None:
with pytest.raises(Exception) as e:
subprocess.Popen(["python", DOWNLOAD_SCRIPT_PATH, "--run_id", "madeuprun", "--output_dir", str(tmp_path)])
assert "was not found" in str(e)
def test_retrieve_runs() -> None:
with patch("health_azure.utils.get_aml_run_from_run_id") as mock_get_run:
dummy_run_id = "run_id_123"
mock_get_run.return_value = MockRun(dummy_run_id)
dummy_download_config = himl_download.HimlDownloadConfig(run=[dummy_run_id])
_ = himl_download.retrieve_runs(dummy_download_config)
mock_get_run.assert_called_with(dummy_run_id)

Просмотреть файл

@ -12,7 +12,7 @@ def test_git_repo_root_folder() -> None:
with patch("health_azure.paths.is_himl_used_from_git_repo", return_value=False):
with pytest.raises(ValueError) as e:
git_repo_root_folder()
assert"This function should not be used if hi-ml is used as an installed package." in str(e)
assert "This function should not be used if hi-ml is used as an installed package." in str(e)
@pytest.mark.fast

Просмотреть файл

@ -96,7 +96,7 @@ import sys
loss.backward()
optimizer.step()
writer.flush()
"""
""",
}
extra_args: List[str] = []

Просмотреть файл

@ -1,5 +1,5 @@
[mypy]
python_version=3.7
python_version=3.9
scripts_are_modules=True
namespace_packages=True
show_traceback=True

Просмотреть файл

@ -87,7 +87,7 @@ SRC_CKPT_RUN_ID_CRCK := TcgaCrckSSLMIL_1667236343_af6e293f
# Run regression tests and compare performance
define BASE_CPATH_RUNNER_COMMAND
cd ../ ; \
python hi-ml/src/health_ml/runner.py --mount_in_azureml --conda_env=hi-ml-cpath/environment.yml
python hi-ml/src/health_ml/runner.py --mount_in_azureml --conda_env=hi-ml-cpath/environment.yml --datastore=himldatasets
endef
define DEEPSMILEPANDASLIDES_ARGS

Просмотреть файл

@ -3,166 +3,205 @@
name: HimlHisto
channels:
- pytorch
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- _libgcc_mutex=0.1=conda_forge
- _openmp_mutex=4.5=2_gnu
- blas=1.0=mkl
- bzip2=1.0.8=h7b6447c_0
- ca-certificates=2022.4.26=h06a4308_0
- certifi=2022.6.15=py37h06a4308_0
- ca-certificates=2022.10.11=h06a4308_0
- cairo=1.16.0=hf32fb01_1
- certifi=2022.9.24=py39h06a4308_0
- cudatoolkit=11.3.1=h2bc3f7f_2
- ffmpeg=4.2.2=h20bf706_0
- freetype=2.11.0=h70c0345_0
- fontconfig=2.13.1=hef1e5e3_1
- freetype=2.12.1=h4a9f257_0
- gdk-pixbuf=2.42.6=h04a7f16_0
- gettext=0.21.0=hf68c758_0
- giflib=5.2.1=h7b6447c_0
- glib=2.68.2=h36276a3_0
- gmp=6.2.1=h295c915_3
- gnutls=3.6.15=he1e5248_0
- icu=58.2=he6710b0_3
- intel-openmp=2021.4.0=h06a4308_3561
- jpeg=9e=h7f8727e_0
- lame=3.100=h7b6447c_0
- lcms2=2.12=h3be6417_0
- libedit=3.1.20210910=h7f8727e_0
- libffi=3.2.1=hf484d3e_1007
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- ld_impl_linux-64=2.38=h1181459_1
- lerc=3.0=h295c915_0
- libdeflate=1.8=h7f8727e_5
- libffi=3.3=he6710b0_2
- libgcc-ng=12.2.0=h65d4601_19
- libglib=2.68.2=h3e27bee_0
- libgomp=12.2.0=h65d4601_19
- libiconv=1.16=h7f8727e_2
- libidn2=2.3.2=h7f8727e_0
- libopus=1.3.1=h7b6447c_0
- libpng=1.6.37=hbc83047_0
- libstdcxx-ng=11.2.0=h1234567_1
- libtasn1=4.16.0=h27cfd23_0
- libtiff=4.2.0=h2818925_1
- libtiff=4.4.0=hecacb30_1
- libunistring=0.9.10=h27cfd23_0
- libuuid=1.0.3=h7f8727e_2
- libuv=1.40.0=h7b6447c_0
- libvpx=1.7.0=h439df22_0
- libwebp=1.2.2=h55f646e_0
- libwebp-base=1.2.2=h7f8727e_0
- libwebp=1.2.4=h11a3e52_0
- libwebp-base=1.2.4=h5eee18b_0
- libxcb=1.15=h7f8727e_0
- libxml2=2.9.14=h74e7548_0
- lz4-c=1.9.3=h295c915_1
- mkl=2021.4.0=h06a4308_640
- mkl-service=2.4.0=py37h7f8727e_0
- mkl_fft=1.3.1=py37hd3c417c_0
- mkl_random=1.2.2=py37h51133e4_0
- mkl-service=2.4.0=py39h7f8727e_0
- mkl_fft=1.3.1=py39hd3c417c_0
- mkl_random=1.2.2=py39h51133e4_0
- ncurses=6.3=h5eee18b_3
- nettle=3.7.3=hbbd107a_1
- openh264=2.1.1=h4ff587b_0
- openssl=1.1.1q=h7f8727e_0
- pip=20.1.1=py37_1
- python=3.7.3
- pytorch=1.10.0=py3.7_cuda11.3_cudnn8.2.0_0
- openjpeg=2.4.0=h3ad879b_0
- openslide=3.4.1=h978ee9a_4
- openslide-python=1.2.0=py39hb9d737c_2
- openssl=1.1.1s=h7f8727e_0
- pcre=8.45=h295c915_0
- pip=20.1.1=py_1
- pixman=0.40.0=h7f8727e_1
- python=3.9.13
- python_abi=3.9=2_cp39
- pytorch=1.10.0=py3.9_cuda11.3_cudnn8.2.0_0
- pytorch-mutex=1.0=cuda
- readline=7.0=h7b6447c_5
- readline=8.2=h5eee18b_0
- six=1.16.0=pyhd3eb1b0_1
- sqlite=3.33.0=h62c20be_0
- sqlite=3.39.3=h5082296_0
- tk=8.6.12=h1ccaba5_0
- torchvision=0.11.1=py37_cu113
- typing_extensions=4.1.1=pyh06a4308_0
- torchvision=0.11.1=py39_cu113
- typing_extensions=4.3.0=py39h06a4308_0
- tzdata=2022f=h04d1e81_0
- x264=1!157.20191217=h7b6447c_0
- xz=5.2.5=h7f8727e_1
- zlib=1.2.12=h7f8727e_2
- xz=5.2.6=h5eee18b_0
- zlib=1.2.13=h5eee18b_0
- zstd=1.5.2=ha4553b6_0
- pip:
- absl-py==1.2.0
- absl-py==1.3.0
- adal==1.2.7
- aiohttp==3.8.1
- aiosignal==1.2.0
- aiohttp==3.8.3
- aiosignal==1.3.1
- alabaster==0.7.12
- alembic==1.8.1
- applicationinsights==0.11.10
- argcomplete==2.0.0
- astroid==2.9.3
- astroid==2.12.12
- async-timeout==4.0.2
- asynctest==0.13.0
- attrs==21.4.0
- azure-ai-ml==1.1.0
- azure-common==1.1.28
- azure-core==1.24.2
- azure-core==1.26.1
- azure-graphrbac==0.61.1
- azure-identity==1.7.0
- azure-mgmt-authorization==2.0.0
- azure-mgmt-containerregistry==10.0.0
- azure-mgmt-core==1.3.0
- azure-mgmt-core==1.3.1
- azure-mgmt-keyvault==10.0.0
- azure-mgmt-resource==21.1.0
- azure-mgmt-resource==21.2.1
- azure-mgmt-storage==20.0.0
- azure-storage-blob==12.5.0
- azure-storage-blob==12.10.0
- azure-storage-file-datalake==12.9.1
- azure-storage-file-share==12.10.1
- azureml-core==1.43.0
- azureml-dataprep==4.0.4
- azureml-dataprep-native==38.0.0
- azureml-dataprep-rslex==2.6.3
- azureml-dataset-runtime==1.43.0.post2
- azureml-dataset-runtime==1.43.0
- azureml-mlflow==1.47.0
- azureml-telemetry==1.43.0
- azureml-tensorboard==1.43.0
- azureml-train-core==1.43.0
- azureml-train-restclients-hyperdrive==1.43.0
- babel==2.10.3
- babel==2.11.0
- backports-tempfile==1.0
- backports-weakref==1.0.post1
- bcrypt==3.2.2
- bcrypt==4.0.1
- black==22.1.0
- bleach==5.0.1
- cachetools==4.2.4
- cffi==1.15.1
- cfgv==3.3.1
- charset-normalizer==2.1.0
- charset-normalizer==2.1.1
- click==8.1.3
- cloudpickle==1.6.0
- colorama==0.4.5
- colorama==0.4.6
- coloredlogs==15.0.1
- conda-merge==0.1.5
- contextlib2==21.6.0
- coverage==6.3.2
- cryptography==37.0.4
- cryptography==38.0.3
- cucim==22.4.0
- cycler==0.11.0
- databricks-cli==0.17.3
- dataclasses-json==0.5.2
- dill==0.3.6
- diskcache==5.4.0
- distlib==0.3.5
- distro==1.7.0
- distlib==0.3.6
- distro==1.8.0
- docker==5.0.3
- docutils==0.16
- dotnetcore2==3.1.23
- filelock==3.7.1
- entrypoints==0.4
- filelock==3.8.0
- flake8==4.0.1
- frozenlist==1.3.0
- fsspec==2022.5.0
- flask==2.2.2
- frozenlist==1.3.3
- fsspec==2022.10.0
- fusepy==3.0.1
- girder-client==3.1.14
- gitdb==4.0.9
- gitpython==3.1.29
- google-api-core==2.10.2
- google-auth==1.35.0
- google-auth-oauthlib==0.4.6
- grpcio==1.47.0
- hi-ml==0.2.3
- hi-ml-azure==0.2.3
- googleapis-common-protos==1.56.4
- greenlet==2.0.1
- grpcio==1.50.0
- gunicorn==20.1.0
- hi-ml==0.2.9
- hi-ml-azure==0.2.9
- humanfriendly==10.0
- identify==2.5.2
- idna==3.3
- imageio==2.19.5
- identify==2.5.8
- idna==3.4
- imageio==2.22.4
- imagesize==1.4.1
- importlib-metadata==4.2.0
- importlib-resources==5.8.0
- iniconfig==1.1.1
- isodate==0.6.1
- isort==5.10.1
- itsdangerous==2.1.2
- jaraco-classes==3.2.3
- jeepney==0.8.0
- jinja2==3.0.2
- jmespath==1.0.0
- joblib==1.1.0
- joblib==1.2.0
- jsonpickle==2.2.0
- jsonschema==4.7.2
- keyring==23.7.0
- jsonschema==4.17.0
- keyring==23.11.0
- kiwisolver==1.4.4
- knack==0.9.0
- lazy-object-proxy==1.7.1
- lazy-object-proxy==1.8.0
- lightning-bolts==0.4.0
- llvmlite==0.38.1
- llvmlite==0.39.1
- lxml==4.9.1
- mako==1.2.3
- markdown==2.6.8
- markdown-it-py==1.1.0
- markupsafe==2.1.1
- marshmallow==3.17.0
- marshmallow==3.18.0
- marshmallow-enum==1.5.1
- matplotlib==3.4.3
- mccabe==0.6.1
- mdit-py-plugins==0.2.8
- mlflow==1.30.0
- mlflow-skinny==1.30.0
- monai==0.8.0
- more-itertools==8.10.0
- msal==1.18.0
- msal==1.20.0
- msal-extensions==0.3.1
- msrest==0.6.21
- msrestazure==0.6.4
@ -171,25 +210,31 @@ dependencies:
- mypy-extensions==0.4.3
- myst-parser==0.15.2
- ndg-httpsclient==0.5.1
- networkx==2.6.3
- networkx==2.8.8
- nodeenv==1.7.0
- numba==0.55.2
- numpy==1.21.6
- oauthlib==3.2.0
- numba==0.56.4
- numpy==1.22.0
- oauthlib==3.2.2
- opencensus==0.11.0
- opencensus-context==0.1.3
- opencensus-ext-azure==1.1.7
- opencv-python-headless==4.5.1.48
- packaging==21.3
- pandas==1.3.4
- param==1.12.0
- paramiko==2.11.0
- pathspec==0.9.0
- paramiko==2.12.0
- pathspec==0.10.1
- pillow==9.0.1
- pipdeptree==2.2.1
- pkginfo==1.8.3
- platformdirs==2.5.2
- platformdirs==2.5.3
- pluggy==0.13.1
- portalocker==2.5.1
- portalocker==2.6.0
- pre-commit==2.19.0
- prometheus-client==0.15.0
- prometheus-flask-exporter==0.20.3
- protobuf==3.20.1
- psutil==5.9.4
- py==1.11.0
- pyarrow==3.0.0
- pyasn1==0.4.8
@ -197,43 +242,45 @@ dependencies:
- pycobertura==2.0.1
- pycodestyle==2.8.0
- pycparser==2.21
- pydash==5.1.1
- pydeprecate==0.3.2
- pydicom==2.3.0
- pyflakes==2.4.0
- pygments==2.12.0
- pyjwt==2.4.0
- pylint==2.12.2
- pygments==2.13.0
- pyjwt==2.6.0
- pylint==2.15.0
- pynacl==1.5.0
- pynndescent==0.5.7
- pyopenssl==22.0.0
- pynndescent==0.5.8
- pyopenssl==22.1.0
- pyparsing==3.0.9
- pyrsistent==0.18.1
- pysocks==1.6.0
- pyrsistent==0.19.2
- pysocks==1.7.1
- pytest==6.2.2
- pytest-cov==2.11.1
- pytest-rerunfailures==10.2
- pytest-timeout==2.0.1
- python-dateutil==2.8.2
- pytorch-lightning==1.6.5
- pytz==2022.1
- pywavelets==1.3.0
- pytz==2022.6
- pywavelets==1.4.1
- pyyaml==6.0
- readme-renderer==35.0
- querystring-parser==1.2.4
- readme-renderer==37.3
- requests==2.28.1
- requests-oauthlib==1.3.1
- requests-toolbelt==0.9.1
- requests-toolbelt==0.10.1
- rfc3986==2.0.0
- rpdb==0.1.6
- rsa==4.9
- ruamel-yaml==0.16.12
- ruamel-yaml-clib==0.2.6
- scikit-image==0.19.3
- scikit-learn==1.0.2
- scikit-learn==1.1.3
- scipy==1.7.3
- seaborn==0.10.1
- secretstorage==3.3.2
- secretstorage==3.3.3
- setuptools==59.5.0
- simpleitk==2.1.1.2
- smmap==5.0.0
- snowballstemmer==2.2.0
- sphinx==4.1.2
- sphinx-autodoc-typehints==1.12.0
@ -245,8 +292,11 @@ dependencies:
- sphinxcontrib-jsmath==1.0.1
- sphinxcontrib-qthelp==1.0.3
- sphinxcontrib-serializinghtml==1.1.5
- sqlalchemy==1.4.43
- sqlparse==0.4.3
- strictyaml==1.6.2
- stringcase==1.2.0
- tabulate==0.8.10
- tabulate==0.9.0
- tensorboard==2.6.0
- tensorboard-data-server==0.6.1
- tensorboard-plugin-wit==1.8.1
@ -254,19 +304,19 @@ dependencies:
- tifffile==2021.11.2
- toml==0.10.2
- tomli==2.0.1
- tomlkit==0.11.6
- torchmetrics==0.6.0
- tqdm==4.64.0
- tqdm==4.64.1
- twine==3.3.0
- typed-ast==1.5.4
- typing-inspect==0.7.1
- typing-inspect==0.8.0
- umap-learn==0.5.2
- urllib3==1.26.9
- virtualenv==20.15.1
- virtualenv==20.16.6
- webencodings==0.5.1
- websocket-client==1.3.3
- werkzeug==2.1.2
- websocket-client==1.4.2
- werkzeug==2.2.2
- wheel==0.36.2
- wrapt==1.13.3
- wrapt==1.14.1
- yacs==0.1.8
- yarl==1.7.2
- zipp==3.8.1
- yarl==1.8.1
- zipp==3.10.0

Просмотреть файл

@ -6,9 +6,11 @@ channels:
dependencies:
- cudatoolkit=11.3.1
- pip=20.1.1
- python=3.7.3
- python=3.9.13
- pytorch=1.10.0
- torchvision=0.11.1
- openslide=3.4.1
- openslide-python=1.2.0
- pip:
# Run requirements for hi-ml
- dataclasses-json==0.5.2
@ -23,13 +25,12 @@ dependencies:
- setuptools==59.5.0
# Run requirements for hi-ml-azure
- azureml-core==1.43.0
- azureml-dataset-runtime[fuse]
- azureml-dataset-runtime[fuse]==1.43.0
- azureml-tensorboard==1.43.0
- azureml-train-core==1.43.0
- conda-merge==0.1.5
- msal-extensions==0.3.1
- param==1.12
- pysocks==1.6.0
- ruamel.yaml==0.16.12
- tensorboard==2.6.0
# Histopathology requirements
@ -39,6 +40,10 @@ dependencies:
# Build requirements
- -r requirements_build.txt
# Pinned secondary dependencies to prevent clashes
- attrs==21.4.0
- azure-mgmt-core==1.3.1
- azure-mgmt-keyvault==10.0.0
- cryptography>=38.0.3
- cloudpickle==1.6.0
- importlib-metadata==4.2.0
- markdown==2.6.8

Просмотреть файл

@ -6,7 +6,7 @@ hi-ml
lightning-bolts==0.4.0
monai==0.8.0
more-itertools==8.10.0
numpy==1.21.6
numpy==1.22.0
pillow==9.0.1
pydicom==2.3.0
scikit-image==0.19.3

Просмотреть файл

@ -5,7 +5,7 @@ mypy==0.961
mypy-extensions==0.4.3
pipdeptree==2.2.1
pre-commit==2.19.0
pylint==2.12.2
pylint==2.15.0
pycobertura==2.0.1
pytest==6.2.2
pytest-cov==2.11.1

Просмотреть файл

@ -19,69 +19,69 @@ from setuptools import find_namespace_packages, setup # type: ignore
here = pathlib.Path(__file__).parent.resolve()
# Get the long description from the README file
long_description = (here / 'README.md').read_text(encoding='utf-8')
long_description = (here / "README.md").read_text(encoding="utf-8")
version = ''
version = ""
# If running from a GitHub Action then a standard set of environment variables will be
# populated (https://docs.github.com/en/actions/reference/environment-variables#default-environment-variables).
# In particular, GITHUB_REF is the branch or tag ref that triggered the workflow.
# If this was triggered by a tagged commit then GITHUB_REF will be: 'ref/tags/new_tag'.
# If this was triggered by a tagged commit then GITHUB_REF will be: "ref/tags/new_tag".
# Extract this tag and use it as a version string
# See also:
# https://packaging.python.org/guides/publishing-package-distribution-releases-using-github-actions-ci-cd-workflows/
# https://github.com/pypa/gh-action-pypi-publish
GITHUB_REF_TAG_COMMIT = 'refs/tags/v'
GITHUB_REF_TAG_COMMIT = "refs/tags/v"
github_ref = os.getenv('GITHUB_REF')
github_ref = os.getenv("GITHUB_REF")
if github_ref and github_ref.startswith(GITHUB_REF_TAG_COMMIT):
version = github_ref[len(GITHUB_REF_TAG_COMMIT):]
# Otherwise, if running from a GitHub Action, but not a tagged commit then GITHUB_RUN_NUMBER will be populated.
# Use this as a post release number. For example if GITHUB_RUN_NUMBER = 124 then the version string will be
# '99.99.post124'. Although this is discouraged, see:
# "99.99.post124". Although this is discouraged, see:
# https://www.python.org/dev/peps/pep-0440/#post-releases
# it is necessary here to avoid duplicate packages in Test.PyPI.
if not version:
build_number = os.getenv('GITHUB_RUN_NUMBER')
build_number = os.getenv("GITHUB_RUN_NUMBER")
if build_number:
# In github workflows, we may pull in hi-ml-multimodal as a dependency. Usually, we have a condition like
# hi-ml-multimodal>=0.1.5. This means that a package version from PyPi would trump the local wheels. For this
# reason, use an extremely large version number to give the local wheel priority.
version = '99.99.post' + build_number
version = "99.99.post" + build_number
else:
default_random_version_number = floor(random() * 10_000_000_000)
version = f'99.99.post{str(default_random_version_number)}'
version = f"99.99.post{str(default_random_version_number)}"
package_name = 'hi-ml-cpath'
(here / 'package_name.txt').write_text(package_name)
(here / 'latest_version.txt').write_text(version)
package_name = "hi-ml-cpath"
(here / "package_name.txt").write_text(package_name)
(here / "latest_version.txt").write_text(version)
# Read run_requirements.txt to get install_requires
install_requires = (here / 'requirements_run.txt').read_text().split("\n")
install_requires = (here / "requirements_run.txt").read_text().split("\n")
# Remove any whitespace and blank lines
install_requires = [line.strip() for line in install_requires if line.strip()]
description = 'Microsoft Health Futures package for deep learning on histopathology images'
description = "Microsoft Health Futures package for deep learning on histopathology images"
setup(
name=package_name,
version=version,
description=description,
long_description=long_description,
long_description_content_type='text/markdown',
url='https://github.com/microsoft/hi-ml',
author="Microsoft Research Cambridge Medical Imaging Team ",
long_description_content_type="text/markdown",
url="https://github.com/microsoft/hi-ml",
author="Biomedical Imaging Team @ Microsoft Health Futures",
author_email="innereyedev@microsoft.com",
classifiers=[
'Development Status :: 3 - Alpha',
'Intended Audience :: Science/Research',
"Development Status :: 3 - Alpha",
"Intended Audience :: Science/Research",
"Topic :: Scientific/Engineering :: Medical Science Apps.",
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.7'
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.7"
],
keywords='Health Futures, Health Intelligence, Computational Pathology, AzureML',
license='MIT License',
keywords="Health Futures, Health Intelligence, Computational Pathology, AzureML",
license="MIT License",
packages=find_namespace_packages(where="src", include=["health_cpath.*"]),
package_dir={"": "src"},
include_package_data=False,

Просмотреть файл

@ -2,18 +2,16 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from enum import Enum
from typing import Any
from SSL.lightning_containers.ssl_container import EncoderName, SSLContainer, SSLDatasetName
from SSL.lightning_containers.ssl_container import EncoderName, SSLContainer
from SSL.utils import SSLTrainingType
from health_cpath.datasets.default_paths import TCGA_CRCK_DATASET_ID
from health_cpath.datasets.tcga_crck_tiles_dataset import TcgaCrck_TilesDatasetWithReturnIndex
from SSL.configs.HistoSimCLRContainer import HistoSSLContainer
class SSLDatasetNameHiml(SSLDatasetName, Enum): # type: ignore
TCGA_CRCK = "CRCKTilesDataset"
SSL_Dataset_TCGA_CRCK = "CRCKTilesDataset"
class CRCK_SimCLR(HistoSSLContainer):
@ -23,7 +21,7 @@ class CRCK_SimCLR(HistoSSLContainer):
in the _get_transforms method.
It has been tested locally and on AML on the full training dataset (93408 tiles).
"""
SSLContainer._SSLDataClassMappings.update({SSLDatasetNameHiml.TCGA_CRCK.value:
SSLContainer.DatasetToClassMapping.update({SSL_Dataset_TCGA_CRCK:
TcgaCrck_TilesDatasetWithReturnIndex})
def __init__(self, **kwargs: Any) -> None:
@ -32,8 +30,8 @@ class CRCK_SimCLR(HistoSSLContainer):
# --num_workers = 0
# --max_epochs = 2
super().__init__(ssl_training_dataset_name=SSLDatasetNameHiml.TCGA_CRCK,
linear_head_dataset_name=SSLDatasetNameHiml.TCGA_CRCK,
super().__init__(ssl_training_dataset_name=SSL_Dataset_TCGA_CRCK,
linear_head_dataset_name=SSL_Dataset_TCGA_CRCK,
azure_datasets=[TCGA_CRCK_DATASET_ID],
random_seed=1,
num_workers=8,

Просмотреть файл

@ -71,7 +71,7 @@ class HistoSSLContainer(SSLContainer):
self.online_eval = SslOnlineEvaluatorHiml(class_weights=self.data_module.class_weights, # type: ignore
z_dim=self.encoder_output_dim,
num_classes=self.data_module.num_classes, # type: ignore
dataset=self.linear_head_dataset_name.value, # type: ignore
dataset=self.linear_head_dataset_name, # type: ignore
drop_p=0.2,
learning_rate=self.learning_rate_linear_head_during_ssl_training)

Просмотреть файл

@ -2,10 +2,9 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from enum import Enum
from typing import Any
from SSL.lightning_containers.ssl_container import EncoderName, SSLContainer, SSLDatasetName
from SSL.lightning_containers.ssl_container import EncoderName, SSLContainer
from SSL.utils import SSLTrainingType
from health_azure.utils import is_running_in_azure_ml
from health_cpath.datasets.panda_tiles_dataset import PandaTilesDatasetWithReturnIndex
@ -13,8 +12,7 @@ from health_cpath.datasets.default_paths import PANDA_5X_TILES_DATASET_ID
from SSL.configs.HistoSimCLRContainer import HistoSSLContainer
class SSLDatasetNameHiml(SSLDatasetName, Enum): # type: ignore
PANDA = "PandaTilesDataset"
SSL_Dataset_PANDA = "PandaTilesDataset"
class PANDA_SimCLR(HistoSSLContainer):
@ -24,11 +22,11 @@ class PANDA_SimCLR(HistoSSLContainer):
in the _get_transforms method.
It has been tested on a toy local dataset (2 slides) and on AML on (~25 slides).
"""
SSLContainer._SSLDataClassMappings.update({SSLDatasetNameHiml.PANDA.value: PandaTilesDatasetWithReturnIndex})
SSLContainer.DatasetToClassMapping.update({SSL_Dataset_PANDA: PandaTilesDatasetWithReturnIndex})
def __init__(self, **kwargs: Any) -> None:
super().__init__(ssl_training_dataset_name=SSLDatasetNameHiml.PANDA,
linear_head_dataset_name=SSLDatasetNameHiml.PANDA,
super().__init__(ssl_training_dataset_name=SSL_Dataset_PANDA,
linear_head_dataset_name=SSL_Dataset_PANDA,
azure_datasets=[PANDA_5X_TILES_DATASET_ID],
random_seed=1,
num_workers=5,

Просмотреть файл

@ -101,7 +101,7 @@ class RSNAKaggleCXR(CxrDatasetWithReturnIndex):
self.dataset_dataframe = pd.read_csv(self.root / "dataset.csv")
self.targets = self.dataset_dataframe.label.values.astype(np.int64)
self.subject_ids = self.dataset_dataframe.subject.values
self.indices = np.arange(len(self.dataset_dataframe))
self.indices = np.arange(len(self.dataset_dataframe)).tolist()
self.filenames = [self.root / f"{subject_id}.dcm" for subject_id in self.subject_ids]
else:
# No test set implemented for this data class.
@ -141,7 +141,7 @@ class NIHCXR(CxrDatasetWithReturnIndex):
train_ids = pd.read_csv(self.root / "train_val_list.txt", header=None).values.reshape(-1)
is_train_val_ids = self.dataset_dataframe["Image Index"].isin(train_ids).values
self.subject_ids = np.where(is_train_val_ids)[0] if self.train else np.where(~is_train_val_ids)[0]
self.indices = np.arange(len(self.subject_ids))
self.indices = np.arange(len(self.subject_ids)).tolist()
self.filenames = [self.root / f"{subject_id}" for subject_id in self.subject_ids]
@ -175,7 +175,7 @@ class CheXpert(CxrDatasetWithReturnIndex):
# Strip away the name of the folder that is included in the path column of the dataset
strip_n = len("CheXpert-v1.0-small/")
self.dataset_dataframe.Path = self.dataset_dataframe.Path.apply(lambda x: x[strip_n:])
self.indices = np.arange(len(self.dataset_dataframe))
self.indices = np.arange(len(self.dataset_dataframe)).tolist()
self.filenames = [self.root / p for p in self.dataset_dataframe.Path.values]
@ -191,7 +191,7 @@ class CovidDataset(CxrDatasetWithReturnIndex):
mapping = {0: 0, 3: 0, 1: 1, 2: 1}
# For monitoring purpose with use binary classification CV03vsCV12
self.dataset_dataframe["final_label"] = self.dataset_dataframe.final_label.apply(lambda x: mapping[x])
self.indices = np.arange(len(self.dataset_dataframe))
self.indices = np.arange(len(self.dataset_dataframe)).tolist()
self.subject_ids = self.dataset_dataframe.subject.values
self.filenames = [self.root / file for file in self.dataset_dataframe.filepath.values]
self.targets = self.dataset_dataframe.final_label.values.astype(np.int64).reshape(-1)

Просмотреть файл

@ -91,4 +91,4 @@ def get_encoder_output_dim(
with torch.no_grad():
representations = pl_module(x)
return representations.shape[1]
return representations.shape[1] # type: ignore

Просмотреть файл

@ -42,7 +42,7 @@ class EncoderName(Enum):
densenet121 = "densenet121"
class SSLDatasetName(Enum):
class SSLDatasetName:
CIFAR10 = "CIFAR10"
CIFAR100 = "CIFAR100"
RSNAKaggleCXR = "RSNAKaggleCXR"
@ -64,17 +64,17 @@ class SSLContainer(LightningContainer):
Note that this container is also used as the base class for SSLImageClassifier (finetuning container) as they share
setup and datamodule methods.
"""
_SSLDataClassMappings = {SSLDatasetName.CIFAR10.value: HimlCifar10,
SSLDatasetName.CIFAR100.value: HimlCifar100,
SSLDatasetName.RSNAKaggleCXR.value: RSNAKaggleCXR,
SSLDatasetName.NIHCXR.value: NIHCXR,
SSLDatasetName.CheXpert.value: CheXpert,
SSLDatasetName.Covid.value: CovidDataset}
DatasetToClassMapping = {SSLDatasetName.CIFAR10: HimlCifar10,
SSLDatasetName.CIFAR100: HimlCifar100,
SSLDatasetName.RSNAKaggleCXR: RSNAKaggleCXR,
SSLDatasetName.NIHCXR: NIHCXR,
SSLDatasetName.CheXpert: CheXpert,
SSLDatasetName.Covid: CovidDataset}
ssl_augmentation_config = param.ClassSelector(class_=Path, allow_None=True,
doc="The path to the yaml config defining the parameters of the "
"augmentations. Ignored for CIFAR10 example")
ssl_training_dataset_name = param.ClassSelector(class_=SSLDatasetName, doc="The name of the dataset")
ssl_training_dataset_name: str = param.String(default="", doc="The name of the dataset")
ssl_training_batch_size = param.Integer(
doc="Training batch size per GPU. The effective batch size will be the number of GPUs times this number. "
"For example, if you specify ssl_training_batch_size=100 and use 4 nodes with 4 gpus each, "
@ -91,8 +91,8 @@ class SSLContainer(LightningContainer):
linear_head_augmentation_config = param.ClassSelector(class_=Path,
doc="The path to the yaml config for the linear head "
"augmentations")
linear_head_dataset_name = param.ClassSelector(class_=SSLDatasetName,
doc="Name of the dataset to use for the linear head training")
linear_head_dataset_name: str = param.String(default="",
doc="Name of the dataset to use for the linear head training")
linear_head_batch_size = param.Integer(default=16, doc="Batch size for linear head tuning")
learning_rate_linear_head_during_ssl_training = param.Number(default=1e-4,
doc="Learning rate for linear head training during "
@ -114,7 +114,7 @@ class SSLContainer(LightningContainer):
# may contain only one dataset entry
elif (
(self.linear_head_dataset_name == self.ssl_training_dataset_name) # noqa: W504
or (self.ssl_training_dataset_name is None and self.linear_head_dataset_name is not None)
or (not (self.ssl_training_dataset_name) and self.linear_head_dataset_name)
) and len(self.local_datasets) == 1:
# self.extra_local_dataset_paths = [self.local_dataset]
linear_head_dataset_path = self.local_datasets[0]
@ -131,7 +131,7 @@ class SSLContainer(LightningContainer):
self.datamodule_args = {SSLDataModuleType.LINEAR_HEAD:
DataModuleArgs(augmentation_params=self.classifier_augmentation_params,
dataset_name=self.linear_head_dataset_name.value,
dataset_name=self.linear_head_dataset_name,
dataset_path=linear_head_dataset_path,
batch_size=self.linear_head_batch_size)}
if self.ssl_training_dataset_name is not None:
@ -142,7 +142,7 @@ class SSLContainer(LightningContainer):
training_dataset_path = None
self.datamodule_args.update(
{SSLDataModuleType.ENCODER: DataModuleArgs(augmentation_params=self.ssl_augmentation_params,
dataset_name=self.ssl_training_dataset_name.value,
dataset_name=self.ssl_training_dataset_name,
dataset_path=training_dataset_path,
batch_size=self.ssl_training_batch_size)})
self.data_module: DataModuleTypes = self.get_data_module()
@ -166,7 +166,7 @@ class SSLContainer(LightningContainer):
"""
# For small images like CIFAR, if using a resnet encoder, switch the first conv layer to a 3x3 kernel instead
# of a 7x7 conv layer.
use_7x7_first_conv_in_resnet = False if self.ssl_training_dataset_name.value.startswith("CIFAR") else True
use_7x7_first_conv_in_resnet = False if self.ssl_training_dataset_name.startswith("CIFAR") else True
# Rescale the learning rate linearly according to the number of available GPUs, as seen in:
# https://arxiv.org/abs/1706.02677, to avoid a drop in performance.
@ -180,7 +180,7 @@ class SSLContainer(LightningContainer):
if self.ssl_training_type == SSLTrainingType.SimCLR:
model: LightningModule = SimClrHiml(encoder_name=self.ssl_encoder.value,
dataset_name=self.ssl_training_dataset_name.value,
dataset_name=self.ssl_training_dataset_name,
use_7x7_first_conv_in_resnet=use_7x7_first_conv_in_resnet,
num_samples=self.data_module.num_train_samples,
batch_size=self.data_module.batch_size,
@ -239,7 +239,7 @@ class SSLContainer(LightningContainer):
effective_batch_size = datamodule_args.batch_size * batch_multiplier
logging.info(f"Batch size per GPU: {datamodule_args.batch_size}")
logging.info(f"Effective batch size on {batch_multiplier} GPUs: {effective_batch_size}")
dm = HimlVisionDataModule(dataset_cls=self._SSLDataClassMappings[datamodule_args.dataset_name],
dm = HimlVisionDataModule(dataset_cls=self.DatasetToClassMapping[datamodule_args.dataset_name],
return_index=not is_ssl_encoder_module, # index is only needed for linear head
train_transforms=train_transforms,
val_split=0.1,
@ -268,17 +268,17 @@ class SSLContainer(LightningContainer):
will return only one transformation.
:return: training transformation pipeline and validation transformation pipeline.
"""
if dataset_name in [SSLDatasetName.RSNAKaggleCXR.value,
SSLDatasetName.NIHCXR.value,
SSLDatasetName.CheXpert.value,
SSLDatasetName.Covid.value]:
if dataset_name in [SSLDatasetName.RSNAKaggleCXR,
SSLDatasetName.NIHCXR,
SSLDatasetName.CheXpert,
SSLDatasetName.Covid]:
assert augmentation_config is not None
train_transforms, val_transforms = get_ssl_transforms_from_config(
augmentation_config,
return_two_views_per_sample=is_ssl_encoder_module,
use_training_augmentations_for_validation=is_ssl_encoder_module
)
elif dataset_name in [SSLDatasetName.CIFAR10.value, SSLDatasetName.CIFAR100.value]:
elif dataset_name in [SSLDatasetName.CIFAR10, SSLDatasetName.CIFAR100]:
train_transforms = \
CIFARTrainTransform(32) if is_ssl_encoder_module else CIFARLinearHeadTransform(32)
val_transforms = \
@ -302,7 +302,7 @@ class SSLContainer(LightningContainer):
self.online_eval = SslOnlineEvaluatorHiml(class_weights=self.data_module.class_weights, # type: ignore
z_dim=self.encoder_output_dim,
num_classes=self.data_module.num_classes, # type: ignore
dataset=self.linear_head_dataset_name.value, # type: ignore
dataset=self.linear_head_dataset_name, # type: ignore
drop_p=0.2,
learning_rate=self.learning_rate_linear_head_during_ssl_training)
return [self.online_eval]

Просмотреть файл

@ -97,7 +97,7 @@ class SslOnlineEvaluatorHiml(SSLOnlineEvaluator):
self.evaluator.to(pl_module.device)
if hasattr(trainer, "accelerator_connector"):
# This works with Lightning 1.3.8
accelerator = trainer.accelerator_connector
accelerator = trainer.accelerator_connector # type: ignore
elif hasattr(trainer, "_accelerator_connector"):
# This works with Lightning 1.5.5
accelerator = trainer._accelerator_connector

Просмотреть файл

@ -292,7 +292,6 @@ class BaseMILTiles(BaseMIL):
pretrained_classifier=self.pretrained_classifier,
dropout_rate=self.dropout_rate,
outputs_folder=self.outputs_folder,
encoder_params=create_from_matching_params(self, EncoderParams),
pooling_params=create_from_matching_params(self, PoolingParams),
optimizer_params=create_from_matching_params(self, OptimizerParams),

Просмотреть файл

@ -106,13 +106,22 @@ class HistoDataModule(LightningDataModule, Generic[_SlidesOrTilesDataset]):
raise NotImplementedError
def train_dataloader(self) -> DataLoader:
return self._get_dataloader(self.train_dataset, shuffle=True, stage=ModelKey.TRAIN, **self.dataloader_kwargs)
return self._get_dataloader(self.train_dataset, # type: ignore
shuffle=True,
stage=ModelKey.TRAIN,
**self.dataloader_kwargs)
def val_dataloader(self) -> DataLoader:
return self._get_dataloader(self.val_dataset, shuffle=False, stage=ModelKey.VAL, **self.dataloader_kwargs)
return self._get_dataloader(self.val_dataset, # type: ignore
shuffle=False,
stage=ModelKey.VAL,
**self.dataloader_kwargs)
def test_dataloader(self) -> DataLoader:
return self._get_dataloader(self.test_dataset, shuffle=False, stage=ModelKey.TEST, **self.dataloader_kwargs)
return self._get_dataloader(self.test_dataset, # type: ignore
shuffle=False,
stage=ModelKey.TEST,
**self.dataloader_kwargs)
class TilesDataModule(HistoDataModule[TilesDataset]):

Просмотреть файл

@ -29,6 +29,7 @@ BEST_VAL_MARKER_STYLE = dict(marker='*', markeredgecolor='w', markersize=11)
def get_tsne_projection(features: List[Any], n_components: int = 2, n_jobs: int = -1, **kwargs: Any) -> List[Any]:
"""
Get the t-sne projection of high dimensional data in a lower dimensional space
:param features: list of features in higher dimensional space (n x f for n samples and f features per sample)
:param **kwargs: keyword arguments to be passed to TSNE()
:return: list of features in lower dimensional space (n x c for n samples and c components)
@ -41,6 +42,7 @@ def get_tsne_projection(features: List[Any], n_components: int = 2, n_jobs: int
def get_umap_projection(features: List[Any], n_components: int = 2, n_jobs: int = -1, **kwargs: Any) -> List[Any]:
"""
Get the umap projection of high dimensional data in a lower dimensional space
:param features: list of features in higher dimensional space (n x f for n samples and f features per sample)
:param **kwargs: keyword arguments to be passed to UMAP()
:return: list of features in lower dimensional space (n x c for n samples and c components)
@ -50,18 +52,20 @@ def get_umap_projection(features: List[Any], n_components: int = 2, n_jobs: int
return umap_proj
def normalize_array_minmax(arr: List[float]) -> List[float]:
def normalize_array_minmax(arr: np.ndarray) -> np.ndarray:
"""
Normalize an array in range 0 to 1
:param arr: array to be normalized
:return: normalized array
"""
return (arr - np.min(arr)) / (np.max(arr) - np.min(arr))
def normalize_array_mean(arr: List[float]) -> List[float]:
def normalize_array_mean(arr: np.ndarray) -> np.ndarray:
"""
Normalize an array with zero mean and unit variance
:param arr: array to be normalized
:return: normalized array
"""
@ -71,6 +75,7 @@ def normalize_array_mean(arr: List[float]) -> List[float]:
def plot_projected_features_2d(data: Any, labels: List[int], classes: List[str], title: str = "") -> None:
"""
Plot a scatter plot of projected features in two dimensions
:param data: features projected in 2d space (nx2)
:param labels: corresponding labels of the data (nx1)
:param classes: list of classes in the dataset
@ -85,6 +90,7 @@ def plot_projected_features_2d(data: Any, labels: List[int], classes: List[str],
def plot_box_whisker(data_list: List[Any], column_names: List[str], show_outliers: bool, title: str = "") -> None:
"""
Plot a box whisker plot of column data
:param columns: data to be plotted in columns
:param column_names: names of the columns
:param show_outliers: whether outliers need to be shown
@ -105,6 +111,7 @@ def plot_box_whisker(data_list: List[Any], column_names: List[str], show_outlier
def plot_histogram(data: List[Any], title: str = "") -> None:
"""
Plot a histogram given some data
:param data: data to be plotted
:param title: plot title string
"""

Просмотреть файл

@ -42,7 +42,7 @@ def load_weights_to_model(weights_url: str, model: nn.Module) -> nn.Module:
https://github.com/ozanciga/self-supervised-histopathology
"""
map_location = device('cpu')
state = load_state_dict_from_url(weights_url, map_location=map_location)
state = load_state_dict_from_url(weights_url, map_location=map_location) # type: ignore
state_dict = state['state_dict']
model_dict = model.state_dict()

Просмотреть файл

@ -4,7 +4,6 @@
# ------------------------------------------------------------------------------------------
from pathlib import Path
from typing import Any, Dict, List, Tuple, Callable, Optional
from enum import Enum
from yacs.config import CfgNode
import pandas as pd
@ -23,7 +22,7 @@ from health_ml.lightning_container import LightningContainer, LightningModuleWit
from health_cpath.datasets.dataset_return_index import DatasetWithReturnIndex
from SSL.data.transforms_utils import DualViewTransformWrapper
from SSL.lightning_containers.ssl_container import EncoderName, SSLContainer, SSLDatasetName
from SSL.lightning_containers.ssl_container import EncoderName, SSLContainer
from SSL.utils import SSLTrainingType
from SSL.data.transform_pipeline import ImageTransformationPipeline
@ -322,19 +321,18 @@ class DummySimCLRHimlData(DatasetWithReturnIndex, DummySimCLRData):
return 2
class DummySimCLRSSLDatasetName(SSLDatasetName, Enum): # type: ignore
DUMMY = "DUMMY"
SSL_Dataset_Dummy = "DUMMY"
class DummySimCLR(SSLContainer):
"""
This module trains an SSL encoder using SimCLR on the DummySimCLRData and finetunes a linear head too.
"""
SSLContainer._SSLDataClassMappings.update({DummySimCLRSSLDatasetName.DUMMY.value: DummySimCLRHimlData})
SSLContainer.DatasetToClassMapping.update({SSL_Dataset_Dummy: DummySimCLRHimlData})
def __init__(self) -> None:
super().__init__(ssl_training_dataset_name=DummySimCLRSSLDatasetName.DUMMY,
linear_head_dataset_name=DummySimCLRSSLDatasetName.DUMMY,
super().__init__(ssl_training_dataset_name=SSL_Dataset_Dummy,
linear_head_dataset_name=SSL_Dataset_Dummy,
# Train with as little data as possible for the test
ssl_training_batch_size=2,
linear_head_batch_size=2,

Просмотреть файл

@ -38,7 +38,7 @@ def test_update_tau() -> None:
initial_tau = 0.99
byol_weight_update = ByolMovingAverageWeightUpdate(initial_tau=initial_tau)
trainer = Trainer(max_epochs=5)
trainer.train_dataloader = dummy_rsna_train_dataloader
trainer.train_dataloader = dummy_rsna_train_dataloader # type: ignore
total_steps = len(trainer.train_dataloader) * trainer.max_epochs # type: ignore
global_step = 15
byol_module = BootstrapYourOwnLatent(num_samples=16,

Просмотреть файл

@ -120,7 +120,7 @@ def test_get_transforms_in_ssl_container_for_cxr_data() -> None:
ssl_augmentation_config=path_encoder_augmentation_cxr)
test_container._load_config()
dual_view_transform, _ = test_container._get_transforms(augmentation_config=test_container.ssl_augmentation_params,
dataset_name=SSLDatasetName.NIHCXR.value,
dataset_name=SSLDatasetName.NIHCXR,
is_ssl_encoder_module=True)
test_img = PIL.Image.fromarray(np.ones([312, 312]) * 255.).convert("L")
@ -134,7 +134,7 @@ def test_get_transforms_in_ssl_container_for_cxr_data() -> None:
single_view_transform, _ = test_container._get_transforms(
augmentation_config=test_container.ssl_augmentation_params,
dataset_name=SSLDatasetName.NIHCXR.value,
dataset_name=SSLDatasetName.NIHCXR,
is_ssl_encoder_module=False)
v1 = single_view_transform(test_img)
# Images should be cropped to 224 x 224 and expanded to 3 channels according to config
@ -149,7 +149,7 @@ def test_get_transforms_in_SSL_container_for_cifar_data() -> None:
"""
test_container = SSLContainer()
dual_view_transform, _ = test_container._get_transforms(augmentation_config=None,
dataset_name=SSLDatasetName.CIFAR10.value,
dataset_name=SSLDatasetName.CIFAR10,
is_ssl_encoder_module=True)
img_array_with_black_square = np.ones([32, 32, 3], dtype=np.uint8)
img_array_with_black_square[10:20, 10:20, :] = 255
@ -161,7 +161,7 @@ def test_get_transforms_in_SSL_container_for_cifar_data() -> None:
assert (v1 != v2).any()
single_view_transform, _ = test_container._get_transforms(augmentation_config=None,
dataset_name=SSLDatasetName.CIFAR10.value,
dataset_name=SSLDatasetName.CIFAR10,
is_ssl_encoder_module=False)
v1 = single_view_transform(test_img)
# Images should be cropped to 224 x 224 and expanded to 3 channels according to config

Просмотреть файл

@ -129,7 +129,7 @@ def test_ssl_container_cifar10_resnet_simclr() -> None:
assert loaded_config.max_epochs == 1
assert loaded_config.ssl_training_type == SSLTrainingType.SimCLR
assert loaded_config.online_eval.num_classes == 10
assert loaded_config.online_eval.dataset == SSLDatasetName.CIFAR10.value
assert loaded_config.online_eval.dataset == SSLDatasetName.CIFAR10
assert loaded_config.ssl_training_dataset_name == SSLDatasetName.CIFAR10
assert not loaded_config.use_balanced_binary_loss_for_linear_head
assert isinstance(loaded_config.model.encoder.cnn_model, ResNet)
@ -208,7 +208,7 @@ def test_ssl_container_rsna() -> None:
loaded_config, _ = runner.run()
assert loaded_config is not None
assert isinstance(loaded_config.model, BootstrapYourOwnLatent)
assert loaded_config.online_eval.dataset == SSLDatasetName.RSNAKaggleCXR.value
assert loaded_config.online_eval.dataset == SSLDatasetName.RSNAKaggleCXR
assert loaded_config.online_eval.num_classes == 2
assert loaded_config.ssl_training_dataset_name == SSLDatasetName.NIHCXR
assert loaded_config.ssl_training_type == SSLTrainingType.BYOL

Просмотреть файл

@ -4,15 +4,13 @@
# ------------------------------------------------------------------------------------------
from enum import Enum
from pathlib import Path
import numpy as np
import pandas as pd
from torch import Tensor
from tifffile import TiffWriter
from typing import Any, Optional, Tuple, List, Union
import numpy as np
import pandas as pd
import torch
from tifffile import TiffWriter
from torch import Tensor
from health_cpath.datasets.panda_dataset import PandaDataset
from testhisto.mocks.base_data_generator import MockHistoDataGenerator, MockHistoDataType, PANDA_N_CLASSES
@ -209,7 +207,7 @@ class MockPandaSlidesGenerator(MockHistoDataGenerator):
if self.n_tiles_list:
self.total_tiles = self.n_tiles_list[slide_counter]
self.n_tiles: int = self.n_tiles_list[slide_counter]
self.dataloader: torch.data.utils.Dataloader = self.get_dataloader()
self.dataloader: torch.utils.data.DataLoader = self.get_dataloader()
iterator = iter(self.dataloader)
tiles, _ = next(iterator) if iterator else (None, None)

Просмотреть файл

@ -21,5 +21,5 @@ def download_azure_dataset(tmp_path: Path, dataset_id: str) -> None:
with check_config_json(script_folder=tmp_path, shared_config_json=get_shared_config_json()):
ws = get_workspace(workspace_config_path=tmp_path / WORKSPACE_CONFIG_JSON)
dataset = DatasetConfig(name=dataset_id, target_folder=tmp_path, use_mounting=False)
dataset_dl_folder = dataset.to_input_dataset_local(ws)
dataset_dl_folder = dataset.to_input_dataset_local(strictly_aml_v1=True, workspace=ws)
logging.info(f"Dataset saved in {dataset_dl_folder}")

Просмотреть файл

@ -61,8 +61,9 @@ def _test_encoder(encoder: nn.Module, input_dims: Tuple[int, ...], output_dim: i
@pytest.mark.parametrize("create_encoder_fn", [get_supervised_imagenet_encoder,
get_simclr_imagenet_encoder,
get_histo_ssl_encoder,
get_ssl_encoder])
get_histo_ssl_encoder
# get_ssl_encoder # Removed because of test failure
])
def test_encoder(create_encoder_fn: Callable[[], TileEncoder], tmp_path: Path) -> None:
if create_encoder_fn == get_ssl_encoder:
with patch("health_ml.utils.checkpoint_utils.get_workspace") as mock_get_workspace:

Просмотреть файл

@ -41,6 +41,7 @@ def test_load_ssl_checkpoint_from_local_file(tmp_path: Path) -> None:
assert isinstance(encoder, SSLEncoder)
@pytest.mark.skip(reason="This test is failing because of issue #655")
def test_load_ssl_checkpoint_from_url(tmp_path: Path) -> None:
blob_url = get_checkpoint_url_from_aml_run(
run_id=TEST_SSL_RUN_ID,
@ -56,6 +57,7 @@ def test_load_ssl_checkpoint_from_url(tmp_path: Path) -> None:
assert isinstance(encoder, SSLEncoder)
@pytest.mark.skip(reason="This test is failing because of issue #655")
def test_load_ssl_checkpoint_from_run_id(tmp_path: Path) -> None:
encoder_params = EncoderParams(encoder_type=SSLEncoder.__name__, ssl_checkpoint=CheckpointParser(TEST_SSL_RUN_ID))
assert encoder_params.ssl_checkpoint.is_aml_run_id

Просмотреть файл

@ -278,8 +278,10 @@ def test_location_selected_tiles(level: int) -> None:
slide_ids = [item[0] for item in test_dict[ResultsKey.SLIDE_ID]] # type: ignore
slide_idx = slide_ids.index(slide)
for tile_idx in range(len(test_dict[ResultsKey.IMAGE_PATH][slide_idx])): # type: ignore
tile_coords = np.transpose(np.array([test_dict[ResultsKey.TILE_LEFT][slide_idx][tile_idx].cpu().numpy(),
test_dict[ResultsKey.TILE_TOP][slide_idx][tile_idx].cpu().numpy()]))
tile_coords = np.transpose(np.array(
[test_dict[ResultsKey.TILE_LEFT][slide_idx][tile_idx].cpu().numpy(), # type: ignore
test_dict[ResultsKey.TILE_TOP][slide_idx][tile_idx].cpu().numpy() # type: ignore
]))
coords_list.append(tile_coords)
coords = np.array(coords_list)

Просмотреть файл

@ -34,6 +34,13 @@ This will create a `conda` environment named `multimodal` and install all the de
You can visit the [API documentation][9] for a deeper understanding of our tools.
## Examples
For zero-shot classification of images using text prompts, please refer to the [example
script](./test_multimodal/vlp/test_zero_shot_classification.py) that utilises a small subset of [Open-Indiana CXR
dataset][10] for pneumonia detection in Chest X-ray images. Please note that the examples and models are not intended for
deployed use cases -- commercial or otherwise -- which is currently out-of-scope.
## Hugging Face 🤗
While the [GitHub repository][1] provides examples and pipelines to use our models,
@ -70,3 +77,4 @@ If you use our code or models in your research, please cite [the manuscript][7]
[7]: https://arxiv.org/abs/2204.09817
[8]: https://eccv2022.ecva.net/
[9]: https://hi-ml.readthedocs.io/en/latest/api/multimodal.html
[10]: https://openi.nlm.nih.gov/faq

Просмотреть файл

@ -51,7 +51,7 @@
},
"outputs": [],
"source": [
"repo_branch = \"main\""
"pip_source = \"hi-ml-multimodal\""
]
},
{
@ -60,9 +60,6 @@
"metadata": {},
"outputs": [],
"source": [
"repo_url = \"git+https://github.com/microsoft/hi-ml.git\"\n",
"subdirectory = \"hi-ml-multimodal\"\n",
"pip_source = f\"{repo_url}@{repo_branch}#subdirectory={subdirectory}\"\n",
"%pip install --quiet {pip_source}"
]
},
@ -203,7 +200,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
"version": "3.7.13"
},
"vscode": {
"interpreter": {

Просмотреть файл

@ -1,4 +1,4 @@
flake8==4.0.1
flake8==5.0.2
ipykernel==6.15.0
ipython==7.34.0
mypy==0.931

Просмотреть файл

@ -8,7 +8,7 @@ from setuptools import find_namespace_packages, setup # type: ignore
long_description = Path("README.md").read_text(encoding="utf-8")
version = "0.1.0"
version = "0.1.1"
package_name = "hi-ml-multimodal"
install_requires = Path("requirements_run.txt").read_text().splitlines()

Просмотреть файл

@ -3,4 +3,4 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------
__version__ = "0.1.0"
__version__ = "0.1.1"

Просмотреть файл

@ -106,7 +106,7 @@ def _plot_heatmap(
axis.set_title(title)
def plot_phrase_grounding_similarity_map(image_path: Path, similarity_map: np.ndarray) -> None:
def plot_phrase_grounding_similarity_map(image_path: Path, similarity_map: np.ndarray) -> plt.Figure:
"""Plot visualization of the input image, the similarity heatmap and the heatmap isolines.
:param image_path: Path to the input image.
@ -117,3 +117,4 @@ def plot_phrase_grounding_similarity_map(image_path: Path, similarity_map: np.nd
_plot_image(image, axis=axes[0], title="Input image")
_plot_isolines(image, similarity_map, axis=axes[1], title="Similarity isolines")
_plot_heatmap(image, similarity_map, figure=fig, axis=axes[2], title="Similarity heatmap")
return fig

Просмотреть файл

@ -8,6 +8,7 @@ from pathlib import Path
from typing import Callable, Tuple
import torch
import torch.nn.functional as F
from torchvision.transforms import Compose
from health_multimodal.image.data.io import load_image
@ -55,8 +56,8 @@ class ImageInferenceEngine:
return transformed_image, image.size
@torch.no_grad()
def get_patch_embeddings_from_image(self, image_path: Path) -> Tuple[torch.Tensor, TypeShape2D]:
"""Compute image embeddings in the joint latent space, preserving the image grid.
def get_projected_patch_embeddings(self, image_path: Path) -> Tuple[torch.Tensor, TypeShape2D]:
"""Compute image patch embeddings in the joint latent space, preserving the image grid.
:param image_path: Path to the image to compute embeddings for.
:return: A tuple containing the image patch embeddings and
@ -67,3 +68,20 @@ class ImageInferenceEngine:
assert projected_img_emb.shape[0] == 1
return projected_img_emb[0], img_shape
@torch.no_grad()
def get_projected_global_embedding(self, image_path: Path) -> torch.Tensor:
"""Compute global image embedding in the joint latent space.
:param image_path: Path to the image to compute embeddings for.
:return: Torch tensor containing l2-normalised global image embedding [joint_feature_dim,]
where joint_feature_dim is the dimensionality of the joint latent space.
"""
input_image, _ = self.load_and_transform_input_image(image_path, self.transform)
projected_img_emb = self.model.forward(input_image).projected_global_embedding
projected_img_emb = F.normalize(projected_img_emb, dim=-1)
assert projected_img_emb.shape[0] == 1
assert projected_img_emb.ndim == 2
return projected_img_emb[0]

Просмотреть файл

@ -48,10 +48,14 @@ class TextInferenceEngine(TextInput):
return tokenizer_output
@torch.no_grad()
def get_embeddings_from_prompt(self, prompts: Union[str, List[str]], verbose: bool = True) -> torch.Tensor:
def get_embeddings_from_prompt(self,
prompts: Union[str, List[str]],
normalize: bool = True,
verbose: bool = True) -> torch.Tensor:
"""Generate L2-normalised embeddings for a list of input text prompts.
:param prompts: Input text prompt(s) either in string or list of string format.
:param normalize: If True, L2-normalise the embeddings.
:param verbose: If set to True, tokenized words are displayed in the console.
:return: Tensor of shape (batch_size, embedding_size).
"""
@ -60,7 +64,8 @@ class TextInferenceEngine(TextInput):
tokenizer_output = self.tokenize_input_prompts(prompts=prompts, verbose=verbose)
txt_emb = self.model.get_projected_text_embeddings( # type: ignore
input_ids=tokenizer_output.input_ids,
attention_mask=tokenizer_output.attention_mask)
attention_mask=tokenizer_output.attention_mask,
normalize_embeddings=normalize)
return txt_emb

Просмотреть файл

@ -114,13 +114,17 @@ class CXRBertModel(BertForMaskedLM):
bert_for_masked_lm_output.hidden_states,
bert_for_masked_lm_output.attentions,)
def get_projected_text_embeddings(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
def get_projected_text_embeddings(self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
normalize_embeddings: bool = True) -> torch.Tensor:
"""
Returns l2-normalised projected cls token embeddings for the given input token ids and attention mask.
The joint latent space is trained using a contrastive objective between image and text data modalities.
:param input_ids: (batch_size, sequence_length)
:param attention_mask: (batch_size, sequence_length)
:param normalize_embeddings: Whether to l2-normalise the embeddings.
:return: (batch_size, projection_size)
"""
@ -128,6 +132,10 @@ class CXRBertModel(BertForMaskedLM):
output_cls_projected_embedding=True, return_dict=True)
assert isinstance(outputs, CXRBertOutput)
assert outputs.cls_projected_embedding is not None
normalized_cls_embedding = F.normalize(outputs.cls_projected_embedding, dim=1)
return normalized_cls_embedding
cls_projected_embedding = outputs.cls_projected_embedding
assert cls_projected_embedding is not None
if normalize_embeddings:
return F.normalize(cls_projected_embedding, dim=1)
return cls_projected_embedding

Просмотреть файл

@ -7,7 +7,7 @@
from math import ceil, floor
from pathlib import Path
from typing import Callable, Optional
from typing import Callable, List, Optional, Union
import numpy as np
import torch
@ -27,6 +27,35 @@ class ImageTextInferenceEngine:
self.image_inference_engine = image_inference_engine
self.text_inference_engine = text_inference_engine
@torch.no_grad()
def get_similarity_score_from_raw_data(self,
image_path: Path,
query_text: Union[List[str], str]) -> float:
"""Compute the cosine similarity score between an image and one or more strings.
If multiple strings are passed, their embeddings are averaged before L2-normalization.
:param image_path: Path to the input chest X-ray, either a DICOM or JPEG file.
:param query_text: Input radiology text phrase.
:return: The similarity score between the image and the text.
"""
assert not self.image_inference_engine.model.training
assert not self.text_inference_engine.model.training
query_text = [query_text] if isinstance(query_text, str) else query_text
num_prompts = len(query_text)
image_embedding = self.image_inference_engine.get_projected_global_embedding(image_path)
text_embedding = self.text_inference_engine.get_embeddings_from_prompt(query_text, normalize=False)
assert text_embedding.shape[0] == num_prompts
text_embedding = text_embedding.mean(dim=0)
text_embedding = F.normalize(text_embedding, dim=0, p=2)
cos_similarity = image_embedding @ text_embedding.t()
return cos_similarity.item()
def get_similarity_map_from_raw_data(self,
image_path: Path,
query_text: str,
@ -42,10 +71,10 @@ class ImageTextInferenceEngine:
"""
assert not self.image_inference_engine.model.training
assert not self.text_inference_engine.model.training
assert isinstance(query_text, str)
# TODO: Add checks in here regarding the text query, etc.
image_embedding, (width, height) = self.image_inference_engine.get_patch_embeddings_from_image(image_path)
image_embedding, (width, height) = self.image_inference_engine.get_projected_patch_embeddings(image_path)
text_embedding = self.text_inference_engine.get_embeddings_from_prompt(query_text)
sim = self._get_similarity_map_from_embeddings(image_embedding, text_embedding)
@ -65,8 +94,7 @@ class ImageTextInferenceEngine:
def _get_similarity_map_from_embeddings(projected_patch_embeddings: torch.Tensor,
projected_text_embeddings: torch.Tensor,
sigma: float = 1.5) -> torch.Tensor:
"""
Get smoothed similarity map for a given image patch embeddings and text embeddings.
"""Get smoothed similarity map for a given image patch embeddings and text embeddings.
:param projected_patch_embeddings: [n_patches_h, n_patches_w, feature_size]
:param projected_text_embeddings: [1, feature_size]

Просмотреть файл

@ -0,0 +1,38 @@
# -------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------
import tempfile
from pathlib import Path
import pytest
import torch
from PIL import Image
from health_multimodal.image import ImageModel, ResnetType, ImageInferenceEngine
from health_multimodal.image.data.transforms import create_chest_xray_transform_for_inference
@pytest.mark.parametrize("height", (400, 500, 650))
def test_image_inference_engine(height: int) -> None:
"""Test the image inference engine with a dummy image and ensure that the output is of the correct shape."""
joint_feature_size = 128
resize = 512
center_crop_size = 480
width = 600
image_inference = ImageInferenceEngine(
image_model=ImageModel(img_model_type=ResnetType.RESNET50.value, joint_feature_size=joint_feature_size),
transform=create_chest_xray_transform_for_inference(resize=resize, center_crop_size=center_crop_size))
with tempfile.NamedTemporaryFile(suffix='.jpg') as f:
image_path = Path(f.name)
image = Image.new('RGB', (width, height))
image.save(image_path)
# Test individual components
image_embedding = image_inference.get_projected_global_embedding(image_path)
assert image_embedding.shape == (joint_feature_size,)
assert torch.allclose(torch.norm(image_embedding), torch.tensor([1.00]))

Просмотреть файл

@ -5,38 +5,44 @@
import tempfile
from pathlib import Path
from typing import List, Union
import torch
import pytest
import numpy as np
import pytest
import torch
from health_multimodal.image import ImageInferenceEngine, ImageModel, ResnetType
from health_multimodal.image.data.transforms import create_chest_xray_transform_for_inference
from health_multimodal.image.model.model import JOINT_FEATURE_SIZE
from health_multimodal.text.utils import get_cxr_bert_inference
from health_multimodal.vlp.inference_engine import ImageTextInferenceEngine
from PIL import Image
from health_multimodal.text.utils import get_cxr_bert_inference
from health_multimodal.image import ImageModel, ResnetType, ImageInferenceEngine
from health_multimodal.image.data.transforms import create_chest_xray_transform_for_inference
from health_multimodal.vlp.inference_engine import ImageTextInferenceEngine
CENTER_CROP_SIZE = 480
@pytest.mark.parametrize("height", (400, 500, 650))
@pytest.mark.parametrize("query_text", ("", "hello", "this is a test"))
def test_vlp_inference(height: int, query_text: str) -> None:
image_embedding_shapes = {
480: (15, 15),
}
def _get_vlp_inference_engine() -> ImageTextInferenceEngine:
joint_feature_size = 128
resize = 512
center_crop_size = 480
width = 600
image_inference = ImageInferenceEngine(
image_model=ImageModel(img_model_type=ResnetType.RESNET50.value, joint_feature_size=joint_feature_size),
transform=create_chest_xray_transform_for_inference(resize=resize, center_crop_size=center_crop_size))
image_model=ImageModel(img_model_type=ResnetType.RESNET50.value, joint_feature_size=JOINT_FEATURE_SIZE),
transform=create_chest_xray_transform_for_inference(resize=512, center_crop_size=CENTER_CROP_SIZE))
img_txt_inference = ImageTextInferenceEngine(
image_inference_engine=image_inference,
text_inference_engine=get_cxr_bert_inference(),
)
return img_txt_inference
@pytest.mark.parametrize("height", (400, 500, 650))
@pytest.mark.parametrize("query_text", ("", "hello", "this is a test"))
def test_vlp_inference(height: int, query_text: Union[str, List[str]]) -> None:
image_embedding_shapes = {480: (15, 15), }
width = 600
img_txt_inference = _get_vlp_inference_engine()
image_inference = img_txt_inference.image_inference_engine
with tempfile.NamedTemporaryFile(suffix='.jpg') as f:
image_path = Path(f.name)
image = Image.new('RGB', (width, height))
@ -53,16 +59,34 @@ def test_vlp_inference(height: int, query_text: str) -> None:
assert resampled_similarity_map.max() <= 1
# Test individual components
image_embedding, size = img_txt_inference.image_inference_engine.get_patch_embeddings_from_image(image_path)
image_embedding, size = image_inference.get_projected_patch_embeddings(image_path)
assert (width, height) == size
expected_image_embedding_size = image_embedding_shapes[center_crop_size]
assert image_embedding.shape == (*expected_image_embedding_size, joint_feature_size)
expected_image_embedding_size = image_embedding_shapes[CENTER_CROP_SIZE]
assert image_embedding.shape == (*expected_image_embedding_size, JOINT_FEATURE_SIZE)
normalized_image_embedding = torch.norm(image_embedding, p=2, dim=-1)
assert torch.allclose(normalized_image_embedding, torch.ones_like(normalized_image_embedding))
text_embedding = img_txt_inference.text_inference_engine.get_embeddings_from_prompt(query_text)
assert text_embedding.shape == (1, joint_feature_size)
assert text_embedding.shape == (1, JOINT_FEATURE_SIZE)
similarity_map = img_txt_inference._get_similarity_map_from_embeddings(image_embedding, text_embedding)
assert similarity_map.shape == expected_image_embedding_size
@pytest.mark.parametrize("query_text", ("this is a test", ["Test prompt 1", "Test prompt 2"]))
def test_vlp_inference_global_similarity(query_text: str) -> None:
img_txt_inference = _get_vlp_inference_engine()
with tempfile.NamedTemporaryFile(suffix='.jpg') as f:
image_path = Path(f.name)
height, width = 500, 600
image = Image.new('RGB', (width, height))
image.save(image_path)
# Test global similarity score
sim_score = img_txt_inference.get_similarity_score_from_raw_data(image_path=image_path,
query_text=query_text)
assert isinstance(sim_score, float)
assert 1 >= sim_score >= -1

Просмотреть файл

@ -0,0 +1,98 @@
import tempfile
from enum import Enum, unique
from pathlib import Path
from typing import List, Tuple, Union
import requests
from torchvision.datasets.utils import check_integrity
from health_multimodal.image import ImageInferenceEngine
from health_multimodal.image.data.transforms import create_chest_xray_transform_for_inference
from health_multimodal.image.model.model import get_biovil_resnet
from health_multimodal.text.utils import get_cxr_bert_inference
from health_multimodal.vlp.inference_engine import ImageTextInferenceEngine
RESIZE = 512
CENTER_CROP_SIZE = 512
@unique
class ClassType(str, Enum):
"""Enum for the different types of CXR abnormality classes."""
PNEUMONIA = "pneumonia"
NO_PNEUMONIA = "no_pneumonia"
def _get_vlp_inference_engine() -> ImageTextInferenceEngine:
image_inference = ImageInferenceEngine(
image_model=get_biovil_resnet(pretrained=True),
transform=create_chest_xray_transform_for_inference(resize=RESIZE, center_crop_size=CENTER_CROP_SIZE))
img_txt_inference = ImageTextInferenceEngine(
image_inference_engine=image_inference,
text_inference_engine=get_cxr_bert_inference(),
)
return img_txt_inference
def _get_default_text_prompts_for_pneumonia() -> Tuple[List, List]:
"""
Get the default text prompts for presence and absence of pneumonia
"""
pos_query = ['Findings consistent with pneumonia', 'Findings suggesting pneumonia',
'This opacity can represent pneumonia', 'Findings are most compatible with pneumonia']
neg_query = ['There is no pneumonia', 'No evidence of pneumonia',
'No evidence of acute pneumonia', 'No signs of pneumonia']
return pos_query, neg_query
def save_img_from_url(image_url: str, local_path: Union[str, Path], md5: str = None) -> None:
"""
Pull an image from a URL and save it to a local path
"""
img_data = requests.get(image_url, timeout=30).content
with open(local_path, 'wb') as handler:
handler.write(img_data)
if md5 is not None:
assert check_integrity(local_path, md5)
def test_zero_shot_pneumonia_classification() -> None:
"""
Checks latent similarity between text prompts and image embeddings for presence and absence of pneumonia
"""
input_data = [
("https://openi.nlm.nih.gov/imgs/512/173/1777/CXR1777_IM-0509-1001.png", "f140126fff5d7d9f4a9402afadcbbf99", ClassType.PNEUMONIA), # noqa: E501
("https://openi.nlm.nih.gov/imgs/512/6/808/CXR808_IM-2341-2001.png", "ee19699d4305d17beecad94762a2ebcc", ClassType.PNEUMONIA), # noqa: E501
("https://openi.nlm.nih.gov/imgs/512/342/3951/CXR3951_IM-2019-1001.png", "786d5d854b1f6be1d6a0c3392794497a", ClassType.PNEUMONIA), # noqa: E501
("https://openi.nlm.nih.gov/imgs/512/16/2422/CXR2422_IM-0965-1001.png", "84dd31c1b0cbaaf8e1c0004d8917a78d", ClassType.PNEUMONIA), # noqa: E501
("https://openi.nlm.nih.gov/imgs/512/365/3172/CXR3172_IM-1494-1001.png", "dc476e73d3fd3b178a5303f48221e9d5", ClassType.NO_PNEUMONIA), # noqa: E501
("https://openi.nlm.nih.gov/imgs/512/161/1765/CXR1765_IM-0499-1001.png", "fdc6d3753f853352b35f5edf4ee0873c", ClassType.NO_PNEUMONIA), # noqa: E501
("https://openi.nlm.nih.gov/imgs/512/327/3936/CXR3936_IM-2007-1001.png", "ef537c81a0d8c0ae618625970c122ecc", ClassType.NO_PNEUMONIA), # noqa: E501
("https://openi.nlm.nih.gov/imgs/512/76/1279/CXR1279_IM-0185-1001.png", "9d03d740dcb0e7e068eb5eb73355262e", ClassType.NO_PNEUMONIA), # noqa: E501
("https://openi.nlm.nih.gov/imgs/512/5/5/CXR5_IM-2117-1003002.png", "204a2c83f94a4d4e74b6cea43caabdf2", ClassType.NO_PNEUMONIA), # noqa: E501
("https://openi.nlm.nih.gov/imgs/512/219/3828/CXR3828_IM-1932-1001.png", "1260a14f197527030fa0cb6b2d6950b8", ClassType.NO_PNEUMONIA), # noqa: E501
("https://openi.nlm.nih.gov/imgs/512/16/818/CXR818_IM-2349-1001.png", "2756b3a746e54e72f1efbbed72c2f83b", ClassType.PNEUMONIA), # noqa: E501
("https://openi.nlm.nih.gov/imgs/512/358/2363/CXR2363_IM-0926-1001.png", "bcdc990161a5234e08d0595ed8a0bbf0", ClassType.PNEUMONIA)] # noqa: E501
img_txt_inference = _get_vlp_inference_engine()
positive_prompts, negative_prompts = _get_default_text_prompts_for_pneumonia()
for cxr_url, md5, label_str in input_data:
suffix = Path(cxr_url).suffix
with tempfile.NamedTemporaryFile(suffix=suffix) as f:
image_path = Path(f.name)
save_img_from_url(cxr_url, image_path, md5=md5)
positive_score = img_txt_inference.get_similarity_score_from_raw_data(
image_path=image_path,
query_text=positive_prompts)
negative_score = img_txt_inference.get_similarity_score_from_raw_data(
image_path=image_path,
query_text=negative_prompts)
if label_str == ClassType.PNEUMONIA:
assert positive_score > negative_score
else:
assert negative_score > positive_score

Просмотреть файл

@ -1,16 +1,254 @@
# This environment definition contains all packages to run hi-ml and hi-ml-azure development work, building and
# testing
# WARNING - DO NOT EDIT THIS FILE MANUALLY
# To update, please modify 'primary_deps.yml' and then run the locking script 'create_and_lock_environment.sh'
name: himl
channels:
- defaults
- pytorch
- defaults
dependencies:
- pip=20.1.1
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- blas=1.0=mkl
- ca-certificates=2022.07.19=h06a4308_0
- certifi=2022.9.24=py37h06a4308_0
- cudatoolkit=11.3.1=h2bc3f7f_2
- intel-openmp=2022.1.0=h9e868ea_3769
- libedit=3.1.20210910=h7f8727e_0
- libffi=3.2.1=hf484d3e_1007
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libstdcxx-ng=11.2.0=h1234567_1
- libuv=1.40.0=h7b6447c_0
- mkl=2022.1.0=hc2b9512_224
- ncurses=6.3=h5eee18b_3
- openssl=1.1.1q=h7f8727e_0
- pip=20.1.1=py37_1
- python=3.7.3
- pytorch=1.10.0
- cudatoolkit=11.3.1
- pytorch-mutex=1.0=cuda
- readline=7.0=h7b6447c_5
- setuptools=63.4.1=py37h06a4308_0
- sqlite=3.33.0=h62c20be_0
- tk=8.6.12=h1ccaba5_0
- xz=5.2.6=h5eee18b_0
- zlib=1.2.12=h5eee18b_3
- pip:
- -r ../hi-ml-azure/run_requirements.txt
- -r run_requirements.txt
- -r ../build_requirements.txt
- -r ../test_requirements.txt
- absl-py==1.2.0
- adal==1.2.7
- aiohttp==3.8.3
- aiosignal==1.2.0
- alabaster==0.7.12
- alembic==1.8.1
- applicationinsights==0.11.10
- argcomplete==2.0.0
- astroid==2.12.11
- async-timeout==4.0.2
- asynctest==0.13.0
- attrs==22.1.0
- azure-ai-ml==1.0.0
- azure-common==1.1.28
- azure-core==1.26.0
- azure-graphrbac==0.61.1
- azure-identity==1.11.0
- azure-mgmt-authorization==2.0.0
- azure-mgmt-containerregistry==10.0.0
- azure-mgmt-core==1.3.2
- azure-mgmt-keyvault==10.1.0
- azure-mgmt-resource==21.2.0
- azure-mgmt-storage==20.0.0
- azure-storage-blob==12.10.0
- azure-storage-file-datalake==12.9.0
- azure-storage-file-share==12.10.0
- azureml-core==1.46.0
- azureml-dataprep==4.5.7
- azureml-dataprep-native==38.0.0
- azureml-dataprep-rslex==2.11.4
- azureml-dataset-runtime==1.46.0
- azureml-mlflow==1.46.0
- azureml-telemetry==1.46.0
- azureml-tensorboard==1.46.0
- azureml-train-core==1.46.0
- azureml-train-restclients-hyperdrive==1.46.0
- babel==2.10.3
- backports-tempfile==1.0
- backports-weakref==1.0.post1
- bcrypt==4.0.1
- black==22.1.0
- bleach==5.0.1
- cachetools==5.2.0
- cffi==1.15.1
- cfgv==3.3.1
- charset-normalizer==2.1.1
- click==8.1.3
- cloudpickle==2.2.0
- colorama==0.4.4
- conda-merge==0.2.0
- contextlib2==21.6.0
- coverage==6.3.2
- cryptography==37.0.4
- cycler==0.11.0
- databricks-cli==0.17.3
- dataclasses-json==0.5.2
- dill==0.3.5.1
- distlib==0.3.6
- distro==1.8.0
- docker==5.0.3
- docutils==0.16
- dotnetcore2==3.1.23
- entrypoints==0.4
- filelock==3.8.0
- flake8==5.0.2
- flask==2.2.2
- fonttools==4.37.4
- frozenlist==1.3.1
- fsspec==2022.8.2
- fusepy==3.0.1
- gitdb==4.0.9
- gitpython==3.1.29
- google-auth==2.12.0
- google-auth-oauthlib==0.4.6
- greenlet==1.1.3.post0
- grpcio==1.49.1
- gunicorn==20.1.0
- hi-ml-azure==0.2.7
- humanfriendly==10.0
- identify==2.5.6
- idna==3.4
- imagesize==1.4.1
- importlib-metadata==4.13.0
- importlib-resources==5.10.0
- iniconfig==1.1.1
- isodate==0.6.1
- isort==5.10.1
- itsdangerous==2.1.2
- jaraco-classes==3.2.3
- jeepney==0.8.0
- jinja2==3.0.2
- jmespath==1.0.1
- joblib==1.2.0
- jsonpickle==2.2.0
- jsonschema==4.16.0
- keyring==23.9.3
- kiwisolver==1.4.4
- knack==0.9.0
- lazy-object-proxy==1.7.1
- lxml==4.9.1
- mako==1.2.3
- markdown==3.4.1
- markdown-it-py==1.1.0
- markupsafe==2.1.1
- marshmallow==3.18.0
- marshmallow-enum==1.5.1
- matplotlib==3.5.3
- mccabe==0.7.0
- mdit-py-plugins==0.2.8
- mlflow==1.29.0
- mlflow-skinny==1.29.0
- more-itertools==8.14.0
- msal==1.20.0
- msal-extensions==1.0.0
- msrest==0.7.1
- msrestazure==0.6.4
- multidict==6.0.2
- mypy==0.931
- mypy-extensions==0.4.3
- myst-parser==0.15.2
- ndg-httpsclient==0.5.1
- nodeenv==1.7.0
- numpy==1.21.6
- oauthlib==3.2.1
- opencv-python-headless==4.6.0.66
- packaging==21.3
- pandas==1.3.5
- param==1.12.2
- paramiko==2.11.0
- pathspec==0.10.1
- pillow==9.2.0
- pkginfo==1.8.3
- pkgutil-resolve-name==1.3.10
- platformdirs==2.5.2
- pluggy==0.13.1
- portalocker==2.5.1
- pre-commit==2.19.0
- prometheus-client==0.14.1
- prometheus-flask-exporter==0.20.3
- protobuf==3.20.3
- py==1.11.0
- pyarrow==9.0.0
- pyasn1==0.4.8
- pyasn1-modules==0.2.8
- pycobertura==2.0.1
- pycodestyle==2.9.1
- pycparser==2.21
- pydash==5.1.1
- pydeprecate==0.3.2
- pyflakes==2.5.0
- pygments==2.13.0
- pyjwt==2.5.0
- pylint==2.15.0
- pynacl==1.5.0
- pyopenssl==22.1.0
- pyparsing==3.0.9
- pyrsistent==0.18.1
- pysocks==1.7.1
- pytest==6.2.2
- pytest-cov==2.11.1
- pytest-rerunfailures==10.2
- pytest-timeout==2.0.1
- python-dateutil==2.8.2
- pytorch-lightning==1.6.5
- pytz==2022.4
- pyyaml==6.0
- querystring-parser==1.2.4
- readme-renderer==37.2
- requests==2.28.1
- requests-oauthlib==1.3.1
- requests-toolbelt==0.10.0
- rfc3986==2.0.0
- rpdb==0.1.6
- rsa==4.9
- ruamel-yaml==0.17.21
- ruamel-yaml-clib==0.2.6
- scikit-learn==1.0.2
- scipy==1.7.3
- secretstorage==3.3.3
- six==1.16.0
- smmap==5.0.0
- snowballstemmer==2.2.0
- sphinx==4.1.2
- sphinx-autodoc-typehints==1.12.0
- sphinx-automodapi==0.13
- sphinx-rtd-theme==1.0.0
- sphinxcontrib-applehelp==1.0.2
- sphinxcontrib-devhelp==1.0.2
- sphinxcontrib-htmlhelp==2.0.0
- sphinxcontrib-jsmath==1.0.1
- sphinxcontrib-qthelp==1.0.3
- sphinxcontrib-serializinghtml==1.1.5
- sqlalchemy==1.4.41
- sqlparse==0.4.3
- strictyaml==1.6.1
- stringcase==1.2.0
- tabulate==0.9.0
- tensorboard==2.10.1
- tensorboard-data-server==0.6.1
- tensorboard-plugin-wit==1.8.1
- threadpoolctl==3.1.0
- toml==0.10.2
- tomli==2.0.1
- tomlkit==0.11.5
- torch==1.12.1
- torchmetrics==0.10.0
- torchvision==0.13.1
- tqdm==4.63.0
- twine==3.3.0
- typed-ast==1.5.4
- typing-extensions==4.4.0
- typing-inspect==0.8.0
- urllib3==1.26.12
- virtualenv==20.16.5
- webencodings==0.5.1
- websocket-client==1.4.1
- werkzeug==2.2.2
- wheel==0.36.2
- wrapt==1.14.1
- yarl==1.8.1
- zipp==3.9.0

16
hi-ml/primary_deps.yml Normal file
Просмотреть файл

@ -0,0 +1,16 @@
# This environment definition contains all packages to run hi-ml and hi-ml-azure development work, building and
# testing
name: himl
channels:
- defaults
- pytorch
dependencies:
- pip=20.1.1
- python=3.7.3
- pytorch=1.10.0
- cudatoolkit=11.3.1
- pip:
- -r ../hi-ml-azure/run_requirements.txt
- -r run_requirements.txt
- -r ../build_requirements.txt
- -r ../test_requirements.txt

Просмотреть файл

@ -4,7 +4,7 @@ jinja2==3.0.2
matplotlib>=3.4.3
opencv-python-headless>=4.5.1.48
pandas>=1.3.4
protobuf<=3.20.1
protobuf<4.0
pytorch-lightning>=1.6.0, <1.7
rpdb>=0.1.6
torchvision>=0.10.0

Просмотреть файл

@ -19,74 +19,74 @@ from setuptools import setup, find_namespace_packages # type: ignore
here = pathlib.Path(__file__).parent.resolve()
# Get the long description from the README file
long_description = (here / 'package_description.md').read_text(encoding='utf-8')
long_description = (here / "package_description.md").read_text(encoding="utf-8")
version = ''
version = ""
# If running from a GitHub Action then a standard set of environment variables will be
# populated (https://docs.github.com/en/actions/reference/environment-variables#default-environment-variables).
# In particular, GITHUB_REF is the branch or tag ref that triggered the workflow.
# If this was triggered by a tagged commit then GITHUB_REF will be: 'ref/tags/new_tag'.
# If this was triggered by a tagged commit then GITHUB_REF will be: "ref/tags/new_tag".
# Extract this tag and use it as a version string
# See also:
# https://packaging.python.org/guides/publishing-package-distribution-releases-using-github-actions-ci-cd-workflows/
# https://github.com/pypa/gh-action-pypi-publish
GITHUB_REF_TAG_COMMIT = 'refs/tags/v'
GITHUB_REF_TAG_COMMIT = "refs/tags/v"
github_ref = os.getenv('GITHUB_REF')
github_ref = os.getenv("GITHUB_REF")
if github_ref and github_ref.startswith(GITHUB_REF_TAG_COMMIT):
version = github_ref[len(GITHUB_REF_TAG_COMMIT):]
# Otherwise, if running from a GitHub Action, but not a tagged commit then GITHUB_RUN_NUMBER will be populated.
# Use this as a post release number. For example if GITHUB_RUN_NUMBER = 124 then the version string will be
# '0.1.2.post124'. Although this is discouraged, see:
# "0.1.2.post124". Although this is discouraged, see:
# https://www.python.org/dev/peps/pep-0440/#post-releases
# it is necessary here to avoid duplicate packages in Test.PyPI.
if not version:
# TODO: Replace this with more principled package version management for the package wheels built during local test
# runs, one which circumvents AzureML's apparent package caching:
build_number = os.getenv('GITHUB_RUN_NUMBER')
# runs, one which circumvents AzureML"s apparent package caching:
build_number = os.getenv("GITHUB_RUN_NUMBER")
if build_number:
version = '0.1.1.post' + build_number
version = "0.1.1.post" + build_number
else:
default_random_version_number = floor(random() * 10_000_000_000)
version = f'0.1.0.post{str(default_random_version_number)}'
version = f"0.1.0.post{str(default_random_version_number)}"
(here / 'package_name.txt').write_text('hi-ml')
(here / 'latest_version.txt').write_text(version)
(here / "package_name.txt").write_text("hi-ml")
(here / "latest_version.txt").write_text(version)
# Read run_requirements.txt to get install_requires
install_requires = (here / 'run_requirements.txt').read_text().split("\n")
install_requires = (here / "run_requirements.txt").read_text().split("\n")
# Remove any whitespace and blank lines
install_requires = [line.strip() for line in install_requires if line.strip()]
description = 'Microsoft Health Intelligence package containing high level ML components'
description = "Microsoft Health Futures package containing high level ML components"
setup(
name='hi-ml',
name="hi-ml",
version=version,
description=description,
long_description=long_description,
long_description_content_type='text/markdown',
url='https://github.com/microsoft/hi-ml',
author="Microsoft Research Cambridge InnerEye Team ",
long_description_content_type="text/markdown",
url="https://github.com/microsoft/hi-ml",
author="Biomedical Imaging Team @ Microsoft Health Futures",
author_email="innereyedev@microsoft.com",
classifiers=[
'Development Status :: 3 - Alpha',
'Intended Audience :: Science/Research',
"Development Status :: 3 - Alpha",
"Intended Audience :: Science/Research",
"Topic :: Scientific/Engineering :: Medical Science Apps.",
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.7'
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.7"
],
keywords='InnerEye, HealthIntelligence, AzureML',
license='MIT License',
keywords="Health Futures, Health Intelligence, AzureML",
license="MIT License",
packages=find_namespace_packages(where="src"),
package_dir={"": "src"},
include_package_data=True,
install_requires=install_requires,
entry_points={
'console_scripts': [
'himl-runner = health_ml.runner:main'
"console_scripts": [
"himl-runner = health_ml.runner:main"
]
}
)

Просмотреть файл

@ -10,12 +10,13 @@ import param
from enum import Enum, unique
from param import Parameterized
from pathlib import Path
from typing import List, Optional
from typing import Any, Dict, List, Optional
from azureml.train.hyperdrive import HyperDriveConfig
from health_azure import create_crossval_hyperdrive_config
from health_azure.himl import create_grid_hyperdrive_config
from health_azure.himl import (create_grid_hyperdrive_config, create_crossval_hyperparam_args_v2,
create_grid_hyperparam_args_v2)
from health_azure.amulet import (ENV_AMLT_PROJECT_NAME, ENV_AMLT_INPUT_OUTPUT,
ENV_AMLT_SNAPSHOT_DIR, ENV_AMLT_AZ_BATCHAI_DIR,
is_amulet_job, get_amulet_aml_working_dir)
@ -249,8 +250,31 @@ class WorkflowParams(param.Parameterized):
metric_name="val/loss"
)
def get_crossval_hyperparam_args_v2(self) -> Dict[str, Any]:
"""
Wrapper function to create hyperparameter search arguments specifically for running cross validation
with AML SDK v2
:return: A dictionary of hyperparameter search arguments and values.
"""
return create_crossval_hyperparam_args_v2(num_splits=self.crossval_count,
cross_val_index_arg_name=self.CROSSVAL_INDEX_ARG_NAME,
metric_name="val/loss")
def get_grid_hyperparam_args_v2(self) -> Dict[str, Any]:
"""
Wrapper function to create hyperparameter search arguments specifically for running grid search
with AML SDK v2
:return: A dictionary of hyperparameter search arguments and values.
"""
return create_grid_hyperparam_args_v2(values=list(map(str, range(self.different_seeds))),
argument_name=self.RANDOM_SEED_ARG_NAME,
metric_name="val/loss")
class DatasetParams(param.Parameterized):
datastore: str = param.String(default="", doc="Datastore to look for data in")
azure_datasets: List[str] = param.List(default=[], class_=str,
doc="If provided, the ID of one or more datasets to use when running in"
" AzureML. This dataset must exist as a folder of the same name "

Просмотреть файл

@ -44,3 +44,5 @@ class ExperimentConfig(param.Parameterized):
"`INFO` or `DETAIL` for different levels of logging. "
"`DETAIL` may impact the application performance and thus "
"should only be used when debugging issues")
strictly_aml_v1: bool = param.Boolean(default=False, doc="If True, use AzureML v1 SDK. If False (default), use "
"the v2 of the SDK")

Просмотреть файл

@ -96,6 +96,14 @@ class LightningContainer(WorkflowParams,
raise NotImplementedError("Parameter search is not implemented. Please override 'get_parameter_tuning_config' "
"in your model container.")
def get_parameter_tuning_args(self) -> Dict[str, Any]:
"""
Returns a dictionary of hyperperameter argument names and values as expected by a AML SDK v2 job
to perform hyperparameter search
"""
raise NotImplementedError("Parameter search is not implemented. Please override 'get_parameter_tuning_args' "
"in your model container.")
def update_experiment_config(self, experiment_config: ExperimentConfig) -> None:
"""
This method allows overriding ExperimentConfig parameters from within a LightningContainer.
@ -185,6 +193,21 @@ class LightningContainer(WorkflowParams,
return self.get_different_seeds_hyperdrive_config()
return None
def get_hyperparam_args(self) -> Optional[Dict[str, Any]]:
"""
Returns a dictionary of hyperparameter search arguments that will be passed to an AML v2 command to
enable either hyperparameter tuning, cross validation, or running with different seeds.
:return: A dictionary of hyperparameter search arguments and values.
"""
if self.hyperdrive:
return self.get_parameter_tuning_args()
if self.is_crossvalidation_enabled:
return self.get_crossval_hyperparam_args_v2()
if self.different_seeds > 0:
return self.get_grid_hyperparam_args_v2()
return None
def load_model_checkpoint(self, checkpoint_path: Path) -> None:
"""
Load a checkpoint from the given path. We need to define a separate method since pytorch lightning cannot

Просмотреть файл

@ -3,6 +3,7 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import logging
import os
from pathlib import Path
from typing import Any, List, Optional, Tuple, TypeVar
@ -16,9 +17,9 @@ from pytorch_lightning.profiler import BaseProfiler, SimpleProfiler, AdvancedPro
from health_azure.utils import RUN_CONTEXT, is_running_in_azure_ml
from health_ml.lightning_container import LightningContainer
from health_ml.utils import AzureMLLogger, AzureMLProgressBar
from health_ml.utils import AzureMLProgressBar
from health_ml.utils.common_utils import AUTOSAVE_CHECKPOINT_FILE_NAME, EXPERIMENT_SUMMARY_FILE
from health_ml.utils.lightning_loggers import StoringLogger
from health_ml.utils.lightning_loggers import StoringLogger, HimlMLFlowLogger
T = TypeVar('T')
@ -55,7 +56,8 @@ def create_lightning_trainer(container: LightningContainer,
resume_from_checkpoint: Optional[Path] = None,
num_nodes: int = 1,
multiple_trainloader_mode: str = "max_size_cycle",
azureml_run_for_logging: Optional[Run] = None) -> \
azureml_run_for_logging: Optional[Run] = None,
mlflow_run_for_logging: Optional[str] = None) -> \
Tuple[Trainer, StoringLogger]:
"""
Creates a Pytorch Lightning Trainer object for the given model configuration. It creates checkpoint handlers
@ -91,9 +93,27 @@ def create_lightning_trainer(container: LightningContainer,
message += "s per node with DDP"
logging.info(f"Using {message}")
tensorboard_logger = TensorBoardLogger(save_dir=str(container.logs_folder), name="Lightning", version="")
azureml_logger = AzureMLLogger(enable_logging_outside_azure_ml=container.log_from_vm,
run=azureml_run_for_logging)
loggers = [tensorboard_logger, azureml_logger]
loggers: List[Any] = [tensorboard_logger]
if is_running_in_azure_ml():
mlflow_run_id = os.environ.get("MLFLOW_RUN_ID", None)
logging.info(f"Logging to MLFlow run with id: {mlflow_run_id}")
mlflow_logger = HimlMLFlowLogger(
run_id=mlflow_run_id
)
loggers.append(mlflow_logger)
else:
mlflow_run_dir = container.outputs_folder / "mlruns"
try:
mlflow_run_dir.mkdir(exist_ok=True)
mlflow_tracking_uri = "file:" + str(mlflow_run_dir)
mlflow_logger = HimlMLFlowLogger(run_id=mlflow_run_for_logging, tracking_uri=mlflow_tracking_uri)
loggers.append(mlflow_logger)
logging.info(f"Logging to MLFlow run with id: {mlflow_run_for_logging}. Local MLFlow logs are stored in "
f"{mlflow_tracking_uri}")
except FileNotFoundError as e:
logging.warning(f"Unable to initialise MLFlowLogger due to error: {e}")
storing_logger = StoringLogger()
loggers.append(storing_logger)
# Use 32bit precision when running on CPU. Otherwise, make it depend on use_mixed_precision flag.

Просмотреть файл

@ -63,7 +63,7 @@ class AttentionLayer(nn.Module):
attention_weights = transpose(attention_weights, 1, 0) # K x N
attention_weights = F.softmax(attention_weights, dim=1) # Softmax over N : K x N
pooled_features = mm(attention_weights, features) # Matrix multiplication : K x L
return(attention_weights, pooled_features)
return attention_weights, pooled_features
class GatedAttentionLayer(nn.Module):
@ -98,7 +98,7 @@ class GatedAttentionLayer(nn.Module):
attention_weights = transpose(attention_weights, 1, 0) # K x N
attention_weights = F.softmax(attention_weights, dim=1) # Softmax over N : K x N
pooled_features = mm(attention_weights, features) # Matrix multiplication : K x L
return(attention_weights, pooled_features)
return attention_weights, pooled_features
class CustomTransformerEncoderLayer(TransformerEncoderLayer):

Просмотреть файл

@ -20,7 +20,6 @@ from health_azure.utils import (create_run_recovery_id, ENV_OMPI_COMM_WORLD_RANK
aggregate_hyperdrive_metrics, get_metrics_for_childless_run,
ENV_GLOBAL_RANK, ENV_LOCAL_RANK, ENV_NODE_RANK,
is_local_rank_zero, is_global_rank_zero, create_aml_run_object)
from health_ml.experiment_config import ExperimentConfig
from health_ml.lightning_container import LightningContainer
from health_ml.model_trainer import create_lightning_trainer, write_experiment_summary_file
@ -35,11 +34,27 @@ from health_ml.utils.common_utils import (
df_to_json,
seed_monai_if_available,
)
from health_ml.utils.lightning_loggers import StoringLogger
from health_ml.utils.lightning_loggers import HimlMLFlowLogger, StoringLogger
from health_ml.utils.regression_test_utils import REGRESSION_TEST_METRICS_FILENAME, compare_folders_and_run_outputs
from health_ml.utils.type_annotations import PathOrString
def get_mlflow_run_id_from_previous_loggers(trainer: Optional[Trainer]) -> Optional[str]:
"""
If self.trainer has already been intialised with loggers, attempt to retrieve a HimlMLFLowLogger and
return the mlflow run_id associated with it, to allow continued logging to the same run. Otherwise, return None
:return: The mlflow run id from the existing HimlMLFlowLogger
"""
if trainer is None:
return None
try:
mlflow_logger = [logger for logger in trainer.loggers if isinstance(logger, HimlMLFlowLogger)][0]
return mlflow_logger.run_id
except IndexError:
return None
def check_dataset_folder_exists(local_dataset: PathOrString) -> Path:
"""
Checks if a folder with a local dataset exists. If it does exist, return the argument converted
@ -80,6 +95,7 @@ class MLRunner:
run_context=RUN_CONTEXT)
self.trainer: Optional[Trainer] = None
self.azureml_run_for_logging: Optional[Run] = None
self.mlflow_run_for_logging: Optional[str] = None
def set_run_tags_from_parent(self) -> None:
"""
@ -117,6 +133,7 @@ class MLRunner:
# the provided local datasets for VM runs, or the AzureML mount points when running in AML.
# This must happen before container setup because that could already read datasets.
input_datasets = azure_run_info.input_datasets
logging.info(f"Setting the following datasets as local datasets: {input_datasets}")
if len(input_datasets) > 0:
local_datasets: List[Path] = []
for i, dataset in enumerate(input_datasets):
@ -261,6 +278,7 @@ class MLRunner:
# We run inference on a single device because distributed strategies such as DDP use DistributedSampler
# internally, which replicates some samples to make sure all devices have the same batch size in case of
# uneven inputs.
mlflow_run_id = get_mlflow_run_id_from_previous_loggers(self.trainer)
self.container.max_num_gpus = 1
if self.container.run_inference_only:
@ -272,7 +290,8 @@ class MLRunner:
container=self.container,
resume_from_checkpoint=checkpoint_path,
num_nodes=1,
azureml_run_for_logging=self.azureml_run_for_logging
azureml_run_for_logging=self.azureml_run_for_logging,
mlflow_run_for_logging=mlflow_run_id
)
return trainer
@ -319,6 +338,12 @@ class MLRunner:
if self.container.has_custom_test_step():
# Run Lightning's built-in test procedure if the `test_step` method has been overridden
logging.info("Running inference via the LightningModule.test_step method")
# We run inference on a single device because distributed strategies such as DDP use DistributedSampler
# internally, which replicates some samples to make sure all devices have some batch size in case of
# uneven inputs.
self.container.max_num_gpus = 1
checkpoint_path = (
self.checkpoint_handler.get_checkpoint_to_test() if self.container.run_inference_only else None
)
@ -384,9 +409,6 @@ class MLRunner:
with logging_section("Model training"):
self.run_training()
# Kill all processes besides rank 0
self.after_ddp_cleanup(old_environ)
# load model checkpoint for custom inference or additional validation step
if self.container.has_custom_test_step() or self.container.run_extra_val_epoch:
self.load_model_checkpoint()
@ -396,6 +418,9 @@ class MLRunner:
with logging_section("Model Validation to save plots on validation set"):
self.run_validation()
# Kill all processes besides rank 0
self.after_ddp_cleanup(old_environ)
# Run inference on a single device
with logging_section("Model inference"):
self.run_inference()

52
hi-ml/src/health_ml/runner.py Normal file → Executable file
Просмотреть файл

@ -1,3 +1,5 @@
#! /usr/bin/env python
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
@ -27,16 +29,17 @@ from health_azure.datasets import create_dataset_configs # noqa: E402
from health_azure.logging import logging_to_stdout # noqa: E402
from health_azure.paths import is_himl_used_from_git_repo # noqa: E402
from health_azure.amulet import prepare_amulet_job, is_amulet_job # noqa: E402
from health_azure.utils import (get_workspace, is_local_rank_zero, is_running_in_azure_ml, # noqa: E402
set_environment_variables_for_multi_node,
create_argparser, parse_arguments, ParserResult, apply_overrides)
from health_azure.utils import (get_workspace, get_ml_client, is_local_rank_zero, # noqa: E402
is_running_in_azure_ml, set_environment_variables_for_multi_node,
create_argparser, parse_arguments, ParserResult, apply_overrides,
filter_v2_input_output_args)
from health_ml.experiment_config import DEBUG_DDP_ENV_VAR, ExperimentConfig # noqa: E402
from health_ml.lightning_container import LightningContainer # noqa: E402
from health_ml.run_ml import MLRunner # noqa: E402
from health_ml.utils import fixed_paths # noqa: E402
from health_ml.utils.common_utils import (check_conda_environment, choose_conda_env_file, # noqa: E402
is_linux)
from health_ml.utils.common_utils import (DEFAULT_DOCKER_BASE_IMAGE, check_conda_environment, # noqa: E402
choose_conda_env_file, is_linux)
from health_ml.utils.config_loader import ModelConfigLoader # noqa: E402
@ -46,8 +49,6 @@ from health_ml.utils.config_loader import ModelConfigLoader # noqa: E402
runner_path = Path(sys.argv[0])
sys.argv[0] = str(runner_path.resolve())
DEFAULT_DOCKER_BASE_IMAGE = "mcr.microsoft.com/azureml/openmpi3.1.2-cuda10.2-cudnn8-ubuntu18.04"
def initialize_rpdb() -> None:
"""
@ -126,8 +127,12 @@ class Runner:
:return: ParserResult object containing args, overrides and settings
"""
# Filter out any args for passing inputs and outputs to scripts with AML SDK v2
args = sys.argv[1:]
filtered_args = filter_v2_input_output_args(args)
parser1 = create_runner_parser()
parser1_result = parse_arguments(parser1, args=sys.argv[1:])
parser1_result = parse_arguments(parser1, args=filtered_args)
experiment_config = ExperimentConfig(**parser1_result.args)
self.experiment_config = experiment_config
@ -229,7 +234,13 @@ class Runner:
except ValueError:
raise ValueError("Unable to submit the script to AzureML because no workspace configuration file "
"(config.json) was found.")
default_datastore = workspace.get_default_datastore().name if workspace is not None else ""
if self.lightning_container.datastore:
datastore = self.lightning_container.datastore
elif workspace:
datastore = workspace.get_default_datastore().name
else:
datastore = ""
local_datasets = self.lightning_container.local_datasets
all_local_datasets = [Path(p) for p in local_datasets] if len(local_datasets) > 0 else []
@ -240,10 +251,18 @@ class Runner:
create_dataset_configs(all_azure_dataset_ids=self.lightning_container.azure_datasets,
all_dataset_mountpoints=self.lightning_container.dataset_mountpoints,
all_local_datasets=all_local_datasets, # type: ignore
datastore=default_datastore,
datastore=datastore,
use_mounting=use_mounting)
if self.experiment_config.cluster and not is_running_in_azure_ml():
hyperdrive_config = self.lightning_container.get_hyperdrive_config()
if self.experiment_config.strictly_aml_v1:
hyperdrive_config = self.lightning_container.get_hyperdrive_config()
hyperparam_args = None
else:
hyperparam_args = self.lightning_container.get_hyperparam_args()
hyperdrive_config = None
ml_client = get_ml_client()
env_file = choose_conda_env_file(env_file=self.experiment_config.conda_env)
logging.info(f"Using this Conda environment definition: {env_file}")
check_conda_environment(env_file)
@ -254,9 +273,10 @@ class Runner:
script_params=script_params,
conda_environment_file=env_file,
aml_workspace=workspace,
ml_client=ml_client,
compute_cluster_name=self.experiment_config.cluster,
environment_variables=environment_variables,
default_datastore=default_datastore,
default_datastore=datastore,
experiment_name=self.lightning_container.effective_experiment_name,
input_datasets=input_datasets, # type: ignore
num_nodes=self.experiment_config.num_nodes,
@ -266,15 +286,19 @@ class Runner:
docker_base_image=DEFAULT_DOCKER_BASE_IMAGE,
docker_shm_size=self.experiment_config.docker_shm_size,
hyperdrive_config=hyperdrive_config,
hyperparam_args=hyperparam_args,
create_output_folders=False,
after_submission=after_submission_hook,
tags=self.additional_run_tags(script_params)
tags=self.additional_run_tags(script_params),
strictly_aml_v1=self.experiment_config.strictly_aml_v1,
)
else:
azure_run_info = submit_to_azure_if_needed(
input_datasets=input_datasets, # type: ignore
submit_to_azureml=False,
environment_variables=environment_variables)
environment_variables=environment_variables,
strictly_aml_v1=self.experiment_config.strictly_aml_v1,
)
if azure_run_info.run:
# This code is only reached inside Azure. Set display name again - this will now affect
# Hypdrive child runs (for other jobs, this has already been done after submission)

Просмотреть файл

@ -49,6 +49,7 @@ RUN_RECOVERY_FROM_ID_KEY_NAME = "recovered_from"
# other
EFFECTIVE_RANDOM_SEED_KEY_NAME = "effective_random_seed"
DEFAULT_DOCKER_BASE_IMAGE = "mcr.microsoft.com/azureml/openmpi3.1.2-cuda10.2-cudnn8-ubuntu18.04"
@unique

Просмотреть файл

@ -3,10 +3,14 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import logging
from typing import Any, Dict, Iterable, List, Optional
from argparse import Namespace
from typing import Any, Dict, Iterable, List, Optional, Union
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.utilities import rank_zero_only
import mlflow
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import LightningLoggerBase, MLFlowLogger
from pytorch_lightning.utilities.logger import _convert_params, _flatten_dict
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
from health_ml.utils.type_annotations import DictStrFloat, DictStrFloatOrFloatList
@ -109,3 +113,49 @@ class StoringLogger(LightningLoggerBase):
:return: A dictionary mapping from epoch number to metric name to metric value.
"""
return {epoch: self.extract_by_prefix(epoch, prefix_filter) for epoch in self.epochs}
class HimlMLFlowLogger(MLFlowLogger):
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
"""
Override underlying log_hyperparams message to avoid trying to log hyperparameters that have already
been logged, thus causing MLFlow to raise an Exception.
:param params: The original hyperparameters to be logged.
"""
run = mlflow.get_run(self.run_id)
existing_hyperparams = run.data.params
params = _convert_params(params)
params = _flatten_dict(params)
for k, v in params.items():
if len(str(v)) > 250:
rank_zero_warn(
f"Mlflow only allows parameters with up to 250 characters. Discard {k}={v}",
category=RuntimeWarning
)
continue
if k in existing_hyperparams:
continue
self.experiment.log_param(self.run_id, k, v)
def get_mlflow_run_id_from_trainer(trainer: Trainer) -> Optional[str]:
"""
If trainer has already been intialised with loggers, attempt to retrieve one of the type HimlMLFlowLogger,
and return its run_id property in order to log to the same run. Otherwise, return None.
:return: The mlflow run id from an existing HimlMLFlowLogger if available, else None.
"""
if trainer is None:
return None
try:
mlflow_logger = [logger for logger in trainer.loggers if isinstance(logger, HimlMLFlowLogger)][0]
return mlflow_logger.run_id
except IndexError:
return None

Просмотреть файл

@ -124,7 +124,7 @@ class AzureMLLogger(LightningLoggerBase):
return
if params is None:
return
params_final = self._preprocess_hyperparams(params)
params_final = _preprocess_hyperparams(params)
if len(params_final) > 0:
# Log hyperparameters as a table with 2 columns. Each "step" is one hyperparameter
self.run.log_table(self.HYPERPARAMS_NAME, {"name": list(params_final.keys()),
@ -150,24 +150,6 @@ class AzureMLLogger(LightningLoggerBase):
# Run.complete should only be called if we created an AzureML run here in the constructor.
self.run.complete()
def _preprocess_hyperparams(self, params: Any) -> Dict[str, str]:
"""
Converts arbitrary hyperparameters to a simple dictionary structure, in particular argparse Namespaces.
Nested dictionaries are converted to folder-like strings, like ``{'a': {'b': 'c'}} -> {'a/b': 'c'}``.
All hyperparameter values are converted to strings, because Run.log_table can't deal with mixed datatypes.
:param params: The parameters to convert
:return: A dictionary mapping from string to string.
"""
# Convert from Namespace to dictionary
params = _convert_params(params)
# Convert nested dictionaries to folder-like structure
params = _flatten_dict(params)
# Convert anything that is not a primitive type to str
params_final = _sanitize_params(params)
if not isinstance(params_final, dict):
raise ValueError(f"Expected the converted hyperparameters to be a dictionary, but got {type(params)}")
return {str(key): str(value) for key, value in params_final.items()}
class AzureMLProgressBar(ProgressBarBase):
"""
@ -333,6 +315,25 @@ class AzureMLProgressBar(ProgressBarBase):
sys.stdout.flush()
def _preprocess_hyperparams(params: Any) -> Dict[str, str]:
"""
Converts arbitrary hyperparameters to a simple dictionary structure, in particular argparse Namespaces.
Nested dictionaries are converted to folder-like strings, like ``{'a': {'b': 'c'}} -> {'a/b': 'c'}``.
All hyperparameter values are converted to strings, because Run.log_table can't deal with mixed datatypes.
:param params: The parameters to convert
:return: A dictionary mapping from string to string.
"""
# Convert from Namespace to dictionary
params = _convert_params(params)
# Convert nested dictionaries to folder-like structure
params = _flatten_dict(params)
# Convert anything that is not a primitive type to str
params_final = _sanitize_params(params)
if not isinstance(params_final, dict):
raise ValueError(f"Expected the converted hyperparameters to be a dictionary, but got {type(params)}")
return {str(key): str(value) for key, value in params_final.items()}
def log_on_epoch(module: LightningModule,
name: Optional[str] = None,
value: Optional[Any] = None,

Просмотреть файл

@ -48,7 +48,7 @@ def _test_data_augmentation(data_augmentation: Callable[[Tensor], Tensor],
# After applying a stochastic augmentation a second time it should have a different output
if stochastic:
augmented_img = data_augmentation(input_img) # type: ignore
assert not(torch.allclose(augmented_img, expected_output_img, atol=1e-04))
assert not torch.allclose(augmented_img, expected_output_img, atol=1e-04)
def test_stain_normalization() -> None:

Просмотреть файл

@ -11,15 +11,17 @@ from unittest.mock import DEFAULT, MagicMock, Mock, patch
from _pytest.logging import LogCaptureFixture
from pytorch_lightning import LightningModule
from azureml._restclient.constants import RunStatus
import mlflow
from pytorch_lightning import Trainer
from health_ml.configs.hello_world import HelloWorld # type: ignore
from health_ml.experiment_config import ExperimentConfig
from health_ml.lightning_container import LightningContainer
from health_ml.run_ml import MLRunner
from health_ml.run_ml import MLRunner, get_mlflow_run_id_from_previous_loggers
from health_ml.utils.checkpoint_utils import CheckpointParser
from health_ml.utils.common_utils import is_gpu_available
from health_ml.utils.lightning_loggers import HimlMLFlowLogger, StoringLogger
from health_azure.utils import is_global_rank_zero
from health_ml.utils.logging import AzureMLLogger
from testazure.utils_testazure import DEFAULT_WORKSPACE
from testhiml.utils.fixed_paths_for_tests import mock_run_id
@ -412,22 +414,7 @@ def test_log_on_vm(log_from_vm: bool) -> None:
assert runner.trainer.loggers is not None
assert len(runner.trainer.loggers) > 1
logger = runner.trainer.loggers[1]
assert isinstance(logger, AzureMLLogger)
if log_from_vm:
assert logger.run is not None
# Check that all user supplied data (experiment and display name) are respected.
assert logger.run.experiment is not None
assert logger.run.experiment.name == experiment_name
assert logger.run.display_name == tag
# Both trainig and inference metrics must be logged in the same Run object.
metrics = logger.run.get_metrics()
assert "test_mse" in metrics
assert "loss" in metrics
# The run must have been correctly marked as completed.
logger.run.wait_for_completion()
assert logger.run.status == RunStatus.COMPLETED
else:
assert logger.run is None
assert isinstance(logger, HimlMLFlowLogger)
def test_experiment_name() -> None:
@ -442,3 +429,21 @@ def test_experiment_name() -> None:
experiment_name = "unittest"
container.experiment = experiment_name
assert container.effective_experiment_name == experiment_name
def test_get_mlflow_run_id_from_previous_loggers() -> None:
trainer_without_loggers = Trainer()
run_id = get_mlflow_run_id_from_previous_loggers(trainer_without_loggers)
assert run_id is None
loggers_not_inc_mlflow = [StoringLogger()]
trainer_with_single_logger = Trainer(logger=loggers_not_inc_mlflow)
run_id = get_mlflow_run_id_from_previous_loggers(trainer_with_single_logger)
assert run_id is None
mock_run_id = "run_id_123"
loggers_inc_mlflow = [StoringLogger(), HimlMLFlowLogger(run_id=mock_run_id)]
trainer_with_loggers = Trainer(logger=loggers_inc_mlflow)
with patch.object(mlflow.tracking.client.TrackingServiceClient, "get_run"):
run_id = get_mlflow_run_id_from_previous_loggers(trainer_with_loggers)
assert run_id == mock_run_id

Просмотреть файл

@ -103,9 +103,10 @@ def test_additional_aml_run_tags(mock_runner: Runner) -> None:
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_azure_if_needed:
with patch("health_ml.runner.check_conda_environment"):
with patch("health_ml.runner.get_workspace"):
with patch("health_ml.runner.Runner.run_in_situ"):
with patch.object(sys, "argv", arguments):
mock_runner.run()
with patch("health_ml.runner.get_ml_client"):
with patch("health_ml.runner.Runner.run_in_situ"):
with patch.object(sys, "argv", arguments):
mock_runner.run()
mock_submit_to_azure_if_needed.assert_called_once()
assert "commandline_args" in mock_submit_to_azure_if_needed.call_args[1]["tags"]
assert "tag" in mock_submit_to_azure_if_needed.call_args[1]["tags"]
@ -134,8 +135,10 @@ def test_submit_to_azureml_if_needed(mock_get_workspace: MagicMock,
mock_get_env_files: MagicMock,
mock_runner: Runner
) -> None:
def _mock_dont_submit_to_aml(input_datasets: List[DatasetConfig], submit_to_azureml: bool, # type: ignore
environment_variables: Dict[str, Any]) -> AzureRunInfo:
def _mock_dont_submit_to_aml(input_datasets: List[DatasetConfig],
submit_to_azureml: bool, strictly_aml_v1: bool, # type: ignore
environment_variables: Dict[str, Any], # type: ignore
) -> AzureRunInfo:
datasets_input = [d.target_folder for d in input_datasets] if input_datasets else []
return AzureRunInfo(input_datasets=datasets_input,
output_datasets=[],
@ -247,15 +250,16 @@ def _test_hyperdrive_submission(mock_runner: Runner,
expected_argument_name: str,
expected_argument_values: List[str]) -> None:
model_name = "HelloWorld"
arguments = ["", f"--model={model_name}", "--cluster=foo", commandline_arg]
arguments = ["", f"--model={model_name}", "--cluster=foo", commandline_arg, "--strictly_aml_v1=True"]
# Use a special simplified environment file only for the tests here. Copy that to a temp folder, then let the runner
# start in that temp folder.
with change_working_folder_and_add_environment(mock_runner.project_root):
with patch("health_ml.runner.Runner.run_in_situ") as mock_run_in_situ:
with patch("health_ml.runner.get_workspace"):
with patch.object(sys, "argv", arguments):
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_aml:
mock_runner.run()
with patch("health_ml.runner.get_ml_client"):
with patch.object(sys, "argv", arguments):
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_aml:
mock_runner.run()
mock_run_in_situ.assert_called_once()
mock_submit_to_aml.assert_called_once()
# call_args is a tuple of (args, kwargs)
@ -281,10 +285,11 @@ def test_submit_to_azure_docker(mock_runner: Runner) -> None:
# start in that temp folder.
with change_working_folder_and_add_environment(mock_runner.project_root):
with patch("health_ml.runner.Runner.run_in_situ") as mock_run_in_situ:
with patch("health_ml.runner.get_workspace"):
with patch.object(sys, "argv", arguments):
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_aml:
mock_runner.run()
with patch("health_ml.runner.get_ml_client"):
with patch("health_ml.runner.get_workspace"):
with patch.object(sys, "argv", arguments):
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_aml:
mock_runner.run()
mock_run_in_situ.assert_called_once()
mock_submit_to_aml.assert_called_once()
# call_args is a tuple of (args, kwargs)

Просмотреть файл

@ -21,6 +21,7 @@ from azureml.core import Run
from health_azure import RUN_CONTEXT, create_aml_run_object
from health_ml.utils import AzureMLLogger, AzureMLProgressBar, log_learning_rate, log_on_epoch
from health_ml.utils.logging import _preprocess_hyperparams
from testhiml.utils_testhiml import DEFAULT_WORKSPACE
@ -240,8 +241,7 @@ def test_azureml_logger_hyperparams_processing() -> None:
"""
hyperparams = {"A long list": ["foo", 1.0, "abc"],
"foo": 1.0}
logger = AzureMLLogger(enable_logging_outside_azure_ml=False)
actual = logger._preprocess_hyperparams(hyperparams)
actual = _preprocess_hyperparams(hyperparams)
assert actual == {"A long list": "['foo', 1.0, 'abc']", "foo": "1.0"}

Просмотреть файл

@ -1,8 +1,8 @@
black==22.1.0
coverage==6.3.2
flake8==4.0.1
flake8==5.0.2
mypy==0.931
pylint==2.12.2
pylint==2.15.0
pycobertura==2.0.1
pytest==6.2.2
pytest-cov==2.11.1

Просмотреть файл

@ -1,9 +1,9 @@
black==22.1.0
coverage==6.3.2
flake8==4.0.1
flake8==5.0.2
mypy==0.931
pre-commit==2.19.0
pylint==2.12.2
pylint==2.15.0
pycobertura==2.0.1
pytest==6.2.2
pytest-cov==2.11.1