* Fix gradient accumulation for SM Model Parallelism

* Style and divide loss by grad accum steps
This commit is contained in:
Sylvain Gugger 2021-03-03 12:13:29 -05:00 коммит произвёл GitHub
Родитель d064fb5647
Коммит b70f441b72
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 16 добавлений и 6 удалений

Просмотреть файл

@ -37,9 +37,10 @@ if is_smdistributed_available():
import smdistributed.modelparallel.torch as smp
@smp.step()
def forward_backward(model, inputs):
def forward_backward(model, inputs, gradient_accumulation_steps=1):
outputs = model(**inputs)
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
loss /= gradient_accumulation_steps
model.backward(loss)
return loss
@ -73,8 +74,6 @@ class SageMakerTrainer(Trainer):
def __init__(self, args=None, **kwargs):
self.is_model_parallel_enabled = is_smdistributed_available() and args.mp_parameters != ""
super().__init__(args=args, **kwargs)
if self.is_model_parallel_enabled and self.args.gradient_accumulation_steps != 1:
raise ValueError("Gradient accumulation is not supported when model parallel is enabled.")
def is_world_process_zero(self) -> bool:
"""
@ -108,7 +107,7 @@ class SageMakerTrainer(Trainer):
# Wrapping the base model twice in a DistributedModel will raise an error.
if isinstance(self.model_wrapped, smp.model.DistributedModel):
return self.model_wrapped
return smp.DistributedModel(model)
return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)
else:
return super()._wrap_model(model)
@ -121,7 +120,7 @@ class SageMakerTrainer(Trainer):
if self.is_model_parallel_enabled:
model.train()
inputs = self._prepare_inputs(inputs)
loss_mb = forward_backward(model, inputs)
loss_mb = forward_backward(model, inputs, self.args.gradient_accumulation_steps)
return loss_mb.reduce_mean().detach().to(self.args.device)
else:
return super().training_step(model, inputs)

Просмотреть файл

@ -87,3 +87,7 @@ class SageMakerTrainingArguments(TrainingArguments):
@property
def place_model_on_device(self):
return not (is_smdistributed_available() and self.mp_parameters != "")
@property
def _no_sync_in_gradient_accumulation(self):
return False

Просмотреть файл

@ -1039,7 +1039,7 @@ class Trainer:
if (
((step + 1) % self.args.gradient_accumulation_steps != 0)
and self.args.local_rank != -1
and not self.args.deepspeed
and self.args._no_sync_in_gradient_accumulation
):
# Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
with model.no_sync():

Просмотреть файл

@ -737,6 +737,13 @@ class TrainingArguments:
"""
return True
@property
def _no_sync_in_gradient_accumulation(self):
"""
Whether or not to use no_sync for the gradients when doing gradient accumulation.
"""
return not self.deepspeed
def to_dict(self):
"""
Serializes this instance while replace `Enum` by their values (for JSON serialization support).