diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 8fa96780..38c8f1d0 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -65,10 +65,7 @@ class TVMType(ctypes.Structure): head = "" else: raise ValueError("Donot know how to handle type %s" % type_str) - bits = int(head) if head else bits - if (bits & (bits - 1)) != 0 or bits < 8: - raise ValueError("Donot know how to handle type %s" % type_str) self.bits = bits diff --git a/tests/python/unittest/test_codegen_llvm.py b/tests/python/unittest/test_codegen_llvm.py index 2ac4ffa3..cb171306 100644 --- a/tests/python/unittest/test_codegen_llvm.py +++ b/tests/python/unittest/test_codegen_llvm.py @@ -231,6 +231,7 @@ def test_multiple_func(): check_llvm() + def test_llvm_select(): def check_llvm(n, offset): if not tvm.module.enabled("llvm"): @@ -251,7 +252,27 @@ def test_llvm_select(): check_llvm(64, 8) +def test_llvm_bool(): + def check_llvm(n): + if not tvm.module.enabled("llvm"): + return + A = tvm.placeholder((n, ), name='A', dtype="int32") + C = tvm.compute((n,), lambda i: A[i].equal(1).astype("float"), name='C') + s = tvm.create_schedule(C.op) + # build and invoke the kernel. + f = tvm.build(s, [A, C], "llvm") + ctx = tvm.cpu(0) + # launch the kernel. + a = tvm.nd.array(np.random.randint(0, 2, size=(n,)).astype(A.dtype), ctx) + c = tvm.nd.empty((n,), C.dtype, ctx) + f(a, c) + c_np = a.asnumpy() == 1 + np.testing.assert_allclose(c.asnumpy(), c_np) + check_llvm(64) + + if __name__ == "__main__": + test_llvm_bool() test_llvm_persist_parallel() test_llvm_select() test_llvm_vadd_pipeline()