зеркало из https://github.com/microsoft/DeepSpeed.git
[CANN] Support cpu offload optimizer for Ascend NPU (#4568)
Support cpu_adam, cpu_adagrad and cpu_lion optimizer for Ascend NPU. All these optimizer are running on host, the difference between each backend is the way to copy params back to device. This commit add a new symbol called "__ENABLE_CANN__". This symbol can compile code adapted to NPU. The NPU builder adds the required header files and libraries for compiling, according to CANN's compilation manual. Note that there's no FusedLion implementation for NPU, test_cpu_lion test case should disabled until FusedLion optimizer implemented. Besides, when NPU is selected as the accelerator, ds_report will show torch_npu and CANN informations. With this PR, deepspeed test cases in [huggingface/accelerate](https://github.com/huggingface/accelerate/tree/main/tests/deepspeed) are all passed. It's a part of feature list for Ascend NPU support, @see #4567 --------- Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
This commit is contained in:
Родитель
efd4556345
Коммит
c1ba6a104f
|
@ -47,6 +47,8 @@ void Adagrad_Optimizer::Step_1(float* _params,
|
|||
size_t offset = copy_size + t;
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); }
|
||||
#endif
|
||||
#pragma omp parallel for
|
||||
for (size_t k = t; k < offset; k++) {
|
||||
|
@ -62,7 +64,7 @@ void Adagrad_Optimizer::Step_1(float* _params,
|
|||
grad += _eps;
|
||||
grad = momentum / grad;
|
||||
param = grad * step_size + param;
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
|
||||
if (dev_params) _doubled_buffer[_buf_index][k - t] = param;
|
||||
#endif
|
||||
if (half_precision)
|
||||
|
@ -79,6 +81,17 @@ void Adagrad_Optimizer::Step_1(float* _params,
|
|||
_doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]);
|
||||
_buf_index = !_buf_index;
|
||||
}
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
if (dev_params) {
|
||||
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
|
||||
aclrtMemcpy(dev_params + t,
|
||||
memcpy_size,
|
||||
_doubled_buffer[_buf_index],
|
||||
memcpy_size,
|
||||
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);
|
||||
|
||||
_buf_index = !_buf_index;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
@ -180,7 +193,7 @@ int ds_adagrad_step(int optimizer_id,
|
|||
opt->update_state(lr, epsilon, weight_decay);
|
||||
opt->Step_8(params_ptr, grads_ptr, exp_avg_sq_ptr, params_c.numel());
|
||||
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
|
||||
opt->SynchronizeStreams();
|
||||
#endif
|
||||
return 0;
|
||||
|
@ -196,7 +209,7 @@ int ds_adagrad_step_plus_copy(int optimizer_id,
|
|||
torch::Tensor& exp_avg_sq,
|
||||
torch::Tensor& gpu_params)
|
||||
{
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
|
||||
auto params_c = params.contiguous();
|
||||
auto gpu_params_c = gpu_params.contiguous();
|
||||
auto exp_avg_sq_c = exp_avg_sq.contiguous();
|
||||
|
|
|
@ -61,6 +61,8 @@ void Adam_Optimizer::Step_1(float* _params,
|
|||
size_t offset = copy_size + t;
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); }
|
||||
#endif
|
||||
#pragma omp parallel for
|
||||
for (size_t k = t; k < offset; k++) {
|
||||
|
@ -81,7 +83,7 @@ void Adam_Optimizer::Step_1(float* _params,
|
|||
grad = momentum / grad;
|
||||
if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; }
|
||||
param = grad * step_size + param;
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
|
||||
if (dev_params) _doubled_buffer[_buf_index][k - t] = param;
|
||||
#endif
|
||||
if (half_precision)
|
||||
|
@ -96,6 +98,17 @@ void Adam_Optimizer::Step_1(float* _params,
|
|||
launch_param_update(
|
||||
_doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]);
|
||||
|
||||
_buf_index = !_buf_index;
|
||||
}
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
if (dev_params) {
|
||||
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
|
||||
aclrtMemcpy(dev_params + t,
|
||||
memcpy_size,
|
||||
_doubled_buffer[_buf_index],
|
||||
memcpy_size,
|
||||
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);
|
||||
|
||||
_buf_index = !_buf_index;
|
||||
}
|
||||
#endif
|
||||
|
@ -239,7 +252,7 @@ int ds_adam_step(int optimizer_id,
|
|||
nullptr,
|
||||
(params.options().dtype() == at::kHalf));
|
||||
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
|
||||
opt->SynchronizeStreams();
|
||||
#endif
|
||||
return 0;
|
||||
|
@ -257,18 +270,18 @@ int ds_adam_step_plus_copy(int optimizer_id,
|
|||
torch::Tensor& grads,
|
||||
torch::Tensor& exp_avg,
|
||||
torch::Tensor& exp_avg_sq,
|
||||
torch::Tensor& gpu_params)
|
||||
torch::Tensor& device_params)
|
||||
{
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
|
||||
auto params_c = params.contiguous();
|
||||
auto gpu_params_c = gpu_params.contiguous();
|
||||
auto device_params_c = device_params.contiguous();
|
||||
auto exp_avg_c = exp_avg.contiguous();
|
||||
auto exp_avg_sq_c = exp_avg_sq.contiguous();
|
||||
auto grads_c = grads.contiguous();
|
||||
|
||||
float* params_ptr = (float*)params_c.data_ptr();
|
||||
float* grads_ptr = (float*)grads_c.data_ptr();
|
||||
ds_half_precision_t* gpu_params_ptr = (ds_half_precision_t*)gpu_params_c.data_ptr();
|
||||
ds_half_precision_t* device_params_ptr = (ds_half_precision_t*)device_params_c.data_ptr();
|
||||
float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
|
||||
float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
|
||||
|
||||
|
@ -281,7 +294,7 @@ int ds_adam_step_plus_copy(int optimizer_id,
|
|||
exp_avg_ptr,
|
||||
exp_avg_sq_ptr,
|
||||
params_c.numel(),
|
||||
gpu_params_ptr,
|
||||
device_params_ptr,
|
||||
(params.options().dtype() == at::kHalf));
|
||||
|
||||
opt->SynchronizeStreams();
|
||||
|
|
|
@ -18,6 +18,10 @@
|
|||
#include "cuda.h"
|
||||
#include "custom_cuda_layers.h"
|
||||
typedef __half ds_half_precision_t;
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
#include "acl/acl.h"
|
||||
#include "torch_npu/csrc/core/npu/NPUStream.h"
|
||||
typedef c10::Half ds_half_precision_t;
|
||||
#else
|
||||
typedef unsigned short ds_half_precision_t;
|
||||
#endif
|
||||
|
@ -41,6 +45,11 @@ public:
|
|||
|
||||
_streams[0] = TrainingContext::Instance().GetCurrentStream();
|
||||
_streams[1] = TrainingContext::Instance().GetNewStream();
|
||||
_buf_index = false;
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
aclrtMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
|
||||
aclrtMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
|
||||
|
||||
_buf_index = false;
|
||||
#endif
|
||||
}
|
||||
|
@ -49,6 +58,9 @@ public:
|
|||
#if defined(__ENABLE_CUDA__)
|
||||
cudaFreeHost(_doubled_buffer[0]);
|
||||
cudaFreeHost(_doubled_buffer[1]);
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
aclrtFreeHost(_doubled_buffer[0]);
|
||||
aclrtFreeHost(_doubled_buffer[1]);
|
||||
#endif
|
||||
}
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
|
@ -69,6 +81,11 @@ public:
|
|||
{
|
||||
for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]);
|
||||
}
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
inline void SynchronizeStreams()
|
||||
{
|
||||
for (int i = 0; i < 2; i++) aclrtSynchronizeStream(_streams[i].stream());
|
||||
}
|
||||
#endif
|
||||
inline void IncrementStep(size_t step)
|
||||
{
|
||||
|
@ -95,6 +112,11 @@ private:
|
|||
bool _buf_index;
|
||||
float* _doubled_buffer[2];
|
||||
cudaStream_t _streams[2];
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
float* _doubled_buffer[2];
|
||||
c10_npu::NPUStream _streams[2] = {c10_npu::getCurrentNPUStream(),
|
||||
c10_npu::getNPUStreamFromPool()};
|
||||
bool _buf_index;
|
||||
#endif
|
||||
};
|
||||
|
||||
|
@ -125,6 +147,8 @@ void Adagrad_Optimizer::Step_AVX(size_t* rounded_size,
|
|||
size_t offset = copy_size + t;
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); }
|
||||
#endif
|
||||
#pragma omp parallel for
|
||||
for (size_t i = t; i < offset; i += SIMD_WIDTH * span) {
|
||||
|
@ -149,7 +173,7 @@ void Adagrad_Optimizer::Step_AVX(size_t* rounded_size,
|
|||
simd_fma<span>(param_4, grad_4, step_size_4, param_4);
|
||||
|
||||
simd_store<span>(_params + i, param_4, half_precision);
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
|
||||
if (dev_params) {
|
||||
simd_store<span>(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision);
|
||||
}
|
||||
|
@ -167,6 +191,17 @@ void Adagrad_Optimizer::Step_AVX(size_t* rounded_size,
|
|||
|
||||
_buf_index = !_buf_index;
|
||||
}
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
if (dev_params) {
|
||||
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
|
||||
if (half_precision) memoryCopySize /= 2;
|
||||
aclrtMemcpy(dev_params + t,
|
||||
memcpy_size,
|
||||
_doubled_buffer[_buf_index],
|
||||
memcpy_size,
|
||||
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);
|
||||
|
||||
_buf_index = !_buf_index;
|
||||
#endif
|
||||
}
|
||||
*rounded_size = new_rounded_size;
|
||||
|
|
|
@ -19,6 +19,10 @@
|
|||
#include "cuda.h"
|
||||
#include "custom_cuda_layers.h"
|
||||
typedef __half ds_half_precision_t;
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
#include "acl/acl.h"
|
||||
#include "torch_npu/csrc/core/npu/NPUStream.h"
|
||||
typedef c10::Half ds_half_precision_t;
|
||||
#else
|
||||
#include <cmath>
|
||||
typedef unsigned short ds_half_precision_t;
|
||||
|
@ -57,6 +61,11 @@ public:
|
|||
|
||||
_streams[0] = TrainingContext::Instance().GetCurrentStream();
|
||||
_streams[1] = TrainingContext::Instance().GetNewStream();
|
||||
_buf_index = false;
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
aclrtMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
|
||||
aclrtMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
|
||||
|
||||
_buf_index = false;
|
||||
#endif
|
||||
}
|
||||
|
@ -65,6 +74,9 @@ public:
|
|||
#if defined(__ENABLE_CUDA__)
|
||||
cudaFreeHost(_doubled_buffer[0]);
|
||||
cudaFreeHost(_doubled_buffer[1]);
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
aclrtFreeHost(_doubled_buffer[0]);
|
||||
aclrtFreeHost(_doubled_buffer[1]);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -87,6 +99,11 @@ public:
|
|||
{
|
||||
for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]);
|
||||
}
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
inline void SynchronizeStreams()
|
||||
{
|
||||
for (int i = 0; i < 2; i++) aclrtSynchronizeStream(_streams[i].stream());
|
||||
}
|
||||
#endif
|
||||
inline void IncrementStep(size_t step, float beta1, float beta2)
|
||||
{
|
||||
|
@ -142,6 +159,11 @@ private:
|
|||
float* _doubled_buffer[2];
|
||||
cudaStream_t _streams[2];
|
||||
bool _buf_index;
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
float* _doubled_buffer[2];
|
||||
c10_npu::NPUStream _streams[2] = {c10_npu::getCurrentNPUStream(),
|
||||
c10_npu::getNPUStreamFromPool()};
|
||||
bool _buf_index;
|
||||
#endif
|
||||
};
|
||||
|
||||
|
@ -192,6 +214,9 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size,
|
|||
size_t offset = copy_size + t;
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
if ((t / TILE) >= 2) { aclrtSynchronizeStream((_streams[_buf_index].stream());
|
||||
}
|
||||
#endif
|
||||
#pragma omp parallel for
|
||||
for (size_t i = t; i < offset; i += SIMD_WIDTH * span) {
|
||||
|
@ -227,7 +252,7 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size,
|
|||
simd_fma<span>(param_4, grad_4, step_size_4, param_4);
|
||||
|
||||
simd_store<span>(_params + (i >> rshft), param_4, half_precision);
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
|
||||
if (dev_params) {
|
||||
simd_store<span>(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision);
|
||||
}
|
||||
|
@ -246,6 +271,17 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size,
|
|||
|
||||
_buf_index = !_buf_index;
|
||||
}
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
if (dev_params) {
|
||||
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
|
||||
if (half_precision) memoryCopySize /= 2;
|
||||
aclrtMemcpy(dev_params + t,
|
||||
memcpy_size,
|
||||
_doubled_buffer[_buf_index],
|
||||
memcpy_size,
|
||||
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);
|
||||
|
||||
_buf_index = !_buf_index;
|
||||
#endif
|
||||
}
|
||||
*rounded_size = new_rounded_size;
|
||||
|
|
|
@ -19,6 +19,10 @@
|
|||
#include "cuda.h"
|
||||
#include "custom_cuda_layers.h"
|
||||
typedef __half ds_half_precision_t;
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
#include "acl/acl.h"
|
||||
#include "torch_npu/csrc/core/npu/NPUStream.h"
|
||||
typedef c10::Half ds_half_precision_t;
|
||||
#else
|
||||
#include <cmath>
|
||||
typedef unsigned short ds_half_precision_t;
|
||||
|
@ -46,6 +50,11 @@ public:
|
|||
|
||||
_streams[0] = TrainingContext::Instance().GetCurrentStream();
|
||||
_streams[1] = TrainingContext::Instance().GetNewStream();
|
||||
_buf_index = false;
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
aclrtMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
|
||||
aclrtMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
|
||||
|
||||
_buf_index = false;
|
||||
#endif
|
||||
}
|
||||
|
@ -54,6 +63,9 @@ public:
|
|||
#if defined(__ENABLE_CUDA__)
|
||||
cudaFreeHost(_doubled_buffer[0]);
|
||||
cudaFreeHost(_doubled_buffer[1]);
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
aclrtFreeHost(_doubled_buffer[0]);
|
||||
aclrtFreeHost(_doubled_buffer[1]);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -75,6 +87,11 @@ public:
|
|||
{
|
||||
for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]);
|
||||
}
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
inline void SynchronizeStreams()
|
||||
{
|
||||
for (int i = 0; i < 2; i++) aclrtSynchronizeStream(_streams[i].stream());
|
||||
}
|
||||
#endif
|
||||
inline void IncrementStep(size_t step, float beta1, float beta2)
|
||||
{
|
||||
|
@ -102,6 +119,11 @@ private:
|
|||
float* _doubled_buffer[2];
|
||||
cudaStream_t _streams[2];
|
||||
bool _buf_index;
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
float* _doubled_buffer[2];
|
||||
c10_npu::NPUStream _streams[2] = {c10_npu::getCurrentNPUStream(),
|
||||
c10_npu::getNPUStreamFromPool()};
|
||||
bool _buf_index;
|
||||
#endif
|
||||
};
|
||||
|
||||
|
@ -149,6 +171,8 @@ void Lion_Optimizer::Step_AVX(size_t* rounded_size,
|
|||
size_t offset = copy_size + t;
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); }
|
||||
#endif
|
||||
#pragma omp parallel for
|
||||
for (size_t i = t; i < offset; i += SIMD_WIDTH * span) {
|
||||
|
@ -178,7 +202,7 @@ void Lion_Optimizer::Step_AVX(size_t* rounded_size,
|
|||
simd_fma<span>(momentum_4, grad_4, betta2_minus1_4, momentum_4);
|
||||
|
||||
simd_store<span>(_params + (i >> rshft), param_4, half_precision);
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
|
||||
if (dev_params) {
|
||||
simd_store<span>(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision);
|
||||
}
|
||||
|
@ -196,6 +220,17 @@ void Lion_Optimizer::Step_AVX(size_t* rounded_size,
|
|||
|
||||
_buf_index = !_buf_index;
|
||||
}
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
if (dev_params) {
|
||||
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
|
||||
if (half_precision) memoryCopySize /= 2;
|
||||
aclrtMemcpy(dev_params + t,
|
||||
memcpy_size,
|
||||
_doubled_buffer[_buf_index],
|
||||
memcpy_size,
|
||||
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);
|
||||
|
||||
_buf_index = !_buf_index;
|
||||
#endif
|
||||
}
|
||||
*rounded_size = new_rounded_size;
|
||||
|
|
|
@ -54,6 +54,8 @@ void Lion_Optimizer::Step_1(float* _params,
|
|||
size_t offset = copy_size + t;
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); }
|
||||
#endif
|
||||
#pragma omp parallel for
|
||||
for (size_t k = t; k < offset; k++) {
|
||||
|
@ -72,7 +74,7 @@ void Lion_Optimizer::Step_1(float* _params,
|
|||
}
|
||||
momentum = momentum * _betta2;
|
||||
momentum = grad * betta2_minus1 + momentum;
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
|
||||
if (dev_params) _doubled_buffer[_buf_index][k - t] = param;
|
||||
#endif
|
||||
if (half_precision)
|
||||
|
@ -86,6 +88,17 @@ void Lion_Optimizer::Step_1(float* _params,
|
|||
launch_param_update(
|
||||
_doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]);
|
||||
|
||||
_buf_index = !_buf_index;
|
||||
}
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
if (dev_params) {
|
||||
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
|
||||
aclrtMemcpy(dev_params + t,
|
||||
memcpy_size,
|
||||
_doubled_buffer[_buf_index],
|
||||
memcpy_size,
|
||||
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);
|
||||
|
||||
_buf_index = !_buf_index;
|
||||
}
|
||||
#endif
|
||||
|
@ -201,7 +214,7 @@ int ds_lion_step(int optimizer_id,
|
|||
nullptr,
|
||||
(params.options().dtype() == at::kHalf));
|
||||
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
|
||||
opt->SynchronizeStreams();
|
||||
#endif
|
||||
return 0;
|
||||
|
@ -218,7 +231,7 @@ int ds_lion_step_plus_copy(int optimizer_id,
|
|||
torch::Tensor& exp_avg,
|
||||
torch::Tensor& gpu_params)
|
||||
{
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
|
||||
auto params_c = params.contiguous();
|
||||
auto gpu_params_c = gpu_params.contiguous();
|
||||
auto exp_avg_c = exp_avg.contiguous();
|
||||
|
|
|
@ -80,6 +80,32 @@ def nvcc_version():
|
|||
return ".".join(release)
|
||||
|
||||
|
||||
def installed_cann_path():
|
||||
if "ASCEND_HOME_PATH" in os.environ or os.path.exists(os.environ["ASCEND_HOME_PATH"]):
|
||||
return os.environ["ASCEND_HOME_PATH"]
|
||||
return None
|
||||
|
||||
|
||||
def installed_cann_version():
|
||||
import re
|
||||
ascend_path = installed_cann_path()
|
||||
if ascend_path is None:
|
||||
return f"CANN_HOME does not exist, unable to compile NPU op(s)"
|
||||
cann_version = ""
|
||||
for dirpath, _, filenames in os.walk(os.path.realpath(ascend_path)):
|
||||
if cann_version:
|
||||
break
|
||||
install_files = [file for file in filenames if re.match(r"ascend_.*_install\.info", file)]
|
||||
if install_files:
|
||||
filepath = os.path.join(dirpath, install_files[0])
|
||||
with open(filepath, "r") as f:
|
||||
for line in f:
|
||||
if line.find("version") != -1:
|
||||
cann_version = line.strip().split("=")[-1]
|
||||
break
|
||||
return cann_version
|
||||
|
||||
|
||||
def get_shm_size():
|
||||
try:
|
||||
shm_stats = os.statvfs('/dev/shm')
|
||||
|
@ -122,6 +148,11 @@ def debug_report():
|
|||
("deepspeed wheel compiled w.", f"torch {torch_info['version']}, " +
|
||||
(f"hip {torch_info['hip_version']}" if hip_version else f"cuda {torch_info['cuda_version']}"))
|
||||
])
|
||||
elif get_accelerator().device_name() == 'npu':
|
||||
import torch_npu
|
||||
report.extend([("deepspeed wheel compiled w.", f"torch {torch_info['version']}"),
|
||||
("torch_npu install path", torch_npu.__path__), ("torch_npu version", torch_npu.__version__),
|
||||
("ascend_cann version", installed_cann_version())])
|
||||
else:
|
||||
report.extend([("deepspeed wheel compiled w.", f"torch {torch_info['version']} ")])
|
||||
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
# DeepSpeed Team
|
||||
'''Copyright The Microsoft DeepSpeed Team'''
|
||||
|
||||
# NPU related operators will be added in the future.
|
||||
from .fused_adam import FusedAdamBuilder
|
||||
from .no_impl import NotImplementedBuilder
|
||||
from .cpu_adam import CPUAdamBuilder
|
||||
from .cpu_adagrad import CPUAdagradBuilder
|
||||
from .cpu_lion import CPULionBuilder
|
||||
|
|
|
@ -3,6 +3,13 @@
|
|||
|
||||
# DeepSpeed Team
|
||||
|
||||
import re
|
||||
import os
|
||||
try:
|
||||
import torch_npu
|
||||
except ImportError as e:
|
||||
pass
|
||||
|
||||
try:
|
||||
# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
|
||||
# if successful this also means we're doing a local install and not JIT compile path
|
||||
|
@ -13,22 +20,67 @@ except ImportError:
|
|||
|
||||
|
||||
class NPUOpBuilder(OpBuilder):
|
||||
_ascend_path = None
|
||||
_torch_npu_path = None
|
||||
_cann_version = None
|
||||
|
||||
def builder(self):
|
||||
from torch.utils.cpp_extension import CppExtension as ExtensionBuilder
|
||||
def __init__(self, name):
|
||||
super().__init__(name)
|
||||
self._ascend_path = self.installed_cann_path()
|
||||
self._torch_npu_path = os.path.join(os.path.dirname(os.path.abspath(torch_npu.__file__)))
|
||||
try:
|
||||
self._cann_version = self.installed_cann_version(self.name)
|
||||
except BaseException:
|
||||
print(f"{self.name} ascend_cann is missing, npu ops cannot be compiled!")
|
||||
|
||||
compile_args = {'cxx': self.strip_empty_entries(self.cxx_args())}
|
||||
def cann_defs(self):
|
||||
if self._cann_version:
|
||||
return '-D__ENABLE_CANN__'
|
||||
return '-D__DISABLE_CANN__'
|
||||
|
||||
cpp_ext = ExtensionBuilder(name=self.absolute_name(),
|
||||
sources=self.strip_empty_entries(self.sources()),
|
||||
include_dirs=self.strip_empty_entries(self.include_paths()),
|
||||
libraries=self.strip_empty_entries(self.libraries_args()),
|
||||
extra_compile_args=compile_args)
|
||||
def installed_cann_path(self):
|
||||
if "ASCEND_HOME_PATH" in os.environ or os.path.exists(os.environ["ASCEND_HOME_PATH"]):
|
||||
return os.environ["ASCEND_HOME_PATH"]
|
||||
return None
|
||||
|
||||
return cpp_ext
|
||||
def installed_cann_version(self, name=""):
|
||||
ascend_path = self.installed_cann_path()
|
||||
assert ascend_path is not None, "CANN_HOME does not exist, unable to compile NPU op(s)"
|
||||
cann_version = ""
|
||||
for dirpath, _, filenames in os.walk(os.path.realpath(ascend_path)):
|
||||
if cann_version:
|
||||
break
|
||||
install_files = [file for file in filenames if re.match(r"ascend_.*_install\.info", file)]
|
||||
if install_files:
|
||||
filepath = os.path.join(dirpath, install_files[0])
|
||||
with open(filepath, "r") as f:
|
||||
for line in f:
|
||||
if line.find("version") != -1:
|
||||
cann_version = line.strip().split("=")[-1]
|
||||
break
|
||||
return cann_version
|
||||
|
||||
def include_paths(self):
|
||||
paths = super().include_paths()
|
||||
paths += [os.path.join(self._ascend_path, 'include'), os.path.join(self._torch_npu_path, 'include')]
|
||||
return paths
|
||||
|
||||
def cxx_args(self):
|
||||
return []
|
||||
args = super().cxx_args()
|
||||
args += ['-O3', '-std=c++17', '-g', '-Wno-reorder', '-fopenmp']
|
||||
args += ['-fstack-protector-all', '-Wl,-z,relro,-z,now,-z,noexecstack', '-Wl,--disable-new-dtags,--rpath']
|
||||
args += [
|
||||
self.cann_defs(),
|
||||
self.cpu_arch(),
|
||||
self.simd_width(), '-L' + os.path.join(self._ascend_path, 'lib64'),
|
||||
'-L' + os.path.join(self._torch_npu_path, 'lib')
|
||||
]
|
||||
return args
|
||||
|
||||
def libraries_args(self):
|
||||
return []
|
||||
def extra_ldflags(self):
|
||||
flags = super().extra_ldflags()
|
||||
flags += [
|
||||
'-L' + os.path.join(self._ascend_path, 'lib64'), '-lascendcl',
|
||||
'-L' + os.path.join(self._torch_npu_path, 'lib'), '-ltorch_npu'
|
||||
]
|
||||
return flags
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from .builder import NPUOpBuilder
|
||||
|
||||
|
||||
class CPUAdagradBuilder(NPUOpBuilder):
|
||||
BUILD_VAR = "DS_BUILD_CPU_ADAGRAD"
|
||||
NAME = "cpu_adagrad"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name=self.NAME)
|
||||
|
||||
def absolute_name(self):
|
||||
return f'deepspeed.ops.adagrad.{self.NAME}_op'
|
||||
|
||||
def sources(self):
|
||||
return ['csrc/adagrad/cpu_adagrad.cpp']
|
||||
|
||||
def include_paths(self):
|
||||
args = super().include_paths()
|
||||
args += ['csrc/includes']
|
||||
return args
|
|
@ -0,0 +1,25 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from .builder import NPUOpBuilder
|
||||
|
||||
|
||||
class CPUAdamBuilder(NPUOpBuilder):
|
||||
BUILD_VAR = "DS_BUILD_CPU_ADAM"
|
||||
NAME = "cpu_adam"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name=self.NAME)
|
||||
|
||||
def absolute_name(self):
|
||||
return f'deepspeed.ops.adam.{self.NAME}_op'
|
||||
|
||||
def sources(self):
|
||||
return ['csrc/adam/cpu_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp']
|
||||
|
||||
def include_paths(self):
|
||||
args = super().include_paths()
|
||||
args += ['csrc/includes']
|
||||
return args
|
|
@ -0,0 +1,25 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from .builder import NPUOpBuilder
|
||||
|
||||
|
||||
class CPULionBuilder(NPUOpBuilder):
|
||||
BUILD_VAR = "DS_BUILD_CPU_LION"
|
||||
NAME = "cpu_lion"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name=self.NAME)
|
||||
|
||||
def absolute_name(self):
|
||||
return f'deepspeed.ops.lion.{self.NAME}_op'
|
||||
|
||||
def sources(self):
|
||||
return ['csrc/lion/cpu_lion.cpp', 'csrc/lion/cpu_lion_impl.cpp']
|
||||
|
||||
def include_paths(self):
|
||||
args = super().include_paths()
|
||||
args += ['csrc/includes']
|
||||
return args
|
|
@ -22,3 +22,12 @@ class NotImplementedBuilder(NPUOpBuilder):
|
|||
|
||||
def sources(self):
|
||||
return []
|
||||
|
||||
def cxx_args(self):
|
||||
return []
|
||||
|
||||
def extra_ldflags(self):
|
||||
return []
|
||||
|
||||
def include_paths(self):
|
||||
return []
|
||||
|
|
Загрузка…
Ссылка в новой задаче