DeepSpeed/accelerator/cuda_accelerator.py

273 строки
8.5 KiB
Python
Исходник Обычный вид История

# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os
import pkgutil
import importlib
from .abstract_accelerator import DeepSpeedAccelerator
# During setup stage torch may not be installed, pass on no torch will
# allow op builder related API to be executed.
try:
import torch.cuda
except ImportError:
pass
class CUDA_Accelerator(DeepSpeedAccelerator):
def __init__(self):
self._name = 'cuda'
self._communication_backend_name = 'nccl'
[CPU] Support Intel CPU inference (#3041) * add fallback path for kernels used in megatron * temporary numactl WA for SPR 56core * adapt core allocation according to number of ranks * add switch to turn on numactl * detect number of cores on the system * allow select a subset of the cores on the system to bind * remove unneeded changes * add ccl backend * change nccl to ccl * remove unused code * add comm/ccl to ops * initial ccl comm support * first broadcast case passed * add CCL_Backend to DeepSpeed * support comm timer for CPU * support barrier for comm backend * support specify master address from deepspeed command line * support pytorch 2.0 * remove 'block' from api * Tweak for debug Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com> * Remove unecessary directory Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com> * Add bf16 kernel support for inference * Add temporary torch implement for cpu inference * Add softmax ops cpu fallback for inference * bind cores to numa domain as well * merge latest change in gma/numactl * initial bf16 kernel support with fallback path * initial fallback path for bloom kernel injection * fix softmax attn mask * check KMP_AFFINITY to avoid conflict with numactl * New CCLBackend which utilize TorchBackend for initialization * rollback last change because there is result error * fix bloom injection policy TP could not work issue. injection_policy={BloomBlock: ("self_attention.dense", "mlp.dense_4h_to_h")} * Use TorchBackend to initialize CCLBackend, make behavior consistent * remove comm under deepspeed/ops * add license header * code clean up * fix format issue * remove magic number in main address * add caching support but not turn on by default * change name of inference_cuda_module to inference_module * Check for is_synchronized_device in accelerator before get Event * fix typo * Fix fallback path of softmax kernel on CUDA device for BF16 data type, because CUDA tril does not support BF16 datatype, enforce fp32 data type * add cpu backend files * change CPU_Accelerator op_builder_dir * remove cpu_kernel_path * using CPU_Accelerator on non-cuda device * fix deepspeed.op_builder => deepspeed.ops.op_builder * add alias for num_gpus: num_accelerators * allow loading cpu_builder in build stage * Assume cuda available if torch not installed * add oneccl_binding_pt to requirements * move oneccl-binding-pt to seperate requiremetns-cpu.txt * add missing file * use dependency_links in setuptools.setup() call for additional dependency links * install oneccl_bind_pt in workflows * change oneccl_bind_pt's version from 1.13 to 2.0 * use intel_exention_for_pytorch as indicator that CPU_Accelerator should be used * Add indicator for Accelerator used * change foo.c to foo.cpp * exclude 'cpu' directory in CUDA op builder reflection * add a cpu-inference workflow * run cpu-inference workflow on self-hosted instance * change cpu runs-on node to v100 node * print out python version in workflow * add verbose in pip command to understand oneccl_bind_pt install issue * update cpu-inference workflow * add a stage to detect instance instruction sets * add back bf16 support for CPU inference * enable autoTP for bloom Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * update workflow to detect cpu instruction sets * temporary WA for Intel Extension for PyTorch AVX2 instructioon set detection * change cpu-inference workflow machine to ubuntu-20.04 * add sharded checkpoint loading for AutoTP path to reduce the peak memory in initialization stage Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * enable policy for llama * use a special build ipex to test avx2 detection fix * fix format * fix test fail issue Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix gptj sharded checkpoint loading problem Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * return a not implemented build in get_op_builder in cpu_backend * support cpu device in tests * use cpuinfo to extract number of CPUs * use ~/tmp as transfomer cache rather than /blob/ * Add support for mpich launcher with prefer_deepspeed_comm * add missing modification in accelerator * enable IMPI launcher * remove unused file and fix formatting * clean up ccl.cpp * Less confusing error message when certin op builder are not implemented * Fix license header * Add license header * add license headers * add license header * fix cuda specific code in test * update CPU workflow * use numactl to bind to core * allow bind_cores_to_rank in multi-node impi runner * fix format error * Remove InferenceBuilder * fix format error in numa.py * check whether op is in installed ops in ds_report.py * allow override accelerator with DS_ACCELERATOR='cuda','cpu' or 'xpu' * lazy init class_dict in CUDA_Accelerator to avoid cyclic initialization of CUDA_Accelerator * put short path in the beginning in real_accelerator.py * device_count return number of NUMA nodes * fix typo * install numactl in cpu workflow * Follow comments * Better implementation of device_count() and current_device() * remove dependency_link for Intel Extension for DeepSpeed * use check is_synchronized_device in timer only once * remove env mapping WA in cpu_accelerator * fix duplicate definition * fix format error * refine ccl backend selection * move comments to the right place * remove prefer_deepspeed_comm, use CCLBackend by default * refractor fallback path * Fix execution failure in kernel injection path * do not refractory kernel injection fallback path in residual_add because it contains function call with side-effect * guard residual_add fallback path with environ DS_KI_FALLBACK=True * fix format error * add test for allreduce on CPU workflow * fix format error * Fallback to TorchBackend if CCLBackend kernel are not implemented * Update Intel Extension for Pytorch installation link * Don't specify version number of Intel Extension for PyTorch * install oneCCL for CCLBackend * fix link path for CPU comm kernels * fix source oneCCL environment * source oneCCL env before run UT * Give more specific instruction when CCL_ROOT not defined --------- Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com> Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> Co-authored-by: sdp <sdp@aia-sdp-spr-108864.jf.intel.com> Co-authored-by: Cao, Zhong Z <zhong.z.cao@intel.com> Co-authored-by: Zhenhuan Chen <zhenhuan.chen@intel.com> Co-authored-by: baodii <di.bao@intel.com> Co-authored-by: Wang, Yi A <yi.a.wang@intel.com> Co-authored-by: jianan-gu <jianan.gu@intel.com> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2023-05-16 18:59:22 +03:00
def is_synchronized_device(self):
return False
# Device APIs
def device_name(self, device_index=None):
if device_index == None:
return 'cuda'
return 'cuda:{}'.format(device_index)
def device(self, device_index=None):
return torch.cuda.device(device_index)
def set_device(self, device_index):
torch.cuda.set_device(device_index)
def current_device(self):
return torch.cuda.current_device()
def current_device_name(self):
return 'cuda:{}'.format(torch.cuda.current_device())
def device_count(self):
return torch.cuda.device_count()
def synchronize(self, device_index=None):
return torch.cuda.synchronize(device_index)
# RNG APIs
def random(self):
return torch.random
def set_rng_state(self, new_state, device_index=None):
if device_index is None:
return torch.cuda.set_rng_state(new_state)
return torch.cuda.set_rng_state(new_state, device_index)
def get_rng_state(self, device_index=None):
if device_index is None:
return torch.cuda.get_rng_state()
return torch.cuda.get_rng_state(device_index)
def manual_seed(self, seed):
return torch.cuda.manual_seed(seed)
def manual_seed_all(self, seed):
return torch.cuda.manual_seed_all(seed)
def initial_seed(self, seed):
return torch.cuda.initial_seed(seed)
def default_generator(self, device_index):
return torch.cuda.default_generators[device_index]
# Streams/Events
@property
def Stream(self):
return torch.cuda.Stream
def stream(self, stream):
return torch.cuda.stream(stream)
def current_stream(self, device_index=None):
return torch.cuda.current_stream(device_index)
def default_stream(self, device_index=None):
return torch.cuda.default_stream(device_index)
@property
def Event(self):
return torch.cuda.Event
# Memory management
def empty_cache(self):
return torch.cuda.empty_cache()
def memory_allocated(self, device_index=None):
return torch.cuda.memory_allocated(device_index)
def max_memory_allocated(self, device_index=None):
return torch.cuda.max_memory_allocated(device_index)
def reset_max_memory_allocated(self, device_index=None):
return torch.cuda.reset_max_memory_allocated(device_index)
def memory_cached(self, device_index=None):
return torch.cuda.memory_cached(device_index)
def max_memory_cached(self, device_index=None):
return torch.cuda.max_memory_cached(device_index)
def reset_max_memory_cached(self, device_index=None):
return torch.cuda.reset_max_memory_cached(device_index)
def memory_stats(self, device_index=None):
if hasattr(torch.cuda, 'memory_stats'):
return torch.cuda.memory_stats(device_index)
def reset_peak_memory_stats(self, device_index=None):
if hasattr(torch.cuda, 'reset_peak_memory_stats'):
return torch.cuda.reset_peak_memory_stats(device_index)
def memory_reserved(self, device_index=None):
if hasattr(torch.cuda, 'memory_reserved'):
return torch.cuda.memory_reserved(device_index)
def max_memory_reserved(self, device_index=None):
if hasattr(torch.cuda, 'max_memory_reserved'):
return torch.cuda.max_memory_reserved(device_index)
def total_memory(self, device_index=None):
return torch.cuda.get_device_properties(device_index).total_memory
# Data types
def is_bf16_supported(self):
return torch.cuda.is_bf16_supported()
def is_fp16_supported(self):
major, _ = torch.cuda.get_device_capability()
if major >= 7:
return True
else:
return False
def supported_dtypes(self):
return [torch.float, torch.half, torch.bfloat16]
# Misc
def amp(self):
if hasattr(torch.cuda, 'amp'):
return torch.cuda.amp
return None
def is_available(self):
return torch.cuda.is_available()
def range_push(self, msg):
if hasattr(torch.cuda.nvtx, 'range_push'):
return torch.cuda.nvtx.range_push(msg)
def range_pop(self):
if hasattr(torch.cuda.nvtx, 'range_pop'):
return torch.cuda.nvtx.range_pop()
def lazy_call(self, callback):
return torch.cuda._lazy_call(callback)
def communication_backend_name(self):
return self._communication_backend_name
# Tensor operations
@property
def BFloat16Tensor(self):
return torch.cuda.BFloat16Tensor
@property
def ByteTensor(self):
return torch.cuda.ByteTensor
@property
def DoubleTensor(self):
return torch.cuda.DoubleTensor
@property
def FloatTensor(self):
return torch.cuda.FloatTensor
@property
def HalfTensor(self):
return torch.cuda.HalfTensor
@property
def IntTensor(self):
return torch.cuda.IntTensor
@property
def LongTensor(self):
return torch.cuda.LongTensor
def pin_memory(self, tensor):
return tensor.pin_memory()
def on_accelerator(self, tensor):
device_str = str(tensor.device)
if device_str.startswith('cuda:'):
return True
else:
return False
def op_builder_dir(self):
try:
# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
# if successful this also means we're doing a local install and not JIT compile path
from op_builder import __deepspeed__ # noqa: F401
return "op_builder"
except ImportError:
return "deepspeed.ops.op_builder"
# dict that holds class name <--> class type mapping i.e.
# 'AsyncIOBuilder': <class 'op_builder.async_io.AsyncIOBuilder'>
# this dict will be filled at init stage
[CPU] Support Intel CPU inference (#3041) * add fallback path for kernels used in megatron * temporary numactl WA for SPR 56core * adapt core allocation according to number of ranks * add switch to turn on numactl * detect number of cores on the system * allow select a subset of the cores on the system to bind * remove unneeded changes * add ccl backend * change nccl to ccl * remove unused code * add comm/ccl to ops * initial ccl comm support * first broadcast case passed * add CCL_Backend to DeepSpeed * support comm timer for CPU * support barrier for comm backend * support specify master address from deepspeed command line * support pytorch 2.0 * remove 'block' from api * Tweak for debug Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com> * Remove unecessary directory Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com> * Add bf16 kernel support for inference * Add temporary torch implement for cpu inference * Add softmax ops cpu fallback for inference * bind cores to numa domain as well * merge latest change in gma/numactl * initial bf16 kernel support with fallback path * initial fallback path for bloom kernel injection * fix softmax attn mask * check KMP_AFFINITY to avoid conflict with numactl * New CCLBackend which utilize TorchBackend for initialization * rollback last change because there is result error * fix bloom injection policy TP could not work issue. injection_policy={BloomBlock: ("self_attention.dense", "mlp.dense_4h_to_h")} * Use TorchBackend to initialize CCLBackend, make behavior consistent * remove comm under deepspeed/ops * add license header * code clean up * fix format issue * remove magic number in main address * add caching support but not turn on by default * change name of inference_cuda_module to inference_module * Check for is_synchronized_device in accelerator before get Event * fix typo * Fix fallback path of softmax kernel on CUDA device for BF16 data type, because CUDA tril does not support BF16 datatype, enforce fp32 data type * add cpu backend files * change CPU_Accelerator op_builder_dir * remove cpu_kernel_path * using CPU_Accelerator on non-cuda device * fix deepspeed.op_builder => deepspeed.ops.op_builder * add alias for num_gpus: num_accelerators * allow loading cpu_builder in build stage * Assume cuda available if torch not installed * add oneccl_binding_pt to requirements * move oneccl-binding-pt to seperate requiremetns-cpu.txt * add missing file * use dependency_links in setuptools.setup() call for additional dependency links * install oneccl_bind_pt in workflows * change oneccl_bind_pt's version from 1.13 to 2.0 * use intel_exention_for_pytorch as indicator that CPU_Accelerator should be used * Add indicator for Accelerator used * change foo.c to foo.cpp * exclude 'cpu' directory in CUDA op builder reflection * add a cpu-inference workflow * run cpu-inference workflow on self-hosted instance * change cpu runs-on node to v100 node * print out python version in workflow * add verbose in pip command to understand oneccl_bind_pt install issue * update cpu-inference workflow * add a stage to detect instance instruction sets * add back bf16 support for CPU inference * enable autoTP for bloom Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * update workflow to detect cpu instruction sets * temporary WA for Intel Extension for PyTorch AVX2 instructioon set detection * change cpu-inference workflow machine to ubuntu-20.04 * add sharded checkpoint loading for AutoTP path to reduce the peak memory in initialization stage Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * enable policy for llama * use a special build ipex to test avx2 detection fix * fix format * fix test fail issue Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix gptj sharded checkpoint loading problem Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * return a not implemented build in get_op_builder in cpu_backend * support cpu device in tests * use cpuinfo to extract number of CPUs * use ~/tmp as transfomer cache rather than /blob/ * Add support for mpich launcher with prefer_deepspeed_comm * add missing modification in accelerator * enable IMPI launcher * remove unused file and fix formatting * clean up ccl.cpp * Less confusing error message when certin op builder are not implemented * Fix license header * Add license header * add license headers * add license header * fix cuda specific code in test * update CPU workflow * use numactl to bind to core * allow bind_cores_to_rank in multi-node impi runner * fix format error * Remove InferenceBuilder * fix format error in numa.py * check whether op is in installed ops in ds_report.py * allow override accelerator with DS_ACCELERATOR='cuda','cpu' or 'xpu' * lazy init class_dict in CUDA_Accelerator to avoid cyclic initialization of CUDA_Accelerator * put short path in the beginning in real_accelerator.py * device_count return number of NUMA nodes * fix typo * install numactl in cpu workflow * Follow comments * Better implementation of device_count() and current_device() * remove dependency_link for Intel Extension for DeepSpeed * use check is_synchronized_device in timer only once * remove env mapping WA in cpu_accelerator * fix duplicate definition * fix format error * refine ccl backend selection * move comments to the right place * remove prefer_deepspeed_comm, use CCLBackend by default * refractor fallback path * Fix execution failure in kernel injection path * do not refractory kernel injection fallback path in residual_add because it contains function call with side-effect * guard residual_add fallback path with environ DS_KI_FALLBACK=True * fix format error * add test for allreduce on CPU workflow * fix format error * Fallback to TorchBackend if CCLBackend kernel are not implemented * Update Intel Extension for Pytorch installation link * Don't specify version number of Intel Extension for PyTorch * install oneCCL for CCLBackend * fix link path for CPU comm kernels * fix source oneCCL environment * source oneCCL env before run UT * Give more specific instruction when CCL_ROOT not defined --------- Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com> Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> Co-authored-by: sdp <sdp@aia-sdp-spr-108864.jf.intel.com> Co-authored-by: Cao, Zhong Z <zhong.z.cao@intel.com> Co-authored-by: Zhenhuan Chen <zhenhuan.chen@intel.com> Co-authored-by: baodii <di.bao@intel.com> Co-authored-by: Wang, Yi A <yi.a.wang@intel.com> Co-authored-by: jianan-gu <jianan.gu@intel.com> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2023-05-16 18:59:22 +03:00
class_dict = None
def _lazy_init_class_dict(self):
if self.class_dict != None:
return
else:
self.class_dict = {}
# begin initialize for create_op_builder()
# put all valid class name <--> class type mapping into class_dict
op_builder_dir = self.op_builder_dir()
op_builder_module = importlib.import_module(op_builder_dir)
op_builder_absolute_path = os.path.dirname(op_builder_module.__file__)
for _, module_name, _ in pkgutil.iter_modules([op_builder_absolute_path]):
# avoid self references,
# skip sub_directories which contains ops for other backend(cpu, npu, etc.).
if module_name != 'all_ops' and module_name != 'builder' and not os.path.isdir(
os.path.join(op_builder_absolute_path, module_name)):
[CPU] Support Intel CPU inference (#3041) * add fallback path for kernels used in megatron * temporary numactl WA for SPR 56core * adapt core allocation according to number of ranks * add switch to turn on numactl * detect number of cores on the system * allow select a subset of the cores on the system to bind * remove unneeded changes * add ccl backend * change nccl to ccl * remove unused code * add comm/ccl to ops * initial ccl comm support * first broadcast case passed * add CCL_Backend to DeepSpeed * support comm timer for CPU * support barrier for comm backend * support specify master address from deepspeed command line * support pytorch 2.0 * remove 'block' from api * Tweak for debug Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com> * Remove unecessary directory Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com> * Add bf16 kernel support for inference * Add temporary torch implement for cpu inference * Add softmax ops cpu fallback for inference * bind cores to numa domain as well * merge latest change in gma/numactl * initial bf16 kernel support with fallback path * initial fallback path for bloom kernel injection * fix softmax attn mask * check KMP_AFFINITY to avoid conflict with numactl * New CCLBackend which utilize TorchBackend for initialization * rollback last change because there is result error * fix bloom injection policy TP could not work issue. injection_policy={BloomBlock: ("self_attention.dense", "mlp.dense_4h_to_h")} * Use TorchBackend to initialize CCLBackend, make behavior consistent * remove comm under deepspeed/ops * add license header * code clean up * fix format issue * remove magic number in main address * add caching support but not turn on by default * change name of inference_cuda_module to inference_module * Check for is_synchronized_device in accelerator before get Event * fix typo * Fix fallback path of softmax kernel on CUDA device for BF16 data type, because CUDA tril does not support BF16 datatype, enforce fp32 data type * add cpu backend files * change CPU_Accelerator op_builder_dir * remove cpu_kernel_path * using CPU_Accelerator on non-cuda device * fix deepspeed.op_builder => deepspeed.ops.op_builder * add alias for num_gpus: num_accelerators * allow loading cpu_builder in build stage * Assume cuda available if torch not installed * add oneccl_binding_pt to requirements * move oneccl-binding-pt to seperate requiremetns-cpu.txt * add missing file * use dependency_links in setuptools.setup() call for additional dependency links * install oneccl_bind_pt in workflows * change oneccl_bind_pt's version from 1.13 to 2.0 * use intel_exention_for_pytorch as indicator that CPU_Accelerator should be used * Add indicator for Accelerator used * change foo.c to foo.cpp * exclude 'cpu' directory in CUDA op builder reflection * add a cpu-inference workflow * run cpu-inference workflow on self-hosted instance * change cpu runs-on node to v100 node * print out python version in workflow * add verbose in pip command to understand oneccl_bind_pt install issue * update cpu-inference workflow * add a stage to detect instance instruction sets * add back bf16 support for CPU inference * enable autoTP for bloom Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * update workflow to detect cpu instruction sets * temporary WA for Intel Extension for PyTorch AVX2 instructioon set detection * change cpu-inference workflow machine to ubuntu-20.04 * add sharded checkpoint loading for AutoTP path to reduce the peak memory in initialization stage Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * enable policy for llama * use a special build ipex to test avx2 detection fix * fix format * fix test fail issue Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix gptj sharded checkpoint loading problem Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * return a not implemented build in get_op_builder in cpu_backend * support cpu device in tests * use cpuinfo to extract number of CPUs * use ~/tmp as transfomer cache rather than /blob/ * Add support for mpich launcher with prefer_deepspeed_comm * add missing modification in accelerator * enable IMPI launcher * remove unused file and fix formatting * clean up ccl.cpp * Less confusing error message when certin op builder are not implemented * Fix license header * Add license header * add license headers * add license header * fix cuda specific code in test * update CPU workflow * use numactl to bind to core * allow bind_cores_to_rank in multi-node impi runner * fix format error * Remove InferenceBuilder * fix format error in numa.py * check whether op is in installed ops in ds_report.py * allow override accelerator with DS_ACCELERATOR='cuda','cpu' or 'xpu' * lazy init class_dict in CUDA_Accelerator to avoid cyclic initialization of CUDA_Accelerator * put short path in the beginning in real_accelerator.py * device_count return number of NUMA nodes * fix typo * install numactl in cpu workflow * Follow comments * Better implementation of device_count() and current_device() * remove dependency_link for Intel Extension for DeepSpeed * use check is_synchronized_device in timer only once * remove env mapping WA in cpu_accelerator * fix duplicate definition * fix format error * refine ccl backend selection * move comments to the right place * remove prefer_deepspeed_comm, use CCLBackend by default * refractor fallback path * Fix execution failure in kernel injection path * do not refractory kernel injection fallback path in residual_add because it contains function call with side-effect * guard residual_add fallback path with environ DS_KI_FALLBACK=True * fix format error * add test for allreduce on CPU workflow * fix format error * Fallback to TorchBackend if CCLBackend kernel are not implemented * Update Intel Extension for Pytorch installation link * Don't specify version number of Intel Extension for PyTorch * install oneCCL for CCLBackend * fix link path for CPU comm kernels * fix source oneCCL environment * source oneCCL env before run UT * Give more specific instruction when CCL_ROOT not defined --------- Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com> Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> Co-authored-by: sdp <sdp@aia-sdp-spr-108864.jf.intel.com> Co-authored-by: Cao, Zhong Z <zhong.z.cao@intel.com> Co-authored-by: Zhenhuan Chen <zhenhuan.chen@intel.com> Co-authored-by: baodii <di.bao@intel.com> Co-authored-by: Wang, Yi A <yi.a.wang@intel.com> Co-authored-by: jianan-gu <jianan.gu@intel.com> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2023-05-16 18:59:22 +03:00
module = importlib.import_module("{}.{}".format(op_builder_dir, module_name))
for member_name in module.__dir__():
if member_name.endswith(
'Builder'
) and member_name != "OpBuilder" and member_name != "CUDAOpBuilder" and member_name != "TorchCPUOpBuilder": # avoid abstract classes
if not member_name in self.class_dict:
self.class_dict[member_name] = getattr(module, member_name)
# end initialize for create_op_builder()
# create an instance of op builder and return, name specified by class_name
def create_op_builder(self, class_name):
[CPU] Support Intel CPU inference (#3041) * add fallback path for kernels used in megatron * temporary numactl WA for SPR 56core * adapt core allocation according to number of ranks * add switch to turn on numactl * detect number of cores on the system * allow select a subset of the cores on the system to bind * remove unneeded changes * add ccl backend * change nccl to ccl * remove unused code * add comm/ccl to ops * initial ccl comm support * first broadcast case passed * add CCL_Backend to DeepSpeed * support comm timer for CPU * support barrier for comm backend * support specify master address from deepspeed command line * support pytorch 2.0 * remove 'block' from api * Tweak for debug Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com> * Remove unecessary directory Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com> * Add bf16 kernel support for inference * Add temporary torch implement for cpu inference * Add softmax ops cpu fallback for inference * bind cores to numa domain as well * merge latest change in gma/numactl * initial bf16 kernel support with fallback path * initial fallback path for bloom kernel injection * fix softmax attn mask * check KMP_AFFINITY to avoid conflict with numactl * New CCLBackend which utilize TorchBackend for initialization * rollback last change because there is result error * fix bloom injection policy TP could not work issue. injection_policy={BloomBlock: ("self_attention.dense", "mlp.dense_4h_to_h")} * Use TorchBackend to initialize CCLBackend, make behavior consistent * remove comm under deepspeed/ops * add license header * code clean up * fix format issue * remove magic number in main address * add caching support but not turn on by default * change name of inference_cuda_module to inference_module * Check for is_synchronized_device in accelerator before get Event * fix typo * Fix fallback path of softmax kernel on CUDA device for BF16 data type, because CUDA tril does not support BF16 datatype, enforce fp32 data type * add cpu backend files * change CPU_Accelerator op_builder_dir * remove cpu_kernel_path * using CPU_Accelerator on non-cuda device * fix deepspeed.op_builder => deepspeed.ops.op_builder * add alias for num_gpus: num_accelerators * allow loading cpu_builder in build stage * Assume cuda available if torch not installed * add oneccl_binding_pt to requirements * move oneccl-binding-pt to seperate requiremetns-cpu.txt * add missing file * use dependency_links in setuptools.setup() call for additional dependency links * install oneccl_bind_pt in workflows * change oneccl_bind_pt's version from 1.13 to 2.0 * use intel_exention_for_pytorch as indicator that CPU_Accelerator should be used * Add indicator for Accelerator used * change foo.c to foo.cpp * exclude 'cpu' directory in CUDA op builder reflection * add a cpu-inference workflow * run cpu-inference workflow on self-hosted instance * change cpu runs-on node to v100 node * print out python version in workflow * add verbose in pip command to understand oneccl_bind_pt install issue * update cpu-inference workflow * add a stage to detect instance instruction sets * add back bf16 support for CPU inference * enable autoTP for bloom Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * update workflow to detect cpu instruction sets * temporary WA for Intel Extension for PyTorch AVX2 instructioon set detection * change cpu-inference workflow machine to ubuntu-20.04 * add sharded checkpoint loading for AutoTP path to reduce the peak memory in initialization stage Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * enable policy for llama * use a special build ipex to test avx2 detection fix * fix format * fix test fail issue Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix gptj sharded checkpoint loading problem Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * return a not implemented build in get_op_builder in cpu_backend * support cpu device in tests * use cpuinfo to extract number of CPUs * use ~/tmp as transfomer cache rather than /blob/ * Add support for mpich launcher with prefer_deepspeed_comm * add missing modification in accelerator * enable IMPI launcher * remove unused file and fix formatting * clean up ccl.cpp * Less confusing error message when certin op builder are not implemented * Fix license header * Add license header * add license headers * add license header * fix cuda specific code in test * update CPU workflow * use numactl to bind to core * allow bind_cores_to_rank in multi-node impi runner * fix format error * Remove InferenceBuilder * fix format error in numa.py * check whether op is in installed ops in ds_report.py * allow override accelerator with DS_ACCELERATOR='cuda','cpu' or 'xpu' * lazy init class_dict in CUDA_Accelerator to avoid cyclic initialization of CUDA_Accelerator * put short path in the beginning in real_accelerator.py * device_count return number of NUMA nodes * fix typo * install numactl in cpu workflow * Follow comments * Better implementation of device_count() and current_device() * remove dependency_link for Intel Extension for DeepSpeed * use check is_synchronized_device in timer only once * remove env mapping WA in cpu_accelerator * fix duplicate definition * fix format error * refine ccl backend selection * move comments to the right place * remove prefer_deepspeed_comm, use CCLBackend by default * refractor fallback path * Fix execution failure in kernel injection path * do not refractory kernel injection fallback path in residual_add because it contains function call with side-effect * guard residual_add fallback path with environ DS_KI_FALLBACK=True * fix format error * add test for allreduce on CPU workflow * fix format error * Fallback to TorchBackend if CCLBackend kernel are not implemented * Update Intel Extension for Pytorch installation link * Don't specify version number of Intel Extension for PyTorch * install oneCCL for CCLBackend * fix link path for CPU comm kernels * fix source oneCCL environment * source oneCCL env before run UT * Give more specific instruction when CCL_ROOT not defined --------- Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com> Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> Co-authored-by: sdp <sdp@aia-sdp-spr-108864.jf.intel.com> Co-authored-by: Cao, Zhong Z <zhong.z.cao@intel.com> Co-authored-by: Zhenhuan Chen <zhenhuan.chen@intel.com> Co-authored-by: baodii <di.bao@intel.com> Co-authored-by: Wang, Yi A <yi.a.wang@intel.com> Co-authored-by: jianan-gu <jianan.gu@intel.com> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2023-05-16 18:59:22 +03:00
self._lazy_init_class_dict()
if class_name in self.class_dict:
return self.class_dict[class_name]()
else:
return None
# return an op builder class, name specified by class_name
def get_op_builder(self, class_name):
[CPU] Support Intel CPU inference (#3041) * add fallback path for kernels used in megatron * temporary numactl WA for SPR 56core * adapt core allocation according to number of ranks * add switch to turn on numactl * detect number of cores on the system * allow select a subset of the cores on the system to bind * remove unneeded changes * add ccl backend * change nccl to ccl * remove unused code * add comm/ccl to ops * initial ccl comm support * first broadcast case passed * add CCL_Backend to DeepSpeed * support comm timer for CPU * support barrier for comm backend * support specify master address from deepspeed command line * support pytorch 2.0 * remove 'block' from api * Tweak for debug Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com> * Remove unecessary directory Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com> * Add bf16 kernel support for inference * Add temporary torch implement for cpu inference * Add softmax ops cpu fallback for inference * bind cores to numa domain as well * merge latest change in gma/numactl * initial bf16 kernel support with fallback path * initial fallback path for bloom kernel injection * fix softmax attn mask * check KMP_AFFINITY to avoid conflict with numactl * New CCLBackend which utilize TorchBackend for initialization * rollback last change because there is result error * fix bloom injection policy TP could not work issue. injection_policy={BloomBlock: ("self_attention.dense", "mlp.dense_4h_to_h")} * Use TorchBackend to initialize CCLBackend, make behavior consistent * remove comm under deepspeed/ops * add license header * code clean up * fix format issue * remove magic number in main address * add caching support but not turn on by default * change name of inference_cuda_module to inference_module * Check for is_synchronized_device in accelerator before get Event * fix typo * Fix fallback path of softmax kernel on CUDA device for BF16 data type, because CUDA tril does not support BF16 datatype, enforce fp32 data type * add cpu backend files * change CPU_Accelerator op_builder_dir * remove cpu_kernel_path * using CPU_Accelerator on non-cuda device * fix deepspeed.op_builder => deepspeed.ops.op_builder * add alias for num_gpus: num_accelerators * allow loading cpu_builder in build stage * Assume cuda available if torch not installed * add oneccl_binding_pt to requirements * move oneccl-binding-pt to seperate requiremetns-cpu.txt * add missing file * use dependency_links in setuptools.setup() call for additional dependency links * install oneccl_bind_pt in workflows * change oneccl_bind_pt's version from 1.13 to 2.0 * use intel_exention_for_pytorch as indicator that CPU_Accelerator should be used * Add indicator for Accelerator used * change foo.c to foo.cpp * exclude 'cpu' directory in CUDA op builder reflection * add a cpu-inference workflow * run cpu-inference workflow on self-hosted instance * change cpu runs-on node to v100 node * print out python version in workflow * add verbose in pip command to understand oneccl_bind_pt install issue * update cpu-inference workflow * add a stage to detect instance instruction sets * add back bf16 support for CPU inference * enable autoTP for bloom Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * update workflow to detect cpu instruction sets * temporary WA for Intel Extension for PyTorch AVX2 instructioon set detection * change cpu-inference workflow machine to ubuntu-20.04 * add sharded checkpoint loading for AutoTP path to reduce the peak memory in initialization stage Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * enable policy for llama * use a special build ipex to test avx2 detection fix * fix format * fix test fail issue Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix gptj sharded checkpoint loading problem Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * return a not implemented build in get_op_builder in cpu_backend * support cpu device in tests * use cpuinfo to extract number of CPUs * use ~/tmp as transfomer cache rather than /blob/ * Add support for mpich launcher with prefer_deepspeed_comm * add missing modification in accelerator * enable IMPI launcher * remove unused file and fix formatting * clean up ccl.cpp * Less confusing error message when certin op builder are not implemented * Fix license header * Add license header * add license headers * add license header * fix cuda specific code in test * update CPU workflow * use numactl to bind to core * allow bind_cores_to_rank in multi-node impi runner * fix format error * Remove InferenceBuilder * fix format error in numa.py * check whether op is in installed ops in ds_report.py * allow override accelerator with DS_ACCELERATOR='cuda','cpu' or 'xpu' * lazy init class_dict in CUDA_Accelerator to avoid cyclic initialization of CUDA_Accelerator * put short path in the beginning in real_accelerator.py * device_count return number of NUMA nodes * fix typo * install numactl in cpu workflow * Follow comments * Better implementation of device_count() and current_device() * remove dependency_link for Intel Extension for DeepSpeed * use check is_synchronized_device in timer only once * remove env mapping WA in cpu_accelerator * fix duplicate definition * fix format error * refine ccl backend selection * move comments to the right place * remove prefer_deepspeed_comm, use CCLBackend by default * refractor fallback path * Fix execution failure in kernel injection path * do not refractory kernel injection fallback path in residual_add because it contains function call with side-effect * guard residual_add fallback path with environ DS_KI_FALLBACK=True * fix format error * add test for allreduce on CPU workflow * fix format error * Fallback to TorchBackend if CCLBackend kernel are not implemented * Update Intel Extension for Pytorch installation link * Don't specify version number of Intel Extension for PyTorch * install oneCCL for CCLBackend * fix link path for CPU comm kernels * fix source oneCCL environment * source oneCCL env before run UT * Give more specific instruction when CCL_ROOT not defined --------- Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com> Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> Co-authored-by: sdp <sdp@aia-sdp-spr-108864.jf.intel.com> Co-authored-by: Cao, Zhong Z <zhong.z.cao@intel.com> Co-authored-by: Zhenhuan Chen <zhenhuan.chen@intel.com> Co-authored-by: baodii <di.bao@intel.com> Co-authored-by: Wang, Yi A <yi.a.wang@intel.com> Co-authored-by: jianan-gu <jianan.gu@intel.com> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2023-05-16 18:59:22 +03:00
self._lazy_init_class_dict()
if class_name in self.class_dict:
return self.class_dict[class_name]
else:
return None
def build_extension(self):
from torch.utils.cpp_extension import BuildExtension
return BuildExtension