2023-03-31 03:14:38 +03:00
|
|
|
# Copyright (c) Microsoft Corporation.
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
|
|
|
# DeepSpeed Team
|
2023-02-27 22:20:41 +03:00
|
|
|
|
2024-02-16 19:09:44 +03:00
|
|
|
import functools
|
2023-01-07 07:40:58 +03:00
|
|
|
import os
|
|
|
|
import pkgutil
|
|
|
|
import importlib
|
2024-06-24 23:39:18 +03:00
|
|
|
import sys
|
2023-01-07 07:40:58 +03:00
|
|
|
|
|
|
|
from .abstract_accelerator import DeepSpeedAccelerator
|
|
|
|
# During setup stage torch may not be installed, pass on no torch will
|
|
|
|
# allow op builder related API to be executed.
|
|
|
|
try:
|
|
|
|
import torch.cuda
|
|
|
|
except ImportError:
|
|
|
|
pass
|
2022-11-30 15:15:58 +03:00
|
|
|
|
2023-10-16 15:23:28 +03:00
|
|
|
# Delay import pynvml to avoid import error when CUDA is not available
|
|
|
|
pynvml = None
|
|
|
|
|
2022-11-30 15:15:58 +03:00
|
|
|
|
|
|
|
class CUDA_Accelerator(DeepSpeedAccelerator):
|
2023-03-27 14:55:19 +03:00
|
|
|
|
2022-11-30 15:15:58 +03:00
|
|
|
def __init__(self):
|
|
|
|
self._name = 'cuda'
|
2024-06-24 23:39:18 +03:00
|
|
|
self._communication_backend_name = 'nccl' if sys.platform != 'win32' else 'gloo'
|
2024-04-24 18:25:18 +03:00
|
|
|
self._compile_backend = "inductor"
|
2023-10-16 15:23:28 +03:00
|
|
|
if pynvml is None:
|
|
|
|
self._init_pynvml()
|
|
|
|
|
|
|
|
def _init_pynvml(self):
|
|
|
|
global pynvml
|
|
|
|
try:
|
|
|
|
import pynvml
|
|
|
|
except ImportError:
|
|
|
|
return
|
|
|
|
try:
|
|
|
|
pynvml.nvmlInit()
|
|
|
|
except pynvml.NVMLError:
|
|
|
|
pynvml = None
|
|
|
|
return
|
2022-11-30 15:15:58 +03:00
|
|
|
|
[CPU] Support Intel CPU inference (#3041)
* add fallback path for kernels used in megatron
* temporary numactl WA for SPR 56core
* adapt core allocation according to number of ranks
* add switch to turn on numactl
* detect number of cores on the system
* allow select a subset of the cores on the system to bind
* remove unneeded changes
* add ccl backend
* change nccl to ccl
* remove unused code
* add comm/ccl to ops
* initial ccl comm support
* first broadcast case passed
* add CCL_Backend to DeepSpeed
* support comm timer for CPU
* support barrier for comm backend
* support specify master address from deepspeed command line
* support pytorch 2.0
* remove 'block' from api
* Tweak for debug
Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com>
* Remove unecessary directory
Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com>
* Add bf16 kernel support for inference
* Add temporary torch implement for cpu inference
* Add softmax ops cpu fallback for inference
* bind cores to numa domain as well
* merge latest change in gma/numactl
* initial bf16 kernel support with fallback path
* initial fallback path for bloom kernel injection
* fix softmax attn mask
* check KMP_AFFINITY to avoid conflict with numactl
* New CCLBackend which utilize TorchBackend for initialization
* rollback last change because there is result error
* fix bloom injection policy TP could not work issue.
injection_policy={BloomBlock: ("self_attention.dense", "mlp.dense_4h_to_h")}
* Use TorchBackend to initialize CCLBackend, make behavior consistent
* remove comm under deepspeed/ops
* add license header
* code clean up
* fix format issue
* remove magic number in main address
* add caching support but not turn on by default
* change name of inference_cuda_module to inference_module
* Check for is_synchronized_device in accelerator before get Event
* fix typo
* Fix fallback path of softmax kernel on CUDA device for BF16 data type, because CUDA tril does not support BF16 datatype, enforce fp32 data type
* add cpu backend files
* change CPU_Accelerator op_builder_dir
* remove cpu_kernel_path
* using CPU_Accelerator on non-cuda device
* fix deepspeed.op_builder => deepspeed.ops.op_builder
* add alias for num_gpus: num_accelerators
* allow loading cpu_builder in build stage
* Assume cuda available if torch not installed
* add oneccl_binding_pt to requirements
* move oneccl-binding-pt to seperate requiremetns-cpu.txt
* add missing file
* use dependency_links in setuptools.setup() call for additional dependency links
* install oneccl_bind_pt in workflows
* change oneccl_bind_pt's version from 1.13 to 2.0
* use intel_exention_for_pytorch as indicator that CPU_Accelerator should be used
* Add indicator for Accelerator used
* change foo.c to foo.cpp
* exclude 'cpu' directory in CUDA op builder reflection
* add a cpu-inference workflow
* run cpu-inference workflow on self-hosted instance
* change cpu runs-on node to v100 node
* print out python version in workflow
* add verbose in pip command to understand oneccl_bind_pt install issue
* update cpu-inference workflow
* add a stage to detect instance instruction sets
* add back bf16 support for CPU inference
* enable autoTP for bloom
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* update workflow to detect cpu instruction sets
* temporary WA for Intel Extension for PyTorch AVX2 instructioon set detection
* change cpu-inference workflow machine to ubuntu-20.04
* add sharded checkpoint loading for AutoTP path to reduce the peak memory in initialization stage
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* enable policy for llama
* use a special build ipex to test avx2 detection fix
* fix format
* fix test fail issue
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* fix gptj sharded checkpoint loading problem
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* return a not implemented build in get_op_builder in cpu_backend
* support cpu device in tests
* use cpuinfo to extract number of CPUs
* use ~/tmp as transfomer cache rather than /blob/
* Add support for mpich launcher with prefer_deepspeed_comm
* add missing modification in accelerator
* enable IMPI launcher
* remove unused file and fix formatting
* clean up ccl.cpp
* Less confusing error message when certin op builder are not implemented
* Fix license header
* Add license header
* add license headers
* add license header
* fix cuda specific code in test
* update CPU workflow
* use numactl to bind to core
* allow bind_cores_to_rank in multi-node impi runner
* fix format error
* Remove InferenceBuilder
* fix format error in numa.py
* check whether op is in installed ops in ds_report.py
* allow override accelerator with DS_ACCELERATOR='cuda','cpu' or 'xpu'
* lazy init class_dict in CUDA_Accelerator to avoid cyclic initialization of CUDA_Accelerator
* put short path in the beginning in real_accelerator.py
* device_count return number of NUMA nodes
* fix typo
* install numactl in cpu workflow
* Follow comments
* Better implementation of device_count() and current_device()
* remove dependency_link for Intel Extension for DeepSpeed
* use check is_synchronized_device in timer only once
* remove env mapping WA in cpu_accelerator
* fix duplicate definition
* fix format error
* refine ccl backend selection
* move comments to the right place
* remove prefer_deepspeed_comm, use CCLBackend by default
* refractor fallback path
* Fix execution failure in kernel injection path
* do not refractory kernel injection fallback path in residual_add because it contains function call with side-effect
* guard residual_add fallback path with environ DS_KI_FALLBACK=True
* fix format error
* add test for allreduce on CPU workflow
* fix format error
* Fallback to TorchBackend if CCLBackend kernel are not implemented
* Update Intel Extension for Pytorch installation link
* Don't specify version number of Intel Extension for PyTorch
* install oneCCL for CCLBackend
* fix link path for CPU comm kernels
* fix source oneCCL environment
* source oneCCL env before run UT
* Give more specific instruction when CCL_ROOT not defined
---------
Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com>
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: sdp <sdp@aia-sdp-spr-108864.jf.intel.com>
Co-authored-by: Cao, Zhong Z <zhong.z.cao@intel.com>
Co-authored-by: Zhenhuan Chen <zhenhuan.chen@intel.com>
Co-authored-by: baodii <di.bao@intel.com>
Co-authored-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: jianan-gu <jianan.gu@intel.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2023-05-16 18:59:22 +03:00
|
|
|
def is_synchronized_device(self):
|
|
|
|
return False
|
2023-01-07 07:40:58 +03:00
|
|
|
|
2024-02-08 19:03:56 +03:00
|
|
|
def use_host_timers(self):
|
|
|
|
return self.is_synchronized_device()
|
|
|
|
|
|
|
|
def resolves_data_dependency(self):
|
|
|
|
return self.is_synchronized_device()
|
|
|
|
|
|
|
|
def handles_memory_backpressure(self):
|
|
|
|
return self.is_synchronized_device()
|
|
|
|
|
2022-11-30 15:15:58 +03:00
|
|
|
# Device APIs
|
|
|
|
def device_name(self, device_index=None):
|
2023-12-08 23:28:48 +03:00
|
|
|
if device_index is None:
|
2022-11-30 15:15:58 +03:00
|
|
|
return 'cuda'
|
|
|
|
return 'cuda:{}'.format(device_index)
|
|
|
|
|
|
|
|
def device(self, device_index=None):
|
|
|
|
return torch.cuda.device(device_index)
|
|
|
|
|
|
|
|
def set_device(self, device_index):
|
|
|
|
torch.cuda.set_device(device_index)
|
|
|
|
|
|
|
|
def current_device(self):
|
|
|
|
return torch.cuda.current_device()
|
|
|
|
|
|
|
|
def current_device_name(self):
|
|
|
|
return 'cuda:{}'.format(torch.cuda.current_device())
|
|
|
|
|
|
|
|
def device_count(self):
|
|
|
|
return torch.cuda.device_count()
|
|
|
|
|
|
|
|
def synchronize(self, device_index=None):
|
|
|
|
return torch.cuda.synchronize(device_index)
|
|
|
|
|
|
|
|
# RNG APIs
|
|
|
|
def random(self):
|
|
|
|
return torch.random
|
|
|
|
|
|
|
|
def set_rng_state(self, new_state, device_index=None):
|
|
|
|
if device_index is None:
|
|
|
|
return torch.cuda.set_rng_state(new_state)
|
|
|
|
|
|
|
|
return torch.cuda.set_rng_state(new_state, device_index)
|
|
|
|
|
|
|
|
def get_rng_state(self, device_index=None):
|
|
|
|
if device_index is None:
|
|
|
|
return torch.cuda.get_rng_state()
|
|
|
|
|
|
|
|
return torch.cuda.get_rng_state(device_index)
|
|
|
|
|
|
|
|
def manual_seed(self, seed):
|
|
|
|
return torch.cuda.manual_seed(seed)
|
|
|
|
|
|
|
|
def manual_seed_all(self, seed):
|
|
|
|
return torch.cuda.manual_seed_all(seed)
|
|
|
|
|
2024-06-12 19:32:17 +03:00
|
|
|
def initial_seed(self):
|
|
|
|
return torch.cuda.initial_seed()
|
2022-11-30 15:15:58 +03:00
|
|
|
|
|
|
|
def default_generator(self, device_index):
|
|
|
|
return torch.cuda.default_generators[device_index]
|
|
|
|
|
|
|
|
# Streams/Events
|
2023-01-26 17:03:12 +03:00
|
|
|
@property
|
|
|
|
def Stream(self):
|
|
|
|
return torch.cuda.Stream
|
2022-11-30 15:15:58 +03:00
|
|
|
|
|
|
|
def stream(self, stream):
|
|
|
|
return torch.cuda.stream(stream)
|
|
|
|
|
|
|
|
def current_stream(self, device_index=None):
|
|
|
|
return torch.cuda.current_stream(device_index)
|
|
|
|
|
|
|
|
def default_stream(self, device_index=None):
|
|
|
|
return torch.cuda.default_stream(device_index)
|
|
|
|
|
2023-01-26 17:03:12 +03:00
|
|
|
@property
|
|
|
|
def Event(self):
|
|
|
|
return torch.cuda.Event
|
2022-11-30 15:15:58 +03:00
|
|
|
|
|
|
|
# Memory management
|
|
|
|
def empty_cache(self):
|
|
|
|
return torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
def memory_allocated(self, device_index=None):
|
|
|
|
return torch.cuda.memory_allocated(device_index)
|
|
|
|
|
|
|
|
def max_memory_allocated(self, device_index=None):
|
|
|
|
return torch.cuda.max_memory_allocated(device_index)
|
|
|
|
|
|
|
|
def reset_max_memory_allocated(self, device_index=None):
|
|
|
|
return torch.cuda.reset_max_memory_allocated(device_index)
|
|
|
|
|
|
|
|
def memory_cached(self, device_index=None):
|
|
|
|
return torch.cuda.memory_cached(device_index)
|
|
|
|
|
|
|
|
def max_memory_cached(self, device_index=None):
|
|
|
|
return torch.cuda.max_memory_cached(device_index)
|
|
|
|
|
|
|
|
def reset_max_memory_cached(self, device_index=None):
|
|
|
|
return torch.cuda.reset_max_memory_cached(device_index)
|
|
|
|
|
|
|
|
def memory_stats(self, device_index=None):
|
|
|
|
if hasattr(torch.cuda, 'memory_stats'):
|
|
|
|
return torch.cuda.memory_stats(device_index)
|
|
|
|
|
|
|
|
def reset_peak_memory_stats(self, device_index=None):
|
|
|
|
if hasattr(torch.cuda, 'reset_peak_memory_stats'):
|
|
|
|
return torch.cuda.reset_peak_memory_stats(device_index)
|
|
|
|
|
|
|
|
def memory_reserved(self, device_index=None):
|
|
|
|
if hasattr(torch.cuda, 'memory_reserved'):
|
|
|
|
return torch.cuda.memory_reserved(device_index)
|
|
|
|
|
|
|
|
def max_memory_reserved(self, device_index=None):
|
|
|
|
if hasattr(torch.cuda, 'max_memory_reserved'):
|
|
|
|
return torch.cuda.max_memory_reserved(device_index)
|
|
|
|
|
|
|
|
def total_memory(self, device_index=None):
|
|
|
|
return torch.cuda.get_device_properties(device_index).total_memory
|
|
|
|
|
2023-11-04 01:07:35 +03:00
|
|
|
def _get_nvml_gpu_id(self, torch_gpu_id):
|
|
|
|
"""
|
|
|
|
credit: https://discuss.pytorch.org/t/making-pynvml-match-torch-device-ids-cuda-visible-devices/103020
|
|
|
|
|
|
|
|
Remap torch device id to nvml device id, respecting CUDA_VISIBLE_DEVICES.
|
|
|
|
|
|
|
|
If the latter isn't set return the same id
|
|
|
|
"""
|
|
|
|
# if CUDA_VISIBLE_DEVICES is used automagically remap the id since pynvml ignores this env var
|
|
|
|
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
|
|
|
ids = list(map(int, os.environ.get("CUDA_VISIBLE_DEVICES", "").split(",")))
|
|
|
|
return ids[torch_gpu_id] # remap
|
|
|
|
else:
|
|
|
|
return torch_gpu_id
|
|
|
|
|
2023-10-16 15:23:28 +03:00
|
|
|
def available_memory(self, device_index=None):
|
|
|
|
if pynvml:
|
2023-11-04 01:07:35 +03:00
|
|
|
if device_index is None:
|
|
|
|
device_index = self.current_device()
|
|
|
|
handle = pynvml.nvmlDeviceGetHandleByIndex(self._get_nvml_gpu_id(device_index))
|
2023-10-16 15:23:28 +03:00
|
|
|
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
|
|
|
return info.free
|
|
|
|
else:
|
|
|
|
return self.total_memory(device_index) - self.memory_allocated(device_index)
|
|
|
|
|
2022-11-30 15:15:58 +03:00
|
|
|
# Data types
|
|
|
|
def is_bf16_supported(self):
|
2024-02-01 20:02:57 +03:00
|
|
|
if not torch.cuda.is_available():
|
|
|
|
return True
|
2022-11-30 15:15:58 +03:00
|
|
|
return torch.cuda.is_bf16_supported()
|
|
|
|
|
|
|
|
def is_fp16_supported(self):
|
2024-02-01 20:02:57 +03:00
|
|
|
if not torch.cuda.is_available():
|
|
|
|
return True
|
2024-01-31 02:20:53 +03:00
|
|
|
# See https://docs.nvidia.com/deeplearning/tensorrt/support-matrix/index.html#hardware-precision-matrix
|
|
|
|
# FP16 on compute capability 6.x is deprecated
|
|
|
|
allow_deprecated_fp16 = os.environ.get('DS_ALLOW_DEPRECATED_FP16', '0') == '1'
|
2022-11-30 15:15:58 +03:00
|
|
|
major, _ = torch.cuda.get_device_capability()
|
|
|
|
if major >= 7:
|
|
|
|
return True
|
2024-01-31 02:20:53 +03:00
|
|
|
elif major == 6 and allow_deprecated_fp16:
|
|
|
|
return True
|
2022-11-30 15:15:58 +03:00
|
|
|
else:
|
|
|
|
return False
|
|
|
|
|
2023-07-19 22:58:38 +03:00
|
|
|
def supported_dtypes(self):
|
2024-01-31 02:20:53 +03:00
|
|
|
supported_dtypes = [torch.float]
|
|
|
|
if self.is_fp16_supported():
|
|
|
|
supported_dtypes.append(torch.half)
|
|
|
|
if self.is_bf16_supported():
|
|
|
|
supported_dtypes.append(torch.bfloat16)
|
|
|
|
return supported_dtypes
|
2023-07-19 22:58:38 +03:00
|
|
|
|
2022-11-30 15:15:58 +03:00
|
|
|
# Misc
|
|
|
|
def amp(self):
|
|
|
|
if hasattr(torch.cuda, 'amp'):
|
|
|
|
return torch.cuda.amp
|
|
|
|
return None
|
|
|
|
|
|
|
|
def is_available(self):
|
|
|
|
return torch.cuda.is_available()
|
|
|
|
|
|
|
|
def range_push(self, msg):
|
|
|
|
if hasattr(torch.cuda.nvtx, 'range_push'):
|
|
|
|
return torch.cuda.nvtx.range_push(msg)
|
|
|
|
|
|
|
|
def range_pop(self):
|
|
|
|
if hasattr(torch.cuda.nvtx, 'range_pop'):
|
|
|
|
return torch.cuda.nvtx.range_pop()
|
|
|
|
|
|
|
|
def lazy_call(self, callback):
|
|
|
|
return torch.cuda._lazy_call(callback)
|
|
|
|
|
|
|
|
def communication_backend_name(self):
|
|
|
|
return self._communication_backend_name
|
|
|
|
|
2023-09-21 02:33:02 +03:00
|
|
|
def is_triton_supported(self):
|
|
|
|
major, _ = torch.cuda.get_device_capability()
|
|
|
|
if major >= 8:
|
|
|
|
return True
|
|
|
|
else:
|
|
|
|
return False
|
|
|
|
|
Capture short kernel sequences to graph (#4318)
**Motivation:**
1. This is a series of cases where short kernel sequences are launched
and executed serially(no dynamic shape), with the launch overhead being
much higher than the execution overhead. We can use a graph to solve
this problem. Compared to ```multi-tensor-apply```, using graph is more
concise and only requires PyTorch as a dependency.
2. Some device software stacks also support lazy-mode PyTorch, enabling
full utilization of the compiler to perform graph optimization. However,
in lazy mode, operation accumulation time (host time) could become
significantly higher compared to device time in such scenario, and
devices are usually not well utilized. By using the same API(after
adding to accelerator cc @delock ) with cuda graph, this issue could
also be resolved.
**Change:**
We modified three functions,
```update_hp_grads```. Here, we executed the operations for the CPU and GPU separately because the graph is unable to record the execution of CPU operations. Additionally, the data input required by the graph must not have its address modified, or the address modification must be captured by the capture operation(In this case, set ```replay_first_step``` to ```True```). Therefore, we changed ```grad=None``` to ```grad.zero_()```. Similarly, we have also placed some inputs that require fixed addresses in the ```graph_cache```
For ```clip_tensors_by_global_norm```, ```clip_coef``` is a scalar with a non-fixed value, so it needs to be moved to the GPU when using a graph.
For ```total_norm = sum ([t. data. float (). norm (norm_type). item () * * norm_type for t in input_tensors])```, ```item () ```, synchronous operation is also not supported by graph. We directly put the ```sum``` and ```* * norm_type``` on the GPU to execute the computation.
Other similar scenarios can also use this ```graph_process()```, or a slightly modified version of ```graph_process()```
you can checkout
[4abab21](https://github.com/microsoft/DeepSpeed/pull/4318/commits/4abab212c8f5aef1eec4f8abe10b4262bb5a5c8a) and set it to True here to do some benchmarking.
https://github.com/microsoft/DeepSpeed/pull/4318/commits/4abab212c8f5aef1eec4f8abe10b4262bb5a5c8a#diff-f8f0b3feb55b0374615405e542c1c3e0f017982b177c46c562bf688532ac935cR42
---------
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
2023-12-20 23:51:36 +03:00
|
|
|
# Graph operations
|
|
|
|
def create_graph(self):
|
|
|
|
return torch.cuda.CUDAGraph()
|
|
|
|
|
|
|
|
def capture_to_graph(self, graph, pool=None, stream=None):
|
|
|
|
return torch.cuda.graph(graph, pool, stream)
|
|
|
|
|
|
|
|
def replay_graph(self, graph):
|
|
|
|
graph.replay()
|
|
|
|
return
|
|
|
|
|
2022-11-30 15:15:58 +03:00
|
|
|
# Tensor operations
|
|
|
|
|
|
|
|
@property
|
|
|
|
def BFloat16Tensor(self):
|
2024-02-16 19:09:44 +03:00
|
|
|
return functools.partial(torch.tensor, dtype=torch.bfloat16, device='cuda')
|
2022-11-30 15:15:58 +03:00
|
|
|
|
|
|
|
@property
|
|
|
|
def ByteTensor(self):
|
2024-02-16 19:09:44 +03:00
|
|
|
return functools.partial(torch.tensor, dtype=torch.uint8, device='cuda')
|
2022-11-30 15:15:58 +03:00
|
|
|
|
|
|
|
@property
|
|
|
|
def DoubleTensor(self):
|
2024-02-16 19:09:44 +03:00
|
|
|
return functools.partial(torch.tensor, dtype=torch.double, device='cuda')
|
2022-11-30 15:15:58 +03:00
|
|
|
|
|
|
|
@property
|
|
|
|
def FloatTensor(self):
|
2024-02-16 19:09:44 +03:00
|
|
|
return functools.partial(torch.tensor, dtype=torch.float, device='cuda')
|
2022-11-30 15:15:58 +03:00
|
|
|
|
|
|
|
@property
|
|
|
|
def HalfTensor(self):
|
2024-02-16 19:09:44 +03:00
|
|
|
return functools.partial(torch.tensor, dtype=torch.half, device='cuda')
|
2022-11-30 15:15:58 +03:00
|
|
|
|
|
|
|
@property
|
|
|
|
def IntTensor(self):
|
2024-02-16 19:09:44 +03:00
|
|
|
return functools.partial(torch.tensor, dtype=torch.int, device='cuda')
|
2022-11-30 15:15:58 +03:00
|
|
|
|
|
|
|
@property
|
|
|
|
def LongTensor(self):
|
2024-02-16 19:09:44 +03:00
|
|
|
return functools.partial(torch.tensor, dtype=torch.long, device='cuda')
|
2022-11-30 15:15:58 +03:00
|
|
|
|
2023-10-03 16:17:08 +03:00
|
|
|
def pin_memory(self, tensor, align_bytes=1):
|
2022-11-30 15:15:58 +03:00
|
|
|
return tensor.pin_memory()
|
|
|
|
|
2023-10-03 16:17:08 +03:00
|
|
|
def is_pinned(self, tensor):
|
|
|
|
return tensor.is_pinned()
|
|
|
|
|
2022-11-30 15:15:58 +03:00
|
|
|
def on_accelerator(self, tensor):
|
|
|
|
device_str = str(tensor.device)
|
|
|
|
if device_str.startswith('cuda:'):
|
|
|
|
return True
|
|
|
|
else:
|
|
|
|
return False
|
|
|
|
|
|
|
|
def op_builder_dir(self):
|
2023-01-07 07:40:58 +03:00
|
|
|
try:
|
2023-03-08 20:55:41 +03:00
|
|
|
# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
|
|
|
|
# if successful this also means we're doing a local install and not JIT compile path
|
2023-08-08 16:30:23 +03:00
|
|
|
from op_builder import __deepspeed__ # noqa: F401 # type: ignore
|
2023-01-07 07:40:58 +03:00
|
|
|
return "op_builder"
|
|
|
|
except ImportError:
|
|
|
|
return "deepspeed.ops.op_builder"
|
|
|
|
|
|
|
|
# dict that holds class name <--> class type mapping i.e.
|
|
|
|
# 'AsyncIOBuilder': <class 'op_builder.async_io.AsyncIOBuilder'>
|
|
|
|
# this dict will be filled at init stage
|
[CPU] Support Intel CPU inference (#3041)
* add fallback path for kernels used in megatron
* temporary numactl WA for SPR 56core
* adapt core allocation according to number of ranks
* add switch to turn on numactl
* detect number of cores on the system
* allow select a subset of the cores on the system to bind
* remove unneeded changes
* add ccl backend
* change nccl to ccl
* remove unused code
* add comm/ccl to ops
* initial ccl comm support
* first broadcast case passed
* add CCL_Backend to DeepSpeed
* support comm timer for CPU
* support barrier for comm backend
* support specify master address from deepspeed command line
* support pytorch 2.0
* remove 'block' from api
* Tweak for debug
Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com>
* Remove unecessary directory
Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com>
* Add bf16 kernel support for inference
* Add temporary torch implement for cpu inference
* Add softmax ops cpu fallback for inference
* bind cores to numa domain as well
* merge latest change in gma/numactl
* initial bf16 kernel support with fallback path
* initial fallback path for bloom kernel injection
* fix softmax attn mask
* check KMP_AFFINITY to avoid conflict with numactl
* New CCLBackend which utilize TorchBackend for initialization
* rollback last change because there is result error
* fix bloom injection policy TP could not work issue.
injection_policy={BloomBlock: ("self_attention.dense", "mlp.dense_4h_to_h")}
* Use TorchBackend to initialize CCLBackend, make behavior consistent
* remove comm under deepspeed/ops
* add license header
* code clean up
* fix format issue
* remove magic number in main address
* add caching support but not turn on by default
* change name of inference_cuda_module to inference_module
* Check for is_synchronized_device in accelerator before get Event
* fix typo
* Fix fallback path of softmax kernel on CUDA device for BF16 data type, because CUDA tril does not support BF16 datatype, enforce fp32 data type
* add cpu backend files
* change CPU_Accelerator op_builder_dir
* remove cpu_kernel_path
* using CPU_Accelerator on non-cuda device
* fix deepspeed.op_builder => deepspeed.ops.op_builder
* add alias for num_gpus: num_accelerators
* allow loading cpu_builder in build stage
* Assume cuda available if torch not installed
* add oneccl_binding_pt to requirements
* move oneccl-binding-pt to seperate requiremetns-cpu.txt
* add missing file
* use dependency_links in setuptools.setup() call for additional dependency links
* install oneccl_bind_pt in workflows
* change oneccl_bind_pt's version from 1.13 to 2.0
* use intel_exention_for_pytorch as indicator that CPU_Accelerator should be used
* Add indicator for Accelerator used
* change foo.c to foo.cpp
* exclude 'cpu' directory in CUDA op builder reflection
* add a cpu-inference workflow
* run cpu-inference workflow on self-hosted instance
* change cpu runs-on node to v100 node
* print out python version in workflow
* add verbose in pip command to understand oneccl_bind_pt install issue
* update cpu-inference workflow
* add a stage to detect instance instruction sets
* add back bf16 support for CPU inference
* enable autoTP for bloom
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* update workflow to detect cpu instruction sets
* temporary WA for Intel Extension for PyTorch AVX2 instructioon set detection
* change cpu-inference workflow machine to ubuntu-20.04
* add sharded checkpoint loading for AutoTP path to reduce the peak memory in initialization stage
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* enable policy for llama
* use a special build ipex to test avx2 detection fix
* fix format
* fix test fail issue
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* fix gptj sharded checkpoint loading problem
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* return a not implemented build in get_op_builder in cpu_backend
* support cpu device in tests
* use cpuinfo to extract number of CPUs
* use ~/tmp as transfomer cache rather than /blob/
* Add support for mpich launcher with prefer_deepspeed_comm
* add missing modification in accelerator
* enable IMPI launcher
* remove unused file and fix formatting
* clean up ccl.cpp
* Less confusing error message when certin op builder are not implemented
* Fix license header
* Add license header
* add license headers
* add license header
* fix cuda specific code in test
* update CPU workflow
* use numactl to bind to core
* allow bind_cores_to_rank in multi-node impi runner
* fix format error
* Remove InferenceBuilder
* fix format error in numa.py
* check whether op is in installed ops in ds_report.py
* allow override accelerator with DS_ACCELERATOR='cuda','cpu' or 'xpu'
* lazy init class_dict in CUDA_Accelerator to avoid cyclic initialization of CUDA_Accelerator
* put short path in the beginning in real_accelerator.py
* device_count return number of NUMA nodes
* fix typo
* install numactl in cpu workflow
* Follow comments
* Better implementation of device_count() and current_device()
* remove dependency_link for Intel Extension for DeepSpeed
* use check is_synchronized_device in timer only once
* remove env mapping WA in cpu_accelerator
* fix duplicate definition
* fix format error
* refine ccl backend selection
* move comments to the right place
* remove prefer_deepspeed_comm, use CCLBackend by default
* refractor fallback path
* Fix execution failure in kernel injection path
* do not refractory kernel injection fallback path in residual_add because it contains function call with side-effect
* guard residual_add fallback path with environ DS_KI_FALLBACK=True
* fix format error
* add test for allreduce on CPU workflow
* fix format error
* Fallback to TorchBackend if CCLBackend kernel are not implemented
* Update Intel Extension for Pytorch installation link
* Don't specify version number of Intel Extension for PyTorch
* install oneCCL for CCLBackend
* fix link path for CPU comm kernels
* fix source oneCCL environment
* source oneCCL env before run UT
* Give more specific instruction when CCL_ROOT not defined
---------
Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com>
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: sdp <sdp@aia-sdp-spr-108864.jf.intel.com>
Co-authored-by: Cao, Zhong Z <zhong.z.cao@intel.com>
Co-authored-by: Zhenhuan Chen <zhenhuan.chen@intel.com>
Co-authored-by: baodii <di.bao@intel.com>
Co-authored-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: jianan-gu <jianan.gu@intel.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2023-05-16 18:59:22 +03:00
|
|
|
class_dict = None
|
|
|
|
|
|
|
|
def _lazy_init_class_dict(self):
|
2023-12-08 23:28:48 +03:00
|
|
|
if self.class_dict is not None:
|
[CPU] Support Intel CPU inference (#3041)
* add fallback path for kernels used in megatron
* temporary numactl WA for SPR 56core
* adapt core allocation according to number of ranks
* add switch to turn on numactl
* detect number of cores on the system
* allow select a subset of the cores on the system to bind
* remove unneeded changes
* add ccl backend
* change nccl to ccl
* remove unused code
* add comm/ccl to ops
* initial ccl comm support
* first broadcast case passed
* add CCL_Backend to DeepSpeed
* support comm timer for CPU
* support barrier for comm backend
* support specify master address from deepspeed command line
* support pytorch 2.0
* remove 'block' from api
* Tweak for debug
Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com>
* Remove unecessary directory
Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com>
* Add bf16 kernel support for inference
* Add temporary torch implement for cpu inference
* Add softmax ops cpu fallback for inference
* bind cores to numa domain as well
* merge latest change in gma/numactl
* initial bf16 kernel support with fallback path
* initial fallback path for bloom kernel injection
* fix softmax attn mask
* check KMP_AFFINITY to avoid conflict with numactl
* New CCLBackend which utilize TorchBackend for initialization
* rollback last change because there is result error
* fix bloom injection policy TP could not work issue.
injection_policy={BloomBlock: ("self_attention.dense", "mlp.dense_4h_to_h")}
* Use TorchBackend to initialize CCLBackend, make behavior consistent
* remove comm under deepspeed/ops
* add license header
* code clean up
* fix format issue
* remove magic number in main address
* add caching support but not turn on by default
* change name of inference_cuda_module to inference_module
* Check for is_synchronized_device in accelerator before get Event
* fix typo
* Fix fallback path of softmax kernel on CUDA device for BF16 data type, because CUDA tril does not support BF16 datatype, enforce fp32 data type
* add cpu backend files
* change CPU_Accelerator op_builder_dir
* remove cpu_kernel_path
* using CPU_Accelerator on non-cuda device
* fix deepspeed.op_builder => deepspeed.ops.op_builder
* add alias for num_gpus: num_accelerators
* allow loading cpu_builder in build stage
* Assume cuda available if torch not installed
* add oneccl_binding_pt to requirements
* move oneccl-binding-pt to seperate requiremetns-cpu.txt
* add missing file
* use dependency_links in setuptools.setup() call for additional dependency links
* install oneccl_bind_pt in workflows
* change oneccl_bind_pt's version from 1.13 to 2.0
* use intel_exention_for_pytorch as indicator that CPU_Accelerator should be used
* Add indicator for Accelerator used
* change foo.c to foo.cpp
* exclude 'cpu' directory in CUDA op builder reflection
* add a cpu-inference workflow
* run cpu-inference workflow on self-hosted instance
* change cpu runs-on node to v100 node
* print out python version in workflow
* add verbose in pip command to understand oneccl_bind_pt install issue
* update cpu-inference workflow
* add a stage to detect instance instruction sets
* add back bf16 support for CPU inference
* enable autoTP for bloom
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* update workflow to detect cpu instruction sets
* temporary WA for Intel Extension for PyTorch AVX2 instructioon set detection
* change cpu-inference workflow machine to ubuntu-20.04
* add sharded checkpoint loading for AutoTP path to reduce the peak memory in initialization stage
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* enable policy for llama
* use a special build ipex to test avx2 detection fix
* fix format
* fix test fail issue
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* fix gptj sharded checkpoint loading problem
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* return a not implemented build in get_op_builder in cpu_backend
* support cpu device in tests
* use cpuinfo to extract number of CPUs
* use ~/tmp as transfomer cache rather than /blob/
* Add support for mpich launcher with prefer_deepspeed_comm
* add missing modification in accelerator
* enable IMPI launcher
* remove unused file and fix formatting
* clean up ccl.cpp
* Less confusing error message when certin op builder are not implemented
* Fix license header
* Add license header
* add license headers
* add license header
* fix cuda specific code in test
* update CPU workflow
* use numactl to bind to core
* allow bind_cores_to_rank in multi-node impi runner
* fix format error
* Remove InferenceBuilder
* fix format error in numa.py
* check whether op is in installed ops in ds_report.py
* allow override accelerator with DS_ACCELERATOR='cuda','cpu' or 'xpu'
* lazy init class_dict in CUDA_Accelerator to avoid cyclic initialization of CUDA_Accelerator
* put short path in the beginning in real_accelerator.py
* device_count return number of NUMA nodes
* fix typo
* install numactl in cpu workflow
* Follow comments
* Better implementation of device_count() and current_device()
* remove dependency_link for Intel Extension for DeepSpeed
* use check is_synchronized_device in timer only once
* remove env mapping WA in cpu_accelerator
* fix duplicate definition
* fix format error
* refine ccl backend selection
* move comments to the right place
* remove prefer_deepspeed_comm, use CCLBackend by default
* refractor fallback path
* Fix execution failure in kernel injection path
* do not refractory kernel injection fallback path in residual_add because it contains function call with side-effect
* guard residual_add fallback path with environ DS_KI_FALLBACK=True
* fix format error
* add test for allreduce on CPU workflow
* fix format error
* Fallback to TorchBackend if CCLBackend kernel are not implemented
* Update Intel Extension for Pytorch installation link
* Don't specify version number of Intel Extension for PyTorch
* install oneCCL for CCLBackend
* fix link path for CPU comm kernels
* fix source oneCCL environment
* source oneCCL env before run UT
* Give more specific instruction when CCL_ROOT not defined
---------
Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com>
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: sdp <sdp@aia-sdp-spr-108864.jf.intel.com>
Co-authored-by: Cao, Zhong Z <zhong.z.cao@intel.com>
Co-authored-by: Zhenhuan Chen <zhenhuan.chen@intel.com>
Co-authored-by: baodii <di.bao@intel.com>
Co-authored-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: jianan-gu <jianan.gu@intel.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2023-05-16 18:59:22 +03:00
|
|
|
return
|
|
|
|
else:
|
|
|
|
self.class_dict = {}
|
|
|
|
# begin initialize for create_op_builder()
|
|
|
|
# put all valid class name <--> class type mapping into class_dict
|
|
|
|
op_builder_dir = self.op_builder_dir()
|
|
|
|
op_builder_module = importlib.import_module(op_builder_dir)
|
2023-07-22 15:52:27 +03:00
|
|
|
op_builder_absolute_path = os.path.dirname(op_builder_module.__file__)
|
|
|
|
for _, module_name, _ in pkgutil.iter_modules([op_builder_absolute_path]):
|
|
|
|
# avoid self references,
|
|
|
|
# skip sub_directories which contains ops for other backend(cpu, npu, etc.).
|
|
|
|
if module_name != 'all_ops' and module_name != 'builder' and not os.path.isdir(
|
|
|
|
os.path.join(op_builder_absolute_path, module_name)):
|
[CPU] Support Intel CPU inference (#3041)
* add fallback path for kernels used in megatron
* temporary numactl WA for SPR 56core
* adapt core allocation according to number of ranks
* add switch to turn on numactl
* detect number of cores on the system
* allow select a subset of the cores on the system to bind
* remove unneeded changes
* add ccl backend
* change nccl to ccl
* remove unused code
* add comm/ccl to ops
* initial ccl comm support
* first broadcast case passed
* add CCL_Backend to DeepSpeed
* support comm timer for CPU
* support barrier for comm backend
* support specify master address from deepspeed command line
* support pytorch 2.0
* remove 'block' from api
* Tweak for debug
Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com>
* Remove unecessary directory
Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com>
* Add bf16 kernel support for inference
* Add temporary torch implement for cpu inference
* Add softmax ops cpu fallback for inference
* bind cores to numa domain as well
* merge latest change in gma/numactl
* initial bf16 kernel support with fallback path
* initial fallback path for bloom kernel injection
* fix softmax attn mask
* check KMP_AFFINITY to avoid conflict with numactl
* New CCLBackend which utilize TorchBackend for initialization
* rollback last change because there is result error
* fix bloom injection policy TP could not work issue.
injection_policy={BloomBlock: ("self_attention.dense", "mlp.dense_4h_to_h")}
* Use TorchBackend to initialize CCLBackend, make behavior consistent
* remove comm under deepspeed/ops
* add license header
* code clean up
* fix format issue
* remove magic number in main address
* add caching support but not turn on by default
* change name of inference_cuda_module to inference_module
* Check for is_synchronized_device in accelerator before get Event
* fix typo
* Fix fallback path of softmax kernel on CUDA device for BF16 data type, because CUDA tril does not support BF16 datatype, enforce fp32 data type
* add cpu backend files
* change CPU_Accelerator op_builder_dir
* remove cpu_kernel_path
* using CPU_Accelerator on non-cuda device
* fix deepspeed.op_builder => deepspeed.ops.op_builder
* add alias for num_gpus: num_accelerators
* allow loading cpu_builder in build stage
* Assume cuda available if torch not installed
* add oneccl_binding_pt to requirements
* move oneccl-binding-pt to seperate requiremetns-cpu.txt
* add missing file
* use dependency_links in setuptools.setup() call for additional dependency links
* install oneccl_bind_pt in workflows
* change oneccl_bind_pt's version from 1.13 to 2.0
* use intel_exention_for_pytorch as indicator that CPU_Accelerator should be used
* Add indicator for Accelerator used
* change foo.c to foo.cpp
* exclude 'cpu' directory in CUDA op builder reflection
* add a cpu-inference workflow
* run cpu-inference workflow on self-hosted instance
* change cpu runs-on node to v100 node
* print out python version in workflow
* add verbose in pip command to understand oneccl_bind_pt install issue
* update cpu-inference workflow
* add a stage to detect instance instruction sets
* add back bf16 support for CPU inference
* enable autoTP for bloom
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* update workflow to detect cpu instruction sets
* temporary WA for Intel Extension for PyTorch AVX2 instructioon set detection
* change cpu-inference workflow machine to ubuntu-20.04
* add sharded checkpoint loading for AutoTP path to reduce the peak memory in initialization stage
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* enable policy for llama
* use a special build ipex to test avx2 detection fix
* fix format
* fix test fail issue
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* fix gptj sharded checkpoint loading problem
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* return a not implemented build in get_op_builder in cpu_backend
* support cpu device in tests
* use cpuinfo to extract number of CPUs
* use ~/tmp as transfomer cache rather than /blob/
* Add support for mpich launcher with prefer_deepspeed_comm
* add missing modification in accelerator
* enable IMPI launcher
* remove unused file and fix formatting
* clean up ccl.cpp
* Less confusing error message when certin op builder are not implemented
* Fix license header
* Add license header
* add license headers
* add license header
* fix cuda specific code in test
* update CPU workflow
* use numactl to bind to core
* allow bind_cores_to_rank in multi-node impi runner
* fix format error
* Remove InferenceBuilder
* fix format error in numa.py
* check whether op is in installed ops in ds_report.py
* allow override accelerator with DS_ACCELERATOR='cuda','cpu' or 'xpu'
* lazy init class_dict in CUDA_Accelerator to avoid cyclic initialization of CUDA_Accelerator
* put short path in the beginning in real_accelerator.py
* device_count return number of NUMA nodes
* fix typo
* install numactl in cpu workflow
* Follow comments
* Better implementation of device_count() and current_device()
* remove dependency_link for Intel Extension for DeepSpeed
* use check is_synchronized_device in timer only once
* remove env mapping WA in cpu_accelerator
* fix duplicate definition
* fix format error
* refine ccl backend selection
* move comments to the right place
* remove prefer_deepspeed_comm, use CCLBackend by default
* refractor fallback path
* Fix execution failure in kernel injection path
* do not refractory kernel injection fallback path in residual_add because it contains function call with side-effect
* guard residual_add fallback path with environ DS_KI_FALLBACK=True
* fix format error
* add test for allreduce on CPU workflow
* fix format error
* Fallback to TorchBackend if CCLBackend kernel are not implemented
* Update Intel Extension for Pytorch installation link
* Don't specify version number of Intel Extension for PyTorch
* install oneCCL for CCLBackend
* fix link path for CPU comm kernels
* fix source oneCCL environment
* source oneCCL env before run UT
* Give more specific instruction when CCL_ROOT not defined
---------
Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com>
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: sdp <sdp@aia-sdp-spr-108864.jf.intel.com>
Co-authored-by: Cao, Zhong Z <zhong.z.cao@intel.com>
Co-authored-by: Zhenhuan Chen <zhenhuan.chen@intel.com>
Co-authored-by: baodii <di.bao@intel.com>
Co-authored-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: jianan-gu <jianan.gu@intel.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2023-05-16 18:59:22 +03:00
|
|
|
module = importlib.import_module("{}.{}".format(op_builder_dir, module_name))
|
|
|
|
for member_name in module.__dir__():
|
|
|
|
if member_name.endswith(
|
|
|
|
'Builder'
|
|
|
|
) and member_name != "OpBuilder" and member_name != "CUDAOpBuilder" and member_name != "TorchCPUOpBuilder": # avoid abstract classes
|
|
|
|
if not member_name in self.class_dict:
|
|
|
|
self.class_dict[member_name] = getattr(module, member_name)
|
|
|
|
# end initialize for create_op_builder()
|
2022-11-30 15:15:58 +03:00
|
|
|
|
2023-01-26 17:03:12 +03:00
|
|
|
# create an instance of op builder and return, name specified by class_name
|
2022-11-30 15:15:58 +03:00
|
|
|
def create_op_builder(self, class_name):
|
[CPU] Support Intel CPU inference (#3041)
* add fallback path for kernels used in megatron
* temporary numactl WA for SPR 56core
* adapt core allocation according to number of ranks
* add switch to turn on numactl
* detect number of cores on the system
* allow select a subset of the cores on the system to bind
* remove unneeded changes
* add ccl backend
* change nccl to ccl
* remove unused code
* add comm/ccl to ops
* initial ccl comm support
* first broadcast case passed
* add CCL_Backend to DeepSpeed
* support comm timer for CPU
* support barrier for comm backend
* support specify master address from deepspeed command line
* support pytorch 2.0
* remove 'block' from api
* Tweak for debug
Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com>
* Remove unecessary directory
Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com>
* Add bf16 kernel support for inference
* Add temporary torch implement for cpu inference
* Add softmax ops cpu fallback for inference
* bind cores to numa domain as well
* merge latest change in gma/numactl
* initial bf16 kernel support with fallback path
* initial fallback path for bloom kernel injection
* fix softmax attn mask
* check KMP_AFFINITY to avoid conflict with numactl
* New CCLBackend which utilize TorchBackend for initialization
* rollback last change because there is result error
* fix bloom injection policy TP could not work issue.
injection_policy={BloomBlock: ("self_attention.dense", "mlp.dense_4h_to_h")}
* Use TorchBackend to initialize CCLBackend, make behavior consistent
* remove comm under deepspeed/ops
* add license header
* code clean up
* fix format issue
* remove magic number in main address
* add caching support but not turn on by default
* change name of inference_cuda_module to inference_module
* Check for is_synchronized_device in accelerator before get Event
* fix typo
* Fix fallback path of softmax kernel on CUDA device for BF16 data type, because CUDA tril does not support BF16 datatype, enforce fp32 data type
* add cpu backend files
* change CPU_Accelerator op_builder_dir
* remove cpu_kernel_path
* using CPU_Accelerator on non-cuda device
* fix deepspeed.op_builder => deepspeed.ops.op_builder
* add alias for num_gpus: num_accelerators
* allow loading cpu_builder in build stage
* Assume cuda available if torch not installed
* add oneccl_binding_pt to requirements
* move oneccl-binding-pt to seperate requiremetns-cpu.txt
* add missing file
* use dependency_links in setuptools.setup() call for additional dependency links
* install oneccl_bind_pt in workflows
* change oneccl_bind_pt's version from 1.13 to 2.0
* use intel_exention_for_pytorch as indicator that CPU_Accelerator should be used
* Add indicator for Accelerator used
* change foo.c to foo.cpp
* exclude 'cpu' directory in CUDA op builder reflection
* add a cpu-inference workflow
* run cpu-inference workflow on self-hosted instance
* change cpu runs-on node to v100 node
* print out python version in workflow
* add verbose in pip command to understand oneccl_bind_pt install issue
* update cpu-inference workflow
* add a stage to detect instance instruction sets
* add back bf16 support for CPU inference
* enable autoTP for bloom
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* update workflow to detect cpu instruction sets
* temporary WA for Intel Extension for PyTorch AVX2 instructioon set detection
* change cpu-inference workflow machine to ubuntu-20.04
* add sharded checkpoint loading for AutoTP path to reduce the peak memory in initialization stage
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* enable policy for llama
* use a special build ipex to test avx2 detection fix
* fix format
* fix test fail issue
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* fix gptj sharded checkpoint loading problem
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* return a not implemented build in get_op_builder in cpu_backend
* support cpu device in tests
* use cpuinfo to extract number of CPUs
* use ~/tmp as transfomer cache rather than /blob/
* Add support for mpich launcher with prefer_deepspeed_comm
* add missing modification in accelerator
* enable IMPI launcher
* remove unused file and fix formatting
* clean up ccl.cpp
* Less confusing error message when certin op builder are not implemented
* Fix license header
* Add license header
* add license headers
* add license header
* fix cuda specific code in test
* update CPU workflow
* use numactl to bind to core
* allow bind_cores_to_rank in multi-node impi runner
* fix format error
* Remove InferenceBuilder
* fix format error in numa.py
* check whether op is in installed ops in ds_report.py
* allow override accelerator with DS_ACCELERATOR='cuda','cpu' or 'xpu'
* lazy init class_dict in CUDA_Accelerator to avoid cyclic initialization of CUDA_Accelerator
* put short path in the beginning in real_accelerator.py
* device_count return number of NUMA nodes
* fix typo
* install numactl in cpu workflow
* Follow comments
* Better implementation of device_count() and current_device()
* remove dependency_link for Intel Extension for DeepSpeed
* use check is_synchronized_device in timer only once
* remove env mapping WA in cpu_accelerator
* fix duplicate definition
* fix format error
* refine ccl backend selection
* move comments to the right place
* remove prefer_deepspeed_comm, use CCLBackend by default
* refractor fallback path
* Fix execution failure in kernel injection path
* do not refractory kernel injection fallback path in residual_add because it contains function call with side-effect
* guard residual_add fallback path with environ DS_KI_FALLBACK=True
* fix format error
* add test for allreduce on CPU workflow
* fix format error
* Fallback to TorchBackend if CCLBackend kernel are not implemented
* Update Intel Extension for Pytorch installation link
* Don't specify version number of Intel Extension for PyTorch
* install oneCCL for CCLBackend
* fix link path for CPU comm kernels
* fix source oneCCL environment
* source oneCCL env before run UT
* Give more specific instruction when CCL_ROOT not defined
---------
Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com>
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: sdp <sdp@aia-sdp-spr-108864.jf.intel.com>
Co-authored-by: Cao, Zhong Z <zhong.z.cao@intel.com>
Co-authored-by: Zhenhuan Chen <zhenhuan.chen@intel.com>
Co-authored-by: baodii <di.bao@intel.com>
Co-authored-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: jianan-gu <jianan.gu@intel.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2023-05-16 18:59:22 +03:00
|
|
|
self._lazy_init_class_dict()
|
2023-01-07 07:40:58 +03:00
|
|
|
if class_name in self.class_dict:
|
|
|
|
return self.class_dict[class_name]()
|
2022-11-30 15:15:58 +03:00
|
|
|
else:
|
|
|
|
return None
|
|
|
|
|
2023-01-26 17:03:12 +03:00
|
|
|
# return an op builder class, name specified by class_name
|
|
|
|
def get_op_builder(self, class_name):
|
[CPU] Support Intel CPU inference (#3041)
* add fallback path for kernels used in megatron
* temporary numactl WA for SPR 56core
* adapt core allocation according to number of ranks
* add switch to turn on numactl
* detect number of cores on the system
* allow select a subset of the cores on the system to bind
* remove unneeded changes
* add ccl backend
* change nccl to ccl
* remove unused code
* add comm/ccl to ops
* initial ccl comm support
* first broadcast case passed
* add CCL_Backend to DeepSpeed
* support comm timer for CPU
* support barrier for comm backend
* support specify master address from deepspeed command line
* support pytorch 2.0
* remove 'block' from api
* Tweak for debug
Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com>
* Remove unecessary directory
Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com>
* Add bf16 kernel support for inference
* Add temporary torch implement for cpu inference
* Add softmax ops cpu fallback for inference
* bind cores to numa domain as well
* merge latest change in gma/numactl
* initial bf16 kernel support with fallback path
* initial fallback path for bloom kernel injection
* fix softmax attn mask
* check KMP_AFFINITY to avoid conflict with numactl
* New CCLBackend which utilize TorchBackend for initialization
* rollback last change because there is result error
* fix bloom injection policy TP could not work issue.
injection_policy={BloomBlock: ("self_attention.dense", "mlp.dense_4h_to_h")}
* Use TorchBackend to initialize CCLBackend, make behavior consistent
* remove comm under deepspeed/ops
* add license header
* code clean up
* fix format issue
* remove magic number in main address
* add caching support but not turn on by default
* change name of inference_cuda_module to inference_module
* Check for is_synchronized_device in accelerator before get Event
* fix typo
* Fix fallback path of softmax kernel on CUDA device for BF16 data type, because CUDA tril does not support BF16 datatype, enforce fp32 data type
* add cpu backend files
* change CPU_Accelerator op_builder_dir
* remove cpu_kernel_path
* using CPU_Accelerator on non-cuda device
* fix deepspeed.op_builder => deepspeed.ops.op_builder
* add alias for num_gpus: num_accelerators
* allow loading cpu_builder in build stage
* Assume cuda available if torch not installed
* add oneccl_binding_pt to requirements
* move oneccl-binding-pt to seperate requiremetns-cpu.txt
* add missing file
* use dependency_links in setuptools.setup() call for additional dependency links
* install oneccl_bind_pt in workflows
* change oneccl_bind_pt's version from 1.13 to 2.0
* use intel_exention_for_pytorch as indicator that CPU_Accelerator should be used
* Add indicator for Accelerator used
* change foo.c to foo.cpp
* exclude 'cpu' directory in CUDA op builder reflection
* add a cpu-inference workflow
* run cpu-inference workflow on self-hosted instance
* change cpu runs-on node to v100 node
* print out python version in workflow
* add verbose in pip command to understand oneccl_bind_pt install issue
* update cpu-inference workflow
* add a stage to detect instance instruction sets
* add back bf16 support for CPU inference
* enable autoTP for bloom
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* update workflow to detect cpu instruction sets
* temporary WA for Intel Extension for PyTorch AVX2 instructioon set detection
* change cpu-inference workflow machine to ubuntu-20.04
* add sharded checkpoint loading for AutoTP path to reduce the peak memory in initialization stage
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* enable policy for llama
* use a special build ipex to test avx2 detection fix
* fix format
* fix test fail issue
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* fix gptj sharded checkpoint loading problem
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* return a not implemented build in get_op_builder in cpu_backend
* support cpu device in tests
* use cpuinfo to extract number of CPUs
* use ~/tmp as transfomer cache rather than /blob/
* Add support for mpich launcher with prefer_deepspeed_comm
* add missing modification in accelerator
* enable IMPI launcher
* remove unused file and fix formatting
* clean up ccl.cpp
* Less confusing error message when certin op builder are not implemented
* Fix license header
* Add license header
* add license headers
* add license header
* fix cuda specific code in test
* update CPU workflow
* use numactl to bind to core
* allow bind_cores_to_rank in multi-node impi runner
* fix format error
* Remove InferenceBuilder
* fix format error in numa.py
* check whether op is in installed ops in ds_report.py
* allow override accelerator with DS_ACCELERATOR='cuda','cpu' or 'xpu'
* lazy init class_dict in CUDA_Accelerator to avoid cyclic initialization of CUDA_Accelerator
* put short path in the beginning in real_accelerator.py
* device_count return number of NUMA nodes
* fix typo
* install numactl in cpu workflow
* Follow comments
* Better implementation of device_count() and current_device()
* remove dependency_link for Intel Extension for DeepSpeed
* use check is_synchronized_device in timer only once
* remove env mapping WA in cpu_accelerator
* fix duplicate definition
* fix format error
* refine ccl backend selection
* move comments to the right place
* remove prefer_deepspeed_comm, use CCLBackend by default
* refractor fallback path
* Fix execution failure in kernel injection path
* do not refractory kernel injection fallback path in residual_add because it contains function call with side-effect
* guard residual_add fallback path with environ DS_KI_FALLBACK=True
* fix format error
* add test for allreduce on CPU workflow
* fix format error
* Fallback to TorchBackend if CCLBackend kernel are not implemented
* Update Intel Extension for Pytorch installation link
* Don't specify version number of Intel Extension for PyTorch
* install oneCCL for CCLBackend
* fix link path for CPU comm kernels
* fix source oneCCL environment
* source oneCCL env before run UT
* Give more specific instruction when CCL_ROOT not defined
---------
Signed-off-by: Cao, Zhong Z <zhong.z.cao@intel.com>
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: sdp <sdp@aia-sdp-spr-108864.jf.intel.com>
Co-authored-by: Cao, Zhong Z <zhong.z.cao@intel.com>
Co-authored-by: Zhenhuan Chen <zhenhuan.chen@intel.com>
Co-authored-by: baodii <di.bao@intel.com>
Co-authored-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: jianan-gu <jianan.gu@intel.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2023-05-16 18:59:22 +03:00
|
|
|
self._lazy_init_class_dict()
|
2023-01-26 17:03:12 +03:00
|
|
|
if class_name in self.class_dict:
|
|
|
|
return self.class_dict[class_name]
|
|
|
|
else:
|
|
|
|
return None
|
|
|
|
|
2022-11-30 15:15:58 +03:00
|
|
|
def build_extension(self):
|
|
|
|
from torch.utils.cpp_extension import BuildExtension
|
|
|
|
return BuildExtension
|
2023-12-19 06:09:10 +03:00
|
|
|
|
|
|
|
def export_envs(self):
|
|
|
|
return ['NCCL']
|
2024-04-21 02:35:50 +03:00
|
|
|
|
|
|
|
def visible_devices_envs(self):
|
|
|
|
return ['CUDA_VISIBLE_DEVICES']
|
|
|
|
|
|
|
|
def set_visible_devices_envs(self, current_env, local_accelerator_ids):
|
|
|
|
for env in self.visible_devices_envs():
|
|
|
|
current_env[env] = ",".join(map(str, local_accelerator_ids))
|
2024-04-24 18:25:18 +03:00
|
|
|
|
|
|
|
def get_compile_backend(self):
|
|
|
|
return self._compile_backend
|
|
|
|
|
|
|
|
def set_compile_backend(self, backend):
|
|
|
|
supported_backends = torch._dynamo.list_backends(exclude_tags=())
|
|
|
|
if backend in supported_backends:
|
|
|
|
self._compile_backend = backend
|
|
|
|
else:
|
|
|
|
raise ValueError(
|
|
|
|
f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}")
|