Switching batch time loading diagnostics to hi-ml (#577)

- Using the BatchTimeCallback from hi-ml
- Switches to trigger PL profiling
This commit is contained in:
Anton Schwaighofer 2021-11-03 15:48:10 +00:00 коммит произвёл GitHub
Родитель 8495a2eec3
Коммит bf4cb628c6
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
21 изменённых файлов: 204 добавлений и 356 удалений

Просмотреть файл

@ -2,4 +2,4 @@
ignore = E226,E302,E41,W391, E701, W291, E722, W503, E128, E126, E127, E731, E401
max-line-length = 160
max-complexity = 25
exclude = fastMRI/ test_outputs/
exclude = fastMRI/ test_outputs/ hi-ml/

Просмотреть файл

@ -4,8 +4,10 @@
<content url="file://$MODULE_DIR$">
<sourceFolder url="file://$MODULE_DIR$" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/fastMRI" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/hi-ml/hi-ml-azure/src" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/hi-ml/hi-ml/src" isTestSource="false" />
</content>
<orderEntry type="jdk" jdkName="3.7 @ Ubuntu 20.04" jdkType="Python SDK" />
<orderEntry type="jdk" jdkName="3.7 @ Ubuntu-20.04" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PackageRequirementsSettings">

Просмотреть файл

@ -13,6 +13,9 @@ created.
## Upcoming
### Added
- ([#577](https://github.com/microsoft/InnerEye-DeepLearning/pull/577)) Commandline switch `monitor_gpu` to monitor
GPU utilization via Lightning's `GpuStatsMonitor`, switch `monitor_loading` to check batch loading times via
`BatchTimeCallback`, and `pl_profiler` to turn on the Lightning profiler (`simple`, `advanced`, or `pytorch`)
- ([#544](https://github.com/microsoft/InnerEye-DeepLearning/pull/544)) Add documentation for segmentation model evaluation.
- ([#465](https://github.com/microsoft/InnerEye-DeepLearning/pull/465/)) Adding ability to run segmentation inference
module on test data with partial ground truth files. (Also [522](https://github.com/microsoft/InnerEye-DeepLearning/pull/522).)
@ -77,6 +80,8 @@ in inference-only runs when using lightning containers.
### Removed
- ([#577](https://github.com/microsoft/InnerEye-DeepLearning/pull/577)) Removing the monitoring of batch loading time,
use the `BatchTimeCallback` from `hi-ml` instead
- ([#542](https://github.com/microsoft/InnerEye-DeepLearning/pull/542)) Removed Windows test leg from build pipeline.
- ([#509](https://github.com/microsoft/InnerEye-DeepLearning/pull/509)) Parameters `local_weights_path` and
`weights_url` can no longer be used to initialize a training run, only inference runs.

Просмотреть файл

@ -106,7 +106,8 @@ def add_submodules_to_path() -> None:
innereye_root = repository_root_directory()
folders_to_add = [(innereye_root, "InnerEye"),
(innereye_root / "fastMRI", "fastmri"),
(innereye_root / "hi-ml" / "src", "health")]
(innereye_root / "hi-ml" / "hi-ml-azure" / "src", "health_azure"),
(innereye_root / "hi-ml" / "hi-ml" / "src", "health_ml")]
for (folder, subfolder_that_must_exist) in folders_to_add:
if (folder / subfolder_that_must_exist).is_dir():
folder_str = str(folder)

Просмотреть файл

@ -3,7 +3,6 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from enum import Enum, unique
# String prefixes when writing training or validation set metrics to a logger
from typing import Union
@ -45,8 +44,6 @@ class LoggingColumns(Enum):
AccuracyAtThreshold05 = "accuracy_at_threshold_05"
Loss = "loss"
CrossEntropy = "cross_entropy"
SecondsPerEpoch = "seconds_per_epoch"
SecondsPerBatch = "seconds_per_batch"
AreaUnderRocCurve = "area_under_roc_curve"
AreaUnderPRCurve = "area_under_pr_curve"
CrossValidationSplitIndex = "cross_validation_split_index"
@ -100,8 +97,6 @@ class MetricType(Enum):
EXPLAINED_VAR = "ExplainedVariance"
# Common metrics
SECONDS_PER_BATCH = "SecondsPerBatch"
SECONDS_PER_EPOCH = "SecondsPerEpoch"
SUBJECT_COUNT = "SubjectCount"
LEARNING_RATE = "LearningRate"
@ -114,8 +109,6 @@ INTERNAL_TO_LOGGING_COLUMN_NAMES = {
MetricType.LOSS.value: LoggingColumns.Loss,
MetricType.ACCURACY_AT_THRESHOLD_05.value: LoggingColumns.AccuracyAtThreshold05,
MetricType.CROSS_ENTROPY.value: LoggingColumns.CrossEntropy,
MetricType.SECONDS_PER_BATCH.value: LoggingColumns.SecondsPerBatch,
MetricType.SECONDS_PER_EPOCH.value: LoggingColumns.SecondsPerEpoch,
MetricType.AREA_UNDER_ROC_CURVE.value: LoggingColumns.AreaUnderRocCurve,
MetricType.AREA_UNDER_PR_CURVE.value: LoggingColumns.AreaUnderPRCurve,
MetricType.SUBJECT_COUNT.value: LoggingColumns.SubjectCount,

Просмотреть файл

@ -3,7 +3,7 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from pathlib import Path
from typing import Dict, Iterable, Optional, Tuple, TypeVar, Union
from typing import Dict, Iterable, List, Optional, Tuple, TypeVar, Union
T = TypeVar('T')
PathOrString = Union[Path, str]
@ -15,3 +15,4 @@ TupleFloat3 = Tuple[float, float, float]
TupleFloat9 = Tuple[float, float, float, float, float, float, float, float, float]
IntOrTuple3 = Union[int, TupleInt3, Iterable]
DictStrFloat = Dict[str, float]
DictStrFloatOrFloatList = Dict[str, Union[float, List[float]]]

Просмотреть файл

@ -100,7 +100,7 @@ class SSLContainer(LightningContainer):
def setup(self) -> None:
from InnerEye.ML.SSL.lightning_containers.ssl_image_classifier import SSLClassifierContainer
self.total_num_gpus = self.num_gpus_per_node * self.num_nodes
self.total_num_gpus = self.num_gpus_per_node() * self.num_nodes
self._load_config()
# If you're using the same data for training and linear head, allow the user to specify the dataset only
# once. Or if you are doing just finetuning of linear head, the user should be able to specify dataset via

Просмотреть файл

@ -218,7 +218,7 @@ class WorkflowParams(param.Parameterized):
doc="If set, enable/disable full image inference on test set after ensemble training.")
weights_url: List[str] = param.List(default=[], class_=str,
doc="If provided, a set of urls from which checkpoints will be downloaded"
"and used for inference.")
"and used for inference.")
local_weights_path: List[Path] = param.List(default=[], class_=Path,
doc="A list of checkpoints paths to use for inference, "
"when the job is running outside Azure.")
@ -590,6 +590,16 @@ class TrainerParams(param.Parameterized):
param.Boolean(default=False,
doc="Controls the PyTorch Lightning flag 'find_unused_parameters' for the DDP plugin. "
"Setting it to True comes with a performance hit.")
monitor_gpu: bool = param.Boolean(default=False,
doc="If True, add the GPUStatsMonitor callback to the Lightning trainer object. "
"This will write GPU utilization metrics every 50 batches by default.")
monitor_loading: bool = param.Boolean(default=True,
doc="If True, add the BatchTimeCallback callback to the Lightning trainer "
"object. This will monitor how long individual batches take to load.")
pl_profiler: Optional[str] = \
param.String(default=None,
doc="The value to use for the 'profiler' argument for the Lightning trainer. "
"Set to either 'simple', 'advanced', or 'pytorch'")
@property
def use_gpu(self) -> bool:
@ -602,7 +612,6 @@ class TrainerParams(param.Parameterized):
from InnerEye.ML.utils.ml_util import is_gpu_available
return is_gpu_available()
@property
def num_gpus_per_node(self) -> int:
"""
Computes the number of gpus to use for each node: either the number of gpus available on the device

Просмотреть файл

@ -3,7 +3,6 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import logging
import numbers
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
@ -20,21 +19,23 @@ from InnerEye.Common.metrics_constants import LoggingColumns, MetricType, TRAIN_
from InnerEye.Common.type_annotations import DictStrFloat
from InnerEye.ML.common import ModelExecutionMode
from InnerEye.ML.config import SegmentationModelBase
from InnerEye.ML.dataset.full_image_dataset import convert_channels_to_file_paths
from InnerEye.ML.deep_learning_config import DatasetParams, DeepLearningConfig, OutputParams, TrainerParams, \
WorkflowParams
from InnerEye.ML.lightning_container import LightningContainer
from InnerEye.ML.lightning_loggers import StoringLogger
from InnerEye.ML.metrics import EpochTimers, MAX_ITEM_LOAD_TIME_SEC, store_epoch_metrics
from InnerEye.ML.metrics import store_epoch_metrics
from InnerEye.ML.metrics_dict import DataframeLogger
from InnerEye.ML.model_config_base import ModelConfigBase
from InnerEye.ML.utils import model_util
from InnerEye.ML.utils.csv_util import CSV_SUBJECT_HEADER
from InnerEye.ML.utils.device_aware_module import DeviceAwareModule
from InnerEye.ML.utils.lr_scheduler import SchedulerWithWarmUp
from InnerEye.ML.utils.ml_util import RandomStateSnapshot, set_random_seed, validate_dataset_paths
from InnerEye.ML.utils.model_util import generate_and_print_model_summary
from InnerEye.ML.visualizers.patch_sampling import visualize_random_crops_for_dataset
from InnerEye.ML.utils.csv_util import CSV_SUBJECT_HEADER
from InnerEye.ML.dataset.full_image_dataset import convert_channels_to_file_paths
from health_ml.utils import log_on_epoch
class TrainAndValDataLightning(LightningDataModule):
"""
@ -220,9 +221,6 @@ class InnerEyeLightning(LightningModule):
self.l_rate_scheduler: Optional[_LRScheduler] = None
self.cross_validation_split_index = config.cross_validation_split_index
self.effective_random_seed = config.get_effective_random_seed()
# Timers for monitoring data loading time
self.train_timers = EpochTimers()
self.val_timers = EpochTimers()
# This should be re-assigned on the outside, to a logger that is hooked up with the Trainer object.
self.storing_logger = StoringLogger()
# This will be initialized correctly in epoch_start
@ -260,14 +258,11 @@ class InnerEyeLightning(LightningModule):
assert isinstance(self.trainer, Trainer)
return self.trainer.accelerator_connector.use_ddp
def on_train_epoch_start(self) -> None:
self.train_timers.reset()
def training_epoch_end(self, outputs: List[Any]) -> None:
# Write out all the metrics that have been accumulated in the StoringLogger in the previous epoch.
# Metrics for the very last epoch are written in on_train_end
self.read_epoch_results_from_logger_and_store(epoch=self.current_epoch - 1)
self.training_or_validation_epoch_end(is_training=True)
self.training_or_validation_epoch_end(is_training=True) # type: ignore
def on_validation_epoch_start(self) -> None:
"""
@ -275,10 +270,6 @@ class InnerEyeLightning(LightningModule):
that any randomization when loading validation data is consistent during training. In particular, this ensures
that drawing random patches for segmentation model training is giving a validation set that does not fluctuate.
"""
self.val_timers.reset()
# In Lightning, the validation epoch is running "inside" the training. If we get here, it means that training
# is done for this epoch, even though the on_training_epoch hook has not yet been called.
self.train_timers.epoch_end()
# Store the random number generator state, so that the next training epoch starts from here.
self.random_state = RandomStateSnapshot.snapshot_random_state()
# reset the random state for validation, so that we get consistent behaviour when drawing random patches
@ -286,9 +277,6 @@ class InnerEyeLightning(LightningModule):
seed = self.effective_random_seed
set_random_seed(seed, "Validation")
def on_validation_epoch_end(self) -> None:
self.val_timers.epoch_end()
def validation_epoch_end(self, outputs: List[Any]) -> None:
"""
Resets the random number generator state to what it was before the current validation epoch started.
@ -297,7 +285,7 @@ class InnerEyeLightning(LightningModule):
# reset the random state for training, so that we get continue from where we were before the validation step.
assert self.random_state is not None
self.random_state.restore_random_state()
self.training_or_validation_epoch_end(is_training=False)
self.training_or_validation_epoch_end(is_training=False) # type: ignore
@rank_zero_only
def on_train_end(self) -> None:
@ -314,50 +302,11 @@ class InnerEyeLightning(LightningModule):
Training and Validation metrics.
"""
if epoch >= 0:
if epoch in self.storing_logger.results:
if epoch in self.storing_logger.results_per_epoch:
for is_training, prefix in [(True, TRAIN_PREFIX), (False, VALIDATION_PREFIX)]:
metrics = self.storing_logger.extract_by_prefix(epoch, prefix)
self.store_epoch_results(metrics, epoch, is_training)
@rank_zero_only
def training_or_validation_epoch_end(self, is_training: bool) -> None:
"""
This is a hook called at the end of a training or validation epoch. In here, we can still write
metrics to a logger.
:param is_training: If True, this is called at the end of a training epoch. If False, this is at the
end of a validation epoch.
"""
if not is_training:
# In validation epochs, mark that it has been completed. Training epochs are marked completed already
# at the start of the validation epoch.
self.val_timers.epoch_end()
# Write all IO stats here, so that the order on the console is Train start, train end, val start, val end.
self.write_and_log_epoch_time(is_training=True)
self.write_and_log_epoch_time(is_training=False)
def write_and_log_epoch_time(self, is_training: bool) -> None:
"""
Reads the IO timers for either the training or validation epoch, writes them to the console, and logs the
time per epoch.
:param is_training: If True, show and log the data for the training epoch. If False, use the data for the
validation epoch.
"""
timers = self.get_timers(is_training=is_training)
epoch_time_seconds = timers.total_epoch_time
status = "training" if is_training else "validation"
logging.info(f"Epoch {self.current_epoch} {status} took {epoch_time_seconds:0.2f}sec, of which waiting for "
f"data took {timers.total_load_time:0.2f} sec total.")
if timers.num_load_time_exceeded > 0 and timers.should_warn_in_this_epoch:
logging.warning("The dataloaders were not fast enough to always supply the next batch in less than "
f"{MAX_ITEM_LOAD_TIME_SEC}sec.")
logging.warning(
f"In this epoch, {timers.num_load_time_exceeded} out of {timers.num_batches} batches exceeded the load "
f"time threshold. Total loading time for the slow batches was {timers.total_extra_load_time:0.2f}sec.")
# This metric is only written at rank zero, and hence must no be synchronized across workers. If attempted,
# training will get stuck.
self.log_on_epoch(MetricType.SECONDS_PER_EPOCH, epoch_time_seconds, is_training=is_training,
sync_dist_override=False)
def log_on_epoch(self,
name: Union[MetricType, str],
value: Any,
@ -375,20 +324,19 @@ class InnerEyeLightning(LightningModule):
:param name: The name of the metric to log
:param value: The value of the metric. This can be a tensor, floating point value, or a Metric class.
:param is_training: If true, give the metric a "train/" prefix, otherwise a "val/" prefix.
:param reduce_fx: The reduce function to apply after synchronizing the tensors across GPUs.
:param reduce_fx: The reduce function to apply to step values. Default: torch.mean
:param sync_dist_op: The reduce operation to use when synchronizing the tensors across GPUs. This must be
a value recognized by sync_ddp: Either 'None' to use 'sum' as aggregate, or 'mean' or 'avg'
"""
metric_name = name if isinstance(name, str) else name.value
if isinstance(value, numbers.Number):
value = torch.tensor(value, dtype=torch.float, device=self.device)
prefix = TRAIN_PREFIX if is_training else VALIDATION_PREFIX
sync_dist = self.use_sync_dist if sync_dist_override is None else sync_dist_override
self.log(prefix + metric_name, value,
sync_dist=sync_dist,
on_step=False, on_epoch=True,
reduce_fx=reduce_fx,
sync_dist_op=sync_dist_op)
log_on_epoch(self,
name=prefix + metric_name,
value=value,
sync_dist=sync_dist,
reduce_fx=reduce_fx,
sync_dist_op=sync_dist_op)
def store_epoch_results(self, metrics: DictStrFloat, epoch: int, is_training: bool) -> None:
"""
@ -414,18 +362,6 @@ class InnerEyeLightning(LightningModule):
self.checkpoint_loading_message = f"Loading checkpoint that was created at ({', '.join(present_keys)})"
logging.info(self.checkpoint_loading_message)
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
self.batch_start(batch_idx=batch_idx, is_training=True)
def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
self.batch_start(batch_idx=batch_idx, is_training=False)
def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
self.batch_end(is_training=True)
def on_validation_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
self.batch_end(is_training=False)
def training_step(self, # type: ignore
sample: Dict[str, Any],
batch_index: int) -> Any:
@ -450,45 +386,6 @@ class InnerEyeLightning(LightningModule):
"""
raise NotImplementedError("This method must be overwritten in a derived class.")
@rank_zero_only
def batch_start(self, batch_idx: int, is_training: bool) -> None:
"""
Shared code to keep track of IO-related metrics when loading a minibatch. This is only done on rank zero.
:param batch_idx: The index of the current minibatch.
:param is_training: If true, this has been called from `on_train_batch_start`, otherwise it has been called from
`on_validation_batch_start`.
:return:
"""
timers = self.get_timers(is_training=is_training)
message_prefix = f"Epoch {self.current_epoch} {'training' if is_training else 'validation'}"
timers.batch_start(batch_index=batch_idx, epoch=self.current_epoch, message_prefix=message_prefix)
@rank_zero_only
def batch_end(self, is_training: bool) -> None:
"""
Shared code to keep track of IO-related metrics when loading a minibatch.
:param is_training: If true, this has been called from `on_train_batch_end`, otherwise it has been called from
`on_validation_batch_end`.
"""
timers = self.get_timers(is_training=is_training)
batch_time = timers.batch_end()
# This metric is only written at rank 0, and hence must not be synchronized. Trying to synchronize will
# block training.
self.log_on_epoch(MetricType.SECONDS_PER_BATCH, batch_time, is_training=is_training, sync_dist_override=False)
def get_timers(self, is_training: bool) -> EpochTimers:
"""
Gets the object that holds all IO-related metrics and timers, for either the validation or the training epoch.
"""
return self.train_timers if is_training else self.val_timers
def reset_timers(self) -> None:
"""
Resets all timers and counters for IO-related metrics, for both the validation and the training epoch.
"""
self.train_timers.reset()
self.val_timers.reset()
def write_loss(self, is_training: bool, loss: torch.Tensor) -> None:
"""
Writes the given loss value to Lightning, labelled either "val/loss" or "train/loss".

Просмотреть файл

@ -2,14 +2,14 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import logging
from typing import Any, Dict, Iterable, List, Optional
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.utilities import rank_zero_only
from InnerEye.Azure.azure_util import RUN_CONTEXT, is_offline_run_context
from InnerEye.Common.metrics_constants import TRAIN_PREFIX, VALIDATION_PREFIX
from InnerEye.Common.type_annotations import DictStrFloat
from InnerEye.Common.type_annotations import DictStrFloat, DictStrFloatOrFloatList
class StoringLogger(LightningLoggerBase):
@ -20,28 +20,40 @@ class StoringLogger(LightningLoggerBase):
def __init__(self) -> None:
super().__init__()
self.results: Dict[int, DictStrFloat] = {}
self.results_per_epoch: Dict[int, DictStrFloatOrFloatList] = {}
self.hyperparams: Any = None
# Fields to store diagnostics for unit testing
self.train_diagnostics: List[Any] = []
self.val_diagnostics: List[Any] = []
self.results_without_epoch: List[DictStrFloat] = []
@rank_zero_only
def log_metrics(self, metrics: DictStrFloat, step: Optional[int] = None) -> None:
logging.debug(f"StoringLogger step={step}: {metrics}")
epoch_name = "epoch"
if epoch_name not in metrics:
raise ValueError("Each of the logged metrics should have an 'epoch' key.")
# Metrics without an "epoch" key are logged during testing, for example
self.results_without_epoch.append(metrics)
return
epoch = int(metrics[epoch_name])
del metrics[epoch_name]
if epoch in self.results:
current_results = self.results[epoch]
overlapping_keys = set(metrics.keys()).intersection(current_results.keys())
if len(overlapping_keys) > 0:
raise ValueError(f"Unable to log metric with same name twice for epoch {epoch}: "
f"{', '.join(overlapping_keys)}")
current_results.update(metrics)
for key, value in metrics.items():
if isinstance(value, int):
metrics[key] = float(value)
if epoch in self.results_per_epoch:
current_results = self.results_per_epoch[epoch]
for key, value in metrics.items():
if key in current_results:
logging.debug(f"StoringLogger: appending results for metric {key}")
current_metrics = current_results[key]
if isinstance(current_metrics, list):
current_metrics.append(value)
else:
current_results[key] = [current_metrics, value]
else:
current_results[key] = value
else:
self.results[epoch] = metrics
self.results_per_epoch[epoch] = metrics # type: ignore
@rank_zero_only
def log_hyperparams(self, params: Any) -> None:
@ -61,7 +73,7 @@ class StoringLogger(LightningLoggerBase):
"""
Gets the epochs for which the present object holds any results.
"""
return self.results.keys()
return self.results_per_epoch.keys()
def extract_by_prefix(self, epoch: int, prefix_filter: str = "") -> DictStrFloat:
"""
@ -73,7 +85,7 @@ class StoringLogger(LightningLoggerBase):
have a name starting with `prefix`, and strip off the prefix.
:return: A metrics dictionary.
"""
epoch_results = self.results.get(epoch, None)
epoch_results = self.results_per_epoch.get(epoch, None)
if epoch_results is None:
raise KeyError(f"No results are stored for epoch {epoch}")
filtered = {}
@ -83,8 +95,8 @@ class StoringLogger(LightningLoggerBase):
# filter is supplied and really matches the metric name
if (not prefix_filter) or key.startswith(prefix_filter):
stripped_key = key[len(prefix_filter):]
filtered[stripped_key] = value
return filtered
filtered[stripped_key] = value # type: ignore
return filtered # type: ignore
def to_metrics_dicts(self, prefix_filter: str = "") -> Dict[int, DictStrFloat]:
"""
@ -107,7 +119,14 @@ class StoringLogger(LightningLoggerBase):
:return: A list of floating point numbers, with one entry per entry in the the training or validation results.
"""
full_metric_name = (TRAIN_PREFIX if is_training else VALIDATION_PREFIX) + metric_type
return [self.results[epoch][full_metric_name] for epoch in self.epochs]
result = []
for epoch in self.epochs:
value = self.results_per_epoch[epoch][full_metric_name]
if not isinstance(value, float):
raise ValueError(f"Expected a floating point value for metric {full_metric_name}, but got: "
f"{value}")
result.append(value)
return result
def get_train_metric(self, metric_type: str) -> List[float]:
"""
@ -138,33 +157,3 @@ class StoringLogger(LightningLoggerBase):
Gets the full set of validation metrics that the logger stores, as a list of dictionaries per epoch.
"""
return list(self.to_metrics_dicts(prefix_filter=VALIDATION_PREFIX).values())
class AzureMLLogger(LightningLoggerBase):
"""
A Pytorch Lightning logger that stores metrics in the current AzureML run. If the present run is not
inside AzureML, nothing gets logged.
"""
def __init__(self) -> None:
super().__init__()
self.is_azureml_run = not is_offline_run_context(RUN_CONTEXT)
@rank_zero_only
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
if self.is_azureml_run:
for key, value in metrics.items():
RUN_CONTEXT.log(key, value)
@rank_zero_only
def log_hyperparams(self, params: Any) -> None:
pass
def experiment(self) -> Any:
return None
def name(self) -> Any:
return ""
def version(self) -> int:
return 0

Просмотреть файл

@ -165,7 +165,6 @@ class SegmentationLightning(InnerEyeLightning):
for name, value in voxel_count.compute_all():
self.log(name, value)
voxel_count.reset()
super().training_or_validation_epoch_end(is_training=is_training)
def get_subject_output_file_per_rank(rank: int) -> str:
@ -292,7 +291,6 @@ class ScalarLightning(InnerEyeLightning):
metric.reset()
logger = self.train_subject_outputs_logger if is_training else self.val_subject_outputs_logger # type: ignore
logger.flush()
super().training_or_validation_epoch_end(is_training)
def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any: # type: ignore
"""

Просмотреть файл

@ -6,16 +6,15 @@ from __future__ import annotations
import logging
import math
import time
from dataclasses import dataclass, field
from typing import List, Optional, Sequence, Set
from dataclasses import dataclass
from typing import List, Optional, Sequence
import SimpleITK as sitk
import numpy as np
from numpy.core.numeric import NaN
import torch
import torch.nn.functional as F
from azureml.core import Run
from numpy.core.numeric import NaN
from InnerEye.Azure.azure_util import get_run_context_or_default
from InnerEye.Common.metrics_constants import LoggingColumns, MetricType
@ -27,15 +26,11 @@ from InnerEye.ML.metrics_dict import (DataframeLogger, INTERNAL_TO_LOGGING_COLUM
from InnerEye.ML.scalar_config import ScalarLoss
from InnerEye.ML.utils.image_util import binaries_from_multi_label_array, is_binary_array
from InnerEye.ML.utils.io_util import reverse_tuple_float3
from InnerEye.ML.utils.metrics_util import (binary_classification_accuracy, mean_absolute_error,
r2_score, is_missing_ground_truth)
from InnerEye.ML.utils.metrics_util import (binary_classification_accuracy, is_missing_ground_truth,
mean_absolute_error, r2_score)
from InnerEye.ML.utils.ml_util import check_size_matches
from InnerEye.ML.utils.sequence_utils import get_masked_model_outputs_and_labels
MAX_ITEM_LOAD_TIME_SEC = 0.5
MAX_LOAD_TIME_WARNINGS = 3
MAX_LOAD_TIME_EPOCHS = 5
@dataclass(frozen=True)
class InferenceMetrics:
@ -81,96 +76,6 @@ class InferenceMetricsForSegmentation(InferenceMetrics):
})
@dataclass
class EpochTimers:
"""
Contains all information necessary to compute the IO metrics: Epoch times, batch times, loading times.
"""
epoch_start_time: float = time.time()
epoch_end_time: float = time.time()
batch_start_time: float = time.time()
num_load_time_warnings: int = 0
num_load_time_exceeded: int = 0
total_extra_load_time: float = 0.0
total_load_time: float = 0.0
num_batches: int = 0
load_time_warning_epochs: Set[int] = field(default_factory=set)
def reset(self) -> None:
"""
Resets all timers to the current time, and all counters to 0. The set of epochs for which warnings about
load time were produced will not be reset.
"""
current_time = time.time()
self.epoch_start_time = current_time
self.epoch_end_time = current_time
self.batch_start_time = current_time
self.num_load_time_warnings = 0
self.num_load_time_exceeded = 0
self.total_extra_load_time = 0.0
self.total_load_time = 0.0
self.num_batches = 0
def epoch_end(self) -> None:
"""
Stores the present time in the epoch_end_time field of the object.
"""
self.epoch_end_time = time.time()
@property
def total_epoch_time(self) -> float:
"""
Gets the time in seconds between epoch start and epoch end.
"""
return self.epoch_end_time - self.epoch_start_time
@property
def should_warn_in_this_epoch(self) -> bool:
"""
Returns True if warnings about loading time should be printed in the present epoch. Returns False if
this warning has been printed already in more than MAX_LOAD_TIME_EPOCHS epochs.
:return:
"""
return len(self.load_time_warning_epochs) <= MAX_LOAD_TIME_EPOCHS
def batch_start(self, batch_index: int, epoch: int, message_prefix: str) -> float:
"""
Called when a minibatch of data has been loaded. This computes the time it took to load the minibatch,
and adds it to the internal bookkeeping.
:return: The time it took to load the minibatch, in seconds.
"""
item_finish_time = time.time()
item_load_time = item_finish_time - self.batch_start_time
self.total_load_time += item_load_time
# Having slow minibatch loading is OK in the very first batch of the every epoch, where processes
# are spawned. Later, the load time should be zero.
if batch_index == 0:
logging.info(f"{message_prefix}: Loaded the first minibatch of data in {item_load_time:0.2f} sec.")
elif item_load_time > MAX_ITEM_LOAD_TIME_SEC:
self.load_time_warning_epochs.add(epoch)
self.num_load_time_exceeded += 1
self.total_extra_load_time += item_load_time
if self.num_load_time_warnings < MAX_LOAD_TIME_WARNINGS and self.should_warn_in_this_epoch:
logging.warning(f"{message_prefix}: Loading minibatch {batch_index} took {item_load_time:0.2f} sec. "
"This can mean that there are not enough data loader worker processes, or that there "
"is a performance problem in loading. This warning will be printed at most "
f"{MAX_LOAD_TIME_WARNINGS} times in at most {MAX_LOAD_TIME_EPOCHS} epochs.")
self.num_load_time_warnings += 1
return item_load_time
def batch_end(self) -> float:
"""
Called after a minibatch has been processed (training or validation step completed). Returns the time it took
to process the current batch (including loading).
:return: The time it took to process the current batch, in seconds.
"""
current_time = time.time()
elapsed = current_time - self.batch_start_time
self.batch_start_time = current_time
self.num_batches += 1
return elapsed
def surface_distance(seg: sitk.Image, reference_segmentation: sitk.Image) -> float:
"""
Symmetric surface distances taking into account the image spacing
@ -366,9 +271,6 @@ def store_epoch_metrics(metrics: DictStrFloat,
hue_suffix = "/" + tokens[1]
else:
raise ValueError(f"Expected key to have format 'metric_name[/optional_suffix_for_hue]', got {key}")
if metric_name == MetricType.SECONDS_PER_BATCH.value or metric_name == MetricType.SECONDS_PER_EPOCH.value:
continue
if metric_name in INTERNAL_TO_LOGGING_COLUMN_NAMES.keys():
logger_row[INTERNAL_TO_LOGGING_COLUMN_NAMES[metric_name].value + hue_suffix] = value
else:

Просмотреть файл

@ -818,8 +818,6 @@ class DataframeLogger:
df = pd.DataFrame.from_records(self.records, columns=columns)
special_formatting = {
MetricType.LEARNING_RATE.value: ".6e",
MetricType.SECONDS_PER_EPOCH.value: ".2f",
MetricType.SECONDS_PER_BATCH.value: ".2f",
}
for column, column_format in special_formatting.items():
if column in df:

Просмотреть файл

@ -6,11 +6,10 @@ import logging
import os
import sys
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, TypeVar
from typing import Any, Dict, List, Optional, Tuple, TypeVar
from health_azure.utils import is_global_rank_zero, is_local_rank_zero
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import Callback, LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks import GPUStatsMonitor, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.plugins import DDPPlugin
@ -22,9 +21,11 @@ from InnerEye.ML.common import ModelExecutionMode, RECOVERY_CHECKPOINT_FILE_NAME
from InnerEye.ML.deep_learning_config import ARGS_TXT, VISUALIZATION_FOLDER
from InnerEye.ML.lightning_base import InnerEyeContainer, InnerEyeLightning
from InnerEye.ML.lightning_container import LightningContainer
from InnerEye.ML.lightning_loggers import AzureMLLogger, StoringLogger
from InnerEye.ML.lightning_loggers import StoringLogger
from InnerEye.ML.lightning_models import SUBJECT_OUTPUT_PER_RANK_PREFIX, ScalarLightning, \
get_subject_output_file_per_rank
from health_azure.utils import is_global_rank_zero, is_local_rank_zero
from health_ml.utils import AzureMLLogger, AzureMLProgressBar, BatchTimeCallback
TEMP_PREFIX = "temp/"
@ -78,7 +79,7 @@ def create_lightning_trainer(container: LightningContainer,
resume_from_checkpoint: Optional[Path] = None,
num_nodes: int = 1,
**kwargs: Dict[str, Any]) -> \
Tuple[Trainer, Optional[StoringLogger]]:
Tuple[Trainer, StoringLogger]:
"""
Creates a Pytorch Lightning Trainer object for the given model configuration. It creates checkpoint handlers
and loggers. That includes a diagnostic logger for use in unit tests, that is also returned as the second
@ -89,18 +90,7 @@ def create_lightning_trainer(container: LightningContainer,
:param kwargs: Any additional keyowrd arguments will be passed to the constructor of Trainer.
:return: A tuple [Trainer object, diagnostic logger]
"""
# For now, stick with the legacy behaviour of always saving only the last epoch checkpoint. For large segmentation
# models, this still appears to be the best way of choosing them because validation loss on the relatively small
# training patches is not stable enough. Going by the validation loss somehow works for the Prostate model, but
# not for the HeadAndNeck model.
last_checkpoint_callback = ModelCheckpoint(dirpath=str(container.checkpoint_folder), save_last=True, save_top_k=0)
# Recovery checkpoints: {epoch} will turn into a string like "epoch=1"
# Store 1 recovery checkpoint every recovery_checkpoint_save_interval epochs, keep the last
# recovery_checkpoints_save_last_k.
recovery_checkpoint_callback = InnerEyeRecoveryCheckpointCallback(container)
num_gpus = container.num_gpus_per_node
num_gpus = container.num_gpus_per_node()
effective_num_gpus = num_gpus * num_nodes
# Accelerator should be "ddp" when running large models in AzureML (when using DDP_spawn, we get out of GPU memory).
if effective_num_gpus > 1:
@ -115,12 +105,8 @@ def create_lightning_trainer(container: LightningContainer,
logging.info(f"Using {num_gpus} GPUs per node with accelerator '{accelerator}'")
tensorboard_logger = TensorBoardLogger(save_dir=str(container.logs_folder), name="Lightning", version="")
loggers = [tensorboard_logger, AzureMLLogger()]
storing_logger: Optional[StoringLogger]
if isinstance(container, InnerEyeContainer):
storing_logger = StoringLogger()
loggers.append(storing_logger)
else:
storing_logger = None
storing_logger = StoringLogger()
loggers.append(storing_logger)
# Use 32bit precision when running on CPU. Otherwise, make it depend on use_mixed_precision flag.
precision = 32 if num_gpus == 0 else 16 if container.use_mixed_precision else 32
# The next two flags control the settings in torch.backends.cudnn.deterministic and torch.backends.cudnn.benchmark
@ -133,18 +119,42 @@ def create_lightning_trainer(container: LightningContainer,
else:
deterministic = False
benchmark = True
# If the users provides additional callbacks via get_trainer_arguments (for custom
# containers
callbacks = [last_checkpoint_callback, recovery_checkpoint_callback]
# For now, stick with the legacy behaviour of always saving only the last epoch checkpoint. For large segmentation
# models, this still appears to be the best way of choosing them because validation loss on the relatively small
# training patches is not stable enough. Going by the validation loss somehow works for the Prostate model, but
# not for the HeadAndNeck model.
last_checkpoint_callback = ModelCheckpoint(dirpath=str(container.checkpoint_folder), save_last=True, save_top_k=0)
# Recovery checkpoints: {epoch} will turn into a string like "epoch=1"
# Store 1 recovery checkpoint every recovery_checkpoint_save_interval epochs, keep the last
# recovery_checkpoints_save_last_k.
recovery_checkpoint_callback = InnerEyeRecoveryCheckpointCallback(container)
callbacks: List[Callback] = [
last_checkpoint_callback,
recovery_checkpoint_callback,
]
if container.monitor_loading:
callbacks.append(BatchTimeCallback())
if num_gpus > 0 and container.monitor_gpu:
logging.info("Adding monitoring for GPU utilization")
callbacks.append(GPUStatsMonitor(intra_step_time=True, inter_step_time=True))
# Add the additional callbacks that were specified in get_trainer_arguments for LightningContainers
if "callbacks" in kwargs:
callbacks.append(kwargs.pop("callbacks")) # type: ignore
more_callbacks = kwargs.pop("callbacks")
if isinstance(more_callbacks, list):
callbacks.extend(more_callbacks) # type: ignore
else:
callbacks.append(more_callbacks) # type: ignore
is_azureml_run = not is_offline_run_context(RUN_CONTEXT)
progress_bar_refresh_rate = container.pl_progress_bar_refresh_rate
if progress_bar_refresh_rate is None and is_azureml_run:
# When running in AzureML, the default progress bar clutters the output files with thousands of lines.
progress_bar_refresh_rate = 50
logging.info(f"The progress bar refresh rate is not set. Using a default of {progress_bar_refresh_rate}. "
f"To change, modify the pl_progress_bar_refresh_rate field of the container.")
if is_azureml_run:
if progress_bar_refresh_rate is None:
progress_bar_refresh_rate = 50
logging.info(f"The progress bar refresh rate is not set. Using a default of {progress_bar_refresh_rate}. "
f"To change, modify the pl_progress_bar_refresh_rate field of the container.")
callbacks.append(AzureMLProgressBar(refresh_rate=progress_bar_refresh_rate,
write_to_logging_info=True,
print_timestamp=False))
# Read out additional model-specific args here.
# We probably want to keep essential ones like numgpu and logging.
trainer = Trainer(default_root_dir=str(container.outputs_folder),
@ -162,6 +172,7 @@ def create_lightning_trainer(container: LightningContainer,
precision=precision,
sync_batchnorm=True,
terminate_on_nan=container.detect_anomaly,
profiler=container.pl_profiler,
resume_from_checkpoint=str(resume_from_checkpoint) if resume_from_checkpoint else None,
**kwargs)
return trainer, storing_logger
@ -185,7 +196,7 @@ def start_resource_monitor(config: LightningContainer) -> ResourceMonitor:
def model_train(checkpoint_path: Optional[Path],
container: LightningContainer,
num_nodes: int = 1) -> Tuple[Trainer, Optional[StoringLogger]]:
num_nodes: int = 1) -> Tuple[Trainer, StoringLogger]:
"""
The main training loop. It creates the Pytorch model based on the configuration options passed in,
creates a Pytorch Lightning trainer, and trains the model.

Просмотреть файл

@ -43,6 +43,7 @@ from InnerEye.ML.deep_learning_config import CHECKPOINT_FOLDER, DeepLearningConf
FINAL_ENSEMBLE_MODEL_FOLDER, FINAL_MODEL_FOLDER, ModelCategory, MultiprocessingStartMethod, load_checkpoint
from InnerEye.ML.lightning_base import InnerEyeContainer
from InnerEye.ML.lightning_container import InnerEyeInference, LightningContainer
from InnerEye.ML.lightning_loggers import StoringLogger
from InnerEye.ML.metrics import InferenceMetrics, InferenceMetricsForSegmentation
from InnerEye.ML.model_config_base import ModelConfigBase
from InnerEye.ML.model_inference_config import ModelInferenceConfig
@ -188,6 +189,7 @@ class MLRunner:
self.post_cross_validation_hook = post_cross_validation_hook
self.model_deployment_hook = model_deployment_hook
self.output_subfolder = output_subfolder
self.storing_logger: Optional[StoringLogger] = None
self._has_setup_run = False
def setup(self, azure_run_info: Optional[AzureRunInfo] = None) -> None:
@ -384,9 +386,10 @@ class MLRunner:
# train a new model if required
if self.azure_config.train:
with logging_section("Model training"):
model_train(self.checkpoint_handler.get_recovery_or_checkpoint_path_train(),
container=self.container,
num_nodes=self.azure_config.num_nodes)
_, storing_logger = model_train(self.checkpoint_handler.get_recovery_or_checkpoint_path_train(),
container=self.container,
num_nodes=self.azure_config.num_nodes)
self.storing_logger = storing_logger
# Since we have trained the model further, let the checkpoint_handler object know so it can handle
# checkpoints correctly.
self.checkpoint_handler.additional_training_done()

Просмотреть файл

@ -134,6 +134,8 @@ class Runner:
self.model_config: Optional[DeepLearningConfig] = None
self.azure_config: AzureConfig = AzureConfig()
self.lightning_container: LightningContainer = None # type: ignore
# This field stores the MLRunner object that has been created in the most recent call to the run() method.
self.ml_runner: Optional[MLRunner] = None
def parse_and_load_model(self) -> ParserResult:
"""
@ -379,9 +381,9 @@ class Runner:
# Set environment variables for multi-node training if needed. This function will terminate early
# if it detects that it is not in a multi-node environment.
set_environment_variables_for_multi_node()
ml_runner = self.create_ml_runner()
ml_runner.setup(azure_run_info)
ml_runner.run()
self.ml_runner = self.create_ml_runner()
self.ml_runner.setup(azure_run_info)
self.ml_runner.run()
def create_ml_runner(self) -> MLRunner:
"""

Просмотреть файл

@ -531,8 +531,8 @@ def test_add_foreground_dice() -> None:
def test_dataframe_logger() -> None:
fixed_columns = {"cross_validation_split_index": 1}
records = [
{"bar": math.pi, MetricType.LEARNING_RATE.value: 1e-5, MetricType.SECONDS_PER_EPOCH.value: 123.123456},
{"bar": math.pi, MetricType.LEARNING_RATE.value: 1, MetricType.SECONDS_PER_EPOCH.value: 123.123456},
{"bar": math.pi, MetricType.LEARNING_RATE.value: 1e-5},
{"bar": math.pi, MetricType.LEARNING_RATE.value: 1},
]
out_buffer = StringIO()
df = DataframeLogger(csv_path=out_buffer, fixed_columns=fixed_columns)
@ -540,6 +540,6 @@ def test_dataframe_logger() -> None:
df.add_record(r)
df.flush()
assert out_buffer.getvalue().splitlines() == [
'bar,LearningRate,SecondsPerEpoch,cross_validation_split_index',
'3.141593,1.000000e-05,123.12,1',
'3.141593,1.000000e+00,123.12,1']
'bar,LearningRate,cross_validation_split_index',
'3.141593,1.000000e-05,1',
'3.141593,1.000000e+00,1']

Просмотреть файл

@ -69,8 +69,6 @@ def test_train_classification_model(class_name: str, test_output_dirs: OutputFol
val_results_per_epoch = model_training_result.val_results_per_epoch()
assert len(train_results_per_epoch) == config.num_epochs
assert len(val_results_per_epoch) == config.num_epochs
assert len(train_results_per_epoch[0]) >= 11
assert len(val_results_per_epoch[0]) >= 11
for metric in [MetricType.ACCURACY_AT_THRESHOLD_05,
MetricType.ACCURACY_AT_OPTIMAL_THRESHOLD,
@ -78,8 +76,6 @@ def test_train_classification_model(class_name: str, test_output_dirs: OutputFol
MetricType.AREA_UNDER_ROC_CURVE,
MetricType.CROSS_ENTROPY,
MetricType.LOSS,
MetricType.SECONDS_PER_BATCH,
MetricType.SECONDS_PER_EPOCH,
MetricType.SUBJECT_COUNT]:
assert metric.value in train_results_per_epoch[0], f"{metric.value} not in training"
assert metric.value in val_results_per_epoch[0], f"{metric.value} not in validation"
@ -193,7 +189,6 @@ def test_train_classification_multilabel_model(test_output_dirs: OutputFolderFor
assert f'{metric.value}/{class_name}' in train_results_per_epoch[0], f"{metric.value} not in training"
assert f'{metric.value}/{class_name}' in val_results_per_epoch[0], f"{metric.value} not in validation"
for metric in [MetricType.LOSS,
MetricType.SECONDS_PER_EPOCH,
MetricType.SUBJECT_COUNT]:
assert metric.value in train_results_per_epoch[0], f"{metric.value} not in training"
assert metric.value in val_results_per_epoch[0], f"{metric.value} not in validation"

Просмотреть файл

@ -12,12 +12,13 @@ import h5py
import numpy as np
import pandas as pd
import pytest
from health_ml.utils import BatchTimeCallback
from torch.utils.data import DataLoader
from InnerEye.Common import fixed_paths
from InnerEye.Common.common_util import SUBJECT_METRICS_FILE_NAME, is_windows, logging_to_stdout
from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path
from InnerEye.Common.metrics_constants import MetricType, TrackedMetrics, VALIDATION_PREFIX
from InnerEye.Common.metrics_constants import MetricType, TRAIN_PREFIX, TrackedMetrics, VALIDATION_PREFIX
from InnerEye.Common.output_directories import OutputFolderForTests
from InnerEye.ML.common import BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, CHECKPOINT_SUFFIX, DATASET_CSV_FILE_NAME, \
ModelExecutionMode, \
@ -114,6 +115,20 @@ def _test_model_train(output_dirs: OutputFolderForTests,
model_training_result, _ = model_train_unittest(train_config, dirs=output_dirs)
assert isinstance(model_training_result, StoringLogger)
# Check that all metrics from the BatchTimeCallback are present
for epoch, epoch_results in model_training_result.results_per_epoch.items():
for prefix in [TRAIN_PREFIX, VALIDATION_PREFIX]:
for metric_type in [BatchTimeCallback.EPOCH_TIME,
BatchTimeCallback.BATCH_TIME + " avg",
BatchTimeCallback.BATCH_TIME + " max",
BatchTimeCallback.EXCESS_LOADING_TIME]:
expected = BatchTimeCallback.METRICS_PREFIX + prefix + metric_type
assert expected in epoch_results, f"Expected {expected} in results for epoch {epoch}"
# Excess loading time can be zero because that only measure batches over the threshold
if metric_type != BatchTimeCallback.EXCESS_LOADING_TIME:
value = epoch_results[expected]
assert isinstance(value, float)
assert value > 0.0, f"Time for {expected} should be > 0"
actual_train_losses = model_training_result.get_train_metric(MetricType.LOSS.value)
actual_val_losses = model_training_result.get_val_metric(MetricType.LOSS.value)
@ -192,12 +207,6 @@ def _test_model_train(output_dirs: OutputFolderForTests,
assert train_config.show_patch_sampling > 0
assert len(list(sampling_folder.rglob("*.png"))) == 3 * train_config.show_patch_sampling
# Time per epoch: Test that we have all these times logged.
model_training_result.get_train_metric(MetricType.SECONDS_PER_EPOCH.value)
model_training_result.get_val_metric(MetricType.SECONDS_PER_EPOCH.value)
model_training_result.get_val_metric(MetricType.SECONDS_PER_BATCH.value)
model_training_result.get_train_metric(MetricType.SECONDS_PER_BATCH.value)
# # Test for saving of example images
assert train_config.example_images_folder.is_dir() if train_config.store_dataset_sample else True
example_files = list(train_config.example_images_folder.rglob("*.*"))
@ -359,3 +368,35 @@ def test_aggregate_and_create_subject_metrics_file(test_output_dirs: OutputFolde
written_lines = pd.read_csv(outputs_folder / mode / SUBJECT_METRICS_FILE_NAME)
expected_lines = pd.read_csv(outputs_folder / mode / "expected_metrics.csv")
assert written_lines.equals(expected_lines)
def test_storing_logger() -> None:
"""
Test if the StoringLogger can correctly handle multiple metrics of the same name logged per epoch.
"""
logger = StoringLogger()
key1 = "key"
key2 = "key2"
value1 = 3.14
value2 = 2.71
value3 = 100.0
assert value1 != value2
epoch = 1
# Add metrics in the same epoch in two calls, so that we test both the cases where the epoch is already present,
# and where not
logger.log_metrics({"epoch": 1, key1: value1})
logger.log_metrics({"epoch": 1, key2: value2})
# All results for epoch 1 should be collated into a single dictionary
assert logger.extract_by_prefix(epoch=epoch) == {key1: value1, key2: value2}
# When updating a metric that already exists, the result should not be a float anymore but a list.
logger.log_metrics({"epoch": epoch, key1: value3})
assert logger.extract_by_prefix(epoch=epoch) == {key1: [value1, value3], key2: value2}
# Add more metrics for key1, so that we also test the case that the results are already a list
logger.log_metrics({"epoch": epoch, key1: value3})
assert logger.extract_by_prefix(epoch=epoch) == {key1: [value1, value3, value3], key2: value2}
# Add metrics that don't have an epoch key: This happens for example during testing with trainer.test
other_metrics1 = {"foo": 1.0}
other_metrics2 = {"foo": 2.0}
logger.log_metrics(other_metrics1)
logger.log_metrics(other_metrics2)
assert logger.results_without_epoch == [other_metrics1, other_metrics2]

Просмотреть файл

@ -197,18 +197,18 @@ as a submodule, rather than a package from pypi. Any change to the package will
and that costs 20min per run.
* In the repository root, run `git submodule add https://github.com/microsoft/hi-ml`
* In PyCharm's project browser, mark the folder `hi-ml/src` as Sources Root
* Remove the entry for the `hi-ml` package from `environment.yml`
* Modify the start of `InnerEye/ML/runner.py` to look like this:
```python
print(f"Starting InnerEye runner at {sys.argv[0]}")
innereye_root = Path(__file__).absolute().parent.parent.parent
if (innereye_root / "InnerEye").is_dir():
innereye_root_str = str(innereye_root)
if innereye_root_str not in sys.path:
print(f"Adding InnerEye folder to sys.path: {innereye_root_str}")
sys.path.insert(0, innereye_root_str)
sys.path.append(str(innereye_root / "hi-ml" / "src"))
* In PyCharm's project browser, mark the folders `hi-ml/hi-ml/src` and `hi-ml/hi-ml-azure/src` as Sources Root
* Remove the entry for the `hi-ml` and `hi-ml-azure` packages from `environment.yml`
* There is already code in `InnerEye.Common.fixed_paths.add_submodules_to_path` that will pick up the submodules and
add them to `sys.path`.
Once you are done testing your changes:
* Remove the entry for `hi-ml` from `.gitmodules`
* Execute these steps from the repository root:
```shell
git submodule deinit -f hi-ml
rm -rf hi-ml
rm -rf .git/modules/hi-ml
```
Alternatively, you can consume a developer version of `hi-ml` from `test.pypi`:

Просмотреть файл

@ -23,7 +23,8 @@ dependencies:
- gitpython==3.1.7
- gputil==1.4.0
- h5py==2.10.0
- hi-ml-azure>=0.1.9
- hi-ml==0.1.10
- hi-ml-azure==0.1.10
- InnerEye-DICOM-RT==1.0.1
- joblib==0.16.0
- jupyter==1.0.0