[BACKEND] Explicitly allow specialization of FMA in llvm (#407)
This commit is contained in:
Родитель
a45d3b01f7
Коммит
f73c461f50
|
@ -67,6 +67,9 @@ inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) {
|
|||
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp")
|
||||
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp>);
|
||||
|
||||
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fma")
|
||||
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd>);
|
||||
|
||||
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log")
|
||||
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log>);
|
||||
|
||||
|
|
|
@ -16,11 +16,12 @@ namespace ir {
|
|||
class IntrinInjecter : public IRMutator {
|
||||
public:
|
||||
explicit IntrinInjecter(std::string target) {
|
||||
patterns_.push_back("tvm.intrin.rule." + target + ".");
|
||||
if (!strncmp(target.c_str(), "llvm", 4) && target != "llvm") {
|
||||
patterns_.push_back("tvm.intrin.rule.llvm.");
|
||||
}
|
||||
std::istringstream is(target);
|
||||
std::string starget;
|
||||
is >> starget;
|
||||
patterns_.push_back("tvm.intrin.rule." + starget + ".");
|
||||
patterns_.push_back("tvm.intrin.rule.default.");
|
||||
fma_ = runtime::Registry::Get(patterns_[0] + "fma");
|
||||
}
|
||||
|
||||
Expr Mutate_(const Call* op, const Expr& e) final {
|
||||
|
@ -32,6 +33,22 @@ class IntrinInjecter : public IRMutator {
|
|||
return IRMutator::Mutate_(op, e);
|
||||
}
|
||||
|
||||
Expr Mutate_(const Add* op, const Expr& e) final {
|
||||
if (fma_ == nullptr || !op->type.is_float()) {
|
||||
return IRMutator::Mutate_(op, e);
|
||||
}
|
||||
if (const Mul* mb = op->b.as<Mul>()) {
|
||||
Expr r = (*fma_)(Call::make(
|
||||
op->type, "fma", {mb->a, mb->b, op->a}, Call::PureIntrinsic));
|
||||
if (r.defined()) return r;
|
||||
} else if (const Mul* ma = op->a.as<Mul>()) {
|
||||
Expr r = (*fma_)(Call::make(
|
||||
op->type, "fma", {ma->a, ma->b, op->b}, Call::PureIntrinsic));
|
||||
if (r.defined()) return r;
|
||||
}
|
||||
return IRMutator::Mutate_(op, e);
|
||||
}
|
||||
|
||||
private:
|
||||
Expr ApplyPattern(const std::string& name, const Expr& e) {
|
||||
for (size_t i = 0; i < patterns_.size(); ++i) {
|
||||
|
@ -54,6 +71,7 @@ class IntrinInjecter : public IRMutator {
|
|||
}
|
||||
// patterns
|
||||
std::vector<std::string> patterns_;
|
||||
const PackedFunc* fma_{nullptr};
|
||||
};
|
||||
|
||||
LoweredFunc
|
||||
|
|
Загрузка…
Ссылка в новой задаче