This commit is contained in:
Bonytu 2023-03-30 13:13:11 +08:00 коммит произвёл GitHub
Родитель ce6b2e8fc9
Коммит 32d47768a5
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
14 изменённых файлов: 246 добавлений и 89 удалений

Просмотреть файл

@ -120,7 +120,7 @@ def main():
model = model.to(device)
configure_list = [{
'op_names': ['conv1', 'conv2', 'fc1', 'fc2'],
'target_names': ['_input_', 'weight', '_output_'],
'target_names': ['_input_', 'weight'],
'quant_dtype': 'int8',
'quant_scheme': 'affine',
'granularity': 'default',

Просмотреть файл

@ -81,3 +81,12 @@ def version_check(expect: dict, raise_error: bool = False) -> None:
raise RuntimeError('Version check failed: ' + err_message)
else:
warnings.warn('Version check with warning: ' + err_message)
def torch_version_is_2() -> bool:
if TORCH_VERSION is None:
return False
if TORCH_VERSION < (2, 0):
return False
else:
return True

Просмотреть файл

@ -25,9 +25,13 @@ def lsq_clamp_round(target: torch.Tensor, target_space: QuantizationTargetSpace)
qmax: int = target_space.qmax
qmin: int = target_space.qmin
if target_space._scaler is not None:
scale = target_space._scaler.expand(target_space.scale, target_space.shape, keepdim=True) # type: ignore
else:
scale = target_space.scale
#Quantize
grad_scale_factor = 1.0 / ((qmax * target.numel()) ** 0.5) if (qmax * target.numel()) ** 0.5 != 0 else 1.0
scale = grad_scale(target_space.scale, grad_scale_factor)
scale = grad_scale(scale, grad_scale_factor)
new_target = torch.clamp(target / scale, qmin, qmax)
dequantized_target = round_pass(new_target) * scale
return dequantized_target
@ -35,40 +39,16 @@ def lsq_clamp_round(target: torch.Tensor, target_space: QuantizationTargetSpace)
class DoferaGradClampRound(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, target: torch.Tensor, target_space: QuantizationTargetSpace) -> Any:
ctx.target_space = target_space
ctx.save_for_backward(target)
return target * 1
@staticmethod
def backward(ctx: Any, grad_output: Any) -> Any:
target_space = ctx.target_space
target, = ctx.saved_variables
grad_o = torch.abs(grad_output.detach())
dim_lis = list(range(len(grad_o.shape)))
dim_lis.pop(0)
max_grad = torch.amax(grad_o, dim=dim_lis, keepdim=True)
# generate uniform noise
uniform_k = torch.zeros_like(max_grad).to(target.device)
N_k = uniform_k.uniform_(-0.5, 0.5) / (2**(target_space.quant_bits) - 1)
q_grad_o = grad_output / (2 * max_grad) + 0.5 + N_k
quantized_grad = target_space.zero_point + q_grad_o / target_space.scale
quantized_grad = torch.round(torch.clamp(quantized_grad, target_space.qmin, target_space.qmax))
dequantized_grad = (quantized_grad - target_space.zero_point) * target_space.scale
return (dequantized_grad - 0.5) * 2 * max_grad, None
@staticmethod
def dorefa_clamp_round_weight(target: torch.Tensor, target_space: QuantizationTargetSpace):
def dorefa_clamp_round_weight(target: torch.Tensor, target_space: QuantizationTargetSpace) -> Any:
# TODO process special case: quant_bit == 1
target = target.tanh()
target = target / (2 * target.abs().max()) + 0.5
dequantized_target = ClampRound.apply(target, target_space)
return 2 * dequantized_target - 1
return 2 * dequantized_target - 1 # type: ignore
@staticmethod
def dorefa_clamp_round_input(target: torch.Tensor, target_space: QuantizationTargetSpace):
def dorefa_clamp_round_output(target: torch.Tensor, target_space: QuantizationTargetSpace) -> Any:
target = torch.clamp(target, 0, 1)
return ClampRound.apply(target, target_space)
@ -95,9 +75,15 @@ class BNNClampRound(torch.autograd.Function):
class ClampRound(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, target: torch.Tensor, target_space: QuantizationTargetSpace) -> Any:
transformed_target = target_space.zero_point + target / target_space.scale
if target_space._scaler is not None:
zero_point = target_space._scaler.expand(target_space.zero_point, target_space.shape, keepdim=True) # type: ignore
scale = target_space._scaler.expand(target_space.scale, target_space.shape, keepdim=True) # type: ignore
else:
zero_point = target_space.zero_point
scale = target_space.scale
transformed_target = zero_point + target / scale
quantized_target = torch.round(torch.clamp(transformed_target, target_space.qmin, target_space.qmax))
dequantized_target = (quantized_target - target_space.zero_point) * target_space.scale
dequantized_target = (quantized_target - zero_point) * scale
return dequantized_target
@staticmethod
@ -108,9 +94,16 @@ class ClampRound(torch.autograd.Function):
class QATClampRound(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, target: torch.Tensor, target_space: QuantizationTargetSpace) -> Any:
transformed_target = target_space.zero_point + target / target_space.scale
if target_space._scaler is not None:
zero_point = target_space._scaler.expand(target_space.zero_point, target_space.shape, keepdim=True) # type: ignore
scale = target_space._scaler.expand(target_space.scale, target_space.shape, keepdim=True) # type: ignore
else:
zero_point = target_space.zero_point
scale = target_space.scale
transformed_target = zero_point + target / scale
quantized_target = torch.round(torch.clamp(transformed_target, target_space.qmin, target_space.qmax))
dequantized_target = (quantized_target - target_space.zero_point) * target_space.scale
dequantized_target = (quantized_target - zero_point) * scale
ctx.save_for_backward(transformed_target)
ctx.target_space = target_space
return dequantized_target
@ -159,7 +152,7 @@ def movement_mul_mask(target: torch.Tensor, target_space: PruningTargetSpace):
assert target_space.mask is not None and target_space.shape is not None
if target_space._scaler is not None:
score = target_space._scaler.expand(score, target_space.shape)
return torch.mul(target, _StraightThrough.apply(score, target_space.mask))
return torch.mul(target, _StraightThrough.apply(score, target_space.mask)) # type: ignore
def movement_add_mask(target: torch.Tensor, target_space: PruningTargetSpace):
@ -171,7 +164,7 @@ def movement_add_mask(target: torch.Tensor, target_space: PruningTargetSpace):
trans_mask = torch.where(target_space.mask == 1, torch.zeros_like(target_space.mask), SMALL_MASK_VALUE)
if target_space._scaler is not None:
score = target_space._scaler.expand(score, target_space.shape)
return torch.add(target, _StraightThrough.apply(score, trans_mask))
return torch.add(target, _StraightThrough.apply(score, trans_mask)) # type: ignore
def slim_mul_mask(target: torch.Tensor, target_space: PruningTargetSpace):
@ -199,9 +192,8 @@ quant_apply_methods = {
'bypass': bypass,
'clamp_round': ClampRound.apply,
'qat_clamp_round': QATClampRound.apply,
'dofera_clamp_round_weight': DoferaGradClampRound.dorefa_clamp_round_weight,
'dofera_clamp_round_input': DoferaGradClampRound.dorefa_clamp_round_input,
'dofera_clamp_round_output': DoferaGradClampRound.apply,
'dorefa_clamp_round_weight': DoferaGradClampRound.dorefa_clamp_round_weight,
'dorefa_clamp_round_output': DoferaGradClampRound.dorefa_clamp_round_output,
"lsq_clamp_round": lsq_clamp_round,
'bnn_clamp_round': BNNClampRound.apply,
}

Просмотреть файл

@ -298,6 +298,16 @@ class Quantizer(Compressor):
# scalers are used to support different sparse/quant granularity
register_scalers(self._target_spaces, self._set_default_sparse_granularity) # type: ignore
def check_target(self, wrapper: ModuleWrapper, target_name: str) -> bool:
module_name = wrapper.name
if module_name not in self._target_spaces:
return False
ts = self._target_spaces[module_name]
if target_name not in ts:
return False
return True
def _set_default_sparse_granularity(self, target_space: PruningTargetSpace) -> List[int] | str | None:
return None

Просмотреть файл

@ -54,7 +54,7 @@ class ModuleWrapper(torch.nn.Module):
config
The config is a dict which contains keys (not required): ``pruning``, ``quantization``, ``distillation``.
fused_modules:
The List contains a series module names which need to fuse.
The List contains a series of modules which need to fuse.
"""
super().__init__()

Просмотреть файл

@ -9,7 +9,7 @@ from typing import List, Dict, overload
import torch
from torch import Tensor
from ..base.compressor import Quantizer
from ..base.compressor import Compressor, Quantizer
from ..base.wrapper import ModuleWrapper
from ..base.target_space import TargetType
from ..utils import Evaluator, _EVALUATOR_DOCSTRING
@ -71,12 +71,18 @@ class BNNQuantizer(Quantizer):
self.register_bnn_apply_method()
self.register_track_func()
@classmethod
def from_compressor(cls, compressor: Compressor, new_config_list: List[Dict], evaluator: Evaluator | None = None):
return super().from_compressor(compressor, new_config_list, evaluator=evaluator)
def check_validation(self):
for _, ts in self._target_spaces.items():
for _, target_space in ts.items():
if target_space.quant_dtype is not None:
warn_msg = "BNNQuantizer will only quantize the value to 1 or -1; the quant_dtype value will not work"
_logger.warning(warn_msg)
if target_space._scaler is not None:
raise ValueError("BNNQauntizer doesn't support for granularity, please set it to False")
def register_track_func(self):
for module_name, _ in self._target_spaces.items():
@ -89,7 +95,7 @@ class BNNQuantizer(Quantizer):
target_space.apply_method = 'bnn_clamp_round'
def init_scale_zp(self, wrapper: ModuleWrapper, target_name: str, target: Tensor):
if self.is_init or target_name not in wrapper.quantization_target_spaces:
if self.is_init or not self.check_target(wrapper, target_name):
return
target_space = wrapper.quantization_target_spaces[target_name]
target_space.zero_point = torch.tensor(0.0).to(target.device)

Просмотреть файл

@ -2,15 +2,29 @@
# Licensed under the MIT license.
from __future__ import annotations
import logging
from typing import List, Dict, Union, overload
import torch
import torch.nn as nn
from torch import Tensor
from ..base.compressor import Quantizer
from nni.common.version import torch_version_is_2
from ..base.compressor import Compressor, Quantizer
from ..base.wrapper import ModuleWrapper
from ..base.target_space import TargetType
from ..utils import Evaluator, _EVALUATOR_DOCSTRING
from ..base.target_space import TargetType, QuantizationTargetSpace
ACTIVATION_LIST = [
nn.ReLU, nn.RReLU, nn.LeakyReLU, nn.PReLU, nn.Softplus, nn.ELU, nn.CELU, nn.SELU, nn.GELU,
nn.ReLU6, nn.Sigmoid, nn.Tanh, nn.Softsign, nn.Hardtanh, nn.Threshold, nn.Tanhshrink,
nn.Softshrink, nn.Hardshrink, nn.LogSigmoid, nn.Softmin, nn.Softmax, nn.LogSoftmax, nn.Hardswish,
]
_logger = logging.getLogger(__name__)
is_proper_torch_version = torch_version_is_2()
class DoReFaQuantizer(Quantizer):
@ -55,39 +69,133 @@ class DoReFaQuantizer(Quantizer):
super().__init__(model, config_list, evaluator, existed_wrappers=existed_wrappers)
self.evaluator: Evaluator
self.is_init = False
self.check_validation()
self.register_dorefa_apply_method()
self.register_track_func()
@classmethod
def from_compressor(cls, compressor: Compressor, new_config_list: List[Dict], evaluator: Evaluator | None = None):
return super().from_compressor(compressor, new_config_list, evaluator=evaluator)
def check_validation(self) -> None:
for ts in self._target_spaces.values():
for target_space in ts.values():
assert target_space.quant_scheme != None
if target_space.type is TargetType.PARAMETER and target_space.quant_scheme != 'affine':
warn_msg = f'Only supports affine mode for weight quantization, bug got {target_space.quant_scheme}'
_logger.warning(warn_msg)
elif target_space.type is TargetType.OUTPUT:
module = target_space._wrapper.module
# case 1: activation module
# case 2: module with activation fused_modules
fused_modules = target_space._wrapper.fused_modules
if not isinstance(module, tuple(ACTIVATION_LIST)) and not (fused_modules and # type: ignore
any([isinstance(item, tuple(ACTIVATION_LIST)) for item in fused_modules[1:]])): # type: ignore
raise ValueError('Output quantization is only supported for activation function or' + \
f'activation module fusion, but got {type(module)}')
if target_space.quant_scheme != 'affine':
warn_msg = f'Only supports affine mode for output quantization, bug got {target_space.quant_scheme}'
_logger.warning(warn_msg)
if target_space._scaler is not None:
raise ValueError('DoRefa Qauntizer doesn\'t support for granularity, please set it to False')
def _quant_dequant_gradient_hook(self, target_space: QuantizationTargetSpace) -> None:
def quant_dequant_gradient(module: nn.Module, grad_output):
tracked_max = torch.tensor(1.0 + 0.5 / (2**target_space.quant_bits - 1)).to(grad_output[0].device)
tracked_min = torch.tensor(0 - 0.5 / (2**target_space.quant_bits - 1)).to(grad_output[0].device)
scale, zero_point = init_scale_zp(tracked_max, tracked_min, target_space.qmax, \
target_space.qmin, 'affine')
new_grad_output = []
for g_o in grad_output:
grad_o = torch.abs(g_o.clone().detach())
dim_lis = list(range(len(grad_o.shape)))
dim_lis.pop(0)
max_grad = torch.amax(grad_o, dim=dim_lis, keepdim=True)
# generate uniform noise
uniform_k = torch.zeros_like(max_grad).to(g_o.device)
N_k = uniform_k.uniform_(-0.5, 0.5) / (2**(target_space.quant_bits) - 1)
q_grad_o = g_o / (2 * max_grad) + 0.5 + N_k
quantized_grad = zero_point + q_grad_o / scale
quantized_grad = torch.round(torch.clamp(quantized_grad, target_space.qmin, target_space.qmax))
dequantized_grad = (quantized_grad - zero_point) * scale
new_grad_output.append((dequantized_grad - 0.5) * 2 * max_grad)
return tuple(new_grad_output)
target_space._wrapper.module.register_full_backward_pre_hook(quant_dequant_gradient) # type: ignore
def register_output_backward_hook(self):
for ts in self._target_spaces.values():
is_output = any([target_space.type is TargetType.OUTPUT for target_space in ts.values()])
is_param = any([target_space.type is TargetType.PARAMETER for target_space in ts.values()])
if is_param and not is_output:
if is_proper_torch_version: # torch version >= 2.0.0
for target_space in ts.values():
if target_space.type is TargetType.PARAMETER:
self._quant_dequant_gradient_hook(target_space)
break
else:
warn_msg = f'Gradient quantization is only supported for torch version >= 2.0.0'
_logger.warning(warn_msg)
def register_dorefa_apply_method(self):
for _, ts in self._target_spaces.items():
for _, target_space in ts.items():
if target_space.type is TargetType.PARAMETER:
target_space.apply_method = 'dofera_clamp_round_weight'
target_space.apply_method = 'dorefa_clamp_round_weight'
elif target_space.type is TargetType.INPUT:
target_space.apply_method = "dofera_clamp_round_input"
target_space.apply_method = 'clamp_round'
elif target_space.type is TargetType.OUTPUT:
target_space.apply_method = "dofera_clamp_round_output"
target_space.apply_method = 'dorefa_clamp_round_output'
def register_track_func(self):
for module_name, _ in self._target_spaces.items():
wrapper = self._module_wrappers[module_name]
wrapper.register_track_func(self.initialize_scale_zp)
wrapper.register_track_func(self.update_scale_zp)
def initialize_scale_zp(self, wrapper: ModuleWrapper, target_name: str, target: Tensor):
if self.is_init or target_name not in wrapper.quantization_target_spaces:
def update_scale_zp(self, wrapper: ModuleWrapper, target_name: str, target: Tensor) -> None:
if not self.check_target(wrapper, target_name):
return
target_space = wrapper.quantization_target_spaces[target_name]
if target_space.type is TargetType.INPUT or "weight" in target_name: #zero_point and scale don't change anymore
if target_space.type is not TargetType.INPUT:
return
# track min max values
current_amin = target.detach().reshape(-1).amin(-1)
current_amax = target.detach().reshape(-1).amax(-1)
# update scale and zero_point
tracked_min = torch.min(current_amin, torch.zeros_like(current_amin))
tracked_max = torch.max(current_amax, torch.zeros_like(current_amax))
zero_point = torch.zeros_like(tracked_min)
qmin, qmax = target_space.qmin, target_space.qmax
assert isinstance(qmin, int) and isinstance(qmax, int)
if target_space.quant_scheme in ['symmetric', None]:
abs_max = torch.max(torch.abs(tracked_min), torch.abs(tracked_max))
scale = abs_max / (float(qmax - qmin) / 2)
scale = torch.max(scale, torch.full_like(scale, torch.finfo(torch.float32).eps))
# NOTE: here need to check, +1 because in pytorch, symmetric qint8 zp is 0, quint8 zp is 128.
zero_point_val = (qmax + qmin + 1) // 2
zero_point = torch.full_like(zero_point, zero_point_val)
elif target_space.quant_scheme == 'affine':
scale = (tracked_max - tracked_min) / float(qmax - qmin)
scale = torch.max(scale, torch.full_like(scale, torch.finfo(torch.float32).eps))
zero_point = qmin - torch.round(tracked_min / scale)
else:
raise RuntimeError(f'Unknown quant_scheme {target_space.quant_scheme}')
zero_point = torch.clamp(zero_point, qmin, qmax)
target_space.scale, target_space.zero_point = scale, zero_point
def initialize_scale_zp(self, wrapper: ModuleWrapper, target_name: str, target: Tensor):
if self.is_init or not self.check_target(wrapper, target_name):
return
target_space = wrapper.quantization_target_spaces[target_name]
if target_space.type is TargetType.INPUT:
return
elif target_space.type in [TargetType.OUTPUT, TargetType.PARAMETER]:
tracked_max = torch.tensor(1.0).to(target.device)
tracked_min = torch.tensor(0.0).to(target.device)
scale, zero_point = init_scale_zp(tracked_max, tracked_min, target_space.qmax, \
target_space.qmin, 'affine')
elif target_space.type is TargetType.OUTPUT:
tracked_max = torch.tensor(1.0 + 0.5 / (2**target_space.quant_bits - 1)).to(target.device)
tracked_min = torch.tensor(0 - 0.5 / (2**target_space.quant_bits - 1)).to(target.device)
scale, zero_point = init_scale_zp(tracked_max, tracked_min, target_space.qmax, \
target_space.qmin, 'affine')
else:
raise RuntimeError(f'Unknown target_name {target_name}')
@ -103,6 +211,7 @@ class DoReFaQuantizer(Quantizer):
self._fusion_compress(max_steps, max_epochs)
def _fuse_preprocess(self, evaluator: Evaluator) -> None:
self.register_output_backward_hook()
module_name_param_dict = self.patch_optimizer_param_group()
if len(module_name_param_dict) > 0:
evaluator.patch_optim_param_group(module_name_param_dict)
@ -121,7 +230,7 @@ def init_scale_zp(tracked_max: Tensor, tracked_min: Tensor, qmax: int, qmin: int
scale = torch.max(scale, torch.full_like(scale, torch.finfo(torch.float32).eps))
zero_point = qmin - torch.round(tracked_min / scale)
elif quant_scheme in ['symmetric', None]:
raise ValueError(f"Unsupported quant_scheme {quant_scheme}")
raise ValueError(f'Unsupported quant_scheme {quant_scheme}')
else:
raise RuntimeError(f'Unknown quant_scheme {quant_scheme}')

Просмотреть файл

@ -2,14 +2,19 @@
# Licensed under the MIT license.
from __future__ import annotations
import logging
from typing import List, Dict, overload
import torch
from torch import Tensor
from ..base.compressor import Quantizer
from ..base.compressor import Compressor, Quantizer
from ..base.wrapper import ModuleWrapper
from ..utils import Evaluator, _EVALUATOR_DOCSTRING
from ..base.target_space import TargetType
_logger = logging.getLogger(__name__)
class LsqQuantizer(Quantizer):
@ -59,22 +64,46 @@ class LsqQuantizer(Quantizer):
self.evaluator: Evaluator
self.is_init = False
self.check_validation()
self.register_scale()
self.register_lsq_apply_method()
self.register_track_func()
@classmethod
def from_compressor(cls, compressor: Compressor, new_config_list: List[Dict], evaluator: Evaluator | None = None):
return super().from_compressor(compressor, new_config_list, evaluator=evaluator)
def check_validation(self) -> None:
for ts in self._target_spaces.values():
for target_space in ts.values():
if target_space.quant_scheme != 'symmetric':
warn_msg = f"LsqQuantizer only supports symmetric mode, but got {target_space.quant_scheme}"
_logger.warning(warn_msg)
if target_space.quant_dtype.startswith("uint") and target_space.type is TargetType.PARAMETER:
warn_msg = f"In the LsqQuantizer, quantization of parameters only supports int type"
_logger.warning(warn_msg)
def register_track_func(self):
for module_name, _ in self._target_spaces.items():
wrapper = self._module_wrappers[module_name]
wrapper.register_track_func(self.init_scale)
def init_scale(self, wrapper: ModuleWrapper, target_name: str, target: Tensor):
if self.is_init or target_name not in wrapper.quantization_target_spaces:
def mean_reduce_func(converted_target: Tensor) -> torch.Tensor:
return converted_target.detach().mean(dim=-1)
if self.is_init or not self.check_target(wrapper, target_name):
return
target_space = wrapper.quantization_target_spaces[target_name]
init_target = target.data.detach().abs().mean() * 2 / (target_space.qmax ** 0.5)
target_space.scale.data = init_target # type: ignore
target_space.zero_point = torch.tensor(0.0).to(target.device)
if not target_space._scaler:
target_space.scale.data = init_target # type: ignore
target_space.zero_point = torch.tensor(0.0).to(target.device)
else:
new_target = init_target.expand(target.shape).to(target.device)
new_target_scale = target_space._scaler.shrink(new_target, mean_reduce_func, keepdim=True)
target_space.scale.data = new_target_scale # type: ignore
target_space.zero_point = torch.zeros_like(new_target_scale)
def register_lsq_apply_method(self):
for _, ts in self._target_spaces.items():
@ -87,15 +116,7 @@ class LsqQuantizer(Quantizer):
for target_name, _ in ts.items():
if hasattr(wrapper, f"{target_name}_scale"):
delattr(wrapper, f"{target_name}_scale")
try:
device = next(wrapper.parameters()).device
except StopIteration:
try:
device = next(wrapper.buffers()).device
except StopIteration:
# NOTE: this will have risk in model parallel
device = next(self.bound_model.parameters()).device
param = torch.nn.Parameter(torch.Tensor([1.0]).to(device))
param = torch.nn.Parameter()
wrapper.register_parameter(f"{target_name}_scale", param)
def patch_optimizer_param_group(self):

Просмотреть файл

@ -7,7 +7,7 @@ from typing import List, Dict, Union, overload
import torch
from torch import Tensor
from ..base.compressor import Quantizer
from ..base.compressor import Compressor, Quantizer
from ..base.wrapper import ModuleWrapper
from ..utils import Evaluator, _EVALUATOR_DOCSTRING
@ -55,6 +55,10 @@ class PtqQuantizer(Quantizer):
self.register_ptq_apply_method()
self.register_track_func()
@classmethod
def from_compressor(cls, compressor: Compressor, new_config_list: List[Dict], evaluator: Evaluator | None = None):
return super().from_compressor(compressor, new_config_list, evaluator=evaluator)
def register_ptq_apply_method(self):
for _, ts in self._target_spaces.items():
for _, target_space in ts.items():

Просмотреть файл

@ -8,7 +8,7 @@ from typing import Dict, List, overload
import torch
from torch import Tensor
from ..base.compressor import Quantizer
from ..base.compressor import Compressor, Quantizer
from ..base.wrapper import ModuleWrapper
from ..utils import Evaluator, _EVALUATOR_DOCSTRING
@ -81,6 +81,11 @@ class QATQuantizer(Quantizer):
self.register_qat_apply_method()
self.register_track_func()
@classmethod
def from_compressor(cls, compressor: Compressor, new_config_list: List[Dict],
quant_start_step: int = 0, evaluator: Evaluator | None = None):
return super().from_compressor(compressor, new_config_list, quant_start_step=quant_start_step, evaluator=evaluator)
def register_qat_apply_method(self):
if self.current_step < self.quant_start_step:
for _, ts in self._target_spaces.items():
@ -103,12 +108,13 @@ class QATQuantizer(Quantizer):
def track_min_max_val(self, wrapper: ModuleWrapper, target_name: str, target: Tensor):
# in a fused compression pipeline, the target name may be another compressor's target name
if not wrapper.training or target_name not in wrapper.quantization_target_spaces:
if not wrapper.training or not self.check_target(wrapper, target_name):
return
return track_min_max_val(wrapper, target_name, target)
def update_scale_zp(self, wrapper: ModuleWrapper, target_name: str, target: Tensor):
if not wrapper.training or self.current_step < self.quant_start_step:
if not wrapper.training or self.current_step < self.quant_start_step \
or not self.check_target(wrapper, target_name):
return
if target_name in wrapper.quantization_target_spaces:
target_space = wrapper.quantization_target_spaces[target_name]

Просмотреть файл

@ -32,14 +32,14 @@ def test_dorefa_forward_with_torch_model():
torch.manual_seed(0)
model = SimpleTorchModel().to(device)
configure_list = [{
'target_names':['_input_', 'weight', '_output_'],
'target_names':['weight'],
'op_names': ['fc1', 'fc2'],
'quant_dtype': 'int8',
'quant_scheme': 'affine',
'granularity': 'default',
},
{
'target_names':['_input_', 'weight', '_output_'],
'target_names':['_input_', 'weight'],
'op_names': ['conv1', 'conv2', 'conv3'],
'quant_dtype': 'int8',
'quant_scheme': 'affine',
@ -54,14 +54,14 @@ def test_dorefa_forward_with_torch_model():
def test_dorefa_forward_with_lighting_model():
torch.manual_seed(0)
configure_list = [{
'target_names':['_input_', 'weight', '_output_'],
'target_names':['_input_', 'weight'],
'op_names': ['model.fc1', 'model.fc2'],
'quant_dtype': 'int8',
'quant_scheme': 'affine',
'granularity': 'default',
},
{
'target_names':['_input_', 'weight', '_output_'],
'target_names':['_input_', 'weight'],
'op_names': ['model.conv1', 'model.conv2', 'model.conv3'],
'quant_dtype': 'int8',
'quant_scheme': 'affine',

Просмотреть файл

@ -32,11 +32,11 @@ def test_lsq_forward_with_torch_model():
torch.manual_seed(0)
model = SimpleTorchModel().to(device)
configure_list = [{
'target_names':['_input_', 'weight', '_output_'],
'target_names':['weight'],
'op_names': ['fc1', 'fc2'],
'quant_dtype': 'int8',
'quant_scheme': 'affine',
'granularity': 'default',
'granularity': 'in_channel',
},
{
'target_names':['_input_', 'weight', '_output_'],
@ -54,11 +54,11 @@ def test_lsq_forward_with_torch_model():
def test_lsq_forward_with_lighting_model():
torch.manual_seed(0)
configure_list = [{
'target_names':['_input_', 'weight', '_output_'],
'target_names':['weight'],
'op_names': ['model.fc1', 'model.fc2'],
'quant_dtype': 'int8',
'quant_scheme': 'affine',
'granularity': 'default',
'granularity': 'in_channel',
},
{
'target_names':['_input_', 'weight', '_output_'],

Просмотреть файл

@ -32,11 +32,11 @@ def test_ptq_forward_with_torch_model():
torch.manual_seed(0)
model = SimpleTorchModel().to(device)
configure_list = [{
'target_names':['_input_', 'weight', '_output_'],
'target_names':['weight'],
'op_names': ['fc1', 'fc2'],
'quant_dtype': 'int8',
'quant_scheme': 'affine',
'granularity': 'default',
'granularity': 'in_channel',
},
{
'target_names':['_input_', 'weight', '_output_'],
@ -54,11 +54,11 @@ def test_ptq_forward_with_torch_model():
def test_ptq_forward_with_lighting_model():
torch.manual_seed(0)
configure_list = [{
'target_names':['_input_', 'weight', '_output_'],
'target_names':['weight'],
'op_names': ['model.fc1', 'model.fc2'],
'quant_dtype': 'int8',
'quant_scheme': 'affine',
'granularity': 'default',
'granularity': 'in_channel',
},
{
'target_names':['_input_', 'weight', '_output_'],

Просмотреть файл

@ -32,11 +32,11 @@ def test_qat_forward_with_torch_model():
torch.manual_seed(0)
model = SimpleTorchModel().to(device)
configure_list = [{
'target_names':['_input_', 'weight', '_output_'],
'target_names':['weight'],
'op_names': ['fc1', 'fc2'],
'quant_dtype': 'int8',
'quant_scheme': 'affine',
'granularity': 'default',
'granularity': 'in_channel',
},
{
'target_names':['_input_', 'weight', '_output_'],
@ -54,11 +54,11 @@ def test_qat_forward_with_torch_model():
def test_qat_forward_with_lighting_model():
torch.manual_seed(0)
configure_list = [{
'target_names':['_input_', 'weight', '_output_'],
'target_names':['weight'],
'op_names': ['model.fc1', 'model.fc2'],
'quant_dtype': 'int8',
'quant_scheme': 'affine',
'granularity': 'default',
'granularity': 'in_channel',
},
{
'target_names':['_input_', 'weight', '_output_'],