зеркало из https://github.com/microsoft/DeepSpeed.git
Merge branch 'master' into jomayeri/aio-file-offset
This commit is contained in:
Коммит
58bc018beb
|
@ -133,7 +133,6 @@ class DeepSpeedZeRoOffload(object):
|
||||||
self.persistent_parameters = self.mark_persistent_parameters(self.param_numel_persistence_threshold,
|
self.persistent_parameters = self.mark_persistent_parameters(self.param_numel_persistence_threshold,
|
||||||
self.model_persistence_threshold)
|
self.model_persistence_threshold)
|
||||||
|
|
||||||
self.param_coordinators = {}
|
|
||||||
self._prefetch_bucket_sz = int(prefetch_bucket_size)
|
self._prefetch_bucket_sz = int(prefetch_bucket_size)
|
||||||
self._max_reuse_distance_in_numel = int(max_reuse_distance)
|
self._max_reuse_distance_in_numel = int(max_reuse_distance)
|
||||||
self._max_available_parameters_in_numel = int(max_live_parameters)
|
self._max_available_parameters_in_numel = int(max_live_parameters)
|
||||||
|
@ -141,12 +140,21 @@ class DeepSpeedZeRoOffload(object):
|
||||||
) if overlap_comm else get_accelerator().default_stream()
|
) if overlap_comm else get_accelerator().default_stream()
|
||||||
|
|
||||||
if not hasattr(module, "ds_inflight_param_registry"):
|
if not hasattr(module, "ds_inflight_param_registry"):
|
||||||
module.ds_inflight_param_registry = dict()
|
module.ds_inflight_param_registry = InflightParamRegistry()
|
||||||
# we need two registries, one for training and one for eval. They will be used when creating PartitionedParameterCoordinator
|
|
||||||
module.ds_inflight_param_registry[True] = InflightParamRegistry()
|
|
||||||
module.ds_inflight_param_registry[False] = InflightParamRegistry()
|
|
||||||
self.__inflight_param_registry = module.ds_inflight_param_registry
|
self.__inflight_param_registry = module.ds_inflight_param_registry
|
||||||
|
|
||||||
|
self.param_coordinator = PartitionedParameterCoordinator(
|
||||||
|
prefetch_bucket_sz=self._prefetch_bucket_sz,
|
||||||
|
max_reuse_distance_in_numel=self._max_reuse_distance_in_numel,
|
||||||
|
max_available_parameters_in_numel=self._max_available_parameters_in_numel,
|
||||||
|
allgather_stream=self.__allgather_stream,
|
||||||
|
inflight_param_registry=self.__inflight_param_registry,
|
||||||
|
prefetch_nvme=self.offload_device == OffloadDeviceEnum.nvme,
|
||||||
|
timers=self.timers,
|
||||||
|
zero_quantized_weights=self.zero_quantized_weights,
|
||||||
|
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights,
|
||||||
|
)
|
||||||
|
|
||||||
self.forward_hooks = []
|
self.forward_hooks = []
|
||||||
self.backward_hooks = []
|
self.backward_hooks = []
|
||||||
self.setup_zero_stage3_hooks()
|
self.setup_zero_stage3_hooks()
|
||||||
|
@ -161,26 +169,13 @@ class DeepSpeedZeRoOffload(object):
|
||||||
"""Partitioning Parameters that were not partitioned usually if parameters
|
"""Partitioning Parameters that were not partitioned usually if parameters
|
||||||
of modules whose input parameters do not require grad computation do not
|
of modules whose input parameters do not require grad computation do not
|
||||||
trigger post call and will therefore will remain unpartitioned"""
|
trigger post call and will therefore will remain unpartitioned"""
|
||||||
self.get_param_coordinator(training=self.module.training).release_and_reset_all(self.module)
|
self.get_param_coordinator().release_and_reset_all(self.module)
|
||||||
for param in iter_params(self.module, recurse=True):
|
for param in iter_params(self.module, recurse=True):
|
||||||
if param.ds_status != ZeroParamStatus.NOT_AVAILABLE:
|
if param.ds_status != ZeroParamStatus.NOT_AVAILABLE:
|
||||||
raise RuntimeError(f"{param.ds_summary()} expected to be released")
|
raise RuntimeError(f"{param.ds_summary()} expected to be released")
|
||||||
|
|
||||||
def get_param_coordinator(self, training):
|
def get_param_coordinator(self):
|
||||||
if not training in self.param_coordinators:
|
return self.param_coordinator
|
||||||
self.param_coordinators[training] = PartitionedParameterCoordinator(
|
|
||||||
prefetch_bucket_sz=self._prefetch_bucket_sz,
|
|
||||||
max_reuse_distance_in_numel=self._max_reuse_distance_in_numel,
|
|
||||||
max_available_parameters_in_numel=self._max_available_parameters_in_numel,
|
|
||||||
allgather_stream=self.__allgather_stream,
|
|
||||||
inflight_param_registry=self.__inflight_param_registry[training],
|
|
||||||
prefetch_nvme=self.offload_device == OffloadDeviceEnum.nvme,
|
|
||||||
timers=self.timers,
|
|
||||||
zero_quantized_weights=self.zero_quantized_weights,
|
|
||||||
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.param_coordinators[training]
|
|
||||||
|
|
||||||
def empty_partition_cache(self):
|
def empty_partition_cache(self):
|
||||||
self.partition_all_parameters()
|
self.partition_all_parameters()
|
||||||
|
@ -228,14 +223,14 @@ class DeepSpeedZeRoOffload(object):
|
||||||
|
|
||||||
#reset step if in inference mode
|
#reset step if in inference mode
|
||||||
@instrument_w_nvtx
|
@instrument_w_nvtx
|
||||||
def _end_of_forward_hook(module, *args):
|
def _start_of_forward_hook(module, *args):
|
||||||
|
|
||||||
if not torch._C.is_grad_enabled():
|
self.get_param_coordinator().reset_step()
|
||||||
self.get_param_coordinator(training=False).reset_step()
|
|
||||||
|
self.module.register_forward_pre_hook(_start_of_forward_hook)
|
||||||
|
|
||||||
#likely one of them should be enough but just to be safe
|
#likely one of them should be enough but just to be safe
|
||||||
self._register_hooks_recursively(self.module)
|
self._register_hooks_recursively(self.module)
|
||||||
self.module.register_forward_hook(_end_of_forward_hook)
|
|
||||||
|
|
||||||
# Add top module to stack trace
|
# Add top module to stack trace
|
||||||
global FWD_MODULE_STACK
|
global FWD_MODULE_STACK
|
||||||
|
@ -447,7 +442,7 @@ class DeepSpeedZeRoOffload(object):
|
||||||
global FWD_MODULE_STACK
|
global FWD_MODULE_STACK
|
||||||
FWD_MODULE_STACK.append(sub_module)
|
FWD_MODULE_STACK.append(sub_module)
|
||||||
|
|
||||||
param_coordinator = self.get_param_coordinator(training=sub_module.training)
|
param_coordinator = self.get_param_coordinator()
|
||||||
param_coordinator.trace_prologue(sub_module)
|
param_coordinator.trace_prologue(sub_module)
|
||||||
if param_coordinator.is_record_trace():
|
if param_coordinator.is_record_trace():
|
||||||
param_coordinator.record_module(sub_module)
|
param_coordinator.record_module(sub_module)
|
||||||
|
@ -460,7 +455,7 @@ class DeepSpeedZeRoOffload(object):
|
||||||
see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release",
|
see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release",
|
||||||
force=False)
|
force=False)
|
||||||
|
|
||||||
param_coordinator = self.get_param_coordinator(training=sub_module.training)
|
param_coordinator = self.get_param_coordinator()
|
||||||
param_coordinator.release_sub_module(sub_module)
|
param_coordinator.release_sub_module(sub_module)
|
||||||
|
|
||||||
see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release",
|
see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release",
|
||||||
|
@ -468,8 +463,8 @@ class DeepSpeedZeRoOffload(object):
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def pre_sub_module_backward_function(self, sub_module):
|
def pre_sub_module_backward_function(self, sub_module):
|
||||||
assert sub_module.training, "backward pass is invalid for module in evaluation mode"
|
# assert sub_module.training, "backward pass is invalid for module in evaluation mode"
|
||||||
param_coordinator = self.get_param_coordinator(training=True)
|
param_coordinator = self.get_param_coordinator()
|
||||||
param_coordinator.trace_prologue(sub_module)
|
param_coordinator.trace_prologue(sub_module)
|
||||||
if param_coordinator.is_record_trace():
|
if param_coordinator.is_record_trace():
|
||||||
param_coordinator.record_module(sub_module)
|
param_coordinator.record_module(sub_module)
|
||||||
|
@ -477,12 +472,12 @@ class DeepSpeedZeRoOffload(object):
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def post_sub_module_backward_function(self, sub_module):
|
def post_sub_module_backward_function(self, sub_module):
|
||||||
assert sub_module.training, "backward pass is invalid for module in evaluation mode"
|
# assert sub_module.training, "backward pass is invalid for module in evaluation mode"
|
||||||
see_memory_usage(
|
see_memory_usage(
|
||||||
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release",
|
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release",
|
||||||
force=False)
|
force=False)
|
||||||
|
|
||||||
self.get_param_coordinator(training=True).release_sub_module(sub_module)
|
self.get_param_coordinator().release_sub_module(sub_module)
|
||||||
|
|
||||||
see_memory_usage(
|
see_memory_usage(
|
||||||
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release",
|
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release",
|
||||||
|
|
|
@ -18,6 +18,7 @@ from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedP
|
||||||
from deepspeed.utils.debug import debug_module2name_id, debug_param2name_id
|
from deepspeed.utils.debug import debug_module2name_id, debug_param2name_id
|
||||||
from deepspeed.accelerator import get_accelerator
|
from deepspeed.accelerator import get_accelerator
|
||||||
import deepspeed.runtime.compiler as compiler
|
import deepspeed.runtime.compiler as compiler
|
||||||
|
from deepspeed.runtime.compiler import is_compiling
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
@ -92,7 +93,7 @@ class PartitionedParameterCoordinator:
|
||||||
# keeps track of the number of submodules invoked so far.
|
# keeps track of the number of submodules invoked so far.
|
||||||
self.__step_id: int = 0
|
self.__step_id: int = 0
|
||||||
# network tracing mode
|
# network tracing mode
|
||||||
self.__trace_mode: ZeRoTraceMode = ZeRoTraceMode.RECORD
|
self.__trace_mode: ZeRoTraceMode = ZeRoTraceMode.INVALID
|
||||||
# sequence of submodules/parameters in forward pass + backward pass
|
# sequence of submodules/parameters in forward pass + backward pass
|
||||||
self.__submodule_order: Iterable[Module] = []
|
self.__submodule_order: Iterable[Module] = []
|
||||||
self.__param_order: Iterable[__class__.__ParamInTrace] = []
|
self.__param_order: Iterable[__class__.__ParamInTrace] = []
|
||||||
|
@ -188,6 +189,9 @@ class PartitionedParameterCoordinator:
|
||||||
@compiler.disable
|
@compiler.disable
|
||||||
def record_module(self, sub_module: Module) -> None:
|
def record_module(self, sub_module: Module) -> None:
|
||||||
"""adds sub module to trace"""
|
"""adds sub module to trace"""
|
||||||
|
if is_compiling():
|
||||||
|
return
|
||||||
|
|
||||||
if not self.is_record_trace():
|
if not self.is_record_trace():
|
||||||
raise RuntimeError(f"attempted to record trace when status = {self.__trace_mode}")
|
raise RuntimeError(f"attempted to record trace when status = {self.__trace_mode}")
|
||||||
|
|
||||||
|
@ -195,6 +199,8 @@ class PartitionedParameterCoordinator:
|
||||||
self.__step_id_module_fetched_for[sub_module.id].append(self.__step_id)
|
self.__step_id_module_fetched_for[sub_module.id].append(self.__step_id)
|
||||||
|
|
||||||
def record_parameters(self, sub_module: Module) -> None:
|
def record_parameters(self, sub_module: Module) -> None:
|
||||||
|
if is_compiling():
|
||||||
|
return
|
||||||
"""adds sub module to trace"""
|
"""adds sub module to trace"""
|
||||||
if not self.is_record_trace():
|
if not self.is_record_trace():
|
||||||
raise RuntimeError(f"attempted to record trace when status = {self.__trace_mode}")
|
raise RuntimeError(f"attempted to record trace when status = {self.__trace_mode}")
|
||||||
|
@ -209,8 +215,12 @@ class PartitionedParameterCoordinator:
|
||||||
for sub_module in self.__submodule_order:
|
for sub_module in self.__submodule_order:
|
||||||
self.record_parameters(sub_module)
|
self.record_parameters(sub_module)
|
||||||
|
|
||||||
|
@compiler.disable
|
||||||
def reset_step(self) -> None:
|
def reset_step(self) -> None:
|
||||||
"""indicate that we have completed one fwd+bwd for the model"""
|
"""indicate that we have completed one fwd+bwd for the model"""
|
||||||
|
if is_compiling():
|
||||||
|
return
|
||||||
|
|
||||||
self._clean_inflight_param_registry()
|
self._clean_inflight_param_registry()
|
||||||
|
|
||||||
if not self.is_complete_trace(): # not self.trace_complete:
|
if not self.is_complete_trace(): # not self.trace_complete:
|
||||||
|
|
|
@ -593,8 +593,8 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||||
|
|
||||||
return device_buffer
|
return device_buffer
|
||||||
|
|
||||||
def _get_param_coordinator(self, training):
|
def _get_param_coordinator(self):
|
||||||
return self.parameter_offload.get_param_coordinator(training)
|
return self.parameter_offload.get_param_coordinator()
|
||||||
|
|
||||||
def _configure_offloading(self, offload_optimizer_config, offload_param_config):
|
def _configure_offloading(self, offload_optimizer_config, offload_param_config):
|
||||||
###################### offload optimizer setup ##################################
|
###################### offload optimizer setup ##################################
|
||||||
|
@ -1874,7 +1874,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||||
see_memory_usage(f"In step before checking overflow", force=False)
|
see_memory_usage(f"In step before checking overflow", force=False)
|
||||||
|
|
||||||
print_rank_0("Finished Tracing at Beginning of Step")
|
print_rank_0("Finished Tracing at Beginning of Step")
|
||||||
self._get_param_coordinator(training=True).hierarchy = 0
|
self._get_param_coordinator().hierarchy = 0
|
||||||
|
|
||||||
print_rank_0("Finished Tracing at Beginning of Step")
|
print_rank_0("Finished Tracing at Beginning of Step")
|
||||||
|
|
||||||
|
@ -2258,8 +2258,6 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||||
else:
|
else:
|
||||||
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
|
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
|
||||||
|
|
||||||
self._get_param_coordinator(training=True).reset_step()
|
|
||||||
|
|
||||||
if self.swap_optimizer:
|
if self.swap_optimizer:
|
||||||
self.optimizer_swapper.post_backward()
|
self.optimizer_swapper.post_backward()
|
||||||
|
|
||||||
|
|
|
@ -1628,3 +1628,48 @@ class TestEmptyParameterGroup(DistributedTest):
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
config=config_dict,
|
config=config_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestZero3SwitchModes(DistributedTest):
|
||||||
|
world_size = 2
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("prefetch_ratio", [0.0, 0.5, 1.0])
|
||||||
|
def test(self, prefetch_ratio, zero_stage=3):
|
||||||
|
|
||||||
|
hidden_dim = 10
|
||||||
|
model = SimpleModel(hidden_dim)
|
||||||
|
|
||||||
|
prefetch_bucket_size = int(sum([p.numel() for p in model.parameters(recurse=True)]) * prefetch_ratio)
|
||||||
|
config_dict = {
|
||||||
|
"train_micro_batch_size_per_gpu": 2,
|
||||||
|
"gradient_accumulation_steps": 2,
|
||||||
|
"zero_optimization": {
|
||||||
|
"stage": zero_stage,
|
||||||
|
"stage3_prefetch_bucket_size": prefetch_bucket_size
|
||||||
|
},
|
||||||
|
"optimizer": {
|
||||||
|
"type": "Adam",
|
||||||
|
"params": {
|
||||||
|
"lr": 1e-3
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fp16": {
|
||||||
|
"enabled": True,
|
||||||
|
"initial_scale_power": 8
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters())
|
||||||
|
data_loader = random_dataloader(model=model, total_samples=16, hidden_dim=hidden_dim, device=model.device)
|
||||||
|
|
||||||
|
for _ in range(3):
|
||||||
|
model.train()
|
||||||
|
for batch in data_loader:
|
||||||
|
loss = model(batch[0], batch[1])
|
||||||
|
model.backward(loss)
|
||||||
|
model.step()
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch in data_loader:
|
||||||
|
loss = model(batch[0], batch[1])
|
||||||
|
|
Загрузка…
Ссылка в новой задаче