зеркало из https://github.com/microsoft/nni.git
[Bugbash] Bug fix (#5467)
This commit is contained in:
Родитель
ce6b2e8fc9
Коммит
32d47768a5
|
@ -120,7 +120,7 @@ def main():
|
|||
model = model.to(device)
|
||||
configure_list = [{
|
||||
'op_names': ['conv1', 'conv2', 'fc1', 'fc2'],
|
||||
'target_names': ['_input_', 'weight', '_output_'],
|
||||
'target_names': ['_input_', 'weight'],
|
||||
'quant_dtype': 'int8',
|
||||
'quant_scheme': 'affine',
|
||||
'granularity': 'default',
|
||||
|
|
|
@ -81,3 +81,12 @@ def version_check(expect: dict, raise_error: bool = False) -> None:
|
|||
raise RuntimeError('Version check failed: ' + err_message)
|
||||
else:
|
||||
warnings.warn('Version check with warning: ' + err_message)
|
||||
|
||||
|
||||
def torch_version_is_2() -> bool:
|
||||
if TORCH_VERSION is None:
|
||||
return False
|
||||
if TORCH_VERSION < (2, 0):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
|
|
@ -25,9 +25,13 @@ def lsq_clamp_round(target: torch.Tensor, target_space: QuantizationTargetSpace)
|
|||
|
||||
qmax: int = target_space.qmax
|
||||
qmin: int = target_space.qmin
|
||||
if target_space._scaler is not None:
|
||||
scale = target_space._scaler.expand(target_space.scale, target_space.shape, keepdim=True) # type: ignore
|
||||
else:
|
||||
scale = target_space.scale
|
||||
#Quantize
|
||||
grad_scale_factor = 1.0 / ((qmax * target.numel()) ** 0.5) if (qmax * target.numel()) ** 0.5 != 0 else 1.0
|
||||
scale = grad_scale(target_space.scale, grad_scale_factor)
|
||||
scale = grad_scale(scale, grad_scale_factor)
|
||||
new_target = torch.clamp(target / scale, qmin, qmax)
|
||||
dequantized_target = round_pass(new_target) * scale
|
||||
return dequantized_target
|
||||
|
@ -35,40 +39,16 @@ def lsq_clamp_round(target: torch.Tensor, target_space: QuantizationTargetSpace)
|
|||
|
||||
class DoferaGradClampRound(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx: Any, target: torch.Tensor, target_space: QuantizationTargetSpace) -> Any:
|
||||
ctx.target_space = target_space
|
||||
ctx.save_for_backward(target)
|
||||
return target * 1
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, grad_output: Any) -> Any:
|
||||
target_space = ctx.target_space
|
||||
target, = ctx.saved_variables
|
||||
grad_o = torch.abs(grad_output.detach())
|
||||
dim_lis = list(range(len(grad_o.shape)))
|
||||
dim_lis.pop(0)
|
||||
max_grad = torch.amax(grad_o, dim=dim_lis, keepdim=True)
|
||||
# generate uniform noise
|
||||
uniform_k = torch.zeros_like(max_grad).to(target.device)
|
||||
N_k = uniform_k.uniform_(-0.5, 0.5) / (2**(target_space.quant_bits) - 1)
|
||||
q_grad_o = grad_output / (2 * max_grad) + 0.5 + N_k
|
||||
quantized_grad = target_space.zero_point + q_grad_o / target_space.scale
|
||||
quantized_grad = torch.round(torch.clamp(quantized_grad, target_space.qmin, target_space.qmax))
|
||||
dequantized_grad = (quantized_grad - target_space.zero_point) * target_space.scale
|
||||
|
||||
return (dequantized_grad - 0.5) * 2 * max_grad, None
|
||||
|
||||
@staticmethod
|
||||
def dorefa_clamp_round_weight(target: torch.Tensor, target_space: QuantizationTargetSpace):
|
||||
def dorefa_clamp_round_weight(target: torch.Tensor, target_space: QuantizationTargetSpace) -> Any:
|
||||
# TODO process special case: quant_bit == 1
|
||||
target = target.tanh()
|
||||
target = target / (2 * target.abs().max()) + 0.5
|
||||
dequantized_target = ClampRound.apply(target, target_space)
|
||||
|
||||
return 2 * dequantized_target - 1
|
||||
return 2 * dequantized_target - 1 # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def dorefa_clamp_round_input(target: torch.Tensor, target_space: QuantizationTargetSpace):
|
||||
def dorefa_clamp_round_output(target: torch.Tensor, target_space: QuantizationTargetSpace) -> Any:
|
||||
target = torch.clamp(target, 0, 1)
|
||||
return ClampRound.apply(target, target_space)
|
||||
|
||||
|
@ -95,9 +75,15 @@ class BNNClampRound(torch.autograd.Function):
|
|||
class ClampRound(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx: Any, target: torch.Tensor, target_space: QuantizationTargetSpace) -> Any:
|
||||
transformed_target = target_space.zero_point + target / target_space.scale
|
||||
if target_space._scaler is not None:
|
||||
zero_point = target_space._scaler.expand(target_space.zero_point, target_space.shape, keepdim=True) # type: ignore
|
||||
scale = target_space._scaler.expand(target_space.scale, target_space.shape, keepdim=True) # type: ignore
|
||||
else:
|
||||
zero_point = target_space.zero_point
|
||||
scale = target_space.scale
|
||||
transformed_target = zero_point + target / scale
|
||||
quantized_target = torch.round(torch.clamp(transformed_target, target_space.qmin, target_space.qmax))
|
||||
dequantized_target = (quantized_target - target_space.zero_point) * target_space.scale
|
||||
dequantized_target = (quantized_target - zero_point) * scale
|
||||
return dequantized_target
|
||||
|
||||
@staticmethod
|
||||
|
@ -108,9 +94,16 @@ class ClampRound(torch.autograd.Function):
|
|||
class QATClampRound(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx: Any, target: torch.Tensor, target_space: QuantizationTargetSpace) -> Any:
|
||||
transformed_target = target_space.zero_point + target / target_space.scale
|
||||
if target_space._scaler is not None:
|
||||
zero_point = target_space._scaler.expand(target_space.zero_point, target_space.shape, keepdim=True) # type: ignore
|
||||
scale = target_space._scaler.expand(target_space.scale, target_space.shape, keepdim=True) # type: ignore
|
||||
else:
|
||||
zero_point = target_space.zero_point
|
||||
scale = target_space.scale
|
||||
|
||||
transformed_target = zero_point + target / scale
|
||||
quantized_target = torch.round(torch.clamp(transformed_target, target_space.qmin, target_space.qmax))
|
||||
dequantized_target = (quantized_target - target_space.zero_point) * target_space.scale
|
||||
dequantized_target = (quantized_target - zero_point) * scale
|
||||
ctx.save_for_backward(transformed_target)
|
||||
ctx.target_space = target_space
|
||||
return dequantized_target
|
||||
|
@ -159,7 +152,7 @@ def movement_mul_mask(target: torch.Tensor, target_space: PruningTargetSpace):
|
|||
assert target_space.mask is not None and target_space.shape is not None
|
||||
if target_space._scaler is not None:
|
||||
score = target_space._scaler.expand(score, target_space.shape)
|
||||
return torch.mul(target, _StraightThrough.apply(score, target_space.mask))
|
||||
return torch.mul(target, _StraightThrough.apply(score, target_space.mask)) # type: ignore
|
||||
|
||||
|
||||
def movement_add_mask(target: torch.Tensor, target_space: PruningTargetSpace):
|
||||
|
@ -171,7 +164,7 @@ def movement_add_mask(target: torch.Tensor, target_space: PruningTargetSpace):
|
|||
trans_mask = torch.where(target_space.mask == 1, torch.zeros_like(target_space.mask), SMALL_MASK_VALUE)
|
||||
if target_space._scaler is not None:
|
||||
score = target_space._scaler.expand(score, target_space.shape)
|
||||
return torch.add(target, _StraightThrough.apply(score, trans_mask))
|
||||
return torch.add(target, _StraightThrough.apply(score, trans_mask)) # type: ignore
|
||||
|
||||
|
||||
def slim_mul_mask(target: torch.Tensor, target_space: PruningTargetSpace):
|
||||
|
@ -199,9 +192,8 @@ quant_apply_methods = {
|
|||
'bypass': bypass,
|
||||
'clamp_round': ClampRound.apply,
|
||||
'qat_clamp_round': QATClampRound.apply,
|
||||
'dofera_clamp_round_weight': DoferaGradClampRound.dorefa_clamp_round_weight,
|
||||
'dofera_clamp_round_input': DoferaGradClampRound.dorefa_clamp_round_input,
|
||||
'dofera_clamp_round_output': DoferaGradClampRound.apply,
|
||||
'dorefa_clamp_round_weight': DoferaGradClampRound.dorefa_clamp_round_weight,
|
||||
'dorefa_clamp_round_output': DoferaGradClampRound.dorefa_clamp_round_output,
|
||||
"lsq_clamp_round": lsq_clamp_round,
|
||||
'bnn_clamp_round': BNNClampRound.apply,
|
||||
}
|
||||
|
|
|
@ -298,6 +298,16 @@ class Quantizer(Compressor):
|
|||
# scalers are used to support different sparse/quant granularity
|
||||
register_scalers(self._target_spaces, self._set_default_sparse_granularity) # type: ignore
|
||||
|
||||
def check_target(self, wrapper: ModuleWrapper, target_name: str) -> bool:
|
||||
module_name = wrapper.name
|
||||
if module_name not in self._target_spaces:
|
||||
return False
|
||||
ts = self._target_spaces[module_name]
|
||||
if target_name not in ts:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _set_default_sparse_granularity(self, target_space: PruningTargetSpace) -> List[int] | str | None:
|
||||
return None
|
||||
|
||||
|
|
|
@ -54,7 +54,7 @@ class ModuleWrapper(torch.nn.Module):
|
|||
config
|
||||
The config is a dict which contains keys (not required): ``pruning``, ``quantization``, ``distillation``.
|
||||
fused_modules:
|
||||
The List contains a series module names which need to fuse.
|
||||
The List contains a series of modules which need to fuse.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ from typing import List, Dict, overload
|
|||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from ..base.compressor import Quantizer
|
||||
from ..base.compressor import Compressor, Quantizer
|
||||
from ..base.wrapper import ModuleWrapper
|
||||
from ..base.target_space import TargetType
|
||||
from ..utils import Evaluator, _EVALUATOR_DOCSTRING
|
||||
|
@ -71,12 +71,18 @@ class BNNQuantizer(Quantizer):
|
|||
self.register_bnn_apply_method()
|
||||
self.register_track_func()
|
||||
|
||||
@classmethod
|
||||
def from_compressor(cls, compressor: Compressor, new_config_list: List[Dict], evaluator: Evaluator | None = None):
|
||||
return super().from_compressor(compressor, new_config_list, evaluator=evaluator)
|
||||
|
||||
def check_validation(self):
|
||||
for _, ts in self._target_spaces.items():
|
||||
for _, target_space in ts.items():
|
||||
if target_space.quant_dtype is not None:
|
||||
warn_msg = "BNNQuantizer will only quantize the value to 1 or -1; the quant_dtype value will not work"
|
||||
_logger.warning(warn_msg)
|
||||
if target_space._scaler is not None:
|
||||
raise ValueError("BNNQauntizer doesn't support for granularity, please set it to False")
|
||||
|
||||
def register_track_func(self):
|
||||
for module_name, _ in self._target_spaces.items():
|
||||
|
@ -89,7 +95,7 @@ class BNNQuantizer(Quantizer):
|
|||
target_space.apply_method = 'bnn_clamp_round'
|
||||
|
||||
def init_scale_zp(self, wrapper: ModuleWrapper, target_name: str, target: Tensor):
|
||||
if self.is_init or target_name not in wrapper.quantization_target_spaces:
|
||||
if self.is_init or not self.check_target(wrapper, target_name):
|
||||
return
|
||||
target_space = wrapper.quantization_target_spaces[target_name]
|
||||
target_space.zero_point = torch.tensor(0.0).to(target.device)
|
||||
|
|
|
@ -2,15 +2,29 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
from typing import List, Dict, Union, overload
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from ..base.compressor import Quantizer
|
||||
from nni.common.version import torch_version_is_2
|
||||
|
||||
from ..base.compressor import Compressor, Quantizer
|
||||
from ..base.wrapper import ModuleWrapper
|
||||
from ..base.target_space import TargetType
|
||||
from ..utils import Evaluator, _EVALUATOR_DOCSTRING
|
||||
from ..base.target_space import TargetType, QuantizationTargetSpace
|
||||
|
||||
|
||||
ACTIVATION_LIST = [
|
||||
nn.ReLU, nn.RReLU, nn.LeakyReLU, nn.PReLU, nn.Softplus, nn.ELU, nn.CELU, nn.SELU, nn.GELU,
|
||||
nn.ReLU6, nn.Sigmoid, nn.Tanh, nn.Softsign, nn.Hardtanh, nn.Threshold, nn.Tanhshrink,
|
||||
nn.Softshrink, nn.Hardshrink, nn.LogSigmoid, nn.Softmin, nn.Softmax, nn.LogSoftmax, nn.Hardswish,
|
||||
]
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
is_proper_torch_version = torch_version_is_2()
|
||||
|
||||
|
||||
class DoReFaQuantizer(Quantizer):
|
||||
|
@ -55,39 +69,133 @@ class DoReFaQuantizer(Quantizer):
|
|||
super().__init__(model, config_list, evaluator, existed_wrappers=existed_wrappers)
|
||||
self.evaluator: Evaluator
|
||||
self.is_init = False
|
||||
|
||||
self.check_validation()
|
||||
self.register_dorefa_apply_method()
|
||||
self.register_track_func()
|
||||
|
||||
@classmethod
|
||||
def from_compressor(cls, compressor: Compressor, new_config_list: List[Dict], evaluator: Evaluator | None = None):
|
||||
return super().from_compressor(compressor, new_config_list, evaluator=evaluator)
|
||||
|
||||
def check_validation(self) -> None:
|
||||
for ts in self._target_spaces.values():
|
||||
for target_space in ts.values():
|
||||
assert target_space.quant_scheme != None
|
||||
if target_space.type is TargetType.PARAMETER and target_space.quant_scheme != 'affine':
|
||||
warn_msg = f'Only supports affine mode for weight quantization, bug got {target_space.quant_scheme}'
|
||||
_logger.warning(warn_msg)
|
||||
elif target_space.type is TargetType.OUTPUT:
|
||||
module = target_space._wrapper.module
|
||||
# case 1: activation module
|
||||
# case 2: module with activation fused_modules
|
||||
fused_modules = target_space._wrapper.fused_modules
|
||||
if not isinstance(module, tuple(ACTIVATION_LIST)) and not (fused_modules and # type: ignore
|
||||
any([isinstance(item, tuple(ACTIVATION_LIST)) for item in fused_modules[1:]])): # type: ignore
|
||||
raise ValueError('Output quantization is only supported for activation function or' + \
|
||||
f'activation module fusion, but got {type(module)}')
|
||||
if target_space.quant_scheme != 'affine':
|
||||
warn_msg = f'Only supports affine mode for output quantization, bug got {target_space.quant_scheme}'
|
||||
_logger.warning(warn_msg)
|
||||
if target_space._scaler is not None:
|
||||
raise ValueError('DoRefa Qauntizer doesn\'t support for granularity, please set it to False')
|
||||
|
||||
def _quant_dequant_gradient_hook(self, target_space: QuantizationTargetSpace) -> None:
|
||||
def quant_dequant_gradient(module: nn.Module, grad_output):
|
||||
tracked_max = torch.tensor(1.0 + 0.5 / (2**target_space.quant_bits - 1)).to(grad_output[0].device)
|
||||
tracked_min = torch.tensor(0 - 0.5 / (2**target_space.quant_bits - 1)).to(grad_output[0].device)
|
||||
scale, zero_point = init_scale_zp(tracked_max, tracked_min, target_space.qmax, \
|
||||
target_space.qmin, 'affine')
|
||||
new_grad_output = []
|
||||
for g_o in grad_output:
|
||||
grad_o = torch.abs(g_o.clone().detach())
|
||||
dim_lis = list(range(len(grad_o.shape)))
|
||||
dim_lis.pop(0)
|
||||
max_grad = torch.amax(grad_o, dim=dim_lis, keepdim=True)
|
||||
# generate uniform noise
|
||||
uniform_k = torch.zeros_like(max_grad).to(g_o.device)
|
||||
N_k = uniform_k.uniform_(-0.5, 0.5) / (2**(target_space.quant_bits) - 1)
|
||||
q_grad_o = g_o / (2 * max_grad) + 0.5 + N_k
|
||||
quantized_grad = zero_point + q_grad_o / scale
|
||||
quantized_grad = torch.round(torch.clamp(quantized_grad, target_space.qmin, target_space.qmax))
|
||||
dequantized_grad = (quantized_grad - zero_point) * scale
|
||||
new_grad_output.append((dequantized_grad - 0.5) * 2 * max_grad)
|
||||
|
||||
return tuple(new_grad_output)
|
||||
|
||||
target_space._wrapper.module.register_full_backward_pre_hook(quant_dequant_gradient) # type: ignore
|
||||
|
||||
def register_output_backward_hook(self):
|
||||
for ts in self._target_spaces.values():
|
||||
is_output = any([target_space.type is TargetType.OUTPUT for target_space in ts.values()])
|
||||
is_param = any([target_space.type is TargetType.PARAMETER for target_space in ts.values()])
|
||||
if is_param and not is_output:
|
||||
if is_proper_torch_version: # torch version >= 2.0.0
|
||||
for target_space in ts.values():
|
||||
if target_space.type is TargetType.PARAMETER:
|
||||
self._quant_dequant_gradient_hook(target_space)
|
||||
break
|
||||
else:
|
||||
warn_msg = f'Gradient quantization is only supported for torch version >= 2.0.0'
|
||||
_logger.warning(warn_msg)
|
||||
|
||||
def register_dorefa_apply_method(self):
|
||||
for _, ts in self._target_spaces.items():
|
||||
for _, target_space in ts.items():
|
||||
if target_space.type is TargetType.PARAMETER:
|
||||
target_space.apply_method = 'dofera_clamp_round_weight'
|
||||
target_space.apply_method = 'dorefa_clamp_round_weight'
|
||||
elif target_space.type is TargetType.INPUT:
|
||||
target_space.apply_method = "dofera_clamp_round_input"
|
||||
target_space.apply_method = 'clamp_round'
|
||||
elif target_space.type is TargetType.OUTPUT:
|
||||
target_space.apply_method = "dofera_clamp_round_output"
|
||||
target_space.apply_method = 'dorefa_clamp_round_output'
|
||||
|
||||
def register_track_func(self):
|
||||
for module_name, _ in self._target_spaces.items():
|
||||
wrapper = self._module_wrappers[module_name]
|
||||
wrapper.register_track_func(self.initialize_scale_zp)
|
||||
wrapper.register_track_func(self.update_scale_zp)
|
||||
|
||||
def initialize_scale_zp(self, wrapper: ModuleWrapper, target_name: str, target: Tensor):
|
||||
if self.is_init or target_name not in wrapper.quantization_target_spaces:
|
||||
def update_scale_zp(self, wrapper: ModuleWrapper, target_name: str, target: Tensor) -> None:
|
||||
if not self.check_target(wrapper, target_name):
|
||||
return
|
||||
target_space = wrapper.quantization_target_spaces[target_name]
|
||||
if target_space.type is TargetType.INPUT or "weight" in target_name: #zero_point and scale don't change anymore
|
||||
if target_space.type is not TargetType.INPUT:
|
||||
return
|
||||
# track min max values
|
||||
current_amin = target.detach().reshape(-1).amin(-1)
|
||||
current_amax = target.detach().reshape(-1).amax(-1)
|
||||
# update scale and zero_point
|
||||
tracked_min = torch.min(current_amin, torch.zeros_like(current_amin))
|
||||
tracked_max = torch.max(current_amax, torch.zeros_like(current_amax))
|
||||
zero_point = torch.zeros_like(tracked_min)
|
||||
qmin, qmax = target_space.qmin, target_space.qmax
|
||||
assert isinstance(qmin, int) and isinstance(qmax, int)
|
||||
if target_space.quant_scheme in ['symmetric', None]:
|
||||
abs_max = torch.max(torch.abs(tracked_min), torch.abs(tracked_max))
|
||||
scale = abs_max / (float(qmax - qmin) / 2)
|
||||
scale = torch.max(scale, torch.full_like(scale, torch.finfo(torch.float32).eps))
|
||||
# NOTE: here need to check, +1 because in pytorch, symmetric qint8 zp is 0, quint8 zp is 128.
|
||||
zero_point_val = (qmax + qmin + 1) // 2
|
||||
zero_point = torch.full_like(zero_point, zero_point_val)
|
||||
elif target_space.quant_scheme == 'affine':
|
||||
scale = (tracked_max - tracked_min) / float(qmax - qmin)
|
||||
scale = torch.max(scale, torch.full_like(scale, torch.finfo(torch.float32).eps))
|
||||
zero_point = qmin - torch.round(tracked_min / scale)
|
||||
else:
|
||||
raise RuntimeError(f'Unknown quant_scheme {target_space.quant_scheme}')
|
||||
zero_point = torch.clamp(zero_point, qmin, qmax)
|
||||
target_space.scale, target_space.zero_point = scale, zero_point
|
||||
|
||||
def initialize_scale_zp(self, wrapper: ModuleWrapper, target_name: str, target: Tensor):
|
||||
if self.is_init or not self.check_target(wrapper, target_name):
|
||||
return
|
||||
target_space = wrapper.quantization_target_spaces[target_name]
|
||||
if target_space.type is TargetType.INPUT:
|
||||
return
|
||||
elif target_space.type in [TargetType.OUTPUT, TargetType.PARAMETER]:
|
||||
tracked_max = torch.tensor(1.0).to(target.device)
|
||||
tracked_min = torch.tensor(0.0).to(target.device)
|
||||
scale, zero_point = init_scale_zp(tracked_max, tracked_min, target_space.qmax, \
|
||||
target_space.qmin, 'affine')
|
||||
elif target_space.type is TargetType.OUTPUT:
|
||||
tracked_max = torch.tensor(1.0 + 0.5 / (2**target_space.quant_bits - 1)).to(target.device)
|
||||
tracked_min = torch.tensor(0 - 0.5 / (2**target_space.quant_bits - 1)).to(target.device)
|
||||
scale, zero_point = init_scale_zp(tracked_max, tracked_min, target_space.qmax, \
|
||||
target_space.qmin, 'affine')
|
||||
else:
|
||||
raise RuntimeError(f'Unknown target_name {target_name}')
|
||||
|
||||
|
@ -103,6 +211,7 @@ class DoReFaQuantizer(Quantizer):
|
|||
self._fusion_compress(max_steps, max_epochs)
|
||||
|
||||
def _fuse_preprocess(self, evaluator: Evaluator) -> None:
|
||||
self.register_output_backward_hook()
|
||||
module_name_param_dict = self.patch_optimizer_param_group()
|
||||
if len(module_name_param_dict) > 0:
|
||||
evaluator.patch_optim_param_group(module_name_param_dict)
|
||||
|
@ -121,7 +230,7 @@ def init_scale_zp(tracked_max: Tensor, tracked_min: Tensor, qmax: int, qmin: int
|
|||
scale = torch.max(scale, torch.full_like(scale, torch.finfo(torch.float32).eps))
|
||||
zero_point = qmin - torch.round(tracked_min / scale)
|
||||
elif quant_scheme in ['symmetric', None]:
|
||||
raise ValueError(f"Unsupported quant_scheme {quant_scheme}")
|
||||
raise ValueError(f'Unsupported quant_scheme {quant_scheme}')
|
||||
else:
|
||||
raise RuntimeError(f'Unknown quant_scheme {quant_scheme}')
|
||||
|
||||
|
|
|
@ -2,14 +2,19 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
from typing import List, Dict, overload
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from ..base.compressor import Quantizer
|
||||
from ..base.compressor import Compressor, Quantizer
|
||||
from ..base.wrapper import ModuleWrapper
|
||||
from ..utils import Evaluator, _EVALUATOR_DOCSTRING
|
||||
from ..base.target_space import TargetType
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LsqQuantizer(Quantizer):
|
||||
|
@ -59,22 +64,46 @@ class LsqQuantizer(Quantizer):
|
|||
self.evaluator: Evaluator
|
||||
self.is_init = False
|
||||
|
||||
self.check_validation()
|
||||
self.register_scale()
|
||||
self.register_lsq_apply_method()
|
||||
self.register_track_func()
|
||||
|
||||
@classmethod
|
||||
def from_compressor(cls, compressor: Compressor, new_config_list: List[Dict], evaluator: Evaluator | None = None):
|
||||
return super().from_compressor(compressor, new_config_list, evaluator=evaluator)
|
||||
|
||||
def check_validation(self) -> None:
|
||||
for ts in self._target_spaces.values():
|
||||
for target_space in ts.values():
|
||||
if target_space.quant_scheme != 'symmetric':
|
||||
warn_msg = f"LsqQuantizer only supports symmetric mode, but got {target_space.quant_scheme}"
|
||||
_logger.warning(warn_msg)
|
||||
if target_space.quant_dtype.startswith("uint") and target_space.type is TargetType.PARAMETER:
|
||||
warn_msg = f"In the LsqQuantizer, quantization of parameters only supports int type"
|
||||
_logger.warning(warn_msg)
|
||||
|
||||
def register_track_func(self):
|
||||
for module_name, _ in self._target_spaces.items():
|
||||
wrapper = self._module_wrappers[module_name]
|
||||
wrapper.register_track_func(self.init_scale)
|
||||
|
||||
def init_scale(self, wrapper: ModuleWrapper, target_name: str, target: Tensor):
|
||||
if self.is_init or target_name not in wrapper.quantization_target_spaces:
|
||||
def mean_reduce_func(converted_target: Tensor) -> torch.Tensor:
|
||||
return converted_target.detach().mean(dim=-1)
|
||||
|
||||
if self.is_init or not self.check_target(wrapper, target_name):
|
||||
return
|
||||
target_space = wrapper.quantization_target_spaces[target_name]
|
||||
init_target = target.data.detach().abs().mean() * 2 / (target_space.qmax ** 0.5)
|
||||
target_space.scale.data = init_target # type: ignore
|
||||
target_space.zero_point = torch.tensor(0.0).to(target.device)
|
||||
if not target_space._scaler:
|
||||
target_space.scale.data = init_target # type: ignore
|
||||
target_space.zero_point = torch.tensor(0.0).to(target.device)
|
||||
else:
|
||||
new_target = init_target.expand(target.shape).to(target.device)
|
||||
new_target_scale = target_space._scaler.shrink(new_target, mean_reduce_func, keepdim=True)
|
||||
target_space.scale.data = new_target_scale # type: ignore
|
||||
target_space.zero_point = torch.zeros_like(new_target_scale)
|
||||
|
||||
def register_lsq_apply_method(self):
|
||||
for _, ts in self._target_spaces.items():
|
||||
|
@ -87,15 +116,7 @@ class LsqQuantizer(Quantizer):
|
|||
for target_name, _ in ts.items():
|
||||
if hasattr(wrapper, f"{target_name}_scale"):
|
||||
delattr(wrapper, f"{target_name}_scale")
|
||||
try:
|
||||
device = next(wrapper.parameters()).device
|
||||
except StopIteration:
|
||||
try:
|
||||
device = next(wrapper.buffers()).device
|
||||
except StopIteration:
|
||||
# NOTE: this will have risk in model parallel
|
||||
device = next(self.bound_model.parameters()).device
|
||||
param = torch.nn.Parameter(torch.Tensor([1.0]).to(device))
|
||||
param = torch.nn.Parameter()
|
||||
wrapper.register_parameter(f"{target_name}_scale", param)
|
||||
|
||||
def patch_optimizer_param_group(self):
|
||||
|
|
|
@ -7,7 +7,7 @@ from typing import List, Dict, Union, overload
|
|||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from ..base.compressor import Quantizer
|
||||
from ..base.compressor import Compressor, Quantizer
|
||||
from ..base.wrapper import ModuleWrapper
|
||||
from ..utils import Evaluator, _EVALUATOR_DOCSTRING
|
||||
|
||||
|
@ -55,6 +55,10 @@ class PtqQuantizer(Quantizer):
|
|||
self.register_ptq_apply_method()
|
||||
self.register_track_func()
|
||||
|
||||
@classmethod
|
||||
def from_compressor(cls, compressor: Compressor, new_config_list: List[Dict], evaluator: Evaluator | None = None):
|
||||
return super().from_compressor(compressor, new_config_list, evaluator=evaluator)
|
||||
|
||||
def register_ptq_apply_method(self):
|
||||
for _, ts in self._target_spaces.items():
|
||||
for _, target_space in ts.items():
|
||||
|
|
|
@ -8,7 +8,7 @@ from typing import Dict, List, overload
|
|||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from ..base.compressor import Quantizer
|
||||
from ..base.compressor import Compressor, Quantizer
|
||||
from ..base.wrapper import ModuleWrapper
|
||||
from ..utils import Evaluator, _EVALUATOR_DOCSTRING
|
||||
|
||||
|
@ -81,6 +81,11 @@ class QATQuantizer(Quantizer):
|
|||
self.register_qat_apply_method()
|
||||
self.register_track_func()
|
||||
|
||||
@classmethod
|
||||
def from_compressor(cls, compressor: Compressor, new_config_list: List[Dict],
|
||||
quant_start_step: int = 0, evaluator: Evaluator | None = None):
|
||||
return super().from_compressor(compressor, new_config_list, quant_start_step=quant_start_step, evaluator=evaluator)
|
||||
|
||||
def register_qat_apply_method(self):
|
||||
if self.current_step < self.quant_start_step:
|
||||
for _, ts in self._target_spaces.items():
|
||||
|
@ -103,12 +108,13 @@ class QATQuantizer(Quantizer):
|
|||
|
||||
def track_min_max_val(self, wrapper: ModuleWrapper, target_name: str, target: Tensor):
|
||||
# in a fused compression pipeline, the target name may be another compressor's target name
|
||||
if not wrapper.training or target_name not in wrapper.quantization_target_spaces:
|
||||
if not wrapper.training or not self.check_target(wrapper, target_name):
|
||||
return
|
||||
return track_min_max_val(wrapper, target_name, target)
|
||||
|
||||
def update_scale_zp(self, wrapper: ModuleWrapper, target_name: str, target: Tensor):
|
||||
if not wrapper.training or self.current_step < self.quant_start_step:
|
||||
if not wrapper.training or self.current_step < self.quant_start_step \
|
||||
or not self.check_target(wrapper, target_name):
|
||||
return
|
||||
if target_name in wrapper.quantization_target_spaces:
|
||||
target_space = wrapper.quantization_target_spaces[target_name]
|
||||
|
|
|
@ -32,14 +32,14 @@ def test_dorefa_forward_with_torch_model():
|
|||
torch.manual_seed(0)
|
||||
model = SimpleTorchModel().to(device)
|
||||
configure_list = [{
|
||||
'target_names':['_input_', 'weight', '_output_'],
|
||||
'target_names':['weight'],
|
||||
'op_names': ['fc1', 'fc2'],
|
||||
'quant_dtype': 'int8',
|
||||
'quant_scheme': 'affine',
|
||||
'granularity': 'default',
|
||||
},
|
||||
{
|
||||
'target_names':['_input_', 'weight', '_output_'],
|
||||
'target_names':['_input_', 'weight'],
|
||||
'op_names': ['conv1', 'conv2', 'conv3'],
|
||||
'quant_dtype': 'int8',
|
||||
'quant_scheme': 'affine',
|
||||
|
@ -54,14 +54,14 @@ def test_dorefa_forward_with_torch_model():
|
|||
def test_dorefa_forward_with_lighting_model():
|
||||
torch.manual_seed(0)
|
||||
configure_list = [{
|
||||
'target_names':['_input_', 'weight', '_output_'],
|
||||
'target_names':['_input_', 'weight'],
|
||||
'op_names': ['model.fc1', 'model.fc2'],
|
||||
'quant_dtype': 'int8',
|
||||
'quant_scheme': 'affine',
|
||||
'granularity': 'default',
|
||||
},
|
||||
{
|
||||
'target_names':['_input_', 'weight', '_output_'],
|
||||
'target_names':['_input_', 'weight'],
|
||||
'op_names': ['model.conv1', 'model.conv2', 'model.conv3'],
|
||||
'quant_dtype': 'int8',
|
||||
'quant_scheme': 'affine',
|
||||
|
|
|
@ -32,11 +32,11 @@ def test_lsq_forward_with_torch_model():
|
|||
torch.manual_seed(0)
|
||||
model = SimpleTorchModel().to(device)
|
||||
configure_list = [{
|
||||
'target_names':['_input_', 'weight', '_output_'],
|
||||
'target_names':['weight'],
|
||||
'op_names': ['fc1', 'fc2'],
|
||||
'quant_dtype': 'int8',
|
||||
'quant_scheme': 'affine',
|
||||
'granularity': 'default',
|
||||
'granularity': 'in_channel',
|
||||
},
|
||||
{
|
||||
'target_names':['_input_', 'weight', '_output_'],
|
||||
|
@ -54,11 +54,11 @@ def test_lsq_forward_with_torch_model():
|
|||
def test_lsq_forward_with_lighting_model():
|
||||
torch.manual_seed(0)
|
||||
configure_list = [{
|
||||
'target_names':['_input_', 'weight', '_output_'],
|
||||
'target_names':['weight'],
|
||||
'op_names': ['model.fc1', 'model.fc2'],
|
||||
'quant_dtype': 'int8',
|
||||
'quant_scheme': 'affine',
|
||||
'granularity': 'default',
|
||||
'granularity': 'in_channel',
|
||||
},
|
||||
{
|
||||
'target_names':['_input_', 'weight', '_output_'],
|
||||
|
|
|
@ -32,11 +32,11 @@ def test_ptq_forward_with_torch_model():
|
|||
torch.manual_seed(0)
|
||||
model = SimpleTorchModel().to(device)
|
||||
configure_list = [{
|
||||
'target_names':['_input_', 'weight', '_output_'],
|
||||
'target_names':['weight'],
|
||||
'op_names': ['fc1', 'fc2'],
|
||||
'quant_dtype': 'int8',
|
||||
'quant_scheme': 'affine',
|
||||
'granularity': 'default',
|
||||
'granularity': 'in_channel',
|
||||
},
|
||||
{
|
||||
'target_names':['_input_', 'weight', '_output_'],
|
||||
|
@ -54,11 +54,11 @@ def test_ptq_forward_with_torch_model():
|
|||
def test_ptq_forward_with_lighting_model():
|
||||
torch.manual_seed(0)
|
||||
configure_list = [{
|
||||
'target_names':['_input_', 'weight', '_output_'],
|
||||
'target_names':['weight'],
|
||||
'op_names': ['model.fc1', 'model.fc2'],
|
||||
'quant_dtype': 'int8',
|
||||
'quant_scheme': 'affine',
|
||||
'granularity': 'default',
|
||||
'granularity': 'in_channel',
|
||||
},
|
||||
{
|
||||
'target_names':['_input_', 'weight', '_output_'],
|
||||
|
|
|
@ -32,11 +32,11 @@ def test_qat_forward_with_torch_model():
|
|||
torch.manual_seed(0)
|
||||
model = SimpleTorchModel().to(device)
|
||||
configure_list = [{
|
||||
'target_names':['_input_', 'weight', '_output_'],
|
||||
'target_names':['weight'],
|
||||
'op_names': ['fc1', 'fc2'],
|
||||
'quant_dtype': 'int8',
|
||||
'quant_scheme': 'affine',
|
||||
'granularity': 'default',
|
||||
'granularity': 'in_channel',
|
||||
},
|
||||
{
|
||||
'target_names':['_input_', 'weight', '_output_'],
|
||||
|
@ -54,11 +54,11 @@ def test_qat_forward_with_torch_model():
|
|||
def test_qat_forward_with_lighting_model():
|
||||
torch.manual_seed(0)
|
||||
configure_list = [{
|
||||
'target_names':['_input_', 'weight', '_output_'],
|
||||
'target_names':['weight'],
|
||||
'op_names': ['model.fc1', 'model.fc2'],
|
||||
'quant_dtype': 'int8',
|
||||
'quant_scheme': 'affine',
|
||||
'granularity': 'default',
|
||||
'granularity': 'in_channel',
|
||||
},
|
||||
{
|
||||
'target_names':['_input_', 'weight', '_output_'],
|
||||
|
|
Загрузка…
Ссылка в новой задаче