int8 specialization for AVX2 Quantize routine (#120)

Summary:
This adds a specialization for `int8` to the AVX2 `Quantize` routine.

I tried also adding a specialization for `int32` (the final datatype we support in PyTorch quantization), but it seemed to introduce numerical issues stemming from the difference in implementations:

https://github.com/pytorch/FBGEMM/blob/master/include/fbgemm/QuantUtils.h#L63

vs

https://github.com/pytorch/FBGEMM/blob/master/src/QuantUtilsAvx2.cc#L82
Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/120

Reviewed By: driazati

Differential Revision: D17115198

Pulled By: jamesr66a

fbshipit-source-id: 119145bb99235a7545389afa61483060200cc2b7
This commit is contained in:
James Reed 2019-08-29 11:11:31 -07:00 коммит произвёл Facebook Github Bot
Родитель 280fa17349
Коммит d4bfa96cda
3 изменённых файлов: 45 добавлений и 22 удалений

Просмотреть файл

@ -40,9 +40,10 @@ struct FBGEMM_API RequantizationParams {
////////////////////////////////////////////////////////////////////////////////
// Utility functions
template <typename T=std::uint8_t>
void QuantizeAvx2(
const float* src,
std::uint8_t* dst,
T* dst,
int len,
const TensorQuantizationParams& qparams);

Просмотреть файл

@ -164,30 +164,34 @@ void ChooseRequantizationMultiplier(
dst[i] = Quantize<T>(src[i], qparams); \
} \
}
FBGEMM_SPECIALIZED_QUANTIZE(int8_t)
FBGEMM_SPECIALIZED_QUANTIZE(uint16_t)
FBGEMM_SPECIALIZED_QUANTIZE(int16_t)
FBGEMM_SPECIALIZED_QUANTIZE(int32_t)
#undef FBGEMM_SPECIALIZED_QUANTIZE
template <>
void Quantize<uint8_t>(
const float* src,
uint8_t* dst,
int len,
const TensorQuantizationParams& qparams) {
bool avx2_support = cpuinfo_initialize() && fbgemmHasAvx2Support();
bool fma_support = cpuinfo_has_x86_fma3();
if (avx2_support && fma_support && qparams.precision == 8) {
// fast path
QuantizeAvx2(src, dst, len, qparams);
} else {
for (std::size_t i = 0; i < len; ++i) {
dst[i] = Quantize<uint8_t>(src[i], qparams);
}
}
#define FBGEMM_SPECIALIZED_QUANTIZE_AVX2(T) \
template <> \
void Quantize<T>( \
const float* src, \
T* dst, \
int len, \
const TensorQuantizationParams& qparams) { \
bool avx2_support = cpuinfo_initialize() && fbgemmHasAvx2Support(); \
bool fma_support = cpuinfo_has_x86_fma3(); \
if (avx2_support && fma_support && qparams.precision == 8) { \
/* fast path */ \
QuantizeAvx2<T>(src, dst, len, qparams); \
} else { \
for (std::size_t i = 0; i < len; ++i) { \
dst[i] = Quantize<T>(src[i], qparams); \
} \
} \
}
FBGEMM_SPECIALIZED_QUANTIZE_AVX2(int8_t)
FBGEMM_SPECIALIZED_QUANTIZE_AVX2(uint8_t)
#undef FBGEMM_SPECIALIZED_QUANTIZE_AVX2
#define FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKCX(T) \
template <> \
void QuantizeGroupwise<T, layout_t::KCX>( \

Просмотреть файл

@ -18,13 +18,16 @@ using namespace std;
////////////////////////////////////////////////////////////////////////////////
// Utility functions
template <typename T>
void QuantizeAvx2(
const float* src,
uint8_t* dst,
T* dst,
int len,
const TensorQuantizationParams& qparams) {
#if defined(__AVX2__) && defined(__FMA__)
constexpr int VLEN = 8;
constexpr float min_val = std::numeric_limits<T>::min();
constexpr float max_val = std::numeric_limits<T>::max();
std::size_t i = 0;
__m256 inverse_scale_v = _mm256_set1_ps(1.f / qparams.scale);
__m256i shuffle_mask_v = _mm256_set_epi8(
@ -67,8 +70,8 @@ void QuantizeAvx2(
__m256 transformed_v = _mm256_fmadd_ps(
src_v, inverse_scale_v, _mm256_set1_ps(qparams.zero_point));
__m256 clipped_v = _mm256_min_ps(
_mm256_max_ps(transformed_v, _mm256_set1_ps(0.f)),
_mm256_set1_ps(255.f));
_mm256_max_ps(transformed_v, _mm256_set1_ps(min_val)),
_mm256_set1_ps(max_val));
__m256i rounded_v = _mm256_cvtps_epi32(clipped_v);
// An instruction sequence to save 8 32-bit integers as 8 8-bit integers
@ -80,7 +83,7 @@ void QuantizeAvx2(
for (; i < len; ++i) {
float transformed = qparams.zero_point + src[i] / qparams.scale;
float clipped = std::min(std::max(transformed, 0.f), 255.f);
float clipped = std::min(std::max(transformed, min_val), max_val);
// Not exactly the same behavior as the vectorized code.
// The vectorized code above always rounds to even in halfway cases
// (https://software.intel.com/en-us/node/523819), but std::nearbyint
@ -95,6 +98,21 @@ void QuantizeAvx2(
#endif
}
// Instantiate QuantizeAvx2 for known datatypes
template
void QuantizeAvx2<uint8_t>(
const float* src,
uint8_t* dst,
int len,
const TensorQuantizationParams& qparams);
template
void QuantizeAvx2<int8_t>(
const float* src,
int8_t* dst,
int len,
const TensorQuantizationParams& qparams);
void FindMinMax(const float* a, float* min, float* max, int len) {
if (len <= 0) {
*min = 0.0f;