diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index 08c6af86..4c1ee8b1 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -751,6 +751,15 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) { llvm::Function* f = llvm::Intrinsic::getDeclaration( module_.get(), id, arg_types); return builder_->CreateCall(f, arg_values); + } else if (op->is_intrinsic("llvm_buildin")) { + std::vector arg_values; + for (size_t i = 1; i < op->args.size(); ++i) { + llvm::Value* v = MakeValue(op->args[i]); + arg_values.push_back(v); + } + auto id = static_cast(op->args[0].as()->value); + llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), id); + return builder_->CreateCall(f, arg_values); } else if (op->is_intrinsic(Call::bitwise_and)) { CHECK_EQ(op->args.size(), 2U); return builder_->CreateAnd( @@ -785,8 +794,10 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) { } else if (op->is_intrinsic(intrinsic::tvm_address_of)) { const Load *l = op->args[0].as(); CHECK(op->args.size() == 1 && l); - return CreateBufferPtr( - l->type, GetVarValue(l->buffer_var.get()), MakeValue(l->index)); + return builder_->CreatePointerCast( + CreateBufferPtr( + l->type, GetVarValue(l->buffer_var.get()), MakeValue(l->index)), + t_void_p_); } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) { CHECK_EQ(op->args.size(), 1U); llvm::Value* ptr = MakeValue(op->args[0]); diff --git a/src/codegen/llvm/intrin_rule_llvm.cc b/src/codegen/llvm/intrin_rule_llvm.cc index c2065153..695ef544 100644 --- a/src/codegen/llvm/intrin_rule_llvm.cc +++ b/src/codegen/llvm/intrin_rule_llvm.cc @@ -17,7 +17,22 @@ namespace llvm { using namespace ir; template -inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) { +inline void DispatchLLVMBuildin(const TVMArgs& targs, TVMRetValue* rv) { + Expr e = targs[0]; + const Call* call = e.as(); + CHECK(call != nullptr); + Array cargs; + // intrin id. + cargs.push_back(UIntImm::make(UInt(32), id)); + for (Expr arg : call->args) { + cargs.push_back(arg); + } + *rv = Call::make( + call->type, "llvm_buildin", cargs, Call::Intrinsic); +} + +template +inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { Expr e = targs[0]; const Call* call = e.as(); CHECK(call != nullptr); @@ -31,14 +46,17 @@ inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) { call->type, "llvm_intrin", cargs, Call::PureIntrinsic); } +TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.__buildin_prefetch") +.set_body(DispatchLLVMBuildin<::llvm::Intrinsic::prefetch>); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp") -.set_body(DispatchLLVMIntrin<::llvm::Intrinsic::exp>); +.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log") -.set_body(DispatchLLVMIntrin<::llvm::Intrinsic::log>); +.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sqrt") -.set_body(DispatchLLVMIntrin<::llvm::Intrinsic::sqrt>); +.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt>); } // namespace llvm } // namespace codegen diff --git a/src/pass/lower_packed_call.cc b/src/pass/lower_packed_call.cc index 2ee3d99a..1aaddf0f 100644 --- a/src/pass/lower_packed_call.cc +++ b/src/pass/lower_packed_call.cc @@ -82,6 +82,8 @@ class PackedCallBuilder : public IRMutator { for (size_t i = 0; i < op->extents.size(); ++i) { total_bytes = total_bytes * op->extents[i]; } + CHECK(device_type_.defined()) << "Unknown device type in current IR"; + CHECK(device_id_.defined()) << "Unknown device id in current IR"; Stmt throw_last_error = Evaluate::make(Call::make(Int(32), intrinsic::tvm_throw_last_error, {}, Call::Intrinsic)); diff --git a/tests/python/unittest/test_codegen_llvm.py b/tests/python/unittest/test_codegen_llvm.py index c9c201a5..23279f9b 100644 --- a/tests/python/unittest/test_codegen_llvm.py +++ b/tests/python/unittest/test_codegen_llvm.py @@ -1,5 +1,21 @@ import tvm import numpy as np +import ctypes + +def test_llvm_intrin(): + ib = tvm.ir_builder.create() + n = tvm.convert(4) + A = ib.pointer("float32", name="A") + args = [ + tvm.call_pure_intrin("handle", "tvm_address_of", A[0]), + 0, 3, 1 + ] + ib.emit(tvm.make.Evaluate( + tvm.make.Call( + "int32", "__buildin_prefetch", args, tvm.expr.Call.Intrinsic, None, 0))) + body = ib.get() + func = tvm.ir_pass.MakeAPI(body, "prefetch", [A], 0, True) + fcode = tvm.build(func, None, "llvm") def test_llvm_add_pipeline(): nn = 1024 @@ -151,6 +167,7 @@ def test_multiple_func(): if __name__ == "__main__": + test_llvm_intrin() test_multiple_func() test_llvm_add_pipeline() test_llvm_flip_pipeline() diff --git a/tests/python/unittest/test_pass_storage_flatten.py b/tests/python/unittest/test_pass_storage_flatten.py index 099fb5a9..87c9bf6a 100644 --- a/tests/python/unittest/test_pass_storage_flatten.py +++ b/tests/python/unittest/test_pass_storage_flatten.py @@ -14,12 +14,10 @@ def test_flatten2(): assert isinstance(bounds, tvm.container.Map) stmt = tvm.schedule.ScheduleOps(s, bounds) - print(stmt) Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') A2b = tvm.decl_buffer(A2.shape, A2.dtype, name='A2') stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}) stmt = tvm.ir_pass.Simplify(stmt) - print(stmt) if __name__ == "__main__": test_flatten2()