[BUGFIX] [Hybrid Script] fix in-correct value index in hybrid script (#2268)
This commit is contained in:
Родитель
6b4058240b
Коммит
4bbf96e43c
|
@ -35,20 +35,21 @@ class HybridParser(ast.NodeVisitor):
|
|||
|
||||
|
||||
_binop_maker = {
|
||||
ast.Add : operator.add,
|
||||
ast.Sub : operator.sub,
|
||||
ast.Mult : operator.mul,
|
||||
ast.Div : operator.div if sys.version_info[0] == 2 else operator.truediv,
|
||||
ast.Mod : operator.mod,
|
||||
ast.BitOr : operator.or_,
|
||||
ast.BitAnd: operator.and_,
|
||||
ast.BitXor: operator.xor,
|
||||
ast.Gt : operator.gt,
|
||||
ast.GtE : operator.ge,
|
||||
ast.Lt : operator.lt,
|
||||
ast.LtE : operator.le,
|
||||
ast.Eq : operator.eq,
|
||||
ast.NotEq : operator.ne,
|
||||
ast.Add : operator.add,
|
||||
ast.Sub : operator.sub,
|
||||
ast.Mult : operator.mul,
|
||||
ast.Div : operator.div if sys.version_info[0] == 2 else operator.truediv,
|
||||
ast.FloorDiv: operator.div if sys.version_info[0] == 2 else operator.truediv,
|
||||
ast.Mod : operator.mod,
|
||||
ast.BitOr : operator.or_,
|
||||
ast.BitAnd : operator.and_,
|
||||
ast.BitXor : operator.xor,
|
||||
ast.Gt : operator.gt,
|
||||
ast.GtE : operator.ge,
|
||||
ast.Lt : operator.lt,
|
||||
ast.LtE : operator.le,
|
||||
ast.Eq : operator.eq,
|
||||
ast.NotEq : operator.ne,
|
||||
ast.And : _all,
|
||||
ast.Or : _any,
|
||||
}
|
||||
|
@ -237,7 +238,7 @@ class HybridParser(ast.NodeVisitor):
|
|||
if isinstance(node.value, ast.Name):
|
||||
array = node.value.id
|
||||
_buf = self._get_buffer_from_id(array)
|
||||
return _make.Call(_buf.dtype, array, args, _expr.Call.Halide, _buf.op, 0)
|
||||
return _make.Call(_buf.dtype, array, args, _expr.Call.Halide, _buf.op, _buf.value_index)
|
||||
|
||||
_internal_assert(isinstance(node.value, ast.Attribute), \
|
||||
"Only variable and attribute's subscript supported so far")
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import tvm, inspect, sys, traceback, numpy, nose
|
||||
import tvm, inspect, sys, traceback, numpy, nose, types
|
||||
from tvm.hybrid import script
|
||||
from tvm.hybrid.intrin import HYBRID_GLOBALS
|
||||
|
||||
|
@ -11,6 +11,10 @@ def run_and_check(func, args, var_dict={}, target='llvm'):
|
|||
return val.value
|
||||
|
||||
ctx = tvm.context(target, 0)
|
||||
op = None
|
||||
|
||||
outs = func(*args)
|
||||
op = outs[0].op if isinstance(outs, list) else outs.op
|
||||
|
||||
emu_args = []
|
||||
nd_args = []
|
||||
|
@ -24,8 +28,6 @@ def run_and_check(func, args, var_dict={}, target='llvm'):
|
|||
emu_args.append(tvm_val_2_py_val(i))
|
||||
nd_args.append(emu_args[-1])
|
||||
|
||||
outs = func(*args)
|
||||
op = outs[0].op if isinstance(outs, list) else outs.op
|
||||
sch = tvm.create_schedule(op)
|
||||
module = tvm.build(sch, args + (outs if isinstance(outs, list) else [outs]), target=target)
|
||||
assert module
|
||||
|
@ -425,10 +427,12 @@ def test_downstream():
|
|||
for i in range(20):
|
||||
b[i] = a[i] * i
|
||||
return b
|
||||
|
||||
|
||||
a = tvm.placeholder((20, ), 'float32')
|
||||
b = downstream(a)
|
||||
c = tvm.compute((20, ), lambda x: b[x] + 1.0)
|
||||
|
||||
sch = tvm.create_schedule(c.op)
|
||||
module = tvm.build(sch, [a, c])
|
||||
assert module
|
||||
|
@ -469,6 +473,40 @@ def test_const_param():
|
|||
|
||||
tvm.testing.assert_allclose(nd_c.asnumpy(), ref, 1e-5, 1e-5)
|
||||
|
||||
def test_value_index():
|
||||
@tvm.hybrid.script
|
||||
def kernel_a(a):
|
||||
b = output_tensor((16, ), 'int32')
|
||||
c = output_tensor((4, 4), 'int32')
|
||||
for i in range(16):
|
||||
b[i] = a[i] + 2
|
||||
c[i // 4, i % 4] = a[i] + 1
|
||||
return b, c
|
||||
|
||||
@tvm.hybrid.script
|
||||
def kernel_b(b, a):
|
||||
c = output_tensor((4, 4), 'int32')
|
||||
for i in range(4):
|
||||
for j in range(4):
|
||||
c[i, j] = a[i * 4 + j] * b[i, j]
|
||||
return c
|
||||
|
||||
a = tvm.placeholder((16, ), 'int32')
|
||||
b, c = kernel_a(a)
|
||||
d = kernel_b(c, b)
|
||||
sch = tvm.create_schedule(d.op)
|
||||
module = tvm.build(sch, [a, d])
|
||||
assert module
|
||||
|
||||
np_a = numpy.arange(16).astype('int32')
|
||||
np_b, np_c = kernel_a(np_a)
|
||||
ref = kernel_b(np_c, np_b)
|
||||
|
||||
res = tvm.ndarray.array(numpy.zeros((4, 4)).astype('int32'))
|
||||
module(tvm.ndarray.array(np_a), res)
|
||||
tvm.testing.assert_allclose(res.asnumpy(), ref)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_outer_product()
|
||||
|
@ -479,9 +517,11 @@ if __name__ == "__main__":
|
|||
test_math_intrin()
|
||||
test_non_zero()
|
||||
test_allocate()
|
||||
#test_inplace()
|
||||
test_upstream()
|
||||
test_downstream()
|
||||
test_const_param()
|
||||
test_value_index()
|
||||
# TODO:
|
||||
# test_inplace()
|
||||
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче