[trainer] move secondary methods into a separate file (#10363)
* move secondary methods into a separate file * cleanup * style
This commit is contained in:
Родитель
5f2a3d721c
Коммит
bdbb2c756b
|
@ -19,7 +19,6 @@ The Trainer class, to easily train a 🤗 Transformers from scratch or finetune
|
|||
import collections
|
||||
import gc
|
||||
import inspect
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
|
@ -82,7 +81,6 @@ from .trainer_pt_utils import (
|
|||
SequentialDistributedSampler,
|
||||
distributed_broadcast_scalars,
|
||||
distributed_concat,
|
||||
get_learning_rate,
|
||||
nested_concat,
|
||||
nested_detach,
|
||||
nested_numpify,
|
||||
|
@ -226,6 +224,8 @@ class Trainer:
|
|||
|
||||
"""
|
||||
|
||||
from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[PreTrainedModel, torch.nn.Module] = None,
|
||||
|
@ -1130,7 +1130,7 @@ class Trainer:
|
|||
tr_loss -= tr_loss
|
||||
|
||||
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
|
||||
logs["learning_rate"] = get_learning_rate(self)
|
||||
logs["learning_rate"] = self._get_learning_rate()
|
||||
|
||||
self._total_loss_scalar += tr_loss_scalar
|
||||
self._globalstep_last_logged = self.state.global_step
|
||||
|
@ -1345,61 +1345,6 @@ class Trainer:
|
|||
self.state.log_history.append(output)
|
||||
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
|
||||
|
||||
def metrics_format(self, metrics: Dict[str, float]) -> Dict[str, float]:
|
||||
"""
|
||||
Reformat Trainer metrics values to a human-readable format
|
||||
|
||||
Args:
|
||||
metrics (:obj:`Dict[str, float]`):
|
||||
The metrics returned from train/evaluate/predict
|
||||
|
||||
Returns:
|
||||
metrics (:obj:`Dict[str, float]`): The reformatted metrics
|
||||
"""
|
||||
|
||||
metrics_copy = metrics.copy()
|
||||
for k, v in metrics_copy.items():
|
||||
if "_mem_" in k:
|
||||
metrics_copy[k] = f"{ v >> 20 }MB"
|
||||
elif k == "total_flos":
|
||||
metrics_copy[k] = f"{ int(v) >> 30 }GF"
|
||||
elif type(metrics_copy[k]) == float:
|
||||
metrics_copy[k] = round(v, 4)
|
||||
|
||||
return metrics_copy
|
||||
|
||||
def log_metrics(self, split, metrics):
|
||||
"""
|
||||
Log metrics in a specially formatted way
|
||||
|
||||
Args:
|
||||
split (:obj:`str`):
|
||||
Mode/split name: one of ``train``, ``eval``, ``test``
|
||||
metrics (:obj:`Dict[str, float]`):
|
||||
The metrics returned from train/evaluate/predictmetrics: metrics dict
|
||||
"""
|
||||
|
||||
logger.info(f"***** {split} metrics *****")
|
||||
metrics_formatted = self.metrics_format(metrics)
|
||||
k_width = max(len(str(x)) for x in metrics_formatted.keys())
|
||||
v_width = max(len(str(x)) for x in metrics_formatted.values())
|
||||
for key in sorted(metrics_formatted.keys()):
|
||||
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
|
||||
|
||||
def save_metrics(self, split, metrics):
|
||||
"""
|
||||
Save metrics into a json file for that split, e.g. ``train_results.json``.
|
||||
|
||||
Args:
|
||||
split (:obj:`str`):
|
||||
Mode/split name: one of ``train``, ``eval``, ``test``, ``all``
|
||||
metrics (:obj:`Dict[str, float]`):
|
||||
The metrics returned from train/evaluate/predict
|
||||
"""
|
||||
path = os.path.join(self.args.output_dir, f"{split}_results.json")
|
||||
with open(path, "w") as f:
|
||||
json.dump(metrics, f, indent=4, sort_keys=True)
|
||||
|
||||
def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
|
||||
"""
|
||||
Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
|
||||
|
|
|
@ -16,11 +16,13 @@
|
|||
Torch utilities for the Trainer class.
|
||||
"""
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterator, List, Optional, Union
|
||||
from typing import Dict, Iterator, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -263,29 +265,6 @@ def _get_first_shape(arrays):
|
|||
return arrays.shape
|
||||
|
||||
|
||||
def get_learning_rate(trainer):
|
||||
if trainer.deepspeed:
|
||||
# with deepspeed's fp16 and dynamic loss scale enabled the optimizer/scheduler steps may
|
||||
# not run for the first few dozen steps while loss scale is too large, and thus during
|
||||
# that time `get_last_lr` will fail if called during that warm up stage, so work around it:
|
||||
try:
|
||||
last_lr = trainer.lr_scheduler.get_last_lr()[0]
|
||||
except AssertionError as e:
|
||||
if "need to call step" in str(e):
|
||||
logger.warn("tried to get lr value before scheduler/optimizer started stepping, returning lr=0")
|
||||
last_lr = 0
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
last_lr = (
|
||||
# backward compatibility for pytorch schedulers
|
||||
trainer.lr_scheduler.get_last_lr()[0]
|
||||
if version.parse(torch.__version__) >= version.parse("1.4")
|
||||
else trainer.lr_scheduler.get_lr()[0]
|
||||
)
|
||||
return last_lr
|
||||
|
||||
|
||||
class DistributedTensorGatherer:
|
||||
"""
|
||||
A class responsible for properly gathering tensors (or nested list/tuple of tensors) on the CPU by chunks.
|
||||
|
@ -563,3 +542,88 @@ class DistributedLengthGroupedSampler(DistributedSampler):
|
|||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices)
|
||||
|
||||
|
||||
# In order to keep `trainer.py` compact and easy to understand, place any secondary PT Trainer
|
||||
# helper methods here
|
||||
|
||||
|
||||
def _get_learning_rate(self):
|
||||
if self.deepspeed:
|
||||
# with deepspeed's fp16 and dynamic loss scale enabled the optimizer/scheduler steps may
|
||||
# not run for the first few dozen steps while loss scale is too large, and thus during
|
||||
# that time `get_last_lr` will fail if called during that warm up stage, so work around it:
|
||||
try:
|
||||
last_lr = self.lr_scheduler.get_last_lr()[0]
|
||||
except AssertionError as e:
|
||||
if "need to call step" in str(e):
|
||||
logger.warn("tried to get lr value before scheduler/optimizer started stepping, returning lr=0")
|
||||
last_lr = 0
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
last_lr = (
|
||||
# backward compatibility for pytorch schedulers
|
||||
self.lr_scheduler.get_last_lr()[0]
|
||||
if version.parse(torch.__version__) >= version.parse("1.4")
|
||||
else self.lr_scheduler.get_lr()[0]
|
||||
)
|
||||
return last_lr
|
||||
|
||||
|
||||
def metrics_format(self, metrics: Dict[str, float]) -> Dict[str, float]:
|
||||
"""
|
||||
Reformat Trainer metrics values to a human-readable format
|
||||
|
||||
Args:
|
||||
metrics (:obj:`Dict[str, float]`):
|
||||
The metrics returned from train/evaluate/predict
|
||||
|
||||
Returns:
|
||||
metrics (:obj:`Dict[str, float]`): The reformatted metrics
|
||||
"""
|
||||
|
||||
metrics_copy = metrics.copy()
|
||||
for k, v in metrics_copy.items():
|
||||
if "_mem_" in k:
|
||||
metrics_copy[k] = f"{ v >> 20 }MB"
|
||||
elif k == "total_flos":
|
||||
metrics_copy[k] = f"{ int(v) >> 30 }GF"
|
||||
elif type(metrics_copy[k]) == float:
|
||||
metrics_copy[k] = round(v, 4)
|
||||
|
||||
return metrics_copy
|
||||
|
||||
|
||||
def log_metrics(self, split, metrics):
|
||||
"""
|
||||
Log metrics in a specially formatted way
|
||||
|
||||
Args:
|
||||
split (:obj:`str`):
|
||||
Mode/split name: one of ``train``, ``eval``, ``test``
|
||||
metrics (:obj:`Dict[str, float]`):
|
||||
The metrics returned from train/evaluate/predictmetrics: metrics dict
|
||||
"""
|
||||
|
||||
logger.info(f"***** {split} metrics *****")
|
||||
metrics_formatted = self.metrics_format(metrics)
|
||||
k_width = max(len(str(x)) for x in metrics_formatted.keys())
|
||||
v_width = max(len(str(x)) for x in metrics_formatted.values())
|
||||
for key in sorted(metrics_formatted.keys()):
|
||||
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
|
||||
|
||||
|
||||
def save_metrics(self, split, metrics):
|
||||
"""
|
||||
Save metrics into a json file for that split, e.g. ``train_results.json``.
|
||||
|
||||
Args:
|
||||
split (:obj:`str`):
|
||||
Mode/split name: one of ``train``, ``eval``, ``test``, ``all``
|
||||
metrics (:obj:`Dict[str, float]`):
|
||||
The metrics returned from train/evaluate/predict
|
||||
"""
|
||||
path = os.path.join(self.args.output_dir, f"{split}_results.json")
|
||||
with open(path, "w") as f:
|
||||
json.dump(metrics, f, indent=4, sort_keys=True)
|
||||
|
|
Загрузка…
Ссылка в новой задаче