зеркало из https://github.com/microsoft/DeepSpeed.git
feat: add Lion optimizer (#4331)
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
This commit is contained in:
Родитель
d72edb3b0d
Коммит
8e64c3b550
|
@ -0,0 +1,43 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include "cpu_lion.h"
|
||||
|
||||
// C++ interface
|
||||
|
||||
void multi_tensor_lion(int chunk_size,
|
||||
at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists, /*gpmv*/
|
||||
const float lr,
|
||||
const float beta1,
|
||||
const float beta2,
|
||||
const int step,
|
||||
const int mode,
|
||||
const float weight_decay)
|
||||
{
|
||||
static bool initialized = false;
|
||||
if (!initialized) {
|
||||
create_lion_optimizer(0);
|
||||
initialized = true;
|
||||
}
|
||||
for (int i = 0; i < tensor_lists[0].size(); i++) {
|
||||
ds_lion_step(0,
|
||||
step,
|
||||
lr,
|
||||
beta1,
|
||||
beta2,
|
||||
weight_decay,
|
||||
tensor_lists[1][i],
|
||||
tensor_lists[0][i],
|
||||
tensor_lists[2][i]);
|
||||
}
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("multi_tensor_lion",
|
||||
&multi_tensor_lion,
|
||||
"Compute and apply gradient update to parameters for Lion optimizer");
|
||||
}
|
|
@ -0,0 +1,233 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#pragma once
|
||||
|
||||
#define NOMINMAX // Windows idiosyncrasy
|
||||
// https://stackoverflow.com/questions/4913922/possible-problems-with-nominmax-on-visual-c
|
||||
|
||||
#include <stdio.h>
|
||||
#include <torch/extension.h>
|
||||
#include <cassert>
|
||||
#include "simd.h"
|
||||
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include "cuda.h"
|
||||
#include "custom_cuda_layers.h"
|
||||
typedef __half ds_half_precision_t;
|
||||
#else
|
||||
#include <cmath>
|
||||
typedef unsigned short ds_half_precision_t;
|
||||
#endif
|
||||
|
||||
#define STEP(SPAN) \
|
||||
void Step_##SPAN(float* _params, \
|
||||
float* grads, \
|
||||
float* _exp_avg, \
|
||||
size_t _param_size, \
|
||||
ds_half_precision_t* dev_param = nullptr, \
|
||||
bool half_precision = false);
|
||||
|
||||
class Lion_Optimizer {
|
||||
public:
|
||||
Lion_Optimizer(float alpha = 1e-3,
|
||||
float betta1 = 0.9,
|
||||
float betta2 = 0.999,
|
||||
float weight_decay = 0)
|
||||
: _alpha(alpha), _betta1(betta1), _betta2(betta2), _weight_decay(weight_decay), _step(0)
|
||||
{
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
|
||||
cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
|
||||
|
||||
_streams[0] = TrainingContext::Instance().GetCurrentStream();
|
||||
_streams[1] = TrainingContext::Instance().GetNewStream();
|
||||
_buf_index = false;
|
||||
#endif
|
||||
}
|
||||
~Lion_Optimizer()
|
||||
{
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
cudaFreeHost(_doubled_buffer[0]);
|
||||
cudaFreeHost(_doubled_buffer[1]);
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
template <int span>
|
||||
void Step_AVX(size_t* rounded_size,
|
||||
float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg,
|
||||
size_t param_size,
|
||||
ds_half_precision_t* dev_param = nullptr,
|
||||
bool half_precision = false);
|
||||
#endif
|
||||
STEP(1)
|
||||
STEP(4)
|
||||
STEP(8)
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
inline void SynchronizeStreams()
|
||||
{
|
||||
for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]);
|
||||
}
|
||||
#endif
|
||||
inline void IncrementStep(size_t step, float beta1, float beta2)
|
||||
{
|
||||
_step++;
|
||||
if (_step != step || beta1 != _betta1 || beta2 != _betta2) {
|
||||
_step = step;
|
||||
_betta1 = beta1;
|
||||
_betta2 = beta2;
|
||||
}
|
||||
}
|
||||
inline void update_state(float lr, float weight_decay)
|
||||
{
|
||||
_alpha = lr;
|
||||
_weight_decay = weight_decay;
|
||||
}
|
||||
|
||||
private:
|
||||
float _alpha;
|
||||
float _betta1;
|
||||
float _betta2;
|
||||
float _weight_decay;
|
||||
size_t _step;
|
||||
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
float* _doubled_buffer[2];
|
||||
cudaStream_t _streams[2];
|
||||
bool _buf_index;
|
||||
#endif
|
||||
};
|
||||
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
template <int span>
|
||||
void Lion_Optimizer::Step_AVX(size_t* rounded_size,
|
||||
float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg,
|
||||
size_t _param_size,
|
||||
ds_half_precision_t* dev_params,
|
||||
bool half_precision)
|
||||
{
|
||||
size_t new_rounded_size = 0;
|
||||
int rshft = half_precision ? 1 : 0;
|
||||
|
||||
constexpr float neg1 = -1.0f;
|
||||
AVX_Data neg1_4;
|
||||
neg1_4.data = SIMD_SET(neg1);
|
||||
|
||||
AVX_Data betta1_4;
|
||||
betta1_4.data = SIMD_SET(_betta1);
|
||||
AVX_Data betta2_4;
|
||||
betta2_4.data = SIMD_SET(_betta2);
|
||||
|
||||
float betta1_minus1 = 1 - _betta1;
|
||||
float betta2_minus1 = 1 - _betta2;
|
||||
AVX_Data betta1_minus1_4;
|
||||
betta1_minus1_4.data = SIMD_SET(betta1_minus1);
|
||||
AVX_Data betta2_minus1_4;
|
||||
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
|
||||
|
||||
float step_size = -_alpha;
|
||||
AVX_Data step_size_4;
|
||||
step_size_4.data = SIMD_SET(step_size);
|
||||
|
||||
float after_decay = 1.0f - _alpha * _weight_decay;
|
||||
AVX_Data after_decay_4;
|
||||
if (_weight_decay > 0) after_decay_4.data = SIMD_SET(after_decay);
|
||||
|
||||
new_rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * span);
|
||||
for (size_t t = 0; t < new_rounded_size; t += TILE) {
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
|
||||
#endif
|
||||
#pragma omp parallel for
|
||||
for (size_t i = t; i < offset; i += SIMD_WIDTH * span) {
|
||||
AVX_Data grad_4[span];
|
||||
simd_load<span>(grad_4, grads + (i >> rshft), half_precision);
|
||||
|
||||
AVX_Data momentum_4[span];
|
||||
simd_load<span>(momentum_4, _exp_avg + i, false);
|
||||
|
||||
AVX_Data param_4[span];
|
||||
simd_load<span>(param_4, _params + (i >> rshft), half_precision);
|
||||
|
||||
AVX_Data tmp_4[span];
|
||||
|
||||
simd_mul<span>(tmp_4, momentum_4, betta1_4);
|
||||
simd_fma<span>(tmp_4, grad_4, betta1_minus1_4, tmp_4);
|
||||
// We already used intrinsics, so consider the machine representation fixed.
|
||||
simd_and<span>(tmp_4, tmp_4, neg1_4);
|
||||
simd_xor<span>(tmp_4, tmp_4, step_size_4);
|
||||
if (_weight_decay > 0) {
|
||||
simd_fma<span>(param_4, param_4, after_decay_4, tmp_4);
|
||||
} else {
|
||||
simd_add<span>(param_4, param_4, tmp_4);
|
||||
}
|
||||
|
||||
simd_mul<span>(momentum_4, momentum_4, betta2_4);
|
||||
simd_fma<span>(momentum_4, grad_4, betta2_minus1_4, momentum_4);
|
||||
|
||||
simd_store<span>(_params + (i >> rshft), param_4, half_precision);
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
if (dev_params) {
|
||||
simd_store<span>(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision);
|
||||
}
|
||||
#endif
|
||||
simd_store<span>(_exp_avg + i, momentum_4, false);
|
||||
}
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
if (dev_params) {
|
||||
if (half_precision)
|
||||
launch_param_update_half(
|
||||
_doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
|
||||
else
|
||||
launch_param_update(
|
||||
_doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
|
||||
|
||||
_buf_index = !_buf_index;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
*rounded_size = new_rounded_size;
|
||||
}
|
||||
#endif
|
||||
|
||||
int create_lion_optimizer(int optimizer_id,
|
||||
float alpha = 1e-3,
|
||||
float betta1 = 0.9,
|
||||
float betta2 = 0.999,
|
||||
float weight_decay = 0,
|
||||
bool should_log = false);
|
||||
|
||||
int ds_lion_step(int optimizer_id,
|
||||
size_t step,
|
||||
float lr,
|
||||
float beta1,
|
||||
float beta2,
|
||||
float weight_decay,
|
||||
torch::Tensor& params,
|
||||
torch::Tensor& grads,
|
||||
torch::Tensor& exp_avg);
|
||||
|
||||
int ds_lion_step_plus_copy(int optimizer_id,
|
||||
size_t step,
|
||||
float lr,
|
||||
float beta1,
|
||||
float beta2,
|
||||
float weight_decay,
|
||||
torch::Tensor& params,
|
||||
torch::Tensor& grads,
|
||||
torch::Tensor& exp_avg,
|
||||
torch::Tensor& gpu_params);
|
||||
|
||||
int destroy_lion_optimizer(int optimizer_id);
|
|
@ -24,6 +24,10 @@
|
|||
#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c)
|
||||
#define SIMD_SQRT(x) _mm512_sqrt_ps(x)
|
||||
#define SIMD_DIV(x, y) _mm512_div_ps(x, y)
|
||||
#define SIMD_AND(x, y) _mm512_and_ps(x, y)
|
||||
#define SIMD_ANDNOT(x, y) _mm512_andnot_ps(x, y)
|
||||
#define SIMD_OR(x, y) _mm512_or_ps(x, y)
|
||||
#define SIMD_XOR(x, y) _mm512_xor_ps(x, y)
|
||||
#define SIMD_WIDTH 16
|
||||
|
||||
#define SIMD_LOAD2(x, h) \
|
||||
|
@ -42,10 +46,14 @@
|
|||
#define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c)
|
||||
#define SIMD_SQRT(x) _mm256_sqrt_ps(x)
|
||||
#define SIMD_DIV(x, y) _mm256_div_ps(x, y)
|
||||
#define SIMD_AND(x, y) _mm256_and_ps(x, y)
|
||||
#define SIMD_ANDNOT(x, y) _mm256_andnot_ps(x, y)
|
||||
#define SIMD_OR(x, y) _mm256_or_ps(x, y)
|
||||
#define SIMD_XOR(x, y) _mm256_xor_ps(x, y)
|
||||
#define SIMD_WIDTH 8
|
||||
|
||||
#define SIMD_LOAD2(x, h) \
|
||||
((h) ? _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)x)) : _mm256_loadu_ps(x))
|
||||
|
||||
#define SIMD_STORE2(x, d, h) \
|
||||
((h) ? _mm_store_ps(x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \
|
||||
: _mm256_storeu_ps(x, d))
|
||||
|
@ -136,5 +144,55 @@ inline void simd_div(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r)
|
|||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_DIV(src_a_l[i].data, src_a_r[i].data); }
|
||||
}
|
||||
template <int span>
|
||||
inline void simd_and(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r)
|
||||
{
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_AND(src_a_l[i].data, src_a_r.data); }
|
||||
}
|
||||
template <int span>
|
||||
inline void simd_and(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r)
|
||||
{
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_AND(src_a_l[i].data, src_a_r[i].data); }
|
||||
}
|
||||
template <int span>
|
||||
inline void simd_andnot(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r)
|
||||
{
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_ANDNOT(src_a_l[i].data, src_a_r.data); }
|
||||
}
|
||||
template <int span>
|
||||
inline void simd_andnot(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r)
|
||||
{
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) {
|
||||
dst[i].data = SIMD_ANDNOT(src_a_l[i].data, src_a_r[i].data);
|
||||
}
|
||||
}
|
||||
template <int span>
|
||||
inline void simd_or(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r)
|
||||
{
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_OR(src_a_l[i].data, src_a_r.data); }
|
||||
}
|
||||
template <int span>
|
||||
inline void simd_or(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r)
|
||||
{
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_OR(src_a_l[i].data, src_a_r[i].data); }
|
||||
}
|
||||
template <int span>
|
||||
inline void simd_xor(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r)
|
||||
{
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_XOR(src_a_l[i].data, src_a_r.data); }
|
||||
}
|
||||
template <int span>
|
||||
inline void simd_xor(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r)
|
||||
{
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_XOR(src_a_l[i].data, src_a_r[i].data); }
|
||||
}
|
||||
|
||||
#endif
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include "cpu_lion.h"
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("lion_update", &ds_lion_step, "DeepSpeed CPU Lion update (C++)");
|
||||
m.def("lion_update_copy",
|
||||
&ds_lion_step_plus_copy,
|
||||
"DeepSpeed CPU Lion update and param copy (C++)");
|
||||
m.def("create_lion", &create_lion_optimizer, "DeepSpeed CPU Lion (C++)");
|
||||
m.def("destroy_lion", &destroy_lion_optimizer, "DeepSpeed CPU Lion destroy (C++)");
|
||||
}
|
|
@ -0,0 +1,255 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
#include <unordered_map>
|
||||
#include "cpu_lion.h"
|
||||
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
#include <cuda_runtime_api.h>
|
||||
#include "cublas_v2.h"
|
||||
#include "cuda.h"
|
||||
#include "curand.h"
|
||||
#include "custom_cuda_layers.h"
|
||||
#endif
|
||||
|
||||
static std::unordered_map<int, std::shared_ptr<void>> s_optimizers;
|
||||
|
||||
// C++ interface
|
||||
|
||||
void Lion_Optimizer::Step_1(float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg,
|
||||
size_t _param_size,
|
||||
ds_half_precision_t* dev_params,
|
||||
bool half_precision)
|
||||
{
|
||||
size_t rounded_size = 0;
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
Step_AVX<1>(&rounded_size, _params, grads, _exp_avg, _param_size, dev_params, half_precision);
|
||||
#endif
|
||||
if (_param_size > rounded_size) {
|
||||
float betta1_minus1 = 1 - _betta1;
|
||||
float betta2_minus1 = 1 - _betta2;
|
||||
|
||||
float alpha = _alpha;
|
||||
float after_decay = 1 - alpha * _weight_decay;
|
||||
ds_half_precision_t* grads_cast_h;
|
||||
ds_half_precision_t* params_cast_h;
|
||||
if (half_precision) {
|
||||
grads_cast_h = reinterpret_cast<ds_half_precision_t*>(grads);
|
||||
params_cast_h = reinterpret_cast<ds_half_precision_t*>(_params);
|
||||
}
|
||||
|
||||
for (size_t t = rounded_size; t < _param_size; t += TILE) {
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > _param_size) copy_size = _param_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
|
||||
#endif
|
||||
#pragma omp parallel for
|
||||
for (size_t k = t; k < offset; k++) {
|
||||
float grad = half_precision ? (float)grads_cast_h[k] : grads[k];
|
||||
float param = half_precision ? (float)params_cast_h[k] : _params[k];
|
||||
float momentum = _exp_avg[k];
|
||||
float tmp = momentum * _betta1;
|
||||
tmp = grad * betta1_minus1 + tmp;
|
||||
// Rely on portable C++ methods to manipulate the sign bit of a floating-point
|
||||
// number.
|
||||
tmp = -std::copysignf(alpha, tmp);
|
||||
if (_weight_decay > 0) {
|
||||
param = param * after_decay + tmp;
|
||||
} else {
|
||||
param = param + tmp;
|
||||
}
|
||||
momentum = momentum * _betta2;
|
||||
momentum = grad * betta2_minus1 + momentum;
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
if (dev_params) _doubled_buffer[_buf_index][k - t] = param;
|
||||
#endif
|
||||
if (half_precision)
|
||||
params_cast_h[k] = (ds_half_precision_t)param;
|
||||
else
|
||||
_params[k] = param;
|
||||
_exp_avg[k] = momentum;
|
||||
}
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
if (dev_params) {
|
||||
launch_param_update(
|
||||
_doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]);
|
||||
|
||||
_buf_index = !_buf_index;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Lion_Optimizer::Step_4(float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg,
|
||||
size_t _param_size,
|
||||
ds_half_precision_t* dev_params,
|
||||
bool half_precision)
|
||||
{
|
||||
size_t rounded_size = 0;
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
Step_AVX<4>(&rounded_size, _params, grads, _exp_avg, _param_size, dev_params, half_precision);
|
||||
#endif
|
||||
if (_param_size > rounded_size)
|
||||
Step_1((_params + rounded_size),
|
||||
(grads + rounded_size),
|
||||
(_exp_avg + rounded_size),
|
||||
(_param_size - rounded_size),
|
||||
(dev_params != nullptr ? (dev_params + rounded_size) : dev_params),
|
||||
half_precision);
|
||||
}
|
||||
|
||||
int create_lion_optimizer(int optimizer_id,
|
||||
float alpha,
|
||||
float betta1,
|
||||
float betta2,
|
||||
float weight_decay,
|
||||
bool should_log)
|
||||
{
|
||||
auto opt = std::make_shared<Lion_Optimizer>(alpha, betta1, betta2, weight_decay);
|
||||
|
||||
s_optimizers[optimizer_id] = opt;
|
||||
|
||||
if (should_log) {
|
||||
std::string avx_type = "";
|
||||
#if defined(__AVX512__)
|
||||
avx_type = "AVX512";
|
||||
#else
|
||||
#if defined(__AVX256__)
|
||||
avx_type = "AVX2";
|
||||
#else
|
||||
avx_type = "scalar";
|
||||
#endif
|
||||
#endif
|
||||
|
||||
printf("Lion Optimizer #%d is created with %s arithmetic capability.\n",
|
||||
optimizer_id,
|
||||
avx_type.c_str());
|
||||
printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f\n",
|
||||
alpha,
|
||||
betta1,
|
||||
betta2,
|
||||
weight_decay);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
void Lion_Optimizer::Step_8(float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg,
|
||||
size_t _param_size,
|
||||
ds_half_precision_t* dev_params,
|
||||
bool half_precision)
|
||||
{
|
||||
size_t rounded_size = 0;
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
Step_AVX<8>(&rounded_size, _params, grads, _exp_avg, _param_size, dev_params, half_precision);
|
||||
#endif
|
||||
if (_param_size > rounded_size)
|
||||
Step_4((_params + rounded_size),
|
||||
(grads + rounded_size),
|
||||
(_exp_avg + rounded_size),
|
||||
(_param_size - rounded_size),
|
||||
(dev_params != nullptr ? (dev_params + rounded_size) : dev_params),
|
||||
half_precision);
|
||||
}
|
||||
|
||||
int ds_lion_step(int optimizer_id,
|
||||
size_t step,
|
||||
float lr,
|
||||
float beta1,
|
||||
float beta2,
|
||||
float weight_decay,
|
||||
torch::Tensor& params,
|
||||
torch::Tensor& grads,
|
||||
torch::Tensor& exp_avg)
|
||||
{
|
||||
auto params_c = params.contiguous();
|
||||
auto grads_c = grads.contiguous();
|
||||
auto exp_avg_c = exp_avg.contiguous();
|
||||
|
||||
// assert(params.options().dtype() == grads.options().dtype());
|
||||
|
||||
float* params_ptr = (float*)params_c.data_ptr();
|
||||
float* grads_ptr = (float*)grads_c.data_ptr();
|
||||
float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
|
||||
|
||||
std::shared_ptr<Lion_Optimizer> opt =
|
||||
std::static_pointer_cast<Lion_Optimizer>(s_optimizers[optimizer_id]);
|
||||
opt->IncrementStep(step, beta1, beta2);
|
||||
opt->update_state(lr, weight_decay);
|
||||
|
||||
opt->Step_8(params_ptr,
|
||||
grads_ptr,
|
||||
exp_avg_ptr,
|
||||
params_c.numel(),
|
||||
nullptr,
|
||||
(params.options().dtype() == at::kHalf));
|
||||
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
opt->SynchronizeStreams();
|
||||
#endif
|
||||
return 0;
|
||||
}
|
||||
|
||||
int ds_lion_step_plus_copy(int optimizer_id,
|
||||
size_t step,
|
||||
float lr,
|
||||
float beta1,
|
||||
float beta2,
|
||||
float weight_decay,
|
||||
torch::Tensor& params,
|
||||
torch::Tensor& grads,
|
||||
torch::Tensor& exp_avg,
|
||||
torch::Tensor& gpu_params)
|
||||
{
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
auto params_c = params.contiguous();
|
||||
auto gpu_params_c = gpu_params.contiguous();
|
||||
auto exp_avg_c = exp_avg.contiguous();
|
||||
auto grads_c = grads.contiguous();
|
||||
|
||||
float* params_ptr = (float*)params_c.data_ptr();
|
||||
float* grads_ptr = (float*)grads_c.data_ptr();
|
||||
ds_half_precision_t* gpu_params_ptr = (ds_half_precision_t*)gpu_params_c.data_ptr();
|
||||
float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
|
||||
|
||||
std::shared_ptr<Lion_Optimizer> opt =
|
||||
std::static_pointer_cast<Lion_Optimizer>(s_optimizers[optimizer_id]);
|
||||
opt->IncrementStep(step, beta1, beta2);
|
||||
opt->update_state(lr, weight_decay);
|
||||
opt->Step_8(params_ptr,
|
||||
grads_ptr,
|
||||
exp_avg_ptr,
|
||||
params_c.numel(),
|
||||
gpu_params_ptr,
|
||||
(params.options().dtype() == at::kHalf));
|
||||
|
||||
opt->SynchronizeStreams();
|
||||
#else
|
||||
assert(false);
|
||||
#endif
|
||||
return 0;
|
||||
}
|
||||
|
||||
int destroy_lion_optimizer(int optimizer_id)
|
||||
{
|
||||
s_optimizers.erase(optimizer_id);
|
||||
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,22 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
void multi_tensor_lion_cuda(int chunk_size,
|
||||
at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
const float lr,
|
||||
const float beta1,
|
||||
const float beta2,
|
||||
const int step,
|
||||
const float weight_decay);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("multi_tensor_lion",
|
||||
&multi_tensor_lion_cuda,
|
||||
"Compute and apply gradient update to parameters for Lion optimizer");
|
||||
}
|
|
@ -0,0 +1,132 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
/*
|
||||
Copyright NVIDIA/apex
|
||||
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
|
||||
*/
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include "compat.h"
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
// #include <iostream>
|
||||
|
||||
// This header is the one-stop shop for all your multi-tensor apply needs.
|
||||
|
||||
// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson)
|
||||
constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
|
||||
constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
|
||||
|
||||
template <int n>
|
||||
struct TensorListMetadata {
|
||||
void* addresses[n][depth_to_max_tensors[n - 1]];
|
||||
int sizes[depth_to_max_tensors[n - 1]];
|
||||
unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
|
||||
int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a full int.
|
||||
int start_tensor_this_launch;
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename... ArgTypes>
|
||||
__global__ void multi_tensor_apply_kernel(int chunk_size,
|
||||
volatile int* noop_flag,
|
||||
T tl,
|
||||
U callable,
|
||||
ArgTypes... args)
|
||||
{
|
||||
// Hand the chunk information to the user-supplied functor to process however it likes.
|
||||
callable(chunk_size, noop_flag, tl, args...);
|
||||
}
|
||||
|
||||
template <int depth, typename T, typename... ArgTypes>
|
||||
void multi_tensor_apply(int block_size,
|
||||
int chunk_size,
|
||||
const at::Tensor& noop_flag,
|
||||
const std::vector<std::vector<at::Tensor>>& tensor_lists,
|
||||
T callable,
|
||||
ArgTypes... args)
|
||||
{
|
||||
TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
|
||||
int len0 = tensor_lists[0].size();
|
||||
TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
|
||||
auto ref_device = tensor_lists[0][0].device();
|
||||
TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
|
||||
for (int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices
|
||||
{
|
||||
TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
|
||||
for (int t = 0; t < tensor_lists[l].size(); t++) {
|
||||
// TODO: Print which tensor fails.
|
||||
bool contiguous_memory = tensor_lists[l][t].is_contiguous();
|
||||
#ifdef VERSION_GE_1_5
|
||||
contiguous_memory = (contiguous_memory ||
|
||||
tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast));
|
||||
#endif
|
||||
TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
|
||||
TORCH_CHECK(tensor_lists[l][t].device() == ref_device,
|
||||
"A tensor was not on the same device as the first tensor");
|
||||
TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
|
||||
}
|
||||
}
|
||||
|
||||
int ntensors = tensor_lists[0].size();
|
||||
|
||||
TensorListMetadata<depth> tl;
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
tl.start_tensor_this_launch = 0;
|
||||
int loc_block_info = 0;
|
||||
int loc_tensor_info = 0;
|
||||
for (int t = 0; t < ntensors; t++) {
|
||||
tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
|
||||
for (int d = 0; d < depth; d++)
|
||||
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
|
||||
loc_tensor_info++;
|
||||
|
||||
int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
|
||||
|
||||
for (int chunk = 0; chunk < chunks_this_tensor; chunk++) {
|
||||
// std::cout << chunks_this_tensor << std::endl;
|
||||
tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
|
||||
tl.block_to_chunk[loc_block_info] = chunk;
|
||||
loc_block_info++;
|
||||
|
||||
bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth - 1] &&
|
||||
chunk == chunks_this_tensor - 1);
|
||||
bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]);
|
||||
bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
|
||||
if (tensors_full || blocks_full || last_chunk) {
|
||||
// using accscalar_t = acc_type<scalar_t, true>;
|
||||
multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>(
|
||||
chunk_size, noop_flag.DATA_PTR<int>(), tl, callable, args...);
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// Reset. The control flow possibilities here make my brain hurt.
|
||||
loc_block_info = 0;
|
||||
if (chunk == chunks_this_tensor - 1) {
|
||||
// std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 <<
|
||||
// std::endl;
|
||||
loc_tensor_info = 0;
|
||||
tl.start_tensor_this_launch = t + 1;
|
||||
} else {
|
||||
// std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 <<
|
||||
// std::endl;
|
||||
tl.sizes[0] = tl.sizes[loc_tensor_info - 1];
|
||||
for (int d = 0; d < depth; d++)
|
||||
tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1];
|
||||
loc_tensor_info = 1;
|
||||
tl.start_tensor_this_launch = t;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,126 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
/*
|
||||
Copyright NVIDIA/apex
|
||||
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
|
||||
*/
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
// Another possibility:
|
||||
// #include <torch/all.h>
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include "multi_tensor_apply.cuh"
|
||||
#include "type_shim.h"
|
||||
|
||||
#define BLOCK_SIZE 512
|
||||
#define ILP 4
|
||||
|
||||
using MATH_T = float;
|
||||
|
||||
template <typename T>
|
||||
struct LionFunctor {
|
||||
__device__ __forceinline__ void operator()(int chunk_size,
|
||||
volatile int* noop_gmem,
|
||||
TensorListMetadata<3>& tl,
|
||||
const float beta1,
|
||||
const float beta2,
|
||||
const float lr,
|
||||
const float decay)
|
||||
{
|
||||
// I'd like this kernel to propagate infs/nans.
|
||||
// if(*noop_gmem == 1)
|
||||
// return;
|
||||
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
|
||||
// potentially use to pass in list of scalar
|
||||
// int tensor_num = tl.start_tensor_this_launch + tensor_loc;
|
||||
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.sizes[tensor_loc];
|
||||
|
||||
T* g = (T*)tl.addresses[0][tensor_loc];
|
||||
g += chunk_idx * chunk_size;
|
||||
|
||||
T* p = (T*)tl.addresses[1][tensor_loc];
|
||||
p += chunk_idx * chunk_size;
|
||||
|
||||
T* m = (T*)tl.addresses[2][tensor_loc];
|
||||
m += chunk_idx * chunk_size;
|
||||
|
||||
n -= chunk_idx * chunk_size;
|
||||
|
||||
MATH_T after_decay = 1.0f - lr * decay;
|
||||
|
||||
// see note in multi_tensor_scale_kernel.cu
|
||||
for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
|
||||
MATH_T r_g[ILP];
|
||||
MATH_T r_p[ILP];
|
||||
MATH_T r_m[ILP];
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size) {
|
||||
r_g[ii] = g[i];
|
||||
r_p[ii] = p[i];
|
||||
r_m[ii] = m[i];
|
||||
} else {
|
||||
r_g[ii] = MATH_T(0);
|
||||
r_p[ii] = MATH_T(0);
|
||||
r_m[ii] = MATH_T(0);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
MATH_T c = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
|
||||
MATH_T update = c > 0 ? (-lr) : lr;
|
||||
r_p[ii] = r_p[ii] * after_decay + update;
|
||||
r_m[ii] = beta2 * r_m[ii] + (1 - beta2) * r_g[ii];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size) {
|
||||
p[i] = r_p[ii];
|
||||
m[i] = r_m[ii];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void multi_tensor_lion_cuda(int chunk_size,
|
||||
at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
const float lr,
|
||||
const float beta1,
|
||||
const float beta2,
|
||||
const int step,
|
||||
const float weight_decay)
|
||||
{
|
||||
using namespace at;
|
||||
|
||||
// Assume single type across p,g,m1,m2 now
|
||||
DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(),
|
||||
0,
|
||||
"lion",
|
||||
multi_tensor_apply<3>(BLOCK_SIZE,
|
||||
chunk_size,
|
||||
noop_flag,
|
||||
tensor_lists,
|
||||
LionFunctor<scalar_t_0>(),
|
||||
beta1,
|
||||
beta2,
|
||||
lr,
|
||||
weight_decay);)
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
|
@ -6,6 +6,7 @@
|
|||
from . import adam
|
||||
from . import adagrad
|
||||
from . import lamb
|
||||
from . import lion
|
||||
#from ..git_version_info_installed import installed_ops as __installed_ops__
|
||||
#if __installed_ops__['sparse_attn']:
|
||||
from . import sparse_attention
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from .cpu_lion import DeepSpeedCPULion
|
||||
from .fused_lion import FusedLion
|
|
@ -0,0 +1,141 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import torch
|
||||
from cpuinfo import get_cpu_info
|
||||
from deepspeed.utils import logger
|
||||
from deepspeed.utils.logging import should_log_le
|
||||
from deepspeed.ops.op_builder import CPULionBuilder
|
||||
|
||||
|
||||
class DeepSpeedCPULion(torch.optim.Optimizer):
|
||||
optimizer_id = 0
|
||||
|
||||
def __init__(self, model_params, lr=1e-3, betas=(0.9, 0.999), weight_decay=0, fp32_optimizer_states=True):
|
||||
"""Fast vectorized implementation of Lion optimizer on CPU:
|
||||
|
||||
See Symbolic Discovery of Optimization Algorithms (https://doi.org/10.48550/arXiv.2302.06675).
|
||||
|
||||
.. note::
|
||||
We recommend using our `config
|
||||
<https://www.deepspeed.ai/docs/config-json/#optimizer-parameters>`_
|
||||
to allow :meth:`deepspeed.initialize` to build this optimizer
|
||||
for you.
|
||||
|
||||
|
||||
Arguments:
|
||||
model_params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups.
|
||||
lr (float, optional): learning rate. (default: 1e-3)
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its square. (default: (0.9, 0.999))
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
full_precision_optimizer_states: creates momentum and variance in full precision regardless of
|
||||
the precision of the parameters (default: True)
|
||||
"""
|
||||
|
||||
default_args = dict(lr=lr, betas=betas, weight_decay=weight_decay)
|
||||
super(DeepSpeedCPULion, self).__init__(model_params, default_args)
|
||||
|
||||
cpu_info = get_cpu_info()
|
||||
self.cpu_vendor = cpu_info["vendor_id_raw"].lower() if "vendor_id_raw" in cpu_info else "unknown"
|
||||
if "amd" in self.cpu_vendor:
|
||||
for group_id, group in enumerate(self.param_groups):
|
||||
for param_id, p in enumerate(group['params']):
|
||||
if p.dtype == torch.half:
|
||||
logger.warning("FP16 params for CPULion may not work on AMD CPUs")
|
||||
break
|
||||
else:
|
||||
continue
|
||||
break
|
||||
|
||||
self.opt_id = DeepSpeedCPULion.optimizer_id
|
||||
DeepSpeedCPULion.optimizer_id = DeepSpeedCPULion.optimizer_id + 1
|
||||
self.fp32_optimizer_states = fp32_optimizer_states
|
||||
self.ds_opt_lion = CPULionBuilder().load()
|
||||
|
||||
self.ds_opt_lion.create_lion(self.opt_id, lr, betas[0], betas[1], weight_decay, should_log_le("info"))
|
||||
|
||||
def __del__(self):
|
||||
# need to destroy the C++ object explicitly to avoid a memory leak when deepspeed.initialize
|
||||
# is used multiple times in the same process (notebook or pytest worker)
|
||||
self.ds_opt_lion.destroy_lion(self.opt_id)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(DeepSpeedCPULion, self).__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
group.setdefault('amsgrad', False)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None, fp16_param_groups=None):
|
||||
"""Update the model parameters.
|
||||
|
||||
.. note::
|
||||
This method will be called internally by ZeRO-Offload. DeepSpeed
|
||||
users should still use ``engine.step()`` as shown in the
|
||||
`Getting Started
|
||||
<https://www.deepspeed.ai/getting-started/#training>`_ guide.
|
||||
|
||||
Args:
|
||||
closure (callable, optional): closure to compute the loss.
|
||||
Defaults to ``None``.
|
||||
fp16_param_groups: FP16 GPU parameters to update. Performing the
|
||||
copy here reduces communication time. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
loss: if ``closure`` is provided. Otherwise ``None``.
|
||||
"""
|
||||
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
# intended device for step
|
||||
device = torch.device('cpu')
|
||||
|
||||
# converting the fp16 params to a group of parameter
|
||||
if type(fp16_param_groups) is list:
|
||||
if type(fp16_param_groups[0]) is not list:
|
||||
fp16_param_groups = [fp16_param_groups]
|
||||
elif fp16_param_groups is not None:
|
||||
fp16_param_groups = [[fp16_param_groups]]
|
||||
|
||||
for group_id, group in enumerate(self.param_groups):
|
||||
for param_id, p in enumerate(group['params']):
|
||||
|
||||
if p.grad is None:
|
||||
continue
|
||||
|
||||
assert p.device == device, f"CPULion param is on {p.device} and must be 'cpu', make " \
|
||||
"sure you enabled 'offload_optimizer': 'cpu' in your ZeRO config."
|
||||
|
||||
state = self.state[p]
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
#print(f'group {group_id} param {param_id} = {p.numel()}')
|
||||
state['step'] = 0
|
||||
|
||||
#use full precision by default unless self.fp32_optimizer_states is off
|
||||
state_dtype = torch.float if self.fp32_optimizer_states else p.dtype
|
||||
|
||||
# gradient momentums
|
||||
state['exp_avg'] = torch.zeros_like(p.data, dtype=state_dtype, device=device)
|
||||
#memory_format=torch.preserve_format)
|
||||
# gradient variances
|
||||
state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=state_dtype, device=device)
|
||||
#memory_format=torch.preserve_format)
|
||||
|
||||
state['step'] += 1
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
if fp16_param_groups is not None:
|
||||
self.ds_opt_lion.lion_update_copy(self.opt_id, state['step'], group['lr'], beta1, beta2,
|
||||
group['weight_decay'], p.data, p.grad.data, state['exp_avg'],
|
||||
fp16_param_groups[group_id][param_id].data)
|
||||
else:
|
||||
self.ds_opt_lion.lion_update(self.opt_id, state['step'], group['lr'], beta1, beta2,
|
||||
group['weight_decay'], p.data, p.grad.data, state['exp_avg'])
|
||||
return loss
|
|
@ -0,0 +1,131 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
"""
|
||||
This file is modified from fused_adam.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
from .multi_tensor_apply import MultiTensorApply
|
||||
|
||||
multi_tensor_applier = MultiTensorApply(2048 * 32)
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
from deepspeed.ops.op_builder import FusedLionBuilder
|
||||
|
||||
|
||||
class FusedLion(torch.optim.Optimizer):
|
||||
"""Implements Lion algorithm.
|
||||
|
||||
Currently GPU-only.
|
||||
|
||||
Arguments:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups.
|
||||
lr (float, optional): learning rate. (default: 1e-3)
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its square. (default: (0.9, 0.999))
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
set_grad_none (bool, optional): whether set grad to None when zero_grad()
|
||||
method is called. (default: True)
|
||||
|
||||
.. _Symbolic Discovery of Optimization Algorithms:
|
||||
https://doi.org/10.48550/arXiv.2302.06675
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), weight_decay=0., set_grad_none=True):
|
||||
|
||||
defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
|
||||
super(FusedLion, self).__init__(params, defaults)
|
||||
self.set_grad_none = set_grad_none
|
||||
|
||||
fused_lion_cuda = FusedLionBuilder().load()
|
||||
# Skip buffer
|
||||
self._dummy_overflow_buf = get_accelerator().IntTensor([0])
|
||||
self.multi_tensor_lion = fused_lion_cuda.multi_tensor_lion
|
||||
|
||||
def zero_grad(self):
|
||||
if self.set_grad_none:
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
p.grad = None
|
||||
else:
|
||||
super(FusedLion, self).zero_grad()
|
||||
|
||||
def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None, grad_scaler=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
|
||||
The remaining arguments are deprecated, and are only retained (for the moment) for error-checking purposes.
|
||||
"""
|
||||
if any(p is not None for p in [grads, output_params, scale, grad_norms]):
|
||||
raise RuntimeError('FusedLion has been updated.')
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
if len(group['params']) == 0:
|
||||
continue
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
# assume same step across group now to simplify things
|
||||
# per parameter step can be easily support by making it tensor, or pass list into kernel
|
||||
if 'step' not in group:
|
||||
group['step'] = 0
|
||||
|
||||
# create lists for multi-tensor apply
|
||||
g_16, p_16, m_16 = [], [], []
|
||||
g_bf, p_bf, m_bf = [], [], []
|
||||
g_32, p_32, m_32 = [], [], []
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
if p.grad.data.is_sparse:
|
||||
raise NotImplementedError('FusedLion does not support sparse gradients')
|
||||
|
||||
state = self.state[p]
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
# DeepSpeed ZeRO 3 processes each subgroup a time, so we need to keep tracking step count for each tensor separately.
|
||||
# While this is not an issue for ZeRO 1 & 2, since they apply a single optimization step to the whole param group at the same time.
|
||||
# In order to keep backward compatibility for the existing checkpoints, we use group['state'] to initialize state['step'] if it exists.
|
||||
state['step'] = group.get('step', 0)
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(p.data)
|
||||
|
||||
if p.dtype == torch.float16:
|
||||
g_16.append(p.grad.data)
|
||||
p_16.append(p.data)
|
||||
m_16.append(state['exp_avg'])
|
||||
elif p.dtype == torch.bfloat16:
|
||||
g_bf.append(p.grad)
|
||||
p_bf.append(p)
|
||||
m_bf.append(state['exp_avg'])
|
||||
elif p.dtype == torch.float32:
|
||||
g_32.append(p.grad.data)
|
||||
p_32.append(p.data)
|
||||
m_32.append(state['exp_avg'])
|
||||
else:
|
||||
raise RuntimeError('FusedLion only support fp16, bf16 and fp32.')
|
||||
|
||||
if len(g_16) > 0:
|
||||
state['step'] += 1
|
||||
multi_tensor_applier(self.multi_tensor_lion, self._dummy_overflow_buf, [g_16, p_16, m_16], group['lr'],
|
||||
beta1, beta2, state['step'], group['weight_decay'])
|
||||
|
||||
if len(g_bf) > 0:
|
||||
state['step'] += 1
|
||||
multi_tensor_applier(self.multi_tensor_lion, self._dummy_overflow_buf, [g_bf, p_bf, m_bf], group['lr'],
|
||||
beta1, beta2, state['step'], group['weight_decay'])
|
||||
|
||||
if len(g_32) > 0:
|
||||
state['step'] += 1
|
||||
multi_tensor_applier(self.multi_tensor_lion, self._dummy_overflow_buf, [g_32, p_32, m_32], group['lr'],
|
||||
beta1, beta2, state['step'], group['weight_decay'])
|
||||
|
||||
return loss
|
|
@ -0,0 +1,17 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
"""
|
||||
Copyright NVIDIA/apex
|
||||
This file is adapted from NVIDIA/apex, commit a109f85
|
||||
"""
|
||||
|
||||
|
||||
class MultiTensorApply(object):
|
||||
|
||||
def __init__(self, chunk_size):
|
||||
self.chunk_size = chunk_size
|
||||
|
||||
def __call__(self, op, noop_flag_buffer, tensor_lists, *args):
|
||||
return op(self.chunk_size, noop_flag_buffer, tensor_lists, *args)
|
|
@ -77,9 +77,10 @@ ONEBIT_LAMB_OPTIMIZER = 'onebitlamb'
|
|||
MUADAM_OPTIMIZER = 'muadam'
|
||||
MUADAMW_OPTIMIZER = 'muadamw'
|
||||
MUSGD_OPTIMIZER = 'musgd'
|
||||
LION_OPTIMIZER = 'lion'
|
||||
DEEPSPEED_OPTIMIZERS = [
|
||||
ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER,
|
||||
ZERO_ONE_ADAM_OPTIMIZER, MUADAM_OPTIMIZER, MUADAMW_OPTIMIZER, MUSGD_OPTIMIZER
|
||||
ZERO_ONE_ADAM_OPTIMIZER, MUADAM_OPTIMIZER, MUADAMW_OPTIMIZER, MUSGD_OPTIMIZER, LION_OPTIMIZER
|
||||
]
|
||||
|
||||
# extra optimizer parameters for adam/adamw
|
||||
|
|
|
@ -37,7 +37,8 @@ from deepspeed.runtime.bf16_optimizer import BF16_Optimizer
|
|||
|
||||
from deepspeed.runtime.config import DEEPSPEED_OPTIMIZERS, \
|
||||
ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \
|
||||
TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT, ZERO_ONE_ADAM_OPTIMIZER, MUADAM_OPTIMIZER, MUADAMW_OPTIMIZER, MUSGD_OPTIMIZER
|
||||
TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT, ZERO_ONE_ADAM_OPTIMIZER, MUADAM_OPTIMIZER, MUADAMW_OPTIMIZER, \
|
||||
MUSGD_OPTIMIZER, LION_OPTIMIZER
|
||||
|
||||
from deepspeed.runtime.dataloader import DeepSpeedDataLoader
|
||||
from deepspeed.runtime.constants import \
|
||||
|
@ -1296,6 +1297,13 @@ class DeepSpeedEngine(Module):
|
|||
optimizer = OnebitLamb(model_parameters, self, **optimizer_parameters)
|
||||
if not self.fp16_enabled():
|
||||
logger.warning(f"Currently the convergence of 1-bit Lamb is only verified under FP16")
|
||||
elif self.optimizer_name() == LION_OPTIMIZER:
|
||||
if self.zero_use_cpu_optimizer():
|
||||
from deepspeed.ops.lion import DeepSpeedCPULion
|
||||
optimizer = DeepSpeedCPULion(model_parameters, **optimizer_parameters)
|
||||
else:
|
||||
from deepspeed.ops.lion import FusedLion
|
||||
optimizer = FusedLion(model_parameters, **optimizer_parameters)
|
||||
elif self.optimizer_name() == MUADAM_OPTIMIZER:
|
||||
try:
|
||||
from mup import MuAdam
|
||||
|
|
|
@ -12,6 +12,7 @@ from deepspeed.utils import logger
|
|||
from deepspeed.ops.adam import DeepSpeedCPUAdam
|
||||
from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad
|
||||
from deepspeed.ops.adam import FusedAdam
|
||||
from deepspeed.ops.lion import DeepSpeedCPULion, FusedLion
|
||||
from deepspeed.utils.nvtx import instrument_w_nvtx
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
|
@ -37,7 +38,8 @@ class ZeRORuntimeException(Exception):
|
|||
|
||||
|
||||
ZERO_SUPPORTED_OPTIMIZERS = [
|
||||
torch.optim.Adam, torch.optim.AdamW, FusedAdam, DeepSpeedCPUAdam, torch.optim.Adagrad, DeepSpeedCPUAdagrad
|
||||
torch.optim.Adam, torch.optim.AdamW, FusedAdam, DeepSpeedCPUAdam, torch.optim.Adagrad, DeepSpeedCPUAdagrad,
|
||||
DeepSpeedCPULion, FusedLion
|
||||
]
|
||||
|
||||
# Add apex FusedAdam to supported list if apex is installed
|
||||
|
|
|
@ -60,7 +60,9 @@ Available `DS_BUILD` options include:
|
|||
* `DS_BUILD_AIO` builds asynchronous (NVMe) I/O op
|
||||
* `DS_BUILD_CCL_COMM` builds the communication collective libs
|
||||
* `DS_BUILD_CPU_ADAM` builds the CPUAdam op
|
||||
* `DS_BUILD_CPU_LION` builds the CPULion op
|
||||
* `DS_BUILD_FUSED_ADAM` builds the FusedAdam op (from [apex](https://github.com/NVIDIA/apex))
|
||||
* `DS_BUILD_FUSED_LION` builds the FusedLion op
|
||||
* `DS_BUILD_CPU_ADAGRAD` builds the CPUAdagrad op
|
||||
* `DS_BUILD_FUSED_LAMB` builds the FusedLamb op
|
||||
* `DS_BUILD_QUANTIZER` builds the quantizer op
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import os
|
||||
from .builder import TorchCPUOpBuilder
|
||||
|
||||
|
||||
class CPULionBuilder(TorchCPUOpBuilder):
|
||||
BUILD_VAR = "DS_BUILD_CPU_LION"
|
||||
NAME = "cpu_lion"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name=self.NAME)
|
||||
|
||||
def absolute_name(self):
|
||||
return f'deepspeed.ops.lion.{self.NAME}_op'
|
||||
|
||||
def sources(self):
|
||||
if self.build_for_cpu:
|
||||
return ['csrc/lion/cpu_lion.cpp', 'csrc/lion/cpu_lion_impl.cpp']
|
||||
|
||||
return ['csrc/lion/cpu_lion.cpp', 'csrc/lion/cpu_lion_impl.cpp', 'csrc/common/custom_cuda_kernel.cu']
|
||||
|
||||
def libraries_args(self):
|
||||
args = super().libraries_args()
|
||||
if self.build_for_cpu:
|
||||
return args
|
||||
|
||||
if not self.is_rocm_pytorch():
|
||||
args += ['curand']
|
||||
|
||||
return args
|
||||
|
||||
def include_paths(self):
|
||||
import torch
|
||||
if self.build_for_cpu:
|
||||
CUDA_INCLUDE = []
|
||||
elif not self.is_rocm_pytorch():
|
||||
CUDA_INCLUDE = [os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include")]
|
||||
else:
|
||||
CUDA_INCLUDE = [
|
||||
os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include"),
|
||||
os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include", "rocrand"),
|
||||
os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include", "hiprand"),
|
||||
]
|
||||
return ['csrc/includes'] + CUDA_INCLUDE
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from .builder import CUDAOpBuilder
|
||||
|
||||
import sys
|
||||
|
||||
|
||||
class FusedLionBuilder(CUDAOpBuilder):
|
||||
BUILD_VAR = "DS_BUILD_FUSED_LION"
|
||||
NAME = "fused_lion"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name=self.NAME)
|
||||
|
||||
def absolute_name(self):
|
||||
return f'deepspeed.ops.lion.{self.NAME}_op'
|
||||
|
||||
def sources(self):
|
||||
return ['csrc/lion/fused_lion_frontend.cpp', 'csrc/lion/multi_tensor_lion.cu']
|
||||
|
||||
def include_paths(self):
|
||||
return ['csrc/includes', 'csrc/lion']
|
||||
|
||||
def cxx_args(self):
|
||||
args = super().cxx_args()
|
||||
return args + self.version_dependent_macros()
|
||||
|
||||
def nvcc_args(self):
|
||||
nvcc_flags = ['-O3'] + self.version_dependent_macros()
|
||||
if not self.is_rocm_pytorch():
|
||||
nvcc_flags.extend(
|
||||
['-allow-unsupported-compiler' if sys.platform == "win32" else '', '-lineinfo', '--use_fast_math'] +
|
||||
self.compute_capability_args())
|
||||
return nvcc_flags
|
|
@ -0,0 +1,96 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import pytest
|
||||
from cpuinfo import get_cpu_info
|
||||
|
||||
import deepspeed
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
from deepspeed.ops.lion import FusedLion
|
||||
from deepspeed.ops.op_builder import CPULionBuilder
|
||||
from unit.common import DistributedTest
|
||||
|
||||
if not deepspeed.ops.__compatible_ops__[CPULionBuilder.NAME]:
|
||||
pytest.skip("cpu-lion is not compatible", allow_module_level=True)
|
||||
|
||||
pytest.cpu_vendor = get_cpu_info()["vendor_id_raw"].lower()
|
||||
|
||||
|
||||
def check_equal(first, second, atol=1e-2, verbose=False):
|
||||
x = first.detach().numpy()
|
||||
y = second.detach().numpy()
|
||||
print("ATOL", atol)
|
||||
if verbose:
|
||||
print("x = {}".format(x.flatten()))
|
||||
print("y = {}".format(y.flatten()))
|
||||
print('-' * 80)
|
||||
np.testing.assert_allclose(x, y, err_msg="param-update mismatch!", atol=atol)
|
||||
|
||||
|
||||
def _compare_optimizers(model_size, param1, optimizer1, param2, optimizer2):
|
||||
for i in range(10):
|
||||
param1.grad = torch.randn(model_size, device=param1.device).to(param1.dtype)
|
||||
param2.grad = param1.grad.clone().detach().to(device=param2.device, dtype=param2.dtype)
|
||||
|
||||
optimizer1.step()
|
||||
optimizer2.step()
|
||||
|
||||
tolerance = param1.float().norm().detach().numpy() * 1e-2
|
||||
check_equal(param1.float().norm(), param2.float().cpu().norm(), atol=tolerance, verbose=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dtype', [torch.half, torch.float], ids=["fp16", "fp32"])
|
||||
@pytest.mark.parametrize('model_size',
|
||||
[
|
||||
(64),
|
||||
(22),
|
||||
#(55),
|
||||
(128),
|
||||
(1024),
|
||||
(1048576),
|
||||
]) # yapf: disable
|
||||
class TestCPULion(DistributedTest):
|
||||
world_size = 1
|
||||
reuse_dist_env = True
|
||||
requires_cuda_env = False
|
||||
if not get_accelerator().is_available():
|
||||
init_distributed = False
|
||||
set_dist_env = False
|
||||
|
||||
@pytest.mark.skipif(not get_accelerator().is_available(), reason="only supported in CUDA environments.")
|
||||
def test_fused_lion_equal(self, dtype, model_size):
|
||||
if ("amd" in pytest.cpu_vendor) and (dtype == torch.half):
|
||||
pytest.skip("cpu-lion with half precision not supported on AMD CPUs")
|
||||
|
||||
from deepspeed.ops.lion import DeepSpeedCPULion
|
||||
|
||||
cpu_data = torch.randn(model_size, device='cpu').to(dtype)
|
||||
cpu_param = torch.nn.Parameter(cpu_data)
|
||||
cuda_param = torch.nn.Parameter(cpu_data.to(get_accelerator().device_name()))
|
||||
|
||||
cpu_optimizer = DeepSpeedCPULion([cpu_param])
|
||||
cuda_optimizer = FusedLion([cuda_param])
|
||||
|
||||
_compare_optimizers(model_size=model_size,
|
||||
param1=cpu_param,
|
||||
optimizer1=cpu_optimizer,
|
||||
param2=cuda_param,
|
||||
optimizer2=cuda_optimizer)
|
||||
|
||||
|
||||
class TestCPULionGPUError(DistributedTest):
|
||||
|
||||
def test_cpu_lion_gpu_error(self):
|
||||
model_size = 64
|
||||
from deepspeed.ops.lion import DeepSpeedCPULion
|
||||
device = get_accelerator().device_name(0) # 'cuda:0' or 'xpu:0'
|
||||
param = torch.nn.Parameter(torch.randn(model_size, device=device))
|
||||
optimizer = DeepSpeedCPULion([param])
|
||||
|
||||
param.grad = torch.randn(model_size, device=device)
|
||||
with pytest.raises(AssertionError):
|
||||
optimizer.step()
|
|
@ -0,0 +1,59 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import deepspeed
|
||||
import torch
|
||||
import pytest
|
||||
|
||||
from deepspeed.ops.lion import FusedLion
|
||||
from deepspeed.ops.lion import DeepSpeedCPULion
|
||||
from unit.common import DistributedTest
|
||||
from unit.simple_model import SimpleModel
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
if torch.half not in get_accelerator().supported_dtypes():
|
||||
pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True)
|
||||
# yapf: disable
|
||||
#'optimizer, zero_offload, resulting_optimizer
|
||||
lion_configs = [["Lion", False, FusedLion],
|
||||
["Lion", True, DeepSpeedCPULion]]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'optimizer, zero_offload, resulting_optimizer',
|
||||
lion_configs)
|
||||
class TestLionConfigs(DistributedTest):
|
||||
world_size = 1
|
||||
reuse_dist_env = True
|
||||
|
||||
def test(self,
|
||||
optimizer,
|
||||
zero_offload,
|
||||
resulting_optimizer):
|
||||
config_dict = {
|
||||
"train_batch_size": 2,
|
||||
"steps_per_print": 1,
|
||||
"optimizer": {
|
||||
"type": optimizer,
|
||||
"params": {
|
||||
"lr": 0.00015,
|
||||
}
|
||||
},
|
||||
"gradient_clipping": 1.0,
|
||||
"fp16": {
|
||||
"enabled": True
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"cpu_offload": zero_offload
|
||||
}
|
||||
}
|
||||
model = SimpleModel(10)
|
||||
model, _, _, _ = deepspeed.initialize(config=config_dict,
|
||||
model=model,
|
||||
model_parameters=model.parameters())
|
||||
# get base optimizer under zero
|
||||
ds_optimizer = model.optimizer.optimizer
|
||||
opt_class = resulting_optimizer
|
||||
assert isinstance(ds_optimizer, opt_class)
|
Загрузка…
Ссылка в новой задаче