From 0a1f3d41c96109602e00149ff03f90ed85d73c85 Mon Sep 17 00:00:00 2001 From: ziheng Date: Mon, 26 Nov 2018 18:25:17 +0000 Subject: [PATCH] [PASS] PostOrderVisit (#2169) --- include/tvm/relay/attrs/transform.h | 13 +++++++++++++ include/tvm/relay/expr_functor.h | 8 ++++++++ python/tvm/relay/ir_pass.py | 13 +++++++++++++ src/relay/ir/expr_functor.cc | 30 +++++++++++++++++++++++++++++ src/relay/op/tensor/unary.cc | 16 +++------------ 5 files changed, 67 insertions(+), 13 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 39cd82de..3e56106d 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -151,6 +151,19 @@ struct SliceLikeAttrs : public tvm::AttrsNode { } }; +// Clip +struct ClipAttrs : public tvm::AttrsNode { + double a_min; + double a_max; + + TVM_DECLARE_ATTRS(ClipAttrs, "relay.attrs.ClipAttrs") { + TVM_ATTR_FIELD(a_min) + .describe("The minimum clip value."); + TVM_ATTR_FIELD(a_max) + .describe("The maximum clip value."); + } +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_TRANSFORM_H_ diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 1681f9b8..60b18218 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -182,6 +182,14 @@ class ExprMutator std::unordered_map memo_; }; +/*! + * \brief recursively visit the ir in post DFS order node, apply fvisit + * Each node is guaranteed to be visited only once. + * \param node The ir to be visited. + * \param fvisit The visitor function to be applied. + */ +void PostOrderVisit(const NodeRef& node, std::function fvisit); + /* * \brief Bind function parameters or free variables. * diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index ef0a59cd..6297e366 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -10,6 +10,19 @@ from . import _make from .expr import Expr from .ty import Type +def post_order_visit(expr, fvisit): + """Recursively visit the ir in post DFS order node, + apply fvisit. Each node is guaranteed to be visited + only once. + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression. + fvisit : function + The visitor function to be applied. + """ + return _ir_pass.post_order_visit(expr, fvisit) def infer_type(expr, mod=None): """Infer the type of expr under the context of mod. diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index 5e3ee176..bacbfea7 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -228,6 +228,36 @@ void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { void ExprVisitor::VisitType(const Type& t) { return; } + +// visitor to implement apply +class ExprApplyVisit : public ExprVisitor { + public: + explicit ExprApplyVisit(std::function f) : f_(f) {} + void VisitExpr(const Expr& e) final { + if (visited_.count(e.get()) != 0) return; + visited_.insert(e.get()); + ExprVisitor::VisitExpr(e); + f_(e); + } + + private: + std::function f_; + std::unordered_set visited_; +}; + +void PostOrderVisit(const Expr& e, std::function fvisit) { + ExprApplyVisit(fvisit).VisitExpr(e); +} + +TVM_REGISTER_API("relay._ir_pass.post_order_visit") +.set_body([](TVMArgs args, TVMRetValue *ret) { + PackedFunc f = args[1]; + PostOrderVisit(args[0], [f](const Expr& n) { + f(n); + }); + }); + + // Implement bind. class ExprBinder : public ExprMutator { public: diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc index 6c94fe2a..fef0302a 100644 --- a/src/relay/op/tensor/unary.cc +++ b/src/relay/op/tensor/unary.cc @@ -5,6 +5,7 @@ */ #include #include +#include #include #include "../type_relations.h" #include "../op_common.h" @@ -89,19 +90,8 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "copy") .add_type_rel("Identity", IdentityRel) .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::identity)); - -// Clip -struct ClipAttrs : public tvm::AttrsNode { - double a_min; - double a_max; - - TVM_DECLARE_ATTRS(ClipAttrs, "relay.attrs.ClipAttrs") { - TVM_ATTR_FIELD(a_min) - .describe("The minimum clip value."); - TVM_ATTR_FIELD(a_max) - .describe("The maximum clip value."); - } -}; +// relay.clip +TVM_REGISTER_NODE_TYPE(ClipAttrs); TVM_REGISTER_API("relay.op._make.clip") .set_body_typed([](Expr a, double a_min, double a_max) {