[Trainer] memory tracker metrics (#10225)
* memory tracker metrics * go back to eval for somewhat consistency * handle no-gpu case * deal with stackable eval calls * restore callback order * style * simplify the API * add test * docs * consistently use eval_ prefix * improve docs * Update src/transformers/trainer_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * rename method * style Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Родитель
d7f38c5d1d
Коммит
97e688bc22
|
@ -588,9 +588,12 @@ def main():
|
|||
)
|
||||
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
|
||||
if trainer.is_world_process_zero():
|
||||
metrics_formatted = trainer.metrics_format(metrics)
|
||||
logger.info("***** train metrics *****")
|
||||
for key in sorted(metrics.keys()):
|
||||
logger.info(f" {key} = {metrics[key]}")
|
||||
k_width = max(len(str(x)) for x in metrics_formatted.keys())
|
||||
v_width = max(len(str(x)) for x in metrics_formatted.values())
|
||||
for key in sorted(metrics_formatted.keys()):
|
||||
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
|
||||
save_json(metrics, os.path.join(training_args.output_dir, "train_results.json"))
|
||||
all_metrics.update(metrics)
|
||||
|
||||
|
@ -603,17 +606,19 @@ def main():
|
|||
logger.info("*** Evaluate ***")
|
||||
|
||||
metrics = trainer.evaluate(
|
||||
max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="val"
|
||||
max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval"
|
||||
)
|
||||
metrics = {k: round(v, 4) for k, v in metrics.items()}
|
||||
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
|
||||
metrics["val_samples"] = min(max_val_samples, len(eval_dataset))
|
||||
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
|
||||
|
||||
if trainer.is_world_process_zero():
|
||||
metrics_formatted = trainer.metrics_format(metrics)
|
||||
logger.info("***** val metrics *****")
|
||||
for key in sorted(metrics.keys()):
|
||||
logger.info(f" {key} = {metrics[key]}")
|
||||
save_json(metrics, os.path.join(training_args.output_dir, "val_results.json"))
|
||||
k_width = max(len(str(x)) for x in metrics_formatted.keys())
|
||||
v_width = max(len(str(x)) for x in metrics_formatted.values())
|
||||
for key in sorted(metrics_formatted.keys()):
|
||||
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
|
||||
save_json(metrics, os.path.join(training_args.output_dir, "eval_results.json"))
|
||||
all_metrics.update(metrics)
|
||||
|
||||
if training_args.do_predict:
|
||||
|
@ -628,12 +633,14 @@ def main():
|
|||
metrics = test_results.metrics
|
||||
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
|
||||
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
|
||||
metrics = {k: round(v, 4) for k, v in metrics.items()}
|
||||
|
||||
if trainer.is_world_process_zero():
|
||||
metrics_formatted = trainer.metrics_format(metrics)
|
||||
logger.info("***** test metrics *****")
|
||||
for key in sorted(metrics.keys()):
|
||||
logger.info(f" {key} = {metrics[key]}")
|
||||
k_width = max(len(str(x)) for x in metrics_formatted.keys())
|
||||
v_width = max(len(str(x)) for x in metrics_formatted.values())
|
||||
for key in sorted(metrics_formatted.keys()):
|
||||
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
|
||||
save_json(metrics, os.path.join(training_args.output_dir, "test_results.json"))
|
||||
all_metrics.update(metrics)
|
||||
|
||||
|
|
|
@ -88,8 +88,8 @@ class TestDeepSpeed(TestCasePlus):
|
|||
extra_args_str="--do_eval",
|
||||
remove_args_str="--do_train",
|
||||
)
|
||||
val_metrics = load_json(os.path.join(output_dir, "val_results.json"))
|
||||
assert "val_bleu" in val_metrics
|
||||
val_metrics = load_json(os.path.join(output_dir, "eval_results.json"))
|
||||
assert "eval_bleu" in val_metrics
|
||||
|
||||
# XXX: need to do better validation beyond just that the run was successful
|
||||
def run_quick(self, distributed=True, extra_args_str=None, remove_args_str=None):
|
||||
|
|
|
@ -236,6 +236,15 @@ def is_torch_available():
|
|||
return _torch_available
|
||||
|
||||
|
||||
def is_torch_cuda_available():
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
return torch.cuda.is_available()
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def is_tf_available():
|
||||
return _tf_available
|
||||
|
||||
|
|
|
@ -93,6 +93,7 @@ from .trainer_utils import (
|
|||
EvalPrediction,
|
||||
HPSearchBackend,
|
||||
PredictionOutput,
|
||||
TrainerMemoryTracker,
|
||||
TrainOutput,
|
||||
default_compute_objective,
|
||||
default_hp_space,
|
||||
|
@ -243,6 +244,10 @@ class Trainer:
|
|||
self.hp_name = None
|
||||
self.deepspeed = None
|
||||
|
||||
# memory metrics - must set up as early as possible
|
||||
self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
|
||||
self._memory_tracker.start()
|
||||
|
||||
# force device and distributed setup init explicitly
|
||||
args._setup_devices
|
||||
|
||||
|
@ -394,6 +399,9 @@ class Trainer:
|
|||
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
|
||||
self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)
|
||||
|
||||
# very last
|
||||
self._memory_tracker.stop_and_update_metrics()
|
||||
|
||||
def add_callback(self, callback):
|
||||
"""
|
||||
Add a callback to the current list of :class:`~transformer.TrainerCallback`.
|
||||
|
@ -761,6 +769,10 @@ class Trainer:
|
|||
kwargs:
|
||||
Additional keyword arguments used to hide deprecated arguments
|
||||
"""
|
||||
|
||||
# memory metrics - must set up as early as possible
|
||||
self._memory_tracker.start()
|
||||
|
||||
if "model_path" in kwargs:
|
||||
resume_from_checkpoint = kwargs.pop("model_path")
|
||||
warnings.warn(
|
||||
|
@ -1077,6 +1089,8 @@ class Trainer:
|
|||
self.model_wrapped = self.model
|
||||
gc.collect() # force memory release
|
||||
|
||||
self._memory_tracker.stop_and_update_metrics(metrics)
|
||||
|
||||
return TrainOutput(self.state.global_step, self._total_loss_scalar / self.state.global_step, metrics)
|
||||
|
||||
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
|
||||
|
@ -1306,6 +1320,29 @@ class Trainer:
|
|||
self.state.log_history.append(output)
|
||||
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
|
||||
|
||||
def metrics_format(self, metrics: Dict[str, float]) -> Dict[str, float]:
|
||||
"""
|
||||
Reformat Trainer metrics values to a human-readable format
|
||||
|
||||
Args:
|
||||
metrics (:obj:`Dict[str, float]`):
|
||||
The metrics returned from train/evaluate/predict
|
||||
|
||||
Returns:
|
||||
metrics (:obj:`Dict[str, float]`): The reformatted metrics
|
||||
"""
|
||||
|
||||
metrics_copy = metrics.copy()
|
||||
for k, v in metrics_copy.items():
|
||||
if "_mem_" in k:
|
||||
metrics_copy[k] = f"{ v >> 20 }MB"
|
||||
elif k == "total_flos":
|
||||
metrics_copy[k] = f"{ int(v) >> 30 }GF"
|
||||
elif type(metrics_copy[k]) == float:
|
||||
metrics_copy[k] = round(v, 4)
|
||||
|
||||
return metrics_copy
|
||||
|
||||
def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
|
||||
"""
|
||||
Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
|
||||
|
@ -1542,6 +1579,9 @@ class Trainer:
|
|||
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
|
||||
dictionary also contains the epoch number which comes from the training state.
|
||||
"""
|
||||
# memory metrics - must set up as early as possible
|
||||
self._memory_tracker.start()
|
||||
|
||||
if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
|
||||
raise ValueError("eval_dataset must implement __len__")
|
||||
|
||||
|
@ -1567,6 +1607,9 @@ class Trainer:
|
|||
xm.master_print(met.metrics_report())
|
||||
|
||||
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
|
||||
|
||||
self._memory_tracker.stop_and_update_metrics(output.metrics)
|
||||
|
||||
return output.metrics
|
||||
|
||||
def predict(
|
||||
|
@ -1602,6 +1645,9 @@ class Trainer:
|
|||
- metrics (:obj:`Dict[str, float]`, `optional`): The potential dictionary of metrics (if the dataset
|
||||
contained labels).
|
||||
"""
|
||||
# memory metrics - must set up as early as possible
|
||||
self._memory_tracker.start()
|
||||
|
||||
if test_dataset is not None and not isinstance(test_dataset, collections.abc.Sized):
|
||||
raise ValueError("test_dataset must implement __len__")
|
||||
|
||||
|
@ -1612,6 +1658,9 @@ class Trainer:
|
|||
test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
|
||||
)
|
||||
output.metrics.update(speed_metrics(metric_key_prefix, start_time, len(test_dataset)))
|
||||
|
||||
self._memory_tracker.stop_and_update_metrics(output.metrics)
|
||||
|
||||
return output
|
||||
|
||||
def prediction_loop(
|
||||
|
|
|
@ -17,15 +17,24 @@ Utilities for the Trainer and TFTrainer class. Should be independent from PyTorc
|
|||
"""
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import inspect
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import tracemalloc
|
||||
from typing import Any, Dict, NamedTuple, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .file_utils import is_sagemaker_distributed_available, is_tf_available, is_torch_available, is_torch_tpu_available
|
||||
from .file_utils import (
|
||||
is_sagemaker_distributed_available,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
is_torch_cuda_available,
|
||||
is_torch_tpu_available,
|
||||
)
|
||||
from .tokenization_utils_base import ExplicitEnum
|
||||
|
||||
|
||||
|
@ -234,3 +243,175 @@ class SchedulerType(ExplicitEnum):
|
|||
POLYNOMIAL = "polynomial"
|
||||
CONSTANT = "constant"
|
||||
CONSTANT_WITH_WARMUP = "constant_with_warmup"
|
||||
|
||||
|
||||
class TrainerMemoryTracker:
|
||||
"""
|
||||
A helper class that tracks cpu and gpu memory.
|
||||
|
||||
When a stage completes, it can pass metrics dict to update with the memory metrics gathered during this stage.
|
||||
|
||||
Example ::
|
||||
|
||||
self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
|
||||
self._memory_tracker.start()
|
||||
code ...
|
||||
metrics = {"train_runtime": 10.5}
|
||||
self._memory_tracker.stop_and_update_metrics(metrics)
|
||||
|
||||
At the moment gpu tracking is only for pytorch, but can be extended to support tensorflow.
|
||||
|
||||
Understanding the reports:
|
||||
|
||||
- ``*_alloc_delta`` - is the difference in the used/allocated memory counter between the end and the start of the
|
||||
stage - it can be negative if a function released more memory than it allocated.
|
||||
|
||||
- ``*_peaked_delta`` - is any extra memory that was consumed and then freed - relative to the current allocated
|
||||
memory counter - it is never negative.
|
||||
|
||||
So when you look at the metrics of any stage you add up ``alloc_delta`` + ``peaked_delta`` and you know how much
|
||||
memory was needed to complete that stage.
|
||||
|
||||
The reporting happens only for process of rank 0 and gpu 0 (if there is a gpu). Typically this is enough since the
|
||||
main process does the bulk of work, but it could be not quite so if model parallel is used and then other gpus may
|
||||
use a different amount of gpu RAM. Perhaps in the future this tracker will evolve to measure those too.
|
||||
|
||||
Note that this tracker doesn't account for memory allocations outside of :class:`~transformers.Trainer`'s
|
||||
``__init__``, ``train``, ``evaluate`` and ``predict`` calls.
|
||||
|
||||
Because ``evaluation`` calls may happen during ``train``, we can't handle nested invocations because
|
||||
``torch.cuda.max_memory_allocated`` is a single counter, so if it gets reset by a nested eval call, ``train``'s
|
||||
tracker will report incorrect info. If this `pytorch issue <https://github.com/pytorch/pytorch/issues/16266>`__
|
||||
gets resolved it will be possible to change this class to be re-entrant. Until then we will only track the outer
|
||||
level of ``train``, ``evaluate`` and ``predict`` methods. Which means that if ``eval`` is called during ``train``,
|
||||
it's the latter that will account for its memory usage and that of the former.
|
||||
|
||||
This also means that if any other tool that is used along the :class:`~transformers.Trainer` calls
|
||||
``torch.cuda.reset_peak_memory_stats``, the gpu peak memory stats could be invalid. And the
|
||||
:class:`~transformers.Trainer` will disrupt the normal behavior of any such tools that rely on calling
|
||||
``torch.cuda.reset_peak_memory_stats`` themselves.
|
||||
|
||||
"""
|
||||
|
||||
# map trainer methods to metrics prefix
|
||||
stages = {
|
||||
"__init__": "init",
|
||||
"train": "train",
|
||||
"evaluate": "eval",
|
||||
"predict": "test",
|
||||
}
|
||||
|
||||
def __init__(self, skip_memory_metrics=False):
|
||||
if is_torch_cuda_available():
|
||||
import torch
|
||||
|
||||
self.torch = torch
|
||||
self.gpu = {}
|
||||
else:
|
||||
self.torch = None
|
||||
|
||||
self.cur_stage = None
|
||||
self.cpu = {}
|
||||
self.init_reported = False
|
||||
self.skip_memory_metrics = skip_memory_metrics
|
||||
|
||||
def derive_stage(self):
|
||||
""" derives the stage/caller name automatically """
|
||||
caller = inspect.currentframe().f_back.f_back.f_code.co_name
|
||||
if caller in self.stages:
|
||||
return self.stages[caller]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"was called from {caller}, but only expect to be called from one of {self.stages.keys()}"
|
||||
)
|
||||
|
||||
def start(self):
|
||||
""" start tracking for the caller's stage """
|
||||
if self.skip_memory_metrics:
|
||||
return
|
||||
|
||||
stage = self.derive_stage()
|
||||
# deal with nested calls of eval during train - simply ignore those
|
||||
if self.cur_stage is not None and self.cur_stage != stage:
|
||||
return
|
||||
|
||||
self.cur_stage = stage
|
||||
|
||||
if self.torch is not None:
|
||||
self.torch.cuda.reset_peak_memory_stats()
|
||||
self.torch.cuda.empty_cache()
|
||||
|
||||
gc.collect()
|
||||
|
||||
# gpu
|
||||
if self.torch is not None:
|
||||
self.gpu[self.cur_stage] = {}
|
||||
self.gpu[self.cur_stage]["alloc"] = self.torch.cuda.memory_allocated()
|
||||
self.gpu[self.cur_stage]["peaked"] = 0
|
||||
|
||||
# cpu
|
||||
self.cpu[self.cur_stage] = {}
|
||||
tracemalloc.start()
|
||||
|
||||
def stop(self, stage):
|
||||
""" stop tracking for the passed stage """
|
||||
|
||||
# deal with nested calls of eval during train - simply ignore those
|
||||
if self.cur_stage is not None and self.cur_stage != stage:
|
||||
return
|
||||
|
||||
if self.torch is not None:
|
||||
self.torch.cuda.empty_cache()
|
||||
|
||||
gc.collect()
|
||||
|
||||
# gpu
|
||||
if self.torch is not None:
|
||||
mem_cur = self.torch.cuda.memory_allocated()
|
||||
# this is the difference between the start and the end allocated memory
|
||||
self.gpu[self.cur_stage]["alloc"] = mem_cur - self.gpu[self.cur_stage]["alloc"] # can be negative
|
||||
# this is the difference if any between the start and the peak
|
||||
self.gpu[self.cur_stage]["peaked"] = max(0, self.torch.cuda.max_memory_allocated() - mem_cur)
|
||||
|
||||
# cpu
|
||||
cpu_mem_used_delta, cpu_mem_used_peak = tracemalloc.get_traced_memory()
|
||||
tracemalloc.stop() # reset accounting
|
||||
self.cpu[self.cur_stage]["alloc"] = cpu_mem_used_delta # can be negative
|
||||
self.cpu[self.cur_stage]["peaked"] = max(0, cpu_mem_used_peak - cpu_mem_used_delta)
|
||||
|
||||
# reset - cycle finished
|
||||
self.cur_stage = None
|
||||
|
||||
def update_metrics(self, stage, metrics):
|
||||
""" stop tracking for the passed stage """
|
||||
if self.skip_memory_metrics:
|
||||
return
|
||||
|
||||
# deal with nested calls of eval during train - simply ignore those
|
||||
if self.cur_stage is not None and self.cur_stage != stage:
|
||||
return
|
||||
|
||||
# since we don't have a way to return init metrics, we push them into the first of train/val/predict
|
||||
stages = [stage]
|
||||
if not self.init_reported:
|
||||
stages.insert(0, "init")
|
||||
self.init_reported = True
|
||||
|
||||
for stage in stages:
|
||||
for t in ["alloc", "peaked"]:
|
||||
if stage in self.cpu and t in self.cpu[stage]:
|
||||
metrics[f"{stage}_mem_cpu_{t}_delta"] = self.cpu[stage][t]
|
||||
if self.torch is not None and stage in self.gpu and t in self.gpu[stage]:
|
||||
metrics[f"{stage}_mem_gpu_{t}_delta"] = self.gpu[stage][t]
|
||||
|
||||
def stop_and_update_metrics(self, metrics=None):
|
||||
""" combine stop + update in one call for simpler code """
|
||||
if self.skip_memory_metrics:
|
||||
return
|
||||
|
||||
stage = self.derive_stage()
|
||||
self.stop(stage)
|
||||
|
||||
# init doesn't have metrics to update so we just save that data for later stages to retrieve
|
||||
if metrics is not None:
|
||||
self.update_metrics(stage, metrics)
|
||||
|
|
|
@ -252,6 +252,9 @@ class TrainingArguments:
|
|||
otherwise.
|
||||
dataloader_pin_memory (:obj:`bool`, `optional`, defaults to :obj:`True`)):
|
||||
Whether you want to pin memory in data loaders or not. Will default to :obj:`True`.
|
||||
skip_memory_metrics (:obj:`bool`, `optional`, defaults to :obj:`False`)):
|
||||
Whether to skip adding of memory profiler reports to metrics. Defaults to :obj:`False`.
|
||||
|
||||
"""
|
||||
|
||||
output_dir: Optional[str] = field(
|
||||
|
@ -451,6 +454,9 @@ class TrainingArguments:
|
|||
dataloader_pin_memory: bool = field(
|
||||
default=True, metadata={"help": "Whether or not to pin memory for DataLoader."}
|
||||
)
|
||||
skip_memory_metrics: bool = field(
|
||||
default=False, metadata={"help": "Whether or not to skip adding of memory profiler reports to metrics."}
|
||||
)
|
||||
_n_gpu: int = field(init=False, repr=False, default=-1)
|
||||
|
||||
def __post_init__(self):
|
||||
|
|
|
@ -884,6 +884,34 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||
trainer.train()
|
||||
self.assertTrue(isinstance(trainer.state.total_flos, float))
|
||||
|
||||
def check_mem_metrics(self, trainer, check_func):
|
||||
metrics = trainer.train().metrics
|
||||
check_func("init_mem_cpu_alloc_delta", metrics)
|
||||
check_func("train_mem_cpu_alloc_delta", metrics)
|
||||
if torch.cuda.device_count() > 0:
|
||||
check_func("init_mem_gpu_alloc_delta", metrics)
|
||||
check_func("train_mem_gpu_alloc_delta", metrics)
|
||||
|
||||
metrics = trainer.evaluate()
|
||||
check_func("eval_mem_cpu_alloc_delta", metrics)
|
||||
if torch.cuda.device_count() > 0:
|
||||
check_func("eval_mem_gpu_alloc_delta", metrics)
|
||||
|
||||
metrics = trainer.predict(RegressionDataset()).metrics
|
||||
check_func("test_mem_cpu_alloc_delta", metrics)
|
||||
if torch.cuda.device_count() > 0:
|
||||
check_func("test_mem_gpu_alloc_delta", metrics)
|
||||
|
||||
def test_mem_metrics(self):
|
||||
|
||||
# with mem metrics enabled
|
||||
trainer = get_regression_trainer()
|
||||
self.check_mem_metrics(trainer, self.assertIn)
|
||||
|
||||
# with mem metrics disabled
|
||||
trainer = get_regression_trainer(skip_memory_metrics=True)
|
||||
self.check_mem_metrics(trainer, self.assertNotIn)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_optuna
|
||||
|
|
Загрузка…
Ссылка в новой задаче