[QUANTIZE] Memorizing the quantize node mapping (#3233)
* [QUANTIZE] Support for clip operator * [QUANTIZE] Memorizing the quantize node mapping. * [QUANTIZE] Remove use_stop_fusion and skip_k_conv in qconfig * update * update * update * update
This commit is contained in:
Родитель
b796e335db
Коммит
bfb4884e47
|
@ -17,7 +17,6 @@
|
|||
"""The interface of expr function exposed from C++."""
|
||||
from __future__ import absolute_import
|
||||
|
||||
import logging
|
||||
from ... import build_module as _build
|
||||
from ... import container as _container
|
||||
from ..._ffi.function import _init_api, register_func
|
||||
|
@ -50,8 +49,8 @@ def lower(sch, inputs, func_name, source_func):
|
|||
# pylint: disable=broad-except
|
||||
try:
|
||||
f = _build.lower(sch, inputs, name=func_name)
|
||||
logging.debug("lower function %s", func_name)
|
||||
logging.debug("%s", _build.lower(sch, inputs, simple_mode=True))
|
||||
# logging.debug("lower function %s", func_name)
|
||||
# logging.debug("%s", _build.lower(sch, inputs, simple_mode=True))
|
||||
except Exception:
|
||||
msg = traceback.format_exc()
|
||||
msg += "Error during compile function\n"
|
||||
|
|
|
@ -22,7 +22,7 @@ import warnings
|
|||
import topi
|
||||
from . import _quantize
|
||||
from .quantize import QAnnotateKind, current_qconfig
|
||||
from .quantize import _conv_counter, _set_conv_counter
|
||||
from .quantize import annotate_context
|
||||
from .. import expr as _expr
|
||||
from .. import op as _op
|
||||
from ..op import op as _reg
|
||||
|
@ -116,7 +116,6 @@ def register_annotate_function(op_name, frewrite=None, level=10):
|
|||
return _register(frewrite) if frewrite is not None else _register
|
||||
|
||||
|
||||
@register_func("relay.quantize.attach_simulated_quantize")
|
||||
def attach_simulated_quantize(data, kind, sign=True, rounding="round"):
|
||||
"""Attach a simulated quantize operation after input data expr.
|
||||
|
||||
|
@ -133,11 +132,20 @@ def attach_simulated_quantize(data, kind, sign=True, rounding="round"):
|
|||
if data.attrs.kind == kind and data.attrs.sign == sign and data.attrs.rounding == rounding:
|
||||
return data
|
||||
|
||||
actx = annotate_context()
|
||||
key = tuple([data, kind, sign, rounding])
|
||||
if key in actx.qnode_map:
|
||||
return actx.qnode_map[key]
|
||||
|
||||
dom_scale = _expr.var("dom_scale")
|
||||
clip_min = _expr.var("clip_min")
|
||||
clip_max = _expr.var("clip_max")
|
||||
return _quantize.simulated_quantize(
|
||||
qnode = _quantize.simulated_quantize(
|
||||
data, dom_scale, clip_min, clip_max, kind, sign, rounding)
|
||||
actx.qnode_map[key] = qnode
|
||||
return qnode
|
||||
|
||||
register_func("relay.quantize.attach_simulated_quantize", attach_simulated_quantize)
|
||||
|
||||
|
||||
@register_annotate_function("nn.contrib_conv2d_NCHWc")
|
||||
|
@ -152,18 +160,13 @@ def conv2d_rewrite(ref_call, new_args, ctx):
|
|||
"""Rewrite function for conv2d. Lhs of conv will be quantized to
|
||||
input field, and rhs of conv will be quantized to weight field.
|
||||
Output would be in activation field"""
|
||||
cnt = _conv_counter()
|
||||
if cnt < current_qconfig().skip_k_conv:
|
||||
_set_conv_counter(cnt + 1)
|
||||
return None
|
||||
|
||||
actx = annotate_context()
|
||||
if current_qconfig().skip_conv_layers is not None:
|
||||
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
|
||||
if cnt in leave_alone_indices:
|
||||
_set_conv_counter(cnt + 1)
|
||||
skipped_indices = [int(x) for x in current_qconfig().skip_conv_layers]
|
||||
if actx.conv2d_counter() in skipped_indices:
|
||||
actx.count_conv2d()
|
||||
return None
|
||||
|
||||
_set_conv_counter(cnt + 1)
|
||||
actx.count_conv2d()
|
||||
|
||||
lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
|
||||
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
|
||||
|
@ -179,17 +182,21 @@ def conv2d_rewrite(ref_call, new_args, ctx):
|
|||
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
|
||||
|
||||
|
||||
def check_to_skip():
|
||||
"""Check the index of conv2d layer to decide whether to skip the current operator."""
|
||||
if current_qconfig().skip_conv_layers is not None:
|
||||
skipped_indices = [int(x) for x in current_qconfig().skip_conv_layers]
|
||||
if annotate_context().conv2d_counter() - 1 in skipped_indices:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@register_annotate_function("nn.dense")
|
||||
def dense_rewrite(ref_call, new_args, ctx):
|
||||
"""Rewrite function for dense. Lhs of dense will be quantized to input field, and rhs of
|
||||
dense will be quantized to weight field. Output would be in activation field."""
|
||||
cnt = _conv_counter()
|
||||
if cnt < current_qconfig().skip_k_conv:
|
||||
if check_to_skip():
|
||||
return None
|
||||
if current_qconfig().skip_conv_layers is not None:
|
||||
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
|
||||
if cnt - 1 in leave_alone_indices:
|
||||
return None
|
||||
|
||||
lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
|
||||
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
|
||||
|
@ -207,13 +214,8 @@ def dense_rewrite(ref_call, new_args, ctx):
|
|||
@register_annotate_function("multiply")
|
||||
def multiply_rewrite(ref_call, new_args, ctx):
|
||||
"""Rewrite function for multiply."""
|
||||
cnt = _conv_counter()
|
||||
if cnt <= current_qconfig().skip_k_conv:
|
||||
if check_to_skip():
|
||||
return None
|
||||
if current_qconfig().skip_conv_layers is not None:
|
||||
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
|
||||
if cnt - 1 in leave_alone_indices:
|
||||
return None
|
||||
|
||||
lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
|
||||
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
|
||||
|
@ -234,13 +236,8 @@ def multiply_rewrite(ref_call, new_args, ctx):
|
|||
@register_annotate_function("add")
|
||||
def add_rewrite(ref_call, new_args, ctx):
|
||||
"""Rewrite function for add."""
|
||||
cnt = _conv_counter()
|
||||
if cnt <= current_qconfig().skip_k_conv:
|
||||
if check_to_skip():
|
||||
return None
|
||||
if current_qconfig().skip_conv_layers is not None:
|
||||
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
|
||||
if cnt - 1 in leave_alone_indices:
|
||||
return None
|
||||
|
||||
lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
|
||||
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
|
||||
|
@ -265,15 +262,25 @@ def add_rewrite(ref_call, new_args, ctx):
|
|||
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
|
||||
|
||||
|
||||
@register_annotate_function("stop_fusion")
|
||||
def stop_fusion_rewrite(ref_call, new_args, ctx):
|
||||
"""Rewrite function for add."""
|
||||
if check_to_skip():
|
||||
return None
|
||||
|
||||
x_expr, x_kind = _get_expr_kind(new_args[0])
|
||||
if x_kind is None:
|
||||
return None
|
||||
|
||||
ret_expr = attach_simulated_quantize(x_expr, QAnnotateKind.INPUT)
|
||||
ret_expr = _forward_op(ref_call, [ret_expr])
|
||||
return QAnnotateExpr(ret_expr, QAnnotateKind.INPUT)
|
||||
|
||||
|
||||
def identity_rewrite(ref_call, new_args, ctx):
|
||||
"""Simply forward the original operation"""
|
||||
cnt = _conv_counter()
|
||||
if cnt <= current_qconfig().skip_k_conv:
|
||||
if check_to_skip():
|
||||
return None
|
||||
if current_qconfig().skip_conv_layers is not None:
|
||||
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
|
||||
if cnt - 1 in leave_alone_indices:
|
||||
return None
|
||||
|
||||
x_expr, x_kind = _get_expr_kind(new_args[0])
|
||||
if x_kind is None:
|
||||
|
@ -283,6 +290,7 @@ def identity_rewrite(ref_call, new_args, ctx):
|
|||
return QAnnotateExpr(ret_expr, x_kind)
|
||||
|
||||
|
||||
register_annotate_function("clip", identity_rewrite)
|
||||
register_annotate_function("nn.relu", identity_rewrite)
|
||||
register_annotate_function("strided_slice", identity_rewrite)
|
||||
register_annotate_function("nn.avg_pool2d", identity_rewrite)
|
||||
|
@ -290,13 +298,8 @@ register_annotate_function("nn.avg_pool2d", identity_rewrite)
|
|||
|
||||
def pool2d_rewrite(ref_call, new_args, ctx):
|
||||
"""Rewrite function for max pool2d"""
|
||||
cnt = _conv_counter()
|
||||
if cnt <= current_qconfig().skip_k_conv:
|
||||
if check_to_skip():
|
||||
return None
|
||||
if current_qconfig().skip_conv_layers is not None:
|
||||
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
|
||||
if cnt - 1 in leave_alone_indices:
|
||||
return None
|
||||
|
||||
expr, x_kind = _get_expr_kind(new_args[0])
|
||||
|
||||
|
@ -314,13 +317,8 @@ register_annotate_function("nn.max_pool2d", pool2d_rewrite)
|
|||
@register_annotate_function("concatenate")
|
||||
def concatenate_rewrite(ref_call, new_args, ctx):
|
||||
"""Rewrite function for concatenate"""
|
||||
cnt = _conv_counter()
|
||||
if cnt <= current_qconfig().skip_k_conv:
|
||||
if check_to_skip():
|
||||
return None
|
||||
if current_qconfig().skip_conv_layers is not None:
|
||||
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
|
||||
if cnt - 1 in leave_alone_indices:
|
||||
return None
|
||||
|
||||
input_tuple = new_args[0]
|
||||
expr_list = [_get_expr_kind(x)[0] for x in input_tuple]
|
||||
|
|
|
@ -71,12 +71,10 @@ class QConfig(NodeBase):
|
|||
"dtype_weight": "int8",
|
||||
"dtype_activation": "int32",
|
||||
"global_scale": 8.0,
|
||||
"skip_k_conv": 1,
|
||||
"skip_conv_layers": None,
|
||||
"skip_conv_layers": [0],
|
||||
"round_for_shift": True,
|
||||
"store_lowbit_output": True,
|
||||
"debug_enabled_ops": None,
|
||||
"use_stop_fusion": True
|
||||
}
|
||||
|
||||
# pylint: disable=no-member
|
||||
|
@ -138,11 +136,8 @@ def qconfig(**kwargs):
|
|||
global_scale: float
|
||||
The global scale for calibration.
|
||||
|
||||
skip_k_conv: int
|
||||
The number of skipped conv2d.
|
||||
|
||||
skip_conv_layers: list
|
||||
Different way of specifying which layers to avoid. Provide a list of indices
|
||||
Specifying which layers to be skipped. Provide a list of indices
|
||||
that indicate which conv2d layers to leave untouched.
|
||||
|
||||
round_for_shift: boolean
|
||||
|
@ -152,9 +147,10 @@ def qconfig(**kwargs):
|
|||
Whether to store low-bit integer back as output before dequantizing.
|
||||
Some accelerators need this, e.g. VTA.
|
||||
|
||||
use_stop_fusion: boolean
|
||||
Whether add stop_fusion when casting to dtype_activation. stop_fusion forces lowbit
|
||||
results to be stored in memory.
|
||||
debug_enabled_ops: None or list of str
|
||||
Partially quantize specified operators for debugging. The default value
|
||||
is None, which means will try to call all operartors' annotate rewrite
|
||||
function.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
@ -166,18 +162,35 @@ def qconfig(**kwargs):
|
|||
return _make.node("relay.quantize.QConfig", **node_args)
|
||||
|
||||
|
||||
CONV_COUNTER = 0
|
||||
class AnnotateContext(object):
|
||||
"""A global singleton annotate scope"""
|
||||
Current = None
|
||||
|
||||
def __init__(self):
|
||||
self.qnode_map = dict()
|
||||
self._conv2d_counter = 0
|
||||
|
||||
def __enter__(self):
|
||||
self._conv2d_counter = 0
|
||||
return self
|
||||
|
||||
def conv2d_counter(self):
|
||||
"""Get the counter for conv2d."""
|
||||
return self._conv2d_counter
|
||||
|
||||
def count_conv2d(self):
|
||||
"""Increase the value of the conv2d counter by one."""
|
||||
self._conv2d_counter += 1
|
||||
|
||||
def __exit__(self, ptype, value, traceback):
|
||||
pass
|
||||
|
||||
|
||||
def _conv_counter():
|
||||
"""Get the global counter for conv2d."""
|
||||
return CONV_COUNTER
|
||||
|
||||
|
||||
def _set_conv_counter(n):
|
||||
"""Set the value of the global conv2d counter."""
|
||||
global CONV_COUNTER
|
||||
CONV_COUNTER = n
|
||||
def annotate_context():
|
||||
"""Get the global singleton scope"""
|
||||
if AnnotateContext.Current is None:
|
||||
AnnotateContext.Current = AnnotateContext()
|
||||
return AnnotateContext.Current
|
||||
|
||||
|
||||
def calibrate(graph, mod=None, ctx=None):
|
||||
|
@ -324,15 +337,15 @@ def quantize(graph, params=None, dataset=None):
|
|||
|
||||
calibrate_pass = _transform.function_pass(calibrate, opt_level=1,
|
||||
name="QuantizeCalibrate")
|
||||
_set_conv_counter(0) # reset counter
|
||||
quantize_seq = _transform.Sequential([annotate(),
|
||||
calibrate_pass,
|
||||
realize(),
|
||||
_transform.FoldConstant()])
|
||||
with _transform.PassContext(opt_level=3,
|
||||
required_pass=["QuantizeAnnotate",
|
||||
"QuantizeCalibrate",
|
||||
"QuantizeRealize"]):
|
||||
mod = optimize(mod)
|
||||
mod = quantize_seq(mod)
|
||||
with annotate_context():
|
||||
with _transform.PassContext(opt_level=3,
|
||||
required_pass=["QuantizeAnnotate",
|
||||
"QuantizeCalibrate",
|
||||
"QuantizeRealize"]):
|
||||
mod = optimize(mod)
|
||||
mod = quantize_seq(mod)
|
||||
return mod[mod.entry_func.name_hint]
|
||||
|
|
|
@ -6,9 +6,9 @@
|
|||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
|
@ -393,7 +393,7 @@ Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args,
|
|||
} else if (ref_arg && ref_arg->op.same_as(simulated_quantize) &&
|
||||
ref_arg->attrs.as<SimulatedQuantizeAttrs>()->kind == kQInput) {
|
||||
auto new_arg = Cast(ret[i], cfg->dtype_input);
|
||||
if (cfg->use_stop_fusion) {
|
||||
if (cfg->store_lowbit_output) {
|
||||
new_arg = StopFusion(new_arg);
|
||||
}
|
||||
ret.Set(i, Cast(new_arg, dtype));
|
||||
|
@ -431,6 +431,28 @@ Expr AddRealize(const Call& ref_call,
|
|||
RELAY_REGISTER_OP("add")
|
||||
.set_attr<FForwardRewrite>("FQRealizeRewrite", AddRealize);
|
||||
|
||||
Expr ClipRealize(const Call& ref_call,
|
||||
const Array<Expr>& new_args,
|
||||
const NodeRef& ctx) {
|
||||
CHECK_EQ(new_args.size(), 1);
|
||||
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
|
||||
const auto ref_attrs = ref_call->attrs.as<ClipAttrs>();
|
||||
auto attrs = make_node<ClipAttrs>();
|
||||
double dom_scale = GetScalarFromConstant<float>(n->dom_scale);
|
||||
attrs->a_min = ref_attrs->a_min / dom_scale;
|
||||
attrs->a_max = ref_attrs->a_max / dom_scale;
|
||||
|
||||
Expr ret = CallNode::make(ref_call->op,
|
||||
{n->data}, Attrs(attrs), ref_call->type_args);
|
||||
return QRealizeIntExprNode::make(ret, n->dom_scale, n->dtype);
|
||||
}
|
||||
CHECK(!new_args[0]->derived_from<TempExprNode>());
|
||||
return Expr(nullptr);
|
||||
}
|
||||
|
||||
RELAY_REGISTER_OP("clip")
|
||||
.set_attr<FForwardRewrite>("FQRealizeRewrite", ClipRealize);
|
||||
|
||||
|
||||
Expr ConcatenateRealize(const Call& ref_call,
|
||||
const Array<Expr>& new_args,
|
||||
|
@ -572,12 +594,10 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
|
|||
p->stream << "nbit_weight=" << op->nbit_weight << ", ";
|
||||
p->stream << "nbit_activation=" << op->nbit_activation << ", ";
|
||||
p->stream << "global_scale=" << op->global_scale << ", ";
|
||||
p->stream << "skip_k_conv==" << op->skip_k_conv << ", ";
|
||||
p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", ";
|
||||
p->stream << "round_for_shift==" << op->round_for_shift << ", ";
|
||||
p->stream << "store_lowbit_output==" << op->store_lowbit_output << ", ";
|
||||
p->stream << "debug_enabled_ops==" << op->debug_enabled_ops << ", ";
|
||||
p->stream << "use_stop_fusion==" << op->use_stop_fusion;
|
||||
p->stream << "debug_enabled_ops==" << op->debug_enabled_ops;
|
||||
p->stream << ")";
|
||||
});
|
||||
|
||||
|
|
|
@ -6,9 +6,9 @@
|
|||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
|
@ -125,12 +125,10 @@ class QConfigNode : public Node {
|
|||
DataType dtype_weight = Int(8);
|
||||
DataType dtype_activation = Int(32);
|
||||
double global_scale = 8.0;
|
||||
int skip_k_conv = 1;
|
||||
Array<Expr> skip_conv_layers = Array<Expr>(NodePtr<Node>(nullptr));
|
||||
bool round_for_shift = true;
|
||||
bool store_lowbit_output = true;
|
||||
Array<Expr> debug_enabled_ops = Array<Expr>(NodePtr<Node>(nullptr));
|
||||
bool use_stop_fusion = true;
|
||||
|
||||
void VisitAttrs(AttrVisitor* v) final {
|
||||
v->Visit("nbit_input", &nbit_input);
|
||||
|
@ -140,12 +138,10 @@ class QConfigNode : public Node {
|
|||
v->Visit("dtype_weight", &dtype_weight);
|
||||
v->Visit("dtype_activation", &dtype_activation);
|
||||
v->Visit("global_scale", &global_scale);
|
||||
v->Visit("skip_k_conv", &skip_k_conv);
|
||||
v->Visit("skip_conv_layers", &skip_conv_layers);
|
||||
v->Visit("round_for_shift", &round_for_shift);
|
||||
v->Visit("store_lowbit_output", &store_lowbit_output);
|
||||
v->Visit("debug_enabled_ops", &debug_enabled_ops);
|
||||
v->Visit("use_stop_fusion", &use_stop_fusion);
|
||||
}
|
||||
|
||||
static constexpr const char* _type_key = "relay.quantize.QConfig";
|
||||
|
|
|
@ -81,7 +81,7 @@ def test_quantize_pass():
|
|||
graph = make_graph(data)
|
||||
dataset, params = make_dataset(graph, 10)
|
||||
|
||||
with qtz.qconfig(skip_k_conv=0, global_scale=4.0,
|
||||
with qtz.qconfig(skip_conv_layers=None, global_scale=4.0,
|
||||
round_for_shift=False, store_lowbit_output=False):
|
||||
qgraph0 = qtz.quantize(graph, params)
|
||||
qgraph0 = relay.ir_pass.infer_type(qgraph0)
|
||||
|
|
Загрузка…
Ссылка в новой задаче