[Relay][Pass] CanonicalizeCast (#3280)
This commit is contained in:
Родитель
fa351045e6
Коммит
04e816241f
|
@ -534,6 +534,13 @@ TVM_DLL Pass CanonicalizeOps();
|
|||
*/
|
||||
TVM_DLL Pass AlterOpLayout();
|
||||
|
||||
/*!
|
||||
* \brief Canonicalize cast expressions to make operator fusion more efficient.
|
||||
*
|
||||
* \return The pass.
|
||||
*/
|
||||
TVM_DLL Pass CanonicalizeCast();
|
||||
|
||||
} // namespace transform
|
||||
} // namespace relay
|
||||
} // namespace tvm
|
||||
|
|
|
@ -445,6 +445,16 @@ def PartialEvaluate():
|
|||
"""
|
||||
return _transform.PartialEvaluate()
|
||||
|
||||
def CanonicalizeCast():
|
||||
"""
|
||||
Canonicalize cast expressions to make operator fusion more efficient.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ret : tvm.relay.Pass
|
||||
The registered pass that canonicalizes cast expression.
|
||||
"""
|
||||
return _transform.CanonicalizeCast()
|
||||
|
||||
def _wrap_class_module_pass(pass_cls, pass_info):
|
||||
"""Wrap a python class as function pass"""
|
||||
|
|
|
@ -299,6 +299,7 @@ class RelayBuildModule : public runtime::ModuleNode {
|
|||
pass_seqs.push_back(transform::CombineParallelConv2D(3));
|
||||
pass_seqs.push_back(transform::FoldConstant());
|
||||
pass_seqs.push_back(transform::FoldScaleAxis());
|
||||
pass_seqs.push_back(transform::CanonicalizeCast());
|
||||
pass_seqs.push_back(transform::CanonicalizeOps());
|
||||
|
||||
// Alter layout transformation is only applied to homogeneous execution yet.
|
||||
|
|
|
@ -0,0 +1,144 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* Copyright (c) 2019 by Contributors
|
||||
* \file canonicalize_cast.cc
|
||||
* \brief Canonicalize cast expressions to make operator fusion more efficient.
|
||||
*/
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/expr_functor.h>
|
||||
#include <tvm/relay/attrs/nn.h>
|
||||
#include <tvm/relay/transform.h>
|
||||
#include "pattern_util.h"
|
||||
#include "pass_util.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace relay {
|
||||
|
||||
// This pass finds upcast that is referred by multiple elemwise/broadcast operators, and creates a
|
||||
// copy of it in each branch such that after fusion the previous function have output with fewer
|
||||
// bits.
|
||||
//
|
||||
// Consider the following example:
|
||||
// \code
|
||||
// def @main(x: int8) {
|
||||
// %1 = cast(%x, f32)
|
||||
// %2 = exp(%1)
|
||||
// %3 = log(%1)
|
||||
// (%3, 4)
|
||||
// }
|
||||
// \endcode
|
||||
//
|
||||
// We would like to prevent sharing of the cast expression such that operator fusion can produce
|
||||
// more efficient result as below.
|
||||
// \code
|
||||
// def @main(x: int8) {
|
||||
// %1 = fn (%p1: i8) {
|
||||
// exp(cast(%p1, f32)
|
||||
// }
|
||||
// %3 = %1(%x)
|
||||
// %2 = fn (%p1: i8) {
|
||||
// log(cast(%p1, f32)
|
||||
// }
|
||||
// %4 = %2(%x)
|
||||
// (%3, 4)
|
||||
// }
|
||||
// \endcode
|
||||
class CastCanonicalizer : public ExprMutator {
|
||||
public:
|
||||
Expr VisitExpr_(const CallNode* call) {
|
||||
static auto fpattern = Op::GetAttr<TOpPattern>("TOpPattern");
|
||||
|
||||
if (const OpNode* opnode = call->op.as<OpNode>()) {
|
||||
auto pattern = fpattern[GetRef<Op>(opnode)];
|
||||
if (pattern <= kBroadcast) {
|
||||
Array<Expr> call_args = call->args;
|
||||
bool unchanged = true;
|
||||
for (size_t i = 0; i < call_args.size(); ++i) {
|
||||
Expr arg = call_args[i];
|
||||
Expr new_arg = GetNewCallArg(arg);
|
||||
if (!arg.same_as(new_arg)) {
|
||||
call_args.Set(i, new_arg);
|
||||
unchanged = false;
|
||||
}
|
||||
}
|
||||
if (unchanged) {
|
||||
return GetRef<Expr>(call);
|
||||
}
|
||||
return CallNode::make(call->op, call_args, call->attrs, call->type_args);
|
||||
}
|
||||
}
|
||||
|
||||
Expr new_expr = ExprMutator::VisitExpr_(call);
|
||||
return new_expr;
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_map<const Node*, size_t> ref_counter_;
|
||||
|
||||
Expr GetNewCallArg(const Expr& e) {
|
||||
// if e is a upcast and ref count > 1, create an copy; otherwise call the default visitor
|
||||
|
||||
static auto& cast = Op::Get("cast");
|
||||
Expr new_expr = this->VisitExpr(e);
|
||||
|
||||
if (const CallNode* call = e.as<CallNode>()) {
|
||||
if (call->op.same_as(cast)) {
|
||||
auto attrs = call->attrs.as<CastAttrs>();
|
||||
const auto* from_type = call->args[0]->type_as<TensorTypeNode>();
|
||||
CHECK(from_type);
|
||||
|
||||
if (from_type->dtype.bits() < attrs->dtype.bits()) {
|
||||
if (++ref_counter_[call] > 1) {
|
||||
const CallNode* new_call = new_expr.as<CallNode>();
|
||||
CHECK(new_call);
|
||||
CHECK(new_call->op.same_as(cast));
|
||||
return CallNode::make(new_call->op, new_call->args, new_call->attrs,
|
||||
new_call->type_args);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return new_expr;
|
||||
}
|
||||
};
|
||||
|
||||
Expr CanonicalizeCast(const Expr& e) {
|
||||
return CastCanonicalizer().Mutate(e);
|
||||
}
|
||||
|
||||
namespace transform {
|
||||
|
||||
Pass CanonicalizeCast() {
|
||||
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
|
||||
[=](Function f, Module m, PassContext pc) {
|
||||
return Downcast<Function>(CanonicalizeCast(f));
|
||||
};
|
||||
return CreateFunctionPass(pass_func, 3, "CanonicalizeCast",
|
||||
{ir::StringImm::make("InferType")});
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay._transform.CanonicalizeCast")
|
||||
.set_body_typed(CanonicalizeCast);
|
||||
|
||||
} // namespace transform
|
||||
|
||||
} // namespace relay
|
||||
} // namespace tvm
|
|
@ -0,0 +1,70 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import tvm
|
||||
import tvm.relay as relay
|
||||
import tvm.relay.module as _module
|
||||
import tvm.relay.transform as _transform
|
||||
|
||||
|
||||
def test_canonicalize_cast():
|
||||
def before(data, conv_weight, bias1, bias2):
|
||||
x = relay.nn.conv2d(data, conv_weight,
|
||||
channels=16,
|
||||
kernel_size=(3, 3),
|
||||
padding=(1, 1),
|
||||
out_dtype="int8")
|
||||
x1 = relay.cast(x, dtype="int32")
|
||||
y1 = relay.add(x1, bias1)
|
||||
y2 = relay.add(x1, bias2)
|
||||
y = relay.add(y1, y2)
|
||||
return relay.Function([data, conv_weight, bias1, bias2], y)
|
||||
|
||||
def expected(data, conv_weight, bias1, bias2):
|
||||
x = relay.nn.conv2d(data, conv_weight,
|
||||
channels=16,
|
||||
kernel_size=(3, 3),
|
||||
padding=(1, 1),
|
||||
out_dtype="int8")
|
||||
x1 = relay.cast(x, dtype="int32")
|
||||
x2 = relay.cast(x, dtype="int32")
|
||||
y1 = relay.add(x1, bias1)
|
||||
y2 = relay.add(x2, bias2)
|
||||
y = relay.add(y1, y2)
|
||||
return relay.Function([data, conv_weight, bias1, bias2], y)
|
||||
|
||||
def check(shape):
|
||||
data = relay.var("data", shape=shape, dtype="int8")
|
||||
conv_weight = relay.var("weight")
|
||||
bias1 = relay.var("bias1", shape=(16, 1, 1), dtype="int32")
|
||||
bias2 = relay.var("bias2", shape=(16, 1, 1), dtype="int32")
|
||||
y = before(data, conv_weight, bias1, bias2)
|
||||
mod = _module.Module.from_expr(y)
|
||||
seq = _transform.Sequential([_transform.InferType(), _transform.CanonicalizeCast(),
|
||||
_transform.InferType()])
|
||||
with _transform.PassContext(opt_level=3):
|
||||
mod = seq(mod)
|
||||
y = mod[mod.entry_func.name_hint]
|
||||
y_expected = expected(data, conv_weight, bias1, bias2)
|
||||
y_expected = relay.ir_pass.infer_type(y_expected)
|
||||
assert relay.ir_pass.alpha_equal(y, y_expected)
|
||||
|
||||
check((1, 16, 7, 7))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_canonicalize_cast()
|
Загрузка…
Ссылка в новой задаче