Improve x86 Inception (#1506)
* Improve x86 pooling and concat * Fix * Fix test concatenate correct layout * Add conditional vectorize * Fix lint * Modify schedule for global pooling * Fix * Fix warning * Fix alter layout test * Remove vectorization for pooling when using 4D layout * Remove vectorization for 4D concat * Fix concatenate layout * Fix concatenate schedule * Fix concat * Fix lint * Fix concat * Simplify pooling logic * Update docstring * Fix test topi pooling * Small changes
This commit is contained in:
Родитель
4dc21bdb29
Коммит
5d533ec99b
|
@ -280,20 +280,22 @@ reg.register_pattern("conv2d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE)
|
|||
|
||||
# max_pool2d
|
||||
@reg.register_schedule("max_pool2d")
|
||||
def schedule_max_pool2d(_, outs, target):
|
||||
def schedule_max_pool2d(attrs, outs, target):
|
||||
"""Schedule definition of max_pool2d"""
|
||||
layout = attrs["layout"]
|
||||
with tvm.target.create(target):
|
||||
return topi.generic.schedule_pool(outs)
|
||||
return topi.generic.schedule_pool(outs, layout)
|
||||
|
||||
reg.register_pattern("max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
|
||||
|
||||
|
||||
# avg_pool2d
|
||||
@reg.register_schedule("avg_pool2d")
|
||||
def schedule_avg_pool2d(_, outs, target):
|
||||
def schedule_avg_pool2d(attrs, outs, target):
|
||||
"""Schedule definition of avg_pool2d"""
|
||||
layout = attrs["layout"]
|
||||
with tvm.target.create(target):
|
||||
return topi.generic.schedule_pool(outs)
|
||||
return topi.generic.schedule_pool(outs, layout)
|
||||
|
||||
reg.register_pattern("avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
"""Tensor transformation ops"""
|
||||
from __future__ import absolute_import
|
||||
|
||||
import tvm
|
||||
import topi
|
||||
from .tensor import _fschedule_broadcast, _fschedule_injective
|
||||
from . import registry as reg
|
||||
|
@ -58,8 +59,13 @@ reg.register_pattern("squeeze", OpPattern.INJECTIVE)
|
|||
reg.register_schedule("squeeze", _fschedule_injective)
|
||||
|
||||
# concatenate
|
||||
@reg.register_schedule("concatenate")
|
||||
def schedule_concatenate(_, outs, target):
|
||||
"""Schedule definition of concatenate"""
|
||||
with tvm.target.create(target):
|
||||
return topi.generic.schedule_concatenate(outs)
|
||||
|
||||
reg.register_pattern("concatenate", OpPattern.INJECTIVE)
|
||||
reg.register_schedule("concatenate", _fschedule_injective)
|
||||
|
||||
# split
|
||||
reg.register_pattern("split", OpPattern.INJECTIVE)
|
||||
|
|
|
@ -129,15 +129,31 @@ inline bool ConcatenateCorrectLayout(const NodeAttrs& attrs,
|
|||
std::vector<Layout> *ilayouts,
|
||||
const std::vector<Layout> *last_ilayouts,
|
||||
std::vector<Layout> *olayouts) {
|
||||
const ConcatenateParam& param = nnvm::get<ConcatenateParam>(attrs.parsed);
|
||||
CHECK_EQ(ilayouts->size(), last_ilayouts->size());
|
||||
CHECK_EQ(olayouts->size(), 1U);
|
||||
|
||||
for (size_t i = 0; i < ilayouts->size(); ++i) {
|
||||
const Layout& input = last_ilayouts->at(i).defined() ?
|
||||
last_ilayouts->at(i) : ilayouts->at(i);
|
||||
NNVM_ASSIGN_LAYOUT(*ilayouts, i, input);
|
||||
Layout layout;
|
||||
if (!ilayouts->at(0).defined()) {
|
||||
layout = last_ilayouts->at(0);
|
||||
} else if (param.axis >= static_cast<int>(ilayouts->at(0).ndim())) {
|
||||
CHECK(last_ilayouts->at(0).defined())
|
||||
<< "Current input layout " << ilayouts->at(0)
|
||||
<< " is invalid but last input layout is not "
|
||||
"defined for the first input.";
|
||||
layout = last_ilayouts->at(0);
|
||||
} else if (last_ilayouts->at(0).defined()
|
||||
&& ilayouts->at(0)[param.axis]
|
||||
!= last_ilayouts->at(0)[param.axis]) {
|
||||
layout = last_ilayouts->at(0);
|
||||
} else {
|
||||
layout = ilayouts->at(0);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < ilayouts->size(); ++i) {
|
||||
NNVM_ASSIGN_LAYOUT(*ilayouts, i, layout);
|
||||
}
|
||||
NNVM_ASSIGN_LAYOUT(*olayouts, 0, layout);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -77,14 +77,25 @@ def test_concatenate():
|
|||
g, ldict = correct_layout(z, {"x": "HW", "y": "HW"})
|
||||
assert(ldict["x"][0] == "HW")
|
||||
assert(ldict["y"][0] == "HW")
|
||||
assert(ldict["concat"][0] == "__undef__")
|
||||
assert(ldict["concat"][0] == "HW")
|
||||
# second pass will insert layout transform
|
||||
_, ldict = correct_layout(g, {"x": "HW16w", "y": "HW16w"})
|
||||
assert(ldict["x"][0] == "HW16w")
|
||||
assert(ldict["y"][0] == "HW16w")
|
||||
assert(ldict["x_HW"][0] == "HW")
|
||||
assert(ldict["y_HW"][0] == "HW")
|
||||
assert(ldict["concat"][0] == "__undef__")
|
||||
assert(ldict["concat"][0] == "HW16w")
|
||||
|
||||
x1 = sym.Variable("x", shape=(10, 20, 60))
|
||||
x2 = sym.Variable("y", shape=(10, 20, 40))
|
||||
z = sym.concatenate(x1, x2, axis=2, name="concat")
|
||||
g, ldict = correct_layout(z, {"x": "H20wW", "y": "H20wW"})
|
||||
assert(ldict["x"][0] == "H20wW")
|
||||
assert(ldict["y"][0] == "H20wW")
|
||||
assert(ldict["concat"][0] == "H20wW")
|
||||
# second pass will insert layout transform
|
||||
_, ldict = correct_layout(g, {"x": "HW", "y": "HW"})
|
||||
assert(ldict["x_H20wW"][0] == "H20wW")
|
||||
assert(ldict["x_H20wW"][0] == "H20wW")
|
||||
assert(ldict["concat"][0] == "H20wW")
|
||||
|
||||
|
||||
def test_expand_dims():
|
||||
|
@ -349,4 +360,4 @@ if __name__ == "__main__":
|
|||
test_transpose()
|
||||
test_broadcast_to()
|
||||
test_broadcast_binary()
|
||||
test_reduce()
|
||||
test_reduce()
|
||||
|
|
|
@ -112,18 +112,18 @@ inline Tensor pool_impl(const Tensor& x,
|
|||
}, "tensor", "pool_max");
|
||||
} else if (pool_type == kAvgPool) {
|
||||
auto temp = do_pad ? pad(x, pad_before, pad_after, 0, "pad_temp") : x;
|
||||
auto tsum = tvm::compute(out_shape, [&](const Array<Var>& output) {
|
||||
auto tavg = [&](const Array<Var>& output, Expr divide_factor) {
|
||||
Array<Expr> indices;
|
||||
for (const Var& var : output) indices.push_back(var);
|
||||
indices.Set(height_axis, output[height_axis] * stride_height + dheight);
|
||||
indices.Set(width_axis, output[width_axis] * stride_width + dwidth);
|
||||
return tvm::sum(temp(indices), { dheight, dwidth });
|
||||
}, "tensor", "pool_avg");
|
||||
return tvm::sum(temp(indices) / divide_factor, { dheight, dwidth });
|
||||
};
|
||||
|
||||
return tvm::compute(out_shape,
|
||||
[&](const Array<Var>& output) {
|
||||
if (count_include_pad) {
|
||||
return tsum(output) / (kernel_height * kernel_width);
|
||||
return tavg(output, kernel_height * kernel_width);
|
||||
} else {
|
||||
Expr h_start = output[height_axis] * stride_height - pad_top;
|
||||
Expr w_start = output[width_axis] * stride_width - pad_left;
|
||||
|
@ -133,9 +133,9 @@ inline Tensor pool_impl(const Tensor& x,
|
|||
w_start = ir::Max::make(w_start, make_const(Int(32), 0));
|
||||
Expr divide_factor = ir::Max::make((h_end - h_start) * (w_end - w_start),
|
||||
make_const(Int(32), 1));
|
||||
return tsum(output) / divide_factor;
|
||||
return tavg(output, divide_factor);
|
||||
}
|
||||
}, "tensor", kElementWise);
|
||||
}, "tensor", "pool_avg");
|
||||
} else {
|
||||
LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
|
||||
return x;
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# pylint: disable=invalid-name, unused-variable
|
||||
# pylint: disable=invalid-name, unused-variable, unused-argument
|
||||
"""Schedule for pooling operators"""
|
||||
import tvm
|
||||
from .. import tag
|
||||
|
@ -70,7 +70,7 @@ def schedule_global_pool(outs):
|
|||
|
||||
|
||||
@generic.schedule_pool.register(["cuda", "gpu"])
|
||||
def schedule_pool(outs):
|
||||
def schedule_pool(outs, layout):
|
||||
"""Schedule for pool.
|
||||
|
||||
Parameters
|
||||
|
@ -79,6 +79,9 @@ def schedule_pool(outs):
|
|||
The computation graph description of pool
|
||||
in the format of an array of tensors.
|
||||
|
||||
layout: str
|
||||
Data layout.
|
||||
|
||||
Returns
|
||||
-------
|
||||
s: Schedule
|
||||
|
|
|
@ -29,5 +29,22 @@ def schedule_injective(outs):
|
|||
s[x].fuse(s[x].op.axis)
|
||||
return s
|
||||
|
||||
@tvm.target.generic_func
|
||||
def schedule_concatenate(outs):
|
||||
"""Schedule for concatenate op.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
outs: Array of Tensor
|
||||
The computation graph description of reduce in the format
|
||||
of an array of tensors.
|
||||
|
||||
Returns
|
||||
-------
|
||||
sch: Schedule
|
||||
The computation schedule for the op.
|
||||
"""
|
||||
return schedule_injective(outs)
|
||||
|
||||
schedule_elemwise = schedule_injective
|
||||
schedule_broadcast = schedule_injective
|
||||
|
|
|
@ -282,7 +282,7 @@ def schedule_dense(outs):
|
|||
|
||||
|
||||
@tvm.target.override_native_generic_func("schedule_pool")
|
||||
def schedule_pool(outs):
|
||||
def schedule_pool(outs, layout):
|
||||
"""Schedule for pool
|
||||
|
||||
Parameters
|
||||
|
@ -291,6 +291,9 @@ def schedule_pool(outs):
|
|||
The computation graph description of pool
|
||||
in the format of an array of tensors.
|
||||
|
||||
layout: str
|
||||
Data layout.
|
||||
|
||||
Returns
|
||||
-------
|
||||
sch: Schedule
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# pylint: disable=invalid-name, unused-variable
|
||||
# pylint: disable=invalid-name, unused-variable, unused-argument
|
||||
"""Schedule for pooling operators"""
|
||||
import tvm
|
||||
from .. import tag
|
||||
|
@ -54,7 +54,7 @@ def schedule_global_pool(outs):
|
|||
|
||||
|
||||
@generic.schedule_pool.register(["opengl"])
|
||||
def schedule_pool(outs):
|
||||
def schedule_pool(outs, layout):
|
||||
"""Schedule for pool.
|
||||
|
||||
Parameters
|
||||
|
@ -63,6 +63,9 @@ def schedule_pool(outs):
|
|||
The computation graph description of pool
|
||||
in the format of an array of tensors.
|
||||
|
||||
layout: str
|
||||
Data layout.
|
||||
|
||||
Returns
|
||||
-------
|
||||
s: Schedule
|
||||
|
|
|
@ -33,5 +33,51 @@ def schedule_injective(outs):
|
|||
s[x].parallel(s[x].op.axis[0])
|
||||
return s
|
||||
|
||||
@generic.schedule_concatenate.register(["cpu"])
|
||||
def schedule_concatenate(outs):
|
||||
"""X86 schedule for concatenate op.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
outs: Array of Tensor
|
||||
The computation graph description of injective in the format
|
||||
of an array of tensors.
|
||||
|
||||
Returns
|
||||
-------
|
||||
sch: Schedule
|
||||
The computation schedule for the op.
|
||||
"""
|
||||
def vectorize(sch, tensor, vectorize_limit):
|
||||
"""Internal vectorization function for concatenate."""
|
||||
inner_axis = s[tensor].op.axis[len(s[tensor].op.axis) - 1]
|
||||
inner_length = tensor.shape[len(tensor.shape) - 1].value
|
||||
if inner_length <= vectorize_limit:
|
||||
sch[tensor].vectorize(inner_axis)
|
||||
else:
|
||||
split_factor = 1
|
||||
for i in range(vectorize_limit, 1, -1):
|
||||
if inner_length % i == 0:
|
||||
split_factor = i
|
||||
break
|
||||
if split_factor > 1:
|
||||
_, inner_i = sch[tensor].split(inner_axis, split_factor)
|
||||
sch[tensor].vectorize(inner_i)
|
||||
|
||||
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
|
||||
x = outs[0]
|
||||
s = tvm.create_schedule([x.op for x in outs])
|
||||
tvm.schedule.AutoInlineInjective(s)
|
||||
if len(s[x].op.axis) >= 5:
|
||||
fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1], s[x].op.axis[2])
|
||||
vectorize(s, x, 64)
|
||||
s[x].parallel(fused)
|
||||
elif len(s[x].op.axis) >= 3:
|
||||
fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1])
|
||||
s[x].parallel(fused)
|
||||
else:
|
||||
s[x].parallel(s[x].op.axis[0])
|
||||
return s
|
||||
|
||||
schedule_elemwise = schedule_injective
|
||||
schedule_broadcast = schedule_injective
|
||||
|
|
|
@ -4,19 +4,47 @@ import tvm
|
|||
from .. import generic
|
||||
from .. import tag
|
||||
|
||||
def _parallel_sch(sch):
|
||||
def _parallel_sch(sch, oshape, do_vectorize=False):
|
||||
def vectorize(fused_axis, num_parallel_axis, vectorize_limit=64):
|
||||
"""Internal vectorization utility function."""
|
||||
reorder_axis = [fused_axis]
|
||||
for i in range(num_parallel_axis, len(sch.op.axis) - 1):
|
||||
reorder_axis.append(sch.op.axis[i])
|
||||
kw, kh = sch.op.reduce_axis
|
||||
fuse_k = sch.fuse(kw, kh)
|
||||
c = sch.op.axis[len(sch.op.axis) - 1]
|
||||
reorder_axis += [fuse_k, c]
|
||||
sch.reorder(*reorder_axis)
|
||||
inner_length = oshape[len(oshape) - 1].value
|
||||
if inner_length <= vectorize_limit:
|
||||
sch.vectorize(c)
|
||||
else:
|
||||
split_factor = 1
|
||||
for i in range(vectorize_limit, 1, -1):
|
||||
if inner_length % i == 0:
|
||||
split_factor = i
|
||||
break
|
||||
if split_factor > 1:
|
||||
_, c_i = sch.split(c, split_factor)
|
||||
sch.vectorize(c_i)
|
||||
|
||||
if len(sch.op.axis) >= 5:
|
||||
fused = sch.fuse(sch.op.axis[0], sch.op.axis[1], sch.op.axis[2])
|
||||
sch.parallel(fused)
|
||||
if do_vectorize:
|
||||
vectorize(fused, 3)
|
||||
|
||||
elif len(sch.op.axis) >= 3:
|
||||
fused = sch.fuse(sch.op.axis[0], sch.op.axis[1])
|
||||
sch.parallel(fused)
|
||||
if do_vectorize:
|
||||
vectorize(fused, 2)
|
||||
else:
|
||||
sch.parallel(sch.op.axis[0])
|
||||
return
|
||||
sch.parallel(fused)
|
||||
|
||||
|
||||
@generic.schedule_pool.register(["cpu"])
|
||||
def schedule_pool(outs):
|
||||
def schedule_pool(outs, layout):
|
||||
"""Schedule for pool
|
||||
|
||||
Parameters
|
||||
|
@ -25,6 +53,9 @@ def schedule_pool(outs):
|
|||
The computation graph description of pool
|
||||
in the format of an array of tensors.
|
||||
|
||||
layout: str
|
||||
Data layout.
|
||||
|
||||
Returns
|
||||
-------
|
||||
sch: Schedule
|
||||
|
@ -37,7 +68,8 @@ def schedule_pool(outs):
|
|||
def _schedule(PaddedInput, Pool):
|
||||
if isinstance(PaddedInput.op, tvm.tensor.ComputeOp):
|
||||
s[PaddedInput].compute_inline()
|
||||
_parallel_sch(s[Pool])
|
||||
do_vectorize = layout[-1] not in "HWhw"
|
||||
_parallel_sch(s[Pool], outs[0].shape, do_vectorize)
|
||||
|
||||
def traverse(OP):
|
||||
"""Internal travserse function"""
|
||||
|
@ -93,7 +125,7 @@ def schedule_global_pool(outs):
|
|||
# schedule pool
|
||||
elif OP.tag.startswith('global_pool'):
|
||||
Pool = OP.output(0)
|
||||
_parallel_sch(s[Pool])
|
||||
_parallel_sch(s[Pool], outs[0].shape)
|
||||
else:
|
||||
raise RuntimeError("Unsupported operator: %s" % OP.tag)
|
||||
|
||||
|
|
|
@ -10,9 +10,11 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_
|
|||
kw = kh
|
||||
sw = sh
|
||||
pt, pl, pb, pr = padding
|
||||
layout = "NCHW"
|
||||
A = tvm.placeholder((n, ic, ih, iw), name='A')
|
||||
B = topi.nn.pool(A, kernel=[kh, kw], stride=[sh, sw], padding=padding,
|
||||
pool_type=pool_type, ceil_mode=ceil_mode, count_include_pad=count_include_pad)
|
||||
pool_type=pool_type, ceil_mode=ceil_mode,
|
||||
layout="NCHW", count_include_pad=count_include_pad)
|
||||
B = topi.nn.relu(B)
|
||||
dtype = A.dtype
|
||||
|
||||
|
@ -54,7 +56,7 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_
|
|||
return
|
||||
print("Running on target: %s" % device)
|
||||
with tvm.target.create(device):
|
||||
s = topi.generic.schedule_pool(B)
|
||||
s = topi.generic.schedule_pool(B, layout)
|
||||
|
||||
a = tvm.nd.array(a_np, ctx)
|
||||
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
|
||||
|
|
Загрузка…
Ссылка в новой задаче