From 394cf9f72a9c16f41055ded4c1cf8b3fa9420db3 Mon Sep 17 00:00:00 2001 From: Wu Zhao Date: Fri, 11 Jan 2019 17:49:09 +0800 Subject: [PATCH] [ARM][Performance] Improve ARM CPU depthwise convolution performance (#2345) * Add sptialpack schedule for arm cpu depthwise convolution * Supply comments. --- nnvm/src/top/nn/convolution.cc | 4 +- src/relay/op/nn/convolution.cc | 3 +- topi/python/topi/arm_cpu/conv2d.py | 127 +++++---- topi/python/topi/arm_cpu/depthwise_conv2d.py | 280 ++++++++++++++++++- 4 files changed, 358 insertions(+), 56 deletions(-) diff --git a/nnvm/src/top/nn/convolution.cc b/nnvm/src/top/nn/convolution.cc index 81394749..e6ff7223 100644 --- a/nnvm/src/top/nn/convolution.cc +++ b/nnvm/src/top/nn/convolution.cc @@ -73,15 +73,13 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(param.channels % param.groups, 0U) << "output channels must divide group size"; - TShape wshape({param.channels / param.groups, + TShape wshape({param.channels, dshape[1] / param.groups, param.kernel_size[0], param.kernel_size[1]}); wshape = ConvertLayout(wshape, kOIHW, kernel_layout); - wshape[kernel_layout.indexof('O')] *= param.groups; - if (in_shape->at(Conv2DParam::kWeight).ndim() == 0) { NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DParam::kWeight, wshape); } diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 608cdab2..53098b71 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -52,12 +52,11 @@ bool Conv2DRel(const Array& types, CHECK_EQ(param->kernel_size.size(), 2); CHECK_EQ(param->dilation.size(), 2); std::vector wshape( - {param->channels / param->groups, + {param->channels, dshape_nchw[1] / param->groups, param->kernel_size[0], param->kernel_size[1]}); wshape = ConvertLayout(wshape, kOIHW, kernel_layout); - wshape[kernel_layout.Indexof('O')] *= param->groups; channels = param->channels; dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; diff --git a/topi/python/topi/arm_cpu/conv2d.py b/topi/python/topi/arm_cpu/conv2d.py index 017c62b7..605749d4 100644 --- a/topi/python/topi/arm_cpu/conv2d.py +++ b/topi/python/topi/arm_cpu/conv2d.py @@ -11,7 +11,8 @@ from tvm import autotvm from ..generic import schedule_conv2d_nchw, schedule_conv2d_winograd_without_weight_transform from ..util import traverse_inline, get_const_tuple, const_matrix -from ..nn import dilate, pad, conv2d, conv2d_alter_layout, conv2d_winograd_without_weight_transform +from ..nn import dilate, pad, conv2d, conv2d_alter_layout, \ + conv2d_winograd_without_weight_transform, depthwise_conv2d_nchw from ..nn.util import get_const_int, get_pad_tuple @autotvm.register_topi_compute(conv2d, 'arm_cpu', ['direct']) @@ -556,7 +557,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): if out_dtype == "" or out_dtype == "same": out_dtype = tinfos[0].dtype - if layout != 'NCHW' or groups != 1: + if layout != 'NCHW': return None if dilation != (1, 1): warnings.warn("Does not support weight pre-transform for dilated convolution.") @@ -566,54 +567,84 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): N, CI, H, W = get_const_tuple(data.shape) CO, _, KH, KW = get_const_tuple(kernel.shape) - # query config of this workload - workload = autotvm.task.args_to_workload( - [data, kernel, strides, padding, dilation, layout, out_dtype], conv2d) - target = tvm.target.current_target() - dispatch_ctx = autotvm.DispatchContext.current - cfg = dispatch_ctx.query(target, workload) + if groups == 1: + # query config of this workload + workload = autotvm.task.args_to_workload( + [data, kernel, strides, padding, dilation, layout, out_dtype], conv2d) + target = tvm.target.current_target() + dispatch_ctx = autotvm.DispatchContext.current + cfg = dispatch_ctx.query(target, workload) - if cfg.is_fallback: # if is fallback, clear query cache and return None - autotvm.task.clear_fallback_cache(target, workload) - return None + if cfg.is_fallback: # if is fallback, clear query cache and return None + autotvm.task.clear_fallback_cache(target, workload) + return None - if cfg.template_key == 'direct': # pack weight tensor - VC = cfg['tile_co'].size[-1] - new_attrs['kernel_layout'] = 'OIHW%do' % VC + if cfg.template_key == 'direct': # pack weight tensor + VC = cfg['tile_co'].size[-1] + new_attrs['kernel_layout'] = 'OIHW%do' % VC - # Store the same config for the altered operator (workload) - new_data = data - new_kernel = tvm.placeholder((CO // VC, CI, KH, KW, VC), dtype=kernel.dtype) - new_workload = autotvm.task.args_to_workload( - [new_data, new_kernel, strides, padding, dilation, 'NCHW', out_dtype], conv2d) - dispatch_ctx.update(target, new_workload, cfg) + # Store the same config for the altered operator (workload) + new_data = data + new_kernel = tvm.placeholder((CO // VC, CI, KH, KW, VC), dtype=kernel.dtype) + new_workload = autotvm.task.args_to_workload( + [new_data, new_kernel, strides, padding, dilation, 'NCHW', out_dtype], conv2d) + dispatch_ctx.update(target, new_workload, cfg) - return F.nn.conv2d(*copy_inputs, **new_attrs) - else: # pre-compute weight transformation in winograd - if "-device=arm_cpu" in target.options: - tile_size = 4 - VC = cfg['tile_k'].size[-1] + return F.nn.conv2d(*copy_inputs, **new_attrs) + else: # pre-compute weight transformation in winograd + if "-device=arm_cpu" in target.options: + tile_size = 4 + VC = cfg['tile_k'].size[-1] + else: + from ..mali.conv2d import _pick_tile_size + tile_size = _pick_tile_size(tinfos[0], tinfos[1]) + VC = cfg['tile_bna'].val + + weight = F.nn.contrib_conv2d_winograd_weight_transform(copy_inputs[1], + tile_size=tile_size) + weight = F.reshape(weight, + newshape=(KH + tile_size - 1, KW + tile_size - 1, CO // VC, VC, CI)) + weight = F.transpose(weight, axes=[0, 1, 2, 4, 3]) + + copy_inputs[1] = weight + new_attrs['tile_size'] = tile_size + + # Store the same config for the altered operator (workload) + new_data = data + new_weight = tvm.placeholder((KH + tile_size - 1, KH + tile_size -1, CO // VC, CI, VC), + kernel.dtype) + new_workload = autotvm.task.args_to_workload( + [new_data, new_weight, strides, padding, dilation, + new_attrs[data_layout_key], out_dtype, tile_size], + conv2d_winograd_without_weight_transform) + dispatch_ctx.update(target, new_workload, cfg) + + return F.nn.contrib_conv2d_winograd_without_weight_transform(*copy_inputs, **new_attrs) + else: + workload = autotvm.task.args_to_workload( + [data, kernel, strides, padding, dilation, out_dtype], depthwise_conv2d_nchw) + target = tvm.target.current_target() + dispatch_ctx = autotvm.DispatchContext.current + cfg = dispatch_ctx.query(target, workload) + + if cfg.is_fallback: # if is fallback, clear query cache and return None + autotvm.task.clear_fallback_cache(tvm.target.current_target(), workload) + return None + if cfg.template_key == 'contrib_spatial_pack': + VC = cfg['tile_co'].size[-1] + new_attrs['kernel_layout'] = 'OIHW%do' % (cfg['tile_co'].size[-1]) + + # Store the same config for the altered operator (workload) + new_data = data + CO, M, KH, KW = get_const_tuple(kernel.shape) + new_kernel = tvm.placeholder((CO // VC, M, KH, KW, VC), dtype=kernel.dtype) + new_workload = autotvm.task.args_to_workload( + [new_data, new_kernel, strides, padding, dilation, out_dtype], + depthwise_conv2d_nchw) + dispatch_ctx.update(target, new_workload, cfg) + + return F.nn.conv2d(*copy_inputs, **new_attrs) else: - from ..mali.conv2d import _pick_tile_size - tile_size = _pick_tile_size(tinfos[0], tinfos[1]) - VC = cfg['tile_bna'].val - - weight = F.nn.contrib_conv2d_winograd_weight_transform(copy_inputs[1], tile_size=tile_size) - weight = F.reshape(weight, - newshape=(KH + tile_size - 1, KW + tile_size - 1, CO // VC, VC, CI)) - weight = F.transpose(weight, axes=[0, 1, 2, 4, 3]) - - copy_inputs[1] = weight - new_attrs['tile_size'] = tile_size - - # Store the same config for the altered operator (workload) - new_data = data - new_weight = tvm.placeholder((KH + tile_size - 1, KH + tile_size -1, CO // VC, CI, VC), - kernel.dtype) - new_workload = autotvm.task.args_to_workload( - [new_data, new_weight, strides, padding, dilation, - new_attrs[data_layout_key], out_dtype, tile_size], - conv2d_winograd_without_weight_transform) - dispatch_ctx.update(target, new_workload, cfg) - - return F.nn.contrib_conv2d_winograd_without_weight_transform(*copy_inputs, **new_attrs) + # currently we only have contrib_spatial_pack and direct template + # add more schedule templates. + return None diff --git a/topi/python/topi/arm_cpu/depthwise_conv2d.py b/topi/python/topi/arm_cpu/depthwise_conv2d.py index 2556af36..1e25eb58 100644 --- a/topi/python/topi/arm_cpu/depthwise_conv2d.py +++ b/topi/python/topi/arm_cpu/depthwise_conv2d.py @@ -5,15 +5,17 @@ import tvm from tvm import autotvm from ..generic import schedule_depthwise_conv2d_nchw -from ..nn import depthwise_conv2d_nchw -from ..util import traverse_inline +from ..nn import depthwise_conv2d_nchw, pad +from ..util import traverse_inline, get_const_tuple, get_const_int +from ..nn.util import get_pad_tuple # register original implementation of depthwise_conv2d_nchw since we don't need to change this part autotvm.register_topi_compute(depthwise_conv2d_nchw, ['arm_cpu', 'cpu'], 'direct', depthwise_conv2d_nchw.fdefault) # register customized schedule for arm cpu. -@autotvm.register_topi_schedule(schedule_depthwise_conv2d_nchw, ['arm_cpu', 'cpu'], 'direct') +@autotvm.register_topi_schedule(schedule_depthwise_conv2d_nchw, ['arm_cpu', 'cpu'], + ['direct', 'contrib_spatial_pack']) def schedule_depthwise_conv2d_nchw_arm(cfg, outs): """Schedule depthwise conv2d @@ -116,5 +118,277 @@ def schedule_depthwise_conv2d_nchw_arm(cfg, outs): data = data_pad.op.input_tensors[0] _schedule(cfg, s, data, data_pad, kernel, output) + if op.tag == 'spatial_depthwise_conv_nchw_output': + output = op.output(0) + conv = op.input_tensors[0] + data_vec = conv.op.input_tensors[0] + kernel_vec = conv.op.input_tensors[1] + if kernel_vec.op.name == 'kernel_vec': + kernel = kernel_vec.op.input_tensors[0] + else: + kernel = kernel_vec + if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag: + s[kernel].compute_inline() + + _schedule_spatial_pack(cfg, s, data_vec, kernel_vec, conv, output, outs[0]) + traverse_inline(s, outs[0].op, _callback) return s + +@autotvm.register_topi_compute(depthwise_conv2d_nchw, ['arm_cpu', 'cpu'], ['contrib_spatial_pack']) +def depthwise_conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, out_dtype): + """TOPI compute callback for depthwise_conv2d nchw + + Parameters + ---------- + cfg: ConfigEntity + The config for this template + + data : tvm.Tensor + 4-D with shape [batch, in_channel, in_height, in_width] + + kernel : tvm.Tensor + 4-D with shape [num_filter, multiplier, filter_height, filter_width] or + pre-packed 5-D with shape [num_filter_chunk, multiplier, filter_height, + filter_width, num_filter_block] + + strides : list of two ints + [stride_height, stride_width] + + padding : list of two ints + [pad_height, pad_width] + + dilation : list of two ints + [dilation_height, dilation_width] + + out_dtype: str + The output type. This is used for mixed precision. + + Returns + ------- + output : tvm.Tensor + 4-D with shape [batch, out_channel, out_height, out_width] + """ + + return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2) + + +def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile): + out_dtype = out_dtype or data.dtype + + N, C, IH, IW = get_const_tuple(data.shape) + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + if len(kernel.shape) == 4: + pre_packed = False + C, M, KH, KW = get_const_tuple(kernel.shape) + else: # kernel tensor is pre packed + pre_packed = True + C, M, KH, KW, VC = get_const_tuple(kernel.shape) + C = C * VC + + dilated_kernel_h = (KH - 1) * dilation_h + 1 + dilated_kernel_w = (KW - 1) * dilation_w + 1 + + pad_top, pad_left, pad_down, pad_right = get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w)) + HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) + OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1 + OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1 + # pack data + HPAD = pad_top + pad_down + WPAD = pad_left + pad_right + DOPAD = (HPAD != 0 or WPAD != 0) + if DOPAD: + data_pad = pad(data, (0, 0, pad_top, pad_left), (0, 0, pad_down, pad_right), + name="data_pad") + else: + data_pad = data + + # fallback support + # Currently, Mali schedule doesn't use it like conv2d. + if cfg.is_fallback: + ref_log = autotvm.tophub.load_reference_log('arm_cpu', 'rk3399', 'depthwise_conv2d_nchw', + 'contrib_spatial_pack') + cfg.fallback_with_reference_log(ref_log) + + # ==================== define configuration space ==================== + n, c, oh, ow = cfg.axis(N), cfg.axis(C), cfg.axis(OH), cfg.axis(OW) + kh, kw = cfg.reduce_axis(KH), cfg.reduce_axis(KW) + + # Currently, Mali schedule doesn't use it like conv2d. + # Leave num_tile for possible future use of Mali schedule + if num_tile == 2: # for arm cpu + co, vc = cfg.define_split('tile_co', c, num_outputs=2) + oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2) + ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2) + else: + raise RuntimeError("Invalid num_tile") + + cfg.define_reorder("reorder_0", + [n, co, oh, ow, kh, kw, vh, vw, vc], + policy='candidate', candidate=[ + [n, co, oh, ow, kh, kw, vh, vw, vc], + [n, co, oh, ow, kh, kw, vc, vh, vw]]) + + cfg.define_reorder("reorder_1", + [n, co, oh, ow, vh, vw, vc], + policy='candidate', candidate=[ + [n, co, oh, ow, vh, vw, vc], + [n, co, oh, ow, vc, vh, vw], + [n, co, oh, ow, vh, vc, vw]]) + + cfg.define_annotate("ann_reduce", [kh, kw], policy='try_unroll') + cfg.define_annotate("ann_spatial", [vh, vw, vc], policy='try_unroll_vec') + # ==================================================================== + + VC = cfg["tile_co"].size[-1] + VH = cfg["tile_oh"].size[-1] + VW = cfg["tile_ow"].size[-1] + + kvshape = (C // VC, M, KH, KW, VC) + ovshape = (N, C * M // VC, OH // VH, OW // VW, VH, VW, VC) + oshape = (N, C * M, OH, OW) + + if dilation_h != 1 or dilation_w != 1: + # undilate input data + dvshape = (N, OH // VH, OW // VW, C, KH, KW, VH, VW) + data_vec = tvm.compute(dvshape, lambda n, h, w, c, kh, kw, vh, vw: + data_pad[n][c][(h * VH + vh) * HSTR + kh * dilation_h] + [(w*VW+vw)*WSTR+kw*dilation_w], + name='data_vec_undilated') + else: + dvshape = (N, OH // VH, OW // VW, C, VH*HSTR + KH-1, VW*WSTR + KW-1) + data_vec = tvm.compute(dvshape, lambda n, h, w, c, vh, vw: + data_pad[n][c][h * VH * HSTR + vh][w * VW * WSTR + vw], + name='data_vec') + + if pre_packed: + kernel_vec = kernel + else: + kernel_vec = tvm.compute(kvshape, lambda co, m, kh, kw, vc: + kernel[co*VC+vc][m][kh][kw], + name='kernel_vec') + + kh = tvm.reduce_axis((0, KH), name='kh') + kw = tvm.reduce_axis((0, KW), name='kw') + + if dilation_h != 1 or dilation_w != 1: + conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \ + tvm.sum(data_vec[n, h, w, (co * VC + vc) // M, kh, kw, vh, vw] + .astype(out_dtype) * + kernel_vec[co // M, co % M, kh, kw, vc].astype(out_dtype), + axis=[kh, kw]), name='depthwise_conv') + else: + conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \ + tvm.sum(data_vec[n, h, w, (co * VC + vc) // M, vh * HSTR + kh, + vw * WSTR + kw].astype(out_dtype) * + kernel_vec[co // M, co % M, kh, kw, vc].astype(out_dtype), + axis=[kh, kw]), name='depthwise_conv') + + output = tvm.compute(oshape, lambda n, co, h, w: + conv[n][co//VC][h//VH][w//VW][h%VH][w%VW][co%VC], + name='output_unpack', tag='spatial_depthwise_conv_nchw_output') + return output + +def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec, + conv, output, last): + """schedule implementation""" + n, co, oh, ow, vh, vw, vc = s[conv].op.axis + kh, kw = s[conv].op.reduce_axis + + if data_vec.op.name == 'data_vec_undilated': + _, dv_oh, dv_ow, dv_c, _, _, dv_vh, dv_vw = s[data_vec].op.axis + else: + _, dv_oh, dv_ow, dv_c, dv_vh, dv_vw = s[data_vec].op.axis + + data_pad = data_vec.op.input_tensors[0] + if data_pad.op.name == "data_pad": + assert isinstance(data_pad.op, tvm.tensor.ComputeOp) + has_padding = True + else: + assert isinstance(data_pad.op, tvm.tensor.PlaceholderOp) + has_padding = False + + cfg.define_knob('data_pad_inline', [0, 1, 2, 3, 4]) + + if cfg['data_pad_inline'].val == 1 and has_padding: + s[data_pad].compute_inline() + if cfg['data_pad_inline'].val == 2 and has_padding: + s[data_pad].vectorize(list(s[data_pad].op.axis)[-1]) + if cfg['data_pad_inline'].val == 3 and has_padding: + s[data_pad].vectorize(list(s[data_pad].op.axis)[-1]) + s[data_pad].compute_at(s[data_vec], dv_oh) + if cfg['data_pad_inline'].val == 4 and has_padding: + s[data_pad].vectorize(list(s[data_pad].op.axis)[-1]) + s[data_pad].compute_at(s[data_vec], dv_ow) + + cfg.define_knob('data_vec_inline', [0, 1, 2, 3]) + if cfg['data_vec_inline'].val == 1: + s[data_vec].compute_at(s[conv], oh) + if cfg['data_vec_inline'].val == 2: + s[data_vec].compute_at(s[conv], ow) + if cfg['data_vec_inline'].val == 3: + s[data_vec].compute_at(s[conv], co) + + # schedule conv + cfg["reorder_0"].apply(s, conv, [n, co, oh, ow, kh, kw, vh, vw, vc]) + cfg["ann_reduce"].apply(s, conv, [kh, kw], + axis_lens=[get_const_int(kh.dom.extent), + get_const_int(kw.dom.extent)], + max_unroll=16, + cfg=cfg) + cfg["ann_spatial"].apply(s, conv, [vh, vw, vc], + axis_lens=[cfg['tile_oh'].size[-1], + cfg['tile_ow'].size[-1], + cfg['tile_co'].size[-1]], + max_unroll=16, + cfg=cfg) + + # schedule fusion + n, co, h, w = s[last].op.axis + co, vc = cfg['tile_co'].apply(s, last, co) + oh, vh = cfg['tile_oh'].apply(s, last, h) + ow, vw = cfg['tile_ow'].apply(s, last, w) + cfg["reorder_1"].apply(s, last, [n, co, oh, ow, vh, vw, vc]) + if last != output: + s[output].compute_inline() + cfg["ann_spatial"].apply(s, last, [vh, vw, vc], + axis_lens=[cfg['tile_oh'].size[-1], + cfg['tile_ow'].size[-1], + cfg['tile_co'].size[-1]], + max_unroll=16, + cfg=cfg) + else: + s[last].vectorize(vw) + cfg.define_knob('conv_inline', [0, 1, 2, 3]) + if cfg['conv_inline'].val == 1: + s[conv].compute_at(s[last], ow) + if cfg['conv_inline'].val == 2: + s[conv].compute_at(s[last], oh) + if cfg['conv_inline'].val == 3: + s[conv].compute_at(s[last], co) + + # mark parallel + s[last].parallel(co) + + if data_vec.op.name == 'data_vec_undilated': + _, h, _, _, _, _, _, _ = s[data_vec].op.axis + else: + _, h, _, _, _, _ = s[data_vec].op.axis + s[data_vec].parallel(h) + + if kernel_vec.op.name == 'kernel_vec': + co, _, _, _, _ = s[kernel_vec].op.axis + if autotvm.GLOBAL_SCOPE.in_tuning: + # kernel packing will be pre-computed during compliation, so we skip + # this part to make tuning records correct + s[kernel_vec].pragma(co, 'debug_skip_region') + else: + s[kernel_vec].parallel(co) + + return s