[BUGFIX] [Hybrid Script] fix in-correct value index in hybrid script (#2268)

This commit is contained in:
Jian Weng 2018-12-13 10:21:36 -08:00 коммит произвёл Tianqi Chen
Родитель 6b4058240b
Коммит 4bbf96e43c
2 изменённых файлов: 60 добавлений и 19 удалений

Просмотреть файл

@ -35,20 +35,21 @@ class HybridParser(ast.NodeVisitor):
_binop_maker = {
ast.Add : operator.add,
ast.Sub : operator.sub,
ast.Mult : operator.mul,
ast.Div : operator.div if sys.version_info[0] == 2 else operator.truediv,
ast.Mod : operator.mod,
ast.BitOr : operator.or_,
ast.BitAnd: operator.and_,
ast.BitXor: operator.xor,
ast.Gt : operator.gt,
ast.GtE : operator.ge,
ast.Lt : operator.lt,
ast.LtE : operator.le,
ast.Eq : operator.eq,
ast.NotEq : operator.ne,
ast.Add : operator.add,
ast.Sub : operator.sub,
ast.Mult : operator.mul,
ast.Div : operator.div if sys.version_info[0] == 2 else operator.truediv,
ast.FloorDiv: operator.div if sys.version_info[0] == 2 else operator.truediv,
ast.Mod : operator.mod,
ast.BitOr : operator.or_,
ast.BitAnd : operator.and_,
ast.BitXor : operator.xor,
ast.Gt : operator.gt,
ast.GtE : operator.ge,
ast.Lt : operator.lt,
ast.LtE : operator.le,
ast.Eq : operator.eq,
ast.NotEq : operator.ne,
ast.And : _all,
ast.Or : _any,
}
@ -237,7 +238,7 @@ class HybridParser(ast.NodeVisitor):
if isinstance(node.value, ast.Name):
array = node.value.id
_buf = self._get_buffer_from_id(array)
return _make.Call(_buf.dtype, array, args, _expr.Call.Halide, _buf.op, 0)
return _make.Call(_buf.dtype, array, args, _expr.Call.Halide, _buf.op, _buf.value_index)
_internal_assert(isinstance(node.value, ast.Attribute), \
"Only variable and attribute's subscript supported so far")

Просмотреть файл

@ -1,4 +1,4 @@
import tvm, inspect, sys, traceback, numpy, nose
import tvm, inspect, sys, traceback, numpy, nose, types
from tvm.hybrid import script
from tvm.hybrid.intrin import HYBRID_GLOBALS
@ -11,6 +11,10 @@ def run_and_check(func, args, var_dict={}, target='llvm'):
return val.value
ctx = tvm.context(target, 0)
op = None
outs = func(*args)
op = outs[0].op if isinstance(outs, list) else outs.op
emu_args = []
nd_args = []
@ -24,8 +28,6 @@ def run_and_check(func, args, var_dict={}, target='llvm'):
emu_args.append(tvm_val_2_py_val(i))
nd_args.append(emu_args[-1])
outs = func(*args)
op = outs[0].op if isinstance(outs, list) else outs.op
sch = tvm.create_schedule(op)
module = tvm.build(sch, args + (outs if isinstance(outs, list) else [outs]), target=target)
assert module
@ -425,10 +427,12 @@ def test_downstream():
for i in range(20):
b[i] = a[i] * i
return b
a = tvm.placeholder((20, ), 'float32')
b = downstream(a)
c = tvm.compute((20, ), lambda x: b[x] + 1.0)
sch = tvm.create_schedule(c.op)
module = tvm.build(sch, [a, c])
assert module
@ -469,6 +473,40 @@ def test_const_param():
tvm.testing.assert_allclose(nd_c.asnumpy(), ref, 1e-5, 1e-5)
def test_value_index():
@tvm.hybrid.script
def kernel_a(a):
b = output_tensor((16, ), 'int32')
c = output_tensor((4, 4), 'int32')
for i in range(16):
b[i] = a[i] + 2
c[i // 4, i % 4] = a[i] + 1
return b, c
@tvm.hybrid.script
def kernel_b(b, a):
c = output_tensor((4, 4), 'int32')
for i in range(4):
for j in range(4):
c[i, j] = a[i * 4 + j] * b[i, j]
return c
a = tvm.placeholder((16, ), 'int32')
b, c = kernel_a(a)
d = kernel_b(c, b)
sch = tvm.create_schedule(d.op)
module = tvm.build(sch, [a, d])
assert module
np_a = numpy.arange(16).astype('int32')
np_b, np_c = kernel_a(np_a)
ref = kernel_b(np_c, np_b)
res = tvm.ndarray.array(numpy.zeros((4, 4)).astype('int32'))
module(tvm.ndarray.array(np_a), res)
tvm.testing.assert_allclose(res.asnumpy(), ref)
if __name__ == "__main__":
test_outer_product()
@ -479,9 +517,11 @@ if __name__ == "__main__":
test_math_intrin()
test_non_zero()
test_allocate()
#test_inplace()
test_upstream()
test_downstream()
test_const_param()
test_value_index()
# TODO:
# test_inplace()