[CPU] Support Intel CPU inference (#3041)

* add fallback path for kernels used in megatron

* temporary numactl WA for SPR 56core

* adapt core allocation according to number of ranks

* add switch to turn on numactl

* detect number of cores on the system

* allow select a subset of the cores on the system to bind

* remove unneeded changes

* add ccl backend

* change nccl to ccl

* remove unused code

* add comm/ccl to ops

* initial ccl comm support

* first broadcast case passed

* add CCL_Backend to DeepSpeed

* support comm timer for CPU

* support barrier for comm backend

* support specify master address from deepspeed command line

* support pytorch 2.0

* remove 'block' from api

* Tweak for debug

Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com>

* Remove unecessary directory

Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com>

* Add bf16 kernel support for inference

* Add temporary torch implement for cpu inference

* Add softmax ops cpu fallback for inference

* bind cores to numa domain as well

* merge latest change in gma/numactl

* initial bf16 kernel support with fallback path

* initial fallback path for bloom kernel injection

* fix softmax attn mask

* check KMP_AFFINITY to avoid conflict with numactl

* New CCLBackend which utilize TorchBackend for initialization

* rollback last change because there is result error

* fix bloom injection policy TP could not work issue.

injection_policy={BloomBlock: ("self_attention.dense", "mlp.dense_4h_to_h")}

* Use TorchBackend to initialize CCLBackend, make behavior consistent

* remove comm under deepspeed/ops

* add license header

* code clean up

* fix format issue

* remove magic number in main address

* add caching support but not turn on by default

* change name of inference_cuda_module to inference_module

* Check for is_synchronized_device in accelerator before get Event

* fix typo

* Fix fallback path of softmax kernel on CUDA device for BF16 data type, because CUDA tril does not support BF16 datatype, enforce fp32 data type

* add cpu backend files

* change CPU_Accelerator op_builder_dir

* remove cpu_kernel_path

* using CPU_Accelerator on non-cuda device

* fix deepspeed.op_builder => deepspeed.ops.op_builder

* add alias for num_gpus: num_accelerators

* allow loading cpu_builder in build stage

* Assume cuda available if torch not installed

* add oneccl_binding_pt to requirements

* move oneccl-binding-pt to seperate requiremetns-cpu.txt

* add missing file

* use dependency_links in setuptools.setup() call for additional dependency links

* install oneccl_bind_pt in workflows

* change oneccl_bind_pt's version from 1.13 to 2.0

* use intel_exention_for_pytorch as indicator that CPU_Accelerator should be used

* Add indicator for Accelerator used

* change foo.c to foo.cpp

* exclude 'cpu' directory in CUDA op builder reflection

* add a cpu-inference workflow

* run cpu-inference workflow on self-hosted instance

* change cpu runs-on node to v100 node

* print out python version in workflow

* add verbose in pip command to understand oneccl_bind_pt install issue

* update cpu-inference workflow

* add a stage to detect instance instruction sets

* add back bf16 support for CPU inference

* enable autoTP for bloom

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* update workflow to detect cpu instruction sets

* temporary WA for Intel Extension for PyTorch AVX2 instructioon set detection

* change cpu-inference workflow machine to ubuntu-20.04

* add sharded checkpoint loading for AutoTP path to reduce the peak memory in initialization stage

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* enable policy for llama

* use a special build ipex to test avx2 detection fix

* fix format

* fix test fail issue

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix gptj sharded checkpoint loading problem

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* return a not implemented build in get_op_builder in cpu_backend

* support cpu device in tests

* use cpuinfo to extract number of CPUs

* use ~/tmp as transfomer cache rather than /blob/

* Add support for mpich launcher with prefer_deepspeed_comm

* add missing modification in accelerator

* enable IMPI launcher

* remove unused file and fix formatting

* clean up ccl.cpp

* Less confusing error message when certin op builder are not implemented

* Fix license header

* Add license header

* add license headers

* add license header

* fix cuda specific code in test

* update CPU workflow

* use numactl to bind to core

* allow bind_cores_to_rank in multi-node impi runner

* fix format error

* Remove InferenceBuilder

* fix format error in numa.py

* check whether op is in installed ops in ds_report.py

* allow override accelerator with DS_ACCELERATOR='cuda','cpu' or 'xpu'

* lazy init class_dict in CUDA_Accelerator to avoid cyclic initialization of CUDA_Accelerator

* put short path in the beginning in real_accelerator.py

* device_count return number of NUMA nodes

* fix typo

* install numactl in cpu workflow

* Follow comments

* Better implementation of device_count() and current_device()

* remove dependency_link for Intel Extension for DeepSpeed

* use check is_synchronized_device in timer only once

* remove env mapping WA in cpu_accelerator

* fix duplicate definition

* fix format error

* refine ccl backend selection

* move comments to the right place

* remove prefer_deepspeed_comm, use CCLBackend by default

* refractor fallback path

* Fix execution failure in kernel injection path

* do not refractory kernel injection fallback path in  residual_add because it contains function call with side-effect

* guard residual_add fallback path with environ DS_KI_FALLBACK=True

* fix format error

* add test for allreduce on CPU workflow

* fix format error

* Fallback to TorchBackend if CCLBackend kernel are not implemented

* Update Intel Extension for Pytorch installation link

* Don't specify version number of Intel Extension for PyTorch

* install oneCCL for CCLBackend

* fix link path for CPU comm kernels

* fix source oneCCL environment

* source oneCCL env before run UT

* Give more specific instruction when CCL_ROOT not defined

---------

Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com>
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: sdp <sdp@aia-sdp-spr-108864.jf.intel.com>
Co-authored-by: Cao, Zhong Z <zhong.z.cao@intel.com>
Co-authored-by: Zhenhuan Chen <zhenhuan.chen@intel.com>
Co-authored-by: baodii <di.bao@intel.com>
Co-authored-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: jianan-gu <jianan.gu@intel.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
This commit is contained in:
Ma, Guokai 2023-05-16 23:59:22 +08:00 коммит произвёл GitHub
Родитель 5147b90aa4
Коммит 1f72082fc0
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
43 изменённых файлов: 1414 добавлений и 329 удалений

83
.github/workflows/cpu-inference.yml поставляемый Normal file
Просмотреть файл

@ -0,0 +1,83 @@
name: cpu-inference
on:
push:
branches:
- 'staging**'
paths-ignore:
- 'docs/**'
pull_request:
paths-ignore:
- 'docs/**'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
unit-tests:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v2
- id: setup-venv
uses: ./.github/workflows/setup-venv
- name: Detect instruction sets on instance
run: |
lscpu
pip install cmake
git clone https://github.com/intel/intel-extension-for-pytorch
cd intel-extension-for-pytorch/tests/cpu/isa
cmake .
make
./cpu_features
- name: Install numactl
run: |
sudo apt-get install -y numactl
- name: Install oneCCL Bindings for PyTorch
run: |
python -m pip install intel_extension_for_pytorch
python -m pip install oneccl_bind_pt==2.0 -f https://developer.intel.com/ipex-whl-stable-cpu
- name: Install oneCCL
run: |
git clone https://github.com/oneapi-src/oneCCL
cd oneCCL
mkdir build
cd build
cmake ..
make
make install
#source ./_install/env/setvars.sh
# test whether oneCCL is correctly installed
#mpirun -n 2 ./examples/benchmark/benchmark
- name: Install transformers
run: |
git clone https://github.com/huggingface/transformers
cd transformers
git rev-parse --short HEAD
pip install .
- name: Install deepspeed
run: |
# check why the host does not have AVX2 support
pip install .[dev,1bit,autotuning,inf]
ds_report
- name: Python environment
run: |
pip list
- name: Unit tests
run: |
source oneCCL/build/_install/env/setvars.sh
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
cd tests
TRANSFORMERS_CACHE=~/tmp/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'inference' unit/inference/test_inference_config.py
TRANSFORMERS_CACHE=~/tmp/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -k TestDistAllReduce unit/comm/test_dist.py

Просмотреть файл

@ -13,6 +13,10 @@ class DeepSpeedAccelerator(ABC):
self._name = None
self._communication_backend_name = None
@abc.abstractmethod
def is_synchronized_device(self):
...
# Device APIs
@abc.abstractmethod
def device_name(self, device_index):

Просмотреть файл

@ -0,0 +1,260 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from deepspeed.accelerator.abstract_accelerator import DeepSpeedAccelerator
import oneccl_bindings_for_pytorch # noqa: F401
import psutil
import os
# accelerator for Intel CPU
class CPU_Accelerator(DeepSpeedAccelerator):
def __init__(self):
self._name = 'cpu'
self._communication_backend_name = 'ccl'
self.max_mem = psutil.Process().memory_info().rss
def is_synchronized_device(self):
return True
# Device APIs
def device_name(self, device_index=None):
return 'cpu'
def device(self, device_index=None):
return None
def set_device(self, device_index):
return
def current_device(self):
return os.environ.get('LOCAL_RANK', 0)
def current_device_name(self):
return 'cpu'
def device_count(self):
device_count = int(os.environ.get('LOCAL_SIZE', 0))
if device_count > 0:
return os.environ.get('LOCAL_SIZE')
else:
from deepspeed.utils.numa import get_numa_cores
# Count NUMA node for number of cpu accelerators. On machine with HBM
# In flat mode, HBM is in separate NUMA node with no cores on this node.
# Ignore these NUMA nodes with no cores.
numa_core_lists = get_numa_cores()
numa_count = 0
for core_list in numa_core_lists:
if len(core_list) > 0:
numa_count += 1
return numa_count
def synchronize(self, device_index=None):
return
# RNG APIs
def random(self):
return torch.random
def set_rng_state(self, new_state, device_index=None):
if device_index == None:
return torch.set_rng_state(new_state)
return torch.set_rng_state(new_state, device_index)
def get_rng_state(self, device_index=None):
return torch.get_rng_state()
def manual_seed(self, seed):
return torch.manual_seed(seed)
def manual_seed_all(self, seed):
return torch.manual_seed(seed)
def initial_seed(self, seed):
return torch.initial_seed(seed)
def default_generator(self, device_index):
return torch.default_generator
# Streams/Events
@property
def Stream(self):
return None
def stream(self, stream):
from deepspeed.runtime.utils import noop_decorator
return noop_decorator
def current_stream(self, device_index=None):
return None
def default_stream(self, device_index=None):
return None
@property
def Event(self):
return None
# Memory management
def empty_cache(self):
return
def get_rss(self):
mem = psutil.Process().memory_info().rss
if mem > self.max_mem:
self.max_mem = mem
return mem
def reset_rss(self):
mem = psutil.Process().memory_info().rss
self.max_mem = mem
return mem
def memory_allocated(self, device_index=None):
return self.get_rss()
def max_memory_allocated(self, device_index=None):
self.get_rss()
return self.max_mem
def reset_max_memory_allocated(self, device_index=None):
self.reset_rss()
return
def memory_cached(self, device_index=None):
return self.get_rss()
def max_memory_cached(self, device_index=None):
self.get_rss()
return self.max_mem
def reset_max_memory_cached(self, device_index=None):
self.reset_rss()
return
def memory_stats(self, device_index=None):
return self.get_rss()
def reset_peak_memory_stats(self, device_index=None):
self.reset_rss()
return
def memory_reserved(self, device_index=None):
return self.get_rss()
def max_memory_reserved(self, device_index=None):
self.get_rss()
return self.max_mem
def total_memory(self, device_index=None):
return psutil.virtual_memory().total
# Misc
def amp(self):
return torch.cpu.amp
def is_available(self):
return True
def range_push(self, msg):
# TODO itt is currently not supported yet
# return torch.profiler.itt.range_push(msg)
return
def range_pop(self):
# TODO itt is currently not supported yet
# return torch.profiler.itt.range_pop()
return
def lazy_call(self, callback):
return callback()
def communication_backend_name(self):
return self._communication_backend_name
# Data types
def is_bf16_supported(self):
return True
def is_fp16_supported(self):
return True
# Tensor operations
@property
def BFloat16Tensor(self):
return torch.BFloat16Tensor
@property
def ByteTensor(self):
return torch.ByteTensor
@property
def DoubleTensor(self):
return torch.DoubleTensor
@property
def FloatTensor(self):
return torch.FloatTensor
@property
def HalfTensor(self):
return torch.HalfTensor
@property
def IntTensor(self):
return torch.IntTensor
@property
def LongTensor(self):
return torch.LongTensor
def pin_memory(self, tensor):
return tensor
def op_builder_dir(self):
try:
# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
# if successful this also means we're doing a local install and not JIT compile path
from op_builder import __deepspeed__ # noqa: F401
return "op_builder.cpu"
except ImportError:
return "deepspeed.ops.op_builder.cpu"
def on_accelerator(self, tensor):
device_str = str(tensor.device)
if device_str.startswith('cpu'):
return True
else:
return False
# create an instance of op builder and return, name specified by class_name
def create_op_builder(self, op_name):
builder_class = self.get_op_builder(op_name)
if builder_class != None:
return builder_class()
return None
# return an op builder class, name specified by class_name
def get_op_builder(self, class_name):
try:
# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
# if successful this also means we're doing a local install and not JIT compile path
from op_builder import __deepspeed__ # noqa: F401
from op_builder.cpu import CCLCommBuilder, NotImplementedBuilder
except ImportError:
from deepspeed.ops.op_builder.cpu import CCLCommBuilder, NotImplementedBuilder
if class_name == "CCLCommBuilder":
return CCLCommBuilder
else:
# return a NotImplementedBuilder to avoid get NoneType[Name] in unit tests
return NotImplementedBuilder
def build_extension(self):
from torch.utils.cpp_extension import BuildExtension
return BuildExtension

Просмотреть файл

@ -22,21 +22,8 @@ class CUDA_Accelerator(DeepSpeedAccelerator):
self._name = 'cuda'
self._communication_backend_name = 'nccl'
# begin initialize for create_op_builder()
# put all valid class name <--> class type mapping into class_dict
op_builder_dir = self.op_builder_dir()
op_builder_module = importlib.import_module(op_builder_dir)
for _, module_name, _ in pkgutil.iter_modules([os.path.dirname(op_builder_module.__file__)]):
# avoid self references
if module_name != 'all_ops' and module_name != 'builder':
module = importlib.import_module("{}.{}".format(op_builder_dir, module_name))
for member_name in module.__dir__():
if member_name.endswith(
'Builder'
) and member_name != "OpBuilder" and member_name != "CUDAOpBuilder" and member_name != "TorchCPUOpBuilder": # avoid abstract classes
if not member_name in self.class_dict:
self.class_dict[member_name] = getattr(module, member_name)
# end initialize for create_op_builder()
def is_synchronized_device(self):
return False
# Device APIs
def device_name(self, device_index=None):
@ -235,10 +222,32 @@ class CUDA_Accelerator(DeepSpeedAccelerator):
# dict that holds class name <--> class type mapping i.e.
# 'AsyncIOBuilder': <class 'op_builder.async_io.AsyncIOBuilder'>
# this dict will be filled at init stage
class_dict = {}
class_dict = None
def _lazy_init_class_dict(self):
if self.class_dict != None:
return
else:
self.class_dict = {}
# begin initialize for create_op_builder()
# put all valid class name <--> class type mapping into class_dict
op_builder_dir = self.op_builder_dir()
op_builder_module = importlib.import_module(op_builder_dir)
for _, module_name, _ in pkgutil.iter_modules([os.path.dirname(op_builder_module.__file__)]):
# avoid self references
if module_name != 'all_ops' and module_name != 'builder' and module_name != 'cpu':
module = importlib.import_module("{}.{}".format(op_builder_dir, module_name))
for member_name in module.__dir__():
if member_name.endswith(
'Builder'
) and member_name != "OpBuilder" and member_name != "CUDAOpBuilder" and member_name != "TorchCPUOpBuilder": # avoid abstract classes
if not member_name in self.class_dict:
self.class_dict[member_name] = getattr(module, member_name)
# end initialize for create_op_builder()
# create an instance of op builder and return, name specified by class_name
def create_op_builder(self, class_name):
self._lazy_init_class_dict()
if class_name in self.class_dict:
return self.class_dict[class_name]()
else:
@ -246,6 +255,7 @@ class CUDA_Accelerator(DeepSpeedAccelerator):
# return an op builder class, name specified by class_name
def get_op_builder(self, class_name):
self._lazy_init_class_dict()
if class_name in self.class_dict:
return self.class_dict[class_name]
else:

Просмотреть файл

@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os
try:
from accelerator.abstract_accelerator import DeepSpeedAccelerator as dsa1
@ -36,25 +37,76 @@ def _validate_accelerator(accel_obj):
def get_accelerator():
global ds_accelerator
if ds_accelerator is None:
try:
from intel_extension_for_deepspeed import XPU_Accelerator
except ImportError as e:
if ds_accelerator is not None:
return ds_accelerator
accelerator_name = None
ds_set_method = None
# 1. Detect whether there is override of DeepSpeed accelerators from environment variable.
# DS_ACCELERATOR = 'cuda'|'xpu'|'cpu'
if 'DS_ACCELERATOR' in os.environ.keys():
accelerator_name = os.environ['DS_ACCELERATOR']
if accelerator_name == 'xpu':
try:
from intel_extension_for_deepspeed import XPU_Accelerator # noqa: F401
except ImportError as e:
raise ValueError(
f'XPU_Accelerator requires intel_extension_for_deepspeed, which is not installed on this system.')
elif accelerator_name == 'cpu':
try:
import intel_extension_for_pytorch # noqa: F401
except ImportError as e:
raise ValueError(
f'CPU_Accelerator requires intel_extension_for_pytorch, which is not installed on this system.')
elif accelerator_name == 'cuda':
pass
else:
ds_accelerator = XPU_Accelerator()
_validate_accelerator(ds_accelerator)
return ds_accelerator
raise ValueError(
f'DS_ACCELERATOR must be one of "cuda", "cpu", or "xpu". Value "{accelerator_name}" is not supported')
ds_set_method = 'override'
# 2. If no override, detect which accelerator to use automatically
if accelerator_name == None:
try:
from intel_extension_for_deepspeed import XPU_Accelerator # noqa: F401,F811
accelerator_name = 'xpu'
except ImportError as e:
# We need a way to choose between CUDA_Accelerator and CPU_Accelerator
# Currently we detect whether intel_etension_for_pytorch is installed
# in the environment and use CPU_Accelerator if the answewr is True.
# An alternative might be detect whether CUDA device is installed on
# the system but this comes with two pitfalls:
# 1. the system may not have torch pre-installed, so
# get_accelerator().is_avaiable() may not work.
# 2. Some scenario like install on login node (without CUDA device)
# and run on compute node (with CUDA device) may cause mismatch
# between installation time and runtime.
try:
import intel_extension_for_pytorch # noqa: F401,F811
accelerator_name = 'cpu'
except ImportError as e:
accelerator_name = 'cuda'
ds_set_method = 'auto detect'
# 3. Set ds_accelerator accordingly
if accelerator_name == 'cuda':
from .cuda_accelerator import CUDA_Accelerator
ds_accelerator = CUDA_Accelerator()
_validate_accelerator(ds_accelerator)
elif accelerator_name == 'cpu':
from .cpu_accelerator import CPU_Accelerator
ds_accelerator = CPU_Accelerator()
elif accelerator_name == 'xpu':
# XPU_Accelerator is already imported in detection stage
ds_accelerator = XPU_Accelerator()
_validate_accelerator(ds_accelerator)
print(f"Setting ds_accelerator to {ds_accelerator._name} ({ds_set_method})")
return ds_accelerator
def set_accelerator(accel_obj):
global ds_accelerator
_validate_accelerator(accel_obj)
print(f"Setting ds_accelerator to {accel_obj._name} (model specified)")
ds_accelerator = accel_obj

211
csrc/cpu/comm/ccl.cpp Normal file
Просмотреть файл

@ -0,0 +1,211 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include <torch/extension.h>
#include <oneapi/ccl.hpp>
std::set<int> _comm_ids;
std::set<int> _colors;
ccl::vector_class<ccl::communicator> _ccl_comms;
#define CCLCHECK(cmd) \
do { \
cmd; \
} while (0)
#define KVS_CREATE_SUCCESS 0
#define KVS_CREATE_FAILURE -1
bool is_initialized = 0;
int world_rank = -1;
int world_size = -1;
ccl::shared_ptr_class<ccl::kvs> kvs;
void initialize(int size, int rank, torch::Tensor& kvs_data)
{
if (is_initialized) return;
world_size = size;
world_rank = rank;
is_initialized = 1;
ccl::kvs::address_type main_addr;
if (rank != 0) {
memcpy(main_addr.data(), kvs_data.data_ptr(), main_addr.size());
kvs = ccl::create_kvs(main_addr);
}
_ccl_comms.emplace_back(ccl::create_communicator(size, rank, kvs));
}
/*
rank == 0: create main kvs and return its address
rank == else: return an empty address
*/
std::vector<uint8_t> get_kvs_addr(int rank)
{
if (rank == 0) {
kvs = ccl::create_main_kvs();
ccl::kvs::address_type main_addr = kvs->get_address();
auto ccl_kvs_addr = std::vector<uint8_t>(main_addr.begin(), main_addr.end());
return ccl_kvs_addr;
} else {
ccl::kvs::address_type main_addr;
auto ccl_kvs_addr = std::vector<uint8_t>(main_addr.begin(), main_addr.end());
return ccl_kvs_addr;
}
}
int get_rank(int group = 0) { return world_rank; }
int get_world_size(int group = 0) { return world_size; }
// Find the next ordered, unique value to a set. E.g. <0,1,2,7> --> 3
int next_unique_val(std::set<int> s)
{
std::set<int>::iterator itr;
// Base case. Add 0 to start of set.
if (s.empty() || *s.begin() != 0) {
return 0;
// second base case where s = {0} (the case of s = {n != 0} is caught above)
} else if (s.size() == 1) {
return 1;
} else {
int prev_val = *s.begin();
for (itr = std::next(s.begin()); itr != s.end(); itr++) {
if (*itr != prev_val + 1) { return prev_val + 1; }
prev_val = *itr;
}
return *(s.end()) + 1;
}
}
py::object new_group(std::vector<int> ranks)
{
int comm_id = next_unique_val(_comm_ids);
int color = next_unique_val(_colors);
std::cout << "RANK: " << get_rank() << " COMM_ID: " << comm_id << " COLOR: " << color
<< std::endl;
}
ccl::datatype get_ccl_datatype(c10::ScalarType type)
{
ccl::datatype ccl_type;
switch (type) {
case c10::ScalarType::Int: ccl_type = ccl::datatype::int32; break;
case c10::ScalarType::Float: ccl_type = ccl::datatype::float32; break;
case c10::ScalarType::Double: ccl_type = ccl::datatype::float64; break;
case c10::ScalarType::BFloat16: ccl_type = ccl::datatype::bfloat16; break;
case c10::ScalarType::Half: ccl_type = ccl::datatype::float16; break;
default: ccl_type = ccl::datatype::int8;
}
return ccl_type;
}
ccl::reduction get_ccl_reduce_op(py::object op, at::Tensor& input)
{
py::object ReduceOp = py::module_::import("deepspeed.comm").attr("ReduceOp");
if (!py::isinstance(op, ReduceOp)) {
throw std::runtime_error("Error: Op must be of type ReduceOp");
}
int op_val = py::int_(op.attr("value"));
ccl::reduction ccl_op;
if (input.scalar_type() == at::kBool) {
if (op_val == (int)py::int_(ReduceOp.attr("SUM").attr("value"))) {
// For bool tensors, map sum to max, which both represent a bitwise or.
// This is to prevent overflow issues with sum, since we use uint8 to
// represent a bool (see cclDataType mapping).
ccl_op = ccl::reduction::max;
} else if (op_val == (int)py::int_(ReduceOp.attr("AVG").attr("value"))) {
throw std::runtime_error("Error: For bool tensors, op must be of type ReduceOp");
}
}
if (op_val == (int)py::int_(ReduceOp.attr("SUM").attr("value"))) {
ccl_op = ccl::reduction::sum;
} else if (op_val == (int)py::int_(ReduceOp.attr("MIN").attr("value"))) {
ccl_op = ccl::reduction::min;
} else if (op_val == (int)py::int_(ReduceOp.attr("MAX").attr("value"))) {
ccl_op = ccl::reduction::max;
} else if (op_val == (int)py::int_(ReduceOp.attr("PRODUCT").attr("value"))) {
ccl_op = ccl::reduction::prod;
} else {
throw std::runtime_error("Error: Unrecognized ReduceOp type");
}
return ccl_op;
}
ccl::communicator& _get_comm_from_group() { return _ccl_comms[0]; }
ccl::communicator& _get_comm_from_group(py::object group) { return _ccl_comms[0]; }
void broadcast(torch::Tensor& data, int src, py::object group, bool async_op)
{
CCLCHECK(ccl::broadcast(data.data_ptr(),
data.numel(),
get_ccl_datatype(data.scalar_type()),
src,
_get_comm_from_group(group))
.wait());
}
// TODO: implement torch's async_op behavior, document it.
void all_reduce(torch::Tensor& data, py::object op, py::object group, bool async_op)
{
CCLCHECK(ccl::allreduce(data.data_ptr(),
data.data_ptr(),
data.numel(),
get_ccl_datatype(data.scalar_type()),
get_ccl_reduce_op(op, data),
_get_comm_from_group(group))
.wait());
}
void all_reduce_caching(torch::Tensor& data,
py::object op,
std::string match_id,
py::object group,
bool async_op)
{
ccl::allreduce_attr attr = ccl::default_allreduce_attr;
auto match_str = ccl::v1::string(match_id);
attr.template set<ccl::operation_attr_id::to_cache>(true);
attr.template set<ccl::operation_attr_id::match_id>(match_str);
// To control this, use operation attribute and set true value for to_cache field and unique
// string (for example, tensor name) for match_id field. Note that:
// match_id should be the same for a specific communication operation across all ranks.
// If the same tensor is a part of different communication operations, match_id should have
// different values for each of these operations.
CCLCHECK(ccl::allreduce(data.data_ptr(),
data.data_ptr(),
data.numel(),
get_ccl_datatype(data.scalar_type()),
get_ccl_reduce_op(op, data),
_get_comm_from_group(group),
attr)
.wait());
}
void barrier(py::object group, bool async_op)
{
CCLCHECK(ccl::barrier(_get_comm_from_group(group)).wait());
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("get_kvs_addr", &get_kvs_addr, "create and get main kvs addr");
m.def("initialize", &initialize, "ccl initialize");
m.def("get_rank", &get_rank, "get rank");
m.def("get_world_size", &get_world_size, "get world size");
m.def("broadcast", &broadcast, "ccl broadcast");
m.def("all_reduce", &all_reduce, "ccl all_reduce");
m.def("all_reduce_caching", &all_reduce_caching, "ccl all_reduce with caching");
m.def("barrier", &barrier, "barrier");
}

62
deepspeed/comm/ccl.py Normal file
Просмотреть файл

@ -0,0 +1,62 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
'''
Copyright 2021 The Microsoft DeepSpeed Team
'''
import torch
from deepspeed.accelerator import get_accelerator
from .reduce_op import ReduceOp
from .torch import TorchBackend
def build_ccl_op():
builder = get_accelerator().create_op_builder("CCLCommBuilder")
if builder == None:
return None
ccl_cpp_module = builder.load()
print(f'DeepSpeed {builder.absolute_name()} built successfully')
return ccl_cpp_module
class CCLBackend(TorchBackend):
def __init__(self, name='ccl', rank=-1, world_size=-1, mpu=None, timeout=None, init_method=None):
super(CCLBackend, self).__init__(backend='ccl',
name='torch',
rank=rank,
world_size=world_size,
timeout=timeout,
init_method=init_method)
self.name = 'ccl'
self.ccl_comm_op = build_ccl_op()
if self.ccl_comm_op == None:
# set CCLBackend to uninitialized state if CCLCommBuilder cannot be loaded
self.initialized = False
return
size = self.get_world_size()
rank = self.get_rank()
main_kvs = self.ccl_comm_op.get_kvs_addr(rank)
main_kvs = torch.tensor(main_kvs).to(torch.uint8)
super(CCLBackend, self).broadcast(main_kvs, 0)
self.ccl_comm_op.initialize(size, rank, main_kvs)
self.initialized = True
def is_initialized(self):
return self.initialized
def broadcast(self, tensor, src, group=None, async_op=False):
self.ccl_comm_op.broadcast(tensor, src, group, async_op)
def all_reduce(self, tensor, op=ReduceOp.SUM, group=None, async_op=False):
use_caching = False
if use_caching:
match_id = f"{tensor.size()}-{op}"
self.ccl_comm_op.all_reduce_caching(tensor, op, match_id, group, async_op)
else:
self.ccl_comm_op.all_reduce(tensor, op, group, async_op)
def barrier(self, group=None, async_op=False):
self.ccl_comm_op.barrier(group, async_op)

Просмотреть файл

@ -21,26 +21,14 @@
-- deepspeed groups API should be brought into ds.comm
"""
from enum import Enum
import torch
import os
from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT, default_pg_timeout
from .constants import *
from .reduce_op import *
from deepspeed.accelerator import get_accelerator
class ReduceOp(Enum):
SUM = 0
PRODUCT = 1
MIN = 2
MAX = 3
BAND = 4
BOR = 5
BXOR = 6
AVG = 7
UNUSED = 8
from deepspeed.comm.ccl import CCLBackend
from deepspeed.utils.comms_logging import CommsLogger
from deepspeed.utils import timer, get_caller_func
@ -50,7 +38,6 @@ from deepspeed import utils
from datetime import timedelta
# Current deepspeed.comm backend (cdb) global object for simple access by client code
use_ds_backend = False
cdb = None
# Create global timer for ops
@ -62,11 +49,20 @@ comms_logger = CommsLogger()
# Maintain objects of all initialized ds backends and assign them using the API functions in this file
nccl_backend = None
mpi_backend = None
ccl_backend = None
# This should be set here so all rank/size information from the launcher can be propagated
from deepspeed.comm.utils import *
class ProcessGroup():
def __init__(self, comm_id, ranks=[]):
self.ranks = ranks
self.comm_id = comm_id
self.size = len(ranks)
def _configure_using_config_file(config):
if config.comms_logger_enabled:
comms_logger.configure(config)
@ -143,11 +139,14 @@ def timed_op(func):
# UNUSED: Future helper function to initialize DS backends
def init_deepspeed_backend(ds_backend):
def init_deepspeed_backend(ds_backend, timeout, init_method):
global cdb
global nccl_backend
global mpi_backend
global use_ds_backend
global ccl_backend
rank = int(os.environ["RANK"])
size = int(os.environ["WORLD_SIZE"])
if ds_backend == NCCL_BACKEND:
utils.logger.warn("NCCL backend in DeepSpeed not yet implemented")
@ -155,6 +154,9 @@ def init_deepspeed_backend(ds_backend):
utils.logger.warn("MPI backend in DeepSpeed not yet implemented")
elif ds_backend == GLOO_BACKEND:
utils.logger.warn("Gloo backend in DeepSpeed not yet implemented")
elif ds_backend == CCL_BACKEND:
ccl_backend = CCLBackend(rank=rank, world_size=size, timeout=timeout, init_method=init_method)
utils.logger.info(f"Initialize {ds_backend} backend")
else:
utils.logger.warn(f"DeepSpeed does not support {ds_backend} backend")
@ -189,26 +191,23 @@ def is_available() -> bool:
return True
def set_backend(backend_name):
if not use_ds_backend:
utils.logger.error(
"DeepSpeed communication backend is required. Please use deepspeed.comm.init_distributed(backend, use_deepspeed=True) to use this functionality"
)
raise RuntimeError('Error: Custom DeepSpeed backend called without initializing DeepSpeed distributed.')
def set_backend():
global cdb
global nccl_backend
global mpi_backend
global ccl_backend
try:
if backend_name == NCCL_BACKEND:
if nccl_backend is not None and nccl_backend.is_initialized():
cdb = nccl_backend
elif backend_name == MPI_BACKEND:
if mpi_backend is not None and mpi_backend.is_initialized():
cdb = mpi_backend
except Exception as inst:
print(inst)
backend_name = get_accelerator().communication_backend_name()
if backend_name == NCCL_BACKEND:
if nccl_backend is not None and nccl_backend.is_initialized():
cdb = nccl_backend
elif backend_name == MPI_BACKEND:
if mpi_backend is not None and mpi_backend.is_initialized():
cdb = mpi_backend
elif backend_name == CCL_BACKEND:
if ccl_backend is not None and ccl_backend.is_initialized():
cdb = ccl_backend
@timed_op
@ -392,7 +391,7 @@ def scatter(tensor,
@timed_op
def barrier(group=None, async_op=False, device_ids=None, prof=False, log_name='barrier', debug=get_caller_func()):
global cdb
return cdb.barrier(group=group, async_op=async_op, device_ids=device_ids)
return cdb.barrier(group=group, async_op=async_op)
@timed_op
@ -589,6 +588,10 @@ def init_distributed(dist_backend=None,
if dist_init_required is None:
dist_init_required = cdb is None or not cdb.is_initialized()
if cdb is None:
init_deepspeed_backend(get_accelerator().communication_backend_name(), timeout, init_method)
set_backend()
utils.logger.info(f'cdb={cdb}')
if cdb is None and torch.distributed.is_initialized():
# The user initialized torch.dist themselves, create cdb and short-circuit
cdb = TorchBackend(dist_backend, timeout, init_method)

Просмотреть файл

@ -4,6 +4,7 @@
# DeepSpeed Team
NCCL_BACKEND = 'nccl'
CCL_BACKEND = 'ccl'
MPI_BACKEND = 'mpi'
GLOO_BACKEND = 'gloo'
SCCL_BACKEND = 'sccl'

Просмотреть файл

@ -0,0 +1,18 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from enum import Enum
class ReduceOp(Enum):
SUM = 0
PRODUCT = 1
MIN = 2
MAX = 3
BAND = 4
BOR = 5
BXOR = 6
AVG = 7
UNUSED = 8

Просмотреть файл

@ -50,7 +50,7 @@ def op_report(verbose=True):
for op_name, builder in ALL_OPS.items():
dots = "." * (max_dots - len(op_name))
is_compatible = OKAY if builder.is_compatible(verbose) else no
is_installed = installed if installed_ops[op_name] else no
is_installed = installed if installed_ops.get(op_name, False) else no
dots2 = '.' * ((len(h[1]) + (max_dots2 - len(h[1]))) - (len(is_installed) - color_len))
print(op_name, dots, is_installed, dots2, is_compatible)
print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1])))

Просмотреть файл

@ -418,6 +418,7 @@ class InferenceEngine(Module):
generic_injection(self.module,
fp16=(config.dtype == torch.half) or (config.dtype == torch.int8),
bf16=(config.dtype == torch.bfloat16),
enable_cuda_graph=config.enable_cuda_graph)
if isinstance(self.module, torch.nn.Module):

Просмотреть файл

@ -8,6 +8,7 @@ PDSH_MAX_FAN_OUT = 1024
OPENMPI_LAUNCHER = 'openmpi'
MPICH_LAUNCHER = 'mpich'
IMPI_LAUNCHER = 'impi'
SLURM_LAUNCHER = 'slurm'
MVAPICH_LAUNCHER = 'mvapich'
MVAPICH_TMP_HOSTFILE = '/tmp/deepspeed_mvapich_hostfile'

Просмотреть файл

@ -19,13 +19,12 @@ import base64
import time
import signal
import psutil
import distutils
from collections import defaultdict
from typing import Dict
from argparse import ArgumentParser, REMAINDER
from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT
from ..nebula.constants import DLTS_POD_ENV_PATH
from ..utils import logger
from ..utils import logger, get_numactl_cmd
from ..elasticity import is_torch_elastic_compatible
from .constants import ELASTIC_TRAINING_ID_DEFAULT
@ -130,89 +129,6 @@ def terminate_process_tree(pid):
p.kill()
def parse_range(rng):
try:
value = int(rng)
return range(value, value + 1)
except ValueError:
# value is not a single number
parts = rng.split('-')
if len(parts) != 2:
raise ValueError("Bad range: '%s', range must be either a number or two number separated by dash" %
(rng, ))
start = int(parts[0])
end = int(parts[1])
if start > end:
raise ValueError("Bad range: '%s', range end must larger than or equal to start" % (rng, ))
return range(start, end + 1)
# parse comma and dash separated range list into list
# i.e. "0,2-4,6" --> [0, 2, 3, 4, 6]
# rules:
# 1. Range list number be comma separated, each item are either a single number,
# or a range marked by two numbers (both number are included in the range)
# 2. Sub ranges must be in ascend order and not overlap with each other
# 3. No space in the range expression
def parse_range_list(range_str):
number_list = []
last = -1
range_list = range_str.split(',')
for sub_range in range_list:
sub_number_list = parse_range(sub_range)
if sub_number_list[0] <= last:
raise ValueError(
"Bad range: '%s', sub ranges must not overlap with each other and should be in ascend order" %
(range_str, ))
last = sub_number_list[-1]
number_list.extend(sub_number_list)
return number_list
# return a list of list for cores to numa mapping
# [
# [ cores for numa 0 ]
# [ cores belong to numa 1 ]
# ...
# ]
def get_numa_cores():
ret = []
output = subprocess.check_output(['numactl', '--hardware']).decode("utf-8")
lines = output.split('\n')
for line in lines:
if line.startswith('available:'):
num_numas = int(line.split(' ')[1])
break
for numa in range(num_numas):
for line in lines:
if line.startswith(f'node {numa} cpus:'):
cores = line.split(' ')[3:]
ret.append([int(core) for core in cores])
return ret
def check_for_numactl_pkg():
libs = dict(
dpkg=["-l", "numactl", "apt"],
pacman=["-Q", "numactl", "pacman"],
rpm=["-q", "numactl", "yum"],
)
found = False
for pkgmgr, data in libs.items():
flag, lib, tool = data
path = distutils.spawn.find_executable(pkgmgr)
if path is not None:
cmd = f"{pkgmgr} {flag} {lib}"
result = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
if result.wait() == 0:
found = True
else:
print(f"please install the {lib} package with {tool}")
break
return found
def main():
args = parse_args()
current_env = os.environ.copy()
@ -308,39 +224,9 @@ def main():
# spawn the processes
cmd = []
if args.bind_cores_to_rank:
check_for_numactl_pkg()
if 'KMP_AFFINITY' in os.environ.keys():
raise ValueError("Environment variable KMP_AFFINITY conflicts with numactl "
"because it interfere with how many CPU cores numactl can set. "
"Unset KMP_AFFINITY before launching deepspeed.\n\n"
"\t$ unset KMP_AFFINITY\n"
"\t$ deepspeed <deepspeed command parameters>")
if args.bind_core_list != None:
core_list = parse_range_list(args.bind_core_list)
total_cores = len(core_list)
else:
total_cores = psutil.cpu_count(logical=False)
core_list = range(total_cores)
cores_per_rank = total_cores // num_local_procs
assert cores_per_rank >= 1, "At least one core needs to be assigned to each rank"
core_list_for_rank = core_list[cores_per_rank * local_rank:cores_per_rank * (local_rank + 1)]
cores_per_rank, numactl_cmd = get_numactl_cmd(args.bind_core_list, num_local_procs, local_rank)
current_env["OMP_NUM_THREADS"] = f"{cores_per_rank}"
cmd.append("numactl")
# check if all cores belong to same numa, if true, bind process to that numa domain with -m parameter
numa_cores = get_numa_cores()
num_numas = len(numa_cores)
for i in range(num_numas):
if set(core_list_for_rank) <= set(numa_cores[i]):
cmd.append("-m")
cmd.append(f"{i}")
break
cmd.append("-C")
core_list_str = f"{core_list_for_rank[0]}"
for core_id in core_list_for_rank[1:]:
core_list_str = f"{core_list_str},{core_id}"
cmd.append(f"{core_list_str}")
cmd = cmd + numactl_cmd
if not args.no_python:
cmd.append(sys.executable)
cmd.append("-u")

Просмотреть файл

@ -11,7 +11,7 @@ import warnings
from shlex import split
from abc import ABC, abstractmethod
from deepspeed.accelerator import get_accelerator
from ..utils import logger
from ..utils import logger, get_numactl_cmd
from .constants import PDSH_MAX_FAN_OUT, MVAPICH_TMP_HOSTFILE
@ -228,6 +228,88 @@ class MPICHRunner(MultiNodeRunner):
return mpirun_cmd + export_cmd + per_host_cmd
class IMPIRunner(MultiNodeRunner):
def __init__(self, args, world_info_base64, resource_pool):
super().__init__(args, world_info_base64)
self.resource_pool = resource_pool
def backend_exists(self):
#TODO: if IB is available we should suggestion mpich
return shutil.which('mpirun') #mpich_info
@property
def name(self):
return "impi"
def validate_args(self):
super().validate_args()
#TODO: Allow for include/exclude at node-level but not gpu-level
if self.args.include != "" or self.args.exclude != "":
raise ValueError(f"{self.name} backend does not support worker include/exclusion")
if self.args.num_nodes != -1 or self.args.num_gpus != -1:
raise ValueError(f"{self.name} backend does not support limiting num nodes/gpus")
def get_cmd(self, environment, active_resources):
devices_per_node = self.resource_pool.values()
total_process_count = sum(devices_per_node)
process_per_node = list(devices_per_node)[0]
if not all([n == process_per_node for n in devices_per_node]):
raise ValueError("Intel MPI requires same number of devices per node")
mpirun_cmd = [
'mpirun',
'-ppn',
f'{process_per_node}',
] + split(self.args.launcher_args)
export_cmd = []
for k, v in self.exports.items():
export_cmd += ['-genv', f'{k}', f'{v}']
if self.args.bind_cores_to_rank:
cores_per_rank, _ = get_numactl_cmd(self.args.bind_core_list, process_per_node, 0)
export_cmd += ['-genv', 'OMP_NUM_THREADS', str(cores_per_rank)]
export_cmd += ['-genv', 'MASTER_ADDR', str(self.args.master_addr)]
export_cmd += ['-genv', 'MASTER_PORT', str(self.args.master_port)]
export_cmd += ['-genv', 'WORLD_SIZE', str(total_process_count)]
export_cmd += ['-genv', 'LOCAL_SIZE', str(process_per_node)]
export_cmd += ['-hosts']
hosts = ""
for i, host in enumerate(self.resource_pool.keys()):
if i == 0:
hosts = f"{host}"
else:
hosts += f",{host}"
export_cmd += [hosts]
per_host_cmd = []
for i in range(total_process_count):
local_rank = i % process_per_node
python_exec = []
if self.args.bind_cores_to_rank:
_, numactl_cmd = get_numactl_cmd(self.args.bind_core_list, process_per_node, local_rank)
python_exec += numactl_cmd
if not self.args.no_python:
python_exec += [sys.executable, "-u"]
if self.args.module:
python_exec.append("-m")
env_mapping = ['-env', 'RANK', str(i)]
env_mapping += ['-env', 'LOCAL_RANK', str(local_rank)]
if i == 0:
per_host_cmd = ['-n', '1'] + env_mapping + python_exec + [self.user_script] + self.user_arguments
else:
per_host_cmd = per_host_cmd + [':', '-n', '1'] + env_mapping + python_exec + [self.user_script
] + self.user_arguments
print(mpirun_cmd + export_cmd + per_host_cmd)
return mpirun_cmd + export_cmd + per_host_cmd
class SlurmRunner(MultiNodeRunner):
def __init__(self, args, world_info_base64, resource_pool):

Просмотреть файл

@ -21,8 +21,8 @@ from copy import deepcopy
import signal
import time
from .multinode_runner import PDSHRunner, OpenMPIRunner, MVAPICHRunner, SlurmRunner, MPICHRunner
from .constants import PDSH_LAUNCHER, OPENMPI_LAUNCHER, MVAPICH_LAUNCHER, SLURM_LAUNCHER, MPICH_LAUNCHER
from .multinode_runner import PDSHRunner, OpenMPIRunner, MVAPICHRunner, SlurmRunner, MPICHRunner, IMPIRunner
from .constants import PDSH_LAUNCHER, OPENMPI_LAUNCHER, MVAPICH_LAUNCHER, SLURM_LAUNCHER, MPICH_LAUNCHER, IMPI_LAUNCHER
from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT
from ..nebula.constants import NEBULA_EXPORT_ENVS
from ..utils import logger
@ -95,6 +95,7 @@ def parse_args(args=None):
"Default is num_nodes when elastic training is enabled")
parser.add_argument("--num_gpus",
"--num_accelerators",
type=int,
default=-1,
help="Max number of GPUs to use on each node, will use "
@ -116,7 +117,7 @@ def parse_args(args=None):
default=PDSH_LAUNCHER,
type=str,
help="(optional) choose launcher backend for multi-node "
"training. Options currently include PDSH, OpenMPI, MVAPICH, SLURM, MPICH.")
"training. Options currently include PDSH, OpenMPI, MVAPICH, SLURM, MPICH, IMPI.")
parser.add_argument("--launcher_args",
default="",
@ -504,6 +505,8 @@ def main(args=None):
runner = OpenMPIRunner(args, world_info_base64, resource_pool)
elif args.launcher == MPICH_LAUNCHER:
runner = MPICHRunner(args, world_info_base64, resource_pool)
elif args.launcher == IMPI_LAUNCHER:
runner = IMPIRunner(args, world_info_base64, resource_pool)
elif args.launcher == MVAPICH_LAUNCHER:
runner = MVAPICHRunner(args, world_info_base64, resource_pool)
elif args.launcher == SLURM_LAUNCHER:

Просмотреть файл

@ -13,7 +13,7 @@ from deepspeed.ops.transformer.inference.ds_attention import DeepSpeedSelfAttent
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import InferenceBuilder
inference_cuda_module = None
inference_module = None
class DeepSpeedTransformerInference(nn.Module):
@ -48,10 +48,10 @@ class DeepSpeedTransformerInference(nn.Module):
DeepSpeedTransformerInference.layer_id += 1
data_type = torch.half if self.config.dtype == torch.int8 else self.config.dtype
global inference_cuda_module
if inference_cuda_module is None:
global inference_module
if inference_module is None:
builder = InferenceBuilder()
inference_cuda_module = builder.load()
inference_module = builder.load()
if DeepSpeedTransformerInference.layer_id == 1:
log_dist(f"DeepSpeed-Inference config: {self.config.__dict__}", [0])
@ -74,14 +74,22 @@ class DeepSpeedTransformerInference(nn.Module):
self.norm_b = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type, device=device),
requires_grad=False)
self.layer_past = None
self.allocate_workspace = inference_cuda_module.allocate_workspace_fp32 if config.dtype == torch.float32 else \
inference_cuda_module.allocate_workspace_fp16
self._alloc_workspace = True
try:
if config.dtype == torch.float32:
self.allocate_workspace = inference_module.allocate_workspace_fp32
elif config.dtype == torch.bfloat16:
self.allocate_workspace = inference_module.allocate_workspace_bf16
else:
self.allocate_workspace = inference_module.allocate_workspace_fp32
self._alloc_workspace = True
except AttributeError:
self.allocate_workspace = None
self._alloc_workspace = False
@classmethod
def reset_cache(cls):
if inference_cuda_module is not None:
inference_cuda_module.reset_cache()
if inference_module is not None:
inference_module.reset_cache()
def forward(
self,
@ -163,7 +171,7 @@ class DeepSpeedTransformerInference(nn.Module):
output = self.mlp(attention_output, input, inp_norm, self.attention.attn_ob)
if not self.config.pre_layer_norm:
output = inference_cuda_module.layer_norm(output, self.norm_w, self.norm_b, self.config.epsilon)
output = inference_module.layer_norm(output, self.norm_w, self.norm_b, self.config.epsilon)
output = output.to(input_type)
if get_present:

Просмотреть файл

@ -184,7 +184,7 @@ def _module_match(module):
return None
def generic_injection(module, fp16=False, enable_cuda_graph=True):
def generic_injection(module, fp16=False, bf16=False, enable_cuda_graph=True):
def replace_attn(child, policy):
policy_attn = policy.attention(child)
@ -199,6 +199,7 @@ def generic_injection(module, fp16=False, enable_cuda_graph=True):
hidden_size=hidden_size,
heads=heads,
fp16=fp16,
bf16=bf16,
triangular_masking=False,
max_out_tokens=4096,
)
@ -231,8 +232,8 @@ def generic_injection(module, fp16=False, enable_cuda_graph=True):
if isinstance(module, torch.nn.Module):
pass
else:
if fp16 is False:
raise ValueError("Generic injection only supported with FP16")
if fp16 is False and bf16 is False:
raise ValueError("Generic injection only supported with FP16 or BF16")
try:
import diffusers

Просмотреть файл

@ -33,6 +33,7 @@ class DeepSpeedInferenceConfig(TransformerConfig):
using model-parallel architecture. If the client model already takes care of this, there is no
need to pass this argument.
fp16: Enable half-precision computation
bf16: Enable bf16 floating point computation
pre_layer_norm: Select between Pre-LN or Post-LN transformer architecture
stochastic_mode: Enable for high performance, please note that this flag has some level of
non-determinism and can produce different results on different runs. However, we have seen

Просмотреть файл

@ -13,7 +13,7 @@ from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import InferenceBuilder
# Cuda modules will be imported if needed
inference_cuda_module = None
inference_module = None
minus_inf = -10000.0
triton_flash_attn = None
@ -77,8 +77,7 @@ class DeepSpeedDiffusersAttentionFunction(Function):
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
query, key, value = inference_cuda_module.pad_transform_fp16(query, key, value, config.heads,
do_flash_attn)
query, key, value = inference_module.pad_transform_fp16(query, key, value, config.heads, do_flash_attn)
attention_scores = (torch.matmul(query, key.transpose(-1, -2)) * scale).softmax(dim=-1)
context_layer = _transpose_for_context(torch.matmul(attention_scores, value))
@ -118,10 +117,10 @@ class DeepSpeedDiffusersAttention(nn.Module):
data_type = self.config.dtype
data_type_fp = torch.half if self.config.dtype == torch.int8 else self.config.dtype
global inference_cuda_module
if inference_cuda_module is None:
global inference_module
if inference_module is None:
builder = InferenceBuilder()
inference_cuda_module = builder.load()
inference_module = builder.load()
if DeepSpeedDiffusersAttention.layer_id == 1:
log_dist(f"DeepSpeed-Attention config: {self.config.__dict__}", [0])
@ -173,13 +172,13 @@ class DeepSpeedDiffusersAttention(nn.Module):
# https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/gpt2/modeling_gpt2.py#L191
if self.config.dtype in [torch.float16, torch.int8]:
self.score_context_func = inference_cuda_module.softmax_context_fp16
self.linear_func = inference_cuda_module.linear_layer_fp16
self.allocate_workspace = inference_cuda_module.allocate_workspace_fp16
self.score_context_func = inference_module.softmax_context_fp16
self.linear_func = inference_module.linear_layer_fp16
self.allocate_workspace = inference_module.allocate_workspace_fp16
else:
self.score_context_func = inference_cuda_module.softmax_context_fp32
self.linear_func = inference_cuda_module.linear_layer_fp32
self.allocate_workspace = inference_cuda_module.allocate_workspace_fp32
self.score_context_func = inference_module.softmax_context_fp32
self.linear_func = inference_module.linear_layer_fp32
self.allocate_workspace = inference_module.allocate_workspace_fp32
def forward(self, input, context=None, input_mask=None):
if self.config.layer_id == 0:

Просмотреть файл

@ -7,9 +7,8 @@ import json
import math
import torch
from torch.autograd import Function
#from ...inference.engine import inference_cuda_module, specialized_mode
# Cuda modules will be imported if needed
inference_cuda_module = None
# accelerator modules will be imported if needed
inference_module = None
specialized_mode = None
import torch.nn as nn
from .ds_attention import DeepSpeedSelfAttention
@ -35,6 +34,7 @@ class DeepSpeedMoEInferenceConfig(DeepSpeedInferenceConfig):
using model-parallel architecture. If the client model already takes care of this, there is no
need to pass this argument.
fp16: Enable half-precision computation
bf16: Enable bf16 floating point computation
pre_layer_norm: Select between Pre-LN or Post-LN transformer architecture
stochastic_mode: Enable for high performance, please note that this flag has some level of
non-determinism and can produce different results on different runs. However, we have seen
@ -55,6 +55,7 @@ class DeepSpeedMoEInferenceConfig(DeepSpeedInferenceConfig):
local_rank=-1,
mp_size=1,
fp16=False,
bf16=False,
q_int8=False,
pre_layer_norm=True,
stochastic_mode=False,
@ -76,9 +77,9 @@ class DeepSpeedMoEInferenceConfig(DeepSpeedInferenceConfig):
scale_attn_by_inverse_layer_idx=False):
super(DeepSpeedMoEInferenceConfig,
self).__init__(hidden_size, (intermediate_size if intermediate_size > 0 else 4 * hidden_size), heads,
num_hidden_layers, layer_norm_eps, local_rank, mp_size, fp16, q_int8, pre_layer_norm,
stochastic_mode, scale_attention, triangular_masking, local_attention, window_size,
return_tuple)
num_hidden_layers, layer_norm_eps, local_rank, mp_size, fp16, bf16, q_int8,
pre_layer_norm, stochastic_mode, scale_attention, triangular_masking, local_attention,
window_size, return_tuple)
self.moe_experts = moe_experts
self.k = k
self.capacity_factor = capacity_factor
@ -111,14 +112,12 @@ class DeepSpeedMLPFunction(Function):
def forward(ctx, input, inter_w, inter_b, config, output_b, output_w, q_scales, q_groups, merge_count, mp_group,
async_op):
if config.q_int8:
intermediate = inference_cuda_module.fused_gemm_gelu_int8(input, inter_w, inter_b, config.epsilon,
q_scales[2], (q_groups * (2**merge_count)),
config.pre_layer_norm)
output = inference_cuda_module.vector_matmul_int8(intermediate, output_w, q_scales[3], q_groups,
(merge_count))
intermediate = inference_module.fused_gemm_gelu_int8(input, inter_w, inter_b, config.epsilon, q_scales[2],
(q_groups * (2**merge_count)), config.pre_layer_norm)
output = inference_module.vector_matmul_int8(intermediate, output_w, q_scales[3], q_groups, (merge_count))
else:
mlp_gemm_func = inference_cuda_module.fused_gemm_gelu_fp16 if config.fp16 else \
inference_cuda_module.fused_gemm_gelu_fp32
mlp_gemm_func = inference_module.fused_gemm_gelu_fp16 if config.fp16 else \
inference_module.fused_gemm_gelu_fp32
output = mlp_gemm_func(input, inter_w, inter_b, output_w, config.epsilon, config.pre_layer_norm, async_op)
if mp_group is not None and dist.get_world_size(group=mp_group) > 1:
@ -188,17 +187,17 @@ class DeepSpeedMoEInference(nn.Module):
self.config = config
self.config.layer_id = DeepSpeedMoEInference.layer_id
global inference_cuda_module
global inference_module
global specialized_mode
if inference_cuda_module is None:
if inference_module is None:
specialized_mode = False
# InferenceSpecializedBuilder is not among DeepSpeed provided builder yet, so we infer by builder name string
builder = get_accelerator().create_op_builder("InferenceSpecializedBuilder")
if builder != None and builder.is_compatible():
inference_cuda_module = builder.load()
inference_module = builder.load()
specialized_mode = True
else:
inference_cuda_module = InferenceBuilder().load()
inference_module = InferenceBuilder().load()
self.config.specialized_mode = specialized_mode
assert self.config.dtype != torch.bfloat16, "DeepSpeed MoE Transformer Inference not yet tested for bfloat support"
@ -214,10 +213,10 @@ class DeepSpeedMoEInference(nn.Module):
self.res_mlp = DeepSpeedMoEMLP(config, quantize_scales, quantize_groups, merge_count, mlp_extra_grouping,
mp_group)
self.res_coef = nn.Parameter(torch.Tensor(self.config.hidden_size, 2))
self.coef_func = inference_cuda_module.softmax_fp16 if self.config.dtype in [torch.float16, torch.int8] else \
inference_cuda_module.softmax_fp32
self.vector_matmul_func = inference_cuda_module.vector_matmul_fp16 if self.config.dtype == torch.float16 else \
inference_cuda_module.vector_matmul_fp32
self.coef_func = inference_module.softmax_fp16 if self.config.dtype in [torch.float16, torch.int8] else \
inference_module.softmax_fp32
self.vector_matmul_func = inference_module.vector_matmul_fp16 if self.config.dtype == torch.float16 else \
inference_module.vector_matmul_fp32
config.mp_size = 1
self.mlp = nn.ModuleList(
@ -235,12 +234,12 @@ class DeepSpeedMoEInference(nn.Module):
print("DeepSpeed MoE Transformer Inference config is ", self.config.__dict__)
self.bias_residual_func = inference_cuda_module.bias_residual_fp16 if self.config.dtype in [torch.float16, torch.int8] else \
inference_cuda_module.bias_residual_fp32
self.ds_layernorm = inference_cuda_module.layer_norm_fp16 if self.config.dtype in [torch.float16, torch.int8] else \
inference_cuda_module.layer_norm_fp32
self.einsum_sec_sm_ecm = inference_cuda_module.einsum_sec_sm_ecm_fp16 if self.config.dtype in [torch.float16, torch.int8] else \
inference_cuda_module.einsum_sec_sm_ecm_fp32
self.bias_residual_func = inference_module.bias_residual_fp16 if self.config.dtype in [torch.float16, torch.int8] else \
inference_module.bias_residual_fp32
self.ds_layernorm = inference_module.layer_norm_fp16 if self.config.dtype in [torch.float16, torch.int8] else \
inference_module.layer_norm_fp32
self.einsum_sec_sm_ecm = inference_module.einsum_sec_sm_ecm_fp16 if self.config.dtype in [torch.float16, torch.int8] else \
inference_module.einsum_sec_sm_ecm_fp32
def res_coef_func(self, inp, async_op):
inp = self.vector_matmul_func(inp, self.res_coef, async_op)
@ -347,7 +346,7 @@ class DeepSpeedMoEInference(nn.Module):
dim=0)[dist.get_rank(group=self.expert_mp_group)]
if self.config.mlp_type == 'residual':
inference_cuda_module.moe_res_matmul(res_mlp_out, res_coef_out, output)
inference_module.moe_res_matmul(res_mlp_out, res_coef_out, output)
output = self.bias_residual_func(output, residual_add, torch.empty(1))

Просмотреть файл

@ -10,11 +10,11 @@ from deepspeed.ops.op_builder import InferenceBuilder
class BaseOp(torch.nn.Module):
inference_cuda_module = None
inference_module = None
def __init__(self, config: DeepSpeedInferenceConfig):
super(BaseOp, self).__init__()
self.config = config
if BaseOp.inference_cuda_module is None:
if BaseOp.inference_module is None:
builder = InferenceBuilder()
BaseOp.inference_cuda_module = builder.load()
BaseOp.inference_module = builder.load()

Просмотреть файл

@ -12,12 +12,18 @@ class GELUGemmOp(BaseOp):
def __init__(self, config: DeepSpeedInferenceConfig):
super(GELUGemmOp, self).__init__(config)
if self.config.dtype in [torch.float16, torch.int8]:
self.fused_gemm_gelu = self.inference_cuda_module.fused_gemm_gelu_fp16 # type: ignore
elif self.config.dtype == torch.bfloat16:
self.fused_gemm_gelu = self.inference_cuda_module.fused_gemm_gelu_bf16
else:
self.fused_gemm_gelu = self.inference_cuda_module.fused_gemm_gelu_fp32 # type: ignore
try:
if self.config.dtype in [torch.float16, torch.int8]:
self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_fp16 # type: ignore
elif self.config.dtype == torch.bfloat16:
self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_bf16
else:
self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_fp32 # type: ignore
except AttributeError:
self.fused_gemm_gelu = self.gelu_gemm_fallback
def gelu_gemm_fallback(self, input, weight, scale, bias, out, out_scale, dtype, transpose):
raise NotImplementedError
def forward(self, input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, weight_out: torch.Tensor):

Просмотреть файл

@ -12,12 +12,18 @@ class LinearOp(BaseOp):
def __init__(self, config: DeepSpeedInferenceConfig):
super(LinearOp, self).__init__(config)
if self.config.dtype in [torch.float16, torch.int8]:
self.linear_func = self.inference_cuda_module.linear_layer_fp16
elif self.config.dtype == torch.bfloat16:
self.linear_func = self.inference_cuda_module.linear_layer_bf16
else:
self.linear_func = self.inference_cuda_module.linear_layer_fp32
try:
if self.config.dtype in [torch.float16, torch.int8]:
self.linear_func = self.inference_module.linear_layer_fp16
elif self.config.dtype == torch.bfloat16:
self.linear_func = self.inference_module.linear_layer_bf16
else:
self.linear_func = self.inference_module.linear_layer_fp32
except AttributeError:
self.linear_func = self.linear_fallback
def linear_fallback(self, input, weight, bias, add_bias, do_flash_attn, num_heads, transpose):
raise NotImplementedError
def forward(self,
input: torch.Tensor,

Просмотреть файл

@ -5,7 +5,9 @@
from typing import Optional
import os
import torch
import torch.nn.functional as F
from ..config import DeepSpeedInferenceConfig
from .base import BaseOp
from deepspeed.utils.types import NormType
@ -15,21 +17,43 @@ class MLPGemmOp(BaseOp):
def __init__(self, config: DeepSpeedInferenceConfig):
super(MLPGemmOp, self).__init__(config)
try:
if self.config.norm_type == NormType.LayerNorm:
if self.config.dtype in [torch.float16, torch.int8]:
self.mlp_gemm_func = self.inference_module.mlp_gemm_fp16 # type: ignore
elif self.config.dtype == torch.bfloat16:
self.mlp_gemm_func = self.inference_module.mlp_gemm_bf16
else:
self.mlp_gemm_func = self.inference_module.mlp_gemm_fp32 # type: ignore
elif self.config.norm_type == NormType.RMSNorm:
if self.config.dtype in [torch.float16, torch.int8]:
self.mlp_gemm_func = self.inference_module.rms_mlp_gemm_fp16 # type: ignore
elif self.config.dtype == torch.bfloat16:
self.mlp_gemm_func = self.inference_module.rms_mlp_gemm_bf16
else:
self.mlp_gemm_func = self.inference_module.rms_mlp_gemm_fp32 # type: ignore
except AttributeError:
if self.config.norm_type == NormType.LayerNorm:
self.mlp_gemm_func = self.mlp_gemm_fallback
elif self.config.norm_type == NormType.RMSNorm:
self.mlp_gemm_func = self.rms_mlp_gemm_fallback
if self.config.norm_type == NormType.LayerNorm:
if self.config.dtype in [torch.float16, torch.int8]:
self.mlp_gemm_func = self.inference_cuda_module.mlp_gemm_fp16 # type: ignore
elif self.config.dtype == torch.bfloat16:
self.mlp_gemm_func = self.inference_cuda_module.mlp_gemm_bf16
else:
self.mlp_gemm_func = self.inference_cuda_module.mlp_gemm_fp32 # type: ignore
elif self.config.norm_type == NormType.RMSNorm:
if self.config.dtype in [torch.float16, torch.int8]:
self.mlp_gemm_func = self.inference_cuda_module.rms_mlp_gemm_fp16 # type: ignore
elif self.config.dtype == torch.bfloat16:
self.mlp_gemm_func = self.inference_cuda_module.rms_mlp_gemm_bf16
else:
self.mlp_gemm_func = self.inference_cuda_module.rms_mlp_gemm_fp32 # type: ignore
def mlp_gemm_fallback(self, input, residual, input_bias, weight_interm, weight_out, bias, gamma, beta, eps,
pre_layer_norm, mlp_after_attn, interm_scale, out_scale, dtype, mlp_act_func_type,
transpose):
if os.environ.get('DS_KI_FALLBACK') == 'True' and mlp_after_attn and not transpose:
residual_add = F.layer_norm(input + residual + input_bias, (input.shape[2], ), gamma, beta,
self.config.epsilon)
tmp = torch.matmul(residual_add, weight_interm)
tmp = F.gelu(tmp + bias)
output = torch.matmul(tmp, weight_out)
return (output, residual_add)
else:
raise NotImplementedError
def rms_mlp_gemm_fallback(self, input, residual, weight_interm, weight_out, gamma, eps, interm_scale, out_scale,
dtype, mlp_act_func_type, transpose):
raise NotImplementedError
def forward(self,
input: torch.Tensor,

Просмотреть файл

@ -3,7 +3,9 @@
# DeepSpeed Team
import os
import torch
import torch.nn.functional as F
from ..config import DeepSpeedInferenceConfig
from .base import BaseOp
from deepspeed.utils.types import NormType
@ -13,21 +15,40 @@ class QKVGemmOp(BaseOp):
def __init__(self, config: DeepSpeedInferenceConfig):
super(QKVGemmOp, self).__init__(config)
try:
if self.config.norm_type == NormType.LayerNorm:
if self.config.dtype in [torch.float16, torch.int8]:
self.qkv_gemm_func = self.inference_module.qkv_gemm_fp16 # type: ignore
elif self.config.dtype == torch.bfloat16:
self.qkv_gemm_func = self.inference_module.qkv_gemm_bf16
else:
self.qkv_gemm_func = self.inference_module.qkv_gemm_fp32 # type: ignore
elif self.config.norm_type == NormType.RMSNorm:
if self.config.dtype in [torch.float16, torch.int8]:
self.qkv_gemm_func = self.inference_module.rms_qkv_gemm_fp16 # type: ignore
elif self.config.dtype == torch.bfloat16:
self.qkv_gemm_func = self.inference_module.rms_qkv_gemm_bf16
else:
self.qkv_gemm_func = self.inference_module.rms_qkv_gemm_fp32 # type: ignore
except AttributeError:
if self.config.norm_type == NormType.LayerNorm:
self.qkv_gemm_func = self.qkv_gemm_fallback
elif self.config.norm_type == NormType.RMSNorm:
self.qkv_gemm_func = self.rms_qkv_gemm_fallback
if self.config.norm_type == NormType.LayerNorm:
if self.config.dtype in [torch.float16, torch.int8]:
self.qkv_gemm_func = self.inference_cuda_module.qkv_gemm_fp16 # type: ignore
elif self.config.dtype == torch.bfloat16:
self.qkv_gemm_func = self.inference_cuda_module.qkv_gemm_bf16
else:
self.qkv_gemm_func = self.inference_cuda_module.qkv_gemm_fp32 # type: ignore
elif self.config.norm_type == NormType.RMSNorm:
if self.config.dtype in [torch.float16, torch.int8]:
self.qkv_gemm_func = self.inference_cuda_module.rms_qkv_gemm_fp16 # type: ignore
elif self.config.dtype == torch.bfloat16:
self.qkv_gemm_func = self.inference_cuda_module.rms_qkv_gemm_bf16
else:
self.qkv_gemm_func = self.inference_cuda_module.rms_qkv_gemm_fp32 # type: ignore
def qkv_gemm_fallback(self, input, weight, q_scale, bias, gamma, beta, eps, add_bias, q_int8, transpose):
if os.environ.get('DS_KI_FALLBACK') == 'True' and not transpose:
inp_norm = F.layer_norm(input, (input.shape[2], ), gamma, beta, eps)
tmp = torch.matmul(inp_norm, weight)
if add_bias:
tmp += bias
output = [tmp, inp_norm]
return output
else:
raise NotImplementedError
def rms_qkv_gemm_fallback(self, input, weight, q_scale, gamma, eps, q_int8, transpose):
raise NotImplementedError
def forward(self, input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, gamma: torch.Tensor,
beta: torch.Tensor):

Просмотреть файл

@ -3,6 +3,7 @@
# DeepSpeed Team
import os
import torch
from typing import Optional
from ..config import DeepSpeedInferenceConfig
@ -13,13 +14,17 @@ class ResidualAddOp(BaseOp):
def __init__(self, config: DeepSpeedInferenceConfig):
super(ResidualAddOp, self).__init__(config)
if self.config.dtype in [torch.float16, torch.int8]:
self.residual_add_func = self.inference_cuda_module.residual_add_bias_fp16
elif self.config.dtype == torch.bfloat16:
self.residual_add_func = self.inference_cuda_module.residual_add_bias_bf16
else:
self.residual_add_func = self.inference_cuda_module.residual_add_bias_fp32
self._vector_add = self.inference_cuda_module._vector_add
try:
if self.config.dtype in [torch.float16, torch.int8]:
self.residual_add_func = self.inference_module.residual_add_bias_fp16
elif self.config.dtype == torch.bfloat16:
self.residual_add_func = self.inference_module.residual_add_bias_bf16
else:
self.residual_add_func = self.inference_module.residual_add_bias_fp32
self._vector_add = self.inference_module._vector_add
except AttributeError:
self.residual_add_func = None
self._vector_add = None
def forward(self,
hidden_state: torch.Tensor,
@ -30,14 +35,28 @@ class ResidualAddOp(BaseOp):
attention_bias: Optional[torch.Tensor] = None,
final_bias: Optional[torch.Tensor] = None):
if final_bias is None:
residual = self._vector_add(residual, hidden_state, 1.0 / self.config.mp_size)
else:
if not self.config.pre_layer_norm and residual_add is not None:
# only use residual add if its set and we are not pre layer norm
residual = residual_add
if self.residual_add_func != None:
if final_bias is None:
residual = self._vector_add(residual, hidden_state, 1.0 / self.config.mp_size)
else:
if not self.config.pre_layer_norm and residual_add is not None:
# only use residual add if its set and we are not pre layer norm
residual = residual_add
self.residual_add_func(hidden_state, residual, attention_output, attention_bias, final_bias,
self.config.mp_size, self.config.mlp_after_attn, add_bias,
self.config.pre_layer_norm)
self.residual_add_func(hidden_state, residual, attention_output, attention_bias, final_bias,
self.config.mp_size, self.config.mlp_after_attn, add_bias,
self.config.pre_layer_norm)
else:
# fallback
if os.environ.get('DS_KI_FALLBACK') == 'True' and self.config.mlp_after_attn:
if self.config.pre_layer_norm:
tmp = (residual.float() + attention_output.float() + attention_bias.float() +
final_bias.float()) / self.config.mp_size + hidden_state.float()
else:
tmp = residual.float() + hidden_state.float() + final_bias.float()
input_dtype = hidden_state.dtype
residual = tmp.to(input_dtype)
else:
raise NotImplementedError
return residual

Просмотреть файл

@ -3,7 +3,9 @@
# DeepSpeed Team
import os
import torch
import torch.nn.functional as F
from ..config import DeepSpeedInferenceConfig
from .base import BaseOp
@ -12,19 +14,40 @@ class SoftmaxOp(BaseOp):
def __init__(self, config: DeepSpeedInferenceConfig):
super(SoftmaxOp, self).__init__(config)
if self.config.dtype in [torch.float16, torch.int8]:
self.softmax_func = self.inference_cuda_module.softmax_fp16
elif self.config.dtype == torch.bfloat16:
self.softmax_func = self.inference_cuda_module.softmax_bf16
else:
self.softmax_func = self.inference_cuda_module.softmax_fp32
self.num_attention_heads_per_partition = config.heads // config.mp_size
try:
if self.config.dtype in [torch.float16, torch.int8]:
self.softmax_func = self.inference_module.softmax_fp16
elif self.config.dtype == torch.bfloat16:
self.softmax_func = self.inference_module.softmax_bf16
else:
self.softmax_func = self.inference_module.softmax_fp32
except AttributeError:
self.softmax_func = self.softmax_fallback
def _not_implemented(self, *args, **kwargs):
raise NotImplementedError
def softmax_fallback(self, attn_scores, attn_mask, alibi, triangular, recompute, local_attention, window_size,
async_op, layer_scale, head_offset, mp_size):
if os.environ.get('DS_KI_FALLBACK') == 'True':
alibi = alibi[head_offset:head_offset + self.num_attention_heads_per_partition]
input_dtype = attn_scores.dtype
if (triangular):
tri = ~torch.tril(torch.ones(attn_scores.size(), device=attn_scores.device)).to(bool)
attn_scores = torch.masked_fill(attn_scores * layer_scale, tri, torch.finfo(input_dtype).min)
if alibi is not None:
attn_scores += alibi
if attn_mask is not None:
# expand atten_mask from two dim into 4 dim, insert two dims in the middle
attn_mask = attn_mask[:, None, None, :]
attn_scores += attn_mask
output = F.softmax(attn_scores, dim=-1, dtype=torch.float32).to(input_dtype)
return output
else:
raise NotImplementedError
def forward(self, attn_scores: torch.Tensor, attn_mask: torch.Tensor, alibi: torch.Tensor, triangular: bool,
recompute: bool, local_attention: bool, window_size: int, async_op: bool, layer_scale: float,
head_offset: int):
output = self.softmax_func(attn_scores, attn_mask, alibi, triangular, recompute, local_attention, window_size,
async_op, layer_scale, head_offset, self.config.mp_size)
return output

Просмотреть файл

@ -13,12 +13,20 @@ class SoftmaxContextOp(BaseOp):
def __init__(self, config: DeepSpeedInferenceConfig):
super(SoftmaxContextOp, self).__init__(config)
if self.config.dtype in [torch.float16, torch.int8]:
self.softmax_context_func = self.inference_cuda_module.softmax_context_fp16
elif self.config.dtype == torch.bfloat16:
self.softmax_context_func = self.inference_cuda_module.softmax_context_bf16
else:
self.softmax_context_func = self.inference_cuda_module.softmax_context_fp32
try:
if self.config.dtype in [torch.float16, torch.int8]:
self.softmax_context_func = self.inference_module.softmax_context_fp16
elif self.config.dtype == torch.bfloat16:
self.softmax_context_func = self.inference_module.softmax_context_bf16
else:
self.softmax_context_func = self.inference_module.softmax_context_fp32
except AttributeError:
self.softmax_context_func = self.softmax_context_fallback
def softmax_context_fallback(self, query_key_value, attn_mask, rotary_dim, rotate_half, roteate_every_two, heads,
norm_factor, triangular_masking, local_attention, window_size, no_masking, layer_id,
num_layers, alibi):
raise NotImplementedError
def forward(self, query_key_value: torch.Tensor, attn_mask: torch.Tensor, heads: int, norm_factor: float,
no_masking: bool, layer_id: int, num_layers: int, alibi: torch.Tensor):

Просмотреть файл

@ -3,6 +3,7 @@
# DeepSpeed Team
import os
import torch
from ..config import DeepSpeedInferenceConfig
from .base import BaseOp
@ -12,12 +13,21 @@ class VectorMatMulOp(BaseOp):
def __init__(self, config: DeepSpeedInferenceConfig):
super(VectorMatMulOp, self).__init__(config)
if self.config.dtype in [torch.float16, torch.int8]:
self.vector_matmul_func = self.inference_cuda_module.vector_matmul_fp16
elif self.config.dtype == torch.bfloat16:
self.vector_matmul_func = self.inference_cuda_module.vector_matmul_bf16
try:
if self.config.dtype in [torch.float16, torch.int8]:
self.vector_matmul_func = self.inference_module.vector_matmul_fp16
elif self.config.dtype == torch.bfloat16:
self.vector_matmul_func = self.inference_module.vector_matmul_bf16
else:
self.vector_matmul_func = self.inference_module.vector_matmul_fp32
except AttributeError:
self.vector_matmul_func = self.vector_matmul_fallback
def vector_matmul_fallback(self, input, weight, async_op, q_scale, q_int8, transpose):
if os.environ.get('DS_KI_FALLBACK') == 'True' and not transpose:
return torch.matmul(input, weight)
else:
self.vector_matmul_func = self.inference_cuda_module.vector_matmul_fp32
raise NotImplementedError
def forward(self, input: torch.Tensor, weight: torch.Tensor, async_op: bool = False):
q_scale = weight.scale if hasattr(weight, 'scale') else torch.empty(1)

Просмотреть файл

@ -14,3 +14,4 @@ from .tensor_fragment import tensor_fragment, get_full_hp_param, get_hp_fragment
from .tensor_fragment import safe_get_full_fp32_param, safe_get_full_grad, safe_get_full_optimizer_state
from .mixed_precision_linkage import link_hp_params
from deepspeed.runtime.dataloader import RepeatingLoader
from .numa import get_numactl_cmd

148
deepspeed/utils/numa.py Normal file
Просмотреть файл

@ -0,0 +1,148 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
# return a list of list for cores to numa mapping
# [
# [ cores for numa 0 ]
# [ cores belong to numa 1 ]
# ...
# ]
import distutils
import os
import psutil
import subprocess
# return a list of list for cores to numa mapping
# [
# [ cores for numa 0 ]
# [ cores belong to numa 1 ]
# ...
# ]
def get_numa_cores():
ret = []
output = subprocess.check_output(['numactl', '--hardware']).decode("utf-8")
lines = output.split('\n')
for line in lines:
if line.startswith('available:'):
num_numas = int(line.split(' ')[1])
break
for numa in range(num_numas):
for line in lines:
if line.startswith(f'node {numa} cpus:'):
cores = line.split(' ')[3:]
ret.append([int(core) for core in cores])
return ret
def check_for_numactl_pkg():
libs = dict(
dpkg=["-l", "numactl", "apt"],
pacman=["-Q", "numactl", "pacman"],
rpm=["-q", "numactl", "yum"],
)
found = False
for pkgmgr, data in libs.items():
flag, lib, tool = data
path = distutils.spawn.find_executable(pkgmgr)
if path is not None:
cmd = f"{pkgmgr} {flag} {lib}"
result = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
if result.wait() == 0:
found = True
else:
print(f"please install the {lib} package with {tool}")
break
return found
def parse_range(rng):
try:
value = int(rng)
return range(value, value + 1)
except ValueError:
# value is not a single number
parts = rng.split('-')
if len(parts) != 2:
raise ValueError("Bad range: '%s', range must be either a number or two number separated by dash" %
(rng, ))
start = int(parts[0])
end = int(parts[1])
if start > end:
raise ValueError("Bad range: '%s', range end must larger than or equal to start" % (rng, ))
return range(start, end + 1)
# parse comma and dash separated range list into list
# i.e. "0,2-4,6" --> [0, 2, 3, 4, 6]
# rules:
# 1. Range list number be comma separated, each item are either a single number,
# or a range marked by two numbers (both number are included in the range)
# 2. Sub ranges must be in ascend order and not overlap with each other
# 3. No space in the range expression
def parse_range_list(range_str):
number_list = []
last = -1
range_list = range_str.split(',')
for sub_range in range_list:
sub_number_list = parse_range(sub_range)
if sub_number_list[0] <= last:
raise ValueError(
"Bad range: '%s', sub ranges must not overlap with each other and should be in ascend order" %
(range_str, ))
last = sub_number_list[-1]
number_list.extend(sub_number_list)
return number_list
def get_numactl_cmd(bind_core_list, num_local_procs, local_rank):
numactl_cmd = []
check_for_numactl_pkg()
if 'KMP_AFFINITY' in os.environ.keys():
raise ValueError("Environment variable KMP_AFFINITY conflicts with numactl "
"because it interfere with how many CPU cores numactl can set. "
"Unset KMP_AFFINITY before launching deepspeed.\n\n"
"\t$ unset KMP_AFFINITY\n"
"\t$ deepspeed <deepspeed command parameters>")
if bind_core_list != None:
core_list = parse_range_list(bind_core_list)
total_cores = len(core_list)
else:
total_cores = psutil.cpu_count(logical=False)
core_list = range(total_cores)
cores_per_rank = total_cores // num_local_procs
assert cores_per_rank >= 1, "At least one core needs to be assigned to each rank"
core_list_for_rank = core_list[cores_per_rank * local_rank:cores_per_rank * (local_rank + 1)]
numactl_cmd.append("numactl")
# check if all cores belong to same numa, if true, bind process to that numa domain with -m parameter
numa_cores = get_numa_cores()
num_numas = len(numa_cores)
for i in range(num_numas):
if set(core_list_for_rank) <= set(numa_cores[i]):
numactl_cmd.append("-m")
numactl_cmd.append(f"{i}")
break
numactl_cmd.append("-C")
last_core = core_list_for_rank[0]
first_core = last_core
core_list_str = f"{last_core}"
for core_id in core_list_for_rank[1:]:
if core_id == last_core + 1:
last_core = core_id
continue
else:
if first_core == last_core:
core_list_str = f"{core_list_str},{core_id}"
else:
core_list_str = f"{core_list_str}-{last_core},{core_id}"
first_core = core_id
last_core = core_id
if first_core != last_core:
core_list_str = f"{core_list_str}-{last_core}"
numactl_cmd.append(f"{core_list_str}")
return cores_per_rank, numactl_cmd

Просмотреть файл

@ -40,27 +40,43 @@ class SynchronizedWallClockTimer:
self.name_ = name
self.started_ = False
self.event_timers = []
self.use_host_timer = get_accelerator().is_synchronized_device()
self.start_event = None
self.elapsed_records = None
self.start_time = 0.0
self.end_time = 0.0
def start(self):
"""Start the timer."""
assert not self.started_, f"{self.name_} timer has already been started"
self.start_event = get_accelerator().Event(enable_timing=True)
self.start_event.record()
if self.use_host_timer:
self.start_time = time.time()
else:
event_class = get_accelerator().Event
self.start_event = event_class(enable_timing=True)
self.start_event.record()
self.started_ = True
def stop(self, reset=False, record=False):
"""Stop the timer."""
assert self.started_, "timer is not started"
end_event = get_accelerator().Event(enable_timing=True)
end_event.record()
self.event_timers.append(CudaEventTimer(self.start_event, end_event))
self.start_event = None
event_class = get_accelerator().Event
if self.use_host_timer:
self.end_time = time.time()
self.event_timers.append(self.end_time - self.start_time)
else:
event_class = get_accelerator().Event
end_event = event_class(enable_timing=True)
end_event.record()
self.event_timers.append(CudaEventTimer(self.start_event, end_event))
self.start_event = None
self.started_ = False
def _get_elapsed_msec(self):
self.elapsed_records = [et.get_elapsed_msec() for et in self.event_timers]
if self.use_host_timer:
self.elapsed_records = [et * 1000.0 for et in self.event_timers]
else:
self.elapsed_records = [et.get_elapsed_msec() for et in self.event_timers]
self.event_timers.clear()
return sum(self.elapsed_records)

Просмотреть файл

@ -441,7 +441,7 @@ class OpBuilder(ABC):
def load(self, verbose=True):
from deepspeed.git_version_info import installed_ops, torch_info
if installed_ops[self.name]:
if installed_ops.get(self.name, False):
# Ensure the op we're about to load was compiled with the same
# torch/cuda versions we are currently using at runtime.
self.validate_torch_version(torch_info)

Просмотреть файл

@ -0,0 +1,8 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
'''Copyright The Microsoft DeepSpeed Team'''
from .comm import CCLCommBuilder
from .no_impl import NotImplementedBuilder

34
op_builder/cpu/builder.py Normal file
Просмотреть файл

@ -0,0 +1,34 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
try:
# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
# if successful this also means we're doing a local install and not JIT compile path
from op_builder import __deepspeed__ # noqa: F401
from op_builder.builder import OpBuilder
except ImportError:
from deepspeed.ops.op_builder.builder import OpBuilder
class CPUOpBuilder(OpBuilder):
def builder(self):
from torch.utils.cpp_extension import CppExtension as ExtensionBuilder
compile_args = {'cxx': self.strip_empty_entries(self.cxx_args())}
cpp_ext = ExtensionBuilder(name=self.absolute_name(),
sources=self.strip_empty_entries(self.sources()),
include_dirs=self.strip_empty_entries(self.include_paths()),
libraries=self.strip_empty_entries(self.libraries_args()),
extra_compile_args=compile_args)
return cpp_ext
def cxx_args(self):
return ['-O3', '-std=c++14', '-g', '-Wno-reorder']
def libraries_args(self):
return []

41
op_builder/cpu/comm.py Normal file
Просмотреть файл

@ -0,0 +1,41 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os
from .builder import CPUOpBuilder
class CCLCommBuilder(CPUOpBuilder):
BUILD_VAR = "DS_BUILD_CCL_COMM"
NAME = "deepspeed_ccl_comm"
def __init__(self, name=None):
name = self.NAME if name is None else name
super().__init__(name=name)
def absolute_name(self):
return f'deepspeed.ops.comm.{self.NAME}_op'
def sources(self):
return ['csrc/cpu/comm/ccl.cpp']
def include_paths(self):
includes = ['csrc/cpu/includes']
return includes
def is_compatible(self, verbose=True):
# TODO: add soft compatibility check for private binary release.
# a soft check, as in we know it can be trivially changed.
return super().is_compatible(verbose)
def extra_ldflags(self):
ccl_root_path = os.environ.get("CCL_ROOT")
if ccl_root_path == None:
raise ValueError(
"Didn't find CCL_ROOT, install oneCCL from https://github.com/oneapi-src/oneCCL and source its environment variable"
)
return []
else:
return ['-lccl', f'-L{ccl_root_path}/lib']

24
op_builder/cpu/no_impl.py Normal file
Просмотреть файл

@ -0,0 +1,24 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .builder import CPUOpBuilder
class NotImplementedBuilder(CPUOpBuilder):
BUILD_VAR = "DS_BUILD_NOT_IMPLEMENTED"
NAME = "deepspeed_not_implemented"
def __init__(self, name=None):
name = self.NAME if name is None else name
super().__init__(name=name)
def absolute_name(self):
return f'deepspeed.ops.comm.{self.NAME}_op'
def load(self, verbose=True):
raise ValueError("This op had not been implemented on CPU backend.")
def sources(self):
return []

Просмотреть файл

@ -0,0 +1 @@
intel_extension_for_pytorch

Просмотреть файл

@ -107,7 +107,13 @@ class TestDistributedFixture(DistributedTest):
class TestDistAllReduce(DistributedTest):
world_size = [1, 2, 4]
device_count = get_accelerator().device_count()
if device_count >= 4:
world_size = [1, 2, 4]
elif device_count >= 2:
world_size = [1, 2]
else:
world_size = [1]
def test(self):
x = torch.ones(1, 3).to(get_accelerator().device_name()) * (dist.get_rank() + 1)

Просмотреть файл

@ -56,22 +56,26 @@ def set_accelerator_visible():
if is_rocm_pytorch:
rocm_smi = subprocess.check_output(['rocm-smi', '--showid'])
gpu_ids = filter(lambda s: 'GPU' in s, rocm_smi.decode('utf-8').strip().split('\n'))
num_gpus = len(list(gpu_ids))
num_accelerators = len(list(gpu_ids))
else:
nvidia_smi = subprocess.check_output(['nvidia-smi', '--list-gpus'])
num_gpus = len(nvidia_smi.decode('utf-8').strip().split('\n'))
else:
assert get_accelerator().device_name() == 'xpu'
num_accelerators = len(nvidia_smi.decode('utf-8').strip().split('\n'))
elif get_accelerator().device_name() == 'xpu':
import re
clinfo = subprocess.check_output(['clinfo'])
lines = clinfo.decode('utf-8').strip().split('\n')
num_gpus = 0
num_accelerators = 0
for line in lines:
match = re.search('Device Type.*GPU', line)
if match:
num_gpus += 1
num_accelerators += 1
else:
assert get_accelerator().device_name() == 'cpu'
cpu_sockets = int(
subprocess.check_output('cat /proc/cpuinfo | grep "physical id" | sort -u | wc -l', shell=True))
num_accelerators = cpu_sockets
cuda_visible = ",".join(map(str, range(num_gpus)))
cuda_visible = ",".join(map(str, range(num_accelerators)))
# rotate list based on xdist worker id, example below
# wid=0 -> ['0', '1', '2', '3']

Просмотреть файл

@ -54,7 +54,7 @@ class TestModelProfiling(DistributedTest):
local_rank = int(os.getenv("LOCAL_RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
pipe = pipeline(task, model, framework="pt", device=local_rank)
pipe = pipeline(task, model, framework="pt", device=get_accelerator().device_name(local_rank))
pipe.model = deepspeed.init_inference(pipe.model,
dtype=dtype,
mp_size=world_size,

Просмотреть файл

@ -6,7 +6,7 @@
import argparse
import pytest
import deepspeed
from deepspeed.launcher.launch import parse_range_list
from deepspeed.utils.numa import parse_range_list
def basic_parser():