Merge branch 'master' into jomayeri/aio-file-offset

This commit is contained in:
Logan Adams 2024-11-06 09:49:32 -08:00 коммит произвёл GitHub
Родитель 2f5446a584 351569dd4a
Коммит 58bc018beb
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
4 изменённых файлов: 85 добавлений и 37 удалений

Просмотреть файл

@ -133,7 +133,6 @@ class DeepSpeedZeRoOffload(object):
self.persistent_parameters = self.mark_persistent_parameters(self.param_numel_persistence_threshold, self.persistent_parameters = self.mark_persistent_parameters(self.param_numel_persistence_threshold,
self.model_persistence_threshold) self.model_persistence_threshold)
self.param_coordinators = {}
self._prefetch_bucket_sz = int(prefetch_bucket_size) self._prefetch_bucket_sz = int(prefetch_bucket_size)
self._max_reuse_distance_in_numel = int(max_reuse_distance) self._max_reuse_distance_in_numel = int(max_reuse_distance)
self._max_available_parameters_in_numel = int(max_live_parameters) self._max_available_parameters_in_numel = int(max_live_parameters)
@ -141,12 +140,21 @@ class DeepSpeedZeRoOffload(object):
) if overlap_comm else get_accelerator().default_stream() ) if overlap_comm else get_accelerator().default_stream()
if not hasattr(module, "ds_inflight_param_registry"): if not hasattr(module, "ds_inflight_param_registry"):
module.ds_inflight_param_registry = dict() module.ds_inflight_param_registry = InflightParamRegistry()
# we need two registries, one for training and one for eval. They will be used when creating PartitionedParameterCoordinator
module.ds_inflight_param_registry[True] = InflightParamRegistry()
module.ds_inflight_param_registry[False] = InflightParamRegistry()
self.__inflight_param_registry = module.ds_inflight_param_registry self.__inflight_param_registry = module.ds_inflight_param_registry
self.param_coordinator = PartitionedParameterCoordinator(
prefetch_bucket_sz=self._prefetch_bucket_sz,
max_reuse_distance_in_numel=self._max_reuse_distance_in_numel,
max_available_parameters_in_numel=self._max_available_parameters_in_numel,
allgather_stream=self.__allgather_stream,
inflight_param_registry=self.__inflight_param_registry,
prefetch_nvme=self.offload_device == OffloadDeviceEnum.nvme,
timers=self.timers,
zero_quantized_weights=self.zero_quantized_weights,
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights,
)
self.forward_hooks = [] self.forward_hooks = []
self.backward_hooks = [] self.backward_hooks = []
self.setup_zero_stage3_hooks() self.setup_zero_stage3_hooks()
@ -161,26 +169,13 @@ class DeepSpeedZeRoOffload(object):
"""Partitioning Parameters that were not partitioned usually if parameters """Partitioning Parameters that were not partitioned usually if parameters
of modules whose input parameters do not require grad computation do not of modules whose input parameters do not require grad computation do not
trigger post call and will therefore will remain unpartitioned""" trigger post call and will therefore will remain unpartitioned"""
self.get_param_coordinator(training=self.module.training).release_and_reset_all(self.module) self.get_param_coordinator().release_and_reset_all(self.module)
for param in iter_params(self.module, recurse=True): for param in iter_params(self.module, recurse=True):
if param.ds_status != ZeroParamStatus.NOT_AVAILABLE: if param.ds_status != ZeroParamStatus.NOT_AVAILABLE:
raise RuntimeError(f"{param.ds_summary()} expected to be released") raise RuntimeError(f"{param.ds_summary()} expected to be released")
def get_param_coordinator(self, training): def get_param_coordinator(self):
if not training in self.param_coordinators: return self.param_coordinator
self.param_coordinators[training] = PartitionedParameterCoordinator(
prefetch_bucket_sz=self._prefetch_bucket_sz,
max_reuse_distance_in_numel=self._max_reuse_distance_in_numel,
max_available_parameters_in_numel=self._max_available_parameters_in_numel,
allgather_stream=self.__allgather_stream,
inflight_param_registry=self.__inflight_param_registry[training],
prefetch_nvme=self.offload_device == OffloadDeviceEnum.nvme,
timers=self.timers,
zero_quantized_weights=self.zero_quantized_weights,
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights,
)
return self.param_coordinators[training]
def empty_partition_cache(self): def empty_partition_cache(self):
self.partition_all_parameters() self.partition_all_parameters()
@ -228,14 +223,14 @@ class DeepSpeedZeRoOffload(object):
#reset step if in inference mode #reset step if in inference mode
@instrument_w_nvtx @instrument_w_nvtx
def _end_of_forward_hook(module, *args): def _start_of_forward_hook(module, *args):
if not torch._C.is_grad_enabled(): self.get_param_coordinator().reset_step()
self.get_param_coordinator(training=False).reset_step()
self.module.register_forward_pre_hook(_start_of_forward_hook)
#likely one of them should be enough but just to be safe #likely one of them should be enough but just to be safe
self._register_hooks_recursively(self.module) self._register_hooks_recursively(self.module)
self.module.register_forward_hook(_end_of_forward_hook)
# Add top module to stack trace # Add top module to stack trace
global FWD_MODULE_STACK global FWD_MODULE_STACK
@ -447,7 +442,7 @@ class DeepSpeedZeRoOffload(object):
global FWD_MODULE_STACK global FWD_MODULE_STACK
FWD_MODULE_STACK.append(sub_module) FWD_MODULE_STACK.append(sub_module)
param_coordinator = self.get_param_coordinator(training=sub_module.training) param_coordinator = self.get_param_coordinator()
param_coordinator.trace_prologue(sub_module) param_coordinator.trace_prologue(sub_module)
if param_coordinator.is_record_trace(): if param_coordinator.is_record_trace():
param_coordinator.record_module(sub_module) param_coordinator.record_module(sub_module)
@ -460,7 +455,7 @@ class DeepSpeedZeRoOffload(object):
see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release", see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release",
force=False) force=False)
param_coordinator = self.get_param_coordinator(training=sub_module.training) param_coordinator = self.get_param_coordinator()
param_coordinator.release_sub_module(sub_module) param_coordinator.release_sub_module(sub_module)
see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release", see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release",
@ -468,8 +463,8 @@ class DeepSpeedZeRoOffload(object):
@torch.no_grad() @torch.no_grad()
def pre_sub_module_backward_function(self, sub_module): def pre_sub_module_backward_function(self, sub_module):
assert sub_module.training, "backward pass is invalid for module in evaluation mode" # assert sub_module.training, "backward pass is invalid for module in evaluation mode"
param_coordinator = self.get_param_coordinator(training=True) param_coordinator = self.get_param_coordinator()
param_coordinator.trace_prologue(sub_module) param_coordinator.trace_prologue(sub_module)
if param_coordinator.is_record_trace(): if param_coordinator.is_record_trace():
param_coordinator.record_module(sub_module) param_coordinator.record_module(sub_module)
@ -477,12 +472,12 @@ class DeepSpeedZeRoOffload(object):
@torch.no_grad() @torch.no_grad()
def post_sub_module_backward_function(self, sub_module): def post_sub_module_backward_function(self, sub_module):
assert sub_module.training, "backward pass is invalid for module in evaluation mode" # assert sub_module.training, "backward pass is invalid for module in evaluation mode"
see_memory_usage( see_memory_usage(
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release", f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release",
force=False) force=False)
self.get_param_coordinator(training=True).release_sub_module(sub_module) self.get_param_coordinator().release_sub_module(sub_module)
see_memory_usage( see_memory_usage(
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release", f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release",

Просмотреть файл

@ -18,6 +18,7 @@ from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedP
from deepspeed.utils.debug import debug_module2name_id, debug_param2name_id from deepspeed.utils.debug import debug_module2name_id, debug_param2name_id
from deepspeed.accelerator import get_accelerator from deepspeed.accelerator import get_accelerator
import deepspeed.runtime.compiler as compiler import deepspeed.runtime.compiler as compiler
from deepspeed.runtime.compiler import is_compiling
import logging import logging
@ -92,7 +93,7 @@ class PartitionedParameterCoordinator:
# keeps track of the number of submodules invoked so far. # keeps track of the number of submodules invoked so far.
self.__step_id: int = 0 self.__step_id: int = 0
# network tracing mode # network tracing mode
self.__trace_mode: ZeRoTraceMode = ZeRoTraceMode.RECORD self.__trace_mode: ZeRoTraceMode = ZeRoTraceMode.INVALID
# sequence of submodules/parameters in forward pass + backward pass # sequence of submodules/parameters in forward pass + backward pass
self.__submodule_order: Iterable[Module] = [] self.__submodule_order: Iterable[Module] = []
self.__param_order: Iterable[__class__.__ParamInTrace] = [] self.__param_order: Iterable[__class__.__ParamInTrace] = []
@ -188,6 +189,9 @@ class PartitionedParameterCoordinator:
@compiler.disable @compiler.disable
def record_module(self, sub_module: Module) -> None: def record_module(self, sub_module: Module) -> None:
"""adds sub module to trace""" """adds sub module to trace"""
if is_compiling():
return
if not self.is_record_trace(): if not self.is_record_trace():
raise RuntimeError(f"attempted to record trace when status = {self.__trace_mode}") raise RuntimeError(f"attempted to record trace when status = {self.__trace_mode}")
@ -195,6 +199,8 @@ class PartitionedParameterCoordinator:
self.__step_id_module_fetched_for[sub_module.id].append(self.__step_id) self.__step_id_module_fetched_for[sub_module.id].append(self.__step_id)
def record_parameters(self, sub_module: Module) -> None: def record_parameters(self, sub_module: Module) -> None:
if is_compiling():
return
"""adds sub module to trace""" """adds sub module to trace"""
if not self.is_record_trace(): if not self.is_record_trace():
raise RuntimeError(f"attempted to record trace when status = {self.__trace_mode}") raise RuntimeError(f"attempted to record trace when status = {self.__trace_mode}")
@ -209,8 +215,12 @@ class PartitionedParameterCoordinator:
for sub_module in self.__submodule_order: for sub_module in self.__submodule_order:
self.record_parameters(sub_module) self.record_parameters(sub_module)
@compiler.disable
def reset_step(self) -> None: def reset_step(self) -> None:
"""indicate that we have completed one fwd+bwd for the model""" """indicate that we have completed one fwd+bwd for the model"""
if is_compiling():
return
self._clean_inflight_param_registry() self._clean_inflight_param_registry()
if not self.is_complete_trace(): # not self.trace_complete: if not self.is_complete_trace(): # not self.trace_complete:

Просмотреть файл

@ -593,8 +593,8 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
return device_buffer return device_buffer
def _get_param_coordinator(self, training): def _get_param_coordinator(self):
return self.parameter_offload.get_param_coordinator(training) return self.parameter_offload.get_param_coordinator()
def _configure_offloading(self, offload_optimizer_config, offload_param_config): def _configure_offloading(self, offload_optimizer_config, offload_param_config):
###################### offload optimizer setup ################################## ###################### offload optimizer setup ##################################
@ -1874,7 +1874,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
see_memory_usage(f"In step before checking overflow", force=False) see_memory_usage(f"In step before checking overflow", force=False)
print_rank_0("Finished Tracing at Beginning of Step") print_rank_0("Finished Tracing at Beginning of Step")
self._get_param_coordinator(training=True).hierarchy = 0 self._get_param_coordinator().hierarchy = 0
print_rank_0("Finished Tracing at Beginning of Step") print_rank_0("Finished Tracing at Beginning of Step")
@ -2258,8 +2258,6 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
else: else:
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
self._get_param_coordinator(training=True).reset_step()
if self.swap_optimizer: if self.swap_optimizer:
self.optimizer_swapper.post_backward() self.optimizer_swapper.post_backward()

Просмотреть файл

@ -1628,3 +1628,48 @@ class TestEmptyParameterGroup(DistributedTest):
optimizer=optimizer, optimizer=optimizer,
config=config_dict, config=config_dict,
) )
class TestZero3SwitchModes(DistributedTest):
world_size = 2
@pytest.mark.parametrize("prefetch_ratio", [0.0, 0.5, 1.0])
def test(self, prefetch_ratio, zero_stage=3):
hidden_dim = 10
model = SimpleModel(hidden_dim)
prefetch_bucket_size = int(sum([p.numel() for p in model.parameters(recurse=True)]) * prefetch_ratio)
config_dict = {
"train_micro_batch_size_per_gpu": 2,
"gradient_accumulation_steps": 2,
"zero_optimization": {
"stage": zero_stage,
"stage3_prefetch_bucket_size": prefetch_bucket_size
},
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-3
}
},
"fp16": {
"enabled": True,
"initial_scale_power": 8
}
}
model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters())
data_loader = random_dataloader(model=model, total_samples=16, hidden_dim=hidden_dim, device=model.device)
for _ in range(3):
model.train()
for batch in data_loader:
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()
model.eval()
with torch.no_grad():
for batch in data_loader:
loss = model(batch[0], batch[1])