Option to catch occasional CUDA OOMs to allow for more robust training.

This commit is contained in:
Miltos Allamanis 2021-09-10 18:59:55 +01:00 коммит произвёл Miltos
Родитель bf72e7f18a
Коммит ef13a9fec0
3 изменённых файлов: 76 добавлений и 39 удалений

Просмотреть файл

@ -12,6 +12,7 @@ from typing import Callable, Dict, Generic, Iterable, List, Optional, TypeVar
from ptgnn.baseneuralmodel.abstractneuralmodel import AbstractNeuralModel
from ptgnn.baseneuralmodel.modulewithmetrics import ModuleWithMetrics
from ptgnn.baseneuralmodel.utils.data import MemorizedDataIterable
from ptgnn.baseneuralmodel.utils.oom import catch_cuda_oom
TRawDatapoint = TypeVar("TRawDatapoint")
TTensorizedDatapoint = TypeVar("TTensorizedDatapoint")
@ -52,6 +53,7 @@ class ModelTrainer(Generic[TRawDatapoint, TTensorizedDatapoint, TNeuralModule]):
target_validation_metric: Optional[str] = None,
target_validation_metric_higher_is_better: bool = False,
enable_amp: bool = False,
catch_cuda_ooms: bool = False,
):
"""
:param model: The Component to be built and trained
@ -64,6 +66,15 @@ class ModelTrainer(Generic[TRawDatapoint, TTensorizedDatapoint, TNeuralModule]):
:param scheduler_creator: An optional function that accepts an optimizer and creates a scheduler
implementing `AbstractScheduler`. This could be a wrapper for existing learning schedulers.
The scheduler will be invoked at after each training step.
:param clip_gradient_norm: An optional norm for clipping the gradient norms during training.
:param target_validation_metric: An optional string of the name of the metric (returned by
the TNeuralModule) which is used to detect if the model performance improved in validation.
This is used for early stopping, and checkpointing the best model. If `None` the model
loss (value returned from `forward()` of TNeuralModule) is used.
:param target_validation_metric_higher_is_better: if `True` increases to `target_validation_metric`
imply improvements. Ignored if `target_validation_metric` is `None`.
:param enable_amp: Enable automatic mixed precision during training.
:param catch_cuda_ooms: Catch CUDA out-of-memory errors (OOM) and resume training when they happen.
"""
self.__model = model
self.__neural_network: Optional[TNeuralModule] = None
@ -87,6 +98,7 @@ class ModelTrainer(Generic[TRawDatapoint, TTensorizedDatapoint, TNeuralModule]):
self._improved_epoch_end_hooks: List[EndOfEpochHook] = []
self._clip_gradient_norm = clip_gradient_norm
self._enable_amp = enable_amp
self._catch_cuda_ooms = catch_cuda_ooms
self._target_metric = target_validation_metric
if target_validation_metric is not None:
@ -203,40 +215,41 @@ class ModelTrainer(Generic[TRawDatapoint, TTensorizedDatapoint, TNeuralModule]):
)
):
optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=self._enable_amp):
mb_loss = self.neural_module(**mb_data)
with catch_cuda_oom(self._catch_cuda_ooms):
with torch.cuda.amp.autocast(enabled=self._enable_amp):
mb_loss = self.neural_module(**mb_data)
scaler.scale(mb_loss).backward()
scaler.scale(mb_loss).backward()
if torch.isnan(mb_loss):
raise Exception("Loss has a NaN value.")
if torch.isnan(mb_loss):
raise Exception("Loss has a NaN value.")
if self._clip_gradient_norm is not None:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(
self.neural_module.parameters(recurse=True), self._clip_gradient_norm
)
scaler.step(optimizer)
scaler.update()
if scheduler is not None:
scheduler.step(epoch_idx=epoch, epoch_step=step_idx)
num_minibatches += 1
num_samples += len(raw_samples)
with torch.no_grad():
sum_epoch_loss += mb_loss
if show_progress_bar:
mb_loss = float(mb_loss)
if num_minibatches == 1: # First minibatch
running_avg_loss = mb_loss
else:
running_avg_loss = (
exponential_running_average_factor * running_avg_loss
+ (1 - exponential_running_average_factor) * mb_loss
if self._clip_gradient_norm is not None:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(
self.neural_module.parameters(recurse=True), self._clip_gradient_norm
)
progress_bar.update()
progress_bar.set_postfix(Loss=f"{running_avg_loss:.2f}")
scaler.step(optimizer)
scaler.update()
if scheduler is not None:
scheduler.step(epoch_idx=epoch, epoch_step=step_idx)
num_minibatches += 1
num_samples += len(raw_samples)
with torch.no_grad():
sum_epoch_loss += mb_loss
if show_progress_bar:
mb_loss = float(mb_loss)
if num_minibatches == 1: # First minibatch
running_avg_loss = mb_loss
else:
running_avg_loss = (
exponential_running_average_factor * running_avg_loss
+ (1 - exponential_running_average_factor) * mb_loss
)
progress_bar.update()
progress_bar.set_postfix(Loss=f"{running_avg_loss:.2f}")
elapsed_time = time.time() - start_time
self.LOGGER.info(
@ -275,14 +288,17 @@ class ModelTrainer(Generic[TRawDatapoint, TTensorizedDatapoint, TNeuralModule]):
shuffle_input=False,
parallelize=parallelize,
):
with torch.cuda.amp.autocast(enabled=self._enable_amp):
mb_loss = self.neural_module(**mb_data)
num_minibatches += 1
num_samples += len(raw_samples)
sum_epoch_loss += mb_loss
if show_progress_bar:
progress_bar.update()
progress_bar.set_postfix(Loss=f"{float(sum_epoch_loss) / num_minibatches:.2f}")
with catch_cuda_oom(self._catch_cuda_ooms):
with torch.cuda.amp.autocast(enabled=self._enable_amp):
mb_loss = self.neural_module(**mb_data)
num_minibatches += 1
num_samples += len(raw_samples)
sum_epoch_loss += mb_loss
if show_progress_bar:
progress_bar.update()
progress_bar.set_postfix(
Loss=f"{float(sum_epoch_loss) / num_minibatches:.2f}"
)
elapsed_time = time.time() - start_time
assert num_samples > 0, "No validation data was found."

Просмотреть файл

@ -0,0 +1,22 @@
from typing_extensions import Final
import logging
import torch
from contextlib import contextmanager
LOGGER: Final = logging.getLogger(__name__)
@contextmanager
def catch_cuda_oom(enabled: bool = True):
if enabled:
try:
yield
except RuntimeError as re:
if "CUDA out of memory." in repr(re):
LOGGER.exception("CUDA Out-Of-Memory Caught and Execution Resumed.", exc_info=re)
torch.cuda.empty_cache()
else:
raise re
else:
yield

Просмотреть файл

@ -17,7 +17,6 @@ Options:
-h --help Show this screen.
--debug Enable debug routines. [default: False]
"""
import logging
import random
import torch
import torch.distributed as dist