[trainer] implement support for full fp16 in evaluation/predict (#10268)

* implement --fp16_full_eval

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* style

* add test

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Stas Bekman 2021-02-18 17:02:35 -08:00 коммит произвёл GitHub
Родитель d9a81fc0c5
Коммит 4eddc459a9
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 92 добавлений и 7 удалений

Просмотреть файл

@ -218,6 +218,8 @@ class Trainer:
- **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set
to :obj:`False` if model parallel or deepspeed is used, or if the default
``TrainingArguments.place_model_on_device`` is overridden to return :obj:`False` .
- **is_in_train** -- Whether or not a model is currently running ``train`` (e.g. when ``evaluate`` is called
while in ``train``)
"""
@ -243,6 +245,7 @@ class Trainer:
set_seed(self.args.seed)
self.hp_name = None
self.deepspeed = None
self.is_in_train = False
# memory metrics - must set up as early as possible
self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
@ -273,7 +276,7 @@ class Trainer:
# one place to sort out whether to place the model on device or not
self.place_model_on_device = args.place_model_on_device
if self.is_model_parallel or (args.deepspeed and args.do_train):
if self.is_model_parallel or (args.deepspeed and args.do_train) or (args.fp16_full_eval and not args.do_train):
self.place_model_on_device = False
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
@ -713,6 +716,10 @@ class Trainer:
return model
def _wrap_model(self, model, training=True):
# already initialized its own DDP and AMP
if self.deepspeed:
return model
# Mixed precision training with apex (torch < 1.6)
if self.use_apex and training:
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
@ -731,8 +738,6 @@ class Trainer:
model = ShardedDDP(model, self.optimizer)
elif is_sagemaker_distributed_available():
model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False)
elif self.deepspeed:
pass # already initialized its own DDP earlier
elif self.args.local_rank != -1:
if self.args.ddp_find_unused_parameters is not None:
find_unused_parameters = self.args.ddp_find_unused_parameters
@ -773,6 +778,8 @@ class Trainer:
# memory metrics - must set up as early as possible
self._memory_tracker.start()
self.is_in_train = True
if "model_path" in kwargs:
resume_from_checkpoint = kwargs.pop("model_path")
warnings.warn(
@ -1088,6 +1095,12 @@ class Trainer:
self.lr_scheduler = None
self.model_wrapped = self.model
gc.collect() # force memory release
# to restore normal behavior outside of train replay the place_model_on_device logic w/o deepspeed
self.place_model_on_device = self.args.place_model_on_device
if self.is_model_parallel:
self.place_model_on_device = False
self.is_in_train = False
self._memory_tracker.stop_and_update_metrics(metrics)
@ -1689,6 +1702,11 @@ class Trainer:
model = self._wrap_model(self.model, training=False)
# if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while
# ``train`` is running, half it first and then put on device
if not self.is_in_train and self.args.fp16_full_eval:
model = model.half().to(self.args.device)
batch_size = dataloader.batch_size
num_examples = self.num_examples(dataloader)
logger.info("***** Running %s *****", description)

Просмотреть файл

@ -155,7 +155,7 @@ class TrainingArguments:
:func:`~transformers.Trainer.model_init` function to instantiate the model if it has some randomly
initialized parameters.
fp16 (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to use 16-bit (mixed) precision training (through NVIDIA Apex) instead of 32-bit training.
Whether to use 16-bit (mixed) precision training instead of 32-bit training.
fp16_opt_level (:obj:`str`, `optional`, defaults to 'O1'):
For :obj:`fp16` training, Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. See details
on the `Apex documentation <https://nvidia.github.io/apex/amp.html>`__.
@ -163,6 +163,9 @@ class TrainingArguments:
The backend to use for mixed precision training. Must be one of :obj:`"auto"`, :obj:`"amp"` or
:obj:`"apex"`. :obj:`"auto"` will use AMP or APEX depending on the PyTorch version detected, while the
other choices will force the requested backend.
fp16_full_eval (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to use full 16-bit precision evaluation instead of 32-bit. This will be faster and save memory but
can harm metric values.
local_rank (:obj:`int`, `optional`, defaults to -1):
Rank of the process during distributed training.
tpu_num_cores (:obj:`int`, `optional`):
@ -353,7 +356,7 @@ class TrainingArguments:
fp16: bool = field(
default=False,
metadata={"help": "Whether to use 16-bit (mixed) precision (through NVIDIA Apex) instead of 32-bit"},
metadata={"help": "Whether to use 16-bit (mixed) precision instead of 32-bit"},
)
fp16_opt_level: str = field(
default="O1",
@ -368,6 +371,10 @@ class TrainingArguments:
default="auto",
metadata={"help": "The backend to be used for mixed precision.", "choices": ["auto", "amp", "apex"]},
)
fp16_full_eval: bool = field(
default=False,
metadata={"help": "Whether to use full 16-bit precision evaluation instead of 32-bit"},
)
local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"})
tpu_num_cores: Optional[int] = field(
@ -488,8 +495,10 @@ class TrainingArguments:
if self.run_name is None:
self.run_name = self.output_dir
if is_torch_available() and self.device.type != "cuda" and self.fp16:
raise ValueError("Mixed precision training with AMP or APEX (`--fp16`) can only be used on CUDA devices.")
if is_torch_available() and self.device.type != "cuda" and (self.fp16 or self.fp16_full_eval):
raise ValueError(
"Mixed precision training with AMP or APEX (`--fp16`) and FP16 evaluation can only be used on CUDA devices."
)
if self.report_to is None:
logger.info(
"The default value for the training argument `--report_to` will change in v5 (from all installed "

Просмотреть файл

@ -14,6 +14,7 @@
# limitations under the License.
import dataclasses
import gc
import os
import tempfile
import unittest
@ -29,6 +30,7 @@ from transformers.testing_utils import (
require_sentencepiece,
require_tokenizers,
require_torch,
require_torch_gpu,
require_torch_multi_gpu,
slow,
)
@ -912,6 +914,62 @@ class TrainerIntegrationTest(unittest.TestCase):
trainer = get_regression_trainer(skip_memory_metrics=True)
self.check_mem_metrics(trainer, self.assertNotIn)
@require_torch_gpu
def test_fp16_full_eval(self):
# this is a sensitive test so let's keep debugging printouts in place for quick diagnosis.
# it's using pretty large safety margins, but small enough to detect broken functionality.
debug = 0
bs = 8
# make the params somewhat big so that there will be enough RAM consumed to be able to
# measure things. We should get about 64KB for a+b in fp32
a = torch.ones(1000, bs) + 0.001
b = torch.ones(1000, bs) - 0.001
# 1. with mem metrics enabled
trainer = get_regression_trainer(a=a, b=b, eval_len=16)
metrics = trainer.evaluate()
del trainer
gc.collect()
fp32_init = metrics["init_mem_gpu_alloc_delta"]
fp32_eval = metrics["eval_mem_gpu_alloc_delta"]
if debug:
print(f"fp32_init {fp32_init}")
print(f"fp32_eval {fp32_eval}")
# here we expect the model to be preloaded in trainer.__init__ and consume around 64K gpu ram.
# perfect world: fp32_init == 64<<10
self.assertGreater(fp32_init, 59_000)
# after eval should be no extra memory allocated - with a small margin (other than the peak
# memory consumption for the forward calculation that gets recovered)
# perfect world: fp32_eval == close to zero
self.assertLess(fp32_eval, 5_000)
# 2. with mem metrics disabled
trainer = get_regression_trainer(a=a, b=b, eval_len=16, fp16_full_eval=True)
metrics = trainer.evaluate()
fp16_init = metrics["init_mem_gpu_alloc_delta"]
fp16_eval = metrics["eval_mem_gpu_alloc_delta"]
if debug:
print(f"fp16_init {fp16_init}")
print(f"fp16_eval {fp16_eval}")
# here we expect the model to not be preloaded in trainer.__init__, so with a small margin it should be close to 0
# perfect world: fp16_init == close to zero
self.assertLess(fp16_init, 5_000)
# here we put the model on device in eval and only `half()` of it, i.e. about 32K,(again we ignore the peak margin which gets returned back)
# perfect world: fp32_init == 32<<10
self.assertGreater(fp16_eval, 27_000)
# 3. relative comparison fp32 vs full fp16
# should be about half of fp16_init
# perfect world: fp32_init/2 == fp16_eval
self.assertAlmostEqual(fp16_eval, fp32_init / 2, delta=5_000)
@require_torch
@require_optuna