From 9d14be5c20478a030b5d9cbf335f4842e10e7e0b Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Thu, 25 Feb 2021 11:07:53 -0500 Subject: [PATCH] Add support for ZeRO-2/3 and ZeRO-offload in fairscale (#10354) * Ass support for ZeRO-2/3 and ZeRO-offload in fairscale * Quality * Rework from review comments * Add doc * Apply suggestions from code review Co-authored-by: Stas Bekman * Address review comments Co-authored-by: Stas Bekman --- docs/source/main_classes/trainer.rst | 51 +++++++++-- examples/tests/trainer/test_trainer_ext.py | 37 ++++++-- src/transformers/trainer.py | 98 ++++++++++++++++------ src/transformers/trainer_utils.py | 7 ++ src/transformers/training_args.py | 46 ++++++++-- 5 files changed, 193 insertions(+), 46 deletions(-) diff --git a/docs/source/main_classes/trainer.rst b/docs/source/main_classes/trainer.rst index 537c8df84..a6edaccf3 100644 --- a/docs/source/main_classes/trainer.rst +++ b/docs/source/main_classes/trainer.rst @@ -241,6 +241,8 @@ provides support for the following features from `the ZeRO paper `__. -2. Add ``--sharded_ddp`` to the command line arguments, and make sure you have added the distributed launcher ``-m - torch.distributed.launch --nproc_per_node=NUMBER_OF_GPUS_YOU_HAVE`` if you haven't been using it already. +2. To use the first version of Sharded data-parallelism, add ``--sharded_ddp simple`` to the command line arguments, + and make sure you have added the distributed launcher ``-m torch.distributed.launch + --nproc_per_node=NUMBER_OF_GPUS_YOU_HAVE`` if you haven't been using it already. For example here is how you could use it for ``run_seq2seq.py`` with 2 GPUs: @@ -268,17 +271,55 @@ For example here is how you could use it for ``run_seq2seq.py`` with 2 GPUs: --do_train --max_train_samples 500 --num_train_epochs 1 \ --dataset_name wmt16 --dataset_config "ro-en" \ --task translation_en_to_ro --source_prefix "translate English to Romanian: " \ - --fp16 --sharded_ddp + --fp16 --sharded_ddp simple Notes: - This feature requires distributed training (so multiple GPUs). - It is not implemented for TPUs. - It works with ``--fp16`` too, to make things even faster. -- One of the main benefits of enabling ``--sharded_ddp`` is that it uses a lot less GPU memory, so you should be able - to use significantly larger batch sizes using the same hardware (e.g. 3x and even bigger) which should lead to +- One of the main benefits of enabling ``--sharded_ddp simple`` is that it uses a lot less GPU memory, so you should be + able to use significantly larger batch sizes using the same hardware (e.g. 3x and even bigger) which should lead to significantly shorter training time. +3. To use the second version of Sharded data-parallelism, add ``--sharded_ddp zero_dp_2`` or ``--sharded_ddp zero_dp_3` + to the command line arguments, and make sure you have added the distributed launcher ``-m torch.distributed.launch + --nproc_per_node=NUMBER_OF_GPUS_YOU_HAVE`` if you haven't been using it already. + +For example here is how you could use it for ``run_seq2seq.py`` with 2 GPUs: + +.. code-block:: bash + + python -m torch.distributed.launch --nproc_per_node=2 examples/seq2seq/run_seq2seq.py \ + --model_name_or_path t5-small --per_device_train_batch_size 1 \ + --output_dir output_dir --overwrite_output_dir \ + --do_train --max_train_samples 500 --num_train_epochs 1 \ + --dataset_name wmt16 --dataset_config "ro-en" \ + --task translation_en_to_ro --source_prefix "translate English to Romanian: " \ + --fp16 --sharded_ddp zero_dp_2 + +:obj:`zero_dp_2` is an optimized version of the simple wrapper, while :obj:`zero_dp_3` fully shards model weights, +gradients and optimizer states. + +Both are compatible with adding :obj:`cpu_offload` to enable ZeRO-offload (activate it like this: :obj:`--sharded_ddp +"zero_dp_2 cpu_offload"`). + +Notes: + +- This feature requires distributed training (so multiple GPUs). +- It is not implemented for TPUs. +- It works with ``--fp16`` too, to make things even faster. +- The ``cpu_offload`` additional option requires ``--fp16``. +- This is an area of active development, so make sure you have a source install of fairscale to use this feature as + some bugs you encounter may have been fixed there already. + +Known caveats: + +- This feature is incompatible with :obj:`--predict_with_generate` in the `run_seq2seq.py` script. +- Using :obj:`--sharded_ddp zero_dp_3` requires wrapping each layer of the model in the special container + :obj:`FullyShardedDataParallelism` of fairscale. This is not done automatically by any of the example scripts of the + :class:`~transformers.Trainer`. + DeepSpeed ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/examples/tests/trainer/test_trainer_ext.py b/examples/tests/trainer/test_trainer_ext.py index d74cd84b6..9da3c1bec 100644 --- a/examples/tests/trainer/test_trainer_ext.py +++ b/examples/tests/trainer/test_trainer_ext.py @@ -64,12 +64,13 @@ def require_apex(test_case): class TestTrainerExt(TestCasePlus): - def run_seq2seq_quick(self, distributed=False, extra_args_str=None): - output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str) + def run_seq2seq_quick(self, distributed=False, extra_args_str=None, eval=True, predict_with_generate=True): + output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str, predict_with_generate) logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history eval_metrics = [log for log in logs if "eval_loss" in log.keys()] first_step_stats = eval_metrics[0] - assert "eval_bleu" in first_step_stats + if predict_with_generate: + assert "eval_bleu" in first_step_stats @require_torch_non_multi_gpu def test_run_seq2seq_no_dist(self): @@ -88,14 +89,28 @@ class TestTrainerExt(TestCasePlus): # test --sharded_ddp w/o --fp16 @require_torch_multi_gpu @require_fairscale - def test_run_seq2seq_ddp_sharded_ddp(self): - self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp") + def test_run_seq2seq_sharded_ddp(self): + self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple") # test --sharded_ddp w/ --fp16 @require_torch_multi_gpu @require_fairscale - def test_run_seq2seq_ddp_sharded_ddp_fp16(self): - self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp --fp16") + def test_run_seq2seq_sharded_ddp_fp16(self): + self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple --fp16") + + # test --sharded_ddp zero2 w/o --fp16 + @require_torch_multi_gpu + @require_fairscale + def test_run_seq2seq_fully_sharded_ddp(self): + self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp zero2", predict_with_generate=False) + + # test --sharded_ddp zero2 w/ --fp16 + @require_torch_multi_gpu + @require_fairscale + def test_run_seq2seq_fully_sharded_ddp_fp16(self): + self.run_seq2seq_quick( + distributed=True, extra_args_str="--sharded_ddp zero2 --fp16", predict_with_generate=False + ) @require_apex def test_run_seq2seq_apex(self): @@ -131,6 +146,7 @@ class TestTrainerExt(TestCasePlus): num_train_epochs: int, distributed: bool = False, extra_args_str: str = None, + predict_with_generate: bool = True, ): data_dir = self.examples_dir / "test_data/wmt_en_ro" output_dir = self.get_auto_remove_tmp_dir() @@ -155,7 +171,6 @@ class TestTrainerExt(TestCasePlus): --learning_rate 3e-3 --warmup_steps 8 --evaluation_strategy steps - --predict_with_generate --logging_steps 0 --save_steps {str(eval_steps)} --eval_steps {str(eval_steps)} @@ -165,7 +180,11 @@ class TestTrainerExt(TestCasePlus): --task translation --target_lang ro_RO --source_lang en_XX - """.split() + """ + if predict_with_generate: + args += "--predict_with_generate" + + args = args.split() if extra_args_str is not None: args.extend(extra_args_str.split()) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 2f2030a9d..815e14d5e 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -93,6 +93,7 @@ from .trainer_utils import ( EvalPrediction, HPSearchBackend, PredictionOutput, + ShardedDDPOption, TrainerMemoryTracker, TrainOutput, default_compute_objective, @@ -131,10 +132,16 @@ if is_torch_tpu_available(): import torch_xla.distributed.parallel_loader as pl if is_fairscale_available(): + import fairscale from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP from fairscale.optim import OSS from fairscale.optim.grad_scaler import ShardedGradScaler + if version.parse(fairscale.__version__) >= version.parse("0.3"): + from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP + else: + FullyShardedDDP = None + if is_sagemaker_distributed_available(): import smdistributed.dataparallel.torch.distributed as dist from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP @@ -277,9 +284,38 @@ class Trainer: else: self.is_model_parallel = False + # Setup Sharded DDP training + self.sharded_ddp = None + if len(args.sharded_ddp) > 0: + if args.deepspeed: + raise ValueError( + "Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags." + ) + + if args.local_rank == -1: + raise ValueError("Using sharded DDP only works in distributed training.") + elif not is_fairscale_available(): + raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.") + elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None: + raise ImportError( + "Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found " + f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`." + ) + elif ShardedDDPOption.SIMPLE in args.sharded_ddp: + self.sharded_ddp = ShardedDDPOption.SIMPLE + elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp: + self.sharded_ddp = ShardedDDPOption.ZERO_DP_2 + elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp: + self.sharded_ddp = ShardedDDPOption.ZERO_DP_3 + # one place to sort out whether to place the model on device or not self.place_model_on_device = args.place_model_on_device - if self.is_model_parallel or (args.deepspeed and args.do_train) or (args.fp16_full_eval and not args.do_train): + if ( + self.is_model_parallel + or (args.deepspeed and args.do_train) + or (args.fp16_full_eval and not args.do_train) + or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3]) + ): self.place_model_on_device = False default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) @@ -346,21 +382,6 @@ class Trainer: if isinstance(eval_dataset, datasets.Dataset): self._remove_unused_columns(self.eval_dataset, description="evaluation") - # Setup Sharded DDP training - self.sharded_dpp = False - if args.sharded_ddp: - if args.deepspeed: - raise ValueError( - "Using --sharded_ddp together with --deepspeed is not possible, deactivate one of those flags." - ) - - if args.local_rank == -1: - raise ValueError("Using sharded DDP only works in distributed training.") - elif not is_fairscale_available(): - raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.") - else: - self.sharded_dpp = True - # Mixed precision setup self.use_apex = False self.use_amp = False @@ -376,7 +397,7 @@ class Trainer: if args.fp16 and not args.deepspeed: # deepspeed manages its own fp16 if self.fp16_backend == "amp": self.use_amp = True - self.scaler = ShardedGradScaler() if self.sharded_dpp else torch.cuda.amp.GradScaler() + self.scaler = ShardedGradScaler() if self.sharded_ddp is not None else torch.cuda.amp.GradScaler() else: if not is_apex_available(): raise ImportError( @@ -619,7 +640,7 @@ class Trainer: "eps": self.args.adam_epsilon, } optimizer_kwargs["lr"] = self.args.learning_rate - if self.sharded_dpp: + if self.sharded_ddp == ShardedDDPOption.SIMPLE: self.optimizer = OSS( params=optimizer_grouped_parameters, optim=optimizer_cls, @@ -737,8 +758,19 @@ class Trainer: return model # Distributed training (should be after apex fp16 initialization) - if self.sharded_dpp: - model = ShardedDDP(model, self.optimizer) + if self.sharded_ddp is not None: + # Sharded DDP! + if self.sharded_ddp == ShardedDDPOption.SIMPLE: + model = ShardedDDP(model, self.optimizer) + else: + mixed_precision = self.args.fp16 + cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp + zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3 + # XXX: Breaking the self.model convention but I see no way around it for now. + self.model = model = FullyShardedDDP( + model, mixed_precision=mixed_precision, reshard_after_forward=zero_3, cpu_offload=cpu_offload + ).to(self.args.device) + elif is_sagemaker_distributed_available(): model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False) elif self.args.local_rank != -1: @@ -855,6 +887,7 @@ class Trainer: num_train_epochs = 1 num_update_steps_per_epoch = max_steps + delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE if self.args.deepspeed: model, optimizer, lr_scheduler = init_deepspeed(self, num_training_steps=max_steps) self.model = model.module @@ -862,7 +895,7 @@ class Trainer: self.deepspeed = model # DeepSpeedEngine object self.optimizer = optimizer self.lr_scheduler = lr_scheduler - else: + elif not delay_optimizer_creation: self.create_optimizer_and_scheduler(num_training_steps=max_steps) self.state = TrainerState() @@ -877,6 +910,9 @@ class Trainer: if model is not self.model: self.model_wrapped = model + if delay_optimizer_creation: + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + # important: at this point: # self.model is the Transformers Model # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc. @@ -1026,6 +1062,9 @@ class Trainer: if hasattr(self.optimizer, "clip_grad_norm"): # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping self.optimizer.clip_grad_norm(self.args.max_grad_norm) + elif hasattr(model, "clip_grad_norm_"): + # Some models (like FullyShardedDDP) have a specific way to do gradient clipping + model.clip_grad_norm_(self.args.max_grad_norm) else: # Revert to normal clipping otherwise, handling Apex or full precision torch.nn.utils.clip_grad_norm_( @@ -1148,8 +1187,8 @@ class Trainer: def _save_checkpoint(self, model, trial, metrics=None): # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we - # want to save. - assert _model_unwrap(model) is self.model, "internal model should be a reference to self.model" + # want to save except FullyShardedDDP. + # assert _model_unwrap(model) is self.model, "internal model should be a reference to self.model" # Save model checkpoint checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" @@ -1173,7 +1212,7 @@ class Trainer: self.deepspeed.save_checkpoint(output_dir) # Save optimizer and scheduler - if self.sharded_dpp: + if self.sharded_ddp == ShardedDDPOption.SIMPLE: self.optimizer.consolidate_state_dict() if is_torch_tpu_available(): @@ -1479,7 +1518,11 @@ class Trainer: # They can then be reloaded using `from_pretrained()` xm.rendezvous("saving_checkpoint") if not isinstance(self.model, PreTrainedModel): - logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") + if isinstance(_model_unwrap(self.model), PreTrainedModel): + if xm.is_master_ordinal(): + _model_unwrap(self.model).config.save_pretrained(output_dir) + else: + logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") state_dict = self.model.state_dict() xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) else: @@ -1494,7 +1537,10 @@ class Trainer: # Save a trained model and configuration using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` if not isinstance(self.model, PreTrainedModel): - logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") + if isinstance(_model_unwrap(self.model), PreTrainedModel): + _model_unwrap(self.model).config.save_pretrained(output_dir) + else: + logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") state_dict = self.model.state_dict() torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) else: diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 76622d34a..cd70001c7 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -421,3 +421,10 @@ class TrainerMemoryTracker: # init doesn't have metrics to update so we just save that data for later stages to retrieve if metrics is not None: self.update_metrics(stage, metrics) + + +class ShardedDDPOption(ExplicitEnum): + SIMPLE = "simple" + ZERO_DP_2 = "zero2" + ZERO_DP_3 = "zero3" + OFFLOAD = "offload" diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 22504aa10..90c04f89d 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -25,7 +25,7 @@ from .file_utils import ( is_torch_tpu_available, torch_required, ) -from .trainer_utils import EvaluationStrategy, LoggingStrategy, SchedulerType +from .trainer_utils import EvaluationStrategy, LoggingStrategy, SchedulerType, ShardedDDPOption from .utils import logging @@ -236,9 +236,22 @@ class TrainingArguments: When resuming training, whether or not to skip the epochs and batches to get the data loading at the same stage as in the previous training. If set to :obj:`True`, the training will begin faster (as that skipping step can take a long time) but will not yield the same results as the interrupted training would have. - sharded_ddp (:obj:`bool`, `optional`, defaults to :obj:`False`): + sharded_ddp (:obj:`bool`, :obj:`str` or list of :class:`~transformers.trainer_utils.ShardedDDPOption`, `optional`, defaults to :obj:`False`): Use Sharded DDP training from `FairScale `__ (in distributed training only). This is an experimental feature. + + A list of options along the following: + + - :obj:`"simple"`: to use first instance of sharded DDP released by fairscale (:obj:`ShardedDDP`) similar + to ZeRO-2. + - :obj:`"zero_dp_2"`: to use the second instance of sharded DPP released by fairscale + (:obj:`FullyShardedDDP`) in Zero-2 mode (with :obj:`reshard_after_forward=False`). + - :obj:`"zero_dp_3"`: to use the second instance of sharded DPP released by fairscale + (:obj:`FullyShardedDDP`) in Zero-3 mode (with :obj:`reshard_after_forward=True`). + - :obj:`"offload"`: to add ZeRO-offload (only compatible with :obj:`"zero_dp_2"` and :obj:`"zero_dp_3"`). + + If a string is passed, it will be split on space. If a bool is passed, it will be converted to an empty + list for :obj:`False` and :obj:`["simple"]` for :obj:`True`. deepspeed (:obj:`str`, `optional`): Use `Deepspeed `__. This is an experimental feature and its API may evolve in the future. The value is the location of its json config file (usually ``ds_config.json``). @@ -443,9 +456,14 @@ class TrainingArguments: "help": "When resuming training, whether or not to skip the first epochs and batches to get to the same training data." }, ) - sharded_ddp: bool = field( - default=False, - metadata={"help": "Whether or not to use sharded DDP training (in distributed training only)."}, + sharded_ddp: str = field( + default="", + metadata={ + "choices": ["simple", "zero_dp_2", "zero_dp_3", "zero_dp_2 offload", "zero_dp_3 offload"], + "help": "Whether or not to use sharded DDP training (in distributed training only). The base option " + "should be `simple`, `zero_dp_2` or `zero_dp_3` and you can add CPU-offload to `zero_dp_2` or `zero_dp_3` " + "like this: zero_dp_2 offload` or `zero_dp_3 offload`", + }, ) deepspeed: Optional[str] = field( default=None, @@ -535,6 +553,20 @@ class TrainingArguments: "Both warmup_ratio and warmup_steps given, warmup_steps will override any effect of warmup_ratio during training" ) + if isinstance(self.sharded_ddp, bool): + self.sharded_ddp = "simple" if self.sharded_ddp else "" + if isinstance(self.sharded_ddp, str): + self.sharded_ddp = [ShardedDDPOption(s) for s in self.sharded_ddp.split()] + if self.sharded_ddp == [ShardedDDPOption.OFFLOAD]: + raise ValueError( + "`--sharded_ddp offload` can't work on its own. It needs to be added to `--sharded_ddp zero_dp_2` or " + '`--sharded_ddp zero_dp_3`. For example, `--sharded_ddp "zero_dp_2 offload"`.' + ) + elif len(self.sharded_ddp) > 1 and ShardedDDPOption.Simple in self.sharded_ddp: + raise ValueError("`--sharded_ddp simple` is not compatible with any other option.") + elif ShardedDDPOption.ZERO_DP_2 in self.sharded_ddp and ShardedDDPOption.ZERO_DP_3 in self.sharded_ddp: + raise ValueError("`--sharded_ddp zero_dp_2` is not compatible with `--sharded_ddp zero_dp_3`.") + def __repr__(self): # We override the default repr to remove deprecated arguments from the repr. This method should be removed once # those deprecated arguments are removed form TrainingArguments. (TODO: v5) @@ -662,7 +694,7 @@ class TrainingArguments: - :obj:`ParallelMode.NOT_PARALLEL`: no parallelism (CPU or one GPU). - :obj:`ParallelMode.NOT_DISTRIBUTED`: several GPUs in one single process (uses :obj:`torch.nn.DataParallel`). - - :obj:`ParallelMode.DISTRIBUTED`: several GPUs, each ahving its own process (uses + - :obj:`ParallelMode.DISTRIBUTED`: several GPUs, each having its own process (uses :obj:`torch.nn.DistributedDataParallel`). - :obj:`ParallelMode.TPU`: several TPU cores. """ @@ -692,6 +724,8 @@ class TrainingArguments: for k, v in d.items(): if isinstance(v, Enum): d[k] = v.value + if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum): + d[k] = [x.value for x in v] return d def to_json_string(self):