diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 3578b1b5d..a24cbfe71 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -218,6 +218,8 @@ class Trainer: - **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set to :obj:`False` if model parallel or deepspeed is used, or if the default ``TrainingArguments.place_model_on_device`` is overridden to return :obj:`False` . + - **is_in_train** -- Whether or not a model is currently running ``train`` (e.g. when ``evaluate`` is called + while in ``train``) """ @@ -243,6 +245,7 @@ class Trainer: set_seed(self.args.seed) self.hp_name = None self.deepspeed = None + self.is_in_train = False # memory metrics - must set up as early as possible self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics) @@ -273,7 +276,7 @@ class Trainer: # one place to sort out whether to place the model on device or not self.place_model_on_device = args.place_model_on_device - if self.is_model_parallel or (args.deepspeed and args.do_train): + if self.is_model_parallel or (args.deepspeed and args.do_train) or (args.fp16_full_eval and not args.do_train): self.place_model_on_device = False default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) @@ -713,6 +716,10 @@ class Trainer: return model def _wrap_model(self, model, training=True): + # already initialized its own DDP and AMP + if self.deepspeed: + return model + # Mixed precision training with apex (torch < 1.6) if self.use_apex and training: model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level) @@ -731,8 +738,6 @@ class Trainer: model = ShardedDDP(model, self.optimizer) elif is_sagemaker_distributed_available(): model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False) - elif self.deepspeed: - pass # already initialized its own DDP earlier elif self.args.local_rank != -1: if self.args.ddp_find_unused_parameters is not None: find_unused_parameters = self.args.ddp_find_unused_parameters @@ -773,6 +778,8 @@ class Trainer: # memory metrics - must set up as early as possible self._memory_tracker.start() + self.is_in_train = True + if "model_path" in kwargs: resume_from_checkpoint = kwargs.pop("model_path") warnings.warn( @@ -1088,6 +1095,12 @@ class Trainer: self.lr_scheduler = None self.model_wrapped = self.model gc.collect() # force memory release + # to restore normal behavior outside of train replay the place_model_on_device logic w/o deepspeed + self.place_model_on_device = self.args.place_model_on_device + if self.is_model_parallel: + self.place_model_on_device = False + + self.is_in_train = False self._memory_tracker.stop_and_update_metrics(metrics) @@ -1689,6 +1702,11 @@ class Trainer: model = self._wrap_model(self.model, training=False) + # if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while + # ``train`` is running, half it first and then put on device + if not self.is_in_train and self.args.fp16_full_eval: + model = model.half().to(self.args.device) + batch_size = dataloader.batch_size num_examples = self.num_examples(dataloader) logger.info("***** Running %s *****", description) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 74e996477..05c05bb6a 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -155,7 +155,7 @@ class TrainingArguments: :func:`~transformers.Trainer.model_init` function to instantiate the model if it has some randomly initialized parameters. fp16 (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether to use 16-bit (mixed) precision training (through NVIDIA Apex) instead of 32-bit training. + Whether to use 16-bit (mixed) precision training instead of 32-bit training. fp16_opt_level (:obj:`str`, `optional`, defaults to 'O1'): For :obj:`fp16` training, Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. See details on the `Apex documentation `__. @@ -163,6 +163,9 @@ class TrainingArguments: The backend to use for mixed precision training. Must be one of :obj:`"auto"`, :obj:`"amp"` or :obj:`"apex"`. :obj:`"auto"` will use AMP or APEX depending on the PyTorch version detected, while the other choices will force the requested backend. + fp16_full_eval (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to use full 16-bit precision evaluation instead of 32-bit. This will be faster and save memory but + can harm metric values. local_rank (:obj:`int`, `optional`, defaults to -1): Rank of the process during distributed training. tpu_num_cores (:obj:`int`, `optional`): @@ -353,7 +356,7 @@ class TrainingArguments: fp16: bool = field( default=False, - metadata={"help": "Whether to use 16-bit (mixed) precision (through NVIDIA Apex) instead of 32-bit"}, + metadata={"help": "Whether to use 16-bit (mixed) precision instead of 32-bit"}, ) fp16_opt_level: str = field( default="O1", @@ -368,6 +371,10 @@ class TrainingArguments: default="auto", metadata={"help": "The backend to be used for mixed precision.", "choices": ["auto", "amp", "apex"]}, ) + fp16_full_eval: bool = field( + default=False, + metadata={"help": "Whether to use full 16-bit precision evaluation instead of 32-bit"}, + ) local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"}) tpu_num_cores: Optional[int] = field( @@ -488,8 +495,10 @@ class TrainingArguments: if self.run_name is None: self.run_name = self.output_dir - if is_torch_available() and self.device.type != "cuda" and self.fp16: - raise ValueError("Mixed precision training with AMP or APEX (`--fp16`) can only be used on CUDA devices.") + if is_torch_available() and self.device.type != "cuda" and (self.fp16 or self.fp16_full_eval): + raise ValueError( + "Mixed precision training with AMP or APEX (`--fp16`) and FP16 evaluation can only be used on CUDA devices." + ) if self.report_to is None: logger.info( "The default value for the training argument `--report_to` will change in v5 (from all installed " diff --git a/tests/test_trainer.py b/tests/test_trainer.py index dc8209ab6..7d9cadfc4 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -14,6 +14,7 @@ # limitations under the License. import dataclasses +import gc import os import tempfile import unittest @@ -29,6 +30,7 @@ from transformers.testing_utils import ( require_sentencepiece, require_tokenizers, require_torch, + require_torch_gpu, require_torch_multi_gpu, slow, ) @@ -912,6 +914,62 @@ class TrainerIntegrationTest(unittest.TestCase): trainer = get_regression_trainer(skip_memory_metrics=True) self.check_mem_metrics(trainer, self.assertNotIn) + @require_torch_gpu + def test_fp16_full_eval(self): + + # this is a sensitive test so let's keep debugging printouts in place for quick diagnosis. + # it's using pretty large safety margins, but small enough to detect broken functionality. + debug = 0 + + bs = 8 + # make the params somewhat big so that there will be enough RAM consumed to be able to + # measure things. We should get about 64KB for a+b in fp32 + a = torch.ones(1000, bs) + 0.001 + b = torch.ones(1000, bs) - 0.001 + + # 1. with mem metrics enabled + trainer = get_regression_trainer(a=a, b=b, eval_len=16) + metrics = trainer.evaluate() + del trainer + gc.collect() + + fp32_init = metrics["init_mem_gpu_alloc_delta"] + fp32_eval = metrics["eval_mem_gpu_alloc_delta"] + + if debug: + print(f"fp32_init {fp32_init}") + print(f"fp32_eval {fp32_eval}") + + # here we expect the model to be preloaded in trainer.__init__ and consume around 64K gpu ram. + # perfect world: fp32_init == 64<<10 + self.assertGreater(fp32_init, 59_000) + # after eval should be no extra memory allocated - with a small margin (other than the peak + # memory consumption for the forward calculation that gets recovered) + # perfect world: fp32_eval == close to zero + self.assertLess(fp32_eval, 5_000) + + # 2. with mem metrics disabled + trainer = get_regression_trainer(a=a, b=b, eval_len=16, fp16_full_eval=True) + metrics = trainer.evaluate() + fp16_init = metrics["init_mem_gpu_alloc_delta"] + fp16_eval = metrics["eval_mem_gpu_alloc_delta"] + + if debug: + print(f"fp16_init {fp16_init}") + print(f"fp16_eval {fp16_eval}") + + # here we expect the model to not be preloaded in trainer.__init__, so with a small margin it should be close to 0 + # perfect world: fp16_init == close to zero + self.assertLess(fp16_init, 5_000) + # here we put the model on device in eval and only `half()` of it, i.e. about 32K,(again we ignore the peak margin which gets returned back) + # perfect world: fp32_init == 32<<10 + self.assertGreater(fp16_eval, 27_000) + + # 3. relative comparison fp32 vs full fp16 + # should be about half of fp16_init + # perfect world: fp32_init/2 == fp16_eval + self.assertAlmostEqual(fp16_eval, fp32_init / 2, delta=5_000) + @require_torch @require_optuna