зеркало из https://github.com/microsoft/DeepSpeed.git
Tensor-Parallelism general support (#1512)
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
This commit is contained in:
Родитель
b16dd943a4
Коммит
9ce00a2171
|
@ -237,7 +237,9 @@ def init_inference(model,
|
|||
dtype=None,
|
||||
injection_policy=None,
|
||||
replace_method='auto',
|
||||
quantization_setting=None):
|
||||
quantization_setting=None,
|
||||
replace_with_kernel_inject=False,
|
||||
return_tuple=True):
|
||||
"""Initialize the DeepSpeed InferenceEngine.
|
||||
|
||||
Arguments:
|
||||
|
@ -267,6 +269,7 @@ def init_inference(model,
|
|||
of groups used in quantization. A tuple is passed in if we want to mention that there is extra-grouping
|
||||
for the MLP part of a Transformer layer (e.g. (True, 8) shows we quantize the model using 8 groups for
|
||||
all the network except the MLP part that we use 8 extra grouping).
|
||||
replace_with_kernel_inject: If set we inject kernel as we initialize the inference-engine
|
||||
|
||||
Returns:
|
||||
A deepspeed.InferenceEngine wrapped model.
|
||||
|
@ -286,7 +289,9 @@ def init_inference(model,
|
|||
checkpoint,
|
||||
dtype,
|
||||
injection_policy,
|
||||
return_tuple,
|
||||
replace_method,
|
||||
quantization_setting)
|
||||
quantization_setting,
|
||||
replace_with_kernel_inject)
|
||||
|
||||
return engine
|
||||
|
|
|
@ -25,7 +25,8 @@ class InferenceEngine(Module):
|
|||
injection_dict=None,
|
||||
return_tuple=True,
|
||||
replace_method='auto',
|
||||
quantization_setting=None):
|
||||
quantization_setting=None,
|
||||
replace_with_kernel_inject=False):
|
||||
"""
|
||||
Args:
|
||||
model: torch.nn.Module
|
||||
|
@ -74,15 +75,17 @@ class InferenceEngine(Module):
|
|||
self.mp_group = self.mpu.get_model_parallel_group()
|
||||
elif self.mp_world_size > 1:
|
||||
self._create_model_parallel_group()
|
||||
|
||||
# apply injection policy
|
||||
if self.injection_dict:
|
||||
for client_module, injection_policy in self.injection_dict.items():
|
||||
self._apply_injection_policy(client_module,
|
||||
injection_policy,
|
||||
return_tuple)
|
||||
elif replace_method == "auto":
|
||||
self._apply_injection_policy()
|
||||
return_tuple,
|
||||
replace_with_kernel_inject)
|
||||
elif replace_method == 'auto':
|
||||
self._apply_injection_policy(
|
||||
return_tuple=return_tuple,
|
||||
replace_with_kernel_inject=replace_with_kernel_inject)
|
||||
|
||||
device = torch.cuda.current_device()
|
||||
logger.info(f"Place model to device: {device}")
|
||||
|
@ -152,7 +155,9 @@ class InferenceEngine(Module):
|
|||
def _apply_injection_policy(self,
|
||||
client_module=None,
|
||||
injection_policy=None,
|
||||
return_tuple=True):
|
||||
return_tuple=True,
|
||||
replace_with_kernel_inject=False):
|
||||
|
||||
replace_transformer_layer(client_module,
|
||||
self.module,
|
||||
policy=injection_policy,
|
||||
|
@ -166,7 +171,8 @@ class InferenceEngine(Module):
|
|||
quantize_settings=(self.quantization_scales,
|
||||
self.quantize_merge_count,
|
||||
self.mlp_extra_grouping,
|
||||
self.quantize_groups))
|
||||
self.quantize_groups),
|
||||
replace_with_kernel_inject=replace_with_kernel_inject)
|
||||
|
||||
def _load_checkpoint(self, load_dir, load_module_strict=True):
|
||||
sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir)
|
||||
|
|
|
@ -2,10 +2,40 @@ import copy
|
|||
import torch
|
||||
import deepspeed
|
||||
import deepspeed.ops.transformer as transformer_inference
|
||||
from .replace_policy import HFBertLayerPolicy, MegatronLayerPolicy
|
||||
from .replace_policy import HFBertLayerPolicy, MegatronLayerPolicy, HFGPT2LayerPolicy
|
||||
from .replace_policy import replace_policies
|
||||
from ..constants import INFERENCE_GENERIC_MODE, INFERENCE_SPECIALIZED_MODE
|
||||
from ..runtime.weight_quantizer import WeightQuantization
|
||||
from torch import nn
|
||||
|
||||
|
||||
class LinearAllreduce(nn.Module):
|
||||
def __init__(self, weight, bias=None, mp_group=None):
|
||||
super(LinearAllreduce, self).__init__()
|
||||
self.weight = weight
|
||||
self.bias = bias
|
||||
self.mp_group = mp_group
|
||||
|
||||
def forward(self, input):
|
||||
output = torch.matmul(input, self.weight)
|
||||
if self.mp_group is not None:
|
||||
torch.distributed.all_reduce(output, group=self.mp_group)
|
||||
if self.bias is not None:
|
||||
output += self.bias
|
||||
return output
|
||||
|
||||
|
||||
class LinearLayer(nn.Module):
|
||||
def __init__(self, weight, bias=None):
|
||||
super(LinearLayer, self).__init__()
|
||||
self.weight = weight
|
||||
self.bias = bias
|
||||
|
||||
def forward(self, input):
|
||||
output = torch.matmul(input, self.weight)
|
||||
if self.bias is not None:
|
||||
output += self.bias
|
||||
return output
|
||||
|
||||
|
||||
class ReplaceWithTensorSlicing:
|
||||
|
@ -103,13 +133,17 @@ def replace_transformer_layer(orig_layer_impl,
|
|||
training=True,
|
||||
quantize=False,
|
||||
quantize_settings=None,
|
||||
return_tuple=False):
|
||||
return_tuple=True,
|
||||
replace_with_kernel_inject=False,
|
||||
linear_layer_setting=None):
|
||||
""" Replace bert-style transformer layers with DeepSpeed's transformer layer
|
||||
Arguments:
|
||||
orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for,
|
||||
e.g., transformers.modeling_bert.BertLayer.
|
||||
model (torch.nn.Module): user's nn.module representing their model
|
||||
policy: shows the policy for mapping from the orig_layer_impl to transformer parameters
|
||||
policy: shows the policy for mapping from the orig_layer_impl to transformer parameters when
|
||||
replace_with_kernel_inject is set, otherwise, it provides the names of two linear layers as
|
||||
a tuple: (attention_output projection, transformer output projection)
|
||||
micro_batch_size (int): micro batch size per gpu used during training/eval
|
||||
config (dict): model config containing hidden size, attention heads, etc.
|
||||
seed (int): random seed value
|
||||
|
@ -127,7 +161,12 @@ def replace_transformer_layer(orig_layer_impl,
|
|||
It includes (quantization_scales, merge_count, mlp_extra_grouping, quantize_groups).
|
||||
return_tuple (bool): if set, transformer layer returns a tuple as the output.
|
||||
Note: this flag needs to be set for huggingface models.
|
||||
|
||||
replace_with_kernel_inject (bool): injection_mode, if true, kernels will be add along with configuring
|
||||
Tensor-Parallelism
|
||||
linear_layer_setting (tuple of modules) [Optional]: shows which two classes are used for linear layers
|
||||
and embedding layers
|
||||
attention_params: (list of strings) [Optional]: shows the parameters in the attention part that needs to
|
||||
be adjusted based on the model-parallelism
|
||||
Returns:
|
||||
Updated nn.module with replaced transformer layers
|
||||
"""
|
||||
|
@ -299,6 +338,110 @@ def replace_transformer_layer(orig_layer_impl,
|
|||
new_module.output_b.data = _4hh_b
|
||||
return new_module
|
||||
|
||||
def replace_wo_policy(module, all_reduce_linears):
|
||||
def _replace(child, name, conv_linear_layer):
|
||||
mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group)
|
||||
if name in all_reduce_linears:
|
||||
new_weight = torch.empty(
|
||||
(child.weight.shape[0]
|
||||
if conv_linear_layer else child.weight.shape[1] // mp_size,
|
||||
child.weight.shape[1]
|
||||
if conv_linear_layer else child.weight.shape[0]),
|
||||
device=child.weight.device,
|
||||
dtype=torch.half if fp16 else torch.float)
|
||||
if not conv_linear_layer:
|
||||
child.weight.data.view(-1).copy_(
|
||||
child.weight.data.transpose(-1,
|
||||
-2).contiguous().view(-1))
|
||||
child.weight.data = child.weight.data.reshape(
|
||||
child.weight.data.shape[-1],
|
||||
child.weight.data.shape[-2])
|
||||
data = mp_replace.copy(new_weight,
|
||||
child.weight.data).to(torch.cuda.current_device())
|
||||
return LinearAllreduce(data, child.bias if child.bias is None else \
|
||||
child.bias.to(torch.cuda.current_device()), mp_group)
|
||||
else:
|
||||
new_weight = torch.empty(
|
||||
(child.weight.shape[0] //
|
||||
mp_size if conv_linear_layer else child.weight.shape[1],
|
||||
child.weight.shape[1]
|
||||
if conv_linear_layer else child.weight.shape[0] // mp_size),
|
||||
device=child.weight.device,
|
||||
dtype=torch.half if fp16 else torch.float)
|
||||
if not conv_linear_layer:
|
||||
child.weight.data.view(-1).copy_(
|
||||
child.weight.data.transpose(-1,
|
||||
-2).contiguous().view(-1))
|
||||
child.weight.data = child.weight.data.reshape(
|
||||
child.weight.data.shape[-1],
|
||||
child.weight.data.shape[-2])
|
||||
data = mp_replace.copy(new_weight, child.weight.data)
|
||||
new_bias = torch.empty((child.weight.shape[1] // mp_size),
|
||||
device=child.weight.device,
|
||||
dtype=torch.half if fp16 else torch.float)
|
||||
bias_data = None if child.bias is None else mp_replace.copy(
|
||||
new_bias,
|
||||
child.bias.data).to(torch.cuda.current_device())
|
||||
return LinearLayer(data.to(torch.cuda.current_device()), bias_data)
|
||||
|
||||
def _slice_embedding(child, name, conv_linear_layer):
|
||||
mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group)
|
||||
new_weight = torch.empty((child.weight.shape[0],
|
||||
child.weight.shape[1] // mp_size),
|
||||
device=child.weight.device,
|
||||
dtype=child.weight.dtype)
|
||||
data = mp_replace.copy(new_weight, child.weight.data)
|
||||
new_embedding = nn.Embedding(child.weight.shape[0],
|
||||
child.weight.shape[1] // mp_size)
|
||||
new_embedding.weight.data.copy_(data)
|
||||
return new_embedding
|
||||
|
||||
def update_mp_params(child):
|
||||
if hasattr(child, 'n_heads'):
|
||||
child.n_heads = child.n_heads // mp_size
|
||||
if hasattr(child, 'inner_dim'):
|
||||
child.inner_dim = child.inner_dim // mp_size
|
||||
if hasattr(child, 'num_heads'):
|
||||
child.num_heads = child.num_heads // mp_size
|
||||
if hasattr(child, 'num_attention_heads'):
|
||||
child.num_attention_heads = child.num_attention_heads // mp_size
|
||||
if hasattr(child, 'all_head_size'):
|
||||
child.all_head_size = child.all_head_size // mp_size
|
||||
if hasattr(child, 'embed_dim'):
|
||||
child.embed_dim = child.embed_dim // mp_size
|
||||
|
||||
conv_linear_layer = False
|
||||
if linear_layer_setting is not None:
|
||||
linear_policies = {linear_layer_setting[0]: _replace}
|
||||
if len(linear_layer_setting) == 2:
|
||||
linear_policies.update({linear_layer_setting[1]: _slice_embedding})
|
||||
else:
|
||||
if orig_layer_impl is HFGPT2LayerPolicy._orig_layer_class:
|
||||
try:
|
||||
import transformers
|
||||
conv_linear_layer = True
|
||||
linear_policies = {transformers.model_utils.Conv1D: _replace}
|
||||
except ImportError:
|
||||
linear_policies = {nn.Linear: _replace}
|
||||
else:
|
||||
linear_policies = {nn.Linear: _replace, nn.Embedding: _slice_embedding}
|
||||
|
||||
def _replace_module(r_module, prev_name=''):
|
||||
for name, child in r_module.named_children():
|
||||
if child.__class__ in linear_policies:
|
||||
setattr(
|
||||
r_module,
|
||||
name,
|
||||
linear_policies[child.__class__](child,
|
||||
prev_name + '.' + name,
|
||||
conv_linear_layer))
|
||||
else:
|
||||
update_mp_params(child)
|
||||
_replace_module(child, name)
|
||||
return r_module
|
||||
|
||||
return _replace_module(module)
|
||||
|
||||
def replace_fn(child, _policy, layer_id=0):
|
||||
if training:
|
||||
# copy relevant state from child -> new module
|
||||
|
@ -306,11 +449,15 @@ def replace_transformer_layer(orig_layer_impl,
|
|||
|
||||
else:
|
||||
# copy relevant state from child -> new module
|
||||
new_module = replace_with_policy(child,
|
||||
_policy,
|
||||
inference=True,
|
||||
preln=(policy is not HFBertLayerPolicy),
|
||||
layer_id=layer_id)
|
||||
if replace_with_kernel_inject:
|
||||
new_module = replace_with_policy(
|
||||
child,
|
||||
_policy,
|
||||
inference=True,
|
||||
preln=(_policy is not HFBertLayerPolicy),
|
||||
layer_id=layer_id)
|
||||
else:
|
||||
new_module = replace_wo_policy(child, _policy)
|
||||
|
||||
return new_module
|
||||
|
||||
|
@ -327,7 +474,6 @@ def revert_transformer_layer(orig_layer_impl, model, config, preln=False):
|
|||
e.g., transformers.modeling_bert.BertLayer.
|
||||
model (torch.nn.Module): user's nn.module representing their model
|
||||
config (dict): model config containing hidden size, attention heads, etc.
|
||||
|
||||
Returns:
|
||||
Updated nn.module with original bert-style transformer layers
|
||||
"""
|
||||
|
@ -396,7 +542,6 @@ def replace_module(model, orig_class, replace_fn, _replace_policy):
|
|||
orig_class (torch.nn.Module): the module to search for
|
||||
replace_fn (method): a method to convert instances of ``orig_class`` to the
|
||||
desired type and return a new instance.
|
||||
|
||||
Returns:
|
||||
A modified ``model``.
|
||||
"""
|
||||
|
@ -422,20 +567,17 @@ def _replace_module(model, policies, layer_id=0):
|
|||
Arguments:
|
||||
model (torch.nn.Module): model to augment
|
||||
policies (dict): Mapping of source class to replacement function.
|
||||
|
||||
Returns:
|
||||
Modified ``model``.
|
||||
"""
|
||||
for name, child in model.named_children():
|
||||
if child.__class__ in policies:
|
||||
orig = repr(child)
|
||||
setattr(
|
||||
model,
|
||||
name,
|
||||
policies[child.__class__][0](child,
|
||||
policies[child.__class__][-1],
|
||||
layer_id))
|
||||
new = getattr(model, name)
|
||||
layer_id += 1
|
||||
else:
|
||||
_, layer_id = _replace_module(child, policies, layer_id=layer_id)
|
||||
|
|
|
@ -620,6 +620,6 @@ class DeepSpeedTransformerInference(nn.Module):
|
|||
output = (output, presents)
|
||||
|
||||
if self.config.return_tuple:
|
||||
return (output, )
|
||||
return output if type(output) is tuple else (output, )
|
||||
else:
|
||||
return output
|
||||
|
|
Загрузка…
Ссылка в новой задаче