140 строки
4.4 KiB
C++
140 строки
4.4 KiB
C++
/*!
|
|
* Copyright (c) 2016 by Contributors
|
|
* \file tvm/ir_visitor.h
|
|
* \brief Visitor to quickly visit IR trees
|
|
*/
|
|
#ifndef TVM_IR_VISITOR_H_
|
|
#define TVM_IR_VISITOR_H_
|
|
|
|
#include "ir.h"
|
|
#include "tvm/node/ir_functor.h"
|
|
|
|
namespace tvm {
|
|
namespace ir {
|
|
|
|
/*!
|
|
* \brief a base class for visitor to iterative traverse the IR
|
|
*
|
|
* This IRVisitor is implemented via IRFunctor
|
|
* This enables extensions of possible new Node.
|
|
*
|
|
* \sa ExprFunctor, StmtFunctor, PostOrderVisit
|
|
*
|
|
* \note If you need to return values during Visit:
|
|
* - If it is mutation of the IR, use IRMutator
|
|
* - If you want to return other things, consider use ExprFunctor/StmtFunctor
|
|
* - Watch out for possible bug pattern if you use IRVisitor to simulate returns.
|
|
*
|
|
* \code
|
|
*
|
|
* // This is an example code to show cases for traps in IRVisitor
|
|
* // The use case is to count number of Variables in the ir tree.
|
|
* class MyCounter : public IRVisitor {
|
|
* public:
|
|
* int Count(const NodeRef& n) {
|
|
* ret_ = 0;
|
|
* this->Visit(n);
|
|
* return ret_;
|
|
* }
|
|
* void Visit_(const Variable* op) final {
|
|
* ret_ = 1;
|
|
* }
|
|
* void Visit_(const Add* op) final {
|
|
* ret_ = count(op->a) + count(op->b);
|
|
* }
|
|
|
|
* private:
|
|
* int ret_;
|
|
* };
|
|
* MyCounter counter;
|
|
* Var x("x");
|
|
* // this returns 2
|
|
* CHECK_EQ(counter.Count(x + x), 2);
|
|
* // Think what is the result of the following count
|
|
* counter.count(Max::make(x, x));
|
|
* // The result is actually 1
|
|
* // This is because Visit is not overriden for Max
|
|
* // so it simply calls Visit for the left and right children
|
|
* // and because Count is not called, ret_ is not cleared.
|
|
* // There can also be cases where ret_ is forgetten to be set.
|
|
*
|
|
* // These traps may not happen if we program carefully
|
|
* // But it is recommended to use ExprFunctor, which allows direct
|
|
* // return the value, this helps us to avoid such problems.
|
|
*
|
|
* \endcode
|
|
*/
|
|
class TVM_DLL IRVisitor {
|
|
public:
|
|
/*!
|
|
* \brief recursively visit an IR node
|
|
*/
|
|
virtual void Visit(const NodeRef& node) {
|
|
static const FVisit& f = vtable();
|
|
if (node.defined()) f(node, this);
|
|
}
|
|
/*! \brief destructor */
|
|
virtual ~IRVisitor() {}
|
|
/*! \brief functor type of visitor */
|
|
using FVisit = IRFunctor<void(const NodeRef&, IRVisitor*)>;
|
|
/*! \return internal vtable*/
|
|
static FVisit& vtable();
|
|
// overloadable visit function.
|
|
virtual void Visit_(const Variable* op);
|
|
virtual void Visit_(const LetStmt* op);
|
|
virtual void Visit_(const AttrStmt* op);
|
|
virtual void Visit_(const IfThenElse* op);
|
|
virtual void Visit_(const For* op);
|
|
virtual void Visit_(const Allocate* op);
|
|
virtual void Visit_(const Load* op);
|
|
virtual void Visit_(const Store* op);
|
|
virtual void Visit_(const Let* op);
|
|
virtual void Visit_(const Free* op);
|
|
virtual void Visit_(const Call* op);
|
|
virtual void Visit_(const Add* op);
|
|
virtual void Visit_(const Sub* op);
|
|
virtual void Visit_(const Mul* op);
|
|
virtual void Visit_(const Div* op);
|
|
virtual void Visit_(const Mod* op);
|
|
virtual void Visit_(const Min* op);
|
|
virtual void Visit_(const Max* op);
|
|
virtual void Visit_(const EQ* op);
|
|
virtual void Visit_(const NE* op);
|
|
virtual void Visit_(const LT* op);
|
|
virtual void Visit_(const LE* op);
|
|
virtual void Visit_(const GT* op);
|
|
virtual void Visit_(const GE* op);
|
|
virtual void Visit_(const And* op);
|
|
virtual void Visit_(const Or* op);
|
|
virtual void Visit_(const Reduce* op);
|
|
virtual void Visit_(const Cast* op);
|
|
virtual void Visit_(const Not* op);
|
|
virtual void Visit_(const Select* op);
|
|
virtual void Visit_(const Ramp* op);
|
|
virtual void Visit_(const Broadcast* op);
|
|
virtual void Visit_(const AssertStmt* op);
|
|
virtual void Visit_(const ProducerConsumer* op);
|
|
virtual void Visit_(const Provide* op);
|
|
virtual void Visit_(const Realize* op);
|
|
virtual void Visit_(const Prefetch* op);
|
|
virtual void Visit_(const Block* op);
|
|
virtual void Visit_(const Evaluate* op);
|
|
virtual void Visit_(const IntImm* op);
|
|
virtual void Visit_(const UIntImm* op);
|
|
virtual void Visit_(const FloatImm* op);
|
|
virtual void Visit_(const StringImm* op);
|
|
};
|
|
|
|
/*!
|
|
* \brief recursively visit the ir in post DFS order node, apply fvisit
|
|
* Each node is guaranteed to be visited only once.
|
|
* \param node The ir to be visited.
|
|
* \param fvisit The visitor function to be applied.
|
|
*/
|
|
void PostOrderVisit(const NodeRef& node, std::function<void(const NodeRef&)> fvisit);
|
|
|
|
} // namespace ir
|
|
} // namespace tvm
|
|
|
|
#endif // TVM_IR_VISITOR_H_
|