зеркало из https://github.com/microsoft/DeepSpeed.git
[XPU] XPU accelerator support for Intel GPU device (#4547)
This PR includes XPU support for Intel GPU. With this PR, DeepSpeed can support XPU devices without install Intel Extension for DeepSpeed. --------- Co-authored-by: Liangliang-Ma <1906710196@qq.com> Co-authored-by: baodi <di.bao@intel.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Yizhou Wang <yizhou.wang@intel.com> Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
This commit is contained in:
Родитель
c8c57b8c24
Коммит
f4f31317ed
|
@ -20,7 +20,7 @@ try:
|
|||
except ImportError as e:
|
||||
dsa2 = None
|
||||
|
||||
SUPPORTED_ACCELERATOR_LIST = ['cuda', 'cpu', 'xpu', 'npu', 'mps']
|
||||
SUPPORTED_ACCELERATOR_LIST = ['cuda', 'cpu', 'xpu', 'xpu.external', 'npu', 'mps']
|
||||
|
||||
ds_accelerator = None
|
||||
|
||||
|
@ -60,10 +60,18 @@ def get_accelerator():
|
|||
accelerator_name = os.environ["DS_ACCELERATOR"]
|
||||
if accelerator_name == "xpu":
|
||||
try:
|
||||
from intel_extension_for_deepspeed import XPU_Accelerator # noqa: F401 # type: ignore
|
||||
import intel_extension_for_pytorch as ipex
|
||||
assert ipex._C._has_xpu(), "XPU_Accelerator requires an intel_extension_for_pytorch that supports XPU."
|
||||
except ImportError as e:
|
||||
raise ValueError(
|
||||
f"XPU_Accelerator requires intel_extension_for_deepspeed, which is not installed on this system.")
|
||||
f"XPU_Accelerator requires intel_extension_for_pytorch, which is not installed on this system.")
|
||||
elif accelerator_name == "xpu.external":
|
||||
try:
|
||||
import intel_extension_for_deepspeed # noqa: F401 # type: ignore
|
||||
except ImportError as e:
|
||||
raise ValueError(
|
||||
f"XPU_Accelerator external requires intel_extension_for_deepspeed, which is not installed on this system."
|
||||
)
|
||||
elif accelerator_name == "cpu":
|
||||
try:
|
||||
import intel_extension_for_pytorch # noqa: F401 # type: ignore
|
||||
|
@ -104,15 +112,16 @@ def get_accelerator():
|
|||
|
||||
try:
|
||||
from intel_extension_for_deepspeed import XPU_Accelerator # noqa: F401,F811 # type: ignore
|
||||
|
||||
accelerator_name = "xpu"
|
||||
accelerator_name = "xpu.external"
|
||||
except ImportError as e:
|
||||
pass
|
||||
if accelerator_name is None:
|
||||
try:
|
||||
import intel_extension_for_pytorch # noqa: F401,F811 # type: ignore
|
||||
|
||||
accelerator_name = "cpu"
|
||||
import intel_extension_for_pytorch as ipex
|
||||
if ipex._C._has_xpu():
|
||||
accelerator_name = "xpu"
|
||||
else:
|
||||
accelerator_name = "cpu"
|
||||
except ImportError as e:
|
||||
pass
|
||||
if accelerator_name is None:
|
||||
|
@ -145,8 +154,12 @@ def get_accelerator():
|
|||
from .cpu_accelerator import CPU_Accelerator
|
||||
|
||||
ds_accelerator = CPU_Accelerator()
|
||||
elif accelerator_name == "xpu":
|
||||
elif accelerator_name == "xpu.external":
|
||||
# XPU_Accelerator is already imported in detection stage
|
||||
ds_accelerator = XPU_Accelerator()
|
||||
elif accelerator_name == "xpu":
|
||||
from .xpu_accelerator import XPU_Accelerator
|
||||
|
||||
ds_accelerator = XPU_Accelerator()
|
||||
elif accelerator_name == "npu":
|
||||
from .npu_accelerator import NPU_Accelerator
|
||||
|
|
|
@ -0,0 +1,282 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import torch
|
||||
from deepspeed.accelerator.abstract_accelerator import DeepSpeedAccelerator
|
||||
import intel_extension_for_pytorch as ipex # noqa: F401 # type: ignore
|
||||
import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore
|
||||
|
||||
|
||||
class XPU_Accelerator(DeepSpeedAccelerator):
|
||||
|
||||
def __init__(self):
|
||||
self._name = 'xpu'
|
||||
self._communication_backend_name = 'ccl'
|
||||
self.aligned_tensors = []
|
||||
|
||||
def is_synchronized_device(self):
|
||||
return False
|
||||
|
||||
# Device APIs
|
||||
def device_name(self, device_index=None):
|
||||
if device_index == None:
|
||||
return 'xpu'
|
||||
return 'xpu:{}'.format(device_index)
|
||||
|
||||
def device(self, device_index=None):
|
||||
return torch.xpu.device(device_index)
|
||||
|
||||
def set_device(self, device_index):
|
||||
torch.xpu.set_device(device_index)
|
||||
|
||||
def current_device(self):
|
||||
return torch.xpu.current_device()
|
||||
|
||||
def current_device_name(self):
|
||||
return 'xpu:{}'.format(torch.xpu.current_device())
|
||||
|
||||
def device_count(self):
|
||||
return torch.xpu.device_count()
|
||||
|
||||
def synchronize(self, device_index=None):
|
||||
return torch.xpu.synchronize(device_index)
|
||||
|
||||
# RNG APIs
|
||||
def random(self):
|
||||
return torch.xpu.random
|
||||
|
||||
def set_rng_state(self, new_state, device_index=None):
|
||||
if device_index == None:
|
||||
return torch.xpu.set_rng_state(new_state)
|
||||
return torch.xpu.set_rng_state(new_state, device_index)
|
||||
|
||||
def get_rng_state(self, device_index=None):
|
||||
if device_index == None:
|
||||
return torch.xpu.get_rng_state()
|
||||
return torch.xpu.get_rng_state(device_index)
|
||||
|
||||
def manual_seed(self, seed):
|
||||
return torch.xpu.manual_seed(seed)
|
||||
|
||||
def manual_seed_all(self, seed):
|
||||
return torch.xpu.manual_seed_all(seed)
|
||||
|
||||
def initial_seed(self, seed):
|
||||
return torch.xpu.initial_seed(seed)
|
||||
|
||||
def default_generator(self, device_index):
|
||||
return torch.xpu.default_generators[device_index]
|
||||
|
||||
# Streams/Events
|
||||
@property
|
||||
def Stream(self):
|
||||
return torch.xpu.Stream
|
||||
|
||||
def stream(self, stream):
|
||||
return torch.xpu.stream(stream)
|
||||
|
||||
def current_stream(self, device_index=None):
|
||||
return torch.xpu.current_stream(device_index)
|
||||
|
||||
def default_stream(self, device_index=None):
|
||||
# torch.xpu does not support the sync behavior of default stream as cuda
|
||||
# use current_stream as workaround
|
||||
# see https://pytorch.org/docs/stable/notes/cuda.html#cuda-streams
|
||||
return torch.xpu.current_stream(device_index)
|
||||
|
||||
@property
|
||||
def Event(self):
|
||||
return torch.xpu.Event
|
||||
|
||||
# Memory management
|
||||
def empty_cache(self):
|
||||
return torch.xpu.empty_cache()
|
||||
|
||||
def memory_allocated(self, device_index=None):
|
||||
return torch.xpu.memory_allocated(device_index)
|
||||
|
||||
def max_memory_allocated(self, device_index=None):
|
||||
return torch.xpu.max_memory_allocated(device_index)
|
||||
|
||||
def reset_max_memory_allocated(self, device_index=None):
|
||||
return torch.xpu.reset_max_memory_allocated(device_index)
|
||||
|
||||
def memory_cached(self, device_index=None):
|
||||
return torch.xpu.memory_reserved(device_index)
|
||||
|
||||
def max_memory_cached(self, device_index=None):
|
||||
return torch.xpu.max_memory_reserved(device_index)
|
||||
|
||||
def reset_max_memory_cached(self, device_index=None):
|
||||
return torch.xpu.reset_max_memory_reserved(device_index)
|
||||
|
||||
def memory_stats(self, device_index=None):
|
||||
return torch.xpu.memory_stats(device_index)
|
||||
|
||||
def reset_peak_memory_stats(self, device_index=None):
|
||||
return torch.xpu.reset_peak_memory_stats(device_index)
|
||||
|
||||
def memory_reserved(self, device_index=None):
|
||||
return torch.xpu.memory_reserved(device_index)
|
||||
|
||||
def max_memory_reserved(self, device_index=None):
|
||||
return torch.xpu.max_memory_reserved(device_index)
|
||||
|
||||
def total_memory(self, device_index=None):
|
||||
return torch.xpu.get_device_properties(device_index).total_memory
|
||||
|
||||
def available_memory(self, device_index=None):
|
||||
return self.total_memory(device_index) - self.memory_allocated(device_index)
|
||||
|
||||
# Misc
|
||||
def amp(self):
|
||||
return torch.xpu.amp
|
||||
|
||||
def is_available(self):
|
||||
return torch.xpu.is_available()
|
||||
|
||||
def range_push(self, msg):
|
||||
# TODO itt is currently not supported yet
|
||||
# return torch.profiler.itt.range_push(msg)
|
||||
return
|
||||
|
||||
def range_pop(self):
|
||||
# TODO itt is currently not supported yet
|
||||
# return torch.profiler.itt.range_pop()
|
||||
return
|
||||
|
||||
def lazy_call(self, callback):
|
||||
return torch.xpu.lazy_init._lazy_call(callback)
|
||||
|
||||
def communication_backend_name(self):
|
||||
return self._communication_backend_name
|
||||
|
||||
def is_triton_supported(self):
|
||||
return False
|
||||
|
||||
# Graph operations
|
||||
def create_graph(self):
|
||||
return None
|
||||
|
||||
def capture_to_graph(self, graph, pool=None, stream=None):
|
||||
from deepspeed.runtime.utils import noop_context
|
||||
return noop_context()
|
||||
|
||||
def replay_graph(self, graph):
|
||||
return
|
||||
|
||||
# Data types
|
||||
def is_bf16_supported(self):
|
||||
return True
|
||||
|
||||
def is_fp16_supported(self):
|
||||
return True
|
||||
|
||||
def supported_dtypes(self):
|
||||
return [torch.float, torch.half, torch.bfloat16]
|
||||
|
||||
# Tensor operations
|
||||
|
||||
@property
|
||||
def BFloat16Tensor(self):
|
||||
return torch.xpu.BFloat16Tensor
|
||||
|
||||
@property
|
||||
def ByteTensor(self):
|
||||
return torch.xpu.ByteTensor
|
||||
|
||||
@property
|
||||
def DoubleTensor(self):
|
||||
return torch.xpu.DoubleTensor
|
||||
|
||||
@property
|
||||
def FloatTensor(self):
|
||||
return torch.xpu.FloatTensor
|
||||
|
||||
@property
|
||||
def HalfTensor(self):
|
||||
return torch.xpu.HalfTensor
|
||||
|
||||
@property
|
||||
def IntTensor(self):
|
||||
return torch.xpu.IntTensor
|
||||
|
||||
@property
|
||||
def LongTensor(self):
|
||||
return torch.xpu.LongTensor
|
||||
|
||||
def pin_memory(self, tensor, align_bytes=1):
|
||||
if align_bytes == 1:
|
||||
return tensor.pin_memory(device=self.current_device_name())
|
||||
elif align_bytes == 0:
|
||||
from intel_extension_for_deepspeed.op_builder.async_io import AsyncIOBuilder
|
||||
self.aio_handle = AsyncIOBuilder().load().aio_handle(128 * 1024, 8, False, False, False)
|
||||
aligned_t = self.aio_handle.new_cpu_locked_tensor(tensor.numel(), tensor)
|
||||
aligned_t = aligned_t[:tensor.numel()].copy_(tensor)
|
||||
self.aligned_tensors.append([aligned_t.data_ptr(), aligned_t[-1].data_ptr()])
|
||||
return aligned_t
|
||||
|
||||
def is_pinned(self, tensor):
|
||||
if tensor.is_pinned(device=self.current_device_name()):
|
||||
return True
|
||||
else:
|
||||
for begin, end in self.aligned_tensors:
|
||||
if begin <= tensor.data_ptr() and tensor.data_ptr() <= end:
|
||||
return True
|
||||
return False
|
||||
|
||||
def op_builder_dir(self):
|
||||
try:
|
||||
# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
|
||||
# if successful this also means we're doing a local install and not JIT compile path
|
||||
from op_builder import __deepspeed__ # noqa: F401 # type: ignore
|
||||
return "op_builder.xpu"
|
||||
except ImportError:
|
||||
return "deepspeed.ops.op_builder.xpu"
|
||||
|
||||
def on_accelerator(self, tensor):
|
||||
device_str = str(tensor.device)
|
||||
if device_str.startswith('xpu:'):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
# create an instance of op builder and return, name specified by class_name
|
||||
def create_op_builder(self, op_name):
|
||||
builder_class = self.get_op_builder(op_name)
|
||||
if builder_class != None:
|
||||
return builder_class()
|
||||
return None
|
||||
|
||||
# return an op builder class, name specified by class_name
|
||||
def get_op_builder(self, class_name):
|
||||
try:
|
||||
# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
|
||||
# if successful this also means we're doing a local install and not JIT compile path
|
||||
from op_builder import __deepspeed__ # noqa: F401 # type: ignore
|
||||
from op_builder.xpu import CPUAdagradBuilder, CPUAdamBuilder, FusedAdamBuilder, AsyncIOBuilder
|
||||
except ImportError:
|
||||
from deepspeed.ops.op_builder.xpu import CPUAdagradBuilder, CPUAdamBuilder, FusedAdamBuilder, AsyncIOBuilder
|
||||
|
||||
if class_name == "AsyncIOBuilder":
|
||||
return AsyncIOBuilder
|
||||
elif class_name == "CPUAdagradBuilder":
|
||||
return CPUAdagradBuilder
|
||||
elif class_name == "CPUAdamBuilder":
|
||||
return CPUAdamBuilder
|
||||
elif class_name == "FusedAdamBuilder":
|
||||
return FusedAdamBuilder
|
||||
else:
|
||||
return None
|
||||
|
||||
def build_extension(self):
|
||||
try:
|
||||
from intel_extension_for_pytorch.xpu.cpp_extension import DpcppBuildExtension
|
||||
except ImportError:
|
||||
from intel_extension_for_pytorch.xpu.utils import DpcppBuildExtension
|
||||
return DpcppBuildExtension
|
||||
|
||||
def export_envs(self):
|
||||
return []
|
|
@ -23,7 +23,7 @@ This file is adapted from fused adam in NVIDIA/apex, commit a109f85
|
|||
#define BLOCK_SIZE 512
|
||||
#define ILP 4
|
||||
|
||||
typedef enum {
|
||||
typedef enum : int {
|
||||
ADAM_MODE_0 = 0, // L2 regularization mode
|
||||
ADAM_MODE_1 = 1 // Decoupled weight decay mode(AdamW)
|
||||
} adamMode_t;
|
||||
|
|
|
@ -0,0 +1,196 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include "cpu_adagrad.h"
|
||||
#include <torch/extension.h>
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
#include <unordered_map>
|
||||
|
||||
static std::unordered_map<int, std::shared_ptr<void>> s_optimizers;
|
||||
|
||||
// C++ interface
|
||||
|
||||
void Adagrad_Optimizer::Step_1(float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg_sq,
|
||||
size_t _param_size,
|
||||
ds_half_precision_t* dev_params,
|
||||
bool half_precision)
|
||||
{
|
||||
size_t rounded_size = 0;
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
Step_AVX<1>(
|
||||
&rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision);
|
||||
#endif
|
||||
if (_param_size > rounded_size) {
|
||||
float step_size = -1 * _alpha;
|
||||
ds_half_precision_t* grads_cast_h;
|
||||
ds_half_precision_t* params_cast_h;
|
||||
if (half_precision) {
|
||||
grads_cast_h = reinterpret_cast<ds_half_precision_t*>(grads);
|
||||
params_cast_h = reinterpret_cast<ds_half_precision_t*>(_params);
|
||||
}
|
||||
for (size_t t = rounded_size; t < _param_size; t += TILE) {
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > _param_size) copy_size = _param_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
#pragma omp parallel for
|
||||
for (size_t k = t; k < offset; k++) {
|
||||
float grad = half_precision ? (float)grads_cast_h[k] : grads[k];
|
||||
float param = half_precision ? (float)params_cast_h[k] : _params[k];
|
||||
float momentum = grads[k];
|
||||
float variance = _exp_avg_sq[k];
|
||||
if (_weight_decay > 0) { grad = param * _weight_decay + grad; }
|
||||
|
||||
variance += grad * grad;
|
||||
|
||||
grad = sqrt(variance);
|
||||
grad += _eps;
|
||||
grad = momentum / grad;
|
||||
param = grad * step_size + param;
|
||||
if (half_precision)
|
||||
params_cast_h[k] = (ds_half_precision_t)param;
|
||||
else
|
||||
_params[k] = param;
|
||||
// STORE UPDATE TERM TO GRAD'S MEMORY
|
||||
grads[k] = grad * step_size;
|
||||
_exp_avg_sq[k] = variance;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Adagrad_Optimizer::Step_4(float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg_sq,
|
||||
size_t _param_size,
|
||||
ds_half_precision_t* dev_params,
|
||||
bool half_precision)
|
||||
{
|
||||
size_t rounded_size = 0;
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
Step_AVX<4>(
|
||||
&rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision);
|
||||
#endif
|
||||
if (_param_size > rounded_size)
|
||||
Step_1((_params + rounded_size),
|
||||
(grads + rounded_size),
|
||||
(_exp_avg_sq + rounded_size),
|
||||
(_param_size - rounded_size),
|
||||
(dev_params != nullptr ? (dev_params + rounded_size) : dev_params),
|
||||
half_precision);
|
||||
}
|
||||
|
||||
int create_adagrad_optimizer(int optimizer_id,
|
||||
float alpha = 1e-2,
|
||||
float eps = 1e-8,
|
||||
float weight_decay = 0,
|
||||
bool should_log = false)
|
||||
{
|
||||
auto opt = std::make_shared<Adagrad_Optimizer>(alpha, eps, weight_decay);
|
||||
|
||||
s_optimizers[optimizer_id] = opt;
|
||||
|
||||
if (should_log) {
|
||||
std::string avx_type = "";
|
||||
#if defined(__AVX512__)
|
||||
avx_type = "AVX512";
|
||||
#else
|
||||
#if defined(__AVX256__)
|
||||
avx_type = "AVX2";
|
||||
#else
|
||||
avx_type = "scalar";
|
||||
#endif
|
||||
#endif
|
||||
|
||||
printf("Adagrad Optimizer #%d is created with %s arithmetic capability.\n",
|
||||
optimizer_id,
|
||||
avx_type.c_str());
|
||||
printf("Config: alpha=%f, weight_decay=%f\n", alpha, weight_decay);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
void Adagrad_Optimizer::Step_8(float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg_sq,
|
||||
size_t _param_size,
|
||||
ds_half_precision_t* dev_params,
|
||||
bool half_precision)
|
||||
{
|
||||
size_t rounded_size = 0;
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
Step_AVX<8>(
|
||||
&rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision);
|
||||
#endif
|
||||
if (_param_size > rounded_size)
|
||||
Step_4((_params + rounded_size),
|
||||
(grads + rounded_size),
|
||||
(_exp_avg_sq + rounded_size),
|
||||
(_param_size - rounded_size),
|
||||
(dev_params != nullptr ? (dev_params + rounded_size) : dev_params),
|
||||
half_precision);
|
||||
}
|
||||
|
||||
int ds_adagrad_step(int optimizer_id,
|
||||
size_t step,
|
||||
float lr,
|
||||
float epsilon,
|
||||
float weight_decay,
|
||||
torch::Tensor& params,
|
||||
torch::Tensor& grads,
|
||||
torch::Tensor& exp_avg_sq)
|
||||
{
|
||||
auto params_c = params.contiguous();
|
||||
auto grads_c = grads.contiguous();
|
||||
auto exp_avg_sq_c = exp_avg_sq.contiguous();
|
||||
|
||||
float* params_ptr = (float*)params_c.data_ptr();
|
||||
float* grads_ptr = (float*)grads_c.data_ptr();
|
||||
float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
|
||||
|
||||
std::shared_ptr<Adagrad_Optimizer> opt =
|
||||
std::static_pointer_cast<Adagrad_Optimizer>(s_optimizers[optimizer_id]);
|
||||
opt->IncrementStep(step);
|
||||
opt->update_state(lr, epsilon, weight_decay);
|
||||
opt->Step_8(params_ptr, grads_ptr, exp_avg_sq_ptr, params_c.numel());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int ds_adagrad_step_plus_copy(int optimizer_id,
|
||||
size_t step,
|
||||
float lr,
|
||||
float epsilon,
|
||||
float weight_decay,
|
||||
torch::Tensor& params,
|
||||
torch::Tensor& grads,
|
||||
torch::Tensor& exp_avg_sq,
|
||||
torch::Tensor& gpu_params)
|
||||
{
|
||||
assert(false);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int destroy_adagrad_optimizer(int optimizer_id)
|
||||
{
|
||||
s_optimizers.erase(optimizer_id);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("adagrad_update", &ds_adagrad_step, "DeepSpeed CPU Adagrad update (C++)");
|
||||
m.def("adagrad_update_copy",
|
||||
&ds_adagrad_step_plus_copy,
|
||||
"DeepSpeed CPU Adagrad update and param copy (C++)");
|
||||
m.def("create_adagrad", &create_adagrad_optimizer, "DeepSpeed CPU Adagrad (C++)");
|
||||
m.def("destroy_adagrad", &destroy_adagrad_optimizer, "DeepSpeed CPU Adagrad destroy (C++)");
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include "cpu_adam.h"
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("adam_update", &ds_adam_step, "DeepSpeed CPU Adam update (C++)");
|
||||
m.def("adam_update_copy",
|
||||
&ds_adam_step_plus_copy,
|
||||
"DeepSpeed CPU Adam update and param copy (C++)");
|
||||
m.def("create_adam", &create_adam_optimizer, "DeepSpeed CPU Adam (C++)");
|
||||
m.def("destroy_adam", &destroy_adam_optimizer, "DeepSpeed CPU Adam destroy (C++)");
|
||||
}
|
|
@ -0,0 +1,247 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
#include <unordered_map>
|
||||
#include "cpu_adam.h"
|
||||
|
||||
static std::unordered_map<int, std::shared_ptr<void>> s_optimizers;
|
||||
|
||||
// C++ interface
|
||||
|
||||
void Adam_Optimizer::Step_1(float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg,
|
||||
float* _exp_avg_sq,
|
||||
size_t _param_size,
|
||||
ds_half_precision_t* dev_params,
|
||||
bool half_precision)
|
||||
{
|
||||
size_t rounded_size = 0;
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
Step_AVX<1>(&rounded_size,
|
||||
_params,
|
||||
grads,
|
||||
_exp_avg,
|
||||
_exp_avg_sq,
|
||||
_param_size,
|
||||
dev_params,
|
||||
half_precision);
|
||||
#endif
|
||||
if (_param_size > rounded_size) {
|
||||
float betta1_minus1 = 1 - _betta1;
|
||||
float betta2_minus1 = 1 - _betta2;
|
||||
|
||||
float step_size = -1 * _alpha / _bias_correction1;
|
||||
float w_decay = -1 * _alpha * _weight_decay;
|
||||
ds_half_precision_t* grads_cast_h;
|
||||
ds_half_precision_t* params_cast_h;
|
||||
if (half_precision) {
|
||||
grads_cast_h = reinterpret_cast<ds_half_precision_t*>(grads);
|
||||
params_cast_h = reinterpret_cast<ds_half_precision_t*>(_params);
|
||||
}
|
||||
|
||||
for (size_t t = rounded_size; t < _param_size; t += TILE) {
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > _param_size) copy_size = _param_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
#pragma omp parallel for
|
||||
for (size_t k = t; k < offset; k++) {
|
||||
float grad = half_precision ? (float)grads_cast_h[k] : grads[k];
|
||||
float param = half_precision ? (float)params_cast_h[k] : _params[k];
|
||||
float momentum = _exp_avg[k];
|
||||
float variance = _exp_avg_sq[k];
|
||||
if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; }
|
||||
momentum = momentum * _betta1;
|
||||
momentum = grad * betta1_minus1 + momentum;
|
||||
|
||||
variance = variance * _betta2;
|
||||
grad = grad * grad;
|
||||
variance = grad * betta2_minus1 + variance;
|
||||
|
||||
grad = sqrt(variance);
|
||||
grad = grad * _bias_correction2 + _eps;
|
||||
grad = momentum / grad;
|
||||
if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; }
|
||||
param = grad * step_size + param;
|
||||
if (half_precision)
|
||||
params_cast_h[k] = (ds_half_precision_t)param;
|
||||
else
|
||||
_params[k] = param;
|
||||
_exp_avg[k] = momentum;
|
||||
_exp_avg_sq[k] = variance;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Adam_Optimizer::Step_4(float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg,
|
||||
float* _exp_avg_sq,
|
||||
size_t _param_size,
|
||||
ds_half_precision_t* dev_params,
|
||||
bool half_precision)
|
||||
{
|
||||
size_t rounded_size = 0;
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
Step_AVX<4>(&rounded_size,
|
||||
_params,
|
||||
grads,
|
||||
_exp_avg,
|
||||
_exp_avg_sq,
|
||||
_param_size,
|
||||
dev_params,
|
||||
half_precision);
|
||||
#endif
|
||||
if (_param_size > rounded_size)
|
||||
Step_1((_params + rounded_size),
|
||||
(grads + rounded_size),
|
||||
(_exp_avg + rounded_size),
|
||||
(_exp_avg_sq + rounded_size),
|
||||
(_param_size - rounded_size),
|
||||
(dev_params != nullptr ? (dev_params + rounded_size) : dev_params),
|
||||
half_precision);
|
||||
}
|
||||
|
||||
int create_adam_optimizer(int optimizer_id,
|
||||
float alpha,
|
||||
float betta1,
|
||||
float betta2,
|
||||
float eps,
|
||||
float weight_decay,
|
||||
bool adamw_mode,
|
||||
bool should_log)
|
||||
{
|
||||
auto opt =
|
||||
std::make_shared<Adam_Optimizer>(alpha, betta1, betta2, eps, weight_decay, adamw_mode);
|
||||
|
||||
s_optimizers[optimizer_id] = opt;
|
||||
|
||||
if (should_log) {
|
||||
std::string avx_type = "";
|
||||
#if defined(__AVX512__)
|
||||
avx_type = "AVX512";
|
||||
#else
|
||||
#if defined(__AVX256__)
|
||||
avx_type = "AVX2";
|
||||
#else
|
||||
avx_type = "scalar";
|
||||
#endif
|
||||
#endif
|
||||
|
||||
printf("Adam Optimizer #%d is created with %s arithmetic capability.\n",
|
||||
optimizer_id,
|
||||
avx_type.c_str());
|
||||
printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d\n",
|
||||
alpha,
|
||||
betta1,
|
||||
betta2,
|
||||
weight_decay,
|
||||
(int)adamw_mode);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
void Adam_Optimizer::Step_8(float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg,
|
||||
float* _exp_avg_sq,
|
||||
size_t _param_size,
|
||||
ds_half_precision_t* dev_params,
|
||||
bool half_precision)
|
||||
{
|
||||
size_t rounded_size = 0;
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
Step_AVX<8>(&rounded_size,
|
||||
_params,
|
||||
grads,
|
||||
_exp_avg,
|
||||
_exp_avg_sq,
|
||||
_param_size,
|
||||
dev_params,
|
||||
half_precision);
|
||||
#endif
|
||||
if (_param_size > rounded_size)
|
||||
Step_4((_params + rounded_size),
|
||||
(grads + rounded_size),
|
||||
(_exp_avg + rounded_size),
|
||||
(_exp_avg_sq + rounded_size),
|
||||
(_param_size - rounded_size),
|
||||
(dev_params != nullptr ? (dev_params + rounded_size) : dev_params),
|
||||
half_precision);
|
||||
}
|
||||
|
||||
int ds_adam_step(int optimizer_id,
|
||||
size_t step,
|
||||
float lr,
|
||||
float beta1,
|
||||
float beta2,
|
||||
float epsilon,
|
||||
float weight_decay,
|
||||
bool bias_correction,
|
||||
torch::Tensor& params,
|
||||
torch::Tensor& grads,
|
||||
torch::Tensor& exp_avg,
|
||||
torch::Tensor& exp_avg_sq)
|
||||
{
|
||||
auto params_c = params.contiguous();
|
||||
auto grads_c = grads.contiguous();
|
||||
auto exp_avg_c = exp_avg.contiguous();
|
||||
auto exp_avg_sq_c = exp_avg_sq.contiguous();
|
||||
|
||||
// assert(params.options().dtype() == grads.options().dtype());
|
||||
|
||||
float* params_ptr = (float*)params_c.data_ptr();
|
||||
float* grads_ptr = (float*)grads_c.data_ptr();
|
||||
float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
|
||||
float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
|
||||
|
||||
std::shared_ptr<Adam_Optimizer> opt =
|
||||
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
|
||||
opt->IncrementStep(step, beta1, beta2);
|
||||
opt->update_state(lr, epsilon, weight_decay, bias_correction);
|
||||
|
||||
opt->Step_8(params_ptr,
|
||||
grads_ptr,
|
||||
exp_avg_ptr,
|
||||
exp_avg_sq_ptr,
|
||||
params_c.numel(),
|
||||
nullptr,
|
||||
(params.options().dtype() == at::kHalf));
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int ds_adam_step_plus_copy(int optimizer_id,
|
||||
size_t step,
|
||||
float lr,
|
||||
float beta1,
|
||||
float beta2,
|
||||
float epsilon,
|
||||
float weight_decay,
|
||||
bool bias_correction,
|
||||
torch::Tensor& params,
|
||||
torch::Tensor& grads,
|
||||
torch::Tensor& exp_avg,
|
||||
torch::Tensor& exp_avg_sq,
|
||||
torch::Tensor& gpu_params)
|
||||
{
|
||||
assert(false);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int destroy_adam_optimizer(int optimizer_id)
|
||||
{
|
||||
s_optimizers.erase(optimizer_id);
|
||||
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
void multi_tensor_adam_cuda(int chunk_size,
|
||||
at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
const float lr,
|
||||
const float beta1,
|
||||
const float beta2,
|
||||
const float epsilon,
|
||||
const int step,
|
||||
const int mode,
|
||||
const int bias_correction,
|
||||
const float weight_decay);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("multi_tensor_adam",
|
||||
&multi_tensor_adam_cuda,
|
||||
"Compute and apply gradient update to parameters for Adam optimizer");
|
||||
}
|
|
@ -0,0 +1,159 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
/*
|
||||
Copyright NVIDIA/apex
|
||||
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
|
||||
*/
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <sycl/sycl.hpp>
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include <cmath>
|
||||
#include "multi_tensor_apply.dp.hpp"
|
||||
#include "type_shim.h"
|
||||
|
||||
#define BLOCK_SIZE 512
|
||||
#define ILP 4
|
||||
|
||||
typedef enum : int {
|
||||
ADAM_MODE_0 = 0, // L2 regularization mode
|
||||
ADAM_MODE_1 = 1 // Decoupled weight decay mode(AdamW)
|
||||
} adamMode_t;
|
||||
|
||||
using MATH_T = float;
|
||||
|
||||
template <typename T>
|
||||
struct AdamFunctor {
|
||||
__inline__ __attribute__((always_inline)) void operator()(int chunk_size,
|
||||
volatile int* noop_gmem,
|
||||
TensorListMetadata<4>& tl,
|
||||
const float beta1,
|
||||
const float beta2,
|
||||
const float beta1_correction,
|
||||
const float beta2_correction,
|
||||
const float epsilon,
|
||||
const float lr,
|
||||
adamMode_t mode,
|
||||
const float decay)
|
||||
{
|
||||
auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>();
|
||||
int tensor_loc = tl.block_to_tensor[item_ct1.get_group(2)];
|
||||
|
||||
int chunk_idx = tl.block_to_chunk[item_ct1.get_group(2)];
|
||||
int n = tl.sizes[tensor_loc];
|
||||
|
||||
T* g = (T*)tl.addresses[0][tensor_loc];
|
||||
g += chunk_idx * chunk_size;
|
||||
|
||||
T* p = (T*)tl.addresses[1][tensor_loc];
|
||||
p += chunk_idx * chunk_size;
|
||||
|
||||
T* m = (T*)tl.addresses[2][tensor_loc];
|
||||
m += chunk_idx * chunk_size;
|
||||
|
||||
T* v = (T*)tl.addresses[3][tensor_loc];
|
||||
v += chunk_idx * chunk_size;
|
||||
|
||||
n -= chunk_idx * chunk_size;
|
||||
|
||||
// see note in multi_tensor_scale_kernel.cu
|
||||
for (int i_start = 0; i_start < n && i_start < chunk_size;
|
||||
i_start += item_ct1.get_local_range(2) * ILP) {
|
||||
MATH_T r_g[ILP];
|
||||
MATH_T r_p[ILP];
|
||||
MATH_T r_m[ILP];
|
||||
MATH_T r_v[ILP];
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
int i = i_start + item_ct1.get_local_id(2) + ii * item_ct1.get_local_range(2);
|
||||
if (i < n && i < chunk_size) {
|
||||
r_g[ii] = g[i];
|
||||
r_p[ii] = p[i];
|
||||
r_m[ii] = m[i];
|
||||
r_v[ii] = v[i];
|
||||
} else {
|
||||
r_g[ii] = MATH_T(0);
|
||||
r_p[ii] = MATH_T(0);
|
||||
r_m[ii] = MATH_T(0);
|
||||
r_v[ii] = MATH_T(0);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
if (mode == ADAM_MODE_0) { // L2
|
||||
r_g[ii] = r_g[ii] + (decay * r_p[ii]);
|
||||
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
|
||||
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
|
||||
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
|
||||
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
|
||||
MATH_T denom = sycl::sqrt(next_v_unbiased) + epsilon;
|
||||
MATH_T update = next_m_unbiased / denom;
|
||||
r_p[ii] = r_p[ii] - (lr * update);
|
||||
} else { // weight decay
|
||||
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
|
||||
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
|
||||
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
|
||||
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
|
||||
MATH_T denom = sycl::sqrt(next_v_unbiased) + epsilon;
|
||||
MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);
|
||||
r_p[ii] = r_p[ii] - (lr * update);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
int i = i_start + item_ct1.get_local_id(2) + ii * item_ct1.get_local_range(2);
|
||||
if (i < n && i < chunk_size) {
|
||||
p[i] = r_p[ii];
|
||||
m[i] = r_m[ii];
|
||||
v[i] = r_v[ii];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void multi_tensor_adam_cuda(int chunk_size,
|
||||
at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
const float lr,
|
||||
const float beta1,
|
||||
const float beta2,
|
||||
const float epsilon,
|
||||
const int step,
|
||||
const int mode,
|
||||
const int bias_correction,
|
||||
const float weight_decay)
|
||||
{
|
||||
using namespace at;
|
||||
|
||||
// Handle bias correction mode
|
||||
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
|
||||
if (bias_correction == 1) {
|
||||
bias_correction1 = 1 - std::pow(beta1, step);
|
||||
bias_correction2 = 1 - std::pow(beta2, step);
|
||||
}
|
||||
|
||||
// Assume single type across p,g,m1,m2 now
|
||||
DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(),
|
||||
0,
|
||||
"adam",
|
||||
multi_tensor_apply<4>(BLOCK_SIZE,
|
||||
chunk_size,
|
||||
noop_flag,
|
||||
tensor_lists,
|
||||
AdamFunctor<scalar_t_0>(),
|
||||
beta1,
|
||||
beta2,
|
||||
bias_correction1,
|
||||
bias_correction2,
|
||||
epsilon,
|
||||
lr,
|
||||
(adamMode_t)mode,
|
||||
weight_decay);)
|
||||
}
|
|
@ -0,0 +1,221 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
/*
|
||||
Copyright NVIDIA/apex
|
||||
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
|
||||
*/
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ipex.h>
|
||||
#include <sycl/sycl.hpp>
|
||||
#include "compat.h"
|
||||
|
||||
#include <assert.h>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
|
||||
namespace at {
|
||||
namespace cuda {
|
||||
sycl::queue* getCurrentCUDAStream()
|
||||
{
|
||||
auto device_type = c10::DeviceType::XPU;
|
||||
c10::impl::VirtualGuardImpl impl(device_type);
|
||||
c10::Stream c10_stream = impl.getStream(c10::Device(device_type));
|
||||
auto& queue = xpu::get_queue_from_stream(c10_stream);
|
||||
return &queue;
|
||||
}
|
||||
|
||||
sycl::queue* getStreamFromPool(bool)
|
||||
{
|
||||
// not implemented
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace cuda
|
||||
} // namespace at
|
||||
// #include <iostream>
|
||||
|
||||
// This header is the one-stop shop for all your multi-tensor apply needs.
|
||||
|
||||
// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson)
|
||||
constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
|
||||
constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
|
||||
|
||||
template <int n>
|
||||
struct TensorListMetadata {
|
||||
void* addresses[n][depth_to_max_tensors[n - 1]];
|
||||
int sizes[depth_to_max_tensors[n - 1]];
|
||||
unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
|
||||
int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a full int.
|
||||
int start_tensor_this_launch;
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename... ArgTypes>
|
||||
class multi_tensor_apply_kernel {
|
||||
public:
|
||||
multi_tensor_apply_kernel(int chunk_size,
|
||||
volatile int* noop_flag,
|
||||
T tl,
|
||||
U callable,
|
||||
ArgTypes... args)
|
||||
: chunk_size(chunk_size), noop_flag(noop_flag), tl(tl), callable(callable), args(args...)
|
||||
{
|
||||
}
|
||||
|
||||
// This should be identical to original __global__ function
|
||||
static void inline __global__function(int chunk_size,
|
||||
volatile int* noop_flag,
|
||||
T tl,
|
||||
U callable,
|
||||
ArgTypes... args)
|
||||
{
|
||||
callable(chunk_size, noop_flag, tl, args...);
|
||||
}
|
||||
|
||||
// If global function template contains parameter pack,
|
||||
// we only deal with parameter pack at the end of template parameter list
|
||||
template <typename Tuple, std::size_t... I>
|
||||
static void inline __tuple_expand_driver(int chunk_size,
|
||||
volatile int* noop_flag,
|
||||
T tl,
|
||||
U callable,
|
||||
Tuple args,
|
||||
std::index_sequence<I...>)
|
||||
{
|
||||
__global__function(chunk_size, noop_flag, tl, callable, std::get<I>(args)...);
|
||||
}
|
||||
|
||||
//
|
||||
// Because __global__ function can't really use any reference types, we can sure that args
|
||||
// are all good behaviors
|
||||
//
|
||||
void operator()(sycl::nd_item<3>) const
|
||||
{
|
||||
__tuple_expand_driver(chunk_size,
|
||||
noop_flag,
|
||||
tl,
|
||||
callable,
|
||||
args,
|
||||
std::make_index_sequence<sizeof...(ArgTypes)>());
|
||||
}
|
||||
|
||||
private:
|
||||
int chunk_size;
|
||||
volatile int* noop_flag;
|
||||
T tl;
|
||||
U callable;
|
||||
std::tuple<ArgTypes...> args;
|
||||
};
|
||||
|
||||
template <int depth, typename T, typename... ArgTypes>
|
||||
void multi_tensor_apply(int block_size,
|
||||
int chunk_size,
|
||||
const at::Tensor& noop_flag,
|
||||
const std::vector<std::vector<at::Tensor>>& tensor_lists,
|
||||
T callable,
|
||||
ArgTypes... args)
|
||||
{
|
||||
TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
|
||||
int len0 = tensor_lists[0].size();
|
||||
TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
|
||||
auto ref_device = tensor_lists[0][0].device();
|
||||
TORCH_CHECK(ref_device.type() == at::kXPU, "expected input to be on cuda");
|
||||
for (int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices
|
||||
{
|
||||
TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
|
||||
for (int t = 0; t < tensor_lists[l].size(); t++) {
|
||||
// TODO: Print which tensor fails.
|
||||
bool contiguous_memory = tensor_lists[l][t].is_contiguous();
|
||||
#ifdef VERSION_GE_1_5
|
||||
contiguous_memory = (contiguous_memory ||
|
||||
tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast));
|
||||
#endif
|
||||
TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
|
||||
TORCH_CHECK(tensor_lists[l][t].device() == ref_device,
|
||||
"A tensor was not on the same device as the first tensor");
|
||||
TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
|
||||
}
|
||||
}
|
||||
|
||||
int ntensors = tensor_lists[0].size();
|
||||
|
||||
TensorListMetadata<depth> tl;
|
||||
|
||||
/* const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0])); */
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
tl.start_tensor_this_launch = 0;
|
||||
int loc_block_info = 0;
|
||||
int loc_tensor_info = 0;
|
||||
for (int t = 0; t < ntensors; t++) {
|
||||
tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
|
||||
for (int d = 0; d < depth; d++)
|
||||
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
|
||||
loc_tensor_info++;
|
||||
|
||||
int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
|
||||
|
||||
for (int chunk = 0; chunk < chunks_this_tensor; chunk++) {
|
||||
// std::cout << chunks_this_tensor << std::endl;
|
||||
tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
|
||||
tl.block_to_chunk[loc_block_info] = chunk;
|
||||
loc_block_info++;
|
||||
|
||||
bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth - 1] &&
|
||||
chunk == chunks_this_tensor - 1);
|
||||
bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]);
|
||||
bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
|
||||
if (tensors_full || blocks_full || last_chunk) {
|
||||
// using accscalar_t = acc_type<scalar_t, true>;
|
||||
/* multi_tensor_apply_kernel<TensorListMetadata<depth>, T, ArgTypes...>
|
||||
* fn(chunk_size, noop_flag.DATA_PTR<int>(), tl, callable, args...); */
|
||||
if constexpr (sizeof(multi_tensor_apply_kernel(
|
||||
chunk_size, noop_flag.DATA_PTR<int>(), tl, callable, args...)) <
|
||||
2048) {
|
||||
((sycl::queue*)(stream))
|
||||
->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, loc_block_info) *
|
||||
sycl::range<3>(1, 1, block_size),
|
||||
sycl::range<3>(1, 1, block_size)),
|
||||
multi_tensor_apply_kernel(
|
||||
chunk_size, noop_flag.DATA_PTR<int>(), tl, callable, args...));
|
||||
} else {
|
||||
auto capture = multi_tensor_apply_kernel(
|
||||
chunk_size, noop_flag.DATA_PTR<int>(), tl, callable, args...);
|
||||
sycl::buffer params(const_cast<const decltype(capture)*>(&capture),
|
||||
sycl::range<1>(1));
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
auto device_params =
|
||||
params.template get_access<sycl::access_mode::read,
|
||||
sycl::target::constant_buffer>(cgh);
|
||||
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, loc_block_info) *
|
||||
sycl::range<3>(1, 1, block_size),
|
||||
sycl::range<3>(1, 1, block_size)),
|
||||
[=](sycl::nd_item<3> item) { device_params[0](item); });
|
||||
});
|
||||
}
|
||||
0;
|
||||
|
||||
// Reset. The control flow possibilities here make my brain hurt.
|
||||
loc_block_info = 0;
|
||||
if (chunk == chunks_this_tensor - 1) {
|
||||
// std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 <<
|
||||
// std::endl;
|
||||
loc_tensor_info = 0;
|
||||
tl.start_tensor_this_launch = t + 1;
|
||||
} else {
|
||||
// std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 <<
|
||||
// std::endl;
|
||||
tl.sizes[0] = tl.sizes[loc_tensor_info - 1];
|
||||
for (int d = 0; d < depth; d++)
|
||||
tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1];
|
||||
loc_tensor_info = 1;
|
||||
tl.start_tensor_this_launch = t;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,92 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include <sycl/sycl.hpp>
|
||||
|
||||
inline void has_capability_or_fail(const sycl::device& dev,
|
||||
const std::initializer_list<sycl::aspect>& props)
|
||||
{
|
||||
for (const auto& it : props) {
|
||||
if (dev.has(it)) continue;
|
||||
switch (it) {
|
||||
case sycl::aspect::fp64:
|
||||
throw std::runtime_error("'double' is not supported in '" +
|
||||
dev.get_info<sycl::info::device::name>() + "' device");
|
||||
break;
|
||||
case sycl::aspect::fp16:
|
||||
throw std::runtime_error("'half' is not supported in '" +
|
||||
dev.get_info<sycl::info::device::name>() + "' device");
|
||||
break;
|
||||
default:
|
||||
#define __SYCL_ASPECT(ASPECT, ID) \
|
||||
case sycl::aspect::ASPECT: return #ASPECT;
|
||||
#define __SYCL_ASPECT_DEPRECATED(ASPECT, ID, MESSAGE) __SYCL_ASPECT(ASPECT, ID)
|
||||
#define __SYCL_ASPECT_DEPRECATED_ALIAS(ASPECT, ID, MESSAGE)
|
||||
auto getAspectNameStr = [](sycl::aspect AspectNum) -> std::string {
|
||||
switch (AspectNum) {
|
||||
#include <sycl/info/aspects.def>
|
||||
#include <sycl/info/aspects_deprecated.def>
|
||||
default: return "unknown aspect";
|
||||
}
|
||||
};
|
||||
#undef __SYCL_ASPECT_DEPRECATED_ALIAS
|
||||
#undef __SYCL_ASPECT_DEPRECATED
|
||||
#undef __SYCL_ASPECT
|
||||
throw std::runtime_error("'" + getAspectNameStr(it) + "' is not supported in '" +
|
||||
dev.get_info<sycl::info::device::name>() + "' device");
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void param_update_kernel(const float* input, sycl::half* output, int size)
|
||||
{
|
||||
auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>();
|
||||
int id = item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2);
|
||||
|
||||
if (id < size) { output[id] = (sycl::half)input[id]; }
|
||||
}
|
||||
|
||||
void launch_param_update(const float* input, sycl::half* output, int size, sycl::queue* stream)
|
||||
{
|
||||
int threads = 1024;
|
||||
|
||||
sycl::range<3> grid_dim(1, 1, (size - 1) / threads + 1);
|
||||
sycl::range<3> block_dim(1, 1, threads);
|
||||
|
||||
{
|
||||
has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16});
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(grid_dim * block_dim, block_dim),
|
||||
[=](sycl::nd_item<3> item_ct1) { param_update_kernel(input, output, size); });
|
||||
}
|
||||
}
|
||||
|
||||
void param_update_kernel_half(const float* input, sycl::half* output, int size)
|
||||
{
|
||||
auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>();
|
||||
int id = item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2);
|
||||
sycl::half2* output_cast = reinterpret_cast<sycl::half2*>(output);
|
||||
if (id < size) {
|
||||
float input_f = input[id];
|
||||
sycl::half2* input_h = reinterpret_cast<sycl::half2*>(&input_f);
|
||||
output_cast[id] = *input_h;
|
||||
}
|
||||
}
|
||||
|
||||
void launch_param_update_half(const float* input, sycl::half* output, int size, sycl::queue* stream)
|
||||
{
|
||||
int threads = 1024;
|
||||
size /= 2;
|
||||
sycl::range<3> grid_dim(1, 1, (size - 1) / threads + 1);
|
||||
sycl::range<3> block_dim(1, 1, threads);
|
||||
|
||||
{
|
||||
has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16});
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(grid_dim * block_dim, block_dim),
|
||||
[=](sycl::nd_item<3> item_ct1) { param_update_kernel_half(input, output, size); });
|
||||
}
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
/*
|
||||
Copyright NVIDIA/apex
|
||||
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
|
||||
*/
|
||||
|
||||
#ifndef TORCH_CHECK
|
||||
#define TORCH_CHECK AT_CHECK
|
||||
#endif
|
||||
|
||||
#ifdef VERSION_GE_1_3
|
||||
#define DATA_PTR data_ptr
|
||||
#else
|
||||
#define DATA_PTR data
|
||||
#endif
|
|
@ -0,0 +1,120 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#pragma once
|
||||
|
||||
#define NOMINMAX // Windows idiosyncrasy
|
||||
// https://stackoverflow.com/questions/4913922/possible-problems-with-nominmax-on-visual-c
|
||||
|
||||
#include <stdio.h>
|
||||
#include <cassert>
|
||||
#include "simd.h"
|
||||
|
||||
typedef unsigned short ds_half_precision_t;
|
||||
|
||||
#define STEP(SPAN) \
|
||||
void Step_##SPAN(float* _params, \
|
||||
float* grads, \
|
||||
float* _exp_avg_sq, \
|
||||
size_t _param_size, \
|
||||
ds_half_precision_t* dev_param = nullptr, \
|
||||
bool half_precision = false);
|
||||
|
||||
class Adagrad_Optimizer {
|
||||
public:
|
||||
Adagrad_Optimizer(float alpha = 1e-2, float eps = 1e-8, float weight_decay = 0)
|
||||
: _alpha(alpha), _eps(eps), _weight_decay(weight_decay)
|
||||
{
|
||||
}
|
||||
~Adagrad_Optimizer() {}
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
template <int span>
|
||||
void Step_AVX(size_t* rounded_size,
|
||||
float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg_sq,
|
||||
size_t param_size,
|
||||
ds_half_precision_t* dev_param = nullptr,
|
||||
bool half_precision = false);
|
||||
#endif
|
||||
STEP(1)
|
||||
STEP(4)
|
||||
STEP(8)
|
||||
inline void IncrementStep(size_t step)
|
||||
{
|
||||
_step++;
|
||||
if (_step != step) { _step = step; }
|
||||
}
|
||||
inline void update_state(float lr, float epsilon, float weight_decay)
|
||||
{
|
||||
_alpha = lr;
|
||||
_eps = epsilon;
|
||||
_weight_decay = weight_decay;
|
||||
}
|
||||
|
||||
private:
|
||||
float _alpha;
|
||||
float _eps;
|
||||
float _weight_decay;
|
||||
|
||||
float _betta1_t;
|
||||
float _betta2_t;
|
||||
size_t _step;
|
||||
};
|
||||
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
template <int span>
|
||||
void Adagrad_Optimizer::Step_AVX(size_t* rounded_size,
|
||||
float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg_sq,
|
||||
size_t _param_size,
|
||||
ds_half_precision_t* dev_params,
|
||||
bool half_precision)
|
||||
{
|
||||
size_t new_rounded_size = 0;
|
||||
AVX_Data eps_4;
|
||||
eps_4.data = SIMD_SET(_eps);
|
||||
|
||||
float step_size = -1 * _alpha;
|
||||
AVX_Data step_size_4;
|
||||
step_size_4.data = SIMD_SET(step_size);
|
||||
|
||||
AVX_Data weight_decay4;
|
||||
if (_weight_decay > 0) weight_decay4.data = SIMD_SET(_weight_decay);
|
||||
new_rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * span);
|
||||
for (size_t t = 0; t < new_rounded_size; t += TILE) {
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
#pragma omp parallel for
|
||||
for (size_t i = t; i < offset; i += SIMD_WIDTH * span) {
|
||||
AVX_Data grad_4[span];
|
||||
simd_load<span>(grad_4, grads + i, half_precision);
|
||||
|
||||
AVX_Data momentum_4[span];
|
||||
simd_load<span>(momentum_4, grads + i, false);
|
||||
|
||||
AVX_Data variance_4[span];
|
||||
simd_load<span>(variance_4, _exp_avg_sq + i, false);
|
||||
|
||||
AVX_Data param_4[span];
|
||||
simd_load<span>(param_4, _params + i, half_precision);
|
||||
|
||||
if (_weight_decay > 0) { simd_fma<span>(grad_4, param_4, weight_decay4, grad_4); }
|
||||
|
||||
simd_fma<span>(variance_4, grad_4, grad_4, variance_4);
|
||||
simd_sqrt<span>(grad_4, variance_4);
|
||||
simd_add<span>(grad_4, grad_4, eps_4);
|
||||
simd_div<span>(grad_4, momentum_4, grad_4);
|
||||
simd_fma<span>(param_4, grad_4, step_size_4, param_4);
|
||||
|
||||
simd_store<span>(_params + i, param_4, half_precision);
|
||||
simd_store<span>(_exp_avg_sq + i, variance_4, false);
|
||||
}
|
||||
}
|
||||
*rounded_size = new_rounded_size;
|
||||
}
|
||||
#endif
|
|
@ -0,0 +1,237 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#pragma once
|
||||
|
||||
#define NOMINMAX // Windows idiosyncrasy
|
||||
// https://stackoverflow.com/questions/4913922/possible-problems-with-nominmax-on-visual-c
|
||||
|
||||
#include <stdio.h>
|
||||
#include <torch/extension.h>
|
||||
#include <cassert>
|
||||
#include "simd.h"
|
||||
|
||||
#include <cmath>
|
||||
typedef unsigned short ds_half_precision_t;
|
||||
|
||||
#define STEP(SPAN) \
|
||||
void Step_##SPAN(float* _params, \
|
||||
float* grads, \
|
||||
float* _exp_avg, \
|
||||
float* _exp_avg_sq, \
|
||||
size_t _param_size, \
|
||||
ds_half_precision_t* dev_param = nullptr, \
|
||||
bool half_precision = false);
|
||||
|
||||
class Adam_Optimizer {
|
||||
public:
|
||||
Adam_Optimizer(float alpha = 1e-3,
|
||||
float betta1 = 0.9,
|
||||
float betta2 = 0.999,
|
||||
float eps = 1e-8,
|
||||
float weight_decay = 0,
|
||||
bool adamw_mode = true)
|
||||
: _alpha(alpha),
|
||||
_betta1(betta1),
|
||||
_betta2(betta2),
|
||||
_eps(eps),
|
||||
_weight_decay(weight_decay),
|
||||
_betta1_t(1.0),
|
||||
_betta2_t(1.0),
|
||||
_step(0),
|
||||
_adamw_mode(adamw_mode)
|
||||
{
|
||||
}
|
||||
~Adam_Optimizer() {}
|
||||
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
template <int span>
|
||||
void Step_AVX(size_t* rounded_size,
|
||||
float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg,
|
||||
float* _exp_avg_sq,
|
||||
size_t param_size,
|
||||
ds_half_precision_t* dev_param = nullptr,
|
||||
bool half_precision = false);
|
||||
#endif
|
||||
STEP(1)
|
||||
STEP(4)
|
||||
STEP(8)
|
||||
inline void IncrementStep(size_t step, float beta1, float beta2)
|
||||
{
|
||||
if (beta1 != _betta1 || beta2 != _betta2) {
|
||||
_step = step;
|
||||
_betta1 = beta1;
|
||||
_betta2 = beta2;
|
||||
_betta1_t = std::pow(_betta1, step);
|
||||
_betta2_t = std::pow(_betta2, step);
|
||||
} else {
|
||||
_step++;
|
||||
if (_step != step) {
|
||||
_betta1_t = std::pow(_betta1, step);
|
||||
_betta2_t = std::pow(_betta2, step);
|
||||
_step = step;
|
||||
} else {
|
||||
_betta1_t *= _betta1;
|
||||
_betta2_t *= _betta2;
|
||||
}
|
||||
}
|
||||
}
|
||||
inline void update_state(float lr, float epsilon, float weight_decay, bool bias_correction)
|
||||
{
|
||||
_alpha = lr;
|
||||
_eps = epsilon;
|
||||
_weight_decay = weight_decay;
|
||||
|
||||
_bias_correction1 = 1.0f;
|
||||
_bias_correction2 = 1.0f;
|
||||
if (bias_correction == 1) {
|
||||
_bias_correction1 = 1 - _betta1_t;
|
||||
_bias_correction2 = 1 / sqrt(1 - _betta2_t);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
float _alpha;
|
||||
float _betta1;
|
||||
float _betta2;
|
||||
float _eps;
|
||||
float _weight_decay;
|
||||
|
||||
float _betta1_t;
|
||||
float _betta2_t;
|
||||
size_t _step;
|
||||
|
||||
float _bias_correction1;
|
||||
float _bias_correction2;
|
||||
|
||||
bool _adamw_mode;
|
||||
};
|
||||
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
template <int span>
|
||||
void Adam_Optimizer::Step_AVX(size_t* rounded_size,
|
||||
float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg,
|
||||
float* _exp_avg_sq,
|
||||
size_t _param_size,
|
||||
ds_half_precision_t* dev_params,
|
||||
bool half_precision)
|
||||
{
|
||||
size_t new_rounded_size = 0;
|
||||
int rshft = half_precision ? 1 : 0;
|
||||
|
||||
AVX_Data betta1_4;
|
||||
betta1_4.data = SIMD_SET(_betta1);
|
||||
AVX_Data betta2_4;
|
||||
betta2_4.data = SIMD_SET(_betta2);
|
||||
|
||||
float betta1_minus1 = 1 - _betta1;
|
||||
float betta2_minus1 = 1 - _betta2;
|
||||
AVX_Data betta1_minus1_4;
|
||||
betta1_minus1_4.data = SIMD_SET(betta1_minus1);
|
||||
AVX_Data betta2_minus1_4;
|
||||
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
|
||||
|
||||
AVX_Data bias2_sqrt;
|
||||
bias2_sqrt.data = SIMD_SET(_bias_correction2);
|
||||
|
||||
AVX_Data eps_4;
|
||||
eps_4.data = SIMD_SET(_eps);
|
||||
|
||||
float step_size = -1 * _alpha / _bias_correction1;
|
||||
AVX_Data step_size_4;
|
||||
step_size_4.data = SIMD_SET(step_size);
|
||||
|
||||
float w_decay = -1 * _alpha * _weight_decay;
|
||||
AVX_Data weight_decay4;
|
||||
if (_weight_decay > 0)
|
||||
weight_decay4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
|
||||
new_rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * span);
|
||||
for (size_t t = 0; t < new_rounded_size; t += TILE) {
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
#pragma omp parallel for
|
||||
for (size_t i = t; i < offset; i += SIMD_WIDTH * span) {
|
||||
AVX_Data grad_4[span];
|
||||
simd_load<span>(grad_4, grads + (i >> rshft), half_precision);
|
||||
|
||||
AVX_Data momentum_4[span];
|
||||
simd_load<span>(momentum_4, _exp_avg + i, false);
|
||||
|
||||
AVX_Data variance_4[span];
|
||||
simd_load<span>(variance_4, _exp_avg_sq + i, false);
|
||||
|
||||
AVX_Data param_4[span];
|
||||
simd_load<span>(param_4, _params + (i >> rshft), half_precision);
|
||||
|
||||
if (_weight_decay > 0 && !_adamw_mode) {
|
||||
simd_fma<span>(grad_4, param_4, weight_decay4, grad_4);
|
||||
}
|
||||
|
||||
simd_mul<span>(momentum_4, momentum_4, betta1_4);
|
||||
simd_fma<span>(momentum_4, grad_4, betta1_minus1_4, momentum_4);
|
||||
simd_mul<span>(variance_4, variance_4, betta2_4);
|
||||
simd_mul<span>(grad_4, grad_4, grad_4);
|
||||
simd_fma<span>(variance_4, grad_4, betta2_minus1_4, variance_4);
|
||||
simd_sqrt<span>(grad_4, variance_4);
|
||||
simd_fma<span>(grad_4, grad_4, bias2_sqrt, eps_4);
|
||||
simd_div<span>(grad_4, momentum_4, grad_4);
|
||||
|
||||
if (_weight_decay > 0 && _adamw_mode) {
|
||||
simd_fma<span>(param_4, param_4, weight_decay4, param_4);
|
||||
}
|
||||
|
||||
simd_fma<span>(param_4, grad_4, step_size_4, param_4);
|
||||
|
||||
simd_store<span>(_params + (i >> rshft), param_4, half_precision);
|
||||
simd_store<span>(_exp_avg + i, momentum_4, false);
|
||||
simd_store<span>(_exp_avg_sq + i, variance_4, false);
|
||||
}
|
||||
}
|
||||
*rounded_size = new_rounded_size;
|
||||
}
|
||||
#endif
|
||||
|
||||
int create_adam_optimizer(int optimizer_id,
|
||||
float alpha = 1e-3,
|
||||
float betta1 = 0.9,
|
||||
float betta2 = 0.999,
|
||||
float eps = 1e-8,
|
||||
float weight_decay = 0,
|
||||
bool adamw_mode = true,
|
||||
bool should_log = false);
|
||||
|
||||
int ds_adam_step(int optimizer_id,
|
||||
size_t step,
|
||||
float lr,
|
||||
float beta1,
|
||||
float beta2,
|
||||
float epsilon,
|
||||
float weight_decay,
|
||||
bool bias_correction,
|
||||
torch::Tensor& params,
|
||||
torch::Tensor& grads,
|
||||
torch::Tensor& exp_avg,
|
||||
torch::Tensor& exp_avg_sq);
|
||||
|
||||
int ds_adam_step_plus_copy(int optimizer_id,
|
||||
size_t step,
|
||||
float lr,
|
||||
float beta1,
|
||||
float beta2,
|
||||
float epsilon,
|
||||
float weight_decay,
|
||||
bool bias_correction,
|
||||
torch::Tensor& params,
|
||||
torch::Tensor& grads,
|
||||
torch::Tensor& exp_avg,
|
||||
torch::Tensor& exp_avg_sq,
|
||||
torch::Tensor& gpu_params);
|
||||
|
||||
int destroy_adam_optimizer(int optimizer_id);
|
|
@ -0,0 +1,198 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#pragma once
|
||||
|
||||
#if (__x86_64__ || __i386__)
|
||||
#include <cpuid.h>
|
||||
#include <x86intrin.h>
|
||||
#endif
|
||||
|
||||
#define TILE (128 * 1024 * 1024)
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
|
||||
#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
|
||||
|
||||
#if defined(__AVX512__)
|
||||
#define SIMD_STORE(a, d) _mm512_storeu_ps(a, d)
|
||||
#define SIMD_LOAD(x) _mm512_loadu_ps(x)
|
||||
#define SIMD_SET(x) _mm512_set1_ps(x)
|
||||
#define SIMD_ADD(x, y) _mm512_add_ps(x, y)
|
||||
#define SIMD_MUL(x, y) _mm512_mul_ps(x, y)
|
||||
#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c)
|
||||
#define SIMD_SQRT(x) _mm512_sqrt_ps(x)
|
||||
#define SIMD_DIV(x, y) _mm512_div_ps(x, y)
|
||||
#define SIMD_AND(x, y) _mm512_and_ps(x, y)
|
||||
#define SIMD_ANDNOT(x, y) _mm512_andnot_ps(x, y)
|
||||
#define SIMD_OR(x, y) _mm512_or_ps(x, y)
|
||||
#define SIMD_XOR(x, y) _mm512_xor_ps(x, y)
|
||||
#define SIMD_WIDTH 16
|
||||
|
||||
#define SIMD_LOAD2(x, h) \
|
||||
((h) ? _mm512_cvtph_ps(_mm256_castps_si256(_mm256_loadu_ps(x))) : _mm512_loadu_ps(x))
|
||||
#define SIMD_STORE2(x, d, h) \
|
||||
((h) ? _mm256_store_ps(x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \
|
||||
: _mm512_storeu_ps(x, d))
|
||||
|
||||
#define INTV __m256i
|
||||
#elif defined(__AVX256__)
|
||||
#define SIMD_STORE(a, d) _mm256_storeu_ps(a, d)
|
||||
#define SIMD_LOAD(x) _mm256_loadu_ps(x)
|
||||
#define SIMD_SET(x) _mm256_set1_ps(x)
|
||||
#define SIMD_ADD(x, y) _mm256_add_ps(x, y)
|
||||
#define SIMD_MUL(x, y) _mm256_mul_ps(x, y)
|
||||
#define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c)
|
||||
#define SIMD_SQRT(x) _mm256_sqrt_ps(x)
|
||||
#define SIMD_DIV(x, y) _mm256_div_ps(x, y)
|
||||
#define SIMD_AND(x, y) _mm256_and_ps(x, y)
|
||||
#define SIMD_ANDNOT(x, y) _mm256_andnot_ps(x, y)
|
||||
#define SIMD_OR(x, y) _mm256_or_ps(x, y)
|
||||
#define SIMD_XOR(x, y) _mm256_xor_ps(x, y)
|
||||
#define SIMD_WIDTH 8
|
||||
|
||||
#define SIMD_LOAD2(x, h) \
|
||||
((h) ? _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)x)) : _mm256_loadu_ps(x))
|
||||
#define SIMD_STORE2(x, d, h) \
|
||||
((h) ? _mm_store_ps(x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \
|
||||
: _mm256_storeu_ps(x, d))
|
||||
|
||||
#define INTV __m128i
|
||||
#endif
|
||||
|
||||
union AVX_Data {
|
||||
#if defined(__AVX512__)
|
||||
__m512 data;
|
||||
#elif defined(__AVX256__)
|
||||
__m256 data;
|
||||
#endif
|
||||
// float data_f[16];
|
||||
};
|
||||
|
||||
template <int span>
|
||||
inline void simd_store(float* dst, AVX_Data* src, bool half_precision)
|
||||
{
|
||||
size_t width = (half_precision ? SIMD_WIDTH / 2 : SIMD_WIDTH);
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) { SIMD_STORE2(dst + width * i, src[i].data, half_precision); }
|
||||
}
|
||||
template <int span>
|
||||
inline void simd_load(AVX_Data* dst, float* src, bool half_precision)
|
||||
{
|
||||
size_t width = (half_precision ? 1 : SIMD_WIDTH);
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_LOAD2(src + width * i, half_precision); }
|
||||
}
|
||||
template <int span>
|
||||
inline void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data src_m_r, AVX_Data* src_a)
|
||||
{
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) {
|
||||
dst[i].data = SIMD_FMA(src_m_l[i].data, src_m_r.data, src_a[i].data);
|
||||
}
|
||||
}
|
||||
template <int span>
|
||||
inline void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data src_m_r, AVX_Data src_a)
|
||||
{
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) {
|
||||
dst[i].data = SIMD_FMA(src_m_l[i].data, src_m_r.data, src_a.data);
|
||||
}
|
||||
}
|
||||
template <int span>
|
||||
inline void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data* src_m_r, AVX_Data* src_a)
|
||||
{
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) {
|
||||
dst[i].data = SIMD_FMA(src_m_l[i].data, src_m_r[i].data, src_a[i].data);
|
||||
}
|
||||
}
|
||||
template <int span>
|
||||
inline void simd_sqrt(AVX_Data* dst, AVX_Data* src)
|
||||
{
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_SQRT(src[i].data); }
|
||||
}
|
||||
template <int span>
|
||||
inline void simd_add(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r)
|
||||
{
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_ADD(src_a_l[i].data, src_a_r.data); }
|
||||
}
|
||||
template <int span>
|
||||
inline void simd_add(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r)
|
||||
{
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_ADD(src_a_l[i].data, src_a_r[i].data); }
|
||||
}
|
||||
template <int span>
|
||||
inline void simd_mul(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r)
|
||||
{
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_MUL(src_a_l[i].data, src_a_r.data); }
|
||||
}
|
||||
template <int span>
|
||||
inline void simd_mul(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r)
|
||||
{
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_MUL(src_a_l[i].data, src_a_r[i].data); }
|
||||
}
|
||||
template <int span>
|
||||
inline void simd_div(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r)
|
||||
{
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_DIV(src_a_l[i].data, src_a_r[i].data); }
|
||||
}
|
||||
template <int span>
|
||||
inline void simd_and(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r)
|
||||
{
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_AND(src_a_l[i].data, src_a_r.data); }
|
||||
}
|
||||
template <int span>
|
||||
inline void simd_and(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r)
|
||||
{
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_AND(src_a_l[i].data, src_a_r[i].data); }
|
||||
}
|
||||
template <int span>
|
||||
inline void simd_andnot(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r)
|
||||
{
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_ANDNOT(src_a_l[i].data, src_a_r.data); }
|
||||
}
|
||||
template <int span>
|
||||
inline void simd_andnot(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r)
|
||||
{
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) {
|
||||
dst[i].data = SIMD_ANDNOT(src_a_l[i].data, src_a_r[i].data);
|
||||
}
|
||||
}
|
||||
template <int span>
|
||||
inline void simd_or(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r)
|
||||
{
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_OR(src_a_l[i].data, src_a_r.data); }
|
||||
}
|
||||
template <int span>
|
||||
inline void simd_or(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r)
|
||||
{
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_OR(src_a_l[i].data, src_a_r[i].data); }
|
||||
}
|
||||
template <int span>
|
||||
inline void simd_xor(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r)
|
||||
{
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_XOR(src_a_l[i].data, src_a_r.data); }
|
||||
}
|
||||
template <int span>
|
||||
inline void simd_xor(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r)
|
||||
{
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_XOR(src_a_l[i].data, src_a_r[i].data); }
|
||||
}
|
||||
|
||||
#endif
|
|
@ -0,0 +1,155 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
/* Taken from NVIDIA/apex commit 855808f3fc268e9715d613f3c2e56469d8c986d8 */
|
||||
#include <sycl/sycl.hpp>
|
||||
/* #include <dpct/dpct.hpp> */
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
// Forward/backward compatibility hack around
|
||||
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
|
||||
// pending more future-proof guidance from upstream.
|
||||
// struct TypeShim
|
||||
// {
|
||||
// const at::Type& payload;
|
||||
// TypeShim(const at::Type& type) : payload(type) {}
|
||||
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
|
||||
// operator const at::Type&(){ return payload; };
|
||||
// // Enable dispatch switch statements to take *this directly for post-3aeb78
|
||||
// //operator at::ScalarType(){ return payload.; };
|
||||
// };
|
||||
|
||||
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
|
||||
switch (TYPE) { \
|
||||
case at::ScalarType::Float: { \
|
||||
using scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: { \
|
||||
using scalar_t_##LEVEL = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::BFloat16: { \
|
||||
using scalar_t_##LEVEL = at::BFloat16; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
|
||||
switch (TYPE) { \
|
||||
case at::ScalarType::Double: { \
|
||||
using scalar_t_##LEVEL = double; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Float: { \
|
||||
using scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: { \
|
||||
using scalar_t_##LEVEL = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::BFloat16: { \
|
||||
using scalar_t_##LEVEL = at::BFloat16; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
|
||||
switch (TYPE) { \
|
||||
case at::ScalarType::Double: { \
|
||||
using scalar_t_##LEVEL = double; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Float: { \
|
||||
using scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__inline__ __attribute__((always_inline)) T reduce_block_into_lanes(
|
||||
T* x,
|
||||
T val,
|
||||
int lanes = 1,
|
||||
bool share_result = false) // lanes is intended to be <= 32.
|
||||
{
|
||||
auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>();
|
||||
int tid = item_ct1.get_local_id(2) + item_ct1.get_local_id(1) * item_ct1.get_local_range(2);
|
||||
int blockSize = item_ct1.get_local_range(2) *
|
||||
item_ct1.get_local_range(1); // blockSize is intended to be a multiple of 32.
|
||||
|
||||
if (blockSize >= 64) {
|
||||
x[tid] = val;
|
||||
/*
|
||||
DPCT1118:1: SYCL group functions and algorithms must be encountered in converged control
|
||||
flow. You may need to adjust the code.
|
||||
*/
|
||||
/*
|
||||
DPCT1065:6: Consider replacing sycl::nd_item::barrier() with
|
||||
sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if
|
||||
there is no access to global memory.
|
||||
*/
|
||||
item_ct1.barrier();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
|
||||
if (tid < i) x[tid] = x[tid] + x[tid + i];
|
||||
/*
|
||||
DPCT1118:2: SYCL group functions and algorithms must be encountered in converged control
|
||||
flow. You may need to adjust the code.
|
||||
*/
|
||||
/*
|
||||
DPCT1065:7: Consider replacing sycl::nd_item::barrier() with
|
||||
sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if
|
||||
there is no access to global memory.
|
||||
*/
|
||||
item_ct1.barrier();
|
||||
}
|
||||
|
||||
T final;
|
||||
|
||||
if (tid < 32) {
|
||||
if (blockSize >= 64)
|
||||
final = x[tid] + x[tid + 32];
|
||||
else
|
||||
final = val;
|
||||
// __SYNCWARP();
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 16; i >= lanes; i >>= 1)
|
||||
final = final + __shfl_down_sync(0xffffffff, final, i);
|
||||
}
|
||||
|
||||
if (share_result) {
|
||||
if (tid < lanes) x[tid] = final; // EpilogueOp
|
||||
// Make sure the smem result is visible to all warps.
|
||||
/*
|
||||
DPCT1118:3: SYCL group functions and algorithms must be encountered in converged control
|
||||
flow. You may need to adjust the code.
|
||||
*/
|
||||
/*
|
||||
DPCT1065:8: Consider replacing sycl::nd_item::barrier() with
|
||||
sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if
|
||||
there is no access to global memory.
|
||||
*/
|
||||
item_ct1.barrier();
|
||||
}
|
||||
|
||||
return final;
|
||||
}
|
|
@ -108,6 +108,7 @@ def assert_no_cuda_mismatch(name=""):
|
|||
class OpBuilder(ABC):
|
||||
_rocm_version = None
|
||||
_is_rocm_pytorch = None
|
||||
_is_sycl_enabled = None
|
||||
_loaded_ops = {}
|
||||
|
||||
def __init__(self, name):
|
||||
|
@ -135,6 +136,9 @@ class OpBuilder(ABC):
|
|||
def hipify_extension(self):
|
||||
pass
|
||||
|
||||
def sycl_extension(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def validate_torch_version(torch_info):
|
||||
install_torch_version = torch_info['version']
|
||||
|
@ -186,6 +190,22 @@ class OpBuilder(ABC):
|
|||
OpBuilder._is_rocm_pytorch = _is_rocm_pytorch
|
||||
return OpBuilder._is_rocm_pytorch
|
||||
|
||||
@staticmethod
|
||||
def is_sycl_enabled():
|
||||
if OpBuilder._is_sycl_enabled is not None:
|
||||
return OpBuilder._is_sycl_enabled
|
||||
|
||||
_is_sycl_enabled = False
|
||||
try:
|
||||
result = subprocess.run(["c2s", "--version"], capture_output=True)
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
_is_sycl_enabled = True
|
||||
|
||||
OpBuilder._is_sycl_enabled = _is_sycl_enabled
|
||||
return OpBuilder._is_sycl_enabled
|
||||
|
||||
@staticmethod
|
||||
def installed_rocm_version():
|
||||
if OpBuilder._rocm_version:
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from .cpu_adam import CPUAdamBuilder
|
||||
from .cpu_adagrad import CPUAdagradBuilder
|
||||
from .fused_adam import FusedAdamBuilder
|
||||
from .async_io import AsyncIOBuilder
|
|
@ -0,0 +1,99 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import distutils.spawn
|
||||
import subprocess
|
||||
import torch
|
||||
|
||||
from .builder import OpBuilder
|
||||
|
||||
|
||||
class AsyncIOBuilder(OpBuilder):
|
||||
BUILD_VAR = "DS_BUILD_AIO"
|
||||
NAME = "async_io"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name=self.NAME)
|
||||
|
||||
def absolute_name(self):
|
||||
return f'deepspeed.ops.aio.{self.NAME}_op'
|
||||
|
||||
def sources(self):
|
||||
return [
|
||||
'csrc/aio/py_lib/deepspeed_py_copy.cpp', 'csrc/aio/py_lib/py_ds_aio.cpp',
|
||||
'csrc/aio/py_lib/deepspeed_py_aio.cpp', 'csrc/aio/py_lib/deepspeed_py_aio_handle.cpp',
|
||||
'csrc/aio/py_lib/deepspeed_aio_thread.cpp', 'csrc/aio/common/deepspeed_aio_utils.cpp',
|
||||
'csrc/aio/common/deepspeed_aio_common.cpp', 'csrc/aio/common/deepspeed_aio_types.cpp',
|
||||
'csrc/aio/py_lib/deepspeed_pin_tensor.cpp'
|
||||
]
|
||||
|
||||
def include_paths(self):
|
||||
return ['csrc/aio/py_lib', 'csrc/aio/common']
|
||||
|
||||
def cxx_args(self):
|
||||
# -O0 for improved debugging, since performance is bound by I/O
|
||||
CPU_ARCH = self.cpu_arch()
|
||||
SIMD_WIDTH = self.simd_width()
|
||||
TORCH_MAJOR, TORCH_MINOR = map(int, torch.__version__.split('.')[0:2])
|
||||
if TORCH_MAJOR >= 2 and TORCH_MINOR >= 1:
|
||||
CPP_STD = '-std=c++17'
|
||||
else:
|
||||
CPP_STD = '-std=c++14'
|
||||
return [
|
||||
'-g',
|
||||
'-Wall',
|
||||
'-O0',
|
||||
CPP_STD,
|
||||
'-shared',
|
||||
'-fPIC',
|
||||
'-Wno-reorder',
|
||||
CPU_ARCH,
|
||||
'-fopenmp',
|
||||
SIMD_WIDTH,
|
||||
'-laio',
|
||||
]
|
||||
|
||||
def extra_ldflags(self):
|
||||
return ['-laio']
|
||||
|
||||
def check_for_libaio_pkg(self):
|
||||
libs = dict(
|
||||
dpkg=["-l", "libaio-dev", "apt"],
|
||||
pacman=["-Q", "libaio", "pacman"],
|
||||
rpm=["-q", "libaio-devel", "yum"],
|
||||
)
|
||||
|
||||
found = False
|
||||
for pkgmgr, data in libs.items():
|
||||
flag, lib, tool = data
|
||||
path = distutils.spawn.find_executable(pkgmgr)
|
||||
if path is not None:
|
||||
cmd = f"{pkgmgr} {flag} {lib}"
|
||||
result = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
|
||||
if result.wait() == 0:
|
||||
found = True
|
||||
else:
|
||||
self.warning(f"{self.NAME}: please install the {lib} package with {tool}")
|
||||
break
|
||||
return found
|
||||
|
||||
def is_compatible(self, verbose=True):
|
||||
# Check for the existence of libaio by using distutils
|
||||
# to compile and link a test program that calls io_submit,
|
||||
# which is a function provided by libaio that is used in the async_io op.
|
||||
# If needed, one can define -I and -L entries in CFLAGS and LDFLAGS
|
||||
# respectively to specify the directories for libaio.h and libaio.so.
|
||||
aio_compatible = self.has_function('io_pgetevents', ('aio', ))
|
||||
if verbose and not aio_compatible:
|
||||
self.warning(f"{self.NAME} requires the dev libaio .so object and headers but these were not found.")
|
||||
|
||||
# Check for the libaio package via known package managers
|
||||
# to print suggestions on which package to install.
|
||||
self.check_for_libaio_pkg()
|
||||
|
||||
self.warning(
|
||||
"If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found."
|
||||
)
|
||||
return super().is_compatible(verbose) and aio_compatible
|
|
@ -0,0 +1,126 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import os
|
||||
import time
|
||||
import importlib
|
||||
|
||||
try:
|
||||
# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
|
||||
# if successful this also means we're doing a local install and not JIT compile path
|
||||
from op_builder import __deepspeed__ # noqa: F401 # type: ignore
|
||||
from op_builder.builder import OpBuilder, TORCH_MAJOR, TORCH_MINOR
|
||||
except ImportError:
|
||||
from deepspeed.ops.op_builder.builder import OpBuilder, TORCH_MAJOR, TORCH_MINOR
|
||||
|
||||
|
||||
class SYCLOpBuilder(OpBuilder):
|
||||
|
||||
def builder(self):
|
||||
try:
|
||||
from intel_extension_for_pytorch.xpu.cpp_extension import DPCPPExtension
|
||||
except ImportError:
|
||||
from intel_extension_for_pytorch.xpu.utils import DPCPPExtension
|
||||
|
||||
print("dpcpp sources = {}".format(self.sources()))
|
||||
dpcpp_ext = DPCPPExtension(name=self.absolute_name(),
|
||||
sources=self.strip_empty_entries(self.sources()),
|
||||
include_dirs=self.strip_empty_entries(self.include_paths()),
|
||||
extra_compile_args={
|
||||
'cxx': self.strip_empty_entries(self.cxx_args()),
|
||||
},
|
||||
extra_link_args=self.strip_empty_entries(self.fixed_aotflags()))
|
||||
return dpcpp_ext
|
||||
|
||||
def version_dependent_macros(self):
|
||||
# Fix from apex that might be relevant for us as well, related to https://github.com/NVIDIA/apex/issues/456
|
||||
version_ge_1_1 = []
|
||||
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0):
|
||||
version_ge_1_1 = ['-DVERSION_GE_1_1']
|
||||
version_ge_1_3 = []
|
||||
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
|
||||
version_ge_1_3 = ['-DVERSION_GE_1_3']
|
||||
version_ge_1_5 = []
|
||||
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4):
|
||||
version_ge_1_5 = ['-DVERSION_GE_1_5']
|
||||
return version_ge_1_1 + version_ge_1_3 + version_ge_1_5
|
||||
|
||||
def cxx_args(self):
|
||||
cxx_flags = [
|
||||
'-fsycl', '-fsycl-targets=spir64_gen', '-g', '-gdwarf-4', '-O3', '-std=c++17', '-fPIC', '-DMKL_ILP64',
|
||||
'-fno-strict-aliasing'
|
||||
]
|
||||
if os.environ.get('USE_MKL_GEMM'):
|
||||
cxx_flags.append('-DUSE_MKL_GEMM')
|
||||
return cxx_flags
|
||||
|
||||
def extra_ldflags(self):
|
||||
return [
|
||||
'-fPIC', '-fsycl', '-fsycl-targets=spir64_gen', '-fsycl-max-parallel-link-jobs=8',
|
||||
'-Xs "-options -cl-poison-unsupported-fp64-kernels,cl-intel-enable-auto-large-GRF-mode"',
|
||||
'-Xs "-device pvc"', '-Wl,-export-dynamic'
|
||||
]
|
||||
|
||||
def fixed_aotflags(self):
|
||||
return [
|
||||
'-fsycl', '-fsycl-targets=spir64_gen', '-fsycl-max-parallel-link-jobs=8', '-Xs',
|
||||
"-options -cl-poison-unsupported-fp64-kernels,cl-intel-enable-auto-large-GRF-mode", '-Xs', "-device pvc"
|
||||
]
|
||||
|
||||
def load(self, verbose=True):
|
||||
from deepspeed.git_version_info import installed_ops, torch_info # noqa: F401
|
||||
if installed_ops.get(self.name, False):
|
||||
return importlib.import_module(self.absolute_name())
|
||||
else:
|
||||
return self.jit_load(verbose)
|
||||
|
||||
def jit_load(self, verbose=True):
|
||||
if not self.is_compatible(verbose):
|
||||
raise RuntimeError(
|
||||
f"Unable to JIT load the {self.name} op due to it not being compatible due to hardware/software issue. {self.error_log}"
|
||||
)
|
||||
try:
|
||||
import ninja # noqa: F401
|
||||
except ImportError:
|
||||
raise RuntimeError(f"Unable to JIT load the {self.name} op due to ninja not being installed.")
|
||||
|
||||
self.jit_mode = True
|
||||
from intel_extension_for_pytorch.xpu.cpp_extension import load
|
||||
|
||||
start_build = time.time()
|
||||
# Recognize relative paths as absolute paths for jit load
|
||||
|
||||
sources = [self.deepspeed_src_path(path) for path in self.sources()]
|
||||
extra_include_paths = [self.deepspeed_src_path(path) for path in self.include_paths()]
|
||||
|
||||
# Torch will try and apply whatever CCs are in the arch list at compile time,
|
||||
# we have already set the intended targets ourselves we know that will be
|
||||
# needed at runtime. This prevents CC collisions such as multiple __half
|
||||
# implementations. Stash arch list to reset after build.
|
||||
'''
|
||||
torch_arch_list = None
|
||||
if "TORCH_CUDA_ARCH_LIST" in os.environ:
|
||||
torch_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST")
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = ""
|
||||
'''
|
||||
|
||||
op_module = load(
|
||||
name=self.name,
|
||||
sources=self.strip_empty_entries(sources),
|
||||
extra_include_paths=self.strip_empty_entries(extra_include_paths),
|
||||
extra_cflags=self.strip_empty_entries(self.cxx_args()),
|
||||
# extra_cuda_cflags=self.strip_empty_entries(self.nvcc_args()),
|
||||
extra_ldflags=self.strip_empty_entries(self.extra_ldflags()),
|
||||
verbose=verbose)
|
||||
|
||||
build_duration = time.time() - start_build
|
||||
if verbose:
|
||||
print(f"Time to load {self.name} op: {build_duration} seconds")
|
||||
'''
|
||||
# Reset arch list so we are not silently removing it for other possible use cases
|
||||
if torch_arch_list:
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = torch_arch_list
|
||||
'''
|
||||
return op_module
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from .builder import SYCLOpBuilder
|
||||
|
||||
|
||||
class CPUAdagradBuilder(SYCLOpBuilder):
|
||||
BUILD_VAR = "DS_BUILD_CPU_ADAGRAD"
|
||||
NAME = "cpu_adagrad"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name=self.NAME)
|
||||
|
||||
def absolute_name(self):
|
||||
return f'deepspeed.ops.adagrad.{self.NAME}_op'
|
||||
|
||||
def sources(self):
|
||||
return ['csrc/xpu/adagrad/cpu_adagrad.cpp', 'csrc/xpu/common/custom_cuda_kernel.dp.cpp']
|
||||
|
||||
def include_paths(self):
|
||||
return ['csrc/xpu/includes']
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from .builder import SYCLOpBuilder
|
||||
|
||||
|
||||
class CPUAdamBuilder(SYCLOpBuilder):
|
||||
BUILD_VAR = "DS_BUILD_CPU_ADAM"
|
||||
NAME = "cpu_adam"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name=self.NAME)
|
||||
|
||||
def absolute_name(self):
|
||||
return f'deepspeed.ops.adam.{self.NAME}_op'
|
||||
|
||||
def sources(self):
|
||||
if self.build_for_cpu:
|
||||
return ['csrc/xpu/adam/cpu_adam.cpp', 'csrc/xpu/adam/cpu_adam_impl.cpp']
|
||||
|
||||
return [
|
||||
'csrc/xpu/adam/cpu_adam.cpp', 'csrc/xpu/adam/cpu_adam_impl.cpp',
|
||||
'csrc/xpu/common/custom_cuda_kernel.dp.cpp'
|
||||
]
|
||||
|
||||
def include_paths(self):
|
||||
return ['csrc/xpu/includes']
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
from .builder import SYCLOpBuilder
|
||||
|
||||
|
||||
class FusedAdamBuilder(SYCLOpBuilder):
|
||||
BUILD_VAR = "DS_BUILD_FUSED_ADAM"
|
||||
NAME = "fused_adam"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name=self.NAME)
|
||||
|
||||
def absolute_name(self):
|
||||
return f'deepspeed.ops.adam.{self.NAME}_op'
|
||||
|
||||
def sources(self):
|
||||
return ['csrc/xpu/adam/fused_adam_frontend.cpp', 'csrc/xpu/adam/multi_tensor_adam.dp.cpp']
|
||||
|
||||
def include_paths(self):
|
||||
return ['csrc/xpu/includes', 'csrc/xpu/adam']
|
||||
|
||||
def cxx_args(self):
|
||||
args = super().cxx_args()
|
||||
return args + self.version_dependent_macros()
|
Загрузка…
Ссылка в новой задаче