* [INTRIN] prefetch support

* lint

* add buildin
This commit is contained in:
Tianqi Chen 2017-07-14 09:53:00 -07:00 коммит произвёл GitHub
Родитель 7bcb3f538b
Коммит 8ca7576943
5 изменённых файлов: 54 добавлений и 8 удалений

Просмотреть файл

@ -751,6 +751,15 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
llvm::Function* f = llvm::Intrinsic::getDeclaration(
module_.get(), id, arg_types);
return builder_->CreateCall(f, arg_values);
} else if (op->is_intrinsic("llvm_buildin")) {
std::vector<llvm::Value*> arg_values;
for (size_t i = 1; i < op->args.size(); ++i) {
llvm::Value* v = MakeValue(op->args[i]);
arg_values.push_back(v);
}
auto id = static_cast<llvm::Intrinsic::ID>(op->args[0].as<UIntImm>()->value);
llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), id);
return builder_->CreateCall(f, arg_values);
} else if (op->is_intrinsic(Call::bitwise_and)) {
CHECK_EQ(op->args.size(), 2U);
return builder_->CreateAnd(
@ -785,8 +794,10 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
} else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
const Load *l = op->args[0].as<Load>();
CHECK(op->args.size() == 1 && l);
return CreateBufferPtr(
l->type, GetVarValue(l->buffer_var.get()), MakeValue(l->index));
return builder_->CreatePointerCast(
CreateBufferPtr(
l->type, GetVarValue(l->buffer_var.get()), MakeValue(l->index)),
t_void_p_);
} else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) {
CHECK_EQ(op->args.size(), 1U);
llvm::Value* ptr = MakeValue(op->args[0]);

Просмотреть файл

@ -17,7 +17,22 @@ namespace llvm {
using namespace ir;
template<unsigned id>
inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) {
inline void DispatchLLVMBuildin(const TVMArgs& targs, TVMRetValue* rv) {
Expr e = targs[0];
const Call* call = e.as<Call>();
CHECK(call != nullptr);
Array<Expr> cargs;
// intrin id.
cargs.push_back(UIntImm::make(UInt(32), id));
for (Expr arg : call->args) {
cargs.push_back(arg);
}
*rv = Call::make(
call->type, "llvm_buildin", cargs, Call::Intrinsic);
}
template<unsigned id>
inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
Expr e = targs[0];
const Call* call = e.as<Call>();
CHECK(call != nullptr);
@ -31,14 +46,17 @@ inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) {
call->type, "llvm_intrin", cargs, Call::PureIntrinsic);
}
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.__buildin_prefetch")
.set_body(DispatchLLVMBuildin<::llvm::Intrinsic::prefetch>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp")
.set_body(DispatchLLVMIntrin<::llvm::Intrinsic::exp>);
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log")
.set_body(DispatchLLVMIntrin<::llvm::Intrinsic::log>);
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sqrt")
.set_body(DispatchLLVMIntrin<::llvm::Intrinsic::sqrt>);
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt>);
} // namespace llvm
} // namespace codegen

Просмотреть файл

@ -82,6 +82,8 @@ class PackedCallBuilder : public IRMutator {
for (size_t i = 0; i < op->extents.size(); ++i) {
total_bytes = total_bytes * op->extents[i];
}
CHECK(device_type_.defined()) << "Unknown device type in current IR";
CHECK(device_id_.defined()) << "Unknown device id in current IR";
Stmt throw_last_error = Evaluate::make(Call::make(Int(32),
intrinsic::tvm_throw_last_error, {},
Call::Intrinsic));

Просмотреть файл

@ -1,5 +1,21 @@
import tvm
import numpy as np
import ctypes
def test_llvm_intrin():
ib = tvm.ir_builder.create()
n = tvm.convert(4)
A = ib.pointer("float32", name="A")
args = [
tvm.call_pure_intrin("handle", "tvm_address_of", A[0]),
0, 3, 1
]
ib.emit(tvm.make.Evaluate(
tvm.make.Call(
"int32", "__buildin_prefetch", args, tvm.expr.Call.Intrinsic, None, 0)))
body = ib.get()
func = tvm.ir_pass.MakeAPI(body, "prefetch", [A], 0, True)
fcode = tvm.build(func, None, "llvm")
def test_llvm_add_pipeline():
nn = 1024
@ -151,6 +167,7 @@ def test_multiple_func():
if __name__ == "__main__":
test_llvm_intrin()
test_multiple_func()
test_llvm_add_pipeline()
test_llvm_flip_pipeline()

Просмотреть файл

@ -14,12 +14,10 @@ def test_flatten2():
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt)
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
A2b = tvm.decl_buffer(A2.shape, A2.dtype, name='A2')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b})
stmt = tvm.ir_pass.Simplify(stmt)
print(stmt)
if __name__ == "__main__":
test_flatten2()