[PYTHON] Allow general types (#425)
This commit is contained in:
Родитель
df3c996b2a
Коммит
5ea4072c5a
|
@ -65,10 +65,7 @@ class TVMType(ctypes.Structure):
|
|||
head = ""
|
||||
else:
|
||||
raise ValueError("Donot know how to handle type %s" % type_str)
|
||||
|
||||
bits = int(head) if head else bits
|
||||
if (bits & (bits - 1)) != 0 or bits < 8:
|
||||
raise ValueError("Donot know how to handle type %s" % type_str)
|
||||
self.bits = bits
|
||||
|
||||
|
||||
|
|
|
@ -231,6 +231,7 @@ def test_multiple_func():
|
|||
check_llvm()
|
||||
|
||||
|
||||
|
||||
def test_llvm_select():
|
||||
def check_llvm(n, offset):
|
||||
if not tvm.module.enabled("llvm"):
|
||||
|
@ -251,7 +252,27 @@ def test_llvm_select():
|
|||
check_llvm(64, 8)
|
||||
|
||||
|
||||
def test_llvm_bool():
|
||||
def check_llvm(n):
|
||||
if not tvm.module.enabled("llvm"):
|
||||
return
|
||||
A = tvm.placeholder((n, ), name='A', dtype="int32")
|
||||
C = tvm.compute((n,), lambda i: A[i].equal(1).astype("float"), name='C')
|
||||
s = tvm.create_schedule(C.op)
|
||||
# build and invoke the kernel.
|
||||
f = tvm.build(s, [A, C], "llvm")
|
||||
ctx = tvm.cpu(0)
|
||||
# launch the kernel.
|
||||
a = tvm.nd.array(np.random.randint(0, 2, size=(n,)).astype(A.dtype), ctx)
|
||||
c = tvm.nd.empty((n,), C.dtype, ctx)
|
||||
f(a, c)
|
||||
c_np = a.asnumpy() == 1
|
||||
np.testing.assert_allclose(c.asnumpy(), c_np)
|
||||
check_llvm(64)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_llvm_bool()
|
||||
test_llvm_persist_parallel()
|
||||
test_llvm_select()
|
||||
test_llvm_vadd_pipeline()
|
||||
|
|
Загрузка…
Ссылка в новой задаче