[inf] Add config var to enable keeping module on host (#6846)

Using keep_module_on_host config var will let us control if the loaded
checkpoints to model parameters will be moved to the device or stay on
host

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
This commit is contained in:
Omar Elayan 2025-01-15 21:25:29 +02:00 коммит произвёл GitHub
Родитель 66d3d3e94d
Коммит fae714d6bd
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
5 изменённых файлов: 53 добавлений и 15 удалений

Просмотреть файл

@ -174,6 +174,15 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
values for :any:`DeepSpeedMoEConfig`.
"""
keep_module_on_host: bool = False
"""
When loading checkpoints to model parameters, they are moved to the device. In very large models
this might fill the device and cause OOM. Setting this flag to true, will keep checkpoints on
host and not move them directly to the device (giving an option to quantize checkpoint data before
moving it to the device for example).
Set only for models with injection policies and auto TP.
"""
quant: QuantizationConfig = {}
"""
NOTE: only works for int8 dtype.

Просмотреть файл

@ -169,7 +169,7 @@ class InferenceEngine(Module):
is_meta_device = hasattr(self.module, "device") and self.module.device.type == 'meta'
if is_meta_device:
self.module.to_empty(device=device)
else:
elif not config.keep_module_on_host:
self.module.to(device)
if config.tensor_parallel.tp_size > 1:

Просмотреть файл

@ -17,14 +17,14 @@ from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw, shard_
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list
def move(tensor, device):
def move(tensor, device, copy=True):
if tensor.is_meta:
return torch.empty_like(tensor, device=device)
else:
# Using new tensors help in freeing memory (after split for example) was done before by calling clone().
# Using copy=True instead of clone() will help in case of cpu --> cpu.
# Otherwise to() will not create a new copy for the view of the full tensor, and it will not be de-referenced.
return tensor.to(device, copy=True)
return tensor.to(device, copy=copy)
class ReplaceWithTensorSlicing:
@ -189,7 +189,14 @@ class Loading():
class AutoTP():
def __init__(self, module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl):
def __init__(self,
module,
all_reduce_linears,
prefix,
state_dict,
linear_layer_setting,
orig_layer_impl,
keep_module_on_host=False):
self.module = module
self.all_reduce_linears = all_reduce_linears
self.prefix = prefix
@ -201,6 +208,7 @@ class AutoTP():
self.orig_layer_impl = orig_layer_impl
self.linear_policies = None
self.conv_linear_layer = False
self.keep_module_on_host = keep_module_on_host
def in_module_list(module, module_list):
for item in module_list:
@ -331,6 +339,10 @@ class AutoTP():
def _replace(self, child, name, conv_linear_layer):
if getattr(child, "replaced", False) == True:
return
device_name = 'cpu' if self.keep_module_on_host else get_accelerator().current_device_name()
# keep_module_on_host is used to keep the module on the host. Checkpoints are loaded to the host first (in some
# cases it can be done from the disk even to prevent filling host's memory), thus no need to create a new copy.
return_new_copy = not self.keep_module_on_host
weight_shape = child.weight.shape
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)
# For TP layer skip, e.g., MoE gate, deepseek low rank layer skip
@ -368,7 +380,7 @@ class AutoTP():
data = child.weight.data.split(get_shard_size_list(
weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size, name),
dim=1)
data_dc = move(data[mp_replace.gpu_index], get_accelerator().current_device_name()).detach()
data_dc = move(data[mp_replace.gpu_index], device_name, return_new_copy).detach()
del data
setattr(child, "replaced", True)
@ -376,10 +388,9 @@ class AutoTP():
return LmHeadLinearAllreduce(
torch.nn.parameter.Parameter(data_dc, requires_grad=False), dist.get_rank(), dist.get_world_size(),
child.bias if child.bias is None else torch.nn.parameter.Parameter(
move(child.bias,
get_accelerator().current_device_name())), self.mp_group)
move(child.bias, device_name, return_new_copy)), self.mp_group)
return LinearAllreduce(torch.nn.parameter.Parameter(data_dc, requires_grad=False), child.bias if child.bias is None else \
torch.nn.parameter.Parameter(move(child.bias, get_accelerator().current_device_name())), self.mp_group)
torch.nn.parameter.Parameter(move(child.bias, device_name, return_new_copy)), self.mp_group)
else:
# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
@ -392,22 +403,22 @@ class AutoTP():
#The copy is a regular copy, The shape of dst and src is the same
data_dc = move(
prepare_tp_fused_qkvw(self.module, child.weight.data, self.mp_size, mp_replace.gpu_index),
get_accelerator().current_device_name())
device_name, return_new_copy)
bias_data_dc = None if child.bias is None else move(
prepare_tp_fused_qkvw(self.module, child.bias.data, self.mp_size, mp_replace.gpu_index),
get_accelerator().current_device_name())
device_name, return_new_copy)
else:
data = child.weight.data.split(get_shard_size_list(weight_shape[0], self.mp_size, name),
dim=1 if self.conv_linear_layer else 0)
data_dc = move(data[mp_replace.gpu_index], get_accelerator().current_device_name()).detach()
data_dc = move(data[mp_replace.gpu_index], device_name, return_new_copy).detach()
del data
if child.bias is not None:
bias_data = child.bias.data.split(get_shard_size_list(
weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size, name),
dim=0)
bias_data = move(bias_data[mp_replace.gpu_index], get_accelerator().current_device_name())
bias_data = move(bias_data[mp_replace.gpu_index], device_name, return_new_copy)
bias_data_dc = torch.nn.parameter.Parameter(bias_data, requires_grad=False)
del bias_data
else:

Просмотреть файл

@ -268,7 +268,8 @@ def replace_transformer_layer(orig_layer_impl, model, checkpoint_dict, config, m
#mp_replace = ReplaceWithTensorSlicing(mp_group=config.tensor_parallel.tp_group)
# 1. Create AutoTP object
_autotp = AutoTP(module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl)
_autotp = AutoTP(module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl,
config.keep_module_on_host)
# 2. Set the tensor parallelism config
_autotp.set_tensor_parallel_config(config.tensor_parallel.tp_size, config.tensor_parallel.tp_group)

Просмотреть файл

@ -554,6 +554,7 @@ class TestInjectionPolicy(DistributedTest):
@pytest.mark.seq_inference
@pytest.mark.parametrize('keep_module_on_host', [True, False])
@pytest.mark.parametrize(
"model_w_task",
[("Helsinki-NLP/opus-mt-en-de", "translation"), ("Salesforce/codegen-350M-mono", "text-generation")],
@ -570,6 +571,7 @@ class TestAutoTensorParallelism(DistributedTest):
inf_kwargs,
assert_fn,
dtype,
keep_module_on_host,
):
invalid_test_msg = validate_test(model_w_task, dtype, enable_cuda_graph=False, enable_triton=False)
if invalid_test_msg:
@ -592,13 +594,20 @@ class TestAutoTensorParallelism(DistributedTest):
framework="pt")
bs_output = pipe(query, **inf_kwargs)
pipe.model = deepspeed.init_inference(pipe.model, mp_size=world_size, dtype=dtype)
pipe.model = deepspeed.init_inference(pipe.model,
mp_size=world_size,
dtype=dtype,
keep_module_on_host=keep_module_on_host)
ds_output = pipe(query, **inf_kwargs)
print(local_rank, "baseline", bs_output)
print(local_rank, "deepspeed", ds_output)
assert assert_fn(bs_output, ds_output)
if keep_module_on_host:
for name, param in model.named_parameters():
assert param.device == torch.device('cpu'), f"keep_module_on_host is on but param {name} is not on cpu"
@pytest.mark.world_size(3)
def test_odd_world_size(
self,
@ -607,6 +616,7 @@ class TestAutoTensorParallelism(DistributedTest):
inf_kwargs,
assert_fn,
dtype,
keep_module_on_host,
):
invalid_test_msg = validate_test(model_w_task, dtype, enable_cuda_graph=False, enable_triton=False)
if invalid_test_msg:
@ -624,13 +634,20 @@ class TestAutoTensorParallelism(DistributedTest):
framework="pt")
bs_output = pipe(query, **inf_kwargs)
pipe.model = deepspeed.init_inference(pipe.model, mp_size=world_size, dtype=dtype)
pipe.model = deepspeed.init_inference(pipe.model,
mp_size=world_size,
dtype=dtype,
keep_module_on_host=keep_module_on_host)
ds_output = pipe(query, **inf_kwargs)
print(local_rank, "baseline", bs_output)
print(local_rank, "deepspeed", ds_output)
assert assert_fn(bs_output, ds_output)
if keep_module_on_host:
for name, param in model.named_parameters():
assert param.device == torch.device('cpu'), f"keep_module_on_host is on but param {name} is not on cpu"
@pytest.mark.nightly
@pytest.mark.parametrize(