зеркало из https://github.com/microsoft/DeepSpeed.git
Add no_sync context manager (#6675)
Fix #1902 --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
This commit is contained in:
Родитель
d702eb5f79
Коммит
fc4e73370d
|
@ -14,7 +14,7 @@ from ..inference_parameter import InferenceParameter
|
|||
|
||||
# Currently have dependency loops for the type hints.
|
||||
InferenceModel = Type["InferenceModel"]
|
||||
LayerContainer = Type["LayerContainer"]
|
||||
LayerContainer = Type["LayerContainer"] # noqa: F811
|
||||
|
||||
MAPPING_KEY = "PARAM_MAPPING"
|
||||
PLIST_HELPERS = "_ds_plist_strip_vals"
|
||||
|
@ -161,7 +161,7 @@ class LayerMetaclass(type):
|
|||
return instance
|
||||
|
||||
|
||||
class LayerContainer(metaclass=LayerMetaclass):
|
||||
class LayerContainer(metaclass=LayerMetaclass): # noqa: F811
|
||||
"""
|
||||
Abstract base class for containing model parameters.
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ from torch.nn.parameter import Parameter
|
|||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
from contextlib import contextmanager
|
||||
|
||||
from typing import Callable, Dict, Union, Iterable, Container
|
||||
|
||||
|
@ -216,6 +217,7 @@ class DeepSpeedEngine(Module):
|
|||
self.loaded_checkpoint_mp_world_size = None
|
||||
self.loaded_checkpoint_dp_world_size = None
|
||||
self.enable_backward_allreduce = True
|
||||
self.inside_no_sync_ctxt = False
|
||||
self.progressive_layer_drop = None
|
||||
self.eigenvalue = None
|
||||
self.block_eigenvalue = None
|
||||
|
@ -1981,12 +1983,31 @@ class DeepSpeedEngine(Module):
|
|||
grads = None
|
||||
self.buffered_allreduce_fallback(grads=grads, elements_per_buffer=bucket_size)
|
||||
|
||||
@contextmanager
|
||||
def no_sync(self):
|
||||
r"""
|
||||
Context manager to disable gradient reduction during backward pass.
|
||||
This context manager has the following effects on other DeepSpeed features.
|
||||
1. Incompatible with ZeRO stage 2/3 which rely on reduction for gradient partitioning.
|
||||
2. It is illegal to call engine.step() within the context manager.
|
||||
3. Tracking of gradient accumulation steps is disabled.
|
||||
"""
|
||||
assert not self.zero_optimization_partition_gradients(), \
|
||||
f"no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage {self.zero_optimization_stage()}"
|
||||
|
||||
assert not self.inside_no_sync_ctxt, f"no_sync context manager reentry is unsupported"
|
||||
|
||||
self.inside_no_sync_ctxt = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.inside_no_sync_ctxt = False
|
||||
|
||||
@instrument_w_nvtx
|
||||
def backward(self, loss, allreduce_gradients=True, release_loss=False, retain_graph=False, scale_wrt_gas=True):
|
||||
def backward(self, loss, release_loss=False, retain_graph=False, scale_wrt_gas=True):
|
||||
r"""Execute backward pass on the loss
|
||||
Arguments:
|
||||
loss: Torch tensor on which to execute backward propagation
|
||||
allreduce_gradients: is deprecated, ignored, and will soon be removed'
|
||||
retain_graph: bool, default: false
|
||||
forward on user defined choice of retain_graph
|
||||
"""
|
||||
|
@ -1996,11 +2017,10 @@ class DeepSpeedEngine(Module):
|
|||
if self.scale_wrt_gas is not None:
|
||||
scale_wrt_gas = self.scale_wrt_gas
|
||||
|
||||
if not allreduce_gradients:
|
||||
logger.warning(f"Argument `allreduce_gradients` is deprecated, ignored, and will soon be removed")
|
||||
do_gradient_reduction = self.enable_backward_allreduce and not self.inside_no_sync_ctxt
|
||||
|
||||
# scale loss w.r.t. gradient accumulation if needed
|
||||
if self.gradient_accumulation_steps() > 1 and scale_wrt_gas:
|
||||
# scale loss w.r.t. gradient accumulation if reduction is not disabled
|
||||
if do_gradient_reduction and self.gradient_accumulation_steps() > 1 and scale_wrt_gas:
|
||||
loss = self._scale_loss_by_gas(loss.float())
|
||||
|
||||
# Log training loss
|
||||
|
@ -2049,7 +2069,7 @@ class DeepSpeedEngine(Module):
|
|||
|
||||
self._start_timers(self.engine_timers.backward_reduce_timers)
|
||||
|
||||
if allreduce_gradients and self.enable_backward_allreduce:
|
||||
if do_gradient_reduction:
|
||||
# Traditional code path that allreduces the module parameter grads
|
||||
self.allreduce_gradients()
|
||||
|
||||
|
@ -2185,6 +2205,9 @@ class DeepSpeedEngine(Module):
|
|||
r"""Execute the weight update step after forward and backward propagation
|
||||
on effective_train_batch.
|
||||
"""
|
||||
assert not self.inside_no_sync_ctxt, \
|
||||
"It is illegal to call Engine.step() inside no_sync context manager"
|
||||
|
||||
see_memory_usage("Engine before step", force=self.memory_breakdown())
|
||||
|
||||
# Check early because self.global_steps is incremented at some point here.
|
||||
|
|
|
@ -2297,11 +2297,6 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
|
|||
def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights):
|
||||
self.load_hp_checkpoint_state_from_checkpoint_dir("bit16_groups", checkpoint_folder)
|
||||
|
||||
@property
|
||||
def param_groups(self):
|
||||
"""Forward the wrapped optimizer's parameters."""
|
||||
return self.optimizer.param_groups
|
||||
|
||||
def _load_global_state(self, sd):
|
||||
self.loss_scaler = sd.get(LOSS_SCALER, self.loss_scaler)
|
||||
self.dynamic_loss_scale = sd.get('dynamic_loss_scale', self.dynamic_loss_scale)
|
||||
|
|
|
@ -0,0 +1,197 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import pytest
|
||||
|
||||
from contextlib import nullcontext
|
||||
import torch
|
||||
|
||||
from unit.simple_model import SimpleModel, random_dataloader
|
||||
from unit.common import DistributedTest
|
||||
|
||||
import deepspeed
|
||||
import deepspeed.comm as dist
|
||||
from deepspeed.utils import safe_get_full_grad
|
||||
|
||||
|
||||
class TestNoSyncCtxt(DistributedTest):
|
||||
world_size = 2
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||
@pytest.mark.parametrize("zero_stage", [0, 1, 2, 3])
|
||||
def test_zero_stage(self, zero_stage, dtype):
|
||||
config_dict = {
|
||||
"train_micro_batch_size_per_gpu": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"steps_per_print": 1,
|
||||
"optimizer": {
|
||||
"type": "Adam",
|
||||
"params": {
|
||||
"lr": 1e-3
|
||||
}
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": zero_stage,
|
||||
},
|
||||
}
|
||||
|
||||
invalid_cfg = zero_stage > 1
|
||||
if dtype == torch.bfloat16:
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
elif dtype == torch.float16:
|
||||
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
|
||||
|
||||
hidden_dim = 64
|
||||
total_samples = 32
|
||||
model = SimpleModel(hidden_dim)
|
||||
model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
|
||||
data_loader = random_dataloader(model=model,
|
||||
total_samples=total_samples,
|
||||
hidden_dim=hidden_dim,
|
||||
device=model.device,
|
||||
dtype=dtype)
|
||||
dist.barrier()
|
||||
|
||||
with pytest.raises(AssertionError) if invalid_cfg else nullcontext() as assertinfo:
|
||||
with model.no_sync():
|
||||
for _, batch in enumerate(data_loader):
|
||||
loss = model(batch[0], batch[1])
|
||||
model.backward(loss)
|
||||
if invalid_cfg:
|
||||
assert ("no_sync context manager is incompatible" in str(assertinfo))
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||
@pytest.mark.parametrize("zero_stage", [0, 1])
|
||||
def test_engine_step(self, zero_stage, dtype):
|
||||
config_dict = {
|
||||
"train_micro_batch_size_per_gpu": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"steps_per_print": 1,
|
||||
"optimizer": {
|
||||
"type": "Adam",
|
||||
"params": {
|
||||
"lr": 1e-3
|
||||
}
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": zero_stage,
|
||||
},
|
||||
}
|
||||
|
||||
if dtype == torch.bfloat16:
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
elif dtype == torch.float16:
|
||||
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
|
||||
|
||||
hidden_dim = 64
|
||||
total_samples = 32
|
||||
model = SimpleModel(hidden_dim)
|
||||
model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
|
||||
data_loader = random_dataloader(model=model,
|
||||
total_samples=total_samples,
|
||||
hidden_dim=hidden_dim,
|
||||
device=model.device,
|
||||
dtype=dtype)
|
||||
dist.barrier()
|
||||
|
||||
with model.no_sync():
|
||||
for _, batch in enumerate(data_loader):
|
||||
loss = model(batch[0], batch[1])
|
||||
model.backward(loss)
|
||||
with pytest.raises(AssertionError) as assertinfo:
|
||||
model.step()
|
||||
assert ("It is illegal to call Engine.step() inside no_sync context manager" in str(assertinfo))
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||
@pytest.mark.parametrize("zero_stage", [0, 1])
|
||||
def test_multiple_ctxts(self, zero_stage, dtype):
|
||||
config_dict = {
|
||||
"train_micro_batch_size_per_gpu": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"steps_per_print": 1,
|
||||
"optimizer": {
|
||||
"type": "Adam",
|
||||
"params": {
|
||||
"lr": 1e-3
|
||||
}
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": zero_stage,
|
||||
},
|
||||
}
|
||||
|
||||
if dtype == torch.bfloat16:
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
elif dtype == torch.float16:
|
||||
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
|
||||
|
||||
hidden_dim = 64
|
||||
total_samples = 32
|
||||
model = SimpleModel(hidden_dim)
|
||||
model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
|
||||
data_loader = random_dataloader(model=model,
|
||||
total_samples=total_samples,
|
||||
hidden_dim=hidden_dim,
|
||||
device=model.device,
|
||||
dtype=dtype)
|
||||
dist.barrier()
|
||||
|
||||
param_list = list(model.parameters())
|
||||
first_losses = []
|
||||
first_grad_norms = []
|
||||
with model.no_sync():
|
||||
for _, batch in enumerate(data_loader):
|
||||
loss = model(batch[0], batch[1])
|
||||
first_losses.append(loss.item())
|
||||
model.backward(loss)
|
||||
grad_norm = sum([safe_get_full_grad(p).norm() for p in param_list])
|
||||
first_grad_norms.append(grad_norm.item())
|
||||
|
||||
second_losses = []
|
||||
second_grad_norms = []
|
||||
|
||||
model.zero_grad()
|
||||
with model.no_sync():
|
||||
for _, batch in enumerate(data_loader):
|
||||
loss = model(batch[0], batch[1])
|
||||
second_losses.append(loss.item())
|
||||
model.backward(loss)
|
||||
grad_norm = sum([safe_get_full_grad(p).norm() for p in param_list])
|
||||
second_grad_norms.append(grad_norm.item())
|
||||
|
||||
assert len(first_losses) == len(second_losses)
|
||||
for x, y in zip(first_losses, second_losses):
|
||||
assert x == y
|
||||
|
||||
assert len(first_grad_norms) == len(second_grad_norms)
|
||||
for x, y in zip(first_grad_norms, second_grad_norms):
|
||||
assert x == y
|
||||
|
||||
def test_reentry(self):
|
||||
config_dict = {
|
||||
"train_micro_batch_size_per_gpu": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"steps_per_print": 1,
|
||||
"optimizer": {
|
||||
"type": "Adam",
|
||||
"params": {
|
||||
"lr": 1e-3
|
||||
}
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 1,
|
||||
},
|
||||
}
|
||||
|
||||
hidden_dim = 64
|
||||
model = SimpleModel(hidden_dim)
|
||||
model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
|
||||
dist.barrier()
|
||||
|
||||
with model.no_sync():
|
||||
with pytest.raises(AssertionError) as assertinfo:
|
||||
with model.no_sync():
|
||||
pass
|
||||
assert ("no_sync context manager reentry is unsupported" in str(assertinfo))
|
Загрузка…
Ссылка в новой задаче