зеркало из https://github.com/microsoft/ptgnn.git
Option to catch occasional CUDA OOMs to allow for more robust training.
This commit is contained in:
Родитель
bf72e7f18a
Коммит
ef13a9fec0
|
@ -12,6 +12,7 @@ from typing import Callable, Dict, Generic, Iterable, List, Optional, TypeVar
|
|||
from ptgnn.baseneuralmodel.abstractneuralmodel import AbstractNeuralModel
|
||||
from ptgnn.baseneuralmodel.modulewithmetrics import ModuleWithMetrics
|
||||
from ptgnn.baseneuralmodel.utils.data import MemorizedDataIterable
|
||||
from ptgnn.baseneuralmodel.utils.oom import catch_cuda_oom
|
||||
|
||||
TRawDatapoint = TypeVar("TRawDatapoint")
|
||||
TTensorizedDatapoint = TypeVar("TTensorizedDatapoint")
|
||||
|
@ -52,6 +53,7 @@ class ModelTrainer(Generic[TRawDatapoint, TTensorizedDatapoint, TNeuralModule]):
|
|||
target_validation_metric: Optional[str] = None,
|
||||
target_validation_metric_higher_is_better: bool = False,
|
||||
enable_amp: bool = False,
|
||||
catch_cuda_ooms: bool = False,
|
||||
):
|
||||
"""
|
||||
:param model: The Component to be built and trained
|
||||
|
@ -64,6 +66,15 @@ class ModelTrainer(Generic[TRawDatapoint, TTensorizedDatapoint, TNeuralModule]):
|
|||
:param scheduler_creator: An optional function that accepts an optimizer and creates a scheduler
|
||||
implementing `AbstractScheduler`. This could be a wrapper for existing learning schedulers.
|
||||
The scheduler will be invoked at after each training step.
|
||||
:param clip_gradient_norm: An optional norm for clipping the gradient norms during training.
|
||||
:param target_validation_metric: An optional string of the name of the metric (returned by
|
||||
the TNeuralModule) which is used to detect if the model performance improved in validation.
|
||||
This is used for early stopping, and checkpointing the best model. If `None` the model
|
||||
loss (value returned from `forward()` of TNeuralModule) is used.
|
||||
:param target_validation_metric_higher_is_better: if `True` increases to `target_validation_metric`
|
||||
imply improvements. Ignored if `target_validation_metric` is `None`.
|
||||
:param enable_amp: Enable automatic mixed precision during training.
|
||||
:param catch_cuda_ooms: Catch CUDA out-of-memory errors (OOM) and resume training when they happen.
|
||||
"""
|
||||
self.__model = model
|
||||
self.__neural_network: Optional[TNeuralModule] = None
|
||||
|
@ -87,6 +98,7 @@ class ModelTrainer(Generic[TRawDatapoint, TTensorizedDatapoint, TNeuralModule]):
|
|||
self._improved_epoch_end_hooks: List[EndOfEpochHook] = []
|
||||
self._clip_gradient_norm = clip_gradient_norm
|
||||
self._enable_amp = enable_amp
|
||||
self._catch_cuda_ooms = catch_cuda_ooms
|
||||
|
||||
self._target_metric = target_validation_metric
|
||||
if target_validation_metric is not None:
|
||||
|
@ -203,40 +215,41 @@ class ModelTrainer(Generic[TRawDatapoint, TTensorizedDatapoint, TNeuralModule]):
|
|||
)
|
||||
):
|
||||
optimizer.zero_grad()
|
||||
with torch.cuda.amp.autocast(enabled=self._enable_amp):
|
||||
mb_loss = self.neural_module(**mb_data)
|
||||
with catch_cuda_oom(self._catch_cuda_ooms):
|
||||
with torch.cuda.amp.autocast(enabled=self._enable_amp):
|
||||
mb_loss = self.neural_module(**mb_data)
|
||||
|
||||
scaler.scale(mb_loss).backward()
|
||||
scaler.scale(mb_loss).backward()
|
||||
|
||||
if torch.isnan(mb_loss):
|
||||
raise Exception("Loss has a NaN value.")
|
||||
if torch.isnan(mb_loss):
|
||||
raise Exception("Loss has a NaN value.")
|
||||
|
||||
if self._clip_gradient_norm is not None:
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.neural_module.parameters(recurse=True), self._clip_gradient_norm
|
||||
)
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
if scheduler is not None:
|
||||
scheduler.step(epoch_idx=epoch, epoch_step=step_idx)
|
||||
|
||||
num_minibatches += 1
|
||||
num_samples += len(raw_samples)
|
||||
with torch.no_grad():
|
||||
sum_epoch_loss += mb_loss
|
||||
if show_progress_bar:
|
||||
mb_loss = float(mb_loss)
|
||||
if num_minibatches == 1: # First minibatch
|
||||
running_avg_loss = mb_loss
|
||||
else:
|
||||
running_avg_loss = (
|
||||
exponential_running_average_factor * running_avg_loss
|
||||
+ (1 - exponential_running_average_factor) * mb_loss
|
||||
if self._clip_gradient_norm is not None:
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.neural_module.parameters(recurse=True), self._clip_gradient_norm
|
||||
)
|
||||
progress_bar.update()
|
||||
progress_bar.set_postfix(Loss=f"{running_avg_loss:.2f}")
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
if scheduler is not None:
|
||||
scheduler.step(epoch_idx=epoch, epoch_step=step_idx)
|
||||
|
||||
num_minibatches += 1
|
||||
num_samples += len(raw_samples)
|
||||
with torch.no_grad():
|
||||
sum_epoch_loss += mb_loss
|
||||
if show_progress_bar:
|
||||
mb_loss = float(mb_loss)
|
||||
if num_minibatches == 1: # First minibatch
|
||||
running_avg_loss = mb_loss
|
||||
else:
|
||||
running_avg_loss = (
|
||||
exponential_running_average_factor * running_avg_loss
|
||||
+ (1 - exponential_running_average_factor) * mb_loss
|
||||
)
|
||||
progress_bar.update()
|
||||
progress_bar.set_postfix(Loss=f"{running_avg_loss:.2f}")
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
self.LOGGER.info(
|
||||
|
@ -275,14 +288,17 @@ class ModelTrainer(Generic[TRawDatapoint, TTensorizedDatapoint, TNeuralModule]):
|
|||
shuffle_input=False,
|
||||
parallelize=parallelize,
|
||||
):
|
||||
with torch.cuda.amp.autocast(enabled=self._enable_amp):
|
||||
mb_loss = self.neural_module(**mb_data)
|
||||
num_minibatches += 1
|
||||
num_samples += len(raw_samples)
|
||||
sum_epoch_loss += mb_loss
|
||||
if show_progress_bar:
|
||||
progress_bar.update()
|
||||
progress_bar.set_postfix(Loss=f"{float(sum_epoch_loss) / num_minibatches:.2f}")
|
||||
with catch_cuda_oom(self._catch_cuda_ooms):
|
||||
with torch.cuda.amp.autocast(enabled=self._enable_amp):
|
||||
mb_loss = self.neural_module(**mb_data)
|
||||
num_minibatches += 1
|
||||
num_samples += len(raw_samples)
|
||||
sum_epoch_loss += mb_loss
|
||||
if show_progress_bar:
|
||||
progress_bar.update()
|
||||
progress_bar.set_postfix(
|
||||
Loss=f"{float(sum_epoch_loss) / num_minibatches:.2f}"
|
||||
)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
assert num_samples > 0, "No validation data was found."
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
from typing_extensions import Final
|
||||
|
||||
import logging
|
||||
import torch
|
||||
from contextlib import contextmanager
|
||||
|
||||
LOGGER: Final = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def catch_cuda_oom(enabled: bool = True):
|
||||
if enabled:
|
||||
try:
|
||||
yield
|
||||
except RuntimeError as re:
|
||||
if "CUDA out of memory." in repr(re):
|
||||
LOGGER.exception("CUDA Out-Of-Memory Caught and Execution Resumed.", exc_info=re)
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
raise re
|
||||
else:
|
||||
yield
|
|
@ -17,7 +17,6 @@ Options:
|
|||
-h --help Show this screen.
|
||||
--debug Enable debug routines. [default: False]
|
||||
"""
|
||||
import logging
|
||||
import random
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
Загрузка…
Ссылка в новой задаче