Switching batch time loading diagnostics to hi-ml (#577)
- Using the BatchTimeCallback from hi-ml - Switches to trigger PL profiling
This commit is contained in:
Родитель
8495a2eec3
Коммит
bf4cb628c6
2
.flake8
2
.flake8
|
@ -2,4 +2,4 @@
|
|||
ignore = E226,E302,E41,W391, E701, W291, E722, W503, E128, E126, E127, E731, E401
|
||||
max-line-length = 160
|
||||
max-complexity = 25
|
||||
exclude = fastMRI/ test_outputs/
|
||||
exclude = fastMRI/ test_outputs/ hi-ml/
|
||||
|
|
|
@ -4,8 +4,10 @@
|
|||
<content url="file://$MODULE_DIR$">
|
||||
<sourceFolder url="file://$MODULE_DIR$" isTestSource="false" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/fastMRI" isTestSource="false" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/hi-ml/hi-ml-azure/src" isTestSource="false" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/hi-ml/hi-ml/src" isTestSource="false" />
|
||||
</content>
|
||||
<orderEntry type="jdk" jdkName="3.7 @ Ubuntu 20.04" jdkType="Python SDK" />
|
||||
<orderEntry type="jdk" jdkName="3.7 @ Ubuntu-20.04" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="PackageRequirementsSettings">
|
||||
|
|
|
@ -13,6 +13,9 @@ created.
|
|||
## Upcoming
|
||||
|
||||
### Added
|
||||
- ([#577](https://github.com/microsoft/InnerEye-DeepLearning/pull/577)) Commandline switch `monitor_gpu` to monitor
|
||||
GPU utilization via Lightning's `GpuStatsMonitor`, switch `monitor_loading` to check batch loading times via
|
||||
`BatchTimeCallback`, and `pl_profiler` to turn on the Lightning profiler (`simple`, `advanced`, or `pytorch`)
|
||||
- ([#544](https://github.com/microsoft/InnerEye-DeepLearning/pull/544)) Add documentation for segmentation model evaluation.
|
||||
- ([#465](https://github.com/microsoft/InnerEye-DeepLearning/pull/465/)) Adding ability to run segmentation inference
|
||||
module on test data with partial ground truth files. (Also [522](https://github.com/microsoft/InnerEye-DeepLearning/pull/522).)
|
||||
|
@ -77,6 +80,8 @@ in inference-only runs when using lightning containers.
|
|||
|
||||
### Removed
|
||||
|
||||
- ([#577](https://github.com/microsoft/InnerEye-DeepLearning/pull/577)) Removing the monitoring of batch loading time,
|
||||
use the `BatchTimeCallback` from `hi-ml` instead
|
||||
- ([#542](https://github.com/microsoft/InnerEye-DeepLearning/pull/542)) Removed Windows test leg from build pipeline.
|
||||
- ([#509](https://github.com/microsoft/InnerEye-DeepLearning/pull/509)) Parameters `local_weights_path` and
|
||||
`weights_url` can no longer be used to initialize a training run, only inference runs.
|
||||
|
|
|
@ -106,7 +106,8 @@ def add_submodules_to_path() -> None:
|
|||
innereye_root = repository_root_directory()
|
||||
folders_to_add = [(innereye_root, "InnerEye"),
|
||||
(innereye_root / "fastMRI", "fastmri"),
|
||||
(innereye_root / "hi-ml" / "src", "health")]
|
||||
(innereye_root / "hi-ml" / "hi-ml-azure" / "src", "health_azure"),
|
||||
(innereye_root / "hi-ml" / "hi-ml" / "src", "health_ml")]
|
||||
for (folder, subfolder_that_must_exist) in folders_to_add:
|
||||
if (folder / subfolder_that_must_exist).is_dir():
|
||||
folder_str = str(folder)
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
from enum import Enum, unique
|
||||
|
||||
# String prefixes when writing training or validation set metrics to a logger
|
||||
from typing import Union
|
||||
|
||||
|
@ -45,8 +44,6 @@ class LoggingColumns(Enum):
|
|||
AccuracyAtThreshold05 = "accuracy_at_threshold_05"
|
||||
Loss = "loss"
|
||||
CrossEntropy = "cross_entropy"
|
||||
SecondsPerEpoch = "seconds_per_epoch"
|
||||
SecondsPerBatch = "seconds_per_batch"
|
||||
AreaUnderRocCurve = "area_under_roc_curve"
|
||||
AreaUnderPRCurve = "area_under_pr_curve"
|
||||
CrossValidationSplitIndex = "cross_validation_split_index"
|
||||
|
@ -100,8 +97,6 @@ class MetricType(Enum):
|
|||
EXPLAINED_VAR = "ExplainedVariance"
|
||||
|
||||
# Common metrics
|
||||
SECONDS_PER_BATCH = "SecondsPerBatch"
|
||||
SECONDS_PER_EPOCH = "SecondsPerEpoch"
|
||||
SUBJECT_COUNT = "SubjectCount"
|
||||
LEARNING_RATE = "LearningRate"
|
||||
|
||||
|
@ -114,8 +109,6 @@ INTERNAL_TO_LOGGING_COLUMN_NAMES = {
|
|||
MetricType.LOSS.value: LoggingColumns.Loss,
|
||||
MetricType.ACCURACY_AT_THRESHOLD_05.value: LoggingColumns.AccuracyAtThreshold05,
|
||||
MetricType.CROSS_ENTROPY.value: LoggingColumns.CrossEntropy,
|
||||
MetricType.SECONDS_PER_BATCH.value: LoggingColumns.SecondsPerBatch,
|
||||
MetricType.SECONDS_PER_EPOCH.value: LoggingColumns.SecondsPerEpoch,
|
||||
MetricType.AREA_UNDER_ROC_CURVE.value: LoggingColumns.AreaUnderRocCurve,
|
||||
MetricType.AREA_UNDER_PR_CURVE.value: LoggingColumns.AreaUnderPRCurve,
|
||||
MetricType.SUBJECT_COUNT.value: LoggingColumns.SubjectCount,
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, Optional, Tuple, TypeVar, Union
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, TypeVar, Union
|
||||
|
||||
T = TypeVar('T')
|
||||
PathOrString = Union[Path, str]
|
||||
|
@ -15,3 +15,4 @@ TupleFloat3 = Tuple[float, float, float]
|
|||
TupleFloat9 = Tuple[float, float, float, float, float, float, float, float, float]
|
||||
IntOrTuple3 = Union[int, TupleInt3, Iterable]
|
||||
DictStrFloat = Dict[str, float]
|
||||
DictStrFloatOrFloatList = Dict[str, Union[float, List[float]]]
|
||||
|
|
|
@ -100,7 +100,7 @@ class SSLContainer(LightningContainer):
|
|||
|
||||
def setup(self) -> None:
|
||||
from InnerEye.ML.SSL.lightning_containers.ssl_image_classifier import SSLClassifierContainer
|
||||
self.total_num_gpus = self.num_gpus_per_node * self.num_nodes
|
||||
self.total_num_gpus = self.num_gpus_per_node() * self.num_nodes
|
||||
self._load_config()
|
||||
# If you're using the same data for training and linear head, allow the user to specify the dataset only
|
||||
# once. Or if you are doing just finetuning of linear head, the user should be able to specify dataset via
|
||||
|
|
|
@ -218,7 +218,7 @@ class WorkflowParams(param.Parameterized):
|
|||
doc="If set, enable/disable full image inference on test set after ensemble training.")
|
||||
weights_url: List[str] = param.List(default=[], class_=str,
|
||||
doc="If provided, a set of urls from which checkpoints will be downloaded"
|
||||
"and used for inference.")
|
||||
"and used for inference.")
|
||||
local_weights_path: List[Path] = param.List(default=[], class_=Path,
|
||||
doc="A list of checkpoints paths to use for inference, "
|
||||
"when the job is running outside Azure.")
|
||||
|
@ -590,6 +590,16 @@ class TrainerParams(param.Parameterized):
|
|||
param.Boolean(default=False,
|
||||
doc="Controls the PyTorch Lightning flag 'find_unused_parameters' for the DDP plugin. "
|
||||
"Setting it to True comes with a performance hit.")
|
||||
monitor_gpu: bool = param.Boolean(default=False,
|
||||
doc="If True, add the GPUStatsMonitor callback to the Lightning trainer object. "
|
||||
"This will write GPU utilization metrics every 50 batches by default.")
|
||||
monitor_loading: bool = param.Boolean(default=True,
|
||||
doc="If True, add the BatchTimeCallback callback to the Lightning trainer "
|
||||
"object. This will monitor how long individual batches take to load.")
|
||||
pl_profiler: Optional[str] = \
|
||||
param.String(default=None,
|
||||
doc="The value to use for the 'profiler' argument for the Lightning trainer. "
|
||||
"Set to either 'simple', 'advanced', or 'pytorch'")
|
||||
|
||||
@property
|
||||
def use_gpu(self) -> bool:
|
||||
|
@ -602,7 +612,6 @@ class TrainerParams(param.Parameterized):
|
|||
from InnerEye.ML.utils.ml_util import is_gpu_available
|
||||
return is_gpu_available()
|
||||
|
||||
@property
|
||||
def num_gpus_per_node(self) -> int:
|
||||
"""
|
||||
Computes the number of gpus to use for each node: either the number of gpus available on the device
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import logging
|
||||
import numbers
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
|
@ -20,21 +19,23 @@ from InnerEye.Common.metrics_constants import LoggingColumns, MetricType, TRAIN_
|
|||
from InnerEye.Common.type_annotations import DictStrFloat
|
||||
from InnerEye.ML.common import ModelExecutionMode
|
||||
from InnerEye.ML.config import SegmentationModelBase
|
||||
from InnerEye.ML.dataset.full_image_dataset import convert_channels_to_file_paths
|
||||
from InnerEye.ML.deep_learning_config import DatasetParams, DeepLearningConfig, OutputParams, TrainerParams, \
|
||||
WorkflowParams
|
||||
from InnerEye.ML.lightning_container import LightningContainer
|
||||
from InnerEye.ML.lightning_loggers import StoringLogger
|
||||
from InnerEye.ML.metrics import EpochTimers, MAX_ITEM_LOAD_TIME_SEC, store_epoch_metrics
|
||||
from InnerEye.ML.metrics import store_epoch_metrics
|
||||
from InnerEye.ML.metrics_dict import DataframeLogger
|
||||
from InnerEye.ML.model_config_base import ModelConfigBase
|
||||
from InnerEye.ML.utils import model_util
|
||||
from InnerEye.ML.utils.csv_util import CSV_SUBJECT_HEADER
|
||||
from InnerEye.ML.utils.device_aware_module import DeviceAwareModule
|
||||
from InnerEye.ML.utils.lr_scheduler import SchedulerWithWarmUp
|
||||
from InnerEye.ML.utils.ml_util import RandomStateSnapshot, set_random_seed, validate_dataset_paths
|
||||
from InnerEye.ML.utils.model_util import generate_and_print_model_summary
|
||||
from InnerEye.ML.visualizers.patch_sampling import visualize_random_crops_for_dataset
|
||||
from InnerEye.ML.utils.csv_util import CSV_SUBJECT_HEADER
|
||||
from InnerEye.ML.dataset.full_image_dataset import convert_channels_to_file_paths
|
||||
from health_ml.utils import log_on_epoch
|
||||
|
||||
|
||||
class TrainAndValDataLightning(LightningDataModule):
|
||||
"""
|
||||
|
@ -220,9 +221,6 @@ class InnerEyeLightning(LightningModule):
|
|||
self.l_rate_scheduler: Optional[_LRScheduler] = None
|
||||
self.cross_validation_split_index = config.cross_validation_split_index
|
||||
self.effective_random_seed = config.get_effective_random_seed()
|
||||
# Timers for monitoring data loading time
|
||||
self.train_timers = EpochTimers()
|
||||
self.val_timers = EpochTimers()
|
||||
# This should be re-assigned on the outside, to a logger that is hooked up with the Trainer object.
|
||||
self.storing_logger = StoringLogger()
|
||||
# This will be initialized correctly in epoch_start
|
||||
|
@ -260,14 +258,11 @@ class InnerEyeLightning(LightningModule):
|
|||
assert isinstance(self.trainer, Trainer)
|
||||
return self.trainer.accelerator_connector.use_ddp
|
||||
|
||||
def on_train_epoch_start(self) -> None:
|
||||
self.train_timers.reset()
|
||||
|
||||
def training_epoch_end(self, outputs: List[Any]) -> None:
|
||||
# Write out all the metrics that have been accumulated in the StoringLogger in the previous epoch.
|
||||
# Metrics for the very last epoch are written in on_train_end
|
||||
self.read_epoch_results_from_logger_and_store(epoch=self.current_epoch - 1)
|
||||
self.training_or_validation_epoch_end(is_training=True)
|
||||
self.training_or_validation_epoch_end(is_training=True) # type: ignore
|
||||
|
||||
def on_validation_epoch_start(self) -> None:
|
||||
"""
|
||||
|
@ -275,10 +270,6 @@ class InnerEyeLightning(LightningModule):
|
|||
that any randomization when loading validation data is consistent during training. In particular, this ensures
|
||||
that drawing random patches for segmentation model training is giving a validation set that does not fluctuate.
|
||||
"""
|
||||
self.val_timers.reset()
|
||||
# In Lightning, the validation epoch is running "inside" the training. If we get here, it means that training
|
||||
# is done for this epoch, even though the on_training_epoch hook has not yet been called.
|
||||
self.train_timers.epoch_end()
|
||||
# Store the random number generator state, so that the next training epoch starts from here.
|
||||
self.random_state = RandomStateSnapshot.snapshot_random_state()
|
||||
# reset the random state for validation, so that we get consistent behaviour when drawing random patches
|
||||
|
@ -286,9 +277,6 @@ class InnerEyeLightning(LightningModule):
|
|||
seed = self.effective_random_seed
|
||||
set_random_seed(seed, "Validation")
|
||||
|
||||
def on_validation_epoch_end(self) -> None:
|
||||
self.val_timers.epoch_end()
|
||||
|
||||
def validation_epoch_end(self, outputs: List[Any]) -> None:
|
||||
"""
|
||||
Resets the random number generator state to what it was before the current validation epoch started.
|
||||
|
@ -297,7 +285,7 @@ class InnerEyeLightning(LightningModule):
|
|||
# reset the random state for training, so that we get continue from where we were before the validation step.
|
||||
assert self.random_state is not None
|
||||
self.random_state.restore_random_state()
|
||||
self.training_or_validation_epoch_end(is_training=False)
|
||||
self.training_or_validation_epoch_end(is_training=False) # type: ignore
|
||||
|
||||
@rank_zero_only
|
||||
def on_train_end(self) -> None:
|
||||
|
@ -314,50 +302,11 @@ class InnerEyeLightning(LightningModule):
|
|||
Training and Validation metrics.
|
||||
"""
|
||||
if epoch >= 0:
|
||||
if epoch in self.storing_logger.results:
|
||||
if epoch in self.storing_logger.results_per_epoch:
|
||||
for is_training, prefix in [(True, TRAIN_PREFIX), (False, VALIDATION_PREFIX)]:
|
||||
metrics = self.storing_logger.extract_by_prefix(epoch, prefix)
|
||||
self.store_epoch_results(metrics, epoch, is_training)
|
||||
|
||||
@rank_zero_only
|
||||
def training_or_validation_epoch_end(self, is_training: bool) -> None:
|
||||
"""
|
||||
This is a hook called at the end of a training or validation epoch. In here, we can still write
|
||||
metrics to a logger.
|
||||
:param is_training: If True, this is called at the end of a training epoch. If False, this is at the
|
||||
end of a validation epoch.
|
||||
"""
|
||||
if not is_training:
|
||||
# In validation epochs, mark that it has been completed. Training epochs are marked completed already
|
||||
# at the start of the validation epoch.
|
||||
self.val_timers.epoch_end()
|
||||
# Write all IO stats here, so that the order on the console is Train start, train end, val start, val end.
|
||||
self.write_and_log_epoch_time(is_training=True)
|
||||
self.write_and_log_epoch_time(is_training=False)
|
||||
|
||||
def write_and_log_epoch_time(self, is_training: bool) -> None:
|
||||
"""
|
||||
Reads the IO timers for either the training or validation epoch, writes them to the console, and logs the
|
||||
time per epoch.
|
||||
:param is_training: If True, show and log the data for the training epoch. If False, use the data for the
|
||||
validation epoch.
|
||||
"""
|
||||
timers = self.get_timers(is_training=is_training)
|
||||
epoch_time_seconds = timers.total_epoch_time
|
||||
status = "training" if is_training else "validation"
|
||||
logging.info(f"Epoch {self.current_epoch} {status} took {epoch_time_seconds:0.2f}sec, of which waiting for "
|
||||
f"data took {timers.total_load_time:0.2f} sec total.")
|
||||
if timers.num_load_time_exceeded > 0 and timers.should_warn_in_this_epoch:
|
||||
logging.warning("The dataloaders were not fast enough to always supply the next batch in less than "
|
||||
f"{MAX_ITEM_LOAD_TIME_SEC}sec.")
|
||||
logging.warning(
|
||||
f"In this epoch, {timers.num_load_time_exceeded} out of {timers.num_batches} batches exceeded the load "
|
||||
f"time threshold. Total loading time for the slow batches was {timers.total_extra_load_time:0.2f}sec.")
|
||||
# This metric is only written at rank zero, and hence must no be synchronized across workers. If attempted,
|
||||
# training will get stuck.
|
||||
self.log_on_epoch(MetricType.SECONDS_PER_EPOCH, epoch_time_seconds, is_training=is_training,
|
||||
sync_dist_override=False)
|
||||
|
||||
def log_on_epoch(self,
|
||||
name: Union[MetricType, str],
|
||||
value: Any,
|
||||
|
@ -375,20 +324,19 @@ class InnerEyeLightning(LightningModule):
|
|||
:param name: The name of the metric to log
|
||||
:param value: The value of the metric. This can be a tensor, floating point value, or a Metric class.
|
||||
:param is_training: If true, give the metric a "train/" prefix, otherwise a "val/" prefix.
|
||||
:param reduce_fx: The reduce function to apply after synchronizing the tensors across GPUs.
|
||||
:param reduce_fx: The reduce function to apply to step values. Default: torch.mean
|
||||
:param sync_dist_op: The reduce operation to use when synchronizing the tensors across GPUs. This must be
|
||||
a value recognized by sync_ddp: Either 'None' to use 'sum' as aggregate, or 'mean' or 'avg'
|
||||
"""
|
||||
metric_name = name if isinstance(name, str) else name.value
|
||||
if isinstance(value, numbers.Number):
|
||||
value = torch.tensor(value, dtype=torch.float, device=self.device)
|
||||
prefix = TRAIN_PREFIX if is_training else VALIDATION_PREFIX
|
||||
sync_dist = self.use_sync_dist if sync_dist_override is None else sync_dist_override
|
||||
self.log(prefix + metric_name, value,
|
||||
sync_dist=sync_dist,
|
||||
on_step=False, on_epoch=True,
|
||||
reduce_fx=reduce_fx,
|
||||
sync_dist_op=sync_dist_op)
|
||||
log_on_epoch(self,
|
||||
name=prefix + metric_name,
|
||||
value=value,
|
||||
sync_dist=sync_dist,
|
||||
reduce_fx=reduce_fx,
|
||||
sync_dist_op=sync_dist_op)
|
||||
|
||||
def store_epoch_results(self, metrics: DictStrFloat, epoch: int, is_training: bool) -> None:
|
||||
"""
|
||||
|
@ -414,18 +362,6 @@ class InnerEyeLightning(LightningModule):
|
|||
self.checkpoint_loading_message = f"Loading checkpoint that was created at ({', '.join(present_keys)})"
|
||||
logging.info(self.checkpoint_loading_message)
|
||||
|
||||
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
|
||||
self.batch_start(batch_idx=batch_idx, is_training=True)
|
||||
|
||||
def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
|
||||
self.batch_start(batch_idx=batch_idx, is_training=False)
|
||||
|
||||
def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
|
||||
self.batch_end(is_training=True)
|
||||
|
||||
def on_validation_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
|
||||
self.batch_end(is_training=False)
|
||||
|
||||
def training_step(self, # type: ignore
|
||||
sample: Dict[str, Any],
|
||||
batch_index: int) -> Any:
|
||||
|
@ -450,45 +386,6 @@ class InnerEyeLightning(LightningModule):
|
|||
"""
|
||||
raise NotImplementedError("This method must be overwritten in a derived class.")
|
||||
|
||||
@rank_zero_only
|
||||
def batch_start(self, batch_idx: int, is_training: bool) -> None:
|
||||
"""
|
||||
Shared code to keep track of IO-related metrics when loading a minibatch. This is only done on rank zero.
|
||||
:param batch_idx: The index of the current minibatch.
|
||||
:param is_training: If true, this has been called from `on_train_batch_start`, otherwise it has been called from
|
||||
`on_validation_batch_start`.
|
||||
:return:
|
||||
"""
|
||||
timers = self.get_timers(is_training=is_training)
|
||||
message_prefix = f"Epoch {self.current_epoch} {'training' if is_training else 'validation'}"
|
||||
timers.batch_start(batch_index=batch_idx, epoch=self.current_epoch, message_prefix=message_prefix)
|
||||
|
||||
@rank_zero_only
|
||||
def batch_end(self, is_training: bool) -> None:
|
||||
"""
|
||||
Shared code to keep track of IO-related metrics when loading a minibatch.
|
||||
:param is_training: If true, this has been called from `on_train_batch_end`, otherwise it has been called from
|
||||
`on_validation_batch_end`.
|
||||
"""
|
||||
timers = self.get_timers(is_training=is_training)
|
||||
batch_time = timers.batch_end()
|
||||
# This metric is only written at rank 0, and hence must not be synchronized. Trying to synchronize will
|
||||
# block training.
|
||||
self.log_on_epoch(MetricType.SECONDS_PER_BATCH, batch_time, is_training=is_training, sync_dist_override=False)
|
||||
|
||||
def get_timers(self, is_training: bool) -> EpochTimers:
|
||||
"""
|
||||
Gets the object that holds all IO-related metrics and timers, for either the validation or the training epoch.
|
||||
"""
|
||||
return self.train_timers if is_training else self.val_timers
|
||||
|
||||
def reset_timers(self) -> None:
|
||||
"""
|
||||
Resets all timers and counters for IO-related metrics, for both the validation and the training epoch.
|
||||
"""
|
||||
self.train_timers.reset()
|
||||
self.val_timers.reset()
|
||||
|
||||
def write_loss(self, is_training: bool, loss: torch.Tensor) -> None:
|
||||
"""
|
||||
Writes the given loss value to Lightning, labelled either "val/loss" or "train/loss".
|
||||
|
|
|
@ -2,14 +2,14 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import logging
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
from pytorch_lightning.loggers import LightningLoggerBase
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
|
||||
from InnerEye.Azure.azure_util import RUN_CONTEXT, is_offline_run_context
|
||||
from InnerEye.Common.metrics_constants import TRAIN_PREFIX, VALIDATION_PREFIX
|
||||
from InnerEye.Common.type_annotations import DictStrFloat
|
||||
from InnerEye.Common.type_annotations import DictStrFloat, DictStrFloatOrFloatList
|
||||
|
||||
|
||||
class StoringLogger(LightningLoggerBase):
|
||||
|
@ -20,28 +20,40 @@ class StoringLogger(LightningLoggerBase):
|
|||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.results: Dict[int, DictStrFloat] = {}
|
||||
self.results_per_epoch: Dict[int, DictStrFloatOrFloatList] = {}
|
||||
self.hyperparams: Any = None
|
||||
# Fields to store diagnostics for unit testing
|
||||
self.train_diagnostics: List[Any] = []
|
||||
self.val_diagnostics: List[Any] = []
|
||||
self.results_without_epoch: List[DictStrFloat] = []
|
||||
|
||||
@rank_zero_only
|
||||
def log_metrics(self, metrics: DictStrFloat, step: Optional[int] = None) -> None:
|
||||
logging.debug(f"StoringLogger step={step}: {metrics}")
|
||||
epoch_name = "epoch"
|
||||
if epoch_name not in metrics:
|
||||
raise ValueError("Each of the logged metrics should have an 'epoch' key.")
|
||||
# Metrics without an "epoch" key are logged during testing, for example
|
||||
self.results_without_epoch.append(metrics)
|
||||
return
|
||||
epoch = int(metrics[epoch_name])
|
||||
del metrics[epoch_name]
|
||||
if epoch in self.results:
|
||||
current_results = self.results[epoch]
|
||||
overlapping_keys = set(metrics.keys()).intersection(current_results.keys())
|
||||
if len(overlapping_keys) > 0:
|
||||
raise ValueError(f"Unable to log metric with same name twice for epoch {epoch}: "
|
||||
f"{', '.join(overlapping_keys)}")
|
||||
current_results.update(metrics)
|
||||
for key, value in metrics.items():
|
||||
if isinstance(value, int):
|
||||
metrics[key] = float(value)
|
||||
if epoch in self.results_per_epoch:
|
||||
current_results = self.results_per_epoch[epoch]
|
||||
for key, value in metrics.items():
|
||||
if key in current_results:
|
||||
logging.debug(f"StoringLogger: appending results for metric {key}")
|
||||
current_metrics = current_results[key]
|
||||
if isinstance(current_metrics, list):
|
||||
current_metrics.append(value)
|
||||
else:
|
||||
current_results[key] = [current_metrics, value]
|
||||
else:
|
||||
current_results[key] = value
|
||||
else:
|
||||
self.results[epoch] = metrics
|
||||
self.results_per_epoch[epoch] = metrics # type: ignore
|
||||
|
||||
@rank_zero_only
|
||||
def log_hyperparams(self, params: Any) -> None:
|
||||
|
@ -61,7 +73,7 @@ class StoringLogger(LightningLoggerBase):
|
|||
"""
|
||||
Gets the epochs for which the present object holds any results.
|
||||
"""
|
||||
return self.results.keys()
|
||||
return self.results_per_epoch.keys()
|
||||
|
||||
def extract_by_prefix(self, epoch: int, prefix_filter: str = "") -> DictStrFloat:
|
||||
"""
|
||||
|
@ -73,7 +85,7 @@ class StoringLogger(LightningLoggerBase):
|
|||
have a name starting with `prefix`, and strip off the prefix.
|
||||
:return: A metrics dictionary.
|
||||
"""
|
||||
epoch_results = self.results.get(epoch, None)
|
||||
epoch_results = self.results_per_epoch.get(epoch, None)
|
||||
if epoch_results is None:
|
||||
raise KeyError(f"No results are stored for epoch {epoch}")
|
||||
filtered = {}
|
||||
|
@ -83,8 +95,8 @@ class StoringLogger(LightningLoggerBase):
|
|||
# filter is supplied and really matches the metric name
|
||||
if (not prefix_filter) or key.startswith(prefix_filter):
|
||||
stripped_key = key[len(prefix_filter):]
|
||||
filtered[stripped_key] = value
|
||||
return filtered
|
||||
filtered[stripped_key] = value # type: ignore
|
||||
return filtered # type: ignore
|
||||
|
||||
def to_metrics_dicts(self, prefix_filter: str = "") -> Dict[int, DictStrFloat]:
|
||||
"""
|
||||
|
@ -107,7 +119,14 @@ class StoringLogger(LightningLoggerBase):
|
|||
:return: A list of floating point numbers, with one entry per entry in the the training or validation results.
|
||||
"""
|
||||
full_metric_name = (TRAIN_PREFIX if is_training else VALIDATION_PREFIX) + metric_type
|
||||
return [self.results[epoch][full_metric_name] for epoch in self.epochs]
|
||||
result = []
|
||||
for epoch in self.epochs:
|
||||
value = self.results_per_epoch[epoch][full_metric_name]
|
||||
if not isinstance(value, float):
|
||||
raise ValueError(f"Expected a floating point value for metric {full_metric_name}, but got: "
|
||||
f"{value}")
|
||||
result.append(value)
|
||||
return result
|
||||
|
||||
def get_train_metric(self, metric_type: str) -> List[float]:
|
||||
"""
|
||||
|
@ -138,33 +157,3 @@ class StoringLogger(LightningLoggerBase):
|
|||
Gets the full set of validation metrics that the logger stores, as a list of dictionaries per epoch.
|
||||
"""
|
||||
return list(self.to_metrics_dicts(prefix_filter=VALIDATION_PREFIX).values())
|
||||
|
||||
|
||||
class AzureMLLogger(LightningLoggerBase):
|
||||
"""
|
||||
A Pytorch Lightning logger that stores metrics in the current AzureML run. If the present run is not
|
||||
inside AzureML, nothing gets logged.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.is_azureml_run = not is_offline_run_context(RUN_CONTEXT)
|
||||
|
||||
@rank_zero_only
|
||||
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
|
||||
if self.is_azureml_run:
|
||||
for key, value in metrics.items():
|
||||
RUN_CONTEXT.log(key, value)
|
||||
|
||||
@rank_zero_only
|
||||
def log_hyperparams(self, params: Any) -> None:
|
||||
pass
|
||||
|
||||
def experiment(self) -> Any:
|
||||
return None
|
||||
|
||||
def name(self) -> Any:
|
||||
return ""
|
||||
|
||||
def version(self) -> int:
|
||||
return 0
|
||||
|
|
|
@ -165,7 +165,6 @@ class SegmentationLightning(InnerEyeLightning):
|
|||
for name, value in voxel_count.compute_all():
|
||||
self.log(name, value)
|
||||
voxel_count.reset()
|
||||
super().training_or_validation_epoch_end(is_training=is_training)
|
||||
|
||||
|
||||
def get_subject_output_file_per_rank(rank: int) -> str:
|
||||
|
@ -292,7 +291,6 @@ class ScalarLightning(InnerEyeLightning):
|
|||
metric.reset()
|
||||
logger = self.train_subject_outputs_logger if is_training else self.val_subject_outputs_logger # type: ignore
|
||||
logger.flush()
|
||||
super().training_or_validation_epoch_end(is_training)
|
||||
|
||||
def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any: # type: ignore
|
||||
"""
|
||||
|
|
|
@ -6,16 +6,15 @@ from __future__ import annotations
|
|||
|
||||
import logging
|
||||
import math
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Sequence, Set
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Sequence
|
||||
|
||||
import SimpleITK as sitk
|
||||
import numpy as np
|
||||
from numpy.core.numeric import NaN
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from azureml.core import Run
|
||||
from numpy.core.numeric import NaN
|
||||
|
||||
from InnerEye.Azure.azure_util import get_run_context_or_default
|
||||
from InnerEye.Common.metrics_constants import LoggingColumns, MetricType
|
||||
|
@ -27,15 +26,11 @@ from InnerEye.ML.metrics_dict import (DataframeLogger, INTERNAL_TO_LOGGING_COLUM
|
|||
from InnerEye.ML.scalar_config import ScalarLoss
|
||||
from InnerEye.ML.utils.image_util import binaries_from_multi_label_array, is_binary_array
|
||||
from InnerEye.ML.utils.io_util import reverse_tuple_float3
|
||||
from InnerEye.ML.utils.metrics_util import (binary_classification_accuracy, mean_absolute_error,
|
||||
r2_score, is_missing_ground_truth)
|
||||
from InnerEye.ML.utils.metrics_util import (binary_classification_accuracy, is_missing_ground_truth,
|
||||
mean_absolute_error, r2_score)
|
||||
from InnerEye.ML.utils.ml_util import check_size_matches
|
||||
from InnerEye.ML.utils.sequence_utils import get_masked_model_outputs_and_labels
|
||||
|
||||
MAX_ITEM_LOAD_TIME_SEC = 0.5
|
||||
MAX_LOAD_TIME_WARNINGS = 3
|
||||
MAX_LOAD_TIME_EPOCHS = 5
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class InferenceMetrics:
|
||||
|
@ -81,96 +76,6 @@ class InferenceMetricsForSegmentation(InferenceMetrics):
|
|||
})
|
||||
|
||||
|
||||
@dataclass
|
||||
class EpochTimers:
|
||||
"""
|
||||
Contains all information necessary to compute the IO metrics: Epoch times, batch times, loading times.
|
||||
"""
|
||||
epoch_start_time: float = time.time()
|
||||
epoch_end_time: float = time.time()
|
||||
batch_start_time: float = time.time()
|
||||
num_load_time_warnings: int = 0
|
||||
num_load_time_exceeded: int = 0
|
||||
total_extra_load_time: float = 0.0
|
||||
total_load_time: float = 0.0
|
||||
num_batches: int = 0
|
||||
load_time_warning_epochs: Set[int] = field(default_factory=set)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
Resets all timers to the current time, and all counters to 0. The set of epochs for which warnings about
|
||||
load time were produced will not be reset.
|
||||
"""
|
||||
current_time = time.time()
|
||||
self.epoch_start_time = current_time
|
||||
self.epoch_end_time = current_time
|
||||
self.batch_start_time = current_time
|
||||
self.num_load_time_warnings = 0
|
||||
self.num_load_time_exceeded = 0
|
||||
self.total_extra_load_time = 0.0
|
||||
self.total_load_time = 0.0
|
||||
self.num_batches = 0
|
||||
|
||||
def epoch_end(self) -> None:
|
||||
"""
|
||||
Stores the present time in the epoch_end_time field of the object.
|
||||
"""
|
||||
self.epoch_end_time = time.time()
|
||||
|
||||
@property
|
||||
def total_epoch_time(self) -> float:
|
||||
"""
|
||||
Gets the time in seconds between epoch start and epoch end.
|
||||
"""
|
||||
return self.epoch_end_time - self.epoch_start_time
|
||||
|
||||
@property
|
||||
def should_warn_in_this_epoch(self) -> bool:
|
||||
"""
|
||||
Returns True if warnings about loading time should be printed in the present epoch. Returns False if
|
||||
this warning has been printed already in more than MAX_LOAD_TIME_EPOCHS epochs.
|
||||
:return:
|
||||
"""
|
||||
return len(self.load_time_warning_epochs) <= MAX_LOAD_TIME_EPOCHS
|
||||
|
||||
def batch_start(self, batch_index: int, epoch: int, message_prefix: str) -> float:
|
||||
"""
|
||||
Called when a minibatch of data has been loaded. This computes the time it took to load the minibatch,
|
||||
and adds it to the internal bookkeeping.
|
||||
:return: The time it took to load the minibatch, in seconds.
|
||||
"""
|
||||
item_finish_time = time.time()
|
||||
item_load_time = item_finish_time - self.batch_start_time
|
||||
self.total_load_time += item_load_time
|
||||
# Having slow minibatch loading is OK in the very first batch of the every epoch, where processes
|
||||
# are spawned. Later, the load time should be zero.
|
||||
if batch_index == 0:
|
||||
logging.info(f"{message_prefix}: Loaded the first minibatch of data in {item_load_time:0.2f} sec.")
|
||||
elif item_load_time > MAX_ITEM_LOAD_TIME_SEC:
|
||||
self.load_time_warning_epochs.add(epoch)
|
||||
self.num_load_time_exceeded += 1
|
||||
self.total_extra_load_time += item_load_time
|
||||
if self.num_load_time_warnings < MAX_LOAD_TIME_WARNINGS and self.should_warn_in_this_epoch:
|
||||
logging.warning(f"{message_prefix}: Loading minibatch {batch_index} took {item_load_time:0.2f} sec. "
|
||||
"This can mean that there are not enough data loader worker processes, or that there "
|
||||
"is a performance problem in loading. This warning will be printed at most "
|
||||
f"{MAX_LOAD_TIME_WARNINGS} times in at most {MAX_LOAD_TIME_EPOCHS} epochs.")
|
||||
self.num_load_time_warnings += 1
|
||||
return item_load_time
|
||||
|
||||
def batch_end(self) -> float:
|
||||
"""
|
||||
Called after a minibatch has been processed (training or validation step completed). Returns the time it took
|
||||
to process the current batch (including loading).
|
||||
:return: The time it took to process the current batch, in seconds.
|
||||
"""
|
||||
current_time = time.time()
|
||||
elapsed = current_time - self.batch_start_time
|
||||
self.batch_start_time = current_time
|
||||
self.num_batches += 1
|
||||
return elapsed
|
||||
|
||||
|
||||
def surface_distance(seg: sitk.Image, reference_segmentation: sitk.Image) -> float:
|
||||
"""
|
||||
Symmetric surface distances taking into account the image spacing
|
||||
|
@ -366,9 +271,6 @@ def store_epoch_metrics(metrics: DictStrFloat,
|
|||
hue_suffix = "/" + tokens[1]
|
||||
else:
|
||||
raise ValueError(f"Expected key to have format 'metric_name[/optional_suffix_for_hue]', got {key}")
|
||||
|
||||
if metric_name == MetricType.SECONDS_PER_BATCH.value or metric_name == MetricType.SECONDS_PER_EPOCH.value:
|
||||
continue
|
||||
if metric_name in INTERNAL_TO_LOGGING_COLUMN_NAMES.keys():
|
||||
logger_row[INTERNAL_TO_LOGGING_COLUMN_NAMES[metric_name].value + hue_suffix] = value
|
||||
else:
|
||||
|
|
|
@ -818,8 +818,6 @@ class DataframeLogger:
|
|||
df = pd.DataFrame.from_records(self.records, columns=columns)
|
||||
special_formatting = {
|
||||
MetricType.LEARNING_RATE.value: ".6e",
|
||||
MetricType.SECONDS_PER_EPOCH.value: ".2f",
|
||||
MetricType.SECONDS_PER_BATCH.value: ".2f",
|
||||
}
|
||||
for column, column_format in special_formatting.items():
|
||||
if column in df:
|
||||
|
|
|
@ -6,11 +6,10 @@ import logging
|
|||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Tuple, TypeVar
|
||||
from typing import Any, Dict, List, Optional, Tuple, TypeVar
|
||||
|
||||
from health_azure.utils import is_global_rank_zero, is_local_rank_zero
|
||||
from pytorch_lightning import LightningModule, Trainer, seed_everything
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning import Callback, LightningModule, Trainer, seed_everything
|
||||
from pytorch_lightning.callbacks import GPUStatsMonitor, ModelCheckpoint
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from pytorch_lightning.plugins import DDPPlugin
|
||||
|
||||
|
@ -22,9 +21,11 @@ from InnerEye.ML.common import ModelExecutionMode, RECOVERY_CHECKPOINT_FILE_NAME
|
|||
from InnerEye.ML.deep_learning_config import ARGS_TXT, VISUALIZATION_FOLDER
|
||||
from InnerEye.ML.lightning_base import InnerEyeContainer, InnerEyeLightning
|
||||
from InnerEye.ML.lightning_container import LightningContainer
|
||||
from InnerEye.ML.lightning_loggers import AzureMLLogger, StoringLogger
|
||||
from InnerEye.ML.lightning_loggers import StoringLogger
|
||||
from InnerEye.ML.lightning_models import SUBJECT_OUTPUT_PER_RANK_PREFIX, ScalarLightning, \
|
||||
get_subject_output_file_per_rank
|
||||
from health_azure.utils import is_global_rank_zero, is_local_rank_zero
|
||||
from health_ml.utils import AzureMLLogger, AzureMLProgressBar, BatchTimeCallback
|
||||
|
||||
TEMP_PREFIX = "temp/"
|
||||
|
||||
|
@ -78,7 +79,7 @@ def create_lightning_trainer(container: LightningContainer,
|
|||
resume_from_checkpoint: Optional[Path] = None,
|
||||
num_nodes: int = 1,
|
||||
**kwargs: Dict[str, Any]) -> \
|
||||
Tuple[Trainer, Optional[StoringLogger]]:
|
||||
Tuple[Trainer, StoringLogger]:
|
||||
"""
|
||||
Creates a Pytorch Lightning Trainer object for the given model configuration. It creates checkpoint handlers
|
||||
and loggers. That includes a diagnostic logger for use in unit tests, that is also returned as the second
|
||||
|
@ -89,18 +90,7 @@ def create_lightning_trainer(container: LightningContainer,
|
|||
:param kwargs: Any additional keyowrd arguments will be passed to the constructor of Trainer.
|
||||
:return: A tuple [Trainer object, diagnostic logger]
|
||||
"""
|
||||
# For now, stick with the legacy behaviour of always saving only the last epoch checkpoint. For large segmentation
|
||||
# models, this still appears to be the best way of choosing them because validation loss on the relatively small
|
||||
# training patches is not stable enough. Going by the validation loss somehow works for the Prostate model, but
|
||||
# not for the HeadAndNeck model.
|
||||
last_checkpoint_callback = ModelCheckpoint(dirpath=str(container.checkpoint_folder), save_last=True, save_top_k=0)
|
||||
|
||||
# Recovery checkpoints: {epoch} will turn into a string like "epoch=1"
|
||||
# Store 1 recovery checkpoint every recovery_checkpoint_save_interval epochs, keep the last
|
||||
# recovery_checkpoints_save_last_k.
|
||||
recovery_checkpoint_callback = InnerEyeRecoveryCheckpointCallback(container)
|
||||
|
||||
num_gpus = container.num_gpus_per_node
|
||||
num_gpus = container.num_gpus_per_node()
|
||||
effective_num_gpus = num_gpus * num_nodes
|
||||
# Accelerator should be "ddp" when running large models in AzureML (when using DDP_spawn, we get out of GPU memory).
|
||||
if effective_num_gpus > 1:
|
||||
|
@ -115,12 +105,8 @@ def create_lightning_trainer(container: LightningContainer,
|
|||
logging.info(f"Using {num_gpus} GPUs per node with accelerator '{accelerator}'")
|
||||
tensorboard_logger = TensorBoardLogger(save_dir=str(container.logs_folder), name="Lightning", version="")
|
||||
loggers = [tensorboard_logger, AzureMLLogger()]
|
||||
storing_logger: Optional[StoringLogger]
|
||||
if isinstance(container, InnerEyeContainer):
|
||||
storing_logger = StoringLogger()
|
||||
loggers.append(storing_logger)
|
||||
else:
|
||||
storing_logger = None
|
||||
storing_logger = StoringLogger()
|
||||
loggers.append(storing_logger)
|
||||
# Use 32bit precision when running on CPU. Otherwise, make it depend on use_mixed_precision flag.
|
||||
precision = 32 if num_gpus == 0 else 16 if container.use_mixed_precision else 32
|
||||
# The next two flags control the settings in torch.backends.cudnn.deterministic and torch.backends.cudnn.benchmark
|
||||
|
@ -133,18 +119,42 @@ def create_lightning_trainer(container: LightningContainer,
|
|||
else:
|
||||
deterministic = False
|
||||
benchmark = True
|
||||
# If the users provides additional callbacks via get_trainer_arguments (for custom
|
||||
# containers
|
||||
callbacks = [last_checkpoint_callback, recovery_checkpoint_callback]
|
||||
|
||||
# For now, stick with the legacy behaviour of always saving only the last epoch checkpoint. For large segmentation
|
||||
# models, this still appears to be the best way of choosing them because validation loss on the relatively small
|
||||
# training patches is not stable enough. Going by the validation loss somehow works for the Prostate model, but
|
||||
# not for the HeadAndNeck model.
|
||||
last_checkpoint_callback = ModelCheckpoint(dirpath=str(container.checkpoint_folder), save_last=True, save_top_k=0)
|
||||
# Recovery checkpoints: {epoch} will turn into a string like "epoch=1"
|
||||
# Store 1 recovery checkpoint every recovery_checkpoint_save_interval epochs, keep the last
|
||||
# recovery_checkpoints_save_last_k.
|
||||
recovery_checkpoint_callback = InnerEyeRecoveryCheckpointCallback(container)
|
||||
callbacks: List[Callback] = [
|
||||
last_checkpoint_callback,
|
||||
recovery_checkpoint_callback,
|
||||
]
|
||||
if container.monitor_loading:
|
||||
callbacks.append(BatchTimeCallback())
|
||||
if num_gpus > 0 and container.monitor_gpu:
|
||||
logging.info("Adding monitoring for GPU utilization")
|
||||
callbacks.append(GPUStatsMonitor(intra_step_time=True, inter_step_time=True))
|
||||
# Add the additional callbacks that were specified in get_trainer_arguments for LightningContainers
|
||||
if "callbacks" in kwargs:
|
||||
callbacks.append(kwargs.pop("callbacks")) # type: ignore
|
||||
more_callbacks = kwargs.pop("callbacks")
|
||||
if isinstance(more_callbacks, list):
|
||||
callbacks.extend(more_callbacks) # type: ignore
|
||||
else:
|
||||
callbacks.append(more_callbacks) # type: ignore
|
||||
is_azureml_run = not is_offline_run_context(RUN_CONTEXT)
|
||||
progress_bar_refresh_rate = container.pl_progress_bar_refresh_rate
|
||||
if progress_bar_refresh_rate is None and is_azureml_run:
|
||||
# When running in AzureML, the default progress bar clutters the output files with thousands of lines.
|
||||
progress_bar_refresh_rate = 50
|
||||
logging.info(f"The progress bar refresh rate is not set. Using a default of {progress_bar_refresh_rate}. "
|
||||
f"To change, modify the pl_progress_bar_refresh_rate field of the container.")
|
||||
if is_azureml_run:
|
||||
if progress_bar_refresh_rate is None:
|
||||
progress_bar_refresh_rate = 50
|
||||
logging.info(f"The progress bar refresh rate is not set. Using a default of {progress_bar_refresh_rate}. "
|
||||
f"To change, modify the pl_progress_bar_refresh_rate field of the container.")
|
||||
callbacks.append(AzureMLProgressBar(refresh_rate=progress_bar_refresh_rate,
|
||||
write_to_logging_info=True,
|
||||
print_timestamp=False))
|
||||
# Read out additional model-specific args here.
|
||||
# We probably want to keep essential ones like numgpu and logging.
|
||||
trainer = Trainer(default_root_dir=str(container.outputs_folder),
|
||||
|
@ -162,6 +172,7 @@ def create_lightning_trainer(container: LightningContainer,
|
|||
precision=precision,
|
||||
sync_batchnorm=True,
|
||||
terminate_on_nan=container.detect_anomaly,
|
||||
profiler=container.pl_profiler,
|
||||
resume_from_checkpoint=str(resume_from_checkpoint) if resume_from_checkpoint else None,
|
||||
**kwargs)
|
||||
return trainer, storing_logger
|
||||
|
@ -185,7 +196,7 @@ def start_resource_monitor(config: LightningContainer) -> ResourceMonitor:
|
|||
|
||||
def model_train(checkpoint_path: Optional[Path],
|
||||
container: LightningContainer,
|
||||
num_nodes: int = 1) -> Tuple[Trainer, Optional[StoringLogger]]:
|
||||
num_nodes: int = 1) -> Tuple[Trainer, StoringLogger]:
|
||||
"""
|
||||
The main training loop. It creates the Pytorch model based on the configuration options passed in,
|
||||
creates a Pytorch Lightning trainer, and trains the model.
|
||||
|
|
|
@ -43,6 +43,7 @@ from InnerEye.ML.deep_learning_config import CHECKPOINT_FOLDER, DeepLearningConf
|
|||
FINAL_ENSEMBLE_MODEL_FOLDER, FINAL_MODEL_FOLDER, ModelCategory, MultiprocessingStartMethod, load_checkpoint
|
||||
from InnerEye.ML.lightning_base import InnerEyeContainer
|
||||
from InnerEye.ML.lightning_container import InnerEyeInference, LightningContainer
|
||||
from InnerEye.ML.lightning_loggers import StoringLogger
|
||||
from InnerEye.ML.metrics import InferenceMetrics, InferenceMetricsForSegmentation
|
||||
from InnerEye.ML.model_config_base import ModelConfigBase
|
||||
from InnerEye.ML.model_inference_config import ModelInferenceConfig
|
||||
|
@ -188,6 +189,7 @@ class MLRunner:
|
|||
self.post_cross_validation_hook = post_cross_validation_hook
|
||||
self.model_deployment_hook = model_deployment_hook
|
||||
self.output_subfolder = output_subfolder
|
||||
self.storing_logger: Optional[StoringLogger] = None
|
||||
self._has_setup_run = False
|
||||
|
||||
def setup(self, azure_run_info: Optional[AzureRunInfo] = None) -> None:
|
||||
|
@ -384,9 +386,10 @@ class MLRunner:
|
|||
# train a new model if required
|
||||
if self.azure_config.train:
|
||||
with logging_section("Model training"):
|
||||
model_train(self.checkpoint_handler.get_recovery_or_checkpoint_path_train(),
|
||||
container=self.container,
|
||||
num_nodes=self.azure_config.num_nodes)
|
||||
_, storing_logger = model_train(self.checkpoint_handler.get_recovery_or_checkpoint_path_train(),
|
||||
container=self.container,
|
||||
num_nodes=self.azure_config.num_nodes)
|
||||
self.storing_logger = storing_logger
|
||||
# Since we have trained the model further, let the checkpoint_handler object know so it can handle
|
||||
# checkpoints correctly.
|
||||
self.checkpoint_handler.additional_training_done()
|
||||
|
|
|
@ -134,6 +134,8 @@ class Runner:
|
|||
self.model_config: Optional[DeepLearningConfig] = None
|
||||
self.azure_config: AzureConfig = AzureConfig()
|
||||
self.lightning_container: LightningContainer = None # type: ignore
|
||||
# This field stores the MLRunner object that has been created in the most recent call to the run() method.
|
||||
self.ml_runner: Optional[MLRunner] = None
|
||||
|
||||
def parse_and_load_model(self) -> ParserResult:
|
||||
"""
|
||||
|
@ -379,9 +381,9 @@ class Runner:
|
|||
# Set environment variables for multi-node training if needed. This function will terminate early
|
||||
# if it detects that it is not in a multi-node environment.
|
||||
set_environment_variables_for_multi_node()
|
||||
ml_runner = self.create_ml_runner()
|
||||
ml_runner.setup(azure_run_info)
|
||||
ml_runner.run()
|
||||
self.ml_runner = self.create_ml_runner()
|
||||
self.ml_runner.setup(azure_run_info)
|
||||
self.ml_runner.run()
|
||||
|
||||
def create_ml_runner(self) -> MLRunner:
|
||||
"""
|
||||
|
|
|
@ -531,8 +531,8 @@ def test_add_foreground_dice() -> None:
|
|||
def test_dataframe_logger() -> None:
|
||||
fixed_columns = {"cross_validation_split_index": 1}
|
||||
records = [
|
||||
{"bar": math.pi, MetricType.LEARNING_RATE.value: 1e-5, MetricType.SECONDS_PER_EPOCH.value: 123.123456},
|
||||
{"bar": math.pi, MetricType.LEARNING_RATE.value: 1, MetricType.SECONDS_PER_EPOCH.value: 123.123456},
|
||||
{"bar": math.pi, MetricType.LEARNING_RATE.value: 1e-5},
|
||||
{"bar": math.pi, MetricType.LEARNING_RATE.value: 1},
|
||||
]
|
||||
out_buffer = StringIO()
|
||||
df = DataframeLogger(csv_path=out_buffer, fixed_columns=fixed_columns)
|
||||
|
@ -540,6 +540,6 @@ def test_dataframe_logger() -> None:
|
|||
df.add_record(r)
|
||||
df.flush()
|
||||
assert out_buffer.getvalue().splitlines() == [
|
||||
'bar,LearningRate,SecondsPerEpoch,cross_validation_split_index',
|
||||
'3.141593,1.000000e-05,123.12,1',
|
||||
'3.141593,1.000000e+00,123.12,1']
|
||||
'bar,LearningRate,cross_validation_split_index',
|
||||
'3.141593,1.000000e-05,1',
|
||||
'3.141593,1.000000e+00,1']
|
||||
|
|
|
@ -69,8 +69,6 @@ def test_train_classification_model(class_name: str, test_output_dirs: OutputFol
|
|||
val_results_per_epoch = model_training_result.val_results_per_epoch()
|
||||
assert len(train_results_per_epoch) == config.num_epochs
|
||||
assert len(val_results_per_epoch) == config.num_epochs
|
||||
assert len(train_results_per_epoch[0]) >= 11
|
||||
assert len(val_results_per_epoch[0]) >= 11
|
||||
|
||||
for metric in [MetricType.ACCURACY_AT_THRESHOLD_05,
|
||||
MetricType.ACCURACY_AT_OPTIMAL_THRESHOLD,
|
||||
|
@ -78,8 +76,6 @@ def test_train_classification_model(class_name: str, test_output_dirs: OutputFol
|
|||
MetricType.AREA_UNDER_ROC_CURVE,
|
||||
MetricType.CROSS_ENTROPY,
|
||||
MetricType.LOSS,
|
||||
MetricType.SECONDS_PER_BATCH,
|
||||
MetricType.SECONDS_PER_EPOCH,
|
||||
MetricType.SUBJECT_COUNT]:
|
||||
assert metric.value in train_results_per_epoch[0], f"{metric.value} not in training"
|
||||
assert metric.value in val_results_per_epoch[0], f"{metric.value} not in validation"
|
||||
|
@ -193,7 +189,6 @@ def test_train_classification_multilabel_model(test_output_dirs: OutputFolderFor
|
|||
assert f'{metric.value}/{class_name}' in train_results_per_epoch[0], f"{metric.value} not in training"
|
||||
assert f'{metric.value}/{class_name}' in val_results_per_epoch[0], f"{metric.value} not in validation"
|
||||
for metric in [MetricType.LOSS,
|
||||
MetricType.SECONDS_PER_EPOCH,
|
||||
MetricType.SUBJECT_COUNT]:
|
||||
assert metric.value in train_results_per_epoch[0], f"{metric.value} not in training"
|
||||
assert metric.value in val_results_per_epoch[0], f"{metric.value} not in validation"
|
||||
|
|
|
@ -12,12 +12,13 @@ import h5py
|
|||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from health_ml.utils import BatchTimeCallback
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from InnerEye.Common import fixed_paths
|
||||
from InnerEye.Common.common_util import SUBJECT_METRICS_FILE_NAME, is_windows, logging_to_stdout
|
||||
from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path
|
||||
from InnerEye.Common.metrics_constants import MetricType, TrackedMetrics, VALIDATION_PREFIX
|
||||
from InnerEye.Common.metrics_constants import MetricType, TRAIN_PREFIX, TrackedMetrics, VALIDATION_PREFIX
|
||||
from InnerEye.Common.output_directories import OutputFolderForTests
|
||||
from InnerEye.ML.common import BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, CHECKPOINT_SUFFIX, DATASET_CSV_FILE_NAME, \
|
||||
ModelExecutionMode, \
|
||||
|
@ -114,6 +115,20 @@ def _test_model_train(output_dirs: OutputFolderForTests,
|
|||
|
||||
model_training_result, _ = model_train_unittest(train_config, dirs=output_dirs)
|
||||
assert isinstance(model_training_result, StoringLogger)
|
||||
# Check that all metrics from the BatchTimeCallback are present
|
||||
for epoch, epoch_results in model_training_result.results_per_epoch.items():
|
||||
for prefix in [TRAIN_PREFIX, VALIDATION_PREFIX]:
|
||||
for metric_type in [BatchTimeCallback.EPOCH_TIME,
|
||||
BatchTimeCallback.BATCH_TIME + " avg",
|
||||
BatchTimeCallback.BATCH_TIME + " max",
|
||||
BatchTimeCallback.EXCESS_LOADING_TIME]:
|
||||
expected = BatchTimeCallback.METRICS_PREFIX + prefix + metric_type
|
||||
assert expected in epoch_results, f"Expected {expected} in results for epoch {epoch}"
|
||||
# Excess loading time can be zero because that only measure batches over the threshold
|
||||
if metric_type != BatchTimeCallback.EXCESS_LOADING_TIME:
|
||||
value = epoch_results[expected]
|
||||
assert isinstance(value, float)
|
||||
assert value > 0.0, f"Time for {expected} should be > 0"
|
||||
|
||||
actual_train_losses = model_training_result.get_train_metric(MetricType.LOSS.value)
|
||||
actual_val_losses = model_training_result.get_val_metric(MetricType.LOSS.value)
|
||||
|
@ -192,12 +207,6 @@ def _test_model_train(output_dirs: OutputFolderForTests,
|
|||
assert train_config.show_patch_sampling > 0
|
||||
assert len(list(sampling_folder.rglob("*.png"))) == 3 * train_config.show_patch_sampling
|
||||
|
||||
# Time per epoch: Test that we have all these times logged.
|
||||
model_training_result.get_train_metric(MetricType.SECONDS_PER_EPOCH.value)
|
||||
model_training_result.get_val_metric(MetricType.SECONDS_PER_EPOCH.value)
|
||||
model_training_result.get_val_metric(MetricType.SECONDS_PER_BATCH.value)
|
||||
model_training_result.get_train_metric(MetricType.SECONDS_PER_BATCH.value)
|
||||
|
||||
# # Test for saving of example images
|
||||
assert train_config.example_images_folder.is_dir() if train_config.store_dataset_sample else True
|
||||
example_files = list(train_config.example_images_folder.rglob("*.*"))
|
||||
|
@ -359,3 +368,35 @@ def test_aggregate_and_create_subject_metrics_file(test_output_dirs: OutputFolde
|
|||
written_lines = pd.read_csv(outputs_folder / mode / SUBJECT_METRICS_FILE_NAME)
|
||||
expected_lines = pd.read_csv(outputs_folder / mode / "expected_metrics.csv")
|
||||
assert written_lines.equals(expected_lines)
|
||||
|
||||
|
||||
def test_storing_logger() -> None:
|
||||
"""
|
||||
Test if the StoringLogger can correctly handle multiple metrics of the same name logged per epoch.
|
||||
"""
|
||||
logger = StoringLogger()
|
||||
key1 = "key"
|
||||
key2 = "key2"
|
||||
value1 = 3.14
|
||||
value2 = 2.71
|
||||
value3 = 100.0
|
||||
assert value1 != value2
|
||||
epoch = 1
|
||||
# Add metrics in the same epoch in two calls, so that we test both the cases where the epoch is already present,
|
||||
# and where not
|
||||
logger.log_metrics({"epoch": 1, key1: value1})
|
||||
logger.log_metrics({"epoch": 1, key2: value2})
|
||||
# All results for epoch 1 should be collated into a single dictionary
|
||||
assert logger.extract_by_prefix(epoch=epoch) == {key1: value1, key2: value2}
|
||||
# When updating a metric that already exists, the result should not be a float anymore but a list.
|
||||
logger.log_metrics({"epoch": epoch, key1: value3})
|
||||
assert logger.extract_by_prefix(epoch=epoch) == {key1: [value1, value3], key2: value2}
|
||||
# Add more metrics for key1, so that we also test the case that the results are already a list
|
||||
logger.log_metrics({"epoch": epoch, key1: value3})
|
||||
assert logger.extract_by_prefix(epoch=epoch) == {key1: [value1, value3, value3], key2: value2}
|
||||
# Add metrics that don't have an epoch key: This happens for example during testing with trainer.test
|
||||
other_metrics1 = {"foo": 1.0}
|
||||
other_metrics2 = {"foo": 2.0}
|
||||
logger.log_metrics(other_metrics1)
|
||||
logger.log_metrics(other_metrics2)
|
||||
assert logger.results_without_epoch == [other_metrics1, other_metrics2]
|
||||
|
|
|
@ -197,18 +197,18 @@ as a submodule, rather than a package from pypi. Any change to the package will
|
|||
and that costs 20min per run.
|
||||
|
||||
* In the repository root, run `git submodule add https://github.com/microsoft/hi-ml`
|
||||
* In PyCharm's project browser, mark the folder `hi-ml/src` as Sources Root
|
||||
* Remove the entry for the `hi-ml` package from `environment.yml`
|
||||
* Modify the start of `InnerEye/ML/runner.py` to look like this:
|
||||
```python
|
||||
print(f"Starting InnerEye runner at {sys.argv[0]}")
|
||||
innereye_root = Path(__file__).absolute().parent.parent.parent
|
||||
if (innereye_root / "InnerEye").is_dir():
|
||||
innereye_root_str = str(innereye_root)
|
||||
if innereye_root_str not in sys.path:
|
||||
print(f"Adding InnerEye folder to sys.path: {innereye_root_str}")
|
||||
sys.path.insert(0, innereye_root_str)
|
||||
sys.path.append(str(innereye_root / "hi-ml" / "src"))
|
||||
* In PyCharm's project browser, mark the folders `hi-ml/hi-ml/src` and `hi-ml/hi-ml-azure/src` as Sources Root
|
||||
* Remove the entry for the `hi-ml` and `hi-ml-azure` packages from `environment.yml`
|
||||
* There is already code in `InnerEye.Common.fixed_paths.add_submodules_to_path` that will pick up the submodules and
|
||||
add them to `sys.path`.
|
||||
|
||||
Once you are done testing your changes:
|
||||
* Remove the entry for `hi-ml` from `.gitmodules`
|
||||
* Execute these steps from the repository root:
|
||||
```shell
|
||||
git submodule deinit -f hi-ml
|
||||
rm -rf hi-ml
|
||||
rm -rf .git/modules/hi-ml
|
||||
```
|
||||
|
||||
Alternatively, you can consume a developer version of `hi-ml` from `test.pypi`:
|
||||
|
|
|
@ -23,7 +23,8 @@ dependencies:
|
|||
- gitpython==3.1.7
|
||||
- gputil==1.4.0
|
||||
- h5py==2.10.0
|
||||
- hi-ml-azure>=0.1.9
|
||||
- hi-ml==0.1.10
|
||||
- hi-ml-azure==0.1.10
|
||||
- InnerEye-DICOM-RT==1.0.1
|
||||
- joblib==0.16.0
|
||||
- jupyter==1.0.0
|
||||
|
|
Загрузка…
Ссылка в новой задаче