[INTRIN] prefetch support (#246)
* [INTRIN] prefetch support * lint * add buildin
This commit is contained in:
Родитель
7bcb3f538b
Коммит
8ca7576943
|
@ -751,6 +751,15 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
|
|||
llvm::Function* f = llvm::Intrinsic::getDeclaration(
|
||||
module_.get(), id, arg_types);
|
||||
return builder_->CreateCall(f, arg_values);
|
||||
} else if (op->is_intrinsic("llvm_buildin")) {
|
||||
std::vector<llvm::Value*> arg_values;
|
||||
for (size_t i = 1; i < op->args.size(); ++i) {
|
||||
llvm::Value* v = MakeValue(op->args[i]);
|
||||
arg_values.push_back(v);
|
||||
}
|
||||
auto id = static_cast<llvm::Intrinsic::ID>(op->args[0].as<UIntImm>()->value);
|
||||
llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), id);
|
||||
return builder_->CreateCall(f, arg_values);
|
||||
} else if (op->is_intrinsic(Call::bitwise_and)) {
|
||||
CHECK_EQ(op->args.size(), 2U);
|
||||
return builder_->CreateAnd(
|
||||
|
@ -785,8 +794,10 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
|
|||
} else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
|
||||
const Load *l = op->args[0].as<Load>();
|
||||
CHECK(op->args.size() == 1 && l);
|
||||
return CreateBufferPtr(
|
||||
l->type, GetVarValue(l->buffer_var.get()), MakeValue(l->index));
|
||||
return builder_->CreatePointerCast(
|
||||
CreateBufferPtr(
|
||||
l->type, GetVarValue(l->buffer_var.get()), MakeValue(l->index)),
|
||||
t_void_p_);
|
||||
} else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) {
|
||||
CHECK_EQ(op->args.size(), 1U);
|
||||
llvm::Value* ptr = MakeValue(op->args[0]);
|
||||
|
|
|
@ -17,7 +17,22 @@ namespace llvm {
|
|||
using namespace ir;
|
||||
|
||||
template<unsigned id>
|
||||
inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) {
|
||||
inline void DispatchLLVMBuildin(const TVMArgs& targs, TVMRetValue* rv) {
|
||||
Expr e = targs[0];
|
||||
const Call* call = e.as<Call>();
|
||||
CHECK(call != nullptr);
|
||||
Array<Expr> cargs;
|
||||
// intrin id.
|
||||
cargs.push_back(UIntImm::make(UInt(32), id));
|
||||
for (Expr arg : call->args) {
|
||||
cargs.push_back(arg);
|
||||
}
|
||||
*rv = Call::make(
|
||||
call->type, "llvm_buildin", cargs, Call::Intrinsic);
|
||||
}
|
||||
|
||||
template<unsigned id>
|
||||
inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
|
||||
Expr e = targs[0];
|
||||
const Call* call = e.as<Call>();
|
||||
CHECK(call != nullptr);
|
||||
|
@ -31,14 +46,17 @@ inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) {
|
|||
call->type, "llvm_intrin", cargs, Call::PureIntrinsic);
|
||||
}
|
||||
|
||||
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.__buildin_prefetch")
|
||||
.set_body(DispatchLLVMBuildin<::llvm::Intrinsic::prefetch>);
|
||||
|
||||
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp")
|
||||
.set_body(DispatchLLVMIntrin<::llvm::Intrinsic::exp>);
|
||||
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp>);
|
||||
|
||||
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log")
|
||||
.set_body(DispatchLLVMIntrin<::llvm::Intrinsic::log>);
|
||||
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log>);
|
||||
|
||||
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sqrt")
|
||||
.set_body(DispatchLLVMIntrin<::llvm::Intrinsic::sqrt>);
|
||||
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt>);
|
||||
|
||||
} // namespace llvm
|
||||
} // namespace codegen
|
||||
|
|
|
@ -82,6 +82,8 @@ class PackedCallBuilder : public IRMutator {
|
|||
for (size_t i = 0; i < op->extents.size(); ++i) {
|
||||
total_bytes = total_bytes * op->extents[i];
|
||||
}
|
||||
CHECK(device_type_.defined()) << "Unknown device type in current IR";
|
||||
CHECK(device_id_.defined()) << "Unknown device id in current IR";
|
||||
Stmt throw_last_error = Evaluate::make(Call::make(Int(32),
|
||||
intrinsic::tvm_throw_last_error, {},
|
||||
Call::Intrinsic));
|
||||
|
|
|
@ -1,5 +1,21 @@
|
|||
import tvm
|
||||
import numpy as np
|
||||
import ctypes
|
||||
|
||||
def test_llvm_intrin():
|
||||
ib = tvm.ir_builder.create()
|
||||
n = tvm.convert(4)
|
||||
A = ib.pointer("float32", name="A")
|
||||
args = [
|
||||
tvm.call_pure_intrin("handle", "tvm_address_of", A[0]),
|
||||
0, 3, 1
|
||||
]
|
||||
ib.emit(tvm.make.Evaluate(
|
||||
tvm.make.Call(
|
||||
"int32", "__buildin_prefetch", args, tvm.expr.Call.Intrinsic, None, 0)))
|
||||
body = ib.get()
|
||||
func = tvm.ir_pass.MakeAPI(body, "prefetch", [A], 0, True)
|
||||
fcode = tvm.build(func, None, "llvm")
|
||||
|
||||
def test_llvm_add_pipeline():
|
||||
nn = 1024
|
||||
|
@ -151,6 +167,7 @@ def test_multiple_func():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_llvm_intrin()
|
||||
test_multiple_func()
|
||||
test_llvm_add_pipeline()
|
||||
test_llvm_flip_pipeline()
|
||||
|
|
|
@ -14,12 +14,10 @@ def test_flatten2():
|
|||
assert isinstance(bounds, tvm.container.Map)
|
||||
stmt = tvm.schedule.ScheduleOps(s, bounds)
|
||||
|
||||
print(stmt)
|
||||
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
|
||||
A2b = tvm.decl_buffer(A2.shape, A2.dtype, name='A2')
|
||||
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b})
|
||||
stmt = tvm.ir_pass.Simplify(stmt)
|
||||
print(stmt)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_flatten2()
|
||||
|
|
Загрузка…
Ссылка в новой задаче