Enable logging from outside AzureML, logging bugfixes (#167)

This commit is contained in:
Anton Schwaighofer 2021-11-25 14:16:55 +00:00 коммит произвёл GitHub
Родитель 145e7dc9a2
Коммит c4ad965d23
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
16 изменённых файлов: 524 добавлений и 185 удалений

Просмотреть файл

@ -15,13 +15,18 @@ the section headers (Added/Changed/...) and incrementing the package version.
### Added
- ([#159](https://github.com/microsoft/hi-ml/pull/159)) Add profiling for loading png image files as numpy arrays.
- ([#152](https://github.com/microsoft/hi-ml/pull/152)) Add a custom HTML reporting tool
- ([#167](https://github.com/microsoft/hi-ml/pull/167)) Ability to log to an AzureML run when outside of AzureML
### Changed
- ([164](https://github.com/microsoft/hi-ml/pull/164)) Look in more locations for std out from AzureML run.
- ([#167](https://github.com/microsoft/hi-ml/pull/167)) The AzureMLLogger has one mandatory argument now, that controls
whether it should log to AzureML also when running on a VM.
### Fixed
- ([#161](https://github.com/microsoft/hi-ml/pull/161)) Empty string as target folder for a dataset creates an invalid mounting path for the dataset in AzureML (fixes #160)
- ([#167](https://github.com/microsoft/hi-ml/pull/167)) Fix bugs in logging hyperparameters: logging as name/value
table, rather than one column per hyperparameter. Use string logging for all hyperparameters
### Removed

Просмотреть файл

@ -66,8 +66,8 @@ This will build all your documentation in `docs/build/html`.
* In the browser, navigate to the AzureML workspace that you want to use for running your tests.
* In the top right section, there will be a dropdown menu showing the name of your AzureML workspace. Expand that.
* In the panel, there is a link "Download config file". Click that.
* This will download a file `config.json`. Move that file to the root folder of your `hi-ml` repository. The file name
is already present in `.gitignore`, and will hence not be checked in.
* This will download a file `config.json`. Move that file to both of the folders `hi-ml/testhiml` and `hi-ml/testazure`
The file `config.json` is already present in `.gitignore`, and will hence not be checked in.
## Creating and Deleting Docker Environments in AzureML

Просмотреть файл

@ -1,45 +1,118 @@
# Logging metrics when training models in AzureML
# Logging metrics when training models in and outside AzureML
This section describes the basics of logging to AzureML, and how this can be simplified when using PyTorch Lightning.
It also describes helper functions to make logging more consistent across your code.
This section describes the basics of logging to AzureML, and how this can be simplified when using PyTorch Lightning. It
also describes helper functions to make logging more consistent across your code.
## Basics
The mechanics of writing metrics to an ML training run inside of AzureML are described
The mechanics of writing metrics to an ML training run inside of AzureML are described
[here](https://docs.microsoft.com/en-us/azure/machine-learning/how-to-log-view-metrics).
Using the `hi-ml-azure` toolbox, you can simplify that like this:
```python
from health_azure import RUN_CONTEXT
...
RUN_CONTEXT.log(name="name_of_the_metric", value=my_tensor.item())
```
Similarly you can log strings (via the `log_text` method) or figures (via the `log_image` method), see the
Similarly you can log strings (via the `log_text` method) or figures (via the `log_image` method), see the
[documentation](https://docs.microsoft.com/en-us/azure/machine-learning/how-to-log-view-metrics).
## Using PyTorch Lightning
The `hi-ml` toolbox relies on `pytorch-lightning` for a lot of its functionality.
Logging of metrics is described in detail
The `hi-ml` toolbox relies on `pytorch-lightning` for a lot of its functionality. Logging of metrics is described in
detail
[here](https://pytorch-lightning.readthedocs.io/en/latest/extensions/logging.html)
`hi-ml` provides a Lightning-ready logger object to use with AzureML. You can add that to your trainer as you would
add a Tensorboard logger, and afterwards see all metrics in both your Tensorboard files and in the AzureML UI.
This logger can be added to the `Trainer` object as follows:
`hi-ml` provides a Lightning-ready logger object to use with AzureML. You can add that to your trainer as you would add
a Tensorboard logger, and afterwards see all metrics in both your Tensorboard files and in the AzureML UI. This logger
can be added to the `Trainer` object as follows:
```python
from health_ml.utils import AzureMLLogger
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
tb_logger = TensorBoardLogger("logs/")
azureml_logger = AzureMLLogger(enable_logging_outside_azure_ml=False)
trainer = Trainer(logger=[tb_logger, azureml_logger])
```
You do not need to make any changes to your logging code to write to both loggers at the same time. This means that, if
your code correctly writes to Tensorboard in a local run, you can expect the metrics to come out correctly in the
AzureML UI as well after adding the `AzureMLLogger`.
## Logging to AzureML when running outside AzureML
You may still see the need to run some of your training jobs on an individual VM, for example small jobs or for
debugging. Keeping track of the results in those runs can be tricky, and comparing or sharing them even more.
All results that you achieve in such runs outside AzureML can be written straight into AzureML using the
`AzureMLLogger`. Its behaviour is as follows:
* When instantiated inside a run in AzureML, it will write metrics straight to the present run.
* When instantiated outside an AzureML run, it will create a new `Run` object that writes its metrics straight through
to AzureML, even though the code is not running in AzureML.
This behaviour is controlled by the `enable_logging_outside_azure_ml` argument. With the following code snippet,
you can to use the `AzureMLLogger` to write metrics to AzureML when the code is inside or outside AzureML:
```python
from health_ml.utils import AzureMLLogger
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import Trainer
tb_logger = TensorBoardLogger("logs/")
azureml_logger = AzureMLLogger()
azureml_logger = AzureMLLogger(enable_logging_outside_azure_ml=True)
trainer = Trainer(logger=[tb_logger, azureml_logger])
```
You do not need to make any changes to your logging code to write to both loggers at the same time. This means
that, if your code correctly writes to Tensorboard in a local run, you can expect the metrics to come out correctly
in the AzureML UI as well after adding the `AzureMLLogger`.
If this is executed on a VM outside an AzureML run, you will see additional information printed to the console like
this:
```text
Writing metrics to run ed52cfac-1b85-42ea-8ebe-2f90de21be6b in experiment azureml_logger.
To check progress, visit this URL: https://ml.azure.com/runs/ed52cfac-1b85-42ea-8ebe-2f90de21be...
```
Clicking on the URL will take you to the AzureML web page, where you can inspect the metrics that the run has written so
far.
There are a few points that you should note:
**Experiments**: Each run in AzureML is associated with an experiment. When executed in an AzureML run,
the `AzureMLLogger` will know which experiment to write to. Outside AzureML, on your VM, the logger will default to
using an experiment called `azureml-logger`. This means that runs inside and outside AzureML end up in different
experiments. You can customize this like in the following code snippet, so that the submitted runs and the runs outside
AzureML end up in the same experiment:
```python
from health_azure import submit_to_azure_if_needed
from health_ml.utils import AzureMLLogger
from pytorch_lightning import Trainer
experiment_name = "my_new_architecture"
submit_to_azure_if_needed(compute_cluster_name="nd24",
experiment_name=experiment_name)
azureml_logger = AzureMLLogger(enable_logging_outside_azure_ml=True, experiment_name=experiment_name)
trainer = Trainer(logger=[azureml_logger])
```
**Snapshots**: The run that you are about the create can follow the usual pattern of AzureML runs, and can create a full
snapshot of all code that was used in the experiment. This will greatly improve reproducibility of your experiments. By
default, this behaviour is turned off, though. You can provide an additional argument to the logger, like
`AzureMLLogger(snapshot_directory='/users/me/code/trainer')` to include the given folder in the snapshot. In addition,
you can place a file called `.amlignore` to exclude previous results, or large checkpoint files, from being included in
the snapshot
(see [here for details](https://docs.microsoft.com/en-us/azure/machine-learning/how-to-save-write-experiment-files#storage-limits-of-experiment-snapshots))
## Making logging consistent when training with PyTorch Lightning
A common problem of training scripts is that the calls to the logging methods tend to run out of sync.
The `.log` method of a `LightningModule` has a lot of arguments, some of which need to be set correctly when running
on multiple GPUs.
A common problem of training scripts is that the calls to the logging methods tend to run out of sync. The `.log` method
of a `LightningModule` has a lot of arguments, some of which need to be set correctly when running on multiple GPUs.
To simplify that, there is a function `log_on_epoch` that turns synchronization across nodes on/off depending on the
number of GPUs, and always forces the metrics to be logged upon epoch completion. Use as follows:
@ -48,6 +121,7 @@ number of GPUs, and always forces the metrics to be logged upon epoch completion
from health_ml.utils import log_on_epoch
from pytorch_lightning import LightningModule
class MyModule(LightningModule):
def training_step(self, *args, **kwargs):
...
@ -60,15 +134,18 @@ class MyModule(LightningModule):
Logging learning rates is important for monitoring training, but again this can add overhead. To log learning rates
easily and consistently, we suggest either of two options:
* Add a `LearningRateMonitor` callback to your trainer, as described
[here](https://pytorch-lightning.readthedocs.io/en/latest/extensions/generated/pytorch_lightning.callbacks.LearningRateMonitor.html#pytorch_lightning.callbacks.LearningRateMonitor)
* Add a `LearningRateMonitor` callback to your trainer, as described
[here](https://pytorch-lightning.readthedocs.io/en/latest/extensions/generated/pytorch_lightning.callbacks.LearningRateMonitor.html#pytorch_lightning.callbacks.LearningRateMonitor)
* Use the `hi-ml` function `log_learning_rate`
The `log_learning_rate` function can be used at any point the training code, like this:
```python
from health_ml.utils import log_learning_rate
from pytorch_lightning import LightningModule
class MyModule(LightningModule):
def training_step(self, *args, **kwargs):
...
@ -76,6 +153,7 @@ class MyModule(LightningModule):
loss = my_loss(y_pred, y)
return loss
```
`log_learning_rate` will log values from all learning rate schedulers, and all learning rates if a scheduler
returns multiple values. In this example, the logged metric will be `learning_rate` if there is a single scheduler
that outputs a single LR, or `learning_rate/1/0` to indicate the value coming from scheduler index 1, value index 0.
`log_learning_rate` will log values from all learning rate schedulers, and all learning rates if a scheduler returns
multiple values. In this example, the logged metric will be `learning_rate` if there is a single scheduler that outputs
a single LR, or `learning_rate/1/0` to indicate the value coming from scheduler index 1, value index 0.

Просмотреть файл

@ -3,18 +3,20 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from health_azure.utils import (download_checkpoints_from_run_id, download_files_from_run_id,
download_from_datastore, fetch_run, get_most_recent_run, is_running_in_azure_ml,
set_environment_variables_for_multi_node, split_recovery_id, torch_barrier,
upload_to_datastore, RUN_CONTEXT, aggregate_hyperdrive_metrics)
from health_azure.datasets import DatasetConfig
from health_azure.himl import (AzureRunInfo, create_run_configuration, create_script_run, get_workspace, submit_run,
submit_to_azure_if_needed, create_crossval_hyperdrive_config)
from health_azure.himl import (AzureRunInfo, create_crossval_hyperdrive_config, create_run_configuration,
create_script_run, get_workspace, submit_run, submit_to_azure_if_needed)
from health_azure.utils import (RUN_CONTEXT, aggregate_hyperdrive_metrics, create_aml_run_object,
download_checkpoints_from_run_id, download_files_from_run_id, download_from_datastore,
fetch_run, get_most_recent_run, is_running_in_azure_ml,
set_environment_variables_for_multi_node, split_recovery_id, torch_barrier,
upload_to_datastore)
__all__ = [
"AzureRunInfo",
"DatasetConfig",
"RUN_CONTEXT",
"create_aml_run_object",
"create_run_configuration",
"create_script_run",
"download_files_from_run_id",

Просмотреть файл

@ -9,17 +9,17 @@ import hashlib
import json
import logging
import os
import pandas as pd
import param
import re
import tempfile
from argparse import ArgumentParser, OPTIONAL
from collections import defaultdict
from itertools import islice
from pathlib import Path
from typing import Any, Callable, DefaultDict, Dict, List, Optional, Tuple, Type, TypeVar, Union, Set
from typing import Any, Callable, DefaultDict, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
import conda_merge
import pandas as pd
import param
import ruamel.yaml
from azureml._restclient.constants import RunStatus
from azureml.core import Environment, Experiment, Run, Workspace, get_run
@ -500,7 +500,7 @@ def _find_file(file_name: str, stop_at_pythonpath: bool = True) -> Optional[Path
Recurse up the file system, starting at the current working directory, to find a file. Optionally stop when we hit
the PYTHONPATH root (defaults to stopping).
:param file_name: The fine name of the file to find.
:param file_name: The file name of the file to find.
:param stop_at_pythonpath: (Defaults to True.) Whether to stop at the PYTHONPATH root.
:return: The path to the file, or None if it cannot be found.
"""
@ -510,9 +510,10 @@ def _find_file(file_name: str, stop_at_pythonpath: bool = True) -> Optional[Path
file_name: str,
stop_at_pythonpath: bool,
pythonpaths: List[Path]) -> Optional[Path]:
for child in start_at.iterdir():
if child.is_file() and child.name == file_name:
return child
logging.debug(f"Searching for file {file_name} in {start_at}")
expected = start_at / file_name
if expected.is_file() and expected.name == file_name:
return expected
if start_at.parent == start_at or start_at in pythonpaths:
return None
return return_file_or_parent(start_at.parent, file_name, stop_at_pythonpath, pythonpaths)
@ -1381,3 +1382,80 @@ def download_files_from_hyperdrive_children(run: Run, remote_file_path: str, loc
downloaded_file_paths.append(str(downloaded_file_path))
return downloaded_file_paths
def create_aml_run_object(experiment_name: str,
run_name: Optional[str] = None,
workspace: Optional[Workspace] = None,
workspace_config_path: Optional[Path] = None,
snapshot_directory: Optional[PathOrString] = None) -> Run:
"""
Creates an AzureML Run object in the given workspace, or in the workspace given by the AzureML config file.
This Run object can be used to write metrics to AzureML, upload files, etc, when the code is not running in
AzureML. After finishing all operations, use `run.flush()` to write metrics to the cloud, and `run.complete()` or
`run.fail()`.
Example:
>>>run = create_aml_run_object(experiment_name="run_on_my_vm", run_name="try1")
>>>run.log("foo", 1.23)
>>>run.flush()
>>>run.complete()
:param experiment_name: The AzureML experiment that should hold the run that will be created.
:param run_name: An optional name for the run (this will be used as the display name in the AzureML UI)
:param workspace: If provided, use this workspace to create the run in. If not provided, use the workspace
specified by the `config.json` file in the folder or its parent folder(s).
:param workspace_config_path: If not provided with an AzureML workspace, then load one given the information in this
config file.
:param snapshot_directory: The folder that should be included as the code snapshot. By default, no snapshot
is created (snapshot_directory=None or snapshot_directory=""). Set this to the folder that contains all the
code your experiment uses. You can use a file .amlignore to skip specific files or folders, akin to .gitignore
:return: An AzureML Run object.
"""
actual_workspace = get_workspace(aml_workspace=workspace, workspace_config_path=workspace_config_path)
exp = Experiment(workspace=actual_workspace, name=experiment_name)
if snapshot_directory is None or snapshot_directory == "":
snapshot_directory = tempfile.mkdtemp()
return exp.start_logging(name=run_name, snapshot_directory=str(snapshot_directory)) # type: ignore
def aml_workspace_for_unittests() -> Workspace:
"""
Gets the default AzureML workspace that is used for unit testing. It first tries to locate a workspace config.json
file in the present folder or its parents, and create a workspace from that if found. If no config.json file
is found, the workspace details are read from environment variables. Authentication information is also read
from environment variables.
"""
config_json = _find_file(WORKSPACE_CONFIG_JSON)
if config_json is not None:
return Workspace.from_config(path=str(config_json))
else:
workspace_name = get_secret_from_environment(ENV_WORKSPACE_NAME, allow_missing=False)
subscription_id = get_secret_from_environment(ENV_SUBSCRIPTION_ID, allow_missing=False)
resource_group = get_secret_from_environment(ENV_RESOURCE_GROUP, allow_missing=False)
auth = get_authentication()
return Workspace.get(name=workspace_name,
auth=auth,
subscription_id=subscription_id,
resource_group=resource_group)
class UnitTestWorkspaceWrapper:
"""
Wrapper around aml_workspace so that it is lazily loaded only once. Used for unit testing only.
"""
def __init__(self) -> None:
"""
Init.
"""
self._workspace: Workspace = None
@property
def workspace(self) -> Workspace:
"""
Lazily load the aml_workspace.
"""
if self._workspace is None:
self._workspace = aml_workspace_for_unittests()
return self._workspace

Просмотреть файл

@ -13,7 +13,7 @@ import time
from enum import Enum
from pathlib import Path
from random import randint
from typing import Dict, List, Optional, Union, Any, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union
from unittest import mock
from unittest.mock import MagicMock, patch
from uuid import uuid4
@ -25,7 +25,7 @@ import pytest
from _pytest.capture import CaptureFixture
from _pytest.logging import LogCaptureFixture
from azureml._vendor.azure_storage.blob import Blob
from azureml.core import Experiment, ScriptRunConfig, Workspace
from azureml.core import Experiment, Run, ScriptRunConfig, Workspace
from azureml.core.authentication import ServicePrincipalAuthentication
from azureml.core.environment import CondaDependencies
from azureml.data.azure_storage_datastore import AzureBlobDatastore
@ -34,9 +34,8 @@ import health_azure.utils as util
from health_azure import himl
from health_azure.himl import AML_IGNORE_FILE, append_to_amlignore
from testazure.test_himl import RunTarget, render_and_run_test_script
from testazure.util import (DEFAULT_WORKSPACE, change_working_directory, repository_root, MockRun,
DEFAULT_IGNORE_FOLDERS)
from testazure.utils_testazure import (DEFAULT_IGNORE_FOLDERS, DEFAULT_WORKSPACE, MockRun, change_working_directory,
repository_root)
RUN_ID = uuid4().hex
RUN_NUMBER = 42
@ -1486,3 +1485,27 @@ def test_aggregate_hyperdrive_metrics(_: MagicMock) -> None:
assert isinstance(epochs[0], list)
test_accuracies = df.loc["test/accuracy"]
assert isinstance(test_accuracies[0], float)
def test_create_run() -> None:
"""
Test if we can create an AML run object here in the test suite, write logs and read them back in.
"""
run_name = "foo"
experiment_name = "himl-tests"
run: Optional[Run] = None
try:
run = util.create_aml_run_object(experiment_name=experiment_name, run_name=run_name,
workspace=DEFAULT_WORKSPACE.workspace)
assert run is not None
assert run.name == run_name
assert run.experiment.name == experiment_name
metric_name = "mymetric"
metric_value = 1.234
run.log(metric_name, metric_value)
run.flush()
metrics = run.get_metrics(name=metric_name)
assert metrics[metric_name] == metric_value
finally:
if run is not None:
run.complete()

Просмотреть файл

@ -12,7 +12,7 @@ from typing import Dict, Optional
from jinja2 import Template
from testazure.util import himl_azure_root, DEFAULT_IGNORE_FOLDERS
from testazure.utils_testazure import himl_azure_root, DEFAULT_IGNORE_FOLDERS
here = Path(__file__).parent.resolve()

Просмотреть файл

@ -19,7 +19,7 @@ from azureml.exceptions._azureml_exception import UserErrorException
from health_azure.datasets import (DatasetConfig, _input_dataset_key, _output_dataset_key,
_replace_string_datasets, get_datastore, get_or_create_dataset)
from testazure.util import DEFAULT_DATASTORE, DEFAULT_WORKSPACE
from testazure.utils_testazure import DEFAULT_DATASTORE, DEFAULT_WORKSPACE
def test_datasetconfig_init() -> None:

Просмотреть файл

@ -34,7 +34,7 @@ from health_azure.datasets import (DatasetConfig, _input_dataset_key, _output_da
from health_azure.utils import (ENVIRONMENT_VERSION, EXPERIMENT_RUN_SEPARATOR, WORKSPACE_CONFIG_JSON,
get_most_recent_run, get_workspace, is_running_in_azure_ml)
from testazure.test_data.make_tests import render_environment_yaml, render_test_script
from testazure.util import DEFAULT_DATASTORE, change_working_directory, check_config_json, repository_root
from testazure.utils_testazure import DEFAULT_DATASTORE, change_working_directory, check_config_json, repository_root
INEXPENSIVE_TESTING_CLUSTER_NAME = "lite-testing-ds2"
EXPECTED_QUEUED = "This command will be run in AzureML:"

Просмотреть файл

@ -9,7 +9,7 @@ import pytest
import subprocess
from health_azure import himl_download
from testazure.util import MockRun
from testazure.utils_testazure import MockRun
DOWNLOAD_SCRIPT_PATH = himl_download.__file__

Просмотреть файл

@ -13,7 +13,7 @@ from health_azure import himl_tensorboard, himl
from health_azure import utils as azure_util
from health_azure.himl_tensorboard import WrappedTensorboard
from testazure.test_himl import render_and_run_test_script, RunTarget
from testazure.util import DEFAULT_WORKSPACE
from testazure.utils_testazure import DEFAULT_WORKSPACE
TENSORBOARD_SCRIPT_PATH = himl_tensorboard.__file__

Просмотреть файл

@ -13,10 +13,8 @@ from contextlib import contextmanager
from pathlib import Path
from typing import Dict, Generator, Optional
from azureml.core import Workspace
from health_azure.utils import (ENV_RESOURCE_GROUP, ENV_SUBSCRIPTION_ID, ENV_WORKSPACE_NAME, get_authentication,
get_secret_from_environment, WORKSPACE_CONFIG_JSON)
from health_azure.utils import (ENV_RESOURCE_GROUP, ENV_SUBSCRIPTION_ID, ENV_WORKSPACE_NAME, WORKSPACE_CONFIG_JSON,
UnitTestWorkspaceWrapper)
DEFAULT_DATASTORE = "himldatasets"
FALLBACK_SINGLE_RUN = "refs_pull_545_merge:refs_pull_545_merge_1626538212_d2b07afd"
@ -50,46 +48,7 @@ def repository_root() -> Path:
return himl_azure_root().parent
def default_aml_workspace() -> Workspace:
"""
Gets the default AzureML workspace that is used for testing.
"""
config_json = repository_root() / WORKSPACE_CONFIG_JSON
if config_json.is_file():
return Workspace.from_config()
else:
workspace_name = get_secret_from_environment(ENV_WORKSPACE_NAME, allow_missing=False)
subscription_id = get_secret_from_environment(ENV_SUBSCRIPTION_ID, allow_missing=False)
resource_group = get_secret_from_environment(ENV_RESOURCE_GROUP, allow_missing=False)
auth = get_authentication()
return Workspace.get(name=workspace_name,
auth=auth,
subscription_id=subscription_id,
resource_group=resource_group)
class WorkspaceWrapper:
"""
Wrapper around aml_workspace so that it is lazily loaded, once.
"""
def __init__(self) -> None:
"""
Init.
"""
self._workspace: Workspace = None
@property
def workspace(self) -> Workspace:
"""
Lazily load the aml_workspace.
"""
if self._workspace is None:
self._workspace = default_aml_workspace()
return self._workspace
DEFAULT_WORKSPACE = WorkspaceWrapper()
DEFAULT_WORKSPACE = UnitTestWorkspaceWrapper()
@contextmanager

Просмотреть файл

@ -10,27 +10,76 @@ import operator
import sys
import time
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, Dict, Mapping, Optional, Union
import torch
from azureml.core import Run, Workspace
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ProgressBarBase
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.utilities.distributed import rank_zero_only
from health_azure import is_running_in_azure_ml
from health_azure.utils import RUN_CONTEXT
from health_azure.utils import PathOrString, RUN_CONTEXT, create_aml_run_object
class AzureMLLogger(LightningLoggerBase):
"""
A Pytorch Lightning logger that stores metrics in the current AzureML run. If the present run is not
inside AzureML, nothing gets logged.
A Pytorch Lightning logger that stores metrics in the current AzureML run. This logger will always write metrics
to AzureML if the training run is executed in AzureML. It can optionally also write to AzureML if the training
run is executed somewhere else, for example on a VM outside of AzureML.
"""
def __init__(self) -> None:
HYPERPARAMS_NAME = "hyperparams"
"""
The name under which hyperparameters are written to the AzureML run.
"""
def __init__(self,
enable_logging_outside_azure_ml: bool,
experiment_name: str = "azureml_logger",
run_name: Optional[str] = None,
workspace: Optional[Workspace] = None,
workspace_config_path: Optional[Path] = None,
snapshot_directory: Optional[PathOrString] = None
) -> None:
"""
:param enable_logging_outside_azure_ml: If True, the AzureML logger will write metrics to AzureML even if
executed outside of an AzureML run (for example, when working on a separate virtual machine). If False,
the logger will only write metrics to AzureML if the code is actually running inside of AzureML.
:param experiment_name: The AzureML experiment that should hold the run when executed outside of AzureML.
:param run_name: An optional name for the run (this will be used as the display name in the AzureML UI). This
argument only matters when running outside of AzureML.
:param workspace: If provided, use this workspace to create the run in.
:param workspace_config_path: Use this path to read workspace configuration json file. If not provided,
use the workspace specified by the `config.json` file in the current working directory or its parents.
:param snapshot_directory: The folder that should be included as the code snapshot. By default, no snapshot
is created. Set this to the folder that contains all the code your experiment uses. You can use a file
.amlignore to skip specific files or folders, akin to .gitignore..
"""
super().__init__()
self.is_running_in_azure_ml = is_running_in_azure_ml()
self.run: Optional[Run] = None
self.has_custom_run = False
if self.is_running_in_azure_ml:
self.run = RUN_CONTEXT
elif enable_logging_outside_azure_ml:
try:
self.run = create_aml_run_object(experiment_name=experiment_name,
run_name=run_name,
workspace=workspace,
workspace_config_path=workspace_config_path,
snapshot_directory=snapshot_directory)
print(f"Writing metrics to run {self.run.id} in experiment {self.run.experiment.name}.")
print(f"To check progress, visit this URL: {self.run.get_portal_url()}")
self.has_custom_run = True
except Exception:
logging.error("Unable to create an AzureML run to store the results.")
raise
else:
print("AzureMLLogger will not write any logs because it is running outside AzureML, and the "
"'enable_logging_outside_azure_ml' flag is set to False")
@rank_zero_only
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
@ -43,33 +92,30 @@ class AzureMLLogger(LightningLoggerBase):
:param step: The trainer global step for logging.
"""
logging.debug(f"AzureMLLogger step={step}: {metrics}")
if self.run is None:
return
is_epoch_metric = "epoch" in metrics
if self.is_running_in_azure_ml:
for key, value in metrics.items():
# Log all epoch-level metrics without the step information
# All step-level metrics with step
RUN_CONTEXT.log(key, value, step=None if is_epoch_metric else step)
for key, value in metrics.items():
# Log all epoch-level metrics without the step information
# All step-level metrics with step
self.run.log(key, value, step=None if is_epoch_metric else step)
@rank_zero_only
def log_hyperparams(self, params: Union[argparse.Namespace, Dict[str, Any]]) -> None:
"""
Logs the given model hyperparameters to AzureML as a table. Namespaces are converted to dictionaries.
Nested dictionaries are flattened out.
Nested dictionaries are flattened out. The hyperparameters are then written as a table with two columns
"name" and "value".
"""
if not self.is_running_in_azure_ml:
if self.run is None:
return
if params is None:
return
# Convert from Namespace to dictionary
params = self._convert_params(params)
# Convert nested dictionaries to folder-like structure
params = self._flatten_dict(params)
# Convert anything that is not a primitive type to str
params = self._sanitize_params(params)
if not isinstance(params, dict):
raise ValueError(f"Expected the hyperparameters to be a dictionary, but got {type(params)}")
if len(params) > 0:
RUN_CONTEXT.log_table("hyperparams", params)
params_final = self._preprocess_hyperparams(params)
if len(params_final) > 0:
# Log hyperparameters as a table with 2 columns. Each "step" is one hyperparameter
self.run.log_table(self.HYPERPARAMS_NAME, {"name": list(params_final.keys()),
"value": list(params_final.values())})
def experiment(self) -> Any:
return None
@ -80,6 +126,29 @@ class AzureMLLogger(LightningLoggerBase):
def version(self) -> int:
return 0
def finalize(self, status: str) -> None:
if self.run is not None and self.has_custom_run:
# Run.complete should only be called if we created an AzureML run here in the constructor.
self.run.complete()
def _preprocess_hyperparams(self, params: Any) -> Dict[str, str]:
"""
Converts arbitrary hyperparameters to a simple dictionary structure, in particular argparse Namespaces.
Nested dictionaries are converted to folder-like strings, like ``{'a': {'b': 'c'}} -> {'a/b': 'c'}``.
All hyperparameter values are converted to strings, because Run.log_table can't deal with mixed datatypes.
:param params: The parameters to convert
:return: A dictionary mapping from string to string.
"""
# Convert from Namespace to dictionary
params = self._convert_params(params)
# Convert nested dictionaries to folder-like structure
params = self._flatten_dict(params)
# Convert anything that is not a primitive type to str
params_final = self._sanitize_params(params)
if not isinstance(params_final, dict):
raise ValueError(f"Expected the converted hyperparameters to be a dictionary, but got {type(params)}")
return {str(key): str(value) for key, value in params_final.items()}
class AzureMLProgressBar(ProgressBarBase):
"""

Просмотреть файл

@ -4,18 +4,31 @@
# ------------------------------------------------------------------------------------------
import logging
import math
import time
from argparse import Namespace
from datetime import datetime
from typing import Any, Dict
from pathlib import Path
from typing import Any, Dict, Optional
from unittest import mock
from unittest.mock import MagicMock
import pytest
import torch
from _pytest.capture import SysCapture
from _pytest.logging import LogCaptureFixture
from azureml._restclient.constants import RunStatus
from azureml.core import Run
from pytorch_lightning import Trainer
from health_azure import RUN_CONTEXT, create_aml_run_object
from health_ml.utils import AzureMLLogger, AzureMLProgressBar, log_learning_rate, log_on_epoch
from testhiml.utils_testhiml import DEFAULT_WORKSPACE
def create_unittest_run_object(snapshot_directory: Optional[Path] = None) -> Run:
return create_aml_run_object(experiment_name="himl-tests",
workspace=DEFAULT_WORKSPACE.workspace,
snapshot_directory=snapshot_directory or ".")
def test_log_on_epoch() -> None:
@ -119,109 +132,214 @@ def test_log_learning_rate_multiple() -> None:
'foo/1/1': lr2[1]}}
def create_mock_logger() -> AzureMLLogger:
"""
Create an AzureMLLogger that has a run field set to a MagicMock.
"""
run_mock = MagicMock()
with mock.patch("health_ml.utils.logging.create_aml_run_object", return_value=run_mock):
return AzureMLLogger(enable_logging_outside_azure_ml=True)
def test_azureml_logger() -> None:
"""
Tests logging to an AzureML run via PytorchLightning
"""
logger = AzureMLLogger()
logger = create_mock_logger()
# On all build agents, this should not be detected as an AzureML run.
assert not logger.is_running_in_azure_ml
# No logging should happen when outside AzureML
with mock.patch("health_azure.utils.RUN_CONTEXT.log") as log_mock:
logger.log_metrics({"foo": 1.0})
assert log_mock.call_count == 0
# Pretend to be running in AzureML
logger.is_running_in_azure_ml = True
with mock.patch("health_azure.utils.RUN_CONTEXT.log") as log_mock:
logger.log_metrics({"foo": 1.0})
assert log_mock.call_count == 1
assert log_mock.call_args[0] == ("foo", 1.0), "Should be called with the unrolled dictionary of metrics"
assert logger.has_custom_run
logger.log_metrics({"foo": 1.0})
assert logger.run is not None
logger.run.log.assert_called_once_with("foo", 1.0, step=None)
# All the following methods of LightningLoggerBase are not implemented
assert logger.name() == ""
assert logger.version() == 0
assert logger.experiment() is None
# Finalizing should call the "Complete" method of the run
logger.finalize(status="foo")
logger.run.complete.assert_called_once()
def test_azureml_logger_hyperparams() -> None:
def test_azureml_log_hyperparameters1() -> None:
"""
Tests logging of hyperparameters to an AzureML
Test logging of hyperparameters
"""
logger = AzureMLLogger()
# On all build agents, this should not be detected as an AzureML run.
assert not logger.is_running_in_azure_ml
# No logging should happen when outside AzureML
with mock.patch("health_azure.utils.RUN_CONTEXT.log_table") as log_mock:
logger.log_hyperparams({"foo": 1.0})
assert log_mock.call_count == 0
# Pretend to be running in AzureML
logger.is_running_in_azure_ml = True
logger = create_mock_logger()
assert logger.run is not None
# No logging should happen with empty params
with mock.patch("health_azure.utils.RUN_CONTEXT.log_table") as log_mock:
logger.log_hyperparams(None) # type: ignore
assert log_mock.call_count == 0
logger.log_hyperparams({})
assert log_mock.call_count == 0
logger.log_hyperparams(Namespace())
assert log_mock.call_count == 0
logger.log_hyperparams(None) # type: ignore
assert logger.run.log.call_count == 0
logger.log_hyperparams({})
assert logger.run.log.call_count == 0
logger.log_hyperparams(Namespace())
assert logger.run.log.call_count == 0
# Logging of hyperparameters that are plain dictionaries
with mock.patch("health_azure.utils.RUN_CONTEXT.log_table") as log_mock:
fake_params = {"foo": 1.0}
logger.log_hyperparams(fake_params)
assert log_mock.call_count == 1
assert log_mock.call_args[0] == ("hyperparams", fake_params), "Should be called with hyperparams dictionary"
fake_params = {"foo": 1.0}
logger.log_hyperparams(fake_params)
# Dictionary should be logged as name/value pairs, one value per row
logger.run.log_table.assert_called_once_with("hyperparams", {'name': ['foo'], 'value': ["1.0"]})
def test_azureml_logger_hyperparams2() -> None:
def test_azureml_log_hyperparameters2() -> None:
"""
Tests logging of complex hyperparameters to AzureML
Logging of hyperparameters that are Namespace objects from the arg parser
"""
logger = create_mock_logger()
assert logger.run is not None
class Dummy:
def __str__(self) -> str:
return "dummy"
logger = AzureMLLogger()
# Pretend to be running in AzureML
logger.is_running_in_azure_ml = True
fake_namespace = Namespace(foo="bar", complex_object=Dummy())
logger.log_hyperparams(fake_namespace)
# Complex objects are converted to str
expected_dict: Dict[str, Any] = {'name': ['foo', 'complex_object'], 'value': ['bar', 'dummy']}
logger.run.log_table.assert_called_once_with("hyperparams", expected_dict)
# Logging of hyperparameters that are Namespace objects from the arg parser
with mock.patch("health_azure.utils.RUN_CONTEXT.log_table") as log_mock:
fake_namespace = Namespace(foo="bar", complex_object=Dummy())
logger.log_hyperparams(fake_namespace)
assert log_mock.call_count == 1
# Complex objects are converted to str
expected_dict: Dict[str, Any] = {"foo": "bar", "complex_object": "dummy"}
assert log_mock.call_args[0] == ("hyperparams", expected_dict)
# Logging of hyperparameters that are nested dictionaries. They should first be flattened, than each complex
# object to str
with mock.patch("health_azure.utils.RUN_CONTEXT.log_table") as log_mock:
fake_namespace = Namespace(foo={"bar": 1, "baz": {"level3": Namespace(a="17")}})
logger.log_hyperparams(fake_namespace)
assert log_mock.call_count == 1
expected_dict = {"foo/bar": 1, "foo/baz/level3/a": "17"}
assert log_mock.call_args[0] == ("hyperparams", expected_dict)
def test_azureml_log_hyperparameters3() -> None:
"""
Logging of hyperparameters that are nested dictionaries. They should first be flattened, than each complex
object to str
"""
logger = create_mock_logger()
assert logger.run is not None
fake_namespace = Namespace(foo={"bar": 1, "baz": {"level3": Namespace(a="17")}})
logger.log_hyperparams(fake_namespace)
expected_dict = {"name": ["foo/bar", "foo/baz/level3/a"], "value": ["1", "17"]}
logger.run.log_table.assert_called_once_with("hyperparams", expected_dict)
def test_azureml_logger_many_hyperparameters(tmpdir: Path) -> None:
"""
Test if large number of hyperparameters are logged correctly.
Earlier versions of the code had a bug that only allowed a maximum of 15 hyperparams to be logged.
"""
many_hyperparams: Dict[str, Any] = {f"param{i}": i for i in range(0, 20)}
many_hyperparams["A long list"] = ["foo", 1.0, "abc"]
expected_metrics = {key: str(value) for key, value in many_hyperparams.items()}
logger: Optional[AzureMLLogger] = None
try:
logger = AzureMLLogger(enable_logging_outside_azure_ml=True, workspace=DEFAULT_WORKSPACE.workspace)
assert logger.run is not None
logger.log_hyperparams(many_hyperparams)
logger.run.flush()
time.sleep(1)
metrics = logger.run.get_metrics(name=AzureMLLogger.HYPERPARAMS_NAME)
print(f"metrics = {metrics}")
actual = metrics[AzureMLLogger.HYPERPARAMS_NAME]
assert actual["name"] == list(expected_metrics.keys())
assert actual["value"] == list(expected_metrics.values())
finally:
if logger:
logger.finalize("done")
def test_azureml_logger_hyperparams_processing() -> None:
"""
Test flattening of hyperparameters: Lists were not handled correctly in previous versions.
"""
hyperparams = {"A long list": ["foo", 1.0, "abc"],
"foo": 1.0}
logger = AzureMLLogger(enable_logging_outside_azure_ml=False)
actual = logger._preprocess_hyperparams(hyperparams)
assert actual == {"A long list": "['foo', 1.0, 'abc']", "foo": "1.0"}
def test_azureml_logger_step() -> None:
"""
Test if the AzureML logger correctly handles epoch-level and step metrics
"""
logger = AzureMLLogger()
# Pretend to be running in AzureML
logger.is_running_in_azure_ml = True
with mock.patch("health_azure.utils.RUN_CONTEXT.log") as log_mock:
logger.log_metrics(metrics={"foo": 1.0, "epoch": 123}, step=78)
assert log_mock.call_count == 2
assert log_mock.call_args_list[0][0] == ("foo", 1.0)
assert log_mock.call_args_list[0][1] == {"step": None}, "For epoch-level metrics, no step should be provided"
assert log_mock.call_args_list[1][0] == ("epoch", 123)
assert log_mock.call_args_list[1][1] == {"step": None}, "For epoch-level metrics, no step should be provided"
with mock.patch("health_azure.utils.RUN_CONTEXT.log") as log_mock:
logger.log_metrics(metrics={"foo": 1.0}, step=78)
assert log_mock.call_count == 1
assert log_mock.call_args[0] == ("foo", 1.0), "Should be called with the unrolled dictionary of metrics"
assert log_mock.call_args[1] == {"step": 78}, "For step-level metrics, the step argument should be provided"
logger = create_mock_logger()
assert logger.run is not None
logger.log_metrics(metrics={"foo": 1.0, "epoch": 123}, step=78)
assert logger.run.log.call_count == 2
assert logger.run.log.call_args_list[0][0] == ("foo", 1.0)
assert logger.run.log.call_args_list[0][1] == {"step": None}, "For epoch-level metrics, no step should be provided"
assert logger.run.log.call_args_list[1][0] == ("epoch", 123)
assert logger.run.log.call_args_list[1][1] == {"step": None}, "For epoch-level metrics, no step should be provided"
logger.run.reset_mock() # type: ignore
logger.log_metrics(metrics={"foo": 1.0}, step=78)
logger.run.log.assert_called_once_with("foo", 1.0, step=78)
def test_azureml_logger_init1() -> None:
"""
Test the logic to choose the run, inside of the constructor of AzureMLLogger.
"""
# When running in AzureML, the RUN_CONTEXT should be used
with mock.patch("health_ml.utils.logging.is_running_in_azure_ml", return_value=True):
with mock.patch("health_ml.utils.logging.RUN_CONTEXT", "foo"):
logger = AzureMLLogger(enable_logging_outside_azure_ml=True)
assert logger.is_running_in_azure_ml
assert not logger.has_custom_run
assert logger.run == "foo"
# We should be able to call finalize without any effect (logger.run == "foo", which has no
# "Complete" method). When running in AzureML, the logger should not
# modify the run in any way, and in particular not complete it.
logger.finalize("nothing")
def test_azureml_logger_init2() -> None:
"""
Test the logic to choose the run, inside of the constructor of AzureMLLogger.
"""
# When disabling offline logging, the logger should be a no-op, and not log anything
logger = AzureMLLogger(enable_logging_outside_azure_ml=False)
assert logger.run is None
logger.log_metrics("foo", 1.0)
logger.finalize(status="nothing")
def test_azureml_logger_actual_run() -> None:
"""
When running outside of AzureML, a new run should be created.
"""
logger = AzureMLLogger(enable_logging_outside_azure_ml=True, workspace=DEFAULT_WORKSPACE.workspace)
assert not logger.is_running_in_azure_ml
assert logger.run is not None
assert logger.run != RUN_CONTEXT
assert isinstance(logger.run, Run)
assert logger.run.experiment.name == "azureml_logger"
assert logger.has_custom_run
expected_metrics = {"foo": 1.0, "bar": 2.0}
logger.log_metrics(expected_metrics)
logger.run.flush()
actual_metrics = logger.run.get_metrics()
assert actual_metrics == expected_metrics
assert logger.run.status != RunStatus.COMPLETED
logger.finalize("nothing")
# The AzureML run has been complete now, insert a mock to check if
logger.run = MagicMock()
logger.finalize("nothing")
logger.run.complete.assert_called_once_with()
def test_azureml_logger_init4() -> None:
"""
Test the logic to choose the run, inside of the constructor of AzureMLLogger.
"""
# Check that all arguments are respected
run_mock = MagicMock()
with mock.patch("health_ml.utils.logging.create_aml_run_object", return_value=run_mock) as mock_create:
logger = AzureMLLogger(enable_logging_outside_azure_ml=True,
experiment_name="exp",
run_name="run",
snapshot_directory="snapshot",
workspace="workspace", # type: ignore
workspace_config_path=Path("config_path"))
assert logger.has_custom_run
assert logger.run == run_mock
mock_create.assert_called_once_with(experiment_name="exp",
run_name="run",
snapshot_directory="snapshot",
workspace="workspace",
workspace_config_path=Path("config_path"))
def test_progress_bar_enable() -> None:

Просмотреть файл

@ -0,0 +1,7 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from health_azure.utils import UnitTestWorkspaceWrapper
DEFAULT_WORKSPACE = UnitTestWorkspaceWrapper()

Просмотреть файл

@ -6,7 +6,7 @@
],
"exclude": [
"hi-ml/testhiml/testhiml/utils/image_loading",
"hi-ml/testhiml/testhiml/utils/slide_image_loading/src",
"hi-ml/testhiml/testhiml/utils/slide_image_loading/src"
],
"useLibraryCodeForTypes": false,
"reportMissingImports": true,