diff --git a/csrc/cpu/lion/fused_lion.cpp b/csrc/cpu/lion/fused_lion.cpp new file mode 100644 index 000000000..708df7f01 --- /dev/null +++ b/csrc/cpu/lion/fused_lion.cpp @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "cpu_lion.h" + +// C++ interface + +void multi_tensor_lion(int chunk_size, + at::Tensor noop_flag, + std::vector> tensor_lists, /*gpmv*/ + const float lr, + const float beta1, + const float beta2, + const int step, + const int mode, + const float weight_decay) +{ + static bool initialized = false; + if (!initialized) { + create_lion_optimizer(0); + initialized = true; + } + for (int i = 0; i < tensor_lists[0].size(); i++) { + ds_lion_step(0, + step, + lr, + beta1, + beta2, + weight_decay, + tensor_lists[1][i], + tensor_lists[0][i], + tensor_lists[2][i]); + } +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("multi_tensor_lion", + &multi_tensor_lion, + "Compute and apply gradient update to parameters for Lion optimizer"); +} diff --git a/csrc/includes/cpu_lion.h b/csrc/includes/cpu_lion.h new file mode 100644 index 000000000..76034ceb3 --- /dev/null +++ b/csrc/includes/cpu_lion.h @@ -0,0 +1,233 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#define NOMINMAX // Windows idiosyncrasy + // https://stackoverflow.com/questions/4913922/possible-problems-with-nominmax-on-visual-c + +#include +#include +#include +#include "simd.h" + +#if defined(__ENABLE_CUDA__) +#include +#include +#include "cuda.h" +#include "custom_cuda_layers.h" +typedef __half ds_half_precision_t; +#else +#include +typedef unsigned short ds_half_precision_t; +#endif + +#define STEP(SPAN) \ + void Step_##SPAN(float* _params, \ + float* grads, \ + float* _exp_avg, \ + size_t _param_size, \ + ds_half_precision_t* dev_param = nullptr, \ + bool half_precision = false); + +class Lion_Optimizer { +public: + Lion_Optimizer(float alpha = 1e-3, + float betta1 = 0.9, + float betta2 = 0.999, + float weight_decay = 0) + : _alpha(alpha), _betta1(betta1), _betta2(betta2), _weight_decay(weight_decay), _step(0) + { +#if defined(__ENABLE_CUDA__) + cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float)); + cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float)); + + _streams[0] = TrainingContext::Instance().GetCurrentStream(); + _streams[1] = TrainingContext::Instance().GetNewStream(); + _buf_index = false; +#endif + } + ~Lion_Optimizer() + { +#if defined(__ENABLE_CUDA__) + cudaFreeHost(_doubled_buffer[0]); + cudaFreeHost(_doubled_buffer[1]); +#endif + } + +#if defined(__AVX512__) or defined(__AVX256__) + template + void Step_AVX(size_t* rounded_size, + float* _params, + float* grads, + float* _exp_avg, + size_t param_size, + ds_half_precision_t* dev_param = nullptr, + bool half_precision = false); +#endif + STEP(1) + STEP(4) + STEP(8) +#if defined(__ENABLE_CUDA__) + inline void SynchronizeStreams() + { + for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]); + } +#endif + inline void IncrementStep(size_t step, float beta1, float beta2) + { + _step++; + if (_step != step || beta1 != _betta1 || beta2 != _betta2) { + _step = step; + _betta1 = beta1; + _betta2 = beta2; + } + } + inline void update_state(float lr, float weight_decay) + { + _alpha = lr; + _weight_decay = weight_decay; + } + +private: + float _alpha; + float _betta1; + float _betta2; + float _weight_decay; + size_t _step; + +#if defined(__ENABLE_CUDA__) + float* _doubled_buffer[2]; + cudaStream_t _streams[2]; + bool _buf_index; +#endif +}; + +#if defined(__AVX512__) or defined(__AVX256__) +template +void Lion_Optimizer::Step_AVX(size_t* rounded_size, + float* _params, + float* grads, + float* _exp_avg, + size_t _param_size, + ds_half_precision_t* dev_params, + bool half_precision) +{ + size_t new_rounded_size = 0; + int rshft = half_precision ? 1 : 0; + + constexpr float neg1 = -1.0f; + AVX_Data neg1_4; + neg1_4.data = SIMD_SET(neg1); + + AVX_Data betta1_4; + betta1_4.data = SIMD_SET(_betta1); + AVX_Data betta2_4; + betta2_4.data = SIMD_SET(_betta2); + + float betta1_minus1 = 1 - _betta1; + float betta2_minus1 = 1 - _betta2; + AVX_Data betta1_minus1_4; + betta1_minus1_4.data = SIMD_SET(betta1_minus1); + AVX_Data betta2_minus1_4; + betta2_minus1_4.data = SIMD_SET(betta2_minus1); + + float step_size = -_alpha; + AVX_Data step_size_4; + step_size_4.data = SIMD_SET(step_size); + + float after_decay = 1.0f - _alpha * _weight_decay; + AVX_Data after_decay_4; + if (_weight_decay > 0) after_decay_4.data = SIMD_SET(after_decay); + + new_rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * span); + for (size_t t = 0; t < new_rounded_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t; + size_t offset = copy_size + t; +#if defined(__ENABLE_CUDA__) + if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } +#endif +#pragma omp parallel for + for (size_t i = t; i < offset; i += SIMD_WIDTH * span) { + AVX_Data grad_4[span]; + simd_load(grad_4, grads + (i >> rshft), half_precision); + + AVX_Data momentum_4[span]; + simd_load(momentum_4, _exp_avg + i, false); + + AVX_Data param_4[span]; + simd_load(param_4, _params + (i >> rshft), half_precision); + + AVX_Data tmp_4[span]; + + simd_mul(tmp_4, momentum_4, betta1_4); + simd_fma(tmp_4, grad_4, betta1_minus1_4, tmp_4); + // We already used intrinsics, so consider the machine representation fixed. + simd_and(tmp_4, tmp_4, neg1_4); + simd_xor(tmp_4, tmp_4, step_size_4); + if (_weight_decay > 0) { + simd_fma(param_4, param_4, after_decay_4, tmp_4); + } else { + simd_add(param_4, param_4, tmp_4); + } + + simd_mul(momentum_4, momentum_4, betta2_4); + simd_fma(momentum_4, grad_4, betta2_minus1_4, momentum_4); + + simd_store(_params + (i >> rshft), param_4, half_precision); +#if defined(__ENABLE_CUDA__) + if (dev_params) { + simd_store(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision); + } +#endif + simd_store(_exp_avg + i, momentum_4, false); + } +#if defined(__ENABLE_CUDA__) + if (dev_params) { + if (half_precision) + launch_param_update_half( + _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]); + else + launch_param_update( + _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]); + + _buf_index = !_buf_index; + } +#endif + } + *rounded_size = new_rounded_size; +} +#endif + +int create_lion_optimizer(int optimizer_id, + float alpha = 1e-3, + float betta1 = 0.9, + float betta2 = 0.999, + float weight_decay = 0, + bool should_log = false); + +int ds_lion_step(int optimizer_id, + size_t step, + float lr, + float beta1, + float beta2, + float weight_decay, + torch::Tensor& params, + torch::Tensor& grads, + torch::Tensor& exp_avg); + +int ds_lion_step_plus_copy(int optimizer_id, + size_t step, + float lr, + float beta1, + float beta2, + float weight_decay, + torch::Tensor& params, + torch::Tensor& grads, + torch::Tensor& exp_avg, + torch::Tensor& gpu_params); + +int destroy_lion_optimizer(int optimizer_id); diff --git a/csrc/includes/simd.h b/csrc/includes/simd.h index 712dd5b32..f77568be7 100644 --- a/csrc/includes/simd.h +++ b/csrc/includes/simd.h @@ -24,6 +24,10 @@ #define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c) #define SIMD_SQRT(x) _mm512_sqrt_ps(x) #define SIMD_DIV(x, y) _mm512_div_ps(x, y) +#define SIMD_AND(x, y) _mm512_and_ps(x, y) +#define SIMD_ANDNOT(x, y) _mm512_andnot_ps(x, y) +#define SIMD_OR(x, y) _mm512_or_ps(x, y) +#define SIMD_XOR(x, y) _mm512_xor_ps(x, y) #define SIMD_WIDTH 16 #define SIMD_LOAD2(x, h) \ @@ -42,10 +46,14 @@ #define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c) #define SIMD_SQRT(x) _mm256_sqrt_ps(x) #define SIMD_DIV(x, y) _mm256_div_ps(x, y) +#define SIMD_AND(x, y) _mm256_and_ps(x, y) +#define SIMD_ANDNOT(x, y) _mm256_andnot_ps(x, y) +#define SIMD_OR(x, y) _mm256_or_ps(x, y) +#define SIMD_XOR(x, y) _mm256_xor_ps(x, y) #define SIMD_WIDTH 8 + #define SIMD_LOAD2(x, h) \ ((h) ? _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)x)) : _mm256_loadu_ps(x)) - #define SIMD_STORE2(x, d, h) \ ((h) ? _mm_store_ps(x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \ : _mm256_storeu_ps(x, d)) @@ -136,5 +144,55 @@ inline void simd_div(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r) #pragma unroll for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_DIV(src_a_l[i].data, src_a_r[i].data); } } +template +inline void simd_and(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_AND(src_a_l[i].data, src_a_r.data); } +} +template +inline void simd_and(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_AND(src_a_l[i].data, src_a_r[i].data); } +} +template +inline void simd_andnot(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_ANDNOT(src_a_l[i].data, src_a_r.data); } +} +template +inline void simd_andnot(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { + dst[i].data = SIMD_ANDNOT(src_a_l[i].data, src_a_r[i].data); + } +} +template +inline void simd_or(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_OR(src_a_l[i].data, src_a_r.data); } +} +template +inline void simd_or(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_OR(src_a_l[i].data, src_a_r[i].data); } +} +template +inline void simd_xor(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_XOR(src_a_l[i].data, src_a_r.data); } +} +template +inline void simd_xor(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_XOR(src_a_l[i].data, src_a_r[i].data); } +} #endif diff --git a/csrc/lion/cpu_lion.cpp b/csrc/lion/cpu_lion.cpp new file mode 100644 index 000000000..a0562eac9 --- /dev/null +++ b/csrc/lion/cpu_lion.cpp @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "cpu_lion.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("lion_update", &ds_lion_step, "DeepSpeed CPU Lion update (C++)"); + m.def("lion_update_copy", + &ds_lion_step_plus_copy, + "DeepSpeed CPU Lion update and param copy (C++)"); + m.def("create_lion", &create_lion_optimizer, "DeepSpeed CPU Lion (C++)"); + m.def("destroy_lion", &destroy_lion_optimizer, "DeepSpeed CPU Lion destroy (C++)"); +} diff --git a/csrc/lion/cpu_lion_impl.cpp b/csrc/lion/cpu_lion_impl.cpp new file mode 100644 index 000000000..5c24e23b4 --- /dev/null +++ b/csrc/lion/cpu_lion_impl.cpp @@ -0,0 +1,255 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include +#include +#include +#include +#include +#include +#include "cpu_lion.h" + +#if defined(__ENABLE_CUDA__) +#include +#include "cublas_v2.h" +#include "cuda.h" +#include "curand.h" +#include "custom_cuda_layers.h" +#endif + +static std::unordered_map> s_optimizers; + +// C++ interface + +void Lion_Optimizer::Step_1(float* _params, + float* grads, + float* _exp_avg, + size_t _param_size, + ds_half_precision_t* dev_params, + bool half_precision) +{ + size_t rounded_size = 0; +#if defined(__AVX512__) or defined(__AVX256__) + Step_AVX<1>(&rounded_size, _params, grads, _exp_avg, _param_size, dev_params, half_precision); +#endif + if (_param_size > rounded_size) { + float betta1_minus1 = 1 - _betta1; + float betta2_minus1 = 1 - _betta2; + + float alpha = _alpha; + float after_decay = 1 - alpha * _weight_decay; + ds_half_precision_t* grads_cast_h; + ds_half_precision_t* params_cast_h; + if (half_precision) { + grads_cast_h = reinterpret_cast(grads); + params_cast_h = reinterpret_cast(_params); + } + + for (size_t t = rounded_size; t < _param_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > _param_size) copy_size = _param_size - t; + size_t offset = copy_size + t; +#if defined(__ENABLE_CUDA__) + if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } +#endif +#pragma omp parallel for + for (size_t k = t; k < offset; k++) { + float grad = half_precision ? (float)grads_cast_h[k] : grads[k]; + float param = half_precision ? (float)params_cast_h[k] : _params[k]; + float momentum = _exp_avg[k]; + float tmp = momentum * _betta1; + tmp = grad * betta1_minus1 + tmp; + // Rely on portable C++ methods to manipulate the sign bit of a floating-point + // number. + tmp = -std::copysignf(alpha, tmp); + if (_weight_decay > 0) { + param = param * after_decay + tmp; + } else { + param = param + tmp; + } + momentum = momentum * _betta2; + momentum = grad * betta2_minus1 + momentum; +#if defined(__ENABLE_CUDA__) + if (dev_params) _doubled_buffer[_buf_index][k - t] = param; +#endif + if (half_precision) + params_cast_h[k] = (ds_half_precision_t)param; + else + _params[k] = param; + _exp_avg[k] = momentum; + } +#if defined(__ENABLE_CUDA__) + if (dev_params) { + launch_param_update( + _doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]); + + _buf_index = !_buf_index; + } +#endif + } + } +} + +void Lion_Optimizer::Step_4(float* _params, + float* grads, + float* _exp_avg, + size_t _param_size, + ds_half_precision_t* dev_params, + bool half_precision) +{ + size_t rounded_size = 0; +#if defined(__AVX512__) or defined(__AVX256__) + Step_AVX<4>(&rounded_size, _params, grads, _exp_avg, _param_size, dev_params, half_precision); +#endif + if (_param_size > rounded_size) + Step_1((_params + rounded_size), + (grads + rounded_size), + (_exp_avg + rounded_size), + (_param_size - rounded_size), + (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), + half_precision); +} + +int create_lion_optimizer(int optimizer_id, + float alpha, + float betta1, + float betta2, + float weight_decay, + bool should_log) +{ + auto opt = std::make_shared(alpha, betta1, betta2, weight_decay); + + s_optimizers[optimizer_id] = opt; + + if (should_log) { + std::string avx_type = ""; +#if defined(__AVX512__) + avx_type = "AVX512"; +#else +#if defined(__AVX256__) + avx_type = "AVX2"; +#else + avx_type = "scalar"; +#endif +#endif + + printf("Lion Optimizer #%d is created with %s arithmetic capability.\n", + optimizer_id, + avx_type.c_str()); + printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f\n", + alpha, + betta1, + betta2, + weight_decay); + } + + return 0; +} + +void Lion_Optimizer::Step_8(float* _params, + float* grads, + float* _exp_avg, + size_t _param_size, + ds_half_precision_t* dev_params, + bool half_precision) +{ + size_t rounded_size = 0; +#if defined(__AVX512__) or defined(__AVX256__) + Step_AVX<8>(&rounded_size, _params, grads, _exp_avg, _param_size, dev_params, half_precision); +#endif + if (_param_size > rounded_size) + Step_4((_params + rounded_size), + (grads + rounded_size), + (_exp_avg + rounded_size), + (_param_size - rounded_size), + (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), + half_precision); +} + +int ds_lion_step(int optimizer_id, + size_t step, + float lr, + float beta1, + float beta2, + float weight_decay, + torch::Tensor& params, + torch::Tensor& grads, + torch::Tensor& exp_avg) +{ + auto params_c = params.contiguous(); + auto grads_c = grads.contiguous(); + auto exp_avg_c = exp_avg.contiguous(); + + // assert(params.options().dtype() == grads.options().dtype()); + + float* params_ptr = (float*)params_c.data_ptr(); + float* grads_ptr = (float*)grads_c.data_ptr(); + float* exp_avg_ptr = (float*)exp_avg_c.data_ptr(); + + std::shared_ptr opt = + std::static_pointer_cast(s_optimizers[optimizer_id]); + opt->IncrementStep(step, beta1, beta2); + opt->update_state(lr, weight_decay); + + opt->Step_8(params_ptr, + grads_ptr, + exp_avg_ptr, + params_c.numel(), + nullptr, + (params.options().dtype() == at::kHalf)); + +#if defined(__ENABLE_CUDA__) + opt->SynchronizeStreams(); +#endif + return 0; +} + +int ds_lion_step_plus_copy(int optimizer_id, + size_t step, + float lr, + float beta1, + float beta2, + float weight_decay, + torch::Tensor& params, + torch::Tensor& grads, + torch::Tensor& exp_avg, + torch::Tensor& gpu_params) +{ +#if defined(__ENABLE_CUDA__) + auto params_c = params.contiguous(); + auto gpu_params_c = gpu_params.contiguous(); + auto exp_avg_c = exp_avg.contiguous(); + auto grads_c = grads.contiguous(); + + float* params_ptr = (float*)params_c.data_ptr(); + float* grads_ptr = (float*)grads_c.data_ptr(); + ds_half_precision_t* gpu_params_ptr = (ds_half_precision_t*)gpu_params_c.data_ptr(); + float* exp_avg_ptr = (float*)exp_avg_c.data_ptr(); + + std::shared_ptr opt = + std::static_pointer_cast(s_optimizers[optimizer_id]); + opt->IncrementStep(step, beta1, beta2); + opt->update_state(lr, weight_decay); + opt->Step_8(params_ptr, + grads_ptr, + exp_avg_ptr, + params_c.numel(), + gpu_params_ptr, + (params.options().dtype() == at::kHalf)); + + opt->SynchronizeStreams(); +#else + assert(false); +#endif + return 0; +} + +int destroy_lion_optimizer(int optimizer_id) +{ + s_optimizers.erase(optimizer_id); + + return 0; +} diff --git a/csrc/lion/fused_lion_frontend.cpp b/csrc/lion/fused_lion_frontend.cpp new file mode 100644 index 000000000..e523f97ca --- /dev/null +++ b/csrc/lion/fused_lion_frontend.cpp @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include + +void multi_tensor_lion_cuda(int chunk_size, + at::Tensor noop_flag, + std::vector> tensor_lists, + const float lr, + const float beta1, + const float beta2, + const int step, + const float weight_decay); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("multi_tensor_lion", + &multi_tensor_lion_cuda, + "Compute and apply gradient update to parameters for Lion optimizer"); +} diff --git a/csrc/lion/multi_tensor_apply.cuh b/csrc/lion/multi_tensor_apply.cuh new file mode 100644 index 000000000..12f41cb49 --- /dev/null +++ b/csrc/lion/multi_tensor_apply.cuh @@ -0,0 +1,132 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/* +Copyright NVIDIA/apex +This file is adapted from fused adam in NVIDIA/apex, commit a109f85 +*/ + +#include +#include +#include +#include +#include +#include "compat.h" + +#include + +// #include + +// This header is the one-stop shop for all your multi-tensor apply needs. + +// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson) +constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30}; +constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320}; + +template +struct TensorListMetadata { + void* addresses[n][depth_to_max_tensors[n - 1]]; + int sizes[depth_to_max_tensors[n - 1]]; + unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; + int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a full int. + int start_tensor_this_launch; +}; + +template +__global__ void multi_tensor_apply_kernel(int chunk_size, + volatile int* noop_flag, + T tl, + U callable, + ArgTypes... args) +{ + // Hand the chunk information to the user-supplied functor to process however it likes. + callable(chunk_size, noop_flag, tl, args...); +} + +template +void multi_tensor_apply(int block_size, + int chunk_size, + const at::Tensor& noop_flag, + const std::vector>& tensor_lists, + T callable, + ArgTypes... args) +{ + TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth"); + int len0 = tensor_lists[0].size(); + TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0"); + auto ref_device = tensor_lists[0][0].device(); + TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda"); + for (int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices + { + TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists"); + for (int t = 0; t < tensor_lists[l].size(); t++) { + // TODO: Print which tensor fails. + bool contiguous_memory = tensor_lists[l][t].is_contiguous(); +#ifdef VERSION_GE_1_5 + contiguous_memory = (contiguous_memory || + tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast)); +#endif + TORCH_CHECK(contiguous_memory, "A tensor was not contiguous."); + TORCH_CHECK(tensor_lists[l][t].device() == ref_device, + "A tensor was not on the same device as the first tensor"); + TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch"); + } + } + + int ntensors = tensor_lists[0].size(); + + TensorListMetadata tl; + + const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0])); + auto stream = at::cuda::getCurrentCUDAStream(); + + tl.start_tensor_this_launch = 0; + int loc_block_info = 0; + int loc_tensor_info = 0; + for (int t = 0; t < ntensors; t++) { + tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel(); + for (int d = 0; d < depth; d++) + tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr(); + loc_tensor_info++; + + int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; + + for (int chunk = 0; chunk < chunks_this_tensor; chunk++) { + // std::cout << chunks_this_tensor << std::endl; + tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1; + tl.block_to_chunk[loc_block_info] = chunk; + loc_block_info++; + + bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth - 1] && + chunk == chunks_this_tensor - 1); + bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]); + bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1); + if (tensors_full || blocks_full || last_chunk) { + // using accscalar_t = acc_type; + multi_tensor_apply_kernel<<>>( + chunk_size, noop_flag.DATA_PTR(), tl, callable, args...); + + AT_CUDA_CHECK(cudaGetLastError()); + + // Reset. The control flow possibilities here make my brain hurt. + loc_block_info = 0; + if (chunk == chunks_this_tensor - 1) { + // std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 << + // std::endl; + loc_tensor_info = 0; + tl.start_tensor_this_launch = t + 1; + } else { + // std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 << + // std::endl; + tl.sizes[0] = tl.sizes[loc_tensor_info - 1]; + for (int d = 0; d < depth; d++) + tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1]; + loc_tensor_info = 1; + tl.start_tensor_this_launch = t; + } + } + } + } +} diff --git a/csrc/lion/multi_tensor_lion.cu b/csrc/lion/multi_tensor_lion.cu new file mode 100644 index 000000000..f5fe6dfdd --- /dev/null +++ b/csrc/lion/multi_tensor_lion.cu @@ -0,0 +1,126 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/* +Copyright NVIDIA/apex +This file is adapted from fused adam in NVIDIA/apex, commit a109f85 +*/ + +#include +#include +#include +#include +// Another possibility: +// #include + +#include + +#include "multi_tensor_apply.cuh" +#include "type_shim.h" + +#define BLOCK_SIZE 512 +#define ILP 4 + +using MATH_T = float; + +template +struct LionFunctor { + __device__ __forceinline__ void operator()(int chunk_size, + volatile int* noop_gmem, + TensorListMetadata<3>& tl, + const float beta1, + const float beta2, + const float lr, + const float decay) + { + // I'd like this kernel to propagate infs/nans. + // if(*noop_gmem == 1) + // return; + + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + + // potentially use to pass in list of scalar + // int tensor_num = tl.start_tensor_this_launch + tensor_loc; + + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + T* g = (T*)tl.addresses[0][tensor_loc]; + g += chunk_idx * chunk_size; + + T* p = (T*)tl.addresses[1][tensor_loc]; + p += chunk_idx * chunk_size; + + T* m = (T*)tl.addresses[2][tensor_loc]; + m += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + MATH_T after_decay = 1.0f - lr * decay; + + // see note in multi_tensor_scale_kernel.cu + for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { + MATH_T r_g[ILP]; + MATH_T r_p[ILP]; + MATH_T r_m[ILP]; +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + r_g[ii] = g[i]; + r_p[ii] = p[i]; + r_m[ii] = m[i]; + } else { + r_g[ii] = MATH_T(0); + r_p[ii] = MATH_T(0); + r_m[ii] = MATH_T(0); + } + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + MATH_T c = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + MATH_T update = c > 0 ? (-lr) : lr; + r_p[ii] = r_p[ii] * after_decay + update; + r_m[ii] = beta2 * r_m[ii] + (1 - beta2) * r_g[ii]; + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + p[i] = r_p[ii]; + m[i] = r_m[ii]; + } + } + } + } +}; + +void multi_tensor_lion_cuda(int chunk_size, + at::Tensor noop_flag, + std::vector> tensor_lists, + const float lr, + const float beta1, + const float beta2, + const int step, + const float weight_decay) +{ + using namespace at; + + // Assume single type across p,g,m1,m2 now + DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), + 0, + "lion", + multi_tensor_apply<3>(BLOCK_SIZE, + chunk_size, + noop_flag, + tensor_lists, + LionFunctor(), + beta1, + beta2, + lr, + weight_decay);) + + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/deepspeed/ops/__init__.py b/deepspeed/ops/__init__.py index b5a03c458..ba1c9c1fd 100755 --- a/deepspeed/ops/__init__.py +++ b/deepspeed/ops/__init__.py @@ -6,6 +6,7 @@ from . import adam from . import adagrad from . import lamb +from . import lion #from ..git_version_info_installed import installed_ops as __installed_ops__ #if __installed_ops__['sparse_attn']: from . import sparse_attention diff --git a/deepspeed/ops/lion/__init__.py b/deepspeed/ops/lion/__init__.py new file mode 100755 index 000000000..2f90e5ec2 --- /dev/null +++ b/deepspeed/ops/lion/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .cpu_lion import DeepSpeedCPULion +from .fused_lion import FusedLion diff --git a/deepspeed/ops/lion/cpu_lion.py b/deepspeed/ops/lion/cpu_lion.py new file mode 100755 index 000000000..a91a00643 --- /dev/null +++ b/deepspeed/ops/lion/cpu_lion.py @@ -0,0 +1,141 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from cpuinfo import get_cpu_info +from deepspeed.utils import logger +from deepspeed.utils.logging import should_log_le +from deepspeed.ops.op_builder import CPULionBuilder + + +class DeepSpeedCPULion(torch.optim.Optimizer): + optimizer_id = 0 + + def __init__(self, model_params, lr=1e-3, betas=(0.9, 0.999), weight_decay=0, fp32_optimizer_states=True): + """Fast vectorized implementation of Lion optimizer on CPU: + + See Symbolic Discovery of Optimization Algorithms (https://doi.org/10.48550/arXiv.2302.06675). + + .. note:: + We recommend using our `config + `_ + to allow :meth:`deepspeed.initialize` to build this optimizer + for you. + + + Arguments: + model_params (iterable): iterable of parameters to optimize or dicts defining + parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square. (default: (0.9, 0.999)) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + full_precision_optimizer_states: creates momentum and variance in full precision regardless of + the precision of the parameters (default: True) + """ + + default_args = dict(lr=lr, betas=betas, weight_decay=weight_decay) + super(DeepSpeedCPULion, self).__init__(model_params, default_args) + + cpu_info = get_cpu_info() + self.cpu_vendor = cpu_info["vendor_id_raw"].lower() if "vendor_id_raw" in cpu_info else "unknown" + if "amd" in self.cpu_vendor: + for group_id, group in enumerate(self.param_groups): + for param_id, p in enumerate(group['params']): + if p.dtype == torch.half: + logger.warning("FP16 params for CPULion may not work on AMD CPUs") + break + else: + continue + break + + self.opt_id = DeepSpeedCPULion.optimizer_id + DeepSpeedCPULion.optimizer_id = DeepSpeedCPULion.optimizer_id + 1 + self.fp32_optimizer_states = fp32_optimizer_states + self.ds_opt_lion = CPULionBuilder().load() + + self.ds_opt_lion.create_lion(self.opt_id, lr, betas[0], betas[1], weight_decay, should_log_le("info")) + + def __del__(self): + # need to destroy the C++ object explicitly to avoid a memory leak when deepspeed.initialize + # is used multiple times in the same process (notebook or pytest worker) + self.ds_opt_lion.destroy_lion(self.opt_id) + + def __setstate__(self, state): + super(DeepSpeedCPULion, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + @torch.no_grad() + def step(self, closure=None, fp16_param_groups=None): + """Update the model parameters. + + .. note:: + This method will be called internally by ZeRO-Offload. DeepSpeed + users should still use ``engine.step()`` as shown in the + `Getting Started + `_ guide. + + Args: + closure (callable, optional): closure to compute the loss. + Defaults to ``None``. + fp16_param_groups: FP16 GPU parameters to update. Performing the + copy here reduces communication time. Defaults to ``None``. + + Returns: + loss: if ``closure`` is provided. Otherwise ``None``. + """ + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + # intended device for step + device = torch.device('cpu') + + # converting the fp16 params to a group of parameter + if type(fp16_param_groups) is list: + if type(fp16_param_groups[0]) is not list: + fp16_param_groups = [fp16_param_groups] + elif fp16_param_groups is not None: + fp16_param_groups = [[fp16_param_groups]] + + for group_id, group in enumerate(self.param_groups): + for param_id, p in enumerate(group['params']): + + if p.grad is None: + continue + + assert p.device == device, f"CPULion param is on {p.device} and must be 'cpu', make " \ + "sure you enabled 'offload_optimizer': 'cpu' in your ZeRO config." + + state = self.state[p] + # State initialization + if len(state) == 0: + #print(f'group {group_id} param {param_id} = {p.numel()}') + state['step'] = 0 + + #use full precision by default unless self.fp32_optimizer_states is off + state_dtype = torch.float if self.fp32_optimizer_states else p.dtype + + # gradient momentums + state['exp_avg'] = torch.zeros_like(p.data, dtype=state_dtype, device=device) + #memory_format=torch.preserve_format) + # gradient variances + state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=state_dtype, device=device) + #memory_format=torch.preserve_format) + + state['step'] += 1 + beta1, beta2 = group['betas'] + + if fp16_param_groups is not None: + self.ds_opt_lion.lion_update_copy(self.opt_id, state['step'], group['lr'], beta1, beta2, + group['weight_decay'], p.data, p.grad.data, state['exp_avg'], + fp16_param_groups[group_id][param_id].data) + else: + self.ds_opt_lion.lion_update(self.opt_id, state['step'], group['lr'], beta1, beta2, + group['weight_decay'], p.data, p.grad.data, state['exp_avg']) + return loss diff --git a/deepspeed/ops/lion/fused_lion.py b/deepspeed/ops/lion/fused_lion.py new file mode 100644 index 000000000..7332a7f96 --- /dev/null +++ b/deepspeed/ops/lion/fused_lion.py @@ -0,0 +1,131 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +This file is modified from fused_adam.py +""" + +import torch +from .multi_tensor_apply import MultiTensorApply + +multi_tensor_applier = MultiTensorApply(2048 * 32) +from deepspeed.accelerator import get_accelerator +from deepspeed.ops.op_builder import FusedLionBuilder + + +class FusedLion(torch.optim.Optimizer): + """Implements Lion algorithm. + + Currently GPU-only. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square. (default: (0.9, 0.999)) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + set_grad_none (bool, optional): whether set grad to None when zero_grad() + method is called. (default: True) + + .. _Symbolic Discovery of Optimization Algorithms: + https://doi.org/10.48550/arXiv.2302.06675 + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), weight_decay=0., set_grad_none=True): + + defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) + super(FusedLion, self).__init__(params, defaults) + self.set_grad_none = set_grad_none + + fused_lion_cuda = FusedLionBuilder().load() + # Skip buffer + self._dummy_overflow_buf = get_accelerator().IntTensor([0]) + self.multi_tensor_lion = fused_lion_cuda.multi_tensor_lion + + def zero_grad(self): + if self.set_grad_none: + for group in self.param_groups: + for p in group['params']: + p.grad = None + else: + super(FusedLion, self).zero_grad() + + def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None, grad_scaler=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + + The remaining arguments are deprecated, and are only retained (for the moment) for error-checking purposes. + """ + if any(p is not None for p in [grads, output_params, scale, grad_norms]): + raise RuntimeError('FusedLion has been updated.') + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + if len(group['params']) == 0: + continue + beta1, beta2 = group['betas'] + + # assume same step across group now to simplify things + # per parameter step can be easily support by making it tensor, or pass list into kernel + if 'step' not in group: + group['step'] = 0 + + # create lists for multi-tensor apply + g_16, p_16, m_16 = [], [], [] + g_bf, p_bf, m_bf = [], [], [] + g_32, p_32, m_32 = [], [], [] + + for p in group['params']: + if p.grad is None: + continue + if p.grad.data.is_sparse: + raise NotImplementedError('FusedLion does not support sparse gradients') + + state = self.state[p] + # State initialization + if len(state) == 0: + # DeepSpeed ZeRO 3 processes each subgroup a time, so we need to keep tracking step count for each tensor separately. + # While this is not an issue for ZeRO 1 & 2, since they apply a single optimization step to the whole param group at the same time. + # In order to keep backward compatibility for the existing checkpoints, we use group['state'] to initialize state['step'] if it exists. + state['step'] = group.get('step', 0) + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + + if p.dtype == torch.float16: + g_16.append(p.grad.data) + p_16.append(p.data) + m_16.append(state['exp_avg']) + elif p.dtype == torch.bfloat16: + g_bf.append(p.grad) + p_bf.append(p) + m_bf.append(state['exp_avg']) + elif p.dtype == torch.float32: + g_32.append(p.grad.data) + p_32.append(p.data) + m_32.append(state['exp_avg']) + else: + raise RuntimeError('FusedLion only support fp16, bf16 and fp32.') + + if len(g_16) > 0: + state['step'] += 1 + multi_tensor_applier(self.multi_tensor_lion, self._dummy_overflow_buf, [g_16, p_16, m_16], group['lr'], + beta1, beta2, state['step'], group['weight_decay']) + + if len(g_bf) > 0: + state['step'] += 1 + multi_tensor_applier(self.multi_tensor_lion, self._dummy_overflow_buf, [g_bf, p_bf, m_bf], group['lr'], + beta1, beta2, state['step'], group['weight_decay']) + + if len(g_32) > 0: + state['step'] += 1 + multi_tensor_applier(self.multi_tensor_lion, self._dummy_overflow_buf, [g_32, p_32, m_32], group['lr'], + beta1, beta2, state['step'], group['weight_decay']) + + return loss diff --git a/deepspeed/ops/lion/multi_tensor_apply.py b/deepspeed/ops/lion/multi_tensor_apply.py new file mode 100644 index 000000000..0ba228505 --- /dev/null +++ b/deepspeed/ops/lion/multi_tensor_apply.py @@ -0,0 +1,17 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +Copyright NVIDIA/apex +This file is adapted from NVIDIA/apex, commit a109f85 +""" + + +class MultiTensorApply(object): + + def __init__(self, chunk_size): + self.chunk_size = chunk_size + + def __call__(self, op, noop_flag_buffer, tensor_lists, *args): + return op(self.chunk_size, noop_flag_buffer, tensor_lists, *args) diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 77f1cd51b..c31b96712 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -77,9 +77,10 @@ ONEBIT_LAMB_OPTIMIZER = 'onebitlamb' MUADAM_OPTIMIZER = 'muadam' MUADAMW_OPTIMIZER = 'muadamw' MUSGD_OPTIMIZER = 'musgd' +LION_OPTIMIZER = 'lion' DEEPSPEED_OPTIMIZERS = [ ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, - ZERO_ONE_ADAM_OPTIMIZER, MUADAM_OPTIMIZER, MUADAMW_OPTIMIZER, MUSGD_OPTIMIZER + ZERO_ONE_ADAM_OPTIMIZER, MUADAM_OPTIMIZER, MUADAMW_OPTIMIZER, MUSGD_OPTIMIZER, LION_OPTIMIZER ] # extra optimizer parameters for adam/adamw diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 22d7c882e..b4c8ef56c 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -37,7 +37,8 @@ from deepspeed.runtime.bf16_optimizer import BF16_Optimizer from deepspeed.runtime.config import DEEPSPEED_OPTIMIZERS, \ ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \ - TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT, ZERO_ONE_ADAM_OPTIMIZER, MUADAM_OPTIMIZER, MUADAMW_OPTIMIZER, MUSGD_OPTIMIZER + TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT, ZERO_ONE_ADAM_OPTIMIZER, MUADAM_OPTIMIZER, MUADAMW_OPTIMIZER, \ + MUSGD_OPTIMIZER, LION_OPTIMIZER from deepspeed.runtime.dataloader import DeepSpeedDataLoader from deepspeed.runtime.constants import \ @@ -1296,6 +1297,13 @@ class DeepSpeedEngine(Module): optimizer = OnebitLamb(model_parameters, self, **optimizer_parameters) if not self.fp16_enabled(): logger.warning(f"Currently the convergence of 1-bit Lamb is only verified under FP16") + elif self.optimizer_name() == LION_OPTIMIZER: + if self.zero_use_cpu_optimizer(): + from deepspeed.ops.lion import DeepSpeedCPULion + optimizer = DeepSpeedCPULion(model_parameters, **optimizer_parameters) + else: + from deepspeed.ops.lion import FusedLion + optimizer = FusedLion(model_parameters, **optimizer_parameters) elif self.optimizer_name() == MUADAM_OPTIMIZER: try: from mup import MuAdam diff --git a/deepspeed/runtime/zero/utils.py b/deepspeed/runtime/zero/utils.py index 0250796f7..0bf1ca4a8 100755 --- a/deepspeed/runtime/zero/utils.py +++ b/deepspeed/runtime/zero/utils.py @@ -12,6 +12,7 @@ from deepspeed.utils import logger from deepspeed.ops.adam import DeepSpeedCPUAdam from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad from deepspeed.ops.adam import FusedAdam +from deepspeed.ops.lion import DeepSpeedCPULion, FusedLion from deepspeed.utils.nvtx import instrument_w_nvtx from deepspeed.accelerator import get_accelerator @@ -37,7 +38,8 @@ class ZeRORuntimeException(Exception): ZERO_SUPPORTED_OPTIMIZERS = [ - torch.optim.Adam, torch.optim.AdamW, FusedAdam, DeepSpeedCPUAdam, torch.optim.Adagrad, DeepSpeedCPUAdagrad + torch.optim.Adam, torch.optim.AdamW, FusedAdam, DeepSpeedCPUAdam, torch.optim.Adagrad, DeepSpeedCPUAdagrad, + DeepSpeedCPULion, FusedLion ] # Add apex FusedAdam to supported list if apex is installed diff --git a/docs/_tutorials/advanced-install.md b/docs/_tutorials/advanced-install.md index 055843399..c2b4c04ca 100755 --- a/docs/_tutorials/advanced-install.md +++ b/docs/_tutorials/advanced-install.md @@ -60,7 +60,9 @@ Available `DS_BUILD` options include: * `DS_BUILD_AIO` builds asynchronous (NVMe) I/O op * `DS_BUILD_CCL_COMM` builds the communication collective libs * `DS_BUILD_CPU_ADAM` builds the CPUAdam op +* `DS_BUILD_CPU_LION` builds the CPULion op * `DS_BUILD_FUSED_ADAM` builds the FusedAdam op (from [apex](https://github.com/NVIDIA/apex)) +* `DS_BUILD_FUSED_LION` builds the FusedLion op * `DS_BUILD_CPU_ADAGRAD` builds the CPUAdagrad op * `DS_BUILD_FUSED_LAMB` builds the FusedLamb op * `DS_BUILD_QUANTIZER` builds the quantizer op diff --git a/op_builder/cpu_lion.py b/op_builder/cpu_lion.py new file mode 100644 index 000000000..5c16d10eb --- /dev/null +++ b/op_builder/cpu_lion.py @@ -0,0 +1,48 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +from .builder import TorchCPUOpBuilder + + +class CPULionBuilder(TorchCPUOpBuilder): + BUILD_VAR = "DS_BUILD_CPU_LION" + NAME = "cpu_lion" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.lion.{self.NAME}_op' + + def sources(self): + if self.build_for_cpu: + return ['csrc/lion/cpu_lion.cpp', 'csrc/lion/cpu_lion_impl.cpp'] + + return ['csrc/lion/cpu_lion.cpp', 'csrc/lion/cpu_lion_impl.cpp', 'csrc/common/custom_cuda_kernel.cu'] + + def libraries_args(self): + args = super().libraries_args() + if self.build_for_cpu: + return args + + if not self.is_rocm_pytorch(): + args += ['curand'] + + return args + + def include_paths(self): + import torch + if self.build_for_cpu: + CUDA_INCLUDE = [] + elif not self.is_rocm_pytorch(): + CUDA_INCLUDE = [os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include")] + else: + CUDA_INCLUDE = [ + os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include"), + os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include", "rocrand"), + os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include", "hiprand"), + ] + return ['csrc/includes'] + CUDA_INCLUDE diff --git a/op_builder/fused_lion.py b/op_builder/fused_lion.py new file mode 100644 index 000000000..b900a8f23 --- /dev/null +++ b/op_builder/fused_lion.py @@ -0,0 +1,37 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import CUDAOpBuilder + +import sys + + +class FusedLionBuilder(CUDAOpBuilder): + BUILD_VAR = "DS_BUILD_FUSED_LION" + NAME = "fused_lion" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.lion.{self.NAME}_op' + + def sources(self): + return ['csrc/lion/fused_lion_frontend.cpp', 'csrc/lion/multi_tensor_lion.cu'] + + def include_paths(self): + return ['csrc/includes', 'csrc/lion'] + + def cxx_args(self): + args = super().cxx_args() + return args + self.version_dependent_macros() + + def nvcc_args(self): + nvcc_flags = ['-O3'] + self.version_dependent_macros() + if not self.is_rocm_pytorch(): + nvcc_flags.extend( + ['-allow-unsupported-compiler' if sys.platform == "win32" else '', '-lineinfo', '--use_fast_math'] + + self.compute_capability_args()) + return nvcc_flags diff --git a/tests/unit/ops/lion/test_cpu_lion.py b/tests/unit/ops/lion/test_cpu_lion.py new file mode 100644 index 000000000..61a069af3 --- /dev/null +++ b/tests/unit/ops/lion/test_cpu_lion.py @@ -0,0 +1,96 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import numpy as np +import pytest +from cpuinfo import get_cpu_info + +import deepspeed +from deepspeed.accelerator import get_accelerator +from deepspeed.ops.lion import FusedLion +from deepspeed.ops.op_builder import CPULionBuilder +from unit.common import DistributedTest + +if not deepspeed.ops.__compatible_ops__[CPULionBuilder.NAME]: + pytest.skip("cpu-lion is not compatible", allow_module_level=True) + +pytest.cpu_vendor = get_cpu_info()["vendor_id_raw"].lower() + + +def check_equal(first, second, atol=1e-2, verbose=False): + x = first.detach().numpy() + y = second.detach().numpy() + print("ATOL", atol) + if verbose: + print("x = {}".format(x.flatten())) + print("y = {}".format(y.flatten())) + print('-' * 80) + np.testing.assert_allclose(x, y, err_msg="param-update mismatch!", atol=atol) + + +def _compare_optimizers(model_size, param1, optimizer1, param2, optimizer2): + for i in range(10): + param1.grad = torch.randn(model_size, device=param1.device).to(param1.dtype) + param2.grad = param1.grad.clone().detach().to(device=param2.device, dtype=param2.dtype) + + optimizer1.step() + optimizer2.step() + + tolerance = param1.float().norm().detach().numpy() * 1e-2 + check_equal(param1.float().norm(), param2.float().cpu().norm(), atol=tolerance, verbose=True) + + +@pytest.mark.parametrize('dtype', [torch.half, torch.float], ids=["fp16", "fp32"]) +@pytest.mark.parametrize('model_size', + [ + (64), + (22), + #(55), + (128), + (1024), + (1048576), + ]) # yapf: disable +class TestCPULion(DistributedTest): + world_size = 1 + reuse_dist_env = True + requires_cuda_env = False + if not get_accelerator().is_available(): + init_distributed = False + set_dist_env = False + + @pytest.mark.skipif(not get_accelerator().is_available(), reason="only supported in CUDA environments.") + def test_fused_lion_equal(self, dtype, model_size): + if ("amd" in pytest.cpu_vendor) and (dtype == torch.half): + pytest.skip("cpu-lion with half precision not supported on AMD CPUs") + + from deepspeed.ops.lion import DeepSpeedCPULion + + cpu_data = torch.randn(model_size, device='cpu').to(dtype) + cpu_param = torch.nn.Parameter(cpu_data) + cuda_param = torch.nn.Parameter(cpu_data.to(get_accelerator().device_name())) + + cpu_optimizer = DeepSpeedCPULion([cpu_param]) + cuda_optimizer = FusedLion([cuda_param]) + + _compare_optimizers(model_size=model_size, + param1=cpu_param, + optimizer1=cpu_optimizer, + param2=cuda_param, + optimizer2=cuda_optimizer) + + +class TestCPULionGPUError(DistributedTest): + + def test_cpu_lion_gpu_error(self): + model_size = 64 + from deepspeed.ops.lion import DeepSpeedCPULion + device = get_accelerator().device_name(0) # 'cuda:0' or 'xpu:0' + param = torch.nn.Parameter(torch.randn(model_size, device=device)) + optimizer = DeepSpeedCPULion([param]) + + param.grad = torch.randn(model_size, device=device) + with pytest.raises(AssertionError): + optimizer.step() diff --git a/tests/unit/ops/lion/test_lion.py b/tests/unit/ops/lion/test_lion.py new file mode 100644 index 000000000..b2c3ac2f5 --- /dev/null +++ b/tests/unit/ops/lion/test_lion.py @@ -0,0 +1,59 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import deepspeed +import torch +import pytest + +from deepspeed.ops.lion import FusedLion +from deepspeed.ops.lion import DeepSpeedCPULion +from unit.common import DistributedTest +from unit.simple_model import SimpleModel +from deepspeed.accelerator import get_accelerator + +if torch.half not in get_accelerator().supported_dtypes(): + pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True) +# yapf: disable +#'optimizer, zero_offload, resulting_optimizer +lion_configs = [["Lion", False, FusedLion], + ["Lion", True, DeepSpeedCPULion]] + +@pytest.mark.parametrize( + 'optimizer, zero_offload, resulting_optimizer', + lion_configs) +class TestLionConfigs(DistributedTest): + world_size = 1 + reuse_dist_env = True + + def test(self, + optimizer, + zero_offload, + resulting_optimizer): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": optimizer, + "params": { + "lr": 0.00015, + } + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": True + }, + "zero_optimization": { + "stage": 2, + "cpu_offload": zero_offload + } + } + model = SimpleModel(10) + model, _, _, _ = deepspeed.initialize(config=config_dict, + model=model, + model_parameters=model.parameters()) + # get base optimizer under zero + ds_optimizer = model.optimizer.optimizer + opt_class = resulting_optimizer + assert isinstance(ds_optimizer, opt_class)