Upgrade PyTorch 1.6 (#199)
- Refactored LR schedulers - Switched to new built-in mixed precision
This commit is contained in:
Родитель
6acbdbcea0
Коммит
12668bdd71
|
@ -204,7 +204,8 @@ class DeepLearningConfig(GenericConfig, CudaAwareConfig):
|
|||
" from a checkpoint.")
|
||||
|
||||
l_rate: float = param.Number(1e-4, doc="The initial learning rate", bounds=(0, None))
|
||||
_min_l_rate: float = param.Number(0.0, doc="The minimum learning rate", bounds=(0.0, None))
|
||||
_min_l_rate: float = param.Number(0.0, doc="The minimum learning rate for the Polynomial and Cosine schedulers.",
|
||||
bounds=(0.0, None))
|
||||
l_rate_scheduler: LRSchedulerType = param.ClassSelector(default=LRSchedulerType.Polynomial,
|
||||
class_=LRSchedulerType,
|
||||
instantiate=False,
|
||||
|
|
|
@ -8,6 +8,8 @@ import os
|
|||
from time import time
|
||||
from typing import Optional, Tuple, TypeVar
|
||||
|
||||
from torch.cuda.amp import GradScaler
|
||||
|
||||
from InnerEye.Azure.azure_util import RUN_CONTEXT
|
||||
from InnerEye.Common.common_util import empty_string_to_none
|
||||
from InnerEye.Common.metrics_dict import MetricsDict
|
||||
|
@ -24,7 +26,7 @@ from InnerEye.ML.scalar_config import ScalarModelBase
|
|||
from InnerEye.ML.sequence_config import SequenceModelBase
|
||||
from InnerEye.ML.utils import ml_util, model_util
|
||||
from InnerEye.ML.utils.config_util import ModelConfigLoader
|
||||
from InnerEye.ML.utils.lr_scheduler import LRScheduler
|
||||
from InnerEye.ML.utils.lr_scheduler import SchedulerWithWarmUp
|
||||
from InnerEye.ML.utils.metrics_util import create_summary_writers
|
||||
from InnerEye.ML.utils.ml_util import RandomStateSnapshot
|
||||
from InnerEye.ML.utils.model_util import ModelAndInfo, create_model_with_temperature_scaling, \
|
||||
|
@ -60,7 +62,7 @@ def model_train(config: ModelConfigBase, run_recovery: Optional[RunRecovery] = N
|
|||
:raises TypeError: If the arguments are of the wrong type.
|
||||
:raises ValueError: When there are issues loading a previous checkpoint.
|
||||
"""
|
||||
# save the datasets csv for record
|
||||
# Save the dataset files for later use in cross validation analysis
|
||||
config.write_dataset_files()
|
||||
|
||||
# set the random seed for all libraries
|
||||
|
@ -93,12 +95,11 @@ def model_train(config: ModelConfigBase, run_recovery: Optional[RunRecovery] = N
|
|||
# Print out a detailed breakdown of layers, memory consumption and time.
|
||||
generate_and_print_model_summary(config, model)
|
||||
|
||||
# Enable mixed precision training and data parallelization (no-op if already done).
|
||||
# Prepare for mixed precision training and data parallelization (no-op if already done).
|
||||
# This relies on the information generated in the model summary.
|
||||
|
||||
# We only want to do this if we didn't call load_checkpoint above, because attempting updating twice
|
||||
# causes an error.
|
||||
models_and_optimizers = [model_util.update_model_for_mixed_precision_and_parallel(model_and_info, config)
|
||||
models_and_optimizers = [model_util.update_model_for_multiple_gpus(model_and_info, config)
|
||||
for model_and_info in models_and_optimizers]
|
||||
|
||||
# Create the SummaryWriters for Tensorboard
|
||||
|
@ -111,7 +112,7 @@ def model_train(config: ModelConfigBase, run_recovery: Optional[RunRecovery] = N
|
|||
mean_teacher_model = models_and_optimizers[1].model if len(models_and_optimizers) > 1 else None
|
||||
|
||||
# Create LR scheduler
|
||||
l_rate_scheduler = LRScheduler(config, optimizer)
|
||||
l_rate_scheduler = SchedulerWithWarmUp(config, optimizer)
|
||||
|
||||
# Training loop
|
||||
logging.info("Starting training")
|
||||
|
@ -124,6 +125,7 @@ def model_train(config: ModelConfigBase, run_recovery: Optional[RunRecovery] = N
|
|||
tb_log_file_path=str(config.logs_folder / "diagnostics"))
|
||||
resource_monitor.start()
|
||||
|
||||
gradient_scaler = GradScaler() if config.use_gpu and config.use_mixed_precision else None
|
||||
optimal_temperature_scale_values = []
|
||||
for epoch in config.get_train_epochs():
|
||||
logging.info("Starting epoch {}".format(epoch))
|
||||
|
@ -139,6 +141,7 @@ def model_train(config: ModelConfigBase, run_recovery: Optional[RunRecovery] = N
|
|||
mean_teacher_model=mean_teacher_model,
|
||||
epoch=epoch,
|
||||
optimizer=optimizer,
|
||||
gradient_scaler=gradient_scaler,
|
||||
epoch_learning_rate=epoch_lrs,
|
||||
summary_writers=writers,
|
||||
dataframe_loggers=config.metrics_data_frame_loggers,
|
||||
|
|
|
@ -12,6 +12,9 @@ import numpy as np
|
|||
import param
|
||||
import torch.cuda
|
||||
import torch.utils.data
|
||||
from torch import Tensor
|
||||
from torch.cuda import amp
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn import MSELoss
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
@ -33,7 +36,8 @@ from InnerEye.ML.models.losses.cross_entropy import CrossEntropyLoss
|
|||
from InnerEye.ML.models.losses.ece import ECELoss
|
||||
from InnerEye.ML.models.losses.mixture import MixtureLoss
|
||||
from InnerEye.ML.models.losses.soft_dice import SoftDiceLoss
|
||||
from InnerEye.ML.models.parallel.data_parallel import DataParallelCriterion, DataParallelModel
|
||||
from InnerEye.ML.models.parallel.data_parallel import DataParallelCriterion, DataParallelModel, \
|
||||
execute_within_autocast_if_needed
|
||||
from InnerEye.ML.pipelines.forward_pass import SegmentationForwardPass, single_optimizer_step
|
||||
from InnerEye.ML.scalar_config import ScalarLoss, ScalarModelBase
|
||||
from InnerEye.ML.sequence_config import SequenceModelBase
|
||||
|
@ -71,6 +75,7 @@ class TrainValidateParameters(param.Parameterized, Generic[M]):
|
|||
in_training_mode: bool = param.Boolean(default=True)
|
||||
dataframe_loggers: MetricsDataframeLoggers = param.ClassSelector(class_=MetricsDataframeLoggers, instantiate=False)
|
||||
save_metrics: bool = param.Boolean(default=True)
|
||||
gradient_scaler = param.ClassSelector(class_=GradScaler, instantiate=False)
|
||||
|
||||
|
||||
class ModelTrainingStepsBase(Generic[C, M], ABC):
|
||||
|
@ -142,7 +147,9 @@ class ModelTrainingStepsBase(Generic[C, M], ABC):
|
|||
"""
|
||||
loss_function = self.create_loss_function()
|
||||
if self.model_config.use_data_parallel:
|
||||
return DataParallelCriterion(loss_function, self.model_config.get_cuda_devices())
|
||||
return DataParallelCriterion(module=loss_function,
|
||||
device_ids=self.model_config.get_cuda_devices(), # type:ignore
|
||||
use_mixed_precision=self.model_config.use_mixed_precision)
|
||||
else:
|
||||
return loss_function
|
||||
|
||||
|
@ -155,7 +162,7 @@ class ModelTrainingStepsBase(Generic[C, M], ABC):
|
|||
"""
|
||||
# ensure that the labels are loaded into the GPU
|
||||
labels = self.model_config.get_gpu_tensor_if_possible(labels)
|
||||
loss = self.forward_criterion(model_output, labels)
|
||||
loss = self.forward_criterion_with_autocast(model_output, labels)
|
||||
if self.model_config.use_data_parallel:
|
||||
# Aggregate the loss values for each parallelized batch element.
|
||||
loss = torch.mean(loss)
|
||||
|
@ -171,6 +178,22 @@ class ModelTrainingStepsBase(Generic[C, M], ABC):
|
|||
"""
|
||||
return self.criterion(model_output, labels)
|
||||
|
||||
def forward_criterion_with_autocast(self,
|
||||
model_output: Union[torch.Tensor, List[torch.Tensor]],
|
||||
labels: NumpyOrTorch) -> torch.Tensor:
|
||||
"""
|
||||
Handles the forward pass for the loss function, possibly taking mixed precision into account.
|
||||
:param model_output: A single Tensor, or a list if using DataParallelCriterion
|
||||
:param labels: Labels to compute loss against.
|
||||
:return: loss tensor. This can be a float16 or float32 tensor, which should be cast to float32 before further
|
||||
use.
|
||||
"""
|
||||
if self.model_config.use_mixed_precision:
|
||||
with amp.autocast():
|
||||
return self.forward_criterion(model_output, labels)
|
||||
else:
|
||||
return self.forward_criterion(model_output, labels)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScalarModelInputsAndLabels(Generic[E, T]):
|
||||
|
@ -289,25 +312,48 @@ class ModelTrainingStepsForScalarModel(ModelTrainingStepsBase[F, DeviceAwareModu
|
|||
return self.model_config.get_gpu_tensor_if_possible(labels)
|
||||
|
||||
def get_logits_and_posteriors(self, *model_inputs: torch.Tensor, use_mean_teacher_model: bool = False) \
|
||||
-> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
-> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Returns a Tuple containing the logits and the final model output. Note that the logits might be
|
||||
distributed over multiple GPU if the model is an instance of DataParallel. In this case,
|
||||
the gathered_logits and posteriors will be gathered to GPU_0.
|
||||
|
||||
the posteriors will be gathered to GPU_0.
|
||||
:param model_inputs: input to evaluate the model on
|
||||
:param use_mean_teacher_model: If True, logits and posteriors are produced for the mean teacher model. Else
|
||||
logits and posteriors are produced for the standard (student) model.
|
||||
:return: Tuple (logits, gathered_logits, posteriors).
|
||||
:return: Tuple (logits, posteriors).
|
||||
"""
|
||||
if use_mean_teacher_model:
|
||||
logits = self.train_val_params.mean_teacher_model(*model_inputs)
|
||||
else:
|
||||
logits = self.train_val_params.model(*model_inputs)
|
||||
posteriors = self.model_config.get_post_loss_logits_normalization_function()(gather_tensor(logits))
|
||||
return logits, posteriors
|
||||
|
||||
gathered_logits = gather_tensor(logits)
|
||||
posteriors = self.model_config.get_post_loss_logits_normalization_function()(gathered_logits)
|
||||
return logits, gathered_logits, posteriors
|
||||
def _compute_model_output_and_loss(self, model_inputs_and_labels: ScalarModelInputsAndLabels) -> \
|
||||
Tuple[Tensor, Tensor, Tensor]:
|
||||
"""
|
||||
Computes the output of the model for a given set of inputs and labels.
|
||||
Returns a tuple of (logits, posteriors, loss). For multi-GPU computation, the logits are returned
|
||||
as a list.
|
||||
"""
|
||||
model = self.train_val_params.model
|
||||
label_gpu = self.get_label_tensor(model_inputs_and_labels.labels)
|
||||
if self.model_config.use_mixed_precision and self.model_config.use_gpu:
|
||||
label_gpu = label_gpu.to(dtype=torch.float16)
|
||||
|
||||
def compute() -> Tuple[Tensor, Tensor, Tensor]:
|
||||
if self.in_training_mode:
|
||||
model.train()
|
||||
logits, posteriors = self.get_logits_and_posteriors(*model_inputs_and_labels.model_inputs)
|
||||
else:
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
logits, posteriors = self.get_logits_and_posteriors(*model_inputs_and_labels.model_inputs)
|
||||
model.train()
|
||||
loss = self.compute_loss(logits, label_gpu)
|
||||
return logits, posteriors, loss
|
||||
|
||||
return execute_within_autocast_if_needed(func=compute, use_autocast=self.model_config.use_mixed_precision)
|
||||
|
||||
def forward_and_backward_minibatch(self, sample: Dict[str, Any],
|
||||
batch_index: int, epoch: int) -> ModelForwardAndBackwardsOutputs:
|
||||
|
@ -321,23 +367,13 @@ class ModelTrainingStepsForScalarModel(ModelTrainingStepsBase[F, DeviceAwareModu
|
|||
model = self.train_val_params.model
|
||||
mean_teacher_model = self.train_val_params.mean_teacher_model
|
||||
model_inputs_and_labels = get_scalar_model_inputs_and_labels(self.model_config, model, sample)
|
||||
|
||||
if self.in_training_mode:
|
||||
model.train()
|
||||
logits, gathered_logits, posteriors = \
|
||||
self.get_logits_and_posteriors(*model_inputs_and_labels.model_inputs)
|
||||
else:
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
logits, gathered_logits, posteriors = \
|
||||
self.get_logits_and_posteriors(*model_inputs_and_labels.model_inputs)
|
||||
model.train()
|
||||
|
||||
label_gpu = self.get_label_tensor(model_inputs_and_labels.labels)
|
||||
loss = self.compute_loss(logits, label_gpu)
|
||||
|
||||
logits, posteriors, loss = self._compute_model_output_and_loss(model_inputs_and_labels)
|
||||
gathered_logits = gather_tensor(logits)
|
||||
if self.in_training_mode:
|
||||
single_optimizer_step(self.model_config, loss, self.train_val_params.optimizer)
|
||||
single_optimizer_step(loss,
|
||||
self.train_val_params.optimizer,
|
||||
self.train_val_params.gradient_scaler)
|
||||
if self.model_config.compute_mean_teacher_model:
|
||||
self.update_mean_teacher_parameters()
|
||||
|
||||
|
@ -346,9 +382,16 @@ class ModelTrainingStepsForScalarModel(ModelTrainingStepsBase[F, DeviceAwareModu
|
|||
# instead of the output of the student model.
|
||||
mean_teacher_model.eval()
|
||||
with torch.no_grad():
|
||||
logits, gathered_logits, posteriors = self.get_logits_and_posteriors(
|
||||
logits, posteriors = self.get_logits_and_posteriors(
|
||||
*model_inputs_and_labels.model_inputs,
|
||||
use_mean_teacher_model=True)
|
||||
gathered_logits = gather_tensor(logits)
|
||||
|
||||
# Autocast may have returned float16 tensors. Documentation suggests to simply cast back to float32.
|
||||
# If tensor was already float32, no overhead is incurred.
|
||||
posteriors = posteriors.detach().float()
|
||||
gathered_logits = gathered_logits.detach().float().cpu()
|
||||
loss_scalar = loss.float().item()
|
||||
|
||||
if self.train_val_params.save_metrics:
|
||||
if self._should_save_grad_cam_output(epoch=epoch, batch_index=batch_index):
|
||||
|
@ -357,15 +400,15 @@ class ModelTrainingStepsForScalarModel(ModelTrainingStepsBase[F, DeviceAwareModu
|
|||
model_inputs_and_labels.model_inputs,
|
||||
label_gpu)
|
||||
|
||||
self.metrics.add_metric(MetricType.LOSS, loss.item())
|
||||
self.metrics.add_metric(MetricType.LOSS, loss_scalar)
|
||||
self.update_metrics(model_inputs_and_labels.subject_ids, posteriors, label_gpu)
|
||||
logging.debug(f"Batch {batch_index}: {self.metrics.to_string()}")
|
||||
minibatch_time = time.time() - start_time
|
||||
self.metrics.add_metric(MetricType.SECONDS_PER_BATCH, minibatch_time)
|
||||
|
||||
return ModelForwardAndBackwardsOutputs(
|
||||
loss=loss.item(),
|
||||
logits=gathered_logits.detach().cpu(),
|
||||
loss=loss_scalar,
|
||||
logits=gathered_logits,
|
||||
labels=model_inputs_and_labels.labels
|
||||
)
|
||||
|
||||
|
@ -504,7 +547,7 @@ class ModelTrainingStepsForSequenceModel(ModelTrainingStepsForScalarModel[Sequen
|
|||
_model = _model.get_module()
|
||||
|
||||
def _forward_criterion(_logits: torch.Tensor, _labels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
loss = self.forward_criterion(_logits, _labels)
|
||||
loss = self.forward_criterion_with_autocast(_logits, _labels).to(torch.float32)
|
||||
masked_model_outputs_and_labels = get_masked_model_outputs_and_labels(_logits, _labels)
|
||||
assert masked_model_outputs_and_labels is not None
|
||||
ece = ece_criterion(masked_model_outputs_and_labels.model_outputs.data.unsqueeze(dim=0),
|
||||
|
@ -540,7 +583,8 @@ class ModelTrainingStepsForSegmentation(ModelTrainingStepsBase[SegmentationModel
|
|||
batch_size=self.model_config.train_batch_size,
|
||||
optimizer=self.train_val_params.optimizer,
|
||||
in_training_mode=self.train_val_params.in_training_mode,
|
||||
criterion=self.compute_loss)
|
||||
criterion=self.compute_loss,
|
||||
gradient_scaler=train_val_params.gradient_scaler)
|
||||
self.metrics = MetricsDict(hues=[BACKGROUND_CLASS_NAME] + model_config.ground_truth_ids)
|
||||
|
||||
def create_loss_function(self) -> torch.nn.Module:
|
||||
|
|
|
@ -19,6 +19,7 @@ from InnerEye.ML.models.layers.basic import BasicLayer
|
|||
from InnerEye.ML.models.layers.identity import Identity
|
||||
from InnerEye.ML.models.layers.pooling_layers import AveragePooling, Gated3dPoolingLayer, \
|
||||
MaxPooling, MixPooling, ZAdaptive3dAvgLayer
|
||||
from InnerEye.ML.models.parallel.data_parallel import execute_within_autocast_if_needed
|
||||
from InnerEye.ML.scalar_config import AggregationType
|
||||
from InnerEye.ML.utils.image_util import HDF5_NUM_SEGMENTATION_CLASSES, segmentation_to_one_hot
|
||||
|
||||
|
@ -88,7 +89,7 @@ class ImageEncoder(DeviceAwareModule[ScalarItem, torch.Tensor]):
|
|||
super().__init__()
|
||||
self.num_non_image_features = num_non_image_features
|
||||
self.imaging_feature_type = imaging_feature_type
|
||||
|
||||
self.use_mixed_precision = use_mixed_precision
|
||||
if isinstance(kernel_size_per_encoding_block, list):
|
||||
if len(kernel_size_per_encoding_block) != num_encoder_blocks:
|
||||
raise ValueError(f"expected kernel_size_per_encoding_block to be of "
|
||||
|
@ -106,7 +107,6 @@ class ImageEncoder(DeviceAwareModule[ScalarItem, torch.Tensor]):
|
|||
self.stride_size_per_encoding_block = [stride_size_per_encoding_block] * num_encoder_blocks
|
||||
self.conv_in_3d = np.any([k[0] != 1 for k in self.kernel_size_per_encoding_block]) \
|
||||
or np.any([s[0] != 1 for s in self.stride_size_per_encoding_block])
|
||||
self.use_mixed_precision = use_mixed_precision
|
||||
self.padding_mode = padding_mode
|
||||
self.encode_channels_jointly = encode_channels_jointly
|
||||
self.num_image_channels = num_image_channels
|
||||
|
@ -199,42 +199,46 @@ class ImageEncoder(DeviceAwareModule[ScalarItem, torch.Tensor]):
|
|||
:param item: ClassificationItem
|
||||
:return: Tensor
|
||||
"""
|
||||
use_gpu = self.is_model_on_gpu()
|
||||
result_dtype = torch.float16 if self.use_mixed_precision and use_gpu else torch.float32
|
||||
if self.imaging_feature_type == ImagingFeatureType.Segmentation \
|
||||
or self.imaging_feature_type == ImagingFeatureType.ImageAndSegmentation:
|
||||
if item.segmentations is None:
|
||||
raise ValueError("Expected item.segmentations to not be None")
|
||||
use_gpu = self.is_model_on_gpu()
|
||||
result_dtype = torch.float16 if self.use_mixed_precision and use_gpu else torch.float32
|
||||
# Special case need for the loading of individual positions in the sequence model,
|
||||
# the images are loaded as [C, Z, X, Y] but the segmentation_to_one_hot expects [B, C, Z, X, Y]
|
||||
if item.segmentations.ndimension() == 4:
|
||||
input_tensors = [segmentation_to_one_hot(item.segmentations.unsqueeze(dim=0),
|
||||
use_gpu=use_gpu,
|
||||
result_dtype=result_dtype).squeeze(dim=0)]
|
||||
else:
|
||||
input_tensors = [
|
||||
segmentation_to_one_hot(item.segmentations, use_gpu=use_gpu, result_dtype=result_dtype)]
|
||||
segmentation_multilabel = item.segmentations
|
||||
is_4dim = segmentation_multilabel.ndimension() == 4
|
||||
if is_4dim:
|
||||
segmentation_multilabel = segmentation_multilabel.unsqueeze(dim=0)
|
||||
segmentation_one_hot = segmentation_to_one_hot(segmentation_multilabel,
|
||||
use_gpu=use_gpu,
|
||||
result_dtype=result_dtype)
|
||||
if is_4dim:
|
||||
segmentation_one_hot = segmentation_one_hot.squeeze(dim=0)
|
||||
input_tensors = [segmentation_one_hot]
|
||||
|
||||
if self.imaging_feature_type == ImagingFeatureType.ImageAndSegmentation:
|
||||
input_tensors.append(item.images.to(dtype=result_dtype, copy=True))
|
||||
_dim = 0 if item.images.ndimension() == 4 else 1
|
||||
input_tensors = [torch.cat(input_tensors, dim=_dim)]
|
||||
else:
|
||||
input_tensors = [item.images]
|
||||
input_tensors = [item.images.to(dtype=result_dtype, copy=True)]
|
||||
|
||||
if self.image_and_non_image_features_aggregator:
|
||||
input_tensors.append(item.get_all_non_imaging_features())
|
||||
return input_tensors
|
||||
|
||||
def forward(self, *item: torch.Tensor, **kwargs: Any) -> torch.Tensor:
|
||||
x = item[0]
|
||||
x = self.encode_and_aggregate(x)
|
||||
def _forward() -> torch.Tensor:
|
||||
x = item[0]
|
||||
x = self.encode_and_aggregate(x)
|
||||
# combine non image features if required
|
||||
if self.image_and_non_image_features_aggregator:
|
||||
x = self.image_and_non_image_features_aggregator(x, item[1].float())
|
||||
return x
|
||||
|
||||
# combine non image features if required
|
||||
if self.image_and_non_image_features_aggregator:
|
||||
x = self.image_and_non_image_features_aggregator(x, item[1].float())
|
||||
|
||||
return x
|
||||
return execute_within_autocast_if_needed(func=_forward, use_autocast=self.use_mixed_precision)
|
||||
|
||||
def encode_and_aggregate(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return encode_and_aggregate(encoder=self.encoder,
|
||||
|
@ -259,7 +263,8 @@ class ImageEncoder(DeviceAwareModule[ScalarItem, torch.Tensor]):
|
|||
kernel_size=self.kernel_size_per_encoding_block[i],
|
||||
downsampling_stride=self.stride_size_per_encoding_block[i],
|
||||
padding_mode=self.padding_mode,
|
||||
use_residual=False
|
||||
use_residual=False,
|
||||
depth=i,
|
||||
)
|
||||
)
|
||||
return ModuleList(layers)
|
||||
|
@ -334,10 +339,13 @@ class ImageEncoderWithMlp(ImageEncoder):
|
|||
self.final_activation = final_activation
|
||||
|
||||
def forward(self, *item: torch.Tensor, **kwargs: Any) -> torch.Tensor:
|
||||
x = super().forward(*item)
|
||||
# pass all the features to the MLP
|
||||
x = self.classification_layer(x.view(-1, x.shape[1]))
|
||||
return self.final_activation(x)
|
||||
def _forward() -> torch.Tensor:
|
||||
x = super(ImageEncoderWithMlp, self).forward(*item)
|
||||
# pass all the features to the MLP
|
||||
x = self.classification_layer(x.view(-1, x.shape[1]))
|
||||
return self.final_activation(x)
|
||||
|
||||
return execute_within_autocast_if_needed(func=_forward, use_autocast=self.use_mixed_precision)
|
||||
|
||||
|
||||
def encode_and_aggregate(input_tensor: torch.Tensor,
|
||||
|
|
|
@ -201,6 +201,4 @@ class MultiSegmentationEncoder(DeviceAwareModule[ScalarItem, torch.Tensor]):
|
|||
raise ValueError("Expected item.segmentations to not be None")
|
||||
use_gpu = self.is_model_on_gpu()
|
||||
result_dtype = torch.float16 if self.use_mixed_precision and use_gpu else torch.float32
|
||||
return [segmentation_to_one_hot(item.segmentations,
|
||||
use_gpu=self.is_model_on_gpu(),
|
||||
result_dtype=result_dtype)]
|
||||
return [segmentation_to_one_hot(item.segmentations, use_gpu=use_gpu, result_dtype=result_dtype)]
|
||||
|
|
|
@ -12,7 +12,8 @@ from InnerEye.Common.type_annotations import IntOrTuple3, TupleInt2
|
|||
from InnerEye.ML.config import PaddingMode
|
||||
from InnerEye.ML.models.architectures.base_model import BaseModel, CropSizeConstraints
|
||||
from InnerEye.ML.models.layers.basic import BasicLayer
|
||||
from InnerEye.ML.models.parallel.model_parallel import move_to_device, partition_layers
|
||||
from InnerEye.ML.models.parallel.model_parallel import get_device_from_parameters, move_to_device, \
|
||||
partition_layers
|
||||
from InnerEye.ML.utils.layer_util import get_padding_from_kernel_size, get_upsampling_kernel_size, \
|
||||
initialise_layer_weights
|
||||
|
||||
|
@ -70,7 +71,10 @@ class UNet3D(BaseModel):
|
|||
activation(inplace=True))
|
||||
|
||||
def forward(self, x: Any) -> Any: # type: ignore
|
||||
[x] = move_to_device(input_tensors=[x], target_device=next(self.parameters()).device)
|
||||
# When using the new DataParallel of PyTorch 1.6, self.parameters would be empty. Do not attempt to move
|
||||
# the tensors in this case. If self.parameters is present, the module is used inside of a model parallel
|
||||
# construct.
|
||||
[x] = move_to_device([x], target_device=get_device_from_parameters(self))
|
||||
return self.upsample_block(x)
|
||||
|
||||
class UNetEncodeBlockSynthesis(torch.nn.Module):
|
||||
|
@ -101,8 +105,11 @@ class UNet3D(BaseModel):
|
|||
self.apply(initialise_layer_weights)
|
||||
|
||||
def forward(self, x: Any, skip_connection: Any) -> Any: # type: ignore
|
||||
# When using the new DataParallel of PyTorch 1.6, self.parameters would be empty. Do not attempt to move
|
||||
# the tensors in this case. If self.parameters is present, the module is used inside of a model parallel
|
||||
# construct.
|
||||
[x, skip_connection] = move_to_device(input_tensors=[x, skip_connection],
|
||||
target_device=next(self.parameters()).device)
|
||||
target_device=get_device_from_parameters(self))
|
||||
x = self.conv1(x)
|
||||
x += self.conv2(skip_connection)
|
||||
x = self.activation_block(x)
|
||||
|
@ -146,7 +153,11 @@ class UNet3D(BaseModel):
|
|||
dilation=dilation, activation=activation)
|
||||
|
||||
def forward(self, x: Any) -> Any: # type: ignore
|
||||
[x] = move_to_device(input_tensors=[x], target_device=next(self.parameters()).device)
|
||||
# When using the new DataParallel of PyTorch 1.6, self.parameters would be empty. Do not attempt to move
|
||||
# the tensors in this case. If self.parameters is present, the module is used inside of a model parallel
|
||||
# construct.
|
||||
target_device = get_device_from_parameters(self)
|
||||
[x] = move_to_device(input_tensors=[x], target_device=target_device)
|
||||
x = self.block1(x)
|
||||
return self.block2(x) + x if self.use_residual else self.block2(x)
|
||||
|
||||
|
@ -232,8 +243,10 @@ class UNet3D(BaseModel):
|
|||
x = layer(x, skip_connections.pop()) if layer.concat else layer(x)
|
||||
if layer_id < self.num_downsampling_paths: # type: ignore
|
||||
skip_connections.append(x)
|
||||
|
||||
[x] = move_to_device(input_tensors=[x], target_device=next(self.output_layer.parameters()).device)
|
||||
# When using the new DataParallel of PyTorch 1.6, self.parameters would be empty. Do not attempt to move
|
||||
# the tensors in this case. If self.parameters is present, the module is used inside of a model parallel
|
||||
# construct.
|
||||
[x] = move_to_device(input_tensors=[x], target_device=get_device_from_parameters(self.output_layer))
|
||||
return self.output_layer(x)
|
||||
|
||||
def get_all_child_layers(self) -> List[torch.nn.Module]:
|
||||
|
@ -241,6 +254,7 @@ class UNet3D(BaseModel):
|
|||
|
||||
def partition_model(self, devices: List[torch.device]) -> None:
|
||||
if self.summary is None:
|
||||
raise RuntimeError("Network summary is required to partition UNet3D. Call model.generate_model_summary() first.")
|
||||
raise RuntimeError(
|
||||
"Network summary is required to partition UNet3D. Call model.generate_model_summary() first.")
|
||||
|
||||
partition_layers(self.get_all_child_layers(), summary=self.summary, target_devices=devices)
|
||||
|
|
|
@ -2,17 +2,33 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from torch.cuda import amp
|
||||
from torch.nn.parallel.data_parallel import DataParallel
|
||||
from torch.nn.parallel.scatter_gather import gather, scatter_kwargs
|
||||
from typing import Any, Callable, Dict, List, Tuple, Union
|
||||
|
||||
from InnerEye.Common.type_annotations import T
|
||||
from InnerEye.ML.models.architectures.base_model import DeviceAwareModule
|
||||
from InnerEye.ML.utils.device_aware_module import E
|
||||
|
||||
|
||||
def execute_within_autocast_if_needed(func: Callable[[], T], use_autocast: bool) -> T:
|
||||
"""
|
||||
Runs the given parameterless function, and returns the function result. If the use_autocast
|
||||
flag is true, the function is evaluated inside of the torch.cuda.amp.autocast context manager,
|
||||
that can automatically cast operations to mixed precision. If the flag is false, the function
|
||||
is called as is.
|
||||
:param func: The function that should be evaluated
|
||||
:param use_autocast: If true, evaluate within the autocast context manager. If false, evaluate as is.
|
||||
"""
|
||||
if use_autocast:
|
||||
with amp.autocast():
|
||||
return func()
|
||||
else:
|
||||
return func()
|
||||
|
||||
|
||||
class DataParallelModel(DataParallel, DeviceAwareModule):
|
||||
"""
|
||||
Modifies the DataParallel class by updating the `gather` method. In this child class, Parallel outputs
|
||||
|
@ -49,6 +65,22 @@ class DataParallelModel(DataParallel, DeviceAwareModule):
|
|||
return outputs
|
||||
|
||||
|
||||
class CriterionWithAutocast(torch.nn.Module):
|
||||
"""
|
||||
A wrapper around a single module, that runs the forward pass in an autocast context manager.
|
||||
"""
|
||||
|
||||
def __init__(self, module: torch.nn.Module) -> None:
|
||||
super().__init__()
|
||||
self.module = module
|
||||
|
||||
def forward(self, # type: ignore
|
||||
*inputs: torch.Tensor,
|
||||
**kwargs: Dict[str, Any]) -> torch.Tensor:
|
||||
with amp.autocast():
|
||||
return self.module(*inputs, **kwargs)
|
||||
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
class DataParallelCriterion(DataParallel):
|
||||
"""
|
||||
|
@ -63,6 +95,13 @@ class DataParallelCriterion(DataParallel):
|
|||
>>> loss = criterion(y, target)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
module: torch.nn.Module,
|
||||
device_ids: List[Union[int, torch.device]],
|
||||
use_mixed_precision: bool):
|
||||
super().__init__(module=module, device_ids=device_ids)
|
||||
self.use_mixed_precision = use_mixed_precision
|
||||
|
||||
def forward(self, # type: ignore
|
||||
inputs: List[torch.Tensor],
|
||||
*targets: Tuple[torch.Tensor],
|
||||
|
@ -74,7 +113,8 @@ class DataParallelCriterion(DataParallel):
|
|||
_targets, _kwargs = scatter_kwargs(targets, kwargs, self.device_ids, dim=self.dim)
|
||||
if len(self.device_ids) == 1:
|
||||
return self.module(inputs, *_targets[0], **_kwargs[0])
|
||||
replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) # type: ignore
|
||||
autocast_if_needed = CriterionWithAutocast(module=self.module) if self.use_mixed_precision else self.module
|
||||
replicas = self.replicate(autocast_if_needed, self.device_ids[:len(inputs)]) # type: ignore
|
||||
|
||||
input_tuples: List[Tuple[torch.Tensor, ...]] = [(i, *t) for i, t in zip(inputs, _targets)]
|
||||
outputs = torch.nn.parallel.parallel_apply(replicas, input_tuples, _kwargs)
|
||||
|
|
|
@ -3,26 +3,38 @@
|
|||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
from collections import OrderedDict
|
||||
from typing import Generator, List, Optional
|
||||
from typing import Generator, Iterable, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def move_to_device(input_tensors: List[torch.Tensor],
|
||||
target_device: torch.device,
|
||||
non_blocking: bool = False) -> Generator:
|
||||
target_device: Optional[torch.device],
|
||||
non_blocking: bool = False) -> Iterable[torch.Tensor]:
|
||||
"""
|
||||
Updates the memory location of tensors stored in a list.
|
||||
:param input_tensors: List of torch tensors
|
||||
:param target_device: Target device (e.g. cuda:0, cuda:1, etc)
|
||||
:param target_device: Target device (e.g. cuda:0, cuda:1, etc). If the device is None, the tensors are not moved.
|
||||
:param non_blocking: bool
|
||||
"""
|
||||
return (tensor if tensor.device == target_device
|
||||
return (tensor if tensor.device == target_device or target_device is None
|
||||
else tensor.to(target_device, non_blocking=non_blocking)
|
||||
for tensor in input_tensors)
|
||||
|
||||
|
||||
def get_device_from_parameters(module: torch.nn.Module) -> Optional[torch.device]:
|
||||
"""
|
||||
Reads out the device information from the first of the module's parameters.
|
||||
If the module does not have any parameters, return None.
|
||||
"""
|
||||
try:
|
||||
first_parameter = next(module.parameters())
|
||||
return first_parameter.device
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
|
||||
def group_layers_with_balanced_memory(inputs: List[torch.nn.Module],
|
||||
num_groups: int,
|
||||
summary: Optional[OrderedDict]) -> Generator:
|
||||
|
|
|
@ -4,20 +4,20 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from apex import amp
|
||||
from torch import autograd
|
||||
from torch import Tensor, autograd
|
||||
from torch.cuda.amp import GradScaler
|
||||
# noinspection PyUnresolvedReferences
|
||||
from torch.optim import Optimizer # type: ignore
|
||||
|
||||
from InnerEye.ML.config import SegmentationModelBase
|
||||
from InnerEye.ML.deep_learning_config import DeepLearningConfig
|
||||
from InnerEye.ML.models.architectures.base_model import DeviceAwareModule
|
||||
from InnerEye.ML.models.parallel.data_parallel import execute_within_autocast_if_needed
|
||||
from InnerEye.ML.utils import image_util, ml_util
|
||||
|
||||
|
||||
|
@ -46,7 +46,8 @@ class SegmentationForwardPass:
|
|||
batch_size: int,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
in_training_mode: Optional[bool] = False,
|
||||
criterion: Optional[Callable] = None):
|
||||
criterion: Optional[Callable] = None,
|
||||
gradient_scaler: Optional[GradScaler] = None):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.config = model_config
|
||||
|
@ -54,6 +55,7 @@ class SegmentationForwardPass:
|
|||
self.optimizer = optimizer
|
||||
self.detect_anomaly = model_config.detect_anomaly
|
||||
self.criterion_fn = criterion
|
||||
self.gradient_scaler = gradient_scaler
|
||||
if in_training_mode and (optimizer is None or criterion is None):
|
||||
raise ValueError("When running in training mode, an optimizer and criterion must be provided.")
|
||||
self.in_training_mode = in_training_mode
|
||||
|
@ -116,8 +118,26 @@ class SegmentationForwardPass:
|
|||
result = self._forward_pass(patches, mask, labels)
|
||||
if result.loss is not None and (math.isnan(result.loss) or math.isinf(result.loss)):
|
||||
raise RuntimeError(f"The loss computation returned {result.loss}")
|
||||
return result
|
||||
return self._forward_pass(patches, mask, labels)
|
||||
|
||||
def _compute_loss(self, patches: Tensor, labels: Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
"""
|
||||
Do a forward pass on the model with the patches as input. If labels are provided, compute the loss.
|
||||
Return a tuple of (logits, loss).
|
||||
"""
|
||||
|
||||
def compute() -> Tuple[Any, Optional[Tensor]]:
|
||||
loss: Optional[torch.Tensor] = None
|
||||
logits = self.model(patches)
|
||||
# If labels *is* None, loss will also be None, which will stop the code below working (and
|
||||
# currently correctly triggers mypy errors).
|
||||
if labels is not None and self.criterion_fn is not None:
|
||||
loss = self.criterion_fn(logits, labels)
|
||||
return logits, loss
|
||||
|
||||
return execute_within_autocast_if_needed(func=compute, use_autocast=True if self.gradient_scaler else False)
|
||||
|
||||
def _forward_pass(self,
|
||||
patches: torch.Tensor,
|
||||
mask: torch.Tensor = None,
|
||||
|
@ -128,20 +148,14 @@ class SegmentationForwardPass:
|
|||
patches = self.config.get_gpu_tensor_if_possible(patches)
|
||||
if mask is not None:
|
||||
mask = self.config.get_gpu_tensor_if_possible(mask)
|
||||
loss: Optional[torch.Tensor] = None
|
||||
|
||||
# do a forward pass on the model with the patches as input
|
||||
# this will give outputs in format: Batches x Classes x Z x Y x X
|
||||
logits = self.model(patches)
|
||||
# If labels *is* None, loss will also be None, which will stop the code below working (and
|
||||
# currently correctly triggers mypy errors).
|
||||
if labels is not None and self.criterion_fn is not None:
|
||||
loss = self.criterion_fn(logits, labels)
|
||||
logits, loss = self._compute_loss(patches, labels)
|
||||
|
||||
if self.in_training_mode:
|
||||
if loss is None:
|
||||
raise ValueError("When running training, the labels must be present for loss computation.")
|
||||
assert self.optimizer is not None # for mypy
|
||||
single_optimizer_step(self.config, loss, self.optimizer)
|
||||
single_optimizer_step(loss, self.optimizer, self.gradient_scaler)
|
||||
|
||||
# Aggregate data parallel logits if multiple hardware are used in forward pass
|
||||
if isinstance(logits, list):
|
||||
|
@ -164,25 +178,28 @@ class SegmentationForwardPass:
|
|||
loss=loss.item() if loss is not None else None)
|
||||
|
||||
|
||||
def single_optimizer_step(config: DeepLearningConfig,
|
||||
loss: torch.Tensor,
|
||||
optimizer: Optimizer) -> None:
|
||||
def single_optimizer_step(loss: torch.Tensor,
|
||||
optimizer: Optimizer,
|
||||
gradient_scaler: Optional[GradScaler]) -> None:
|
||||
"""
|
||||
Wrapper function to make the optimizer take a single step, given a loss tensor with gradients.
|
||||
This will update the loss tensor with auto scaling for mixed
|
||||
precision training and anomaly detection to identify NaN values in gradient updates.
|
||||
:param loss: Torch tensor representing the training loss.
|
||||
:param config: The object containing all relevant settings like use of mixed precision and anomaly detection.
|
||||
:param optimizer: The torch optimizer.
|
||||
:param gradient_scaler: The Torch gradient scaler object to handle mixed precision training.
|
||||
"""
|
||||
# zero the gradients for the next optimization step as these
|
||||
# will be taken from the loss gradients
|
||||
optimizer.zero_grad()
|
||||
# compute the gradients w.r.t to the optimization variables and update the optimizer_type
|
||||
if config.use_mixed_precision and config.use_gpu:
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
if gradient_scaler:
|
||||
# Scales the loss, and calls backward() to create scaled gradients
|
||||
gradient_scaler.scale(loss).backward()
|
||||
# Unscales gradients and calls or skips optimizer.step()
|
||||
gradient_scaler.step(optimizer)
|
||||
# Updates the scale for next iteration
|
||||
gradient_scaler.update()
|
||||
else:
|
||||
loss.backward()
|
||||
# perform next optimization step
|
||||
optimizer.step(closure=None)
|
||||
optimizer.step(closure=None)
|
||||
|
|
|
@ -105,7 +105,6 @@ class ScalarInferencePipeline(ScalarInferencePipelineBase):
|
|||
model_inputs_and_labels = get_scalar_model_inputs_and_labels(self.model_config, self.model, sample)
|
||||
subject_ids = model_inputs_and_labels.subject_ids
|
||||
labels = self.model_config.get_gpu_tensor_if_possible(model_inputs_and_labels.labels)
|
||||
|
||||
model_output: torch.Tensor = self.model.forward(*model_inputs_and_labels.model_inputs)
|
||||
if isinstance(model_output, list):
|
||||
# Model output is a list if we are using data parallel. Here, this will be a degenerate list with
|
||||
|
@ -114,9 +113,8 @@ class ScalarInferencePipeline(ScalarInferencePipelineBase):
|
|||
|
||||
# Apply any post loss normalization to logits
|
||||
model_output = self.model_config.get_post_loss_logits_normalization_function()(model_output)
|
||||
result = ScalarInferencePipelineBase.Result(subject_ids, labels, model_output)
|
||||
|
||||
return result
|
||||
# Cast labels and model outputs back to float32, if the model had been run in mixed precision
|
||||
return ScalarInferencePipelineBase.Result(subject_ids, labels.float(), model_output.float())
|
||||
|
||||
|
||||
class ScalarEnsemblePipeline(ScalarInferencePipelineBase):
|
||||
|
|
|
@ -4,125 +4,102 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, List, Optional, Dict, Union
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from torch.optim.lr_scheduler import ExponentialLR, LambdaLR, StepLR, MultiStepLR, CosineAnnealingLR
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR, ExponentialLR, LambdaLR, MultiStepLR, StepLR, _LRScheduler
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
from InnerEye.ML.deep_learning_config import DeepLearningConfig, LRSchedulerType, LRWarmUpType
|
||||
|
||||
|
||||
class LRWarmUp(_LRScheduler):
|
||||
def get_current_learning_rates(optimizer: Optimizer) -> List[float]:
|
||||
"""
|
||||
Base class for schedulers that implement learning rate warmup.
|
||||
Reads the current values of the learning rate(s) for all parameter groups from the optimizer.
|
||||
"""
|
||||
return [group['lr'] for group in optimizer.param_groups]
|
||||
|
||||
def __init__(self, optimizer: Optimizer, warmup_epochs: int, last_epoch: int = -1):
|
||||
|
||||
class LinearWarmUp(_LRScheduler):
|
||||
"""
|
||||
Implements linear warmup up to a given initial learning rate.
|
||||
"""
|
||||
def __init__(self, optimizer: Optimizer, warmup_epochs: int, final_lr: float, last_epoch: int = -1):
|
||||
if warmup_epochs < 0:
|
||||
raise ValueError("The number of warmup epochs must be >= 0.")
|
||||
self.warmup_epochs = warmup_epochs
|
||||
self.final_lr = final_lr
|
||||
self.last_epoch = last_epoch
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def warmup_multiplier(self) -> float:
|
||||
if self.warmup_epochs <= 0:
|
||||
return 1.0
|
||||
if self.last_epoch >= self.warmup_epochs:
|
||||
return 1.0
|
||||
return (self.last_epoch + 1) / (self.warmup_epochs + 1)
|
||||
|
||||
def get_lr(self) -> List[float]: # type: ignore
|
||||
return [base_lr * min(self.warmup_function(), 1) for base_lr in self.base_lrs] # type: ignore
|
||||
|
||||
def warmup_function(self) -> float:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class NoLRWarmUp(LRWarmUp):
|
||||
"""
|
||||
Identity class when there is no warmup step.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer: Optimizer, last_epoch: int = -1):
|
||||
warmup_epochs = 0
|
||||
super().__init__(optimizer, warmup_epochs, last_epoch)
|
||||
|
||||
def warmup_function(self) -> float:
|
||||
return 1
|
||||
|
||||
|
||||
class LinearLRWarmUp(LRWarmUp):
|
||||
"""
|
||||
Implements linear warmup.
|
||||
"""
|
||||
def __init__(self, optimizer: Optimizer, warmup_epochs: int, last_epoch: int = -1):
|
||||
if warmup_epochs < 1:
|
||||
raise ValueError("The number of warmup epochs must be a positive integer.")
|
||||
super().__init__(optimizer, warmup_epochs, last_epoch)
|
||||
|
||||
def warmup_function(self) -> float:
|
||||
return min((self.last_epoch + 1) / self.warmup_epochs, 1) # type: ignore
|
||||
return [self.final_lr * self.warmup_multiplier()]
|
||||
|
||||
|
||||
class SchedulerWithWarmUp(_LRScheduler):
|
||||
"""
|
||||
LR Scheduler which runs first a warmup step and then a standard scheduler for LR decay.
|
||||
TODO: For Pytorch 1.6, implement get_last_lr() so that it returns a list and not float.
|
||||
LR Scheduler which runs a warmup schedule (linear ramp-up) for a few iterations, and then switches to one
|
||||
of the normal schedulers.
|
||||
"""
|
||||
|
||||
def __init__(self, args: DeepLearningConfig, optimizer: Optimizer, last_epoch: int = -1):
|
||||
self.optimizer = optimizer
|
||||
self.last_epoch = last_epoch
|
||||
self._warmup_scheduler = self.get_warmup(args)
|
||||
self.warmup_epochs = 0 if args.l_rate_warmup == LRWarmUpType.NoWarmUp else args.l_rate_warmup_epochs
|
||||
self._scheduler = self.get_scheduler(args)
|
||||
# This must be called after self.get_scheduler, because we want the optimizer to have the learning rate
|
||||
# guided by the warmup schedule
|
||||
self._warmup = LinearWarmUp(optimizer,
|
||||
warmup_epochs=self.warmup_epochs,
|
||||
final_lr=args.l_rate,
|
||||
last_epoch=last_epoch)
|
||||
self._last_lr = get_current_learning_rates(optimizer)
|
||||
self.min_l_rate = args.min_l_rate
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def get_scheduler(self, args: DeepLearningConfig) -> _LRScheduler:
|
||||
"""
|
||||
Create a LR scheduler from the config params.
|
||||
Create the LR scheduler that will be used after warmup, based on the config params.
|
||||
"""
|
||||
|
||||
last_epoch = max(-1, self.last_epoch - args.l_rate_warmup_epochs)
|
||||
|
||||
scheduler: _LRScheduler
|
||||
epochs_after_warmup = args.num_epochs - self.warmup_epochs
|
||||
if args.l_rate_scheduler == LRSchedulerType.Exponential:
|
||||
scheduler = ExponentialLR(optimizer=self.optimizer,
|
||||
gamma=args.l_rate_exponential_gamma,
|
||||
last_epoch=last_epoch)
|
||||
last_epoch=self.last_epoch)
|
||||
elif args.l_rate_scheduler == LRSchedulerType.Step:
|
||||
scheduler = StepLR(optimizer=self.optimizer,
|
||||
step_size=args.l_rate_step_step_size,
|
||||
gamma=args.l_rate_step_gamma,
|
||||
last_epoch=last_epoch)
|
||||
last_epoch=self.last_epoch)
|
||||
elif args.l_rate_scheduler == LRSchedulerType.MultiStep:
|
||||
assert args.l_rate_multi_step_milestones is not None # for mypy, we have done this check elsewhere
|
||||
assert args.l_rate_multi_step_milestones is not None
|
||||
scheduler = MultiStepLR(optimizer=self.optimizer,
|
||||
milestones=args.l_rate_multi_step_milestones,
|
||||
gamma=args.l_rate_multi_step_gamma,
|
||||
last_epoch=last_epoch)
|
||||
last_epoch=self.last_epoch)
|
||||
elif args.l_rate_scheduler == LRSchedulerType.Polynomial:
|
||||
x = args.min_l_rate / args.l_rate
|
||||
polynomial_decay: Any = lambda epoch: (1 - x) * (
|
||||
(1. - float(epoch) / args.num_epochs) ** args.l_rate_polynomial_gamma) + x
|
||||
(1. - float(epoch) / epochs_after_warmup) ** args.l_rate_polynomial_gamma) + x
|
||||
scheduler = LambdaLR(optimizer=self.optimizer,
|
||||
lr_lambda=polynomial_decay,
|
||||
last_epoch=last_epoch)
|
||||
last_epoch=self.last_epoch)
|
||||
elif args.l_rate_scheduler == LRSchedulerType.Cosine:
|
||||
scheduler = CosineAnnealingLR(optimizer=self.optimizer,
|
||||
T_max=args.num_epochs,
|
||||
T_max=epochs_after_warmup,
|
||||
eta_min=args.min_l_rate,
|
||||
last_epoch=last_epoch)
|
||||
last_epoch=self.last_epoch)
|
||||
else:
|
||||
raise ValueError("Unknown learning rate scheduler {}".format(args.l_rate_scheduler))
|
||||
return scheduler
|
||||
|
||||
def get_warmup(self, args: DeepLearningConfig) -> LRWarmUp:
|
||||
"""
|
||||
Create a scheduler for warmup steps from the config params.
|
||||
"""
|
||||
|
||||
warmup: LRWarmUp
|
||||
if args.l_rate_warmup == LRWarmUpType.NoWarmUp:
|
||||
warmup = NoLRWarmUp(optimizer=self.optimizer,
|
||||
last_epoch=self.last_epoch)
|
||||
elif args.l_rate_warmup == LRWarmUpType.Linear:
|
||||
warmup = LinearLRWarmUp(optimizer=self.optimizer,
|
||||
warmup_epochs=args.l_rate_warmup_epochs,
|
||||
last_epoch=self.last_epoch)
|
||||
else:
|
||||
raise ValueError("Unknown learning rate warmup {}".format(args.l_rate_warmup))
|
||||
return warmup
|
||||
|
||||
def state_dict(self) -> Dict:
|
||||
"""
|
||||
Added for completeness, since base class _LRScheduler implements this.
|
||||
|
@ -132,10 +109,9 @@ class SchedulerWithWarmUp(_LRScheduler):
|
|||
The state dict does not include the state of the optimizer.
|
||||
"""
|
||||
state_dict = {key: val for key, val in self.__dict__.items()
|
||||
if key != "_scheduler" and key != "_warmup_scheduler"
|
||||
and key != "optimizer"}
|
||||
if key != "_scheduler" and key != "optimizer" and key != "_warmup"}
|
||||
state_dict['_scheduler'] = self._scheduler.state_dict()
|
||||
state_dict['_warmup_scheduler'] = self._warmup_scheduler.state_dict()
|
||||
state_dict['_warmup'] = self._warmup.state_dict()
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict: Dict) -> None:
|
||||
|
@ -145,90 +121,29 @@ class SchedulerWithWarmUp(_LRScheduler):
|
|||
Initializes variables "_scheduler" and "_warmup_scheduler" separately, by calling load_state_dict
|
||||
for these variables.
|
||||
"""
|
||||
top_level = {key: val for key, val in state_dict.items()
|
||||
if key != "_scheduler" and key != "_warmup_scheduler"}
|
||||
top_level = {key: val for key, val in state_dict.items() if key != "_scheduler" and key != "_warmup"}
|
||||
self.__dict__.update(top_level)
|
||||
self._scheduler.__dict__.update(state_dict["_scheduler"])
|
||||
self._warmup_scheduler.__dict__.update(state_dict["_warmup_scheduler"])
|
||||
|
||||
def get_lr(self) -> List[float]: # type: ignore
|
||||
lr: Union[float, List[float]]
|
||||
if self.last_epoch < self._warmup_scheduler.warmup_epochs:
|
||||
lr = self._warmup_scheduler.get_lr()
|
||||
else:
|
||||
lr = self._scheduler.get_lr()
|
||||
lrs: List[float] = [lr] if isinstance(lr, float) else lr
|
||||
return lrs
|
||||
self._warmup.__dict__.update(state_dict["_warmup"])
|
||||
|
||||
def step(self, epoch: int = None) -> None:
|
||||
target_epoch = epoch if epoch is not None else self.last_epoch + 1
|
||||
|
||||
if target_epoch < self._warmup_scheduler.warmup_epochs:
|
||||
self._warmup_scheduler.step(epoch)
|
||||
elif target_epoch == self._warmup_scheduler.warmup_epochs:
|
||||
# don't step here, or we will miss the first value from the scheduler
|
||||
pass
|
||||
else:
|
||||
scheduler_epoch = epoch - self._warmup_scheduler.warmup_epochs if epoch else None
|
||||
self._scheduler.step(scheduler_epoch)
|
||||
|
||||
self.last_epoch = target_epoch
|
||||
|
||||
|
||||
class LRScheduler:
|
||||
"""
|
||||
Wrapper around Torch LRScheduler functions with added functionality to restrict learning rate to a
|
||||
minimum value based on the provided configurations.
|
||||
"""
|
||||
_scheduler: SchedulerWithWarmUp
|
||||
_min_lr: float = 0
|
||||
_max_epochs: int = 0
|
||||
|
||||
def __init__(self, args: DeepLearningConfig, optimizer: Optimizer):
|
||||
"""
|
||||
:param args: the config defining the model
|
||||
:param optimizer: the optimizer to use for model training
|
||||
"""
|
||||
self._min_lr = args.min_l_rate
|
||||
self._max_epochs = args.num_epochs
|
||||
|
||||
# if loading from a checkpoint, then last epoch will be the checkpoint epoch
|
||||
# otherwise -1 as no epochs have been trained.
|
||||
# For pytorch version 1.3:
|
||||
last_epoch = args.start_epoch if args.should_load_checkpoint_for_training() else -1
|
||||
# For pytorch version 1.6:
|
||||
# last_epoch = args.start_epoch - 1 if args.should_load_checkpoint_for_training() else -1
|
||||
|
||||
self._scheduler = SchedulerWithWarmUp(args, optimizer, last_epoch)
|
||||
# self.step() is called in the _LRScheduler.__init__, as the very last operation, when self.last_epoch == -1
|
||||
# Inside of the default implementation of self.step, it calls
|
||||
# self.last_epoch += 1
|
||||
# values = self.get_lr()
|
||||
# The values are then set in the optimizer, and stored in self._last_lr
|
||||
if epoch is not None:
|
||||
raise ValueError("Calling scheduler.step with an epoch argument will be deprecated.")
|
||||
# self.step is called from within the base class constructor, _LRScheduler.__init__
|
||||
# The scheduler itself has already been initialized, and scheduler.step has also been called already in
|
||||
# the respective constructor. Avoid calling it again here.
|
||||
if self.last_epoch != -1:
|
||||
if self.last_epoch < self._warmup.warmup_epochs:
|
||||
self._warmup.step()
|
||||
else:
|
||||
self._scheduler.step()
|
||||
self.last_epoch += 1
|
||||
self._last_lr = get_current_learning_rates(self.optimizer)
|
||||
|
||||
def get_last_lr(self) -> List[float]:
|
||||
"""
|
||||
Get the current learning rate (making sure it is >= min_l_rate if provided in the config)
|
||||
"""
|
||||
# For pytorch version 1.3:
|
||||
lrs = self._scheduler.get_lr() # type: ignore
|
||||
# For pytorch version 1.6:
|
||||
# lrs = self._scheduler.get_last_lr() # type: ignore
|
||||
return [max(self._min_lr, x) for x in lrs]
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
"""
|
||||
Get the current lr scheduler state
|
||||
"""
|
||||
return self._scheduler.state_dict()
|
||||
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
"""
|
||||
Load the given state into the lr scheduler
|
||||
"""
|
||||
self._scheduler.load_state_dict(state_dict)
|
||||
|
||||
def step(self, epoch: Optional[int] = None) -> None:
|
||||
"""
|
||||
Move the lr scheduler to the state corresponding to the provided epoch or next epoch.
|
||||
"""
|
||||
if epoch is not None and epoch > self._max_epochs:
|
||||
raise ValueError("epoch must be <= {}".format(self._max_epochs))
|
||||
else:
|
||||
# noinspection PyTypeChecker
|
||||
self._scheduler.step(epoch)
|
||||
return self._last_lr
|
||||
|
|
|
@ -9,7 +9,6 @@ from pathlib import Path
|
|||
from typing import Any, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from apex import amp
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from torch.optim.rmsprop import RMSprop
|
||||
|
||||
|
@ -34,7 +33,7 @@ from InnerEye.ML.utils.ml_util import RandomStateSnapshot, is_gpu_available
|
|||
from InnerEye.ML.utils.temperature_scaling import ModelWithTemperature
|
||||
from InnerEye.ML.visualizers.model_summary import ModelSummary
|
||||
|
||||
BaseModelOrDataParallelModel = Union[BaseModel, DataParallelModel]
|
||||
BaseModelOrDataParallelModel = Union[DeviceAwareModule, DataParallelModel]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -59,12 +58,6 @@ class ModelAndInfo:
|
|||
assert self.model is not None
|
||||
self.model = self.model.cuda()
|
||||
|
||||
def apply_amp_output(self, amp_output: Any) -> None:
|
||||
if isinstance(amp_output, tuple):
|
||||
self.model, self.optimizer = amp_output
|
||||
else:
|
||||
self.model = amp_output
|
||||
|
||||
def set_data_parallel(self, device_ids: Optional[List[Any]]) -> None:
|
||||
assert self.model is not None
|
||||
self.model = DataParallelModel(self.model, device_ids=device_ids)
|
||||
|
@ -131,17 +124,15 @@ def build_net(args: SegmentationModelBase) -> BaseModel:
|
|||
return network
|
||||
|
||||
|
||||
def update_model_for_mixed_precision_and_parallel(model_and_info: ModelAndInfo,
|
||||
args: ModelConfigBase,
|
||||
execution_mode: ModelExecutionMode = ModelExecutionMode.TRAIN) -> \
|
||||
def update_model_for_multiple_gpus(model_and_info: ModelAndInfo,
|
||||
args: ModelConfigBase,
|
||||
execution_mode: ModelExecutionMode = ModelExecutionMode.TRAIN) -> \
|
||||
ModelAndInfo:
|
||||
"""
|
||||
Updates a given torch model as such input mini-batches are parallelized across the batch dimension to utilise
|
||||
multiple gpus. If model parallel is set to True and execution is in test mode, then model is partitioned to
|
||||
perform full volume inference. Additionally, mixed precision training (amp) is utilised on both the model and
|
||||
optimizer instances to improve the training performance.
|
||||
|
||||
:param model_and_info: The torch module object representing the network, and more
|
||||
perform full volume inference.
|
||||
:param model_and_info: The torch module object representing the network and the optimizer.
|
||||
:param args: The arguments object with attributes used to enable amp training and create the parallel model.
|
||||
:param execution_mode: mode, i.e. train or test
|
||||
:return: Updated torch model and optimizer.
|
||||
|
@ -158,19 +149,10 @@ def update_model_for_mixed_precision_and_parallel(model_and_info: ModelAndInfo,
|
|||
devices = args.get_cuda_devices()
|
||||
assert devices is not None # for mypy
|
||||
model_and_info.model.partition_model(devices=devices) # type: ignore
|
||||
|
||||
# This is required to support sigmoid function
|
||||
amp.register_float_function(torch, 'sigmoid')
|
||||
|
||||
# Activate automatic mixed precision
|
||||
# With optimization GEMMs and convolutions are performed in FP16, see https://nvidia.github.io/apex/amp.html
|
||||
amp_output = amp.initialize(model_and_info.model, model_and_info.optimizer, enabled=args.use_mixed_precision,
|
||||
opt_level="O1", keep_batchnorm_fp32=None, loss_scale="dynamic", num_losses=1)
|
||||
model_and_info.apply_amp_output(amp_output)
|
||||
else:
|
||||
logging.info("Making no adjustments to the model because no GPU was found.")
|
||||
|
||||
# Update model related config attributes (After AMP & Model Parallel Activated)
|
||||
# Update model related config attributes (After Model Parallel Activated)
|
||||
args.adjust_after_mixed_precision_and_parallel(model_and_info.model)
|
||||
|
||||
# DataParallel enables running the model with multiple gpus by splitting samples across GPUs
|
||||
|
@ -253,8 +235,6 @@ def generate_and_print_model_summary(config: ModelConfigBase, model: DeviceAware
|
|||
model_inputs = get_scalar_model_inputs_and_labels(config, model, train_item_0).model_inputs
|
||||
# The model inputs may already be converted to float16, assuming that we would do mixed precision.
|
||||
# However, the model is not yet converted to float16 when this function is called, hence convert back to float32
|
||||
if config.use_gpu:
|
||||
model_inputs = [x.float() for x in model_inputs]
|
||||
summary = ModelSummary(model)
|
||||
summary.generate_summary(input_tensors=model_inputs, log_summaries_to_files=config.log_summaries_to_files)
|
||||
elif config.is_segmentation_model:
|
||||
|
@ -365,7 +345,7 @@ def load_from_checkpoint_and_adjust(model_config: ModelConfigBase,
|
|||
if model_config.is_segmentation_model:
|
||||
# Generate the model summary, which is required for model partitioning across GPUs.
|
||||
summary_for_segmentation_models(model_config, model_and_info.model)
|
||||
return update_model_for_mixed_precision_and_parallel(
|
||||
return update_model_for_multiple_gpus(
|
||||
model_and_info, args=model_config, execution_mode=model_and_info.model_execution_mode)
|
||||
|
||||
|
||||
|
|
|
@ -90,7 +90,7 @@ class ModelWithTemperature(DeviceAwareModule):
|
|||
loss.backward()
|
||||
return loss
|
||||
|
||||
optimizer.step(eval_criterion)
|
||||
optimizer.step(eval_criterion) # type: ignore
|
||||
|
||||
after_temperature_loss, after_temperature_ece = criterion_fn(self.temperature_scale(logits), labels)
|
||||
print('Optimal temperature: {:.3f}'.format(self.temperature.item()))
|
||||
|
|
|
@ -5,4 +5,3 @@ channels:
|
|||
dependencies:
|
||||
- pip=20.1.1
|
||||
- python=3.7.3
|
||||
- pytorch=1.3.0
|
||||
|
|
|
@ -10,6 +10,7 @@ from InnerEye.Common.type_annotations import TupleInt3
|
|||
from InnerEye.ML.dataset.scalar_sample import ScalarItem
|
||||
from InnerEye.ML.models.architectures.base_model import DeviceAwareModule
|
||||
from InnerEye.ML.models.layers.identity import Identity
|
||||
from InnerEye.ML.models.parallel.data_parallel import execute_within_autocast_if_needed
|
||||
|
||||
|
||||
class DummyScalarModel(DeviceAwareModule[ScalarItem, torch.Tensor]):
|
||||
|
@ -28,6 +29,7 @@ class DummyScalarModel(DeviceAwareModule[ScalarItem, torch.Tensor]):
|
|||
self.activation = activation
|
||||
self.last_encoder_layer: List[str] = ["_layers", "0"]
|
||||
self.conv_in_3d = False
|
||||
self.use_mixed_precision = False
|
||||
|
||||
def get_last_encoder_layer_names(self) -> List[str]:
|
||||
return self.last_encoder_layer
|
||||
|
@ -41,12 +43,18 @@ class DummyScalarModel(DeviceAwareModule[ScalarItem, torch.Tensor]):
|
|||
return [item.images]
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
|
||||
if x.shape[-3:] != self.expected_image_size_zyx:
|
||||
raise ValueError(f"Expected a tensor with trailing size {self.expected_image_size_zyx}, but got "
|
||||
f"{x.shape}")
|
||||
def _forward() -> torch.Tensor:
|
||||
# Need to copy to a local variable, because we can't re-assign x here
|
||||
x2 = x
|
||||
if x2.shape[-3:] != self.expected_image_size_zyx:
|
||||
raise ValueError(f"Expected a tensor with trailing size {self.expected_image_size_zyx}, but got "
|
||||
f"{x2.shape}")
|
||||
|
||||
for layer in self._layers.__iter__():
|
||||
x = layer(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc(x)
|
||||
return self.activation(x)
|
||||
for layer in self._layers.__iter__():
|
||||
x2 = layer(x2)
|
||||
x2 = x2.view(x2.size(0), -1)
|
||||
x2 = self.fc(x2)
|
||||
return self.activation(x2)
|
||||
|
||||
# Models that will be used inside of DataParallel need to do their own autocast
|
||||
return execute_within_autocast_if_needed(_forward, use_autocast=self.use_mixed_precision)
|
||||
|
|
|
@ -32,7 +32,7 @@ from InnerEye.ML.utils.dataset_util import CategoricalToOneHotEncoder
|
|||
from InnerEye.ML.utils.io_util import ImageAndSegmentations
|
||||
from InnerEye.ML.utils.metrics_constants import LoggingColumns
|
||||
from InnerEye.ML.utils.model_util import ModelAndInfo, create_model_with_temperature_scaling, \
|
||||
update_model_for_mixed_precision_and_parallel
|
||||
update_model_for_multiple_gpus
|
||||
from InnerEye.ML.utils.split_dataset import DatasetSplits
|
||||
from InnerEye.ML.visualizers.grad_cam_hooks import VisualizationMaps
|
||||
from Tests.ML.util import get_default_azure_config
|
||||
|
@ -127,7 +127,6 @@ class ToySequenceModel(SequenceModelBase):
|
|||
stride_size_per_encoding_block=(1, 2, 2),
|
||||
initial_feature_channels=4,
|
||||
num_encoder_blocks=3,
|
||||
use_mixed_precision=True
|
||||
)
|
||||
assert image_encoder is not None # for mypy
|
||||
input_dims = image_encoder.final_num_feature_channels
|
||||
|
@ -267,7 +266,7 @@ def test_visualization_with_sequence_model(use_combined_model: bool,
|
|||
config.num_epochs = 1
|
||||
|
||||
model = create_model_with_temperature_scaling(config)
|
||||
update_model_for_mixed_precision_and_parallel(ModelAndInfo(model), config)
|
||||
update_model_for_multiple_gpus(ModelAndInfo(model), config)
|
||||
dataloader = SequenceDataset(config,
|
||||
data_frame=config.dataset_data_frame).as_data_loader(shuffle=False,
|
||||
batch_size=2)
|
||||
|
|
|
@ -6,14 +6,15 @@ from typing import Any, List
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from InnerEye.Common import common_util
|
||||
from InnerEye.ML.models.architectures.base_model import BaseModel, CropSizeConstraints
|
||||
from InnerEye.ML.models.losses.soft_dice import SoftDiceLoss
|
||||
from InnerEye.ML.models.parallel.data_parallel import DataParallelCriterion
|
||||
from InnerEye.ML.models.parallel.model_parallel import group_layers_with_balanced_memory, move_to_device, \
|
||||
partition_layers
|
||||
from InnerEye.ML.utils.ml_util import is_gpu_available
|
||||
from InnerEye.ML.models.parallel.model_parallel import group_layers_with_balanced_memory, \
|
||||
move_to_device, partition_layers
|
||||
from InnerEye.ML.utils.ml_util import is_gpu_available, set_random_seed
|
||||
|
||||
no_gpu = not is_gpu_available()
|
||||
no_or_single_gpu = not torch.cuda.is_available() or torch.cuda.device_count() <= 1
|
||||
|
@ -51,22 +52,26 @@ class SimpleModel(BaseModel):
|
|||
@pytest.mark.gpu
|
||||
@pytest.mark.skipif(no_gpu, reason="CUDA capable GPU is not available")
|
||||
def test_move_to_device() -> None:
|
||||
def assert_device_matches(tensors: List[Tensor], target_device: torch.device) -> None:
|
||||
for tensor in tensors:
|
||||
assert tensor.device == target_device
|
||||
|
||||
target_device = torch.device('cuda:0')
|
||||
input_tensor_1 = torch.tensor(3, device=torch.device('cpu'))
|
||||
input_tensor_2 = torch.tensor(3, device=torch.device('cuda:0'))
|
||||
[moved_tensor_1, moved_tensor_2] = move_to_device([input_tensor_1, input_tensor_2],
|
||||
target_device=target_device)
|
||||
|
||||
assert moved_tensor_1.device == target_device
|
||||
assert moved_tensor_2.device == target_device
|
||||
tensors = [input_tensor_1, input_tensor_2]
|
||||
moved = list(move_to_device(tensors, target_device=target_device))
|
||||
assert_device_matches(moved, target_device)
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
target_device = torch.device('cuda:1')
|
||||
[moved_tensor_1, moved_tensor_2] = move_to_device([input_tensor_1, input_tensor_2],
|
||||
target_device=target_device)
|
||||
moved = list(move_to_device(tensors, target_device=target_device))
|
||||
assert_device_matches(moved, target_device)
|
||||
|
||||
assert moved_tensor_1.device == target_device
|
||||
assert moved_tensor_2.device == target_device
|
||||
# Not supplying a target device should leave the tensor untouched
|
||||
moved = list(move_to_device(tensors, target_device=None))
|
||||
assert moved[0].device == tensors[0].device
|
||||
assert moved[1].device == tensors[1].device
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
|
@ -96,7 +101,8 @@ def test_partition_layers() -> None:
|
|||
all_layers = model.get_all_child_layers()
|
||||
|
||||
if summary is None:
|
||||
raise RuntimeError("Network summary is required to partition UNet3D. Call model.generate_model_summary() first.")
|
||||
raise RuntimeError(
|
||||
"Network summary is required to partition UNet3D. Call model.generate_model_summary() first.")
|
||||
|
||||
partition_layers(layers=all_layers, summary=summary, target_devices=devices)
|
||||
|
||||
|
@ -108,11 +114,15 @@ def test_partition_layers() -> None:
|
|||
|
||||
@pytest.mark.gpu
|
||||
@pytest.mark.skipif(no_gpu, reason="CUDA capable GPU is not available")
|
||||
def test_dataparallel_criterion() -> None:
|
||||
@pytest.mark.parametrize("use_mixed_precision", [False, True])
|
||||
def test_dataparallel_criterion(use_mixed_precision: bool) -> None:
|
||||
set_random_seed(1)
|
||||
num_batches = torch.cuda.device_count()
|
||||
array_shape = [num_batches, 2, 8, 4, 8]
|
||||
segmentation = torch.rand(array_shape).cuda()
|
||||
ground_truth = torch.rand(array_shape).cuda()
|
||||
if use_mixed_precision:
|
||||
segmentation = segmentation.to(dtype=torch.float16)
|
||||
target_loss_values = torch.zeros(num_batches).cuda()
|
||||
|
||||
# Sequential computation with multi-hardware parallelisation
|
||||
|
@ -122,13 +132,20 @@ def test_dataparallel_criterion() -> None:
|
|||
target_loss_values[ii] = loss
|
||||
|
||||
# Use parallel criterion
|
||||
parallel_loss_fn = DataParallelCriterion(SoftDiceLoss(), device_ids=range(torch.cuda.device_count()))
|
||||
parallel_loss_fn = DataParallelCriterion(loss_fn,
|
||||
device_ids=list(range(torch.cuda.device_count())),
|
||||
use_mixed_precision=use_mixed_precision)
|
||||
segmentation_as_parallel = [segmentation[ii:ii + 1].to("cuda:{}".format(ii)) for ii in range(num_batches)]
|
||||
computed_loss_values = \
|
||||
parallel_loss_fn(segmentation_as_parallel[0], ground_truth) if num_batches == 1 \
|
||||
else parallel_loss_fn(segmentation_as_parallel, ground_truth)
|
||||
|
||||
assert isinstance(computed_loss_values, torch.Tensor)
|
||||
if num_batches == 1:
|
||||
assert torch.equal(target_loss_values[0], computed_loss_values)
|
||||
target_loss_values = target_loss_values[0]
|
||||
diff = (target_loss_values - computed_loss_values).abs().sum().item()
|
||||
# Even with autocast turned on, the result tensor always comes back as float32, and we can't run any more
|
||||
# detailed asserts on it. Best thing to do is to check if there are signs of lower precision computation.
|
||||
if use_mixed_precision:
|
||||
assert diff > 2e-5
|
||||
else:
|
||||
assert torch.equal(target_loss_values, computed_loss_values)
|
||||
assert diff < 1e-10
|
||||
|
|
|
@ -32,19 +32,16 @@ from InnerEye.ML.visualizers.plot_cross_validation import EpochMetricValues, get
|
|||
unroll_aggregate_metrics
|
||||
from Tests.ML.configs.ClassificationModelForTesting import ClassificationModelForTesting
|
||||
from Tests.ML.configs.DummyModel import DummyModel
|
||||
from Tests.ML.models.test_parallel import no_gpu
|
||||
from Tests.ML.util import get_default_azure_config, machine_has_gpu
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
@pytest.mark.parametrize("use_mixed_precision", [False, True])
|
||||
def test_train_classification_model(test_output_dirs: TestOutputDirectories,
|
||||
use_mixed_precision: bool,
|
||||
check_logs: bool = True) -> None:
|
||||
use_mixed_precision: bool) -> None:
|
||||
"""
|
||||
Test training and testing of classification models, asserting on the individual results from training and testing.
|
||||
This executes correctly only on CPU machines - there it will return the same results for both
|
||||
use_mixed_precision==True and ==False. It will fail on a SurfaceBook where it recognizes the GPU (loss values
|
||||
don't match when use_mixed_precision==True)
|
||||
Expected test results are stored for GPU with and without mixed precision.
|
||||
"""
|
||||
logging_to_stdout(logging.DEBUG)
|
||||
config = ClassificationModelForTesting()
|
||||
|
@ -64,8 +61,8 @@ def test_train_classification_model(test_output_dirs: TestOutputDirectories,
|
|||
expected_learning_rates = [0.0001, 9.99971e-05, 9.99930e-05, 9.99861e-05]
|
||||
use_mixed_precision_and_gpu = use_mixed_precision and machine_has_gpu
|
||||
if use_mixed_precision_and_gpu:
|
||||
expected_train_loss = [0.686615, 0.686467, 0.686322, 0.686174]
|
||||
expected_val_loss = [0.737038, 0.736720, 0.736338, 0.735957]
|
||||
expected_train_loss = [0.686614, 0.686465, 0.686316, 0.686167]
|
||||
expected_val_loss = [0.737039, 0.736721, 0.736339, 0.735957]
|
||||
else:
|
||||
expected_train_loss = [0.686614, 0.686465, 0.686316, 0.686167]
|
||||
expected_val_loss = [0.737061, 0.736690, 0.736321, 0.735952]
|
||||
|
@ -84,8 +81,8 @@ def test_train_classification_model(test_output_dirs: TestOutputDirectories,
|
|||
assert list(test_results.epochs.keys()) == expected_epochs
|
||||
if use_mixed_precision_and_gpu:
|
||||
expected_metrics = {
|
||||
2: [0.635924, 0.736720],
|
||||
4: [0.636096, 0.735957],
|
||||
2: [0.635942, 0.736691],
|
||||
4: [0.636085, 0.735952],
|
||||
}
|
||||
else:
|
||||
expected_metrics = {
|
||||
|
@ -95,7 +92,9 @@ def test_train_classification_model(test_output_dirs: TestOutputDirectories,
|
|||
for epoch in expected_epochs:
|
||||
assert test_results.epochs[epoch].values()[MetricType.CROSS_ENTROPY.value] == \
|
||||
pytest.approx(expected_metrics[epoch], abs=1e-6)
|
||||
if check_logs:
|
||||
# Run detailed logs file check only on CPU, it will contain slightly different metrics on GPU, but here
|
||||
# we want to mostly assert that the files look reasonable
|
||||
if not machine_has_gpu:
|
||||
# Check log EPOCH_METRICS_FILE_NAME
|
||||
epoch_metrics_path = config.outputs_folder / ModelExecutionMode.TRAIN.value / EPOCH_METRICS_FILE_NAME
|
||||
# Auto-format will break the long header line, hence the strange way of writing it!
|
||||
|
@ -140,12 +139,6 @@ def check_log_file(path: Path, expected_csv: str, ignore_columns: List[str]) ->
|
|||
pd.testing.assert_frame_equal(df_expected, df_epoch_metrics_actual, check_less_precise=True)
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
@pytest.mark.skipif(no_gpu, reason="CUDA capable GPU is not available")
|
||||
def test_train_classification_model_with_amp(test_output_dirs: TestOutputDirectories) -> None:
|
||||
test_train_classification_model(test_output_dirs, True, check_logs=False)
|
||||
|
||||
|
||||
@pytest.mark.skipif(common_util.is_windows(), reason="Too slow on windows")
|
||||
@pytest.mark.parametrize("model_name", ["DummyClassification", "DummyRegression"])
|
||||
@pytest.mark.parametrize("number_of_offline_cross_validation_splits", [2])
|
||||
|
|
|
@ -2,24 +2,34 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn import Identity
|
||||
|
||||
from InnerEye.Common import common_util
|
||||
from InnerEye.Common.common_util import MetricsDataframeLoggers
|
||||
from InnerEye.Common.output_directories import TestOutputDirectories
|
||||
from InnerEye.ML.common import ModelExecutionMode
|
||||
from InnerEye.ML.config import SegmentationModelBase
|
||||
from InnerEye.ML.configs.classification.DummyClassification import DummyClassification
|
||||
from InnerEye.ML.deep_learning_config import DeepLearningConfig
|
||||
from InnerEye.ML.model_training import model_train
|
||||
from InnerEye.ML.model_training_steps import ModelTrainingStepsForScalarModel, TrainValidateParameters, \
|
||||
get_scalar_model_inputs_and_labels
|
||||
from InnerEye.ML.models.architectures.base_model import BaseModel, CropSizeConstraints
|
||||
from InnerEye.ML.models.parallel.data_parallel import DataParallelModel
|
||||
from InnerEye.ML.pipelines.forward_pass import SegmentationForwardPass
|
||||
from InnerEye.ML.utils import ml_util, model_util
|
||||
from InnerEye.ML.utils.io_util import ImageDataType
|
||||
from InnerEye.ML.utils.metrics_util import SummaryWriters
|
||||
from InnerEye.ML.utils.model_util import ModelAndInfo, create_model_with_temperature_scaling
|
||||
from Tests.ML.configs.ClassificationModelForTesting import ClassificationModelForTesting
|
||||
from Tests.ML.models.architectures.DummyScalarModel import DummyScalarModel
|
||||
from Tests.ML.util import machine_has_gpu, no_gpu_available
|
||||
|
||||
|
||||
|
@ -112,7 +122,7 @@ def test_amp_activated(use_model_parallel: bool,
|
|||
execution_mode: ModelExecutionMode,
|
||||
use_mixed_precision: bool) -> None:
|
||||
"""
|
||||
Tests the amp flag both for True and False states. Verifys that the mixed precision training functions as expected.
|
||||
Tests the mix precision flag and the model parallel flag.
|
||||
"""
|
||||
assert machine_has_gpu, "This test must be executed on a GPU machine."
|
||||
assert torch.cuda.device_count() > 1, "This test must be executed on a multi-GPU machine"
|
||||
|
@ -137,9 +147,9 @@ def test_amp_activated(use_model_parallel: bool,
|
|||
optimizer = model_util.create_optimizer(model_config, model)
|
||||
model_and_info = ModelAndInfo(model, optimizer)
|
||||
try:
|
||||
model_and_info_amp = model_util.update_model_for_mixed_precision_and_parallel(model_and_info,
|
||||
model_config,
|
||||
execution_mode)
|
||||
model_and_info_amp = model_util.update_model_for_multiple_gpus(model_and_info,
|
||||
model_config,
|
||||
execution_mode)
|
||||
except NotImplementedError as ex:
|
||||
if use_model_parallel:
|
||||
# The SimpleModel does not implement model partitioning, and should hence fail at this step.
|
||||
|
@ -148,20 +158,29 @@ def test_amp_activated(use_model_parallel: bool,
|
|||
else:
|
||||
raise ValueError(f"Expected this call to succeed, but got: {ex}")
|
||||
|
||||
# Check if the optimizer is updated with AMP mixed precision features. The attribute should be present
|
||||
# if and only if mixed precision is switched on.
|
||||
optimizer_amp = model_and_info_amp.optimizer
|
||||
assert optimizer_amp is not None
|
||||
assert hasattr(optimizer_amp, '_amp_stash') == use_mixed_precision
|
||||
assert hasattr(optimizer_amp, '_post_amp_backward') == use_mixed_precision
|
||||
|
||||
# This is the same logic spelt out in update_model_for_multiple_gpu
|
||||
use_data_parallel = (execution_mode == ModelExecutionMode.TRAIN) or (not use_model_parallel)
|
||||
if use_data_parallel:
|
||||
assert isinstance(model_and_info.model, DataParallelModel)
|
||||
gradient_scaler = GradScaler() if use_mixed_precision else None
|
||||
criterion = lambda x, y: torch.tensor([0.0], requires_grad=True).cuda()
|
||||
pipeline = SegmentationForwardPass(model_and_info_amp.model,
|
||||
model_config,
|
||||
batch_size=1,
|
||||
optimizer=optimizer_amp,
|
||||
optimizer=optimizer,
|
||||
gradient_scaler=gradient_scaler,
|
||||
criterion=criterion)
|
||||
|
||||
logits, _ = pipeline._compute_loss(image, labels)
|
||||
# When using DataParallel, we expect to get a list of tensors back, one per GPU.
|
||||
if use_data_parallel:
|
||||
assert isinstance(logits, list)
|
||||
first_logit = logits[0]
|
||||
else:
|
||||
first_logit = logits
|
||||
if use_mixed_precision:
|
||||
assert first_logit.dtype == torch.float16
|
||||
else:
|
||||
assert first_logit.dtype == torch.float32
|
||||
# Verify that forward and backward passes do not throw an exception
|
||||
pipeline._forward_pass(patches=image, mask=mask, labels=labels)
|
||||
|
||||
|
@ -195,6 +214,7 @@ def test_mean_teacher_model() -> None:
|
|||
"""
|
||||
Test training and weight updates of the mean teacher model computation.
|
||||
"""
|
||||
|
||||
def _get_parameters_of_model(model: Union[torch.nn.Module, DataParallelModel]) -> Any:
|
||||
"""
|
||||
Returns the iterator of model parameters
|
||||
|
@ -243,3 +263,64 @@ def test_mean_teacher_model() -> None:
|
|||
|
||||
# Check the update of the parameters
|
||||
assert torch.all(alpha * initial_weight_mean_teacher_model + (1 - alpha) * student_model_weight == result_weight)
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
@pytest.mark.skipif(no_gpu_available, reason="Testing AMP requires a GPU")
|
||||
@pytest.mark.parametrize("use_mixed_precision", [False, True])
|
||||
@pytest.mark.parametrize("execution_mode", [ModelExecutionMode.TRAIN, ModelExecutionMode.VAL])
|
||||
def test_amp_and_parallel_for_scalar_models(test_output_dirs: TestOutputDirectories,
|
||||
execution_mode: ModelExecutionMode,
|
||||
use_mixed_precision: bool) -> None:
|
||||
"""
|
||||
Tests the mix precision flag and data parallel for scalar models.
|
||||
"""
|
||||
assert machine_has_gpu, "This test must be executed on a GPU machine."
|
||||
assert torch.cuda.device_count() > 1, "This test must be executed on a multi-GPU machine"
|
||||
config = ClassificationModelForTesting()
|
||||
config.use_mixed_precision = use_mixed_precision
|
||||
model = DummyScalarModel(expected_image_size_zyx=config.expected_image_size_zyx,
|
||||
activation=Identity())
|
||||
model.use_mixed_precision = use_mixed_precision
|
||||
model_and_info = ModelAndInfo(
|
||||
model=model,
|
||||
model_execution_mode=execution_mode
|
||||
)
|
||||
# This is the same logic spelt out in update_model_for_multiple_gpu
|
||||
# execution_mode == ModelExecutionMode.TRAIN or (not use_model_parallel), which is always True in our case
|
||||
use_data_parallel = True
|
||||
model_and_info = model_util.update_model_for_multiple_gpus(model_and_info, config)
|
||||
if use_data_parallel:
|
||||
assert isinstance(model_and_info.model, DataParallelModel)
|
||||
data_loaders = config.create_data_loaders()
|
||||
gradient_scaler = GradScaler() if use_mixed_precision else None
|
||||
train_val_parameters: TrainValidateParameters = TrainValidateParameters(
|
||||
model=model_and_info.model,
|
||||
data_loader=data_loaders[execution_mode],
|
||||
in_training_mode=execution_mode == ModelExecutionMode.TRAIN,
|
||||
gradient_scaler=gradient_scaler,
|
||||
dataframe_loggers=MetricsDataframeLoggers(Path(test_output_dirs.root_dir)),
|
||||
summary_writers=SummaryWriters(train=None, val=None) # type: ignore
|
||||
)
|
||||
training_steps = ModelTrainingStepsForScalarModel(config, train_val_parameters)
|
||||
sample = list(data_loaders[execution_mode])[0]
|
||||
model_input = get_scalar_model_inputs_and_labels(config, model, sample)
|
||||
logits, posteriors, loss = training_steps._compute_model_output_and_loss(model_input)
|
||||
# When using DataParallel, we expect to get a list of tensors back, one per GPU.
|
||||
if use_data_parallel:
|
||||
assert isinstance(logits, list)
|
||||
first_logit = logits[0]
|
||||
else:
|
||||
first_logit = logits
|
||||
if use_mixed_precision:
|
||||
assert first_logit.dtype == torch.float16
|
||||
assert posteriors.dtype == torch.float16
|
||||
# BCEWithLogitsLoss outputs float32, even with float16 args
|
||||
assert loss.dtype == torch.float32
|
||||
else:
|
||||
assert first_logit.dtype == torch.float32
|
||||
assert posteriors.dtype == torch.float32
|
||||
assert loss.dtype == torch.float32
|
||||
# Verify that forward pass does not throw. It would for example if it fails to gather tensors or not convert
|
||||
# float16 to float32
|
||||
_, _, _ = training_steps._compute_model_output_and_loss(model_input)
|
||||
|
|
|
@ -91,7 +91,7 @@ def test_invalid_stride_size() -> None:
|
|||
)
|
||||
with pytest.raises(ValueError) as ex:
|
||||
model = create_model_with_temperature_scaling(config)
|
||||
model_util.update_model_for_mixed_precision_and_parallel(ModelAndInfo(model), config)
|
||||
model_util.update_model_for_multiple_gpus(ModelAndInfo(model), config)
|
||||
assert "inference stride size must be smaller" in ex.value.args[0]
|
||||
assert str(config.inference_stride_size) in ex.value.args[0]
|
||||
assert str(config.test_crop_size) in ex.value.args[0]
|
||||
|
|
|
@ -16,7 +16,7 @@ from InnerEye.ML.models.architectures.base_model import DeviceAwareModule
|
|||
from InnerEye.ML.pipelines.scalar_inference import ScalarEnsemblePipeline, ScalarInferencePipeline, \
|
||||
ScalarInferencePipelineBase
|
||||
from InnerEye.ML.scalar_config import EnsembleAggregationType
|
||||
from InnerEye.ML.utils.model_util import ModelAndInfo, update_model_for_mixed_precision_and_parallel
|
||||
from InnerEye.ML.utils.model_util import ModelAndInfo, update_model_for_multiple_gpus
|
||||
from Tests.ML.configs.ClassificationModelForTesting import ClassificationModelForTesting
|
||||
from Tests.fixed_paths_for_tests import full_ml_test_data_path
|
||||
|
||||
|
@ -106,9 +106,9 @@ class ScalarOnesModel(DeviceAwareModule[ScalarItem, torch.Tensor]):
|
|||
def test_predict_non_ensemble(batch_size: int, empty_labels: bool) -> None:
|
||||
config = ClassificationModelForTesting()
|
||||
model: Any = ScalarOnesModel(config.expected_image_size_zyx, 1.)
|
||||
update_model_for_mixed_precision_and_parallel(ModelAndInfo(model),
|
||||
args=config,
|
||||
execution_mode=ModelExecutionMode.TEST)
|
||||
update_model_for_multiple_gpus(ModelAndInfo(model),
|
||||
args=config,
|
||||
execution_mode=ModelExecutionMode.TEST)
|
||||
pipeline = ScalarInferencePipeline(model, config, 0, 0)
|
||||
actual_labels = torch.zeros((batch_size, 1)) * np.nan if empty_labels else torch.zeros((batch_size, 1))
|
||||
data = {"metadata": [GeneralSampleMetadata(id='2')] * batch_size,
|
||||
|
@ -131,13 +131,13 @@ def test_predict_ensemble(batch_size: int) -> None:
|
|||
config = ClassificationModelForTesting()
|
||||
model_returns_0: Any = ScalarOnesModel(config.expected_image_size_zyx, 0.)
|
||||
model_returns_1: Any = ScalarOnesModel(config.expected_image_size_zyx, 1.)
|
||||
model_and_opt_0 = update_model_for_mixed_precision_and_parallel(ModelAndInfo(model_returns_0),
|
||||
args=config,
|
||||
execution_mode=ModelExecutionMode.TEST)
|
||||
model_and_opt_0 = update_model_for_multiple_gpus(ModelAndInfo(model_returns_0),
|
||||
args=config,
|
||||
execution_mode=ModelExecutionMode.TEST)
|
||||
model_returns_0 = model_and_opt_0.model
|
||||
model_and_opt_1 = update_model_for_mixed_precision_and_parallel(ModelAndInfo(model_returns_1),
|
||||
args=config,
|
||||
execution_mode=ModelExecutionMode.TEST)
|
||||
model_and_opt_1 = update_model_for_multiple_gpus(ModelAndInfo(model_returns_1),
|
||||
args=config,
|
||||
execution_mode=ModelExecutionMode.TEST)
|
||||
model_returns_1 = model_and_opt_1.model
|
||||
pipeline_0 = ScalarInferencePipeline(model_returns_0, config, 0, 0)
|
||||
pipeline_1 = ScalarInferencePipeline(model_returns_0, config, 0, 1)
|
||||
|
|
|
@ -2,57 +2,64 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
from typing import List, Tuple, Any, Optional
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from torch.optim import lr_scheduler
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR, ExponentialLR, LambdaLR, MultiStepLR, \
|
||||
StepLR, _LRScheduler
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from torch.optim.lr_scheduler import ExponentialLR, StepLR, MultiStepLR, LambdaLR, CosineAnnealingLR, _LRScheduler
|
||||
|
||||
from InnerEye.ML.config import SegmentationModelBase
|
||||
from InnerEye.ML.deep_learning_config import LRSchedulerType, LRWarmUpType
|
||||
from InnerEye.ML.utils.lr_scheduler import LRScheduler
|
||||
from InnerEye.ML.deep_learning_config import DeepLearningConfig, LRSchedulerType, LRWarmUpType
|
||||
from InnerEye.ML.utils.lr_scheduler import SchedulerWithWarmUp
|
||||
from Tests.ML.configs.DummyModel import DummyModel
|
||||
|
||||
|
||||
def enumerate_scheduler(scheduler: _LRScheduler, steps: int) -> List[float]:
|
||||
"""
|
||||
Reads the current learning rate via get_last_lr, run 1 scheduler step, and repeat. Returns the LR values.
|
||||
"""
|
||||
lrs = []
|
||||
for _ in range(steps):
|
||||
lr = scheduler.get_last_lr() # type: ignore
|
||||
assert isinstance(lr, list)
|
||||
assert len(lr) == 1
|
||||
lrs.append(lr[0])
|
||||
scheduler.step()
|
||||
return lrs
|
||||
|
||||
|
||||
def test_create_lr_scheduler_last_epoch() -> None:
|
||||
"""
|
||||
Test to check if the lr scheduler is initialized to the correct epoch
|
||||
"""
|
||||
expected_lrs_per_epoch = [0.001, 0.0005358867312681466]
|
||||
l_rate = 1e-3
|
||||
gamma = 0.5
|
||||
total_epochs = 5
|
||||
expected_lrs_per_epoch = [l_rate * (gamma ** i) for i in range(total_epochs)]
|
||||
config = DummyModel()
|
||||
config.l_rate = l_rate
|
||||
config.l_rate_scheduler = LRSchedulerType.Step
|
||||
config.l_rate_step_step_size = 1
|
||||
config.l_rate_step_gamma = gamma
|
||||
# create lr scheduler
|
||||
lr_scheduler, optimizer = _create_lr_scheduler_and_optimizer(config)
|
||||
initial_scheduler, initial_optimizer = _create_lr_scheduler_and_optimizer(config)
|
||||
# check lr scheduler initialization step
|
||||
assert np.isclose(lr_scheduler.get_last_lr(), expected_lrs_per_epoch[:1])
|
||||
initial_epochs = 3
|
||||
assert np.allclose(enumerate_scheduler(initial_scheduler, initial_epochs), expected_lrs_per_epoch[:initial_epochs])
|
||||
# create lr scheduler for recovery checkpoint
|
||||
config.start_epoch = 1
|
||||
lr_scheduler, _ = _create_lr_scheduler_and_optimizer(config, optimizer)
|
||||
config.start_epoch = initial_epochs
|
||||
recovery_scheduler, recovery_optimizer = _create_lr_scheduler_and_optimizer(config)
|
||||
# Both the scheduler and the optimizer need to be loaded from the checkpoint.
|
||||
recovery_scheduler.load_state_dict(initial_scheduler.state_dict())
|
||||
recovery_optimizer.load_state_dict(initial_optimizer.state_dict())
|
||||
assert recovery_scheduler.last_epoch == config.start_epoch
|
||||
# check lr scheduler initialization matches the checkpoint epoch
|
||||
# as training will start for start_epoch + 1 in this case
|
||||
lr = lr_scheduler.get_last_lr()
|
||||
assert np.isclose(lr, expected_lrs_per_epoch[1:])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("lr_scheduler_type", [x for x in LRSchedulerType])
|
||||
def test_min_and_initial_lr(lr_scheduler_type: LRSchedulerType) -> None:
|
||||
"""
|
||||
Test if minimum learning rate threshold is applied as expected
|
||||
"""
|
||||
config = DummyModel(num_epochs=2, l_rate=1e-3, min_l_rate=0.0009,
|
||||
l_rate_scheduler=lr_scheduler_type,
|
||||
l_rate_exponential_gamma=0.9,
|
||||
l_rate_step_gamma=0.9,
|
||||
l_rate_step_step_size=1,
|
||||
l_rate_multi_step_gamma=0.7,
|
||||
l_rate_multi_step_milestones=[1],
|
||||
l_rate_polynomial_gamma=0.9)
|
||||
# create lr scheduler
|
||||
lr_scheduler, _ = _create_lr_scheduler_and_optimizer(config)
|
||||
assert lr_scheduler.get_last_lr()[0] == config.l_rate
|
||||
lr_scheduler.step(2)
|
||||
assert lr_scheduler.get_last_lr()[0] == config.min_l_rate
|
||||
assert np.allclose(enumerate_scheduler(recovery_scheduler, 2), expected_lrs_per_epoch[initial_epochs:])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("lr_scheduler_type", [x for x in LRSchedulerType])
|
||||
|
@ -74,21 +81,18 @@ def test_lr_monotonically_decreasing_function(lr_scheduler_type: LRSchedulerType
|
|||
|
||||
# create lr scheduler
|
||||
lr_scheduler, _ = _create_lr_scheduler_and_optimizer(config)
|
||||
lr_list = []
|
||||
for _ in range(config.num_epochs):
|
||||
lr_scheduler.step()
|
||||
lr_list.append(lr_scheduler.get_last_lr()[0])
|
||||
|
||||
lr_list = enumerate_scheduler(lr_scheduler, config.num_epochs)
|
||||
assert non_increasing(lr_list)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("lr_scheduler_type", [x for x in LRSchedulerType])
|
||||
@pytest.mark.parametrize("warmup_epochs", [0, 4, 5])
|
||||
@pytest.mark.parametrize("warmup_epochs", [0, 3])
|
||||
def test_warmup_against_original_schedule(lr_scheduler_type: LRSchedulerType, warmup_epochs: int) -> None:
|
||||
"""
|
||||
Tests if LR scheduler with warmup matches the Pytorch implementation after the warmup stage is completed.
|
||||
"""
|
||||
config = DummyModel(num_epochs=10,
|
||||
config = DummyModel(num_epochs=6,
|
||||
l_rate=1e-2,
|
||||
l_rate_scheduler=lr_scheduler_type,
|
||||
l_rate_exponential_gamma=0.9,
|
||||
l_rate_step_gamma=0.9,
|
||||
|
@ -99,43 +103,52 @@ def test_warmup_against_original_schedule(lr_scheduler_type: LRSchedulerType, wa
|
|||
l_rate_warmup=LRWarmUpType.Linear if warmup_epochs > 0 else LRWarmUpType.NoWarmUp,
|
||||
l_rate_warmup_epochs=warmup_epochs)
|
||||
# create lr scheduler
|
||||
lr_scheduler, optimizer = _create_lr_scheduler_and_optimizer(config)
|
||||
lr_scheduler, optimizer1 = _create_lr_scheduler_and_optimizer(config)
|
||||
|
||||
original_scheduler: Optional[_LRScheduler] = None
|
||||
optimizer2 = _create_dummy_optimizer(config)
|
||||
# This mimics the code in SchedulerWithWarmUp.get_scheduler and must be in sync
|
||||
if lr_scheduler_type == LRSchedulerType.Exponential:
|
||||
original_scheduler = ExponentialLR(optimizer=optimizer, gamma=config.l_rate_exponential_gamma)
|
||||
original_scheduler = ExponentialLR(optimizer=optimizer2, gamma=config.l_rate_exponential_gamma)
|
||||
elif lr_scheduler_type == LRSchedulerType.Step:
|
||||
original_scheduler = StepLR(optimizer=optimizer, step_size=config.l_rate_step_step_size,
|
||||
original_scheduler = StepLR(optimizer=optimizer2, step_size=config.l_rate_step_step_size,
|
||||
gamma=config.l_rate_step_gamma)
|
||||
elif lr_scheduler_type == LRSchedulerType.Cosine:
|
||||
original_scheduler = CosineAnnealingLR(optimizer, T_max=config.num_epochs, eta_min=config.min_l_rate)
|
||||
original_scheduler = CosineAnnealingLR(optimizer2, T_max=config.num_epochs, eta_min=config.min_l_rate)
|
||||
elif lr_scheduler_type == LRSchedulerType.MultiStep:
|
||||
assert config.l_rate_multi_step_milestones is not None # for mypy
|
||||
original_scheduler = MultiStepLR(optimizer=optimizer, milestones=config.l_rate_multi_step_milestones,
|
||||
original_scheduler = MultiStepLR(optimizer=optimizer2, milestones=config.l_rate_multi_step_milestones,
|
||||
gamma=config.l_rate_multi_step_gamma)
|
||||
elif lr_scheduler_type == LRSchedulerType.Polynomial:
|
||||
x = config.min_l_rate / config.l_rate
|
||||
polynomial_decay: Any = lambda epoch: (1 - x) * (
|
||||
(1. - float(epoch) / config.num_epochs) ** config.l_rate_polynomial_gamma) + x
|
||||
original_scheduler = LambdaLR(optimizer=optimizer, lr_lambda=polynomial_decay)
|
||||
original_scheduler = LambdaLR(optimizer=optimizer2, lr_lambda=polynomial_decay)
|
||||
else:
|
||||
raise ValueError("Scheduler has not been added to this test.")
|
||||
|
||||
result_lr_list = []
|
||||
for _ in range(config.num_epochs):
|
||||
result_lr_list.append(lr_scheduler.get_last_lr()[0])
|
||||
lr_scheduler.step()
|
||||
|
||||
expected_lr_list = []
|
||||
for i in range(warmup_epochs):
|
||||
expected_lr_list.append(config.l_rate * (i + 1) / warmup_epochs)
|
||||
for _ in range(config.num_epochs - warmup_epochs):
|
||||
# For pytorch version 1.6:
|
||||
# expected_lr_list.append(original_scheduler.get_last_lr())
|
||||
expected_lr_list.append(original_scheduler.get_lr()[0]) # type: ignore
|
||||
original_scheduler.step() # type: ignore
|
||||
if warmup_epochs == 0:
|
||||
pass
|
||||
elif warmup_epochs == 3:
|
||||
# For the first config.l_rate_warmup_epochs, the learning rate is lower than the initial learning rate by a
|
||||
# linear factor
|
||||
expected_lr_list.extend([f * config.l_rate for f in [0.25, 0.5, 0.75]])
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
expected_lr_list.extend(enumerate_scheduler(original_scheduler, config.num_epochs - warmup_epochs))
|
||||
print(f"Expected schedule with warmup: {expected_lr_list}")
|
||||
|
||||
assert np.allclose(result_lr_list, expected_lr_list)
|
||||
lr_with_warmup_scheduler = enumerate_scheduler(lr_scheduler, config.num_epochs)
|
||||
print(f"Actual schedule: {lr_with_warmup_scheduler}")
|
||||
|
||||
if ((lr_scheduler_type == LRSchedulerType.Polynomial or lr_scheduler_type == LRSchedulerType.Cosine)
|
||||
and warmup_epochs > 0):
|
||||
# Polynomial and Cosine scheduler will be squashed in time because the number of epochs is reduced
|
||||
# (both schedulers take a "length of training" argument, and that is now shorter). Skip comparing those.
|
||||
pass
|
||||
else:
|
||||
assert np.allclose(lr_with_warmup_scheduler, expected_lr_list, rtol=1e-5)
|
||||
|
||||
|
||||
def _create_dummy_optimizer(config: SegmentationModelBase) -> Optimizer:
|
||||
|
@ -143,24 +156,79 @@ def _create_dummy_optimizer(config: SegmentationModelBase) -> Optimizer:
|
|||
|
||||
|
||||
def _create_lr_scheduler_and_optimizer(config: SegmentationModelBase, optimizer: Optimizer = None) \
|
||||
-> Tuple[LRScheduler, Optimizer]:
|
||||
-> Tuple[SchedulerWithWarmUp, Optimizer]:
|
||||
# create dummy optimizer
|
||||
if optimizer is None:
|
||||
optimizer = _create_dummy_optimizer(config)
|
||||
# create lr scheduler
|
||||
lr_scheduler = LRScheduler(config, optimizer)
|
||||
lr_scheduler = SchedulerWithWarmUp(config, optimizer)
|
||||
return lr_scheduler, optimizer
|
||||
|
||||
|
||||
@pytest.mark.parametrize("lr_scheduler_type", [x for x in LRSchedulerType])
|
||||
# This construct is to work around an issue where mypy does not think that MultiplicativeLR exists in lr_scheduler
|
||||
def multiplicative(optimizer: Optimizer) -> _LRScheduler:
|
||||
return lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lambda epoch: 0.5) # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.parametrize("scheduler_func, expected_values",
|
||||
# A scheduler that reduces learning rate by a factor of 0.5 in each epoch
|
||||
[(multiplicative, [1, 0.5, 0.25, 0.125, 0.0625]),
|
||||
# A scheduler that reduces learning rate by a factor of 0.5 at epochs 2 and 4
|
||||
(lambda optimizer: MultiStepLR(optimizer, [2, 4], gamma=0.5),
|
||||
[1, 1, 0.5, 0.5, 0.25]),
|
||||
(lambda optimizer: MultiStepLR(optimizer, [1, 2, 3, 4, 5], gamma=0.5),
|
||||
[1, 0.5, 0.25, 0.125, 0.0625])
|
||||
])
|
||||
def test_built_in_lr_scheduler(scheduler_func: Callable[[Optimizer], _LRScheduler],
|
||||
expected_values: List[float]) -> None:
|
||||
"""
|
||||
A test to check that the behaviour of the built-in learning rate schedulers is still what we think it is.
|
||||
"""
|
||||
initial_lr = 1
|
||||
optimizer = torch.optim.Adam([torch.ones(2, 2, requires_grad=True)], lr=initial_lr)
|
||||
scheduler = scheduler_func(optimizer)
|
||||
lrs = []
|
||||
for _ in range(5):
|
||||
last_lr = scheduler.get_last_lr() # type: ignore
|
||||
lrs.append(last_lr)
|
||||
# get_last_lr should not change the state when called twice
|
||||
assert scheduler.get_last_lr() == last_lr # type: ignore
|
||||
scheduler.step()
|
||||
# Expected behaviour: First LR should be the initial LR set in the optimizers.
|
||||
assert lrs == [[v] for v in expected_values]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("warmup_epochs, expected_values",
|
||||
[(0, [1, 1, 0.5, 0.5]),
|
||||
(1, [0.5, 1, 1, 0.5]),
|
||||
(2, [1 / 3, 2 / 3, 1, 1])])
|
||||
def test_lr_scheduler_with_warmup(warmup_epochs: int, expected_values: List[float]) -> None:
|
||||
"""
|
||||
Check that warmup is applied correctly to a multistep scheduler
|
||||
"""
|
||||
initial_lr = 1
|
||||
optimizer = torch.optim.Adam([torch.ones(2, 2, requires_grad=True)], lr=initial_lr)
|
||||
config = DeepLearningConfig(l_rate=initial_lr,
|
||||
l_rate_scheduler=LRSchedulerType.MultiStep,
|
||||
l_rate_multi_step_milestones=[2, 4],
|
||||
l_rate_multi_step_gamma=0.5,
|
||||
l_rate_warmup_epochs=warmup_epochs,
|
||||
l_rate_warmup=LRWarmUpType.Linear,
|
||||
should_validate=False)
|
||||
scheduler = SchedulerWithWarmUp(config, optimizer)
|
||||
lrs = enumerate_scheduler(scheduler, 4)
|
||||
assert lrs == expected_values
|
||||
|
||||
|
||||
# Exclude Polynomial scheduler because that uses lambdas, which we can't save to a state dict
|
||||
@pytest.mark.parametrize("lr_scheduler_type", [x for x in LRSchedulerType if x != LRSchedulerType.Polynomial])
|
||||
@pytest.mark.parametrize("warmup_epochs", [0, 3, 4, 5])
|
||||
@pytest.mark.parametrize("restart_from_epoch", [4])
|
||||
def test_resume_from_saved_state(lr_scheduler_type: LRSchedulerType,
|
||||
warmup_epochs: int, restart_from_epoch: int) -> None:
|
||||
def test_resume_from_saved_state(lr_scheduler_type: LRSchedulerType, warmup_epochs: int) -> None:
|
||||
"""
|
||||
Tests if LR scheduler when restarted from an epoch continues as expected.
|
||||
"""
|
||||
config = DummyModel(num_epochs=10,
|
||||
restart_from_epoch = 4
|
||||
config = DummyModel(num_epochs=7,
|
||||
l_rate_scheduler=lr_scheduler_type,
|
||||
l_rate_exponential_gamma=0.9,
|
||||
l_rate_step_gamma=0.9,
|
||||
|
@ -170,67 +238,44 @@ def test_resume_from_saved_state(lr_scheduler_type: LRSchedulerType,
|
|||
l_rate_polynomial_gamma=0.9,
|
||||
l_rate_warmup=LRWarmUpType.Linear if warmup_epochs > 0 else LRWarmUpType.NoWarmUp,
|
||||
l_rate_warmup_epochs=warmup_epochs)
|
||||
# create two lr schedulers
|
||||
lr_scheduler_1, optimizer_1 = _create_lr_scheduler_and_optimizer(config)
|
||||
lr_scheduler_2, optimizer_2 = _create_lr_scheduler_and_optimizer(config)
|
||||
|
||||
expected_lr_list = []
|
||||
for _ in range(config.num_epochs):
|
||||
expected_lr_list.append(lr_scheduler_2.get_last_lr()[0])
|
||||
lr_scheduler_2.step()
|
||||
|
||||
result_lr_list = []
|
||||
for _ in range(restart_from_epoch):
|
||||
result_lr_list.append(lr_scheduler_1.get_last_lr()[0])
|
||||
lr_scheduler_1.step()
|
||||
# This scheduler mimics what happens if we train for the full set of epochs
|
||||
scheduler_all_epochs, _ = _create_lr_scheduler_and_optimizer(config)
|
||||
expected_lr_list = enumerate_scheduler(scheduler_all_epochs, config.num_epochs)
|
||||
|
||||
# Create a scheduler where training will be recovered
|
||||
scheduler1, optimizer1 = _create_lr_scheduler_and_optimizer(config)
|
||||
# Scheduler 1 is only run for 4 epochs, and then "restarted" to train the rest of the epochs.
|
||||
result_lr_list = enumerate_scheduler(scheduler1, restart_from_epoch)
|
||||
# resume state: This just means setting start_epoch in the config
|
||||
config.start_epoch = restart_from_epoch
|
||||
lr_scheduler_resume, _ = _create_lr_scheduler_and_optimizer(config, optimizer_1)
|
||||
for _ in range(config.num_epochs - restart_from_epoch):
|
||||
result_lr_list.append(lr_scheduler_resume.get_last_lr()[0])
|
||||
lr_scheduler_resume.step()
|
||||
|
||||
assert result_lr_list == expected_lr_list
|
||||
scheduler_resume, optimizer_resume = _create_lr_scheduler_and_optimizer(config)
|
||||
# Load a "checkpoint" for both scheduler and optimizer
|
||||
scheduler_resume.load_state_dict(scheduler1.state_dict())
|
||||
optimizer_resume.load_state_dict(optimizer1.state_dict())
|
||||
result_lr_list.extend(enumerate_scheduler(scheduler_resume, config.num_epochs - restart_from_epoch))
|
||||
print(f"Actual schedule: {result_lr_list}")
|
||||
print(f"Expected schedule: {expected_lr_list}")
|
||||
assert len(result_lr_list) == len(expected_lr_list)
|
||||
assert np.allclose(result_lr_list, expected_lr_list)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("lr_scheduler_type", [x for x in LRSchedulerType])
|
||||
def test_save_and_load_state_dict(lr_scheduler_type: LRSchedulerType) -> None:
|
||||
|
||||
def object_dict_same(lr1: LRScheduler, lr2: LRScheduler) -> bool:
|
||||
def object_dict_same(lr1: SchedulerWithWarmUp, lr2: SchedulerWithWarmUp) -> bool:
|
||||
"""
|
||||
Tests to see if two LRScheduler objects are the same.
|
||||
This ignores lambdas if one of the schedulers is LambdaLR, since lambdas are not stored to the state dict.
|
||||
"""
|
||||
# dict of object LR scheduler
|
||||
# ignore the _scheduler attribute, which is of type SchedulerWithWarmUp and compare it separately
|
||||
dict1 = {key: val for key, val in lr1.__dict__.items() if key != "_scheduler"}
|
||||
dict2 = {key: val for key, val in lr2.__dict__.items() if key != "_scheduler"}
|
||||
# ignore the _scheduler and _warmup objects, compare those separately
|
||||
dict1 = {key: val for key, val in lr1.__dict__.items() if key != "_scheduler" and key != "_warmup"}
|
||||
dict2 = {key: val for key, val in lr2.__dict__.items() if key != "_scheduler" and key != "_warmup"}
|
||||
|
||||
# see if the SchedulerWithWarmUp object is the same
|
||||
warmup_and_scheduler1 = lr1.__dict__["_scheduler"]
|
||||
warmup_and_scheduler2 = lr2.__dict__["_scheduler"]
|
||||
|
||||
# scheduler object
|
||||
scheduler1 = warmup_and_scheduler1.__dict__["_scheduler"]
|
||||
scheduler2 = warmup_and_scheduler2.__dict__["_scheduler"]
|
||||
# remove lambdas from scheduler dict
|
||||
scheduler1_dict = {key: val for key, val in scheduler1.__dict__.items() if key != "lr_lambdas"}
|
||||
scheduler2_dict = {key: val for key, val in scheduler2.__dict__.items() if key != "lr_lambdas"}
|
||||
|
||||
# warmup object
|
||||
warmup1 = warmup_and_scheduler1.__dict__["_warmup_scheduler"]
|
||||
warmup2 = warmup_and_scheduler2.__dict__["_warmup_scheduler"]
|
||||
|
||||
# Other variables in the object SchedulerWithWarmUp
|
||||
other_variables1 = {key: val for key, val in warmup_and_scheduler1.__dict__.items()
|
||||
if key != "_scheduler" and key != "_warmup_scheduler"}
|
||||
other_variables2 = {key: val for key, val in warmup_and_scheduler2.__dict__.items()
|
||||
if key != "_scheduler" and key != "_warmup_scheduler"}
|
||||
|
||||
return dict1 == dict2 and other_variables1 == other_variables2 and \
|
||||
scheduler1_dict == scheduler2_dict and \
|
||||
warmup1.__dict__ == warmup2.__dict__
|
||||
# see if the underlying scheduler object is the same
|
||||
scheduler1_dict = {key: val for key, val in lr1._scheduler.__dict__.items() if key != "lr_lambdas"}
|
||||
scheduler2_dict = {key: val for key, val in lr2._scheduler.__dict__.items() if key != "lr_lambdas"}
|
||||
warmup1_dict = lr1._warmup.__dict__
|
||||
warmup2_dict = lr2._warmup.__dict__
|
||||
return dict1 == dict2 and scheduler1_dict == scheduler2_dict and warmup1_dict == warmup2_dict
|
||||
|
||||
config = DummyModel(num_epochs=10,
|
||||
l_rate_scheduler=lr_scheduler_type,
|
||||
|
@ -271,32 +316,36 @@ def test_cosine_decay_function() -> None:
|
|||
# create lr scheduler
|
||||
test_epoch = 5
|
||||
lr_scheduler, _ = _create_lr_scheduler_and_optimizer(config)
|
||||
lr_scheduler.step(test_epoch)
|
||||
for _ in range(test_epoch):
|
||||
lr_scheduler.step()
|
||||
assert lr_scheduler.get_last_lr()[0] == 0.5 * config.l_rate
|
||||
|
||||
|
||||
@pytest.mark.parametrize("warmup_epochs, expected_lrs",
|
||||
[(0, np.array([1e-3, 1e-3, 1e-4, 1e-4, 1e-4, 1e-5, 1e-5, 1e-6, 1e-6, 1e-6])),
|
||||
(5, np.array([2e-4, 4e-4, 6e-4, 8e-4, 1e-3, 1e-3, 1e-3, 1e-4, 1e-4, 1e-4]))])
|
||||
def test_multistep_lr(warmup_epochs: int, expected_lrs: np.ndarray) -> None:
|
||||
"""
|
||||
Creates a MultiStep LR and check values are returned as expected
|
||||
"""
|
||||
|
||||
num_epochs = 10
|
||||
def test_multistep_lr() -> None:
|
||||
l_rate = 0.3
|
||||
config = DummyModel(l_rate_scheduler=LRSchedulerType.MultiStep,
|
||||
l_rate=l_rate,
|
||||
l_rate_multi_step_gamma=0.1,
|
||||
num_epochs=num_epochs,
|
||||
l_rate_multi_step_milestones=[2, 5, 7],
|
||||
l_rate_warmup=LRWarmUpType.Linear if warmup_epochs > 0 else LRWarmUpType.NoWarmUp,
|
||||
l_rate_warmup_epochs=warmup_epochs)
|
||||
num_epochs=10,
|
||||
l_rate_multi_step_milestones=[2],
|
||||
l_rate_warmup=LRWarmUpType.Linear,
|
||||
l_rate_warmup_epochs=5)
|
||||
|
||||
# create lr scheduler
|
||||
lr_scheduler, optimizer = _create_lr_scheduler_and_optimizer(config)
|
||||
def check_warmup(expected: List[float]) -> None:
|
||||
scheduler, _ = _create_lr_scheduler_and_optimizer(config)
|
||||
actual = enumerate_scheduler(scheduler, 4)
|
||||
assert actual == expected
|
||||
|
||||
lrs = []
|
||||
for _ in range(num_epochs):
|
||||
lrs.append(lr_scheduler.get_last_lr()[0])
|
||||
lr_scheduler.step()
|
||||
# No warmup: multi-step LR with milestone after 2 epochs
|
||||
original_schedule = [l_rate, l_rate, l_rate * 0.1, l_rate * 0.1]
|
||||
config.l_rate_warmup = LRWarmUpType.Linear
|
||||
config.l_rate_warmup_epochs = 0
|
||||
check_warmup(original_schedule)
|
||||
|
||||
assert np.allclose(lrs, expected_lrs)
|
||||
# 1 epoch warmup: linear function up to the initial learning rate gives a warmup value of half the initial LR
|
||||
config.l_rate_warmup_epochs = 1
|
||||
check_warmup([l_rate * 0.5] + original_schedule[:3])
|
||||
|
||||
# 2 epochs warmup
|
||||
config.l_rate_warmup_epochs = 2
|
||||
check_warmup([l_rate / 3, l_rate * 2 / 3] + original_schedule[:2])
|
||||
|
|
|
@ -5,15 +5,10 @@ channels:
|
|||
dependencies:
|
||||
- pip=20.1.1
|
||||
- python=3.7.3
|
||||
- pytorch=1.3.0
|
||||
- pytorch=1.6.0
|
||||
- python-blosc==1.7.0
|
||||
- torchvision=0.7.0
|
||||
- pip:
|
||||
# We need a recent apex, as nvidia-apex=0.1 from conda-forge suffers from
|
||||
# https://github.com/NVIDIA/apex/issues/552
|
||||
# However this stops installation on CPU-only machines, and the alleged fix does not fix it:
|
||||
# https://github.com/NVIDIA/apex/issues/931
|
||||
# Commit a0d99fd, from 2020-07-09, seems to avoid both bugs.
|
||||
- git+https://github.com/NVIDIA/apex.git@a0d99fd#egg=apex
|
||||
- git+https://github.com/analysiscenter/radio.git@6d53e25#egg=radio
|
||||
- azureml-mlflow==1.12.0
|
||||
- azureml-sdk==1.12.0
|
||||
|
@ -55,5 +50,4 @@ dependencies:
|
|||
- tensorboard==2.3.0
|
||||
- tensorboardX==2.1
|
||||
- torchprof==1.1.1
|
||||
- torchvision==0.4.1
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче