[BACKEND] Explicitly allow specialization of FMA in llvm (#407)

This commit is contained in:
Tianqi Chen 2017-09-01 11:46:45 -07:00 коммит произвёл GitHub
Родитель a45d3b01f7
Коммит f73c461f50
2 изменённых файлов: 25 добавлений и 4 удалений

Просмотреть файл

@ -67,6 +67,9 @@ inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) {
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fma")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log>);

Просмотреть файл

@ -16,11 +16,12 @@ namespace ir {
class IntrinInjecter : public IRMutator {
public:
explicit IntrinInjecter(std::string target) {
patterns_.push_back("tvm.intrin.rule." + target + ".");
if (!strncmp(target.c_str(), "llvm", 4) && target != "llvm") {
patterns_.push_back("tvm.intrin.rule.llvm.");
}
std::istringstream is(target);
std::string starget;
is >> starget;
patterns_.push_back("tvm.intrin.rule." + starget + ".");
patterns_.push_back("tvm.intrin.rule.default.");
fma_ = runtime::Registry::Get(patterns_[0] + "fma");
}
Expr Mutate_(const Call* op, const Expr& e) final {
@ -32,6 +33,22 @@ class IntrinInjecter : public IRMutator {
return IRMutator::Mutate_(op, e);
}
Expr Mutate_(const Add* op, const Expr& e) final {
if (fma_ == nullptr || !op->type.is_float()) {
return IRMutator::Mutate_(op, e);
}
if (const Mul* mb = op->b.as<Mul>()) {
Expr r = (*fma_)(Call::make(
op->type, "fma", {mb->a, mb->b, op->a}, Call::PureIntrinsic));
if (r.defined()) return r;
} else if (const Mul* ma = op->a.as<Mul>()) {
Expr r = (*fma_)(Call::make(
op->type, "fma", {ma->a, ma->b, op->b}, Call::PureIntrinsic));
if (r.defined()) return r;
}
return IRMutator::Mutate_(op, e);
}
private:
Expr ApplyPattern(const std::string& name, const Expr& e) {
for (size_t i = 0; i < patterns_.size(); ++i) {
@ -54,6 +71,7 @@ class IntrinInjecter : public IRMutator {
}
// patterns
std::vector<std::string> patterns_;
const PackedFunc* fma_{nullptr};
};
LoweredFunc