This commit is contained in:
Miltos Allamanis 2021-07-14 12:40:59 +01:00
Родитель 6edb8aeeba
Коммит ce0d8f6814
1 изменённых файлов: 12 добавлений и 12 удалений

Просмотреть файл

@ -206,21 +206,21 @@ class ModelTrainer(Generic[TRawDatapoint, TTensorizedDatapoint, TNeuralModule]):
with torch.cuda.amp.autocast(enabled=self._enable_amp):
mb_loss = self.neural_module(**mb_data)
scaler.scale(mb_loss).backward()
scaler.scale(mb_loss).backward()
if torch.isnan(mb_loss):
raise Exception("Loss has a NaN value.")
if torch.isnan(mb_loss):
raise Exception("Loss has a NaN value.")
if self._clip_gradient_norm is not None:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(
self.neural_module.parameters(recurse=True), self._clip_gradient_norm
)
if self._clip_gradient_norm is not None:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(
self.neural_module.parameters(recurse=True), self._clip_gradient_norm
)
scaler.step(optimizer)
scaler.update()
if scheduler is not None:
scheduler.step(epoch_idx=epoch, epoch_step=step_idx)
scaler.step(optimizer)
scaler.update()
if scheduler is not None:
scheduler.step(epoch_idx=epoch, epoch_step=step_idx)
num_minibatches += 1
num_samples += len(raw_samples)