зеркало из https://github.com/microsoft/DeepSpeed.git
[CPU] Support Intel CPU inference (#3041)
* add fallback path for kernels used in megatron * temporary numactl WA for SPR 56core * adapt core allocation according to number of ranks * add switch to turn on numactl * detect number of cores on the system * allow select a subset of the cores on the system to bind * remove unneeded changes * add ccl backend * change nccl to ccl * remove unused code * add comm/ccl to ops * initial ccl comm support * first broadcast case passed * add CCL_Backend to DeepSpeed * support comm timer for CPU * support barrier for comm backend * support specify master address from deepspeed command line * support pytorch 2.0 * remove 'block' from api * Tweak for debug Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com> * Remove unecessary directory Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com> * Add bf16 kernel support for inference * Add temporary torch implement for cpu inference * Add softmax ops cpu fallback for inference * bind cores to numa domain as well * merge latest change in gma/numactl * initial bf16 kernel support with fallback path * initial fallback path for bloom kernel injection * fix softmax attn mask * check KMP_AFFINITY to avoid conflict with numactl * New CCLBackend which utilize TorchBackend for initialization * rollback last change because there is result error * fix bloom injection policy TP could not work issue. injection_policy={BloomBlock: ("self_attention.dense", "mlp.dense_4h_to_h")} * Use TorchBackend to initialize CCLBackend, make behavior consistent * remove comm under deepspeed/ops * add license header * code clean up * fix format issue * remove magic number in main address * add caching support but not turn on by default * change name of inference_cuda_module to inference_module * Check for is_synchronized_device in accelerator before get Event * fix typo * Fix fallback path of softmax kernel on CUDA device for BF16 data type, because CUDA tril does not support BF16 datatype, enforce fp32 data type * add cpu backend files * change CPU_Accelerator op_builder_dir * remove cpu_kernel_path * using CPU_Accelerator on non-cuda device * fix deepspeed.op_builder => deepspeed.ops.op_builder * add alias for num_gpus: num_accelerators * allow loading cpu_builder in build stage * Assume cuda available if torch not installed * add oneccl_binding_pt to requirements * move oneccl-binding-pt to seperate requiremetns-cpu.txt * add missing file * use dependency_links in setuptools.setup() call for additional dependency links * install oneccl_bind_pt in workflows * change oneccl_bind_pt's version from 1.13 to 2.0 * use intel_exention_for_pytorch as indicator that CPU_Accelerator should be used * Add indicator for Accelerator used * change foo.c to foo.cpp * exclude 'cpu' directory in CUDA op builder reflection * add a cpu-inference workflow * run cpu-inference workflow on self-hosted instance * change cpu runs-on node to v100 node * print out python version in workflow * add verbose in pip command to understand oneccl_bind_pt install issue * update cpu-inference workflow * add a stage to detect instance instruction sets * add back bf16 support for CPU inference * enable autoTP for bloom Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * update workflow to detect cpu instruction sets * temporary WA for Intel Extension for PyTorch AVX2 instructioon set detection * change cpu-inference workflow machine to ubuntu-20.04 * add sharded checkpoint loading for AutoTP path to reduce the peak memory in initialization stage Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * enable policy for llama * use a special build ipex to test avx2 detection fix * fix format * fix test fail issue Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix gptj sharded checkpoint loading problem Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * return a not implemented build in get_op_builder in cpu_backend * support cpu device in tests * use cpuinfo to extract number of CPUs * use ~/tmp as transfomer cache rather than /blob/ * Add support for mpich launcher with prefer_deepspeed_comm * add missing modification in accelerator * enable IMPI launcher * remove unused file and fix formatting * clean up ccl.cpp * Less confusing error message when certin op builder are not implemented * Fix license header * Add license header * add license headers * add license header * fix cuda specific code in test * update CPU workflow * use numactl to bind to core * allow bind_cores_to_rank in multi-node impi runner * fix format error * Remove InferenceBuilder * fix format error in numa.py * check whether op is in installed ops in ds_report.py * allow override accelerator with DS_ACCELERATOR='cuda','cpu' or 'xpu' * lazy init class_dict in CUDA_Accelerator to avoid cyclic initialization of CUDA_Accelerator * put short path in the beginning in real_accelerator.py * device_count return number of NUMA nodes * fix typo * install numactl in cpu workflow * Follow comments * Better implementation of device_count() and current_device() * remove dependency_link for Intel Extension for DeepSpeed * use check is_synchronized_device in timer only once * remove env mapping WA in cpu_accelerator * fix duplicate definition * fix format error * refine ccl backend selection * move comments to the right place * remove prefer_deepspeed_comm, use CCLBackend by default * refractor fallback path * Fix execution failure in kernel injection path * do not refractory kernel injection fallback path in residual_add because it contains function call with side-effect * guard residual_add fallback path with environ DS_KI_FALLBACK=True * fix format error * add test for allreduce on CPU workflow * fix format error * Fallback to TorchBackend if CCLBackend kernel are not implemented * Update Intel Extension for Pytorch installation link * Don't specify version number of Intel Extension for PyTorch * install oneCCL for CCLBackend * fix link path for CPU comm kernels * fix source oneCCL environment * source oneCCL env before run UT * Give more specific instruction when CCL_ROOT not defined --------- Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com> Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> Co-authored-by: sdp <sdp@aia-sdp-spr-108864.jf.intel.com> Co-authored-by: Cao, Zhong Z <zhong.z.cao@intel.com> Co-authored-by: Zhenhuan Chen <zhenhuan.chen@intel.com> Co-authored-by: baodii <di.bao@intel.com> Co-authored-by: Wang, Yi A <yi.a.wang@intel.com> Co-authored-by: jianan-gu <jianan.gu@intel.com> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
This commit is contained in:
Родитель
5147b90aa4
Коммит
1f72082fc0
|
@ -0,0 +1,83 @@
|
|||
name: cpu-inference
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'staging**'
|
||||
paths-ignore:
|
||||
- 'docs/**'
|
||||
pull_request:
|
||||
paths-ignore:
|
||||
- 'docs/**'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
unit-tests:
|
||||
runs-on: ubuntu-20.04
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
- id: setup-venv
|
||||
uses: ./.github/workflows/setup-venv
|
||||
|
||||
- name: Detect instruction sets on instance
|
||||
run: |
|
||||
lscpu
|
||||
pip install cmake
|
||||
git clone https://github.com/intel/intel-extension-for-pytorch
|
||||
cd intel-extension-for-pytorch/tests/cpu/isa
|
||||
cmake .
|
||||
make
|
||||
./cpu_features
|
||||
|
||||
- name: Install numactl
|
||||
run: |
|
||||
sudo apt-get install -y numactl
|
||||
|
||||
- name: Install oneCCL Bindings for PyTorch
|
||||
run: |
|
||||
python -m pip install intel_extension_for_pytorch
|
||||
python -m pip install oneccl_bind_pt==2.0 -f https://developer.intel.com/ipex-whl-stable-cpu
|
||||
|
||||
- name: Install oneCCL
|
||||
run: |
|
||||
git clone https://github.com/oneapi-src/oneCCL
|
||||
cd oneCCL
|
||||
mkdir build
|
||||
cd build
|
||||
cmake ..
|
||||
make
|
||||
make install
|
||||
#source ./_install/env/setvars.sh
|
||||
# test whether oneCCL is correctly installed
|
||||
#mpirun -n 2 ./examples/benchmark/benchmark
|
||||
|
||||
- name: Install transformers
|
||||
run: |
|
||||
git clone https://github.com/huggingface/transformers
|
||||
cd transformers
|
||||
git rev-parse --short HEAD
|
||||
pip install .
|
||||
|
||||
- name: Install deepspeed
|
||||
run: |
|
||||
# check why the host does not have AVX2 support
|
||||
pip install .[dev,1bit,autotuning,inf]
|
||||
ds_report
|
||||
|
||||
- name: Python environment
|
||||
run: |
|
||||
pip list
|
||||
|
||||
- name: Unit tests
|
||||
run: |
|
||||
source oneCCL/build/_install/env/setvars.sh
|
||||
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
|
||||
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
|
||||
cd tests
|
||||
TRANSFORMERS_CACHE=~/tmp/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'inference' unit/inference/test_inference_config.py
|
||||
TRANSFORMERS_CACHE=~/tmp/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -k TestDistAllReduce unit/comm/test_dist.py
|
|
@ -13,6 +13,10 @@ class DeepSpeedAccelerator(ABC):
|
|||
self._name = None
|
||||
self._communication_backend_name = None
|
||||
|
||||
@abc.abstractmethod
|
||||
def is_synchronized_device(self):
|
||||
...
|
||||
|
||||
# Device APIs
|
||||
@abc.abstractmethod
|
||||
def device_name(self, device_index):
|
||||
|
|
|
@ -0,0 +1,260 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import torch
|
||||
from deepspeed.accelerator.abstract_accelerator import DeepSpeedAccelerator
|
||||
import oneccl_bindings_for_pytorch # noqa: F401
|
||||
import psutil
|
||||
import os
|
||||
|
||||
|
||||
# accelerator for Intel CPU
|
||||
class CPU_Accelerator(DeepSpeedAccelerator):
|
||||
|
||||
def __init__(self):
|
||||
self._name = 'cpu'
|
||||
self._communication_backend_name = 'ccl'
|
||||
self.max_mem = psutil.Process().memory_info().rss
|
||||
|
||||
def is_synchronized_device(self):
|
||||
return True
|
||||
|
||||
# Device APIs
|
||||
def device_name(self, device_index=None):
|
||||
return 'cpu'
|
||||
|
||||
def device(self, device_index=None):
|
||||
return None
|
||||
|
||||
def set_device(self, device_index):
|
||||
return
|
||||
|
||||
def current_device(self):
|
||||
return os.environ.get('LOCAL_RANK', 0)
|
||||
|
||||
def current_device_name(self):
|
||||
return 'cpu'
|
||||
|
||||
def device_count(self):
|
||||
device_count = int(os.environ.get('LOCAL_SIZE', 0))
|
||||
if device_count > 0:
|
||||
return os.environ.get('LOCAL_SIZE')
|
||||
else:
|
||||
from deepspeed.utils.numa import get_numa_cores
|
||||
# Count NUMA node for number of cpu accelerators. On machine with HBM
|
||||
# In flat mode, HBM is in separate NUMA node with no cores on this node.
|
||||
# Ignore these NUMA nodes with no cores.
|
||||
numa_core_lists = get_numa_cores()
|
||||
numa_count = 0
|
||||
for core_list in numa_core_lists:
|
||||
if len(core_list) > 0:
|
||||
numa_count += 1
|
||||
return numa_count
|
||||
|
||||
def synchronize(self, device_index=None):
|
||||
return
|
||||
|
||||
# RNG APIs
|
||||
def random(self):
|
||||
return torch.random
|
||||
|
||||
def set_rng_state(self, new_state, device_index=None):
|
||||
if device_index == None:
|
||||
return torch.set_rng_state(new_state)
|
||||
return torch.set_rng_state(new_state, device_index)
|
||||
|
||||
def get_rng_state(self, device_index=None):
|
||||
return torch.get_rng_state()
|
||||
|
||||
def manual_seed(self, seed):
|
||||
return torch.manual_seed(seed)
|
||||
|
||||
def manual_seed_all(self, seed):
|
||||
return torch.manual_seed(seed)
|
||||
|
||||
def initial_seed(self, seed):
|
||||
return torch.initial_seed(seed)
|
||||
|
||||
def default_generator(self, device_index):
|
||||
return torch.default_generator
|
||||
|
||||
# Streams/Events
|
||||
@property
|
||||
def Stream(self):
|
||||
return None
|
||||
|
||||
def stream(self, stream):
|
||||
from deepspeed.runtime.utils import noop_decorator
|
||||
return noop_decorator
|
||||
|
||||
def current_stream(self, device_index=None):
|
||||
return None
|
||||
|
||||
def default_stream(self, device_index=None):
|
||||
return None
|
||||
|
||||
@property
|
||||
def Event(self):
|
||||
return None
|
||||
|
||||
# Memory management
|
||||
def empty_cache(self):
|
||||
return
|
||||
|
||||
def get_rss(self):
|
||||
mem = psutil.Process().memory_info().rss
|
||||
if mem > self.max_mem:
|
||||
self.max_mem = mem
|
||||
return mem
|
||||
|
||||
def reset_rss(self):
|
||||
mem = psutil.Process().memory_info().rss
|
||||
self.max_mem = mem
|
||||
return mem
|
||||
|
||||
def memory_allocated(self, device_index=None):
|
||||
return self.get_rss()
|
||||
|
||||
def max_memory_allocated(self, device_index=None):
|
||||
self.get_rss()
|
||||
return self.max_mem
|
||||
|
||||
def reset_max_memory_allocated(self, device_index=None):
|
||||
self.reset_rss()
|
||||
return
|
||||
|
||||
def memory_cached(self, device_index=None):
|
||||
return self.get_rss()
|
||||
|
||||
def max_memory_cached(self, device_index=None):
|
||||
self.get_rss()
|
||||
return self.max_mem
|
||||
|
||||
def reset_max_memory_cached(self, device_index=None):
|
||||
self.reset_rss()
|
||||
return
|
||||
|
||||
def memory_stats(self, device_index=None):
|
||||
return self.get_rss()
|
||||
|
||||
def reset_peak_memory_stats(self, device_index=None):
|
||||
self.reset_rss()
|
||||
return
|
||||
|
||||
def memory_reserved(self, device_index=None):
|
||||
return self.get_rss()
|
||||
|
||||
def max_memory_reserved(self, device_index=None):
|
||||
self.get_rss()
|
||||
return self.max_mem
|
||||
|
||||
def total_memory(self, device_index=None):
|
||||
return psutil.virtual_memory().total
|
||||
|
||||
# Misc
|
||||
def amp(self):
|
||||
return torch.cpu.amp
|
||||
|
||||
def is_available(self):
|
||||
return True
|
||||
|
||||
def range_push(self, msg):
|
||||
# TODO itt is currently not supported yet
|
||||
# return torch.profiler.itt.range_push(msg)
|
||||
return
|
||||
|
||||
def range_pop(self):
|
||||
# TODO itt is currently not supported yet
|
||||
# return torch.profiler.itt.range_pop()
|
||||
return
|
||||
|
||||
def lazy_call(self, callback):
|
||||
return callback()
|
||||
|
||||
def communication_backend_name(self):
|
||||
return self._communication_backend_name
|
||||
|
||||
# Data types
|
||||
def is_bf16_supported(self):
|
||||
return True
|
||||
|
||||
def is_fp16_supported(self):
|
||||
return True
|
||||
|
||||
# Tensor operations
|
||||
|
||||
@property
|
||||
def BFloat16Tensor(self):
|
||||
return torch.BFloat16Tensor
|
||||
|
||||
@property
|
||||
def ByteTensor(self):
|
||||
return torch.ByteTensor
|
||||
|
||||
@property
|
||||
def DoubleTensor(self):
|
||||
return torch.DoubleTensor
|
||||
|
||||
@property
|
||||
def FloatTensor(self):
|
||||
return torch.FloatTensor
|
||||
|
||||
@property
|
||||
def HalfTensor(self):
|
||||
return torch.HalfTensor
|
||||
|
||||
@property
|
||||
def IntTensor(self):
|
||||
return torch.IntTensor
|
||||
|
||||
@property
|
||||
def LongTensor(self):
|
||||
return torch.LongTensor
|
||||
|
||||
def pin_memory(self, tensor):
|
||||
return tensor
|
||||
|
||||
def op_builder_dir(self):
|
||||
try:
|
||||
# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
|
||||
# if successful this also means we're doing a local install and not JIT compile path
|
||||
from op_builder import __deepspeed__ # noqa: F401
|
||||
return "op_builder.cpu"
|
||||
except ImportError:
|
||||
return "deepspeed.ops.op_builder.cpu"
|
||||
|
||||
def on_accelerator(self, tensor):
|
||||
device_str = str(tensor.device)
|
||||
if device_str.startswith('cpu'):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
# create an instance of op builder and return, name specified by class_name
|
||||
def create_op_builder(self, op_name):
|
||||
builder_class = self.get_op_builder(op_name)
|
||||
if builder_class != None:
|
||||
return builder_class()
|
||||
return None
|
||||
|
||||
# return an op builder class, name specified by class_name
|
||||
def get_op_builder(self, class_name):
|
||||
try:
|
||||
# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
|
||||
# if successful this also means we're doing a local install and not JIT compile path
|
||||
from op_builder import __deepspeed__ # noqa: F401
|
||||
from op_builder.cpu import CCLCommBuilder, NotImplementedBuilder
|
||||
except ImportError:
|
||||
from deepspeed.ops.op_builder.cpu import CCLCommBuilder, NotImplementedBuilder
|
||||
|
||||
if class_name == "CCLCommBuilder":
|
||||
return CCLCommBuilder
|
||||
else:
|
||||
# return a NotImplementedBuilder to avoid get NoneType[Name] in unit tests
|
||||
return NotImplementedBuilder
|
||||
|
||||
def build_extension(self):
|
||||
from torch.utils.cpp_extension import BuildExtension
|
||||
return BuildExtension
|
|
@ -22,21 +22,8 @@ class CUDA_Accelerator(DeepSpeedAccelerator):
|
|||
self._name = 'cuda'
|
||||
self._communication_backend_name = 'nccl'
|
||||
|
||||
# begin initialize for create_op_builder()
|
||||
# put all valid class name <--> class type mapping into class_dict
|
||||
op_builder_dir = self.op_builder_dir()
|
||||
op_builder_module = importlib.import_module(op_builder_dir)
|
||||
for _, module_name, _ in pkgutil.iter_modules([os.path.dirname(op_builder_module.__file__)]):
|
||||
# avoid self references
|
||||
if module_name != 'all_ops' and module_name != 'builder':
|
||||
module = importlib.import_module("{}.{}".format(op_builder_dir, module_name))
|
||||
for member_name in module.__dir__():
|
||||
if member_name.endswith(
|
||||
'Builder'
|
||||
) and member_name != "OpBuilder" and member_name != "CUDAOpBuilder" and member_name != "TorchCPUOpBuilder": # avoid abstract classes
|
||||
if not member_name in self.class_dict:
|
||||
self.class_dict[member_name] = getattr(module, member_name)
|
||||
# end initialize for create_op_builder()
|
||||
def is_synchronized_device(self):
|
||||
return False
|
||||
|
||||
# Device APIs
|
||||
def device_name(self, device_index=None):
|
||||
|
@ -235,10 +222,32 @@ class CUDA_Accelerator(DeepSpeedAccelerator):
|
|||
# dict that holds class name <--> class type mapping i.e.
|
||||
# 'AsyncIOBuilder': <class 'op_builder.async_io.AsyncIOBuilder'>
|
||||
# this dict will be filled at init stage
|
||||
class_dict = {}
|
||||
class_dict = None
|
||||
|
||||
def _lazy_init_class_dict(self):
|
||||
if self.class_dict != None:
|
||||
return
|
||||
else:
|
||||
self.class_dict = {}
|
||||
# begin initialize for create_op_builder()
|
||||
# put all valid class name <--> class type mapping into class_dict
|
||||
op_builder_dir = self.op_builder_dir()
|
||||
op_builder_module = importlib.import_module(op_builder_dir)
|
||||
for _, module_name, _ in pkgutil.iter_modules([os.path.dirname(op_builder_module.__file__)]):
|
||||
# avoid self references
|
||||
if module_name != 'all_ops' and module_name != 'builder' and module_name != 'cpu':
|
||||
module = importlib.import_module("{}.{}".format(op_builder_dir, module_name))
|
||||
for member_name in module.__dir__():
|
||||
if member_name.endswith(
|
||||
'Builder'
|
||||
) and member_name != "OpBuilder" and member_name != "CUDAOpBuilder" and member_name != "TorchCPUOpBuilder": # avoid abstract classes
|
||||
if not member_name in self.class_dict:
|
||||
self.class_dict[member_name] = getattr(module, member_name)
|
||||
# end initialize for create_op_builder()
|
||||
|
||||
# create an instance of op builder and return, name specified by class_name
|
||||
def create_op_builder(self, class_name):
|
||||
self._lazy_init_class_dict()
|
||||
if class_name in self.class_dict:
|
||||
return self.class_dict[class_name]()
|
||||
else:
|
||||
|
@ -246,6 +255,7 @@ class CUDA_Accelerator(DeepSpeedAccelerator):
|
|||
|
||||
# return an op builder class, name specified by class_name
|
||||
def get_op_builder(self, class_name):
|
||||
self._lazy_init_class_dict()
|
||||
if class_name in self.class_dict:
|
||||
return self.class_dict[class_name]
|
||||
else:
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
import os
|
||||
|
||||
try:
|
||||
from accelerator.abstract_accelerator import DeepSpeedAccelerator as dsa1
|
||||
|
@ -36,25 +37,76 @@ def _validate_accelerator(accel_obj):
|
|||
|
||||
def get_accelerator():
|
||||
global ds_accelerator
|
||||
if ds_accelerator is None:
|
||||
try:
|
||||
from intel_extension_for_deepspeed import XPU_Accelerator
|
||||
except ImportError as e:
|
||||
if ds_accelerator is not None:
|
||||
return ds_accelerator
|
||||
|
||||
accelerator_name = None
|
||||
ds_set_method = None
|
||||
# 1. Detect whether there is override of DeepSpeed accelerators from environment variable.
|
||||
# DS_ACCELERATOR = 'cuda'|'xpu'|'cpu'
|
||||
if 'DS_ACCELERATOR' in os.environ.keys():
|
||||
accelerator_name = os.environ['DS_ACCELERATOR']
|
||||
if accelerator_name == 'xpu':
|
||||
try:
|
||||
from intel_extension_for_deepspeed import XPU_Accelerator # noqa: F401
|
||||
except ImportError as e:
|
||||
raise ValueError(
|
||||
f'XPU_Accelerator requires intel_extension_for_deepspeed, which is not installed on this system.')
|
||||
elif accelerator_name == 'cpu':
|
||||
try:
|
||||
import intel_extension_for_pytorch # noqa: F401
|
||||
except ImportError as e:
|
||||
raise ValueError(
|
||||
f'CPU_Accelerator requires intel_extension_for_pytorch, which is not installed on this system.')
|
||||
elif accelerator_name == 'cuda':
|
||||
pass
|
||||
else:
|
||||
ds_accelerator = XPU_Accelerator()
|
||||
_validate_accelerator(ds_accelerator)
|
||||
return ds_accelerator
|
||||
raise ValueError(
|
||||
f'DS_ACCELERATOR must be one of "cuda", "cpu", or "xpu". Value "{accelerator_name}" is not supported')
|
||||
ds_set_method = 'override'
|
||||
|
||||
# 2. If no override, detect which accelerator to use automatically
|
||||
if accelerator_name == None:
|
||||
try:
|
||||
from intel_extension_for_deepspeed import XPU_Accelerator # noqa: F401,F811
|
||||
accelerator_name = 'xpu'
|
||||
except ImportError as e:
|
||||
# We need a way to choose between CUDA_Accelerator and CPU_Accelerator
|
||||
# Currently we detect whether intel_etension_for_pytorch is installed
|
||||
# in the environment and use CPU_Accelerator if the answewr is True.
|
||||
# An alternative might be detect whether CUDA device is installed on
|
||||
# the system but this comes with two pitfalls:
|
||||
# 1. the system may not have torch pre-installed, so
|
||||
# get_accelerator().is_avaiable() may not work.
|
||||
# 2. Some scenario like install on login node (without CUDA device)
|
||||
# and run on compute node (with CUDA device) may cause mismatch
|
||||
# between installation time and runtime.
|
||||
try:
|
||||
import intel_extension_for_pytorch # noqa: F401,F811
|
||||
accelerator_name = 'cpu'
|
||||
except ImportError as e:
|
||||
accelerator_name = 'cuda'
|
||||
ds_set_method = 'auto detect'
|
||||
|
||||
# 3. Set ds_accelerator accordingly
|
||||
if accelerator_name == 'cuda':
|
||||
from .cuda_accelerator import CUDA_Accelerator
|
||||
ds_accelerator = CUDA_Accelerator()
|
||||
_validate_accelerator(ds_accelerator)
|
||||
elif accelerator_name == 'cpu':
|
||||
from .cpu_accelerator import CPU_Accelerator
|
||||
ds_accelerator = CPU_Accelerator()
|
||||
elif accelerator_name == 'xpu':
|
||||
# XPU_Accelerator is already imported in detection stage
|
||||
ds_accelerator = XPU_Accelerator()
|
||||
_validate_accelerator(ds_accelerator)
|
||||
print(f"Setting ds_accelerator to {ds_accelerator._name} ({ds_set_method})")
|
||||
return ds_accelerator
|
||||
|
||||
|
||||
def set_accelerator(accel_obj):
|
||||
global ds_accelerator
|
||||
_validate_accelerator(accel_obj)
|
||||
print(f"Setting ds_accelerator to {accel_obj._name} (model specified)")
|
||||
ds_accelerator = accel_obj
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,211 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <oneapi/ccl.hpp>
|
||||
|
||||
std::set<int> _comm_ids;
|
||||
std::set<int> _colors;
|
||||
ccl::vector_class<ccl::communicator> _ccl_comms;
|
||||
|
||||
#define CCLCHECK(cmd) \
|
||||
do { \
|
||||
cmd; \
|
||||
} while (0)
|
||||
|
||||
#define KVS_CREATE_SUCCESS 0
|
||||
#define KVS_CREATE_FAILURE -1
|
||||
|
||||
bool is_initialized = 0;
|
||||
|
||||
int world_rank = -1;
|
||||
int world_size = -1;
|
||||
|
||||
ccl::shared_ptr_class<ccl::kvs> kvs;
|
||||
|
||||
void initialize(int size, int rank, torch::Tensor& kvs_data)
|
||||
{
|
||||
if (is_initialized) return;
|
||||
world_size = size;
|
||||
world_rank = rank;
|
||||
is_initialized = 1;
|
||||
|
||||
ccl::kvs::address_type main_addr;
|
||||
|
||||
if (rank != 0) {
|
||||
memcpy(main_addr.data(), kvs_data.data_ptr(), main_addr.size());
|
||||
kvs = ccl::create_kvs(main_addr);
|
||||
}
|
||||
|
||||
_ccl_comms.emplace_back(ccl::create_communicator(size, rank, kvs));
|
||||
}
|
||||
|
||||
/*
|
||||
rank == 0: create main kvs and return its address
|
||||
rank == else: return an empty address
|
||||
*/
|
||||
std::vector<uint8_t> get_kvs_addr(int rank)
|
||||
{
|
||||
if (rank == 0) {
|
||||
kvs = ccl::create_main_kvs();
|
||||
ccl::kvs::address_type main_addr = kvs->get_address();
|
||||
auto ccl_kvs_addr = std::vector<uint8_t>(main_addr.begin(), main_addr.end());
|
||||
return ccl_kvs_addr;
|
||||
} else {
|
||||
ccl::kvs::address_type main_addr;
|
||||
auto ccl_kvs_addr = std::vector<uint8_t>(main_addr.begin(), main_addr.end());
|
||||
return ccl_kvs_addr;
|
||||
}
|
||||
}
|
||||
|
||||
int get_rank(int group = 0) { return world_rank; }
|
||||
|
||||
int get_world_size(int group = 0) { return world_size; }
|
||||
|
||||
// Find the next ordered, unique value to a set. E.g. <0,1,2,7> --> 3
|
||||
int next_unique_val(std::set<int> s)
|
||||
{
|
||||
std::set<int>::iterator itr;
|
||||
// Base case. Add 0 to start of set.
|
||||
if (s.empty() || *s.begin() != 0) {
|
||||
return 0;
|
||||
// second base case where s = {0} (the case of s = {n != 0} is caught above)
|
||||
} else if (s.size() == 1) {
|
||||
return 1;
|
||||
} else {
|
||||
int prev_val = *s.begin();
|
||||
for (itr = std::next(s.begin()); itr != s.end(); itr++) {
|
||||
if (*itr != prev_val + 1) { return prev_val + 1; }
|
||||
prev_val = *itr;
|
||||
}
|
||||
return *(s.end()) + 1;
|
||||
}
|
||||
}
|
||||
|
||||
py::object new_group(std::vector<int> ranks)
|
||||
{
|
||||
int comm_id = next_unique_val(_comm_ids);
|
||||
int color = next_unique_val(_colors);
|
||||
std::cout << "RANK: " << get_rank() << " COMM_ID: " << comm_id << " COLOR: " << color
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
ccl::datatype get_ccl_datatype(c10::ScalarType type)
|
||||
{
|
||||
ccl::datatype ccl_type;
|
||||
switch (type) {
|
||||
case c10::ScalarType::Int: ccl_type = ccl::datatype::int32; break;
|
||||
case c10::ScalarType::Float: ccl_type = ccl::datatype::float32; break;
|
||||
case c10::ScalarType::Double: ccl_type = ccl::datatype::float64; break;
|
||||
case c10::ScalarType::BFloat16: ccl_type = ccl::datatype::bfloat16; break;
|
||||
case c10::ScalarType::Half: ccl_type = ccl::datatype::float16; break;
|
||||
default: ccl_type = ccl::datatype::int8;
|
||||
}
|
||||
return ccl_type;
|
||||
}
|
||||
|
||||
ccl::reduction get_ccl_reduce_op(py::object op, at::Tensor& input)
|
||||
{
|
||||
py::object ReduceOp = py::module_::import("deepspeed.comm").attr("ReduceOp");
|
||||
if (!py::isinstance(op, ReduceOp)) {
|
||||
throw std::runtime_error("Error: Op must be of type ReduceOp");
|
||||
}
|
||||
|
||||
int op_val = py::int_(op.attr("value"));
|
||||
ccl::reduction ccl_op;
|
||||
|
||||
if (input.scalar_type() == at::kBool) {
|
||||
if (op_val == (int)py::int_(ReduceOp.attr("SUM").attr("value"))) {
|
||||
// For bool tensors, map sum to max, which both represent a bitwise or.
|
||||
// This is to prevent overflow issues with sum, since we use uint8 to
|
||||
// represent a bool (see cclDataType mapping).
|
||||
ccl_op = ccl::reduction::max;
|
||||
} else if (op_val == (int)py::int_(ReduceOp.attr("AVG").attr("value"))) {
|
||||
throw std::runtime_error("Error: For bool tensors, op must be of type ReduceOp");
|
||||
}
|
||||
}
|
||||
|
||||
if (op_val == (int)py::int_(ReduceOp.attr("SUM").attr("value"))) {
|
||||
ccl_op = ccl::reduction::sum;
|
||||
} else if (op_val == (int)py::int_(ReduceOp.attr("MIN").attr("value"))) {
|
||||
ccl_op = ccl::reduction::min;
|
||||
} else if (op_val == (int)py::int_(ReduceOp.attr("MAX").attr("value"))) {
|
||||
ccl_op = ccl::reduction::max;
|
||||
} else if (op_val == (int)py::int_(ReduceOp.attr("PRODUCT").attr("value"))) {
|
||||
ccl_op = ccl::reduction::prod;
|
||||
} else {
|
||||
throw std::runtime_error("Error: Unrecognized ReduceOp type");
|
||||
}
|
||||
return ccl_op;
|
||||
}
|
||||
|
||||
ccl::communicator& _get_comm_from_group() { return _ccl_comms[0]; }
|
||||
|
||||
ccl::communicator& _get_comm_from_group(py::object group) { return _ccl_comms[0]; }
|
||||
|
||||
void broadcast(torch::Tensor& data, int src, py::object group, bool async_op)
|
||||
{
|
||||
CCLCHECK(ccl::broadcast(data.data_ptr(),
|
||||
data.numel(),
|
||||
get_ccl_datatype(data.scalar_type()),
|
||||
src,
|
||||
_get_comm_from_group(group))
|
||||
.wait());
|
||||
}
|
||||
|
||||
// TODO: implement torch's async_op behavior, document it.
|
||||
void all_reduce(torch::Tensor& data, py::object op, py::object group, bool async_op)
|
||||
{
|
||||
CCLCHECK(ccl::allreduce(data.data_ptr(),
|
||||
data.data_ptr(),
|
||||
data.numel(),
|
||||
get_ccl_datatype(data.scalar_type()),
|
||||
get_ccl_reduce_op(op, data),
|
||||
_get_comm_from_group(group))
|
||||
.wait());
|
||||
}
|
||||
|
||||
void all_reduce_caching(torch::Tensor& data,
|
||||
py::object op,
|
||||
std::string match_id,
|
||||
py::object group,
|
||||
bool async_op)
|
||||
{
|
||||
ccl::allreduce_attr attr = ccl::default_allreduce_attr;
|
||||
auto match_str = ccl::v1::string(match_id);
|
||||
attr.template set<ccl::operation_attr_id::to_cache>(true);
|
||||
attr.template set<ccl::operation_attr_id::match_id>(match_str);
|
||||
// To control this, use operation attribute and set true value for to_cache field and unique
|
||||
// string (for example, tensor name) for match_id field. Note that:
|
||||
// match_id should be the same for a specific communication operation across all ranks.
|
||||
// If the same tensor is a part of different communication operations, match_id should have
|
||||
// different values for each of these operations.
|
||||
CCLCHECK(ccl::allreduce(data.data_ptr(),
|
||||
data.data_ptr(),
|
||||
data.numel(),
|
||||
get_ccl_datatype(data.scalar_type()),
|
||||
get_ccl_reduce_op(op, data),
|
||||
_get_comm_from_group(group),
|
||||
attr)
|
||||
.wait());
|
||||
}
|
||||
|
||||
void barrier(py::object group, bool async_op)
|
||||
{
|
||||
CCLCHECK(ccl::barrier(_get_comm_from_group(group)).wait());
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("get_kvs_addr", &get_kvs_addr, "create and get main kvs addr");
|
||||
m.def("initialize", &initialize, "ccl initialize");
|
||||
m.def("get_rank", &get_rank, "get rank");
|
||||
m.def("get_world_size", &get_world_size, "get world size");
|
||||
m.def("broadcast", &broadcast, "ccl broadcast");
|
||||
m.def("all_reduce", &all_reduce, "ccl all_reduce");
|
||||
m.def("all_reduce_caching", &all_reduce_caching, "ccl all_reduce with caching");
|
||||
m.def("barrier", &barrier, "barrier");
|
||||
}
|
|
@ -0,0 +1,62 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
'''
|
||||
Copyright 2021 The Microsoft DeepSpeed Team
|
||||
'''
|
||||
|
||||
import torch
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
from .reduce_op import ReduceOp
|
||||
from .torch import TorchBackend
|
||||
|
||||
|
||||
def build_ccl_op():
|
||||
builder = get_accelerator().create_op_builder("CCLCommBuilder")
|
||||
if builder == None:
|
||||
return None
|
||||
ccl_cpp_module = builder.load()
|
||||
print(f'DeepSpeed {builder.absolute_name()} built successfully')
|
||||
return ccl_cpp_module
|
||||
|
||||
|
||||
class CCLBackend(TorchBackend):
|
||||
|
||||
def __init__(self, name='ccl', rank=-1, world_size=-1, mpu=None, timeout=None, init_method=None):
|
||||
super(CCLBackend, self).__init__(backend='ccl',
|
||||
name='torch',
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
timeout=timeout,
|
||||
init_method=init_method)
|
||||
self.name = 'ccl'
|
||||
self.ccl_comm_op = build_ccl_op()
|
||||
if self.ccl_comm_op == None:
|
||||
# set CCLBackend to uninitialized state if CCLCommBuilder cannot be loaded
|
||||
self.initialized = False
|
||||
return
|
||||
size = self.get_world_size()
|
||||
rank = self.get_rank()
|
||||
main_kvs = self.ccl_comm_op.get_kvs_addr(rank)
|
||||
main_kvs = torch.tensor(main_kvs).to(torch.uint8)
|
||||
super(CCLBackend, self).broadcast(main_kvs, 0)
|
||||
self.ccl_comm_op.initialize(size, rank, main_kvs)
|
||||
self.initialized = True
|
||||
|
||||
def is_initialized(self):
|
||||
return self.initialized
|
||||
|
||||
def broadcast(self, tensor, src, group=None, async_op=False):
|
||||
self.ccl_comm_op.broadcast(tensor, src, group, async_op)
|
||||
|
||||
def all_reduce(self, tensor, op=ReduceOp.SUM, group=None, async_op=False):
|
||||
use_caching = False
|
||||
if use_caching:
|
||||
match_id = f"{tensor.size()}-{op}"
|
||||
self.ccl_comm_op.all_reduce_caching(tensor, op, match_id, group, async_op)
|
||||
else:
|
||||
self.ccl_comm_op.all_reduce(tensor, op, group, async_op)
|
||||
|
||||
def barrier(self, group=None, async_op=False):
|
||||
self.ccl_comm_op.barrier(group, async_op)
|
|
@ -21,26 +21,14 @@
|
|||
-- deepspeed groups API should be brought into ds.comm
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
import torch
|
||||
import os
|
||||
|
||||
from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT, default_pg_timeout
|
||||
from .constants import *
|
||||
from .reduce_op import *
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
|
||||
class ReduceOp(Enum):
|
||||
SUM = 0
|
||||
PRODUCT = 1
|
||||
MIN = 2
|
||||
MAX = 3
|
||||
BAND = 4
|
||||
BOR = 5
|
||||
BXOR = 6
|
||||
AVG = 7
|
||||
UNUSED = 8
|
||||
|
||||
from deepspeed.comm.ccl import CCLBackend
|
||||
|
||||
from deepspeed.utils.comms_logging import CommsLogger
|
||||
from deepspeed.utils import timer, get_caller_func
|
||||
|
@ -50,7 +38,6 @@ from deepspeed import utils
|
|||
from datetime import timedelta
|
||||
|
||||
# Current deepspeed.comm backend (cdb) global object for simple access by client code
|
||||
use_ds_backend = False
|
||||
cdb = None
|
||||
|
||||
# Create global timer for ops
|
||||
|
@ -62,11 +49,20 @@ comms_logger = CommsLogger()
|
|||
# Maintain objects of all initialized ds backends and assign them using the API functions in this file
|
||||
nccl_backend = None
|
||||
mpi_backend = None
|
||||
ccl_backend = None
|
||||
|
||||
# This should be set here so all rank/size information from the launcher can be propagated
|
||||
from deepspeed.comm.utils import *
|
||||
|
||||
|
||||
class ProcessGroup():
|
||||
|
||||
def __init__(self, comm_id, ranks=[]):
|
||||
self.ranks = ranks
|
||||
self.comm_id = comm_id
|
||||
self.size = len(ranks)
|
||||
|
||||
|
||||
def _configure_using_config_file(config):
|
||||
if config.comms_logger_enabled:
|
||||
comms_logger.configure(config)
|
||||
|
@ -143,11 +139,14 @@ def timed_op(func):
|
|||
|
||||
|
||||
# UNUSED: Future helper function to initialize DS backends
|
||||
def init_deepspeed_backend(ds_backend):
|
||||
def init_deepspeed_backend(ds_backend, timeout, init_method):
|
||||
global cdb
|
||||
global nccl_backend
|
||||
global mpi_backend
|
||||
global use_ds_backend
|
||||
global ccl_backend
|
||||
|
||||
rank = int(os.environ["RANK"])
|
||||
size = int(os.environ["WORLD_SIZE"])
|
||||
|
||||
if ds_backend == NCCL_BACKEND:
|
||||
utils.logger.warn("NCCL backend in DeepSpeed not yet implemented")
|
||||
|
@ -155,6 +154,9 @@ def init_deepspeed_backend(ds_backend):
|
|||
utils.logger.warn("MPI backend in DeepSpeed not yet implemented")
|
||||
elif ds_backend == GLOO_BACKEND:
|
||||
utils.logger.warn("Gloo backend in DeepSpeed not yet implemented")
|
||||
elif ds_backend == CCL_BACKEND:
|
||||
ccl_backend = CCLBackend(rank=rank, world_size=size, timeout=timeout, init_method=init_method)
|
||||
utils.logger.info(f"Initialize {ds_backend} backend")
|
||||
else:
|
||||
utils.logger.warn(f"DeepSpeed does not support {ds_backend} backend")
|
||||
|
||||
|
@ -189,26 +191,23 @@ def is_available() -> bool:
|
|||
return True
|
||||
|
||||
|
||||
def set_backend(backend_name):
|
||||
if not use_ds_backend:
|
||||
utils.logger.error(
|
||||
"DeepSpeed communication backend is required. Please use deepspeed.comm.init_distributed(backend, use_deepspeed=True) to use this functionality"
|
||||
)
|
||||
raise RuntimeError('Error: Custom DeepSpeed backend called without initializing DeepSpeed distributed.')
|
||||
|
||||
def set_backend():
|
||||
global cdb
|
||||
global nccl_backend
|
||||
global mpi_backend
|
||||
global ccl_backend
|
||||
|
||||
try:
|
||||
if backend_name == NCCL_BACKEND:
|
||||
if nccl_backend is not None and nccl_backend.is_initialized():
|
||||
cdb = nccl_backend
|
||||
elif backend_name == MPI_BACKEND:
|
||||
if mpi_backend is not None and mpi_backend.is_initialized():
|
||||
cdb = mpi_backend
|
||||
except Exception as inst:
|
||||
print(inst)
|
||||
backend_name = get_accelerator().communication_backend_name()
|
||||
|
||||
if backend_name == NCCL_BACKEND:
|
||||
if nccl_backend is not None and nccl_backend.is_initialized():
|
||||
cdb = nccl_backend
|
||||
elif backend_name == MPI_BACKEND:
|
||||
if mpi_backend is not None and mpi_backend.is_initialized():
|
||||
cdb = mpi_backend
|
||||
elif backend_name == CCL_BACKEND:
|
||||
if ccl_backend is not None and ccl_backend.is_initialized():
|
||||
cdb = ccl_backend
|
||||
|
||||
|
||||
@timed_op
|
||||
|
@ -392,7 +391,7 @@ def scatter(tensor,
|
|||
@timed_op
|
||||
def barrier(group=None, async_op=False, device_ids=None, prof=False, log_name='barrier', debug=get_caller_func()):
|
||||
global cdb
|
||||
return cdb.barrier(group=group, async_op=async_op, device_ids=device_ids)
|
||||
return cdb.barrier(group=group, async_op=async_op)
|
||||
|
||||
|
||||
@timed_op
|
||||
|
@ -589,6 +588,10 @@ def init_distributed(dist_backend=None,
|
|||
if dist_init_required is None:
|
||||
dist_init_required = cdb is None or not cdb.is_initialized()
|
||||
|
||||
if cdb is None:
|
||||
init_deepspeed_backend(get_accelerator().communication_backend_name(), timeout, init_method)
|
||||
set_backend()
|
||||
utils.logger.info(f'cdb={cdb}')
|
||||
if cdb is None and torch.distributed.is_initialized():
|
||||
# The user initialized torch.dist themselves, create cdb and short-circuit
|
||||
cdb = TorchBackend(dist_backend, timeout, init_method)
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# DeepSpeed Team
|
||||
|
||||
NCCL_BACKEND = 'nccl'
|
||||
CCL_BACKEND = 'ccl'
|
||||
MPI_BACKEND = 'mpi'
|
||||
GLOO_BACKEND = 'gloo'
|
||||
SCCL_BACKEND = 'sccl'
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ReduceOp(Enum):
|
||||
SUM = 0
|
||||
PRODUCT = 1
|
||||
MIN = 2
|
||||
MAX = 3
|
||||
BAND = 4
|
||||
BOR = 5
|
||||
BXOR = 6
|
||||
AVG = 7
|
||||
UNUSED = 8
|
|
@ -50,7 +50,7 @@ def op_report(verbose=True):
|
|||
for op_name, builder in ALL_OPS.items():
|
||||
dots = "." * (max_dots - len(op_name))
|
||||
is_compatible = OKAY if builder.is_compatible(verbose) else no
|
||||
is_installed = installed if installed_ops[op_name] else no
|
||||
is_installed = installed if installed_ops.get(op_name, False) else no
|
||||
dots2 = '.' * ((len(h[1]) + (max_dots2 - len(h[1]))) - (len(is_installed) - color_len))
|
||||
print(op_name, dots, is_installed, dots2, is_compatible)
|
||||
print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1])))
|
||||
|
|
|
@ -418,6 +418,7 @@ class InferenceEngine(Module):
|
|||
|
||||
generic_injection(self.module,
|
||||
fp16=(config.dtype == torch.half) or (config.dtype == torch.int8),
|
||||
bf16=(config.dtype == torch.bfloat16),
|
||||
enable_cuda_graph=config.enable_cuda_graph)
|
||||
|
||||
if isinstance(self.module, torch.nn.Module):
|
||||
|
|
|
@ -8,6 +8,7 @@ PDSH_MAX_FAN_OUT = 1024
|
|||
|
||||
OPENMPI_LAUNCHER = 'openmpi'
|
||||
MPICH_LAUNCHER = 'mpich'
|
||||
IMPI_LAUNCHER = 'impi'
|
||||
SLURM_LAUNCHER = 'slurm'
|
||||
MVAPICH_LAUNCHER = 'mvapich'
|
||||
MVAPICH_TMP_HOSTFILE = '/tmp/deepspeed_mvapich_hostfile'
|
||||
|
|
|
@ -19,13 +19,12 @@ import base64
|
|||
import time
|
||||
import signal
|
||||
import psutil
|
||||
import distutils
|
||||
from collections import defaultdict
|
||||
from typing import Dict
|
||||
from argparse import ArgumentParser, REMAINDER
|
||||
from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT
|
||||
from ..nebula.constants import DLTS_POD_ENV_PATH
|
||||
from ..utils import logger
|
||||
from ..utils import logger, get_numactl_cmd
|
||||
from ..elasticity import is_torch_elastic_compatible
|
||||
from .constants import ELASTIC_TRAINING_ID_DEFAULT
|
||||
|
||||
|
@ -130,89 +129,6 @@ def terminate_process_tree(pid):
|
|||
p.kill()
|
||||
|
||||
|
||||
def parse_range(rng):
|
||||
try:
|
||||
value = int(rng)
|
||||
return range(value, value + 1)
|
||||
except ValueError:
|
||||
# value is not a single number
|
||||
parts = rng.split('-')
|
||||
if len(parts) != 2:
|
||||
raise ValueError("Bad range: '%s', range must be either a number or two number separated by dash" %
|
||||
(rng, ))
|
||||
start = int(parts[0])
|
||||
end = int(parts[1])
|
||||
if start > end:
|
||||
raise ValueError("Bad range: '%s', range end must larger than or equal to start" % (rng, ))
|
||||
return range(start, end + 1)
|
||||
|
||||
|
||||
# parse comma and dash separated range list into list
|
||||
# i.e. "0,2-4,6" --> [0, 2, 3, 4, 6]
|
||||
# rules:
|
||||
# 1. Range list number be comma separated, each item are either a single number,
|
||||
# or a range marked by two numbers (both number are included in the range)
|
||||
# 2. Sub ranges must be in ascend order and not overlap with each other
|
||||
# 3. No space in the range expression
|
||||
def parse_range_list(range_str):
|
||||
number_list = []
|
||||
last = -1
|
||||
range_list = range_str.split(',')
|
||||
for sub_range in range_list:
|
||||
sub_number_list = parse_range(sub_range)
|
||||
if sub_number_list[0] <= last:
|
||||
raise ValueError(
|
||||
"Bad range: '%s', sub ranges must not overlap with each other and should be in ascend order" %
|
||||
(range_str, ))
|
||||
last = sub_number_list[-1]
|
||||
number_list.extend(sub_number_list)
|
||||
return number_list
|
||||
|
||||
|
||||
# return a list of list for cores to numa mapping
|
||||
# [
|
||||
# [ cores for numa 0 ]
|
||||
# [ cores belong to numa 1 ]
|
||||
# ...
|
||||
# ]
|
||||
def get_numa_cores():
|
||||
ret = []
|
||||
output = subprocess.check_output(['numactl', '--hardware']).decode("utf-8")
|
||||
lines = output.split('\n')
|
||||
for line in lines:
|
||||
if line.startswith('available:'):
|
||||
num_numas = int(line.split(' ')[1])
|
||||
break
|
||||
for numa in range(num_numas):
|
||||
for line in lines:
|
||||
if line.startswith(f'node {numa} cpus:'):
|
||||
cores = line.split(' ')[3:]
|
||||
ret.append([int(core) for core in cores])
|
||||
return ret
|
||||
|
||||
|
||||
def check_for_numactl_pkg():
|
||||
libs = dict(
|
||||
dpkg=["-l", "numactl", "apt"],
|
||||
pacman=["-Q", "numactl", "pacman"],
|
||||
rpm=["-q", "numactl", "yum"],
|
||||
)
|
||||
|
||||
found = False
|
||||
for pkgmgr, data in libs.items():
|
||||
flag, lib, tool = data
|
||||
path = distutils.spawn.find_executable(pkgmgr)
|
||||
if path is not None:
|
||||
cmd = f"{pkgmgr} {flag} {lib}"
|
||||
result = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
|
||||
if result.wait() == 0:
|
||||
found = True
|
||||
else:
|
||||
print(f"please install the {lib} package with {tool}")
|
||||
break
|
||||
return found
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
current_env = os.environ.copy()
|
||||
|
@ -308,39 +224,9 @@ def main():
|
|||
# spawn the processes
|
||||
cmd = []
|
||||
if args.bind_cores_to_rank:
|
||||
check_for_numactl_pkg()
|
||||
if 'KMP_AFFINITY' in os.environ.keys():
|
||||
raise ValueError("Environment variable KMP_AFFINITY conflicts with numactl "
|
||||
"because it interfere with how many CPU cores numactl can set. "
|
||||
"Unset KMP_AFFINITY before launching deepspeed.\n\n"
|
||||
"\t$ unset KMP_AFFINITY\n"
|
||||
"\t$ deepspeed <deepspeed command parameters>")
|
||||
if args.bind_core_list != None:
|
||||
core_list = parse_range_list(args.bind_core_list)
|
||||
total_cores = len(core_list)
|
||||
else:
|
||||
total_cores = psutil.cpu_count(logical=False)
|
||||
core_list = range(total_cores)
|
||||
cores_per_rank = total_cores // num_local_procs
|
||||
assert cores_per_rank >= 1, "At least one core needs to be assigned to each rank"
|
||||
core_list_for_rank = core_list[cores_per_rank * local_rank:cores_per_rank * (local_rank + 1)]
|
||||
cores_per_rank, numactl_cmd = get_numactl_cmd(args.bind_core_list, num_local_procs, local_rank)
|
||||
current_env["OMP_NUM_THREADS"] = f"{cores_per_rank}"
|
||||
cmd.append("numactl")
|
||||
|
||||
# check if all cores belong to same numa, if true, bind process to that numa domain with -m parameter
|
||||
numa_cores = get_numa_cores()
|
||||
num_numas = len(numa_cores)
|
||||
for i in range(num_numas):
|
||||
if set(core_list_for_rank) <= set(numa_cores[i]):
|
||||
cmd.append("-m")
|
||||
cmd.append(f"{i}")
|
||||
break
|
||||
|
||||
cmd.append("-C")
|
||||
core_list_str = f"{core_list_for_rank[0]}"
|
||||
for core_id in core_list_for_rank[1:]:
|
||||
core_list_str = f"{core_list_str},{core_id}"
|
||||
cmd.append(f"{core_list_str}")
|
||||
cmd = cmd + numactl_cmd
|
||||
if not args.no_python:
|
||||
cmd.append(sys.executable)
|
||||
cmd.append("-u")
|
||||
|
|
|
@ -11,7 +11,7 @@ import warnings
|
|||
from shlex import split
|
||||
from abc import ABC, abstractmethod
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
from ..utils import logger
|
||||
from ..utils import logger, get_numactl_cmd
|
||||
from .constants import PDSH_MAX_FAN_OUT, MVAPICH_TMP_HOSTFILE
|
||||
|
||||
|
||||
|
@ -228,6 +228,88 @@ class MPICHRunner(MultiNodeRunner):
|
|||
return mpirun_cmd + export_cmd + per_host_cmd
|
||||
|
||||
|
||||
class IMPIRunner(MultiNodeRunner):
|
||||
|
||||
def __init__(self, args, world_info_base64, resource_pool):
|
||||
super().__init__(args, world_info_base64)
|
||||
self.resource_pool = resource_pool
|
||||
|
||||
def backend_exists(self):
|
||||
#TODO: if IB is available we should suggestion mpich
|
||||
return shutil.which('mpirun') #mpich_info
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return "impi"
|
||||
|
||||
def validate_args(self):
|
||||
super().validate_args()
|
||||
#TODO: Allow for include/exclude at node-level but not gpu-level
|
||||
if self.args.include != "" or self.args.exclude != "":
|
||||
raise ValueError(f"{self.name} backend does not support worker include/exclusion")
|
||||
|
||||
if self.args.num_nodes != -1 or self.args.num_gpus != -1:
|
||||
raise ValueError(f"{self.name} backend does not support limiting num nodes/gpus")
|
||||
|
||||
def get_cmd(self, environment, active_resources):
|
||||
devices_per_node = self.resource_pool.values()
|
||||
total_process_count = sum(devices_per_node)
|
||||
process_per_node = list(devices_per_node)[0]
|
||||
if not all([n == process_per_node for n in devices_per_node]):
|
||||
raise ValueError("Intel MPI requires same number of devices per node")
|
||||
|
||||
mpirun_cmd = [
|
||||
'mpirun',
|
||||
'-ppn',
|
||||
f'{process_per_node}',
|
||||
] + split(self.args.launcher_args)
|
||||
export_cmd = []
|
||||
|
||||
for k, v in self.exports.items():
|
||||
export_cmd += ['-genv', f'{k}', f'{v}']
|
||||
|
||||
if self.args.bind_cores_to_rank:
|
||||
cores_per_rank, _ = get_numactl_cmd(self.args.bind_core_list, process_per_node, 0)
|
||||
export_cmd += ['-genv', 'OMP_NUM_THREADS', str(cores_per_rank)]
|
||||
|
||||
export_cmd += ['-genv', 'MASTER_ADDR', str(self.args.master_addr)]
|
||||
export_cmd += ['-genv', 'MASTER_PORT', str(self.args.master_port)]
|
||||
export_cmd += ['-genv', 'WORLD_SIZE', str(total_process_count)]
|
||||
export_cmd += ['-genv', 'LOCAL_SIZE', str(process_per_node)]
|
||||
|
||||
export_cmd += ['-hosts']
|
||||
hosts = ""
|
||||
for i, host in enumerate(self.resource_pool.keys()):
|
||||
if i == 0:
|
||||
hosts = f"{host}"
|
||||
else:
|
||||
hosts += f",{host}"
|
||||
export_cmd += [hosts]
|
||||
|
||||
per_host_cmd = []
|
||||
|
||||
for i in range(total_process_count):
|
||||
local_rank = i % process_per_node
|
||||
python_exec = []
|
||||
if self.args.bind_cores_to_rank:
|
||||
_, numactl_cmd = get_numactl_cmd(self.args.bind_core_list, process_per_node, local_rank)
|
||||
python_exec += numactl_cmd
|
||||
|
||||
if not self.args.no_python:
|
||||
python_exec += [sys.executable, "-u"]
|
||||
if self.args.module:
|
||||
python_exec.append("-m")
|
||||
env_mapping = ['-env', 'RANK', str(i)]
|
||||
env_mapping += ['-env', 'LOCAL_RANK', str(local_rank)]
|
||||
if i == 0:
|
||||
per_host_cmd = ['-n', '1'] + env_mapping + python_exec + [self.user_script] + self.user_arguments
|
||||
else:
|
||||
per_host_cmd = per_host_cmd + [':', '-n', '1'] + env_mapping + python_exec + [self.user_script
|
||||
] + self.user_arguments
|
||||
print(mpirun_cmd + export_cmd + per_host_cmd)
|
||||
return mpirun_cmd + export_cmd + per_host_cmd
|
||||
|
||||
|
||||
class SlurmRunner(MultiNodeRunner):
|
||||
|
||||
def __init__(self, args, world_info_base64, resource_pool):
|
||||
|
|
|
@ -21,8 +21,8 @@ from copy import deepcopy
|
|||
import signal
|
||||
import time
|
||||
|
||||
from .multinode_runner import PDSHRunner, OpenMPIRunner, MVAPICHRunner, SlurmRunner, MPICHRunner
|
||||
from .constants import PDSH_LAUNCHER, OPENMPI_LAUNCHER, MVAPICH_LAUNCHER, SLURM_LAUNCHER, MPICH_LAUNCHER
|
||||
from .multinode_runner import PDSHRunner, OpenMPIRunner, MVAPICHRunner, SlurmRunner, MPICHRunner, IMPIRunner
|
||||
from .constants import PDSH_LAUNCHER, OPENMPI_LAUNCHER, MVAPICH_LAUNCHER, SLURM_LAUNCHER, MPICH_LAUNCHER, IMPI_LAUNCHER
|
||||
from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT
|
||||
from ..nebula.constants import NEBULA_EXPORT_ENVS
|
||||
from ..utils import logger
|
||||
|
@ -95,6 +95,7 @@ def parse_args(args=None):
|
|||
"Default is num_nodes when elastic training is enabled")
|
||||
|
||||
parser.add_argument("--num_gpus",
|
||||
"--num_accelerators",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Max number of GPUs to use on each node, will use "
|
||||
|
@ -116,7 +117,7 @@ def parse_args(args=None):
|
|||
default=PDSH_LAUNCHER,
|
||||
type=str,
|
||||
help="(optional) choose launcher backend for multi-node "
|
||||
"training. Options currently include PDSH, OpenMPI, MVAPICH, SLURM, MPICH.")
|
||||
"training. Options currently include PDSH, OpenMPI, MVAPICH, SLURM, MPICH, IMPI.")
|
||||
|
||||
parser.add_argument("--launcher_args",
|
||||
default="",
|
||||
|
@ -504,6 +505,8 @@ def main(args=None):
|
|||
runner = OpenMPIRunner(args, world_info_base64, resource_pool)
|
||||
elif args.launcher == MPICH_LAUNCHER:
|
||||
runner = MPICHRunner(args, world_info_base64, resource_pool)
|
||||
elif args.launcher == IMPI_LAUNCHER:
|
||||
runner = IMPIRunner(args, world_info_base64, resource_pool)
|
||||
elif args.launcher == MVAPICH_LAUNCHER:
|
||||
runner = MVAPICHRunner(args, world_info_base64, resource_pool)
|
||||
elif args.launcher == SLURM_LAUNCHER:
|
||||
|
|
|
@ -13,7 +13,7 @@ from deepspeed.ops.transformer.inference.ds_attention import DeepSpeedSelfAttent
|
|||
from deepspeed.accelerator import get_accelerator
|
||||
from deepspeed.ops.op_builder import InferenceBuilder
|
||||
|
||||
inference_cuda_module = None
|
||||
inference_module = None
|
||||
|
||||
|
||||
class DeepSpeedTransformerInference(nn.Module):
|
||||
|
@ -48,10 +48,10 @@ class DeepSpeedTransformerInference(nn.Module):
|
|||
DeepSpeedTransformerInference.layer_id += 1
|
||||
|
||||
data_type = torch.half if self.config.dtype == torch.int8 else self.config.dtype
|
||||
global inference_cuda_module
|
||||
if inference_cuda_module is None:
|
||||
global inference_module
|
||||
if inference_module is None:
|
||||
builder = InferenceBuilder()
|
||||
inference_cuda_module = builder.load()
|
||||
inference_module = builder.load()
|
||||
|
||||
if DeepSpeedTransformerInference.layer_id == 1:
|
||||
log_dist(f"DeepSpeed-Inference config: {self.config.__dict__}", [0])
|
||||
|
@ -74,14 +74,22 @@ class DeepSpeedTransformerInference(nn.Module):
|
|||
self.norm_b = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type, device=device),
|
||||
requires_grad=False)
|
||||
self.layer_past = None
|
||||
self.allocate_workspace = inference_cuda_module.allocate_workspace_fp32 if config.dtype == torch.float32 else \
|
||||
inference_cuda_module.allocate_workspace_fp16
|
||||
self._alloc_workspace = True
|
||||
try:
|
||||
if config.dtype == torch.float32:
|
||||
self.allocate_workspace = inference_module.allocate_workspace_fp32
|
||||
elif config.dtype == torch.bfloat16:
|
||||
self.allocate_workspace = inference_module.allocate_workspace_bf16
|
||||
else:
|
||||
self.allocate_workspace = inference_module.allocate_workspace_fp32
|
||||
self._alloc_workspace = True
|
||||
except AttributeError:
|
||||
self.allocate_workspace = None
|
||||
self._alloc_workspace = False
|
||||
|
||||
@classmethod
|
||||
def reset_cache(cls):
|
||||
if inference_cuda_module is not None:
|
||||
inference_cuda_module.reset_cache()
|
||||
if inference_module is not None:
|
||||
inference_module.reset_cache()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -163,7 +171,7 @@ class DeepSpeedTransformerInference(nn.Module):
|
|||
output = self.mlp(attention_output, input, inp_norm, self.attention.attn_ob)
|
||||
|
||||
if not self.config.pre_layer_norm:
|
||||
output = inference_cuda_module.layer_norm(output, self.norm_w, self.norm_b, self.config.epsilon)
|
||||
output = inference_module.layer_norm(output, self.norm_w, self.norm_b, self.config.epsilon)
|
||||
|
||||
output = output.to(input_type)
|
||||
if get_present:
|
||||
|
|
|
@ -184,7 +184,7 @@ def _module_match(module):
|
|||
return None
|
||||
|
||||
|
||||
def generic_injection(module, fp16=False, enable_cuda_graph=True):
|
||||
def generic_injection(module, fp16=False, bf16=False, enable_cuda_graph=True):
|
||||
|
||||
def replace_attn(child, policy):
|
||||
policy_attn = policy.attention(child)
|
||||
|
@ -199,6 +199,7 @@ def generic_injection(module, fp16=False, enable_cuda_graph=True):
|
|||
hidden_size=hidden_size,
|
||||
heads=heads,
|
||||
fp16=fp16,
|
||||
bf16=bf16,
|
||||
triangular_masking=False,
|
||||
max_out_tokens=4096,
|
||||
)
|
||||
|
@ -231,8 +232,8 @@ def generic_injection(module, fp16=False, enable_cuda_graph=True):
|
|||
if isinstance(module, torch.nn.Module):
|
||||
pass
|
||||
else:
|
||||
if fp16 is False:
|
||||
raise ValueError("Generic injection only supported with FP16")
|
||||
if fp16 is False and bf16 is False:
|
||||
raise ValueError("Generic injection only supported with FP16 or BF16")
|
||||
|
||||
try:
|
||||
import diffusers
|
||||
|
|
|
@ -33,6 +33,7 @@ class DeepSpeedInferenceConfig(TransformerConfig):
|
|||
using model-parallel architecture. If the client model already takes care of this, there is no
|
||||
need to pass this argument.
|
||||
fp16: Enable half-precision computation
|
||||
bf16: Enable bf16 floating point computation
|
||||
pre_layer_norm: Select between Pre-LN or Post-LN transformer architecture
|
||||
stochastic_mode: Enable for high performance, please note that this flag has some level of
|
||||
non-determinism and can produce different results on different runs. However, we have seen
|
||||
|
|
|
@ -13,7 +13,7 @@ from deepspeed.accelerator import get_accelerator
|
|||
from deepspeed.ops.op_builder import InferenceBuilder
|
||||
|
||||
# Cuda modules will be imported if needed
|
||||
inference_cuda_module = None
|
||||
inference_module = None
|
||||
minus_inf = -10000.0
|
||||
triton_flash_attn = None
|
||||
|
||||
|
@ -77,8 +77,7 @@ class DeepSpeedDiffusersAttentionFunction(Function):
|
|||
query = query.contiguous()
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
query, key, value = inference_cuda_module.pad_transform_fp16(query, key, value, config.heads,
|
||||
do_flash_attn)
|
||||
query, key, value = inference_module.pad_transform_fp16(query, key, value, config.heads, do_flash_attn)
|
||||
attention_scores = (torch.matmul(query, key.transpose(-1, -2)) * scale).softmax(dim=-1)
|
||||
context_layer = _transpose_for_context(torch.matmul(attention_scores, value))
|
||||
|
||||
|
@ -118,10 +117,10 @@ class DeepSpeedDiffusersAttention(nn.Module):
|
|||
|
||||
data_type = self.config.dtype
|
||||
data_type_fp = torch.half if self.config.dtype == torch.int8 else self.config.dtype
|
||||
global inference_cuda_module
|
||||
if inference_cuda_module is None:
|
||||
global inference_module
|
||||
if inference_module is None:
|
||||
builder = InferenceBuilder()
|
||||
inference_cuda_module = builder.load()
|
||||
inference_module = builder.load()
|
||||
|
||||
if DeepSpeedDiffusersAttention.layer_id == 1:
|
||||
log_dist(f"DeepSpeed-Attention config: {self.config.__dict__}", [0])
|
||||
|
@ -173,13 +172,13 @@ class DeepSpeedDiffusersAttention(nn.Module):
|
|||
# https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/gpt2/modeling_gpt2.py#L191
|
||||
|
||||
if self.config.dtype in [torch.float16, torch.int8]:
|
||||
self.score_context_func = inference_cuda_module.softmax_context_fp16
|
||||
self.linear_func = inference_cuda_module.linear_layer_fp16
|
||||
self.allocate_workspace = inference_cuda_module.allocate_workspace_fp16
|
||||
self.score_context_func = inference_module.softmax_context_fp16
|
||||
self.linear_func = inference_module.linear_layer_fp16
|
||||
self.allocate_workspace = inference_module.allocate_workspace_fp16
|
||||
else:
|
||||
self.score_context_func = inference_cuda_module.softmax_context_fp32
|
||||
self.linear_func = inference_cuda_module.linear_layer_fp32
|
||||
self.allocate_workspace = inference_cuda_module.allocate_workspace_fp32
|
||||
self.score_context_func = inference_module.softmax_context_fp32
|
||||
self.linear_func = inference_module.linear_layer_fp32
|
||||
self.allocate_workspace = inference_module.allocate_workspace_fp32
|
||||
|
||||
def forward(self, input, context=None, input_mask=None):
|
||||
if self.config.layer_id == 0:
|
||||
|
|
|
@ -7,9 +7,8 @@ import json
|
|||
import math
|
||||
import torch
|
||||
from torch.autograd import Function
|
||||
#from ...inference.engine import inference_cuda_module, specialized_mode
|
||||
# Cuda modules will be imported if needed
|
||||
inference_cuda_module = None
|
||||
# accelerator modules will be imported if needed
|
||||
inference_module = None
|
||||
specialized_mode = None
|
||||
import torch.nn as nn
|
||||
from .ds_attention import DeepSpeedSelfAttention
|
||||
|
@ -35,6 +34,7 @@ class DeepSpeedMoEInferenceConfig(DeepSpeedInferenceConfig):
|
|||
using model-parallel architecture. If the client model already takes care of this, there is no
|
||||
need to pass this argument.
|
||||
fp16: Enable half-precision computation
|
||||
bf16: Enable bf16 floating point computation
|
||||
pre_layer_norm: Select between Pre-LN or Post-LN transformer architecture
|
||||
stochastic_mode: Enable for high performance, please note that this flag has some level of
|
||||
non-determinism and can produce different results on different runs. However, we have seen
|
||||
|
@ -55,6 +55,7 @@ class DeepSpeedMoEInferenceConfig(DeepSpeedInferenceConfig):
|
|||
local_rank=-1,
|
||||
mp_size=1,
|
||||
fp16=False,
|
||||
bf16=False,
|
||||
q_int8=False,
|
||||
pre_layer_norm=True,
|
||||
stochastic_mode=False,
|
||||
|
@ -76,9 +77,9 @@ class DeepSpeedMoEInferenceConfig(DeepSpeedInferenceConfig):
|
|||
scale_attn_by_inverse_layer_idx=False):
|
||||
super(DeepSpeedMoEInferenceConfig,
|
||||
self).__init__(hidden_size, (intermediate_size if intermediate_size > 0 else 4 * hidden_size), heads,
|
||||
num_hidden_layers, layer_norm_eps, local_rank, mp_size, fp16, q_int8, pre_layer_norm,
|
||||
stochastic_mode, scale_attention, triangular_masking, local_attention, window_size,
|
||||
return_tuple)
|
||||
num_hidden_layers, layer_norm_eps, local_rank, mp_size, fp16, bf16, q_int8,
|
||||
pre_layer_norm, stochastic_mode, scale_attention, triangular_masking, local_attention,
|
||||
window_size, return_tuple)
|
||||
self.moe_experts = moe_experts
|
||||
self.k = k
|
||||
self.capacity_factor = capacity_factor
|
||||
|
@ -111,14 +112,12 @@ class DeepSpeedMLPFunction(Function):
|
|||
def forward(ctx, input, inter_w, inter_b, config, output_b, output_w, q_scales, q_groups, merge_count, mp_group,
|
||||
async_op):
|
||||
if config.q_int8:
|
||||
intermediate = inference_cuda_module.fused_gemm_gelu_int8(input, inter_w, inter_b, config.epsilon,
|
||||
q_scales[2], (q_groups * (2**merge_count)),
|
||||
config.pre_layer_norm)
|
||||
output = inference_cuda_module.vector_matmul_int8(intermediate, output_w, q_scales[3], q_groups,
|
||||
(merge_count))
|
||||
intermediate = inference_module.fused_gemm_gelu_int8(input, inter_w, inter_b, config.epsilon, q_scales[2],
|
||||
(q_groups * (2**merge_count)), config.pre_layer_norm)
|
||||
output = inference_module.vector_matmul_int8(intermediate, output_w, q_scales[3], q_groups, (merge_count))
|
||||
else:
|
||||
mlp_gemm_func = inference_cuda_module.fused_gemm_gelu_fp16 if config.fp16 else \
|
||||
inference_cuda_module.fused_gemm_gelu_fp32
|
||||
mlp_gemm_func = inference_module.fused_gemm_gelu_fp16 if config.fp16 else \
|
||||
inference_module.fused_gemm_gelu_fp32
|
||||
|
||||
output = mlp_gemm_func(input, inter_w, inter_b, output_w, config.epsilon, config.pre_layer_norm, async_op)
|
||||
if mp_group is not None and dist.get_world_size(group=mp_group) > 1:
|
||||
|
@ -188,17 +187,17 @@ class DeepSpeedMoEInference(nn.Module):
|
|||
|
||||
self.config = config
|
||||
self.config.layer_id = DeepSpeedMoEInference.layer_id
|
||||
global inference_cuda_module
|
||||
global inference_module
|
||||
global specialized_mode
|
||||
if inference_cuda_module is None:
|
||||
if inference_module is None:
|
||||
specialized_mode = False
|
||||
# InferenceSpecializedBuilder is not among DeepSpeed provided builder yet, so we infer by builder name string
|
||||
builder = get_accelerator().create_op_builder("InferenceSpecializedBuilder")
|
||||
if builder != None and builder.is_compatible():
|
||||
inference_cuda_module = builder.load()
|
||||
inference_module = builder.load()
|
||||
specialized_mode = True
|
||||
else:
|
||||
inference_cuda_module = InferenceBuilder().load()
|
||||
inference_module = InferenceBuilder().load()
|
||||
self.config.specialized_mode = specialized_mode
|
||||
assert self.config.dtype != torch.bfloat16, "DeepSpeed MoE Transformer Inference not yet tested for bfloat support"
|
||||
|
||||
|
@ -214,10 +213,10 @@ class DeepSpeedMoEInference(nn.Module):
|
|||
self.res_mlp = DeepSpeedMoEMLP(config, quantize_scales, quantize_groups, merge_count, mlp_extra_grouping,
|
||||
mp_group)
|
||||
self.res_coef = nn.Parameter(torch.Tensor(self.config.hidden_size, 2))
|
||||
self.coef_func = inference_cuda_module.softmax_fp16 if self.config.dtype in [torch.float16, torch.int8] else \
|
||||
inference_cuda_module.softmax_fp32
|
||||
self.vector_matmul_func = inference_cuda_module.vector_matmul_fp16 if self.config.dtype == torch.float16 else \
|
||||
inference_cuda_module.vector_matmul_fp32
|
||||
self.coef_func = inference_module.softmax_fp16 if self.config.dtype in [torch.float16, torch.int8] else \
|
||||
inference_module.softmax_fp32
|
||||
self.vector_matmul_func = inference_module.vector_matmul_fp16 if self.config.dtype == torch.float16 else \
|
||||
inference_module.vector_matmul_fp32
|
||||
|
||||
config.mp_size = 1
|
||||
self.mlp = nn.ModuleList(
|
||||
|
@ -235,12 +234,12 @@ class DeepSpeedMoEInference(nn.Module):
|
|||
|
||||
print("DeepSpeed MoE Transformer Inference config is ", self.config.__dict__)
|
||||
|
||||
self.bias_residual_func = inference_cuda_module.bias_residual_fp16 if self.config.dtype in [torch.float16, torch.int8] else \
|
||||
inference_cuda_module.bias_residual_fp32
|
||||
self.ds_layernorm = inference_cuda_module.layer_norm_fp16 if self.config.dtype in [torch.float16, torch.int8] else \
|
||||
inference_cuda_module.layer_norm_fp32
|
||||
self.einsum_sec_sm_ecm = inference_cuda_module.einsum_sec_sm_ecm_fp16 if self.config.dtype in [torch.float16, torch.int8] else \
|
||||
inference_cuda_module.einsum_sec_sm_ecm_fp32
|
||||
self.bias_residual_func = inference_module.bias_residual_fp16 if self.config.dtype in [torch.float16, torch.int8] else \
|
||||
inference_module.bias_residual_fp32
|
||||
self.ds_layernorm = inference_module.layer_norm_fp16 if self.config.dtype in [torch.float16, torch.int8] else \
|
||||
inference_module.layer_norm_fp32
|
||||
self.einsum_sec_sm_ecm = inference_module.einsum_sec_sm_ecm_fp16 if self.config.dtype in [torch.float16, torch.int8] else \
|
||||
inference_module.einsum_sec_sm_ecm_fp32
|
||||
|
||||
def res_coef_func(self, inp, async_op):
|
||||
inp = self.vector_matmul_func(inp, self.res_coef, async_op)
|
||||
|
@ -347,7 +346,7 @@ class DeepSpeedMoEInference(nn.Module):
|
|||
dim=0)[dist.get_rank(group=self.expert_mp_group)]
|
||||
|
||||
if self.config.mlp_type == 'residual':
|
||||
inference_cuda_module.moe_res_matmul(res_mlp_out, res_coef_out, output)
|
||||
inference_module.moe_res_matmul(res_mlp_out, res_coef_out, output)
|
||||
|
||||
output = self.bias_residual_func(output, residual_add, torch.empty(1))
|
||||
|
||||
|
|
|
@ -10,11 +10,11 @@ from deepspeed.ops.op_builder import InferenceBuilder
|
|||
|
||||
|
||||
class BaseOp(torch.nn.Module):
|
||||
inference_cuda_module = None
|
||||
inference_module = None
|
||||
|
||||
def __init__(self, config: DeepSpeedInferenceConfig):
|
||||
super(BaseOp, self).__init__()
|
||||
self.config = config
|
||||
if BaseOp.inference_cuda_module is None:
|
||||
if BaseOp.inference_module is None:
|
||||
builder = InferenceBuilder()
|
||||
BaseOp.inference_cuda_module = builder.load()
|
||||
BaseOp.inference_module = builder.load()
|
||||
|
|
|
@ -12,12 +12,18 @@ class GELUGemmOp(BaseOp):
|
|||
|
||||
def __init__(self, config: DeepSpeedInferenceConfig):
|
||||
super(GELUGemmOp, self).__init__(config)
|
||||
if self.config.dtype in [torch.float16, torch.int8]:
|
||||
self.fused_gemm_gelu = self.inference_cuda_module.fused_gemm_gelu_fp16 # type: ignore
|
||||
elif self.config.dtype == torch.bfloat16:
|
||||
self.fused_gemm_gelu = self.inference_cuda_module.fused_gemm_gelu_bf16
|
||||
else:
|
||||
self.fused_gemm_gelu = self.inference_cuda_module.fused_gemm_gelu_fp32 # type: ignore
|
||||
try:
|
||||
if self.config.dtype in [torch.float16, torch.int8]:
|
||||
self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_fp16 # type: ignore
|
||||
elif self.config.dtype == torch.bfloat16:
|
||||
self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_bf16
|
||||
else:
|
||||
self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_fp32 # type: ignore
|
||||
except AttributeError:
|
||||
self.fused_gemm_gelu = self.gelu_gemm_fallback
|
||||
|
||||
def gelu_gemm_fallback(self, input, weight, scale, bias, out, out_scale, dtype, transpose):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, weight_out: torch.Tensor):
|
||||
|
||||
|
|
|
@ -12,12 +12,18 @@ class LinearOp(BaseOp):
|
|||
|
||||
def __init__(self, config: DeepSpeedInferenceConfig):
|
||||
super(LinearOp, self).__init__(config)
|
||||
if self.config.dtype in [torch.float16, torch.int8]:
|
||||
self.linear_func = self.inference_cuda_module.linear_layer_fp16
|
||||
elif self.config.dtype == torch.bfloat16:
|
||||
self.linear_func = self.inference_cuda_module.linear_layer_bf16
|
||||
else:
|
||||
self.linear_func = self.inference_cuda_module.linear_layer_fp32
|
||||
try:
|
||||
if self.config.dtype in [torch.float16, torch.int8]:
|
||||
self.linear_func = self.inference_module.linear_layer_fp16
|
||||
elif self.config.dtype == torch.bfloat16:
|
||||
self.linear_func = self.inference_module.linear_layer_bf16
|
||||
else:
|
||||
self.linear_func = self.inference_module.linear_layer_fp32
|
||||
except AttributeError:
|
||||
self.linear_func = self.linear_fallback
|
||||
|
||||
def linear_fallback(self, input, weight, bias, add_bias, do_flash_attn, num_heads, transpose):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self,
|
||||
input: torch.Tensor,
|
||||
|
|
|
@ -5,7 +5,9 @@
|
|||
|
||||
from typing import Optional
|
||||
|
||||
import os
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from ..config import DeepSpeedInferenceConfig
|
||||
from .base import BaseOp
|
||||
from deepspeed.utils.types import NormType
|
||||
|
@ -15,21 +17,43 @@ class MLPGemmOp(BaseOp):
|
|||
|
||||
def __init__(self, config: DeepSpeedInferenceConfig):
|
||||
super(MLPGemmOp, self).__init__(config)
|
||||
try:
|
||||
if self.config.norm_type == NormType.LayerNorm:
|
||||
if self.config.dtype in [torch.float16, torch.int8]:
|
||||
self.mlp_gemm_func = self.inference_module.mlp_gemm_fp16 # type: ignore
|
||||
elif self.config.dtype == torch.bfloat16:
|
||||
self.mlp_gemm_func = self.inference_module.mlp_gemm_bf16
|
||||
else:
|
||||
self.mlp_gemm_func = self.inference_module.mlp_gemm_fp32 # type: ignore
|
||||
elif self.config.norm_type == NormType.RMSNorm:
|
||||
if self.config.dtype in [torch.float16, torch.int8]:
|
||||
self.mlp_gemm_func = self.inference_module.rms_mlp_gemm_fp16 # type: ignore
|
||||
elif self.config.dtype == torch.bfloat16:
|
||||
self.mlp_gemm_func = self.inference_module.rms_mlp_gemm_bf16
|
||||
else:
|
||||
self.mlp_gemm_func = self.inference_module.rms_mlp_gemm_fp32 # type: ignore
|
||||
except AttributeError:
|
||||
if self.config.norm_type == NormType.LayerNorm:
|
||||
self.mlp_gemm_func = self.mlp_gemm_fallback
|
||||
elif self.config.norm_type == NormType.RMSNorm:
|
||||
self.mlp_gemm_func = self.rms_mlp_gemm_fallback
|
||||
|
||||
if self.config.norm_type == NormType.LayerNorm:
|
||||
if self.config.dtype in [torch.float16, torch.int8]:
|
||||
self.mlp_gemm_func = self.inference_cuda_module.mlp_gemm_fp16 # type: ignore
|
||||
elif self.config.dtype == torch.bfloat16:
|
||||
self.mlp_gemm_func = self.inference_cuda_module.mlp_gemm_bf16
|
||||
else:
|
||||
self.mlp_gemm_func = self.inference_cuda_module.mlp_gemm_fp32 # type: ignore
|
||||
elif self.config.norm_type == NormType.RMSNorm:
|
||||
if self.config.dtype in [torch.float16, torch.int8]:
|
||||
self.mlp_gemm_func = self.inference_cuda_module.rms_mlp_gemm_fp16 # type: ignore
|
||||
elif self.config.dtype == torch.bfloat16:
|
||||
self.mlp_gemm_func = self.inference_cuda_module.rms_mlp_gemm_bf16
|
||||
else:
|
||||
self.mlp_gemm_func = self.inference_cuda_module.rms_mlp_gemm_fp32 # type: ignore
|
||||
def mlp_gemm_fallback(self, input, residual, input_bias, weight_interm, weight_out, bias, gamma, beta, eps,
|
||||
pre_layer_norm, mlp_after_attn, interm_scale, out_scale, dtype, mlp_act_func_type,
|
||||
transpose):
|
||||
if os.environ.get('DS_KI_FALLBACK') == 'True' and mlp_after_attn and not transpose:
|
||||
residual_add = F.layer_norm(input + residual + input_bias, (input.shape[2], ), gamma, beta,
|
||||
self.config.epsilon)
|
||||
tmp = torch.matmul(residual_add, weight_interm)
|
||||
tmp = F.gelu(tmp + bias)
|
||||
output = torch.matmul(tmp, weight_out)
|
||||
return (output, residual_add)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def rms_mlp_gemm_fallback(self, input, residual, weight_interm, weight_out, gamma, eps, interm_scale, out_scale,
|
||||
dtype, mlp_act_func_type, transpose):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self,
|
||||
input: torch.Tensor,
|
||||
|
|
|
@ -3,7 +3,9 @@
|
|||
|
||||
# DeepSpeed Team
|
||||
|
||||
import os
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from ..config import DeepSpeedInferenceConfig
|
||||
from .base import BaseOp
|
||||
from deepspeed.utils.types import NormType
|
||||
|
@ -13,21 +15,40 @@ class QKVGemmOp(BaseOp):
|
|||
|
||||
def __init__(self, config: DeepSpeedInferenceConfig):
|
||||
super(QKVGemmOp, self).__init__(config)
|
||||
try:
|
||||
if self.config.norm_type == NormType.LayerNorm:
|
||||
if self.config.dtype in [torch.float16, torch.int8]:
|
||||
self.qkv_gemm_func = self.inference_module.qkv_gemm_fp16 # type: ignore
|
||||
elif self.config.dtype == torch.bfloat16:
|
||||
self.qkv_gemm_func = self.inference_module.qkv_gemm_bf16
|
||||
else:
|
||||
self.qkv_gemm_func = self.inference_module.qkv_gemm_fp32 # type: ignore
|
||||
elif self.config.norm_type == NormType.RMSNorm:
|
||||
if self.config.dtype in [torch.float16, torch.int8]:
|
||||
self.qkv_gemm_func = self.inference_module.rms_qkv_gemm_fp16 # type: ignore
|
||||
elif self.config.dtype == torch.bfloat16:
|
||||
self.qkv_gemm_func = self.inference_module.rms_qkv_gemm_bf16
|
||||
else:
|
||||
self.qkv_gemm_func = self.inference_module.rms_qkv_gemm_fp32 # type: ignore
|
||||
except AttributeError:
|
||||
if self.config.norm_type == NormType.LayerNorm:
|
||||
self.qkv_gemm_func = self.qkv_gemm_fallback
|
||||
elif self.config.norm_type == NormType.RMSNorm:
|
||||
self.qkv_gemm_func = self.rms_qkv_gemm_fallback
|
||||
|
||||
if self.config.norm_type == NormType.LayerNorm:
|
||||
if self.config.dtype in [torch.float16, torch.int8]:
|
||||
self.qkv_gemm_func = self.inference_cuda_module.qkv_gemm_fp16 # type: ignore
|
||||
elif self.config.dtype == torch.bfloat16:
|
||||
self.qkv_gemm_func = self.inference_cuda_module.qkv_gemm_bf16
|
||||
else:
|
||||
self.qkv_gemm_func = self.inference_cuda_module.qkv_gemm_fp32 # type: ignore
|
||||
elif self.config.norm_type == NormType.RMSNorm:
|
||||
if self.config.dtype in [torch.float16, torch.int8]:
|
||||
self.qkv_gemm_func = self.inference_cuda_module.rms_qkv_gemm_fp16 # type: ignore
|
||||
elif self.config.dtype == torch.bfloat16:
|
||||
self.qkv_gemm_func = self.inference_cuda_module.rms_qkv_gemm_bf16
|
||||
else:
|
||||
self.qkv_gemm_func = self.inference_cuda_module.rms_qkv_gemm_fp32 # type: ignore
|
||||
def qkv_gemm_fallback(self, input, weight, q_scale, bias, gamma, beta, eps, add_bias, q_int8, transpose):
|
||||
if os.environ.get('DS_KI_FALLBACK') == 'True' and not transpose:
|
||||
inp_norm = F.layer_norm(input, (input.shape[2], ), gamma, beta, eps)
|
||||
tmp = torch.matmul(inp_norm, weight)
|
||||
if add_bias:
|
||||
tmp += bias
|
||||
output = [tmp, inp_norm]
|
||||
return output
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def rms_qkv_gemm_fallback(self, input, weight, q_scale, gamma, eps, q_int8, transpose):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, gamma: torch.Tensor,
|
||||
beta: torch.Tensor):
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
# DeepSpeed Team
|
||||
|
||||
import os
|
||||
import torch
|
||||
from typing import Optional
|
||||
from ..config import DeepSpeedInferenceConfig
|
||||
|
@ -13,13 +14,17 @@ class ResidualAddOp(BaseOp):
|
|||
|
||||
def __init__(self, config: DeepSpeedInferenceConfig):
|
||||
super(ResidualAddOp, self).__init__(config)
|
||||
if self.config.dtype in [torch.float16, torch.int8]:
|
||||
self.residual_add_func = self.inference_cuda_module.residual_add_bias_fp16
|
||||
elif self.config.dtype == torch.bfloat16:
|
||||
self.residual_add_func = self.inference_cuda_module.residual_add_bias_bf16
|
||||
else:
|
||||
self.residual_add_func = self.inference_cuda_module.residual_add_bias_fp32
|
||||
self._vector_add = self.inference_cuda_module._vector_add
|
||||
try:
|
||||
if self.config.dtype in [torch.float16, torch.int8]:
|
||||
self.residual_add_func = self.inference_module.residual_add_bias_fp16
|
||||
elif self.config.dtype == torch.bfloat16:
|
||||
self.residual_add_func = self.inference_module.residual_add_bias_bf16
|
||||
else:
|
||||
self.residual_add_func = self.inference_module.residual_add_bias_fp32
|
||||
self._vector_add = self.inference_module._vector_add
|
||||
except AttributeError:
|
||||
self.residual_add_func = None
|
||||
self._vector_add = None
|
||||
|
||||
def forward(self,
|
||||
hidden_state: torch.Tensor,
|
||||
|
@ -30,14 +35,28 @@ class ResidualAddOp(BaseOp):
|
|||
attention_bias: Optional[torch.Tensor] = None,
|
||||
final_bias: Optional[torch.Tensor] = None):
|
||||
|
||||
if final_bias is None:
|
||||
residual = self._vector_add(residual, hidden_state, 1.0 / self.config.mp_size)
|
||||
else:
|
||||
if not self.config.pre_layer_norm and residual_add is not None:
|
||||
# only use residual add if its set and we are not pre layer norm
|
||||
residual = residual_add
|
||||
if self.residual_add_func != None:
|
||||
if final_bias is None:
|
||||
residual = self._vector_add(residual, hidden_state, 1.0 / self.config.mp_size)
|
||||
else:
|
||||
if not self.config.pre_layer_norm and residual_add is not None:
|
||||
# only use residual add if its set and we are not pre layer norm
|
||||
residual = residual_add
|
||||
|
||||
self.residual_add_func(hidden_state, residual, attention_output, attention_bias, final_bias,
|
||||
self.config.mp_size, self.config.mlp_after_attn, add_bias,
|
||||
self.config.pre_layer_norm)
|
||||
self.residual_add_func(hidden_state, residual, attention_output, attention_bias, final_bias,
|
||||
self.config.mp_size, self.config.mlp_after_attn, add_bias,
|
||||
self.config.pre_layer_norm)
|
||||
else:
|
||||
# fallback
|
||||
if os.environ.get('DS_KI_FALLBACK') == 'True' and self.config.mlp_after_attn:
|
||||
if self.config.pre_layer_norm:
|
||||
tmp = (residual.float() + attention_output.float() + attention_bias.float() +
|
||||
final_bias.float()) / self.config.mp_size + hidden_state.float()
|
||||
else:
|
||||
tmp = residual.float() + hidden_state.float() + final_bias.float()
|
||||
|
||||
input_dtype = hidden_state.dtype
|
||||
residual = tmp.to(input_dtype)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return residual
|
||||
|
|
|
@ -3,7 +3,9 @@
|
|||
|
||||
# DeepSpeed Team
|
||||
|
||||
import os
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from ..config import DeepSpeedInferenceConfig
|
||||
from .base import BaseOp
|
||||
|
||||
|
@ -12,19 +14,40 @@ class SoftmaxOp(BaseOp):
|
|||
|
||||
def __init__(self, config: DeepSpeedInferenceConfig):
|
||||
super(SoftmaxOp, self).__init__(config)
|
||||
if self.config.dtype in [torch.float16, torch.int8]:
|
||||
self.softmax_func = self.inference_cuda_module.softmax_fp16
|
||||
elif self.config.dtype == torch.bfloat16:
|
||||
self.softmax_func = self.inference_cuda_module.softmax_bf16
|
||||
else:
|
||||
self.softmax_func = self.inference_cuda_module.softmax_fp32
|
||||
self.num_attention_heads_per_partition = config.heads // config.mp_size
|
||||
try:
|
||||
if self.config.dtype in [torch.float16, torch.int8]:
|
||||
self.softmax_func = self.inference_module.softmax_fp16
|
||||
elif self.config.dtype == torch.bfloat16:
|
||||
self.softmax_func = self.inference_module.softmax_bf16
|
||||
else:
|
||||
self.softmax_func = self.inference_module.softmax_fp32
|
||||
except AttributeError:
|
||||
self.softmax_func = self.softmax_fallback
|
||||
|
||||
def _not_implemented(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
def softmax_fallback(self, attn_scores, attn_mask, alibi, triangular, recompute, local_attention, window_size,
|
||||
async_op, layer_scale, head_offset, mp_size):
|
||||
if os.environ.get('DS_KI_FALLBACK') == 'True':
|
||||
alibi = alibi[head_offset:head_offset + self.num_attention_heads_per_partition]
|
||||
input_dtype = attn_scores.dtype
|
||||
if (triangular):
|
||||
tri = ~torch.tril(torch.ones(attn_scores.size(), device=attn_scores.device)).to(bool)
|
||||
attn_scores = torch.masked_fill(attn_scores * layer_scale, tri, torch.finfo(input_dtype).min)
|
||||
if alibi is not None:
|
||||
attn_scores += alibi
|
||||
if attn_mask is not None:
|
||||
# expand atten_mask from two dim into 4 dim, insert two dims in the middle
|
||||
attn_mask = attn_mask[:, None, None, :]
|
||||
attn_scores += attn_mask
|
||||
output = F.softmax(attn_scores, dim=-1, dtype=torch.float32).to(input_dtype)
|
||||
return output
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, attn_scores: torch.Tensor, attn_mask: torch.Tensor, alibi: torch.Tensor, triangular: bool,
|
||||
recompute: bool, local_attention: bool, window_size: int, async_op: bool, layer_scale: float,
|
||||
head_offset: int):
|
||||
output = self.softmax_func(attn_scores, attn_mask, alibi, triangular, recompute, local_attention, window_size,
|
||||
async_op, layer_scale, head_offset, self.config.mp_size)
|
||||
|
||||
return output
|
||||
|
|
|
@ -13,12 +13,20 @@ class SoftmaxContextOp(BaseOp):
|
|||
|
||||
def __init__(self, config: DeepSpeedInferenceConfig):
|
||||
super(SoftmaxContextOp, self).__init__(config)
|
||||
if self.config.dtype in [torch.float16, torch.int8]:
|
||||
self.softmax_context_func = self.inference_cuda_module.softmax_context_fp16
|
||||
elif self.config.dtype == torch.bfloat16:
|
||||
self.softmax_context_func = self.inference_cuda_module.softmax_context_bf16
|
||||
else:
|
||||
self.softmax_context_func = self.inference_cuda_module.softmax_context_fp32
|
||||
try:
|
||||
if self.config.dtype in [torch.float16, torch.int8]:
|
||||
self.softmax_context_func = self.inference_module.softmax_context_fp16
|
||||
elif self.config.dtype == torch.bfloat16:
|
||||
self.softmax_context_func = self.inference_module.softmax_context_bf16
|
||||
else:
|
||||
self.softmax_context_func = self.inference_module.softmax_context_fp32
|
||||
except AttributeError:
|
||||
self.softmax_context_func = self.softmax_context_fallback
|
||||
|
||||
def softmax_context_fallback(self, query_key_value, attn_mask, rotary_dim, rotate_half, roteate_every_two, heads,
|
||||
norm_factor, triangular_masking, local_attention, window_size, no_masking, layer_id,
|
||||
num_layers, alibi):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, query_key_value: torch.Tensor, attn_mask: torch.Tensor, heads: int, norm_factor: float,
|
||||
no_masking: bool, layer_id: int, num_layers: int, alibi: torch.Tensor):
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
# DeepSpeed Team
|
||||
|
||||
import os
|
||||
import torch
|
||||
from ..config import DeepSpeedInferenceConfig
|
||||
from .base import BaseOp
|
||||
|
@ -12,12 +13,21 @@ class VectorMatMulOp(BaseOp):
|
|||
|
||||
def __init__(self, config: DeepSpeedInferenceConfig):
|
||||
super(VectorMatMulOp, self).__init__(config)
|
||||
if self.config.dtype in [torch.float16, torch.int8]:
|
||||
self.vector_matmul_func = self.inference_cuda_module.vector_matmul_fp16
|
||||
elif self.config.dtype == torch.bfloat16:
|
||||
self.vector_matmul_func = self.inference_cuda_module.vector_matmul_bf16
|
||||
try:
|
||||
if self.config.dtype in [torch.float16, torch.int8]:
|
||||
self.vector_matmul_func = self.inference_module.vector_matmul_fp16
|
||||
elif self.config.dtype == torch.bfloat16:
|
||||
self.vector_matmul_func = self.inference_module.vector_matmul_bf16
|
||||
else:
|
||||
self.vector_matmul_func = self.inference_module.vector_matmul_fp32
|
||||
except AttributeError:
|
||||
self.vector_matmul_func = self.vector_matmul_fallback
|
||||
|
||||
def vector_matmul_fallback(self, input, weight, async_op, q_scale, q_int8, transpose):
|
||||
if os.environ.get('DS_KI_FALLBACK') == 'True' and not transpose:
|
||||
return torch.matmul(input, weight)
|
||||
else:
|
||||
self.vector_matmul_func = self.inference_cuda_module.vector_matmul_fp32
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, input: torch.Tensor, weight: torch.Tensor, async_op: bool = False):
|
||||
q_scale = weight.scale if hasattr(weight, 'scale') else torch.empty(1)
|
||||
|
|
|
@ -14,3 +14,4 @@ from .tensor_fragment import tensor_fragment, get_full_hp_param, get_hp_fragment
|
|||
from .tensor_fragment import safe_get_full_fp32_param, safe_get_full_grad, safe_get_full_optimizer_state
|
||||
from .mixed_precision_linkage import link_hp_params
|
||||
from deepspeed.runtime.dataloader import RepeatingLoader
|
||||
from .numa import get_numactl_cmd
|
||||
|
|
|
@ -0,0 +1,148 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
# return a list of list for cores to numa mapping
|
||||
# [
|
||||
# [ cores for numa 0 ]
|
||||
# [ cores belong to numa 1 ]
|
||||
# ...
|
||||
# ]
|
||||
|
||||
import distutils
|
||||
import os
|
||||
import psutil
|
||||
import subprocess
|
||||
|
||||
|
||||
# return a list of list for cores to numa mapping
|
||||
# [
|
||||
# [ cores for numa 0 ]
|
||||
# [ cores belong to numa 1 ]
|
||||
# ...
|
||||
# ]
|
||||
def get_numa_cores():
|
||||
ret = []
|
||||
output = subprocess.check_output(['numactl', '--hardware']).decode("utf-8")
|
||||
lines = output.split('\n')
|
||||
for line in lines:
|
||||
if line.startswith('available:'):
|
||||
num_numas = int(line.split(' ')[1])
|
||||
break
|
||||
for numa in range(num_numas):
|
||||
for line in lines:
|
||||
if line.startswith(f'node {numa} cpus:'):
|
||||
cores = line.split(' ')[3:]
|
||||
ret.append([int(core) for core in cores])
|
||||
return ret
|
||||
|
||||
|
||||
def check_for_numactl_pkg():
|
||||
libs = dict(
|
||||
dpkg=["-l", "numactl", "apt"],
|
||||
pacman=["-Q", "numactl", "pacman"],
|
||||
rpm=["-q", "numactl", "yum"],
|
||||
)
|
||||
|
||||
found = False
|
||||
for pkgmgr, data in libs.items():
|
||||
flag, lib, tool = data
|
||||
path = distutils.spawn.find_executable(pkgmgr)
|
||||
if path is not None:
|
||||
cmd = f"{pkgmgr} {flag} {lib}"
|
||||
result = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
|
||||
if result.wait() == 0:
|
||||
found = True
|
||||
else:
|
||||
print(f"please install the {lib} package with {tool}")
|
||||
break
|
||||
return found
|
||||
|
||||
|
||||
def parse_range(rng):
|
||||
try:
|
||||
value = int(rng)
|
||||
return range(value, value + 1)
|
||||
except ValueError:
|
||||
# value is not a single number
|
||||
parts = rng.split('-')
|
||||
if len(parts) != 2:
|
||||
raise ValueError("Bad range: '%s', range must be either a number or two number separated by dash" %
|
||||
(rng, ))
|
||||
start = int(parts[0])
|
||||
end = int(parts[1])
|
||||
if start > end:
|
||||
raise ValueError("Bad range: '%s', range end must larger than or equal to start" % (rng, ))
|
||||
return range(start, end + 1)
|
||||
|
||||
|
||||
# parse comma and dash separated range list into list
|
||||
# i.e. "0,2-4,6" --> [0, 2, 3, 4, 6]
|
||||
# rules:
|
||||
# 1. Range list number be comma separated, each item are either a single number,
|
||||
# or a range marked by two numbers (both number are included in the range)
|
||||
# 2. Sub ranges must be in ascend order and not overlap with each other
|
||||
# 3. No space in the range expression
|
||||
def parse_range_list(range_str):
|
||||
number_list = []
|
||||
last = -1
|
||||
range_list = range_str.split(',')
|
||||
for sub_range in range_list:
|
||||
sub_number_list = parse_range(sub_range)
|
||||
if sub_number_list[0] <= last:
|
||||
raise ValueError(
|
||||
"Bad range: '%s', sub ranges must not overlap with each other and should be in ascend order" %
|
||||
(range_str, ))
|
||||
last = sub_number_list[-1]
|
||||
number_list.extend(sub_number_list)
|
||||
return number_list
|
||||
|
||||
|
||||
def get_numactl_cmd(bind_core_list, num_local_procs, local_rank):
|
||||
numactl_cmd = []
|
||||
check_for_numactl_pkg()
|
||||
if 'KMP_AFFINITY' in os.environ.keys():
|
||||
raise ValueError("Environment variable KMP_AFFINITY conflicts with numactl "
|
||||
"because it interfere with how many CPU cores numactl can set. "
|
||||
"Unset KMP_AFFINITY before launching deepspeed.\n\n"
|
||||
"\t$ unset KMP_AFFINITY\n"
|
||||
"\t$ deepspeed <deepspeed command parameters>")
|
||||
if bind_core_list != None:
|
||||
core_list = parse_range_list(bind_core_list)
|
||||
total_cores = len(core_list)
|
||||
else:
|
||||
total_cores = psutil.cpu_count(logical=False)
|
||||
core_list = range(total_cores)
|
||||
cores_per_rank = total_cores // num_local_procs
|
||||
assert cores_per_rank >= 1, "At least one core needs to be assigned to each rank"
|
||||
core_list_for_rank = core_list[cores_per_rank * local_rank:cores_per_rank * (local_rank + 1)]
|
||||
numactl_cmd.append("numactl")
|
||||
|
||||
# check if all cores belong to same numa, if true, bind process to that numa domain with -m parameter
|
||||
numa_cores = get_numa_cores()
|
||||
num_numas = len(numa_cores)
|
||||
for i in range(num_numas):
|
||||
if set(core_list_for_rank) <= set(numa_cores[i]):
|
||||
numactl_cmd.append("-m")
|
||||
numactl_cmd.append(f"{i}")
|
||||
break
|
||||
|
||||
numactl_cmd.append("-C")
|
||||
last_core = core_list_for_rank[0]
|
||||
first_core = last_core
|
||||
core_list_str = f"{last_core}"
|
||||
for core_id in core_list_for_rank[1:]:
|
||||
if core_id == last_core + 1:
|
||||
last_core = core_id
|
||||
continue
|
||||
else:
|
||||
if first_core == last_core:
|
||||
core_list_str = f"{core_list_str},{core_id}"
|
||||
else:
|
||||
core_list_str = f"{core_list_str}-{last_core},{core_id}"
|
||||
first_core = core_id
|
||||
last_core = core_id
|
||||
if first_core != last_core:
|
||||
core_list_str = f"{core_list_str}-{last_core}"
|
||||
numactl_cmd.append(f"{core_list_str}")
|
||||
return cores_per_rank, numactl_cmd
|
|
@ -40,27 +40,43 @@ class SynchronizedWallClockTimer:
|
|||
self.name_ = name
|
||||
self.started_ = False
|
||||
self.event_timers = []
|
||||
self.use_host_timer = get_accelerator().is_synchronized_device()
|
||||
self.start_event = None
|
||||
self.elapsed_records = None
|
||||
self.start_time = 0.0
|
||||
self.end_time = 0.0
|
||||
|
||||
def start(self):
|
||||
"""Start the timer."""
|
||||
assert not self.started_, f"{self.name_} timer has already been started"
|
||||
self.start_event = get_accelerator().Event(enable_timing=True)
|
||||
self.start_event.record()
|
||||
if self.use_host_timer:
|
||||
self.start_time = time.time()
|
||||
else:
|
||||
event_class = get_accelerator().Event
|
||||
self.start_event = event_class(enable_timing=True)
|
||||
self.start_event.record()
|
||||
self.started_ = True
|
||||
|
||||
def stop(self, reset=False, record=False):
|
||||
"""Stop the timer."""
|
||||
assert self.started_, "timer is not started"
|
||||
end_event = get_accelerator().Event(enable_timing=True)
|
||||
end_event.record()
|
||||
self.event_timers.append(CudaEventTimer(self.start_event, end_event))
|
||||
self.start_event = None
|
||||
event_class = get_accelerator().Event
|
||||
if self.use_host_timer:
|
||||
self.end_time = time.time()
|
||||
self.event_timers.append(self.end_time - self.start_time)
|
||||
else:
|
||||
event_class = get_accelerator().Event
|
||||
end_event = event_class(enable_timing=True)
|
||||
end_event.record()
|
||||
self.event_timers.append(CudaEventTimer(self.start_event, end_event))
|
||||
self.start_event = None
|
||||
self.started_ = False
|
||||
|
||||
def _get_elapsed_msec(self):
|
||||
self.elapsed_records = [et.get_elapsed_msec() for et in self.event_timers]
|
||||
if self.use_host_timer:
|
||||
self.elapsed_records = [et * 1000.0 for et in self.event_timers]
|
||||
else:
|
||||
self.elapsed_records = [et.get_elapsed_msec() for et in self.event_timers]
|
||||
self.event_timers.clear()
|
||||
return sum(self.elapsed_records)
|
||||
|
||||
|
|
|
@ -441,7 +441,7 @@ class OpBuilder(ABC):
|
|||
|
||||
def load(self, verbose=True):
|
||||
from deepspeed.git_version_info import installed_ops, torch_info
|
||||
if installed_ops[self.name]:
|
||||
if installed_ops.get(self.name, False):
|
||||
# Ensure the op we're about to load was compiled with the same
|
||||
# torch/cuda versions we are currently using at runtime.
|
||||
self.validate_torch_version(torch_info)
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
'''Copyright The Microsoft DeepSpeed Team'''
|
||||
|
||||
from .comm import CCLCommBuilder
|
||||
from .no_impl import NotImplementedBuilder
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
try:
|
||||
# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
|
||||
# if successful this also means we're doing a local install and not JIT compile path
|
||||
from op_builder import __deepspeed__ # noqa: F401
|
||||
from op_builder.builder import OpBuilder
|
||||
except ImportError:
|
||||
from deepspeed.ops.op_builder.builder import OpBuilder
|
||||
|
||||
|
||||
class CPUOpBuilder(OpBuilder):
|
||||
|
||||
def builder(self):
|
||||
from torch.utils.cpp_extension import CppExtension as ExtensionBuilder
|
||||
|
||||
compile_args = {'cxx': self.strip_empty_entries(self.cxx_args())}
|
||||
|
||||
cpp_ext = ExtensionBuilder(name=self.absolute_name(),
|
||||
sources=self.strip_empty_entries(self.sources()),
|
||||
include_dirs=self.strip_empty_entries(self.include_paths()),
|
||||
libraries=self.strip_empty_entries(self.libraries_args()),
|
||||
extra_compile_args=compile_args)
|
||||
|
||||
return cpp_ext
|
||||
|
||||
def cxx_args(self):
|
||||
return ['-O3', '-std=c++14', '-g', '-Wno-reorder']
|
||||
|
||||
def libraries_args(self):
|
||||
return []
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import os
|
||||
from .builder import CPUOpBuilder
|
||||
|
||||
|
||||
class CCLCommBuilder(CPUOpBuilder):
|
||||
BUILD_VAR = "DS_BUILD_CCL_COMM"
|
||||
NAME = "deepspeed_ccl_comm"
|
||||
|
||||
def __init__(self, name=None):
|
||||
name = self.NAME if name is None else name
|
||||
super().__init__(name=name)
|
||||
|
||||
def absolute_name(self):
|
||||
return f'deepspeed.ops.comm.{self.NAME}_op'
|
||||
|
||||
def sources(self):
|
||||
return ['csrc/cpu/comm/ccl.cpp']
|
||||
|
||||
def include_paths(self):
|
||||
includes = ['csrc/cpu/includes']
|
||||
return includes
|
||||
|
||||
def is_compatible(self, verbose=True):
|
||||
# TODO: add soft compatibility check for private binary release.
|
||||
# a soft check, as in we know it can be trivially changed.
|
||||
return super().is_compatible(verbose)
|
||||
|
||||
def extra_ldflags(self):
|
||||
ccl_root_path = os.environ.get("CCL_ROOT")
|
||||
if ccl_root_path == None:
|
||||
raise ValueError(
|
||||
"Didn't find CCL_ROOT, install oneCCL from https://github.com/oneapi-src/oneCCL and source its environment variable"
|
||||
)
|
||||
return []
|
||||
else:
|
||||
return ['-lccl', f'-L{ccl_root_path}/lib']
|
|
@ -0,0 +1,24 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from .builder import CPUOpBuilder
|
||||
|
||||
|
||||
class NotImplementedBuilder(CPUOpBuilder):
|
||||
BUILD_VAR = "DS_BUILD_NOT_IMPLEMENTED"
|
||||
NAME = "deepspeed_not_implemented"
|
||||
|
||||
def __init__(self, name=None):
|
||||
name = self.NAME if name is None else name
|
||||
super().__init__(name=name)
|
||||
|
||||
def absolute_name(self):
|
||||
return f'deepspeed.ops.comm.{self.NAME}_op'
|
||||
|
||||
def load(self, verbose=True):
|
||||
raise ValueError("This op had not been implemented on CPU backend.")
|
||||
|
||||
def sources(self):
|
||||
return []
|
|
@ -0,0 +1 @@
|
|||
intel_extension_for_pytorch
|
|
@ -107,7 +107,13 @@ class TestDistributedFixture(DistributedTest):
|
|||
|
||||
|
||||
class TestDistAllReduce(DistributedTest):
|
||||
world_size = [1, 2, 4]
|
||||
device_count = get_accelerator().device_count()
|
||||
if device_count >= 4:
|
||||
world_size = [1, 2, 4]
|
||||
elif device_count >= 2:
|
||||
world_size = [1, 2]
|
||||
else:
|
||||
world_size = [1]
|
||||
|
||||
def test(self):
|
||||
x = torch.ones(1, 3).to(get_accelerator().device_name()) * (dist.get_rank() + 1)
|
||||
|
|
|
@ -56,22 +56,26 @@ def set_accelerator_visible():
|
|||
if is_rocm_pytorch:
|
||||
rocm_smi = subprocess.check_output(['rocm-smi', '--showid'])
|
||||
gpu_ids = filter(lambda s: 'GPU' in s, rocm_smi.decode('utf-8').strip().split('\n'))
|
||||
num_gpus = len(list(gpu_ids))
|
||||
num_accelerators = len(list(gpu_ids))
|
||||
else:
|
||||
nvidia_smi = subprocess.check_output(['nvidia-smi', '--list-gpus'])
|
||||
num_gpus = len(nvidia_smi.decode('utf-8').strip().split('\n'))
|
||||
else:
|
||||
assert get_accelerator().device_name() == 'xpu'
|
||||
num_accelerators = len(nvidia_smi.decode('utf-8').strip().split('\n'))
|
||||
elif get_accelerator().device_name() == 'xpu':
|
||||
import re
|
||||
clinfo = subprocess.check_output(['clinfo'])
|
||||
lines = clinfo.decode('utf-8').strip().split('\n')
|
||||
num_gpus = 0
|
||||
num_accelerators = 0
|
||||
for line in lines:
|
||||
match = re.search('Device Type.*GPU', line)
|
||||
if match:
|
||||
num_gpus += 1
|
||||
num_accelerators += 1
|
||||
else:
|
||||
assert get_accelerator().device_name() == 'cpu'
|
||||
cpu_sockets = int(
|
||||
subprocess.check_output('cat /proc/cpuinfo | grep "physical id" | sort -u | wc -l', shell=True))
|
||||
num_accelerators = cpu_sockets
|
||||
|
||||
cuda_visible = ",".join(map(str, range(num_gpus)))
|
||||
cuda_visible = ",".join(map(str, range(num_accelerators)))
|
||||
|
||||
# rotate list based on xdist worker id, example below
|
||||
# wid=0 -> ['0', '1', '2', '3']
|
||||
|
|
|
@ -54,7 +54,7 @@ class TestModelProfiling(DistributedTest):
|
|||
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||
|
||||
pipe = pipeline(task, model, framework="pt", device=local_rank)
|
||||
pipe = pipeline(task, model, framework="pt", device=get_accelerator().device_name(local_rank))
|
||||
pipe.model = deepspeed.init_inference(pipe.model,
|
||||
dtype=dtype,
|
||||
mp_size=world_size,
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
import argparse
|
||||
import pytest
|
||||
import deepspeed
|
||||
from deepspeed.launcher.launch import parse_range_list
|
||||
from deepspeed.utils.numa import parse_range_list
|
||||
|
||||
|
||||
def basic_parser():
|
||||
|
|
Загрузка…
Ссылка в новой задаче