зеркало из https://github.com/microsoft/hi-ml.git
ENH: Merge Main into Transfer Main (#653)
Sync 32 commits from main Co-authored-by: Fernando Pérez-García <fepegar@gmail.com> Co-authored-by: Shruthi42 <13177030+Shruthi42@users.noreply.github.com> Co-authored-by: vale-salvatelli <vale-salvatelli@users.noreply.github.com> Co-authored-by: Melissa Bristow <66642528+mebristo@users.noreply.github.com> Co-authored-by: Anton Schwaighofer <antonsc@microsoft.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Fernando Pérez-García <fperezgarcia@microsoft.com> Co-authored-by: Ozan Oktay <ozan.oktay@microsoft.com>
This commit is contained in:
Родитель
6ba3cfe685
Коммит
e30a0d1f6d
|
@ -0,0 +1,54 @@
|
|||
name: 'Hi-ml-cpath environment setup'
|
||||
description: 'Set up environment hi-ml-cpath workflows'
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Create AzureML config.json file
|
||||
shell: bash
|
||||
run: ./create_config.sh
|
||||
|
||||
# Use a cache action to save the full conda environment, so that we don't have to reinstall it every time.
|
||||
# Paths are tied to the location of the miniconda installation, and may need adjustment on a different OS.
|
||||
- name: Retrieve cached Conda environment
|
||||
id: cache-conda
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: /usr/share/miniconda/envs/HimlHisto
|
||||
key: hi-ml-cpath-conda-${{ hashFiles('hi-ml-cpath/environment.yml') }}
|
||||
|
||||
# If the cache action didn't find a cache, then install the conda environment afresh.
|
||||
- name: Build Conda environment from scratch
|
||||
uses: conda-incubator/setup-miniconda@v2
|
||||
if: steps.cache-conda.outputs.cache-hit != 'true'
|
||||
with:
|
||||
environment-file: hi-ml-cpath/environment.yml
|
||||
activate-environment: HimlHisto
|
||||
|
||||
# Modify the path to point to the new or cached Conda environment.
|
||||
# This is effectively also what `conda activate` does.
|
||||
- name: Activate environment
|
||||
shell: bash
|
||||
run: |
|
||||
echo "Adding Conda bin folder to path"
|
||||
echo "/usr/share/miniconda/envs/HimlHisto/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Conda info
|
||||
shell: bash
|
||||
run: conda info
|
||||
|
||||
- name: Show active Python path
|
||||
shell: bash
|
||||
run: which python
|
||||
|
||||
- name: Install hi-ml locally
|
||||
shell: bash
|
||||
run: |
|
||||
cd hi-ml
|
||||
make pip_local
|
||||
|
||||
- name: Install hi-ml-azure locally
|
||||
shell: bash
|
||||
run: |
|
||||
cd hi-ml-azure
|
||||
make pip_local
|
|
@ -1,31 +0,0 @@
|
|||
name: 'Smoke test environment setup'
|
||||
description: 'Set up environment for running smoke tests'
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: create config file
|
||||
shell: bash
|
||||
run: ./create_config.sh
|
||||
|
||||
- name: Set up Python ${{ env.pythonVersion }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ env.pythonVersion }}
|
||||
|
||||
- name: Install required packages
|
||||
shell: bash
|
||||
run: |
|
||||
cd ${{ env.folder }}
|
||||
make pip_from_conda
|
||||
|
||||
- name: Install hi-ml locally
|
||||
shell: bash
|
||||
run: |
|
||||
cd hi-ml
|
||||
make pip_local
|
||||
|
||||
- name: Install hi-ml-azure locally
|
||||
shell: bash
|
||||
run: |
|
||||
cd hi-ml-azure
|
||||
make pip_local
|
|
@ -14,7 +14,7 @@ on:
|
|||
- "hi-ml/**"
|
||||
|
||||
env:
|
||||
pythonVersion: 3.7
|
||||
pythonVersion: 3.9
|
||||
folder: hi-ml-cpath
|
||||
module_for_coverage_reporting: health_cpath
|
||||
HIML_TENANT_ID: ${{ secrets.HIML_TENANT_ID }}
|
||||
|
@ -69,30 +69,8 @@ jobs:
|
|||
with:
|
||||
lfs: true
|
||||
|
||||
- name: Set up Python ${{ env.pythonVersion }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ env.pythonVersion }}
|
||||
|
||||
- name: PIP upgrade
|
||||
run: |
|
||||
cd ${{ env.folder }}
|
||||
make pip_upgrade
|
||||
|
||||
- name: Install required packages
|
||||
run: |
|
||||
cd ${{ env.folder }}
|
||||
make pip_from_conda
|
||||
|
||||
- name: Install hi-ml locally
|
||||
run: |
|
||||
cd hi-ml
|
||||
make pip_local
|
||||
|
||||
- name: Install hi-ml-azure locally
|
||||
run: |
|
||||
cd hi-ml-azure
|
||||
make pip_local
|
||||
- name: Prepare Conda environment
|
||||
uses: ./.github/actions/prepare_cpath_environment
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
|
@ -126,9 +104,8 @@ jobs:
|
|||
with:
|
||||
lfs: true
|
||||
|
||||
- name: Set up smoke test environment
|
||||
id: setup-slides-smoke-test-environment
|
||||
uses: ./.github/actions/prepare_smoke_test_environment
|
||||
- name: Prepare Conda environment
|
||||
uses: ./.github/actions/prepare_cpath_environment
|
||||
|
||||
- name: smoke test
|
||||
run: |
|
||||
|
@ -142,9 +119,8 @@ jobs:
|
|||
with:
|
||||
lfs: true
|
||||
|
||||
- name: Set up smoke test environment
|
||||
id: setup-tiles-smoke-test-environment
|
||||
uses: ./.github/actions/prepare_smoke_test_environment
|
||||
- name: Prepare Conda environment
|
||||
uses: ./.github/actions/prepare_cpath_environment
|
||||
|
||||
- name: smoke test
|
||||
run: |
|
||||
|
@ -158,9 +134,8 @@ jobs:
|
|||
with:
|
||||
lfs: true
|
||||
|
||||
- name: Set up smoke test environment
|
||||
id: setup-tiles-smoke-test-environment
|
||||
uses: ./.github/actions/prepare_smoke_test_environment
|
||||
- name: Prepare Conda environment
|
||||
uses: ./.github/actions/prepare_cpath_environment
|
||||
|
||||
- name: smoke test
|
||||
run: |
|
||||
|
@ -174,9 +149,8 @@ jobs:
|
|||
with:
|
||||
lfs: true
|
||||
|
||||
- name: Set up smoke test environment
|
||||
id: setup-sslmil-smoke-test-environment
|
||||
uses: ./.github/actions/prepare_smoke_test_environment
|
||||
- name: Prepare Conda environment
|
||||
uses: ./.github/actions/prepare_cpath_environment
|
||||
|
||||
- name: smoke test
|
||||
run: |
|
||||
|
@ -190,9 +164,8 @@ jobs:
|
|||
with:
|
||||
lfs: true
|
||||
|
||||
- name: Set up smoke test environment
|
||||
id: setup-simclr-smoke-test-environment
|
||||
uses: ./.github/actions/prepare_smoke_test_environment
|
||||
- name: Prepare Conda environment
|
||||
uses: ./.github/actions/prepare_cpath_environment
|
||||
|
||||
- name: smoke test
|
||||
run: |
|
||||
|
@ -206,9 +179,8 @@ jobs:
|
|||
with:
|
||||
lfs: true
|
||||
|
||||
- name: Set up smoke test environment
|
||||
id: setup-finetuning-smoke-test-environment
|
||||
uses: ./.github/actions/prepare_smoke_test_environment
|
||||
- name: Prepare Conda environment
|
||||
uses: ./.github/actions/prepare_cpath_environment
|
||||
|
||||
- name: smoke test
|
||||
run: |
|
||||
|
@ -222,9 +194,8 @@ jobs:
|
|||
with:
|
||||
lfs: true
|
||||
|
||||
- name: Set up smoke test environment
|
||||
id: setup-finetuning-smoke-test-environment
|
||||
uses: ./.github/actions/prepare_smoke_test_environment
|
||||
- name: Prepare Conda environment
|
||||
uses: ./.github/actions/prepare_cpath_environment
|
||||
|
||||
- name: smoke test
|
||||
run: |
|
||||
|
@ -238,9 +209,8 @@ jobs:
|
|||
with:
|
||||
lfs: true
|
||||
|
||||
- name: Set up smoke test environment
|
||||
id: setup-finetuning-smoke-test-environment
|
||||
uses: ./.github/actions/prepare_smoke_test_environment
|
||||
- name: Prepare Conda environment
|
||||
uses: ./.github/actions/prepare_cpath_environment
|
||||
|
||||
- name: smoke test
|
||||
run: |
|
||||
|
|
|
@ -105,5 +105,9 @@ jobs:
|
|||
run: |
|
||||
cd ${{ env.folder }}
|
||||
ipython kernel install --name "python3" --user
|
||||
echo "Current branch: ${BRANCH_NAME}"
|
||||
papermill --parameters repo_branch ${BRANCH_NAME} notebooks/phrase_grounding.ipynb /tmp/phrase_grounding_output.ipynb
|
||||
PIP_SOURCE="git+https://github.com/microsoft/hi-ml.git@${BRANCH_NAME}#subdirectory=hi-ml-multimodal"
|
||||
echo "The package will be installed from: $PIP_SOURCE"
|
||||
papermill \
|
||||
--parameters pip_source $PIP_SOURCE \
|
||||
notebooks/phrase_grounding.ipynb \
|
||||
/tmp/phrase_grounding_output.ipynb
|
||||
|
|
|
@ -27,7 +27,7 @@ repos:
|
|||
# The structure of this was suggested by the author of pre-commit and maintainer of flake8
|
||||
# See https://stackoverflow.com/a/66485642/3956024
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 4.0.1
|
||||
rev: 5.0.2
|
||||
hooks:
|
||||
- id: flake8
|
||||
name: flake8 ./hi-ml/
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from azureml.core import Datastore
|
||||
from sklearn import datasets
|
||||
from sklearn.model_selection import KFold
|
||||
|
||||
|
@ -40,8 +41,8 @@ def main() -> None:
|
|||
np.savetxt(str(target_splits_file), np.vstack(indices_test_splits), delimiter=",")
|
||||
|
||||
ws = get_workspace()
|
||||
datastore = get_datastore(workspace=ws,
|
||||
datastore_name="himldatasets")
|
||||
datastore: Datastore = get_datastore(workspace=ws,
|
||||
datastore_name="himldatasets")
|
||||
|
||||
dataset_name = 'himl_kfold_split_iris'
|
||||
datastore.upload_files(
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
from pathlib import Path
|
||||
|
||||
from azureml.core import Datastore
|
||||
|
||||
from health_azure.datasets import get_datastore
|
||||
from health_azure import get_workspace
|
||||
|
||||
|
@ -13,8 +15,8 @@ def main() -> None:
|
|||
|
||||
workspace = get_workspace()
|
||||
|
||||
datastore = get_datastore(workspace=workspace,
|
||||
datastore_name="himldatasets")
|
||||
datastore: Datastore = get_datastore(workspace=workspace,
|
||||
datastore_name="himldatasets")
|
||||
|
||||
# Either download all outputs:
|
||||
# run.download_files(prefix="outputs", output_directory=str(path))
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from azureml.core import Datastore
|
||||
from sklearn import datasets
|
||||
|
||||
from health_azure.datasets import get_datastore
|
||||
|
@ -25,8 +26,8 @@ def main() -> None:
|
|||
|
||||
workspace = get_workspace()
|
||||
|
||||
datastore = get_datastore(workspace=workspace,
|
||||
datastore_name="himldatasets")
|
||||
datastore: Datastore = get_datastore(workspace=workspace,
|
||||
datastore_name="himldatasets")
|
||||
|
||||
datastore.upload_files(
|
||||
[str(X_csv), str(y_csv)],
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from azureml.core import Datastore
|
||||
from sklearn import datasets
|
||||
|
||||
from health_azure import get_workspace
|
||||
|
@ -25,8 +26,8 @@ def main() -> None:
|
|||
|
||||
workspace = get_workspace()
|
||||
|
||||
datastore = get_datastore(workspace=workspace,
|
||||
datastore_name="himldatasets")
|
||||
datastore: Datastore = get_datastore(workspace=workspace,
|
||||
datastore_name="himldatasets")
|
||||
|
||||
datastore.upload_files(
|
||||
[str(X_csv), str(y_csv)],
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Hyperparameter Search via Hyperdrive
|
||||
# Hyperparameter Search via Hyperdrive (AML SDK v1)
|
||||
|
||||
[HyperDrive runs](https://docs.microsoft.com/en-us/azure/machine-learning/how-to-tune-hyperparameters)
|
||||
can start multiple AzureML jobs in parallel. This can be used for tuning hyperparameters, or executing multiple
|
||||
|
@ -27,3 +27,37 @@ submit_to_azure_if_needed(..., hyperdrive_config=hyperdrive_config)
|
|||
|
||||
For further examples, please check the [example scripts here](examples.md), and the
|
||||
[HyperDrive documentation](https://docs.microsoft.com/en-us/azure/machine-learning/how-to-tune-hyperparameters).
|
||||
|
||||
|
||||
# Hyperparameter Search in AML SDK v2
|
||||
|
||||
There is no concept of a HyperDriveConfig in AML SDK v2. Instead, hyperparameter search arguments are passed into a
|
||||
command, and then the 'sweep' method is called [AML
|
||||
docs](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-tune-hyperparameters). To specify a hyperparameter
|
||||
search job you must specify the method `get_parameter_tuning_args` in your Container. This should return a dictionary of
|
||||
the arguments to be passed in to the command. For example:
|
||||
|
||||
```python
|
||||
def get_parameter_tuning_args(self) -> Dict[str, Any]:
|
||||
from azure.ai.ml.entities import Choice
|
||||
from health_azure.himl import (MAX_TOTAL_TRIALS_ARG, PARAM_SAMPLING_ARG, SAMPLING_ALGORITHM_ARG,
|
||||
PRIMARY_METRIC_ARG, GOAL_ARG)
|
||||
|
||||
values = [0.1, 0.5, 0.9]
|
||||
argument_name = "learning_rate"
|
||||
param_sampling = {argument_name: Choice(values)}
|
||||
metric_name = "val/loss"
|
||||
|
||||
hparam_args = {
|
||||
MAX_TOTAL_TRIALS_ARG: len(values),
|
||||
PARAM_SAMPLING_ARG: param_sampling,
|
||||
SAMPLING_ALGORITHM_ARG: "grid",
|
||||
PRIMARY_METRIC_ARG: metric_name,
|
||||
GOAL_ARG: "Minimize"
|
||||
|
||||
}
|
||||
return hparam_args
|
||||
```
|
||||
Additional parameters, sampling strategies, limits etc. are described in the link above. Note that each job that is
|
||||
created will receive an additional command line argument `<argument_name>` and it is your job to update the script to be
|
||||
able to parse and use this argument.
|
|
@ -22,7 +22,7 @@ The `hi-ml` toolbox provides
|
|||
azure_setup.md
|
||||
authentication.md
|
||||
datasets.md
|
||||
hyperdrive.md
|
||||
hyperparameter_search.md
|
||||
lowpriority.md
|
||||
commandline_tools.md
|
||||
downloading.md
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
../../hi-ml-multimodal/README.md
|
|
@ -291,3 +291,12 @@ experiment will get run 3 times with seeds 0, 1 and 2. This is equivalent to sta
|
|||
|
||||
These runs will be started in parallel in AzureML via the HyperDrive framework. It is not possible to run with different
|
||||
seeds on a local machine, other than by manually starting runs with `--random_seed=0` etc.
|
||||
|
||||
## Common problems with running in AML
|
||||
|
||||
1. `"Your total snapshot size exceeds the limit <SNAPSHOT_LIMIT>"`. Cause: The size of your source directory is larger than
|
||||
the limit that AML sets for snapshots. Solution: check for cache files, log files or other files that are not
|
||||
necessary for running your experiment and add them to a `.amlignore` file in the root directory. Alternatively, you
|
||||
can see Azure ML documentation for instructions on increasing this limit, although it will make your jobs slower.
|
||||
2. `"FileNotFoundError"`. Possible cause: Symlinked files. Azure ML SDK v2 will resolve the symlink and attempt to upload
|
||||
the resolved file. Solution: Remove symlinks from any files that should be uploaded to Azure ML.
|
||||
|
|
|
@ -133,7 +133,7 @@ To use this code with your own data, you will need to:
|
|||
and `DatasetWithReturnIndex`. See for example how we constructed `RSNAKaggleCXR`
|
||||
class. WARNING: the first positional argument of your dataset class constructor MUST be the data directory ("root"),
|
||||
as VisionDataModule expects this in the prepare_data step.
|
||||
3. In your own container update the `_SSLDataClassMappings` member of the class so that the code knows which data class
|
||||
3. In your own container update the `DatasetToClassMapping` member of the class so that the code knows which data class
|
||||
to associate to your new dataset name.
|
||||
4. Create a yaml configuration file that contains the augmentations specific to your dataset. The yaml file will be
|
||||
consumed by the `create_transforms_from_config` function defined in the
|
||||
|
|
|
@ -39,9 +39,9 @@ etc. Here, we need to define some important parameters:
|
|||
1. The dataset we want to use for training the image encoder and the linear model we only use for evaluation of the
|
||||
image encoder. In theory, they could be two different datasets.
|
||||
|
||||
+ [```ssl_training_dataset_name=SSLDatasetNameHiml.TCGA_CRCK```](https://github.com/microsoft/hi-ml/blob/7f4baadaa8bc0d08a4895ca896ebc3f68ea6a4f8/hi-ml-histopathology/src/histopathology/configs/SSL/CRCK_SimCLRContainer.py#L40)
|
||||
+ [```ssl_training_dataset_name=SSL_Dataset_TCGA_CRCK```](https://github.com/microsoft/hi-ml/blob/7f4baadaa8bc0d08a4895ca896ebc3f68ea6a4f8/hi-ml-histopathology/src/histopathology/configs/SSL/CRCK_SimCLRContainer.py#L40)
|
||||
|
||||
+ [```linear_head_dataset_name=SSLDatasetNameHiml.TCGA_CRCK```](https://github.com/microsoft/hi-ml/blob/7f4baadaa8bc0d08a4895ca896ebc3f68ea6a4f8/hi-ml-histopathology/src/histopathology/configs/SSL/CRCK_SimCLRContainer.py#L41)
|
||||
+ [```linear_head_dataset_name=SSL_Dataset_TCGA_CRCK```](https://github.com/microsoft/hi-ml/blob/7f4baadaa8bc0d08a4895ca896ebc3f68ea6a4f8/hi-ml-histopathology/src/histopathology/configs/SSL/CRCK_SimCLRContainer.py#L41)
|
||||
|
||||
1. Model checkpointing: We use [PyTorch lightning
|
||||
checkpointing](https://pytorch-lightning.readthedocs.io/en/stable/common/checkpointing.html). Among others, we define
|
||||
|
|
|
@ -1,6 +1,14 @@
|
|||
#! /usr/bin/env python
|
||||
|
||||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import param
|
||||
|
@ -8,7 +16,6 @@ from _pytest.main import ExitCode
|
|||
from azureml._restclient.constants import RunStatus
|
||||
from azureml.core import Run
|
||||
|
||||
|
||||
# Add hi-ml packages to sys.path so that AML can find them if we are using the runner directly from the git repo
|
||||
himl_root = Path(__file__).resolve().parent.parent
|
||||
|
||||
|
@ -34,7 +41,7 @@ from health_azure.utils import ( # noqa: E402
|
|||
is_running_in_azure_ml,
|
||||
parse_arguments,
|
||||
)
|
||||
from health_ml.utils.common_utils import DEFAULT_AML_UPLOAD_DIR # noqa: E402
|
||||
from health_ml.utils.common_utils import DEFAULT_AML_UPLOAD_DIR, DEFAULT_DOCKER_BASE_IMAGE # noqa: E402
|
||||
|
||||
PYTEST_RESULTS_FILE = "pytest_results.xml"
|
||||
PYTEST_GPU_COVERAGE_FILE = "pytest_gpu_coverage.xml"
|
||||
|
@ -68,6 +75,8 @@ class RunPytestConfig(param.Parameterized):
|
|||
default="",
|
||||
doc="A folder name that should be added to sys.path. The folder name should be relative to repository root."
|
||||
)
|
||||
strictly_aml_v1: bool = param.Boolean(default=True, doc="If True, use AzureML v1 SDK. If False (default), use "
|
||||
"the v2 of the SDK")
|
||||
|
||||
|
||||
def run_pytest(folder_to_test: str, pytest_mark: str, coverage_module: str) -> None:
|
||||
|
@ -178,5 +187,8 @@ if __name__ == "__main__":
|
|||
experiment_name=config.experiment,
|
||||
max_run_duration=config.max_run_duration,
|
||||
after_submission=pytest_after_submission_hook,
|
||||
docker_base_image=DEFAULT_DOCKER_BASE_IMAGE,
|
||||
strictly_aml_v1=config.strictly_aml_v1,
|
||||
)
|
||||
run_pytest(folder_to_test=config.folder, pytest_mark=config.mark, coverage_module=config.coverage_module)
|
||||
time.sleep(10) # Give the AzureML job time to finish uploading the pytest result file.
|
||||
|
|
|
@ -1,12 +1,16 @@
|
|||
azureml-core==1.43.0
|
||||
azureml-dataset-runtime[fuse]
|
||||
azureml-tensorboard==1.43.0
|
||||
azureml-train-core==1.43.0
|
||||
azure-ai-ml>=0.1.0b6
|
||||
azureml-core>=1.42.0
|
||||
azureml-dataset-runtime[fuse]>=1.42.0
|
||||
azureml-mlflow>=1.42.0
|
||||
azure-storage-blob==12.10.0
|
||||
azureml-tensorboard>=1.42.0
|
||||
azureml-train-core>=1.42.0
|
||||
conda-merge>=0.1.5
|
||||
mlflow>=1.29.0
|
||||
pandas>=1.3.4
|
||||
param>=1.12
|
||||
protobuf<=3.20.1
|
||||
protobuf<4.0
|
||||
pysocks>=1.5.8
|
||||
ruamel.yaml>=0.16.12
|
||||
tensorboard>=2.6.0
|
||||
typing-extensions==4.3.0
|
||||
typing-extensions>=4.3.0
|
||||
|
|
|
@ -19,76 +19,76 @@ from setuptools import setup, find_namespace_packages # type: ignore
|
|||
here = pathlib.Path(__file__).parent.resolve()
|
||||
|
||||
# Get the long description from the README file
|
||||
long_description = (here / 'package_description.md').read_text(encoding='utf-8')
|
||||
long_description = (here / "package_description.md").read_text(encoding="utf-8")
|
||||
|
||||
version = ''
|
||||
version = ""
|
||||
|
||||
# If running from a GitHub Action then a standard set of environment variables will be
|
||||
# populated (https://docs.github.com/en/actions/reference/environment-variables#default-environment-variables).
|
||||
# In particular, GITHUB_REF is the branch or tag ref that triggered the workflow.
|
||||
# If this was triggered by a tagged commit then GITHUB_REF will be: 'ref/tags/new_tag'.
|
||||
# If this was triggered by a tagged commit then GITHUB_REF will be: "ref/tags/new_tag".
|
||||
# Extract this tag and use it as a version string
|
||||
# See also:
|
||||
# https://packaging.python.org/guides/publishing-package-distribution-releases-using-github-actions-ci-cd-workflows/
|
||||
# https://github.com/pypa/gh-action-pypi-publish
|
||||
GITHUB_REF_TAG_COMMIT = 'refs/tags/v'
|
||||
GITHUB_REF_TAG_COMMIT = "refs/tags/v"
|
||||
|
||||
github_ref = os.getenv('GITHUB_REF')
|
||||
github_ref = os.getenv("GITHUB_REF")
|
||||
if github_ref and github_ref.startswith(GITHUB_REF_TAG_COMMIT):
|
||||
version = github_ref[len(GITHUB_REF_TAG_COMMIT):]
|
||||
|
||||
# Otherwise, if running from a GitHub Action, but not a tagged commit then GITHUB_RUN_NUMBER will be populated.
|
||||
# Use this as a post release number. For example if GITHUB_RUN_NUMBER = 124 then the version string will be
|
||||
# '99.99.post124'. Although this is discouraged, see:
|
||||
# "99.99.post124". Although this is discouraged, see:
|
||||
# https://www.python.org/dev/peps/pep-0440/#post-releases
|
||||
# it is necessary here to avoid duplicate packages in Test.PyPI.
|
||||
if not version:
|
||||
build_number = os.getenv('GITHUB_RUN_NUMBER')
|
||||
build_number = os.getenv("GITHUB_RUN_NUMBER")
|
||||
if build_number:
|
||||
# In github workflows, tests for hi-ml pull in hi-ml-azure as a dependency. Usually, we have a condition like
|
||||
# hi-ml-azure>=0.1.5. This means that a package version from PyPi would trump the local wheels. For this reason,
|
||||
# use an extremely large version number to give the local wheel priority.
|
||||
version = '99.991.post' + build_number
|
||||
version = "99.991.post" + build_number
|
||||
else:
|
||||
default_random_version_number = floor(random() * 10_000_000_000)
|
||||
version = f'99.991.post{str(default_random_version_number)}'
|
||||
version = f"99.991.post{str(default_random_version_number)}"
|
||||
|
||||
(here / 'package_name.txt').write_text('hi-ml-azure')
|
||||
(here / 'latest_version.txt').write_text(version)
|
||||
(here / "package_name.txt").write_text("hi-ml-azure")
|
||||
(here / "latest_version.txt").write_text(version)
|
||||
|
||||
# Read run_requirements.txt to get install_requires
|
||||
install_requires = (here / 'run_requirements.txt').read_text().split("\n")
|
||||
install_requires = (here / "run_requirements.txt").read_text().split("\n")
|
||||
# Remove any whitespace and blank lines
|
||||
install_requires = [line.strip() for line in install_requires if line.strip()]
|
||||
|
||||
description = 'Microsoft Health Intelligence package to elevate and monitor scripts to an AzureML workspace'
|
||||
description = "Microsoft Health Futures package to elevate and monitor scripts to an AzureML workspace"
|
||||
|
||||
setup(
|
||||
name='hi-ml-azure',
|
||||
name="hi-ml-azure",
|
||||
version=version,
|
||||
description=description,
|
||||
long_description=long_description,
|
||||
long_description_content_type='text/markdown',
|
||||
url='https://github.com/microsoft/hi-ml',
|
||||
author="Microsoft Research Cambridge InnerEye Team ",
|
||||
long_description_content_type="text/markdown",
|
||||
url="https://github.com/microsoft/hi-ml",
|
||||
author="Biomedical Imaging Team @ Microsoft Health Futures",
|
||||
author_email="innereyedev@microsoft.com",
|
||||
classifiers=[
|
||||
'Development Status :: 3 - Alpha',
|
||||
'Intended Audience :: Science/Research',
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Intended Audience :: Science/Research",
|
||||
"Topic :: Scientific/Engineering :: Medical Science Apps.",
|
||||
'License :: OSI Approved :: MIT License',
|
||||
'Programming Language :: Python :: 3.7'
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3.7"
|
||||
],
|
||||
keywords='InnerEye, HealthIntelligence, AzureML',
|
||||
license='MIT License',
|
||||
keywords="Health Futures, Health Intelligence, AzureML",
|
||||
license="MIT License",
|
||||
packages=find_namespace_packages(where="src"),
|
||||
package_dir={"": "src"},
|
||||
include_package_data=True,
|
||||
install_requires=install_requires,
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
'himl-tb = health_azure.himl_tensorboard:main',
|
||||
'himl-download = health_azure.himl_download:main'
|
||||
"console_scripts": [
|
||||
"himl-tb = health_azure.himl_tensorboard:main",
|
||||
"himl-download = health_azure.himl_download:main"
|
||||
]
|
||||
}
|
||||
)
|
||||
|
|
|
@ -5,17 +5,28 @@
|
|||
import logging
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from azureml.core import Dataset, Datastore, Workspace
|
||||
from azure.ai.ml import MLClient
|
||||
from azure.ai.ml.entities import Data
|
||||
from azure.ai.ml.entities import Datastore as V2Datastore
|
||||
from azure.ai.ml.constants import AssetTypes
|
||||
from azure.ai.ml.operations import DatastoreOperations
|
||||
from azure.core.exceptions import HttpResponseError, ResourceNotFoundError
|
||||
from azureml.core import Dataset, Workspace, Datastore
|
||||
from azureml.data import FileDataset, OutputFileDatasetConfig
|
||||
from azureml.data.azure_storage_datastore import AzureBlobDatastore
|
||||
from azureml.data.dataset_consumption_config import DatasetConsumptionConfig
|
||||
from azureml.dataprep.fuse.daemon import MountContext
|
||||
from azureml.exceptions._azureml_exception import UserErrorException
|
||||
|
||||
from health_azure.utils import PathOrString, get_workspace
|
||||
from health_azure.utils import PathOrString, get_workspace, get_ml_client
|
||||
|
||||
|
||||
def get_datastore(workspace: Workspace, datastore_name: str) -> Datastore:
|
||||
V1OrV2DataType = Union[FileDataset, Data]
|
||||
|
||||
|
||||
def get_datastore(workspace: Workspace, datastore_name: str) -> Union[AzureBlobDatastore, V2Datastore]:
|
||||
"""
|
||||
Retrieves a datastore of a given name from an AzureML workspace. The datastore_name argument can be omitted if
|
||||
the workspace only contains a single datastore. Raises a ValueError if there is no datastore of the given name.
|
||||
|
@ -24,45 +35,199 @@ def get_datastore(workspace: Workspace, datastore_name: str) -> Datastore:
|
|||
:param datastore_name: The name of the datastore to retrieve.
|
||||
:return: An AzureML datastore.
|
||||
"""
|
||||
datastores = workspace.datastores
|
||||
existing_stores = list(datastores.keys())
|
||||
if not datastore_name:
|
||||
def _retrieve_v1_datastore(datastores: Dict[str, Datastore], datastore_name: str) -> Datastore:
|
||||
# First check if there is only one datastore, which is then obviously unique.
|
||||
# Only then try to use the default datastore, because there may not be a default set.
|
||||
if len(existing_stores) == 1:
|
||||
return datastores[existing_stores[0]]
|
||||
datastore = workspace.get_default_datastore()
|
||||
logging.info(f"Using the workspace default datastore {datastore.name} to access datasets.")
|
||||
existing_stores = list(datastores.keys())
|
||||
if not datastore_name:
|
||||
if len(existing_stores) == 1:
|
||||
return datastores[existing_stores[0]]
|
||||
datastore = workspace.get_default_datastore()
|
||||
logging.info(f"Using the workspace default datastore {datastore.name} to access datasets.")
|
||||
return datastore
|
||||
|
||||
if datastore_name in datastores:
|
||||
return datastores[datastore_name]
|
||||
raise ValueError(f"Datastore \"{datastore_name}\" was not found in the \"{workspace.name}\" workspace. "
|
||||
f"Existing datastores: {existing_stores}")
|
||||
|
||||
def _retrieve_v2_datastore(datastores: DatastoreOperations, datastore_name: str) -> V2Datastore:
|
||||
existing_stores = list(datastores.list())
|
||||
if not datastore_name:
|
||||
if len(existing_stores) == 1:
|
||||
return existing_stores[0]
|
||||
datastore = datastores.get_default()
|
||||
logging.info(f"Using the workspace default datastore {datastore.name} to access datasets.")
|
||||
return datastore
|
||||
|
||||
try:
|
||||
datastore = datastores.get(datastore_name)
|
||||
except ResourceNotFoundError:
|
||||
raise ValueError(f"Datastore \"{datastore_name}\" was not found in the workspace")
|
||||
return datastore
|
||||
if datastore_name in datastores:
|
||||
return datastores[datastore_name]
|
||||
raise ValueError(f"Datastore \"{datastore_name}\" was not found in the \"{workspace.name}\" workspace. "
|
||||
f"Existing datastores: {existing_stores}")
|
||||
|
||||
datastores = workspace.datastores
|
||||
if isinstance(datastores, DatastoreOperations):
|
||||
return _retrieve_v2_datastore(datastores, datastore_name)
|
||||
elif isinstance(datastores, dict):
|
||||
return _retrieve_v1_datastore(datastores, datastore_name)
|
||||
else:
|
||||
raise ValueError(f"Unrecognised type for datastores: {type(datastores)}")
|
||||
|
||||
|
||||
def get_or_create_dataset(workspace: Workspace, datastore_name: str, dataset_name: str) -> FileDataset:
|
||||
def _retrieve_v1_dataset(dataset_name: str, workspace: Workspace) -> Optional[FileDataset]:
|
||||
"""
|
||||
Retrieve an Azure ML v1 Dataset if it exists, otherwise return None
|
||||
|
||||
:param dataset_name: The name of the Dataset to look for.
|
||||
:param workspace: An Azure ML Workspace object for retrieving the Dataset.
|
||||
:return: A Dataset object if it is found, else None.
|
||||
"""
|
||||
logging.info(f"Trying to retrieve AzureML Dataset '{dataset_name}'")
|
||||
azureml_dataset = Dataset.get_by_name(workspace, name=dataset_name)
|
||||
return azureml_dataset
|
||||
|
||||
|
||||
def _create_v1_dataset(datastore_name: str, dataset_name: str, workspace: Workspace
|
||||
) -> FileDataset:
|
||||
"""
|
||||
Create a v1 Dataset in the specified Datastore
|
||||
|
||||
:param datastore_name: The AML Datastore to create the Dataset in.
|
||||
:param dataset_name: The name of the Dataset to create.
|
||||
:param workspace: An AML Workspace object.
|
||||
:return: An Azure ML (v1) FileDataset object.
|
||||
"""
|
||||
if not dataset_name:
|
||||
raise ValueError(f"Cannot create dataset without a valid dataset name (received '{dataset_name}')")
|
||||
|
||||
datastore = get_datastore(workspace, datastore_name)
|
||||
|
||||
assert isinstance(datastore, AzureBlobDatastore)
|
||||
logging.info(f"Creating a new dataset from data in folder '{dataset_name}' in the datastore")
|
||||
# Ensure that there is a / at the end of the file path, otherwise folder that share a prefix could create
|
||||
# trouble (for example, folders foo and foo_bar exist, and I'm trying to create a dataset from "foo")
|
||||
azureml_dataset = Dataset.File.from_files(path=(datastore, dataset_name + "/"))
|
||||
logging.info("Registering the dataset for future use.")
|
||||
azureml_dataset.register(workspace, name=dataset_name)
|
||||
return azureml_dataset
|
||||
|
||||
|
||||
def _get_or_create_v1_dataset(datastore_name: str, dataset_name: str, workspace: Workspace) -> Dataset:
|
||||
"""
|
||||
Attempt to retrieve a v1 Dataset object and return that, otherwise attempt to create and register
|
||||
a v1 Dataset and return that.
|
||||
|
||||
:param datastore_name: The name of the Datastore to either retrieve or create and register the Dataset in.
|
||||
:param dataset_name: The name of the Dataset to be retrieved or registered.
|
||||
:param workspace: An Azure ML Workspace object.
|
||||
:return: An Azure ML Dataset object with the provided dataset name, in the provided datastore.
|
||||
"""
|
||||
try:
|
||||
azureml_dataset = _retrieve_v1_dataset(dataset_name, workspace)
|
||||
except UserErrorException:
|
||||
azureml_dataset = _create_v1_dataset(datastore_name, dataset_name, workspace)
|
||||
return azureml_dataset
|
||||
|
||||
|
||||
def _retrieve_v2_dataset(dataset_name: str, ml_client: MLClient) -> Data:
|
||||
"""
|
||||
Attempt to retrieve a v2 Data Asset using a provided Azure ML Workspace connection. If
|
||||
no Data asset can be found with a matching name, the underlying code will raise an Exception
|
||||
|
||||
:param dataset_name: The name of the dataset to look for.
|
||||
:param ml_client: An Azure MLClient object for interacting with Azure resources.
|
||||
:return: An Azure Data asset representing the dataset if found, otherwise an Exception will be raised.
|
||||
"""
|
||||
aml_data = ml_client.data.get(name=dataset_name)
|
||||
return aml_data
|
||||
|
||||
|
||||
def _create_v2_dataset(datastore_name: str, dataset_name: str, ml_client: MLClient) -> Data:
|
||||
"""
|
||||
Create or update a v2 Data asset in the specified Datastore
|
||||
|
||||
:param datastore_name: The name of the datastore in which to create or update the Data asset.
|
||||
:param dataset_name: The name of the dataset to be created.
|
||||
:param ml_client: An Azure MLClient object for interacting with Azure resources.
|
||||
:raises ValueError: If no datastore name is provided to define where to create the data
|
||||
:return: The created or updated Data asset.
|
||||
"""
|
||||
if not dataset_name:
|
||||
raise ValueError(f"Cannot create data asset without a valid dataset name (received {dataset_name})")
|
||||
|
||||
if not datastore_name:
|
||||
default_datastore = ml_client.datastores.get_default()
|
||||
datastore_name = default_datastore.name
|
||||
|
||||
logging.info(f"Creating a new Data asset from data in folder '{dataset_name}' in the datastore '{datastore_name}'")
|
||||
azureml_data_asset = Data(
|
||||
path=f"azureml://datastores/{datastore_name}/paths/{dataset_name}/",
|
||||
type=AssetTypes.URI_FOLDER,
|
||||
description="<description>",
|
||||
name=dataset_name,
|
||||
version=None
|
||||
)
|
||||
ml_client.data.create_or_update(azureml_data_asset)
|
||||
return azureml_data_asset
|
||||
|
||||
|
||||
def _get_or_create_v2_dataset(datastore_name: str, dataset_name: str, ml_client: MLClient) -> Data:
|
||||
"""
|
||||
Attempt to retrieve a v2 Dataset object and return that, otherwise attempt to create and register
|
||||
a v2 Dataset and return that.
|
||||
|
||||
:param datastore_name: The name of the Datastore to either retrieve or create and register the Data asset in.
|
||||
:param dataset_name: The name of the Data asset to be retrieved or registered.
|
||||
:param ml_client: An Azure MLClient object for interacting with Azure resources.
|
||||
:return: An Azure Data asset object with the provided dataset name, in the provided datastore
|
||||
"""
|
||||
try:
|
||||
azureml_dataset = _retrieve_v2_dataset(dataset_name, ml_client)
|
||||
except Exception:
|
||||
azureml_dataset = _create_v2_dataset(datastore_name, dataset_name, ml_client)
|
||||
return azureml_dataset
|
||||
|
||||
|
||||
def get_or_create_dataset(datastore_name: str,
|
||||
dataset_name: str,
|
||||
workspace: Workspace,
|
||||
strictly_aml_v1: bool,
|
||||
ml_client: Optional[MLClient] = None,
|
||||
) -> V1OrV2DataType:
|
||||
"""
|
||||
Looks in the AzureML datastore for a dataset of the given name. If there is no such dataset, a dataset is
|
||||
created and registered, assuming that the files are in a folder that has the same name as the dataset.
|
||||
For example, if dataset_name is 'foo', then the 'foo' dataset should be pointing to the folder
|
||||
<container_root>/datasets/dataset_name/
|
||||
<container_root>/datasets/dataset_name/.
|
||||
|
||||
If the command line arg to strictly use AML SDK v1 is set to True, will attempt to retrieve a dataset using
|
||||
v1 of the SDK. Otherwise, will attempt to use v2 of the SDK. If no data of this name is found in the v2 datastore,
|
||||
will attempt to create it, but if the data container provided is v1 version, will fall back to using the
|
||||
v1 SDK to create and register this dataset.
|
||||
|
||||
:param datastore_name: The name of the datastore in which to look for, or create and register, the dataset.
|
||||
:param dataset_name: The name of the dataset to find or create.
|
||||
:param workspace: An AML Workspace object for interacting with AML v1 datastores.
|
||||
:param strictly_aml_v1: If True, use Azure ML SDK v1 to attempt to find or create and reigster the dataset.
|
||||
Otherwise, attempt to use Azure ML SDK v2.
|
||||
:param ml_client: An optional MLClient object for interacting with AML v2 datastores.
|
||||
"""
|
||||
if not dataset_name:
|
||||
raise ValueError("No dataset name provided.")
|
||||
try:
|
||||
logging.info(f"Trying to retrieve AzureML Dataset '{dataset_name}'")
|
||||
azureml_dataset = Dataset.get_by_name(workspace, name=dataset_name)
|
||||
logging.info("Dataset found.")
|
||||
except Exception:
|
||||
logging.info(f"Retrieving datastore '{datastore_name}' from AzureML workspace")
|
||||
datastore = get_datastore(workspace, datastore_name)
|
||||
logging.info(f"Creating a new dataset from data in folder '{dataset_name}' in the datastore")
|
||||
# Ensure that there is a / at the end of the file path, otherwise folder that share a prefix could create
|
||||
# trouble (for example, folders foo and foo_bar exist, and I'm trying to create a dataset from "foo")
|
||||
azureml_dataset = Dataset.File.from_files(path=(datastore, dataset_name + "/"))
|
||||
logging.info("Registering the dataset for future use.")
|
||||
azureml_dataset.register(workspace, name=dataset_name)
|
||||
return azureml_dataset
|
||||
if strictly_aml_v1:
|
||||
aml_dataset = _get_or_create_v1_dataset(datastore_name, dataset_name, workspace)
|
||||
return aml_dataset
|
||||
else:
|
||||
try:
|
||||
ml_client = get_ml_client(ml_client=ml_client)
|
||||
aml_dataset = _get_or_create_v2_dataset(datastore_name, dataset_name, ml_client)
|
||||
except HttpResponseError as e:
|
||||
if "Cannot create v2 Data Version in v1 Data Container" in e.message:
|
||||
logging.info("This appears to be a v1 Data Container. Reverting to API v1 to create this Dataset")
|
||||
aml_dataset = _get_or_create_v1_dataset(datastore_name, dataset_name, workspace)
|
||||
|
||||
return aml_dataset
|
||||
|
||||
|
||||
def _input_dataset_key(index: int) -> str:
|
||||
|
@ -118,14 +283,22 @@ class DatasetConfig:
|
|||
raise ValueError("Can't mount or download a dataset to the current working directory.")
|
||||
self.local_folder = Path(local_folder) if local_folder else None
|
||||
|
||||
def to_input_dataset_local(self, workspace: Optional[Workspace]) -> Tuple[Path, Optional[MountContext]]:
|
||||
def to_input_dataset_local(self,
|
||||
strictly_aml_v1: bool,
|
||||
workspace: Workspace = None,
|
||||
ml_client: Optional[MLClient] = None,
|
||||
) -> Tuple[Optional[Path], Optional[MountContext]]:
|
||||
"""
|
||||
Return a local path to the dataset when outside of an AzureML run.
|
||||
If local_folder is supplied, then this is assumed to be a local dataset, and this is returned.
|
||||
Otherwise the dataset is mounted or downloaded to either the target folder or a temporary folder and that is
|
||||
returned.
|
||||
returned. If self.name refers to a v2 dataset, it is not possible to mount the data here,
|
||||
therefore a tuple of Nones will be returned.
|
||||
|
||||
:param workspace: The AzureML workspace to read from.
|
||||
:param strictly_aml_v1: If True, use Azure ML SDK v1 to attempt to find or create and reigster the dataset.
|
||||
Otherwise, attempt to use Azure ML SDK v2.
|
||||
:param ml_client: An Azure MLClient object for interacting with Azure resources.
|
||||
:return: Tuple of (path to dataset, optional mountcontext)
|
||||
"""
|
||||
status = f"Dataset '{self.name}' will be "
|
||||
|
@ -138,55 +311,70 @@ class DatasetConfig:
|
|||
if workspace is None:
|
||||
raise ValueError(f"Unable to make dataset '{self.name} available for a local run because no AzureML "
|
||||
"workspace has been provided. Provide a workspace, or set a folder for local execution.")
|
||||
|
||||
azureml_dataset = get_or_create_dataset(workspace=workspace,
|
||||
ml_client=ml_client,
|
||||
dataset_name=self.name,
|
||||
datastore_name=self.datastore)
|
||||
target_path = self.target_folder or Path(tempfile.mkdtemp())
|
||||
use_mounting = self.use_mounting if self.use_mounting is not None else False
|
||||
if use_mounting:
|
||||
status += f"mounted at {target_path}"
|
||||
datastore_name=self.datastore,
|
||||
strictly_aml_v1=strictly_aml_v1)
|
||||
if isinstance(azureml_dataset, FileDataset):
|
||||
target_path = self.target_folder or Path(tempfile.mkdtemp())
|
||||
use_mounting = self.use_mounting if self.use_mounting is not None else False
|
||||
if use_mounting:
|
||||
status += f"mounted at {target_path}"
|
||||
|
||||
mount_context = azureml_dataset.mount(mount_point=str(target_path)) # type: ignore
|
||||
result = target_path, mount_context
|
||||
else:
|
||||
status += f"downloaded to {target_path}"
|
||||
|
||||
azureml_dataset.download(target_path=str(target_path), overwrite=False) # type: ignore
|
||||
result = target_path, None
|
||||
print(status)
|
||||
mount_context = azureml_dataset.mount(mount_point=str(target_path))
|
||||
result = target_path, mount_context
|
||||
return result
|
||||
else:
|
||||
status += f"downloaded to {target_path}"
|
||||
print(status)
|
||||
azureml_dataset.download(target_path=str(target_path), overwrite=False)
|
||||
result = target_path, None
|
||||
return result
|
||||
return None, None
|
||||
|
||||
def to_input_dataset(self,
|
||||
dataset_index: int,
|
||||
workspace: Workspace,
|
||||
dataset_index: int) -> DatasetConsumptionConfig:
|
||||
strictly_aml_v1: bool,
|
||||
ml_client: Optional[MLClient] = None,
|
||||
) -> Optional[DatasetConsumptionConfig]:
|
||||
"""
|
||||
Creates a configuration for using an AzureML dataset inside of an AzureML run. This will make the AzureML
|
||||
dataset with given name available as a named input, using INPUT_0 as the key for dataset index 0.
|
||||
|
||||
:param workspace: The AzureML workspace to read from.
|
||||
:param dataset_index: Suffix for using datasets as named inputs, the dataset will be marked INPUT_{index}
|
||||
:param strictly_aml_v1: If True, use Azure ML SDK v1. Otherwise, attempt to use Azure ML SDK v2.
|
||||
:param ml_client: An Azure MLClient object for interacting with Azure resources.
|
||||
"""
|
||||
status = f"In AzureML, dataset {self.name} (index {dataset_index}) will be "
|
||||
azureml_dataset = get_or_create_dataset(workspace=workspace,
|
||||
ml_client=ml_client,
|
||||
dataset_name=self.name,
|
||||
datastore_name=self.datastore)
|
||||
named_input = azureml_dataset.as_named_input(_input_dataset_key(index=dataset_index))
|
||||
datastore_name=self.datastore,
|
||||
strictly_aml_v1=strictly_aml_v1)
|
||||
# If running on windows then self.target_folder may be a WindowsPath, make sure it is
|
||||
# in posix format for Azure.
|
||||
path_on_compute = self.target_folder.as_posix() if self.target_folder is not None else None
|
||||
use_mounting = False if self.use_mounting is None else self.use_mounting
|
||||
if use_mounting:
|
||||
status += "mounted at "
|
||||
result = named_input.as_mount(path_on_compute)
|
||||
if isinstance(azureml_dataset, FileDataset):
|
||||
named_input = azureml_dataset.as_named_input(_input_dataset_key(index=dataset_index)) # type: ignore
|
||||
path_on_compute = self.target_folder.as_posix() if self.target_folder is not None else None
|
||||
if use_mounting:
|
||||
status += "mounted at "
|
||||
result = named_input.as_mount(path_on_compute)
|
||||
else:
|
||||
status += "downloaded to "
|
||||
result = named_input.as_download(path_on_compute)
|
||||
if path_on_compute:
|
||||
status += f"{path_on_compute}."
|
||||
else:
|
||||
status += "a randomly chosen folder."
|
||||
print(status)
|
||||
return result
|
||||
else:
|
||||
status += "downloaded to "
|
||||
result = named_input.as_download(path_on_compute)
|
||||
if path_on_compute:
|
||||
status += f"{path_on_compute}."
|
||||
else:
|
||||
status += "a randomly chosen folder."
|
||||
print(status)
|
||||
return result
|
||||
return None
|
||||
|
||||
def to_output_dataset(self,
|
||||
workspace: Workspace,
|
||||
|
@ -316,8 +504,10 @@ def find_workspace_for_local_datasets(aml_workspace: Optional[Workspace],
|
|||
|
||||
|
||||
def setup_local_datasets(dataset_configs: List[DatasetConfig],
|
||||
strictly_aml_v1: bool,
|
||||
aml_workspace: Optional[Workspace] = None,
|
||||
workspace_config_path: Optional[Path] = None
|
||||
ml_client: Optional[MLClient] = None,
|
||||
workspace_config_path: Optional[Path] = None,
|
||||
) -> Tuple[List[Optional[Path]], List[MountContext]]:
|
||||
"""
|
||||
When running outside of AzureML, setup datasets to be used locally.
|
||||
|
@ -330,16 +520,17 @@ def setup_local_datasets(dataset_configs: List[DatasetConfig],
|
|||
to pass it in as a parameter.
|
||||
:param workspace_config_path: The 2nd option is to specify the path to the config.json file downloaded from the
|
||||
Azure portal from which we can retrieve the existing Workspace.
|
||||
:param dataset_configs: List of DatasetConfig describing the input datasets.
|
||||
:param dataset_configs: List of DatasetConfig describing the input data assets.
|
||||
:param strictly_aml_v1: If True, use Azure ML SDK v1. Otherwise, attempt to use Azure ML SDK v2.
|
||||
:param ml_client: An MLClient object for interacting with AML v2 datastores.
|
||||
:return: Pair of: list of optional paths to the input datasets, list of mountcontexts, one for each mounted dataset.
|
||||
"""
|
||||
workspace = find_workspace_for_local_datasets(aml_workspace, workspace_config_path, dataset_configs)
|
||||
|
||||
mounted_input_datasets: List[Optional[Path]] = []
|
||||
mount_contexts: List[MountContext] = []
|
||||
|
||||
for d in dataset_configs:
|
||||
target_path, mount_context = d.to_input_dataset_local(workspace)
|
||||
for data_config in dataset_configs:
|
||||
target_path, mount_context = data_config.to_input_dataset_local(strictly_aml_v1, workspace, ml_client)
|
||||
|
||||
mounted_input_datasets.append(target_path)
|
||||
|
||||
|
|
|
@ -11,14 +11,21 @@ See examples/elevate_this.py for a very simple 'hello world' example of use.
|
|||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import warnings
|
||||
from argparse import ArgumentParser
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Generator, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
|
||||
|
||||
from azure.ai.ml import MLClient, Input, Output, command
|
||||
from azure.ai.ml.constants import AssetTypes, InputOutputModes
|
||||
from azure.ai.ml.entities import Data, Job, Command, Sweep
|
||||
from azure.ai.ml.entities import Environment as EnvironmentV2
|
||||
|
||||
from azure.ai.ml.sweep import Choice
|
||||
from azureml._base_sdk_common import user_agent
|
||||
from azureml.core import ComputeTarget, Environment, Experiment, Run, RunConfiguration, ScriptRunConfig, Workspace
|
||||
from azureml.core.runconfig import DockerConfiguration, MpiConfiguration
|
||||
|
@ -31,9 +38,12 @@ from health_azure.amulet import (ENV_AMLT_DATAREFERENCE_DATA, ENV_AMLT_DATAREFER
|
|||
from health_azure.utils import (create_python_environment, create_run_recovery_id, find_file_in_parent_to_pythonpath,
|
||||
is_run_and_child_runs_completed, is_running_in_azure_ml, register_environment,
|
||||
run_duration_string_to_seconds, to_azure_friendly_string, RUN_CONTEXT, get_workspace,
|
||||
PathOrString, DEFAULT_ENVIRONMENT_VARIABLES)
|
||||
from health_azure.datasets import (DatasetConfig, StrOrDatasetConfig, _input_dataset_key, _output_dataset_key,
|
||||
_replace_string_datasets, setup_local_datasets)
|
||||
PathOrString, DEFAULT_ENVIRONMENT_VARIABLES, get_ml_client,
|
||||
create_python_environment_v2, register_environment_v2, V2_INPUT_DATASET_PATTERN,
|
||||
V2_OUTPUT_DATASET_PATTERN)
|
||||
from health_azure.datasets import (DatasetConfig, StrOrDatasetConfig, setup_local_datasets,
|
||||
_input_dataset_key, _output_dataset_key, _replace_string_datasets)
|
||||
|
||||
|
||||
logger = logging.getLogger('health_azure')
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
@ -47,6 +57,13 @@ RUN_RECOVERY_FILE = "most_recent_run.txt"
|
|||
SDK_NAME = "innereye"
|
||||
SDK_VERSION = "2.0"
|
||||
|
||||
# hyperparameter search args
|
||||
PARAM_SAMPLING_ARG = "parameter_sampling"
|
||||
MAX_TOTAL_TRIALS_ARG = "max_total_trials"
|
||||
PRIMARY_METRIC_ARG = "primary_metric"
|
||||
SAMPLING_ALGORITHM_ARG = "sampling_algorithm"
|
||||
GOAL_ARG = "goal"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AzureRunInfo:
|
||||
|
@ -215,7 +232,9 @@ def create_run_configuration(workspace: Workspace,
|
|||
if input_datasets or output_datasets:
|
||||
inputs, outputs = convert_himl_to_azureml_datasets(cleaned_input_datasets=input_datasets or [],
|
||||
cleaned_output_datasets=output_datasets or [],
|
||||
workspace=workspace)
|
||||
workspace=workspace,
|
||||
strictly_aml_v1=True
|
||||
)
|
||||
run_config.data = inputs
|
||||
run_config.output_data = outputs
|
||||
|
||||
|
@ -256,6 +275,30 @@ def create_grid_hyperdrive_config(values: List[str],
|
|||
)
|
||||
|
||||
|
||||
def create_grid_hyperparam_args_v2(values: List[Any],
|
||||
argument_name: str,
|
||||
metric_name: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a dictionary of arguments to create an Azure ML v2 SDK Sweep job.
|
||||
|
||||
:param values: The list of values to try for the commandline argument given by `argument_name`.
|
||||
:param argument_name: The name of the commandline argument that each of the child runs gets, to
|
||||
indicate which value they should work on.
|
||||
:param metric_name: The name of the metric that the sweep job will compare runs by. Please note that it is
|
||||
your responsibility to make sure a metric with this name is logged to the Run in your training script
|
||||
:return: A dictionary of arguments and values to pass in to the command job.
|
||||
"""
|
||||
param_sampling = {argument_name: Choice(values)}
|
||||
hyperparam_args = {
|
||||
MAX_TOTAL_TRIALS_ARG: len(values),
|
||||
PARAM_SAMPLING_ARG: param_sampling,
|
||||
SAMPLING_ALGORITHM_ARG: "grid",
|
||||
PRIMARY_METRIC_ARG: metric_name,
|
||||
GOAL_ARG: "Minimize"
|
||||
}
|
||||
return hyperparam_args
|
||||
|
||||
|
||||
def create_crossval_hyperdrive_config(num_splits: int,
|
||||
cross_val_index_arg_name: str = "crossval_index",
|
||||
metric_name: str = "val/loss") -> HyperDriveConfig:
|
||||
|
@ -276,6 +319,24 @@ def create_crossval_hyperdrive_config(num_splits: int,
|
|||
metric_name=metric_name)
|
||||
|
||||
|
||||
def create_crossval_hyperparam_args_v2(num_splits: int,
|
||||
cross_val_index_arg_name: str = "crossval_index",
|
||||
metric_name: str = "val/loss") -> Dict[str, Any]:
|
||||
"""
|
||||
Create a dictionary of arguments to create an Azure ML v2 SDK Sweep job.
|
||||
|
||||
:param num_splits: The number of splits for k-fold cross validation
|
||||
:param cross_val_index_arg_name: The name of the commandline argument that each of the child runs gets, to
|
||||
indicate which split they should work on.
|
||||
:param metric_name: The name of the metric that the HyperDriveConfig will compare runs by. Please note that it is
|
||||
your responsibility to make sure a metric with this name is logged to the Run in your training script
|
||||
:return: A dictionary of arguments and values to pass in to the command job.
|
||||
"""
|
||||
return create_grid_hyperparam_args_v2(values=list(map(str, range(num_splits))),
|
||||
argument_name=cross_val_index_arg_name,
|
||||
metric_name=metric_name)
|
||||
|
||||
|
||||
def create_script_run(snapshot_root_directory: Optional[Path] = None,
|
||||
entry_script: Optional[PathOrString] = None,
|
||||
script_params: Optional[List[str]] = None) -> ScriptRunConfig:
|
||||
|
@ -317,12 +378,196 @@ def create_script_run(snapshot_root_directory: Optional[Path] = None,
|
|||
arguments=script_params)
|
||||
|
||||
|
||||
def _generate_input_dataset_command(input_datasets_v2: Dict[str, Input]) -> str:
|
||||
"""
|
||||
Generate command line arguments to pass AML v2 data assets into a script
|
||||
|
||||
:param input_datasets_v2: A dictionary of Input objects that have been passed into the AML command
|
||||
:return: A string representing the input datasets that the script should expect
|
||||
"""
|
||||
input_cmd = ""
|
||||
for i, (input_data_name, input_dataset_v2) in enumerate(input_datasets_v2.items()):
|
||||
input_name = f"INPUT_{i}"
|
||||
input_str = "${{inputs." + f"{input_name}" + "}}"
|
||||
input_cmd += f" --{input_name}={input_str}"
|
||||
return input_cmd
|
||||
|
||||
|
||||
def _generate_output_dataset_command(output_datasets_v2: Dict[str, Output]) -> str:
|
||||
"""
|
||||
Generate command line arguments to pass AML v2 outputs into a script
|
||||
|
||||
:param output_datasets_v2: A dictionary of Output objects that have been passed into the AML command
|
||||
:return: A string representing the output values that the script should expect
|
||||
"""
|
||||
output_cmd = ""
|
||||
for i, (output_data_name, output_dataset_v2) in enumerate(output_datasets_v2.items()):
|
||||
output_name = f"OUTPUT_{i}"
|
||||
output_str = "${{outputs." + f"{output_name}" + "}}"
|
||||
output_cmd += f" --{output_name}={output_str}"
|
||||
return output_cmd
|
||||
|
||||
|
||||
def submit_run_v2(workspace: Optional[Workspace],
|
||||
experiment_name: str,
|
||||
environment: EnvironmentV2,
|
||||
input_datasets_v2: Optional[Dict[str, Input]] = None,
|
||||
output_datasets_v2: Optional[Dict[str, Output]] = None,
|
||||
snapshot_root_directory: Optional[Path] = None,
|
||||
entry_script: Optional[PathOrString] = None,
|
||||
script_params: Optional[List[str]] = None,
|
||||
compute_target: Optional[str] = None,
|
||||
tags: Optional[Dict[str, str]] = None,
|
||||
wait_for_completion: bool = False,
|
||||
wait_for_completion_show_output: bool = False,
|
||||
workspace_config_path: Optional[PathOrString] = None,
|
||||
ml_client: Optional[MLClient] = None,
|
||||
hyperparam_args: Optional[Dict[str, Any]] = None) -> Job:
|
||||
"""
|
||||
Starts a v2 AML Job on a given workspace by submitting a command
|
||||
|
||||
:param workspace: The AzureML workspace to use.
|
||||
:param experiment_name: The name of the experiment that will be used or created. If the experiment name contains
|
||||
characters that are not valid in Azure, those will be removed.
|
||||
:param environment: An AML v2 Environment object.
|
||||
:param input_datasets_v2: An optional dictionary of Inputs to pass in to the command.
|
||||
:param output_datasets_v2: An optional dictionary of Outputs to pass in to the command.
|
||||
:param snapshot_root_directory: The directory that contains all code that should be packaged and sent to AzureML.
|
||||
All Python code that the script uses must be copied over.
|
||||
:param entry_script: The script that should be run in AzureML.
|
||||
:param script_params: A list of parameter to pass on to the script as it runs in AzureML.
|
||||
:param compute_target: Optional name of a compute target in Azure ML to submit the job to. If None, will run
|
||||
locally.
|
||||
:param tags: A dictionary of string key/value pairs, that will be added as metadata to the run. If set to None,
|
||||
a default metadata field will be added that only contains the commandline arguments that started the run.
|
||||
:param wait_for_completion: If False (the default) return after the run is submitted to AzureML, otherwise wait for
|
||||
the completion of this run (if True).
|
||||
:param wait_for_completion_show_output: If wait_for_completion is True this parameter indicates whether to show the
|
||||
run output on sys.stdout.
|
||||
:param workspace_config_path: If not provided with an AzureML Workspace, then load one given the information in this
|
||||
config
|
||||
:param ml_client: An Azure MLClient object for interacting with Azure resources.
|
||||
:param hyperparam_args: A dictionary of hyperparameter search args to pass into a sweep job.
|
||||
:return: An AzureML Run object.
|
||||
"""
|
||||
if ml_client is None:
|
||||
if workspace is not None:
|
||||
ml_client = get_ml_client(
|
||||
subscription_id=workspace.subscription_id,
|
||||
resource_group=workspace.resource_group,
|
||||
workspace_name=workspace.name
|
||||
)
|
||||
elif workspace_config_path is not None:
|
||||
ml_client = get_ml_client(workspace_config_path=workspace_config_path)
|
||||
else:
|
||||
raise ValueError("Either workspace or workspace_config_path must be specified to connect to the Workspace")
|
||||
|
||||
assert compute_target is not None, "No compute_target has been provided"
|
||||
assert entry_script is not None, "No entry_script has been provided"
|
||||
snapshot_root_directory = snapshot_root_directory or Path.cwd()
|
||||
root_dir = Path(snapshot_root_directory)
|
||||
entry_script = Path(entry_script).relative_to(root_dir).as_posix()
|
||||
|
||||
script_params = script_params or []
|
||||
args = [p for p in script_params if "conda_env" not in p]
|
||||
arg_str = " ".join(args)
|
||||
cmd = "python " + str(entry_script) + " " + arg_str
|
||||
|
||||
if input_datasets_v2:
|
||||
cmd += _generate_input_dataset_command(input_datasets_v2)
|
||||
else:
|
||||
input_datasets_v2 = {}
|
||||
|
||||
if output_datasets_v2:
|
||||
cmd += _generate_output_dataset_command(output_datasets_v2)
|
||||
else:
|
||||
output_datasets_v2 = {}
|
||||
|
||||
job_to_submit: Union[Command, Sweep]
|
||||
|
||||
if hyperparam_args:
|
||||
param_sampling = hyperparam_args[PARAM_SAMPLING_ARG]
|
||||
|
||||
for sample_param, choices in param_sampling.items():
|
||||
input_datasets_v2[sample_param] = choices.values[0]
|
||||
cmd += f" --{sample_param}=" + "${{inputs." + sample_param + "}}"
|
||||
|
||||
command_job = command(
|
||||
code=str(snapshot_root_directory),
|
||||
command=cmd,
|
||||
inputs=input_datasets_v2,
|
||||
outputs=output_datasets_v2,
|
||||
environment=environment.name + "@latest",
|
||||
compute=compute_target,
|
||||
experiment_name=experiment_name,
|
||||
environment_variables={
|
||||
"JOB_EXECUTION_MODE": "Basic",
|
||||
}
|
||||
)
|
||||
|
||||
del hyperparam_args[PARAM_SAMPLING_ARG]
|
||||
# override command with parameter expressions
|
||||
command_job = command_job(
|
||||
**param_sampling,
|
||||
)
|
||||
|
||||
job_to_submit = command_job.sweep(
|
||||
compute=compute_target, # AML docs suggest setting this here although already passed to command
|
||||
**hyperparam_args
|
||||
)
|
||||
|
||||
# AML docs state to reset certain properties here which aren't picked up from the
|
||||
# underlying command such as experiment name and max_total_trials
|
||||
job_to_submit.experiment_name = experiment_name
|
||||
job_to_submit.set_limits(max_total_trials=hyperparam_args.get(MAX_TOTAL_TRIALS_ARG, None))
|
||||
|
||||
else:
|
||||
job_to_submit = command(
|
||||
code=str(snapshot_root_directory),
|
||||
command=cmd,
|
||||
inputs=input_datasets_v2,
|
||||
outputs=output_datasets_v2,
|
||||
environment=environment.name + "@latest",
|
||||
compute=compute_target,
|
||||
experiment_name=experiment_name,
|
||||
environment_variables={
|
||||
"JOB_EXECUTION_MODE": "Basic",
|
||||
}
|
||||
)
|
||||
|
||||
returned_job = ml_client.jobs.create_or_update(job_to_submit)
|
||||
logging.info(f"URL to job: {returned_job.services['Studio'].endpoint}") # type: ignore
|
||||
return returned_job
|
||||
|
||||
|
||||
def download_job_outputs_logs(ml_client: MLClient,
|
||||
job_name: str,
|
||||
file_to_download_path: str = "",
|
||||
download_dir: Optional[PathOrString] = None) -> None:
|
||||
"""
|
||||
Download output files from an mlflow job. Outputs will be downloaded to a folder named
|
||||
`<download_dir>/<job_name>` where download_dir is either provided to this function,
|
||||
or is "outputs". If a single file is required, the path to this file within the job can
|
||||
be specified with 'file_to_download_path'
|
||||
|
||||
:param ml_client: An MLClient object.
|
||||
:param job_name: The name (id) of the job to download output files from.
|
||||
:param file_to_download_path: An optional path to a single file/folder to download.
|
||||
:param download_dir: An optional folder into which to download the run files.
|
||||
"""
|
||||
|
||||
download_dir = Path(download_dir) if download_dir else Path("outputs")
|
||||
download_dir = download_dir / job_name
|
||||
ml_client.jobs.download(job_name, output_name=file_to_download_path, download_path=download_dir)
|
||||
|
||||
|
||||
def submit_run(workspace: Workspace,
|
||||
experiment_name: str,
|
||||
script_run_config: Union[ScriptRunConfig, HyperDriveConfig],
|
||||
tags: Optional[Dict[str, str]] = None,
|
||||
wait_for_completion: bool = False,
|
||||
wait_for_completion_show_output: bool = False, ) -> Run:
|
||||
wait_for_completion_show_output: bool = False,
|
||||
) -> Run:
|
||||
"""
|
||||
Starts an AzureML run on a given workspace, via the script_run_config.
|
||||
|
||||
|
@ -381,11 +626,59 @@ def _str_to_path(s: Optional[PathOrString]) -> Optional[Path]:
|
|||
return s
|
||||
|
||||
|
||||
def create_v2_inputs(ml_client: MLClient, input_datasets: List[DatasetConfig]) -> Dict[str, Input]:
|
||||
"""
|
||||
Create a dictionary of Azure ML v2 Input objects, required for passing input data in to an AML job
|
||||
|
||||
:param ml_client: An MLClient object.
|
||||
:param input_datasets: A list of DatasetConfigs to convert to Inputs.
|
||||
:return: A dictionary in the format "input_name": Input.
|
||||
"""
|
||||
inputs: Dict[str, Input] = {}
|
||||
for i, input_dataset in enumerate(input_datasets):
|
||||
input_name = f"INPUT_{i}"
|
||||
version = input_dataset.version or 1
|
||||
data_asset: Data = ml_client.data.get(input_dataset.name, version=str(version))
|
||||
data_path = data_asset.id or ""
|
||||
# Note that there are alternative formats that the input path can take, such as:
|
||||
# v1_datastore_path = f"azureml://datastores/{input_dataset.datastore}/paths/<path_to_dataset>"
|
||||
# v2_dataset_path = f"azureml:{input_dataset.name}:1"
|
||||
|
||||
inputs[input_name] = Input( # type: ignore
|
||||
type=AssetTypes.URI_FOLDER,
|
||||
path=data_path,
|
||||
mode=InputOutputModes.MOUNT,
|
||||
)
|
||||
return inputs
|
||||
|
||||
|
||||
def create_v2_outputs(output_datasets: List[DatasetConfig]) -> Dict[str, Output]:
|
||||
"""
|
||||
Create a dictionary of Azure ML v2 Output objects, required for passing output data in to an AML job
|
||||
|
||||
:param output_datasets: A list of DatasetConfigs to convert to Outputs.
|
||||
:return: A dictionary in the format "output_name": Output.
|
||||
"""
|
||||
outputs = {}
|
||||
for i, output_dataset in enumerate(output_datasets):
|
||||
output_name = f"OUTPUT_{i}"
|
||||
v1_datastore_path = f"azureml://datastores/{output_dataset.datastore}/paths/{output_dataset.name}"
|
||||
# Note that there are alternative formats that the output path can take, such as:
|
||||
# v2_data_asset_path = f"azureml:{output_dataset.name}@latest"
|
||||
outputs[output_name] = Output( # type: ignore
|
||||
type=AssetTypes.URI_FOLDER,
|
||||
path=v1_datastore_path,
|
||||
mode=InputOutputModes.DIRECT,
|
||||
)
|
||||
return outputs
|
||||
|
||||
|
||||
def submit_to_azure_if_needed( # type: ignore
|
||||
compute_cluster_name: str = "",
|
||||
entry_script: Optional[PathOrString] = None,
|
||||
aml_workspace: Optional[Workspace] = None,
|
||||
workspace_config_file: Optional[PathOrString] = None,
|
||||
ml_client: Optional[MLClient] = None,
|
||||
snapshot_root_directory: Optional[PathOrString] = None,
|
||||
script_params: Optional[List[str]] = None,
|
||||
conda_environment_file: Optional[PathOrString] = None,
|
||||
|
@ -408,7 +701,9 @@ def submit_to_azure_if_needed( # type: ignore
|
|||
tags: Optional[Dict[str, str]] = None,
|
||||
after_submission: Optional[Callable[[Run], None]] = None,
|
||||
hyperdrive_config: Optional[HyperDriveConfig] = None,
|
||||
hyperparam_args: Optional[Dict[str, Any]] = None,
|
||||
create_output_folders: bool = True,
|
||||
strictly_aml_v1: bool = False,
|
||||
) -> AzureRunInfo: # pragma: no cover
|
||||
"""
|
||||
Submit a folder to Azure, if needed and run it.
|
||||
|
@ -434,6 +729,7 @@ def submit_to_azure_if_needed( # type: ignore
|
|||
to pass it in as a parameter.
|
||||
:param workspace_config_file: The 2nd option is to specify the path to the config.json file downloaded from the
|
||||
Azure portal from which we can retrieve the existing Workspace.
|
||||
:param ml_client: An Azure MLClient object for interacting with Azure resources.
|
||||
:param snapshot_root_directory: The directory that contains all code that should be packaged and sent to AzureML.
|
||||
All Python code that the script uses must be copied over.
|
||||
:param ignored_folders: A list of folders to exclude from the snapshot when copying it to AzureML.
|
||||
|
@ -462,6 +758,7 @@ def submit_to_azure_if_needed( # type: ignore
|
|||
will be triggered if the commandline flag '--azureml' is present in sys.argv
|
||||
:param hyperdrive_config: A configuration object for Hyperdrive (hyperparameter search).
|
||||
:param create_output_folders: If True (default), create folders "outputs" and "logs" in the current working folder.
|
||||
:param strictly_aml_v1: If True, use Azure ML SDK v1. Otherwise, attempt to use Azure ML SDK v2.
|
||||
:return: If the script is submitted to AzureML then we terminate python as the script should be executed in AzureML,
|
||||
otherwise we return a AzureRunInfo object.
|
||||
"""
|
||||
|
@ -472,17 +769,18 @@ def submit_to_azure_if_needed( # type: ignore
|
|||
default_datastore_name=default_datastore)
|
||||
cleaned_output_datasets = _replace_string_datasets(output_datasets or [],
|
||||
default_datastore_name=default_datastore)
|
||||
|
||||
# The present function will most likely be called from the script once it is running in AzureML.
|
||||
# The '--azureml' flag will not be present anymore, but we don't want to rely on that. From Run.get_context we
|
||||
# can infer if the present code is running in AzureML.
|
||||
in_azure = is_running_in_azure_ml(RUN_CONTEXT)
|
||||
if in_azure:
|
||||
return _generate_azure_datasets(cleaned_input_datasets, cleaned_output_datasets)
|
||||
|
||||
if strictly_aml_v1:
|
||||
return _generate_azure_datasets(cleaned_input_datasets, cleaned_output_datasets)
|
||||
else:
|
||||
return _generate_v2_azure_datasets(cleaned_input_datasets, cleaned_output_datasets)
|
||||
# This codepath is reached when executing outside AzureML. Here we first check if a script submission to AzureML
|
||||
# is necessary. If not, return to the caller for local execution.
|
||||
if submit_to_azureml is None:
|
||||
submit_to_azureml = AZUREML_COMMANDLINE_FLAG in sys.argv[1:]
|
||||
if not submit_to_azureml:
|
||||
# Set the environment variables for local execution.
|
||||
environment_variables = {
|
||||
|
@ -500,8 +798,10 @@ def submit_to_azure_if_needed( # type: ignore
|
|||
logs_folder.mkdir(exist_ok=True)
|
||||
|
||||
mounted_input_datasets, mount_contexts = setup_local_datasets(cleaned_input_datasets,
|
||||
aml_workspace,
|
||||
workspace_config_path)
|
||||
strictly_aml_v1,
|
||||
aml_workspace=aml_workspace,
|
||||
ml_client=ml_client,
|
||||
workspace_config_path=workspace_config_path)
|
||||
|
||||
return AzureRunInfo(
|
||||
input_datasets=mounted_input_datasets,
|
||||
|
@ -518,6 +818,7 @@ def submit_to_azure_if_needed( # type: ignore
|
|||
snapshot_root_directory = Path.cwd()
|
||||
|
||||
workspace = get_workspace(aml_workspace, workspace_config_path)
|
||||
ml_client = get_ml_client(ml_client=ml_client, aml_workspace=workspace)
|
||||
print(f"Loaded AzureML workspace {workspace.name}")
|
||||
|
||||
if conda_environment_file is None:
|
||||
|
@ -525,47 +826,76 @@ def submit_to_azure_if_needed( # type: ignore
|
|||
print(f"Using the Conda environment from this file: {conda_environment_file}")
|
||||
conda_environment_file = _str_to_path(conda_environment_file)
|
||||
|
||||
run_config = create_run_configuration(
|
||||
workspace=workspace,
|
||||
compute_cluster_name=compute_cluster_name,
|
||||
aml_environment_name=aml_environment_name,
|
||||
conda_environment_file=conda_environment_file,
|
||||
environment_variables=environment_variables,
|
||||
pip_extra_index_url=pip_extra_index_url,
|
||||
private_pip_wheel_path=_str_to_path(private_pip_wheel_path),
|
||||
docker_base_image=docker_base_image,
|
||||
docker_shm_size=docker_shm_size,
|
||||
num_nodes=num_nodes,
|
||||
max_run_duration=max_run_duration,
|
||||
input_datasets=cleaned_input_datasets,
|
||||
output_datasets=cleaned_output_datasets
|
||||
)
|
||||
script_run_config = create_script_run(snapshot_root_directory=snapshot_root_directory,
|
||||
entry_script=entry_script,
|
||||
script_params=script_params)
|
||||
script_run_config.run_config = run_config
|
||||
|
||||
if hyperdrive_config:
|
||||
config_to_submit: Union[ScriptRunConfig, HyperDriveConfig] = hyperdrive_config
|
||||
config_to_submit._run_config = script_run_config
|
||||
else:
|
||||
config_to_submit = script_run_config
|
||||
|
||||
effective_experiment_name = experiment_name or Path(script_run_config.script).stem
|
||||
|
||||
amlignore_path = snapshot_root_directory / AML_IGNORE_FILE
|
||||
lines_to_append = [str(path) for path in (ignored_folders or [])]
|
||||
with append_to_amlignore(
|
||||
amlignore=amlignore_path,
|
||||
lines_to_append=lines_to_append):
|
||||
run = submit_run(workspace=workspace,
|
||||
experiment_name=effective_experiment_name,
|
||||
script_run_config=config_to_submit,
|
||||
tags=tags,
|
||||
wait_for_completion=wait_for_completion,
|
||||
wait_for_completion_show_output=wait_for_completion_show_output)
|
||||
with append_to_amlignore(amlignore=amlignore_path, lines_to_append=lines_to_append):
|
||||
if strictly_aml_v1:
|
||||
run_config = create_run_configuration(
|
||||
workspace=workspace,
|
||||
compute_cluster_name=compute_cluster_name,
|
||||
aml_environment_name=aml_environment_name,
|
||||
conda_environment_file=conda_environment_file,
|
||||
environment_variables=environment_variables,
|
||||
pip_extra_index_url=pip_extra_index_url,
|
||||
private_pip_wheel_path=_str_to_path(private_pip_wheel_path),
|
||||
docker_base_image=docker_base_image,
|
||||
docker_shm_size=docker_shm_size,
|
||||
num_nodes=num_nodes,
|
||||
max_run_duration=max_run_duration,
|
||||
input_datasets=cleaned_input_datasets,
|
||||
output_datasets=cleaned_output_datasets,
|
||||
)
|
||||
script_run_config = create_script_run(snapshot_root_directory=snapshot_root_directory,
|
||||
entry_script=entry_script,
|
||||
script_params=script_params)
|
||||
script_run_config.run_config = run_config
|
||||
|
||||
if after_submission is not None:
|
||||
if hyperdrive_config:
|
||||
config_to_submit: Union[ScriptRunConfig, HyperDriveConfig] = hyperdrive_config
|
||||
config_to_submit._run_config = script_run_config
|
||||
else:
|
||||
config_to_submit = script_run_config
|
||||
|
||||
effective_experiment_name = experiment_name or Path(script_run_config.script).stem
|
||||
|
||||
run = submit_run(workspace=workspace,
|
||||
experiment_name=effective_experiment_name,
|
||||
script_run_config=config_to_submit,
|
||||
tags=tags,
|
||||
wait_for_completion=wait_for_completion,
|
||||
wait_for_completion_show_output=wait_for_completion_show_output)
|
||||
else:
|
||||
|
||||
assert conda_environment_file is not None
|
||||
environment = create_python_environment_v2(
|
||||
conda_environment_file=conda_environment_file,
|
||||
docker_base_image=docker_base_image
|
||||
)
|
||||
if entry_script is None:
|
||||
entry_script = Path(sys.argv[0])
|
||||
script_params = script_params or sys.argv[1:]
|
||||
effective_experiment_name = experiment_name or Path(entry_script).stem
|
||||
|
||||
registered_env = register_environment_v2(environment, ml_client)
|
||||
input_datasets_v2 = create_v2_inputs(ml_client, cleaned_input_datasets)
|
||||
output_datasets_v2 = create_v2_outputs(cleaned_output_datasets)
|
||||
|
||||
run = submit_run_v2(workspace=workspace,
|
||||
input_datasets_v2=input_datasets_v2,
|
||||
output_datasets_v2=output_datasets_v2,
|
||||
experiment_name=effective_experiment_name,
|
||||
environment=registered_env,
|
||||
snapshot_root_directory=snapshot_root_directory,
|
||||
entry_script=entry_script,
|
||||
script_params=script_params,
|
||||
compute_target=compute_cluster_name,
|
||||
tags=tags,
|
||||
wait_for_completion=wait_for_completion,
|
||||
wait_for_completion_show_output=wait_for_completion_show_output,
|
||||
hyperparam_args=hyperparam_args
|
||||
)
|
||||
|
||||
if after_submission is not None and strictly_aml_v1:
|
||||
after_submission(run)
|
||||
exit(0)
|
||||
|
||||
|
@ -584,26 +914,36 @@ def _write_run_recovery_file(run: Run) -> None:
|
|||
|
||||
|
||||
def convert_himl_to_azureml_datasets(
|
||||
cleaned_input_datasets: List[DatasetConfig],
|
||||
cleaned_output_datasets: List[DatasetConfig],
|
||||
workspace: Workspace) -> Tuple[Dict[str, DatasetConsumptionConfig], Dict[str, OutputFileDatasetConfig]]:
|
||||
cleaned_input_datasets: List[DatasetConfig],
|
||||
cleaned_output_datasets: List[DatasetConfig],
|
||||
workspace: Workspace,
|
||||
strictly_aml_v1: bool
|
||||
) -> Tuple[Dict[str, DatasetConsumptionConfig], Dict[str, OutputFileDatasetConfig]]:
|
||||
"""
|
||||
Convert the cleaned input and output datasets into dictionaries of DatasetConsumptionConfigs for use in AzureML.
|
||||
|
||||
:param cleaned_input_datasets: The list of input DatasetConfigs
|
||||
:param cleaned_output_datasets: The list of output DatasetConfigs
|
||||
:param workspace: The AzureML workspace
|
||||
:param strictly_aml_v1: If True, use Azure ML SDK v1 to attempt to find or create and reigster the dataset.
|
||||
Otherwise, attempt to use Azure ML SDK v2.
|
||||
:return: The input and output dictionaries of DatasetConsumptionConfigs.
|
||||
"""
|
||||
inputs = {}
|
||||
for index, d in enumerate(cleaned_input_datasets):
|
||||
consumption = d.to_input_dataset(workspace=workspace, dataset_index=index)
|
||||
if consumption.name in inputs:
|
||||
raise ValueError(f"There is already an input dataset with name '{consumption.name}' set up?")
|
||||
inputs[consumption.name] = consumption
|
||||
for index, input_dataset in enumerate(cleaned_input_datasets):
|
||||
consumption = input_dataset.to_input_dataset(index, workspace, strictly_aml_v1=strictly_aml_v1)
|
||||
if isinstance(consumption, DatasetConsumptionConfig):
|
||||
data_name = consumption.name # type: ignore
|
||||
if data_name in inputs:
|
||||
raise ValueError(f"There is already an input dataset with name '{data_name}' set up?")
|
||||
inputs[data_name] = consumption
|
||||
elif isinstance(consumption, Input):
|
||||
inputs[input_dataset.name] = consumption
|
||||
else:
|
||||
raise ValueError(f"Unrecognised input data type: {type(consumption)}")
|
||||
outputs = {}
|
||||
for index, d in enumerate(cleaned_output_datasets):
|
||||
out = d.to_output_dataset(workspace=workspace, dataset_index=index)
|
||||
for index, output_dataset in enumerate(cleaned_output_datasets):
|
||||
out = output_dataset.to_output_dataset(workspace=workspace, dataset_index=index)
|
||||
if out.name in outputs:
|
||||
raise ValueError(f"There is already an output dataset with name '{out.name}' set up?")
|
||||
outputs[out.name] = out
|
||||
|
@ -659,6 +999,52 @@ def _generate_azure_datasets(
|
|||
logs_folder=Path.cwd() / LOGS_FOLDER)
|
||||
|
||||
|
||||
def _get_dataset_names_from_string(sys_arg: str, pattern: str) -> Path:
|
||||
dataset_string = re.split(pattern, sys_arg)[-1]
|
||||
dataset_path = Path(dataset_string)
|
||||
return dataset_path
|
||||
|
||||
|
||||
def _extract_v2_inputs_outputs_from_args() -> Tuple[List[Path], List[Path]]:
|
||||
"""
|
||||
Extract all command line arguments of the format INPUT_i=path_to_input or OUTPUT_i=path_to_output (where i is any
|
||||
integer) and return a list of the Paths for each.
|
||||
|
||||
:return: A list of Input paths and a list of Output paths
|
||||
"""
|
||||
returned_input_datasets: List[Path] = []
|
||||
returned_output_datasets: List[Path] = []
|
||||
|
||||
for sys_arg in sys.argv:
|
||||
if re.match(V2_INPUT_DATASET_PATTERN, sys_arg):
|
||||
returned_input_datasets += [_get_dataset_names_from_string(sys_arg, V2_INPUT_DATASET_PATTERN)]
|
||||
if re.match(V2_OUTPUT_DATASET_PATTERN, sys_arg):
|
||||
returned_output_datasets += [_get_dataset_names_from_string(sys_arg, V2_OUTPUT_DATASET_PATTERN)]
|
||||
return returned_input_datasets, returned_output_datasets
|
||||
|
||||
|
||||
def _generate_v2_azure_datasets(cleaned_input_datasets: List[DatasetConfig],
|
||||
cleaned_output_datasets: List[DatasetConfig]) -> AzureRunInfo:
|
||||
"""
|
||||
Generate returned datasets when running in AzureML. Assumes this is v2 Job, so we need to get
|
||||
the input datasets from the command line args
|
||||
|
||||
:param cleaned_input_datasets: The list of input dataset configs
|
||||
:param cleaned_output_datasets: The list of output dataset configs
|
||||
:return: The AzureRunInfo containing the AzureML input and output dataset lists etc.
|
||||
"""
|
||||
returned_input_datasets, returned_output_datasets = _extract_v2_inputs_outputs_from_args()
|
||||
|
||||
return AzureRunInfo(
|
||||
input_datasets=returned_input_datasets, # type: ignore
|
||||
output_datasets=returned_output_datasets, # type: ignore
|
||||
mount_contexts=[],
|
||||
run=RUN_CONTEXT,
|
||||
is_running_in_azure_ml=True,
|
||||
output_folder=Path.cwd() / OUTPUT_FOLDER,
|
||||
logs_folder=Path.cwd() / LOGS_FOLDER)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def append_to_amlignore(lines_to_append: List[str], amlignore: Optional[Path] = None) -> Generator:
|
||||
"""
|
||||
|
|
|
@ -6,56 +6,19 @@
|
|||
import param
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from azureml.core import Run
|
||||
|
||||
import health_azure.utils as azure_util
|
||||
from health_azure.himl import RUN_RECOVERY_FILE
|
||||
from health_azure.himl import download_job_outputs_logs
|
||||
|
||||
|
||||
class HimlDownloadConfig(azure_util.AmlRunScriptConfig):
|
||||
output_dir: Path = param.ClassSelector(class_=Path, default=Path(), instantiate=False,
|
||||
output_dir: Path = param.ClassSelector(class_=Path, default=Path("outputs"), instantiate=False,
|
||||
doc="Path to directory to store files downloaded from the AML Run")
|
||||
config_file: Path = param.ClassSelector(class_=Path, default=None, instantiate=False,
|
||||
doc="Path to config.json where Workspace name is defined. If not provided, "
|
||||
"the code will try to locate a config.json file in any of the parent "
|
||||
"folders of the current working directory")
|
||||
|
||||
prefix: str = param.String(default=None, allow_None=True, doc="Optional prefix to filter Run files by")
|
||||
|
||||
|
||||
def retrieve_runs(download_config: HimlDownloadConfig) -> List[Run]:
|
||||
"""
|
||||
Retrieve a list of AML Run objects, given a HimlDownloadConfig object which contains values for either run
|
||||
(one or more run ids), experiment (experiment name) or latest_run_file. If none of these are provided,
|
||||
the parent directories of this script will be searched for a "most_recent_run.txt" file, and the run id will
|
||||
be extracted from there, to retrieve the run object(s). If no Runs are found, a ValueError will be raised.
|
||||
|
||||
:param download_config: A HimlDownloadConfig object containing run information (e.g. run ids or experiment name)
|
||||
:return: List of AML Run objects
|
||||
"""
|
||||
if download_config.run is not None:
|
||||
run_ids: List[str] = download_config.run
|
||||
runs = [azure_util.get_aml_run_from_run_id(r_id) for r_id in run_ids]
|
||||
if len(runs) == 0:
|
||||
raise ValueError(f"Did not find any runs with the given run id(s): {download_config.run}")
|
||||
elif download_config.experiment is not None:
|
||||
runs = azure_util.get_latest_aml_runs_from_experiment(download_config.experiment,
|
||||
download_config.num_runs,
|
||||
download_config.tags,
|
||||
workspace_config_path=download_config.config_file)
|
||||
if len(runs) == 0:
|
||||
raise ValueError(f"Did not find any runs under the given experiment name: {download_config.experiment}")
|
||||
else:
|
||||
most_recent_run_path = download_config.latest_run_file or Path(RUN_RECOVERY_FILE)
|
||||
run_or_recovery_id = azure_util.get_most_recent_run_id(most_recent_run_path)
|
||||
runs = [azure_util.get_aml_run_from_run_id(run_or_recovery_id,
|
||||
workspace_config_path=download_config.config_file)]
|
||||
if len(runs) == 0:
|
||||
raise ValueError(f"Did not find any runs with run id {run_or_recovery_id} as found in"
|
||||
f" {download_config.latest_run_file}")
|
||||
return runs
|
||||
files_to_download: str = param.String(default=None, allow_None=True, doc="Path to the file to download")
|
||||
|
||||
|
||||
def main() -> None: # pragma: no cover
|
||||
|
@ -66,17 +29,17 @@ def main() -> None: # pragma: no cover
|
|||
output_dir = download_config.output_dir
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
runs = retrieve_runs(download_config)
|
||||
files_to_download = download_config.files_to_download
|
||||
|
||||
for run in runs:
|
||||
output_folder = output_dir / run.id
|
||||
|
||||
try: # pragma: no cover
|
||||
azure_util.download_files_from_run_id(run.id, output_folder=output_folder, prefix=download_config.prefix,
|
||||
workspace_config_path=download_config.config_file)
|
||||
print(f"Downloaded file(s) to '{output_folder}'")
|
||||
except Exception as e: # pragma: no cover
|
||||
raise ValueError(f"Failed to download files from run {run.id}: {e}")
|
||||
workspace = azure_util.get_workspace()
|
||||
ml_client = azure_util.get_ml_client(
|
||||
subscription_id=workspace.subscription_id,
|
||||
resource_group=workspace.resource_group,
|
||||
workspace_name=workspace.name
|
||||
)
|
||||
for run_id in download_config.run:
|
||||
download_job_outputs_logs(ml_client, run_id, file_to_download_path=files_to_download, download_dir=output_dir)
|
||||
print("Successfully downloaded output and log files")
|
||||
|
||||
|
||||
if __name__ == "__main__": # pragma: no cover
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
"""
|
||||
Utility functions for interacting with mlflow runs
|
||||
"""
|
||||
from typing import Any, Dict, List
|
||||
from mlflow.client import MlflowClient
|
||||
from mlflow.entities import Run, Metric
|
||||
|
||||
|
||||
def get_mlflow_run(mlflow_client: MlflowClient, mlflow_run_id: str) -> Run:
|
||||
"""
|
||||
Retrieve a Run from an MLFlow client
|
||||
|
||||
:param mlflow_client: An MLflowClient object.
|
||||
:param mlflow_run_id: The id of an mlflow run to retrieve.
|
||||
:return: An mlflow Run object
|
||||
"""
|
||||
mlflow_run = mlflow_client.get_run(mlflow_run_id)
|
||||
return mlflow_run
|
||||
|
||||
|
||||
def get_last_metrics_from_mlflow_run(mlflow_run: Run) -> Dict[str, Any]:
|
||||
"""
|
||||
Retrieve the last logged metrics from an mlflow Run
|
||||
|
||||
:param mlflow_run: the mlflow Run to retrieve metrics from
|
||||
:return: A dictionary of metric_name to value
|
||||
"""
|
||||
metrics = mlflow_run.data.metrics
|
||||
return metrics
|
||||
|
||||
|
||||
def get_metric_from_mlflow_run(mlflow_client: MlflowClient, run_id: str, metric_name: str
|
||||
) -> List[Metric]:
|
||||
"""
|
||||
For a given metric name, get the entire history of logged values from an mlflow Run
|
||||
|
||||
:param mlflow_client: An MLFlowClient object.
|
||||
:param run_id: The id of the run to retrieve the metrics from.
|
||||
:param metric_name: The name of the metric to retrieve values for.
|
||||
:return: A list of mlflow Metric objects representing the all of the values of the given
|
||||
metric throughout the run
|
||||
"""
|
||||
metric_history = mlflow_client.get_metric_history(run_id, metric_name)
|
||||
return metric_history
|
|
@ -27,6 +27,7 @@ from typing import (Any, Callable, DefaultDict, Dict, Generator, Iterable, List,
|
|||
import conda_merge
|
||||
import pandas as pd
|
||||
import param
|
||||
|
||||
from azureml._restclient.constants import RunStatus
|
||||
from azureml.core import Environment, Experiment, Run, Workspace, get_run
|
||||
from azureml.core.authentication import InteractiveLoginAuthentication, ServicePrincipalAuthentication
|
||||
|
@ -35,6 +36,15 @@ from azureml.core.run import _OfflineRun
|
|||
from azureml.data.azure_storage_datastore import AzureBlobDatastore
|
||||
from azureml.train.hyperdrive import HyperDriveRun
|
||||
|
||||
from azure.ai.ml import MLClient
|
||||
from azure.ai.ml.entities import Job
|
||||
from azure.ai.ml.entities import Workspace as WorkspaceV2
|
||||
from azure.ai.ml.entities import Environment as EnvironmentV2
|
||||
from azure.core.credentials import TokenCredential
|
||||
from azure.core.exceptions import ClientAuthenticationError, ResourceNotFoundError
|
||||
from azure.identity import (ClientSecretCredential, DeviceCodeCredential,
|
||||
DefaultAzureCredential, InteractiveBrowserCredential)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
@ -96,6 +106,10 @@ DEFAULT_ENVIRONMENT_VARIABLES = {
|
|||
"AZUREML_COMPUTE_USE_COMMON_RUNTIME": "false",
|
||||
}
|
||||
|
||||
|
||||
V2_INPUT_DATASET_PATTERN = r"--INPUT_\d[=| ]"
|
||||
V2_OUTPUT_DATASET_PATTERN = r"--OUTPUT_\d[=| ]"
|
||||
|
||||
PathOrString = Union[Path, str]
|
||||
|
||||
|
||||
|
@ -702,6 +716,7 @@ def get_workspace(aml_workspace: Optional[Workspace] = None, workspace_config_pa
|
|||
if is_running_in_azure_ml(RUN_CONTEXT):
|
||||
return RUN_CONTEXT.experiment.workspace
|
||||
|
||||
# If aml_workspace has been provided, use that
|
||||
if aml_workspace:
|
||||
return aml_workspace
|
||||
|
||||
|
@ -1040,6 +1055,22 @@ def is_conda_file_with_pip_include(conda_file: Path) -> Tuple[bool, Dict]:
|
|||
return False, conda_yaml
|
||||
|
||||
|
||||
def generate_unique_environment_name(environment_description_string: str) -> str:
|
||||
"""
|
||||
Generates a unique environment name beginning with "HealthML" and ending with a hash string generated
|
||||
from the environment description.
|
||||
|
||||
:param environment_description_string: String to be hashed that should include everything that can
|
||||
reasonably change between environments.
|
||||
:return: A string representing the unique environment name.
|
||||
"""
|
||||
|
||||
sha1 = hashlib.sha1(environment_description_string.encode("utf8"))
|
||||
overall_hash = sha1.hexdigest()[:32]
|
||||
unique_env_name = f"HealthML-{overall_hash}"
|
||||
return unique_env_name
|
||||
|
||||
|
||||
def create_python_environment(
|
||||
conda_environment_file: Path,
|
||||
pip_extra_index_url: str = "",
|
||||
|
@ -1080,8 +1111,7 @@ def create_python_environment(
|
|||
logging.info(f"Added add_private_pip_wheel {private_pip_wheel_path} to AzureML environment.")
|
||||
# Create a name for the environment that will likely uniquely identify it. AzureML does hashing on top of that,
|
||||
# and will re-use existing environments even if they don't have the same name.
|
||||
# Hashing should include everything that can reasonably change. Rely on hashlib here, because the built-in
|
||||
hash_string = "\n".join(
|
||||
env_description_string = "\n".join(
|
||||
[
|
||||
yaml_contents,
|
||||
docker_base_image,
|
||||
|
@ -1096,9 +1126,7 @@ def create_python_environment(
|
|||
)
|
||||
# Python's hash function gives different results for the same string in different python instances,
|
||||
# hence need to use hashlib
|
||||
sha1 = hashlib.sha1(hash_string.encode("utf8"))
|
||||
overall_hash = sha1.hexdigest()[:32]
|
||||
unique_env_name = f"HealthML-{overall_hash}"
|
||||
unique_env_name = generate_unique_environment_name(env_description_string)
|
||||
env = Environment(name=unique_env_name)
|
||||
env.python.conda_dependencies = conda_dependencies
|
||||
if docker_base_image:
|
||||
|
@ -1132,6 +1160,69 @@ def register_environment(workspace: Workspace, environment: Environment) -> Envi
|
|||
return environment.register(workspace)
|
||||
|
||||
|
||||
def create_python_environment_v2(
|
||||
conda_environment_file: Path,
|
||||
pip_extra_index_url: str = "",
|
||||
private_pip_wheel_path: Optional[Path] = None,
|
||||
docker_base_image: str = ""
|
||||
) -> EnvironmentV2:
|
||||
"""
|
||||
Creates a description for the V2 Python execution environment in AzureML, based on the arguments.
|
||||
The environment will have a name that uniquely identifies it (it is based on hashing the contents of the
|
||||
Conda file, the docker base image, environment variables and private wheels.
|
||||
|
||||
:param docker_base_image: The Docker base image that should be used when creating a new Docker image.
|
||||
:param pip_extra_index_url: If provided, use this PIP package index to find additional packages when building
|
||||
the Docker image.
|
||||
:param private_pip_wheel_path: If provided, add this wheel as a private package to the AzureML environment.
|
||||
:param conda_environment_file: The file that contains the Conda environment definition.
|
||||
:return: A v2 Azure ML Environment object
|
||||
"""
|
||||
yaml_contents = conda_environment_file.read_text()
|
||||
environment_description_string = "\n".join(
|
||||
[
|
||||
yaml_contents,
|
||||
docker_base_image,
|
||||
# Changing the index URL can lead to differences in package version resolution
|
||||
pip_extra_index_url,
|
||||
# Use the path of the private wheel as a proxy. This could lead to problems if
|
||||
# a new environment uses the same private wheel file name, but the wheel has different
|
||||
# contents. In hi-ml PR builds, the wheel file name is unique to the build, so it
|
||||
# should not occur there.
|
||||
str(private_pip_wheel_path),
|
||||
]
|
||||
)
|
||||
unique_env_name = generate_unique_environment_name(environment_description_string)
|
||||
environment = EnvironmentV2(
|
||||
image=docker_base_image,
|
||||
name=unique_env_name + "-v2",
|
||||
conda_file=conda_environment_file,
|
||||
)
|
||||
return environment
|
||||
|
||||
|
||||
def register_environment_v2(environment: EnvironmentV2, ml_client: MLClient) -> EnvironmentV2:
|
||||
"""
|
||||
Try to get the v2 AzureML environment by name and version from the AzureML workspace. If it succeeds, return that
|
||||
environment object. If that fails, register the environment with the MLClient.
|
||||
|
||||
:param ml_client: An AzureML MLClient object.
|
||||
:param environment: An AzureML execution environment.
|
||||
:return: A v2 AzureML Environment object. If the environment did already exist on the workspace, returns that,
|
||||
otherwise returns the newly registered environment.
|
||||
"""
|
||||
try:
|
||||
if environment.version:
|
||||
env = ml_client.environments.get(environment.name, environment.version)
|
||||
else:
|
||||
env = ml_client.environments.get(environment.name, label="latest")
|
||||
logging.info(f"Found a registered environment with name {environment.name}, returning that.")
|
||||
except ResourceNotFoundError:
|
||||
logging.info("Didn't find existing environment. Registering a new one.")
|
||||
env = ml_client.environments.create_or_update(environment)
|
||||
return env
|
||||
|
||||
|
||||
def run_duration_string_to_seconds(s: str) -> Optional[int]:
|
||||
"""
|
||||
Parse a string that represents a timespan, and returns it converted into seconds. The string is expected to be
|
||||
|
@ -1309,7 +1400,7 @@ def get_run_file_names(run: Run, prefix: str = "") -> List[str]:
|
|||
:return: A list of paths within the Run's container
|
||||
"""
|
||||
all_files = run.get_file_names()
|
||||
print(f"Selecting files with prefix {prefix}")
|
||||
logging.info(f"Selecting files with prefix {prefix}")
|
||||
return [f for f in all_files if f.startswith(prefix)] if prefix else all_files
|
||||
|
||||
|
||||
|
@ -1429,12 +1520,12 @@ def download_file_if_necessary(run: Run, filename: str, output_file: Path, overw
|
|||
:return: Local path to the downloaded file.
|
||||
"""
|
||||
if not overwrite and output_file.exists():
|
||||
print("File already exists at", output_file)
|
||||
logging.info(f"File already exists at {output_file}")
|
||||
else:
|
||||
output_file.parent.mkdir(exist_ok=True, parents=True)
|
||||
_download_file_from_run(run, filename, output_file, validate_checksum=True)
|
||||
assert output_file.exists()
|
||||
print("File is downloaded at", output_file)
|
||||
logging.info(f"File is downloaded at {output_file}")
|
||||
return output_file
|
||||
|
||||
|
||||
|
@ -2022,3 +2113,219 @@ def check_is_any_of(message: str, actual: Optional[str], valid: Iterable[Optiona
|
|||
if actual not in valid:
|
||||
all_valid = ", ".join(["<None>" if v is None else v for v in valid])
|
||||
raise ValueError("{} must be one of [{}], but got: {}".format(message, all_valid, actual))
|
||||
|
||||
|
||||
def _validate_credential(credential: TokenCredential) -> None:
|
||||
"""
|
||||
Validate credential by attempting to get token. If authentication has been successful, get_token
|
||||
will succeed. Otherwise an exception will be raised
|
||||
|
||||
:param credential: The credential object to validate.
|
||||
"""
|
||||
credential.get_token("https://management.azure.com/.default")
|
||||
|
||||
|
||||
def _get_legitimate_service_principal_credential(tenant_id: str, service_principal_id: str,
|
||||
service_principal_password: str) -> TokenCredential:
|
||||
"""
|
||||
Create a ClientSecretCredential given a tenant id, service principal id and password
|
||||
|
||||
:param tenant_id: The Azure tenant id.
|
||||
:param service_principal_id: The id of an existing Service Principal.
|
||||
:param service_principal_password: The password of an existing Service Principal.
|
||||
:raises ValueError: If the credential cannot be validated (i.e. authentication was unsucessful).
|
||||
:return: The validated credential.
|
||||
"""
|
||||
cred = ClientSecretCredential(tenant_id=tenant_id,
|
||||
client_id=service_principal_id,
|
||||
client_secret=service_principal_password)
|
||||
try:
|
||||
_validate_credential(cred)
|
||||
return cred
|
||||
except ClientAuthenticationError as e:
|
||||
raise ValueError(f"Found environment variables for {ENV_SERVICE_PRINCIPAL_ID}, "
|
||||
f"{ENV_SERVICE_PRINCIPAL_PASSWORD}, and {ENV_TENANT_ID} but was "
|
||||
f"not able to authenticate: {e}")
|
||||
|
||||
|
||||
def _get_legitimate_device_code_credential() -> Optional[TokenCredential]:
|
||||
"""
|
||||
Create a DeviceCodeCredential for interacting with Azure resources. If the credential can't be
|
||||
validated, return None.
|
||||
|
||||
:return: A valid Azure credential.
|
||||
"""
|
||||
cred = DeviceCodeCredential(timeout=60)
|
||||
try:
|
||||
_validate_credential(cred)
|
||||
return cred
|
||||
except ClientAuthenticationError:
|
||||
return None
|
||||
|
||||
|
||||
def _get_legitimate_default_credential() -> Optional[TokenCredential]:
|
||||
"""
|
||||
Create a DefaultAzure credential for interacting with Azure resources. If the credential can't be
|
||||
validated, return None.
|
||||
|
||||
:return: A valid Azure credential.
|
||||
"""
|
||||
cred = DefaultAzureCredential(timeout=60)
|
||||
try:
|
||||
_validate_credential(cred)
|
||||
return cred
|
||||
except ClientAuthenticationError:
|
||||
return None
|
||||
|
||||
|
||||
def _get_legitimate_interactive_browser_credential() -> Optional[TokenCredential]:
|
||||
"""
|
||||
Create an InteractiveBrowser credential for interacting with Azure resources. If the credential can't be
|
||||
validated, return None.
|
||||
|
||||
:return: A valid Azure credential.
|
||||
"""
|
||||
cred = InteractiveBrowserCredential(timeout=60)
|
||||
try:
|
||||
_validate_credential(cred)
|
||||
return cred
|
||||
except ClientAuthenticationError:
|
||||
return None
|
||||
|
||||
|
||||
def get_credential() -> Optional[TokenCredential]:
|
||||
"""
|
||||
Get a credential for authenticating with Azure.There are multiple ways to retrieve a credential.
|
||||
If environment variables pertaining to details of a Service Principal are available, those will be used
|
||||
to authenticate. If no environment variables exist, and the script is not currently
|
||||
running inside of Azure ML or another Azure agent, will attempt to retrieve a credential via a
|
||||
device code (which requires the user to visit a link and enter a provided code). If this fails, or if running in
|
||||
Azure, DefaultAzureCredential will be used which iterates through a number of possible authentication methods
|
||||
including identifying an Azure managed identity, cached credentials from VS code, Azure CLI, Powershell etc.
|
||||
Otherwise returns None.
|
||||
|
||||
:return: Any of the aforementioned credentials if available, else None.
|
||||
"""
|
||||
service_principal_id = get_secret_from_environment(ENV_SERVICE_PRINCIPAL_ID, allow_missing=True)
|
||||
tenant_id = get_secret_from_environment(ENV_TENANT_ID, allow_missing=True)
|
||||
service_principal_password = get_secret_from_environment(ENV_SERVICE_PRINCIPAL_PASSWORD, allow_missing=True)
|
||||
if service_principal_id and tenant_id and service_principal_password:
|
||||
return _get_legitimate_service_principal_credential(tenant_id, service_principal_id, service_principal_password)
|
||||
|
||||
try:
|
||||
cred = _get_legitimate_default_credential()
|
||||
if cred is not None:
|
||||
return cred
|
||||
except ClientAuthenticationError:
|
||||
cred = _get_legitimate_device_code_credential()
|
||||
if cred is not None:
|
||||
return cred
|
||||
|
||||
cred = _get_legitimate_interactive_browser_credential()
|
||||
if cred is not None:
|
||||
return cred
|
||||
|
||||
raise ValueError("Unable to generate and validate a credential. Please see Azure ML documentation"
|
||||
"for instructions on diffrent options to get a credential")
|
||||
|
||||
|
||||
def get_ml_client(ml_client: Optional[MLClient] = None,
|
||||
aml_workspace: Optional[Workspace] = None,
|
||||
workspace_config_path: Optional[PathOrString] = None,
|
||||
subscription_id: Optional[str] = None,
|
||||
resource_group: Optional[str] = None,
|
||||
workspace_name: str = "",
|
||||
) -> MLClient:
|
||||
"""
|
||||
Instantiate an MLClient for interacting with Azure resources via v2 of the Azure ML SDK.
|
||||
If a ml_client is provided, return that. Otherwise, create one using workspace details
|
||||
coming from either an existing Workspace object, a config.json file or passed in as an argument.
|
||||
|
||||
:param ml_client: An optional existing MLClient object to be returned.
|
||||
:param aml_workspace: An optional Workspace object to take connection details from.
|
||||
:param workspace_config_path: An optional path toa config.json file containing details of the Workspace.
|
||||
:param subscription_id: An optional subscription ID.
|
||||
:param resource_group: An optional resource group name.
|
||||
:param workspace_name: An optional workspace name.
|
||||
:return: An instance of MLClient to interact with Azure resources.
|
||||
"""
|
||||
if ml_client:
|
||||
return ml_client
|
||||
|
||||
credential = get_credential()
|
||||
if credential is None:
|
||||
raise ValueError("Can't connect to MLClient without a valid credential")
|
||||
if aml_workspace is not None:
|
||||
ml_client = MLClient(
|
||||
subscription_id=aml_workspace.subscription_id,
|
||||
resource_group_name=aml_workspace.resource_group,
|
||||
workspace_name=aml_workspace.name,
|
||||
credential=credential) # type: ignore
|
||||
elif workspace_config_path:
|
||||
ml_client = MLClient.from_config(
|
||||
credential=credential, # type: ignore
|
||||
path=str(workspace_config_path))
|
||||
elif subscription_id and resource_group and workspace_name:
|
||||
ml_client = MLClient(
|
||||
subscription_id=subscription_id,
|
||||
resource_group_name=resource_group,
|
||||
workspace_name=workspace_name,
|
||||
credential=credential) # type: ignore
|
||||
else:
|
||||
try:
|
||||
workspace = get_workspace()
|
||||
ml_client = MLClient(
|
||||
subscription_id=workspace.subscription_id,
|
||||
resource_group_name=workspace.resource_group,
|
||||
workspace_name=workspace.name,
|
||||
credential=credential) # type: ignore
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Couldn't connect to MLClient: {e}")
|
||||
logging.info(f"Logged into AzureML workspace {ml_client.workspace_name}")
|
||||
return ml_client
|
||||
|
||||
|
||||
def retrieve_workspace_from_client(ml_client: MLClient, workspace_name: Optional[str] = None
|
||||
) -> WorkspaceV2:
|
||||
"""
|
||||
Get a v2 Workspace object from an MLClient object. If a workspace_name is passed, will attempt
|
||||
to retrieve a workspace with that name. Otherweise will use the MLClient's default workspace_name
|
||||
|
||||
:param ml_client: An MLClient object to retrieve the Workspace from
|
||||
:param workspace_name: An optional name of the workspace to retrieve.
|
||||
:return: A v2 Workspace object.
|
||||
"""
|
||||
if workspace_name is not None:
|
||||
workspace_name = workspace_name
|
||||
elif ml_client.workspace_name is not None:
|
||||
workspace_name = ml_client.workspace_name
|
||||
else:
|
||||
workspace_name = ""
|
||||
workspace = ml_client.workspaces.get(workspace_name)
|
||||
return workspace
|
||||
|
||||
|
||||
def fetch_job(ml_client: MLClient, run_id: str) -> Job:
|
||||
"""
|
||||
Retrieve a job with a given run_id from an MLClient
|
||||
|
||||
:param ml_client: An MLClient object.
|
||||
:param run_id: The id of the run to retrieve.
|
||||
:return: An Azure ML (v2) Job object.
|
||||
"""
|
||||
job = ml_client.jobs.get(run_id)
|
||||
return job
|
||||
|
||||
|
||||
def filter_v2_input_output_args(args: List[str]) -> List[str]:
|
||||
"""
|
||||
Filter out AML v2 Input and Output entries from a list of args. Under AML SDK v2 it is necessary to
|
||||
pass input and output arguments to a script via the command line, of which there can be an unknown number.
|
||||
Therefore we need to remove these from the list of args passed to the argument parsers.
|
||||
|
||||
:param args: A list of arguments from which to remove input and output args
|
||||
:return: A filtered list of arguments, without entries in the format of INPUT_i or OUTPUT_i where i is
|
||||
any integer.
|
||||
"""
|
||||
return [a for a in args if
|
||||
not re.match(V2_INPUT_DATASET_PATTERN, a) and not re.match(V2_OUTPUT_DATASET_PATTERN, a)]
|
||||
|
|
|
@ -15,7 +15,7 @@ from enum import Enum
|
|||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import DEFAULT, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
from xmlrpc.client import Boolean
|
||||
|
||||
|
@ -26,15 +26,18 @@ import param
|
|||
import pytest
|
||||
from _pytest.capture import CaptureFixture
|
||||
from _pytest.logging import LogCaptureFixture
|
||||
from azure.identity import (ClientSecretCredential, DeviceCodeCredential, DefaultAzureCredential)
|
||||
from azure.storage.blob import ContainerClient
|
||||
from azureml.core import Experiment, Run, ScriptRunConfig, Workspace
|
||||
from azureml.core.authentication import ServicePrincipalAuthentication
|
||||
from azureml.core.environment import CondaDependencies
|
||||
from azure.core.exceptions import ClientAuthenticationError, ResourceNotFoundError
|
||||
from azureml.data.azure_storage_datastore import AzureBlobDatastore
|
||||
|
||||
import health_azure.utils as util
|
||||
from health_azure.himl import AML_IGNORE_FILE, append_to_amlignore
|
||||
from health_azure.utils import (ENV_MASTER_ADDR, ENV_MASTER_PORT, MASTER_PORT_DEFAULT,
|
||||
PackageDependency, create_argparser)
|
||||
PackageDependency, create_argparser, get_credential)
|
||||
from testazure.test_himl import RunTarget, render_and_run_test_script
|
||||
from testazure.utils_testazure import (DEFAULT_IGNORE_FOLDERS, DEFAULT_WORKSPACE, MockRun, change_working_directory,
|
||||
himl_azure_root, repository_root)
|
||||
|
@ -67,7 +70,7 @@ def test_find_file_in_parent_folders(caplog: LogCaptureFixture) -> None:
|
|||
)
|
||||
last_caplog_msg = caplog.messages[-1]
|
||||
assert found_file_path == current_file_path
|
||||
assert(f"Searching for file {current_file_path.name} in {himl_azure_test_root}" in last_caplog_msg)
|
||||
assert f"Searching for file {current_file_path.name} in {himl_azure_test_root}" in last_caplog_msg
|
||||
|
||||
# Now try to search for a nonexistent path in the same folder. This should return None
|
||||
nonexistent_path = himl_az_root / "idontexist.py"
|
||||
|
@ -652,6 +655,19 @@ def test_nonexisting_amlignore(random_folder: Path) -> None:
|
|||
os.chdir(cwd)
|
||||
|
||||
|
||||
def test_generate_unique_environment_name() -> None:
|
||||
dummy_env_description_string_1 = "A pretend environment description\ncontaining information about pip "
|
||||
"packages\netc etc"
|
||||
env_name_1 = util.generate_unique_environment_name(dummy_env_description_string_1)
|
||||
assert env_name_1.startswith("HealthML-")
|
||||
|
||||
dummy_env_description_string_2 = "A slightly differetpretend environment description\ncontaining "
|
||||
"information about pip packages\netc etc"
|
||||
env_name_2 = util.generate_unique_environment_name(dummy_env_description_string_2)
|
||||
assert env_name_2.startswith("HealthML-")
|
||||
assert env_name_1 != env_name_2
|
||||
|
||||
|
||||
@patch("health_azure.utils.Workspace")
|
||||
def test_create_python_environment(
|
||||
mock_workspace: mock.MagicMock,
|
||||
|
@ -700,6 +716,42 @@ dependencies:
|
|||
assert private_pip_wheel_url in envs_pip_packages
|
||||
|
||||
|
||||
@patch("health_azure.utils.Workspace")
|
||||
def test_create_python_environment_v2(
|
||||
mock_workspace: mock.MagicMock,
|
||||
random_folder: Path,
|
||||
) -> None:
|
||||
conda_str = """name: simple-env
|
||||
dependencies:
|
||||
- pip=20.1.1
|
||||
- python=3.7.3
|
||||
- pip:
|
||||
- azureml-sdk==1.23.0
|
||||
- something-else==0.1.5
|
||||
- pip:
|
||||
- --index-url https://test.pypi.org/simple/
|
||||
- --extra-index-url https://pypi.org/simple
|
||||
- hi-ml-azure
|
||||
"""
|
||||
conda_environment_file = random_folder / "environment.yml"
|
||||
conda_environment_file.write_text(conda_str)
|
||||
env = util.create_python_environment_v2(conda_environment_file=conda_environment_file)
|
||||
|
||||
# Check that the environment has a reasonable name. Detailed checks for uniqueness of the name follow below.
|
||||
assert env.name.startswith("HealthML")
|
||||
assert env.name.endswith("-v2")
|
||||
assert env._conda_file_path == conda_environment_file
|
||||
|
||||
pip_extra_index_url = "https://where.great.packages.live/"
|
||||
docker_base_image = "viennaglobal.azurecr.io/azureml/azureml_a187a87cc7c31ac4d9f67496bc9c8239"
|
||||
env = util.create_python_environment_v2(
|
||||
conda_environment_file=conda_environment_file,
|
||||
pip_extra_index_url=pip_extra_index_url,
|
||||
docker_base_image=docker_base_image)
|
||||
|
||||
assert env.image == docker_base_image
|
||||
|
||||
|
||||
def test_create_environment_unique_name(random_folder: Path) -> None:
|
||||
"""
|
||||
Test if the name of the conda environment changes with each of the components
|
||||
|
@ -815,6 +867,33 @@ def test_register_environment(
|
|||
assert env.version == util.ENVIRONMENT_VERSION
|
||||
|
||||
|
||||
@patch("azure.ai.ml.entities.Environment")
|
||||
@patch("azure.ai.ml.MLClient")
|
||||
def test_register_environment_v2(
|
||||
mock_ml_client: MagicMock,
|
||||
mock_environment_v2: MagicMock,
|
||||
caplog: LogCaptureFixture,
|
||||
) -> None:
|
||||
def _mock_cant_find_env(env_name: str, label_or_version: str) -> None:
|
||||
raise ResourceNotFoundError("Does not exist")
|
||||
|
||||
env_name = "an environment"
|
||||
env_version = "environment version"
|
||||
mock_ml_client.environments.get.return_value = mock_environment_v2
|
||||
mock_environment_v2.name = env_name
|
||||
mock_environment_v2.version = env_version
|
||||
with caplog.at_level(logging.INFO): # type: ignore
|
||||
_ = util.register_environment_v2(mock_environment_v2, mock_ml_client)
|
||||
caplog_text = caplog.text
|
||||
assert f"Found a registered environment with name {env_name}, returning that." in caplog_text
|
||||
|
||||
# test that log is correct when exception is triggered
|
||||
mock_ml_client.environments.get.side_effect = _mock_cant_find_env
|
||||
_ = util.register_environment_v2(mock_environment_v2, mock_ml_client)
|
||||
caplog_text = caplog.text
|
||||
assert "Didn't find existing environment. Registering a new one." in caplog_text
|
||||
|
||||
|
||||
def test_set_environment_variables_for_multi_node(caplog: LogCaptureFixture) -> None:
|
||||
# If none of AZ_BATCHAI_MPI_MASTER_NODE, AZ_BATCH_MASTER_NODE or ENV_MASTER_IP are set, should assume
|
||||
# single node training job
|
||||
|
@ -1267,16 +1346,19 @@ import sys
|
|||
from pathlib import Path
|
||||
from azureml.core import Run
|
||||
from health_azure.utils import download_files_from_run_id""",
|
||||
"body": script_body
|
||||
"body": script_body,
|
||||
}
|
||||
# Run the script locally first, then in the cloud. In local runs, the workspace should be picked up from the
|
||||
# config.json file, in AzureML runs it should be read off the run context.
|
||||
render_and_run_test_script(tmp_path, RunTarget.LOCAL, extra_options, extra_args=[], expected_pass=True)
|
||||
render_and_run_test_script(tmp_path, RunTarget.LOCAL, extra_options,
|
||||
extra_args=[], expected_pass=True)
|
||||
print("Local run finished")
|
||||
render_and_run_test_script(tmp_path / "foo", RunTarget.AZUREML, extra_options, extra_args=[], expected_pass=True)
|
||||
render_and_run_test_script(tmp_path / "foo", RunTarget.AZUREML, extra_options,
|
||||
extra_args=[], expected_pass=True)
|
||||
|
||||
|
||||
def test_replace_directory(tmp_path: Path) -> None:
|
||||
|
||||
extra_options = {
|
||||
"imports": """
|
||||
import sys
|
||||
|
@ -1297,13 +1379,15 @@ from health_azure.utils import replace_directory
|
|||
|
||||
assert not output_dir.exists()
|
||||
assert (new_output_dir / file_name).exists()
|
||||
"""
|
||||
""",
|
||||
}
|
||||
|
||||
render_and_run_test_script(tmp_path, RunTarget.LOCAL, extra_options, extra_args=[], expected_pass=True)
|
||||
render_and_run_test_script(tmp_path, RunTarget.LOCAL, extra_options,
|
||||
extra_args=[], expected_pass=True)
|
||||
print("Local run finished")
|
||||
|
||||
render_and_run_test_script(tmp_path / "foo", RunTarget.AZUREML, extra_options, extra_args=[], expected_pass=True)
|
||||
render_and_run_test_script(tmp_path / "foo", RunTarget.AZUREML, extra_options,
|
||||
extra_args=[], expected_pass=True)
|
||||
|
||||
|
||||
def test_is_global_rank_zero() -> None:
|
||||
|
@ -1346,17 +1430,32 @@ def test_get_run_source(dummy_recovery_id: str,
|
|||
assert isinstance(script_config.run, str)
|
||||
|
||||
|
||||
def delete_existing_blobs(datastore: AzureBlobDatastore, prefix: str) -> None:
|
||||
def get_container_client(datastore: AzureBlobDatastore) -> ContainerClient:
|
||||
"""Gets a ContainerClient to interact with the blobs in the given datastore.
|
||||
|
||||
param datastore: The datastore from which the files should be read.
|
||||
"""
|
||||
return datastore.blob_service.get_container_client(datastore.container_name)
|
||||
|
||||
|
||||
def get_blobs_in_datastore(datastore: AzureBlobDatastore, prefix: str) -> List[Any]:
|
||||
"""Gets all blobs in the datastore where the name starts with the given prefix.
|
||||
|
||||
param datastore: The datastore from which the files should be read.
|
||||
param prefix: The prefix string for the files that should be returned.
|
||||
"""
|
||||
return list(get_container_client(datastore).list_blobs(name_starts_with=prefix))
|
||||
|
||||
|
||||
def delete_blobs_in_datastore(datastore: AzureBlobDatastore, prefix: str) -> None:
|
||||
"""Deletes all existing files in blob storage at the location that the test uses.
|
||||
|
||||
param datastore: The datastore from which the files should be deleted.
|
||||
param prefix: The prefix string for the files that should be deleted.
|
||||
"""
|
||||
container = datastore.container_name
|
||||
existing_blobs = list(datastore.blob_service.list_blobs(prefix=prefix,
|
||||
container_name=container))
|
||||
for existing_blob in existing_blobs:
|
||||
datastore.blob_service.delete_blob(container_name=container, blob_name=existing_blob.name)
|
||||
container_client = get_container_client(datastore)
|
||||
for existing_blob in get_blobs_in_datastore(datastore, prefix):
|
||||
container_client.delete_blob(existing_blob.name)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("overwrite", [True, False])
|
||||
|
@ -1374,7 +1473,7 @@ def test_download_from_datastore(tmp_path: Path, overwrite: bool) -> None:
|
|||
local_data_path.mkdir()
|
||||
test_data_path_remote = "test_data/abc"
|
||||
|
||||
delete_existing_blobs(datastore=default_datastore, prefix=test_data_path_remote)
|
||||
delete_blobs_in_datastore(datastore=default_datastore, prefix=test_data_path_remote)
|
||||
try:
|
||||
# Create dummy data files and upload to datastore (checking they are uploaded)
|
||||
dummy_filenames = []
|
||||
|
@ -1387,8 +1486,7 @@ def test_download_from_datastore(tmp_path: Path, overwrite: bool) -> None:
|
|||
default_datastore.upload(str(local_data_path), test_data_path_remote, overwrite=False)
|
||||
# Wait a bit because there seem to be spurious errors with files not yet existing at this point
|
||||
time.sleep(0.1)
|
||||
existing = list(default_datastore.blob_service.list_blobs(prefix=test_data_path_remote,
|
||||
container_name=default_datastore.container_name))
|
||||
existing = get_blobs_in_datastore(default_datastore, prefix=test_data_path_remote)
|
||||
assert len(existing) == num_dummy_files
|
||||
|
||||
# Check that the file doesn't currently exist at download location
|
||||
|
@ -1403,7 +1501,7 @@ def test_download_from_datastore(tmp_path: Path, overwrite: bool) -> None:
|
|||
expected_download_paths = [expected_local_download_dir / dummy_filename for dummy_filename in dummy_filenames]
|
||||
assert all([p.exists() for p in expected_download_paths])
|
||||
finally:
|
||||
delete_existing_blobs(datastore=default_datastore, prefix=test_data_path_remote)
|
||||
delete_blobs_in_datastore(datastore=default_datastore, prefix=test_data_path_remote)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("overwrite", [True, False])
|
||||
|
@ -1416,14 +1514,13 @@ def test_upload_to_datastore(tmp_path: Path, overwrite: bool) -> None:
|
|||
"""
|
||||
ws = DEFAULT_WORKSPACE.workspace
|
||||
default_datastore: AzureBlobDatastore = ws.get_default_datastore()
|
||||
container = default_datastore.container_name
|
||||
dummy_file_content = "Hello world"
|
||||
|
||||
remote_data_dir = "test_data"
|
||||
dummy_file_name = Path("abc/uploaded_file.txt")
|
||||
expected_remote_path = Path(remote_data_dir) / dummy_file_name.name
|
||||
|
||||
delete_existing_blobs(datastore=default_datastore, prefix=str(expected_remote_path.as_posix()))
|
||||
delete_blobs_in_datastore(datastore=default_datastore, prefix=str(expected_remote_path.as_posix()))
|
||||
|
||||
try:
|
||||
# Create a dummy data file and upload to datastore
|
||||
|
@ -1436,11 +1533,10 @@ def test_upload_to_datastore(tmp_path: Path, overwrite: bool) -> None:
|
|||
# Wait a bit because there seem to be spurious errors with files not yet existing at this point
|
||||
time.sleep(0.1)
|
||||
|
||||
existing_blobs = list(default_datastore.blob_service.list_blobs(prefix=str(expected_remote_path.as_posix()),
|
||||
container_name=container))
|
||||
existing_blobs = get_blobs_in_datastore(default_datastore, prefix=str(expected_remote_path.as_posix()))
|
||||
assert len(existing_blobs) == 1
|
||||
finally:
|
||||
delete_existing_blobs(datastore=default_datastore, prefix=str(expected_remote_path.as_posix()))
|
||||
delete_blobs_in_datastore(datastore=default_datastore, prefix=str(expected_remote_path.as_posix()))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("arguments, run_id", [
|
||||
|
@ -2263,3 +2359,150 @@ def test_create_run() -> None:
|
|||
finally:
|
||||
if run is not None:
|
||||
run.complete()
|
||||
|
||||
|
||||
def test_get_credential() -> None:
|
||||
def _mock_validation_error() -> None:
|
||||
raise ClientAuthenticationError("")
|
||||
# test the case where service principal credentials are set as environment variables
|
||||
mock_env_vars = {
|
||||
util.ENV_SERVICE_PRINCIPAL_ID: "foo",
|
||||
util.ENV_TENANT_ID: "bar",
|
||||
util.ENV_SERVICE_PRINCIPAL_PASSWORD: "baz"
|
||||
}
|
||||
|
||||
with patch.object(os.environ, "get", return_value=mock_env_vars):
|
||||
with patch.multiple(
|
||||
"health_azure.utils",
|
||||
is_running_in_azure_ml=DEFAULT,
|
||||
is_running_on_azure_agent=DEFAULT,
|
||||
_get_legitimate_service_principal_credential=DEFAULT,
|
||||
_get_legitimate_device_code_credential=DEFAULT,
|
||||
_get_legitimate_default_credential=DEFAULT,
|
||||
_get_legitimate_interactive_browser_credential=DEFAULT
|
||||
) as mocks:
|
||||
mocks["is_running_in_azure_ml"].return_value = False
|
||||
mocks["is_running_on_azure_agent"].return_value = False
|
||||
_ = get_credential()
|
||||
mocks["_get_legitimate_service_principal_credential"].assert_called_once()
|
||||
mocks["_get_legitimate_device_code_credential"].assert_not_called()
|
||||
mocks["_get_legitimate_default_credential"].assert_not_called()
|
||||
mocks["_get_legitimate_interactive_browser_credential"].assert_not_called()
|
||||
|
||||
# if the environment variables are not set and we are running on a local machine, a
|
||||
# DefaultAzureCredential should be attempted first
|
||||
with patch.object(os.environ, "get", return_value={}):
|
||||
with patch.multiple(
|
||||
"health_azure.utils",
|
||||
is_running_in_azure_ml=DEFAULT,
|
||||
is_running_on_azure_agent=DEFAULT,
|
||||
_get_legitimate_service_principal_credential=DEFAULT,
|
||||
_get_legitimate_device_code_credential=DEFAULT,
|
||||
_get_legitimate_default_credential=DEFAULT,
|
||||
_get_legitimate_interactive_browser_credential=DEFAULT
|
||||
) as mocks:
|
||||
mock_get_sp_cred = mocks["_get_legitimate_service_principal_credential"]
|
||||
mock_get_device_cred = mocks["_get_legitimate_device_code_credential"]
|
||||
mock_get_default_cred = mocks["_get_legitimate_default_credential"]
|
||||
mock_get_browser_cred = mocks["_get_legitimate_interactive_browser_credential"]
|
||||
|
||||
mocks["is_running_in_azure_ml"].return_value = False
|
||||
mocks["is_running_on_azure_agent"].return_value = False
|
||||
_ = get_credential()
|
||||
mock_get_sp_cred.assert_not_called()
|
||||
mock_get_device_cred.assert_not_called()
|
||||
mock_get_default_cred.assert_called_once()
|
||||
mock_get_browser_cred.assert_not_called()
|
||||
|
||||
# if that fails, a DeviceCode credential should be attempted
|
||||
mock_get_default_cred.side_effect = _mock_validation_error
|
||||
_ = get_credential()
|
||||
mock_get_sp_cred.assert_not_called()
|
||||
mock_get_device_cred.assert_called_once()
|
||||
assert mock_get_default_cred.call_count == 2
|
||||
mock_get_browser_cred.assert_not_called()
|
||||
|
||||
# if None of the previous credentials work, an InteractiveBrowser credential should be tried
|
||||
mock_get_device_cred.return_value = None
|
||||
_ = get_credential()
|
||||
mock_get_sp_cred.assert_not_called()
|
||||
assert mock_get_device_cred.call_count == 2
|
||||
assert mock_get_default_cred.call_count == 3
|
||||
mock_get_browser_cred.assert_called_once()
|
||||
|
||||
# finally, if none of the methods work, an Exception should be raised
|
||||
mock_get_browser_cred.return_value = None
|
||||
with pytest.raises(Exception) as e:
|
||||
get_credential()
|
||||
assert "Unable to generate and validate a credential. Please see Azure ML documentation"\
|
||||
"for instructions on different options to get a credential" in str(e)
|
||||
|
||||
|
||||
def test_get_legitimate_service_principal_credential() -> None:
|
||||
# first attempt to create and valiadate a credential with non-existant service principal credentials
|
||||
# and check it fails
|
||||
mock_service_principal_id = "foo"
|
||||
mock_service_principal_password = "bar"
|
||||
mock_tenant_id = "baz"
|
||||
expected_error_msg = f"Found environment variables for {util.ENV_SERVICE_PRINCIPAL_ID}, "
|
||||
f"{util.ENV_SERVICE_PRINCIPAL_PASSWORD}, and {util.ENV_TENANT_ID} but was not able to authenticate"
|
||||
with pytest.raises(Exception) as e:
|
||||
util._get_legitimate_service_principal_credential(mock_tenant_id, mock_service_principal_id,
|
||||
mock_service_principal_password)
|
||||
assert expected_error_msg in str(e)
|
||||
|
||||
# now mock the case where validating the credential succeeds and check the value of that
|
||||
with patch("health_azure.utils._validate_credential"):
|
||||
cred = util._get_legitimate_service_principal_credential(mock_tenant_id, mock_service_principal_id,
|
||||
mock_service_principal_password)
|
||||
assert isinstance(cred, ClientSecretCredential)
|
||||
|
||||
|
||||
def test_get_legitimate_device_code_credential() -> None:
|
||||
def _mock_credential_fast_timeout(timeout: int) -> DeviceCodeCredential:
|
||||
return DeviceCodeCredential(timeout=1)
|
||||
|
||||
with patch("health_azure.utils.DeviceCodeCredential", new=_mock_credential_fast_timeout):
|
||||
cred = util._get_legitimate_device_code_credential()
|
||||
assert cred is None
|
||||
|
||||
# now mock the case where validating the credential succeeds
|
||||
with patch("health_azure.utils._validate_credential"):
|
||||
cred = util._get_legitimate_device_code_credential()
|
||||
assert isinstance(cred, DeviceCodeCredential)
|
||||
|
||||
|
||||
def test_get_legitimate_default_credential() -> None:
|
||||
def _mock_credential_fast_timeout(timeout: int) -> DefaultAzureCredential:
|
||||
return DefaultAzureCredential(timeout=1)
|
||||
|
||||
with patch("health_azure.utils.DefaultAzureCredential", new=_mock_credential_fast_timeout):
|
||||
cred = util._get_legitimate_default_credential()
|
||||
assert cred is None
|
||||
|
||||
with patch("health_azure.utils._validate_credential"):
|
||||
cred = util._get_legitimate_default_credential()
|
||||
assert isinstance(cred, DefaultAzureCredential)
|
||||
|
||||
|
||||
def test_filter_v2_input_output_args() -> None:
|
||||
def _compare_args(expected: List[str], actual: List[str]) -> None:
|
||||
assert len(actual) == len(expected)
|
||||
for actual_entry in actual:
|
||||
assert actual_entry in expected
|
||||
|
||||
args_to_filter = ["--a=foo", "--INPUT_0=input0", "--b=bar", "--INPUT_1=input1"]
|
||||
expected_filtered = ["--a=foo", "--b=bar"]
|
||||
actual_filtered = util.filter_v2_input_output_args(args_to_filter)
|
||||
_compare_args(expected_filtered, actual_filtered)
|
||||
|
||||
# try passing empty list
|
||||
empty_list: List[str] = []
|
||||
actual_filtered = util.filter_v2_input_output_args(empty_list)
|
||||
assert actual_filtered == empty_list
|
||||
|
||||
# pass args with similar but different input and output args
|
||||
args_to_filter = ["--input_0=input0", "--a=foo"]
|
||||
expected_filtered = ["--input_0=input0", "--a=foo"]
|
||||
actual_filtered = util.filter_v2_input_output_args(args_to_filter)
|
||||
_compare_args(expected_filtered, actual_filtered)
|
||||
|
|
|
@ -110,6 +110,8 @@ def render_test_script(entry_script_path: Path, extra_options: Dict[str, str],
|
|||
default_options['args'] = ''
|
||||
default_options['body'] = ''
|
||||
default_options["tags"] = '{}'
|
||||
default_options["strictly_aml_v1"] = 'True'
|
||||
default_options["submit_to_azureml"] = 'False'
|
||||
|
||||
all_options = dict(default_options, **extra_options)
|
||||
|
||||
|
|
6
hi-ml-azure/testazure/testazure/test_data/simple/hello_world_template.txt
Normal file → Executable file
6
hi-ml-azure/testazure/testazure/test_data/simple/hello_world_template.txt
Normal file → Executable file
|
@ -1,3 +1,5 @@
|
|||
#! /usr/bin/env python
|
||||
|
||||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
|
@ -49,7 +51,9 @@ def main() -> None:
|
|||
output_datasets={{ output_datasets }},
|
||||
wait_for_completion={{ wait_for_completion }},
|
||||
wait_for_completion_show_output={{ wait_for_completion_show_output }},
|
||||
tags={{ tags }})
|
||||
tags={{ tags }},
|
||||
submit_to_azureml={{ submit_to_azureml}},
|
||||
strictly_aml_v1={{ strictly_aml_v1 }})
|
||||
|
||||
parser = ArgumentParser()
|
||||
{{ args }}
|
||||
|
|
|
@ -6,13 +6,17 @@
|
|||
Test the data input and output functionality
|
||||
"""
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
from health_azure.utils import PathOrString
|
||||
from unittest.mock import create_autospec, DEFAULT, MagicMock, patch
|
||||
from health_azure.utils import PathOrString, get_ml_client
|
||||
from typing import List, Union, Optional
|
||||
|
||||
import pytest
|
||||
from azure.ai.ml import MLClient
|
||||
from azure.ai.ml.entities import Data
|
||||
from azure.ai.ml.operations import DatastoreOperations
|
||||
from azure.core.exceptions import HttpResponseError
|
||||
from azureml._restclient.exceptions import ServiceException
|
||||
from azureml.core import Dataset
|
||||
from azureml.core import Dataset, Workspace
|
||||
from azureml.data import FileDataset, OutputFileDatasetConfig
|
||||
from azureml.data.azure_storage_datastore import AzureBlobDatastore
|
||||
from azureml.data.dataset_consumption_config import DatasetConsumptionConfig
|
||||
|
@ -20,7 +24,8 @@ from azureml.exceptions._azureml_exception import UserErrorException
|
|||
|
||||
from health_azure.datasets import (DatasetConfig, _input_dataset_key, _output_dataset_key,
|
||||
_replace_string_datasets, get_datastore, get_or_create_dataset,
|
||||
create_dataset_configs)
|
||||
create_dataset_configs, _get_or_create_v1_dataset, _get_or_create_v2_dataset,
|
||||
_retrieve_v1_dataset, _create_v1_dataset, _retrieve_v2_dataset, _create_v2_dataset)
|
||||
from testazure.utils_testazure import DEFAULT_DATASTORE, DEFAULT_WORKSPACE
|
||||
|
||||
|
||||
|
@ -56,11 +61,39 @@ def test_get_datastore() -> None:
|
|||
# Now mock the datastores property of the workspace, to pretend there is only a single datastore.
|
||||
# With that in place, we can get the datastore without the name
|
||||
faked_stores = {name: datastore}
|
||||
with mock.patch("azureml.core.Workspace.datastores", faked_stores):
|
||||
with patch("azureml.core.Workspace.datastores", faked_stores):
|
||||
single_store = get_datastore(workspace=workspace, datastore_name="")
|
||||
assert isinstance(single_store, AzureBlobDatastore)
|
||||
assert single_store.name == name
|
||||
|
||||
# Now test retrieving a v2 datastore by name
|
||||
mock_v2_dataset_name = "dummy_v2_datastore"
|
||||
mock_returned_datastore = MagicMock()
|
||||
mock_returned_datastore.name = mock_v2_dataset_name
|
||||
mock_workspace = MagicMock()
|
||||
mock_workspace.datastores = create_autospec(DatastoreOperations)
|
||||
mock_workspace.datastores.get.return_value = mock_returned_datastore
|
||||
v2_datastore = get_datastore(mock_workspace, datastore_name=mock_v2_dataset_name)
|
||||
assert v2_datastore.name == mock_v2_dataset_name
|
||||
|
||||
# Test retrieving a default v2 datastore
|
||||
mock_workspace.datastores.list.return_value = [mock_returned_datastore]
|
||||
v2_default_datastore = get_datastore(mock_workspace, datastore_name="")
|
||||
assert v2_default_datastore.name == mock_v2_dataset_name
|
||||
|
||||
# Mock case where list is empty but get_default returns a value
|
||||
mock_workspace.datastores.list.return_value = []
|
||||
mock_workspace.datastores.get_default.return_value = mock_returned_datastore
|
||||
v2_default_datastore = get_datastore(mock_workspace, datastore_name="")
|
||||
assert v2_default_datastore.name == mock_v2_dataset_name
|
||||
|
||||
# If datastores has an unknown format, an exception should be raised
|
||||
mock_workspace = MagicMock()
|
||||
mock_workspace.datastores.return_value = ["dummy_datastore_name"]
|
||||
with pytest.raises(Exception) as e:
|
||||
get_datastore(workspace=mock_workspace, datastore_name="")
|
||||
assert "Unrecognised type for datastores" in str(e)
|
||||
|
||||
|
||||
def test_dataset_input() -> None:
|
||||
"""
|
||||
|
@ -69,21 +102,23 @@ def test_dataset_input() -> None:
|
|||
workspace = DEFAULT_WORKSPACE.workspace
|
||||
# This dataset must exist in the workspace already, or at least in blob storage.
|
||||
dataset_config = DatasetConfig(name="hello_world", datastore=DEFAULT_DATASTORE)
|
||||
aml_dataset = dataset_config.to_input_dataset(workspace=workspace, dataset_index=1)
|
||||
aml_dataset = dataset_config.to_input_dataset(dataset_index=1, workspace=workspace, strictly_aml_v1=True)
|
||||
assert isinstance(aml_dataset, DatasetConsumptionConfig)
|
||||
assert aml_dataset.path_on_compute is None
|
||||
assert aml_dataset.mode == "download"
|
||||
assert aml_dataset.path_on_compute is None # type: ignore
|
||||
assert aml_dataset.mode == "download" # type: ignore
|
||||
# Downloading or mounting to a given path
|
||||
target_folder = "/tmp/foo"
|
||||
dataset_config = DatasetConfig(name="hello_world", datastore=DEFAULT_DATASTORE, target_folder=target_folder)
|
||||
aml_dataset = dataset_config.to_input_dataset(workspace=workspace, dataset_index=1)
|
||||
aml_dataset = dataset_config.to_input_dataset(
|
||||
dataset_index=1, workspace=workspace, strictly_aml_v1=True)
|
||||
assert isinstance(aml_dataset, DatasetConsumptionConfig)
|
||||
assert aml_dataset.path_on_compute == target_folder
|
||||
assert aml_dataset.path_on_compute == target_folder # type: ignore
|
||||
# Use mounting instead of downloading
|
||||
dataset_config = DatasetConfig(name="hello_world", datastore=DEFAULT_DATASTORE, use_mounting=True)
|
||||
aml_dataset = dataset_config.to_input_dataset(workspace=workspace, dataset_index=1)
|
||||
aml_dataset = dataset_config.to_input_dataset(
|
||||
dataset_index=1, workspace=workspace, strictly_aml_v1=True)
|
||||
assert isinstance(aml_dataset, DatasetConsumptionConfig)
|
||||
assert aml_dataset.mode == "mount"
|
||||
assert aml_dataset.mode == "mount" # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.parametrize("target_folder", [
|
||||
|
@ -97,7 +132,8 @@ def test_dataset_input_target_empty(target_folder: PathOrString) -> None:
|
|||
workspace = DEFAULT_WORKSPACE.workspace
|
||||
# This dataset must exist in the workspace already, or at least in blob storage.
|
||||
dataset_config = DatasetConfig(name="hello_world", datastore=DEFAULT_DATASTORE, target_folder=target_folder)
|
||||
aml_dataset = dataset_config.to_input_dataset(workspace=workspace, dataset_index=1)
|
||||
aml_dataset: DatasetConsumptionConfig = dataset_config.to_input_dataset(
|
||||
workspace=workspace, dataset_index=1, strictly_aml_v1=True)
|
||||
assert isinstance(aml_dataset, DatasetConsumptionConfig)
|
||||
assert aml_dataset.path_on_compute is None
|
||||
|
||||
|
@ -159,21 +195,138 @@ def test_datasets_from_string() -> None:
|
|||
assert replaced[1] == original[1]
|
||||
|
||||
|
||||
def test_get_dataset() -> None:
|
||||
"""
|
||||
Test if a dataset that does not yet exist can be created from a folder in blob storage
|
||||
"""
|
||||
# A folder with a single tiny file
|
||||
tiny_dataset = "himl_tiny_dataset"
|
||||
def test_get_or_create_dataset() -> None:
|
||||
def _mock_retrieve_or_create_v2_dataset_fails(datastore_name: str, dataset_name: str, ml_client: MLClient) -> None:
|
||||
raise HttpResponseError("Cannot create v2 Data Version in v1 Data Container")
|
||||
|
||||
data_asset_name = "himl_tiny_data_asset"
|
||||
workspace = DEFAULT_WORKSPACE.workspace
|
||||
ml_client = get_ml_client(aml_workspace=workspace)
|
||||
# When creating a dataset, we need a non-empty name
|
||||
with pytest.raises(ValueError) as ex:
|
||||
get_or_create_dataset(workspace=workspace,
|
||||
datastore_name=DEFAULT_DATASTORE,
|
||||
dataset_name="")
|
||||
ml_client=ml_client,
|
||||
datastore_name="himldatasetsv2",
|
||||
dataset_name="",
|
||||
strictly_aml_v1=True)
|
||||
assert "No dataset name" in str(ex)
|
||||
# Check first that there is no dataset yet of that name. If there is, delete that dataset (it would come
|
||||
# from previous runs of this test)
|
||||
|
||||
# pass strictly_aml_v1 = True and check the expected function is called
|
||||
mock_v1_dataset = "v1_dataset"
|
||||
with patch.multiple("health_azure.datasets",
|
||||
_get_or_create_v1_dataset=DEFAULT,
|
||||
_get_or_create_v2_dataset=DEFAULT) as mocks:
|
||||
mocks["_get_or_create_v1_dataset"].return_value = mock_v1_dataset
|
||||
dataset = get_or_create_dataset(workspace=workspace,
|
||||
ml_client=ml_client,
|
||||
datastore_name="himldatasetsv2",
|
||||
dataset_name=data_asset_name,
|
||||
strictly_aml_v1=True)
|
||||
mocks["_get_or_create_v1_dataset"].assert_called_once()
|
||||
mocks["_get_or_create_v2_dataset"].assert_not_called()
|
||||
assert dataset == mock_v1_dataset
|
||||
|
||||
# Now pass strictly_aml_v1 as False
|
||||
mock_v2_dataset = "v2_dataset"
|
||||
mocks["_get_or_create_v2_dataset"].return_value = mock_v2_dataset
|
||||
dataset = get_or_create_dataset(workspace=workspace,
|
||||
ml_client=ml_client,
|
||||
datastore_name="himldatasetsv2",
|
||||
dataset_name=data_asset_name,
|
||||
strictly_aml_v1=False)
|
||||
mocks["_get_or_create_v1_dataset"].assert_called_once()
|
||||
mocks["_get_or_create_v2_dataset"].assert_called_once()
|
||||
assert dataset == mock_v2_dataset
|
||||
|
||||
# if trying to get or create a v2 dataset fails, should revert back to _get_or_create_v1_dataset
|
||||
mocks["_get_or_create_v2_dataset"].side_effect = _mock_retrieve_or_create_v2_dataset_fails
|
||||
dataset = get_or_create_dataset(workspace=workspace,
|
||||
ml_client=ml_client,
|
||||
datastore_name="himldatasetsv2",
|
||||
dataset_name=data_asset_name,
|
||||
strictly_aml_v1=False)
|
||||
assert mocks["_get_or_create_v1_dataset"].call_count == 2
|
||||
assert mocks["_get_or_create_v2_dataset"].call_count == 2
|
||||
assert dataset == mock_v1_dataset
|
||||
|
||||
|
||||
def test_get_or_create_v1_dataset() -> None:
|
||||
def _mock_error_from_retrieve_v1_dataset(dataset_name: str, workspace: Workspace) -> None:
|
||||
raise UserErrorException("Error Message")
|
||||
|
||||
workspace = DEFAULT_WORKSPACE.workspace
|
||||
datastore = workspace.get_default_datastore()
|
||||
dataset_name = "foo"
|
||||
|
||||
with patch.multiple("health_azure.datasets",
|
||||
_retrieve_v1_dataset=DEFAULT,
|
||||
_create_v1_dataset=DEFAULT) as mocks:
|
||||
_get_or_create_v1_dataset(datastore, dataset_name, workspace)
|
||||
mocks["_retrieve_v1_dataset"].assert_called_once()
|
||||
mocks["_create_v1_dataset"].assert_not_called()
|
||||
|
||||
mocks["_retrieve_v1_dataset"].side_effect = _mock_error_from_retrieve_v1_dataset
|
||||
_get_or_create_v1_dataset(datastore, dataset_name, workspace)
|
||||
assert mocks["_retrieve_v1_dataset"].call_count == 2
|
||||
mocks["_create_v1_dataset"].assert_called_once()
|
||||
|
||||
|
||||
def test_get_or_create_v2_dataset() -> None:
|
||||
def _mock_error_from_retrieve_v2_dataset(dataset_name: str, workspace: Workspace) -> None:
|
||||
raise Exception("Error Message")
|
||||
|
||||
ml_client = MagicMock()
|
||||
datastore = "dummy_datastore"
|
||||
dataset_name = "foo"
|
||||
|
||||
with patch.multiple("health_azure.datasets",
|
||||
_retrieve_v2_dataset=DEFAULT,
|
||||
_create_v2_dataset=DEFAULT) as mocks:
|
||||
_get_or_create_v2_dataset(datastore, dataset_name, ml_client)
|
||||
mocks["_retrieve_v2_dataset"].assert_called_once()
|
||||
mocks["_create_v2_dataset"].assert_not_called()
|
||||
|
||||
mocks["_retrieve_v2_dataset"].side_effect = _mock_error_from_retrieve_v2_dataset
|
||||
_get_or_create_v2_dataset(datastore, dataset_name, ml_client)
|
||||
assert mocks["_retrieve_v2_dataset"].call_count == 2
|
||||
mocks["_create_v2_dataset"].assert_called_once()
|
||||
|
||||
|
||||
def test_retrieve_v1_dataset() -> None:
|
||||
nonexistent_dataset = "idontexist"
|
||||
workspace = DEFAULT_WORKSPACE.workspace
|
||||
# patch get_by_name to ensure it is called
|
||||
with patch("azureml.core.Dataset.get_by_name") as mock_get_dataset:
|
||||
_retrieve_v1_dataset(nonexistent_dataset, workspace)
|
||||
mock_get_dataset.assert_called_once()
|
||||
|
||||
# Expect a ValueError to be raised if the dataset doesnt exist
|
||||
with pytest.raises(Exception) as e:
|
||||
_retrieve_v1_dataset(nonexistent_dataset, workspace)
|
||||
assert "Cannot find dataset registered with name \"idontexist\"" in str(e)
|
||||
|
||||
|
||||
def test_create_v1_dataset() -> None:
|
||||
# If dataset_name or datastore_name are empty strings expect an Exception
|
||||
empty_dataset_name = ""
|
||||
empty_datastore_name = ""
|
||||
nonempty_dataset_name = "foo"
|
||||
nonempty_datastore_name = "bar"
|
||||
|
||||
workspace = DEFAULT_WORKSPACE.workspace
|
||||
tiny_dataset = "himl_tiny_dataset"
|
||||
|
||||
with pytest.raises(Exception) as e:
|
||||
_create_v1_dataset(empty_datastore_name, nonempty_dataset_name, workspace)
|
||||
expected_str = "Cannot create dataset without a valid datastore name (received '') and a valid dataset name"
|
||||
f" (received '{nonempty_dataset_name}')"
|
||||
assert expected_str in str(e)
|
||||
|
||||
_create_v1_dataset(nonempty_datastore_name, empty_dataset_name, workspace)
|
||||
expected_str = f"Cannot create dataset without a valid datastore name (received '{empty_dataset_name}') and "
|
||||
"a valid dataset name (received '')"
|
||||
assert expected_str in str(e)
|
||||
|
||||
try:
|
||||
existing_dataset = Dataset.get_by_name(workspace, name=tiny_dataset)
|
||||
try:
|
||||
|
@ -183,10 +336,10 @@ def test_get_dataset() -> None:
|
|||
pass
|
||||
except Exception as ex:
|
||||
assert "Cannot find dataset registered" in str(ex)
|
||||
dataset = get_or_create_dataset(workspace=workspace,
|
||||
datastore_name=DEFAULT_DATASTORE,
|
||||
dataset_name=tiny_dataset)
|
||||
|
||||
dataset = _create_v1_dataset(DEFAULT_DATASTORE, tiny_dataset, workspace)
|
||||
assert isinstance(dataset, FileDataset)
|
||||
|
||||
# We should now be able to get that dataset without special means
|
||||
dataset2 = Dataset.get_by_name(workspace, name=tiny_dataset)
|
||||
try:
|
||||
|
@ -197,9 +350,31 @@ def test_get_dataset() -> None:
|
|||
pass
|
||||
|
||||
|
||||
def test_retrieve_v2_dataset() -> None:
|
||||
dataset_name = "dummydataset"
|
||||
mock_ml_client = MagicMock()
|
||||
mock_retrieved_dataset = "dummy_data"
|
||||
mock_ml_client.data.get.return_value = mock_retrieved_dataset
|
||||
data_asset = _retrieve_v2_dataset(dataset_name, mock_ml_client)
|
||||
assert data_asset == mock_retrieved_dataset
|
||||
|
||||
|
||||
def test_create_v2_dataset() -> None:
|
||||
dataset_name = "dummydataset"
|
||||
datastore = "dummydataset"
|
||||
mock_ml_client = MagicMock()
|
||||
# mock_ml_client.data.create_or_update.return_value = None
|
||||
data_asset = _create_v2_dataset(datastore, dataset_name, mock_ml_client)
|
||||
assert isinstance(data_asset, Data)
|
||||
assert data_asset.path == "azureml://datastores/dummydataset/paths/dummydataset/"
|
||||
assert data_asset.type == "uri_folder"
|
||||
assert data_asset.name == dataset_name
|
||||
|
||||
|
||||
def test_dataset_keys() -> None:
|
||||
"""
|
||||
Check that dataset keys are non-empty strings, and that inputs and outputs have different keys.
|
||||
Check that dataset keys are non-e
|
||||
mpty strings, and that inputs and outputs have different keys.
|
||||
"""
|
||||
in1 = _input_dataset_key(1)
|
||||
out1 = _output_dataset_key(1)
|
||||
|
|
|
@ -19,15 +19,20 @@ from ruamel import yaml
|
|||
from ruamel.yaml.comments import CommentedMap as OrderedDict, CommentedSeq as OrderedList
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock, create_autospec, patch, DEFAULT
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from _pytest.capture import CaptureFixture
|
||||
from azure.ai.ml import Input, Output, MLClient
|
||||
from azure.ai.ml.constants import AssetTypes
|
||||
from azure.ai.ml.entities import Data
|
||||
from azure.ai.ml.sweep import Choice
|
||||
from azureml._restclient.constants import RunStatus
|
||||
from azureml.core import ComputeTarget, Environment, RunConfiguration, ScriptRunConfig, Workspace
|
||||
from azureml.data.azure_storage_datastore import AzureBlobDatastore
|
||||
from azureml.data.dataset_consumption_config import DatasetConsumptionConfig
|
||||
from azureml.dataprep.fuse.daemon import MountContext
|
||||
from azureml.train.hyperdrive import HyperDriveConfig
|
||||
|
||||
import health_azure.himl as himl
|
||||
|
@ -69,30 +74,34 @@ logger.setLevel(logging.DEBUG)
|
|||
# region Small fast local unit tests
|
||||
|
||||
@pytest.mark.fast
|
||||
def test_submit_to_azure_if_needed_returns_immediately() -> None:
|
||||
def test_submit_to_azure_if_needed_returns_immediately(tmp_path: Path) -> None:
|
||||
"""
|
||||
Test that himl.submit_to_azure_if_needed can be called, and returns immediately.
|
||||
"""
|
||||
with mock.patch("sys.argv", ["", "--azureml"]):
|
||||
shared_config_json = get_shared_config_json()
|
||||
with check_config_json(tmp_path, shared_config_json=shared_config_json):
|
||||
|
||||
with pytest.raises(Exception) as ex:
|
||||
himl.submit_to_azure_if_needed(
|
||||
aml_workspace=None,
|
||||
workspace_config_file=None,
|
||||
entry_script=Path(__file__),
|
||||
compute_cluster_name="foo",
|
||||
snapshot_root_directory=Path(__file__).parent)
|
||||
snapshot_root_directory=Path(__file__).parent,
|
||||
submit_to_azureml=True)
|
||||
# N.B. This assert may fail when run locally since we may find a workspace_config_file through the call to
|
||||
# _find_file(CONDA_ENVIRONMENT_FILE) in submit_to_azure_if_needed
|
||||
if _is_running_in_github_pipeline():
|
||||
assert "No workspace config file given, nor can we find one" in str(ex)
|
||||
with mock.patch("sys.argv", [""]):
|
||||
result = himl.submit_to_azure_if_needed(
|
||||
entry_script=Path(__file__),
|
||||
compute_cluster_name="foo",
|
||||
conda_environment_file=Path("env.yml"))
|
||||
assert isinstance(result, himl.AzureRunInfo)
|
||||
assert not result.is_running_in_azure_ml
|
||||
assert result.run is None
|
||||
|
||||
with mock.patch("sys.argv", [""]):
|
||||
result = himl.submit_to_azure_if_needed(
|
||||
entry_script=Path(__file__),
|
||||
compute_cluster_name="foo",
|
||||
conda_environment_file=shared_config_json)
|
||||
assert isinstance(result, himl.AzureRunInfo)
|
||||
assert not result.is_running_in_azure_ml
|
||||
assert result.run is None
|
||||
|
||||
|
||||
def _is_running_in_github_pipeline() -> bool:
|
||||
|
@ -254,20 +263,21 @@ def test_validate_compute_real(tmp_path: Path) -> None:
|
|||
|
||||
@pytest.mark.fast
|
||||
@patch("azureml.data.OutputFileDatasetConfig")
|
||||
@patch("health_azure.himl.DatasetConsumptionConfig")
|
||||
@patch("health_azure.himl.Workspace")
|
||||
@patch("health_azure.himl.DatasetConfig")
|
||||
def test_to_datasets(
|
||||
mock_dataset_config: mock.MagicMock,
|
||||
mock_workspace: mock.MagicMock,
|
||||
mock_dataset_consumption_config: mock.MagicMock,
|
||||
mock_output_file_dataset_config: mock.MagicMock) -> None:
|
||||
def to_input_dataset(workspace: Workspace, dataset_index: int, ) -> DatasetConsumptionConfig:
|
||||
def to_input_dataset(workspace: Workspace, dataset_index: int, strictly_aml_v1: bool,
|
||||
ml_client: Optional[MLClient] = None) -> DatasetConsumptionConfig:
|
||||
return mock_dataset_consumption_config
|
||||
|
||||
def to_output_dataset(workspace: Workspace, dataset_index: int, ) -> DatasetConsumptionConfig:
|
||||
def to_output_dataset(workspace: Workspace, dataset_index: int) -> DatasetConsumptionConfig:
|
||||
return mock_output_file_dataset_config
|
||||
|
||||
mock_dataset_consumption_config = mock.create_autospec(DatasetConsumptionConfig)
|
||||
mock_dataset_consumption_config.__class__.return_value = DatasetConsumptionConfig
|
||||
mock_dataset_consumption_config.name = "A Consumption Config"
|
||||
mock_output_file_dataset_config.name = "An Output File Dataset Config"
|
||||
mock_dataset_config.to_input_dataset = to_input_dataset
|
||||
|
@ -276,12 +286,14 @@ def test_to_datasets(
|
|||
himl.convert_himl_to_azureml_datasets(
|
||||
cleaned_input_datasets=[mock_dataset_config, mock_dataset_config],
|
||||
cleaned_output_datasets=[],
|
||||
strictly_aml_v1=True,
|
||||
workspace=mock_workspace)
|
||||
assert "already an input dataset with name" in str(ex1)
|
||||
with pytest.raises(ValueError) as ex2:
|
||||
himl.convert_himl_to_azureml_datasets(
|
||||
cleaned_input_datasets=[mock_dataset_config, mock_dataset_config],
|
||||
cleaned_output_datasets=[],
|
||||
strictly_aml_v1=True,
|
||||
workspace=mock_workspace)
|
||||
assert "already an output dataset with name" in str(ex2)
|
||||
|
||||
|
@ -290,6 +302,7 @@ def test_to_datasets(
|
|||
inputs, outputs = himl.convert_himl_to_azureml_datasets(
|
||||
cleaned_input_datasets=cleaned_input_datasets,
|
||||
cleaned_output_datasets=cleaned_output_datasets,
|
||||
strictly_aml_v1=True,
|
||||
workspace=mock_workspace)
|
||||
assert len(inputs) == 1
|
||||
assert len(outputs) == 1
|
||||
|
@ -341,9 +354,9 @@ def test_create_run_configuration(
|
|||
mock_env_name = "Mock Env"
|
||||
mock_environment_get.return_value = mock_env_name
|
||||
mock_workspace.compute_targets = {existing_compute_target: mock_compute_cluster}
|
||||
aml_input_dataset = MagicMock()
|
||||
aml_input_dataset = create_autospec(DatasetConsumptionConfig)
|
||||
aml_input_dataset.name = "dataset_in"
|
||||
aml_output_dataset = MagicMock()
|
||||
aml_output_dataset = create_autospec(DatasetConsumptionConfig)
|
||||
aml_output_dataset.name = "dataset_out"
|
||||
mock_to_input_dataset.return_value = aml_input_dataset
|
||||
mock_to_output_dataset.return_value = aml_output_dataset
|
||||
|
@ -460,7 +473,8 @@ def test_create_run_configuration_correct_env(mock_create_environment: MagicMock
|
|||
with pytest.raises(Exception) as e:
|
||||
himl.create_run_configuration(mock_workspace,
|
||||
"dummy_compute_cluster",
|
||||
conda_environment_file=conda_env_path)
|
||||
conda_environment_file=conda_env_path,
|
||||
)
|
||||
assert "you must specify the python version" in str(e)
|
||||
|
||||
# check that when create_run_configuration is called, whatever is returned from register_environment
|
||||
|
@ -538,13 +552,11 @@ def test_get_workspace_no_config(
|
|||
mock_is_running_in_azure.return_value = False
|
||||
with change_working_directory(tmp_path):
|
||||
with pytest.raises(ValueError) as ex:
|
||||
with mock.patch("sys.argv", ["", "--azureml"]):
|
||||
himl.submit_to_azure_if_needed(compute_cluster_name="foo")
|
||||
himl.submit_to_azure_if_needed(compute_cluster_name="foo", submit_to_azureml=True)
|
||||
assert "No workspace config file given" in str(ex)
|
||||
|
||||
|
||||
@pytest.mark.fast
|
||||
@patch("health_azure.himl.Run")
|
||||
@patch("health_azure.himl.Workspace")
|
||||
@patch("health_azure.himl._generate_azure_datasets")
|
||||
@patch("health_azure.himl.RUN_CONTEXT")
|
||||
|
@ -552,7 +564,7 @@ def test_submit_to_azure_if_needed_azure_return(
|
|||
mock_run_context: mock.MagicMock,
|
||||
mock_generate_azure_datasets: mock.MagicMock,
|
||||
mock_workspace: mock.MagicMock,
|
||||
mock_run: mock.MagicMock) -> None:
|
||||
) -> None:
|
||||
"""
|
||||
When running in AzureML, the call to submit_to_azure_if_needed should return immediately, without trying to
|
||||
submit a new job.
|
||||
|
@ -561,13 +573,13 @@ def test_submit_to_azure_if_needed_azure_return(
|
|||
mock_run_context.experiment = mock.MagicMock(workspace=mock_workspace)
|
||||
assert is_running_in_azure_ml(himl.RUN_CONTEXT)
|
||||
expected_run_info = himl.AzureRunInfo(
|
||||
run=mock_run,
|
||||
run=mock_run_context,
|
||||
input_datasets=[],
|
||||
output_datasets=[],
|
||||
mount_contexts=[],
|
||||
is_running_in_azure_ml=True,
|
||||
output_folder=Path.cwd(),
|
||||
logs_folder=Path.cwd())
|
||||
output_folder=Path.cwd() / himl.OUTPUT_FOLDER,
|
||||
logs_folder=Path.cwd() / himl.LOGS_FOLDER)
|
||||
mock_generate_azure_datasets.return_value = expected_run_info
|
||||
with mock.patch("sys.argv", ["", "--azureml"]):
|
||||
run_info = himl.submit_to_azure_if_needed(
|
||||
|
@ -835,6 +847,9 @@ def render_and_run_test_script(path: Path,
|
|||
run_requirements = True
|
||||
print("Copied 'src' folder.")
|
||||
|
||||
if run_target == RunTarget.AZUREML:
|
||||
extra_options["submit_to_azureml"] = 'True'
|
||||
|
||||
environment_yaml_path = path / "environment.yml"
|
||||
render_environment_yaml(environment_yaml_path, version, run_requirements, extra_options=extra_options)
|
||||
|
||||
|
@ -844,8 +859,6 @@ def render_and_run_test_script(path: Path,
|
|||
workspace_config_file_arg=workspace_config_file_arg)
|
||||
|
||||
score_args = [str(entry_script_path)]
|
||||
if run_target == RunTarget.AZUREML:
|
||||
score_args.append("--azureml")
|
||||
score_args.extend(extra_args)
|
||||
|
||||
env = dict(os.environ.items())
|
||||
|
@ -907,10 +920,11 @@ def test_invoking_hello_world_no_config(run_target: RunTarget, tmp_path: Path) -
|
|||
:param run_target: Where to run the script.
|
||||
:param tmp_path: PyTest test fixture for temporary path.
|
||||
"""
|
||||
parser_args = "parser.add_argument('-m', '--message', type=str, required=True, help='The message to print out')"
|
||||
message_guid = uuid4().hex
|
||||
extra_options = {
|
||||
'workspace_config_file': 'None',
|
||||
'args': 'parser.add_argument("-m", "--message", type=str, required=True, help="The message to print out")',
|
||||
'args': parser_args,
|
||||
'body': 'print(f"The message was: {args.message}")'
|
||||
}
|
||||
extra_args = [f"--message={message_guid}"]
|
||||
|
@ -946,8 +960,9 @@ def test_invoking_hello_world_config(run_target: RunTarget, use_package: bool, t
|
|||
return
|
||||
|
||||
message_guid = uuid4().hex
|
||||
parser_args = "parser.add_argument('-m', '--message', type=str, required=True, help='The message to print out')"
|
||||
extra_options = {
|
||||
'args': 'parser.add_argument("-m", "--message", type=str, required=True, help="The message to print out")',
|
||||
'args': parser_args,
|
||||
'body': 'print(f"The message was: {args.message}")'
|
||||
}
|
||||
extra_args = [f"--message={message_guid}"]
|
||||
|
@ -1006,8 +1021,9 @@ def test_invoking_hello_world_env_var(run_target: RunTarget, tmp_path: Path) ->
|
|||
import os
|
||||
import sys""",
|
||||
'environment_variables': f"{{'message_guid': '{message_guid}'}}",
|
||||
'body': 'print(f"The message_guid env var was: {os.getenv(\'message_guid\')}")'
|
||||
'body': 'print(f"The message_guid env var was: {os.getenv(\'message_guid\')}")',
|
||||
}
|
||||
|
||||
extra_args: List[str] = []
|
||||
output = render_and_run_test_script(tmp_path, run_target, extra_options, extra_args, True)
|
||||
expected_output = f"The message_guid env var was: {message_guid}"
|
||||
|
@ -1046,7 +1062,11 @@ def test_mounting_and_downloading_dataset(tmp_path: Path) -> None:
|
|||
use_mounting=use_mounting,
|
||||
target_folder=target_path)
|
||||
logging.info(f"ready to {action}")
|
||||
paths, mount_contexts = setup_local_datasets(dataset_configs=[dataset_config], aml_workspace=workspace)
|
||||
paths, mount_contexts = setup_local_datasets(
|
||||
dataset_configs=[dataset_config],
|
||||
strictly_aml_v1=True,
|
||||
aml_workspace=workspace
|
||||
)
|
||||
logging.info(f"{action} done")
|
||||
path = paths[0]
|
||||
assert path is not None
|
||||
|
@ -1235,7 +1255,7 @@ import sys
|
|||
file = input_folder / filename
|
||||
shutil.copy(file, output_folder)
|
||||
print(f"Copied file: {{file.name}} from {{input_blob_name}} to {{output_blob_name}}")
|
||||
"""
|
||||
""",
|
||||
}
|
||||
extra_args: List[str] = []
|
||||
output = render_and_run_test_script(tmp_path, run_target, extra_options, extra_args, True)
|
||||
|
@ -1288,33 +1308,204 @@ def test_create_crossval_hyperdrive_config(_: MagicMock, num_crossval_splits: in
|
|||
assert crossval_config._max_total_runs == num_crossval_splits
|
||||
|
||||
|
||||
def test_create_crossval_hyperparam_args_v2() -> None:
|
||||
num_splits = 3
|
||||
crossval_args = himl.create_crossval_hyperparam_args_v2(num_splits)
|
||||
assert isinstance(crossval_args, Dict)
|
||||
assert crossval_args[himl.MAX_TOTAL_TRIALS_ARG] == num_splits
|
||||
assert isinstance(crossval_args[himl.PARAM_SAMPLING_ARG], Dict)
|
||||
assert isinstance(crossval_args[himl.PARAM_SAMPLING_ARG]["crossval_index"], Choice)
|
||||
assert crossval_args[himl.PRIMARY_METRIC_ARG] == "val/loss"
|
||||
assert crossval_args[himl.SAMPLING_ALGORITHM_ARG] == "grid"
|
||||
assert crossval_args[himl.GOAL_ARG] == "Minimize"
|
||||
|
||||
|
||||
def test_create_grid_hyperparam_args_v2() -> None:
|
||||
mock_values_float = [0.1, 0.2, 0.5]
|
||||
mock_arg_name_float = "float_number"
|
||||
mock_metric_name_float = mock_arg_name_float
|
||||
hparams_args_float = himl.create_grid_hyperparam_args_v2(mock_values_float, mock_arg_name_float,
|
||||
mock_metric_name_float)
|
||||
assert isinstance(hparams_args_float, Dict)
|
||||
assert hparams_args_float[himl.MAX_TOTAL_TRIALS_ARG] == len(mock_values_float)
|
||||
assert isinstance(hparams_args_float[himl.PARAM_SAMPLING_ARG], Dict)
|
||||
assert isinstance(hparams_args_float[himl.PARAM_SAMPLING_ARG][mock_arg_name_float], Choice)
|
||||
assert hparams_args_float[himl.PRIMARY_METRIC_ARG] == mock_metric_name_float
|
||||
assert hparams_args_float[himl.SAMPLING_ALGORITHM_ARG] == "grid"
|
||||
assert hparams_args_float[himl.GOAL_ARG] == "Minimize"
|
||||
|
||||
mock_values_str = ["a", "b", "c"]
|
||||
mock_arg_name_str = "letter"
|
||||
mock_metric_name_str = mock_arg_name_str
|
||||
hparam_args_str = himl.create_grid_hyperparam_args_v2(mock_values_str, mock_arg_name_str,
|
||||
mock_metric_name_str)
|
||||
assert isinstance(hparam_args_str, Dict)
|
||||
assert hparam_args_str[himl.MAX_TOTAL_TRIALS_ARG] == len(mock_values_str)
|
||||
assert isinstance(hparam_args_str[himl.PARAM_SAMPLING_ARG], Dict)
|
||||
assert isinstance(hparam_args_str[himl.PARAM_SAMPLING_ARG][mock_arg_name_str], Choice)
|
||||
assert hparam_args_str[himl.PRIMARY_METRIC_ARG] == mock_metric_name_str
|
||||
assert hparam_args_str[himl.SAMPLING_ALGORITHM_ARG] == "grid"
|
||||
assert hparam_args_str[himl.GOAL_ARG] == "Minimize"
|
||||
|
||||
|
||||
@pytest.mark.fast
|
||||
@pytest.mark.parametrize("cross_validation_metric_name", [None, "accuracy"])
|
||||
@patch("sys.argv")
|
||||
@patch("health_azure.himl.exit")
|
||||
def test_submit_to_azure_if_needed_with_hyperdrive(mock_sys_args: MagicMock, mock_exit: MagicMock,
|
||||
def test_submit_to_azure_if_needed_with_hyperdrive(mock_sys_args: MagicMock,
|
||||
mock_exit: MagicMock,
|
||||
mock_compute_cluster: MagicMock,
|
||||
cross_validation_metric_name: Optional[str]) -> None:
|
||||
cross_validation_metric_name: Optional[str],
|
||||
) -> None:
|
||||
"""
|
||||
Test that himl.submit_to_azure_if_needed can be called, and returns immediately.
|
||||
"""
|
||||
cross_validation_metric_name = cross_validation_metric_name or ""
|
||||
mock_sys_args.return_value = ["", "--azureml"]
|
||||
with patch.object(Environment, "get", return_value="dummy_env"):
|
||||
with patch("azureml.core.Workspace") as mock_workspace:
|
||||
with patch("health_azure.himl.get_ml_client") as mock_get_ml_client:
|
||||
mock_ml_client = MagicMock()
|
||||
mock_get_ml_client.return_value = mock_ml_client
|
||||
with patch.object(Environment, "get", return_value="dummy_env"):
|
||||
mock_workspace = MagicMock()
|
||||
mock_workspace.compute_targets = {"foo": mock_compute_cluster}
|
||||
with patch("health_azure.datasets.setup_local_datasets") as mock_setup_local_datasets:
|
||||
mock_setup_local_datasets.return_value = [], []
|
||||
with patch("health_azure.himl.submit_run") as mock_submit_run:
|
||||
with patch("health_azure.himl.HyperDriveConfig") as mock_hyperdrive_config:
|
||||
crossval_config = himl.create_crossval_hyperdrive_config(
|
||||
num_splits=2,
|
||||
cross_val_index_arg_name="cross_val_split_index",
|
||||
metric_name=cross_validation_metric_name)
|
||||
himl.submit_to_azure_if_needed(
|
||||
aml_workspace=mock_workspace,
|
||||
ml_client=mock_ml_client,
|
||||
entry_script=Path(__file__),
|
||||
snapshot_root_directory=Path(__file__).parent,
|
||||
compute_cluster_name="foo",
|
||||
aml_environment_name="dummy_env",
|
||||
submit_to_azureml=True,
|
||||
hyperdrive_config=crossval_config,
|
||||
strictly_aml_v1=True)
|
||||
mock_submit_run.assert_called_once()
|
||||
mock_hyperdrive_config.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.fast
|
||||
def test_create_v2_inputs() -> None:
|
||||
mock_ml_client = MagicMock()
|
||||
mock_data_name = "mock_data"
|
||||
mock_data_version = "1"
|
||||
mock_data_path = "path/to/mock/data"
|
||||
mock_ml_client.data.get.return_value = Data(
|
||||
name=mock_data_name,
|
||||
version=mock_data_version,
|
||||
id=mock_data_path
|
||||
)
|
||||
|
||||
mock_input_dataconfigs = [DatasetConfig(name="dummy_dataset")]
|
||||
inputs = himl.create_v2_inputs(mock_ml_client, mock_input_dataconfigs)
|
||||
assert isinstance(inputs, Dict)
|
||||
assert len(inputs) == len(mock_input_dataconfigs)
|
||||
input_entry = inputs["INPUT_0"]
|
||||
assert isinstance(input_entry, Input)
|
||||
assert input_entry.type == AssetTypes.URI_FOLDER
|
||||
actual_path: str = input_entry.path # type: ignore
|
||||
assert actual_path == mock_data_path
|
||||
|
||||
|
||||
@pytest.mark.fast
|
||||
def test_create_v2_outputs() -> None:
|
||||
mock_datastore_name = "dummy_datastore"
|
||||
mock_data_name = "dummy_dataset"
|
||||
|
||||
mock_output_dataconfigs = [DatasetConfig(name=mock_data_name, datastore=mock_datastore_name)]
|
||||
outputs = himl.create_v2_outputs(mock_output_dataconfigs)
|
||||
assert isinstance(outputs, Dict)
|
||||
assert len(outputs) == len(mock_output_dataconfigs)
|
||||
output_entry = outputs["OUTPUT_0"]
|
||||
assert isinstance(output_entry, Output)
|
||||
assert output_entry.type == AssetTypes.URI_FOLDER
|
||||
expected_path = f"azureml://datastores/{mock_datastore_name}/paths/{mock_data_name}"
|
||||
assert expected_path in output_entry['path']
|
||||
|
||||
|
||||
def test_submit_to_azure_if_needed_v2() -> None:
|
||||
"""
|
||||
Check that submit_run_v2 is called when submit_to_azure_if_needed is called, unless strictly_aml_v1 is
|
||||
set to True, in which case submit_run should be called instead
|
||||
"""
|
||||
dummy_input_datasets: List[Optional[Path]] = []
|
||||
dummy_mount_contexts: List[MountContext] = []
|
||||
|
||||
with patch.multiple(
|
||||
"health_azure.himl",
|
||||
_package_setup=DEFAULT,
|
||||
get_workspace=DEFAULT,
|
||||
get_ml_client=DEFAULT,
|
||||
create_run_configuration=DEFAULT,
|
||||
create_script_run=DEFAULT,
|
||||
append_to_amlignore=DEFAULT,
|
||||
exit=DEFAULT
|
||||
) as mocks:
|
||||
mock_script_run = mocks["create_script_run"].return_value
|
||||
mock_script_run.script = "dummy_script"
|
||||
mock_script_run.source_directory = "dummy_dir"
|
||||
|
||||
with patch("health_azure.himl.setup_local_datasets") as mock_setup_datasets:
|
||||
mock_setup_datasets.return_value = dummy_input_datasets, dummy_mount_contexts
|
||||
with patch("health_azure.himl.submit_run_v2") as mock_submit_run_v2:
|
||||
return_value = himl.submit_to_azure_if_needed(
|
||||
workspace_config_file="mockconfig.json",
|
||||
snapshot_root_directory="dummy",
|
||||
submit_to_azureml=True,
|
||||
strictly_aml_v1=False
|
||||
)
|
||||
mock_submit_run_v2.assert_called_once()
|
||||
assert return_value is None
|
||||
|
||||
# Now supply strictly_aml_v1=True, and check that submit_run is called
|
||||
with patch("health_azure.himl.submit_run") as mock_submit_run:
|
||||
with patch("health_azure.himl.HyperDriveConfig") as mock_hyperdrive_config:
|
||||
crossval_config = himl.create_crossval_hyperdrive_config(
|
||||
num_splits=2,
|
||||
cross_val_index_arg_name="cross_val_split_index",
|
||||
metric_name=cross_validation_metric_name)
|
||||
himl.submit_to_azure_if_needed(
|
||||
aml_workspace=mock_workspace,
|
||||
entry_script=Path(__file__),
|
||||
compute_cluster_name="foo",
|
||||
aml_environment_name="dummy_env",
|
||||
submit_to_azureml=True,
|
||||
hyperdrive_config=crossval_config)
|
||||
mock_submit_run.assert_called_once()
|
||||
mock_hyperdrive_config.assert_called_once()
|
||||
return_value = himl.submit_to_azure_if_needed(
|
||||
workspace_config_file="mockconfig.json",
|
||||
snapshot_root_directory="dummy",
|
||||
submit_to_azureml=True,
|
||||
strictly_aml_v1=True,
|
||||
)
|
||||
mock_submit_run.assert_called_once()
|
||||
assert return_value is None
|
||||
|
||||
|
||||
@pytest.mark.fast
|
||||
def test_generate_input_dataset_command() -> None:
|
||||
input_datasets = {"INPUT_0": Input(), "INPUT_1": Input()}
|
||||
input_data_cmd = himl._generate_input_dataset_command(input_datasets)
|
||||
assert input_data_cmd == " --INPUT_0=${{inputs.INPUT_0}} --INPUT_1=${{inputs.INPUT_1}}"
|
||||
|
||||
|
||||
@pytest.mark.fast
|
||||
def test_generate_output_dataset_command() -> None:
|
||||
output_datasets = {"OUTPUT_0": Output(), "OUTPUT_1": Output()}
|
||||
output_data_cmd = himl._generate_output_dataset_command(output_datasets)
|
||||
assert output_data_cmd == " --OUTPUT_0=${{outputs.OUTPUT_0}} --OUTPUT_1=${{outputs.OUTPUT_1}}"
|
||||
|
||||
|
||||
@pytest.mark.fast
|
||||
def test_extract_v2_inputs_outputs_from_args() -> None:
|
||||
path_to_input_0 = "path_to_input_0"
|
||||
path_to_output_0 = "path_to_output_0"
|
||||
mock_args = [f"--INPUT_0={path_to_input_0}", "--INPUT_1=path_to_input_1", f"--OUTPUT_0={path_to_output_0}",
|
||||
"--a=foo", "--b=bar"]
|
||||
with patch.object(sys, "argv", new=mock_args):
|
||||
input_datasets, output_datasets = himl._extract_v2_inputs_outputs_from_args()
|
||||
assert len(input_datasets) == 2
|
||||
assert input_datasets[0] == Path(path_to_input_0)
|
||||
assert len(output_datasets) == 1
|
||||
assert output_datasets[0] == Path(path_to_output_0)
|
||||
|
||||
# similar args should be ignored
|
||||
mock_args_similar = [f"--input_0={path_to_input_0}", "--input_1=path_to_input_1", f"--output_0={path_to_output_0}",
|
||||
"--a=foo", "--b=bar"]
|
||||
with patch.object(sys, "argv", new=mock_args_similar):
|
||||
input_datasets, output_datasets = himl._extract_v2_inputs_outputs_from_args()
|
||||
assert len(input_datasets) == 0
|
||||
assert len(output_datasets) == 0
|
||||
|
|
|
@ -3,13 +3,10 @@
|
|||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import subprocess
|
||||
|
||||
from health_azure import himl_download
|
||||
from testazure.utils_testazure import MockRun
|
||||
|
||||
DOWNLOAD_SCRIPT_PATH = himl_download.__file__
|
||||
|
||||
|
@ -37,12 +34,3 @@ def test_download_aml_run_no_runs(tmp_path: Path) -> None:
|
|||
with pytest.raises(Exception) as e:
|
||||
subprocess.Popen(["python", DOWNLOAD_SCRIPT_PATH, "--run_id", "madeuprun", "--output_dir", str(tmp_path)])
|
||||
assert "was not found" in str(e)
|
||||
|
||||
|
||||
def test_retrieve_runs() -> None:
|
||||
with patch("health_azure.utils.get_aml_run_from_run_id") as mock_get_run:
|
||||
dummy_run_id = "run_id_123"
|
||||
mock_get_run.return_value = MockRun(dummy_run_id)
|
||||
dummy_download_config = himl_download.HimlDownloadConfig(run=[dummy_run_id])
|
||||
_ = himl_download.retrieve_runs(dummy_download_config)
|
||||
mock_get_run.assert_called_with(dummy_run_id)
|
||||
|
|
|
@ -12,7 +12,7 @@ def test_git_repo_root_folder() -> None:
|
|||
with patch("health_azure.paths.is_himl_used_from_git_repo", return_value=False):
|
||||
with pytest.raises(ValueError) as e:
|
||||
git_repo_root_folder()
|
||||
assert"This function should not be used if hi-ml is used as an installed package." in str(e)
|
||||
assert "This function should not be used if hi-ml is used as an installed package." in str(e)
|
||||
|
||||
|
||||
@pytest.mark.fast
|
||||
|
|
|
@ -96,7 +96,7 @@ import sys
|
|||
loss.backward()
|
||||
optimizer.step()
|
||||
writer.flush()
|
||||
"""
|
||||
""",
|
||||
}
|
||||
|
||||
extra_args: List[str] = []
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
[mypy]
|
||||
python_version=3.7
|
||||
python_version=3.9
|
||||
scripts_are_modules=True
|
||||
namespace_packages=True
|
||||
show_traceback=True
|
||||
|
|
|
@ -87,7 +87,7 @@ SRC_CKPT_RUN_ID_CRCK := TcgaCrckSSLMIL_1667236343_af6e293f
|
|||
# Run regression tests and compare performance
|
||||
define BASE_CPATH_RUNNER_COMMAND
|
||||
cd ../ ; \
|
||||
python hi-ml/src/health_ml/runner.py --mount_in_azureml --conda_env=hi-ml-cpath/environment.yml
|
||||
python hi-ml/src/health_ml/runner.py --mount_in_azureml --conda_env=hi-ml-cpath/environment.yml --datastore=himldatasets
|
||||
endef
|
||||
|
||||
define DEEPSMILEPANDASLIDES_ARGS
|
||||
|
|
|
@ -3,166 +3,205 @@
|
|||
name: HimlHisto
|
||||
channels:
|
||||
- pytorch
|
||||
- conda-forge
|
||||
- defaults
|
||||
dependencies:
|
||||
- _libgcc_mutex=0.1=main
|
||||
- _openmp_mutex=5.1=1_gnu
|
||||
- _libgcc_mutex=0.1=conda_forge
|
||||
- _openmp_mutex=4.5=2_gnu
|
||||
- blas=1.0=mkl
|
||||
- bzip2=1.0.8=h7b6447c_0
|
||||
- ca-certificates=2022.4.26=h06a4308_0
|
||||
- certifi=2022.6.15=py37h06a4308_0
|
||||
- ca-certificates=2022.10.11=h06a4308_0
|
||||
- cairo=1.16.0=hf32fb01_1
|
||||
- certifi=2022.9.24=py39h06a4308_0
|
||||
- cudatoolkit=11.3.1=h2bc3f7f_2
|
||||
- ffmpeg=4.2.2=h20bf706_0
|
||||
- freetype=2.11.0=h70c0345_0
|
||||
- fontconfig=2.13.1=hef1e5e3_1
|
||||
- freetype=2.12.1=h4a9f257_0
|
||||
- gdk-pixbuf=2.42.6=h04a7f16_0
|
||||
- gettext=0.21.0=hf68c758_0
|
||||
- giflib=5.2.1=h7b6447c_0
|
||||
- glib=2.68.2=h36276a3_0
|
||||
- gmp=6.2.1=h295c915_3
|
||||
- gnutls=3.6.15=he1e5248_0
|
||||
- icu=58.2=he6710b0_3
|
||||
- intel-openmp=2021.4.0=h06a4308_3561
|
||||
- jpeg=9e=h7f8727e_0
|
||||
- lame=3.100=h7b6447c_0
|
||||
- lcms2=2.12=h3be6417_0
|
||||
- libedit=3.1.20210910=h7f8727e_0
|
||||
- libffi=3.2.1=hf484d3e_1007
|
||||
- libgcc-ng=11.2.0=h1234567_1
|
||||
- libgomp=11.2.0=h1234567_1
|
||||
- ld_impl_linux-64=2.38=h1181459_1
|
||||
- lerc=3.0=h295c915_0
|
||||
- libdeflate=1.8=h7f8727e_5
|
||||
- libffi=3.3=he6710b0_2
|
||||
- libgcc-ng=12.2.0=h65d4601_19
|
||||
- libglib=2.68.2=h3e27bee_0
|
||||
- libgomp=12.2.0=h65d4601_19
|
||||
- libiconv=1.16=h7f8727e_2
|
||||
- libidn2=2.3.2=h7f8727e_0
|
||||
- libopus=1.3.1=h7b6447c_0
|
||||
- libpng=1.6.37=hbc83047_0
|
||||
- libstdcxx-ng=11.2.0=h1234567_1
|
||||
- libtasn1=4.16.0=h27cfd23_0
|
||||
- libtiff=4.2.0=h2818925_1
|
||||
- libtiff=4.4.0=hecacb30_1
|
||||
- libunistring=0.9.10=h27cfd23_0
|
||||
- libuuid=1.0.3=h7f8727e_2
|
||||
- libuv=1.40.0=h7b6447c_0
|
||||
- libvpx=1.7.0=h439df22_0
|
||||
- libwebp=1.2.2=h55f646e_0
|
||||
- libwebp-base=1.2.2=h7f8727e_0
|
||||
- libwebp=1.2.4=h11a3e52_0
|
||||
- libwebp-base=1.2.4=h5eee18b_0
|
||||
- libxcb=1.15=h7f8727e_0
|
||||
- libxml2=2.9.14=h74e7548_0
|
||||
- lz4-c=1.9.3=h295c915_1
|
||||
- mkl=2021.4.0=h06a4308_640
|
||||
- mkl-service=2.4.0=py37h7f8727e_0
|
||||
- mkl_fft=1.3.1=py37hd3c417c_0
|
||||
- mkl_random=1.2.2=py37h51133e4_0
|
||||
- mkl-service=2.4.0=py39h7f8727e_0
|
||||
- mkl_fft=1.3.1=py39hd3c417c_0
|
||||
- mkl_random=1.2.2=py39h51133e4_0
|
||||
- ncurses=6.3=h5eee18b_3
|
||||
- nettle=3.7.3=hbbd107a_1
|
||||
- openh264=2.1.1=h4ff587b_0
|
||||
- openssl=1.1.1q=h7f8727e_0
|
||||
- pip=20.1.1=py37_1
|
||||
- python=3.7.3
|
||||
- pytorch=1.10.0=py3.7_cuda11.3_cudnn8.2.0_0
|
||||
- openjpeg=2.4.0=h3ad879b_0
|
||||
- openslide=3.4.1=h978ee9a_4
|
||||
- openslide-python=1.2.0=py39hb9d737c_2
|
||||
- openssl=1.1.1s=h7f8727e_0
|
||||
- pcre=8.45=h295c915_0
|
||||
- pip=20.1.1=py_1
|
||||
- pixman=0.40.0=h7f8727e_1
|
||||
- python=3.9.13
|
||||
- python_abi=3.9=2_cp39
|
||||
- pytorch=1.10.0=py3.9_cuda11.3_cudnn8.2.0_0
|
||||
- pytorch-mutex=1.0=cuda
|
||||
- readline=7.0=h7b6447c_5
|
||||
- readline=8.2=h5eee18b_0
|
||||
- six=1.16.0=pyhd3eb1b0_1
|
||||
- sqlite=3.33.0=h62c20be_0
|
||||
- sqlite=3.39.3=h5082296_0
|
||||
- tk=8.6.12=h1ccaba5_0
|
||||
- torchvision=0.11.1=py37_cu113
|
||||
- typing_extensions=4.1.1=pyh06a4308_0
|
||||
- torchvision=0.11.1=py39_cu113
|
||||
- typing_extensions=4.3.0=py39h06a4308_0
|
||||
- tzdata=2022f=h04d1e81_0
|
||||
- x264=1!157.20191217=h7b6447c_0
|
||||
- xz=5.2.5=h7f8727e_1
|
||||
- zlib=1.2.12=h7f8727e_2
|
||||
- xz=5.2.6=h5eee18b_0
|
||||
- zlib=1.2.13=h5eee18b_0
|
||||
- zstd=1.5.2=ha4553b6_0
|
||||
- pip:
|
||||
- absl-py==1.2.0
|
||||
- absl-py==1.3.0
|
||||
- adal==1.2.7
|
||||
- aiohttp==3.8.1
|
||||
- aiosignal==1.2.0
|
||||
- aiohttp==3.8.3
|
||||
- aiosignal==1.3.1
|
||||
- alabaster==0.7.12
|
||||
- alembic==1.8.1
|
||||
- applicationinsights==0.11.10
|
||||
- argcomplete==2.0.0
|
||||
- astroid==2.9.3
|
||||
- astroid==2.12.12
|
||||
- async-timeout==4.0.2
|
||||
- asynctest==0.13.0
|
||||
- attrs==21.4.0
|
||||
- azure-ai-ml==1.1.0
|
||||
- azure-common==1.1.28
|
||||
- azure-core==1.24.2
|
||||
- azure-core==1.26.1
|
||||
- azure-graphrbac==0.61.1
|
||||
- azure-identity==1.7.0
|
||||
- azure-mgmt-authorization==2.0.0
|
||||
- azure-mgmt-containerregistry==10.0.0
|
||||
- azure-mgmt-core==1.3.0
|
||||
- azure-mgmt-core==1.3.1
|
||||
- azure-mgmt-keyvault==10.0.0
|
||||
- azure-mgmt-resource==21.1.0
|
||||
- azure-mgmt-resource==21.2.1
|
||||
- azure-mgmt-storage==20.0.0
|
||||
- azure-storage-blob==12.5.0
|
||||
- azure-storage-blob==12.10.0
|
||||
- azure-storage-file-datalake==12.9.1
|
||||
- azure-storage-file-share==12.10.1
|
||||
- azureml-core==1.43.0
|
||||
- azureml-dataprep==4.0.4
|
||||
- azureml-dataprep-native==38.0.0
|
||||
- azureml-dataprep-rslex==2.6.3
|
||||
- azureml-dataset-runtime==1.43.0.post2
|
||||
- azureml-dataset-runtime==1.43.0
|
||||
- azureml-mlflow==1.47.0
|
||||
- azureml-telemetry==1.43.0
|
||||
- azureml-tensorboard==1.43.0
|
||||
- azureml-train-core==1.43.0
|
||||
- azureml-train-restclients-hyperdrive==1.43.0
|
||||
- babel==2.10.3
|
||||
- babel==2.11.0
|
||||
- backports-tempfile==1.0
|
||||
- backports-weakref==1.0.post1
|
||||
- bcrypt==3.2.2
|
||||
- bcrypt==4.0.1
|
||||
- black==22.1.0
|
||||
- bleach==5.0.1
|
||||
- cachetools==4.2.4
|
||||
- cffi==1.15.1
|
||||
- cfgv==3.3.1
|
||||
- charset-normalizer==2.1.0
|
||||
- charset-normalizer==2.1.1
|
||||
- click==8.1.3
|
||||
- cloudpickle==1.6.0
|
||||
- colorama==0.4.5
|
||||
- colorama==0.4.6
|
||||
- coloredlogs==15.0.1
|
||||
- conda-merge==0.1.5
|
||||
- contextlib2==21.6.0
|
||||
- coverage==6.3.2
|
||||
- cryptography==37.0.4
|
||||
- cryptography==38.0.3
|
||||
- cucim==22.4.0
|
||||
- cycler==0.11.0
|
||||
- databricks-cli==0.17.3
|
||||
- dataclasses-json==0.5.2
|
||||
- dill==0.3.6
|
||||
- diskcache==5.4.0
|
||||
- distlib==0.3.5
|
||||
- distro==1.7.0
|
||||
- distlib==0.3.6
|
||||
- distro==1.8.0
|
||||
- docker==5.0.3
|
||||
- docutils==0.16
|
||||
- dotnetcore2==3.1.23
|
||||
- filelock==3.7.1
|
||||
- entrypoints==0.4
|
||||
- filelock==3.8.0
|
||||
- flake8==4.0.1
|
||||
- frozenlist==1.3.0
|
||||
- fsspec==2022.5.0
|
||||
- flask==2.2.2
|
||||
- frozenlist==1.3.3
|
||||
- fsspec==2022.10.0
|
||||
- fusepy==3.0.1
|
||||
- girder-client==3.1.14
|
||||
- gitdb==4.0.9
|
||||
- gitpython==3.1.29
|
||||
- google-api-core==2.10.2
|
||||
- google-auth==1.35.0
|
||||
- google-auth-oauthlib==0.4.6
|
||||
- grpcio==1.47.0
|
||||
- hi-ml==0.2.3
|
||||
- hi-ml-azure==0.2.3
|
||||
- googleapis-common-protos==1.56.4
|
||||
- greenlet==2.0.1
|
||||
- grpcio==1.50.0
|
||||
- gunicorn==20.1.0
|
||||
- hi-ml==0.2.9
|
||||
- hi-ml-azure==0.2.9
|
||||
- humanfriendly==10.0
|
||||
- identify==2.5.2
|
||||
- idna==3.3
|
||||
- imageio==2.19.5
|
||||
- identify==2.5.8
|
||||
- idna==3.4
|
||||
- imageio==2.22.4
|
||||
- imagesize==1.4.1
|
||||
- importlib-metadata==4.2.0
|
||||
- importlib-resources==5.8.0
|
||||
- iniconfig==1.1.1
|
||||
- isodate==0.6.1
|
||||
- isort==5.10.1
|
||||
- itsdangerous==2.1.2
|
||||
- jaraco-classes==3.2.3
|
||||
- jeepney==0.8.0
|
||||
- jinja2==3.0.2
|
||||
- jmespath==1.0.0
|
||||
- joblib==1.1.0
|
||||
- joblib==1.2.0
|
||||
- jsonpickle==2.2.0
|
||||
- jsonschema==4.7.2
|
||||
- keyring==23.7.0
|
||||
- jsonschema==4.17.0
|
||||
- keyring==23.11.0
|
||||
- kiwisolver==1.4.4
|
||||
- knack==0.9.0
|
||||
- lazy-object-proxy==1.7.1
|
||||
- lazy-object-proxy==1.8.0
|
||||
- lightning-bolts==0.4.0
|
||||
- llvmlite==0.38.1
|
||||
- llvmlite==0.39.1
|
||||
- lxml==4.9.1
|
||||
- mako==1.2.3
|
||||
- markdown==2.6.8
|
||||
- markdown-it-py==1.1.0
|
||||
- markupsafe==2.1.1
|
||||
- marshmallow==3.17.0
|
||||
- marshmallow==3.18.0
|
||||
- marshmallow-enum==1.5.1
|
||||
- matplotlib==3.4.3
|
||||
- mccabe==0.6.1
|
||||
- mdit-py-plugins==0.2.8
|
||||
- mlflow==1.30.0
|
||||
- mlflow-skinny==1.30.0
|
||||
- monai==0.8.0
|
||||
- more-itertools==8.10.0
|
||||
- msal==1.18.0
|
||||
- msal==1.20.0
|
||||
- msal-extensions==0.3.1
|
||||
- msrest==0.6.21
|
||||
- msrestazure==0.6.4
|
||||
|
@ -171,25 +210,31 @@ dependencies:
|
|||
- mypy-extensions==0.4.3
|
||||
- myst-parser==0.15.2
|
||||
- ndg-httpsclient==0.5.1
|
||||
- networkx==2.6.3
|
||||
- networkx==2.8.8
|
||||
- nodeenv==1.7.0
|
||||
- numba==0.55.2
|
||||
- numpy==1.21.6
|
||||
- oauthlib==3.2.0
|
||||
- numba==0.56.4
|
||||
- numpy==1.22.0
|
||||
- oauthlib==3.2.2
|
||||
- opencensus==0.11.0
|
||||
- opencensus-context==0.1.3
|
||||
- opencensus-ext-azure==1.1.7
|
||||
- opencv-python-headless==4.5.1.48
|
||||
- packaging==21.3
|
||||
- pandas==1.3.4
|
||||
- param==1.12.0
|
||||
- paramiko==2.11.0
|
||||
- pathspec==0.9.0
|
||||
- paramiko==2.12.0
|
||||
- pathspec==0.10.1
|
||||
- pillow==9.0.1
|
||||
- pipdeptree==2.2.1
|
||||
- pkginfo==1.8.3
|
||||
- platformdirs==2.5.2
|
||||
- platformdirs==2.5.3
|
||||
- pluggy==0.13.1
|
||||
- portalocker==2.5.1
|
||||
- portalocker==2.6.0
|
||||
- pre-commit==2.19.0
|
||||
- prometheus-client==0.15.0
|
||||
- prometheus-flask-exporter==0.20.3
|
||||
- protobuf==3.20.1
|
||||
- psutil==5.9.4
|
||||
- py==1.11.0
|
||||
- pyarrow==3.0.0
|
||||
- pyasn1==0.4.8
|
||||
|
@ -197,43 +242,45 @@ dependencies:
|
|||
- pycobertura==2.0.1
|
||||
- pycodestyle==2.8.0
|
||||
- pycparser==2.21
|
||||
- pydash==5.1.1
|
||||
- pydeprecate==0.3.2
|
||||
- pydicom==2.3.0
|
||||
- pyflakes==2.4.0
|
||||
- pygments==2.12.0
|
||||
- pyjwt==2.4.0
|
||||
- pylint==2.12.2
|
||||
- pygments==2.13.0
|
||||
- pyjwt==2.6.0
|
||||
- pylint==2.15.0
|
||||
- pynacl==1.5.0
|
||||
- pynndescent==0.5.7
|
||||
- pyopenssl==22.0.0
|
||||
- pynndescent==0.5.8
|
||||
- pyopenssl==22.1.0
|
||||
- pyparsing==3.0.9
|
||||
- pyrsistent==0.18.1
|
||||
- pysocks==1.6.0
|
||||
- pyrsistent==0.19.2
|
||||
- pysocks==1.7.1
|
||||
- pytest==6.2.2
|
||||
- pytest-cov==2.11.1
|
||||
- pytest-rerunfailures==10.2
|
||||
- pytest-timeout==2.0.1
|
||||
- python-dateutil==2.8.2
|
||||
- pytorch-lightning==1.6.5
|
||||
- pytz==2022.1
|
||||
- pywavelets==1.3.0
|
||||
- pytz==2022.6
|
||||
- pywavelets==1.4.1
|
||||
- pyyaml==6.0
|
||||
- readme-renderer==35.0
|
||||
- querystring-parser==1.2.4
|
||||
- readme-renderer==37.3
|
||||
- requests==2.28.1
|
||||
- requests-oauthlib==1.3.1
|
||||
- requests-toolbelt==0.9.1
|
||||
- requests-toolbelt==0.10.1
|
||||
- rfc3986==2.0.0
|
||||
- rpdb==0.1.6
|
||||
- rsa==4.9
|
||||
- ruamel-yaml==0.16.12
|
||||
- ruamel-yaml-clib==0.2.6
|
||||
- scikit-image==0.19.3
|
||||
- scikit-learn==1.0.2
|
||||
- scikit-learn==1.1.3
|
||||
- scipy==1.7.3
|
||||
- seaborn==0.10.1
|
||||
- secretstorage==3.3.2
|
||||
- secretstorage==3.3.3
|
||||
- setuptools==59.5.0
|
||||
- simpleitk==2.1.1.2
|
||||
- smmap==5.0.0
|
||||
- snowballstemmer==2.2.0
|
||||
- sphinx==4.1.2
|
||||
- sphinx-autodoc-typehints==1.12.0
|
||||
|
@ -245,8 +292,11 @@ dependencies:
|
|||
- sphinxcontrib-jsmath==1.0.1
|
||||
- sphinxcontrib-qthelp==1.0.3
|
||||
- sphinxcontrib-serializinghtml==1.1.5
|
||||
- sqlalchemy==1.4.43
|
||||
- sqlparse==0.4.3
|
||||
- strictyaml==1.6.2
|
||||
- stringcase==1.2.0
|
||||
- tabulate==0.8.10
|
||||
- tabulate==0.9.0
|
||||
- tensorboard==2.6.0
|
||||
- tensorboard-data-server==0.6.1
|
||||
- tensorboard-plugin-wit==1.8.1
|
||||
|
@ -254,19 +304,19 @@ dependencies:
|
|||
- tifffile==2021.11.2
|
||||
- toml==0.10.2
|
||||
- tomli==2.0.1
|
||||
- tomlkit==0.11.6
|
||||
- torchmetrics==0.6.0
|
||||
- tqdm==4.64.0
|
||||
- tqdm==4.64.1
|
||||
- twine==3.3.0
|
||||
- typed-ast==1.5.4
|
||||
- typing-inspect==0.7.1
|
||||
- typing-inspect==0.8.0
|
||||
- umap-learn==0.5.2
|
||||
- urllib3==1.26.9
|
||||
- virtualenv==20.15.1
|
||||
- virtualenv==20.16.6
|
||||
- webencodings==0.5.1
|
||||
- websocket-client==1.3.3
|
||||
- werkzeug==2.1.2
|
||||
- websocket-client==1.4.2
|
||||
- werkzeug==2.2.2
|
||||
- wheel==0.36.2
|
||||
- wrapt==1.13.3
|
||||
- wrapt==1.14.1
|
||||
- yacs==0.1.8
|
||||
- yarl==1.7.2
|
||||
- zipp==3.8.1
|
||||
- yarl==1.8.1
|
||||
- zipp==3.10.0
|
||||
|
|
|
@ -6,9 +6,11 @@ channels:
|
|||
dependencies:
|
||||
- cudatoolkit=11.3.1
|
||||
- pip=20.1.1
|
||||
- python=3.7.3
|
||||
- python=3.9.13
|
||||
- pytorch=1.10.0
|
||||
- torchvision=0.11.1
|
||||
- openslide=3.4.1
|
||||
- openslide-python=1.2.0
|
||||
- pip:
|
||||
# Run requirements for hi-ml
|
||||
- dataclasses-json==0.5.2
|
||||
|
@ -23,13 +25,12 @@ dependencies:
|
|||
- setuptools==59.5.0
|
||||
# Run requirements for hi-ml-azure
|
||||
- azureml-core==1.43.0
|
||||
- azureml-dataset-runtime[fuse]
|
||||
- azureml-dataset-runtime[fuse]==1.43.0
|
||||
- azureml-tensorboard==1.43.0
|
||||
- azureml-train-core==1.43.0
|
||||
- conda-merge==0.1.5
|
||||
- msal-extensions==0.3.1
|
||||
- param==1.12
|
||||
- pysocks==1.6.0
|
||||
- ruamel.yaml==0.16.12
|
||||
- tensorboard==2.6.0
|
||||
# Histopathology requirements
|
||||
|
@ -39,6 +40,10 @@ dependencies:
|
|||
# Build requirements
|
||||
- -r requirements_build.txt
|
||||
# Pinned secondary dependencies to prevent clashes
|
||||
- attrs==21.4.0
|
||||
- azure-mgmt-core==1.3.1
|
||||
- azure-mgmt-keyvault==10.0.0
|
||||
- cryptography>=38.0.3
|
||||
- cloudpickle==1.6.0
|
||||
- importlib-metadata==4.2.0
|
||||
- markdown==2.6.8
|
||||
|
|
|
@ -6,7 +6,7 @@ hi-ml
|
|||
lightning-bolts==0.4.0
|
||||
monai==0.8.0
|
||||
more-itertools==8.10.0
|
||||
numpy==1.21.6
|
||||
numpy==1.22.0
|
||||
pillow==9.0.1
|
||||
pydicom==2.3.0
|
||||
scikit-image==0.19.3
|
||||
|
|
|
@ -5,7 +5,7 @@ mypy==0.961
|
|||
mypy-extensions==0.4.3
|
||||
pipdeptree==2.2.1
|
||||
pre-commit==2.19.0
|
||||
pylint==2.12.2
|
||||
pylint==2.15.0
|
||||
pycobertura==2.0.1
|
||||
pytest==6.2.2
|
||||
pytest-cov==2.11.1
|
||||
|
|
|
@ -19,69 +19,69 @@ from setuptools import find_namespace_packages, setup # type: ignore
|
|||
here = pathlib.Path(__file__).parent.resolve()
|
||||
|
||||
# Get the long description from the README file
|
||||
long_description = (here / 'README.md').read_text(encoding='utf-8')
|
||||
long_description = (here / "README.md").read_text(encoding="utf-8")
|
||||
|
||||
version = ''
|
||||
version = ""
|
||||
|
||||
# If running from a GitHub Action then a standard set of environment variables will be
|
||||
# populated (https://docs.github.com/en/actions/reference/environment-variables#default-environment-variables).
|
||||
# In particular, GITHUB_REF is the branch or tag ref that triggered the workflow.
|
||||
# If this was triggered by a tagged commit then GITHUB_REF will be: 'ref/tags/new_tag'.
|
||||
# If this was triggered by a tagged commit then GITHUB_REF will be: "ref/tags/new_tag".
|
||||
# Extract this tag and use it as a version string
|
||||
# See also:
|
||||
# https://packaging.python.org/guides/publishing-package-distribution-releases-using-github-actions-ci-cd-workflows/
|
||||
# https://github.com/pypa/gh-action-pypi-publish
|
||||
GITHUB_REF_TAG_COMMIT = 'refs/tags/v'
|
||||
GITHUB_REF_TAG_COMMIT = "refs/tags/v"
|
||||
|
||||
github_ref = os.getenv('GITHUB_REF')
|
||||
github_ref = os.getenv("GITHUB_REF")
|
||||
if github_ref and github_ref.startswith(GITHUB_REF_TAG_COMMIT):
|
||||
version = github_ref[len(GITHUB_REF_TAG_COMMIT):]
|
||||
|
||||
# Otherwise, if running from a GitHub Action, but not a tagged commit then GITHUB_RUN_NUMBER will be populated.
|
||||
# Use this as a post release number. For example if GITHUB_RUN_NUMBER = 124 then the version string will be
|
||||
# '99.99.post124'. Although this is discouraged, see:
|
||||
# "99.99.post124". Although this is discouraged, see:
|
||||
# https://www.python.org/dev/peps/pep-0440/#post-releases
|
||||
# it is necessary here to avoid duplicate packages in Test.PyPI.
|
||||
if not version:
|
||||
build_number = os.getenv('GITHUB_RUN_NUMBER')
|
||||
build_number = os.getenv("GITHUB_RUN_NUMBER")
|
||||
if build_number:
|
||||
# In github workflows, we may pull in hi-ml-multimodal as a dependency. Usually, we have a condition like
|
||||
# hi-ml-multimodal>=0.1.5. This means that a package version from PyPi would trump the local wheels. For this
|
||||
# reason, use an extremely large version number to give the local wheel priority.
|
||||
version = '99.99.post' + build_number
|
||||
version = "99.99.post" + build_number
|
||||
else:
|
||||
default_random_version_number = floor(random() * 10_000_000_000)
|
||||
version = f'99.99.post{str(default_random_version_number)}'
|
||||
version = f"99.99.post{str(default_random_version_number)}"
|
||||
|
||||
package_name = 'hi-ml-cpath'
|
||||
(here / 'package_name.txt').write_text(package_name)
|
||||
(here / 'latest_version.txt').write_text(version)
|
||||
package_name = "hi-ml-cpath"
|
||||
(here / "package_name.txt").write_text(package_name)
|
||||
(here / "latest_version.txt").write_text(version)
|
||||
|
||||
# Read run_requirements.txt to get install_requires
|
||||
install_requires = (here / 'requirements_run.txt').read_text().split("\n")
|
||||
install_requires = (here / "requirements_run.txt").read_text().split("\n")
|
||||
# Remove any whitespace and blank lines
|
||||
install_requires = [line.strip() for line in install_requires if line.strip()]
|
||||
|
||||
description = 'Microsoft Health Futures package for deep learning on histopathology images'
|
||||
description = "Microsoft Health Futures package for deep learning on histopathology images"
|
||||
|
||||
setup(
|
||||
name=package_name,
|
||||
version=version,
|
||||
description=description,
|
||||
long_description=long_description,
|
||||
long_description_content_type='text/markdown',
|
||||
url='https://github.com/microsoft/hi-ml',
|
||||
author="Microsoft Research Cambridge Medical Imaging Team ",
|
||||
long_description_content_type="text/markdown",
|
||||
url="https://github.com/microsoft/hi-ml",
|
||||
author="Biomedical Imaging Team @ Microsoft Health Futures",
|
||||
author_email="innereyedev@microsoft.com",
|
||||
classifiers=[
|
||||
'Development Status :: 3 - Alpha',
|
||||
'Intended Audience :: Science/Research',
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Intended Audience :: Science/Research",
|
||||
"Topic :: Scientific/Engineering :: Medical Science Apps.",
|
||||
'License :: OSI Approved :: MIT License',
|
||||
'Programming Language :: Python :: 3.7'
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3.7"
|
||||
],
|
||||
keywords='Health Futures, Health Intelligence, Computational Pathology, AzureML',
|
||||
license='MIT License',
|
||||
keywords="Health Futures, Health Intelligence, Computational Pathology, AzureML",
|
||||
license="MIT License",
|
||||
packages=find_namespace_packages(where="src", include=["health_cpath.*"]),
|
||||
package_dir={"": "src"},
|
||||
include_package_data=False,
|
||||
|
|
|
@ -2,18 +2,16 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from SSL.lightning_containers.ssl_container import EncoderName, SSLContainer, SSLDatasetName
|
||||
from SSL.lightning_containers.ssl_container import EncoderName, SSLContainer
|
||||
from SSL.utils import SSLTrainingType
|
||||
from health_cpath.datasets.default_paths import TCGA_CRCK_DATASET_ID
|
||||
from health_cpath.datasets.tcga_crck_tiles_dataset import TcgaCrck_TilesDatasetWithReturnIndex
|
||||
from SSL.configs.HistoSimCLRContainer import HistoSSLContainer
|
||||
|
||||
|
||||
class SSLDatasetNameHiml(SSLDatasetName, Enum): # type: ignore
|
||||
TCGA_CRCK = "CRCKTilesDataset"
|
||||
SSL_Dataset_TCGA_CRCK = "CRCKTilesDataset"
|
||||
|
||||
|
||||
class CRCK_SimCLR(HistoSSLContainer):
|
||||
|
@ -23,7 +21,7 @@ class CRCK_SimCLR(HistoSSLContainer):
|
|||
in the _get_transforms method.
|
||||
It has been tested locally and on AML on the full training dataset (93408 tiles).
|
||||
"""
|
||||
SSLContainer._SSLDataClassMappings.update({SSLDatasetNameHiml.TCGA_CRCK.value:
|
||||
SSLContainer.DatasetToClassMapping.update({SSL_Dataset_TCGA_CRCK:
|
||||
TcgaCrck_TilesDatasetWithReturnIndex})
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
|
@ -32,8 +30,8 @@ class CRCK_SimCLR(HistoSSLContainer):
|
|||
# --num_workers = 0
|
||||
# --max_epochs = 2
|
||||
|
||||
super().__init__(ssl_training_dataset_name=SSLDatasetNameHiml.TCGA_CRCK,
|
||||
linear_head_dataset_name=SSLDatasetNameHiml.TCGA_CRCK,
|
||||
super().__init__(ssl_training_dataset_name=SSL_Dataset_TCGA_CRCK,
|
||||
linear_head_dataset_name=SSL_Dataset_TCGA_CRCK,
|
||||
azure_datasets=[TCGA_CRCK_DATASET_ID],
|
||||
random_seed=1,
|
||||
num_workers=8,
|
||||
|
|
|
@ -71,7 +71,7 @@ class HistoSSLContainer(SSLContainer):
|
|||
self.online_eval = SslOnlineEvaluatorHiml(class_weights=self.data_module.class_weights, # type: ignore
|
||||
z_dim=self.encoder_output_dim,
|
||||
num_classes=self.data_module.num_classes, # type: ignore
|
||||
dataset=self.linear_head_dataset_name.value, # type: ignore
|
||||
dataset=self.linear_head_dataset_name, # type: ignore
|
||||
drop_p=0.2,
|
||||
learning_rate=self.learning_rate_linear_head_during_ssl_training)
|
||||
|
||||
|
|
|
@ -2,10 +2,9 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from SSL.lightning_containers.ssl_container import EncoderName, SSLContainer, SSLDatasetName
|
||||
from SSL.lightning_containers.ssl_container import EncoderName, SSLContainer
|
||||
from SSL.utils import SSLTrainingType
|
||||
from health_azure.utils import is_running_in_azure_ml
|
||||
from health_cpath.datasets.panda_tiles_dataset import PandaTilesDatasetWithReturnIndex
|
||||
|
@ -13,8 +12,7 @@ from health_cpath.datasets.default_paths import PANDA_5X_TILES_DATASET_ID
|
|||
from SSL.configs.HistoSimCLRContainer import HistoSSLContainer
|
||||
|
||||
|
||||
class SSLDatasetNameHiml(SSLDatasetName, Enum): # type: ignore
|
||||
PANDA = "PandaTilesDataset"
|
||||
SSL_Dataset_PANDA = "PandaTilesDataset"
|
||||
|
||||
|
||||
class PANDA_SimCLR(HistoSSLContainer):
|
||||
|
@ -24,11 +22,11 @@ class PANDA_SimCLR(HistoSSLContainer):
|
|||
in the _get_transforms method.
|
||||
It has been tested on a toy local dataset (2 slides) and on AML on (~25 slides).
|
||||
"""
|
||||
SSLContainer._SSLDataClassMappings.update({SSLDatasetNameHiml.PANDA.value: PandaTilesDatasetWithReturnIndex})
|
||||
SSLContainer.DatasetToClassMapping.update({SSL_Dataset_PANDA: PandaTilesDatasetWithReturnIndex})
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(ssl_training_dataset_name=SSLDatasetNameHiml.PANDA,
|
||||
linear_head_dataset_name=SSLDatasetNameHiml.PANDA,
|
||||
super().__init__(ssl_training_dataset_name=SSL_Dataset_PANDA,
|
||||
linear_head_dataset_name=SSL_Dataset_PANDA,
|
||||
azure_datasets=[PANDA_5X_TILES_DATASET_ID],
|
||||
random_seed=1,
|
||||
num_workers=5,
|
||||
|
|
|
@ -101,7 +101,7 @@ class RSNAKaggleCXR(CxrDatasetWithReturnIndex):
|
|||
self.dataset_dataframe = pd.read_csv(self.root / "dataset.csv")
|
||||
self.targets = self.dataset_dataframe.label.values.astype(np.int64)
|
||||
self.subject_ids = self.dataset_dataframe.subject.values
|
||||
self.indices = np.arange(len(self.dataset_dataframe))
|
||||
self.indices = np.arange(len(self.dataset_dataframe)).tolist()
|
||||
self.filenames = [self.root / f"{subject_id}.dcm" for subject_id in self.subject_ids]
|
||||
else:
|
||||
# No test set implemented for this data class.
|
||||
|
@ -141,7 +141,7 @@ class NIHCXR(CxrDatasetWithReturnIndex):
|
|||
train_ids = pd.read_csv(self.root / "train_val_list.txt", header=None).values.reshape(-1)
|
||||
is_train_val_ids = self.dataset_dataframe["Image Index"].isin(train_ids).values
|
||||
self.subject_ids = np.where(is_train_val_ids)[0] if self.train else np.where(~is_train_val_ids)[0]
|
||||
self.indices = np.arange(len(self.subject_ids))
|
||||
self.indices = np.arange(len(self.subject_ids)).tolist()
|
||||
self.filenames = [self.root / f"{subject_id}" for subject_id in self.subject_ids]
|
||||
|
||||
|
||||
|
@ -175,7 +175,7 @@ class CheXpert(CxrDatasetWithReturnIndex):
|
|||
# Strip away the name of the folder that is included in the path column of the dataset
|
||||
strip_n = len("CheXpert-v1.0-small/")
|
||||
self.dataset_dataframe.Path = self.dataset_dataframe.Path.apply(lambda x: x[strip_n:])
|
||||
self.indices = np.arange(len(self.dataset_dataframe))
|
||||
self.indices = np.arange(len(self.dataset_dataframe)).tolist()
|
||||
self.filenames = [self.root / p for p in self.dataset_dataframe.Path.values]
|
||||
|
||||
|
||||
|
@ -191,7 +191,7 @@ class CovidDataset(CxrDatasetWithReturnIndex):
|
|||
mapping = {0: 0, 3: 0, 1: 1, 2: 1}
|
||||
# For monitoring purpose with use binary classification CV03vsCV12
|
||||
self.dataset_dataframe["final_label"] = self.dataset_dataframe.final_label.apply(lambda x: mapping[x])
|
||||
self.indices = np.arange(len(self.dataset_dataframe))
|
||||
self.indices = np.arange(len(self.dataset_dataframe)).tolist()
|
||||
self.subject_ids = self.dataset_dataframe.subject.values
|
||||
self.filenames = [self.root / file for file in self.dataset_dataframe.filepath.values]
|
||||
self.targets = self.dataset_dataframe.final_label.values.astype(np.int64).reshape(-1)
|
||||
|
|
|
@ -91,4 +91,4 @@ def get_encoder_output_dim(
|
|||
with torch.no_grad():
|
||||
representations = pl_module(x)
|
||||
|
||||
return representations.shape[1]
|
||||
return representations.shape[1] # type: ignore
|
||||
|
|
|
@ -42,7 +42,7 @@ class EncoderName(Enum):
|
|||
densenet121 = "densenet121"
|
||||
|
||||
|
||||
class SSLDatasetName(Enum):
|
||||
class SSLDatasetName:
|
||||
CIFAR10 = "CIFAR10"
|
||||
CIFAR100 = "CIFAR100"
|
||||
RSNAKaggleCXR = "RSNAKaggleCXR"
|
||||
|
@ -64,17 +64,17 @@ class SSLContainer(LightningContainer):
|
|||
Note that this container is also used as the base class for SSLImageClassifier (finetuning container) as they share
|
||||
setup and datamodule methods.
|
||||
"""
|
||||
_SSLDataClassMappings = {SSLDatasetName.CIFAR10.value: HimlCifar10,
|
||||
SSLDatasetName.CIFAR100.value: HimlCifar100,
|
||||
SSLDatasetName.RSNAKaggleCXR.value: RSNAKaggleCXR,
|
||||
SSLDatasetName.NIHCXR.value: NIHCXR,
|
||||
SSLDatasetName.CheXpert.value: CheXpert,
|
||||
SSLDatasetName.Covid.value: CovidDataset}
|
||||
DatasetToClassMapping = {SSLDatasetName.CIFAR10: HimlCifar10,
|
||||
SSLDatasetName.CIFAR100: HimlCifar100,
|
||||
SSLDatasetName.RSNAKaggleCXR: RSNAKaggleCXR,
|
||||
SSLDatasetName.NIHCXR: NIHCXR,
|
||||
SSLDatasetName.CheXpert: CheXpert,
|
||||
SSLDatasetName.Covid: CovidDataset}
|
||||
|
||||
ssl_augmentation_config = param.ClassSelector(class_=Path, allow_None=True,
|
||||
doc="The path to the yaml config defining the parameters of the "
|
||||
"augmentations. Ignored for CIFAR10 example")
|
||||
ssl_training_dataset_name = param.ClassSelector(class_=SSLDatasetName, doc="The name of the dataset")
|
||||
ssl_training_dataset_name: str = param.String(default="", doc="The name of the dataset")
|
||||
ssl_training_batch_size = param.Integer(
|
||||
doc="Training batch size per GPU. The effective batch size will be the number of GPUs times this number. "
|
||||
"For example, if you specify ssl_training_batch_size=100 and use 4 nodes with 4 gpus each, "
|
||||
|
@ -91,8 +91,8 @@ class SSLContainer(LightningContainer):
|
|||
linear_head_augmentation_config = param.ClassSelector(class_=Path,
|
||||
doc="The path to the yaml config for the linear head "
|
||||
"augmentations")
|
||||
linear_head_dataset_name = param.ClassSelector(class_=SSLDatasetName,
|
||||
doc="Name of the dataset to use for the linear head training")
|
||||
linear_head_dataset_name: str = param.String(default="",
|
||||
doc="Name of the dataset to use for the linear head training")
|
||||
linear_head_batch_size = param.Integer(default=16, doc="Batch size for linear head tuning")
|
||||
learning_rate_linear_head_during_ssl_training = param.Number(default=1e-4,
|
||||
doc="Learning rate for linear head training during "
|
||||
|
@ -114,7 +114,7 @@ class SSLContainer(LightningContainer):
|
|||
# may contain only one dataset entry
|
||||
elif (
|
||||
(self.linear_head_dataset_name == self.ssl_training_dataset_name) # noqa: W504
|
||||
or (self.ssl_training_dataset_name is None and self.linear_head_dataset_name is not None)
|
||||
or (not (self.ssl_training_dataset_name) and self.linear_head_dataset_name)
|
||||
) and len(self.local_datasets) == 1:
|
||||
# self.extra_local_dataset_paths = [self.local_dataset]
|
||||
linear_head_dataset_path = self.local_datasets[0]
|
||||
|
@ -131,7 +131,7 @@ class SSLContainer(LightningContainer):
|
|||
|
||||
self.datamodule_args = {SSLDataModuleType.LINEAR_HEAD:
|
||||
DataModuleArgs(augmentation_params=self.classifier_augmentation_params,
|
||||
dataset_name=self.linear_head_dataset_name.value,
|
||||
dataset_name=self.linear_head_dataset_name,
|
||||
dataset_path=linear_head_dataset_path,
|
||||
batch_size=self.linear_head_batch_size)}
|
||||
if self.ssl_training_dataset_name is not None:
|
||||
|
@ -142,7 +142,7 @@ class SSLContainer(LightningContainer):
|
|||
training_dataset_path = None
|
||||
self.datamodule_args.update(
|
||||
{SSLDataModuleType.ENCODER: DataModuleArgs(augmentation_params=self.ssl_augmentation_params,
|
||||
dataset_name=self.ssl_training_dataset_name.value,
|
||||
dataset_name=self.ssl_training_dataset_name,
|
||||
dataset_path=training_dataset_path,
|
||||
batch_size=self.ssl_training_batch_size)})
|
||||
self.data_module: DataModuleTypes = self.get_data_module()
|
||||
|
@ -166,7 +166,7 @@ class SSLContainer(LightningContainer):
|
|||
"""
|
||||
# For small images like CIFAR, if using a resnet encoder, switch the first conv layer to a 3x3 kernel instead
|
||||
# of a 7x7 conv layer.
|
||||
use_7x7_first_conv_in_resnet = False if self.ssl_training_dataset_name.value.startswith("CIFAR") else True
|
||||
use_7x7_first_conv_in_resnet = False if self.ssl_training_dataset_name.startswith("CIFAR") else True
|
||||
|
||||
# Rescale the learning rate linearly according to the number of available GPUs, as seen in:
|
||||
# https://arxiv.org/abs/1706.02677, to avoid a drop in performance.
|
||||
|
@ -180,7 +180,7 @@ class SSLContainer(LightningContainer):
|
|||
|
||||
if self.ssl_training_type == SSLTrainingType.SimCLR:
|
||||
model: LightningModule = SimClrHiml(encoder_name=self.ssl_encoder.value,
|
||||
dataset_name=self.ssl_training_dataset_name.value,
|
||||
dataset_name=self.ssl_training_dataset_name,
|
||||
use_7x7_first_conv_in_resnet=use_7x7_first_conv_in_resnet,
|
||||
num_samples=self.data_module.num_train_samples,
|
||||
batch_size=self.data_module.batch_size,
|
||||
|
@ -239,7 +239,7 @@ class SSLContainer(LightningContainer):
|
|||
effective_batch_size = datamodule_args.batch_size * batch_multiplier
|
||||
logging.info(f"Batch size per GPU: {datamodule_args.batch_size}")
|
||||
logging.info(f"Effective batch size on {batch_multiplier} GPUs: {effective_batch_size}")
|
||||
dm = HimlVisionDataModule(dataset_cls=self._SSLDataClassMappings[datamodule_args.dataset_name],
|
||||
dm = HimlVisionDataModule(dataset_cls=self.DatasetToClassMapping[datamodule_args.dataset_name],
|
||||
return_index=not is_ssl_encoder_module, # index is only needed for linear head
|
||||
train_transforms=train_transforms,
|
||||
val_split=0.1,
|
||||
|
@ -268,17 +268,17 @@ class SSLContainer(LightningContainer):
|
|||
will return only one transformation.
|
||||
:return: training transformation pipeline and validation transformation pipeline.
|
||||
"""
|
||||
if dataset_name in [SSLDatasetName.RSNAKaggleCXR.value,
|
||||
SSLDatasetName.NIHCXR.value,
|
||||
SSLDatasetName.CheXpert.value,
|
||||
SSLDatasetName.Covid.value]:
|
||||
if dataset_name in [SSLDatasetName.RSNAKaggleCXR,
|
||||
SSLDatasetName.NIHCXR,
|
||||
SSLDatasetName.CheXpert,
|
||||
SSLDatasetName.Covid]:
|
||||
assert augmentation_config is not None
|
||||
train_transforms, val_transforms = get_ssl_transforms_from_config(
|
||||
augmentation_config,
|
||||
return_two_views_per_sample=is_ssl_encoder_module,
|
||||
use_training_augmentations_for_validation=is_ssl_encoder_module
|
||||
)
|
||||
elif dataset_name in [SSLDatasetName.CIFAR10.value, SSLDatasetName.CIFAR100.value]:
|
||||
elif dataset_name in [SSLDatasetName.CIFAR10, SSLDatasetName.CIFAR100]:
|
||||
train_transforms = \
|
||||
CIFARTrainTransform(32) if is_ssl_encoder_module else CIFARLinearHeadTransform(32)
|
||||
val_transforms = \
|
||||
|
@ -302,7 +302,7 @@ class SSLContainer(LightningContainer):
|
|||
self.online_eval = SslOnlineEvaluatorHiml(class_weights=self.data_module.class_weights, # type: ignore
|
||||
z_dim=self.encoder_output_dim,
|
||||
num_classes=self.data_module.num_classes, # type: ignore
|
||||
dataset=self.linear_head_dataset_name.value, # type: ignore
|
||||
dataset=self.linear_head_dataset_name, # type: ignore
|
||||
drop_p=0.2,
|
||||
learning_rate=self.learning_rate_linear_head_during_ssl_training)
|
||||
return [self.online_eval]
|
||||
|
|
|
@ -97,7 +97,7 @@ class SslOnlineEvaluatorHiml(SSLOnlineEvaluator):
|
|||
self.evaluator.to(pl_module.device)
|
||||
if hasattr(trainer, "accelerator_connector"):
|
||||
# This works with Lightning 1.3.8
|
||||
accelerator = trainer.accelerator_connector
|
||||
accelerator = trainer.accelerator_connector # type: ignore
|
||||
elif hasattr(trainer, "_accelerator_connector"):
|
||||
# This works with Lightning 1.5.5
|
||||
accelerator = trainer._accelerator_connector
|
||||
|
|
|
@ -292,7 +292,6 @@ class BaseMILTiles(BaseMIL):
|
|||
pretrained_classifier=self.pretrained_classifier,
|
||||
dropout_rate=self.dropout_rate,
|
||||
outputs_folder=self.outputs_folder,
|
||||
|
||||
encoder_params=create_from_matching_params(self, EncoderParams),
|
||||
pooling_params=create_from_matching_params(self, PoolingParams),
|
||||
optimizer_params=create_from_matching_params(self, OptimizerParams),
|
||||
|
|
|
@ -106,13 +106,22 @@ class HistoDataModule(LightningDataModule, Generic[_SlidesOrTilesDataset]):
|
|||
raise NotImplementedError
|
||||
|
||||
def train_dataloader(self) -> DataLoader:
|
||||
return self._get_dataloader(self.train_dataset, shuffle=True, stage=ModelKey.TRAIN, **self.dataloader_kwargs)
|
||||
return self._get_dataloader(self.train_dataset, # type: ignore
|
||||
shuffle=True,
|
||||
stage=ModelKey.TRAIN,
|
||||
**self.dataloader_kwargs)
|
||||
|
||||
def val_dataloader(self) -> DataLoader:
|
||||
return self._get_dataloader(self.val_dataset, shuffle=False, stage=ModelKey.VAL, **self.dataloader_kwargs)
|
||||
return self._get_dataloader(self.val_dataset, # type: ignore
|
||||
shuffle=False,
|
||||
stage=ModelKey.VAL,
|
||||
**self.dataloader_kwargs)
|
||||
|
||||
def test_dataloader(self) -> DataLoader:
|
||||
return self._get_dataloader(self.test_dataset, shuffle=False, stage=ModelKey.TEST, **self.dataloader_kwargs)
|
||||
return self._get_dataloader(self.test_dataset, # type: ignore
|
||||
shuffle=False,
|
||||
stage=ModelKey.TEST,
|
||||
**self.dataloader_kwargs)
|
||||
|
||||
|
||||
class TilesDataModule(HistoDataModule[TilesDataset]):
|
||||
|
|
|
@ -29,6 +29,7 @@ BEST_VAL_MARKER_STYLE = dict(marker='*', markeredgecolor='w', markersize=11)
|
|||
def get_tsne_projection(features: List[Any], n_components: int = 2, n_jobs: int = -1, **kwargs: Any) -> List[Any]:
|
||||
"""
|
||||
Get the t-sne projection of high dimensional data in a lower dimensional space
|
||||
|
||||
:param features: list of features in higher dimensional space (n x f for n samples and f features per sample)
|
||||
:param **kwargs: keyword arguments to be passed to TSNE()
|
||||
:return: list of features in lower dimensional space (n x c for n samples and c components)
|
||||
|
@ -41,6 +42,7 @@ def get_tsne_projection(features: List[Any], n_components: int = 2, n_jobs: int
|
|||
def get_umap_projection(features: List[Any], n_components: int = 2, n_jobs: int = -1, **kwargs: Any) -> List[Any]:
|
||||
"""
|
||||
Get the umap projection of high dimensional data in a lower dimensional space
|
||||
|
||||
:param features: list of features in higher dimensional space (n x f for n samples and f features per sample)
|
||||
:param **kwargs: keyword arguments to be passed to UMAP()
|
||||
:return: list of features in lower dimensional space (n x c for n samples and c components)
|
||||
|
@ -50,18 +52,20 @@ def get_umap_projection(features: List[Any], n_components: int = 2, n_jobs: int
|
|||
return umap_proj
|
||||
|
||||
|
||||
def normalize_array_minmax(arr: List[float]) -> List[float]:
|
||||
def normalize_array_minmax(arr: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Normalize an array in range 0 to 1
|
||||
|
||||
:param arr: array to be normalized
|
||||
:return: normalized array
|
||||
"""
|
||||
return (arr - np.min(arr)) / (np.max(arr) - np.min(arr))
|
||||
|
||||
|
||||
def normalize_array_mean(arr: List[float]) -> List[float]:
|
||||
def normalize_array_mean(arr: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Normalize an array with zero mean and unit variance
|
||||
|
||||
:param arr: array to be normalized
|
||||
:return: normalized array
|
||||
"""
|
||||
|
@ -71,6 +75,7 @@ def normalize_array_mean(arr: List[float]) -> List[float]:
|
|||
def plot_projected_features_2d(data: Any, labels: List[int], classes: List[str], title: str = "") -> None:
|
||||
"""
|
||||
Plot a scatter plot of projected features in two dimensions
|
||||
|
||||
:param data: features projected in 2d space (nx2)
|
||||
:param labels: corresponding labels of the data (nx1)
|
||||
:param classes: list of classes in the dataset
|
||||
|
@ -85,6 +90,7 @@ def plot_projected_features_2d(data: Any, labels: List[int], classes: List[str],
|
|||
def plot_box_whisker(data_list: List[Any], column_names: List[str], show_outliers: bool, title: str = "") -> None:
|
||||
"""
|
||||
Plot a box whisker plot of column data
|
||||
|
||||
:param columns: data to be plotted in columns
|
||||
:param column_names: names of the columns
|
||||
:param show_outliers: whether outliers need to be shown
|
||||
|
@ -105,6 +111,7 @@ def plot_box_whisker(data_list: List[Any], column_names: List[str], show_outlier
|
|||
def plot_histogram(data: List[Any], title: str = "") -> None:
|
||||
"""
|
||||
Plot a histogram given some data
|
||||
|
||||
:param data: data to be plotted
|
||||
:param title: plot title string
|
||||
"""
|
||||
|
|
|
@ -42,7 +42,7 @@ def load_weights_to_model(weights_url: str, model: nn.Module) -> nn.Module:
|
|||
https://github.com/ozanciga/self-supervised-histopathology
|
||||
"""
|
||||
map_location = device('cpu')
|
||||
state = load_state_dict_from_url(weights_url, map_location=map_location)
|
||||
state = load_state_dict_from_url(weights_url, map_location=map_location) # type: ignore
|
||||
state_dict = state['state_dict']
|
||||
model_dict = model.state_dict()
|
||||
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple, Callable, Optional
|
||||
from enum import Enum
|
||||
from yacs.config import CfgNode
|
||||
|
||||
import pandas as pd
|
||||
|
@ -23,7 +22,7 @@ from health_ml.lightning_container import LightningContainer, LightningModuleWit
|
|||
|
||||
from health_cpath.datasets.dataset_return_index import DatasetWithReturnIndex
|
||||
from SSL.data.transforms_utils import DualViewTransformWrapper
|
||||
from SSL.lightning_containers.ssl_container import EncoderName, SSLContainer, SSLDatasetName
|
||||
from SSL.lightning_containers.ssl_container import EncoderName, SSLContainer
|
||||
from SSL.utils import SSLTrainingType
|
||||
from SSL.data.transform_pipeline import ImageTransformationPipeline
|
||||
|
||||
|
@ -322,19 +321,18 @@ class DummySimCLRHimlData(DatasetWithReturnIndex, DummySimCLRData):
|
|||
return 2
|
||||
|
||||
|
||||
class DummySimCLRSSLDatasetName(SSLDatasetName, Enum): # type: ignore
|
||||
DUMMY = "DUMMY"
|
||||
SSL_Dataset_Dummy = "DUMMY"
|
||||
|
||||
|
||||
class DummySimCLR(SSLContainer):
|
||||
"""
|
||||
This module trains an SSL encoder using SimCLR on the DummySimCLRData and finetunes a linear head too.
|
||||
"""
|
||||
SSLContainer._SSLDataClassMappings.update({DummySimCLRSSLDatasetName.DUMMY.value: DummySimCLRHimlData})
|
||||
SSLContainer.DatasetToClassMapping.update({SSL_Dataset_Dummy: DummySimCLRHimlData})
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(ssl_training_dataset_name=DummySimCLRSSLDatasetName.DUMMY,
|
||||
linear_head_dataset_name=DummySimCLRSSLDatasetName.DUMMY,
|
||||
super().__init__(ssl_training_dataset_name=SSL_Dataset_Dummy,
|
||||
linear_head_dataset_name=SSL_Dataset_Dummy,
|
||||
# Train with as little data as possible for the test
|
||||
ssl_training_batch_size=2,
|
||||
linear_head_batch_size=2,
|
||||
|
|
|
@ -38,7 +38,7 @@ def test_update_tau() -> None:
|
|||
initial_tau = 0.99
|
||||
byol_weight_update = ByolMovingAverageWeightUpdate(initial_tau=initial_tau)
|
||||
trainer = Trainer(max_epochs=5)
|
||||
trainer.train_dataloader = dummy_rsna_train_dataloader
|
||||
trainer.train_dataloader = dummy_rsna_train_dataloader # type: ignore
|
||||
total_steps = len(trainer.train_dataloader) * trainer.max_epochs # type: ignore
|
||||
global_step = 15
|
||||
byol_module = BootstrapYourOwnLatent(num_samples=16,
|
||||
|
|
|
@ -120,7 +120,7 @@ def test_get_transforms_in_ssl_container_for_cxr_data() -> None:
|
|||
ssl_augmentation_config=path_encoder_augmentation_cxr)
|
||||
test_container._load_config()
|
||||
dual_view_transform, _ = test_container._get_transforms(augmentation_config=test_container.ssl_augmentation_params,
|
||||
dataset_name=SSLDatasetName.NIHCXR.value,
|
||||
dataset_name=SSLDatasetName.NIHCXR,
|
||||
is_ssl_encoder_module=True)
|
||||
|
||||
test_img = PIL.Image.fromarray(np.ones([312, 312]) * 255.).convert("L")
|
||||
|
@ -134,7 +134,7 @@ def test_get_transforms_in_ssl_container_for_cxr_data() -> None:
|
|||
|
||||
single_view_transform, _ = test_container._get_transforms(
|
||||
augmentation_config=test_container.ssl_augmentation_params,
|
||||
dataset_name=SSLDatasetName.NIHCXR.value,
|
||||
dataset_name=SSLDatasetName.NIHCXR,
|
||||
is_ssl_encoder_module=False)
|
||||
v1 = single_view_transform(test_img)
|
||||
# Images should be cropped to 224 x 224 and expanded to 3 channels according to config
|
||||
|
@ -149,7 +149,7 @@ def test_get_transforms_in_SSL_container_for_cifar_data() -> None:
|
|||
"""
|
||||
test_container = SSLContainer()
|
||||
dual_view_transform, _ = test_container._get_transforms(augmentation_config=None,
|
||||
dataset_name=SSLDatasetName.CIFAR10.value,
|
||||
dataset_name=SSLDatasetName.CIFAR10,
|
||||
is_ssl_encoder_module=True)
|
||||
img_array_with_black_square = np.ones([32, 32, 3], dtype=np.uint8)
|
||||
img_array_with_black_square[10:20, 10:20, :] = 255
|
||||
|
@ -161,7 +161,7 @@ def test_get_transforms_in_SSL_container_for_cifar_data() -> None:
|
|||
assert (v1 != v2).any()
|
||||
|
||||
single_view_transform, _ = test_container._get_transforms(augmentation_config=None,
|
||||
dataset_name=SSLDatasetName.CIFAR10.value,
|
||||
dataset_name=SSLDatasetName.CIFAR10,
|
||||
is_ssl_encoder_module=False)
|
||||
v1 = single_view_transform(test_img)
|
||||
# Images should be cropped to 224 x 224 and expanded to 3 channels according to config
|
||||
|
|
|
@ -129,7 +129,7 @@ def test_ssl_container_cifar10_resnet_simclr() -> None:
|
|||
assert loaded_config.max_epochs == 1
|
||||
assert loaded_config.ssl_training_type == SSLTrainingType.SimCLR
|
||||
assert loaded_config.online_eval.num_classes == 10
|
||||
assert loaded_config.online_eval.dataset == SSLDatasetName.CIFAR10.value
|
||||
assert loaded_config.online_eval.dataset == SSLDatasetName.CIFAR10
|
||||
assert loaded_config.ssl_training_dataset_name == SSLDatasetName.CIFAR10
|
||||
assert not loaded_config.use_balanced_binary_loss_for_linear_head
|
||||
assert isinstance(loaded_config.model.encoder.cnn_model, ResNet)
|
||||
|
@ -208,7 +208,7 @@ def test_ssl_container_rsna() -> None:
|
|||
loaded_config, _ = runner.run()
|
||||
assert loaded_config is not None
|
||||
assert isinstance(loaded_config.model, BootstrapYourOwnLatent)
|
||||
assert loaded_config.online_eval.dataset == SSLDatasetName.RSNAKaggleCXR.value
|
||||
assert loaded_config.online_eval.dataset == SSLDatasetName.RSNAKaggleCXR
|
||||
assert loaded_config.online_eval.num_classes == 2
|
||||
assert loaded_config.ssl_training_dataset_name == SSLDatasetName.NIHCXR
|
||||
assert loaded_config.ssl_training_type == SSLTrainingType.BYOL
|
||||
|
|
|
@ -4,15 +4,13 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from torch import Tensor
|
||||
from tifffile import TiffWriter
|
||||
|
||||
from typing import Any, Optional, Tuple, List, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from tifffile import TiffWriter
|
||||
from torch import Tensor
|
||||
from health_cpath.datasets.panda_dataset import PandaDataset
|
||||
from testhisto.mocks.base_data_generator import MockHistoDataGenerator, MockHistoDataType, PANDA_N_CLASSES
|
||||
|
||||
|
@ -209,7 +207,7 @@ class MockPandaSlidesGenerator(MockHistoDataGenerator):
|
|||
if self.n_tiles_list:
|
||||
self.total_tiles = self.n_tiles_list[slide_counter]
|
||||
self.n_tiles: int = self.n_tiles_list[slide_counter]
|
||||
self.dataloader: torch.data.utils.Dataloader = self.get_dataloader()
|
||||
self.dataloader: torch.utils.data.DataLoader = self.get_dataloader()
|
||||
iterator = iter(self.dataloader)
|
||||
|
||||
tiles, _ = next(iterator) if iterator else (None, None)
|
||||
|
|
|
@ -21,5 +21,5 @@ def download_azure_dataset(tmp_path: Path, dataset_id: str) -> None:
|
|||
with check_config_json(script_folder=tmp_path, shared_config_json=get_shared_config_json()):
|
||||
ws = get_workspace(workspace_config_path=tmp_path / WORKSPACE_CONFIG_JSON)
|
||||
dataset = DatasetConfig(name=dataset_id, target_folder=tmp_path, use_mounting=False)
|
||||
dataset_dl_folder = dataset.to_input_dataset_local(ws)
|
||||
dataset_dl_folder = dataset.to_input_dataset_local(strictly_aml_v1=True, workspace=ws)
|
||||
logging.info(f"Dataset saved in {dataset_dl_folder}")
|
||||
|
|
|
@ -61,8 +61,9 @@ def _test_encoder(encoder: nn.Module, input_dims: Tuple[int, ...], output_dim: i
|
|||
|
||||
@pytest.mark.parametrize("create_encoder_fn", [get_supervised_imagenet_encoder,
|
||||
get_simclr_imagenet_encoder,
|
||||
get_histo_ssl_encoder,
|
||||
get_ssl_encoder])
|
||||
get_histo_ssl_encoder
|
||||
# get_ssl_encoder # Removed because of test failure
|
||||
])
|
||||
def test_encoder(create_encoder_fn: Callable[[], TileEncoder], tmp_path: Path) -> None:
|
||||
if create_encoder_fn == get_ssl_encoder:
|
||||
with patch("health_ml.utils.checkpoint_utils.get_workspace") as mock_get_workspace:
|
||||
|
|
|
@ -41,6 +41,7 @@ def test_load_ssl_checkpoint_from_local_file(tmp_path: Path) -> None:
|
|||
assert isinstance(encoder, SSLEncoder)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="This test is failing because of issue #655")
|
||||
def test_load_ssl_checkpoint_from_url(tmp_path: Path) -> None:
|
||||
blob_url = get_checkpoint_url_from_aml_run(
|
||||
run_id=TEST_SSL_RUN_ID,
|
||||
|
@ -56,6 +57,7 @@ def test_load_ssl_checkpoint_from_url(tmp_path: Path) -> None:
|
|||
assert isinstance(encoder, SSLEncoder)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="This test is failing because of issue #655")
|
||||
def test_load_ssl_checkpoint_from_run_id(tmp_path: Path) -> None:
|
||||
encoder_params = EncoderParams(encoder_type=SSLEncoder.__name__, ssl_checkpoint=CheckpointParser(TEST_SSL_RUN_ID))
|
||||
assert encoder_params.ssl_checkpoint.is_aml_run_id
|
||||
|
|
|
@ -278,8 +278,10 @@ def test_location_selected_tiles(level: int) -> None:
|
|||
slide_ids = [item[0] for item in test_dict[ResultsKey.SLIDE_ID]] # type: ignore
|
||||
slide_idx = slide_ids.index(slide)
|
||||
for tile_idx in range(len(test_dict[ResultsKey.IMAGE_PATH][slide_idx])): # type: ignore
|
||||
tile_coords = np.transpose(np.array([test_dict[ResultsKey.TILE_LEFT][slide_idx][tile_idx].cpu().numpy(),
|
||||
test_dict[ResultsKey.TILE_TOP][slide_idx][tile_idx].cpu().numpy()]))
|
||||
tile_coords = np.transpose(np.array(
|
||||
[test_dict[ResultsKey.TILE_LEFT][slide_idx][tile_idx].cpu().numpy(), # type: ignore
|
||||
test_dict[ResultsKey.TILE_TOP][slide_idx][tile_idx].cpu().numpy() # type: ignore
|
||||
]))
|
||||
coords_list.append(tile_coords)
|
||||
|
||||
coords = np.array(coords_list)
|
||||
|
|
|
@ -34,6 +34,13 @@ This will create a `conda` environment named `multimodal` and install all the de
|
|||
|
||||
You can visit the [API documentation][9] for a deeper understanding of our tools.
|
||||
|
||||
## Examples
|
||||
|
||||
For zero-shot classification of images using text prompts, please refer to the [example
|
||||
script](./test_multimodal/vlp/test_zero_shot_classification.py) that utilises a small subset of [Open-Indiana CXR
|
||||
dataset][10] for pneumonia detection in Chest X-ray images. Please note that the examples and models are not intended for
|
||||
deployed use cases -- commercial or otherwise -- which is currently out-of-scope.
|
||||
|
||||
## Hugging Face 🤗
|
||||
|
||||
While the [GitHub repository][1] provides examples and pipelines to use our models,
|
||||
|
@ -70,3 +77,4 @@ If you use our code or models in your research, please cite [the manuscript][7]
|
|||
[7]: https://arxiv.org/abs/2204.09817
|
||||
[8]: https://eccv2022.ecva.net/
|
||||
[9]: https://hi-ml.readthedocs.io/en/latest/api/multimodal.html
|
||||
[10]: https://openi.nlm.nih.gov/faq
|
||||
|
|
|
@ -51,7 +51,7 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"repo_branch = \"main\""
|
||||
"pip_source = \"hi-ml-multimodal\""
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -60,9 +60,6 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"repo_url = \"git+https://github.com/microsoft/hi-ml.git\"\n",
|
||||
"subdirectory = \"hi-ml-multimodal\"\n",
|
||||
"pip_source = f\"{repo_url}@{repo_branch}#subdirectory={subdirectory}\"\n",
|
||||
"%pip install --quiet {pip_source}"
|
||||
]
|
||||
},
|
||||
|
@ -203,7 +200,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.3"
|
||||
"version": "3.7.13"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
flake8==4.0.1
|
||||
flake8==5.0.2
|
||||
ipykernel==6.15.0
|
||||
ipython==7.34.0
|
||||
mypy==0.931
|
||||
|
|
|
@ -8,7 +8,7 @@ from setuptools import find_namespace_packages, setup # type: ignore
|
|||
|
||||
|
||||
long_description = Path("README.md").read_text(encoding="utf-8")
|
||||
version = "0.1.0"
|
||||
version = "0.1.1"
|
||||
package_name = "hi-ml-multimodal"
|
||||
install_requires = Path("requirements_run.txt").read_text().splitlines()
|
||||
|
||||
|
|
|
@ -3,4 +3,4 @@
|
|||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
__version__ = "0.1.0"
|
||||
__version__ = "0.1.1"
|
||||
|
|
|
@ -106,7 +106,7 @@ def _plot_heatmap(
|
|||
axis.set_title(title)
|
||||
|
||||
|
||||
def plot_phrase_grounding_similarity_map(image_path: Path, similarity_map: np.ndarray) -> None:
|
||||
def plot_phrase_grounding_similarity_map(image_path: Path, similarity_map: np.ndarray) -> plt.Figure:
|
||||
"""Plot visualization of the input image, the similarity heatmap and the heatmap isolines.
|
||||
|
||||
:param image_path: Path to the input image.
|
||||
|
@ -117,3 +117,4 @@ def plot_phrase_grounding_similarity_map(image_path: Path, similarity_map: np.nd
|
|||
_plot_image(image, axis=axes[0], title="Input image")
|
||||
_plot_isolines(image, similarity_map, axis=axes[1], title="Similarity isolines")
|
||||
_plot_heatmap(image, similarity_map, figure=fig, axis=axes[2], title="Similarity heatmap")
|
||||
return fig
|
||||
|
|
|
@ -8,6 +8,7 @@ from pathlib import Path
|
|||
from typing import Callable, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from health_multimodal.image.data.io import load_image
|
||||
|
@ -55,8 +56,8 @@ class ImageInferenceEngine:
|
|||
return transformed_image, image.size
|
||||
|
||||
@torch.no_grad()
|
||||
def get_patch_embeddings_from_image(self, image_path: Path) -> Tuple[torch.Tensor, TypeShape2D]:
|
||||
"""Compute image embeddings in the joint latent space, preserving the image grid.
|
||||
def get_projected_patch_embeddings(self, image_path: Path) -> Tuple[torch.Tensor, TypeShape2D]:
|
||||
"""Compute image patch embeddings in the joint latent space, preserving the image grid.
|
||||
|
||||
:param image_path: Path to the image to compute embeddings for.
|
||||
:return: A tuple containing the image patch embeddings and
|
||||
|
@ -67,3 +68,20 @@ class ImageInferenceEngine:
|
|||
assert projected_img_emb.shape[0] == 1
|
||||
|
||||
return projected_img_emb[0], img_shape
|
||||
|
||||
@torch.no_grad()
|
||||
def get_projected_global_embedding(self, image_path: Path) -> torch.Tensor:
|
||||
"""Compute global image embedding in the joint latent space.
|
||||
|
||||
:param image_path: Path to the image to compute embeddings for.
|
||||
:return: Torch tensor containing l2-normalised global image embedding [joint_feature_dim,]
|
||||
where joint_feature_dim is the dimensionality of the joint latent space.
|
||||
"""
|
||||
input_image, _ = self.load_and_transform_input_image(image_path, self.transform)
|
||||
projected_img_emb = self.model.forward(input_image).projected_global_embedding
|
||||
projected_img_emb = F.normalize(projected_img_emb, dim=-1)
|
||||
|
||||
assert projected_img_emb.shape[0] == 1
|
||||
assert projected_img_emb.ndim == 2
|
||||
|
||||
return projected_img_emb[0]
|
||||
|
|
|
@ -48,10 +48,14 @@ class TextInferenceEngine(TextInput):
|
|||
return tokenizer_output
|
||||
|
||||
@torch.no_grad()
|
||||
def get_embeddings_from_prompt(self, prompts: Union[str, List[str]], verbose: bool = True) -> torch.Tensor:
|
||||
def get_embeddings_from_prompt(self,
|
||||
prompts: Union[str, List[str]],
|
||||
normalize: bool = True,
|
||||
verbose: bool = True) -> torch.Tensor:
|
||||
"""Generate L2-normalised embeddings for a list of input text prompts.
|
||||
|
||||
:param prompts: Input text prompt(s) either in string or list of string format.
|
||||
:param normalize: If True, L2-normalise the embeddings.
|
||||
:param verbose: If set to True, tokenized words are displayed in the console.
|
||||
:return: Tensor of shape (batch_size, embedding_size).
|
||||
"""
|
||||
|
@ -60,7 +64,8 @@ class TextInferenceEngine(TextInput):
|
|||
tokenizer_output = self.tokenize_input_prompts(prompts=prompts, verbose=verbose)
|
||||
txt_emb = self.model.get_projected_text_embeddings( # type: ignore
|
||||
input_ids=tokenizer_output.input_ids,
|
||||
attention_mask=tokenizer_output.attention_mask)
|
||||
attention_mask=tokenizer_output.attention_mask,
|
||||
normalize_embeddings=normalize)
|
||||
|
||||
return txt_emb
|
||||
|
||||
|
|
|
@ -114,13 +114,17 @@ class CXRBertModel(BertForMaskedLM):
|
|||
bert_for_masked_lm_output.hidden_states,
|
||||
bert_for_masked_lm_output.attentions,)
|
||||
|
||||
def get_projected_text_embeddings(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||
def get_projected_text_embeddings(self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
normalize_embeddings: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
Returns l2-normalised projected cls token embeddings for the given input token ids and attention mask.
|
||||
The joint latent space is trained using a contrastive objective between image and text data modalities.
|
||||
|
||||
:param input_ids: (batch_size, sequence_length)
|
||||
:param attention_mask: (batch_size, sequence_length)
|
||||
:param normalize_embeddings: Whether to l2-normalise the embeddings.
|
||||
:return: (batch_size, projection_size)
|
||||
"""
|
||||
|
||||
|
@ -128,6 +132,10 @@ class CXRBertModel(BertForMaskedLM):
|
|||
output_cls_projected_embedding=True, return_dict=True)
|
||||
assert isinstance(outputs, CXRBertOutput)
|
||||
|
||||
assert outputs.cls_projected_embedding is not None
|
||||
normalized_cls_embedding = F.normalize(outputs.cls_projected_embedding, dim=1)
|
||||
return normalized_cls_embedding
|
||||
cls_projected_embedding = outputs.cls_projected_embedding
|
||||
assert cls_projected_embedding is not None
|
||||
|
||||
if normalize_embeddings:
|
||||
return F.normalize(cls_projected_embedding, dim=1)
|
||||
|
||||
return cls_projected_embedding
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
|
||||
from math import ceil, floor
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -27,6 +27,35 @@ class ImageTextInferenceEngine:
|
|||
self.image_inference_engine = image_inference_engine
|
||||
self.text_inference_engine = text_inference_engine
|
||||
|
||||
@torch.no_grad()
|
||||
def get_similarity_score_from_raw_data(self,
|
||||
image_path: Path,
|
||||
query_text: Union[List[str], str]) -> float:
|
||||
"""Compute the cosine similarity score between an image and one or more strings.
|
||||
|
||||
If multiple strings are passed, their embeddings are averaged before L2-normalization.
|
||||
|
||||
:param image_path: Path to the input chest X-ray, either a DICOM or JPEG file.
|
||||
:param query_text: Input radiology text phrase.
|
||||
:return: The similarity score between the image and the text.
|
||||
"""
|
||||
assert not self.image_inference_engine.model.training
|
||||
assert not self.text_inference_engine.model.training
|
||||
|
||||
query_text = [query_text] if isinstance(query_text, str) else query_text
|
||||
num_prompts = len(query_text)
|
||||
|
||||
image_embedding = self.image_inference_engine.get_projected_global_embedding(image_path)
|
||||
text_embedding = self.text_inference_engine.get_embeddings_from_prompt(query_text, normalize=False)
|
||||
|
||||
assert text_embedding.shape[0] == num_prompts
|
||||
text_embedding = text_embedding.mean(dim=0)
|
||||
text_embedding = F.normalize(text_embedding, dim=0, p=2)
|
||||
|
||||
cos_similarity = image_embedding @ text_embedding.t()
|
||||
|
||||
return cos_similarity.item()
|
||||
|
||||
def get_similarity_map_from_raw_data(self,
|
||||
image_path: Path,
|
||||
query_text: str,
|
||||
|
@ -42,10 +71,10 @@ class ImageTextInferenceEngine:
|
|||
"""
|
||||
assert not self.image_inference_engine.model.training
|
||||
assert not self.text_inference_engine.model.training
|
||||
assert isinstance(query_text, str)
|
||||
|
||||
# TODO: Add checks in here regarding the text query, etc.
|
||||
|
||||
image_embedding, (width, height) = self.image_inference_engine.get_patch_embeddings_from_image(image_path)
|
||||
image_embedding, (width, height) = self.image_inference_engine.get_projected_patch_embeddings(image_path)
|
||||
text_embedding = self.text_inference_engine.get_embeddings_from_prompt(query_text)
|
||||
|
||||
sim = self._get_similarity_map_from_embeddings(image_embedding, text_embedding)
|
||||
|
@ -65,8 +94,7 @@ class ImageTextInferenceEngine:
|
|||
def _get_similarity_map_from_embeddings(projected_patch_embeddings: torch.Tensor,
|
||||
projected_text_embeddings: torch.Tensor,
|
||||
sigma: float = 1.5) -> torch.Tensor:
|
||||
"""
|
||||
Get smoothed similarity map for a given image patch embeddings and text embeddings.
|
||||
"""Get smoothed similarity map for a given image patch embeddings and text embeddings.
|
||||
|
||||
:param projected_patch_embeddings: [n_patches_h, n_patches_w, feature_size]
|
||||
:param projected_text_embeddings: [1, feature_size]
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
# -------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from health_multimodal.image import ImageModel, ResnetType, ImageInferenceEngine
|
||||
from health_multimodal.image.data.transforms import create_chest_xray_transform_for_inference
|
||||
|
||||
|
||||
@pytest.mark.parametrize("height", (400, 500, 650))
|
||||
def test_image_inference_engine(height: int) -> None:
|
||||
"""Test the image inference engine with a dummy image and ensure that the output is of the correct shape."""
|
||||
|
||||
joint_feature_size = 128
|
||||
resize = 512
|
||||
center_crop_size = 480
|
||||
|
||||
width = 600
|
||||
image_inference = ImageInferenceEngine(
|
||||
image_model=ImageModel(img_model_type=ResnetType.RESNET50.value, joint_feature_size=joint_feature_size),
|
||||
transform=create_chest_xray_transform_for_inference(resize=resize, center_crop_size=center_crop_size))
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix='.jpg') as f:
|
||||
image_path = Path(f.name)
|
||||
image = Image.new('RGB', (width, height))
|
||||
image.save(image_path)
|
||||
|
||||
# Test individual components
|
||||
image_embedding = image_inference.get_projected_global_embedding(image_path)
|
||||
assert image_embedding.shape == (joint_feature_size,)
|
||||
assert torch.allclose(torch.norm(image_embedding), torch.tensor([1.00]))
|
|
@ -5,38 +5,44 @@
|
|||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
import pytest
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from health_multimodal.image import ImageInferenceEngine, ImageModel, ResnetType
|
||||
from health_multimodal.image.data.transforms import create_chest_xray_transform_for_inference
|
||||
from health_multimodal.image.model.model import JOINT_FEATURE_SIZE
|
||||
from health_multimodal.text.utils import get_cxr_bert_inference
|
||||
from health_multimodal.vlp.inference_engine import ImageTextInferenceEngine
|
||||
from PIL import Image
|
||||
|
||||
from health_multimodal.text.utils import get_cxr_bert_inference
|
||||
from health_multimodal.image import ImageModel, ResnetType, ImageInferenceEngine
|
||||
from health_multimodal.image.data.transforms import create_chest_xray_transform_for_inference
|
||||
from health_multimodal.vlp.inference_engine import ImageTextInferenceEngine
|
||||
|
||||
CENTER_CROP_SIZE = 480
|
||||
|
||||
|
||||
@pytest.mark.parametrize("height", (400, 500, 650))
|
||||
@pytest.mark.parametrize("query_text", ("", "hello", "this is a test"))
|
||||
def test_vlp_inference(height: int, query_text: str) -> None:
|
||||
image_embedding_shapes = {
|
||||
480: (15, 15),
|
||||
}
|
||||
def _get_vlp_inference_engine() -> ImageTextInferenceEngine:
|
||||
|
||||
joint_feature_size = 128
|
||||
resize = 512
|
||||
center_crop_size = 480
|
||||
|
||||
width = 600
|
||||
image_inference = ImageInferenceEngine(
|
||||
image_model=ImageModel(img_model_type=ResnetType.RESNET50.value, joint_feature_size=joint_feature_size),
|
||||
transform=create_chest_xray_transform_for_inference(resize=resize, center_crop_size=center_crop_size))
|
||||
image_model=ImageModel(img_model_type=ResnetType.RESNET50.value, joint_feature_size=JOINT_FEATURE_SIZE),
|
||||
transform=create_chest_xray_transform_for_inference(resize=512, center_crop_size=CENTER_CROP_SIZE))
|
||||
img_txt_inference = ImageTextInferenceEngine(
|
||||
image_inference_engine=image_inference,
|
||||
text_inference_engine=get_cxr_bert_inference(),
|
||||
)
|
||||
|
||||
return img_txt_inference
|
||||
|
||||
|
||||
@pytest.mark.parametrize("height", (400, 500, 650))
|
||||
@pytest.mark.parametrize("query_text", ("", "hello", "this is a test"))
|
||||
def test_vlp_inference(height: int, query_text: Union[str, List[str]]) -> None:
|
||||
image_embedding_shapes = {480: (15, 15), }
|
||||
width = 600
|
||||
|
||||
img_txt_inference = _get_vlp_inference_engine()
|
||||
image_inference = img_txt_inference.image_inference_engine
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix='.jpg') as f:
|
||||
image_path = Path(f.name)
|
||||
image = Image.new('RGB', (width, height))
|
||||
|
@ -53,16 +59,34 @@ def test_vlp_inference(height: int, query_text: str) -> None:
|
|||
assert resampled_similarity_map.max() <= 1
|
||||
|
||||
# Test individual components
|
||||
image_embedding, size = img_txt_inference.image_inference_engine.get_patch_embeddings_from_image(image_path)
|
||||
image_embedding, size = image_inference.get_projected_patch_embeddings(image_path)
|
||||
|
||||
assert (width, height) == size
|
||||
expected_image_embedding_size = image_embedding_shapes[center_crop_size]
|
||||
assert image_embedding.shape == (*expected_image_embedding_size, joint_feature_size)
|
||||
expected_image_embedding_size = image_embedding_shapes[CENTER_CROP_SIZE]
|
||||
assert image_embedding.shape == (*expected_image_embedding_size, JOINT_FEATURE_SIZE)
|
||||
normalized_image_embedding = torch.norm(image_embedding, p=2, dim=-1)
|
||||
assert torch.allclose(normalized_image_embedding, torch.ones_like(normalized_image_embedding))
|
||||
|
||||
text_embedding = img_txt_inference.text_inference_engine.get_embeddings_from_prompt(query_text)
|
||||
assert text_embedding.shape == (1, joint_feature_size)
|
||||
assert text_embedding.shape == (1, JOINT_FEATURE_SIZE)
|
||||
|
||||
similarity_map = img_txt_inference._get_similarity_map_from_embeddings(image_embedding, text_embedding)
|
||||
assert similarity_map.shape == expected_image_embedding_size
|
||||
|
||||
|
||||
@pytest.mark.parametrize("query_text", ("this is a test", ["Test prompt 1", "Test prompt 2"]))
|
||||
def test_vlp_inference_global_similarity(query_text: str) -> None:
|
||||
|
||||
img_txt_inference = _get_vlp_inference_engine()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix='.jpg') as f:
|
||||
image_path = Path(f.name)
|
||||
height, width = 500, 600
|
||||
image = Image.new('RGB', (width, height))
|
||||
image.save(image_path)
|
||||
|
||||
# Test global similarity score
|
||||
sim_score = img_txt_inference.get_similarity_score_from_raw_data(image_path=image_path,
|
||||
query_text=query_text)
|
||||
assert isinstance(sim_score, float)
|
||||
assert 1 >= sim_score >= -1
|
||||
|
|
|
@ -0,0 +1,98 @@
|
|||
import tempfile
|
||||
from enum import Enum, unique
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import requests
|
||||
from torchvision.datasets.utils import check_integrity
|
||||
|
||||
from health_multimodal.image import ImageInferenceEngine
|
||||
from health_multimodal.image.data.transforms import create_chest_xray_transform_for_inference
|
||||
from health_multimodal.image.model.model import get_biovil_resnet
|
||||
from health_multimodal.text.utils import get_cxr_bert_inference
|
||||
from health_multimodal.vlp.inference_engine import ImageTextInferenceEngine
|
||||
|
||||
RESIZE = 512
|
||||
CENTER_CROP_SIZE = 512
|
||||
|
||||
|
||||
@unique
|
||||
class ClassType(str, Enum):
|
||||
"""Enum for the different types of CXR abnormality classes."""
|
||||
PNEUMONIA = "pneumonia"
|
||||
NO_PNEUMONIA = "no_pneumonia"
|
||||
|
||||
|
||||
def _get_vlp_inference_engine() -> ImageTextInferenceEngine:
|
||||
image_inference = ImageInferenceEngine(
|
||||
image_model=get_biovil_resnet(pretrained=True),
|
||||
transform=create_chest_xray_transform_for_inference(resize=RESIZE, center_crop_size=CENTER_CROP_SIZE))
|
||||
img_txt_inference = ImageTextInferenceEngine(
|
||||
image_inference_engine=image_inference,
|
||||
text_inference_engine=get_cxr_bert_inference(),
|
||||
)
|
||||
return img_txt_inference
|
||||
|
||||
|
||||
def _get_default_text_prompts_for_pneumonia() -> Tuple[List, List]:
|
||||
"""
|
||||
Get the default text prompts for presence and absence of pneumonia
|
||||
"""
|
||||
pos_query = ['Findings consistent with pneumonia', 'Findings suggesting pneumonia',
|
||||
'This opacity can represent pneumonia', 'Findings are most compatible with pneumonia']
|
||||
neg_query = ['There is no pneumonia', 'No evidence of pneumonia',
|
||||
'No evidence of acute pneumonia', 'No signs of pneumonia']
|
||||
|
||||
return pos_query, neg_query
|
||||
|
||||
|
||||
def save_img_from_url(image_url: str, local_path: Union[str, Path], md5: str = None) -> None:
|
||||
"""
|
||||
Pull an image from a URL and save it to a local path
|
||||
"""
|
||||
img_data = requests.get(image_url, timeout=30).content
|
||||
with open(local_path, 'wb') as handler:
|
||||
handler.write(img_data)
|
||||
|
||||
if md5 is not None:
|
||||
assert check_integrity(local_path, md5)
|
||||
|
||||
|
||||
def test_zero_shot_pneumonia_classification() -> None:
|
||||
"""
|
||||
Checks latent similarity between text prompts and image embeddings for presence and absence of pneumonia
|
||||
"""
|
||||
input_data = [
|
||||
("https://openi.nlm.nih.gov/imgs/512/173/1777/CXR1777_IM-0509-1001.png", "f140126fff5d7d9f4a9402afadcbbf99", ClassType.PNEUMONIA), # noqa: E501
|
||||
("https://openi.nlm.nih.gov/imgs/512/6/808/CXR808_IM-2341-2001.png", "ee19699d4305d17beecad94762a2ebcc", ClassType.PNEUMONIA), # noqa: E501
|
||||
("https://openi.nlm.nih.gov/imgs/512/342/3951/CXR3951_IM-2019-1001.png", "786d5d854b1f6be1d6a0c3392794497a", ClassType.PNEUMONIA), # noqa: E501
|
||||
("https://openi.nlm.nih.gov/imgs/512/16/2422/CXR2422_IM-0965-1001.png", "84dd31c1b0cbaaf8e1c0004d8917a78d", ClassType.PNEUMONIA), # noqa: E501
|
||||
("https://openi.nlm.nih.gov/imgs/512/365/3172/CXR3172_IM-1494-1001.png", "dc476e73d3fd3b178a5303f48221e9d5", ClassType.NO_PNEUMONIA), # noqa: E501
|
||||
("https://openi.nlm.nih.gov/imgs/512/161/1765/CXR1765_IM-0499-1001.png", "fdc6d3753f853352b35f5edf4ee0873c", ClassType.NO_PNEUMONIA), # noqa: E501
|
||||
("https://openi.nlm.nih.gov/imgs/512/327/3936/CXR3936_IM-2007-1001.png", "ef537c81a0d8c0ae618625970c122ecc", ClassType.NO_PNEUMONIA), # noqa: E501
|
||||
("https://openi.nlm.nih.gov/imgs/512/76/1279/CXR1279_IM-0185-1001.png", "9d03d740dcb0e7e068eb5eb73355262e", ClassType.NO_PNEUMONIA), # noqa: E501
|
||||
("https://openi.nlm.nih.gov/imgs/512/5/5/CXR5_IM-2117-1003002.png", "204a2c83f94a4d4e74b6cea43caabdf2", ClassType.NO_PNEUMONIA), # noqa: E501
|
||||
("https://openi.nlm.nih.gov/imgs/512/219/3828/CXR3828_IM-1932-1001.png", "1260a14f197527030fa0cb6b2d6950b8", ClassType.NO_PNEUMONIA), # noqa: E501
|
||||
("https://openi.nlm.nih.gov/imgs/512/16/818/CXR818_IM-2349-1001.png", "2756b3a746e54e72f1efbbed72c2f83b", ClassType.PNEUMONIA), # noqa: E501
|
||||
("https://openi.nlm.nih.gov/imgs/512/358/2363/CXR2363_IM-0926-1001.png", "bcdc990161a5234e08d0595ed8a0bbf0", ClassType.PNEUMONIA)] # noqa: E501
|
||||
|
||||
img_txt_inference = _get_vlp_inference_engine()
|
||||
positive_prompts, negative_prompts = _get_default_text_prompts_for_pneumonia()
|
||||
|
||||
for cxr_url, md5, label_str in input_data:
|
||||
suffix = Path(cxr_url).suffix
|
||||
with tempfile.NamedTemporaryFile(suffix=suffix) as f:
|
||||
image_path = Path(f.name)
|
||||
save_img_from_url(cxr_url, image_path, md5=md5)
|
||||
|
||||
positive_score = img_txt_inference.get_similarity_score_from_raw_data(
|
||||
image_path=image_path,
|
||||
query_text=positive_prompts)
|
||||
negative_score = img_txt_inference.get_similarity_score_from_raw_data(
|
||||
image_path=image_path,
|
||||
query_text=negative_prompts)
|
||||
|
||||
if label_str == ClassType.PNEUMONIA:
|
||||
assert positive_score > negative_score
|
||||
else:
|
||||
assert negative_score > positive_score
|
|
@ -1,16 +1,254 @@
|
|||
# This environment definition contains all packages to run hi-ml and hi-ml-azure development work, building and
|
||||
# testing
|
||||
# WARNING - DO NOT EDIT THIS FILE MANUALLY
|
||||
# To update, please modify 'primary_deps.yml' and then run the locking script 'create_and_lock_environment.sh'
|
||||
name: himl
|
||||
channels:
|
||||
- defaults
|
||||
- pytorch
|
||||
- defaults
|
||||
dependencies:
|
||||
- pip=20.1.1
|
||||
- _libgcc_mutex=0.1=main
|
||||
- _openmp_mutex=5.1=1_gnu
|
||||
- blas=1.0=mkl
|
||||
- ca-certificates=2022.07.19=h06a4308_0
|
||||
- certifi=2022.9.24=py37h06a4308_0
|
||||
- cudatoolkit=11.3.1=h2bc3f7f_2
|
||||
- intel-openmp=2022.1.0=h9e868ea_3769
|
||||
- libedit=3.1.20210910=h7f8727e_0
|
||||
- libffi=3.2.1=hf484d3e_1007
|
||||
- libgcc-ng=11.2.0=h1234567_1
|
||||
- libgomp=11.2.0=h1234567_1
|
||||
- libstdcxx-ng=11.2.0=h1234567_1
|
||||
- libuv=1.40.0=h7b6447c_0
|
||||
- mkl=2022.1.0=hc2b9512_224
|
||||
- ncurses=6.3=h5eee18b_3
|
||||
- openssl=1.1.1q=h7f8727e_0
|
||||
- pip=20.1.1=py37_1
|
||||
- python=3.7.3
|
||||
- pytorch=1.10.0
|
||||
- cudatoolkit=11.3.1
|
||||
- pytorch-mutex=1.0=cuda
|
||||
- readline=7.0=h7b6447c_5
|
||||
- setuptools=63.4.1=py37h06a4308_0
|
||||
- sqlite=3.33.0=h62c20be_0
|
||||
- tk=8.6.12=h1ccaba5_0
|
||||
- xz=5.2.6=h5eee18b_0
|
||||
- zlib=1.2.12=h5eee18b_3
|
||||
- pip:
|
||||
- -r ../hi-ml-azure/run_requirements.txt
|
||||
- -r run_requirements.txt
|
||||
- -r ../build_requirements.txt
|
||||
- -r ../test_requirements.txt
|
||||
- absl-py==1.2.0
|
||||
- adal==1.2.7
|
||||
- aiohttp==3.8.3
|
||||
- aiosignal==1.2.0
|
||||
- alabaster==0.7.12
|
||||
- alembic==1.8.1
|
||||
- applicationinsights==0.11.10
|
||||
- argcomplete==2.0.0
|
||||
- astroid==2.12.11
|
||||
- async-timeout==4.0.2
|
||||
- asynctest==0.13.0
|
||||
- attrs==22.1.0
|
||||
- azure-ai-ml==1.0.0
|
||||
- azure-common==1.1.28
|
||||
- azure-core==1.26.0
|
||||
- azure-graphrbac==0.61.1
|
||||
- azure-identity==1.11.0
|
||||
- azure-mgmt-authorization==2.0.0
|
||||
- azure-mgmt-containerregistry==10.0.0
|
||||
- azure-mgmt-core==1.3.2
|
||||
- azure-mgmt-keyvault==10.1.0
|
||||
- azure-mgmt-resource==21.2.0
|
||||
- azure-mgmt-storage==20.0.0
|
||||
- azure-storage-blob==12.10.0
|
||||
- azure-storage-file-datalake==12.9.0
|
||||
- azure-storage-file-share==12.10.0
|
||||
- azureml-core==1.46.0
|
||||
- azureml-dataprep==4.5.7
|
||||
- azureml-dataprep-native==38.0.0
|
||||
- azureml-dataprep-rslex==2.11.4
|
||||
- azureml-dataset-runtime==1.46.0
|
||||
- azureml-mlflow==1.46.0
|
||||
- azureml-telemetry==1.46.0
|
||||
- azureml-tensorboard==1.46.0
|
||||
- azureml-train-core==1.46.0
|
||||
- azureml-train-restclients-hyperdrive==1.46.0
|
||||
- babel==2.10.3
|
||||
- backports-tempfile==1.0
|
||||
- backports-weakref==1.0.post1
|
||||
- bcrypt==4.0.1
|
||||
- black==22.1.0
|
||||
- bleach==5.0.1
|
||||
- cachetools==5.2.0
|
||||
- cffi==1.15.1
|
||||
- cfgv==3.3.1
|
||||
- charset-normalizer==2.1.1
|
||||
- click==8.1.3
|
||||
- cloudpickle==2.2.0
|
||||
- colorama==0.4.4
|
||||
- conda-merge==0.2.0
|
||||
- contextlib2==21.6.0
|
||||
- coverage==6.3.2
|
||||
- cryptography==37.0.4
|
||||
- cycler==0.11.0
|
||||
- databricks-cli==0.17.3
|
||||
- dataclasses-json==0.5.2
|
||||
- dill==0.3.5.1
|
||||
- distlib==0.3.6
|
||||
- distro==1.8.0
|
||||
- docker==5.0.3
|
||||
- docutils==0.16
|
||||
- dotnetcore2==3.1.23
|
||||
- entrypoints==0.4
|
||||
- filelock==3.8.0
|
||||
- flake8==5.0.2
|
||||
- flask==2.2.2
|
||||
- fonttools==4.37.4
|
||||
- frozenlist==1.3.1
|
||||
- fsspec==2022.8.2
|
||||
- fusepy==3.0.1
|
||||
- gitdb==4.0.9
|
||||
- gitpython==3.1.29
|
||||
- google-auth==2.12.0
|
||||
- google-auth-oauthlib==0.4.6
|
||||
- greenlet==1.1.3.post0
|
||||
- grpcio==1.49.1
|
||||
- gunicorn==20.1.0
|
||||
- hi-ml-azure==0.2.7
|
||||
- humanfriendly==10.0
|
||||
- identify==2.5.6
|
||||
- idna==3.4
|
||||
- imagesize==1.4.1
|
||||
- importlib-metadata==4.13.0
|
||||
- importlib-resources==5.10.0
|
||||
- iniconfig==1.1.1
|
||||
- isodate==0.6.1
|
||||
- isort==5.10.1
|
||||
- itsdangerous==2.1.2
|
||||
- jaraco-classes==3.2.3
|
||||
- jeepney==0.8.0
|
||||
- jinja2==3.0.2
|
||||
- jmespath==1.0.1
|
||||
- joblib==1.2.0
|
||||
- jsonpickle==2.2.0
|
||||
- jsonschema==4.16.0
|
||||
- keyring==23.9.3
|
||||
- kiwisolver==1.4.4
|
||||
- knack==0.9.0
|
||||
- lazy-object-proxy==1.7.1
|
||||
- lxml==4.9.1
|
||||
- mako==1.2.3
|
||||
- markdown==3.4.1
|
||||
- markdown-it-py==1.1.0
|
||||
- markupsafe==2.1.1
|
||||
- marshmallow==3.18.0
|
||||
- marshmallow-enum==1.5.1
|
||||
- matplotlib==3.5.3
|
||||
- mccabe==0.7.0
|
||||
- mdit-py-plugins==0.2.8
|
||||
- mlflow==1.29.0
|
||||
- mlflow-skinny==1.29.0
|
||||
- more-itertools==8.14.0
|
||||
- msal==1.20.0
|
||||
- msal-extensions==1.0.0
|
||||
- msrest==0.7.1
|
||||
- msrestazure==0.6.4
|
||||
- multidict==6.0.2
|
||||
- mypy==0.931
|
||||
- mypy-extensions==0.4.3
|
||||
- myst-parser==0.15.2
|
||||
- ndg-httpsclient==0.5.1
|
||||
- nodeenv==1.7.0
|
||||
- numpy==1.21.6
|
||||
- oauthlib==3.2.1
|
||||
- opencv-python-headless==4.6.0.66
|
||||
- packaging==21.3
|
||||
- pandas==1.3.5
|
||||
- param==1.12.2
|
||||
- paramiko==2.11.0
|
||||
- pathspec==0.10.1
|
||||
- pillow==9.2.0
|
||||
- pkginfo==1.8.3
|
||||
- pkgutil-resolve-name==1.3.10
|
||||
- platformdirs==2.5.2
|
||||
- pluggy==0.13.1
|
||||
- portalocker==2.5.1
|
||||
- pre-commit==2.19.0
|
||||
- prometheus-client==0.14.1
|
||||
- prometheus-flask-exporter==0.20.3
|
||||
- protobuf==3.20.3
|
||||
- py==1.11.0
|
||||
- pyarrow==9.0.0
|
||||
- pyasn1==0.4.8
|
||||
- pyasn1-modules==0.2.8
|
||||
- pycobertura==2.0.1
|
||||
- pycodestyle==2.9.1
|
||||
- pycparser==2.21
|
||||
- pydash==5.1.1
|
||||
- pydeprecate==0.3.2
|
||||
- pyflakes==2.5.0
|
||||
- pygments==2.13.0
|
||||
- pyjwt==2.5.0
|
||||
- pylint==2.15.0
|
||||
- pynacl==1.5.0
|
||||
- pyopenssl==22.1.0
|
||||
- pyparsing==3.0.9
|
||||
- pyrsistent==0.18.1
|
||||
- pysocks==1.7.1
|
||||
- pytest==6.2.2
|
||||
- pytest-cov==2.11.1
|
||||
- pytest-rerunfailures==10.2
|
||||
- pytest-timeout==2.0.1
|
||||
- python-dateutil==2.8.2
|
||||
- pytorch-lightning==1.6.5
|
||||
- pytz==2022.4
|
||||
- pyyaml==6.0
|
||||
- querystring-parser==1.2.4
|
||||
- readme-renderer==37.2
|
||||
- requests==2.28.1
|
||||
- requests-oauthlib==1.3.1
|
||||
- requests-toolbelt==0.10.0
|
||||
- rfc3986==2.0.0
|
||||
- rpdb==0.1.6
|
||||
- rsa==4.9
|
||||
- ruamel-yaml==0.17.21
|
||||
- ruamel-yaml-clib==0.2.6
|
||||
- scikit-learn==1.0.2
|
||||
- scipy==1.7.3
|
||||
- secretstorage==3.3.3
|
||||
- six==1.16.0
|
||||
- smmap==5.0.0
|
||||
- snowballstemmer==2.2.0
|
||||
- sphinx==4.1.2
|
||||
- sphinx-autodoc-typehints==1.12.0
|
||||
- sphinx-automodapi==0.13
|
||||
- sphinx-rtd-theme==1.0.0
|
||||
- sphinxcontrib-applehelp==1.0.2
|
||||
- sphinxcontrib-devhelp==1.0.2
|
||||
- sphinxcontrib-htmlhelp==2.0.0
|
||||
- sphinxcontrib-jsmath==1.0.1
|
||||
- sphinxcontrib-qthelp==1.0.3
|
||||
- sphinxcontrib-serializinghtml==1.1.5
|
||||
- sqlalchemy==1.4.41
|
||||
- sqlparse==0.4.3
|
||||
- strictyaml==1.6.1
|
||||
- stringcase==1.2.0
|
||||
- tabulate==0.9.0
|
||||
- tensorboard==2.10.1
|
||||
- tensorboard-data-server==0.6.1
|
||||
- tensorboard-plugin-wit==1.8.1
|
||||
- threadpoolctl==3.1.0
|
||||
- toml==0.10.2
|
||||
- tomli==2.0.1
|
||||
- tomlkit==0.11.5
|
||||
- torch==1.12.1
|
||||
- torchmetrics==0.10.0
|
||||
- torchvision==0.13.1
|
||||
- tqdm==4.63.0
|
||||
- twine==3.3.0
|
||||
- typed-ast==1.5.4
|
||||
- typing-extensions==4.4.0
|
||||
- typing-inspect==0.8.0
|
||||
- urllib3==1.26.12
|
||||
- virtualenv==20.16.5
|
||||
- webencodings==0.5.1
|
||||
- websocket-client==1.4.1
|
||||
- werkzeug==2.2.2
|
||||
- wheel==0.36.2
|
||||
- wrapt==1.14.1
|
||||
- yarl==1.8.1
|
||||
- zipp==3.9.0
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
# This environment definition contains all packages to run hi-ml and hi-ml-azure development work, building and
|
||||
# testing
|
||||
name: himl
|
||||
channels:
|
||||
- defaults
|
||||
- pytorch
|
||||
dependencies:
|
||||
- pip=20.1.1
|
||||
- python=3.7.3
|
||||
- pytorch=1.10.0
|
||||
- cudatoolkit=11.3.1
|
||||
- pip:
|
||||
- -r ../hi-ml-azure/run_requirements.txt
|
||||
- -r run_requirements.txt
|
||||
- -r ../build_requirements.txt
|
||||
- -r ../test_requirements.txt
|
|
@ -4,7 +4,7 @@ jinja2==3.0.2
|
|||
matplotlib>=3.4.3
|
||||
opencv-python-headless>=4.5.1.48
|
||||
pandas>=1.3.4
|
||||
protobuf<=3.20.1
|
||||
protobuf<4.0
|
||||
pytorch-lightning>=1.6.0, <1.7
|
||||
rpdb>=0.1.6
|
||||
torchvision>=0.10.0
|
||||
|
|
|
@ -19,74 +19,74 @@ from setuptools import setup, find_namespace_packages # type: ignore
|
|||
here = pathlib.Path(__file__).parent.resolve()
|
||||
|
||||
# Get the long description from the README file
|
||||
long_description = (here / 'package_description.md').read_text(encoding='utf-8')
|
||||
long_description = (here / "package_description.md").read_text(encoding="utf-8")
|
||||
|
||||
version = ''
|
||||
version = ""
|
||||
|
||||
# If running from a GitHub Action then a standard set of environment variables will be
|
||||
# populated (https://docs.github.com/en/actions/reference/environment-variables#default-environment-variables).
|
||||
# In particular, GITHUB_REF is the branch or tag ref that triggered the workflow.
|
||||
# If this was triggered by a tagged commit then GITHUB_REF will be: 'ref/tags/new_tag'.
|
||||
# If this was triggered by a tagged commit then GITHUB_REF will be: "ref/tags/new_tag".
|
||||
# Extract this tag and use it as a version string
|
||||
# See also:
|
||||
# https://packaging.python.org/guides/publishing-package-distribution-releases-using-github-actions-ci-cd-workflows/
|
||||
# https://github.com/pypa/gh-action-pypi-publish
|
||||
GITHUB_REF_TAG_COMMIT = 'refs/tags/v'
|
||||
GITHUB_REF_TAG_COMMIT = "refs/tags/v"
|
||||
|
||||
github_ref = os.getenv('GITHUB_REF')
|
||||
github_ref = os.getenv("GITHUB_REF")
|
||||
if github_ref and github_ref.startswith(GITHUB_REF_TAG_COMMIT):
|
||||
version = github_ref[len(GITHUB_REF_TAG_COMMIT):]
|
||||
|
||||
# Otherwise, if running from a GitHub Action, but not a tagged commit then GITHUB_RUN_NUMBER will be populated.
|
||||
# Use this as a post release number. For example if GITHUB_RUN_NUMBER = 124 then the version string will be
|
||||
# '0.1.2.post124'. Although this is discouraged, see:
|
||||
# "0.1.2.post124". Although this is discouraged, see:
|
||||
# https://www.python.org/dev/peps/pep-0440/#post-releases
|
||||
# it is necessary here to avoid duplicate packages in Test.PyPI.
|
||||
if not version:
|
||||
# TODO: Replace this with more principled package version management for the package wheels built during local test
|
||||
# runs, one which circumvents AzureML's apparent package caching:
|
||||
build_number = os.getenv('GITHUB_RUN_NUMBER')
|
||||
# runs, one which circumvents AzureML"s apparent package caching:
|
||||
build_number = os.getenv("GITHUB_RUN_NUMBER")
|
||||
if build_number:
|
||||
version = '0.1.1.post' + build_number
|
||||
version = "0.1.1.post" + build_number
|
||||
else:
|
||||
default_random_version_number = floor(random() * 10_000_000_000)
|
||||
version = f'0.1.0.post{str(default_random_version_number)}'
|
||||
version = f"0.1.0.post{str(default_random_version_number)}"
|
||||
|
||||
(here / 'package_name.txt').write_text('hi-ml')
|
||||
(here / 'latest_version.txt').write_text(version)
|
||||
(here / "package_name.txt").write_text("hi-ml")
|
||||
(here / "latest_version.txt").write_text(version)
|
||||
|
||||
# Read run_requirements.txt to get install_requires
|
||||
install_requires = (here / 'run_requirements.txt').read_text().split("\n")
|
||||
install_requires = (here / "run_requirements.txt").read_text().split("\n")
|
||||
# Remove any whitespace and blank lines
|
||||
install_requires = [line.strip() for line in install_requires if line.strip()]
|
||||
|
||||
description = 'Microsoft Health Intelligence package containing high level ML components'
|
||||
description = "Microsoft Health Futures package containing high level ML components"
|
||||
|
||||
setup(
|
||||
name='hi-ml',
|
||||
name="hi-ml",
|
||||
version=version,
|
||||
description=description,
|
||||
long_description=long_description,
|
||||
long_description_content_type='text/markdown',
|
||||
url='https://github.com/microsoft/hi-ml',
|
||||
author="Microsoft Research Cambridge InnerEye Team ",
|
||||
long_description_content_type="text/markdown",
|
||||
url="https://github.com/microsoft/hi-ml",
|
||||
author="Biomedical Imaging Team @ Microsoft Health Futures",
|
||||
author_email="innereyedev@microsoft.com",
|
||||
classifiers=[
|
||||
'Development Status :: 3 - Alpha',
|
||||
'Intended Audience :: Science/Research',
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Intended Audience :: Science/Research",
|
||||
"Topic :: Scientific/Engineering :: Medical Science Apps.",
|
||||
'License :: OSI Approved :: MIT License',
|
||||
'Programming Language :: Python :: 3.7'
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3.7"
|
||||
],
|
||||
keywords='InnerEye, HealthIntelligence, AzureML',
|
||||
license='MIT License',
|
||||
keywords="Health Futures, Health Intelligence, AzureML",
|
||||
license="MIT License",
|
||||
packages=find_namespace_packages(where="src"),
|
||||
package_dir={"": "src"},
|
||||
include_package_data=True,
|
||||
install_requires=install_requires,
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
'himl-runner = health_ml.runner:main'
|
||||
"console_scripts": [
|
||||
"himl-runner = health_ml.runner:main"
|
||||
]
|
||||
}
|
||||
)
|
||||
|
|
|
@ -10,12 +10,13 @@ import param
|
|||
from enum import Enum, unique
|
||||
from param import Parameterized
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from azureml.train.hyperdrive import HyperDriveConfig
|
||||
|
||||
from health_azure import create_crossval_hyperdrive_config
|
||||
from health_azure.himl import create_grid_hyperdrive_config
|
||||
from health_azure.himl import (create_grid_hyperdrive_config, create_crossval_hyperparam_args_v2,
|
||||
create_grid_hyperparam_args_v2)
|
||||
from health_azure.amulet import (ENV_AMLT_PROJECT_NAME, ENV_AMLT_INPUT_OUTPUT,
|
||||
ENV_AMLT_SNAPSHOT_DIR, ENV_AMLT_AZ_BATCHAI_DIR,
|
||||
is_amulet_job, get_amulet_aml_working_dir)
|
||||
|
@ -249,8 +250,31 @@ class WorkflowParams(param.Parameterized):
|
|||
metric_name="val/loss"
|
||||
)
|
||||
|
||||
def get_crossval_hyperparam_args_v2(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Wrapper function to create hyperparameter search arguments specifically for running cross validation
|
||||
with AML SDK v2
|
||||
|
||||
:return: A dictionary of hyperparameter search arguments and values.
|
||||
"""
|
||||
return create_crossval_hyperparam_args_v2(num_splits=self.crossval_count,
|
||||
cross_val_index_arg_name=self.CROSSVAL_INDEX_ARG_NAME,
|
||||
metric_name="val/loss")
|
||||
|
||||
def get_grid_hyperparam_args_v2(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Wrapper function to create hyperparameter search arguments specifically for running grid search
|
||||
with AML SDK v2
|
||||
|
||||
:return: A dictionary of hyperparameter search arguments and values.
|
||||
"""
|
||||
return create_grid_hyperparam_args_v2(values=list(map(str, range(self.different_seeds))),
|
||||
argument_name=self.RANDOM_SEED_ARG_NAME,
|
||||
metric_name="val/loss")
|
||||
|
||||
|
||||
class DatasetParams(param.Parameterized):
|
||||
datastore: str = param.String(default="", doc="Datastore to look for data in")
|
||||
azure_datasets: List[str] = param.List(default=[], class_=str,
|
||||
doc="If provided, the ID of one or more datasets to use when running in"
|
||||
" AzureML. This dataset must exist as a folder of the same name "
|
||||
|
|
|
@ -44,3 +44,5 @@ class ExperimentConfig(param.Parameterized):
|
|||
"`INFO` or `DETAIL` for different levels of logging. "
|
||||
"`DETAIL` may impact the application performance and thus "
|
||||
"should only be used when debugging issues")
|
||||
strictly_aml_v1: bool = param.Boolean(default=False, doc="If True, use AzureML v1 SDK. If False (default), use "
|
||||
"the v2 of the SDK")
|
||||
|
|
|
@ -96,6 +96,14 @@ class LightningContainer(WorkflowParams,
|
|||
raise NotImplementedError("Parameter search is not implemented. Please override 'get_parameter_tuning_config' "
|
||||
"in your model container.")
|
||||
|
||||
def get_parameter_tuning_args(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns a dictionary of hyperperameter argument names and values as expected by a AML SDK v2 job
|
||||
to perform hyperparameter search
|
||||
"""
|
||||
raise NotImplementedError("Parameter search is not implemented. Please override 'get_parameter_tuning_args' "
|
||||
"in your model container.")
|
||||
|
||||
def update_experiment_config(self, experiment_config: ExperimentConfig) -> None:
|
||||
"""
|
||||
This method allows overriding ExperimentConfig parameters from within a LightningContainer.
|
||||
|
@ -185,6 +193,21 @@ class LightningContainer(WorkflowParams,
|
|||
return self.get_different_seeds_hyperdrive_config()
|
||||
return None
|
||||
|
||||
def get_hyperparam_args(self) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Returns a dictionary of hyperparameter search arguments that will be passed to an AML v2 command to
|
||||
enable either hyperparameter tuning, cross validation, or running with different seeds.
|
||||
|
||||
:return: A dictionary of hyperparameter search arguments and values.
|
||||
"""
|
||||
if self.hyperdrive:
|
||||
return self.get_parameter_tuning_args()
|
||||
if self.is_crossvalidation_enabled:
|
||||
return self.get_crossval_hyperparam_args_v2()
|
||||
if self.different_seeds > 0:
|
||||
return self.get_grid_hyperparam_args_v2()
|
||||
return None
|
||||
|
||||
def load_model_checkpoint(self, checkpoint_path: Path) -> None:
|
||||
"""
|
||||
Load a checkpoint from the given path. We need to define a separate method since pytorch lightning cannot
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional, Tuple, TypeVar
|
||||
|
||||
|
@ -16,9 +17,9 @@ from pytorch_lightning.profiler import BaseProfiler, SimpleProfiler, AdvancedPro
|
|||
from health_azure.utils import RUN_CONTEXT, is_running_in_azure_ml
|
||||
|
||||
from health_ml.lightning_container import LightningContainer
|
||||
from health_ml.utils import AzureMLLogger, AzureMLProgressBar
|
||||
from health_ml.utils import AzureMLProgressBar
|
||||
from health_ml.utils.common_utils import AUTOSAVE_CHECKPOINT_FILE_NAME, EXPERIMENT_SUMMARY_FILE
|
||||
from health_ml.utils.lightning_loggers import StoringLogger
|
||||
from health_ml.utils.lightning_loggers import StoringLogger, HimlMLFlowLogger
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
@ -55,7 +56,8 @@ def create_lightning_trainer(container: LightningContainer,
|
|||
resume_from_checkpoint: Optional[Path] = None,
|
||||
num_nodes: int = 1,
|
||||
multiple_trainloader_mode: str = "max_size_cycle",
|
||||
azureml_run_for_logging: Optional[Run] = None) -> \
|
||||
azureml_run_for_logging: Optional[Run] = None,
|
||||
mlflow_run_for_logging: Optional[str] = None) -> \
|
||||
Tuple[Trainer, StoringLogger]:
|
||||
"""
|
||||
Creates a Pytorch Lightning Trainer object for the given model configuration. It creates checkpoint handlers
|
||||
|
@ -91,9 +93,27 @@ def create_lightning_trainer(container: LightningContainer,
|
|||
message += "s per node with DDP"
|
||||
logging.info(f"Using {message}")
|
||||
tensorboard_logger = TensorBoardLogger(save_dir=str(container.logs_folder), name="Lightning", version="")
|
||||
azureml_logger = AzureMLLogger(enable_logging_outside_azure_ml=container.log_from_vm,
|
||||
run=azureml_run_for_logging)
|
||||
loggers = [tensorboard_logger, azureml_logger]
|
||||
loggers: List[Any] = [tensorboard_logger]
|
||||
|
||||
if is_running_in_azure_ml():
|
||||
mlflow_run_id = os.environ.get("MLFLOW_RUN_ID", None)
|
||||
logging.info(f"Logging to MLFlow run with id: {mlflow_run_id}")
|
||||
mlflow_logger = HimlMLFlowLogger(
|
||||
run_id=mlflow_run_id
|
||||
)
|
||||
loggers.append(mlflow_logger)
|
||||
else:
|
||||
mlflow_run_dir = container.outputs_folder / "mlruns"
|
||||
try:
|
||||
mlflow_run_dir.mkdir(exist_ok=True)
|
||||
mlflow_tracking_uri = "file:" + str(mlflow_run_dir)
|
||||
mlflow_logger = HimlMLFlowLogger(run_id=mlflow_run_for_logging, tracking_uri=mlflow_tracking_uri)
|
||||
loggers.append(mlflow_logger)
|
||||
logging.info(f"Logging to MLFlow run with id: {mlflow_run_for_logging}. Local MLFlow logs are stored in "
|
||||
f"{mlflow_tracking_uri}")
|
||||
except FileNotFoundError as e:
|
||||
logging.warning(f"Unable to initialise MLFlowLogger due to error: {e}")
|
||||
|
||||
storing_logger = StoringLogger()
|
||||
loggers.append(storing_logger)
|
||||
# Use 32bit precision when running on CPU. Otherwise, make it depend on use_mixed_precision flag.
|
||||
|
|
|
@ -63,7 +63,7 @@ class AttentionLayer(nn.Module):
|
|||
attention_weights = transpose(attention_weights, 1, 0) # K x N
|
||||
attention_weights = F.softmax(attention_weights, dim=1) # Softmax over N : K x N
|
||||
pooled_features = mm(attention_weights, features) # Matrix multiplication : K x L
|
||||
return(attention_weights, pooled_features)
|
||||
return attention_weights, pooled_features
|
||||
|
||||
|
||||
class GatedAttentionLayer(nn.Module):
|
||||
|
@ -98,7 +98,7 @@ class GatedAttentionLayer(nn.Module):
|
|||
attention_weights = transpose(attention_weights, 1, 0) # K x N
|
||||
attention_weights = F.softmax(attention_weights, dim=1) # Softmax over N : K x N
|
||||
pooled_features = mm(attention_weights, features) # Matrix multiplication : K x L
|
||||
return(attention_weights, pooled_features)
|
||||
return attention_weights, pooled_features
|
||||
|
||||
|
||||
class CustomTransformerEncoderLayer(TransformerEncoderLayer):
|
||||
|
|
|
@ -20,7 +20,6 @@ from health_azure.utils import (create_run_recovery_id, ENV_OMPI_COMM_WORLD_RANK
|
|||
aggregate_hyperdrive_metrics, get_metrics_for_childless_run,
|
||||
ENV_GLOBAL_RANK, ENV_LOCAL_RANK, ENV_NODE_RANK,
|
||||
is_local_rank_zero, is_global_rank_zero, create_aml_run_object)
|
||||
|
||||
from health_ml.experiment_config import ExperimentConfig
|
||||
from health_ml.lightning_container import LightningContainer
|
||||
from health_ml.model_trainer import create_lightning_trainer, write_experiment_summary_file
|
||||
|
@ -35,11 +34,27 @@ from health_ml.utils.common_utils import (
|
|||
df_to_json,
|
||||
seed_monai_if_available,
|
||||
)
|
||||
from health_ml.utils.lightning_loggers import StoringLogger
|
||||
from health_ml.utils.lightning_loggers import HimlMLFlowLogger, StoringLogger
|
||||
from health_ml.utils.regression_test_utils import REGRESSION_TEST_METRICS_FILENAME, compare_folders_and_run_outputs
|
||||
from health_ml.utils.type_annotations import PathOrString
|
||||
|
||||
|
||||
def get_mlflow_run_id_from_previous_loggers(trainer: Optional[Trainer]) -> Optional[str]:
|
||||
"""
|
||||
If self.trainer has already been intialised with loggers, attempt to retrieve a HimlMLFLowLogger and
|
||||
return the mlflow run_id associated with it, to allow continued logging to the same run. Otherwise, return None
|
||||
|
||||
:return: The mlflow run id from the existing HimlMLFlowLogger
|
||||
"""
|
||||
if trainer is None:
|
||||
return None
|
||||
try:
|
||||
mlflow_logger = [logger for logger in trainer.loggers if isinstance(logger, HimlMLFlowLogger)][0]
|
||||
return mlflow_logger.run_id
|
||||
except IndexError:
|
||||
return None
|
||||
|
||||
|
||||
def check_dataset_folder_exists(local_dataset: PathOrString) -> Path:
|
||||
"""
|
||||
Checks if a folder with a local dataset exists. If it does exist, return the argument converted
|
||||
|
@ -80,6 +95,7 @@ class MLRunner:
|
|||
run_context=RUN_CONTEXT)
|
||||
self.trainer: Optional[Trainer] = None
|
||||
self.azureml_run_for_logging: Optional[Run] = None
|
||||
self.mlflow_run_for_logging: Optional[str] = None
|
||||
|
||||
def set_run_tags_from_parent(self) -> None:
|
||||
"""
|
||||
|
@ -117,6 +133,7 @@ class MLRunner:
|
|||
# the provided local datasets for VM runs, or the AzureML mount points when running in AML.
|
||||
# This must happen before container setup because that could already read datasets.
|
||||
input_datasets = azure_run_info.input_datasets
|
||||
logging.info(f"Setting the following datasets as local datasets: {input_datasets}")
|
||||
if len(input_datasets) > 0:
|
||||
local_datasets: List[Path] = []
|
||||
for i, dataset in enumerate(input_datasets):
|
||||
|
@ -261,6 +278,7 @@ class MLRunner:
|
|||
# We run inference on a single device because distributed strategies such as DDP use DistributedSampler
|
||||
# internally, which replicates some samples to make sure all devices have the same batch size in case of
|
||||
# uneven inputs.
|
||||
mlflow_run_id = get_mlflow_run_id_from_previous_loggers(self.trainer)
|
||||
self.container.max_num_gpus = 1
|
||||
|
||||
if self.container.run_inference_only:
|
||||
|
@ -272,7 +290,8 @@ class MLRunner:
|
|||
container=self.container,
|
||||
resume_from_checkpoint=checkpoint_path,
|
||||
num_nodes=1,
|
||||
azureml_run_for_logging=self.azureml_run_for_logging
|
||||
azureml_run_for_logging=self.azureml_run_for_logging,
|
||||
mlflow_run_for_logging=mlflow_run_id
|
||||
)
|
||||
return trainer
|
||||
|
||||
|
@ -319,6 +338,12 @@ class MLRunner:
|
|||
if self.container.has_custom_test_step():
|
||||
# Run Lightning's built-in test procedure if the `test_step` method has been overridden
|
||||
logging.info("Running inference via the LightningModule.test_step method")
|
||||
# We run inference on a single device because distributed strategies such as DDP use DistributedSampler
|
||||
# internally, which replicates some samples to make sure all devices have some batch size in case of
|
||||
# uneven inputs.
|
||||
|
||||
self.container.max_num_gpus = 1
|
||||
|
||||
checkpoint_path = (
|
||||
self.checkpoint_handler.get_checkpoint_to_test() if self.container.run_inference_only else None
|
||||
)
|
||||
|
@ -384,9 +409,6 @@ class MLRunner:
|
|||
with logging_section("Model training"):
|
||||
self.run_training()
|
||||
|
||||
# Kill all processes besides rank 0
|
||||
self.after_ddp_cleanup(old_environ)
|
||||
|
||||
# load model checkpoint for custom inference or additional validation step
|
||||
if self.container.has_custom_test_step() or self.container.run_extra_val_epoch:
|
||||
self.load_model_checkpoint()
|
||||
|
@ -396,6 +418,9 @@ class MLRunner:
|
|||
with logging_section("Model Validation to save plots on validation set"):
|
||||
self.run_validation()
|
||||
|
||||
# Kill all processes besides rank 0
|
||||
self.after_ddp_cleanup(old_environ)
|
||||
|
||||
# Run inference on a single device
|
||||
with logging_section("Model inference"):
|
||||
self.run_inference()
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
#! /usr/bin/env python
|
||||
|
||||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
|
@ -27,16 +29,17 @@ from health_azure.datasets import create_dataset_configs # noqa: E402
|
|||
from health_azure.logging import logging_to_stdout # noqa: E402
|
||||
from health_azure.paths import is_himl_used_from_git_repo # noqa: E402
|
||||
from health_azure.amulet import prepare_amulet_job, is_amulet_job # noqa: E402
|
||||
from health_azure.utils import (get_workspace, is_local_rank_zero, is_running_in_azure_ml, # noqa: E402
|
||||
set_environment_variables_for_multi_node,
|
||||
create_argparser, parse_arguments, ParserResult, apply_overrides)
|
||||
from health_azure.utils import (get_workspace, get_ml_client, is_local_rank_zero, # noqa: E402
|
||||
is_running_in_azure_ml, set_environment_variables_for_multi_node,
|
||||
create_argparser, parse_arguments, ParserResult, apply_overrides,
|
||||
filter_v2_input_output_args)
|
||||
|
||||
from health_ml.experiment_config import DEBUG_DDP_ENV_VAR, ExperimentConfig # noqa: E402
|
||||
from health_ml.lightning_container import LightningContainer # noqa: E402
|
||||
from health_ml.run_ml import MLRunner # noqa: E402
|
||||
from health_ml.utils import fixed_paths # noqa: E402
|
||||
from health_ml.utils.common_utils import (check_conda_environment, choose_conda_env_file, # noqa: E402
|
||||
is_linux)
|
||||
from health_ml.utils.common_utils import (DEFAULT_DOCKER_BASE_IMAGE, check_conda_environment, # noqa: E402
|
||||
choose_conda_env_file, is_linux)
|
||||
from health_ml.utils.config_loader import ModelConfigLoader # noqa: E402
|
||||
|
||||
|
||||
|
@ -46,8 +49,6 @@ from health_ml.utils.config_loader import ModelConfigLoader # noqa: E402
|
|||
runner_path = Path(sys.argv[0])
|
||||
sys.argv[0] = str(runner_path.resolve())
|
||||
|
||||
DEFAULT_DOCKER_BASE_IMAGE = "mcr.microsoft.com/azureml/openmpi3.1.2-cuda10.2-cudnn8-ubuntu18.04"
|
||||
|
||||
|
||||
def initialize_rpdb() -> None:
|
||||
"""
|
||||
|
@ -126,8 +127,12 @@ class Runner:
|
|||
|
||||
:return: ParserResult object containing args, overrides and settings
|
||||
"""
|
||||
# Filter out any args for passing inputs and outputs to scripts with AML SDK v2
|
||||
args = sys.argv[1:]
|
||||
filtered_args = filter_v2_input_output_args(args)
|
||||
|
||||
parser1 = create_runner_parser()
|
||||
parser1_result = parse_arguments(parser1, args=sys.argv[1:])
|
||||
parser1_result = parse_arguments(parser1, args=filtered_args)
|
||||
experiment_config = ExperimentConfig(**parser1_result.args)
|
||||
|
||||
self.experiment_config = experiment_config
|
||||
|
@ -229,7 +234,13 @@ class Runner:
|
|||
except ValueError:
|
||||
raise ValueError("Unable to submit the script to AzureML because no workspace configuration file "
|
||||
"(config.json) was found.")
|
||||
default_datastore = workspace.get_default_datastore().name if workspace is not None else ""
|
||||
|
||||
if self.lightning_container.datastore:
|
||||
datastore = self.lightning_container.datastore
|
||||
elif workspace:
|
||||
datastore = workspace.get_default_datastore().name
|
||||
else:
|
||||
datastore = ""
|
||||
|
||||
local_datasets = self.lightning_container.local_datasets
|
||||
all_local_datasets = [Path(p) for p in local_datasets] if len(local_datasets) > 0 else []
|
||||
|
@ -240,10 +251,18 @@ class Runner:
|
|||
create_dataset_configs(all_azure_dataset_ids=self.lightning_container.azure_datasets,
|
||||
all_dataset_mountpoints=self.lightning_container.dataset_mountpoints,
|
||||
all_local_datasets=all_local_datasets, # type: ignore
|
||||
datastore=default_datastore,
|
||||
datastore=datastore,
|
||||
use_mounting=use_mounting)
|
||||
|
||||
if self.experiment_config.cluster and not is_running_in_azure_ml():
|
||||
hyperdrive_config = self.lightning_container.get_hyperdrive_config()
|
||||
if self.experiment_config.strictly_aml_v1:
|
||||
hyperdrive_config = self.lightning_container.get_hyperdrive_config()
|
||||
hyperparam_args = None
|
||||
else:
|
||||
hyperparam_args = self.lightning_container.get_hyperparam_args()
|
||||
hyperdrive_config = None
|
||||
ml_client = get_ml_client()
|
||||
|
||||
env_file = choose_conda_env_file(env_file=self.experiment_config.conda_env)
|
||||
logging.info(f"Using this Conda environment definition: {env_file}")
|
||||
check_conda_environment(env_file)
|
||||
|
@ -254,9 +273,10 @@ class Runner:
|
|||
script_params=script_params,
|
||||
conda_environment_file=env_file,
|
||||
aml_workspace=workspace,
|
||||
ml_client=ml_client,
|
||||
compute_cluster_name=self.experiment_config.cluster,
|
||||
environment_variables=environment_variables,
|
||||
default_datastore=default_datastore,
|
||||
default_datastore=datastore,
|
||||
experiment_name=self.lightning_container.effective_experiment_name,
|
||||
input_datasets=input_datasets, # type: ignore
|
||||
num_nodes=self.experiment_config.num_nodes,
|
||||
|
@ -266,15 +286,19 @@ class Runner:
|
|||
docker_base_image=DEFAULT_DOCKER_BASE_IMAGE,
|
||||
docker_shm_size=self.experiment_config.docker_shm_size,
|
||||
hyperdrive_config=hyperdrive_config,
|
||||
hyperparam_args=hyperparam_args,
|
||||
create_output_folders=False,
|
||||
after_submission=after_submission_hook,
|
||||
tags=self.additional_run_tags(script_params)
|
||||
tags=self.additional_run_tags(script_params),
|
||||
strictly_aml_v1=self.experiment_config.strictly_aml_v1,
|
||||
)
|
||||
else:
|
||||
azure_run_info = submit_to_azure_if_needed(
|
||||
input_datasets=input_datasets, # type: ignore
|
||||
submit_to_azureml=False,
|
||||
environment_variables=environment_variables)
|
||||
environment_variables=environment_variables,
|
||||
strictly_aml_v1=self.experiment_config.strictly_aml_v1,
|
||||
)
|
||||
if azure_run_info.run:
|
||||
# This code is only reached inside Azure. Set display name again - this will now affect
|
||||
# Hypdrive child runs (for other jobs, this has already been done after submission)
|
||||
|
|
|
@ -49,6 +49,7 @@ RUN_RECOVERY_FROM_ID_KEY_NAME = "recovered_from"
|
|||
|
||||
# other
|
||||
EFFECTIVE_RANDOM_SEED_KEY_NAME = "effective_random_seed"
|
||||
DEFAULT_DOCKER_BASE_IMAGE = "mcr.microsoft.com/azureml/openmpi3.1.2-cuda10.2-cudnn8-ubuntu18.04"
|
||||
|
||||
|
||||
@unique
|
||||
|
|
|
@ -3,10 +3,14 @@
|
|||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import logging
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
from argparse import Namespace
|
||||
from typing import Any, Dict, Iterable, List, Optional, Union
|
||||
|
||||
from pytorch_lightning.loggers import LightningLoggerBase
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
import mlflow
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.loggers import LightningLoggerBase, MLFlowLogger
|
||||
from pytorch_lightning.utilities.logger import _convert_params, _flatten_dict
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
|
||||
|
||||
from health_ml.utils.type_annotations import DictStrFloat, DictStrFloatOrFloatList
|
||||
|
||||
|
@ -109,3 +113,49 @@ class StoringLogger(LightningLoggerBase):
|
|||
:return: A dictionary mapping from epoch number to metric name to metric value.
|
||||
"""
|
||||
return {epoch: self.extract_by_prefix(epoch, prefix_filter) for epoch in self.epochs}
|
||||
|
||||
|
||||
class HimlMLFlowLogger(MLFlowLogger):
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@rank_zero_only
|
||||
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
|
||||
"""
|
||||
Override underlying log_hyperparams message to avoid trying to log hyperparameters that have already
|
||||
been logged, thus causing MLFlow to raise an Exception.
|
||||
|
||||
:param params: The original hyperparameters to be logged.
|
||||
"""
|
||||
run = mlflow.get_run(self.run_id)
|
||||
existing_hyperparams = run.data.params
|
||||
|
||||
params = _convert_params(params)
|
||||
params = _flatten_dict(params)
|
||||
for k, v in params.items():
|
||||
if len(str(v)) > 250:
|
||||
rank_zero_warn(
|
||||
f"Mlflow only allows parameters with up to 250 characters. Discard {k}={v}",
|
||||
category=RuntimeWarning
|
||||
)
|
||||
continue
|
||||
if k in existing_hyperparams:
|
||||
continue
|
||||
|
||||
self.experiment.log_param(self.run_id, k, v)
|
||||
|
||||
|
||||
def get_mlflow_run_id_from_trainer(trainer: Trainer) -> Optional[str]:
|
||||
"""
|
||||
If trainer has already been intialised with loggers, attempt to retrieve one of the type HimlMLFlowLogger,
|
||||
and return its run_id property in order to log to the same run. Otherwise, return None.
|
||||
|
||||
:return: The mlflow run id from an existing HimlMLFlowLogger if available, else None.
|
||||
"""
|
||||
if trainer is None:
|
||||
return None
|
||||
try:
|
||||
mlflow_logger = [logger for logger in trainer.loggers if isinstance(logger, HimlMLFlowLogger)][0]
|
||||
return mlflow_logger.run_id
|
||||
except IndexError:
|
||||
return None
|
||||
|
|
|
@ -124,7 +124,7 @@ class AzureMLLogger(LightningLoggerBase):
|
|||
return
|
||||
if params is None:
|
||||
return
|
||||
params_final = self._preprocess_hyperparams(params)
|
||||
params_final = _preprocess_hyperparams(params)
|
||||
if len(params_final) > 0:
|
||||
# Log hyperparameters as a table with 2 columns. Each "step" is one hyperparameter
|
||||
self.run.log_table(self.HYPERPARAMS_NAME, {"name": list(params_final.keys()),
|
||||
|
@ -150,24 +150,6 @@ class AzureMLLogger(LightningLoggerBase):
|
|||
# Run.complete should only be called if we created an AzureML run here in the constructor.
|
||||
self.run.complete()
|
||||
|
||||
def _preprocess_hyperparams(self, params: Any) -> Dict[str, str]:
|
||||
"""
|
||||
Converts arbitrary hyperparameters to a simple dictionary structure, in particular argparse Namespaces.
|
||||
Nested dictionaries are converted to folder-like strings, like ``{'a': {'b': 'c'}} -> {'a/b': 'c'}``.
|
||||
All hyperparameter values are converted to strings, because Run.log_table can't deal with mixed datatypes.
|
||||
:param params: The parameters to convert
|
||||
:return: A dictionary mapping from string to string.
|
||||
"""
|
||||
# Convert from Namespace to dictionary
|
||||
params = _convert_params(params)
|
||||
# Convert nested dictionaries to folder-like structure
|
||||
params = _flatten_dict(params)
|
||||
# Convert anything that is not a primitive type to str
|
||||
params_final = _sanitize_params(params)
|
||||
if not isinstance(params_final, dict):
|
||||
raise ValueError(f"Expected the converted hyperparameters to be a dictionary, but got {type(params)}")
|
||||
return {str(key): str(value) for key, value in params_final.items()}
|
||||
|
||||
|
||||
class AzureMLProgressBar(ProgressBarBase):
|
||||
"""
|
||||
|
@ -333,6 +315,25 @@ class AzureMLProgressBar(ProgressBarBase):
|
|||
sys.stdout.flush()
|
||||
|
||||
|
||||
def _preprocess_hyperparams(params: Any) -> Dict[str, str]:
|
||||
"""
|
||||
Converts arbitrary hyperparameters to a simple dictionary structure, in particular argparse Namespaces.
|
||||
Nested dictionaries are converted to folder-like strings, like ``{'a': {'b': 'c'}} -> {'a/b': 'c'}``.
|
||||
All hyperparameter values are converted to strings, because Run.log_table can't deal with mixed datatypes.
|
||||
:param params: The parameters to convert
|
||||
:return: A dictionary mapping from string to string.
|
||||
"""
|
||||
# Convert from Namespace to dictionary
|
||||
params = _convert_params(params)
|
||||
# Convert nested dictionaries to folder-like structure
|
||||
params = _flatten_dict(params)
|
||||
# Convert anything that is not a primitive type to str
|
||||
params_final = _sanitize_params(params)
|
||||
if not isinstance(params_final, dict):
|
||||
raise ValueError(f"Expected the converted hyperparameters to be a dictionary, but got {type(params)}")
|
||||
return {str(key): str(value) for key, value in params_final.items()}
|
||||
|
||||
|
||||
def log_on_epoch(module: LightningModule,
|
||||
name: Optional[str] = None,
|
||||
value: Optional[Any] = None,
|
||||
|
|
|
@ -48,7 +48,7 @@ def _test_data_augmentation(data_augmentation: Callable[[Tensor], Tensor],
|
|||
# After applying a stochastic augmentation a second time it should have a different output
|
||||
if stochastic:
|
||||
augmented_img = data_augmentation(input_img) # type: ignore
|
||||
assert not(torch.allclose(augmented_img, expected_output_img, atol=1e-04))
|
||||
assert not torch.allclose(augmented_img, expected_output_img, atol=1e-04)
|
||||
|
||||
|
||||
def test_stain_normalization() -> None:
|
||||
|
|
|
@ -11,15 +11,17 @@ from unittest.mock import DEFAULT, MagicMock, Mock, patch
|
|||
from _pytest.logging import LogCaptureFixture
|
||||
from pytorch_lightning import LightningModule
|
||||
|
||||
from azureml._restclient.constants import RunStatus
|
||||
import mlflow
|
||||
from pytorch_lightning import Trainer
|
||||
|
||||
from health_ml.configs.hello_world import HelloWorld # type: ignore
|
||||
from health_ml.experiment_config import ExperimentConfig
|
||||
from health_ml.lightning_container import LightningContainer
|
||||
from health_ml.run_ml import MLRunner
|
||||
from health_ml.run_ml import MLRunner, get_mlflow_run_id_from_previous_loggers
|
||||
from health_ml.utils.checkpoint_utils import CheckpointParser
|
||||
from health_ml.utils.common_utils import is_gpu_available
|
||||
from health_ml.utils.lightning_loggers import HimlMLFlowLogger, StoringLogger
|
||||
from health_azure.utils import is_global_rank_zero
|
||||
from health_ml.utils.logging import AzureMLLogger
|
||||
from testazure.utils_testazure import DEFAULT_WORKSPACE
|
||||
from testhiml.utils.fixed_paths_for_tests import mock_run_id
|
||||
|
||||
|
@ -412,22 +414,7 @@ def test_log_on_vm(log_from_vm: bool) -> None:
|
|||
assert runner.trainer.loggers is not None
|
||||
assert len(runner.trainer.loggers) > 1
|
||||
logger = runner.trainer.loggers[1]
|
||||
assert isinstance(logger, AzureMLLogger)
|
||||
if log_from_vm:
|
||||
assert logger.run is not None
|
||||
# Check that all user supplied data (experiment and display name) are respected.
|
||||
assert logger.run.experiment is not None
|
||||
assert logger.run.experiment.name == experiment_name
|
||||
assert logger.run.display_name == tag
|
||||
# Both trainig and inference metrics must be logged in the same Run object.
|
||||
metrics = logger.run.get_metrics()
|
||||
assert "test_mse" in metrics
|
||||
assert "loss" in metrics
|
||||
# The run must have been correctly marked as completed.
|
||||
logger.run.wait_for_completion()
|
||||
assert logger.run.status == RunStatus.COMPLETED
|
||||
else:
|
||||
assert logger.run is None
|
||||
assert isinstance(logger, HimlMLFlowLogger)
|
||||
|
||||
|
||||
def test_experiment_name() -> None:
|
||||
|
@ -442,3 +429,21 @@ def test_experiment_name() -> None:
|
|||
experiment_name = "unittest"
|
||||
container.experiment = experiment_name
|
||||
assert container.effective_experiment_name == experiment_name
|
||||
|
||||
|
||||
def test_get_mlflow_run_id_from_previous_loggers() -> None:
|
||||
trainer_without_loggers = Trainer()
|
||||
run_id = get_mlflow_run_id_from_previous_loggers(trainer_without_loggers)
|
||||
assert run_id is None
|
||||
|
||||
loggers_not_inc_mlflow = [StoringLogger()]
|
||||
trainer_with_single_logger = Trainer(logger=loggers_not_inc_mlflow)
|
||||
run_id = get_mlflow_run_id_from_previous_loggers(trainer_with_single_logger)
|
||||
assert run_id is None
|
||||
|
||||
mock_run_id = "run_id_123"
|
||||
loggers_inc_mlflow = [StoringLogger(), HimlMLFlowLogger(run_id=mock_run_id)]
|
||||
trainer_with_loggers = Trainer(logger=loggers_inc_mlflow)
|
||||
with patch.object(mlflow.tracking.client.TrackingServiceClient, "get_run"):
|
||||
run_id = get_mlflow_run_id_from_previous_loggers(trainer_with_loggers)
|
||||
assert run_id == mock_run_id
|
||||
|
|
|
@ -103,9 +103,10 @@ def test_additional_aml_run_tags(mock_runner: Runner) -> None:
|
|||
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_azure_if_needed:
|
||||
with patch("health_ml.runner.check_conda_environment"):
|
||||
with patch("health_ml.runner.get_workspace"):
|
||||
with patch("health_ml.runner.Runner.run_in_situ"):
|
||||
with patch.object(sys, "argv", arguments):
|
||||
mock_runner.run()
|
||||
with patch("health_ml.runner.get_ml_client"):
|
||||
with patch("health_ml.runner.Runner.run_in_situ"):
|
||||
with patch.object(sys, "argv", arguments):
|
||||
mock_runner.run()
|
||||
mock_submit_to_azure_if_needed.assert_called_once()
|
||||
assert "commandline_args" in mock_submit_to_azure_if_needed.call_args[1]["tags"]
|
||||
assert "tag" in mock_submit_to_azure_if_needed.call_args[1]["tags"]
|
||||
|
@ -134,8 +135,10 @@ def test_submit_to_azureml_if_needed(mock_get_workspace: MagicMock,
|
|||
mock_get_env_files: MagicMock,
|
||||
mock_runner: Runner
|
||||
) -> None:
|
||||
def _mock_dont_submit_to_aml(input_datasets: List[DatasetConfig], submit_to_azureml: bool, # type: ignore
|
||||
environment_variables: Dict[str, Any]) -> AzureRunInfo:
|
||||
def _mock_dont_submit_to_aml(input_datasets: List[DatasetConfig],
|
||||
submit_to_azureml: bool, strictly_aml_v1: bool, # type: ignore
|
||||
environment_variables: Dict[str, Any], # type: ignore
|
||||
) -> AzureRunInfo:
|
||||
datasets_input = [d.target_folder for d in input_datasets] if input_datasets else []
|
||||
return AzureRunInfo(input_datasets=datasets_input,
|
||||
output_datasets=[],
|
||||
|
@ -247,15 +250,16 @@ def _test_hyperdrive_submission(mock_runner: Runner,
|
|||
expected_argument_name: str,
|
||||
expected_argument_values: List[str]) -> None:
|
||||
model_name = "HelloWorld"
|
||||
arguments = ["", f"--model={model_name}", "--cluster=foo", commandline_arg]
|
||||
arguments = ["", f"--model={model_name}", "--cluster=foo", commandline_arg, "--strictly_aml_v1=True"]
|
||||
# Use a special simplified environment file only for the tests here. Copy that to a temp folder, then let the runner
|
||||
# start in that temp folder.
|
||||
with change_working_folder_and_add_environment(mock_runner.project_root):
|
||||
with patch("health_ml.runner.Runner.run_in_situ") as mock_run_in_situ:
|
||||
with patch("health_ml.runner.get_workspace"):
|
||||
with patch.object(sys, "argv", arguments):
|
||||
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_aml:
|
||||
mock_runner.run()
|
||||
with patch("health_ml.runner.get_ml_client"):
|
||||
with patch.object(sys, "argv", arguments):
|
||||
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_aml:
|
||||
mock_runner.run()
|
||||
mock_run_in_situ.assert_called_once()
|
||||
mock_submit_to_aml.assert_called_once()
|
||||
# call_args is a tuple of (args, kwargs)
|
||||
|
@ -281,10 +285,11 @@ def test_submit_to_azure_docker(mock_runner: Runner) -> None:
|
|||
# start in that temp folder.
|
||||
with change_working_folder_and_add_environment(mock_runner.project_root):
|
||||
with patch("health_ml.runner.Runner.run_in_situ") as mock_run_in_situ:
|
||||
with patch("health_ml.runner.get_workspace"):
|
||||
with patch.object(sys, "argv", arguments):
|
||||
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_aml:
|
||||
mock_runner.run()
|
||||
with patch("health_ml.runner.get_ml_client"):
|
||||
with patch("health_ml.runner.get_workspace"):
|
||||
with patch.object(sys, "argv", arguments):
|
||||
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_aml:
|
||||
mock_runner.run()
|
||||
mock_run_in_situ.assert_called_once()
|
||||
mock_submit_to_aml.assert_called_once()
|
||||
# call_args is a tuple of (args, kwargs)
|
||||
|
|
|
@ -21,6 +21,7 @@ from azureml.core import Run
|
|||
|
||||
from health_azure import RUN_CONTEXT, create_aml_run_object
|
||||
from health_ml.utils import AzureMLLogger, AzureMLProgressBar, log_learning_rate, log_on_epoch
|
||||
from health_ml.utils.logging import _preprocess_hyperparams
|
||||
from testhiml.utils_testhiml import DEFAULT_WORKSPACE
|
||||
|
||||
|
||||
|
@ -240,8 +241,7 @@ def test_azureml_logger_hyperparams_processing() -> None:
|
|||
"""
|
||||
hyperparams = {"A long list": ["foo", 1.0, "abc"],
|
||||
"foo": 1.0}
|
||||
logger = AzureMLLogger(enable_logging_outside_azure_ml=False)
|
||||
actual = logger._preprocess_hyperparams(hyperparams)
|
||||
actual = _preprocess_hyperparams(hyperparams)
|
||||
assert actual == {"A long list": "['foo', 1.0, 'abc']", "foo": "1.0"}
|
||||
|
||||
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
black==22.1.0
|
||||
coverage==6.3.2
|
||||
flake8==4.0.1
|
||||
flake8==5.0.2
|
||||
mypy==0.931
|
||||
pylint==2.12.2
|
||||
pylint==2.15.0
|
||||
pycobertura==2.0.1
|
||||
pytest==6.2.2
|
||||
pytest-cov==2.11.1
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
black==22.1.0
|
||||
coverage==6.3.2
|
||||
flake8==4.0.1
|
||||
flake8==5.0.2
|
||||
mypy==0.931
|
||||
pre-commit==2.19.0
|
||||
pylint==2.12.2
|
||||
pylint==2.15.0
|
||||
pycobertura==2.0.1
|
||||
pytest==6.2.2
|
||||
pytest-cov==2.11.1
|
||||
|
|
Загрузка…
Ссылка в новой задаче