support model declaration in zero.Init context (#3592)

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
This commit is contained in:
Masahiro Tanaka 2023-06-26 15:19:42 -07:00 коммит произвёл GitHub
Родитель a63f9152b9
Коммит 203ac9d7ac
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 130 добавлений и 115 удалений

Просмотреть файл

@ -195,6 +195,9 @@ def initialize(args=None,
config=config,
config_class=config_class)
# Restore zero.Init context if necessary
zero.partition_parameters.restore_init_context()
return_items = [engine, engine.optimizer, engine.training_dataloader, engine.lr_scheduler]
return tuple(return_items)

Просмотреть файл

@ -36,8 +36,8 @@ from ..swap_tensor.partitioned_param_swapper import AsyncPartitionedParameterSwa
param_count = 0
partitioned_param_data_shape = [0]
zero_init_context = []
all_wrapped_classes = set()
zero_init_context = 0
top_level_context = None
class NoGatherHandle:
@ -301,6 +301,54 @@ class InsertPostInitMethodToModuleSubClasses(object):
if not self.enabled:
return
global zero_init_context
if zero_init_context == 0:
self.patch_init_and_builtins()
global top_level_context
top_level_context = self
zero_init_context += 1
def __exit__(self, exc_type, exc_value, traceback):
if not self.enabled:
return
global zero_init_context
zero_init_context -= 1
# Exiting the top level context
if zero_init_context == 0:
self.unpatch_init_and_builtins()
global top_level_context
top_level_context = None
if dist.get_rank() == 0:
logger.info("finished initializing model with %.2fB parameters", param_count / 1e9)
# Now that we cleaned up the metaclass injection, raise the exception.
if exc_type is not None:
return False
# To be implemented by inheriting classes
def _post_init_method(self, module):
pass
def _set_dtype(self, ds_config, dtype):
if ds_config is not None and dtype is None:
if ds_config.bfloat16_enabled and ds_config.fp16_enabled:
raise RuntimeError("bfloat16 and fp16 cannot be enabled at once")
if ds_config.bfloat16_enabled:
self.dtype = torch.bfloat16
elif ds_config.fp16_enabled:
self.dtype = torch.half
else:
self.dtype = torch.float
else:
self.dtype = dtype or torch.half
def patch_init_and_builtins(self):
def apply_with_gather(orig_module_apply_fn: Callable) -> Callable:
"""many models make use of child modules like Linear or Embedding which
perform their own weight initialization in their __init__ methods,
@ -401,79 +449,50 @@ class InsertPostInitMethodToModuleSubClasses(object):
cls.__init__ = partition_after(cls.__init__)
def _init_subclass(cls, **kwargs):
cls._old_init = cls.__init__
cls.__init__ = partition_after(cls.__init__)
# Replace .__init__() for all existing subclasses of torch.nn.Module recursively
global zero_init_context
self.nest_level = len(zero_init_context)
global all_wrapped_classes
for subclass in get_all_subclasses(torch.nn.modules.module.Module):
# Only wrap classes that haven't been wrapped yet
if subclass not in all_wrapped_classes:
_enable_class(subclass)
self.wrapped_cls.add(subclass)
_enable_class(subclass)
all_wrapped_classes = all_wrapped_classes.union(self.wrapped_cls)
# holding onto some methods so we can put them back the way they were in __exit__
torch.nn.modules.module.Module._old_init_subclass = torch.nn.modules.module.Module.__init_subclass__
torch.nn.modules.module.Module._old_apply = torch.nn.modules.module.Module.apply
torch.Tensor.__old_new__ = torch.Tensor.__new__
# Wrap some functions only at top level call of Init
if self.nest_level == 0:
# holding onto some methods so we can put them back the way they were in __exit__
torch.nn.modules.module.Module._old_init_subclass = torch.nn.modules.module.Module.__init_subclass__
torch.nn.modules.module.Module._old_apply = torch.nn.modules.module.Module.apply
torch.Tensor.__old_new__ = torch.Tensor.__new__
# Replace .__init__() for future subclasses of torch.nn.Module
torch.nn.modules.module.Module.__init_subclass__ = classmethod(_init_subclass)
torch.nn.modules.module.Module.apply = apply_with_gather(torch.nn.modules.module.Module._old_apply)
# Replace .__init__() for future subclasses of torch.nn.Module
torch.nn.modules.module.Module.__init_subclass__ = classmethod(_init_subclass)
torch.nn.modules.module.Module.apply = apply_with_gather(torch.nn.modules.module.Module._old_apply)
self._add_tensor_creation_wrappers()
self._add_tensor_creation_wrappers()
if self.mem_efficient_linear:
print_rank_0(
"nn.functional.linear has been overridden with a more memory efficient version. This will persist unless manually reset.",
force=False)
self.linear_bk = torch.nn.functional.linear
torch.nn.functional.linear = zero3_linear_wrap
if self.mem_efficient_linear:
print_rank_0(
"nn.functional.linear has been overridden with a more memory efficient version. This will persist unless manually reset.",
force=False)
self.linear_bk = torch.nn.functional.linear
torch.nn.functional.linear = zero3_linear_wrap
self.patched = True
self.torch_func_wrapped = True
def unpatch_init_and_builtins(self):
zero_init_context.append(self)
if self.patched:
def __exit__(self, exc_type, exc_value, traceback):
if not self.enabled:
return
def _disable_class(cls):
cls.__init__ = cls._old_init
self.remove_wrappers()
for subclass in get_all_subclasses(torch.nn.modules.module.Module):
_disable_class(subclass)
# Exiting the top level context
global zero_init_context
zero_init_context.pop()
if self.nest_level == 0:
if dist.get_rank() == 0:
logger.info("finished initializing model with %.2fB parameters", param_count / 1e9)
# putting methods back the way we found them
torch.nn.modules.module.Module.__init_subclass__ = torch.nn.modules.module.Module._old_init_subclass
torch.nn.modules.module.Module.apply = torch.nn.modules.module.Module._old_apply
# Now that we cleaned up the metaclass injection, raise the exception.
if exc_type is not None:
return False
self._remove_tensor_creation_wrappers()
# To be implemented by inheriting classes
def _post_init_method(self, module):
pass
def _set_dtype(self, ds_config, dtype):
if ds_config is not None and dtype is None:
if ds_config.bfloat16_enabled and ds_config.fp16_enabled:
raise RuntimeError("bfloat16 and fp16 cannot be enabled at once")
if ds_config.bfloat16_enabled:
self.dtype = torch.bfloat16
elif ds_config.fp16_enabled:
self.dtype = torch.half
else:
self.dtype = torch.float
else:
self.dtype = dtype or torch.half
self.patched = False
def _add_tensor_creation_wrappers(self):
torch.Tensor.__new__ = get_new_tensor_fn_for_dtype(self.dtype)
@ -493,47 +512,22 @@ class InsertPostInitMethodToModuleSubClasses(object):
torch.arange = _orig_torch_arange
torch.eye = _orig_torch_eye
def remove_wrappers(self):
def _disable_class(cls):
cls.__init__ = cls._old_init
for subclass in self.wrapped_cls:
_disable_class(subclass)
self.wrapped_cls.clear()
# This context is the top level of nested Init
if self.nest_level == 0 and self.torch_func_wrapped:
# putting methods back the way we found them
torch.nn.modules.module.Module.__init_subclass__ = torch.nn.modules.module.Module._old_init_subclass
torch.nn.modules.module.Module.apply = torch.nn.modules.module.Module._old_apply
self._remove_tensor_creation_wrappers()
# un doing it here will undo it during training
# if self.mem_efficient_linear:
# torch.nn.functional.linear = self.linear_bk
# if self.mem_efficient_linear:
# torch.nn.functional.linear = self.linear_bk
self.torch_func_wrapped = False
global all_wrapped_classes
for subclass in get_all_subclasses(torch.nn.modules.module.Module):
if subclass not in all_wrapped_classes:
msg = f"`{subclass}' was not properly set up for sharding by zero.Init(). A subclass of torch.nn.Module must be defined before zero.Init() where an instance of the class is created."
raise RuntimeError(msg)
all_wrapped_classes.clear()
def shutdown_init_context():
"""
This function is used to initialize deepspeed engine inside the context of Init.
We need to remove the wrappers but keep the list of contexts.
We need to remove the wrappers but keep the context.
"""
global zero_init_context
for ctx in zero_init_context:
ctx.remove_wrappers()
if top_level_context:
top_level_context.unpatch_init_and_builtins()
def restore_init_context():
"""
This function is used to restore the wrappers after deepspeed engine is initialized.
"""
if top_level_context:
top_level_context.patch_init_and_builtins()
class AllGatherHandle:

Просмотреть файл

@ -10,10 +10,10 @@ from unit.common import DistributedTest
import deepspeed
class TestNewClassDeclaredInsideInit(DistributedTest):
class TestNewClassDeclaredNestingInit(DistributedTest):
world_size = 1
def test_new_class_declared_inside_init(self):
def test_new_class_declared_nesting_init(self):
ds_config = dict(train_batch_size=1, zero_optimization=dict(stage=3))
with deepspeed.zero.Init(config_dict_or_path=ds_config):
@ -27,30 +27,27 @@ class TestNewClassDeclaredInsideInit(DistributedTest):
with deepspeed.zero.Init(config_dict_or_path=ds_config):
model = MyModel()
deepspeed_engine, *_ = deepspeed.initialize(model=model, config_params=ds_config)
# ensure that zero3 processed the parameter
assert hasattr(deepspeed_engine.fc.weight, "ds_id")
assert hasattr(model.fc.weight, "ds_id")
deepspeed_engine, *_ = deepspeed.initialize(model=model, config_params=ds_config)
class TestNewClassDeclaredInsideInitFailure(DistributedTest):
class TestNewClassDeclaredInsideNestingInit(DistributedTest):
world_size = 1
def test_new_class_declared_inside_init_failure(self):
def test_new_class_declared_inside_nesting_init(self):
ds_config = dict(train_batch_size=1, zero_optimization=dict(stage=3))
try:
with deepspeed.zero.Init(config_dict_or_path=ds_config):
with deepspeed.zero.Init(config_dict_or_path=ds_config):
class MyModel(torch.nn.Module):
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc = torch.nn.Linear(1, 1)
def __init__(self):
super().__init__()
self.fc = torch.nn.Linear(1, 1)
model = MyModel()
model = MyModel()
assert False, "Should have failed. A subclass of torch.nn.Module must be defined before zero.Init() where an instance of the class is created."
except RuntimeError as e:
pass
except:
assert False, "Should have failed. Runtime error is expected."
# ensure that zero3 processed the parameter
assert hasattr(model.fc.weight, "ds_id")
deepspeed_engine, *_ = deepspeed.initialize(model=model, config_params=ds_config)

Просмотреть файл

@ -20,6 +20,27 @@ class TestNestingInit(DistributedTest):
with deepspeed.zero.Init(config_dict_or_path=ds_config):
model = torch.nn.Linear(4, 4)
deepspeed_engine, *_ = deepspeed.initialize(model=model, config_params=ds_config)
# ensure that zero3 processed the parameter
assert hasattr(deepspeed_engine.weight, "ds_id")
assert hasattr(model.weight, "ds_id")
deepspeed_engine, *_ = deepspeed.initialize(model=model, config_params=ds_config)
class TestShutdownInNestingInit(DistributedTest):
world_size = 1
def test_shutdown_in_nesting_init(self):
ds_config = dict(train_batch_size=1, zero_optimization=dict(stage=3))
with deepspeed.zero.Init(config_dict_or_path=ds_config):
with deepspeed.zero.Init(config_dict_or_path=ds_config):
model1 = torch.nn.Linear(4, 4)
assert hasattr(model1.weight, "ds_id")
deepspeed_engine1, *_ = deepspeed.initialize(model=model1, config_params=ds_config)
with deepspeed.zero.Init(config_dict_or_path=ds_config):
model2 = torch.nn.Linear(4, 4)
# ensure that zero3 processed the parameter
assert hasattr(model2.weight, "ds_id")
deepspeed_engine2, *_ = deepspeed.initialize(model=model2, config_params=ds_config)