[CANN] Support cpu offload optimizer for Ascend NPU (#4568)

Support cpu_adam, cpu_adagrad and cpu_lion optimizer for Ascend NPU. All
these optimizer are running on host, the difference between each backend
is the way to copy params back to device. This commit add a new symbol
called "__ENABLE_CANN__". This symbol can compile code adapted to NPU.
The NPU builder adds the required header files and libraries for
compiling, according to CANN's compilation manual.
Note that there's no FusedLion implementation for NPU, test_cpu_lion
test case should disabled until FusedLion optimizer implemented.

Besides, when NPU is selected as the accelerator, ds_report will show
torch_npu and CANN informations.

With this PR, deepspeed test cases in
[huggingface/accelerate](https://github.com/huggingface/accelerate/tree/main/tests/deepspeed)
are all passed.

It's a part of feature list for Ascend NPU support, @see #4567

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
This commit is contained in:
hipudding 2023-11-14 21:37:16 +08:00 коммит произвёл GitHub
Родитель efd4556345
Коммит c1ba6a104f
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
13 изменённых файлов: 343 добавлений и 29 удалений

Просмотреть файл

@ -47,6 +47,8 @@ void Adagrad_Optimizer::Step_1(float* _params,
size_t offset = copy_size + t;
#if defined(__ENABLE_CUDA__)
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
#elif defined(__ENABLE_CANN__)
if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); }
#endif
#pragma omp parallel for
for (size_t k = t; k < offset; k++) {
@ -62,7 +64,7 @@ void Adagrad_Optimizer::Step_1(float* _params,
grad += _eps;
grad = momentum / grad;
param = grad * step_size + param;
#if defined(__ENABLE_CUDA__)
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
if (dev_params) _doubled_buffer[_buf_index][k - t] = param;
#endif
if (half_precision)
@ -79,6 +81,17 @@ void Adagrad_Optimizer::Step_1(float* _params,
_doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]);
_buf_index = !_buf_index;
}
#elif defined(__ENABLE_CANN__)
if (dev_params) {
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
aclrtMemcpy(dev_params + t,
memcpy_size,
_doubled_buffer[_buf_index],
memcpy_size,
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);
_buf_index = !_buf_index;
}
#endif
}
}
@ -180,7 +193,7 @@ int ds_adagrad_step(int optimizer_id,
opt->update_state(lr, epsilon, weight_decay);
opt->Step_8(params_ptr, grads_ptr, exp_avg_sq_ptr, params_c.numel());
#if defined(__ENABLE_CUDA__)
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
opt->SynchronizeStreams();
#endif
return 0;
@ -196,7 +209,7 @@ int ds_adagrad_step_plus_copy(int optimizer_id,
torch::Tensor& exp_avg_sq,
torch::Tensor& gpu_params)
{
#if defined(__ENABLE_CUDA__)
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
auto params_c = params.contiguous();
auto gpu_params_c = gpu_params.contiguous();
auto exp_avg_sq_c = exp_avg_sq.contiguous();

Просмотреть файл

@ -61,6 +61,8 @@ void Adam_Optimizer::Step_1(float* _params,
size_t offset = copy_size + t;
#if defined(__ENABLE_CUDA__)
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
#elif defined(__ENABLE_CANN__)
if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); }
#endif
#pragma omp parallel for
for (size_t k = t; k < offset; k++) {
@ -81,7 +83,7 @@ void Adam_Optimizer::Step_1(float* _params,
grad = momentum / grad;
if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; }
param = grad * step_size + param;
#if defined(__ENABLE_CUDA__)
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
if (dev_params) _doubled_buffer[_buf_index][k - t] = param;
#endif
if (half_precision)
@ -96,6 +98,17 @@ void Adam_Optimizer::Step_1(float* _params,
launch_param_update(
_doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]);
_buf_index = !_buf_index;
}
#elif defined(__ENABLE_CANN__)
if (dev_params) {
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
aclrtMemcpy(dev_params + t,
memcpy_size,
_doubled_buffer[_buf_index],
memcpy_size,
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);
_buf_index = !_buf_index;
}
#endif
@ -239,7 +252,7 @@ int ds_adam_step(int optimizer_id,
nullptr,
(params.options().dtype() == at::kHalf));
#if defined(__ENABLE_CUDA__)
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
opt->SynchronizeStreams();
#endif
return 0;
@ -257,18 +270,18 @@ int ds_adam_step_plus_copy(int optimizer_id,
torch::Tensor& grads,
torch::Tensor& exp_avg,
torch::Tensor& exp_avg_sq,
torch::Tensor& gpu_params)
torch::Tensor& device_params)
{
#if defined(__ENABLE_CUDA__)
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
auto params_c = params.contiguous();
auto gpu_params_c = gpu_params.contiguous();
auto device_params_c = device_params.contiguous();
auto exp_avg_c = exp_avg.contiguous();
auto exp_avg_sq_c = exp_avg_sq.contiguous();
auto grads_c = grads.contiguous();
float* params_ptr = (float*)params_c.data_ptr();
float* grads_ptr = (float*)grads_c.data_ptr();
ds_half_precision_t* gpu_params_ptr = (ds_half_precision_t*)gpu_params_c.data_ptr();
ds_half_precision_t* device_params_ptr = (ds_half_precision_t*)device_params_c.data_ptr();
float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
@ -281,7 +294,7 @@ int ds_adam_step_plus_copy(int optimizer_id,
exp_avg_ptr,
exp_avg_sq_ptr,
params_c.numel(),
gpu_params_ptr,
device_params_ptr,
(params.options().dtype() == at::kHalf));
opt->SynchronizeStreams();

Просмотреть файл

@ -18,6 +18,10 @@
#include "cuda.h"
#include "custom_cuda_layers.h"
typedef __half ds_half_precision_t;
#elif defined(__ENABLE_CANN__)
#include "acl/acl.h"
#include "torch_npu/csrc/core/npu/NPUStream.h"
typedef c10::Half ds_half_precision_t;
#else
typedef unsigned short ds_half_precision_t;
#endif
@ -41,6 +45,11 @@ public:
_streams[0] = TrainingContext::Instance().GetCurrentStream();
_streams[1] = TrainingContext::Instance().GetNewStream();
_buf_index = false;
#elif defined(__ENABLE_CANN__)
aclrtMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
aclrtMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
_buf_index = false;
#endif
}
@ -49,6 +58,9 @@ public:
#if defined(__ENABLE_CUDA__)
cudaFreeHost(_doubled_buffer[0]);
cudaFreeHost(_doubled_buffer[1]);
#elif defined(__ENABLE_CANN__)
aclrtFreeHost(_doubled_buffer[0]);
aclrtFreeHost(_doubled_buffer[1]);
#endif
}
#if defined(__AVX512__) or defined(__AVX256__)
@ -69,6 +81,11 @@ public:
{
for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]);
}
#elif defined(__ENABLE_CANN__)
inline void SynchronizeStreams()
{
for (int i = 0; i < 2; i++) aclrtSynchronizeStream(_streams[i].stream());
}
#endif
inline void IncrementStep(size_t step)
{
@ -95,6 +112,11 @@ private:
bool _buf_index;
float* _doubled_buffer[2];
cudaStream_t _streams[2];
#elif defined(__ENABLE_CANN__)
float* _doubled_buffer[2];
c10_npu::NPUStream _streams[2] = {c10_npu::getCurrentNPUStream(),
c10_npu::getNPUStreamFromPool()};
bool _buf_index;
#endif
};
@ -125,6 +147,8 @@ void Adagrad_Optimizer::Step_AVX(size_t* rounded_size,
size_t offset = copy_size + t;
#if defined(__ENABLE_CUDA__)
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
#elif defined(__ENABLE_CANN__)
if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); }
#endif
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH * span) {
@ -149,7 +173,7 @@ void Adagrad_Optimizer::Step_AVX(size_t* rounded_size,
simd_fma<span>(param_4, grad_4, step_size_4, param_4);
simd_store<span>(_params + i, param_4, half_precision);
#if defined(__ENABLE_CUDA__)
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
if (dev_params) {
simd_store<span>(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision);
}
@ -167,6 +191,17 @@ void Adagrad_Optimizer::Step_AVX(size_t* rounded_size,
_buf_index = !_buf_index;
}
#elif defined(__ENABLE_CANN__)
if (dev_params) {
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
if (half_precision) memoryCopySize /= 2;
aclrtMemcpy(dev_params + t,
memcpy_size,
_doubled_buffer[_buf_index],
memcpy_size,
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);
_buf_index = !_buf_index;
#endif
}
*rounded_size = new_rounded_size;

Просмотреть файл

@ -19,6 +19,10 @@
#include "cuda.h"
#include "custom_cuda_layers.h"
typedef __half ds_half_precision_t;
#elif defined(__ENABLE_CANN__)
#include "acl/acl.h"
#include "torch_npu/csrc/core/npu/NPUStream.h"
typedef c10::Half ds_half_precision_t;
#else
#include <cmath>
typedef unsigned short ds_half_precision_t;
@ -57,6 +61,11 @@ public:
_streams[0] = TrainingContext::Instance().GetCurrentStream();
_streams[1] = TrainingContext::Instance().GetNewStream();
_buf_index = false;
#elif defined(__ENABLE_CANN__)
aclrtMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
aclrtMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
_buf_index = false;
#endif
}
@ -65,6 +74,9 @@ public:
#if defined(__ENABLE_CUDA__)
cudaFreeHost(_doubled_buffer[0]);
cudaFreeHost(_doubled_buffer[1]);
#elif defined(__ENABLE_CANN__)
aclrtFreeHost(_doubled_buffer[0]);
aclrtFreeHost(_doubled_buffer[1]);
#endif
}
@ -87,6 +99,11 @@ public:
{
for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]);
}
#elif defined(__ENABLE_CANN__)
inline void SynchronizeStreams()
{
for (int i = 0; i < 2; i++) aclrtSynchronizeStream(_streams[i].stream());
}
#endif
inline void IncrementStep(size_t step, float beta1, float beta2)
{
@ -142,6 +159,11 @@ private:
float* _doubled_buffer[2];
cudaStream_t _streams[2];
bool _buf_index;
#elif defined(__ENABLE_CANN__)
float* _doubled_buffer[2];
c10_npu::NPUStream _streams[2] = {c10_npu::getCurrentNPUStream(),
c10_npu::getNPUStreamFromPool()};
bool _buf_index;
#endif
};
@ -192,6 +214,9 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size,
size_t offset = copy_size + t;
#if defined(__ENABLE_CUDA__)
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
#elif defined(__ENABLE_CANN__)
if ((t / TILE) >= 2) { aclrtSynchronizeStream((_streams[_buf_index].stream());
}
#endif
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH * span) {
@ -227,7 +252,7 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size,
simd_fma<span>(param_4, grad_4, step_size_4, param_4);
simd_store<span>(_params + (i >> rshft), param_4, half_precision);
#if defined(__ENABLE_CUDA__)
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
if (dev_params) {
simd_store<span>(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision);
}
@ -246,6 +271,17 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size,
_buf_index = !_buf_index;
}
#elif defined(__ENABLE_CANN__)
if (dev_params) {
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
if (half_precision) memoryCopySize /= 2;
aclrtMemcpy(dev_params + t,
memcpy_size,
_doubled_buffer[_buf_index],
memcpy_size,
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);
_buf_index = !_buf_index;
#endif
}
*rounded_size = new_rounded_size;

Просмотреть файл

@ -19,6 +19,10 @@
#include "cuda.h"
#include "custom_cuda_layers.h"
typedef __half ds_half_precision_t;
#elif defined(__ENABLE_CANN__)
#include "acl/acl.h"
#include "torch_npu/csrc/core/npu/NPUStream.h"
typedef c10::Half ds_half_precision_t;
#else
#include <cmath>
typedef unsigned short ds_half_precision_t;
@ -46,6 +50,11 @@ public:
_streams[0] = TrainingContext::Instance().GetCurrentStream();
_streams[1] = TrainingContext::Instance().GetNewStream();
_buf_index = false;
#elif defined(__ENABLE_CANN__)
aclrtMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
aclrtMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
_buf_index = false;
#endif
}
@ -54,6 +63,9 @@ public:
#if defined(__ENABLE_CUDA__)
cudaFreeHost(_doubled_buffer[0]);
cudaFreeHost(_doubled_buffer[1]);
#elif defined(__ENABLE_CANN__)
aclrtFreeHost(_doubled_buffer[0]);
aclrtFreeHost(_doubled_buffer[1]);
#endif
}
@ -75,6 +87,11 @@ public:
{
for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]);
}
#elif defined(__ENABLE_CANN__)
inline void SynchronizeStreams()
{
for (int i = 0; i < 2; i++) aclrtSynchronizeStream(_streams[i].stream());
}
#endif
inline void IncrementStep(size_t step, float beta1, float beta2)
{
@ -102,6 +119,11 @@ private:
float* _doubled_buffer[2];
cudaStream_t _streams[2];
bool _buf_index;
#elif defined(__ENABLE_CANN__)
float* _doubled_buffer[2];
c10_npu::NPUStream _streams[2] = {c10_npu::getCurrentNPUStream(),
c10_npu::getNPUStreamFromPool()};
bool _buf_index;
#endif
};
@ -149,6 +171,8 @@ void Lion_Optimizer::Step_AVX(size_t* rounded_size,
size_t offset = copy_size + t;
#if defined(__ENABLE_CUDA__)
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
#elif defined(__ENABLE_CANN__)
if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); }
#endif
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH * span) {
@ -178,7 +202,7 @@ void Lion_Optimizer::Step_AVX(size_t* rounded_size,
simd_fma<span>(momentum_4, grad_4, betta2_minus1_4, momentum_4);
simd_store<span>(_params + (i >> rshft), param_4, half_precision);
#if defined(__ENABLE_CUDA__)
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
if (dev_params) {
simd_store<span>(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision);
}
@ -196,6 +220,17 @@ void Lion_Optimizer::Step_AVX(size_t* rounded_size,
_buf_index = !_buf_index;
}
#elif defined(__ENABLE_CANN__)
if (dev_params) {
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
if (half_precision) memoryCopySize /= 2;
aclrtMemcpy(dev_params + t,
memcpy_size,
_doubled_buffer[_buf_index],
memcpy_size,
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);
_buf_index = !_buf_index;
#endif
}
*rounded_size = new_rounded_size;

Просмотреть файл

@ -54,6 +54,8 @@ void Lion_Optimizer::Step_1(float* _params,
size_t offset = copy_size + t;
#if defined(__ENABLE_CUDA__)
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
#elif defined(__ENABLE_CANN__)
if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); }
#endif
#pragma omp parallel for
for (size_t k = t; k < offset; k++) {
@ -72,7 +74,7 @@ void Lion_Optimizer::Step_1(float* _params,
}
momentum = momentum * _betta2;
momentum = grad * betta2_minus1 + momentum;
#if defined(__ENABLE_CUDA__)
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
if (dev_params) _doubled_buffer[_buf_index][k - t] = param;
#endif
if (half_precision)
@ -86,6 +88,17 @@ void Lion_Optimizer::Step_1(float* _params,
launch_param_update(
_doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]);
_buf_index = !_buf_index;
}
#elif defined(__ENABLE_CANN__)
if (dev_params) {
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
aclrtMemcpy(dev_params + t,
memcpy_size,
_doubled_buffer[_buf_index],
memcpy_size,
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);
_buf_index = !_buf_index;
}
#endif
@ -201,7 +214,7 @@ int ds_lion_step(int optimizer_id,
nullptr,
(params.options().dtype() == at::kHalf));
#if defined(__ENABLE_CUDA__)
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
opt->SynchronizeStreams();
#endif
return 0;
@ -218,7 +231,7 @@ int ds_lion_step_plus_copy(int optimizer_id,
torch::Tensor& exp_avg,
torch::Tensor& gpu_params)
{
#if defined(__ENABLE_CUDA__)
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
auto params_c = params.contiguous();
auto gpu_params_c = gpu_params.contiguous();
auto exp_avg_c = exp_avg.contiguous();

Просмотреть файл

@ -80,6 +80,32 @@ def nvcc_version():
return ".".join(release)
def installed_cann_path():
if "ASCEND_HOME_PATH" in os.environ or os.path.exists(os.environ["ASCEND_HOME_PATH"]):
return os.environ["ASCEND_HOME_PATH"]
return None
def installed_cann_version():
import re
ascend_path = installed_cann_path()
if ascend_path is None:
return f"CANN_HOME does not exist, unable to compile NPU op(s)"
cann_version = ""
for dirpath, _, filenames in os.walk(os.path.realpath(ascend_path)):
if cann_version:
break
install_files = [file for file in filenames if re.match(r"ascend_.*_install\.info", file)]
if install_files:
filepath = os.path.join(dirpath, install_files[0])
with open(filepath, "r") as f:
for line in f:
if line.find("version") != -1:
cann_version = line.strip().split("=")[-1]
break
return cann_version
def get_shm_size():
try:
shm_stats = os.statvfs('/dev/shm')
@ -122,6 +148,11 @@ def debug_report():
("deepspeed wheel compiled w.", f"torch {torch_info['version']}, " +
(f"hip {torch_info['hip_version']}" if hip_version else f"cuda {torch_info['cuda_version']}"))
])
elif get_accelerator().device_name() == 'npu':
import torch_npu
report.extend([("deepspeed wheel compiled w.", f"torch {torch_info['version']}"),
("torch_npu install path", torch_npu.__path__), ("torch_npu version", torch_npu.__version__),
("ascend_cann version", installed_cann_version())])
else:
report.extend([("deepspeed wheel compiled w.", f"torch {torch_info['version']} ")])

Просмотреть файл

@ -4,6 +4,8 @@
# DeepSpeed Team
'''Copyright The Microsoft DeepSpeed Team'''
# NPU related operators will be added in the future.
from .fused_adam import FusedAdamBuilder
from .no_impl import NotImplementedBuilder
from .cpu_adam import CPUAdamBuilder
from .cpu_adagrad import CPUAdagradBuilder
from .cpu_lion import CPULionBuilder

Просмотреть файл

@ -3,6 +3,13 @@
# DeepSpeed Team
import re
import os
try:
import torch_npu
except ImportError as e:
pass
try:
# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
# if successful this also means we're doing a local install and not JIT compile path
@ -13,22 +20,67 @@ except ImportError:
class NPUOpBuilder(OpBuilder):
_ascend_path = None
_torch_npu_path = None
_cann_version = None
def builder(self):
from torch.utils.cpp_extension import CppExtension as ExtensionBuilder
def __init__(self, name):
super().__init__(name)
self._ascend_path = self.installed_cann_path()
self._torch_npu_path = os.path.join(os.path.dirname(os.path.abspath(torch_npu.__file__)))
try:
self._cann_version = self.installed_cann_version(self.name)
except BaseException:
print(f"{self.name} ascend_cann is missing, npu ops cannot be compiled!")
compile_args = {'cxx': self.strip_empty_entries(self.cxx_args())}
def cann_defs(self):
if self._cann_version:
return '-D__ENABLE_CANN__'
return '-D__DISABLE_CANN__'
cpp_ext = ExtensionBuilder(name=self.absolute_name(),
sources=self.strip_empty_entries(self.sources()),
include_dirs=self.strip_empty_entries(self.include_paths()),
libraries=self.strip_empty_entries(self.libraries_args()),
extra_compile_args=compile_args)
def installed_cann_path(self):
if "ASCEND_HOME_PATH" in os.environ or os.path.exists(os.environ["ASCEND_HOME_PATH"]):
return os.environ["ASCEND_HOME_PATH"]
return None
return cpp_ext
def installed_cann_version(self, name=""):
ascend_path = self.installed_cann_path()
assert ascend_path is not None, "CANN_HOME does not exist, unable to compile NPU op(s)"
cann_version = ""
for dirpath, _, filenames in os.walk(os.path.realpath(ascend_path)):
if cann_version:
break
install_files = [file for file in filenames if re.match(r"ascend_.*_install\.info", file)]
if install_files:
filepath = os.path.join(dirpath, install_files[0])
with open(filepath, "r") as f:
for line in f:
if line.find("version") != -1:
cann_version = line.strip().split("=")[-1]
break
return cann_version
def include_paths(self):
paths = super().include_paths()
paths += [os.path.join(self._ascend_path, 'include'), os.path.join(self._torch_npu_path, 'include')]
return paths
def cxx_args(self):
return []
args = super().cxx_args()
args += ['-O3', '-std=c++17', '-g', '-Wno-reorder', '-fopenmp']
args += ['-fstack-protector-all', '-Wl,-z,relro,-z,now,-z,noexecstack', '-Wl,--disable-new-dtags,--rpath']
args += [
self.cann_defs(),
self.cpu_arch(),
self.simd_width(), '-L' + os.path.join(self._ascend_path, 'lib64'),
'-L' + os.path.join(self._torch_npu_path, 'lib')
]
return args
def libraries_args(self):
return []
def extra_ldflags(self):
flags = super().extra_ldflags()
flags += [
'-L' + os.path.join(self._ascend_path, 'lib64'), '-lascendcl',
'-L' + os.path.join(self._torch_npu_path, 'lib'), '-ltorch_npu'
]
return flags

Просмотреть файл

@ -0,0 +1,25 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .builder import NPUOpBuilder
class CPUAdagradBuilder(NPUOpBuilder):
BUILD_VAR = "DS_BUILD_CPU_ADAGRAD"
NAME = "cpu_adagrad"
def __init__(self):
super().__init__(name=self.NAME)
def absolute_name(self):
return f'deepspeed.ops.adagrad.{self.NAME}_op'
def sources(self):
return ['csrc/adagrad/cpu_adagrad.cpp']
def include_paths(self):
args = super().include_paths()
args += ['csrc/includes']
return args

Просмотреть файл

@ -0,0 +1,25 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .builder import NPUOpBuilder
class CPUAdamBuilder(NPUOpBuilder):
BUILD_VAR = "DS_BUILD_CPU_ADAM"
NAME = "cpu_adam"
def __init__(self):
super().__init__(name=self.NAME)
def absolute_name(self):
return f'deepspeed.ops.adam.{self.NAME}_op'
def sources(self):
return ['csrc/adam/cpu_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp']
def include_paths(self):
args = super().include_paths()
args += ['csrc/includes']
return args

Просмотреть файл

@ -0,0 +1,25 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .builder import NPUOpBuilder
class CPULionBuilder(NPUOpBuilder):
BUILD_VAR = "DS_BUILD_CPU_LION"
NAME = "cpu_lion"
def __init__(self):
super().__init__(name=self.NAME)
def absolute_name(self):
return f'deepspeed.ops.lion.{self.NAME}_op'
def sources(self):
return ['csrc/lion/cpu_lion.cpp', 'csrc/lion/cpu_lion_impl.cpp']
def include_paths(self):
args = super().include_paths()
args += ['csrc/includes']
return args

Просмотреть файл

@ -22,3 +22,12 @@ class NotImplementedBuilder(NPUOpBuilder):
def sources(self):
return []
def cxx_args(self):
return []
def extra_ldflags(self):
return []
def include_paths(self):
return []