Smp grad accum (#10488)
* Fix gradient accumulation for SM Model Parallelism * Style and divide loss by grad accum steps
This commit is contained in:
Родитель
d064fb5647
Коммит
b70f441b72
|
@ -37,9 +37,10 @@ if is_smdistributed_available():
|
|||
import smdistributed.modelparallel.torch as smp
|
||||
|
||||
@smp.step()
|
||||
def forward_backward(model, inputs):
|
||||
def forward_backward(model, inputs, gradient_accumulation_steps=1):
|
||||
outputs = model(**inputs)
|
||||
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
|
||||
loss /= gradient_accumulation_steps
|
||||
model.backward(loss)
|
||||
return loss
|
||||
|
||||
|
@ -73,8 +74,6 @@ class SageMakerTrainer(Trainer):
|
|||
def __init__(self, args=None, **kwargs):
|
||||
self.is_model_parallel_enabled = is_smdistributed_available() and args.mp_parameters != ""
|
||||
super().__init__(args=args, **kwargs)
|
||||
if self.is_model_parallel_enabled and self.args.gradient_accumulation_steps != 1:
|
||||
raise ValueError("Gradient accumulation is not supported when model parallel is enabled.")
|
||||
|
||||
def is_world_process_zero(self) -> bool:
|
||||
"""
|
||||
|
@ -108,7 +107,7 @@ class SageMakerTrainer(Trainer):
|
|||
# Wrapping the base model twice in a DistributedModel will raise an error.
|
||||
if isinstance(self.model_wrapped, smp.model.DistributedModel):
|
||||
return self.model_wrapped
|
||||
return smp.DistributedModel(model)
|
||||
return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)
|
||||
else:
|
||||
return super()._wrap_model(model)
|
||||
|
||||
|
@ -121,7 +120,7 @@ class SageMakerTrainer(Trainer):
|
|||
if self.is_model_parallel_enabled:
|
||||
model.train()
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
loss_mb = forward_backward(model, inputs)
|
||||
loss_mb = forward_backward(model, inputs, self.args.gradient_accumulation_steps)
|
||||
return loss_mb.reduce_mean().detach().to(self.args.device)
|
||||
else:
|
||||
return super().training_step(model, inputs)
|
||||
|
|
|
@ -87,3 +87,7 @@ class SageMakerTrainingArguments(TrainingArguments):
|
|||
@property
|
||||
def place_model_on_device(self):
|
||||
return not (is_smdistributed_available() and self.mp_parameters != "")
|
||||
|
||||
@property
|
||||
def _no_sync_in_gradient_accumulation(self):
|
||||
return False
|
||||
|
|
|
@ -1039,7 +1039,7 @@ class Trainer:
|
|||
if (
|
||||
((step + 1) % self.args.gradient_accumulation_steps != 0)
|
||||
and self.args.local_rank != -1
|
||||
and not self.args.deepspeed
|
||||
and self.args._no_sync_in_gradient_accumulation
|
||||
):
|
||||
# Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
|
||||
with model.no_sync():
|
||||
|
|
|
@ -737,6 +737,13 @@ class TrainingArguments:
|
|||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def _no_sync_in_gradient_accumulation(self):
|
||||
"""
|
||||
Whether or not to use no_sync for the gradients when doing gradient accumulation.
|
||||
"""
|
||||
return not self.deepspeed
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes this instance while replace `Enum` by their values (for JSON serialization support).
|
||||
|
|
Загрузка…
Ссылка в новой задаче