[trainer] move secondary methods into a separate file (#10363)

* move secondary methods into a separate file

* cleanup

* style
This commit is contained in:
Stas Bekman 2021-02-24 08:32:52 -08:00 коммит произвёл GitHub
Родитель 5f2a3d721c
Коммит bdbb2c756b
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 91 добавлений и 82 удалений

Просмотреть файл

@ -19,7 +19,6 @@ The Trainer class, to easily train a 🤗 Transformers from scratch or finetune
import collections
import gc
import inspect
import json
import math
import os
import re
@ -82,7 +81,6 @@ from .trainer_pt_utils import (
SequentialDistributedSampler,
distributed_broadcast_scalars,
distributed_concat,
get_learning_rate,
nested_concat,
nested_detach,
nested_numpify,
@ -226,6 +224,8 @@ class Trainer:
"""
from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics
def __init__(
self,
model: Union[PreTrainedModel, torch.nn.Module] = None,
@ -1130,7 +1130,7 @@ class Trainer:
tr_loss -= tr_loss
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
logs["learning_rate"] = get_learning_rate(self)
logs["learning_rate"] = self._get_learning_rate()
self._total_loss_scalar += tr_loss_scalar
self._globalstep_last_logged = self.state.global_step
@ -1345,61 +1345,6 @@ class Trainer:
self.state.log_history.append(output)
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
def metrics_format(self, metrics: Dict[str, float]) -> Dict[str, float]:
"""
Reformat Trainer metrics values to a human-readable format
Args:
metrics (:obj:`Dict[str, float]`):
The metrics returned from train/evaluate/predict
Returns:
metrics (:obj:`Dict[str, float]`): The reformatted metrics
"""
metrics_copy = metrics.copy()
for k, v in metrics_copy.items():
if "_mem_" in k:
metrics_copy[k] = f"{ v >> 20 }MB"
elif k == "total_flos":
metrics_copy[k] = f"{ int(v) >> 30 }GF"
elif type(metrics_copy[k]) == float:
metrics_copy[k] = round(v, 4)
return metrics_copy
def log_metrics(self, split, metrics):
"""
Log metrics in a specially formatted way
Args:
split (:obj:`str`):
Mode/split name: one of ``train``, ``eval``, ``test``
metrics (:obj:`Dict[str, float]`):
The metrics returned from train/evaluate/predictmetrics: metrics dict
"""
logger.info(f"***** {split} metrics *****")
metrics_formatted = self.metrics_format(metrics)
k_width = max(len(str(x)) for x in metrics_formatted.keys())
v_width = max(len(str(x)) for x in metrics_formatted.values())
for key in sorted(metrics_formatted.keys()):
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
def save_metrics(self, split, metrics):
"""
Save metrics into a json file for that split, e.g. ``train_results.json``.
Args:
split (:obj:`str`):
Mode/split name: one of ``train``, ``eval``, ``test``, ``all``
metrics (:obj:`Dict[str, float]`):
The metrics returned from train/evaluate/predict
"""
path = os.path.join(self.args.output_dir, f"{split}_results.json")
with open(path, "w") as f:
json.dump(metrics, f, indent=4, sort_keys=True)
def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
"""
Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and

Просмотреть файл

@ -16,11 +16,13 @@
Torch utilities for the Trainer class.
"""
import json
import math
import os
import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Iterator, List, Optional, Union
from typing import Dict, Iterator, List, Optional, Union
import numpy as np
import torch
@ -263,29 +265,6 @@ def _get_first_shape(arrays):
return arrays.shape
def get_learning_rate(trainer):
if trainer.deepspeed:
# with deepspeed's fp16 and dynamic loss scale enabled the optimizer/scheduler steps may
# not run for the first few dozen steps while loss scale is too large, and thus during
# that time `get_last_lr` will fail if called during that warm up stage, so work around it:
try:
last_lr = trainer.lr_scheduler.get_last_lr()[0]
except AssertionError as e:
if "need to call step" in str(e):
logger.warn("tried to get lr value before scheduler/optimizer started stepping, returning lr=0")
last_lr = 0
else:
raise
else:
last_lr = (
# backward compatibility for pytorch schedulers
trainer.lr_scheduler.get_last_lr()[0]
if version.parse(torch.__version__) >= version.parse("1.4")
else trainer.lr_scheduler.get_lr()[0]
)
return last_lr
class DistributedTensorGatherer:
"""
A class responsible for properly gathering tensors (or nested list/tuple of tensors) on the CPU by chunks.
@ -563,3 +542,88 @@ class DistributedLengthGroupedSampler(DistributedSampler):
assert len(indices) == self.num_samples
return iter(indices)
# In order to keep `trainer.py` compact and easy to understand, place any secondary PT Trainer
# helper methods here
def _get_learning_rate(self):
if self.deepspeed:
# with deepspeed's fp16 and dynamic loss scale enabled the optimizer/scheduler steps may
# not run for the first few dozen steps while loss scale is too large, and thus during
# that time `get_last_lr` will fail if called during that warm up stage, so work around it:
try:
last_lr = self.lr_scheduler.get_last_lr()[0]
except AssertionError as e:
if "need to call step" in str(e):
logger.warn("tried to get lr value before scheduler/optimizer started stepping, returning lr=0")
last_lr = 0
else:
raise
else:
last_lr = (
# backward compatibility for pytorch schedulers
self.lr_scheduler.get_last_lr()[0]
if version.parse(torch.__version__) >= version.parse("1.4")
else self.lr_scheduler.get_lr()[0]
)
return last_lr
def metrics_format(self, metrics: Dict[str, float]) -> Dict[str, float]:
"""
Reformat Trainer metrics values to a human-readable format
Args:
metrics (:obj:`Dict[str, float]`):
The metrics returned from train/evaluate/predict
Returns:
metrics (:obj:`Dict[str, float]`): The reformatted metrics
"""
metrics_copy = metrics.copy()
for k, v in metrics_copy.items():
if "_mem_" in k:
metrics_copy[k] = f"{ v >> 20 }MB"
elif k == "total_flos":
metrics_copy[k] = f"{ int(v) >> 30 }GF"
elif type(metrics_copy[k]) == float:
metrics_copy[k] = round(v, 4)
return metrics_copy
def log_metrics(self, split, metrics):
"""
Log metrics in a specially formatted way
Args:
split (:obj:`str`):
Mode/split name: one of ``train``, ``eval``, ``test``
metrics (:obj:`Dict[str, float]`):
The metrics returned from train/evaluate/predictmetrics: metrics dict
"""
logger.info(f"***** {split} metrics *****")
metrics_formatted = self.metrics_format(metrics)
k_width = max(len(str(x)) for x in metrics_formatted.keys())
v_width = max(len(str(x)) for x in metrics_formatted.values())
for key in sorted(metrics_formatted.keys()):
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
def save_metrics(self, split, metrics):
"""
Save metrics into a json file for that split, e.g. ``train_results.json``.
Args:
split (:obj:`str`):
Mode/split name: one of ``train``, ``eval``, ``test``, ``all``
metrics (:obj:`Dict[str, float]`):
The metrics returned from train/evaluate/predict
"""
path = os.path.join(self.args.output_dir, f"{split}_results.json")
with open(path, "w") as f:
json.dump(metrics, f, indent=4, sort_keys=True)