268 строки
10 KiB
C++
268 строки
10 KiB
C++
/*!
|
|
* Copyright (c) 2017 by Contributors
|
|
* \file ir_functor_ext.h
|
|
* \brief More powerful Visitor that allows define function signatures.
|
|
*/
|
|
#ifndef TVM_IR_FUNCTOR_EXT_H_
|
|
#define TVM_IR_FUNCTOR_EXT_H_
|
|
|
|
#include <tvm/ir_functor.h>
|
|
#include "./ir.h"
|
|
|
|
namespace tvm {
|
|
namespace ir {
|
|
|
|
/*!
|
|
* \brief A dynamical functor that dispatches on in the first Expr argument.
|
|
* You can use this as a more powerful Visitor, since it allows you to
|
|
* define function signatures of Visit Function.
|
|
*
|
|
* This helps you to avoid to book-keep return value of Visitor via state,
|
|
* which can cause bugs easily when state is incorrectly maintained.
|
|
*
|
|
* \code
|
|
* // A functor that set variable to b. and calculate results.
|
|
* class MyExprFunctor
|
|
* : public ir::ExprFunctor<int(const Expr&, int)> {
|
|
* public:
|
|
* int VisitExpr_(const Variable* op, int b) final {
|
|
* return b;
|
|
* }
|
|
* int VisitExpr_(const IntImm* op, int b) final {
|
|
* return op->value;
|
|
* }
|
|
* int VisitExpr_(const Add* op, int b) final {
|
|
* return Visit(op->a, b) + Visit(op->b, b);
|
|
* }
|
|
* };
|
|
* MyExprFunctor f;
|
|
* Var x("x");
|
|
* CHECK_EQ(f(x + 1, 2), 3);
|
|
* \endcode
|
|
*
|
|
* \note Why do we need this more powerful Functor:
|
|
*
|
|
* We often need to implement a transformer tasks.
|
|
* Say we want to take Expr and transform it to some analysis result,
|
|
* This easily be done incorrectly using plain Visitor. See IRVisitor's
|
|
* document for possible error cases.
|
|
*
|
|
* \tparam FType function signiture
|
|
* This type if only defined for FType with function signiture R(const Expr&, Args...)
|
|
*/
|
|
template<typename FType>
|
|
class ExprFunctor;
|
|
/*!
|
|
* \brief Same as ExprFunctor except it is applied on statements
|
|
* \tparam FType The function signature.
|
|
*/
|
|
template<typename FType>
|
|
class StmtFunctor;
|
|
|
|
// functions to be overriden.
|
|
#define EXPR_FUNCTOR_DEFAULT { \
|
|
return VisitExprDefault_(op, std::forward<Args>(args)...); \
|
|
}
|
|
#define STMT_FUNCTOR_DEFAULT { \
|
|
return VisitStmtDefault_(op, std::forward<Args>(args)...); \
|
|
}
|
|
|
|
#define IR_EXPR_FUNCTOR_DISPATCH(OP) \
|
|
vtable.template set_dispatch<OP>( \
|
|
[](const NodeRef& n, TSelf* self, Args... args) { \
|
|
return self->VisitExpr_(static_cast<const OP*>(n.node_.get()), \
|
|
std::forward<Args>(args)...); \
|
|
}); \
|
|
|
|
#define IR_STMT_FUNCTOR_DISPATCH(OP) \
|
|
vtable.template set_dispatch<OP>( \
|
|
[](const NodeRef& n, TSelf* self, Args... args) { \
|
|
return self->VisitStmt_(static_cast<const OP*>(n.node_.get()), \
|
|
std::forward<Args>(args)...); \
|
|
}); \
|
|
|
|
template<typename R, typename ...Args>
|
|
class ExprFunctor<R(const Expr& n, Args...)> {
|
|
private:
|
|
using TSelf = ExprFunctor<R(const Expr& n, Args...)>;
|
|
using FType = IRFunctor<R(const NodeRef& n, TSelf* self, Args...)>;
|
|
|
|
public:
|
|
/*! \brief the result type of this functor */
|
|
using result_type = R;
|
|
/*! \brief virtual destructor */
|
|
virtual ~ExprFunctor() {}
|
|
/*!
|
|
* \brief Same as call.
|
|
* \param n The expression node.
|
|
* \param args Additional arguments.
|
|
* \return The result of the call
|
|
*/
|
|
R operator()(const Expr& n, Args... args) {
|
|
return VisitExpr(n, std::forward<Args>(args)...);
|
|
}
|
|
/*!
|
|
* \brief The functor call.
|
|
* \param n The expression node.
|
|
* \param args Additional arguments.
|
|
* \return The result of the call
|
|
*/
|
|
virtual R VisitExpr(const Expr& n, Args... args) {
|
|
static FType vtable = InitVTable();
|
|
return vtable(n, this, std::forward<Args>(args)...);
|
|
}
|
|
// Functions that can be overriden by subclass
|
|
virtual R VisitExpr_(const Variable* op, Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExpr_(const Load* op, Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExpr_(const Let* op, Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExpr_(const Call* op, Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExpr_(const Add* op, Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExpr_(const Sub* op, Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExpr_(const Mul* op, Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExpr_(const Div* op, Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExpr_(const Mod* op, Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExpr_(const Min* op, Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExpr_(const Max* op, Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExpr_(const EQ* op, Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExpr_(const NE* op, Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExpr_(const LT* op, Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExpr_(const LE* op, Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExpr_(const GT* op, Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExpr_(const GE* op, Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExpr_(const And* op, Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExpr_(const Or* op, Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExpr_(const Reduce* op, Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExpr_(const Cast* op, Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExpr_(const Not* op, Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExpr_(const Select* op, Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExpr_(const Ramp* op, Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExpr_(const Broadcast* op, Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExpr_(const Shuffle* op, Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExpr_(const IntImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExpr_(const UIntImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExpr_(const FloatImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExpr_(const StringImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExprDefault_(const Node* op, Args ...) {
|
|
LOG(FATAL) << "Do not have a default for " << op->type_key();
|
|
return R();
|
|
}
|
|
|
|
private:
|
|
// initialize the vtable.
|
|
static FType InitVTable() {
|
|
FType vtable;
|
|
// Set dispatch
|
|
IR_EXPR_FUNCTOR_DISPATCH(Variable);
|
|
IR_EXPR_FUNCTOR_DISPATCH(Load);
|
|
IR_EXPR_FUNCTOR_DISPATCH(Let);
|
|
IR_EXPR_FUNCTOR_DISPATCH(Call);
|
|
IR_EXPR_FUNCTOR_DISPATCH(Add);
|
|
IR_EXPR_FUNCTOR_DISPATCH(Sub);
|
|
IR_EXPR_FUNCTOR_DISPATCH(Mul);
|
|
IR_EXPR_FUNCTOR_DISPATCH(Div);
|
|
IR_EXPR_FUNCTOR_DISPATCH(Mod);
|
|
IR_EXPR_FUNCTOR_DISPATCH(Min);
|
|
IR_EXPR_FUNCTOR_DISPATCH(Max);
|
|
IR_EXPR_FUNCTOR_DISPATCH(EQ);
|
|
IR_EXPR_FUNCTOR_DISPATCH(NE);
|
|
IR_EXPR_FUNCTOR_DISPATCH(LT);
|
|
IR_EXPR_FUNCTOR_DISPATCH(LE);
|
|
IR_EXPR_FUNCTOR_DISPATCH(GT);
|
|
IR_EXPR_FUNCTOR_DISPATCH(GE);
|
|
IR_EXPR_FUNCTOR_DISPATCH(And);
|
|
IR_EXPR_FUNCTOR_DISPATCH(Or);
|
|
IR_EXPR_FUNCTOR_DISPATCH(Reduce);
|
|
IR_EXPR_FUNCTOR_DISPATCH(Cast);
|
|
IR_EXPR_FUNCTOR_DISPATCH(Not);
|
|
IR_EXPR_FUNCTOR_DISPATCH(Select);
|
|
IR_EXPR_FUNCTOR_DISPATCH(Ramp);
|
|
IR_EXPR_FUNCTOR_DISPATCH(Broadcast);
|
|
IR_EXPR_FUNCTOR_DISPATCH(IntImm);
|
|
IR_EXPR_FUNCTOR_DISPATCH(UIntImm);
|
|
IR_EXPR_FUNCTOR_DISPATCH(FloatImm);
|
|
IR_EXPR_FUNCTOR_DISPATCH(StringImm);
|
|
return vtable;
|
|
}
|
|
};
|
|
|
|
template<typename R, typename ...Args>
|
|
class StmtFunctor<R(const Stmt& n, Args... args)> {
|
|
private:
|
|
using TSelf = StmtFunctor<R(const Stmt& n, Args... args)>;
|
|
using FType = IRFunctor<R(const NodeRef& n, TSelf* self, Args... args)>;
|
|
|
|
public:
|
|
/*! \brief the result type of this functor */
|
|
using result_type = R;
|
|
/*! \brief virtual destructor */
|
|
virtual ~StmtFunctor() {}
|
|
/*!
|
|
* \brief Same as call.
|
|
* \param n The stmt node.
|
|
* \param args Additional arguments.
|
|
* \return The result of the call
|
|
*/
|
|
R operator()(const Stmt& n, Args... args) {
|
|
return VisitStmt(n, std::forward<Args>(args)...);
|
|
}
|
|
/*!
|
|
* \brief The functor call.
|
|
* \param n The stmt node.
|
|
* \param args Additional arguments.
|
|
* \return The result of the call
|
|
*/
|
|
virtual R VisitStmt(const Stmt& n, Args... args) {
|
|
static FType vtable = InitVTable();
|
|
return vtable(n, this, std::forward<Args>(args)...);
|
|
}
|
|
// Functions that can be overriden by subclass
|
|
virtual R VisitStmt_(const LetStmt* op, Args... args) STMT_FUNCTOR_DEFAULT;
|
|
virtual R VisitStmt_(const AttrStmt* op, Args... args) STMT_FUNCTOR_DEFAULT;
|
|
virtual R VisitStmt_(const IfThenElse* op, Args... args) STMT_FUNCTOR_DEFAULT;
|
|
virtual R VisitStmt_(const For* op, Args... args) STMT_FUNCTOR_DEFAULT;
|
|
virtual R VisitStmt_(const Allocate* op, Args... args) STMT_FUNCTOR_DEFAULT;
|
|
virtual R VisitStmt_(const Store* op, Args... args) STMT_FUNCTOR_DEFAULT;
|
|
virtual R VisitStmt_(const Free* op, Args... args) STMT_FUNCTOR_DEFAULT;
|
|
virtual R VisitStmt_(const AssertStmt* op, Args... args) STMT_FUNCTOR_DEFAULT;
|
|
virtual R VisitStmt_(const ProducerConsumer* op, Args... args) STMT_FUNCTOR_DEFAULT;
|
|
virtual R VisitStmt_(const Provide* op, Args... args) STMT_FUNCTOR_DEFAULT;
|
|
virtual R VisitStmt_(const Realize* op, Args... args) STMT_FUNCTOR_DEFAULT;
|
|
virtual R VisitStmt_(const Prefetch* op, Args... args) STMT_FUNCTOR_DEFAULT;
|
|
virtual R VisitStmt_(const Block* op, Args... args) STMT_FUNCTOR_DEFAULT;
|
|
virtual R VisitStmt_(const Evaluate* op, Args... args) STMT_FUNCTOR_DEFAULT;
|
|
virtual R VisitStmtDefault_(const Node* op, Args ...) {
|
|
LOG(FATAL) << "Do not have a default for " << op->type_key();
|
|
return R();
|
|
}
|
|
|
|
private:
|
|
// initialize the vtable.
|
|
static FType InitVTable() {
|
|
FType vtable;
|
|
IR_STMT_FUNCTOR_DISPATCH(LetStmt);
|
|
IR_STMT_FUNCTOR_DISPATCH(AttrStmt);
|
|
IR_STMT_FUNCTOR_DISPATCH(IfThenElse);
|
|
IR_STMT_FUNCTOR_DISPATCH(For);
|
|
IR_STMT_FUNCTOR_DISPATCH(Allocate);
|
|
IR_STMT_FUNCTOR_DISPATCH(Store);
|
|
IR_STMT_FUNCTOR_DISPATCH(Free);
|
|
IR_STMT_FUNCTOR_DISPATCH(AssertStmt);
|
|
IR_STMT_FUNCTOR_DISPATCH(ProducerConsumer);
|
|
IR_STMT_FUNCTOR_DISPATCH(Provide);
|
|
IR_STMT_FUNCTOR_DISPATCH(Realize);
|
|
IR_STMT_FUNCTOR_DISPATCH(Prefetch);
|
|
IR_STMT_FUNCTOR_DISPATCH(Block);
|
|
IR_STMT_FUNCTOR_DISPATCH(Evaluate);
|
|
return vtable;
|
|
}
|
|
};
|
|
|
|
#undef IR_STMT_FUNCTOR_DISPATCH
|
|
#undef IR_EXPR_FUNCTOR_DISPATCH
|
|
#undef EXPR_FUNCTOR_DEFAULT
|
|
#undef STMT_FUNCTOR_DEFAULT
|
|
|
|
} // namespace ir
|
|
} // namespace tvm
|
|
#endif // TVM_IR_FUNCTOR_EXT_H_
|