diff --git a/ptgnn/baseneuralmodel/trainer.py b/ptgnn/baseneuralmodel/trainer.py index f429d06..144e465 100644 --- a/ptgnn/baseneuralmodel/trainer.py +++ b/ptgnn/baseneuralmodel/trainer.py @@ -12,6 +12,7 @@ from typing import Callable, Dict, Generic, Iterable, List, Optional, TypeVar from ptgnn.baseneuralmodel.abstractneuralmodel import AbstractNeuralModel from ptgnn.baseneuralmodel.modulewithmetrics import ModuleWithMetrics from ptgnn.baseneuralmodel.utils.data import MemorizedDataIterable +from ptgnn.baseneuralmodel.utils.oom import catch_cuda_oom TRawDatapoint = TypeVar("TRawDatapoint") TTensorizedDatapoint = TypeVar("TTensorizedDatapoint") @@ -52,6 +53,7 @@ class ModelTrainer(Generic[TRawDatapoint, TTensorizedDatapoint, TNeuralModule]): target_validation_metric: Optional[str] = None, target_validation_metric_higher_is_better: bool = False, enable_amp: bool = False, + catch_cuda_ooms: bool = False, ): """ :param model: The Component to be built and trained @@ -64,6 +66,15 @@ class ModelTrainer(Generic[TRawDatapoint, TTensorizedDatapoint, TNeuralModule]): :param scheduler_creator: An optional function that accepts an optimizer and creates a scheduler implementing `AbstractScheduler`. This could be a wrapper for existing learning schedulers. The scheduler will be invoked at after each training step. + :param clip_gradient_norm: An optional norm for clipping the gradient norms during training. + :param target_validation_metric: An optional string of the name of the metric (returned by + the TNeuralModule) which is used to detect if the model performance improved in validation. + This is used for early stopping, and checkpointing the best model. If `None` the model + loss (value returned from `forward()` of TNeuralModule) is used. + :param target_validation_metric_higher_is_better: if `True` increases to `target_validation_metric` + imply improvements. Ignored if `target_validation_metric` is `None`. + :param enable_amp: Enable automatic mixed precision during training. + :param catch_cuda_ooms: Catch CUDA out-of-memory errors (OOM) and resume training when they happen. """ self.__model = model self.__neural_network: Optional[TNeuralModule] = None @@ -87,6 +98,7 @@ class ModelTrainer(Generic[TRawDatapoint, TTensorizedDatapoint, TNeuralModule]): self._improved_epoch_end_hooks: List[EndOfEpochHook] = [] self._clip_gradient_norm = clip_gradient_norm self._enable_amp = enable_amp + self._catch_cuda_ooms = catch_cuda_ooms self._target_metric = target_validation_metric if target_validation_metric is not None: @@ -203,40 +215,41 @@ class ModelTrainer(Generic[TRawDatapoint, TTensorizedDatapoint, TNeuralModule]): ) ): optimizer.zero_grad() - with torch.cuda.amp.autocast(enabled=self._enable_amp): - mb_loss = self.neural_module(**mb_data) + with catch_cuda_oom(self._catch_cuda_ooms): + with torch.cuda.amp.autocast(enabled=self._enable_amp): + mb_loss = self.neural_module(**mb_data) - scaler.scale(mb_loss).backward() + scaler.scale(mb_loss).backward() - if torch.isnan(mb_loss): - raise Exception("Loss has a NaN value.") + if torch.isnan(mb_loss): + raise Exception("Loss has a NaN value.") - if self._clip_gradient_norm is not None: - scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_( - self.neural_module.parameters(recurse=True), self._clip_gradient_norm - ) - - scaler.step(optimizer) - scaler.update() - if scheduler is not None: - scheduler.step(epoch_idx=epoch, epoch_step=step_idx) - - num_minibatches += 1 - num_samples += len(raw_samples) - with torch.no_grad(): - sum_epoch_loss += mb_loss - if show_progress_bar: - mb_loss = float(mb_loss) - if num_minibatches == 1: # First minibatch - running_avg_loss = mb_loss - else: - running_avg_loss = ( - exponential_running_average_factor * running_avg_loss - + (1 - exponential_running_average_factor) * mb_loss + if self._clip_gradient_norm is not None: + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_( + self.neural_module.parameters(recurse=True), self._clip_gradient_norm ) - progress_bar.update() - progress_bar.set_postfix(Loss=f"{running_avg_loss:.2f}") + + scaler.step(optimizer) + scaler.update() + if scheduler is not None: + scheduler.step(epoch_idx=epoch, epoch_step=step_idx) + + num_minibatches += 1 + num_samples += len(raw_samples) + with torch.no_grad(): + sum_epoch_loss += mb_loss + if show_progress_bar: + mb_loss = float(mb_loss) + if num_minibatches == 1: # First minibatch + running_avg_loss = mb_loss + else: + running_avg_loss = ( + exponential_running_average_factor * running_avg_loss + + (1 - exponential_running_average_factor) * mb_loss + ) + progress_bar.update() + progress_bar.set_postfix(Loss=f"{running_avg_loss:.2f}") elapsed_time = time.time() - start_time self.LOGGER.info( @@ -275,14 +288,17 @@ class ModelTrainer(Generic[TRawDatapoint, TTensorizedDatapoint, TNeuralModule]): shuffle_input=False, parallelize=parallelize, ): - with torch.cuda.amp.autocast(enabled=self._enable_amp): - mb_loss = self.neural_module(**mb_data) - num_minibatches += 1 - num_samples += len(raw_samples) - sum_epoch_loss += mb_loss - if show_progress_bar: - progress_bar.update() - progress_bar.set_postfix(Loss=f"{float(sum_epoch_loss) / num_minibatches:.2f}") + with catch_cuda_oom(self._catch_cuda_ooms): + with torch.cuda.amp.autocast(enabled=self._enable_amp): + mb_loss = self.neural_module(**mb_data) + num_minibatches += 1 + num_samples += len(raw_samples) + sum_epoch_loss += mb_loss + if show_progress_bar: + progress_bar.update() + progress_bar.set_postfix( + Loss=f"{float(sum_epoch_loss) / num_minibatches:.2f}" + ) elapsed_time = time.time() - start_time assert num_samples > 0, "No validation data was found." diff --git a/ptgnn/baseneuralmodel/utils/oom.py b/ptgnn/baseneuralmodel/utils/oom.py new file mode 100644 index 0000000..f30d394 --- /dev/null +++ b/ptgnn/baseneuralmodel/utils/oom.py @@ -0,0 +1,22 @@ +from typing_extensions import Final + +import logging +import torch +from contextlib import contextmanager + +LOGGER: Final = logging.getLogger(__name__) + + +@contextmanager +def catch_cuda_oom(enabled: bool = True): + if enabled: + try: + yield + except RuntimeError as re: + if "CUDA out of memory." in repr(re): + LOGGER.exception("CUDA Out-Of-Memory Caught and Execution Resumed.", exc_info=re) + torch.cuda.empty_cache() + else: + raise re + else: + yield diff --git a/ptgnn/implementations/typilus/traindistributed.py b/ptgnn/implementations/typilus/traindistributed.py index 5a743ff..e24175f 100755 --- a/ptgnn/implementations/typilus/traindistributed.py +++ b/ptgnn/implementations/typilus/traindistributed.py @@ -17,7 +17,6 @@ Options: -h --help Show this screen. --debug Enable debug routines. [default: False] """ -import logging import random import torch import torch.distributed as dist