[TOPI] Relu and schedule elemwise (#390)

* convolution compute typo fixed

* relu activation migrated to topi

* reviews addressed

* elemwise schedule added

* relu compute deleted
This commit is contained in:
Leyuan Wang 2017-08-27 16:47:22 -07:00 коммит произвёл Tianqi Chen
Родитель 897b50e72c
Коммит efafa1a0dd
4 изменённых файлов: 75 добавлений и 1 удалений

Просмотреть файл

@ -8,3 +8,4 @@ from .depthwise_conv2d import schedule_depthwise_conv2d_nchw, schedule_depthwise
from .reduction import schedule_reduce
from .broadcast import schedule_broadcast_to
from .softmax import schedule_softmax
from .elemwise import schedule_elemwise

Просмотреть файл

@ -0,0 +1,39 @@
# pylint: disable=invalid-name, unused-variable, trailing-whitespace, no-member
"""Schedule for element wise operator"""
import tvm
def schedule_elemwise(outs):
"""Schedule for element wise op.
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s)
x = outs[0]
num_dim = len(x.shape)
block_factor = tvm.ir_pass.Simplify(x.op.output(0).shape[num_dim-1]).value
if block_factor % 48 == 0:
block_factor = 48
elif block_factor % 32 == 0:
block_factor = 32
bx, tx = s[x].split(x.op.axis[num_dim-1], factor=block_factor)
for i in range(num_dim-2, 0, -1):
bx = s[x].fuse(bx, x.op.axis[i])
s[x].bind(bx, tvm.thread_axis("blockIdx.x"))
s[x].bind(tx, tvm.thread_axis("threadIdx.x"))
return s

Просмотреть файл

@ -29,7 +29,6 @@ def conv2d_nchw(Input, Filter, stride, padding):
4-D with shape [batch, out_channel, out_height, out_width]
"""
assert isinstance(stride, int) or len(stride) == 2
assert isinstance(padding, int) or padding in ['VALID', 'SAME']
batch, in_channel, in_height, in_width = Input.shape
num_filter, channel, kernel_h, kernel_w = Filter.shape
if isinstance(stride, int):

Просмотреть файл

@ -0,0 +1,35 @@
"""Test code for relu activation"""
import os
import numpy as np
import tvm
import topi
from topi.util import get_const_tuple
def verify_relu(m, n):
A = tvm.placeholder((m, n), name='A')
B = topi.nn.relu(A)
s = topi.cuda.schedule_elemwise(B)
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
b_np = a_np * (a_np > 0)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
foo = tvm.build(s, [A, B], device, name="relu")
foo(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal']:
check_device(device)
def test_relu():
verify_relu(10, 128)
if __name__ == "__main__":
test_relu()