зеркало из https://github.com/microsoft/DeepSpeed.git
Merge branch 'master' into reduce_coalesced_fetch_bubble
This commit is contained in:
Коммит
ed91abcf21
|
@ -27,7 +27,7 @@ jobs:
|
|||
env: {ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true} # Allow using Node16 actions
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- id: setup-venv
|
||||
uses: ./.github/workflows/setup-venv
|
||||
|
|
|
@ -39,7 +39,7 @@ jobs:
|
|||
# The type of runner that the job will run on
|
||||
runs-on: [self-hosted, intel, gaudi2]
|
||||
container:
|
||||
image: vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest
|
||||
image: vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest
|
||||
ports:
|
||||
- 80
|
||||
options: --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice
|
||||
|
|
|
@ -23,7 +23,7 @@ jobs:
|
|||
unit-tests:
|
||||
runs-on: [self-hosted, nvidia, a6000]
|
||||
container:
|
||||
image: nvcr.io/nvidia/pytorch:23.03-py3
|
||||
image: nvcr.io/nvidia/pytorch:24.03-py3
|
||||
ports:
|
||||
- 80
|
||||
options: --gpus all --shm-size "8G"
|
||||
|
@ -47,8 +47,6 @@ jobs:
|
|||
- name: Install deepspeed
|
||||
run: |
|
||||
python -m pip install docutils==0.18.1 jinja2==3.0 urllib3==1.26.11 ninja
|
||||
# Update packages included in the container that do not support pydantic 2+ to versions that do
|
||||
python -m pip install thinc spacy confection --upgrade
|
||||
python -m pip install .[dev,1bit,autotuning,inf]
|
||||
ds_report
|
||||
- name: Python environment
|
||||
|
@ -58,8 +56,8 @@ jobs:
|
|||
run: |
|
||||
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
|
||||
cd tests
|
||||
python -m pytest --color=yes --durations=0 --verbose -rF -m 'inference_v2' unit/ --torch_ver="2.0" --cuda_ver="12"
|
||||
python -m pytest --color=yes --durations=0 --verbose -rF -m 'inference_v2_ops' unit/ --torch_ver="2.0" --cuda_ver="12"
|
||||
python -m pytest --color=yes --durations=0 --verbose -rF -m 'inference_v2' unit/ --torch_ver="2.3" --cuda_ver="12"
|
||||
python -m pytest --color=yes --durations=0 --verbose -rF -m 'inference_v2_ops' unit/ --torch_ver="2.3" --cuda_ver="12"
|
||||
- name: MII unit tests
|
||||
run: |
|
||||
BRANCH="main"
|
||||
|
|
|
@ -11,7 +11,7 @@ jobs:
|
|||
unit-tests:
|
||||
runs-on: [self-hosted, nvidia, a6000]
|
||||
container:
|
||||
image: nvcr.io/nvidia/pytorch:23.03-py3
|
||||
image: nvcr.io/nvidia/pytorch:24.03-py3
|
||||
ports:
|
||||
- 80
|
||||
options: --gpus all --shm-size "8G"
|
||||
|
@ -50,4 +50,4 @@ jobs:
|
|||
run: |
|
||||
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
|
||||
cd tests
|
||||
python -m pytest --color=yes --durations=0 --verbose -rF -m 'evaluation' -k "test_human_eval" unit/ --torch_ver="2.0" --cuda_ver="12"
|
||||
python -m pytest --color=yes --durations=0 --verbose -rF -m 'evaluation' -k "test_human_eval" unit/ --torch_ver="2.3" --cuda_ver="12"
|
||||
|
|
|
@ -22,7 +22,7 @@ jobs:
|
|||
runs-on: [self-hosted, nvidia, cu121, v100]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- id: setup-venv
|
||||
uses: ./.github/workflows/setup-venv
|
||||
|
|
|
@ -27,7 +27,7 @@ jobs:
|
|||
sd-tests:
|
||||
runs-on: [self-hosted, nvidia, a6000]
|
||||
container:
|
||||
image: nvcr.io/nvidia/pytorch:23.03-py3
|
||||
image: nvcr.io/nvidia/pytorch:24.03-py3
|
||||
ports:
|
||||
- 80
|
||||
options: --gpus all --shm-size "8G"
|
||||
|
@ -53,8 +53,6 @@ jobs:
|
|||
pip install image-similarity-measures
|
||||
python -m pip install opencv-python==4.6.* --force-reinstall
|
||||
python -m pip install docutils==0.18.1 jinja2==3.0 urllib3==1.26.11 ninja
|
||||
# Update packages included in the container that do not support pydantic 2+ to versions that do
|
||||
python -m pip install thinc spacy confection --upgrade
|
||||
python -m pip install .[dev,1bit,autotuning,sd]
|
||||
ds_report
|
||||
- name: Python environment
|
||||
|
@ -64,7 +62,7 @@ jobs:
|
|||
run: |
|
||||
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
|
||||
cd tests
|
||||
python -m pytest --color=yes --durations=0 --verbose -rF -m 'stable_diffusion' -k "TestStableDiffusion" unit/ --torch_ver="2.0" --cuda_ver="12"
|
||||
python -m pytest --color=yes --durations=0 --verbose -rF -m 'stable_diffusion' -k "TestStableDiffusion" unit/ --torch_ver="2.3" --cuda_ver="12"
|
||||
|
||||
- name: Open GitHub issue if weekly CI fails
|
||||
if: ${{ failure() && (github.event_name == 'schedule') }}
|
||||
|
|
|
@ -20,7 +20,7 @@ jobs:
|
|||
env: {ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true} # Allow using Node16 actions
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- id: setup-venv
|
||||
uses: ./.github/workflows/setup-venv
|
||||
|
|
|
@ -20,7 +20,7 @@ jobs:
|
|||
env: {ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true} # Allow using Node16 actions
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- id: setup-venv
|
||||
uses: ./.github/workflows/setup-venv
|
||||
|
|
|
@ -277,8 +277,10 @@ def replace_transformer_layer(orig_layer_impl, model, checkpoint_dict, config, m
|
|||
if hasattr(model_config, "vision_config"):
|
||||
if "MllamaVisionEncoderLayer" in str(module):
|
||||
num_kv_heads = _autotp.get_model_num_kv_heads(model_config.vision_config)
|
||||
else:
|
||||
elif hasattr(model_config, "text_config"):
|
||||
num_kv_heads = _autotp.get_model_num_kv_heads(model_config.text_config)
|
||||
else:
|
||||
num_kv_heads = _autotp.get_model_num_kv_heads(model_config)
|
||||
else:
|
||||
num_kv_heads = _autotp.get_model_num_kv_heads(model_config)
|
||||
|
||||
|
|
|
@ -24,7 +24,9 @@ def set_n_embd(num):
|
|||
|
||||
def get_num_kv_heads():
|
||||
global num_kv_heads
|
||||
return num_kv_heads
|
||||
if 'num_kv_heads' in globals():
|
||||
return num_kv_heads
|
||||
return None
|
||||
|
||||
|
||||
def get_num_attention_heads():
|
||||
|
|
|
@ -115,7 +115,7 @@ class FlopsProfiler(object):
|
|||
get_accelerator().synchronize()
|
||||
module.__start_time__ = time.time()
|
||||
|
||||
if not hasattr(module, "__start_time_hook_handle"):
|
||||
if not hasattr(module, "__start_time_hook_handle__"):
|
||||
module.__start_time_hook_handle__ = module.register_forward_pre_hook(start_time_hook)
|
||||
|
||||
def end_time_hook(module, input, output):
|
||||
|
|
|
@ -5,6 +5,15 @@
|
|||
|
||||
import torch
|
||||
|
||||
try:
|
||||
from torch.compiler import is_compiling as torch_is_compiling
|
||||
except ImportError:
|
||||
try:
|
||||
from torch._dynamo.external_utils import is_compiling as torch_is_compiling
|
||||
except ImportError:
|
||||
# Torch does not have compiler support
|
||||
torch_is_compiling = lambda: False
|
||||
|
||||
|
||||
def is_compile_supported():
|
||||
return hasattr(torch, "compiler") and hasattr(torch.nn.Module, "compile")
|
||||
|
@ -14,3 +23,7 @@ def disable(func):
|
|||
if is_compile_supported():
|
||||
return torch.compiler.disable(func)
|
||||
return func
|
||||
|
||||
|
||||
def is_compiling():
|
||||
return torch_is_compiling()
|
||||
|
|
|
@ -287,7 +287,8 @@ class PipelineEngine(DeepSpeedEngine):
|
|||
weight_group_list = self.module.get_tied_weights_and_groups()
|
||||
for weight, group in weight_group_list:
|
||||
grad = weight._hp_grad if self.using_bf16_optimizer else weight.grad
|
||||
dist.all_reduce(grad, group=group)
|
||||
if grad is not None:
|
||||
dist.all_reduce(grad, group=group)
|
||||
|
||||
def _exec_reduce_grads(self):
|
||||
self._force_grad_boundary = True
|
||||
|
|
|
@ -38,7 +38,7 @@ def _apply_forward_and_backward_to_tensors_only(module, forward_function, backwa
|
|||
|
||||
class ZeROOrderedDict(OrderedDict):
|
||||
|
||||
def __init__(self, parent_module=None, *args, **kwargs):
|
||||
def __init__(self, parent_module, *args, **kwargs):
|
||||
"""A replacement for ``collections.OrderedDict`` to detect external ZeRO params.
|
||||
|
||||
Args:
|
||||
|
@ -49,6 +49,10 @@ class ZeROOrderedDict(OrderedDict):
|
|||
self._parent_module = parent_module
|
||||
self._in_forward = False
|
||||
|
||||
def __reduce__(self):
|
||||
r0, _, *r2 = super().__reduce__()
|
||||
return (r0, (self._parent_module, )) + r2
|
||||
|
||||
def __getitem__(self, key):
|
||||
param = super().__getitem__(key)
|
||||
|
||||
|
@ -56,6 +60,7 @@ class ZeROOrderedDict(OrderedDict):
|
|||
if param is None:
|
||||
return param
|
||||
|
||||
# TODO: only weaken this check during compilation
|
||||
if hasattr(param, "ds_status") and param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
|
||||
if self._parent_module._parameters._in_forward:
|
||||
register_external_parameter(FWD_MODULE_STACK[-1], param)
|
||||
|
|
|
@ -10,6 +10,8 @@ from torch.nn import Module
|
|||
|
||||
import deepspeed.comm as dist
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
from deepspeed.module_inject.tp_shard import get_shard_size_list, set_num_kv_heads, get_num_kv_heads
|
||||
from deepspeed.utils import groups
|
||||
|
||||
|
||||
def post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, seq_len, num_head, head_dim):
|
||||
|
@ -38,8 +40,132 @@ def post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, seq_len, num_he
|
|||
return post_func
|
||||
|
||||
|
||||
def uneven_heads_all2all(input, scatter_idx, gather_idx, batch_dim_idx, group):
|
||||
seq_world_size = dist.get_world_size(group)
|
||||
inp_shape = list(input.shape)
|
||||
assert batch_dim_idx in [0, 1], "batch_dim_idx must be either 0 or 1"
|
||||
|
||||
if not (scatter_idx < 2):
|
||||
input_splits = get_shard_size_list(inp_shape[scatter_idx], seq_world_size)
|
||||
input = input.transpose(0, scatter_idx).contiguous()
|
||||
local_heads = input_splits[groups._get_sequence_parallel_rank()]
|
||||
output_splits = [local_heads] * seq_world_size
|
||||
|
||||
output_buffer_shape = [seq_world_size * local_heads] + list(input.shape[1:])
|
||||
output = torch.empty(output_buffer_shape, device=input.device, dtype=input.dtype)
|
||||
dist.all_to_all_single(output,input,output_split_sizes=output_splits,\
|
||||
input_split_sizes=input_splits,group=group)
|
||||
###[seq_ws*local_heads, ...] to [seq_ws, local_heads, ...]
|
||||
output = output.view(seq_world_size, local_heads, *output.shape[1:])
|
||||
###[seq_ws,local_heads,b,seq_len,...] to [seq_ws,seq_len,b,local_heads,...]
|
||||
|
||||
### batch_dim_idx=0 [seq_ws,local_heads,seq_len,b,...] to [b, seq_ws, seq_len, local_heads ...]
|
||||
### batch_dim_idx=1 [seq_ws,local_heads,b,seq_len,...] to [seq_ws,seq_len,b,local_heads,...]
|
||||
if batch_dim_idx == 0:
|
||||
order = [3, 0, 2, 1] + list(range(4, len(output.shape)))
|
||||
output = output.permute(order).contiguous()
|
||||
###[b, seq_ws*local_seq_len, local_heads,...]
|
||||
output = output.view(output.shape[0], inp_shape[gather_idx] * seq_world_size,
|
||||
*output.shape[3:]).contiguous()
|
||||
elif batch_dim_idx == 1:
|
||||
output = output.transpose(1, 3).contiguous()
|
||||
###[seq_ws*local_seq_len, b, local_heads,...]
|
||||
output = output.view(inp_shape[gather_idx] * seq_world_size, *output.shape[2:]).contiguous()
|
||||
else:
|
||||
# The compatibility handling of 4D and 3D tensors, standardizing to 3D.
|
||||
input = input.reshape(input.shape[0], input.shape[1], -1)
|
||||
|
||||
if batch_dim_idx == 0: #b,s,h
|
||||
input = input.permute(1, 2, 0).contiguous() #s,h,b
|
||||
elif batch_dim_idx == 1: #s,b,h
|
||||
input = input.transpose(1, 2).contiguous() #s,h,b
|
||||
seq_len, h, batch_size = input.shape
|
||||
num_local_heads_list = get_shard_size_list(get_num_kv_heads(), seq_world_size)
|
||||
local_heads = num_local_heads_list[groups._get_sequence_parallel_rank()]
|
||||
h_dim = h // local_heads
|
||||
local_seq_len = seq_len // seq_world_size
|
||||
|
||||
input = input.view(seq_len * h, batch_size)
|
||||
local_seq_len_with_heads = int(input.shape[0] / seq_world_size) # dim size of local_seq_len*local_heads*hdim
|
||||
input_splits = [local_seq_len_with_heads] * seq_world_size
|
||||
coeff = local_seq_len_with_heads // local_heads #per head: dim size of local_seq_len*hdim
|
||||
|
||||
#uneven seq_world_size coeff, total_heads/local_heads.
|
||||
heads_scale_coeff = get_num_kv_heads() / local_heads
|
||||
|
||||
output_splits = [num_local_heads * coeff for num_local_heads in num_local_heads_list]
|
||||
output_buff_d1_size = int(heads_scale_coeff * local_seq_len_with_heads)
|
||||
total_h = int(inp_shape[gather_idx] * heads_scale_coeff)
|
||||
output = torch.empty(output_buff_d1_size, input.shape[1], device=input.device, dtype=input.dtype)
|
||||
dist.all_to_all_single(output,input,output_split_sizes=output_splits, \
|
||||
input_split_sizes=input_splits,group=group)
|
||||
##################
|
||||
#suppose 7 heads divide into 4 ranks [2,2,2,1]
|
||||
#chunk_num_heads_small=floor(7/4)=1
|
||||
#chunk_num_heads_large=ceil(7/4)=2
|
||||
#num_chunk_heads_large=len([2,2,2])=3, all2all_buffer_counts
|
||||
#num_chunk_heads_small=len([1])=1, all2all_buffer_counts
|
||||
#total_num_large_heads=sum([2,2,2])=7
|
||||
#total_num_small_heads=sum([1])=1
|
||||
|
||||
chunk_num_heads_small = get_num_kv_heads() // seq_world_size # even heads compatible
|
||||
chunk_num_heads_large = chunk_num_heads_small + 1
|
||||
num_chunk_heads_large = get_num_kv_heads() % seq_world_size
|
||||
num_chunk_heads_small = seq_world_size - num_chunk_heads_large
|
||||
total_num_large_heads = num_chunk_heads_large * chunk_num_heads_large
|
||||
total_num_small_heads = num_chunk_heads_small * chunk_num_heads_small
|
||||
|
||||
heads_large_combine_size = coeff * total_num_large_heads
|
||||
heads_small_combine_size = coeff * total_num_small_heads
|
||||
heads_large_chunk, heads_small_chunk = output.split([heads_large_combine_size, heads_small_combine_size],
|
||||
dim=0)
|
||||
heads_large_chunk = heads_large_chunk.view(num_chunk_heads_large, local_seq_len, chunk_num_heads_large, h_dim,
|
||||
batch_size)
|
||||
heads_small_chunk = heads_small_chunk.view(num_chunk_heads_small, local_seq_len, chunk_num_heads_small, h_dim,
|
||||
batch_size)
|
||||
if batch_dim_idx == 0:
|
||||
#[all2all_buffer_counts, local_seq_len, n_heads,dim,batch]->[batch,local_seq_len,all2all_buffer_counts*n_heads,dim]
|
||||
order = [4, 1, 0, 2, 3]
|
||||
heads_large_chunk = heads_large_chunk.permute(order).contiguous().view(batch_size, local_seq_len,
|
||||
total_num_large_heads, h_dim)
|
||||
heads_small_chunk = heads_small_chunk.permute(order).contiguous().view(batch_size, local_seq_len,
|
||||
total_num_small_heads, h_dim)
|
||||
elif batch_dim_idx == 1:
|
||||
#[all2all_buffer_counts, local_seq_len, n_heads,dim,batch]->[local_seq_len,batch,all2all_buffer_counts*n_heads,dim]
|
||||
order = [1, 4, 0, 2, 3]
|
||||
heads_large_chunk = heads_large_chunk.permute(order).contiguous().view(local_seq_len, batch_size,
|
||||
total_num_large_heads, h_dim)
|
||||
heads_small_chunk = heads_small_chunk.permute(order).contiguous().view(local_seq_len, batch_size,
|
||||
total_num_small_heads, h_dim)
|
||||
|
||||
output = torch.cat([heads_large_chunk, heads_small_chunk], dim=2).contiguous()
|
||||
|
||||
inp_shape[scatter_idx] = inp_shape[scatter_idx] // seq_world_size
|
||||
output_shape= inp_shape[: gather_idx] + \
|
||||
[total_h,] + \
|
||||
inp_shape[gather_idx + 1:]
|
||||
|
||||
output = output.view(output_shape)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, async_op=False, handle=None, type=None):
|
||||
seq_world_size = dist.get_world_size(group)
|
||||
# we only need num_heads once
|
||||
num_heads = input.shape[2]
|
||||
|
||||
if get_num_kv_heads() is not None or num_heads % seq_world_size != 0:
|
||||
# Assuming here that the number of heads for q is consistent with kv
|
||||
# If not, additional logic is required for cases like GQA
|
||||
if get_num_kv_heads() is None:
|
||||
assert num_heads > seq_world_size, f"Number of heads ({num_heads}) must be larger than sequence parallel size ({seq_world_size})"
|
||||
# set heads at first call by num_total_heads.
|
||||
# then use ``get_num_kv_heads() is not None`` to re-entry uneven path.
|
||||
set_num_kv_heads(num_heads)
|
||||
assert async_op == False, "uneven head sp does not support async op"
|
||||
return uneven_heads_all2all(input, scatter_idx, gather_idx, batch_dim_idx, group)
|
||||
|
||||
if batch_dim_idx == 0:
|
||||
# b, s, n, h
|
||||
if scatter_idx < 2:
|
||||
|
|
|
@ -484,6 +484,8 @@ def _get_sequence_parallel_rank():
|
|||
global mpu
|
||||
if mpu is not None and hasattr(mpu, 'get_sequence_parallel_rank'):
|
||||
return mpu.get_sequence_parallel_rank()
|
||||
if mesh_device is not None:
|
||||
return dist.get_rank(mesh_device.get_group(mesh_dim="sequence_parallel"))
|
||||
return 0
|
||||
|
||||
|
||||
|
|
|
@ -7,8 +7,7 @@ import functools
|
|||
import logging
|
||||
import sys
|
||||
import os
|
||||
import torch
|
||||
from deepspeed.runtime.compiler import is_compile_supported
|
||||
from deepspeed.runtime.compiler import is_compile_supported, is_compiling
|
||||
|
||||
log_levels = {
|
||||
"debug": logging.DEBUG,
|
||||
|
@ -26,7 +25,7 @@ class LoggerFactory:
|
|||
|
||||
def warn_once(record):
|
||||
nonlocal warn
|
||||
if is_compile_supported() and torch.compiler.is_compiling() and not warn:
|
||||
if is_compile_supported() and is_compiling() and not warn:
|
||||
warn = True
|
||||
logger.warning("To avoid graph breaks caused by logger in compile-mode, it is recommended to"
|
||||
" disable logging by setting env var DISABLE_LOGS_WHILE_COMPILING=1")
|
||||
|
@ -39,7 +38,7 @@ class LoggerFactory:
|
|||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if torch.compiler.is_compiling():
|
||||
if is_compiling():
|
||||
return
|
||||
else:
|
||||
return func(*args, **kwargs)
|
||||
|
|
|
@ -67,7 +67,7 @@ def get_default_compute_capabilities():
|
|||
# Special treatment of CUDA 11.0 because compute_86 is not supported.
|
||||
compute_caps += ";8.0"
|
||||
else:
|
||||
compute_caps += ";8.0;8.6"
|
||||
compute_caps += ";8.0;8.6;9.0"
|
||||
return compute_caps
|
||||
|
||||
|
||||
|
|
|
@ -11,9 +11,12 @@ from transformers import AutoModel
|
|||
from unit.common import DistributedTest
|
||||
from deepspeed.sequence.layer import _SeqAllToAll
|
||||
from unit.util import skip_on_arch
|
||||
|
||||
|
||||
from unit.simple_model import *
|
||||
from deepspeed.utils import groups
|
||||
from deepspeed.module_inject.tp_shard import get_shard_size_list
|
||||
#Use mesh device to create data and sequence parallel group
|
||||
|
||||
|
||||
class TestUlyssesUtils(DistributedTest):
|
||||
world_size = 4
|
||||
|
||||
|
@ -75,3 +78,82 @@ class TestUlyssesAll2All(DistributedTest):
|
|||
# Check outputs are the same as input
|
||||
for i in range(1, len(outputs)):
|
||||
assert torch.allclose(input_tensor, outputs[i]), f"Outputs differ for sequence dim {seq_dims[i]}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("d0", [2, 4]) #batch or sequence dimension
|
||||
@pytest.mark.parametrize("d1", [4, 8]) #batch or sequence dimension
|
||||
@pytest.mark.parametrize("num_heads", [3, 7])
|
||||
@pytest.mark.parametrize("head_dim", [16])
|
||||
class TestUlyssesAll2All_odd(DistributedTest):
|
||||
world_size = 4
|
||||
|
||||
def test_alltoall_output_consistency(self, d0: int, d1: int, head_dim: int, num_heads: int) -> None:
|
||||
|
||||
data_parallel_size = 2
|
||||
seq_parallel_size = self.world_size // data_parallel_size
|
||||
skip_on_arch(min_arch=8)
|
||||
|
||||
def seq_batch_heads_hash(d0, d1, h, offset_d0=0, offset_d1=0, offset_h=0):
|
||||
d0 += offset_d0
|
||||
d1 += offset_d1
|
||||
h += offset_h
|
||||
return d0 * 10 + h + d1 * 0.1
|
||||
|
||||
hidden_dim = 10
|
||||
model = SimpleModel(hidden_dim)
|
||||
ds_engine, _, _, _ = initialize(model=model,
|
||||
config_params={"train_batch_size": 8},
|
||||
mesh_param=(data_parallel_size, seq_parallel_size))
|
||||
|
||||
scatter_idx = 2
|
||||
outputs = []
|
||||
inputs = []
|
||||
batch_dims = [0, 1]
|
||||
seq_dims = [1, 0]
|
||||
|
||||
for idx, seq_dim in enumerate(seq_dims):
|
||||
gather_idx = seq_dim
|
||||
batch_dim_idx = batch_dims[idx]
|
||||
|
||||
#4D tensor : b,s,h,d or s,b,h,d
|
||||
#create a hash tensor from pos_id, head_id, and batch_id
|
||||
d0_indices = torch.arange(d0).reshape(-1, 1, 1, 1)
|
||||
d1_indices = torch.arange(d1).reshape(1, -1, 1, 1)
|
||||
h_indices = torch.arange(num_heads).reshape(1, 1, -1, 1)
|
||||
input_tensor = torch.randn(d0, d1, num_heads, head_dim, device=ds_engine.device)
|
||||
if batch_dim_idx == 1: #seq_len_dim : 0(d0)
|
||||
input_tensor[:] = seq_batch_heads_hash(d0_indices, d1_indices, h_indices,
|
||||
d0 * groups._get_sequence_parallel_rank(), 0)
|
||||
elif batch_dim_idx == 0: #seq_len_dim : 1(d1)
|
||||
input_tensor[:] = seq_batch_heads_hash(d0_indices, d1_indices, h_indices, 0,
|
||||
d1 * groups._get_sequence_parallel_rank())
|
||||
inputs.append(input_tensor)
|
||||
|
||||
### first all2all: sequence parallel to head parallel
|
||||
s2h_tensor = _SeqAllToAll.apply(ds_engine.seq_parallel_group, input_tensor, scatter_idx, gather_idx,
|
||||
batch_dim_idx)
|
||||
|
||||
# s2h_tensor check for the first all2all: compare with the expected ground truth
|
||||
d0_indices = torch.arange(s2h_tensor.shape[0]).reshape(-1, 1, 1, 1)
|
||||
d1_indices = torch.arange(s2h_tensor.shape[1]).reshape(1, -1, 1, 1)
|
||||
h_indices = torch.arange(s2h_tensor.shape[2]).reshape(1, 1, -1, 1)
|
||||
shard_list = get_shard_size_list(num_heads, groups._get_sequence_parallel_world_size())
|
||||
head_offset = sum(shard_list[:groups._get_sequence_parallel_rank()])
|
||||
s2h_truth = torch.zeros_like(s2h_tensor)
|
||||
s2h_truth[:] = seq_batch_heads_hash(d0_indices, d1_indices, h_indices, 0, 0, head_offset)
|
||||
|
||||
assert torch.allclose(s2h_truth,
|
||||
s2h_tensor), f"s2h_tensor differs from the expected for sequence dim: {seq_dim}"
|
||||
#No op
|
||||
### second all2all: head parallel to sequence parallel
|
||||
h2s_tensor = _SeqAllToAll.apply(ds_engine.seq_parallel_group, s2h_tensor, gather_idx, scatter_idx,
|
||||
batch_dim_idx)
|
||||
print(
|
||||
f'[{dist.get_rank()}] s={seq_dim} input: {input_tensor.shape} s2h: {s2h_tensor.shape} h2s_tensor: {h2s_tensor.shape}'
|
||||
)
|
||||
outputs.append(h2s_tensor)
|
||||
|
||||
# Check outputs for the second all2all
|
||||
for i in range(0, len(outputs)):
|
||||
assert torch.allclose(inputs[i],
|
||||
outputs[i]), f"[{dist.get_rank()}]Outputs differ for sequence dim {seq_dims[i]}"
|
||||
|
|
Загрузка…
Ссылка в новой задаче