[TOPI] Add topi.target; Schedule for raspberry pi (#406)

* CPU Schedule for raspberry pi

* Update

* Update

* Add topi.target

* Refactor

* Update

* Make python3 happy

* Improve

* Improve

* Improve

* Use get_const_int
This commit is contained in:
ziheng 2017-09-02 23:25:07 -07:00 коммит произвёл GitHub
Родитель f6bb7ababa
Коммит e05f54bee0
11 изменённых файлов: 857 добавлений и 145 удалений

Просмотреть файл

@ -14,5 +14,7 @@ from .reduction import *
from .broadcast import *
from . import nn
from . import cuda
from . import rasp
from . import target
from . import testing
from . import util

Просмотреть файл

@ -4,6 +4,7 @@ from __future__ import absolute_import as _abs
from .batch_norm import *
from .convolution import *
from .depthwise_convolution import *
from .elemwise import *
from .dilate import *
from .flatten import *

Просмотреть файл

@ -1,9 +1,232 @@
# pylint: disable=invalid-name, unused-variable, too-many-locals
"""Convolution operators"""
from __future__ import absolute_import as _abs
from collections import namedtuple
import tvm
from .pad import pad
from .util import get_pad_tuple
from ..util import simplify
from .pad import pad, _spatial2d_pad_option
from .. import target as _target
# workload description of convolution
Workload = namedtuple('Workload',
['height', 'width', 'in_filter', 'out_filter',
'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
# schedule description of spatial
SpatialPack = namedtuple('SpatialPack',
['vh', 'vw', 'vc', 'ba', 'bc', 'unroll'])
# schedule description of im2col
Im2ColPack = namedtuple('Im2ColPack',
['vp', 'vq', 'ba', 'bc', 'unroll'])
# workloads of resnet18 on imagenet
_WORKLOADS = [
Workload(224, 224, 3, 64, 7, 7, 3, 3, 2, 2),
Workload(56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
Workload(56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
Workload(56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
Workload(56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
Workload(28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
Workload(28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
Workload(28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
Workload(14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
Workload(14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
Workload(14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
Workload(7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
]
# platform specific schedule
_CONV_SCHEDULE = {}
# platform specific declaration
_CONV_DECLARATION = {}
def convolution(data, kernel, stride, padding, layout='NCHW'):
"""Convolution operator.
Parameters
----------
input : tvm.Tensor
4-D with shape [batch, in_channel, in_height, in_width]
filter : tvm.Tensor
4-D with shape [num_filter, in_channel, filter_height, filter_width]
stride : int or a list/tuple of two ints
stride size, or [stride_height, stride_width]
padding : int or a list/tuple of two ints
padding size, or [pad_height, pad_width]
layout : str
layout of data
Returns
-------
output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
# search platform specific declaration first
target = _target.current_target()
if target in _CONV_DECLARATION:
return _CONV_DECLARATION[target](data, kernel, stride, padding, layout)
# default declaration
if layout == 'NCHW':
conv2d_nchw(data, kernel, stride, padding)
elif layout == 'HWCN':
conv2d_hwcn(data, kernel, stride, padding)
else:
raise ValueError("not support this layout {} yet".format(layout))
def _get_workload(data, kernel, stride, padding):
""" Get the workload structure. """
_, CI, IH, IW = [x.value for x in data.shape]
CO, _, KH, KW = [x.value for x in kernel.shape]
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
if isinstance(stride, (tuple, list)):
HSTR, WSTR = stride
else:
HSTR, WSTR = stride, stride
return Workload(IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR)
def _get_schedule(wkl, target=None):
""" Get the platform specific schedule. """
if target is None:
target = _target.current_target()
else:
target = _target.Target(target)
assert target in _CONV_SCHEDULE, "no schedule for such target: {}".format(target)
return _CONV_SCHEDULE[target](wkl)
def _spatial_pack(data, kernel, stride, padding):
""" Compute convolution with pack on spatial axes. """
assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1"
wkl = _get_workload(data, kernel, stride, padding)
sch = _get_schedule(wkl)
H, W = wkl.height, wkl.width
CI, CO = wkl.in_filter, wkl.out_filter
KH, KW = wkl.hkernel, wkl.wkernel
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
HCAT, WCAT = KH-1, KW-1
VH = sch.vh
VW = sch.vw
VC = sch.vc
UNROLL = sch.unroll
TH = H + 2*HPAD
TW = W + 2*WPAD
OH = (H + 2*HPAD - KH) // HSTR + 1
OW = (W + 2*WPAD - KW) // WSTR + 1
dshape = (1, CI, H, W)
dpshape = (1, CI, TH, TW)
dvshape = (1, TH//(VH*HSTR), TW//(VW*WSTR), CI, VH*HSTR+HCAT, VW*WSTR+WCAT)
kshape = (CO, CI, KH, KW)
kvshape = (CO/VC, CI, KH, KW, VC)
ovshape = (1, CO // VC, OH // VH, OW // VW, VH, VW, VC)
oshape = (1, CO, OH, OW)
DOPAD = (HPAD != 0 and WPAD != 0)
if DOPAD:
data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad")
else:
data_pad = data
data_vec = tvm.compute(dvshape, lambda n, h, w, ci, vh, vw: \
data_pad[n][ci][h*VH*HSTR+vh][w*VW*WSTR+vw], name='data_vec')
kernel_vec = tvm.compute(kvshape, lambda co, ci, dh, dw, vc: \
kernel[co*VC+vc][ci][dh][dw], name='kernel_vec')
ci = tvm.reduce_axis((0, CI), name='ci')
dh = tvm.reduce_axis((0, KH), name='dh')
dw = tvm.reduce_axis((0, KW), name='dw')
conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \
tvm.sum(data_vec[n, h, w, ci, vh*HSTR+dh, vw*WSTR+dw] *
kernel_vec[co, ci, dh, dw, vc],
axis=[ci, dh, dw]), name='conv')
output = tvm.compute(oshape, lambda n, co, h, w:
conv[n][co//VC][h/VH][w//VW][h%VH][w%VW][co%VC],
name='output_unpack', tag='spatial_conv_output')
return output
def _im2col_pack(data, kernel, stride, padding):
""" Compute convolution with im2col pack layout. """
assert data.shape[0].value == 1, "im2col pack convolution only support batch size=1"
wkl = _get_workload(data, kernel, stride, padding)
sch = _get_schedule(wkl)
N = 1
H, W = wkl.height, wkl.width
CI = wkl.in_filter
CO = wkl.out_filter
KH, KW = wkl.hkernel, wkl.wkernel
HPAD, WPAD = wkl.hpad, wkl.hpad
HSTR, WSTR = wkl.hstride, wkl.wstride
OH = (H + 2*HPAD - KH) // HSTR + 1
OW = (W + 2*WPAD - KW) // WSTR + 1
P = sch.vp
Q = sch.vq
UNROLL = sch.unroll
dshape = (N, CI, H, W)
dpshape = (N, CI, H+2*HPAD, W+2*WPAD)
dcshape = (N, OH, OW, CI, KH, KW)
dvshape = (N, OH * OW // P, CI, KH, KW, P)
kshape = (CO, CI, KH, KW)
kvshape = (CO // Q, CI, KH, KW, Q)
ovshape = (N, CO // Q, OH * OW // P, P, Q)
oshape = (N, CO, OH, OW)
############### declaration
DO_PAD = (wkl.hpad != 0 and wkl.wpad != 0)
if DO_PAD:
data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad")
else:
data_pad = data
data_col = tvm.compute(dcshape, lambda n, oh, ow, ci, hk, wk: \
data_pad[n][ci][oh*HSTR+hk][ow*WSTR+wk], name='data_col')
data_vec = tvm.compute(dvshape, lambda n, im, ci, hk, wk, vim: \
data_col[n][(im*P+vim)//OW][(im*P+vim)%OW][ci][hk][wk], name='data_vec')
kernel_vec = tvm.compute(kvshape, lambda co, ci, dh, dw, vc: \
kernel[co*Q+vc][ci][dh][dw], name='kernel_vec')
ci = tvm.reduce_axis((0, CI), name='ci')
hk = tvm.reduce_axis((0, KH), name='hk')
wk = tvm.reduce_axis((0, KW), name='wk')
conv = tvm.compute(ovshape, lambda n, co, im, vim, vco: \
tvm.sum(data_vec[n][im][ci][hk][wk][vim] * kernel_vec[co][ci][hk][wk][vco],
axis=[ci, hk, wk]), name='conv')
output = tvm.compute(oshape, lambda n, co, h, w: \
conv[n][co//Q][(h*OW+w)//P][(h*OW+w)%P][co%Q],
name='output_vec', tag='im2col_conv_output')
return output
def conv2d_nchw(Input, Filter, stride, padding):
@ -35,7 +258,7 @@ def conv2d_nchw(Input, Filter, stride, padding):
stride_h = stride_w = stride
else:
stride_h, stride_w = stride
pad_top, pad_left, pad_down, pad_right = _spatial2d_pad_option(
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (kernel_h, kernel_w))
# compute the output shape
out_channel = num_filter
@ -86,7 +309,7 @@ def conv2d_hwcn(Input, Filter, stride, padding):
else:
stride_h, stride_w = stride
pad_top, pad_left, pad_down, pad_right = _spatial2d_pad_option(
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (kernel_h, kernel_w))
# compute the output shape
out_channel = num_filter
@ -106,98 +329,8 @@ def conv2d_hwcn(Input, Filter, stride, padding):
name="Conv2dOutput", tag="conv2d_hwcn")
return Output
def depthwise_conv2d_nchw(Input, Filter, stride, padding):
"""Depthwise convolution nchw forward operator.
Parameters
----------
Input : tvm.Tensor
4-D with shape [batch, in_channel, in_height, in_width]
Filter : tvm.Tensor
4-D with shape [in_channel, channel_multiplier, filter_height, filter_width]
stride : tuple of two ints
The spatial stride along height and width
padding : int or str
Padding size, or ['VALID', 'SAME']
Returns
-------
Output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
batch, in_channel, in_height, in_width = Input.shape
filter_channel, channel_multiplier, filter_height, filter_width = Filter.shape
stride_h, stride_w = stride
pad_top, pad_left, pad_down, pad_right = _spatial2d_pad_option(
padding, (filter_height, filter_width))
out_channel = simplify(in_channel * channel_multiplier)
out_height = simplify((in_height - filter_height + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - filter_width + pad_left + pad_right) // stride_w + 1)
# padding stage
pad_before = [0, 0, pad_top, pad_left]
pad_after = [0, 0, pad_down, pad_right]
PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
# depthconv stage
di = tvm.reduce_axis((0, filter_height), name='di')
dj = tvm.reduce_axis((0, filter_width), name='dj')
Output = tvm.compute(
(batch, out_channel, out_height, out_width),
lambda b, c, i, j: tvm.sum(
(PaddedInput[b, c/channel_multiplier, i*stride_h + di, j*stride_w + dj] *
Filter[c/channel_multiplier, c%channel_multiplier, di, dj]),
axis=[di, dj]),
name='DepthwiseConv2d', tag="depthwise_conv2d_nchw")
return Output
def depthwise_conv2d_nhwc(Input, Filter, stride, padding):
"""Depthwise convolution nhwc forward operator.
Parameters
----------
Input : tvm.Tensor
4-D with shape [batch, in_height, in_width, in_channel]
Filter : tvm.Tensor
4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
Stride : tvm.Tensor
1-D of size 2
padding : int or str
Padding size, or ['VALID', 'SAME']
Returns
-------
Output : tvm.Tensor
4-D with shape [batch, out_height, out_width, out_channel]
"""
batch, in_height, in_width, in_channel = Input.shape
filter_height, filter_width, filter_channel, channel_multiplier = Filter.shape
stride_h, stride_w = stride
pad_top, pad_left, pad_down, pad_right = _spatial2d_pad_option(
padding, (filter_height, filter_width))
out_channel = simplify(in_channel * channel_multiplier)
out_height = simplify((in_height - filter_height + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - filter_width + pad_left + pad_right) // stride_w + 1)
# padding stage
pad_before = [0, pad_top, pad_left, 0]
pad_after = [0, pad_down, pad_right, 0]
PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
# depthconv stage
di = tvm.reduce_axis((0, filter_height), name='di')
dj = tvm.reduce_axis((0, filter_width), name='dj')
Output = tvm.compute(
(batch, out_height, out_width, out_channel),
lambda b, i, j, c: tvm.sum(
(PaddedInput[b, i*stride_h + di, j*stride_w + dj, c/channel_multiplier] *
Filter[di, dj, c/channel_multiplier, c%channel_multiplier]),
axis=[di, dj]),
name='DepthwiseConv2d', tag="depthwise_conv2d_nhwc")
return Output
# map from schedule type to declaration function
_SCH_TO_DECL_FUNC = {
SpatialPack: _spatial_pack,
Im2ColPack: _im2col_pack,
}

Просмотреть файл

@ -0,0 +1,104 @@
# pylint: disable=invalid-name, unused-variable, too-many-locals
"""Depthwise Convolution operators"""
from __future__ import absolute_import as _abs
import tvm
from .pad import pad
from .util import get_pad_tuple
from ..util import simplify
def depthwise_conv2d_nchw(Input, Filter, stride, padding):
"""Depthwise convolution nchw forward operator.
Parameters
----------
Input : tvm.Tensor
4-D with shape [batch, in_channel, in_height, in_width]
Filter : tvm.Tensor
4-D with shape [in_channel, channel_multiplier, filter_height, filter_width]
stride : tuple of two ints
The spatial stride along height and width
padding : int or str
Padding size, or ['VALID', 'SAME']
Returns
-------
Output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
batch, in_channel, in_height, in_width = Input.shape
filter_channel, channel_multiplier, filter_height, filter_width = Filter.shape
stride_h, stride_w = stride
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (filter_height, filter_width))
out_channel = simplify(in_channel * channel_multiplier)
out_height = simplify((in_height - filter_height + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - filter_width + pad_left + pad_right) // stride_w + 1)
# padding stage
pad_before = [0, 0, pad_top, pad_left]
pad_after = [0, 0, pad_down, pad_right]
PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
# depthconv stage
di = tvm.reduce_axis((0, filter_height), name='di')
dj = tvm.reduce_axis((0, filter_width), name='dj')
Output = tvm.compute(
(batch, out_channel, out_height, out_width),
lambda b, c, i, j: tvm.sum(
(PaddedInput[b, c/channel_multiplier, i*stride_h + di, j*stride_w + dj] *
Filter[c/channel_multiplier, c%channel_multiplier, di, dj]),
axis=[di, dj]),
name='DepthwiseConv2d', tag="depthwise_conv2d_nchw")
return Output
def depthwise_conv2d_nhwc(Input, Filter, stride, padding):
"""Depthwise convolution nhwc forward operator.
Parameters
----------
Input : tvm.Tensor
4-D with shape [batch, in_height, in_width, in_channel]
Filter : tvm.Tensor
4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
Stride : tvm.Tensor
1-D of size 2
padding : int or str
Padding size, or ['VALID', 'SAME']
Returns
-------
Output : tvm.Tensor
4-D with shape [batch, out_height, out_width, out_channel]
"""
batch, in_height, in_width, in_channel = Input.shape
filter_height, filter_width, filter_channel, channel_multiplier = Filter.shape
stride_h, stride_w = stride
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (filter_height, filter_width))
out_channel = simplify(in_channel * channel_multiplier)
out_height = simplify((in_height - filter_height + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - filter_width + pad_left + pad_right) // stride_w + 1)
# padding stage
pad_before = [0, pad_top, pad_left, 0]
pad_after = [0, pad_down, pad_right, 0]
PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
# depthconv stage
di = tvm.reduce_axis((0, filter_height), name='di')
dj = tvm.reduce_axis((0, filter_width), name='dj')
Output = tvm.compute(
(batch, out_height, out_width, out_channel),
lambda b, i, j, c: tvm.sum(
(PaddedInput[b, i*stride_h + di, j*stride_w + dj, c/channel_multiplier] *
Filter[di, dj, c/channel_multiplier, c%channel_multiplier]),
axis=[di, dj]),
name='DepthwiseConv2d', tag="depthwise_conv2d_nhwc")
return Output

Просмотреть файл

@ -3,51 +3,6 @@ from __future__ import absolute_import as _abs
import tvm
from ..util import equal_const_int
def _spatial2d_pad_option(padding, kernel):
"""Common code to get the pad option
Parameters
----------
padding : int or str
Padding size, or ['VALID', 'SAME']
kernel : tuple of int
Conv kernel size
Returns
-------
pad_top : int
Padding size on top
pad_left : int
Padding size on left
pad_down : int
Padding size on down.
pad_right : int
Padding size on right.
"""
# compute the padding size
if isinstance(padding, (tuple, list)):
pad_h = padding[0] * 2
pad_w = padding[1] * 2
elif isinstance(padding, int):
pad_h = pad_w = padding * 2
elif padding == "VALID":
pad_h = 0
pad_w = 0
elif padding == "SAME":
pad_h = kernel[0] - 1
pad_w = kernel[1] - 1
else:
raise ValueError("Unknown padding option %s" % padding)
pad_top = (pad_h + 1) // 2
pad_left = (pad_w + 1) // 2
return pad_top, pad_left, pad_h - pad_top, pad_w - pad_left
@tvm.tag_scope(tag="pad")
def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput"):
"""Dilate Input with zeros.

Просмотреть файл

@ -1,8 +1,9 @@
"""TVM operator pooling compute."""
from __future__ import absolute_import
import tvm
from .pad import pad
from .util import get_pad_tuple
from .. import util
from .pad import pad, _spatial2d_pad_option
def max_pool(data, kernel, stride, padding):
"""Perform max pooling on the data
@ -32,7 +33,7 @@ def max_pool(data, kernel, stride, padding):
stride_height, stride_width = stride
batch, channel, height, width = data.shape
pad_top, pad_left, pad_down, pad_right = _spatial2d_pad_option(
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (kernel_height, kernel_width))
pad_before = [0, 0, pad_top, pad_left]
pad_after = [0, 0, pad_down, pad_right]

102
topi/python/topi/nn/util.py Normal file
Просмотреть файл

@ -0,0 +1,102 @@
# pylint: disable=invalid-name, unused-variable
"""NN operator common utilities"""
from __future__ import absolute_import
from ..util import get_const_int
def infer_pad(data, data_pad):
"""Infer the padding from stages in reverse.
Parameters
----------
data : Tensor
data stage.
data_pad : Tensor
pad stage.
Returns
-------
hpad : int
padding size on height
wpad : int
padding size on width
"""
if data_pad is None:
return 0, 0
_, _, IH, IW = data.shape
_, _, TH, TW = data_pad.shape
hpad = (TH - IH) // 2
wpad = (TW - IW) // 2
return get_const_int(hpad), get_const_int(wpad)
def infer_stride(data, kernel, out):
"""Infer the stride from stages in reverse.
Parameters
----------
data : Tensor
data stage.
kernel : Tensor
kernel stage.
out : Tensor
output stage.
Returns
-------
hstride : int
stride size on height
wstride : int
stride size on width
"""
_, _, IH, IW = data.shape
_, _, KH, KW = kernel.shape
_, _, OH, OW = out.shape
hstride = (IH - KH) // (OH - 1)
wstride = (IW - KW) // (OW - 1)
return get_const_int(hstride), get_const_int(wstride)
def get_pad_tuple(padding, kernel):
"""Common code to get the pad option
Parameters
----------
padding : int or str
Padding size, or ['VALID', 'SAME']
kernel : tuple of int
Conv kernel size
Returns
-------
pad_top : int
Padding size on top
pad_left : int
Padding size on left
pad_down : int
Padding size on down.
pad_right : int
Padding size on right.
"""
# compute the padding size
if isinstance(padding, (tuple, list)):
pad_h = padding[0] * 2
pad_w = padding[1] * 2
elif isinstance(padding, int):
pad_h = pad_w = padding * 2
elif padding == "VALID":
pad_h = 0
pad_w = 0
elif padding == "SAME":
pad_h = kernel[0] - 1
pad_w = kernel[1] - 1
else:
raise ValueError("Unknown padding option %s" % padding)
pad_top = (pad_h + 1) // 2
pad_left = (pad_w + 1) // 2
return pad_top, pad_left, pad_h - pad_top, pad_w - pad_left

Просмотреть файл

@ -0,0 +1,5 @@
# pylint: disable=redefined-builtin, wildcard-import
"""Raspberry pi specific declaration and schedules."""
from __future__ import absolute_import as _abs
from .convolution import *

Просмотреть файл

@ -0,0 +1,312 @@
# pylint: disable=invalid-name,unused-variable,invalid-name
"""Convolution schedule on raspberry pi"""
from __future__ import absolute_import as _abs
import tvm
from .. import target as _target
from ..nn.convolution import SpatialPack, Im2ColPack
from ..nn.convolution import _CONV_DECLARATION, _CONV_SCHEDULE
from ..nn.convolution import _WORKLOADS, _SCH_TO_DECL_FUNC
from ..nn.convolution import _get_workload, _get_schedule
from ..nn.util import infer_pad, infer_stride
_SCHEDULES = [
SpatialPack(1, 8, 4, 1, 4, True),
SpatialPack(1, 7, 4, 2, 4, True),
SpatialPack(1, 4, 8, 4, 1, True),
SpatialPack(1, 4, 4, 1, 16, False),
SpatialPack(1, 4, 8, 4, 8, False),
SpatialPack(1, 7, 4, 3, 8, True),
SpatialPack(1, 2, 8, 1, 8, True),
SpatialPack(2, 1, 16, 1, 4, True),
SpatialPack(1, 7, 4, 1, 1, True),
Im2ColPack(7, 4, 1, 16, True),
Im2ColPack(7, 4, 1, 8, False),
Im2ColPack(7, 4, 1, 16, False),
]
def _schedule_conv2d(wkl):
if wkl not in _WORKLOADS:
raise ValueError("no schedule for such workload: {}".format(wkl))
idx = _WORKLOADS.index(wkl)
sch = _SCHEDULES[idx]
return sch
_CONV_SCHEDULE[_target.rasp()] = _schedule_conv2d
def _declaration_conv2d(data, kernel, stride, padding, layout):
assert layout == 'NCHW', "only support NCHW convolution on rasp"
assert data.shape[0].value == 1, "only support batch size=1 convolution on rasp"
wkl = _get_workload(data, kernel, stride, padding)
sch = _get_schedule(wkl)
return _SCH_TO_DECL_FUNC[type(sch)](data, kernel, stride, padding)
_CONV_DECLARATION[_target.rasp()] = _declaration_conv2d
def _schedule_spatial_conv2d(s, data, data_pad, data_vec,
kernel, kernel_vec,
conv_out, output, last):
# no stride and padding info here
padding = infer_pad(data, data_pad)
if data_pad is None:
stride = infer_stride(data, kernel, output)
else:
stride = infer_stride(data_pad, kernel, output)
wkl = _get_workload(data, kernel, stride, padding)
sch = _get_schedule(wkl, 'rasp')
H, W = wkl.height, wkl.width
CI, CO = wkl.in_filter, wkl.out_filter
HK, WK = wkl.hkernel, wkl.wkernel
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
HCAT, WCAT = HK-1, WK-1
DOPAD = (HPAD != 0 and WPAD != 0)
VH = sch.vh
VW = sch.vw
VC = sch.vc
UNROLL = sch.unroll
A, B, C = data, kernel, last
A0, A1 = data_pad, data_vec
B0 = kernel_vec
C0, C1 = conv_out, output
CC = s.cache_write(C0, "global")
_, co, oh, ow, vh, vw, vc = s[C0].op.axis
if UNROLL:
s[C0].unroll(vw)
s[C0].vectorize(vc)
s[CC].compute_at(s[C0], ow)
_, co, oh, ow, vh, vw, vc = s[CC].op.axis
ci, dh, dw = s[CC].op.reduce_axis
s[CC].reorder(ci, dh, vh, dw, vw, vc)
if UNROLL:
s[CC].unroll(vw)
s[CC].vectorize(vc)
##### Schedule A
if DOPAD:
s[A0].compute_inline()
_, h, _, _, _, _ = s[A1].op.axis
if sch.ba == 1:
oaxis = h
paxis = h
else:
oh, ih = s[A1].split(h, sch.ba)
oaxis = oh
paxis = ih
s[A1].parallel(paxis)
s[A1].pragma(oaxis, "parallel_launch_point")
s[A1].pragma(paxis, "parallel_stride_pattern")
s[A1].pragma(oaxis, "parallel_barrier_when_finish")
##### Schedule B
co, _, _, _, _ = s[B0].op.axis
if sch.bc == 1:
oaxis = co
paxis = co
else:
oco, ico = s[B0].split(co, sch.bc)
oaxis = oco
paxis = ico
s[B0].parallel(paxis)
s[B0].pragma(oaxis, "parallel_launch_point")
s[B0].pragma(paxis, "parallel_stride_pattern")
s[B0].pragma(oaxis, "parallel_barrier_when_finish")
##### Schedule C
n, co, h, w = s[C].op.axis
co, vc = s[C].split(co, VC)
oh, ow, vh, vw = s[C].tile(h, w, VH, VW)
s[C].reorder(n, co, oh, ow, vh, vw, vc)
if C != C1:
s[C1].compute_inline()
s[C0].compute_at(s[C], ow)
if sch.bc == 1:
oaxis = co
paxis = co
else:
oco, ico = s[C].split(co, sch.bc)
oaxis = oco
paxis = ico
s[C].parallel(paxis)
s[C].pragma(oaxis, "parallel_launch_point")
s[C].pragma(paxis, "parallel_stride_pattern")
s[C].pragma(oaxis, "parallel_barrier_when_finish")
return s
def _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec,
kernel, kernel_vec,
conv_out, output, last):
# no stride and padding info here
padding = infer_pad(data, data_pad)
if data_pad is None:
stride = infer_stride(data, kernel, output)
else:
stride = infer_stride(data_pad, kernel, output)
wkl = _get_workload(data, kernel, stride, padding)
sch = _get_schedule(wkl, 'rasp')
H, W = wkl.height, wkl.width
CI = wkl.in_filter
CO = wkl.out_filter
HK, WK = wkl.hkernel, wkl.wkernel
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
HCAT, WCAT = HK-1, WK-1
DOPAD = (HPAD != 0 and WPAD != 0)
P = sch.vp
Q = sch.vq
UNROLL = sch.unroll
A, B, C = data, kernel, last
A0, A1, A2 = data_pad, data_col, data_vec
B0 = kernel_vec
C0, C1 = conv_out, output
CC = s.cache_write(C0, "global")
AA = s.cache_read(A2, "global", [CC])
BB = s.cache_read(B0, "global", [CC])
##### Schedule CC
_, co, im, vim, vco = s[C0].op.axis
s[C0].unroll(vim)
s[C0].vectorize(vco)
s[CC].compute_at(s[C0], im)
_, co, im, vim, vco = s[CC].op.axis
ci, hk, wk = s[CC].op.reduce_axis
s[CC].reorder(ci, hk, wk, vim, vco)
s[CC].unroll(vim)
s[CC].vectorize(vco)
# s[CC].unroll(ccr)
### Schedule C
_, co, h, w = s[C].op.axis
im = s[C].fuse(h, w)
im, vim = s[C].split(im, P)
co, vco = s[C].split(co, Q)
s[C].reorder(co, im, vim, vco)
if sch.bc == 1:
oaxis = co
paxis = co
else:
oco, ico = s[C].split(co, sch.bc)
oaxis = oco
paxis = ico
s[C].parallel(paxis)
s[C].pragma(oaxis, "parallel_launch_point")
s[C].pragma(paxis, "parallel_stride_pattern")
s[C].pragma(oaxis, "parallel_barrier_when_finish")
if C1 != C:
s[C1].compute_inline()
s[C0].compute_at(s[C], paxis)
##### Schedule A
if DOPAD:
s[A0].compute_inline()
s[A1].compute_inline()
s[AA].compute_at(s[CC], wk)
s[AA].unroll(AA.op.axis[4])
_, im, _, _, _, _ = s[A2].op.axis
if sch.ba == 1:
oaxis = im
paxis = im
else:
oim, iim = s[A2].split(im, sch.ba)
oaxis = oim
paxis = iim
s[A2].parallel(paxis)
s[A2].pragma(oaxis, "parallel_launch_point")
s[A2].pragma(paxis, "parallel_stride_pattern")
s[A2].pragma(oaxis, "parallel_barrier_when_finish")
##### Schedule B
s[BB].compute_at(s[CC], wk)
s[BB].vectorize(BB.op.axis[4])
co, _, _, _, _ = s[B0].op.axis
if sch.bc == 1:
oaxis = co
paxis = co
else:
oco, ico = s[B0].split(co, sch.bc)
oaxis = oco
paxis = ico
s[B0].parallel(paxis)
s[B0].pragma(oaxis, "parallel_launch_point")
s[B0].pragma(paxis, "parallel_stride_pattern")
s[B0].pragma(oaxis, "parallel_barrier_when_finish")
return s
def schedule_convolution(outs):
"""Create schedule for tensors"""
s = tvm.create_schedule([x.op for x in outs])
def traverse(op):
"""Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output)
if 'ewise' in op.tag or 'bcast' in op.tag:
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
if 'spatial_conv_output' in op.tag:
output = op.output(0)
conv_out = op.input_tensors[0]
kernel_vec = conv_out.op.input_tensors[1]
kernel = kernel_vec.op.input_tensors[0]
data_vec = conv_out.op.input_tensors[0]
data = data_vec.op.input_tensors[0]
data_pad = None
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
_schedule_spatial_conv2d(s, data, data_pad, data_vec,
kernel, kernel_vec,
conv_out, output, outs[0])
if 'im2col_conv_output' in op.tag:
output = op.output(0)
conv_out = op.input_tensors[0]
kernel_vec = conv_out.op.input_tensors[1]
kernel = kernel_vec.op.input_tensors[0]
data_vec = conv_out.op.input_tensors[0]
data_col = data_vec.op.input_tensors[0]
data = data_col.op.input_tensors[0]
data_pad = None
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
_schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec,
kernel, kernel_vec,
conv_out, output, outs[0])
traverse(outs[0].op)
return s

Просмотреть файл

@ -0,0 +1,63 @@
"""Target management API of topi"""
from __future__ import absolute_import
class Target(object):
"""A Target describes the target type on which computation should be carried on"""
default_target = None
str2type = {'x86': 1, 'cuda': 2, 'rasp': 3}
type2str = {1: 'x86', 2: 'cuda', 3: 'rasp'}
def __init__(self, target_type):
"""Constructs a context."""
if isinstance(target_type, Target):
self.target_typeid = target_type.target_typeid
else:
self.target_typeid = Target.str2type[target_type]
@property
def target_type(self):
"""Returns the target type of current target."""
return Target.type2str[self.target_typeid]
def __hash__(self):
"""Compute hash value of target for dictionary lookup"""
return hash(self.target_typeid)
def __eq__(self, other):
"""Compares two targets. Two targets are equal if they
have the same target type.
"""
return isinstance(other, Target) and \
self.target_typeid == other.target_typeid
def __str__(self):
return '%s' % (self.target_type)
def __repr__(self):
return self.__str__()
def __enter__(self):
self._old_target = Target.default_target
Target.default_target = self
return self
def __exit__(self, ptype, value, trace):
Target.default_target = self._old_target
Target.default_target = Target('x86')
def x86():
"""Returns a x86 target."""
return Target('x86')
def cuda():
"""Returns a cuda target."""
return Target('cuda')
def rasp():
"""Returns a rasp target."""
return Target('rasp')
def current_target():
"""Returns the current target."""
return Target.default_target

Просмотреть файл

@ -0,0 +1,34 @@
"""Example code to do convolution."""
import os
import numpy as np
import tvm
import topi
from topi.util import get_const_tuple
def verify_convolution(batch, in_size, in_channel, num_filter, kernel, stride, padding):
in_height = in_width = in_size
with topi.target.rasp():
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W')
B = topi.nn.convolution(A, W, stride, padding)
s = topi.rasp.schedule_convolution([B])
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
w_np = np.random.uniform(size=get_const_tuple(W.shape)).astype(W.dtype)
b_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding)
ctx = tvm.cpu(0)
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
func = tvm.build(s, [A, W, B], "llvm")
func(a, w, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
def test_convolution():
verify_convolution(1, 56, 64, 64, 3, 1, 1)
if __name__ == "__main__":
test_convolution()