[ARM][Performance] Improve ARM CPU depthwise convolution performance (#2345)
* Add sptialpack schedule for arm cpu depthwise convolution * Supply comments.
This commit is contained in:
Родитель
b1c7869fe0
Коммит
394cf9f72a
|
@ -73,15 +73,13 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs,
|
|||
CHECK_EQ(param.channels % param.groups, 0U)
|
||||
<< "output channels must divide group size";
|
||||
|
||||
TShape wshape({param.channels / param.groups,
|
||||
TShape wshape({param.channels,
|
||||
dshape[1] / param.groups,
|
||||
param.kernel_size[0],
|
||||
param.kernel_size[1]});
|
||||
|
||||
wshape = ConvertLayout(wshape, kOIHW, kernel_layout);
|
||||
|
||||
wshape[kernel_layout.indexof('O')] *= param.groups;
|
||||
|
||||
if (in_shape->at(Conv2DParam::kWeight).ndim() == 0) {
|
||||
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DParam::kWeight, wshape);
|
||||
}
|
||||
|
|
|
@ -52,12 +52,11 @@ bool Conv2DRel(const Array<Type>& types,
|
|||
CHECK_EQ(param->kernel_size.size(), 2);
|
||||
CHECK_EQ(param->dilation.size(), 2);
|
||||
std::vector<IndexExpr> wshape(
|
||||
{param->channels / param->groups,
|
||||
{param->channels,
|
||||
dshape_nchw[1] / param->groups,
|
||||
param->kernel_size[0],
|
||||
param->kernel_size[1]});
|
||||
wshape = ConvertLayout(wshape, kOIHW, kernel_layout);
|
||||
wshape[kernel_layout.Indexof('O')] *= param->groups;
|
||||
channels = param->channels;
|
||||
dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
|
||||
dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
|
||||
|
|
|
@ -11,7 +11,8 @@ from tvm import autotvm
|
|||
|
||||
from ..generic import schedule_conv2d_nchw, schedule_conv2d_winograd_without_weight_transform
|
||||
from ..util import traverse_inline, get_const_tuple, const_matrix
|
||||
from ..nn import dilate, pad, conv2d, conv2d_alter_layout, conv2d_winograd_without_weight_transform
|
||||
from ..nn import dilate, pad, conv2d, conv2d_alter_layout, \
|
||||
conv2d_winograd_without_weight_transform, depthwise_conv2d_nchw
|
||||
from ..nn.util import get_const_int, get_pad_tuple
|
||||
|
||||
@autotvm.register_topi_compute(conv2d, 'arm_cpu', ['direct'])
|
||||
|
@ -556,7 +557,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
|
|||
if out_dtype == "" or out_dtype == "same":
|
||||
out_dtype = tinfos[0].dtype
|
||||
|
||||
if layout != 'NCHW' or groups != 1:
|
||||
if layout != 'NCHW':
|
||||
return None
|
||||
if dilation != (1, 1):
|
||||
warnings.warn("Does not support weight pre-transform for dilated convolution.")
|
||||
|
@ -566,54 +567,84 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
|
|||
N, CI, H, W = get_const_tuple(data.shape)
|
||||
CO, _, KH, KW = get_const_tuple(kernel.shape)
|
||||
|
||||
# query config of this workload
|
||||
workload = autotvm.task.args_to_workload(
|
||||
[data, kernel, strides, padding, dilation, layout, out_dtype], conv2d)
|
||||
target = tvm.target.current_target()
|
||||
dispatch_ctx = autotvm.DispatchContext.current
|
||||
cfg = dispatch_ctx.query(target, workload)
|
||||
if groups == 1:
|
||||
# query config of this workload
|
||||
workload = autotvm.task.args_to_workload(
|
||||
[data, kernel, strides, padding, dilation, layout, out_dtype], conv2d)
|
||||
target = tvm.target.current_target()
|
||||
dispatch_ctx = autotvm.DispatchContext.current
|
||||
cfg = dispatch_ctx.query(target, workload)
|
||||
|
||||
if cfg.is_fallback: # if is fallback, clear query cache and return None
|
||||
autotvm.task.clear_fallback_cache(target, workload)
|
||||
return None
|
||||
if cfg.is_fallback: # if is fallback, clear query cache and return None
|
||||
autotvm.task.clear_fallback_cache(target, workload)
|
||||
return None
|
||||
|
||||
if cfg.template_key == 'direct': # pack weight tensor
|
||||
VC = cfg['tile_co'].size[-1]
|
||||
new_attrs['kernel_layout'] = 'OIHW%do' % VC
|
||||
if cfg.template_key == 'direct': # pack weight tensor
|
||||
VC = cfg['tile_co'].size[-1]
|
||||
new_attrs['kernel_layout'] = 'OIHW%do' % VC
|
||||
|
||||
# Store the same config for the altered operator (workload)
|
||||
new_data = data
|
||||
new_kernel = tvm.placeholder((CO // VC, CI, KH, KW, VC), dtype=kernel.dtype)
|
||||
new_workload = autotvm.task.args_to_workload(
|
||||
[new_data, new_kernel, strides, padding, dilation, 'NCHW', out_dtype], conv2d)
|
||||
dispatch_ctx.update(target, new_workload, cfg)
|
||||
# Store the same config for the altered operator (workload)
|
||||
new_data = data
|
||||
new_kernel = tvm.placeholder((CO // VC, CI, KH, KW, VC), dtype=kernel.dtype)
|
||||
new_workload = autotvm.task.args_to_workload(
|
||||
[new_data, new_kernel, strides, padding, dilation, 'NCHW', out_dtype], conv2d)
|
||||
dispatch_ctx.update(target, new_workload, cfg)
|
||||
|
||||
return F.nn.conv2d(*copy_inputs, **new_attrs)
|
||||
else: # pre-compute weight transformation in winograd
|
||||
if "-device=arm_cpu" in target.options:
|
||||
tile_size = 4
|
||||
VC = cfg['tile_k'].size[-1]
|
||||
return F.nn.conv2d(*copy_inputs, **new_attrs)
|
||||
else: # pre-compute weight transformation in winograd
|
||||
if "-device=arm_cpu" in target.options:
|
||||
tile_size = 4
|
||||
VC = cfg['tile_k'].size[-1]
|
||||
else:
|
||||
from ..mali.conv2d import _pick_tile_size
|
||||
tile_size = _pick_tile_size(tinfos[0], tinfos[1])
|
||||
VC = cfg['tile_bna'].val
|
||||
|
||||
weight = F.nn.contrib_conv2d_winograd_weight_transform(copy_inputs[1],
|
||||
tile_size=tile_size)
|
||||
weight = F.reshape(weight,
|
||||
newshape=(KH + tile_size - 1, KW + tile_size - 1, CO // VC, VC, CI))
|
||||
weight = F.transpose(weight, axes=[0, 1, 2, 4, 3])
|
||||
|
||||
copy_inputs[1] = weight
|
||||
new_attrs['tile_size'] = tile_size
|
||||
|
||||
# Store the same config for the altered operator (workload)
|
||||
new_data = data
|
||||
new_weight = tvm.placeholder((KH + tile_size - 1, KH + tile_size -1, CO // VC, CI, VC),
|
||||
kernel.dtype)
|
||||
new_workload = autotvm.task.args_to_workload(
|
||||
[new_data, new_weight, strides, padding, dilation,
|
||||
new_attrs[data_layout_key], out_dtype, tile_size],
|
||||
conv2d_winograd_without_weight_transform)
|
||||
dispatch_ctx.update(target, new_workload, cfg)
|
||||
|
||||
return F.nn.contrib_conv2d_winograd_without_weight_transform(*copy_inputs, **new_attrs)
|
||||
else:
|
||||
workload = autotvm.task.args_to_workload(
|
||||
[data, kernel, strides, padding, dilation, out_dtype], depthwise_conv2d_nchw)
|
||||
target = tvm.target.current_target()
|
||||
dispatch_ctx = autotvm.DispatchContext.current
|
||||
cfg = dispatch_ctx.query(target, workload)
|
||||
|
||||
if cfg.is_fallback: # if is fallback, clear query cache and return None
|
||||
autotvm.task.clear_fallback_cache(tvm.target.current_target(), workload)
|
||||
return None
|
||||
if cfg.template_key == 'contrib_spatial_pack':
|
||||
VC = cfg['tile_co'].size[-1]
|
||||
new_attrs['kernel_layout'] = 'OIHW%do' % (cfg['tile_co'].size[-1])
|
||||
|
||||
# Store the same config for the altered operator (workload)
|
||||
new_data = data
|
||||
CO, M, KH, KW = get_const_tuple(kernel.shape)
|
||||
new_kernel = tvm.placeholder((CO // VC, M, KH, KW, VC), dtype=kernel.dtype)
|
||||
new_workload = autotvm.task.args_to_workload(
|
||||
[new_data, new_kernel, strides, padding, dilation, out_dtype],
|
||||
depthwise_conv2d_nchw)
|
||||
dispatch_ctx.update(target, new_workload, cfg)
|
||||
|
||||
return F.nn.conv2d(*copy_inputs, **new_attrs)
|
||||
else:
|
||||
from ..mali.conv2d import _pick_tile_size
|
||||
tile_size = _pick_tile_size(tinfos[0], tinfos[1])
|
||||
VC = cfg['tile_bna'].val
|
||||
|
||||
weight = F.nn.contrib_conv2d_winograd_weight_transform(copy_inputs[1], tile_size=tile_size)
|
||||
weight = F.reshape(weight,
|
||||
newshape=(KH + tile_size - 1, KW + tile_size - 1, CO // VC, VC, CI))
|
||||
weight = F.transpose(weight, axes=[0, 1, 2, 4, 3])
|
||||
|
||||
copy_inputs[1] = weight
|
||||
new_attrs['tile_size'] = tile_size
|
||||
|
||||
# Store the same config for the altered operator (workload)
|
||||
new_data = data
|
||||
new_weight = tvm.placeholder((KH + tile_size - 1, KH + tile_size -1, CO // VC, CI, VC),
|
||||
kernel.dtype)
|
||||
new_workload = autotvm.task.args_to_workload(
|
||||
[new_data, new_weight, strides, padding, dilation,
|
||||
new_attrs[data_layout_key], out_dtype, tile_size],
|
||||
conv2d_winograd_without_weight_transform)
|
||||
dispatch_ctx.update(target, new_workload, cfg)
|
||||
|
||||
return F.nn.contrib_conv2d_winograd_without_weight_transform(*copy_inputs, **new_attrs)
|
||||
# currently we only have contrib_spatial_pack and direct template
|
||||
# add more schedule templates.
|
||||
return None
|
||||
|
|
|
@ -5,15 +5,17 @@ import tvm
|
|||
from tvm import autotvm
|
||||
|
||||
from ..generic import schedule_depthwise_conv2d_nchw
|
||||
from ..nn import depthwise_conv2d_nchw
|
||||
from ..util import traverse_inline
|
||||
from ..nn import depthwise_conv2d_nchw, pad
|
||||
from ..util import traverse_inline, get_const_tuple, get_const_int
|
||||
from ..nn.util import get_pad_tuple
|
||||
|
||||
# register original implementation of depthwise_conv2d_nchw since we don't need to change this part
|
||||
autotvm.register_topi_compute(depthwise_conv2d_nchw, ['arm_cpu', 'cpu'], 'direct',
|
||||
depthwise_conv2d_nchw.fdefault)
|
||||
|
||||
# register customized schedule for arm cpu.
|
||||
@autotvm.register_topi_schedule(schedule_depthwise_conv2d_nchw, ['arm_cpu', 'cpu'], 'direct')
|
||||
@autotvm.register_topi_schedule(schedule_depthwise_conv2d_nchw, ['arm_cpu', 'cpu'],
|
||||
['direct', 'contrib_spatial_pack'])
|
||||
def schedule_depthwise_conv2d_nchw_arm(cfg, outs):
|
||||
"""Schedule depthwise conv2d
|
||||
|
||||
|
@ -116,5 +118,277 @@ def schedule_depthwise_conv2d_nchw_arm(cfg, outs):
|
|||
data = data_pad.op.input_tensors[0]
|
||||
_schedule(cfg, s, data, data_pad, kernel, output)
|
||||
|
||||
if op.tag == 'spatial_depthwise_conv_nchw_output':
|
||||
output = op.output(0)
|
||||
conv = op.input_tensors[0]
|
||||
data_vec = conv.op.input_tensors[0]
|
||||
kernel_vec = conv.op.input_tensors[1]
|
||||
if kernel_vec.op.name == 'kernel_vec':
|
||||
kernel = kernel_vec.op.input_tensors[0]
|
||||
else:
|
||||
kernel = kernel_vec
|
||||
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
|
||||
s[kernel].compute_inline()
|
||||
|
||||
_schedule_spatial_pack(cfg, s, data_vec, kernel_vec, conv, output, outs[0])
|
||||
|
||||
traverse_inline(s, outs[0].op, _callback)
|
||||
return s
|
||||
|
||||
@autotvm.register_topi_compute(depthwise_conv2d_nchw, ['arm_cpu', 'cpu'], ['contrib_spatial_pack'])
|
||||
def depthwise_conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, out_dtype):
|
||||
"""TOPI compute callback for depthwise_conv2d nchw
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cfg: ConfigEntity
|
||||
The config for this template
|
||||
|
||||
data : tvm.Tensor
|
||||
4-D with shape [batch, in_channel, in_height, in_width]
|
||||
|
||||
kernel : tvm.Tensor
|
||||
4-D with shape [num_filter, multiplier, filter_height, filter_width] or
|
||||
pre-packed 5-D with shape [num_filter_chunk, multiplier, filter_height,
|
||||
filter_width, num_filter_block]
|
||||
|
||||
strides : list of two ints
|
||||
[stride_height, stride_width]
|
||||
|
||||
padding : list of two ints
|
||||
[pad_height, pad_width]
|
||||
|
||||
dilation : list of two ints
|
||||
[dilation_height, dilation_width]
|
||||
|
||||
out_dtype: str
|
||||
The output type. This is used for mixed precision.
|
||||
|
||||
Returns
|
||||
-------
|
||||
output : tvm.Tensor
|
||||
4-D with shape [batch, out_channel, out_height, out_width]
|
||||
"""
|
||||
|
||||
return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2)
|
||||
|
||||
|
||||
def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile):
|
||||
out_dtype = out_dtype or data.dtype
|
||||
|
||||
N, C, IH, IW = get_const_tuple(data.shape)
|
||||
|
||||
if isinstance(dilation, int):
|
||||
dilation_h = dilation_w = dilation
|
||||
else:
|
||||
dilation_h, dilation_w = dilation
|
||||
|
||||
if len(kernel.shape) == 4:
|
||||
pre_packed = False
|
||||
C, M, KH, KW = get_const_tuple(kernel.shape)
|
||||
else: # kernel tensor is pre packed
|
||||
pre_packed = True
|
||||
C, M, KH, KW, VC = get_const_tuple(kernel.shape)
|
||||
C = C * VC
|
||||
|
||||
dilated_kernel_h = (KH - 1) * dilation_h + 1
|
||||
dilated_kernel_w = (KW - 1) * dilation_w + 1
|
||||
|
||||
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
|
||||
padding, (dilated_kernel_h, dilated_kernel_w))
|
||||
HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
|
||||
OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
|
||||
OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
|
||||
# pack data
|
||||
HPAD = pad_top + pad_down
|
||||
WPAD = pad_left + pad_right
|
||||
DOPAD = (HPAD != 0 or WPAD != 0)
|
||||
if DOPAD:
|
||||
data_pad = pad(data, (0, 0, pad_top, pad_left), (0, 0, pad_down, pad_right),
|
||||
name="data_pad")
|
||||
else:
|
||||
data_pad = data
|
||||
|
||||
# fallback support
|
||||
# Currently, Mali schedule doesn't use it like conv2d.
|
||||
if cfg.is_fallback:
|
||||
ref_log = autotvm.tophub.load_reference_log('arm_cpu', 'rk3399', 'depthwise_conv2d_nchw',
|
||||
'contrib_spatial_pack')
|
||||
cfg.fallback_with_reference_log(ref_log)
|
||||
|
||||
# ==================== define configuration space ====================
|
||||
n, c, oh, ow = cfg.axis(N), cfg.axis(C), cfg.axis(OH), cfg.axis(OW)
|
||||
kh, kw = cfg.reduce_axis(KH), cfg.reduce_axis(KW)
|
||||
|
||||
# Currently, Mali schedule doesn't use it like conv2d.
|
||||
# Leave num_tile for possible future use of Mali schedule
|
||||
if num_tile == 2: # for arm cpu
|
||||
co, vc = cfg.define_split('tile_co', c, num_outputs=2)
|
||||
oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2)
|
||||
ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2)
|
||||
else:
|
||||
raise RuntimeError("Invalid num_tile")
|
||||
|
||||
cfg.define_reorder("reorder_0",
|
||||
[n, co, oh, ow, kh, kw, vh, vw, vc],
|
||||
policy='candidate', candidate=[
|
||||
[n, co, oh, ow, kh, kw, vh, vw, vc],
|
||||
[n, co, oh, ow, kh, kw, vc, vh, vw]])
|
||||
|
||||
cfg.define_reorder("reorder_1",
|
||||
[n, co, oh, ow, vh, vw, vc],
|
||||
policy='candidate', candidate=[
|
||||
[n, co, oh, ow, vh, vw, vc],
|
||||
[n, co, oh, ow, vc, vh, vw],
|
||||
[n, co, oh, ow, vh, vc, vw]])
|
||||
|
||||
cfg.define_annotate("ann_reduce", [kh, kw], policy='try_unroll')
|
||||
cfg.define_annotate("ann_spatial", [vh, vw, vc], policy='try_unroll_vec')
|
||||
# ====================================================================
|
||||
|
||||
VC = cfg["tile_co"].size[-1]
|
||||
VH = cfg["tile_oh"].size[-1]
|
||||
VW = cfg["tile_ow"].size[-1]
|
||||
|
||||
kvshape = (C // VC, M, KH, KW, VC)
|
||||
ovshape = (N, C * M // VC, OH // VH, OW // VW, VH, VW, VC)
|
||||
oshape = (N, C * M, OH, OW)
|
||||
|
||||
if dilation_h != 1 or dilation_w != 1:
|
||||
# undilate input data
|
||||
dvshape = (N, OH // VH, OW // VW, C, KH, KW, VH, VW)
|
||||
data_vec = tvm.compute(dvshape, lambda n, h, w, c, kh, kw, vh, vw:
|
||||
data_pad[n][c][(h * VH + vh) * HSTR + kh * dilation_h]
|
||||
[(w*VW+vw)*WSTR+kw*dilation_w],
|
||||
name='data_vec_undilated')
|
||||
else:
|
||||
dvshape = (N, OH // VH, OW // VW, C, VH*HSTR + KH-1, VW*WSTR + KW-1)
|
||||
data_vec = tvm.compute(dvshape, lambda n, h, w, c, vh, vw:
|
||||
data_pad[n][c][h * VH * HSTR + vh][w * VW * WSTR + vw],
|
||||
name='data_vec')
|
||||
|
||||
if pre_packed:
|
||||
kernel_vec = kernel
|
||||
else:
|
||||
kernel_vec = tvm.compute(kvshape, lambda co, m, kh, kw, vc:
|
||||
kernel[co*VC+vc][m][kh][kw],
|
||||
name='kernel_vec')
|
||||
|
||||
kh = tvm.reduce_axis((0, KH), name='kh')
|
||||
kw = tvm.reduce_axis((0, KW), name='kw')
|
||||
|
||||
if dilation_h != 1 or dilation_w != 1:
|
||||
conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \
|
||||
tvm.sum(data_vec[n, h, w, (co * VC + vc) // M, kh, kw, vh, vw]
|
||||
.astype(out_dtype) *
|
||||
kernel_vec[co // M, co % M, kh, kw, vc].astype(out_dtype),
|
||||
axis=[kh, kw]), name='depthwise_conv')
|
||||
else:
|
||||
conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \
|
||||
tvm.sum(data_vec[n, h, w, (co * VC + vc) // M, vh * HSTR + kh,
|
||||
vw * WSTR + kw].astype(out_dtype) *
|
||||
kernel_vec[co // M, co % M, kh, kw, vc].astype(out_dtype),
|
||||
axis=[kh, kw]), name='depthwise_conv')
|
||||
|
||||
output = tvm.compute(oshape, lambda n, co, h, w:
|
||||
conv[n][co//VC][h//VH][w//VW][h%VH][w%VW][co%VC],
|
||||
name='output_unpack', tag='spatial_depthwise_conv_nchw_output')
|
||||
return output
|
||||
|
||||
def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec,
|
||||
conv, output, last):
|
||||
"""schedule implementation"""
|
||||
n, co, oh, ow, vh, vw, vc = s[conv].op.axis
|
||||
kh, kw = s[conv].op.reduce_axis
|
||||
|
||||
if data_vec.op.name == 'data_vec_undilated':
|
||||
_, dv_oh, dv_ow, dv_c, _, _, dv_vh, dv_vw = s[data_vec].op.axis
|
||||
else:
|
||||
_, dv_oh, dv_ow, dv_c, dv_vh, dv_vw = s[data_vec].op.axis
|
||||
|
||||
data_pad = data_vec.op.input_tensors[0]
|
||||
if data_pad.op.name == "data_pad":
|
||||
assert isinstance(data_pad.op, tvm.tensor.ComputeOp)
|
||||
has_padding = True
|
||||
else:
|
||||
assert isinstance(data_pad.op, tvm.tensor.PlaceholderOp)
|
||||
has_padding = False
|
||||
|
||||
cfg.define_knob('data_pad_inline', [0, 1, 2, 3, 4])
|
||||
|
||||
if cfg['data_pad_inline'].val == 1 and has_padding:
|
||||
s[data_pad].compute_inline()
|
||||
if cfg['data_pad_inline'].val == 2 and has_padding:
|
||||
s[data_pad].vectorize(list(s[data_pad].op.axis)[-1])
|
||||
if cfg['data_pad_inline'].val == 3 and has_padding:
|
||||
s[data_pad].vectorize(list(s[data_pad].op.axis)[-1])
|
||||
s[data_pad].compute_at(s[data_vec], dv_oh)
|
||||
if cfg['data_pad_inline'].val == 4 and has_padding:
|
||||
s[data_pad].vectorize(list(s[data_pad].op.axis)[-1])
|
||||
s[data_pad].compute_at(s[data_vec], dv_ow)
|
||||
|
||||
cfg.define_knob('data_vec_inline', [0, 1, 2, 3])
|
||||
if cfg['data_vec_inline'].val == 1:
|
||||
s[data_vec].compute_at(s[conv], oh)
|
||||
if cfg['data_vec_inline'].val == 2:
|
||||
s[data_vec].compute_at(s[conv], ow)
|
||||
if cfg['data_vec_inline'].val == 3:
|
||||
s[data_vec].compute_at(s[conv], co)
|
||||
|
||||
# schedule conv
|
||||
cfg["reorder_0"].apply(s, conv, [n, co, oh, ow, kh, kw, vh, vw, vc])
|
||||
cfg["ann_reduce"].apply(s, conv, [kh, kw],
|
||||
axis_lens=[get_const_int(kh.dom.extent),
|
||||
get_const_int(kw.dom.extent)],
|
||||
max_unroll=16,
|
||||
cfg=cfg)
|
||||
cfg["ann_spatial"].apply(s, conv, [vh, vw, vc],
|
||||
axis_lens=[cfg['tile_oh'].size[-1],
|
||||
cfg['tile_ow'].size[-1],
|
||||
cfg['tile_co'].size[-1]],
|
||||
max_unroll=16,
|
||||
cfg=cfg)
|
||||
|
||||
# schedule fusion
|
||||
n, co, h, w = s[last].op.axis
|
||||
co, vc = cfg['tile_co'].apply(s, last, co)
|
||||
oh, vh = cfg['tile_oh'].apply(s, last, h)
|
||||
ow, vw = cfg['tile_ow'].apply(s, last, w)
|
||||
cfg["reorder_1"].apply(s, last, [n, co, oh, ow, vh, vw, vc])
|
||||
if last != output:
|
||||
s[output].compute_inline()
|
||||
cfg["ann_spatial"].apply(s, last, [vh, vw, vc],
|
||||
axis_lens=[cfg['tile_oh'].size[-1],
|
||||
cfg['tile_ow'].size[-1],
|
||||
cfg['tile_co'].size[-1]],
|
||||
max_unroll=16,
|
||||
cfg=cfg)
|
||||
else:
|
||||
s[last].vectorize(vw)
|
||||
cfg.define_knob('conv_inline', [0, 1, 2, 3])
|
||||
if cfg['conv_inline'].val == 1:
|
||||
s[conv].compute_at(s[last], ow)
|
||||
if cfg['conv_inline'].val == 2:
|
||||
s[conv].compute_at(s[last], oh)
|
||||
if cfg['conv_inline'].val == 3:
|
||||
s[conv].compute_at(s[last], co)
|
||||
|
||||
# mark parallel
|
||||
s[last].parallel(co)
|
||||
|
||||
if data_vec.op.name == 'data_vec_undilated':
|
||||
_, h, _, _, _, _, _, _ = s[data_vec].op.axis
|
||||
else:
|
||||
_, h, _, _, _, _ = s[data_vec].op.axis
|
||||
s[data_vec].parallel(h)
|
||||
|
||||
if kernel_vec.op.name == 'kernel_vec':
|
||||
co, _, _, _, _ = s[kernel_vec].op.axis
|
||||
if autotvm.GLOBAL_SCOPE.in_tuning:
|
||||
# kernel packing will be pre-computed during compliation, so we skip
|
||||
# this part to make tuning records correct
|
||||
s[kernel_vec].pragma(co, 'debug_skip_region')
|
||||
else:
|
||||
s[kernel_vec].parallel(co)
|
||||
|
||||
return s
|
||||
|
|
Загрузка…
Ссылка в новой задаче