Add support for ZeRO-2/3 and ZeRO-offload in fairscale (#10354)
* Ass support for ZeRO-2/3 and ZeRO-offload in fairscale * Quality * Rework from review comments * Add doc * Apply suggestions from code review Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Address review comments Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
Родитель
88cc26dcd1
Коммит
9d14be5c20
|
@ -241,6 +241,8 @@ provides support for the following features from `the ZeRO paper <https://arxiv.
|
|||
|
||||
1. Optimizer State Sharding
|
||||
2. Gradient Sharding
|
||||
3. Model Parameters Sharding (new and very experimental)
|
||||
4. CPU offload (new and very experimental)
|
||||
|
||||
You will need at least two GPUs to use this feature.
|
||||
|
||||
|
@ -255,8 +257,9 @@ To deploy this feature:
|
|||
or find more details on `the FairScale's GitHub page
|
||||
<https://github.com/facebookresearch/fairscale/#installation>`__.
|
||||
|
||||
2. Add ``--sharded_ddp`` to the command line arguments, and make sure you have added the distributed launcher ``-m
|
||||
torch.distributed.launch --nproc_per_node=NUMBER_OF_GPUS_YOU_HAVE`` if you haven't been using it already.
|
||||
2. To use the first version of Sharded data-parallelism, add ``--sharded_ddp simple`` to the command line arguments,
|
||||
and make sure you have added the distributed launcher ``-m torch.distributed.launch
|
||||
--nproc_per_node=NUMBER_OF_GPUS_YOU_HAVE`` if you haven't been using it already.
|
||||
|
||||
For example here is how you could use it for ``run_seq2seq.py`` with 2 GPUs:
|
||||
|
||||
|
@ -268,17 +271,55 @@ For example here is how you could use it for ``run_seq2seq.py`` with 2 GPUs:
|
|||
--do_train --max_train_samples 500 --num_train_epochs 1 \
|
||||
--dataset_name wmt16 --dataset_config "ro-en" \
|
||||
--task translation_en_to_ro --source_prefix "translate English to Romanian: " \
|
||||
--fp16 --sharded_ddp
|
||||
--fp16 --sharded_ddp simple
|
||||
|
||||
Notes:
|
||||
|
||||
- This feature requires distributed training (so multiple GPUs).
|
||||
- It is not implemented for TPUs.
|
||||
- It works with ``--fp16`` too, to make things even faster.
|
||||
- One of the main benefits of enabling ``--sharded_ddp`` is that it uses a lot less GPU memory, so you should be able
|
||||
to use significantly larger batch sizes using the same hardware (e.g. 3x and even bigger) which should lead to
|
||||
- One of the main benefits of enabling ``--sharded_ddp simple`` is that it uses a lot less GPU memory, so you should be
|
||||
able to use significantly larger batch sizes using the same hardware (e.g. 3x and even bigger) which should lead to
|
||||
significantly shorter training time.
|
||||
|
||||
3. To use the second version of Sharded data-parallelism, add ``--sharded_ddp zero_dp_2`` or ``--sharded_ddp zero_dp_3`
|
||||
to the command line arguments, and make sure you have added the distributed launcher ``-m torch.distributed.launch
|
||||
--nproc_per_node=NUMBER_OF_GPUS_YOU_HAVE`` if you haven't been using it already.
|
||||
|
||||
For example here is how you could use it for ``run_seq2seq.py`` with 2 GPUs:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python -m torch.distributed.launch --nproc_per_node=2 examples/seq2seq/run_seq2seq.py \
|
||||
--model_name_or_path t5-small --per_device_train_batch_size 1 \
|
||||
--output_dir output_dir --overwrite_output_dir \
|
||||
--do_train --max_train_samples 500 --num_train_epochs 1 \
|
||||
--dataset_name wmt16 --dataset_config "ro-en" \
|
||||
--task translation_en_to_ro --source_prefix "translate English to Romanian: " \
|
||||
--fp16 --sharded_ddp zero_dp_2
|
||||
|
||||
:obj:`zero_dp_2` is an optimized version of the simple wrapper, while :obj:`zero_dp_3` fully shards model weights,
|
||||
gradients and optimizer states.
|
||||
|
||||
Both are compatible with adding :obj:`cpu_offload` to enable ZeRO-offload (activate it like this: :obj:`--sharded_ddp
|
||||
"zero_dp_2 cpu_offload"`).
|
||||
|
||||
Notes:
|
||||
|
||||
- This feature requires distributed training (so multiple GPUs).
|
||||
- It is not implemented for TPUs.
|
||||
- It works with ``--fp16`` too, to make things even faster.
|
||||
- The ``cpu_offload`` additional option requires ``--fp16``.
|
||||
- This is an area of active development, so make sure you have a source install of fairscale to use this feature as
|
||||
some bugs you encounter may have been fixed there already.
|
||||
|
||||
Known caveats:
|
||||
|
||||
- This feature is incompatible with :obj:`--predict_with_generate` in the `run_seq2seq.py` script.
|
||||
- Using :obj:`--sharded_ddp zero_dp_3` requires wrapping each layer of the model in the special container
|
||||
:obj:`FullyShardedDataParallelism` of fairscale. This is not done automatically by any of the example scripts of the
|
||||
:class:`~transformers.Trainer`.
|
||||
|
||||
|
||||
DeepSpeed
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
|
|
@ -64,12 +64,13 @@ def require_apex(test_case):
|
|||
|
||||
|
||||
class TestTrainerExt(TestCasePlus):
|
||||
def run_seq2seq_quick(self, distributed=False, extra_args_str=None):
|
||||
output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str)
|
||||
def run_seq2seq_quick(self, distributed=False, extra_args_str=None, eval=True, predict_with_generate=True):
|
||||
output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str, predict_with_generate)
|
||||
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
||||
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
||||
first_step_stats = eval_metrics[0]
|
||||
assert "eval_bleu" in first_step_stats
|
||||
if predict_with_generate:
|
||||
assert "eval_bleu" in first_step_stats
|
||||
|
||||
@require_torch_non_multi_gpu
|
||||
def test_run_seq2seq_no_dist(self):
|
||||
|
@ -88,14 +89,28 @@ class TestTrainerExt(TestCasePlus):
|
|||
# test --sharded_ddp w/o --fp16
|
||||
@require_torch_multi_gpu
|
||||
@require_fairscale
|
||||
def test_run_seq2seq_ddp_sharded_ddp(self):
|
||||
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp")
|
||||
def test_run_seq2seq_sharded_ddp(self):
|
||||
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple")
|
||||
|
||||
# test --sharded_ddp w/ --fp16
|
||||
@require_torch_multi_gpu
|
||||
@require_fairscale
|
||||
def test_run_seq2seq_ddp_sharded_ddp_fp16(self):
|
||||
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp --fp16")
|
||||
def test_run_seq2seq_sharded_ddp_fp16(self):
|
||||
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple --fp16")
|
||||
|
||||
# test --sharded_ddp zero2 w/o --fp16
|
||||
@require_torch_multi_gpu
|
||||
@require_fairscale
|
||||
def test_run_seq2seq_fully_sharded_ddp(self):
|
||||
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp zero2", predict_with_generate=False)
|
||||
|
||||
# test --sharded_ddp zero2 w/ --fp16
|
||||
@require_torch_multi_gpu
|
||||
@require_fairscale
|
||||
def test_run_seq2seq_fully_sharded_ddp_fp16(self):
|
||||
self.run_seq2seq_quick(
|
||||
distributed=True, extra_args_str="--sharded_ddp zero2 --fp16", predict_with_generate=False
|
||||
)
|
||||
|
||||
@require_apex
|
||||
def test_run_seq2seq_apex(self):
|
||||
|
@ -131,6 +146,7 @@ class TestTrainerExt(TestCasePlus):
|
|||
num_train_epochs: int,
|
||||
distributed: bool = False,
|
||||
extra_args_str: str = None,
|
||||
predict_with_generate: bool = True,
|
||||
):
|
||||
data_dir = self.examples_dir / "test_data/wmt_en_ro"
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
|
@ -155,7 +171,6 @@ class TestTrainerExt(TestCasePlus):
|
|||
--learning_rate 3e-3
|
||||
--warmup_steps 8
|
||||
--evaluation_strategy steps
|
||||
--predict_with_generate
|
||||
--logging_steps 0
|
||||
--save_steps {str(eval_steps)}
|
||||
--eval_steps {str(eval_steps)}
|
||||
|
@ -165,7 +180,11 @@ class TestTrainerExt(TestCasePlus):
|
|||
--task translation
|
||||
--target_lang ro_RO
|
||||
--source_lang en_XX
|
||||
""".split()
|
||||
"""
|
||||
if predict_with_generate:
|
||||
args += "--predict_with_generate"
|
||||
|
||||
args = args.split()
|
||||
|
||||
if extra_args_str is not None:
|
||||
args.extend(extra_args_str.split())
|
||||
|
|
|
@ -93,6 +93,7 @@ from .trainer_utils import (
|
|||
EvalPrediction,
|
||||
HPSearchBackend,
|
||||
PredictionOutput,
|
||||
ShardedDDPOption,
|
||||
TrainerMemoryTracker,
|
||||
TrainOutput,
|
||||
default_compute_objective,
|
||||
|
@ -131,10 +132,16 @@ if is_torch_tpu_available():
|
|||
import torch_xla.distributed.parallel_loader as pl
|
||||
|
||||
if is_fairscale_available():
|
||||
import fairscale
|
||||
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
|
||||
from fairscale.optim import OSS
|
||||
from fairscale.optim.grad_scaler import ShardedGradScaler
|
||||
|
||||
if version.parse(fairscale.__version__) >= version.parse("0.3"):
|
||||
from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
|
||||
else:
|
||||
FullyShardedDDP = None
|
||||
|
||||
if is_sagemaker_distributed_available():
|
||||
import smdistributed.dataparallel.torch.distributed as dist
|
||||
from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP
|
||||
|
@ -277,9 +284,38 @@ class Trainer:
|
|||
else:
|
||||
self.is_model_parallel = False
|
||||
|
||||
# Setup Sharded DDP training
|
||||
self.sharded_ddp = None
|
||||
if len(args.sharded_ddp) > 0:
|
||||
if args.deepspeed:
|
||||
raise ValueError(
|
||||
"Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags."
|
||||
)
|
||||
|
||||
if args.local_rank == -1:
|
||||
raise ValueError("Using sharded DDP only works in distributed training.")
|
||||
elif not is_fairscale_available():
|
||||
raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
|
||||
elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None:
|
||||
raise ImportError(
|
||||
"Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found "
|
||||
f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`."
|
||||
)
|
||||
elif ShardedDDPOption.SIMPLE in args.sharded_ddp:
|
||||
self.sharded_ddp = ShardedDDPOption.SIMPLE
|
||||
elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp:
|
||||
self.sharded_ddp = ShardedDDPOption.ZERO_DP_2
|
||||
elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp:
|
||||
self.sharded_ddp = ShardedDDPOption.ZERO_DP_3
|
||||
|
||||
# one place to sort out whether to place the model on device or not
|
||||
self.place_model_on_device = args.place_model_on_device
|
||||
if self.is_model_parallel or (args.deepspeed and args.do_train) or (args.fp16_full_eval and not args.do_train):
|
||||
if (
|
||||
self.is_model_parallel
|
||||
or (args.deepspeed and args.do_train)
|
||||
or (args.fp16_full_eval and not args.do_train)
|
||||
or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
|
||||
):
|
||||
self.place_model_on_device = False
|
||||
|
||||
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
|
||||
|
@ -346,21 +382,6 @@ class Trainer:
|
|||
if isinstance(eval_dataset, datasets.Dataset):
|
||||
self._remove_unused_columns(self.eval_dataset, description="evaluation")
|
||||
|
||||
# Setup Sharded DDP training
|
||||
self.sharded_dpp = False
|
||||
if args.sharded_ddp:
|
||||
if args.deepspeed:
|
||||
raise ValueError(
|
||||
"Using --sharded_ddp together with --deepspeed is not possible, deactivate one of those flags."
|
||||
)
|
||||
|
||||
if args.local_rank == -1:
|
||||
raise ValueError("Using sharded DDP only works in distributed training.")
|
||||
elif not is_fairscale_available():
|
||||
raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
|
||||
else:
|
||||
self.sharded_dpp = True
|
||||
|
||||
# Mixed precision setup
|
||||
self.use_apex = False
|
||||
self.use_amp = False
|
||||
|
@ -376,7 +397,7 @@ class Trainer:
|
|||
if args.fp16 and not args.deepspeed: # deepspeed manages its own fp16
|
||||
if self.fp16_backend == "amp":
|
||||
self.use_amp = True
|
||||
self.scaler = ShardedGradScaler() if self.sharded_dpp else torch.cuda.amp.GradScaler()
|
||||
self.scaler = ShardedGradScaler() if self.sharded_ddp is not None else torch.cuda.amp.GradScaler()
|
||||
else:
|
||||
if not is_apex_available():
|
||||
raise ImportError(
|
||||
|
@ -619,7 +640,7 @@ class Trainer:
|
|||
"eps": self.args.adam_epsilon,
|
||||
}
|
||||
optimizer_kwargs["lr"] = self.args.learning_rate
|
||||
if self.sharded_dpp:
|
||||
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
|
||||
self.optimizer = OSS(
|
||||
params=optimizer_grouped_parameters,
|
||||
optim=optimizer_cls,
|
||||
|
@ -737,8 +758,19 @@ class Trainer:
|
|||
return model
|
||||
|
||||
# Distributed training (should be after apex fp16 initialization)
|
||||
if self.sharded_dpp:
|
||||
model = ShardedDDP(model, self.optimizer)
|
||||
if self.sharded_ddp is not None:
|
||||
# Sharded DDP!
|
||||
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
|
||||
model = ShardedDDP(model, self.optimizer)
|
||||
else:
|
||||
mixed_precision = self.args.fp16
|
||||
cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp
|
||||
zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3
|
||||
# XXX: Breaking the self.model convention but I see no way around it for now.
|
||||
self.model = model = FullyShardedDDP(
|
||||
model, mixed_precision=mixed_precision, reshard_after_forward=zero_3, cpu_offload=cpu_offload
|
||||
).to(self.args.device)
|
||||
|
||||
elif is_sagemaker_distributed_available():
|
||||
model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False)
|
||||
elif self.args.local_rank != -1:
|
||||
|
@ -855,6 +887,7 @@ class Trainer:
|
|||
num_train_epochs = 1
|
||||
num_update_steps_per_epoch = max_steps
|
||||
|
||||
delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE
|
||||
if self.args.deepspeed:
|
||||
model, optimizer, lr_scheduler = init_deepspeed(self, num_training_steps=max_steps)
|
||||
self.model = model.module
|
||||
|
@ -862,7 +895,7 @@ class Trainer:
|
|||
self.deepspeed = model # DeepSpeedEngine object
|
||||
self.optimizer = optimizer
|
||||
self.lr_scheduler = lr_scheduler
|
||||
else:
|
||||
elif not delay_optimizer_creation:
|
||||
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
||||
|
||||
self.state = TrainerState()
|
||||
|
@ -877,6 +910,9 @@ class Trainer:
|
|||
if model is not self.model:
|
||||
self.model_wrapped = model
|
||||
|
||||
if delay_optimizer_creation:
|
||||
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
||||
|
||||
# important: at this point:
|
||||
# self.model is the Transformers Model
|
||||
# self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
|
||||
|
@ -1026,6 +1062,9 @@ class Trainer:
|
|||
if hasattr(self.optimizer, "clip_grad_norm"):
|
||||
# Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
|
||||
self.optimizer.clip_grad_norm(self.args.max_grad_norm)
|
||||
elif hasattr(model, "clip_grad_norm_"):
|
||||
# Some models (like FullyShardedDDP) have a specific way to do gradient clipping
|
||||
model.clip_grad_norm_(self.args.max_grad_norm)
|
||||
else:
|
||||
# Revert to normal clipping otherwise, handling Apex or full precision
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
|
@ -1148,8 +1187,8 @@ class Trainer:
|
|||
|
||||
def _save_checkpoint(self, model, trial, metrics=None):
|
||||
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
|
||||
# want to save.
|
||||
assert _model_unwrap(model) is self.model, "internal model should be a reference to self.model"
|
||||
# want to save except FullyShardedDDP.
|
||||
# assert _model_unwrap(model) is self.model, "internal model should be a reference to self.model"
|
||||
|
||||
# Save model checkpoint
|
||||
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
||||
|
@ -1173,7 +1212,7 @@ class Trainer:
|
|||
self.deepspeed.save_checkpoint(output_dir)
|
||||
|
||||
# Save optimizer and scheduler
|
||||
if self.sharded_dpp:
|
||||
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
|
||||
self.optimizer.consolidate_state_dict()
|
||||
|
||||
if is_torch_tpu_available():
|
||||
|
@ -1479,7 +1518,11 @@ class Trainer:
|
|||
# They can then be reloaded using `from_pretrained()`
|
||||
xm.rendezvous("saving_checkpoint")
|
||||
if not isinstance(self.model, PreTrainedModel):
|
||||
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
|
||||
if isinstance(_model_unwrap(self.model), PreTrainedModel):
|
||||
if xm.is_master_ordinal():
|
||||
_model_unwrap(self.model).config.save_pretrained(output_dir)
|
||||
else:
|
||||
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
|
||||
state_dict = self.model.state_dict()
|
||||
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||
else:
|
||||
|
@ -1494,7 +1537,10 @@ class Trainer:
|
|||
# Save a trained model and configuration using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
if not isinstance(self.model, PreTrainedModel):
|
||||
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
|
||||
if isinstance(_model_unwrap(self.model), PreTrainedModel):
|
||||
_model_unwrap(self.model).config.save_pretrained(output_dir)
|
||||
else:
|
||||
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
|
||||
state_dict = self.model.state_dict()
|
||||
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||
else:
|
||||
|
|
|
@ -421,3 +421,10 @@ class TrainerMemoryTracker:
|
|||
# init doesn't have metrics to update so we just save that data for later stages to retrieve
|
||||
if metrics is not None:
|
||||
self.update_metrics(stage, metrics)
|
||||
|
||||
|
||||
class ShardedDDPOption(ExplicitEnum):
|
||||
SIMPLE = "simple"
|
||||
ZERO_DP_2 = "zero2"
|
||||
ZERO_DP_3 = "zero3"
|
||||
OFFLOAD = "offload"
|
||||
|
|
|
@ -25,7 +25,7 @@ from .file_utils import (
|
|||
is_torch_tpu_available,
|
||||
torch_required,
|
||||
)
|
||||
from .trainer_utils import EvaluationStrategy, LoggingStrategy, SchedulerType
|
||||
from .trainer_utils import EvaluationStrategy, LoggingStrategy, SchedulerType, ShardedDDPOption
|
||||
from .utils import logging
|
||||
|
||||
|
||||
|
@ -236,9 +236,22 @@ class TrainingArguments:
|
|||
When resuming training, whether or not to skip the epochs and batches to get the data loading at the same
|
||||
stage as in the previous training. If set to :obj:`True`, the training will begin faster (as that skipping
|
||||
step can take a long time) but will not yield the same results as the interrupted training would have.
|
||||
sharded_ddp (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
sharded_ddp (:obj:`bool`, :obj:`str` or list of :class:`~transformers.trainer_utils.ShardedDDPOption`, `optional`, defaults to :obj:`False`):
|
||||
Use Sharded DDP training from `FairScale <https://github.com/facebookresearch/fairscale>`__ (in distributed
|
||||
training only). This is an experimental feature.
|
||||
|
||||
A list of options along the following:
|
||||
|
||||
- :obj:`"simple"`: to use first instance of sharded DDP released by fairscale (:obj:`ShardedDDP`) similar
|
||||
to ZeRO-2.
|
||||
- :obj:`"zero_dp_2"`: to use the second instance of sharded DPP released by fairscale
|
||||
(:obj:`FullyShardedDDP`) in Zero-2 mode (with :obj:`reshard_after_forward=False`).
|
||||
- :obj:`"zero_dp_3"`: to use the second instance of sharded DPP released by fairscale
|
||||
(:obj:`FullyShardedDDP`) in Zero-3 mode (with :obj:`reshard_after_forward=True`).
|
||||
- :obj:`"offload"`: to add ZeRO-offload (only compatible with :obj:`"zero_dp_2"` and :obj:`"zero_dp_3"`).
|
||||
|
||||
If a string is passed, it will be split on space. If a bool is passed, it will be converted to an empty
|
||||
list for :obj:`False` and :obj:`["simple"]` for :obj:`True`.
|
||||
deepspeed (:obj:`str`, `optional`):
|
||||
Use `Deepspeed <https://github.com/microsoft/deepspeed>`__. This is an experimental feature and its API may
|
||||
evolve in the future. The value is the location of its json config file (usually ``ds_config.json``).
|
||||
|
@ -443,9 +456,14 @@ class TrainingArguments:
|
|||
"help": "When resuming training, whether or not to skip the first epochs and batches to get to the same training data."
|
||||
},
|
||||
)
|
||||
sharded_ddp: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use sharded DDP training (in distributed training only)."},
|
||||
sharded_ddp: str = field(
|
||||
default="",
|
||||
metadata={
|
||||
"choices": ["simple", "zero_dp_2", "zero_dp_3", "zero_dp_2 offload", "zero_dp_3 offload"],
|
||||
"help": "Whether or not to use sharded DDP training (in distributed training only). The base option "
|
||||
"should be `simple`, `zero_dp_2` or `zero_dp_3` and you can add CPU-offload to `zero_dp_2` or `zero_dp_3` "
|
||||
"like this: zero_dp_2 offload` or `zero_dp_3 offload`",
|
||||
},
|
||||
)
|
||||
deepspeed: Optional[str] = field(
|
||||
default=None,
|
||||
|
@ -535,6 +553,20 @@ class TrainingArguments:
|
|||
"Both warmup_ratio and warmup_steps given, warmup_steps will override any effect of warmup_ratio during training"
|
||||
)
|
||||
|
||||
if isinstance(self.sharded_ddp, bool):
|
||||
self.sharded_ddp = "simple" if self.sharded_ddp else ""
|
||||
if isinstance(self.sharded_ddp, str):
|
||||
self.sharded_ddp = [ShardedDDPOption(s) for s in self.sharded_ddp.split()]
|
||||
if self.sharded_ddp == [ShardedDDPOption.OFFLOAD]:
|
||||
raise ValueError(
|
||||
"`--sharded_ddp offload` can't work on its own. It needs to be added to `--sharded_ddp zero_dp_2` or "
|
||||
'`--sharded_ddp zero_dp_3`. For example, `--sharded_ddp "zero_dp_2 offload"`.'
|
||||
)
|
||||
elif len(self.sharded_ddp) > 1 and ShardedDDPOption.Simple in self.sharded_ddp:
|
||||
raise ValueError("`--sharded_ddp simple` is not compatible with any other option.")
|
||||
elif ShardedDDPOption.ZERO_DP_2 in self.sharded_ddp and ShardedDDPOption.ZERO_DP_3 in self.sharded_ddp:
|
||||
raise ValueError("`--sharded_ddp zero_dp_2` is not compatible with `--sharded_ddp zero_dp_3`.")
|
||||
|
||||
def __repr__(self):
|
||||
# We override the default repr to remove deprecated arguments from the repr. This method should be removed once
|
||||
# those deprecated arguments are removed form TrainingArguments. (TODO: v5)
|
||||
|
@ -662,7 +694,7 @@ class TrainingArguments:
|
|||
|
||||
- :obj:`ParallelMode.NOT_PARALLEL`: no parallelism (CPU or one GPU).
|
||||
- :obj:`ParallelMode.NOT_DISTRIBUTED`: several GPUs in one single process (uses :obj:`torch.nn.DataParallel`).
|
||||
- :obj:`ParallelMode.DISTRIBUTED`: several GPUs, each ahving its own process (uses
|
||||
- :obj:`ParallelMode.DISTRIBUTED`: several GPUs, each having its own process (uses
|
||||
:obj:`torch.nn.DistributedDataParallel`).
|
||||
- :obj:`ParallelMode.TPU`: several TPU cores.
|
||||
"""
|
||||
|
@ -692,6 +724,8 @@ class TrainingArguments:
|
|||
for k, v in d.items():
|
||||
if isinstance(v, Enum):
|
||||
d[k] = v.value
|
||||
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
|
||||
d[k] = [x.value for x in v]
|
||||
return d
|
||||
|
||||
def to_json_string(self):
|
||||
|
|
Загрузка…
Ссылка в новой задаче