207 строки
7.1 KiB
C++
207 строки
7.1 KiB
C++
/*!
|
|
* Copyright (c) 2018 by Contributors
|
|
* \file tvm/relay/expr_functor.h
|
|
* \brief A more powerful visitor which enables defining arbitrary function
|
|
* signatures with type based dispatch on first argument.
|
|
*/
|
|
#ifndef TVM_RELAY_EXPR_FUNCTOR_H_
|
|
#define TVM_RELAY_EXPR_FUNCTOR_H_
|
|
|
|
#include <tvm/node/ir_functor.h>
|
|
#include <string>
|
|
#include "./expr.h"
|
|
#include "./op.h"
|
|
#include "./error.h"
|
|
|
|
namespace tvm {
|
|
namespace relay {
|
|
|
|
/*!
|
|
* \brief A dynamical functor that dispatches on in the first Expr argument.
|
|
* You can use this as a more powerful Visitor, since it allows you to
|
|
* define function signatures of Visit Function.
|
|
*
|
|
* \sa tvm/ir_functor.h
|
|
*
|
|
* \tparam FType function signiture
|
|
* This type is only defined for FType with function signature R(const Expr&,
|
|
* Args...)
|
|
*/
|
|
template <typename FType>
|
|
class ExprFunctor;
|
|
|
|
// functions to be overriden.
|
|
#define EXPR_FUNCTOR_DEFAULT \
|
|
{ return VisitExprDefault_(op, std::forward<Args>(args)...); }
|
|
|
|
#define RELAY_EXPR_FUNCTOR_DISPATCH(OP) \
|
|
vtable.template set_dispatch<OP>( \
|
|
[](const NodeRef& n, TSelf* self, Args... args) { \
|
|
return self->VisitExpr_(static_cast<const OP*>(n.node_.get()), \
|
|
std::forward<Args>(args)...); \
|
|
});
|
|
|
|
template <typename R, typename... Args>
|
|
class ExprFunctor<R(const Expr& n, Args...)> {
|
|
private:
|
|
using TSelf = ExprFunctor<R(const Expr& n, Args...)>;
|
|
using FType = tvm::IRFunctor<R(const NodeRef& n, TSelf* self, Args...)>;
|
|
|
|
public:
|
|
/*! \brief the result type of this functor */
|
|
using result_type = R;
|
|
/*! \brief virtual destructor */
|
|
virtual ~ExprFunctor() {}
|
|
/*!
|
|
* \brief Same as call.
|
|
* \param n The expression node.
|
|
* \param args Additional arguments.
|
|
* \return The result of the call
|
|
*/
|
|
R operator()(const Expr& n, Args... args) {
|
|
return VisitExpr(n, std::forward<Args>(args)...);
|
|
}
|
|
/*!
|
|
* \brief The functor call.
|
|
* \param n The expression node.
|
|
* \param args Additional arguments.
|
|
* \return The result of the call
|
|
*/
|
|
virtual R VisitExpr(const Expr& n, Args... args) {
|
|
static FType vtable = InitVTable();
|
|
return vtable(n, this, std::forward<Args>(args)...);
|
|
}
|
|
// Functions that can be overriden by subclass
|
|
virtual R VisitExpr_(const ConstantNode* op,
|
|
Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExpr_(const TupleNode* op,
|
|
Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExpr_(const VarNode* op,
|
|
Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExpr_(const GlobalVarNode* op,
|
|
Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExpr_(const FunctionNode* op,
|
|
Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExpr_(const IfNode* op,
|
|
Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExpr_(const OpNode* op,
|
|
Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
|
|
virtual R VisitExprDefault_(const Node* op, Args...) {
|
|
throw Error(std::string("Do not have a default for ") + op->type_key());
|
|
}
|
|
|
|
private:
|
|
// initialize the vtable.
|
|
static FType InitVTable() {
|
|
FType vtable;
|
|
// Set dispatch
|
|
RELAY_EXPR_FUNCTOR_DISPATCH(ConstantNode);
|
|
RELAY_EXPR_FUNCTOR_DISPATCH(TupleNode);
|
|
RELAY_EXPR_FUNCTOR_DISPATCH(VarNode);
|
|
RELAY_EXPR_FUNCTOR_DISPATCH(GlobalVarNode);
|
|
RELAY_EXPR_FUNCTOR_DISPATCH(FunctionNode);
|
|
RELAY_EXPR_FUNCTOR_DISPATCH(CallNode);
|
|
RELAY_EXPR_FUNCTOR_DISPATCH(LetNode);
|
|
RELAY_EXPR_FUNCTOR_DISPATCH(IfNode);
|
|
RELAY_EXPR_FUNCTOR_DISPATCH(OpNode);
|
|
RELAY_EXPR_FUNCTOR_DISPATCH(TupleGetItemNode);
|
|
return vtable;
|
|
}
|
|
};
|
|
|
|
/*!
|
|
* \brief A simple visitor wrapper around ExprFunctor.
|
|
* Recursively visit the content.
|
|
*
|
|
* ExprVisitor treats Expr as dataflow graph,
|
|
* and only visit each Expr node once.
|
|
*/
|
|
class ExprVisitor
|
|
: public ::tvm::relay::ExprFunctor<void(const Expr& n)> {
|
|
public:
|
|
void VisitExpr(const Expr& expr) override;
|
|
void VisitExpr_(const VarNode* op) override;
|
|
void VisitExpr_(const GlobalVarNode* op) override;
|
|
void VisitExpr_(const ConstantNode* op) override;
|
|
void VisitExpr_(const TupleNode* op) override;
|
|
void VisitExpr_(const FunctionNode* op) override;
|
|
void VisitExpr_(const CallNode* op) override;
|
|
void VisitExpr_(const LetNode* op) override;
|
|
void VisitExpr_(const IfNode* op) override;
|
|
void VisitExpr_(const OpNode* op) override;
|
|
void VisitExpr_(const TupleGetItemNode* op) override;
|
|
virtual void VisitType(const Type& t);
|
|
|
|
protected:
|
|
// Internal visiting counter
|
|
std::unordered_map<const Node*, size_t> visit_counter_;
|
|
};
|
|
|
|
/*!
|
|
* \brief A wrapper around ExprFunctor which functionally updates the AST.
|
|
*
|
|
* ExprMutator treats Expr as dataflow graph, and only Mutate each Expr once.
|
|
* The mutated results are memoized in a map and reused so that
|
|
* local transformation on the dataflow preserves the graph structure.
|
|
*/
|
|
class ExprMutator
|
|
: public ::tvm::relay::ExprFunctor<Expr(const Expr&)> {
|
|
public:
|
|
/*!
|
|
* \brief Mutate is alias for VisitExpr
|
|
* \return expr.
|
|
*/
|
|
Expr Mutate(const Expr& expr) {
|
|
return this->VisitExpr(expr);
|
|
}
|
|
Expr VisitExpr(const Expr& expr) override;
|
|
Expr VisitExpr_(const VarNode* op) override;
|
|
Expr VisitExpr_(const ConstantNode* op) override;
|
|
Expr VisitExpr_(const GlobalVarNode* op) override;
|
|
Expr VisitExpr_(const OpNode* op) override;
|
|
Expr VisitExpr_(const TupleNode* op) override;
|
|
Expr VisitExpr_(const FunctionNode* op) override;
|
|
Expr VisitExpr_(const CallNode* call_node) override;
|
|
Expr VisitExpr_(const LetNode* op) override;
|
|
Expr VisitExpr_(const IfNode* op) override;
|
|
Expr VisitExpr_(const TupleGetItemNode* op) override;
|
|
/*!
|
|
* \brief Used to visit the types inside of expressions.
|
|
*
|
|
* Can be overloaded to transform the types in arbitrary
|
|
* ways, one way would be to define a sub-class of type
|
|
* visitor for types which transform them appropriately.
|
|
*/
|
|
virtual Type VisitType(const Type& t);
|
|
|
|
protected:
|
|
/*! \brief Internal map used for memoization. */
|
|
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo_;
|
|
};
|
|
|
|
/*!
|
|
* \brief recursively visit the ir in post DFS order node, apply fvisit
|
|
* Each node is guaranteed to be visited only once.
|
|
* \param node The ir to be visited.
|
|
* \param fvisit The visitor function to be applied.
|
|
*/
|
|
void PostOrderVisit(const NodeRef& node, std::function<void(const NodeRef&)> fvisit);
|
|
|
|
/*
|
|
* \brief Bind function parameters or free variables.
|
|
*
|
|
* Parameter binding can only happen if expr is a Function.
|
|
* binds cannot change internal arguments of internal functions.
|
|
*
|
|
* \param expr The function to be binded.
|
|
* \param binds The map of arguments to
|
|
*/
|
|
Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);
|
|
|
|
} // namespace relay
|
|
} // namespace tvm
|
|
#endif // TVM_RELAY_EXPR_FUNCTOR_H_
|