зеркало из https://github.com/microsoft/DeepSpeed.git
FP [6,8,12] quantizer op (#5336)
Flexible-bit quantizer-dequantizer library with fp6/fp12/fp8 support Requires Ampere+ architecture, this is due to the initial focus of this op only on `bfloat16` input types. Co-authored-by: Reza Yazdani <reza.yazdani@snowflake.com>
This commit is contained in:
Родитель
4621ba4cd4
Коммит
3fbd01ccca
|
@ -36,7 +36,7 @@ jobs:
|
|||
#python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
|
||||
- name: Compile DeepSpeed Ops
|
||||
run: |
|
||||
DS_ACCELERATOR=cuda DS_ENABLE_NINJA=1 TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_CUTLASS_OPS=0 DS_BUILD_RAGGED_DEVICE_OPS=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install .
|
||||
DS_ACCELERATOR=cuda DS_ENABLE_NINJA=1 TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_FP_QUANTIZER=0 DS_BUILD_CUTLASS_OPS=0 DS_BUILD_RAGGED_DEVICE_OPS=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install .
|
||||
- name: DS Report
|
||||
run: |
|
||||
ds_report
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include "cublas_v2.h"
|
||||
#include "cuda.h"
|
||||
#include "curand.h"
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <stdlib.h>
|
||||
#include <sys/time.h>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <stack>
|
||||
#include <string>
|
||||
#define WARP_SIZE 32
|
||||
|
||||
class FPContext {
|
||||
public:
|
||||
FPContext() : _seed(42)
|
||||
{
|
||||
curandCreateGenerator(&_gen, CURAND_RNG_PSEUDO_DEFAULT);
|
||||
curandSetPseudoRandomGeneratorSeed(_gen, 123);
|
||||
}
|
||||
|
||||
virtual ~FPContext() {}
|
||||
|
||||
static FPContext& Instance()
|
||||
{
|
||||
static FPContext _ctx;
|
||||
return _ctx;
|
||||
}
|
||||
|
||||
curandGenerator_t& GetRandGenerator() { return _gen; }
|
||||
|
||||
cudaStream_t GetCurrentStream()
|
||||
{
|
||||
// get current pytorch stream.
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
return stream;
|
||||
}
|
||||
|
||||
std::pair<uint64_t, uint64_t> IncrementOffset(uint64_t offset_inc)
|
||||
{
|
||||
uint64_t offset = _curr_offset;
|
||||
_curr_offset += offset_inc;
|
||||
return std::pair<uint64_t, uint64_t>(_seed, offset);
|
||||
}
|
||||
|
||||
void SetSeed(uint64_t new_seed) { _seed = new_seed; }
|
||||
|
||||
private:
|
||||
curandGenerator_t _gen;
|
||||
cublasHandle_t _cublasHandle;
|
||||
uint64_t _seed;
|
||||
uint64_t _curr_offset;
|
||||
};
|
|
@ -0,0 +1,115 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#define QUANT_SWITCH(Q_BITS, ...) \
|
||||
[&] { \
|
||||
if (12 == Q_BITS) { \
|
||||
constexpr int CONST_STOCHASTIC_ROUNDING = 0; \
|
||||
constexpr int CONST_Q_BITS = 8; \
|
||||
constexpr int CONST_Q_MANTISA_BITS = 3; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (13 == Q_BITS) { \
|
||||
constexpr int CONST_STOCHASTIC_ROUNDING = 1; \
|
||||
constexpr int CONST_Q_BITS = 8; \
|
||||
constexpr int CONST_Q_MANTISA_BITS = 3; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (10 == Q_BITS) { \
|
||||
constexpr int CONST_STOCHASTIC_ROUNDING = 0; \
|
||||
constexpr int CONST_Q_BITS = 8; \
|
||||
constexpr int CONST_Q_MANTISA_BITS = 2; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (11 == Q_BITS) { \
|
||||
constexpr int CONST_STOCHASTIC_ROUNDING = 1; \
|
||||
constexpr int CONST_Q_BITS = 8; \
|
||||
constexpr int CONST_Q_MANTISA_BITS = 2; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (28 == Q_BITS) { \
|
||||
constexpr int CONST_STOCHASTIC_ROUNDING = 0; \
|
||||
constexpr int CONST_Q_BITS = 12; \
|
||||
constexpr int CONST_Q_MANTISA_BITS = 7; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (29 == Q_BITS) { \
|
||||
constexpr int CONST_STOCHASTIC_ROUNDING = 1; \
|
||||
constexpr int CONST_Q_BITS = 12; \
|
||||
constexpr int CONST_Q_MANTISA_BITS = 7; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (6 == Q_BITS) { \
|
||||
constexpr int CONST_STOCHASTIC_ROUNDING = 0; \
|
||||
constexpr int CONST_Q_BITS = 6; \
|
||||
constexpr int CONST_Q_MANTISA_BITS = 2; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (7 == Q_BITS) { \
|
||||
constexpr int CONST_STOCHASTIC_ROUNDING = 1; \
|
||||
constexpr int CONST_Q_BITS = 6; \
|
||||
constexpr int CONST_Q_MANTISA_BITS = 2; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (2 == Q_BITS) { \
|
||||
constexpr int CONST_STOCHASTIC_ROUNDING = 0; \
|
||||
constexpr int CONST_Q_BITS = 4; \
|
||||
constexpr int CONST_Q_MANTISA_BITS = 1; \
|
||||
__VA_ARGS__(); \
|
||||
} else { \
|
||||
constexpr int CONST_STOCHASTIC_ROUNDING = 1; \
|
||||
constexpr int CONST_Q_BITS = 4; \
|
||||
constexpr int CONST_Q_MANTISA_BITS = 1; \
|
||||
__VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define DEQUANT_SWITCH(Q_MANTISA_EXPONENT_BITS, ...) \
|
||||
[&] { \
|
||||
if (12 == Q_MANTISA_EXPONENT_BITS) { \
|
||||
constexpr int CONST_Q_MANTISA_BITS = 3; \
|
||||
constexpr int CONST_Q_EXPONENT_BITS = 4; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (10 == Q_MANTISA_EXPONENT_BITS) { \
|
||||
constexpr int CONST_Q_MANTISA_BITS = 2; \
|
||||
constexpr int CONST_Q_EXPONENT_BITS = 5; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (28 == Q_MANTISA_EXPONENT_BITS) { \
|
||||
constexpr int CONST_Q_MANTISA_BITS = 7; \
|
||||
constexpr int CONST_Q_EXPONENT_BITS = 4; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (6 == Q_MANTISA_EXPONENT_BITS) { \
|
||||
constexpr int CONST_Q_MANTISA_BITS = 2; \
|
||||
constexpr int CONST_Q_EXPONENT_BITS = 3; \
|
||||
__VA_ARGS__(); \
|
||||
} else { \
|
||||
constexpr int CONST_Q_MANTISA_BITS = 1; \
|
||||
constexpr int CONST_Q_EXPONENT_BITS = 2; \
|
||||
__VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
|
||||
template <typename T, int mantisa, int exponent>
|
||||
void launch_quantization(T* val,
|
||||
uint8_t* q_val,
|
||||
int num_groups,
|
||||
int group_size,
|
||||
cudaStream_t stream,
|
||||
float q_range,
|
||||
int q_bits,
|
||||
int q_mantisa_bits,
|
||||
int stochastic_rounding);
|
||||
|
||||
template <typename T, int mantisa>
|
||||
void launch_dequantization(uint8_t* val,
|
||||
T* q_val,
|
||||
int num_groups,
|
||||
int group_size,
|
||||
int q_mantisa_bits,
|
||||
int q_exponent_bits,
|
||||
cudaStream_t stream);
|
|
@ -0,0 +1,85 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include "quantize.h"
|
||||
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/extension.h>
|
||||
#include <vector>
|
||||
|
||||
#define DISPATCH_QUANTIZE(T_TYPE, C_TYPE, mantisa, exponent) \
|
||||
if (val.options().dtype() == torch::T_TYPE) { \
|
||||
launch_quantization<C_TYPE, mantisa, exponent>((C_TYPE*)val.data_ptr(), \
|
||||
(uint8_t*)out.data_ptr(), \
|
||||
num_groups, \
|
||||
group_size, \
|
||||
at::cuda::getCurrentCUDAStream(), \
|
||||
q_range, \
|
||||
q_bits, \
|
||||
q_mantisa_bits, \
|
||||
stochastic_rounding); \
|
||||
}
|
||||
|
||||
at::Tensor quantize(torch::Tensor& val,
|
||||
int group_size,
|
||||
int stochastic_rounding,
|
||||
int q_bits,
|
||||
int q_mantisa_bits)
|
||||
{
|
||||
int total_elems = at::numel(val);
|
||||
auto options = at::TensorOptions()
|
||||
.dtype(torch::kInt8)
|
||||
.layout(val.layout())
|
||||
.device(val.device())
|
||||
.requires_grad(false);
|
||||
float q_range = q_bits == 8 ? (q_mantisa_bits == 3 ? 480.0 : 114688.0) : // fp8 ranges
|
||||
(q_bits == 12 ? 510.0 : // fp12 range
|
||||
(q_bits == 6 ? 28.0 : // fp6 range
|
||||
6.0)); // fp4 range (using power 2); TODO (Reza): add the power-4
|
||||
// in case accuracy is not matching!
|
||||
int num_groups = total_elems / group_size;
|
||||
auto out = torch::empty({num_groups, group_size * q_bits / 8 + 4}, options);
|
||||
|
||||
DISPATCH_QUANTIZE(kHalf, __half, 23, 8);
|
||||
#ifdef BF16_AVAILABLE
|
||||
DISPATCH_QUANTIZE(kBFloat16, __nv_bfloat16, 23, 8);
|
||||
#endif
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
#define DISPATCH_DEQUANTIZE(T_TYPE, C_TYPE, mantisa) \
|
||||
if (val.options().dtype() == torch::T_TYPE) { \
|
||||
launch_dequantization<C_TYPE, mantisa>((uint8_t*)val_q.data_ptr(), \
|
||||
(C_TYPE*)val.data_ptr(), \
|
||||
num_groups, \
|
||||
group_size, \
|
||||
q_mantisa_bits, \
|
||||
q_exponent_bits, \
|
||||
at::cuda::getCurrentCUDAStream()); \
|
||||
return; \
|
||||
}
|
||||
|
||||
void dequantize(torch::Tensor& val,
|
||||
torch::Tensor& val_q,
|
||||
int group_size,
|
||||
int q_mantisa_bits,
|
||||
int q_exponent_bits)
|
||||
{
|
||||
int total_elems = at::numel(val);
|
||||
|
||||
int num_groups = total_elems / group_size;
|
||||
|
||||
DISPATCH_DEQUANTIZE(kHalf, __half, 10);
|
||||
#ifdef BF16_AVAILABLE
|
||||
DISPATCH_DEQUANTIZE(kBFloat16, __nv_bfloat16, 7);
|
||||
#endif
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("quantize", &quantize, "quantize function");
|
||||
m.def("dequantize", &dequantize, "dequantize function");
|
||||
}
|
|
@ -0,0 +1,427 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include <stdexcept>
|
||||
#include "context.h"
|
||||
#include "memory_access_utils.h"
|
||||
#include "quantize.h"
|
||||
#include "reduction_utils.h"
|
||||
|
||||
#include <cuda.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <curand_kernel.h>
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
using ROp = reduce::ROpType;
|
||||
|
||||
namespace quantization {
|
||||
|
||||
constexpr int access_granularity = 16;
|
||||
constexpr int quanitzed_access_granularity = 4;
|
||||
constexpr int quanitzed_access_granularity_6bits = 2;
|
||||
constexpr int threads = 256;
|
||||
constexpr int warps = threads / 32;
|
||||
|
||||
} // namespace quantization
|
||||
|
||||
template <int _mantisa_bits, int q_mantisa_bits, int stochastic_rounding>
|
||||
__device__ void round(uint32_t& mantisa, uint32_t& dst_exponent, curandStatePhilox4_32_10_t* state)
|
||||
{
|
||||
constexpr uint32_t mantisa_mask = (1 << (_mantisa_bits - q_mantisa_bits)) - 1;
|
||||
uint32_t offset = stochastic_rounding ? (curand_poisson(state, 10) & mantisa_mask)
|
||||
: 1 << (_mantisa_bits - q_mantisa_bits - 1);
|
||||
mantisa += offset;
|
||||
dst_exponent += (((mantisa & ~mantisa_mask) == (1 << _mantisa_bits)) ? 1 : 0);
|
||||
}
|
||||
|
||||
template <int _mantisa_bits, int _exponent_bits, int q_mantisa_bits, int q_exponent_bits>
|
||||
__device__ void clip(uint32_t& exponent, uint32_t& mantisa)
|
||||
{
|
||||
constexpr uint32_t max_exponent = (1 << (q_exponent_bits - 1)) + (1 << (_exponent_bits - 1));
|
||||
constexpr uint32_t min_exponent =
|
||||
(1 << (_exponent_bits - 1)) - ((1 << (q_exponent_bits - 1)) - 1);
|
||||
if (exponent > max_exponent) {
|
||||
exponent = max_exponent;
|
||||
mantisa = (((uint32_t)-1) >> (32 - q_mantisa_bits)) << 1; //.11 .. 10
|
||||
}
|
||||
if (exponent < min_exponent) {
|
||||
exponent = min_exponent;
|
||||
mantisa = 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
int unroll,
|
||||
int _mantisa_bits,
|
||||
int _exponent_bits,
|
||||
int total_q_bits = 8,
|
||||
int q_mantisa_bits = 3,
|
||||
int stochastic_rounding = 0>
|
||||
__global__ void apply_quantization(T* val,
|
||||
uint8_t* q_val,
|
||||
int group_size,
|
||||
std::pair<uint64_t, uint64_t> seed,
|
||||
float q_range)
|
||||
{
|
||||
int tidx = threadIdx.x;
|
||||
int wid = tidx >> 5;
|
||||
int lane = tidx & 0x1f;
|
||||
int gid = blockIdx.x * quantization::warps + wid;
|
||||
|
||||
constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1;
|
||||
constexpr uint32_t _mantisa_mask = (1 << _mantisa_bits) - 1;
|
||||
constexpr uint32_t _exponent_mask = ((1 << _exponent_bits) - 1) << _mantisa_bits;
|
||||
constexpr uint32_t _sign_mask = 1 << (_mantisa_bits + _exponent_bits);
|
||||
// CG helpers
|
||||
cg::thread_block tb = cg::this_thread_block();
|
||||
cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);
|
||||
|
||||
constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T);
|
||||
constexpr uint32_t load_stride = vector_size * hw_warp_size;
|
||||
constexpr uint32_t store_stride = (total_q_bits * vector_size / 8) * hw_warp_size;
|
||||
const uint32_t thread_offset = lane * vector_size;
|
||||
const uint32_t store_thread_offset = lane * (total_q_bits * vector_size / 8);
|
||||
const uint32_t base_load_offset = gid * group_size + thread_offset;
|
||||
const uint32_t base_store_offset =
|
||||
gid * ((group_size * total_q_bits / 8) + 4) +
|
||||
store_thread_offset; // 4-byte for saving the scale per group
|
||||
const T* load_base_ptr = val + base_load_offset;
|
||||
T tmp_buf[unroll * vector_size];
|
||||
T cur_max;
|
||||
reduce::init<ROp::Max>(&cur_max);
|
||||
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
curandStatePhilox4_32_10_t state;
|
||||
curand_init(seed.first, idx, seed.second, &state);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < unroll; i++) {
|
||||
if (i * load_stride + thread_offset < group_size) {
|
||||
mem_access::load_global<quantization::access_granularity>(
|
||||
&tmp_buf[vector_size * i], load_base_ptr + i * load_stride);
|
||||
for (int j = 0; j < vector_size; j++)
|
||||
cur_max = reduce::element<ROp::Max>(cur_max, __habs(tmp_buf[i * vector_size + j]));
|
||||
}
|
||||
}
|
||||
reduce::_block<T, 1, ROp::Max>(tb, warp, &cur_max);
|
||||
|
||||
int mantisa_mask = ((1 << q_mantisa_bits) - 1);
|
||||
mantisa_mask <<= (_mantisa_bits - q_mantisa_bits);
|
||||
|
||||
uint8_t* store_base_ptr = q_val + base_store_offset;
|
||||
float scale = (float)q_range / conversion::to<float>(cur_max);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < unroll; i++) {
|
||||
if (i * load_stride + thread_offset < group_size) {
|
||||
uint64_t q_buf = 0;
|
||||
uint64_t q_buf1 = 0;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < vector_size; j++) {
|
||||
float val_f = conversion::to<float>(tmp_buf[i * vector_size + j]) * scale;
|
||||
uint32_t* data = reinterpret_cast<uint32_t*>(&val_f);
|
||||
uint32_t sign = (data[0] & _sign_mask) >> (_mantisa_bits + _exponent_bits);
|
||||
uint32_t cur_exponent = (data[0] & _exponent_mask) >> _mantisa_bits;
|
||||
uint32_t dst_mantisa = (data[0] & _mantisa_mask);
|
||||
|
||||
uint32_t dst_exponent = cur_exponent;
|
||||
|
||||
round<_mantisa_bits, q_mantisa_bits, stochastic_rounding>(
|
||||
dst_mantisa, dst_exponent, &state);
|
||||
if (cur_exponent != 0)
|
||||
clip<_mantisa_bits, _exponent_bits, q_mantisa_bits, q_exponent_bits>(
|
||||
dst_exponent, dst_mantisa);
|
||||
|
||||
dst_mantisa = (dst_mantisa & mantisa_mask) >> (_mantisa_bits - q_mantisa_bits);
|
||||
|
||||
if (dst_exponent != (1 << q_exponent_bits) - 1)
|
||||
dst_exponent = (dst_exponent - ((1 << (_exponent_bits - 1)) - 1)) +
|
||||
(1 << (q_exponent_bits - 1)) - 1;
|
||||
if (total_q_bits == 8 || total_q_bits == 4 || total_q_bits == 6)
|
||||
q_buf = q_buf |
|
||||
((uint64_t)((uint8_t)(sign << (q_exponent_bits + q_mantisa_bits) |
|
||||
(dst_exponent << q_mantisa_bits) | dst_mantisa))
|
||||
<< j * total_q_bits);
|
||||
else if (total_q_bits == 12) {
|
||||
if (j < 5)
|
||||
q_buf =
|
||||
q_buf |
|
||||
((uint64_t)((uint16_t)(sign << (q_exponent_bits + q_mantisa_bits) |
|
||||
(dst_exponent << q_mantisa_bits) | dst_mantisa))
|
||||
<< j * total_q_bits);
|
||||
else
|
||||
q_buf1 =
|
||||
q_buf1 |
|
||||
((uint64_t)((uint16_t)(sign << (q_exponent_bits + q_mantisa_bits) |
|
||||
(dst_exponent << q_mantisa_bits) | dst_mantisa))
|
||||
<< (j - 5) * total_q_bits);
|
||||
}
|
||||
}
|
||||
if (total_q_bits == 12) {
|
||||
uint64_t last_nibble_mask = 0xf;
|
||||
last_nibble_mask = q_buf1 & last_nibble_mask;
|
||||
q_buf = (last_nibble_mask << 60) | q_buf;
|
||||
q_buf1 >>= 4;
|
||||
}
|
||||
uint8_t* int8_data = reinterpret_cast<uint8_t*>(&q_buf);
|
||||
uint8_t* int8_data1 = reinterpret_cast<uint8_t*>(&q_buf1);
|
||||
if (total_q_bits == 6) {
|
||||
mem_access::store_global<quantization::quanitzed_access_granularity_6bits>(
|
||||
store_base_ptr + i * store_stride, int8_data);
|
||||
mem_access::store_global<quantization::quanitzed_access_granularity_6bits>(
|
||||
store_base_ptr + i * store_stride +
|
||||
quantization::quanitzed_access_granularity_6bits,
|
||||
int8_data + quantization::quanitzed_access_granularity_6bits);
|
||||
mem_access::store_global<quantization::quanitzed_access_granularity_6bits>(
|
||||
store_base_ptr + i * store_stride +
|
||||
quantization::quanitzed_access_granularity_6bits * 2,
|
||||
int8_data + 2 * quantization::quanitzed_access_granularity_6bits);
|
||||
} else {
|
||||
mem_access::store_global<quantization::quanitzed_access_granularity>(
|
||||
store_base_ptr + i * store_stride, int8_data);
|
||||
|
||||
if (total_q_bits > 4) {
|
||||
mem_access::store_global<quantization::quanitzed_access_granularity>(
|
||||
store_base_ptr + i * store_stride +
|
||||
quantization::quanitzed_access_granularity,
|
||||
int8_data + quantization::quanitzed_access_granularity);
|
||||
if (total_q_bits == 12) {
|
||||
mem_access::store_global<quantization::quanitzed_access_granularity>(
|
||||
store_base_ptr + i * store_stride +
|
||||
quantization::quanitzed_access_granularity * 2,
|
||||
int8_data1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (lane == 0) {
|
||||
float q_scale = conversion::to<float>(cur_max) / (float)q_range;
|
||||
uint8_t* scale_as_int8 = reinterpret_cast<uint8_t*>(&q_scale);
|
||||
uint32_t scale_offset =
|
||||
gid * ((group_size * total_q_bits / 8) + 4) + (group_size * total_q_bits / 8);
|
||||
if (total_q_bits != 6)
|
||||
mem_access::store_global<quantization::quanitzed_access_granularity>(
|
||||
q_val + scale_offset, scale_as_int8);
|
||||
else {
|
||||
mem_access::store_global<quantization::quanitzed_access_granularity_6bits>(
|
||||
q_val + scale_offset, scale_as_int8);
|
||||
mem_access::store_global<quantization::quanitzed_access_granularity_6bits>(
|
||||
q_val + scale_offset + quantization::quanitzed_access_granularity_6bits,
|
||||
scale_as_int8 + quantization::quanitzed_access_granularity_6bits);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
int unroll,
|
||||
int q_mantisa_bits,
|
||||
int total_q_bits = 16,
|
||||
int _mantisa_bits = 3,
|
||||
int _exponent_bits = 4>
|
||||
__global__ void apply_dequantization(uint8_t* val, T* q_val, int group_size)
|
||||
{
|
||||
int tidx = threadIdx.x;
|
||||
int wid = tidx >> 5;
|
||||
int lane = tidx & 0x1f;
|
||||
int gid = blockIdx.x * quantization::warps + wid;
|
||||
constexpr int quantized_bits = _mantisa_bits + _exponent_bits + 1;
|
||||
constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1;
|
||||
constexpr uint16_t _mantisa_mask = (1 << _mantisa_bits) - 1;
|
||||
constexpr uint16_t _exponent_mask = ((1 << _exponent_bits) - 1) << _mantisa_bits;
|
||||
constexpr uint16_t _sign_mask = 1 << (_mantisa_bits + _exponent_bits);
|
||||
|
||||
constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T);
|
||||
constexpr uint32_t load_stride = vector_size * hw_warp_size;
|
||||
const uint32_t thread_offset = lane * vector_size;
|
||||
const uint32_t thread_load_offset = lane * vector_size * quantized_bits / 8;
|
||||
const uint32_t base_load_offset =
|
||||
gid * (group_size * quantized_bits / 8 + 4) + thread_load_offset; // 4-byte scale offset
|
||||
const uint32_t base_store_offset = gid * group_size + thread_offset;
|
||||
const uint8_t* load_base_ptr = val + base_load_offset;
|
||||
|
||||
int mantisa_mask = ((1 << q_mantisa_bits) - 1);
|
||||
mantisa_mask <<= (_mantisa_bits - q_mantisa_bits);
|
||||
|
||||
T* store_base_ptr = q_val + base_store_offset;
|
||||
float scale; //= q_scale[gid];
|
||||
|
||||
uint8_t* scale_as_int8 = reinterpret_cast<uint8_t*>(&scale);
|
||||
if (quantized_bits == 6) {
|
||||
mem_access::load_global<quantization::quanitzed_access_granularity>(
|
||||
scale_as_int8,
|
||||
val + gid * (group_size * quantized_bits / 8 + 4) + (group_size * quantized_bits / 8));
|
||||
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
|
||||
scale_as_int8 + quantization::quanitzed_access_granularity_6bits,
|
||||
val + gid * (group_size * quantized_bits / 8 + 4) + (group_size * quantized_bits / 8) +
|
||||
quantization::quanitzed_access_granularity_6bits);
|
||||
} else
|
||||
mem_access::load_global<quantization::quanitzed_access_granularity>(
|
||||
scale_as_int8,
|
||||
val + gid * (group_size * quantized_bits / 8 + 4) + (group_size * quantized_bits / 8));
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < unroll; i++) {
|
||||
if (i * load_stride + thread_offset < group_size) {
|
||||
uint64_t q_buf_in;
|
||||
uint64_t q_buf_in1;
|
||||
uint8_t* int8_data = reinterpret_cast<uint8_t*>(&q_buf_in);
|
||||
uint8_t* int8_data1 = reinterpret_cast<uint8_t*>(&q_buf_in1);
|
||||
uint32_t loading_offset = i * load_stride * quantized_bits / 8;
|
||||
if (quantized_bits == 6) {
|
||||
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
|
||||
int8_data, load_base_ptr + loading_offset);
|
||||
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
|
||||
int8_data + quantization::quanitzed_access_granularity_6bits,
|
||||
load_base_ptr + loading_offset +
|
||||
quantization::quanitzed_access_granularity_6bits);
|
||||
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
|
||||
int8_data + quantization::quanitzed_access_granularity_6bits * 2,
|
||||
load_base_ptr + loading_offset +
|
||||
quantization::quanitzed_access_granularity_6bits * 2);
|
||||
} else {
|
||||
mem_access::load_global<quantization::quanitzed_access_granularity>(
|
||||
int8_data, load_base_ptr + loading_offset);
|
||||
if (quantized_bits > 4) {
|
||||
mem_access::load_global<quantization::quanitzed_access_granularity>(
|
||||
int8_data + quantization::quanitzed_access_granularity,
|
||||
load_base_ptr + loading_offset +
|
||||
quantization::quanitzed_access_granularity);
|
||||
if (quantized_bits == 12) {
|
||||
mem_access::load_global<quantization::quanitzed_access_granularity>(
|
||||
int8_data1,
|
||||
load_base_ptr + loading_offset +
|
||||
quantization::quanitzed_access_granularity * 2);
|
||||
}
|
||||
}
|
||||
}
|
||||
T store_buf[vector_size];
|
||||
uint16_t* q_buf = reinterpret_cast<uint16_t*>(store_buf);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < vector_size; j++) {
|
||||
uint16_t new_data;
|
||||
if (j < 5 || quantized_bits != 12) {
|
||||
new_data = (uint16_t)(q_buf_in >> (j * quantized_bits));
|
||||
} else {
|
||||
if (j == 5) {
|
||||
new_data = (uint16_t)(q_buf_in1);
|
||||
new_data = (uint16_t)((new_data << 4) | (q_buf_in >> 60));
|
||||
} else
|
||||
new_data = (uint16_t)(q_buf_in1 >> ((j - 6) * quantized_bits + 8));
|
||||
}
|
||||
|
||||
uint16_t sign = (new_data & _sign_mask) >> (_mantisa_bits + _exponent_bits);
|
||||
uint16_t dst_exponent = (new_data & _exponent_mask) >> _mantisa_bits;
|
||||
uint16_t dst_mantisa = (new_data & _mantisa_mask);
|
||||
|
||||
if (dst_exponent != (1 << q_exponent_bits) - 1)
|
||||
dst_exponent = (dst_exponent - ((1 << (_exponent_bits - 1)) - 1)) +
|
||||
(1 << (q_exponent_bits - 1)) - 1;
|
||||
|
||||
q_buf[j] = ((sign << (q_exponent_bits + q_mantisa_bits)) |
|
||||
(dst_exponent << q_mantisa_bits) |
|
||||
(dst_mantisa << (q_mantisa_bits - _mantisa_bits)));
|
||||
float up_cast = conversion::to<float>(store_buf[j]);
|
||||
store_buf[j] = conversion::to<T>(up_cast * scale);
|
||||
}
|
||||
mem_access::store_global<quantization::access_granularity>(
|
||||
store_base_ptr + i * load_stride, store_buf);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define LAUNCH_FOR_QUANTIZATION_UNROLL(COUNT) \
|
||||
case COUNT: \
|
||||
apply_quantization<T, \
|
||||
COUNT, \
|
||||
mantisa, \
|
||||
exponent, \
|
||||
CONST_Q_BITS, \
|
||||
CONST_Q_MANTISA_BITS, \
|
||||
CONST_STOCHASTIC_ROUNDING> \
|
||||
<<<grid, block, 0, stream>>>(val, q_val, group_size, seed, q_range); \
|
||||
break;
|
||||
|
||||
template <typename T, int mantisa, int exponent>
|
||||
void launch_quantization(T* val,
|
||||
uint8_t* q_val,
|
||||
int num_groups,
|
||||
int group_size,
|
||||
cudaStream_t stream,
|
||||
float q_range,
|
||||
int q_bits,
|
||||
int q_mantisa_bits,
|
||||
int stochastic_rounding)
|
||||
{
|
||||
const dim3 grid((num_groups + quantization::warps - 1) / quantization::warps);
|
||||
const dim3 block(quantization::threads);
|
||||
|
||||
std::pair<uint64_t, uint64_t> seed = FPContext::Instance().IncrementOffset(16);
|
||||
|
||||
constexpr int vals_per_unroll = hw_warp_size * quantization::access_granularity / sizeof(T);
|
||||
|
||||
const int copy_unroll = (group_size + vals_per_unroll - 1) / vals_per_unroll;
|
||||
QUANT_SWITCH((q_bits - q_mantisa_bits - 1) * q_mantisa_bits + stochastic_rounding, [&] {
|
||||
switch (copy_unroll) {
|
||||
LAUNCH_FOR_QUANTIZATION_UNROLL(1)
|
||||
LAUNCH_FOR_QUANTIZATION_UNROLL(2)
|
||||
LAUNCH_FOR_QUANTIZATION_UNROLL(3)
|
||||
LAUNCH_FOR_QUANTIZATION_UNROLL(4)
|
||||
LAUNCH_FOR_QUANTIZATION_UNROLL(5)
|
||||
LAUNCH_FOR_QUANTIZATION_UNROLL(6)
|
||||
}
|
||||
});
|
||||
}
|
||||
#define INSTANTIATE_LAUNCH_QUANTIZATION(T, mantisa, exponent) \
|
||||
template void launch_quantization<T, mantisa, exponent>( \
|
||||
T*, uint8_t*, int, int, cudaStream_t, float q_range, int, int, int);
|
||||
// fp8(E4M3), nearest-rounding
|
||||
#ifdef BF16_AVAILABLE
|
||||
INSTANTIATE_LAUNCH_QUANTIZATION(__nv_bfloat16, 23, 8);
|
||||
#endif
|
||||
INSTANTIATE_LAUNCH_QUANTIZATION(__half, 23, 8);
|
||||
|
||||
#define LAUNCH_FOR_DEQUANTIZATION_UNROLL(COUNT) \
|
||||
case COUNT: \
|
||||
apply_dequantization<T, COUNT, mantisa, 16, CONST_Q_MANTISA_BITS, CONST_Q_EXPONENT_BITS> \
|
||||
<<<grid, block, 0, stream>>>(val, q_val, group_size); \
|
||||
break;
|
||||
|
||||
template <typename T, int mantisa>
|
||||
void launch_dequantization(uint8_t* val,
|
||||
T* q_val,
|
||||
int num_groups,
|
||||
int group_size,
|
||||
int q_mantisa_bits,
|
||||
int q_exponent_bits,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
const dim3 grid((num_groups + quantization::warps - 1) / quantization::warps);
|
||||
const dim3 block(quantization::threads);
|
||||
|
||||
constexpr int vals_per_unroll = hw_warp_size * quantization::access_granularity / sizeof(T);
|
||||
const int copy_unroll = (group_size + vals_per_unroll - 1) / vals_per_unroll;
|
||||
|
||||
DEQUANT_SWITCH(q_mantisa_bits * q_exponent_bits, [&] {
|
||||
switch (copy_unroll) {
|
||||
LAUNCH_FOR_DEQUANTIZATION_UNROLL(1)
|
||||
LAUNCH_FOR_DEQUANTIZATION_UNROLL(2)
|
||||
LAUNCH_FOR_DEQUANTIZATION_UNROLL(3)
|
||||
LAUNCH_FOR_DEQUANTIZATION_UNROLL(4)
|
||||
LAUNCH_FOR_DEQUANTIZATION_UNROLL(5)
|
||||
LAUNCH_FOR_DEQUANTIZATION_UNROLL(6)
|
||||
}
|
||||
});
|
||||
}
|
||||
#define INSTANTIATE_LAUNCH_DEQUANTIZATION(T, mantisa) \
|
||||
template void launch_dequantization<T, mantisa>(uint8_t*, T*, int, int, int, int, cudaStream_t);
|
||||
// fp8(E4M3)
|
||||
#ifdef BF16_AVAILABLE
|
||||
INSTANTIATE_LAUNCH_DEQUANTIZATION(__nv_bfloat16, 7);
|
||||
#endif
|
||||
INSTANTIATE_LAUNCH_DEQUANTIZATION(__half, 10);
|
|
@ -868,6 +868,35 @@ __device__ __forceinline__ void store_global<4, StorePolicy::CacheStreaming>(voi
|
|||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ void store_global<2>(void* dst, const void* src)
|
||||
{
|
||||
const int16_t* data = reinterpret_cast<const int16_t*>(src);
|
||||
|
||||
int16_t* dst_cast = reinterpret_cast<int16_t*>(dst);
|
||||
dst_cast[0] = data[0];
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ void store_global<2, StorePolicy::CacheGlobal>(void* dst,
|
||||
const void* src)
|
||||
{
|
||||
const int16_t* data = reinterpret_cast<const int16_t*>(src);
|
||||
|
||||
int16_t* dst_cast = reinterpret_cast<int16_t*>(dst);
|
||||
dst_cast[0] = data[0];
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ void store_global<2, StorePolicy::CacheStreaming>(void* dst,
|
||||
const void* src)
|
||||
{
|
||||
const int16_t* data = reinterpret_cast<const int16_t*>(src);
|
||||
|
||||
int16_t* dst_cast = reinterpret_cast<int16_t*>(dst);
|
||||
dst_cast[0] = data[0];
|
||||
}
|
||||
|
||||
/////////// Store Shared ///////////
|
||||
|
||||
template <>
|
||||
|
|
|
@ -159,6 +159,12 @@ DS_D_INLINE float element<ROpType::Add>(const float lhs, const float rhs)
|
|||
return lhs + rhs;
|
||||
}
|
||||
|
||||
template <>
|
||||
DS_D_INLINE double element<ROpType::Add>(const double lhs, const double rhs)
|
||||
{
|
||||
return lhs + rhs;
|
||||
}
|
||||
|
||||
template <>
|
||||
DS_D_INLINE float element<ROpType::Max>(const float lhs, const float rhs)
|
||||
{
|
||||
|
@ -189,6 +195,19 @@ DS_D_INLINE __half element<ROpType::Max>(const __half lhs, const __half rhs)
|
|||
#endif
|
||||
}
|
||||
|
||||
#ifdef BF16_AVAILABLE
|
||||
template <>
|
||||
DS_D_INLINE __nv_bfloat16 element<ROpType::Max>(const __nv_bfloat16 lhs, const __nv_bfloat16 rhs)
|
||||
{
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
// Intrinsic limited to Ampere + newer
|
||||
return __hmax(lhs, rhs);
|
||||
#else
|
||||
return (lhs > rhs) ? lhs : rhs;
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
template <>
|
||||
DS_D_INLINE __half element<ROpType::Min>(const __half lhs, const __half rhs)
|
||||
{
|
||||
|
@ -220,6 +239,21 @@ DS_D_INLINE __half2 element<ROpType::Max>(const __half2 lhs, const __half2 rhs)
|
|||
#endif
|
||||
}
|
||||
|
||||
#ifdef BF16_AVAILABLE
|
||||
template <>
|
||||
DS_D_INLINE __nv_bfloat162 element<ROpType::Max>(const __nv_bfloat162 lhs, const __nv_bfloat162 rhs)
|
||||
{
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
return __hmax2(lhs, rhs);
|
||||
#else
|
||||
__nv_bfloat162 ret_val;
|
||||
ret_val.x = (lhs.x > rhs.x) ? lhs.x : rhs.x;
|
||||
ret_val.y = (lhs.y > rhs.y) ? lhs.y : rhs.y;
|
||||
return ret_val;
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
template <>
|
||||
DS_D_INLINE __half2 element<ROpType::Min>(const __half2 lhs, const __half2 rhs)
|
||||
{
|
||||
|
@ -295,6 +329,11 @@ DS_D_INLINE float init<ROpType::Add>()
|
|||
{
|
||||
return 0.0f;
|
||||
}
|
||||
template <>
|
||||
DS_D_INLINE double init<ROpType::Add>()
|
||||
{
|
||||
return (double)0.0f;
|
||||
}
|
||||
|
||||
template <>
|
||||
DS_D_INLINE float init<ROpType::Min>()
|
||||
|
@ -331,6 +370,15 @@ DS_D_INLINE __half init<ROpType::Max>()
|
|||
return __half(neg_inf);
|
||||
}
|
||||
|
||||
#ifdef BF16_AVAILABLE
|
||||
template <>
|
||||
DS_D_INLINE __nv_bfloat16 init<ROpType::Max>()
|
||||
{
|
||||
constexpr __nv_bfloat16_raw neg_inf = {0xFF80};
|
||||
return __nv_bfloat16(neg_inf);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <>
|
||||
DS_D_INLINE __half2 init<ROpType::Add>()
|
||||
{
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from .quantize import FP_Quantize
|
|
@ -0,0 +1,79 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import torch
|
||||
|
||||
from deepspeed.ops.op_builder import FPQuantizerBuilder
|
||||
|
||||
fp_quant_module = None
|
||||
|
||||
|
||||
class FP_Quantize:
|
||||
|
||||
def __init__(self, group_size=512) -> None:
|
||||
global fp_quant_module
|
||||
if fp_quant_module is None:
|
||||
fp_quant_module = FPQuantizerBuilder().load()
|
||||
|
||||
self.group_size = group_size
|
||||
self.orig_dtype = None
|
||||
|
||||
def quantize(self,
|
||||
input,
|
||||
q_bits=8,
|
||||
q_mantisa_bits=3,
|
||||
stochastic_mode=False,
|
||||
return_meta_tensor=False) -> torch.Tensor:
|
||||
assert input.dtype == torch.bfloat16, "only support bf16 for now"
|
||||
if return_meta_tensor:
|
||||
assert q_bits == 8, "meta tensor is only supported with q_bit=8"
|
||||
|
||||
self.orig_dtype = input.dtype
|
||||
self.orig_shape = input.shape
|
||||
|
||||
if q_bits == 8:
|
||||
pass
|
||||
elif q_bits == 12:
|
||||
q_mantisa_bits = 4
|
||||
elif q_bits == 6:
|
||||
q_mantisa_bits = 2
|
||||
elif q_bits == 4:
|
||||
q_mantisa_bits = 1
|
||||
else:
|
||||
assert (0), \
|
||||
f"Missing {q_bits}-quantization, please add the template arguments for the kernel to support this precision!"
|
||||
|
||||
out = fp_quant_module.quantize(input, self.group_size, stochastic_mode, q_bits, q_mantisa_bits)
|
||||
|
||||
if return_meta_tensor:
|
||||
data, scale = out.split(self.group_size, dim=-1)
|
||||
return data.contiguous().reshape(input.shape), scale.contiguous()
|
||||
|
||||
return out
|
||||
|
||||
def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=None) -> torch.Tensor:
|
||||
assert (self.orig_dtype is not None), \
|
||||
"[De-quantization Error]: you need to call quantize before dequantizing!"
|
||||
fp_out = torch.empty(self.orig_shape, dtype=self.orig_dtype,
|
||||
device=input_q.device) if fp_out is None else fp_out
|
||||
if q_bits == 8:
|
||||
pass
|
||||
elif q_bits == 12:
|
||||
q_mantisa_bits = 4
|
||||
elif q_bits == 6:
|
||||
q_mantisa_bits = 2
|
||||
elif q_bits == 4:
|
||||
q_mantisa_bits = 1
|
||||
else:
|
||||
assert (0), \
|
||||
f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!"
|
||||
|
||||
if scale is not None:
|
||||
assert input_q.numel() == fp_out.numel(), \
|
||||
f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!'
|
||||
input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous()
|
||||
|
||||
fp_quant_module.dequantize(fp_out, input_q, self.group_size, q_mantisa_bits, q_bits - q_mantisa_bits - 1)
|
||||
return fp_out
|
|
@ -0,0 +1,63 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from .builder import CUDAOpBuilder, installed_cuda_version
|
||||
|
||||
|
||||
class FPQuantizerBuilder(CUDAOpBuilder):
|
||||
BUILD_VAR = "DS_BUILD_FP_QUANTIZER"
|
||||
NAME = "fp_quantizer"
|
||||
|
||||
def __init__(self, name=None):
|
||||
name = self.NAME if name is None else name
|
||||
super().__init__(name=name)
|
||||
|
||||
def absolute_name(self):
|
||||
return f'deepspeed.ops.fp_quantizer.{self.NAME}_op'
|
||||
|
||||
def is_compatible(self, verbose=True):
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
self.warning("Please install torch if trying to pre-compile inference kernels")
|
||||
return False
|
||||
|
||||
cuda_okay = True
|
||||
if not self.is_rocm_pytorch() and torch.cuda.is_available(): #ignore-cuda
|
||||
sys_cuda_major, _ = installed_cuda_version()
|
||||
torch_cuda_major = int(torch.version.cuda.split('.')[0])
|
||||
cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda
|
||||
if cuda_capability < 8:
|
||||
self.warning("NVIDIA Inference is only supported on Ampere and newer architectures")
|
||||
cuda_okay = False
|
||||
if cuda_capability >= 8:
|
||||
if torch_cuda_major < 11 or sys_cuda_major < 11:
|
||||
self.warning("On Ampere and higher architectures please use CUDA 11+")
|
||||
cuda_okay = False
|
||||
return super().is_compatible(verbose) and cuda_okay
|
||||
|
||||
def filter_ccs(self, ccs):
|
||||
ccs_retained = []
|
||||
ccs_pruned = []
|
||||
for cc in ccs:
|
||||
if int(cc[0]) >= 8:
|
||||
ccs_retained.append(cc)
|
||||
else:
|
||||
ccs_pruned.append(cc)
|
||||
if len(ccs_pruned) > 0:
|
||||
self.warning(f"Filtered compute capabilities {ccs_pruned}")
|
||||
return ccs_retained
|
||||
|
||||
def sources(self):
|
||||
return [
|
||||
"csrc/fp_quantizer/quantize.cu",
|
||||
"csrc/fp_quantizer/quantize.cpp",
|
||||
]
|
||||
|
||||
def extra_ldflags(self):
|
||||
return ['-lcurand']
|
||||
|
||||
def include_paths(self):
|
||||
return ['csrc/fp_quantizer/includes', 'csrc/includes']
|
|
@ -10,6 +10,7 @@ pytest<=8.0.0
|
|||
pytest-forked
|
||||
pytest-randomly
|
||||
pytest-xdist
|
||||
qtorch==0.3.0
|
||||
recommonmark
|
||||
sphinx
|
||||
sphinx-rtd-theme
|
||||
|
|
|
@ -0,0 +1,94 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import deepspeed
|
||||
|
||||
from deepspeed.ops.fp_quantizer import FP_Quantize
|
||||
from deepspeed.ops.op_builder import FPQuantizerBuilder
|
||||
|
||||
if not deepspeed.ops.__compatible_ops__[FPQuantizerBuilder.NAME]:
|
||||
pytest.skip("FPQuantizer op is not available on this system", allow_module_level=True)
|
||||
|
||||
# warning: this import silently JIT builds a set of kernels and may take a minute
|
||||
from qtorch.quant import float_quantize
|
||||
|
||||
|
||||
def qtorch_quantize(input, exp_bits=4, man_bits=3, rounding="nearest", group_size=1024):
|
||||
ori_dt = input.dtype
|
||||
ori_shape = input.shape
|
||||
last_dim = group_size
|
||||
input = input.view(-1, last_dim)
|
||||
|
||||
q_bits = exp_bits + man_bits + 1
|
||||
input_to_float = input.float()
|
||||
if q_bits == 8:
|
||||
q_range = 480.
|
||||
elif q_bits == 6:
|
||||
q_range = 28.
|
||||
elif q_bits == 12:
|
||||
q_range = 510.
|
||||
else:
|
||||
assert (0), \
|
||||
"Please specify the right quantization range for the selected precision!"
|
||||
input_max = input_to_float.abs().amax(dim=-1, keepdim=True)
|
||||
return ((float_quantize(input_to_float / input_max * q_range, exp_bits, man_bits, rounding=rounding) * \
|
||||
input_max / q_range).to(ori_dt)).reshape(ori_shape)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"])
|
||||
def test_fp_quant_meta(dtype):
|
||||
group_size = 128
|
||||
q_bits = 8
|
||||
exp_bits = 4
|
||||
man_bits = 3
|
||||
|
||||
fpq = FP_Quantize(group_size=group_size)
|
||||
for i in range(10):
|
||||
x = torch.rand(4, 1024, dtype=dtype, device='cuda')
|
||||
|
||||
ds_x = x.clone()
|
||||
x_quantized, meta_tensor = fpq.quantize(ds_x, q_bits=q_bits, return_meta_tensor=True)
|
||||
x_dequantized = fpq.dequantize(x_quantized, q_bits=q_bits, scale=meta_tensor)
|
||||
|
||||
qtorch_out = qtorch_quantize(x, exp_bits=exp_bits, man_bits=man_bits, group_size=group_size)
|
||||
qtorch_error = (qtorch_out - x).abs().sum() / x.numel()
|
||||
ds_error = (x_dequantized - x).abs().sum() / x.numel()
|
||||
|
||||
assert 0.0004 > abs(qtorch_error.item() - ds_error.item()), f"failed on iteration {i}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"])
|
||||
@pytest.mark.parametrize("q_bits", [8, 6, 12], ids=["qbits8", "qbits6", "qbits12"])
|
||||
def test_fp_quant(dtype, q_bits):
|
||||
group_size = 128
|
||||
fpq = FP_Quantize(group_size=group_size)
|
||||
|
||||
for i in range(10):
|
||||
x = torch.rand(4, 1024, dtype=dtype, device='cuda')
|
||||
|
||||
ds_x = x.clone()
|
||||
x_quantized = fpq.quantize(ds_x, q_bits=q_bits)
|
||||
x_dequantized = fpq.dequantize(x_quantized, q_bits=q_bits)
|
||||
|
||||
if q_bits == 8:
|
||||
exp_bits = 4
|
||||
man_bits = 3
|
||||
elif q_bits == 6:
|
||||
exp_bits = 3
|
||||
man_bits = 2
|
||||
elif q_bits == 12:
|
||||
exp_bits = 4
|
||||
man_bits = 7
|
||||
else:
|
||||
raise ValueError(f"unknown {q_bits=}")
|
||||
|
||||
qtorch_out = qtorch_quantize(x, exp_bits=exp_bits, man_bits=man_bits, group_size=group_size)
|
||||
|
||||
qtorch_error = (qtorch_out - x).abs().sum() / x.numel()
|
||||
ds_error = (x_dequantized - x).abs().sum() / x.numel()
|
||||
|
||||
assert 0.0004 > abs(qtorch_error.item() - ds_error.item()), f"failed on iteration {i}"
|
Загрузка…
Ссылка в новой задаче