ENH: Fix logging + model registration for Amulet runs (#804)
Adds functionality to properly log IE-DL outputs and register trained models when running jobs through amulet.
This commit is contained in:
Родитель
a6dcef0252
Коммит
78b7de2b5f
|
@ -37,3 +37,4 @@ InnerEyePrivateSettings.yml
|
|||
cifar-10-batches-py
|
||||
cifar-100-python
|
||||
!**/InnerEye/ML/Histopathology/datasets
|
||||
None/
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
.idea
|
||||
.git
|
||||
.azureml
|
||||
.pytest_cache
|
||||
.mypy_cache
|
||||
.github
|
||||
.amlignore
|
||||
.coveragerc
|
||||
.editorconfig
|
||||
.flake8
|
||||
.gitattributes
|
||||
.gitconfig
|
||||
.gitignore
|
||||
.gitmodules
|
||||
CODE_OF_CONDUCT.md
|
||||
GeoPol.xml
|
||||
most_recent_run.txt
|
||||
mypy.ini
|
||||
mypy_runner.py
|
||||
pull_request_template.md
|
||||
SECURITY.md
|
||||
__pycache__
|
||||
azure-pipelines
|
||||
/datasets
|
||||
docs
|
||||
sphinx-docs
|
||||
modelweights
|
||||
outputs
|
||||
logs
|
||||
test_outputs
|
||||
run_outputs
|
||||
# Test output from model registration
|
||||
TestsOutsidePackage/azureml-models
|
||||
tensorboard_runs
|
||||
InnerEyeTestVariables.txt
|
||||
InnerEyePrivateSettings.yml
|
||||
cifar-10-batches-py
|
||||
cifar-100-python
|
||||
!**/InnerEye/ML/Histopathology/datasets
|
||||
None/
|
||||
Tests/ML/test_data
|
|
@ -162,3 +162,8 @@ tensorboard_runs
|
|||
None
|
||||
cifar-10-batches-py
|
||||
cifar-10-python.tar.gz
|
||||
|
||||
## Amulet
|
||||
|
||||
.amltconfig
|
||||
amlt_job.yml
|
||||
|
|
|
@ -87,6 +87,7 @@ DEFAULT_TEST_ZIP_NAME = "test.zip"
|
|||
|
||||
# The property in the model registry that holds the name of the Python environment
|
||||
PYTHON_ENVIRONMENT_NAME = "python_environment_name"
|
||||
PYTHON_ENVIRONMENT_VERSION = "python_environment_version"
|
||||
|
||||
|
||||
def get_environment_yaml_file() -> Path:
|
||||
|
|
|
@ -5,11 +5,16 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from enum import Enum, unique
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import param
|
||||
from health_azure.utils import (
|
||||
ENV_AMLT_AZ_BATCHAI_DIR, ENV_AMLT_INPUT_OUTPUT, ENV_AMLT_PROJECT_NAME,
|
||||
ENV_AMLT_SNAPSHOT_DIR, get_amlt_aml_working_dir, is_amulet_job, is_global_rank_zero
|
||||
)
|
||||
from pandas import DataFrame
|
||||
from param import Parameterized
|
||||
|
||||
|
@ -19,10 +24,10 @@ from InnerEye.Common.common_util import ModelProcessing, is_windows
|
|||
from InnerEye.Common.fixed_paths import DEFAULT_AML_UPLOAD_DIR, DEFAULT_LOGS_DIR_NAME
|
||||
from InnerEye.Common.generic_parsing import GenericConfig
|
||||
from InnerEye.Common.type_annotations import PathOrString, T, TupleFloat2
|
||||
from InnerEye.ML.common import CHECKPOINT_FOLDER, DATASET_CSV_FILE_NAME, \
|
||||
ModelExecutionMode, VISUALIZATION_FOLDER, \
|
||||
create_unique_timestamp_id, get_best_checkpoint_path
|
||||
from health_azure.utils import is_global_rank_zero
|
||||
from InnerEye.ML.common import (
|
||||
CHECKPOINT_FOLDER, DATASET_CSV_FILE_NAME, VISUALIZATION_FOLDER,
|
||||
ModelExecutionMode, create_unique_timestamp_id, get_best_checkpoint_path
|
||||
)
|
||||
|
||||
|
||||
@unique
|
||||
|
@ -150,9 +155,26 @@ class DeepLearningFileSystemConfig(Parameterized):
|
|||
else:
|
||||
logging.info("Running inside AzureML.")
|
||||
logging.info("All results will be written to a subfolder of the project root folder.")
|
||||
run_folder = project_root
|
||||
outputs_folder = project_root / DEFAULT_AML_UPLOAD_DIR
|
||||
logs_folder = project_root / DEFAULT_LOGS_DIR_NAME
|
||||
|
||||
if is_amulet_job():
|
||||
# Job submitted via Amulet
|
||||
amlt_root_folder = Path(os.environ[ENV_AMLT_INPUT_OUTPUT])
|
||||
project_name = os.environ[ENV_AMLT_PROJECT_NAME]
|
||||
snapshot_dir = get_amlt_aml_working_dir()
|
||||
assert snapshot_dir, \
|
||||
f"Either {ENV_AMLT_SNAPSHOT_DIR} or {ENV_AMLT_AZ_BATCHAI_DIR} must exist in env vars"
|
||||
print(f"Found the following environment variables set by Amulet: "
|
||||
f"AZURE_ML_INPUT_OUTPUT: {amlt_root_folder}, AZUREML_ARM_PROJECT_NAME: {project_name}")
|
||||
run_id = RUN_CONTEXT.id
|
||||
run_folder = amlt_root_folder / "projects" / project_name / "amlt-code" / run_id
|
||||
outputs_folder = snapshot_dir / DEFAULT_AML_UPLOAD_DIR
|
||||
logs_folder = snapshot_dir / DEFAULT_LOGS_DIR_NAME
|
||||
|
||||
else:
|
||||
run_folder = project_root
|
||||
outputs_folder = project_root / DEFAULT_AML_UPLOAD_DIR
|
||||
logs_folder = project_root / DEFAULT_LOGS_DIR_NAME
|
||||
|
||||
logging.info(f"Run outputs folder: {outputs_folder}")
|
||||
logging.info(f"Logs folder: {logs_folder}")
|
||||
return DeepLearningFileSystemConfig(
|
||||
|
|
|
@ -16,7 +16,7 @@ import torch.multiprocessing
|
|||
from azureml._restclient.constants import RunStatus
|
||||
from azureml.core import Model, Run, model
|
||||
from health_azure import AzureRunInfo
|
||||
from health_azure.utils import ENVIRONMENT_VERSION, create_run_recovery_id, is_global_rank_zero
|
||||
from health_azure.utils import ENVIRONMENT_VERSION, create_run_recovery_id, is_amulet_job, is_global_rank_zero
|
||||
from pytorch_lightning import LightningModule, seed_everything
|
||||
from pytorch_lightning.core.datamodule import LightningDataModule
|
||||
from torch.utils.data import DataLoader
|
||||
|
@ -36,7 +36,7 @@ from InnerEye.Common.common_util import (
|
|||
SUBJECT_METRICS_FILE_NAME, ModelProcessing, change_working_directory, get_best_epoch_results_path,
|
||||
is_windows, logging_section, merge_conda_files, print_exception, remove_file_or_directory
|
||||
)
|
||||
from InnerEye.Common.fixed_paths import INNEREYE_PACKAGE_NAME, PYTHON_ENVIRONMENT_NAME
|
||||
from InnerEye.Common.fixed_paths import INNEREYE_PACKAGE_NAME, PYTHON_ENVIRONMENT_NAME, PYTHON_ENVIRONMENT_VERSION
|
||||
from InnerEye.Common.type_annotations import PathOrString
|
||||
from InnerEye.ML.baselines_util import compare_folders_and_run_outputs
|
||||
from InnerEye.ML.common import (
|
||||
|
@ -571,8 +571,13 @@ class MLRunner:
|
|||
run_to_register_on = RUN_CONTEXT
|
||||
logging.info(f"Registering the model on the current run {run_to_register_on.id}")
|
||||
logging.info(f"Uploading files in {final_model_folder} with prefix '{artifacts_path}'")
|
||||
final_model_folder_relative = final_model_folder.relative_to(Path.cwd())
|
||||
run_to_register_on.upload_folder(name=artifacts_path, path=str(final_model_folder_relative))
|
||||
|
||||
if is_amulet_job():
|
||||
final_model_upload_path = final_model_folder
|
||||
else:
|
||||
final_model_upload_path = final_model_folder.relative_to(Path.cwd())
|
||||
|
||||
run_to_register_on.upload_folder(name=artifacts_path, path=str(final_model_upload_path))
|
||||
# When registering the model on the run, we need to provide a relative path inside of the run's output
|
||||
# folder in `model_path`
|
||||
model = run_to_register_on.register_model(
|
||||
|
@ -584,10 +589,16 @@ class MLRunner:
|
|||
# on the model. We could add that as an immutable property, but with tags we have the option to modify
|
||||
# to a custom environment later.
|
||||
python_environment = RUN_CONTEXT.get_environment()
|
||||
assert python_environment.version == ENVIRONMENT_VERSION, \
|
||||
f"Expected all Python environments to have version '{ENVIRONMENT_VERSION}', but got: " \
|
||||
f"'{python_environment.version}"
|
||||
model.add_tags({PYTHON_ENVIRONMENT_NAME: python_environment.name})
|
||||
|
||||
if not is_amulet_job():
|
||||
# amulet jobs re-use environment names, so we can't fix to version 1
|
||||
assert python_environment.version == ENVIRONMENT_VERSION, \
|
||||
f"Expected all Python environments to have version '{ENVIRONMENT_VERSION}', but got: " \
|
||||
f"'{python_environment.version}"
|
||||
model.add_tags({
|
||||
PYTHON_ENVIRONMENT_NAME: python_environment.name,
|
||||
PYTHON_ENVIRONMENT_VERSION: python_environment.version,
|
||||
})
|
||||
# update the run's tags with the registered model information
|
||||
run_to_register_on.tag(MODEL_ID_KEY_NAME, model.id)
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ dependencies:
|
|||
- blosc=1.21.0=h4ff587b_1
|
||||
- bzip2=1.0.8=h7b6447c_0
|
||||
- ca-certificates=2022.07.19=h06a4308_0
|
||||
- certifi=2022.6.15=py38h06a4308_0
|
||||
- certifi=2022.9.14=py38h06a4308_0
|
||||
- cudatoolkit=11.3.1=h2bc3f7f_2
|
||||
- ffmpeg=4.2.2=h20bf706_0
|
||||
- freetype=2.11.0=h70c0345_0
|
||||
|
@ -67,7 +67,7 @@ dependencies:
|
|||
- pip:
|
||||
- absl-py==1.2.0
|
||||
- adal==1.2.7
|
||||
- aiohttp==3.8.1
|
||||
- aiohttp==3.8.3
|
||||
- aiosignal==1.2.0
|
||||
- alabaster==0.7.12
|
||||
- alembic==1.8.1
|
||||
|
@ -134,7 +134,7 @@ dependencies:
|
|||
- dotnetcore2==2.1.23
|
||||
- entrypoints==0.4
|
||||
- execnet==1.9.0
|
||||
- fastjsonschema==2.16.1
|
||||
- fastjsonschema==2.16.2
|
||||
- fastmri==0.2.0
|
||||
- flake8==3.8.3
|
||||
- flask==2.2.2
|
||||
|
@ -148,12 +148,12 @@ dependencies:
|
|||
- google-auth-oauthlib==0.4.6
|
||||
- gputil==1.4.0
|
||||
- greenlet==1.1.3
|
||||
- grpcio==1.48.1
|
||||
- grpcio==1.49.1
|
||||
- gunicorn==20.1.0
|
||||
- h5py==2.10.0
|
||||
- hi-ml==0.2.5
|
||||
- hi-ml-azure==0.2.5
|
||||
- humanize==4.3.0
|
||||
- humanize==4.4.0
|
||||
- idna==3.4
|
||||
- imageio==2.15.0
|
||||
- imagesize==1.4.1
|
||||
|
@ -161,7 +161,7 @@ dependencies:
|
|||
- importlib-resources==5.9.0
|
||||
- iniconfig==1.1.1
|
||||
- innereye-dicom-rt==1.0.3
|
||||
- ipykernel==6.15.3
|
||||
- ipykernel==6.16.0
|
||||
- ipython==7.31.1
|
||||
- ipython-genutils==0.2.0
|
||||
- ipywidgets==8.0.2
|
||||
|
@ -183,22 +183,22 @@ dependencies:
|
|||
- lightning-bolts==0.4.0
|
||||
- llvmlite==0.34.0
|
||||
- lxml==4.9.1
|
||||
- mako==1.2.2
|
||||
- mako==1.2.3
|
||||
- markdown==3.4.1
|
||||
- markdown-it-py==2.1.0
|
||||
- markupsafe==2.1.1
|
||||
- marshmallow==3.17.1
|
||||
- marshmallow==3.18.0
|
||||
- marshmallow-enum==1.5.1
|
||||
- matplotlib==3.3.0
|
||||
- mccabe==0.6.1
|
||||
- mdit-py-plugins==0.3.0
|
||||
- mdit-py-plugins==0.3.1
|
||||
- mdurl==0.1.2
|
||||
- mistune==2.0.4
|
||||
- mlflow==1.23.1
|
||||
- mlflow-skinny==1.28.0
|
||||
- mlflow-skinny==1.29.0
|
||||
- monai==0.6.0
|
||||
- more-itertools==8.14.0
|
||||
- msal==1.18.0
|
||||
- msal==1.19.0
|
||||
- msal-extensions==0.3.1
|
||||
- msrest==0.7.1
|
||||
- msrestazure==0.6.4
|
||||
|
@ -208,7 +208,7 @@ dependencies:
|
|||
- myst-parser==0.18.0
|
||||
- nbclient==0.6.8
|
||||
- nbconvert==7.0.0
|
||||
- nbformat==5.5.0
|
||||
- nbformat==5.6.1
|
||||
- ndg-httpsclient==0.5.1
|
||||
- nest-asyncio==1.5.5
|
||||
- networkx==2.8.6
|
||||
|
@ -257,9 +257,9 @@ dependencies:
|
|||
- python-dateutil==2.8.2
|
||||
- pytorch-lightning==1.6.5
|
||||
- pytz==2022.2.1
|
||||
- pywavelets==1.3.0
|
||||
- pywavelets==1.4.1
|
||||
- pyyaml==6.0
|
||||
- pyzmq==23.2.1
|
||||
- pyzmq==24.0.1
|
||||
- qtconsole==5.3.2
|
||||
- qtpy==2.2.0
|
||||
- querystring-parser==1.2.4
|
||||
|
@ -293,11 +293,11 @@ dependencies:
|
|||
- sphinxcontrib-qthelp==1.0.3
|
||||
- sphinxcontrib-serializinghtml==1.1.5
|
||||
- sqlalchemy==1.4.41
|
||||
- sqlparse==0.4.2
|
||||
- sqlparse==0.4.3
|
||||
- stopit==1.1.2
|
||||
- stringcase==1.2.0
|
||||
- tabulate==0.8.7
|
||||
- tenacity==8.0.1
|
||||
- tenacity==8.1.0
|
||||
- tensorboard==2.3.0
|
||||
- tensorboard-plugin-wit==1.8.1
|
||||
- tensorboardx==2.1
|
||||
|
|
|
@ -11,7 +11,7 @@ dependencies:
|
|||
- blosc=1.21.0=h4ff587b_1
|
||||
- bzip2=1.0.8=h7b6447c_0
|
||||
- ca-certificates=2022.07.19=h06a4308_0
|
||||
- certifi=2022.6.15=py38h06a4308_0
|
||||
- certifi=2022.9.14=py38h06a4308_0
|
||||
- cudatoolkit=11.3.1=h2bc3f7f_2
|
||||
- ffmpeg=4.2.2=h20bf706_0
|
||||
- freetype=2.11.0=h70c0345_0
|
||||
|
@ -67,7 +67,7 @@ dependencies:
|
|||
- pip:
|
||||
- absl-py==1.2.0
|
||||
- adal==1.2.7
|
||||
- aiohttp==3.8.1
|
||||
- aiohttp==3.8.3
|
||||
- aiosignal==1.2.0
|
||||
- alabaster==0.7.12
|
||||
- alembic==1.8.1
|
||||
|
@ -134,7 +134,7 @@ dependencies:
|
|||
- dotnetcore2==2.1.23
|
||||
- entrypoints==0.4
|
||||
- execnet==1.9.0
|
||||
- fastjsonschema==2.16.1
|
||||
- fastjsonschema==2.16.2
|
||||
- fastmri==0.2.0
|
||||
- flake8==3.8.3
|
||||
- flask==2.2.2
|
||||
|
@ -148,12 +148,12 @@ dependencies:
|
|||
- google-auth-oauthlib==0.4.6
|
||||
- gputil==1.4.0
|
||||
- greenlet==1.1.3
|
||||
- grpcio==1.48.1
|
||||
- grpcio==1.49.1
|
||||
- gunicorn==20.1.0
|
||||
- h5py==2.10.0
|
||||
- hi-ml==0.2.5
|
||||
- hi-ml-azure==0.2.5
|
||||
- humanize==4.3.0
|
||||
- humanize==4.4.0
|
||||
- idna==3.4
|
||||
- imageio==2.15.0
|
||||
- imagesize==1.4.1
|
||||
|
@ -161,7 +161,7 @@ dependencies:
|
|||
- importlib-resources==5.9.0
|
||||
- iniconfig==1.1.1
|
||||
- innereye-dicom-rt==1.0.3
|
||||
- ipykernel==6.15.3
|
||||
- ipykernel==6.16.0
|
||||
- ipython==7.31.1
|
||||
- ipython-genutils==0.2.0
|
||||
- ipywidgets==8.0.2
|
||||
|
@ -183,22 +183,22 @@ dependencies:
|
|||
- lightning-bolts==0.4.0
|
||||
- llvmlite==0.34.0
|
||||
- lxml==4.9.1
|
||||
- mako==1.2.2
|
||||
- mako==1.2.3
|
||||
- markdown==3.4.1
|
||||
- markdown-it-py==2.1.0
|
||||
- markupsafe==2.1.1
|
||||
- marshmallow==3.17.1
|
||||
- marshmallow==3.18.0
|
||||
- marshmallow-enum==1.5.1
|
||||
- matplotlib==3.3.0
|
||||
- mccabe==0.6.1
|
||||
- mdit-py-plugins==0.3.0
|
||||
- mdit-py-plugins==0.3.1
|
||||
- mdurl==0.1.2
|
||||
- mistune==2.0.4
|
||||
- mlflow==1.23.1
|
||||
- mlflow-skinny==1.28.0
|
||||
- mlflow-skinny==1.29.0
|
||||
- monai==0.6.0
|
||||
- more-itertools==8.14.0
|
||||
- msal==1.18.0
|
||||
- msal==1.19.0
|
||||
- msal-extensions==0.3.1
|
||||
- msrest==0.7.1
|
||||
- msrestazure==0.6.4
|
||||
|
@ -208,7 +208,7 @@ dependencies:
|
|||
- myst-parser==0.18.0
|
||||
- nbclient==0.6.8
|
||||
- nbconvert==7.0.0
|
||||
- nbformat==5.5.0
|
||||
- nbformat==5.6.1
|
||||
- ndg-httpsclient==0.5.1
|
||||
- nest-asyncio==1.5.5
|
||||
- networkx==2.8.6
|
||||
|
@ -257,9 +257,9 @@ dependencies:
|
|||
- python-dateutil==2.8.2
|
||||
- pytorch-lightning==1.6.5
|
||||
- pytz==2022.2.1
|
||||
- pywavelets==1.3.0
|
||||
- pywavelets==1.4.1
|
||||
- pyyaml==6.0
|
||||
- pyzmq==23.2.1
|
||||
- pyzmq==24.0.1
|
||||
- qtconsole==5.3.2
|
||||
- qtpy==2.2.0
|
||||
- querystring-parser==1.2.4
|
||||
|
@ -293,11 +293,11 @@ dependencies:
|
|||
- sphinxcontrib-qthelp==1.0.3
|
||||
- sphinxcontrib-serializinghtml==1.1.5
|
||||
- sqlalchemy==1.4.41
|
||||
- sqlparse==0.4.2
|
||||
- sqlparse==0.4.3
|
||||
- stopit==1.1.2
|
||||
- stringcase==1.2.0
|
||||
- tabulate==0.8.7
|
||||
- tenacity==8.0.1
|
||||
- tenacity==8.1.0
|
||||
- tensorboard==2.3.0
|
||||
- tensorboard-plugin-wit==1.8.1
|
||||
- tensorboardx==2.1
|
||||
|
|
Загрузка…
Ссылка в новой задаче