зеркало из https://github.com/microsoft/DeepSpeed.git
support model declaration in zero.Init context (#3592)
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
This commit is contained in:
Родитель
a63f9152b9
Коммит
203ac9d7ac
|
@ -195,6 +195,9 @@ def initialize(args=None,
|
|||
config=config,
|
||||
config_class=config_class)
|
||||
|
||||
# Restore zero.Init context if necessary
|
||||
zero.partition_parameters.restore_init_context()
|
||||
|
||||
return_items = [engine, engine.optimizer, engine.training_dataloader, engine.lr_scheduler]
|
||||
return tuple(return_items)
|
||||
|
||||
|
|
|
@ -36,8 +36,8 @@ from ..swap_tensor.partitioned_param_swapper import AsyncPartitionedParameterSwa
|
|||
|
||||
param_count = 0
|
||||
partitioned_param_data_shape = [0]
|
||||
zero_init_context = []
|
||||
all_wrapped_classes = set()
|
||||
zero_init_context = 0
|
||||
top_level_context = None
|
||||
|
||||
|
||||
class NoGatherHandle:
|
||||
|
@ -301,6 +301,54 @@ class InsertPostInitMethodToModuleSubClasses(object):
|
|||
if not self.enabled:
|
||||
return
|
||||
|
||||
global zero_init_context
|
||||
if zero_init_context == 0:
|
||||
self.patch_init_and_builtins()
|
||||
global top_level_context
|
||||
top_level_context = self
|
||||
|
||||
zero_init_context += 1
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
global zero_init_context
|
||||
zero_init_context -= 1
|
||||
|
||||
# Exiting the top level context
|
||||
if zero_init_context == 0:
|
||||
self.unpatch_init_and_builtins()
|
||||
global top_level_context
|
||||
top_level_context = None
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
logger.info("finished initializing model with %.2fB parameters", param_count / 1e9)
|
||||
|
||||
# Now that we cleaned up the metaclass injection, raise the exception.
|
||||
if exc_type is not None:
|
||||
return False
|
||||
|
||||
# To be implemented by inheriting classes
|
||||
def _post_init_method(self, module):
|
||||
pass
|
||||
|
||||
def _set_dtype(self, ds_config, dtype):
|
||||
if ds_config is not None and dtype is None:
|
||||
if ds_config.bfloat16_enabled and ds_config.fp16_enabled:
|
||||
raise RuntimeError("bfloat16 and fp16 cannot be enabled at once")
|
||||
|
||||
if ds_config.bfloat16_enabled:
|
||||
self.dtype = torch.bfloat16
|
||||
elif ds_config.fp16_enabled:
|
||||
self.dtype = torch.half
|
||||
else:
|
||||
self.dtype = torch.float
|
||||
else:
|
||||
self.dtype = dtype or torch.half
|
||||
|
||||
def patch_init_and_builtins(self):
|
||||
|
||||
def apply_with_gather(orig_module_apply_fn: Callable) -> Callable:
|
||||
"""many models make use of child modules like Linear or Embedding which
|
||||
perform their own weight initialization in their __init__ methods,
|
||||
|
@ -401,79 +449,50 @@ class InsertPostInitMethodToModuleSubClasses(object):
|
|||
cls.__init__ = partition_after(cls.__init__)
|
||||
|
||||
def _init_subclass(cls, **kwargs):
|
||||
cls._old_init = cls.__init__
|
||||
cls.__init__ = partition_after(cls.__init__)
|
||||
|
||||
# Replace .__init__() for all existing subclasses of torch.nn.Module recursively
|
||||
global zero_init_context
|
||||
self.nest_level = len(zero_init_context)
|
||||
|
||||
global all_wrapped_classes
|
||||
for subclass in get_all_subclasses(torch.nn.modules.module.Module):
|
||||
# Only wrap classes that haven't been wrapped yet
|
||||
if subclass not in all_wrapped_classes:
|
||||
_enable_class(subclass)
|
||||
self.wrapped_cls.add(subclass)
|
||||
_enable_class(subclass)
|
||||
|
||||
all_wrapped_classes = all_wrapped_classes.union(self.wrapped_cls)
|
||||
# holding onto some methods so we can put them back the way they were in __exit__
|
||||
torch.nn.modules.module.Module._old_init_subclass = torch.nn.modules.module.Module.__init_subclass__
|
||||
torch.nn.modules.module.Module._old_apply = torch.nn.modules.module.Module.apply
|
||||
torch.Tensor.__old_new__ = torch.Tensor.__new__
|
||||
|
||||
# Wrap some functions only at top level call of Init
|
||||
if self.nest_level == 0:
|
||||
# holding onto some methods so we can put them back the way they were in __exit__
|
||||
torch.nn.modules.module.Module._old_init_subclass = torch.nn.modules.module.Module.__init_subclass__
|
||||
torch.nn.modules.module.Module._old_apply = torch.nn.modules.module.Module.apply
|
||||
torch.Tensor.__old_new__ = torch.Tensor.__new__
|
||||
# Replace .__init__() for future subclasses of torch.nn.Module
|
||||
torch.nn.modules.module.Module.__init_subclass__ = classmethod(_init_subclass)
|
||||
torch.nn.modules.module.Module.apply = apply_with_gather(torch.nn.modules.module.Module._old_apply)
|
||||
|
||||
# Replace .__init__() for future subclasses of torch.nn.Module
|
||||
torch.nn.modules.module.Module.__init_subclass__ = classmethod(_init_subclass)
|
||||
torch.nn.modules.module.Module.apply = apply_with_gather(torch.nn.modules.module.Module._old_apply)
|
||||
self._add_tensor_creation_wrappers()
|
||||
|
||||
self._add_tensor_creation_wrappers()
|
||||
if self.mem_efficient_linear:
|
||||
print_rank_0(
|
||||
"nn.functional.linear has been overridden with a more memory efficient version. This will persist unless manually reset.",
|
||||
force=False)
|
||||
self.linear_bk = torch.nn.functional.linear
|
||||
torch.nn.functional.linear = zero3_linear_wrap
|
||||
|
||||
if self.mem_efficient_linear:
|
||||
print_rank_0(
|
||||
"nn.functional.linear has been overridden with a more memory efficient version. This will persist unless manually reset.",
|
||||
force=False)
|
||||
self.linear_bk = torch.nn.functional.linear
|
||||
torch.nn.functional.linear = zero3_linear_wrap
|
||||
self.patched = True
|
||||
|
||||
self.torch_func_wrapped = True
|
||||
def unpatch_init_and_builtins(self):
|
||||
|
||||
zero_init_context.append(self)
|
||||
if self.patched:
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
if not self.enabled:
|
||||
return
|
||||
def _disable_class(cls):
|
||||
cls.__init__ = cls._old_init
|
||||
|
||||
self.remove_wrappers()
|
||||
for subclass in get_all_subclasses(torch.nn.modules.module.Module):
|
||||
_disable_class(subclass)
|
||||
|
||||
# Exiting the top level context
|
||||
global zero_init_context
|
||||
zero_init_context.pop()
|
||||
if self.nest_level == 0:
|
||||
if dist.get_rank() == 0:
|
||||
logger.info("finished initializing model with %.2fB parameters", param_count / 1e9)
|
||||
# putting methods back the way we found them
|
||||
torch.nn.modules.module.Module.__init_subclass__ = torch.nn.modules.module.Module._old_init_subclass
|
||||
torch.nn.modules.module.Module.apply = torch.nn.modules.module.Module._old_apply
|
||||
|
||||
# Now that we cleaned up the metaclass injection, raise the exception.
|
||||
if exc_type is not None:
|
||||
return False
|
||||
self._remove_tensor_creation_wrappers()
|
||||
|
||||
# To be implemented by inheriting classes
|
||||
def _post_init_method(self, module):
|
||||
pass
|
||||
|
||||
def _set_dtype(self, ds_config, dtype):
|
||||
if ds_config is not None and dtype is None:
|
||||
if ds_config.bfloat16_enabled and ds_config.fp16_enabled:
|
||||
raise RuntimeError("bfloat16 and fp16 cannot be enabled at once")
|
||||
|
||||
if ds_config.bfloat16_enabled:
|
||||
self.dtype = torch.bfloat16
|
||||
elif ds_config.fp16_enabled:
|
||||
self.dtype = torch.half
|
||||
else:
|
||||
self.dtype = torch.float
|
||||
else:
|
||||
self.dtype = dtype or torch.half
|
||||
self.patched = False
|
||||
|
||||
def _add_tensor_creation_wrappers(self):
|
||||
torch.Tensor.__new__ = get_new_tensor_fn_for_dtype(self.dtype)
|
||||
|
@ -493,47 +512,22 @@ class InsertPostInitMethodToModuleSubClasses(object):
|
|||
torch.arange = _orig_torch_arange
|
||||
torch.eye = _orig_torch_eye
|
||||
|
||||
def remove_wrappers(self):
|
||||
|
||||
def _disable_class(cls):
|
||||
cls.__init__ = cls._old_init
|
||||
|
||||
for subclass in self.wrapped_cls:
|
||||
_disable_class(subclass)
|
||||
self.wrapped_cls.clear()
|
||||
|
||||
# This context is the top level of nested Init
|
||||
if self.nest_level == 0 and self.torch_func_wrapped:
|
||||
# putting methods back the way we found them
|
||||
torch.nn.modules.module.Module.__init_subclass__ = torch.nn.modules.module.Module._old_init_subclass
|
||||
torch.nn.modules.module.Module.apply = torch.nn.modules.module.Module._old_apply
|
||||
|
||||
self._remove_tensor_creation_wrappers()
|
||||
|
||||
# un doing it here will undo it during training
|
||||
# if self.mem_efficient_linear:
|
||||
# torch.nn.functional.linear = self.linear_bk
|
||||
# if self.mem_efficient_linear:
|
||||
# torch.nn.functional.linear = self.linear_bk
|
||||
|
||||
self.torch_func_wrapped = False
|
||||
|
||||
global all_wrapped_classes
|
||||
for subclass in get_all_subclasses(torch.nn.modules.module.Module):
|
||||
if subclass not in all_wrapped_classes:
|
||||
msg = f"`{subclass}' was not properly set up for sharding by zero.Init(). A subclass of torch.nn.Module must be defined before zero.Init() where an instance of the class is created."
|
||||
raise RuntimeError(msg)
|
||||
all_wrapped_classes.clear()
|
||||
|
||||
|
||||
def shutdown_init_context():
|
||||
"""
|
||||
This function is used to initialize deepspeed engine inside the context of Init.
|
||||
We need to remove the wrappers but keep the list of contexts.
|
||||
We need to remove the wrappers but keep the context.
|
||||
"""
|
||||
global zero_init_context
|
||||
for ctx in zero_init_context:
|
||||
ctx.remove_wrappers()
|
||||
if top_level_context:
|
||||
top_level_context.unpatch_init_and_builtins()
|
||||
|
||||
|
||||
def restore_init_context():
|
||||
"""
|
||||
This function is used to restore the wrappers after deepspeed engine is initialized.
|
||||
"""
|
||||
if top_level_context:
|
||||
top_level_context.patch_init_and_builtins()
|
||||
|
||||
|
||||
class AllGatherHandle:
|
||||
|
|
|
@ -10,10 +10,10 @@ from unit.common import DistributedTest
|
|||
import deepspeed
|
||||
|
||||
|
||||
class TestNewClassDeclaredInsideInit(DistributedTest):
|
||||
class TestNewClassDeclaredNestingInit(DistributedTest):
|
||||
world_size = 1
|
||||
|
||||
def test_new_class_declared_inside_init(self):
|
||||
def test_new_class_declared_nesting_init(self):
|
||||
ds_config = dict(train_batch_size=1, zero_optimization=dict(stage=3))
|
||||
|
||||
with deepspeed.zero.Init(config_dict_or_path=ds_config):
|
||||
|
@ -27,30 +27,27 @@ class TestNewClassDeclaredInsideInit(DistributedTest):
|
|||
with deepspeed.zero.Init(config_dict_or_path=ds_config):
|
||||
model = MyModel()
|
||||
|
||||
deepspeed_engine, *_ = deepspeed.initialize(model=model, config_params=ds_config)
|
||||
# ensure that zero3 processed the parameter
|
||||
assert hasattr(deepspeed_engine.fc.weight, "ds_id")
|
||||
assert hasattr(model.fc.weight, "ds_id")
|
||||
deepspeed_engine, *_ = deepspeed.initialize(model=model, config_params=ds_config)
|
||||
|
||||
|
||||
class TestNewClassDeclaredInsideInitFailure(DistributedTest):
|
||||
class TestNewClassDeclaredInsideNestingInit(DistributedTest):
|
||||
world_size = 1
|
||||
|
||||
def test_new_class_declared_inside_init_failure(self):
|
||||
def test_new_class_declared_inside_nesting_init(self):
|
||||
ds_config = dict(train_batch_size=1, zero_optimization=dict(stage=3))
|
||||
|
||||
try:
|
||||
with deepspeed.zero.Init(config_dict_or_path=ds_config):
|
||||
with deepspeed.zero.Init(config_dict_or_path=ds_config):
|
||||
|
||||
class MyModel(torch.nn.Module):
|
||||
class MyModel(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc = torch.nn.Linear(1, 1)
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc = torch.nn.Linear(1, 1)
|
||||
|
||||
model = MyModel()
|
||||
model = MyModel()
|
||||
|
||||
assert False, "Should have failed. A subclass of torch.nn.Module must be defined before zero.Init() where an instance of the class is created."
|
||||
except RuntimeError as e:
|
||||
pass
|
||||
except:
|
||||
assert False, "Should have failed. Runtime error is expected."
|
||||
# ensure that zero3 processed the parameter
|
||||
assert hasattr(model.fc.weight, "ds_id")
|
||||
deepspeed_engine, *_ = deepspeed.initialize(model=model, config_params=ds_config)
|
||||
|
|
|
@ -20,6 +20,27 @@ class TestNestingInit(DistributedTest):
|
|||
with deepspeed.zero.Init(config_dict_or_path=ds_config):
|
||||
model = torch.nn.Linear(4, 4)
|
||||
|
||||
deepspeed_engine, *_ = deepspeed.initialize(model=model, config_params=ds_config)
|
||||
# ensure that zero3 processed the parameter
|
||||
assert hasattr(deepspeed_engine.weight, "ds_id")
|
||||
assert hasattr(model.weight, "ds_id")
|
||||
|
||||
deepspeed_engine, *_ = deepspeed.initialize(model=model, config_params=ds_config)
|
||||
|
||||
|
||||
class TestShutdownInNestingInit(DistributedTest):
|
||||
world_size = 1
|
||||
|
||||
def test_shutdown_in_nesting_init(self):
|
||||
ds_config = dict(train_batch_size=1, zero_optimization=dict(stage=3))
|
||||
|
||||
with deepspeed.zero.Init(config_dict_or_path=ds_config):
|
||||
with deepspeed.zero.Init(config_dict_or_path=ds_config):
|
||||
model1 = torch.nn.Linear(4, 4)
|
||||
|
||||
assert hasattr(model1.weight, "ds_id")
|
||||
deepspeed_engine1, *_ = deepspeed.initialize(model=model1, config_params=ds_config)
|
||||
with deepspeed.zero.Init(config_dict_or_path=ds_config):
|
||||
model2 = torch.nn.Linear(4, 4)
|
||||
|
||||
# ensure that zero3 processed the parameter
|
||||
assert hasattr(model2.weight, "ds_id")
|
||||
deepspeed_engine2, *_ = deepspeed.initialize(model=model2, config_params=ds_config)
|
||||
|
|
Загрузка…
Ссылка в новой задаче