[RELAY][PASS] CombineParallelConv2D (#2089)
This commit is contained in:
Родитель
ff5dffa440
Коммит
53ac89ede7
|
@ -13,6 +13,7 @@ from .backend import graph_runtime_codegen as _graph_gen
|
|||
# List of optimization pass and level when switch on
|
||||
OPT_PASS_LEVEL = {
|
||||
"SimplifyInference": 0,
|
||||
"CombineParallelConv2D": 1,
|
||||
"OpFusion": 1,
|
||||
"FoldConstant": 2,
|
||||
"FoldScaleAxis": 3,
|
||||
|
@ -144,6 +145,10 @@ def optimize(func, params=None):
|
|||
func = ir_pass.infer_type(func)
|
||||
func = ir_pass.simplify_inference(func)
|
||||
|
||||
if cfg.pass_enabled("CombineParallelConv2D"):
|
||||
func = ir_pass.infer_type(func)
|
||||
func = ir_pass.combine_parallel_conv2d(func)
|
||||
|
||||
if cfg.pass_enabled("FoldScaleAxis"):
|
||||
func = ir_pass.infer_type(func)
|
||||
func = ir_pass.backward_fold_scale_axis(func)
|
||||
|
|
|
@ -292,3 +292,19 @@ def fuse_ops(expr, opt_level=1):
|
|||
Transformed expression, containing fused result.
|
||||
"""
|
||||
return _ir_pass.FuseOps(expr, opt_level)
|
||||
|
||||
|
||||
def combine_parallel_conv2d(expr):
|
||||
"""Fold multiple conv2d into one.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
expr : tvm.relay.Expr
|
||||
The input expression.
|
||||
|
||||
Returns
|
||||
-------
|
||||
transformed_expr : tvm.relay.Expr
|
||||
Transformed expression
|
||||
"""
|
||||
return _ir_pass.CombineParallelConv2D(expr)
|
||||
|
|
|
@ -0,0 +1,328 @@
|
|||
/*!
|
||||
* Copyright (c) 2018 by Contributors
|
||||
*
|
||||
* \file combine_parallel_conv2d.cc
|
||||
* \brief Combine parallel 2d convolutions into a single convolution.
|
||||
*
|
||||
* This pass replaces convolutions that share the same input node and the same
|
||||
* arguments (except that the number of output channels can be different) with a
|
||||
* single convolution. The weight of the new 2d convolution is the concatenation
|
||||
* of the original weights. Elemwise and broadcast ops following conv2d are also
|
||||
* combined if possible.
|
||||
*
|
||||
* This prevents launching multiple kernels in networks with multiple
|
||||
* convolution branches, such as Inception block.
|
||||
*/
|
||||
|
||||
#include <tvm/relay/pass.h>
|
||||
#include <tvm/relay/expr_functor.h>
|
||||
#include <tvm/relay/attrs/nn.h>
|
||||
#include <tvm/relay/attrs/transform.h>
|
||||
#include <tvm/relay/op_attr_types.h>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include "./expr_subst.h"
|
||||
#include "./pattern_util.h"
|
||||
|
||||
|
||||
namespace tvm {
|
||||
namespace relay {
|
||||
|
||||
using Branch = std::vector<const CallNode*>;
|
||||
using Group = std::vector<Branch>;
|
||||
|
||||
/*
|
||||
Find parallel branches starting with conv2d as shown below and then group branches by kernel
|
||||
shape and attributes of conv2d. Conv2d can be followed by zero or more elemwise or broadcast ops.
|
||||
Intermediate nodes have exactly one successor. It is possible that branches meet at a point,
|
||||
which should be handled in ParallelConv2DCombiner.
|
||||
|
||||
data
|
||||
/ \
|
||||
conv2d conv2d
|
||||
| |
|
||||
op op
|
||||
| |
|
||||
*/
|
||||
class BranchGroupFinder : private ExprVisitor {
|
||||
public:
|
||||
std::vector<Group> Find(const Expr& expr) {
|
||||
this->VisitExpr(expr);
|
||||
|
||||
std::vector<Group> groups;
|
||||
for (const auto& root : conv_roots_) {
|
||||
const auto& convs = children_map_.at(root);
|
||||
for (const CallNode* conv : convs) {
|
||||
auto&& branch = CreateBranch(conv);
|
||||
// add the branch to a group, or create a new group
|
||||
auto it = std::find_if(groups.begin(), groups.end(), [&](const Group& group) {
|
||||
CHECK(!group.empty() && !group[0].empty());
|
||||
return IsCompatibleConv2D(conv, group[0][0]);
|
||||
});
|
||||
if (it != groups.end()) {
|
||||
it->push_back(branch);
|
||||
} else {
|
||||
groups.emplace_back();
|
||||
// each group has at least one branch
|
||||
groups.back().push_back(branch);
|
||||
}
|
||||
}
|
||||
}
|
||||
return groups;
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_set<Expr, NodeHash, NodeEqual> conv_roots_;
|
||||
std::unordered_map<Expr, std::vector<const CallNode*>, NodeHash, NodeEqual> children_map_;
|
||||
|
||||
// Two 2d convolutions can be combined if they have the same attributes or
|
||||
// only have different output channels.
|
||||
bool IsCompatibleConv2D(const CallNode* a, const CallNode* b) {
|
||||
AttrsEqual eq;
|
||||
static const Layout kOIHW("OIHW");
|
||||
const auto* attrs_a = a->attrs.as<Conv2DAttrs>();
|
||||
const auto* attrs_b = b->attrs.as<Conv2DAttrs>();
|
||||
CHECK(attrs_a);
|
||||
CHECK(attrs_b);
|
||||
const auto* tweight_a = a->args[1]->type_as<TensorTypeNode>();
|
||||
const auto* tweight_b = b->args[1]->type_as<TensorTypeNode>();
|
||||
const auto shape_a = ConvertLayout(tweight_a->shape, attrs_a->weight_layout, kOIHW);
|
||||
const auto shape_b = ConvertLayout(tweight_b->shape, attrs_b->weight_layout, kOIHW);
|
||||
|
||||
return eq(attrs_a->strides, attrs_b->strides) && eq(attrs_a->padding, attrs_b->padding) &&
|
||||
eq(attrs_a->dilation, attrs_b->dilation) && eq(attrs_a->groups, attrs_b->groups) &&
|
||||
eq(attrs_a->data_layout, attrs_b->data_layout) &&
|
||||
eq(attrs_a->weight_layout, attrs_b->weight_layout) &&
|
||||
eq(attrs_a->out_dtype, attrs_b->out_dtype) &&
|
||||
eq(attrs_a->out_layout, attrs_b->out_layout) && eq(shape_a[2], shape_b[2]) &&
|
||||
eq(shape_a[3], shape_b[3]);
|
||||
}
|
||||
|
||||
// Create a branch starting from conv2d.
|
||||
Branch CreateBranch(const CallNode* conv) {
|
||||
static auto fpattern = Op::GetAttr<TOpPattern>("TOpPattern");
|
||||
// each branch has at least one element, the first element is always conv2d
|
||||
Branch branch{conv};
|
||||
auto it = children_map_.find(GetRef<Expr>(branch.back()));
|
||||
while (it != children_map_.end() && it->second.size() == 1) {
|
||||
const CallNode* call = it->second[0];
|
||||
auto pattern = fpattern[Downcast<Op>(call->op)];
|
||||
if (pattern <= kBroadcast) {
|
||||
branch.push_back(it->second[0]);
|
||||
it = children_map_.find(GetRef<Expr>(branch.back()));
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return branch;
|
||||
}
|
||||
|
||||
void VisitExpr_(const CallNode* n) final {
|
||||
static const Op& conv2d = Op::Get("nn.conv2d");
|
||||
ExprVisitor::VisitExpr_(n);
|
||||
if (n->op.same_as(conv2d) && n->attrs.as<Conv2DAttrs>()->groups == 1) {
|
||||
conv_roots_.insert(n->args[0]);
|
||||
children_map_[n->args[0]].push_back(n);
|
||||
} else {
|
||||
for (size_t i = 0; i < n->args.size(); i++) {
|
||||
children_map_[n->args[i]].push_back(n);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class ParallelConv2DCombiner {
|
||||
public:
|
||||
Expr Combine(const Expr& expr) {
|
||||
auto groups = BranchGroupFinder().Find(expr);
|
||||
for (const Group& group : groups) {
|
||||
if (group.size() < 2) continue;
|
||||
CombineBranches(group);
|
||||
}
|
||||
return ExprSubst(expr, std::move(subst_map_));
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> subst_map_;
|
||||
|
||||
std::tuple<Expr, IndexExpr> TransformWeight(const Group& branches) {
|
||||
int64_t num_filters = 0; // number of filters of the transformed weight
|
||||
Array<Expr> weights;
|
||||
for (const auto& branch : branches) {
|
||||
auto conv2d = branch[0];
|
||||
weights.push_back(conv2d->args[1]);
|
||||
auto channels = GetConv2DSuperChannelsDim(conv2d);
|
||||
num_filters += channels;
|
||||
}
|
||||
auto index = branches[0][0]->attrs.as<Conv2DAttrs>()->weight_layout.find('O');
|
||||
CHECK_NE(index, std::string::npos);
|
||||
return std::make_tuple(MakeConcatenate(TupleNode::make(weights), index),
|
||||
MakeConstScalar(Int(32), num_filters));
|
||||
}
|
||||
|
||||
Call MakeCombinedConv2D(const Group& branches) {
|
||||
static const Op& conv2d = Op::Get("nn.conv2d");
|
||||
Expr data = branches[0][0]->args[0];
|
||||
Expr new_weight;
|
||||
IndexExpr new_channels;
|
||||
std::tie(new_weight, new_channels) = TransformWeight(branches);
|
||||
|
||||
const CallNode* group_root = branches[0][0];
|
||||
const auto* attrs = group_root->attrs.as<Conv2DAttrs>();
|
||||
CHECK(attrs);
|
||||
const auto new_attrs = make_node<Conv2DAttrs>();
|
||||
new_attrs->strides = attrs->strides;
|
||||
new_attrs->padding = attrs->padding;
|
||||
new_attrs->dilation = attrs->dilation;
|
||||
new_attrs->groups = attrs->groups;
|
||||
new_attrs->kernel_size = attrs->kernel_size;
|
||||
new_attrs->data_layout = attrs->data_layout;
|
||||
new_attrs->weight_layout = attrs->weight_layout;
|
||||
new_attrs->out_layout = attrs->out_layout;
|
||||
new_attrs->out_dtype = attrs->out_dtype;
|
||||
new_attrs->channels = new_channels;
|
||||
|
||||
return CallNode::make(conv2d, {data, new_weight}, Attrs{new_attrs}, {});
|
||||
}
|
||||
|
||||
bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index, size_t channel_pos) {
|
||||
AttrsEqual eq;
|
||||
auto ta = a->args[index]->type_as<TensorTypeNode>();
|
||||
auto tb = b->args[index]->type_as<TensorTypeNode>();
|
||||
auto toutput_a = a->type_as<TensorTypeNode>();
|
||||
auto toutput_b = b->type_as<TensorTypeNode>();
|
||||
|
||||
if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size())
|
||||
return false;
|
||||
|
||||
// Position of the 'C' dimension in the argument
|
||||
size_t arg_channel_pos = channel_pos - toutput_a->shape.size() + ta->shape.size();
|
||||
|
||||
// Channel super-dimension shoule be present and not broadcasted
|
||||
if ((arg_channel_pos > channel_pos) || // size_t overflow
|
||||
!eq(ta->shape[arg_channel_pos], toutput_a->shape[channel_pos]) ||
|
||||
!eq(tb->shape[arg_channel_pos], toutput_b->shape[channel_pos]))
|
||||
return false;
|
||||
|
||||
for (size_t i = 0; i < ta->shape.size(); i++) {
|
||||
if (i == arg_channel_pos) continue;
|
||||
if (!eq(ta->shape[i], tb->shape[i]))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check if ops in depth-th level can be combined
|
||||
bool CheckLevel(const Group& branches, size_t depth, size_t channel_pos, size_t parent_index) {
|
||||
const CallNode* call = branches[0][depth];
|
||||
AttrsEqual attrs_equal;
|
||||
// check if all branches in current depth can be combined
|
||||
for (auto it = branches.begin() + 1; it != branches.end(); it++) {
|
||||
const Branch& branch = *it;
|
||||
if (!branch[depth]->op.same_as(call->op) ||
|
||||
!attrs_equal(branch[depth]->attrs, call->attrs) ||
|
||||
branch[depth]->args.size() != call->args.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (branch[depth]->args[parent_index].get() != branch[depth - 1])
|
||||
return false;
|
||||
|
||||
// Check args
|
||||
for (size_t i = 0; i < call->args.size(); i++) {
|
||||
if (i == parent_index) continue;
|
||||
|
||||
if (!IsArgCompatible(call, branch[depth], i, channel_pos) ||
|
||||
!attrs_equal(call->attrs, branch[depth]->attrs)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Combine args and make the combined CallNode
|
||||
Call MakeCombinedCall(const Expr& data, const Group& branches, size_t depth, size_t channel_pos,
|
||||
size_t parent_index) {
|
||||
Array<Expr> new_args;
|
||||
const CallNode* call = branches[0][depth];
|
||||
size_t ndim = call->type_as<TensorTypeNode>()->shape.size();
|
||||
|
||||
for (size_t i = 0; i < call->args.size(); i++) {
|
||||
if (i == parent_index) {
|
||||
new_args.push_back(data);
|
||||
continue;
|
||||
}
|
||||
size_t arg_ndim = call->args[i]->type_as<TensorTypeNode>()->shape.size();
|
||||
size_t arg_channel_pos = channel_pos - ndim + arg_ndim;
|
||||
Array<Expr> tuple;
|
||||
for (const auto& branch : branches) {
|
||||
tuple.push_back(branch[depth]->args[i]);
|
||||
}
|
||||
auto concat = MakeConcatenate(TupleNode::make(tuple), arg_channel_pos);
|
||||
new_args.push_back(std::move(concat));
|
||||
}
|
||||
return CallNode::make(call->op, new_args, call->attrs, {});
|
||||
}
|
||||
|
||||
// Replace output of each branch with slices of the combined output
|
||||
void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth,
|
||||
size_t channel_pos) {
|
||||
int64_t index = 0;
|
||||
for (const auto& branch : branches) {
|
||||
const CallNode* conv2d = branch[0];
|
||||
int64_t channels = GetConv2DSuperChannelsDim(conv2d);
|
||||
Array<Integer> begin;
|
||||
Array<Integer> end;
|
||||
for (size_t i = 0; i < channel_pos; i++) {
|
||||
begin.push_back(0);
|
||||
end.push_back(NullValue<Integer>());
|
||||
}
|
||||
begin.push_back(index);
|
||||
index += channels;
|
||||
end.push_back(index);
|
||||
auto slice = MakeStridedSlice(data, std::move(begin), std::move(end), Array<Integer>{});
|
||||
subst_map_[GetRef<Expr>(branch[depth])] = slice;
|
||||
}
|
||||
}
|
||||
|
||||
// Combine branches in a group. Conv2d in different branches in the same group are safe to
|
||||
// combine. Subsequent ops may or may not be combined. We start from conv2d and try to
|
||||
// combine ops from all branches in the same depth.
|
||||
void CombineBranches(const Group& branches) {
|
||||
Call combined = MakeCombinedConv2D(branches);
|
||||
auto conv_param = combined->attrs.as<Conv2DAttrs>();
|
||||
const std::string& layout =
|
||||
conv_param->out_layout == "" ? conv_param->data_layout : conv_param->out_layout;
|
||||
size_t channel_pos = layout.find('C');
|
||||
CHECK_NE(channel_pos, std::string::npos);
|
||||
auto it = std::min_element(branches.begin(), branches.end(),
|
||||
[](const Branch& branch_a,
|
||||
const Branch& branch_b) {
|
||||
return branch_a.size() < branch_b.size();
|
||||
});
|
||||
size_t depth = it->size();
|
||||
size_t i;
|
||||
// starting from 1 to skip the conv2d
|
||||
for (i = 1; i < depth; i++) {
|
||||
size_t parent_index;
|
||||
for (parent_index = 0; parent_index < branches[0][i]->args.size(); parent_index++) {
|
||||
if (branches[0][i]->args[parent_index].get() == branches[0][i - 1]) break;
|
||||
}
|
||||
CHECK_NE(parent_index, branches[0][i]->args.size());
|
||||
if (!CheckLevel(branches, i, channel_pos, parent_index)) break;
|
||||
combined = MakeCombinedCall(combined, branches, i, channel_pos, parent_index);
|
||||
}
|
||||
UpdateGroupOutput(combined, branches, i - 1, channel_pos);
|
||||
}
|
||||
};
|
||||
|
||||
Expr CombineParallelConv2D(const Expr& expr) { return ParallelConv2DCombiner().Combine(expr); }
|
||||
|
||||
TVM_REGISTER_API("relay._ir_pass.CombineParallelConv2D")
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
*ret = CombineParallelConv2D(args[0]);
|
||||
});
|
||||
|
||||
} // namespace relay
|
||||
} // namespace tvm
|
|
@ -0,0 +1,35 @@
|
|||
/*!
|
||||
* Copyright (c) 2018 by Contributors
|
||||
* \file expr_subst.h
|
||||
* \brief Utility functions for substituting expressions.
|
||||
*/
|
||||
|
||||
#include <tvm/relay/expr_functor.h>
|
||||
#include "./expr_subst.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace relay {
|
||||
|
||||
class ExprSubstituter : public ExprMutator {
|
||||
public:
|
||||
explicit ExprSubstituter(std::unordered_map<Expr, Expr, NodeHash, NodeEqual> subst_map)
|
||||
: subst_map_(subst_map) {}
|
||||
|
||||
Expr VisitExpr(const Expr& expr) final {
|
||||
auto it = subst_map_.find(expr);
|
||||
if (it != subst_map_.end()) {
|
||||
return (*it).second;
|
||||
}
|
||||
return ExprMutator::VisitExpr(expr);
|
||||
}
|
||||
|
||||
private:
|
||||
tvm::Map<Expr, Expr> subst_map_;
|
||||
};
|
||||
|
||||
Expr ExprSubst(const Expr& expr, std::unordered_map<Expr, Expr, NodeHash, NodeEqual> subst_map) {
|
||||
return ExprSubstituter(std::move(subst_map)).Mutate(expr);
|
||||
}
|
||||
|
||||
} // namespace relay
|
||||
} // namespace tvm
|
|
@ -0,0 +1,18 @@
|
|||
/*!
|
||||
* Copyright (c) 2018 by Contributors
|
||||
* \file expr_subst.h
|
||||
* \brief Utility functions for substituting expressions.
|
||||
*/
|
||||
#ifndef TVM_RELAY_PASS_EXPR_SUBST_H_
|
||||
#define TVM_RELAY_PASS_EXPR_SUBST_H_
|
||||
#include <tvm/relay/expr.h>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace tvm {
|
||||
namespace relay {
|
||||
|
||||
Expr ExprSubst(const Expr& expr, std::unordered_map<Expr, Expr, NodeHash, NodeEqual> subst_map);
|
||||
|
||||
} // namespace relay
|
||||
} // namespace tvm
|
||||
#endif // TVM_RELAY_PASS_EXPR_SUBST_H_
|
|
@ -11,6 +11,7 @@
|
|||
#include <tvm/relay/op.h>
|
||||
#include <tvm/relay/expr.h>
|
||||
#include <tvm/relay/attrs/transform.h>
|
||||
#include <string>
|
||||
#include "../op/layout.h"
|
||||
|
||||
|
||||
|
@ -120,6 +121,19 @@ inline bool IsDepthwiseConv2D(const Call& call,
|
|||
is_const_int(wshape[1], 1);
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Get super-dimension of output channels of conv2d
|
||||
* \param call The conv2d call.
|
||||
* \return Super-dimension size of output channels of conv2d.
|
||||
*/
|
||||
inline int64_t GetConv2DSuperChannelsDim(const CallNode* call) {
|
||||
auto param = call->attrs.as<Conv2DAttrs>();
|
||||
auto tweight = call->args[1]->type_as<TensorTypeNode>();
|
||||
auto index = param->weight_layout.find('O');
|
||||
CHECK_NE(index, std::string::npos);
|
||||
auto channels = as_const_int(tweight->shape[index]);
|
||||
return *channels;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Create a Constant with a scalar
|
||||
|
@ -172,6 +186,10 @@ inline Expr ReshapeLike(Expr lhs, Expr rhs) {
|
|||
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
|
||||
}
|
||||
|
||||
Expr MakeConcatenate(Expr data, int axis);
|
||||
|
||||
Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides);
|
||||
|
||||
} // namespace relay
|
||||
} // namespace tvm
|
||||
#endif // TVM_RELAY_PASS_PATTERN_UTIL_H_
|
||||
|
|
|
@ -0,0 +1,138 @@
|
|||
from tvm import relay
|
||||
import numpy as np
|
||||
|
||||
|
||||
def test_combine_parallel_conv2d():
|
||||
"""Simple testcase."""
|
||||
def before(x, w1, w2, w3, w4):
|
||||
args = [x, w1, w2, w3, w4]
|
||||
y1 = relay.nn.conv2d(x, w1)
|
||||
y2 = relay.nn.conv2d(x, w2)
|
||||
# y3 cannot be combined
|
||||
y3 = relay.nn.conv2d(x, w3)
|
||||
y4 = relay.nn.conv2d(x, w4)
|
||||
y = relay.Tuple((y1, y2, y3, y4))
|
||||
return relay.Function(args, y)
|
||||
|
||||
def expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4):
|
||||
# use a fixed order of args so alpha equal check can pass
|
||||
args = [x, w1, w2, w3, w4]
|
||||
w = relay.concatenate((w1, w2, w4), axis=0)
|
||||
y = relay.nn.conv2d(x, w, channels=channels1 + channels2 + channels4)
|
||||
y1 = relay.strided_slice(y, [0, 0], [None, channels1])
|
||||
y2 = relay.strided_slice(y, [0, channels1], [None, channels1 + channels2])
|
||||
y3 = relay.nn.conv2d(x, w3)
|
||||
y4 = relay.strided_slice(y, [0, channels1 + channels2],
|
||||
[None, channels1 + channels2 + channels4])
|
||||
y = relay.Tuple((y1, y2, y3, y4))
|
||||
return relay.Function(args, y)
|
||||
|
||||
def check(x_shape, channels1, channels2, channels3, channels4):
|
||||
x = relay.var("x", shape=x_shape)
|
||||
in_c = x_shape[1]
|
||||
w1 = relay.var("w1", shape=(channels1, in_c, 1, 1))
|
||||
w2 = relay.var("w2", shape=(channels2, in_c, 1, 1))
|
||||
w3 = relay.var("w3", shape=(channels3, in_c, 3, 3))
|
||||
w4 = relay.var("w4", shape=(channels4, in_c, 1, 1))
|
||||
|
||||
y_before = before(x, w1, w2, w3, w4)
|
||||
y = relay.ir_pass.infer_type(y_before)
|
||||
y = relay.ir_pass.combine_parallel_conv2d(y)
|
||||
y = relay.ir_pass.infer_type(y)
|
||||
y_expected = expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4)
|
||||
y_expected = relay.ir_pass.infer_type(y_expected)
|
||||
assert relay.ir_pass.alpha_equal(y, y_expected)
|
||||
|
||||
check((1, 4, 16, 16), 4, 4, 4, 4)
|
||||
check((1, 4, 16, 16), 4, 8, 4, 7)
|
||||
|
||||
|
||||
def test_combine_parallel_conv2d_scale_relu():
|
||||
"""Testcase of combining conv2d + scale + relu"""
|
||||
def before(x, w1, w2, scale1, scale2, bias):
|
||||
args = [x, w1, w2, scale1, scale2, bias]
|
||||
y1 = relay.nn.conv2d(x, w1)
|
||||
y1 = relay.multiply(y1, scale1)
|
||||
y1 = relay.nn.relu(y1)
|
||||
y2 = relay.nn.conv2d(x, w2)
|
||||
y2 = relay.multiply(y2, scale2)
|
||||
y2 = relay.nn.relu(y2)
|
||||
y2 = relay.add(y2, bias)
|
||||
y = relay.Tuple((y1, y2))
|
||||
return relay.Function(args, y)
|
||||
|
||||
def expected(x, w1, w2, scale1, scale2, bias, channels1, channels2):
|
||||
args = [x, w1, w2, scale1, scale2, bias]
|
||||
w = relay.concatenate((w1, w2), axis=0)
|
||||
scale = relay.concatenate((scale1, scale2), axis=0)
|
||||
y = relay.nn.conv2d(x, w, channels=channels1 + channels2)
|
||||
y = relay.multiply(y, scale)
|
||||
y = relay.nn.relu(y)
|
||||
y1 = relay.strided_slice(y, [0, 0], [None, channels1])
|
||||
y2 = relay.strided_slice(y, [0, channels1], [None, channels1 + channels2])
|
||||
y2 = relay.add(y2, bias)
|
||||
y = relay.Tuple((y1, y2))
|
||||
return relay.Function(args, y)
|
||||
|
||||
def check(x_shape, channels1, channels2):
|
||||
x = relay.var("x", shape=x_shape)
|
||||
in_c = x_shape[1]
|
||||
w1 = relay.var("w1", shape=(channels1, in_c, 1, 1))
|
||||
w2 = relay.var("w2", shape=(channels2, in_c, 1, 1))
|
||||
scale1 = relay.var("scale1", shape=(channels1, 1, 1))
|
||||
scale2 = relay.var("scale2", shape=(channels2, 1, 1))
|
||||
bias = relay.var("bias", shape=(channels2, 1, 1))
|
||||
y_before = before(x, w1, w2, scale1, scale2, bias)
|
||||
y = relay.ir_pass.infer_type(y_before)
|
||||
y = relay.ir_pass.combine_parallel_conv2d(y)
|
||||
y = relay.ir_pass.infer_type(y)
|
||||
y_expected = expected(x, w1, w2, scale1, scale2, bias, channels1, channels2)
|
||||
y_expected = relay.ir_pass.infer_type(y_expected)
|
||||
assert relay.ir_pass.alpha_equal(y, y_expected)
|
||||
|
||||
check((1, 4, 16, 16), 4, 8)
|
||||
|
||||
|
||||
def test_combine_parallel_conv2d_scale():
|
||||
"""Testcase of un-combinable scale"""
|
||||
def before(x, w1, w2, scale1, scale2):
|
||||
args = [x, w1, w2, scale1, scale2]
|
||||
y1 = relay.nn.conv2d(x, w1)
|
||||
y1 = relay.multiply(y1, scale1)
|
||||
y2 = relay.nn.conv2d(x, w2)
|
||||
y2 = relay.multiply(y2, scale2)
|
||||
y = relay.Tuple((y1, y2))
|
||||
return relay.Function(args, y)
|
||||
|
||||
def expected(x, w1, w2, scale1, scale2, channels1, channels2):
|
||||
args = [x, w1, w2, scale1, scale2]
|
||||
w = relay.concatenate((w1, w2), axis=0)
|
||||
y = relay.nn.conv2d(x, w, channels=channels1 + channels2)
|
||||
y1 = relay.strided_slice(y, [0, 0], [None, channels1])
|
||||
y2 = relay.strided_slice(y, [0, channels1], [None, channels1 + channels2])
|
||||
y1 = relay.multiply(y1, scale1)
|
||||
y2 = relay.multiply(y2, scale2)
|
||||
y = relay.Tuple((y1, y2))
|
||||
return relay.Function(args, y)
|
||||
|
||||
def check(x_shape, channels1, channels2):
|
||||
x = relay.var("x", shape=x_shape)
|
||||
in_c = x_shape[1]
|
||||
w1 = relay.var("w1", shape=(channels1, in_c, 1, 1))
|
||||
w2 = relay.var("w2", shape=(channels2, in_c, 1, 1))
|
||||
scale1 = relay.var("scale1", shape=(1,))
|
||||
scale2 = relay.var("scale2", shape=(1,))
|
||||
y_before = before(x, w1, w2, scale1, scale2)
|
||||
y = relay.ir_pass.infer_type(y_before)
|
||||
y = relay.ir_pass.combine_parallel_conv2d(y)
|
||||
y = relay.ir_pass.infer_type(y)
|
||||
y_expected = expected(x, w1, w2, scale1, scale2, channels1, channels2)
|
||||
y_expected = relay.ir_pass.infer_type(y_expected)
|
||||
assert relay.ir_pass.alpha_equal(y, y_expected)
|
||||
|
||||
check((1, 4, 16, 16), 4, 8)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_combine_parallel_conv2d()
|
||||
test_combine_parallel_conv2d_scale_relu()
|
||||
test_combine_parallel_conv2d_scale()
|
Загрузка…
Ссылка в новой задаче