Flexible-bit quantizer-dequantizer library with fp6/fp12/fp8 support

Requires Ampere+ architecture, this is due to the initial focus of this
op only on `bfloat16` input types.

Co-authored-by: Reza Yazdani <reza.yazdani@snowflake.com>
This commit is contained in:
Jeff Rasley 2024-04-04 12:58:08 -07:00 коммит произвёл GitHub
Родитель 4621ba4cd4
Коммит 3fbd01ccca
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
12 изменённых файлов: 1014 добавлений и 1 удалений

2
.github/workflows/nv-pre-compile-ops.yml поставляемый
Просмотреть файл

@ -36,7 +36,7 @@ jobs:
#python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
- name: Compile DeepSpeed Ops
run: |
DS_ACCELERATOR=cuda DS_ENABLE_NINJA=1 TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_CUTLASS_OPS=0 DS_BUILD_RAGGED_DEVICE_OPS=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install .
DS_ACCELERATOR=cuda DS_ENABLE_NINJA=1 TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_FP_QUANTIZER=0 DS_BUILD_CUTLASS_OPS=0 DS_BUILD_RAGGED_DEVICE_OPS=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install .
- name: DS Report
run: |
ds_report

Просмотреть файл

@ -0,0 +1,66 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime_api.h>
#include <cassert>
#include <iostream>
#include <vector>
#include "cublas_v2.h"
#include "cuda.h"
#include "curand.h"
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <stdlib.h>
#include <sys/time.h>
#include <map>
#include <memory>
#include <stack>
#include <string>
#define WARP_SIZE 32
class FPContext {
public:
FPContext() : _seed(42)
{
curandCreateGenerator(&_gen, CURAND_RNG_PSEUDO_DEFAULT);
curandSetPseudoRandomGeneratorSeed(_gen, 123);
}
virtual ~FPContext() {}
static FPContext& Instance()
{
static FPContext _ctx;
return _ctx;
}
curandGenerator_t& GetRandGenerator() { return _gen; }
cudaStream_t GetCurrentStream()
{
// get current pytorch stream.
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
return stream;
}
std::pair<uint64_t, uint64_t> IncrementOffset(uint64_t offset_inc)
{
uint64_t offset = _curr_offset;
_curr_offset += offset_inc;
return std::pair<uint64_t, uint64_t>(_seed, offset);
}
void SetSeed(uint64_t new_seed) { _seed = new_seed; }
private:
curandGenerator_t _gen;
cublasHandle_t _cublasHandle;
uint64_t _seed;
uint64_t _curr_offset;
};

Просмотреть файл

@ -0,0 +1,115 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#pragma once
#include <cuda.h>
#include <stdint.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cuda_runtime_api.h>
#include <stdio.h>
#define QUANT_SWITCH(Q_BITS, ...) \
[&] { \
if (12 == Q_BITS) { \
constexpr int CONST_STOCHASTIC_ROUNDING = 0; \
constexpr int CONST_Q_BITS = 8; \
constexpr int CONST_Q_MANTISA_BITS = 3; \
__VA_ARGS__(); \
} else if (13 == Q_BITS) { \
constexpr int CONST_STOCHASTIC_ROUNDING = 1; \
constexpr int CONST_Q_BITS = 8; \
constexpr int CONST_Q_MANTISA_BITS = 3; \
__VA_ARGS__(); \
} else if (10 == Q_BITS) { \
constexpr int CONST_STOCHASTIC_ROUNDING = 0; \
constexpr int CONST_Q_BITS = 8; \
constexpr int CONST_Q_MANTISA_BITS = 2; \
__VA_ARGS__(); \
} else if (11 == Q_BITS) { \
constexpr int CONST_STOCHASTIC_ROUNDING = 1; \
constexpr int CONST_Q_BITS = 8; \
constexpr int CONST_Q_MANTISA_BITS = 2; \
__VA_ARGS__(); \
} else if (28 == Q_BITS) { \
constexpr int CONST_STOCHASTIC_ROUNDING = 0; \
constexpr int CONST_Q_BITS = 12; \
constexpr int CONST_Q_MANTISA_BITS = 7; \
__VA_ARGS__(); \
} else if (29 == Q_BITS) { \
constexpr int CONST_STOCHASTIC_ROUNDING = 1; \
constexpr int CONST_Q_BITS = 12; \
constexpr int CONST_Q_MANTISA_BITS = 7; \
__VA_ARGS__(); \
} else if (6 == Q_BITS) { \
constexpr int CONST_STOCHASTIC_ROUNDING = 0; \
constexpr int CONST_Q_BITS = 6; \
constexpr int CONST_Q_MANTISA_BITS = 2; \
__VA_ARGS__(); \
} else if (7 == Q_BITS) { \
constexpr int CONST_STOCHASTIC_ROUNDING = 1; \
constexpr int CONST_Q_BITS = 6; \
constexpr int CONST_Q_MANTISA_BITS = 2; \
__VA_ARGS__(); \
} else if (2 == Q_BITS) { \
constexpr int CONST_STOCHASTIC_ROUNDING = 0; \
constexpr int CONST_Q_BITS = 4; \
constexpr int CONST_Q_MANTISA_BITS = 1; \
__VA_ARGS__(); \
} else { \
constexpr int CONST_STOCHASTIC_ROUNDING = 1; \
constexpr int CONST_Q_BITS = 4; \
constexpr int CONST_Q_MANTISA_BITS = 1; \
__VA_ARGS__(); \
} \
}()
#define DEQUANT_SWITCH(Q_MANTISA_EXPONENT_BITS, ...) \
[&] { \
if (12 == Q_MANTISA_EXPONENT_BITS) { \
constexpr int CONST_Q_MANTISA_BITS = 3; \
constexpr int CONST_Q_EXPONENT_BITS = 4; \
__VA_ARGS__(); \
} else if (10 == Q_MANTISA_EXPONENT_BITS) { \
constexpr int CONST_Q_MANTISA_BITS = 2; \
constexpr int CONST_Q_EXPONENT_BITS = 5; \
__VA_ARGS__(); \
} else if (28 == Q_MANTISA_EXPONENT_BITS) { \
constexpr int CONST_Q_MANTISA_BITS = 7; \
constexpr int CONST_Q_EXPONENT_BITS = 4; \
__VA_ARGS__(); \
} else if (6 == Q_MANTISA_EXPONENT_BITS) { \
constexpr int CONST_Q_MANTISA_BITS = 2; \
constexpr int CONST_Q_EXPONENT_BITS = 3; \
__VA_ARGS__(); \
} else { \
constexpr int CONST_Q_MANTISA_BITS = 1; \
constexpr int CONST_Q_EXPONENT_BITS = 2; \
__VA_ARGS__(); \
} \
}()
template <typename T, int mantisa, int exponent>
void launch_quantization(T* val,
uint8_t* q_val,
int num_groups,
int group_size,
cudaStream_t stream,
float q_range,
int q_bits,
int q_mantisa_bits,
int stochastic_rounding);
template <typename T, int mantisa>
void launch_dequantization(uint8_t* val,
T* q_val,
int num_groups,
int group_size,
int q_mantisa_bits,
int q_exponent_bits,
cudaStream_t stream);

Просмотреть файл

@ -0,0 +1,85 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include "quantize.h"
#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>
#include <vector>
#define DISPATCH_QUANTIZE(T_TYPE, C_TYPE, mantisa, exponent) \
if (val.options().dtype() == torch::T_TYPE) { \
launch_quantization<C_TYPE, mantisa, exponent>((C_TYPE*)val.data_ptr(), \
(uint8_t*)out.data_ptr(), \
num_groups, \
group_size, \
at::cuda::getCurrentCUDAStream(), \
q_range, \
q_bits, \
q_mantisa_bits, \
stochastic_rounding); \
}
at::Tensor quantize(torch::Tensor& val,
int group_size,
int stochastic_rounding,
int q_bits,
int q_mantisa_bits)
{
int total_elems = at::numel(val);
auto options = at::TensorOptions()
.dtype(torch::kInt8)
.layout(val.layout())
.device(val.device())
.requires_grad(false);
float q_range = q_bits == 8 ? (q_mantisa_bits == 3 ? 480.0 : 114688.0) : // fp8 ranges
(q_bits == 12 ? 510.0 : // fp12 range
(q_bits == 6 ? 28.0 : // fp6 range
6.0)); // fp4 range (using power 2); TODO (Reza): add the power-4
// in case accuracy is not matching!
int num_groups = total_elems / group_size;
auto out = torch::empty({num_groups, group_size * q_bits / 8 + 4}, options);
DISPATCH_QUANTIZE(kHalf, __half, 23, 8);
#ifdef BF16_AVAILABLE
DISPATCH_QUANTIZE(kBFloat16, __nv_bfloat16, 23, 8);
#endif
return out;
}
#define DISPATCH_DEQUANTIZE(T_TYPE, C_TYPE, mantisa) \
if (val.options().dtype() == torch::T_TYPE) { \
launch_dequantization<C_TYPE, mantisa>((uint8_t*)val_q.data_ptr(), \
(C_TYPE*)val.data_ptr(), \
num_groups, \
group_size, \
q_mantisa_bits, \
q_exponent_bits, \
at::cuda::getCurrentCUDAStream()); \
return; \
}
void dequantize(torch::Tensor& val,
torch::Tensor& val_q,
int group_size,
int q_mantisa_bits,
int q_exponent_bits)
{
int total_elems = at::numel(val);
int num_groups = total_elems / group_size;
DISPATCH_DEQUANTIZE(kHalf, __half, 10);
#ifdef BF16_AVAILABLE
DISPATCH_DEQUANTIZE(kBFloat16, __nv_bfloat16, 7);
#endif
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("quantize", &quantize, "quantize function");
m.def("dequantize", &dequantize, "dequantize function");
}

Просмотреть файл

@ -0,0 +1,427 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include <stdexcept>
#include "context.h"
#include "memory_access_utils.h"
#include "quantize.h"
#include "reduction_utils.h"
#include <cuda.h>
#include <stdint.h>
#include <cuda_fp16.h>
#include <curand_kernel.h>
#include <cuda_bf16.h>
#include <cuda_runtime_api.h>
using ROp = reduce::ROpType;
namespace quantization {
constexpr int access_granularity = 16;
constexpr int quanitzed_access_granularity = 4;
constexpr int quanitzed_access_granularity_6bits = 2;
constexpr int threads = 256;
constexpr int warps = threads / 32;
} // namespace quantization
template <int _mantisa_bits, int q_mantisa_bits, int stochastic_rounding>
__device__ void round(uint32_t& mantisa, uint32_t& dst_exponent, curandStatePhilox4_32_10_t* state)
{
constexpr uint32_t mantisa_mask = (1 << (_mantisa_bits - q_mantisa_bits)) - 1;
uint32_t offset = stochastic_rounding ? (curand_poisson(state, 10) & mantisa_mask)
: 1 << (_mantisa_bits - q_mantisa_bits - 1);
mantisa += offset;
dst_exponent += (((mantisa & ~mantisa_mask) == (1 << _mantisa_bits)) ? 1 : 0);
}
template <int _mantisa_bits, int _exponent_bits, int q_mantisa_bits, int q_exponent_bits>
__device__ void clip(uint32_t& exponent, uint32_t& mantisa)
{
constexpr uint32_t max_exponent = (1 << (q_exponent_bits - 1)) + (1 << (_exponent_bits - 1));
constexpr uint32_t min_exponent =
(1 << (_exponent_bits - 1)) - ((1 << (q_exponent_bits - 1)) - 1);
if (exponent > max_exponent) {
exponent = max_exponent;
mantisa = (((uint32_t)-1) >> (32 - q_mantisa_bits)) << 1; //.11 .. 10
}
if (exponent < min_exponent) {
exponent = min_exponent;
mantisa = 0;
}
}
template <typename T,
int unroll,
int _mantisa_bits,
int _exponent_bits,
int total_q_bits = 8,
int q_mantisa_bits = 3,
int stochastic_rounding = 0>
__global__ void apply_quantization(T* val,
uint8_t* q_val,
int group_size,
std::pair<uint64_t, uint64_t> seed,
float q_range)
{
int tidx = threadIdx.x;
int wid = tidx >> 5;
int lane = tidx & 0x1f;
int gid = blockIdx.x * quantization::warps + wid;
constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1;
constexpr uint32_t _mantisa_mask = (1 << _mantisa_bits) - 1;
constexpr uint32_t _exponent_mask = ((1 << _exponent_bits) - 1) << _mantisa_bits;
constexpr uint32_t _sign_mask = 1 << (_mantisa_bits + _exponent_bits);
// CG helpers
cg::thread_block tb = cg::this_thread_block();
cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);
constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T);
constexpr uint32_t load_stride = vector_size * hw_warp_size;
constexpr uint32_t store_stride = (total_q_bits * vector_size / 8) * hw_warp_size;
const uint32_t thread_offset = lane * vector_size;
const uint32_t store_thread_offset = lane * (total_q_bits * vector_size / 8);
const uint32_t base_load_offset = gid * group_size + thread_offset;
const uint32_t base_store_offset =
gid * ((group_size * total_q_bits / 8) + 4) +
store_thread_offset; // 4-byte for saving the scale per group
const T* load_base_ptr = val + base_load_offset;
T tmp_buf[unroll * vector_size];
T cur_max;
reduce::init<ROp::Max>(&cur_max);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(seed.first, idx, seed.second, &state);
#pragma unroll
for (int i = 0; i < unroll; i++) {
if (i * load_stride + thread_offset < group_size) {
mem_access::load_global<quantization::access_granularity>(
&tmp_buf[vector_size * i], load_base_ptr + i * load_stride);
for (int j = 0; j < vector_size; j++)
cur_max = reduce::element<ROp::Max>(cur_max, __habs(tmp_buf[i * vector_size + j]));
}
}
reduce::_block<T, 1, ROp::Max>(tb, warp, &cur_max);
int mantisa_mask = ((1 << q_mantisa_bits) - 1);
mantisa_mask <<= (_mantisa_bits - q_mantisa_bits);
uint8_t* store_base_ptr = q_val + base_store_offset;
float scale = (float)q_range / conversion::to<float>(cur_max);
#pragma unroll
for (int i = 0; i < unroll; i++) {
if (i * load_stride + thread_offset < group_size) {
uint64_t q_buf = 0;
uint64_t q_buf1 = 0;
#pragma unroll
for (int j = 0; j < vector_size; j++) {
float val_f = conversion::to<float>(tmp_buf[i * vector_size + j]) * scale;
uint32_t* data = reinterpret_cast<uint32_t*>(&val_f);
uint32_t sign = (data[0] & _sign_mask) >> (_mantisa_bits + _exponent_bits);
uint32_t cur_exponent = (data[0] & _exponent_mask) >> _mantisa_bits;
uint32_t dst_mantisa = (data[0] & _mantisa_mask);
uint32_t dst_exponent = cur_exponent;
round<_mantisa_bits, q_mantisa_bits, stochastic_rounding>(
dst_mantisa, dst_exponent, &state);
if (cur_exponent != 0)
clip<_mantisa_bits, _exponent_bits, q_mantisa_bits, q_exponent_bits>(
dst_exponent, dst_mantisa);
dst_mantisa = (dst_mantisa & mantisa_mask) >> (_mantisa_bits - q_mantisa_bits);
if (dst_exponent != (1 << q_exponent_bits) - 1)
dst_exponent = (dst_exponent - ((1 << (_exponent_bits - 1)) - 1)) +
(1 << (q_exponent_bits - 1)) - 1;
if (total_q_bits == 8 || total_q_bits == 4 || total_q_bits == 6)
q_buf = q_buf |
((uint64_t)((uint8_t)(sign << (q_exponent_bits + q_mantisa_bits) |
(dst_exponent << q_mantisa_bits) | dst_mantisa))
<< j * total_q_bits);
else if (total_q_bits == 12) {
if (j < 5)
q_buf =
q_buf |
((uint64_t)((uint16_t)(sign << (q_exponent_bits + q_mantisa_bits) |
(dst_exponent << q_mantisa_bits) | dst_mantisa))
<< j * total_q_bits);
else
q_buf1 =
q_buf1 |
((uint64_t)((uint16_t)(sign << (q_exponent_bits + q_mantisa_bits) |
(dst_exponent << q_mantisa_bits) | dst_mantisa))
<< (j - 5) * total_q_bits);
}
}
if (total_q_bits == 12) {
uint64_t last_nibble_mask = 0xf;
last_nibble_mask = q_buf1 & last_nibble_mask;
q_buf = (last_nibble_mask << 60) | q_buf;
q_buf1 >>= 4;
}
uint8_t* int8_data = reinterpret_cast<uint8_t*>(&q_buf);
uint8_t* int8_data1 = reinterpret_cast<uint8_t*>(&q_buf1);
if (total_q_bits == 6) {
mem_access::store_global<quantization::quanitzed_access_granularity_6bits>(
store_base_ptr + i * store_stride, int8_data);
mem_access::store_global<quantization::quanitzed_access_granularity_6bits>(
store_base_ptr + i * store_stride +
quantization::quanitzed_access_granularity_6bits,
int8_data + quantization::quanitzed_access_granularity_6bits);
mem_access::store_global<quantization::quanitzed_access_granularity_6bits>(
store_base_ptr + i * store_stride +
quantization::quanitzed_access_granularity_6bits * 2,
int8_data + 2 * quantization::quanitzed_access_granularity_6bits);
} else {
mem_access::store_global<quantization::quanitzed_access_granularity>(
store_base_ptr + i * store_stride, int8_data);
if (total_q_bits > 4) {
mem_access::store_global<quantization::quanitzed_access_granularity>(
store_base_ptr + i * store_stride +
quantization::quanitzed_access_granularity,
int8_data + quantization::quanitzed_access_granularity);
if (total_q_bits == 12) {
mem_access::store_global<quantization::quanitzed_access_granularity>(
store_base_ptr + i * store_stride +
quantization::quanitzed_access_granularity * 2,
int8_data1);
}
}
}
}
}
if (lane == 0) {
float q_scale = conversion::to<float>(cur_max) / (float)q_range;
uint8_t* scale_as_int8 = reinterpret_cast<uint8_t*>(&q_scale);
uint32_t scale_offset =
gid * ((group_size * total_q_bits / 8) + 4) + (group_size * total_q_bits / 8);
if (total_q_bits != 6)
mem_access::store_global<quantization::quanitzed_access_granularity>(
q_val + scale_offset, scale_as_int8);
else {
mem_access::store_global<quantization::quanitzed_access_granularity_6bits>(
q_val + scale_offset, scale_as_int8);
mem_access::store_global<quantization::quanitzed_access_granularity_6bits>(
q_val + scale_offset + quantization::quanitzed_access_granularity_6bits,
scale_as_int8 + quantization::quanitzed_access_granularity_6bits);
}
}
}
template <typename T,
int unroll,
int q_mantisa_bits,
int total_q_bits = 16,
int _mantisa_bits = 3,
int _exponent_bits = 4>
__global__ void apply_dequantization(uint8_t* val, T* q_val, int group_size)
{
int tidx = threadIdx.x;
int wid = tidx >> 5;
int lane = tidx & 0x1f;
int gid = blockIdx.x * quantization::warps + wid;
constexpr int quantized_bits = _mantisa_bits + _exponent_bits + 1;
constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1;
constexpr uint16_t _mantisa_mask = (1 << _mantisa_bits) - 1;
constexpr uint16_t _exponent_mask = ((1 << _exponent_bits) - 1) << _mantisa_bits;
constexpr uint16_t _sign_mask = 1 << (_mantisa_bits + _exponent_bits);
constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T);
constexpr uint32_t load_stride = vector_size * hw_warp_size;
const uint32_t thread_offset = lane * vector_size;
const uint32_t thread_load_offset = lane * vector_size * quantized_bits / 8;
const uint32_t base_load_offset =
gid * (group_size * quantized_bits / 8 + 4) + thread_load_offset; // 4-byte scale offset
const uint32_t base_store_offset = gid * group_size + thread_offset;
const uint8_t* load_base_ptr = val + base_load_offset;
int mantisa_mask = ((1 << q_mantisa_bits) - 1);
mantisa_mask <<= (_mantisa_bits - q_mantisa_bits);
T* store_base_ptr = q_val + base_store_offset;
float scale; //= q_scale[gid];
uint8_t* scale_as_int8 = reinterpret_cast<uint8_t*>(&scale);
if (quantized_bits == 6) {
mem_access::load_global<quantization::quanitzed_access_granularity>(
scale_as_int8,
val + gid * (group_size * quantized_bits / 8 + 4) + (group_size * quantized_bits / 8));
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
scale_as_int8 + quantization::quanitzed_access_granularity_6bits,
val + gid * (group_size * quantized_bits / 8 + 4) + (group_size * quantized_bits / 8) +
quantization::quanitzed_access_granularity_6bits);
} else
mem_access::load_global<quantization::quanitzed_access_granularity>(
scale_as_int8,
val + gid * (group_size * quantized_bits / 8 + 4) + (group_size * quantized_bits / 8));
#pragma unroll
for (int i = 0; i < unroll; i++) {
if (i * load_stride + thread_offset < group_size) {
uint64_t q_buf_in;
uint64_t q_buf_in1;
uint8_t* int8_data = reinterpret_cast<uint8_t*>(&q_buf_in);
uint8_t* int8_data1 = reinterpret_cast<uint8_t*>(&q_buf_in1);
uint32_t loading_offset = i * load_stride * quantized_bits / 8;
if (quantized_bits == 6) {
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
int8_data, load_base_ptr + loading_offset);
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
int8_data + quantization::quanitzed_access_granularity_6bits,
load_base_ptr + loading_offset +
quantization::quanitzed_access_granularity_6bits);
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
int8_data + quantization::quanitzed_access_granularity_6bits * 2,
load_base_ptr + loading_offset +
quantization::quanitzed_access_granularity_6bits * 2);
} else {
mem_access::load_global<quantization::quanitzed_access_granularity>(
int8_data, load_base_ptr + loading_offset);
if (quantized_bits > 4) {
mem_access::load_global<quantization::quanitzed_access_granularity>(
int8_data + quantization::quanitzed_access_granularity,
load_base_ptr + loading_offset +
quantization::quanitzed_access_granularity);
if (quantized_bits == 12) {
mem_access::load_global<quantization::quanitzed_access_granularity>(
int8_data1,
load_base_ptr + loading_offset +
quantization::quanitzed_access_granularity * 2);
}
}
}
T store_buf[vector_size];
uint16_t* q_buf = reinterpret_cast<uint16_t*>(store_buf);
#pragma unroll
for (int j = 0; j < vector_size; j++) {
uint16_t new_data;
if (j < 5 || quantized_bits != 12) {
new_data = (uint16_t)(q_buf_in >> (j * quantized_bits));
} else {
if (j == 5) {
new_data = (uint16_t)(q_buf_in1);
new_data = (uint16_t)((new_data << 4) | (q_buf_in >> 60));
} else
new_data = (uint16_t)(q_buf_in1 >> ((j - 6) * quantized_bits + 8));
}
uint16_t sign = (new_data & _sign_mask) >> (_mantisa_bits + _exponent_bits);
uint16_t dst_exponent = (new_data & _exponent_mask) >> _mantisa_bits;
uint16_t dst_mantisa = (new_data & _mantisa_mask);
if (dst_exponent != (1 << q_exponent_bits) - 1)
dst_exponent = (dst_exponent - ((1 << (_exponent_bits - 1)) - 1)) +
(1 << (q_exponent_bits - 1)) - 1;
q_buf[j] = ((sign << (q_exponent_bits + q_mantisa_bits)) |
(dst_exponent << q_mantisa_bits) |
(dst_mantisa << (q_mantisa_bits - _mantisa_bits)));
float up_cast = conversion::to<float>(store_buf[j]);
store_buf[j] = conversion::to<T>(up_cast * scale);
}
mem_access::store_global<quantization::access_granularity>(
store_base_ptr + i * load_stride, store_buf);
}
}
}
#define LAUNCH_FOR_QUANTIZATION_UNROLL(COUNT) \
case COUNT: \
apply_quantization<T, \
COUNT, \
mantisa, \
exponent, \
CONST_Q_BITS, \
CONST_Q_MANTISA_BITS, \
CONST_STOCHASTIC_ROUNDING> \
<<<grid, block, 0, stream>>>(val, q_val, group_size, seed, q_range); \
break;
template <typename T, int mantisa, int exponent>
void launch_quantization(T* val,
uint8_t* q_val,
int num_groups,
int group_size,
cudaStream_t stream,
float q_range,
int q_bits,
int q_mantisa_bits,
int stochastic_rounding)
{
const dim3 grid((num_groups + quantization::warps - 1) / quantization::warps);
const dim3 block(quantization::threads);
std::pair<uint64_t, uint64_t> seed = FPContext::Instance().IncrementOffset(16);
constexpr int vals_per_unroll = hw_warp_size * quantization::access_granularity / sizeof(T);
const int copy_unroll = (group_size + vals_per_unroll - 1) / vals_per_unroll;
QUANT_SWITCH((q_bits - q_mantisa_bits - 1) * q_mantisa_bits + stochastic_rounding, [&] {
switch (copy_unroll) {
LAUNCH_FOR_QUANTIZATION_UNROLL(1)
LAUNCH_FOR_QUANTIZATION_UNROLL(2)
LAUNCH_FOR_QUANTIZATION_UNROLL(3)
LAUNCH_FOR_QUANTIZATION_UNROLL(4)
LAUNCH_FOR_QUANTIZATION_UNROLL(5)
LAUNCH_FOR_QUANTIZATION_UNROLL(6)
}
});
}
#define INSTANTIATE_LAUNCH_QUANTIZATION(T, mantisa, exponent) \
template void launch_quantization<T, mantisa, exponent>( \
T*, uint8_t*, int, int, cudaStream_t, float q_range, int, int, int);
// fp8(E4M3), nearest-rounding
#ifdef BF16_AVAILABLE
INSTANTIATE_LAUNCH_QUANTIZATION(__nv_bfloat16, 23, 8);
#endif
INSTANTIATE_LAUNCH_QUANTIZATION(__half, 23, 8);
#define LAUNCH_FOR_DEQUANTIZATION_UNROLL(COUNT) \
case COUNT: \
apply_dequantization<T, COUNT, mantisa, 16, CONST_Q_MANTISA_BITS, CONST_Q_EXPONENT_BITS> \
<<<grid, block, 0, stream>>>(val, q_val, group_size); \
break;
template <typename T, int mantisa>
void launch_dequantization(uint8_t* val,
T* q_val,
int num_groups,
int group_size,
int q_mantisa_bits,
int q_exponent_bits,
cudaStream_t stream)
{
const dim3 grid((num_groups + quantization::warps - 1) / quantization::warps);
const dim3 block(quantization::threads);
constexpr int vals_per_unroll = hw_warp_size * quantization::access_granularity / sizeof(T);
const int copy_unroll = (group_size + vals_per_unroll - 1) / vals_per_unroll;
DEQUANT_SWITCH(q_mantisa_bits * q_exponent_bits, [&] {
switch (copy_unroll) {
LAUNCH_FOR_DEQUANTIZATION_UNROLL(1)
LAUNCH_FOR_DEQUANTIZATION_UNROLL(2)
LAUNCH_FOR_DEQUANTIZATION_UNROLL(3)
LAUNCH_FOR_DEQUANTIZATION_UNROLL(4)
LAUNCH_FOR_DEQUANTIZATION_UNROLL(5)
LAUNCH_FOR_DEQUANTIZATION_UNROLL(6)
}
});
}
#define INSTANTIATE_LAUNCH_DEQUANTIZATION(T, mantisa) \
template void launch_dequantization<T, mantisa>(uint8_t*, T*, int, int, int, int, cudaStream_t);
// fp8(E4M3)
#ifdef BF16_AVAILABLE
INSTANTIATE_LAUNCH_DEQUANTIZATION(__nv_bfloat16, 7);
#endif
INSTANTIATE_LAUNCH_DEQUANTIZATION(__half, 10);

Просмотреть файл

@ -868,6 +868,35 @@ __device__ __forceinline__ void store_global<4, StorePolicy::CacheStreaming>(voi
#endif
}
template <>
__device__ __forceinline__ void store_global<2>(void* dst, const void* src)
{
const int16_t* data = reinterpret_cast<const int16_t*>(src);
int16_t* dst_cast = reinterpret_cast<int16_t*>(dst);
dst_cast[0] = data[0];
}
template <>
__device__ __forceinline__ void store_global<2, StorePolicy::CacheGlobal>(void* dst,
const void* src)
{
const int16_t* data = reinterpret_cast<const int16_t*>(src);
int16_t* dst_cast = reinterpret_cast<int16_t*>(dst);
dst_cast[0] = data[0];
}
template <>
__device__ __forceinline__ void store_global<2, StorePolicy::CacheStreaming>(void* dst,
const void* src)
{
const int16_t* data = reinterpret_cast<const int16_t*>(src);
int16_t* dst_cast = reinterpret_cast<int16_t*>(dst);
dst_cast[0] = data[0];
}
/////////// Store Shared ///////////
template <>

Просмотреть файл

@ -159,6 +159,12 @@ DS_D_INLINE float element<ROpType::Add>(const float lhs, const float rhs)
return lhs + rhs;
}
template <>
DS_D_INLINE double element<ROpType::Add>(const double lhs, const double rhs)
{
return lhs + rhs;
}
template <>
DS_D_INLINE float element<ROpType::Max>(const float lhs, const float rhs)
{
@ -189,6 +195,19 @@ DS_D_INLINE __half element<ROpType::Max>(const __half lhs, const __half rhs)
#endif
}
#ifdef BF16_AVAILABLE
template <>
DS_D_INLINE __nv_bfloat16 element<ROpType::Max>(const __nv_bfloat16 lhs, const __nv_bfloat16 rhs)
{
#if __CUDA_ARCH__ >= 800
// Intrinsic limited to Ampere + newer
return __hmax(lhs, rhs);
#else
return (lhs > rhs) ? lhs : rhs;
#endif
}
#endif
template <>
DS_D_INLINE __half element<ROpType::Min>(const __half lhs, const __half rhs)
{
@ -220,6 +239,21 @@ DS_D_INLINE __half2 element<ROpType::Max>(const __half2 lhs, const __half2 rhs)
#endif
}
#ifdef BF16_AVAILABLE
template <>
DS_D_INLINE __nv_bfloat162 element<ROpType::Max>(const __nv_bfloat162 lhs, const __nv_bfloat162 rhs)
{
#if __CUDA_ARCH__ >= 800
return __hmax2(lhs, rhs);
#else
__nv_bfloat162 ret_val;
ret_val.x = (lhs.x > rhs.x) ? lhs.x : rhs.x;
ret_val.y = (lhs.y > rhs.y) ? lhs.y : rhs.y;
return ret_val;
#endif
}
#endif
template <>
DS_D_INLINE __half2 element<ROpType::Min>(const __half2 lhs, const __half2 rhs)
{
@ -295,6 +329,11 @@ DS_D_INLINE float init<ROpType::Add>()
{
return 0.0f;
}
template <>
DS_D_INLINE double init<ROpType::Add>()
{
return (double)0.0f;
}
template <>
DS_D_INLINE float init<ROpType::Min>()
@ -331,6 +370,15 @@ DS_D_INLINE __half init<ROpType::Max>()
return __half(neg_inf);
}
#ifdef BF16_AVAILABLE
template <>
DS_D_INLINE __nv_bfloat16 init<ROpType::Max>()
{
constexpr __nv_bfloat16_raw neg_inf = {0xFF80};
return __nv_bfloat16(neg_inf);
}
#endif
template <>
DS_D_INLINE __half2 init<ROpType::Add>()
{

Просмотреть файл

@ -0,0 +1,6 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .quantize import FP_Quantize

Просмотреть файл

@ -0,0 +1,79 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from deepspeed.ops.op_builder import FPQuantizerBuilder
fp_quant_module = None
class FP_Quantize:
def __init__(self, group_size=512) -> None:
global fp_quant_module
if fp_quant_module is None:
fp_quant_module = FPQuantizerBuilder().load()
self.group_size = group_size
self.orig_dtype = None
def quantize(self,
input,
q_bits=8,
q_mantisa_bits=3,
stochastic_mode=False,
return_meta_tensor=False) -> torch.Tensor:
assert input.dtype == torch.bfloat16, "only support bf16 for now"
if return_meta_tensor:
assert q_bits == 8, "meta tensor is only supported with q_bit=8"
self.orig_dtype = input.dtype
self.orig_shape = input.shape
if q_bits == 8:
pass
elif q_bits == 12:
q_mantisa_bits = 4
elif q_bits == 6:
q_mantisa_bits = 2
elif q_bits == 4:
q_mantisa_bits = 1
else:
assert (0), \
f"Missing {q_bits}-quantization, please add the template arguments for the kernel to support this precision!"
out = fp_quant_module.quantize(input, self.group_size, stochastic_mode, q_bits, q_mantisa_bits)
if return_meta_tensor:
data, scale = out.split(self.group_size, dim=-1)
return data.contiguous().reshape(input.shape), scale.contiguous()
return out
def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=None) -> torch.Tensor:
assert (self.orig_dtype is not None), \
"[De-quantization Error]: you need to call quantize before dequantizing!"
fp_out = torch.empty(self.orig_shape, dtype=self.orig_dtype,
device=input_q.device) if fp_out is None else fp_out
if q_bits == 8:
pass
elif q_bits == 12:
q_mantisa_bits = 4
elif q_bits == 6:
q_mantisa_bits = 2
elif q_bits == 4:
q_mantisa_bits = 1
else:
assert (0), \
f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!"
if scale is not None:
assert input_q.numel() == fp_out.numel(), \
f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!'
input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous()
fp_quant_module.dequantize(fp_out, input_q, self.group_size, q_mantisa_bits, q_bits - q_mantisa_bits - 1)
return fp_out

Просмотреть файл

@ -0,0 +1,63 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .builder import CUDAOpBuilder, installed_cuda_version
class FPQuantizerBuilder(CUDAOpBuilder):
BUILD_VAR = "DS_BUILD_FP_QUANTIZER"
NAME = "fp_quantizer"
def __init__(self, name=None):
name = self.NAME if name is None else name
super().__init__(name=name)
def absolute_name(self):
return f'deepspeed.ops.fp_quantizer.{self.NAME}_op'
def is_compatible(self, verbose=True):
try:
import torch
except ImportError:
self.warning("Please install torch if trying to pre-compile inference kernels")
return False
cuda_okay = True
if not self.is_rocm_pytorch() and torch.cuda.is_available(): #ignore-cuda
sys_cuda_major, _ = installed_cuda_version()
torch_cuda_major = int(torch.version.cuda.split('.')[0])
cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda
if cuda_capability < 8:
self.warning("NVIDIA Inference is only supported on Ampere and newer architectures")
cuda_okay = False
if cuda_capability >= 8:
if torch_cuda_major < 11 or sys_cuda_major < 11:
self.warning("On Ampere and higher architectures please use CUDA 11+")
cuda_okay = False
return super().is_compatible(verbose) and cuda_okay
def filter_ccs(self, ccs):
ccs_retained = []
ccs_pruned = []
for cc in ccs:
if int(cc[0]) >= 8:
ccs_retained.append(cc)
else:
ccs_pruned.append(cc)
if len(ccs_pruned) > 0:
self.warning(f"Filtered compute capabilities {ccs_pruned}")
return ccs_retained
def sources(self):
return [
"csrc/fp_quantizer/quantize.cu",
"csrc/fp_quantizer/quantize.cpp",
]
def extra_ldflags(self):
return ['-lcurand']
def include_paths(self):
return ['csrc/fp_quantizer/includes', 'csrc/includes']

Просмотреть файл

@ -10,6 +10,7 @@ pytest<=8.0.0
pytest-forked
pytest-randomly
pytest-xdist
qtorch==0.3.0
recommonmark
sphinx
sphinx-rtd-theme

Просмотреть файл

@ -0,0 +1,94 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import pytest
import torch
import deepspeed
from deepspeed.ops.fp_quantizer import FP_Quantize
from deepspeed.ops.op_builder import FPQuantizerBuilder
if not deepspeed.ops.__compatible_ops__[FPQuantizerBuilder.NAME]:
pytest.skip("FPQuantizer op is not available on this system", allow_module_level=True)
# warning: this import silently JIT builds a set of kernels and may take a minute
from qtorch.quant import float_quantize
def qtorch_quantize(input, exp_bits=4, man_bits=3, rounding="nearest", group_size=1024):
ori_dt = input.dtype
ori_shape = input.shape
last_dim = group_size
input = input.view(-1, last_dim)
q_bits = exp_bits + man_bits + 1
input_to_float = input.float()
if q_bits == 8:
q_range = 480.
elif q_bits == 6:
q_range = 28.
elif q_bits == 12:
q_range = 510.
else:
assert (0), \
"Please specify the right quantization range for the selected precision!"
input_max = input_to_float.abs().amax(dim=-1, keepdim=True)
return ((float_quantize(input_to_float / input_max * q_range, exp_bits, man_bits, rounding=rounding) * \
input_max / q_range).to(ori_dt)).reshape(ori_shape)
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"])
def test_fp_quant_meta(dtype):
group_size = 128
q_bits = 8
exp_bits = 4
man_bits = 3
fpq = FP_Quantize(group_size=group_size)
for i in range(10):
x = torch.rand(4, 1024, dtype=dtype, device='cuda')
ds_x = x.clone()
x_quantized, meta_tensor = fpq.quantize(ds_x, q_bits=q_bits, return_meta_tensor=True)
x_dequantized = fpq.dequantize(x_quantized, q_bits=q_bits, scale=meta_tensor)
qtorch_out = qtorch_quantize(x, exp_bits=exp_bits, man_bits=man_bits, group_size=group_size)
qtorch_error = (qtorch_out - x).abs().sum() / x.numel()
ds_error = (x_dequantized - x).abs().sum() / x.numel()
assert 0.0004 > abs(qtorch_error.item() - ds_error.item()), f"failed on iteration {i}"
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"])
@pytest.mark.parametrize("q_bits", [8, 6, 12], ids=["qbits8", "qbits6", "qbits12"])
def test_fp_quant(dtype, q_bits):
group_size = 128
fpq = FP_Quantize(group_size=group_size)
for i in range(10):
x = torch.rand(4, 1024, dtype=dtype, device='cuda')
ds_x = x.clone()
x_quantized = fpq.quantize(ds_x, q_bits=q_bits)
x_dequantized = fpq.dequantize(x_quantized, q_bits=q_bits)
if q_bits == 8:
exp_bits = 4
man_bits = 3
elif q_bits == 6:
exp_bits = 3
man_bits = 2
elif q_bits == 12:
exp_bits = 4
man_bits = 7
else:
raise ValueError(f"unknown {q_bits=}")
qtorch_out = qtorch_quantize(x, exp_bits=exp_bits, man_bits=man_bits, group_size=group_size)
qtorch_error = (qtorch_out - x).abs().sum() / x.numel()
ds_error = (x_dequantized - x).abs().sum() / x.numel()
assert 0.0004 > abs(qtorch_error.item() - ds_error.item()), f"failed on iteration {i}"