зеркало из https://github.com/microsoft/ptgnn.git
Option to catch occasional CUDA OOMs to allow for more robust training.
This commit is contained in:
Родитель
bf72e7f18a
Коммит
ef13a9fec0
|
@ -12,6 +12,7 @@ from typing import Callable, Dict, Generic, Iterable, List, Optional, TypeVar
|
||||||
from ptgnn.baseneuralmodel.abstractneuralmodel import AbstractNeuralModel
|
from ptgnn.baseneuralmodel.abstractneuralmodel import AbstractNeuralModel
|
||||||
from ptgnn.baseneuralmodel.modulewithmetrics import ModuleWithMetrics
|
from ptgnn.baseneuralmodel.modulewithmetrics import ModuleWithMetrics
|
||||||
from ptgnn.baseneuralmodel.utils.data import MemorizedDataIterable
|
from ptgnn.baseneuralmodel.utils.data import MemorizedDataIterable
|
||||||
|
from ptgnn.baseneuralmodel.utils.oom import catch_cuda_oom
|
||||||
|
|
||||||
TRawDatapoint = TypeVar("TRawDatapoint")
|
TRawDatapoint = TypeVar("TRawDatapoint")
|
||||||
TTensorizedDatapoint = TypeVar("TTensorizedDatapoint")
|
TTensorizedDatapoint = TypeVar("TTensorizedDatapoint")
|
||||||
|
@ -52,6 +53,7 @@ class ModelTrainer(Generic[TRawDatapoint, TTensorizedDatapoint, TNeuralModule]):
|
||||||
target_validation_metric: Optional[str] = None,
|
target_validation_metric: Optional[str] = None,
|
||||||
target_validation_metric_higher_is_better: bool = False,
|
target_validation_metric_higher_is_better: bool = False,
|
||||||
enable_amp: bool = False,
|
enable_amp: bool = False,
|
||||||
|
catch_cuda_ooms: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
:param model: The Component to be built and trained
|
:param model: The Component to be built and trained
|
||||||
|
@ -64,6 +66,15 @@ class ModelTrainer(Generic[TRawDatapoint, TTensorizedDatapoint, TNeuralModule]):
|
||||||
:param scheduler_creator: An optional function that accepts an optimizer and creates a scheduler
|
:param scheduler_creator: An optional function that accepts an optimizer and creates a scheduler
|
||||||
implementing `AbstractScheduler`. This could be a wrapper for existing learning schedulers.
|
implementing `AbstractScheduler`. This could be a wrapper for existing learning schedulers.
|
||||||
The scheduler will be invoked at after each training step.
|
The scheduler will be invoked at after each training step.
|
||||||
|
:param clip_gradient_norm: An optional norm for clipping the gradient norms during training.
|
||||||
|
:param target_validation_metric: An optional string of the name of the metric (returned by
|
||||||
|
the TNeuralModule) which is used to detect if the model performance improved in validation.
|
||||||
|
This is used for early stopping, and checkpointing the best model. If `None` the model
|
||||||
|
loss (value returned from `forward()` of TNeuralModule) is used.
|
||||||
|
:param target_validation_metric_higher_is_better: if `True` increases to `target_validation_metric`
|
||||||
|
imply improvements. Ignored if `target_validation_metric` is `None`.
|
||||||
|
:param enable_amp: Enable automatic mixed precision during training.
|
||||||
|
:param catch_cuda_ooms: Catch CUDA out-of-memory errors (OOM) and resume training when they happen.
|
||||||
"""
|
"""
|
||||||
self.__model = model
|
self.__model = model
|
||||||
self.__neural_network: Optional[TNeuralModule] = None
|
self.__neural_network: Optional[TNeuralModule] = None
|
||||||
|
@ -87,6 +98,7 @@ class ModelTrainer(Generic[TRawDatapoint, TTensorizedDatapoint, TNeuralModule]):
|
||||||
self._improved_epoch_end_hooks: List[EndOfEpochHook] = []
|
self._improved_epoch_end_hooks: List[EndOfEpochHook] = []
|
||||||
self._clip_gradient_norm = clip_gradient_norm
|
self._clip_gradient_norm = clip_gradient_norm
|
||||||
self._enable_amp = enable_amp
|
self._enable_amp = enable_amp
|
||||||
|
self._catch_cuda_ooms = catch_cuda_ooms
|
||||||
|
|
||||||
self._target_metric = target_validation_metric
|
self._target_metric = target_validation_metric
|
||||||
if target_validation_metric is not None:
|
if target_validation_metric is not None:
|
||||||
|
@ -203,40 +215,41 @@ class ModelTrainer(Generic[TRawDatapoint, TTensorizedDatapoint, TNeuralModule]):
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
with torch.cuda.amp.autocast(enabled=self._enable_amp):
|
with catch_cuda_oom(self._catch_cuda_ooms):
|
||||||
mb_loss = self.neural_module(**mb_data)
|
with torch.cuda.amp.autocast(enabled=self._enable_amp):
|
||||||
|
mb_loss = self.neural_module(**mb_data)
|
||||||
|
|
||||||
scaler.scale(mb_loss).backward()
|
scaler.scale(mb_loss).backward()
|
||||||
|
|
||||||
if torch.isnan(mb_loss):
|
if torch.isnan(mb_loss):
|
||||||
raise Exception("Loss has a NaN value.")
|
raise Exception("Loss has a NaN value.")
|
||||||
|
|
||||||
if self._clip_gradient_norm is not None:
|
if self._clip_gradient_norm is not None:
|
||||||
scaler.unscale_(optimizer)
|
scaler.unscale_(optimizer)
|
||||||
torch.nn.utils.clip_grad_norm_(
|
torch.nn.utils.clip_grad_norm_(
|
||||||
self.neural_module.parameters(recurse=True), self._clip_gradient_norm
|
self.neural_module.parameters(recurse=True), self._clip_gradient_norm
|
||||||
)
|
|
||||||
|
|
||||||
scaler.step(optimizer)
|
|
||||||
scaler.update()
|
|
||||||
if scheduler is not None:
|
|
||||||
scheduler.step(epoch_idx=epoch, epoch_step=step_idx)
|
|
||||||
|
|
||||||
num_minibatches += 1
|
|
||||||
num_samples += len(raw_samples)
|
|
||||||
with torch.no_grad():
|
|
||||||
sum_epoch_loss += mb_loss
|
|
||||||
if show_progress_bar:
|
|
||||||
mb_loss = float(mb_loss)
|
|
||||||
if num_minibatches == 1: # First minibatch
|
|
||||||
running_avg_loss = mb_loss
|
|
||||||
else:
|
|
||||||
running_avg_loss = (
|
|
||||||
exponential_running_average_factor * running_avg_loss
|
|
||||||
+ (1 - exponential_running_average_factor) * mb_loss
|
|
||||||
)
|
)
|
||||||
progress_bar.update()
|
|
||||||
progress_bar.set_postfix(Loss=f"{running_avg_loss:.2f}")
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
if scheduler is not None:
|
||||||
|
scheduler.step(epoch_idx=epoch, epoch_step=step_idx)
|
||||||
|
|
||||||
|
num_minibatches += 1
|
||||||
|
num_samples += len(raw_samples)
|
||||||
|
with torch.no_grad():
|
||||||
|
sum_epoch_loss += mb_loss
|
||||||
|
if show_progress_bar:
|
||||||
|
mb_loss = float(mb_loss)
|
||||||
|
if num_minibatches == 1: # First minibatch
|
||||||
|
running_avg_loss = mb_loss
|
||||||
|
else:
|
||||||
|
running_avg_loss = (
|
||||||
|
exponential_running_average_factor * running_avg_loss
|
||||||
|
+ (1 - exponential_running_average_factor) * mb_loss
|
||||||
|
)
|
||||||
|
progress_bar.update()
|
||||||
|
progress_bar.set_postfix(Loss=f"{running_avg_loss:.2f}")
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
self.LOGGER.info(
|
self.LOGGER.info(
|
||||||
|
@ -275,14 +288,17 @@ class ModelTrainer(Generic[TRawDatapoint, TTensorizedDatapoint, TNeuralModule]):
|
||||||
shuffle_input=False,
|
shuffle_input=False,
|
||||||
parallelize=parallelize,
|
parallelize=parallelize,
|
||||||
):
|
):
|
||||||
with torch.cuda.amp.autocast(enabled=self._enable_amp):
|
with catch_cuda_oom(self._catch_cuda_ooms):
|
||||||
mb_loss = self.neural_module(**mb_data)
|
with torch.cuda.amp.autocast(enabled=self._enable_amp):
|
||||||
num_minibatches += 1
|
mb_loss = self.neural_module(**mb_data)
|
||||||
num_samples += len(raw_samples)
|
num_minibatches += 1
|
||||||
sum_epoch_loss += mb_loss
|
num_samples += len(raw_samples)
|
||||||
if show_progress_bar:
|
sum_epoch_loss += mb_loss
|
||||||
progress_bar.update()
|
if show_progress_bar:
|
||||||
progress_bar.set_postfix(Loss=f"{float(sum_epoch_loss) / num_minibatches:.2f}")
|
progress_bar.update()
|
||||||
|
progress_bar.set_postfix(
|
||||||
|
Loss=f"{float(sum_epoch_loss) / num_minibatches:.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
assert num_samples > 0, "No validation data was found."
|
assert num_samples > 0, "No validation data was found."
|
||||||
|
|
|
@ -0,0 +1,22 @@
|
||||||
|
from typing_extensions import Final
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
LOGGER: Final = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def catch_cuda_oom(enabled: bool = True):
|
||||||
|
if enabled:
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
except RuntimeError as re:
|
||||||
|
if "CUDA out of memory." in repr(re):
|
||||||
|
LOGGER.exception("CUDA Out-Of-Memory Caught and Execution Resumed.", exc_info=re)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
else:
|
||||||
|
raise re
|
||||||
|
else:
|
||||||
|
yield
|
|
@ -17,7 +17,6 @@ Options:
|
||||||
-h --help Show this screen.
|
-h --help Show this screen.
|
||||||
--debug Enable debug routines. [default: False]
|
--debug Enable debug routines. [default: False]
|
||||||
"""
|
"""
|
||||||
import logging
|
|
||||||
import random
|
import random
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
Загрузка…
Ссылка в новой задаче