### Integration of LoCo Method into ZeRO++

#### Overview
This PR introduces the integration of the **LoCo** method, as outlined
in [this paper](https://arxiv.org/abs/2407.04480), into the ZeRO++
framework of DeepSpeed. The key enhancement involves applying error
feedback compensation to 4-bit gradients before communication. This
approach ***improves pre-training loss outcomes without additional time
overhead***, though it requires extra GPU memory. The extent of this
memory increase depends on model size and training configuration.

#### Experimental Results
We conducted pre-training experiments using the Llama2 architecture,
adjusting the number of layers and hidden size. The experiments
included:
- **A smaller-scale model with 0.8B parameters trained on 30B tokens**.
- **A larger-scale model with 8B parameters trained on 5B tokens**.

The training data was sampled from **Redpajama-V2**.
<p align="center">
<img
src="https://github.com/user-attachments/assets/e7db9487-728c-4a17-9806-c15afa12f62e"
width="49%" />
<img
src="https://github.com/user-attachments/assets/3efec895-b71d-43ab-b5ce-65468ba8b9f1"
width="49%" />
</p>

**Findings**:
- **Smaller Models (0.8B parameters)**: Significant gains were observed
when applying the LoCo method.
- **Larger Models (8B parameters)**: The gains were present but less
pronounced. This could be due to:
  1. Relatively smaller data volume.
2. Lower pre-training loss for larger models, making significant
improvements harder to achieve.

However, even a smaller pre-training loss gap in larger models can
translate to meaningful gains in downstream tasks.

#### Example Script
For reference, the
[run.sh](https://github.com/user-attachments/files/17679552/zeroplus-7b3.zip)
script used for the 8B parameter, 5B tokens experiment is attached. The
experiment was conducted using the **DeepSpeed-Megatron** platform.



#### Acknowledgments
Special thanks to cc @GuanhuaWang for ongoing communication and guidance
throughout this work.

---

We appreciate your consideration of this PR and welcome any feedback or
questions!

---------

Co-authored-by: ChuanxinTang <tangchuanxin.chn@gmail.com>
Co-authored-by: root <pan.jiachun@outlook.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Logan Adams <loadams@microsoft.com>
Co-authored-by: Hongwei Chen <33092912+hwchen2017@users.noreply.github.com>
This commit is contained in:
xyxie 2024-12-11 02:31:11 +08:00 коммит произвёл GitHub
Родитель 06f1d3609e
Коммит 1b58ba5ec0
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
10 изменённых файлов: 851 добавлений и 5 удалений

Просмотреть файл

@ -52,6 +52,36 @@ void launch_swizzled_quant(int8_t* q_data,
int devices_per_node,
cudaStream_t stream);
void launch_loco_swizzled_quant(int8_t* quantized_data,
float* quantized_scales,
const __half* uncompressed_data,
__half* error_feedback,
const float err_beta,
int num_bits,
quantize::Type quant_type,
int groups,
int elems_per_group,
int pipelining,
int nodes,
int devices_per_node,
cudaStream_t stream);
void launch_loco_dequant_reduce(int8_t* reduced_data,
float* reduced_scales,
const int8_t* input_data,
const float* input_scales,
int num_gpus,
int num_bits,
quantize::Type quant_type,
int out_groups,
int elems_per_out_group,
int elems_per_in_tensor,
int groups_per_in_tensor,
int elems_per_in_group,
__half2* error_feedback,
const float err_beta,
cudaStream_t stream);
void launch_dequant_reduce(int8_t* reduced_data,
float* reduced_scales,
const int8_t* input_data,

Просмотреть файл

@ -24,6 +24,7 @@ constexpr int max_threads = 1024;
Class to hold the quantization parameters for a given tensor.
Holds the implementation of the quantization operation.
*/
template <Type qType, int numBits>
class Params {
public:

Просмотреть файл

@ -176,6 +176,53 @@ at::Tensor dequantize_int8_to_half_experimental(at::Tensor& data_in,
return output;
}
std::vector<at::Tensor> ds_loco_swizzle_quant(at::Tensor& input_vals,
at::Tensor& error_feedback,
float err_beta,
int groups,
int num_bits,
quantize::Type quant_type,
int pipeline_size,
int nodes,
int devices_per_node)
{
auto scales_options = at::TensorOptions()
.dtype(at::kFloat)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
const int scales_elems = (quantize::requires_offset(quant_type)) ? 2 : 1;
auto scales = torch::empty({groups, scales_elems}, scales_options);
auto output_options = at::TensorOptions()
.dtype(at::kChar)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
const int quantization_scalar = 8 / num_bits;
const int compressed_vals = at::numel(input_vals) / quantization_scalar;
auto output = torch::empty({compressed_vals}, output_options);
const int elems_per_group = at::numel(input_vals) / groups;
launch_loco_swizzled_quant(reinterpret_cast<int8_t*>(output.data_ptr()),
reinterpret_cast<float*>(scales.data_ptr()),
reinterpret_cast<const __half*>(input_vals.data_ptr()),
reinterpret_cast<__half*>(error_feedback.data_ptr()),
err_beta,
num_bits,
quant_type,
groups,
elems_per_group,
pipeline_size,
nodes,
devices_per_node,
at::cuda::getCurrentCUDAStream());
return {output, scales};
}
std::vector<at::Tensor> ds_swizzle_quant(at::Tensor& input_vals,
int groups,
int num_bits,
@ -265,6 +312,61 @@ std::vector<at::Tensor> quantized_reduction(at::Tensor& input_vals,
return {output, scales};
}
std::vector<at::Tensor> loco_quantized_reduction(at::Tensor& input_vals,
at::Tensor& input_scales,
at::Tensor& error_feedback,
float err_beta,
int in_groups,
int out_groups,
int num_bits,
quantize::Type quant_type,
int devices_per_node)
{
auto scales_options = at::TensorOptions()
.dtype(at::kFloat)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
const int scales_elems = (quantize::requires_offset(quant_type)) ? 2 : 1;
auto scales = torch::empty({out_groups, scales_elems}, scales_options);
auto output_options = at::TensorOptions()
.dtype(at::kChar)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
std::vector<int64_t> sz(input_vals.sizes().begin(), input_vals.sizes().end());
sz[sz.size() - 1] = sz.back() / devices_per_node;
const int elems_per_in_tensor = at::numel(input_vals) / devices_per_node;
auto output = torch::empty(sz, output_options);
const int elems_per_in_group = elems_per_in_tensor / (in_groups / devices_per_node);
const int elems_per_out_group = elems_per_in_tensor / out_groups;
launch_loco_dequant_reduce((int8_t*)output.data_ptr(),
(float*)scales.data_ptr(),
(const int8_t*)input_vals.data_ptr(),
(const float*)input_scales.data_ptr(),
devices_per_node,
num_bits,
quant_type,
out_groups,
elems_per_out_group,
elems_per_in_tensor,
in_groups / devices_per_node,
elems_per_in_group,
(__half2*)error_feedback.data_ptr(),
err_beta,
at::cuda::getCurrentCUDAStream());
return {output, scales};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("ds_quantize_fp32", &ds_quantize<float>, "DeepSpeed Quantize with fp32 (CUDA)");
@ -295,4 +397,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
"Dequantize int8 to half (experimental)");
m.def("swizzle_quant", &ds_swizzle_quant);
m.def("quantized_reduction", &quantized_reduction);
m.def("loco_swizzle_quant", &ds_loco_swizzle_quant, "LoCo Swizzled Quantization Kernel");
m.def("loco_quantized_reduction",
&loco_quantized_reduction,
"LoCo Quantization and Reduction Kernel");
}

Просмотреть файл

@ -261,3 +261,297 @@ void launch_dequant_reduce(int8_t* reduced_data,
}
}
}
/*
Modified loco_dequant_reduce function that performs dequantization and reduction,
and incorporates error-feedback by updating the error_feedback tensor in-place.
*/
template <int numBits, int numTensors, int totalChunks, quantize::Type quantType>
__global__ void __launch_bounds__(1024) loco_dequant_reduce(int8_t* reduced_data,
float* reduced_scales,
const int8_t* input_data,
const float* input_scales,
int elems_per_out_group,
int elems_per_in_tensor,
int groups_per_in_tensor,
int elems_per_in_group,
int num_tensors,
__half2* error_feedback,
const float err_beta)
{
cg::thread_block tb = cg::this_thread_block();
cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);
constexpr int mem_granularity = (numBits == 8) ? 8 : 4;
constexpr int elems_per_load = mem_granularity / sizeof(int8_t);
constexpr int storage_values = 16 / sizeof(__half2);
const int block_offset = tb.group_index().x * elems_per_out_group;
const int elem_offset = tb.thread_index().x * elems_per_load;
const int base_offset = block_offset + elem_offset;
const int stride = tb.group_dim().x * elems_per_load;
constexpr int scaling_factor = elems_per_load / storage_values;
const int block_offset_err = block_offset / scaling_factor;
const int elem_offset_err = tb.thread_index().x * storage_values;
const int base_offset_err = block_offset_err + elem_offset_err;
const int stride_err = tb.group_dim().x * storage_values;
__half2 local_buffer[totalChunks * storage_values];
__half2 err_buffer[totalChunks * storage_values];
quantize::GroupStats<quantType> stats;
#pragma unroll
for (int i = 0; i < totalChunks; i++) {
__half2* iteration_buffer = local_buffer + i * storage_values;
__half2* iter_err_buffer = err_buffer + i * storage_values;
#pragma unroll
for (int j = 0; j < storage_values; j++) {
iteration_buffer[j] = reduce::init<rop::Add, __half2>();
}
const int iter_offset = i * stride + base_offset;
const int iter_offset_err = i * stride_err + base_offset_err;
const int iter_scale_idx = iter_offset / elems_per_in_group;
bool do_loads = i * stride + elem_offset < elems_per_out_group;
if (numTensors > 0) {
#pragma unroll
for (int j = 0; j < numTensors; j++) {
if (do_loads) {
int8_t load_buffer[elems_per_load];
mem_access::load_global<mem_granularity>(
load_buffer, input_data + j * elems_per_in_tensor + iter_offset);
quantize::Params<quantType, numBits> params(
input_scales + j * groups_per_in_tensor, iter_scale_idx);
__half2 dequant_buffer[storage_values];
dequantize::chunk<numBits, quantType>(dequant_buffer, load_buffer, params);
#pragma unroll
for (int k = 0; k < storage_values; k++) {
iteration_buffer[k] =
reduce::element<rop::Add>(iteration_buffer[k], dequant_buffer[k]);
}
}
}
} else {
#pragma unroll 4
for (int j = 0; j < num_tensors; j++) {
if (do_loads) {
int8_t load_buffer[elems_per_load];
mem_access::load_global<mem_granularity>(
load_buffer, input_data + j * elems_per_in_tensor + iter_offset);
quantize::Params<quantType, numBits> params(
input_scales + j * groups_per_in_tensor, iter_scale_idx);
__half2 dequant_buffer[storage_values];
dequantize::chunk<numBits, quantType>(dequant_buffer, load_buffer, params);
#pragma unroll
for (int k = 0; k < storage_values; k++) {
iteration_buffer[k] =
reduce::element<rop::Add>(iteration_buffer[k], dequant_buffer[k]);
}
}
}
}
mem_access::load_global<quantize::granularity>(
iter_err_buffer, error_feedback + iter_offset_err, do_loads);
#pragma unroll
for (int k = 0; k < storage_values; k++) {
iteration_buffer[k] = __hadd2(iteration_buffer[k], iter_err_buffer[k]);
stats.update(iteration_buffer[k]);
}
}
auto params = stats.template get_params<numBits, 1024>(tb, warp);
// Initialize dequantization parameters based on params
auto de_params = params;
de_params.scale = 1.0f / params.scale;
if constexpr (quantType == quantize::Type::Asymmetric) { de_params.offset = params.offset; }
if (tb.thread_index().x == 0) { params.store(reduced_scales, tb.group_index().x); }
#pragma unroll
for (int i = 0; i < totalChunks; i++) {
const int iter_offset = i * stride + base_offset;
const int iter_offset_err = i * stride_err + base_offset_err;
__half2* iteration_buffer = local_buffer + i * storage_values;
__half2* iter_err_buffer = err_buffer + i * storage_values;
if (i * stride + elem_offset < elems_per_out_group) {
// ----------- Begin Error-Feedback Modification -----------
int8_t local_output[elems_per_load];
quantize::_chunk<numBits, quantType>(local_output, iteration_buffer, params);
mem_access::store_global<mem_granularity>(reduced_data + iter_offset, local_output);
// Dequantize the quantized output to compute the dequantized value
__half2 dequant_buffer[storage_values];
dequantize::chunk<numBits, quantType>(dequant_buffer, local_output, de_params);
#pragma unroll
for (int k = 0; k < storage_values; k++) {
// __half2 to float2
float2 iter_buf_f = __half22float2(iteration_buffer[k]);
float2 dequant_buf_f = __half22float2(dequant_buffer[k]);
// Update within float precision
float2 new_error_f;
new_error_f.x = iter_buf_f.x - dequant_buf_f.x;
new_error_f.y = iter_buf_f.y - dequant_buf_f.y;
float2 iter_err_buf_f = __half22float2(iter_err_buffer[k]);
iter_err_buf_f.x = err_beta * iter_err_buf_f.x + (1.0f - err_beta) * new_error_f.x;
iter_err_buf_f.y = err_beta * iter_err_buf_f.y + (1.0f - err_beta) * new_error_f.y;
// float2 back to __half2
iter_err_buffer[k] = __float22half2_rn(iter_err_buf_f);
}
mem_access::store_global<quantize::granularity>(error_feedback + iter_offset_err,
iter_err_buffer);
}
}
}
#define LAUNCH_LOCO_DEQUANT_REDUCE(num_chunks) \
loco_dequant_reduce<numBits, numTensors, num_chunks, quantType> \
<<<grid, block, 0, stream>>>(reduced_data, \
reduced_scales, \
input_data, \
input_scales, \
elems_per_out_group, \
elems_per_in_tensor, \
groups_per_in_tensor, \
elems_per_in_group, \
num_tensors, \
error_feedback, \
err_beta);
template <int numBits, int numTensors, quantize::Type quantType>
void launch_loco_dequant_reduce_impl(int8_t* reduced_data,
float* reduced_scales,
const int8_t* input_data,
const float* input_scales,
int out_groups,
int elems_per_out_group,
int elems_per_in_tensor,
int groups_per_in_tensor,
int elems_per_in_group,
int num_tensors,
__half2* error_feedback,
const float err_beta,
cudaStream_t stream)
{
constexpr int elems_per_thread = numBits;
const int one_step_threads =
next_pow2((elems_per_out_group + elems_per_thread - 1) / (elems_per_thread));
const int threads = (one_step_threads < 1024) ? one_step_threads : 1024;
dim3 block(threads);
dim3 grid(out_groups);
const int elems_per_step = threads * elems_per_thread;
const int unroll_raw = (elems_per_out_group + elems_per_step - 1) / elems_per_step;
const int unroll = (unroll_raw >= 4) ? pow2_round<1>(unroll_raw) : unroll_raw;
if (unroll == 1) {
LAUNCH_LOCO_DEQUANT_REDUCE(1);
} else if (unroll == 2) {
LAUNCH_LOCO_DEQUANT_REDUCE(2);
} else if (unroll == 3) {
LAUNCH_LOCO_DEQUANT_REDUCE(3);
} else if (unroll == 4) {
LAUNCH_LOCO_DEQUANT_REDUCE(4);
} else if (unroll == 6) {
LAUNCH_LOCO_DEQUANT_REDUCE(6);
} else if (unroll == 8) {
LAUNCH_LOCO_DEQUANT_REDUCE(8);
} else if (unroll == 10) {
LAUNCH_LOCO_DEQUANT_REDUCE(10);
} else if (unroll == 12) {
LAUNCH_LOCO_DEQUANT_REDUCE(12);
} else {
assert(false);
}
}
#define LAUNCH_LOCO_DEQUANT_REDUCE_IMPL(NUM_BITS, NUM_GPUS, QUANT_TYPE) \
launch_loco_dequant_reduce_impl<NUM_BITS, NUM_GPUS, QUANT_TYPE>(reduced_data, \
reduced_scales, \
input_data, \
input_scales, \
out_groups, \
elems_per_out_group, \
elems_per_in_tensor, \
groups_per_in_tensor, \
elems_per_in_group, \
num_gpus, \
error_feedback, \
err_beta, \
stream);
void launch_loco_dequant_reduce(int8_t* reduced_data,
float* reduced_scales,
const int8_t* input_data,
const float* input_scales,
int num_gpus,
int num_bits,
quantize::Type quant_type,
int out_groups,
int elems_per_out_group,
int elems_per_in_tensor,
int groups_per_in_tensor,
int elems_per_in_group,
__half2* error_feedback,
const float err_beta,
cudaStream_t stream)
{
if (quant_type == quantize::Type::Symmetric) {
if (num_bits == 4) {
if (num_gpus == 8) {
LAUNCH_LOCO_DEQUANT_REDUCE_IMPL(4, 8, quantize::Type::Symmetric);
} else if (num_gpus == 16) {
LAUNCH_LOCO_DEQUANT_REDUCE_IMPL(4, 16, quantize::Type::Symmetric);
} else {
LAUNCH_LOCO_DEQUANT_REDUCE_IMPL(4, -1, quantize::Type::Symmetric);
}
} else if (num_bits == 8) {
if (num_gpus == 8) {
LAUNCH_LOCO_DEQUANT_REDUCE_IMPL(8, 8, quantize::Type::Symmetric);
} else if (num_gpus == 16) {
LAUNCH_LOCO_DEQUANT_REDUCE_IMPL(8, 16, quantize::Type::Symmetric);
} else {
LAUNCH_LOCO_DEQUANT_REDUCE_IMPL(8, -1, quantize::Type::Symmetric);
}
}
} else if (quant_type == quantize::Type::Asymmetric) {
if (num_bits == 4) {
if (num_gpus == 8) {
LAUNCH_LOCO_DEQUANT_REDUCE_IMPL(4, 8, quantize::Type::Asymmetric);
} else if (num_gpus == 16) {
LAUNCH_LOCO_DEQUANT_REDUCE_IMPL(4, 16, quantize::Type::Asymmetric);
} else {
LAUNCH_LOCO_DEQUANT_REDUCE_IMPL(4, -1, quantize::Type::Asymmetric);
}
} else if (num_bits == 8) {
if (num_gpus == 8) {
LAUNCH_LOCO_DEQUANT_REDUCE_IMPL(8, 8, quantize::Type::Asymmetric);
} else if (num_gpus == 16) {
LAUNCH_LOCO_DEQUANT_REDUCE_IMPL(8, 16, quantize::Type::Asymmetric);
} else {
LAUNCH_LOCO_DEQUANT_REDUCE_IMPL(8, -1, quantize::Type::Asymmetric);
}
}
}
}

Просмотреть файл

@ -3,6 +3,7 @@
// DeepSpeed Team
#include "dequantization_utils.h"
#include "memory_access_utils.h"
#include "quantization_utils.h"
#include "reduction_utils.h"
@ -194,3 +195,233 @@ void launch_swizzled_quant(int8_t* q_data,
}
}
}
template <int numBits, int totalChunks, int threads, quantize::Type quantType>
__global__ void loco_swizzled_quant_kernel(int8_t* quantized_data,
float* quantized_scales,
const __half* uncompressed_data,
__half* error_feedback,
const float err_beta,
int groups,
int elems_per_group,
int pipelining,
int nodes,
int devices_per_node)
{
cg::thread_block tb = cg::this_thread_block();
cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);
// Indexing offsets, same as normal quantization for in-case
const int block_rank_data =
blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y;
const int block_offset_data = block_rank_data * elems_per_group;
const int elem_offset = tb.thread_index().x * quantize::h_per_load;
const int base_offset_data = block_offset_data + elem_offset;
const int stride = tb.size() * quantize::h_per_load;
const __half* uncompressed_data_base = uncompressed_data + base_offset_data;
const int partition_id = blockIdx.z;
const int partition_offset = partition_id / devices_per_node;
const int partition_base = (partition_id % devices_per_node) * nodes;
const int pipelining_offset = blockIdx.y * (devices_per_node * nodes);
const int output_partition = (pipelining_offset + partition_base + partition_offset);
const int block_rank_err = output_partition * gridDim.x + blockIdx.x;
const int block_offset_err = block_rank_err * elems_per_group;
const int base_offset_err = block_offset_err + elem_offset;
__half* error_feedback_base = error_feedback + base_offset_err;
__half2 local_buffer[totalChunks * quantize::h2_per_load];
__half2 err_buffer[totalChunks * quantize::h2_per_load];
quantize::GroupStats<quantType> stats;
#pragma unroll
for (int i = 0; i < totalChunks; i++) {
__half2* iteration_buffer = local_buffer + i * quantize::h2_per_load;
__half2* iter_err_buffer = err_buffer + i * quantize::h2_per_load;
const int i_stride = i * stride;
bool do_loads = (elem_offset + i_stride) < elems_per_group;
mem_access::load_global<quantize::granularity>(
iteration_buffer, uncompressed_data_base + i_stride, do_loads);
mem_access::load_global<quantize::granularity>(
iter_err_buffer, error_feedback_base + i_stride, do_loads);
#pragma unroll
for (int j = 0; j < quantize::h2_per_load; j++) {
iteration_buffer[j] = __hadd2(iteration_buffer[j], iter_err_buffer[j]);
stats.update(iteration_buffer[j]);
}
}
auto params = stats.template get_params<numBits, threads>(tb, warp);
// Initialize dequantization parameters based on params
auto de_params = params;
de_params.scale = 1.0f / params.scale;
if constexpr (quantType == quantize::Type::Asymmetric) { de_params.offset = params.offset; }
if (threadIdx.x == 0) { params.store(quantized_scales, block_rank_err); }
constexpr int out_scalar_effect = 8 / numBits;
const int out_block_offset = block_rank_err * elems_per_group / out_scalar_effect;
const int out_base_offset = out_block_offset + elem_offset / out_scalar_effect;
int8_t* out_base = quantized_data + out_base_offset;
const int out_stride = stride / out_scalar_effect;
constexpr int num_int8_out = quantize::h_per_load / out_scalar_effect;
#pragma unroll
for (int i = 0; i < totalChunks; i++) {
const int i_stride = i * stride;
__half2* iteration_buffer = local_buffer + i * quantize::h2_per_load;
__half2* iter_err_buffer = err_buffer + i * quantize::h2_per_load;
if (i_stride + elem_offset < elems_per_group) {
int8_t local_output[quantize::h_per_load / out_scalar_effect];
quantize::_chunk<numBits, quantType>(local_output, iteration_buffer, params);
mem_access::store_global<num_int8_out>(out_base + i * out_stride, local_output);
// Dequantize the quantized output to compute the dequantized value
__half2 dequant_buffer[quantize::h2_per_load];
dequantize::chunk<numBits, quantType>(dequant_buffer, local_output, de_params);
// Compute new error: sum - dequant_buffer
#pragma unroll
for (int k = 0; k < quantize::h2_per_load; k++) {
// __half2 to float2
float2 iter_buf_f = __half22float2(iteration_buffer[k]);
float2 dequant_buf_f = __half22float2(dequant_buffer[k]);
// Update within float precision
float2 new_error_f;
new_error_f.x = iter_buf_f.x - dequant_buf_f.x;
new_error_f.y = iter_buf_f.y - dequant_buf_f.y;
float2 iter_err_buf_f = __half22float2(iter_err_buffer[k]);
iter_err_buf_f.x = err_beta * iter_err_buf_f.x + (1.0f - err_beta) * new_error_f.x;
iter_err_buf_f.y = err_beta * iter_err_buf_f.y + (1.0f - err_beta) * new_error_f.y;
// float2 back to __half2
iter_err_buffer[k] = __float22half2_rn(iter_err_buf_f);
}
__half2* error_feedback_base_h2 = reinterpret_cast<__half2*>(error_feedback_base);
mem_access::store_global<quantize::granularity>(error_feedback_base_h2 + i_stride / 2,
iter_err_buffer);
}
}
}
#define LAUNCH_LOCO_SWIZZLE_QUANT(total_chunks, threads) \
loco_swizzled_quant_kernel<numBits, total_chunks, threads, qType> \
<<<grid, block, 0, stream>>>(output_data, \
params, \
input_data, \
error_feedback, \
err_beta, \
groups, \
elems_per_group, \
pipelining, \
nodes, \
devices_per_node);
template <int numBits, quantize::Type qType>
void launch_loco_swizzled_quant_impl(int8_t* output_data,
float* params,
const __half* input_data,
__half* error_feedback,
const float err_beta,
int groups,
int elems_per_group,
int pipelining,
int nodes,
int devices_per_node,
cudaStream_t stream)
{
const int one_step_threads =
next_pow2((elems_per_group + swiz_quant::h_per_step - 1) / swiz_quant::h_per_step);
const int max_threads = (one_step_threads < swiz_quant::max_threads) ? one_step_threads
: swiz_quant::max_threads;
const int threads = (max_threads < swiz_quant::min_threads) ? swiz_quant::min_threads
: max_threads;
dim3 block(threads);
const int groups_per_partition = groups / (nodes * devices_per_node);
assert(groups_per_partition % pipelining == 0);
const int contiguous_groups = groups_per_partition / pipelining;
const int partitions = nodes * devices_per_node;
dim3 grid(contiguous_groups, pipelining, partitions);
const int elems_per_step = threads * swiz_quant::h_per_step;
const int external_unroll = ((elems_per_group + elems_per_step - 1) / elems_per_step);
const int total_unroll = external_unroll * swiz_quant::step_granularity;
assert(total_unroll % 2 == 0);
if (threads == 32) {
LAUNCH_LOCO_SWIZZLE_QUANT(2, 32);
} else if (threads == 64) {
LAUNCH_LOCO_SWIZZLE_QUANT(2, 64);
} else if (threads == 128) {
LAUNCH_LOCO_SWIZZLE_QUANT(2, 128);
} else if (threads == 256) {
LAUNCH_LOCO_SWIZZLE_QUANT(2, 256);
} else if (threads == 512) {
if (total_unroll == 2) {
LAUNCH_LOCO_SWIZZLE_QUANT(2, 512);
} else if (total_unroll == 4) {
LAUNCH_LOCO_SWIZZLE_QUANT(4, 512);
} else if (total_unroll == 6) {
LAUNCH_LOCO_SWIZZLE_QUANT(6, 512);
} else if (total_unroll == 8) {
LAUNCH_LOCO_SWIZZLE_QUANT(8, 512);
} else if (total_unroll == 10) {
LAUNCH_LOCO_SWIZZLE_QUANT(10, 512);
}
}
}
#define DISPATCH_LOCO_SWIZZLE_QUANT(num_bits, qtype) \
launch_loco_swizzled_quant_impl<num_bits, qtype>(output_data, \
params, \
input_data, \
error_feedback, \
err_beta, \
groups, \
elems_per_group, \
pipelining, \
nodes, \
devices_per_node, \
stream);
void launch_loco_swizzled_quant(int8_t* output_data,
float* params,
const __half* input_data,
__half* error_feedback,
const float err_beta,
int num_bits,
quantize::Type q_type,
int groups,
int elems_per_group,
int pipelining,
int nodes,
int devices_per_node,
cudaStream_t stream)
{
if (num_bits == 4) {
if (q_type == quantize::Type::Asymmetric) {
DISPATCH_LOCO_SWIZZLE_QUANT(4, quantize::Type::Asymmetric);
} else if (q_type == quantize::Type::Symmetric) {
DISPATCH_LOCO_SWIZZLE_QUANT(4, quantize::Type::Symmetric);
}
} else if (num_bits == 8) {
if (q_type == quantize::Type::Asymmetric) {
DISPATCH_LOCO_SWIZZLE_QUANT(8, quantize::Type::Asymmetric);
} else if (q_type == quantize::Type::Symmetric) {
DISPATCH_LOCO_SWIZZLE_QUANT(8, quantize::Type::Symmetric);
}
}
}

Просмотреть файл

@ -8,7 +8,7 @@ bandwidth utilization
"""
import math
from typing import List
from typing import List, Any
import torch
from torch import Tensor
from deepspeed import comm as dist
@ -76,6 +76,83 @@ def all_to_all_quant_reduce(tensors: List[Tensor], groups: {}) -> List[Tensor]:
return output_lst
@instrument_w_nvtx
@torch.no_grad()
def all_to_all_loco_quant_reduce(
params: List[Tensor],
groups: {},
loco_param: Any = None,
) -> List[Tensor]:
global quantizer_module
global loco_idx
if quantizer_module is None:
quantizer_module = op_builder.QuantizerBuilder().load()
local_world_size = get_accelerator().device_count()
global_world_size = dist.get_world_size()
num_nodes = global_world_size // local_world_size
this_rank = dist.get_rank()
intra_idx = int(this_rank / local_world_size)
inter_idx = this_rank % local_world_size
output_lst: List[Tensor] = [None] * len(params)
for idx, p in enumerate(params):
tensor = p.grad
if tensor.dim() == 1:
output_lst[idx] = reduce_scatter_coalesced([tensor])[0]
elif tensor.numel() % (2 * global_world_size) != 0:
# Due to the constraint of 2-stage all-to-all, the input tensor must be divisible by 2 * global_world_size
# Otherwise, all-to-all cannot be performed because of shape mismatch.
# See more at https://github.com/microsoft/DeepSpeed/pull/5056
logger.warning(
f"qgZ falls back to reduce_scatter because tensor size = {tensor.numel()} is not divisible by (2 * global_world_size) = {2 * global_world_size}. Please consider allocating a new world to enable qgZ"
)
output_lst[idx] = reduce_scatter_coalesced([tensor])[0]
else:
err_beta = loco_param['err_beta']
reset_T = loco_param['reset_T']
if not hasattr(p, 'intra_ef_buf') or loco_idx > reset_T:
loco_idx = 0
intra_err = torch.zeros_like(p.grad)
inter_err = torch.zeros(tensor.numel() // local_world_size, device=tensor.device, dtype=tensor.dtype)
else:
intra_err = quantizer_module.dequantize(p.intra_ef_buf[0], p.intra_ef_buf[1],
p.intra_ef_buf[1].numel(), 8, quantizer_module.Symmetric)
inter_err = quantizer_module.dequantize(p.inter_ef_buf[0], p.inter_ef_buf[1],
p.inter_ef_buf[1].numel(), 8, quantizer_module.Symmetric)
intra_quant_group = max(tensor.shape[0], tensor.shape[1], global_world_size)
inter_quant_group = intra_quant_group // local_world_size
intra_quant_int4, intra_q_scales = quantizer_module.loco_swizzle_quant(tensor, intra_err, err_beta,
intra_quant_group, 4,
quantizer_module.Symmetric, 1,
num_nodes, local_world_size)
local_output = torch.empty_like(intra_quant_int4)
scale_output = torch.empty_like(intra_q_scales)
all_to_all_single(local_output, intra_quant_int4, group=groups[f'local_{intra_idx}'])
all_to_all_single(scale_output, intra_q_scales, group=groups[f'local_{intra_idx}'])
p.intra_ef_buf = quantizer_module.quantize(intra_err, intra_quant_group, 8, quantizer_module.Symmetric)
global_input_tensor, global_scales = quantizer_module.loco_quantized_reduction(
local_output, scale_output, inter_err, err_beta, intra_quant_group, inter_quant_group, 4,
quantizer_module.Symmetric, local_world_size)
global_output = torch.empty_like(global_input_tensor)
global_scale_output = torch.empty_like(global_scales)
all_to_all_single(global_output, global_input_tensor, group=groups[f'global_{inter_idx}'])
all_to_all_single(global_scale_output, global_scales, group=groups[f'global_{inter_idx}'])
p.inter_ef_buf = quantizer_module.quantize(inter_err, inter_quant_group, 8, quantizer_module.Symmetric)
final_output = quantizer_module.dequantize(global_output, global_scale_output, global_scale_output.numel(),
4, quantizer_module.Symmetric)
assert final_output.numel(
) % num_nodes == 0, f"final_output.numel()={final_output.numel()} is not divisible by num_nodes={num_nodes}"
output_lst[idx] = (sum(list(final_output.chunk(num_nodes))) / num_nodes).view(-1)
loco_idx += 1
return output_lst
@instrument_w_nvtx
@torch.no_grad()
def reduce_scatter_coalesced(

Просмотреть файл

@ -912,6 +912,9 @@ class DeepSpeedEngine(Module):
def zero_quantized_gradients(self):
return self._config.zero_config.zero_quantized_gradients
def zeropp_loco_param(self):
return self._config.zero_config.zeropp_loco_param
def dump_state(self):
return self._config.dump_state
@ -1191,7 +1194,8 @@ class DeepSpeedEngine(Module):
# Query the groups module to get information about various parallel groups
self.local_all_to_all_group = None
if self.zero_quantized_gradients():
log_dist("Using quantized gradients", ranks=[0])
message = "Using LoCo quantized gradients" if self.zeropp_loco_param() else "Using quantized gradients"
log_dist(message, ranks=[0])
self.local_all_to_all_group = groups._get_local_all_to_all_group()
self.data_parallel_group = groups._get_data_parallel_group()
self.dp_world_size = groups._get_data_parallel_world_size()
@ -1667,6 +1671,7 @@ class DeepSpeedEngine(Module):
zero_quantized_weights=self.zero_quantized_weights(),
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(),
zero_module_granularity_threshold=self.zero_module_granularity_threshold(),
zeropp_loco_param=self.zeropp_loco_param(),
)
else:

Просмотреть файл

@ -4,7 +4,7 @@
# DeepSpeed Team
import sys
from typing import Optional
from typing import Optional, Dict, Any
from enum import Enum
from pydantic import Field, model_validator
from deepspeed.runtime.config_utils import get_scalar_param, pp_int, DeepSpeedConfigModel
@ -44,6 +44,7 @@ ZeRO optimization should be enabled as:
"zero_quantized_gradients": [true|false],
"memory_efficient_linear": [true|false],
"override_module_apply": [true|false],
"zeropp_loco_param": {...},
}
}
"""
@ -310,6 +311,16 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
Boolean indicating whether to use quantized zero gradients
for efficient all_2_all_reduce comm
"""
zeropp_loco_param: Optional[Dict[str, Any]] = None
"""
This dictionary contains parameters for using LoCo-Zero++, with two key parameters:
- `err_beta`: A coefficient for the moving average of quantization errors before and after gradient computation.
It ranges between 0 and 1, with a default value of 0.8.
- `reset_T`: The number of steps after which the moving-average error buffer is cleared. The default value is 1024.
These parameters can be adjusted based on performance needs. Example configuration in ds config:
"zeropp_loco_param": { "err_beta": 0.8, "reset_T": 1024 }.
See LoCo paper for more details: (https://arxiv.org/abs/2407.04480).
"""
mics_shard_size: int = Field(-1, json_schema_extra={"new_param": "mics_shard_size"})

Просмотреть файл

@ -17,7 +17,7 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from deepspeed.runtime.base_optimizer import ZeROOptimizer
from deepspeed.utils import logger
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce
from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce, all_to_all_loco_quant_reduce
from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item
from deepspeed.runtime.zero.partition_parameters import *
from deepspeed.runtime.zero.config import ZeroStageEnum
@ -158,6 +158,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
zero_quantized_weights=False,
zero_quantized_nontrainable_weights=False,
zero_module_granularity_threshold=0,
zeropp_loco_param=None,
):
see_memory_usage("Stage 3 initialize beginning", force=True)
@ -284,6 +285,8 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
self.partition_count = dist.get_world_size(group=self.dp_process_group)
self.zeropp_loco_param = zeropp_loco_param
if mpu is None:
self.model_parallel_group = None
self.model_parallel_rank = 0
@ -1383,7 +1386,10 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
global_world_size = dist.get_world_size()
num_nodes = global_world_size // local_world_size
if self.all2all_process_group is not None and num_nodes > 1:
grad_partitions_for_rank = all_to_all_quant_reduce(full_grads_for_rank, self.all2all_process_group)
grad_partitions_for_rank = (all_to_all_loco_quant_reduce(params_to_reduce, self.all2all_process_group,
self.zeropp_loco_param)
if self.zeropp_loco_param is not None else all_to_all_quant_reduce(
full_grads_for_rank, self.all2all_process_group))
else:
grad_partitions_for_rank = reduce_scatter_coalesced(full_grads_for_rank, self.dp_process_group)
@ -2009,6 +2015,25 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
see_memory_usage('After overflow after clearing gradients', force=False)
def _loco_err_buf_update(self, overflow: bool, scale=1.0):
"""
Loco Error Buffer update.
"""
if not overflow and scale == 1.0: return
if dist.get_rank() == 0:
logger.info(f"update loco-zero++ error buffer with overflow: {overflow}")
# FP32 grad should never exist.
# For speed, set model fp16 grad to None by default
for group in self.fp16_groups:
for p in group:
if hasattr(p, 'intra_ef_buf'):
if overflow:
del p.intra_ef_buf
del p.inter_ef_buf
else:
p.intra_ef_buf[1] *= scale
p.inter_ef_buf[1] *= scale
@instrument_w_nvtx
def _overflow_check_and_loss_scale_update(self):
@ -2023,6 +2048,9 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
if self.overflow:
self._overflow_clean_up(prev_scale)
#update loco error buf
self._loco_err_buf_update(self.overflow, self.loss_scale / prev_scale)
return self.overflow
@instrument_w_nvtx

Просмотреть файл

@ -96,3 +96,66 @@ class TestAllToAllQuantReduceFallback(DistributedTest):
elif dist.get_rank() == 1:
assert output.shape == (24, )
assert torch.allclose(output, torch.zeros_like(output))
class TestLocoQuantized(DistributedTest):
world_size = 1
@pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("tensor_size", [(16, 16), (64, 64)])
@pytest.mark.parametrize("devices_per_node", [4, 8])
def test_loco_quantized_reduction(self, num_bits, tensor_size, devices_per_node):
from deepspeed.ops.op_builder import QuantizerBuilder
if not deepspeed.ops.__compatible_ops__[QuantizerBuilder.NAME]:
pytest.skip("QuantizerBuilder is not implemented")
quantizer_module = QuantizerBuilder().load()
tensor = torch.randn(tensor_size, device='cuda', dtype=torch.half)
num_nodes = 2 # Fake world size
total_elements = tensor.numel()
total_devices = devices_per_node * num_nodes
num_groups = max(tensor.shape[0], tensor.shape[1], total_devices)
# Initialize error_feedback tensor
error_feedback = torch.randn(tensor_size, device=tensor.device, dtype=tensor.dtype)
error_feedback_ori = error_feedback.clone()
# Swizzle the original tensor
tensor_reshaped = tensor.reshape(num_nodes, devices_per_node, total_elements // total_devices)
swizzled_tensor = tensor_reshaped.permute(1, 0, 2).reshape(tensor.size())
# Perform loco_swizzle_quant
output, scales = quantizer_module.loco_swizzle_quant(tensor, error_feedback, 0.0, num_groups, num_bits,
quantizer_module.Symmetric, 1, num_nodes,
devices_per_node)
# Compare swizzled_tensor with the output of loco_swizzle_quant
dequantized = quantizer_module.dequantize(output, scales, scales.numel(), num_bits,
quantizer_module.Symmetric).view(tensor.size())
assert torch.allclose(swizzled_tensor + error_feedback_ori, dequantized + error_feedback)
# Calculate elements per group and groups per partition
elements_per_group = total_elements // num_groups
groups_per_partition = num_groups // devices_per_node
# Reshape dequantized data to match the grouping in loco_quantized_reduction
dequantized_reshaped = dequantized.view(devices_per_node, groups_per_partition, elements_per_group)
# Perform reduction across devices_per_node dimension
reduced_dequantized = dequantized_reshaped.cumsum(dim=0)[-1]
# Initialize error_feedback tensor
error_feedback = torch.randn(reduced_dequantized.shape, device=tensor.device, dtype=dequantized.dtype)
error_feedback_ori = error_feedback.clone()
# perform loco_quantized_reduction
output, scales = quantizer_module.loco_quantized_reduction(output, scales, error_feedback, 0.0, num_groups,
num_groups // devices_per_node, num_bits,
quantizer_module.Symmetric, devices_per_node)
dequantized_reduced = quantizer_module.dequantize(output, scales, scales.numel(), num_bits,
quantizer_module.Symmetric).view(error_feedback.size())
assert torch.allclose(reduced_dequantized + error_feedback_ori, dequantized_reduced + error_feedback)