From ec0d497c69ca307fb998c3d81c0a7e48bb5f18d6 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Thu, 20 Sep 2018 20:17:24 -0700 Subject: [PATCH] [NODE][RELAY] Move most of the reference related code to node (#1747) --- include/tvm/node/node.h | 52 +++++++++++++++-- include/tvm/relay/base.h | 37 ------------ include/tvm/relay/expr.h | 6 +- include/tvm/relay/expr_functor.h | 24 ++++---- src/relay/ir/environment.cc | 4 +- src/relay/ir/expr_functor.cc | 97 +++++++++++++++++--------------- src/relay/pass/type_visitor.h | 3 +- tests/cpp/expr_test.cc | 2 +- 8 files changed, 120 insertions(+), 105 deletions(-) diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h index d726b1da..efa93056 100644 --- a/include/tvm/node/node.h +++ b/include/tvm/node/node.h @@ -102,10 +102,10 @@ class TVM_DLL Node : public NodeBase { template inline bool is_type() const; /*! - * \brief Get a NodeRef that holds reference to this Node. - * \return the NodeRef + * \brief Get a NodePtr that holds reference to this Node. + * \return the NodePtr */ - inline NodeRef GetNodeRef() const; + inline NodePtr GetNodePtr() const; // node ref can see this friend class NodeRef; static constexpr const char* _type_key = "Node"; @@ -176,6 +176,32 @@ class NodeRef { NodePtr node_; }; +/*! + * \brief Get a reference type from a Node ptr type + * + * It is always important to get a reference type + * if we want to return a value as reference or keep + * the node alive beyond the scope of the function. + * + * \param ptr The node pointer + * \tparam RefType The reference type + * \tparam NodeType The node type + * \return The corresponding RefType + */ +template +inline RefType GetRef(const NodeType* ptr); + +/*! + * \brief Downcast a base reference type to a more specific type. + * + * \param ref The inptut reference + * \return The corresponding SubRef. + * \tparam SubRef The target specific reference type. + * \tparam BaseRef the current reference type. + */ +template +inline SubRef Downcast(BaseRef ref); + /*! * \brief helper macro to declare type information in a base node. */ @@ -218,8 +244,24 @@ inline bool Node::derived_from() const { return this->_DerivedFrom(type_id); } -inline NodeRef Node::GetNodeRef() const { - return NodeRef(NodePtr(const_cast(this))); +inline NodePtr Node::GetNodePtr() const { + return NodePtr(const_cast(this)); +} + +template +inline RefType GetRef(const NodeType* ptr) { + static_assert(std::is_base_of::value, + "Can only cast to the ref of same container type"); + return RefType(ptr->GetNodePtr()); +} + +template +inline SubRef Downcast(BaseRef ref) { + CHECK(ref->template is_type() || + ref->template derived_from()) + << "Downcast from " << ref->type_key() << " to " + << SubRef::ContainerType::_type_key << " failed."; + return SubRef(std::move(ref.node_)); } inline const Node* NodeRef::get() const { diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index ecf45353..ab55f6f3 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -158,43 +158,6 @@ class RelayNode : public Node { TVM_DECLARE_BASE_NODE_INFO(RelayNode, Node); }; -/*! - * \brief Get a reference type from a Node ptr type - * - * It is always important to get a reference type - * if we want to return a value as reference or keep - * the node alive beyond the scope of the function. - * - * \param ptr The node pointer - * \tparam RefType The reference type - * \tparam NodeType The node type - * \return The corresponding RefType - */ -template -RefType GetRef(const NodeType* ptr) { - static_assert(std::is_same::value, - "Can only cast to the ref of same container type"); - return RefType(std::move(ptr->GetNodeRef().node_)); -} - -// TODO(@tqchen, @jroesch): can we move these semantics to HalideIR -template -inline const T* As(const NodeRef& node) { - const Node* ptr = static_cast(node.get()); - if (ptr && (ptr->is_type() || ptr->derived_from())) { - return static_cast(ptr); - } - return nullptr; -} - -template -SubRef Downcast(BaseRef ref) { - CHECK(ref->template is_type()) - << "Downcast from " << ref->type_key() << " to " - << SubRef::ContainerType::_type_key << " failed."; - return SubRef(ref.node_); -} - } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 6388e836..0dc2ff6f 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -65,7 +65,9 @@ class ConstantNode : public ExprNode { TensorType tensor_type() const; /*! \return Whether it is scalar(rank-0 tensor) */ - bool is_scalar() const { return data->ndim == 0; } + bool is_scalar() const { + return data->ndim == 0; + } void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("data", &data); @@ -341,7 +343,7 @@ RELAY_DEFINE_NODE_REF(Let, LetNode, Expr); * * let x = if (true) { 1 } else { 0 }; // x is 1 * let y = if (false) { 1 } else { 0 }; // y is 0 - * + * * \note This is similar to C's ternary operator. */ class If; diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 27bb464b..e79535a5 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -139,19 +139,19 @@ class ExprVisitor : public ::tvm::relay::ExprFunctor { * the cost of using functional updates. */ class ExprMutator - : public ::tvm::relay::ExprFunctor { + : public ::tvm::relay::ExprFunctor { public: Expr Mutate(const Expr& expr); - Expr VisitExpr_(const VarNode* op, const Expr& e) override; - Expr VisitExpr_(const ConstantNode* op, const Expr& e) override; - Expr VisitExpr_(const GlobalVarNode* op, const Expr& e) override; - Expr VisitExpr_(const OpNode* op, const Expr& expr) override; - Expr VisitExpr_(const TupleNode* op, const Expr& e) override; - Expr VisitExpr_(const ParamNode* op, const Expr& e) override; - Expr VisitExpr_(const FunctionNode* op, const Expr& e) override; - Expr VisitExpr_(const CallNode* call_node, const Expr& e) override; - Expr VisitExpr_(const LetNode* op, const Expr& e) override; - Expr VisitExpr_(const IfNode* op, const Expr& e) override; + Expr VisitExpr_(const VarNode* op) override; + Expr VisitExpr_(const ConstantNode* op) override; + Expr VisitExpr_(const GlobalVarNode* op) override; + Expr VisitExpr_(const OpNode* op) override; + Expr VisitExpr_(const TupleNode* op) override; + Expr VisitExpr_(const ParamNode* op) override; + Expr VisitExpr_(const FunctionNode* op) override; + Expr VisitExpr_(const CallNode* call_node) override; + Expr VisitExpr_(const LetNode* op) override; + Expr VisitExpr_(const IfNode* op) override; /*! \brief Used to visit the types inside of expressions. * * Can be overloaded to transform the types in arbitrary @@ -162,7 +162,7 @@ class ExprMutator private: /*! \brief Internal map used for memoization. */ - tvm::Map memo_; + std::unordered_map memo_; }; } // namespace relay diff --git a/src/relay/ir/environment.cc b/src/relay/ir/environment.cc index 16b03145..d7a28231 100644 --- a/src/relay/ir/environment.cc +++ b/src/relay/ir/environment.cc @@ -41,12 +41,12 @@ void EnvironmentNode::Add(const GlobalVar &var, const Function &func, bool update) { // Type check the item before we add it to the environment. - auto env = relay::GetRef(this); + auto env = GetRef(this); Expr checked_expr = InferType(env, var, func); if (const FunctionNode *func_node = checked_expr.as()) { - auto checked_func = relay::GetRef(func_node); + auto checked_func = GetRef(func_node); auto type = checked_func->checked_type(); CHECK(IsFullyResolved(type)); diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index 85ae5ffa..e3393bdb 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -13,33 +13,33 @@ namespace tvm { namespace relay { Expr ExprMutator::Mutate(const Expr& expr) { - auto cached_expr = this->memo_.find(expr); - if (cached_expr != this->memo_.end()) { - return (*cached_expr).second; + auto it = this->memo_.find(expr); + if (it != this->memo_.end()) { + return it->second; } else { - auto new_expr = this->ExprMutator::VisitExpr(expr, expr); - this->memo_.Set(expr, new_expr); + Expr new_expr = ExprMutator::VisitExpr(expr); + memo_[expr] = new_expr; return new_expr; } } -Expr ExprMutator::VisitExpr_(const VarNode* op, const Expr& expr) { - return expr; +Expr ExprMutator::VisitExpr_(const VarNode* op) { + return GetRef(op); } -Expr ExprMutator::VisitExpr_(const ConstantNode* op, const Expr& expr) { - return expr; +Expr ExprMutator::VisitExpr_(const ConstantNode* op) { + return GetRef(op); } -Expr ExprMutator::VisitExpr_(const GlobalVarNode* op, const Expr& expr) { - return expr; +Expr ExprMutator::VisitExpr_(const GlobalVarNode* op) { + return GetRef(op); } -Expr ExprMutator::VisitExpr_(const OpNode* op, const Expr& expr) { - return expr; +Expr ExprMutator::VisitExpr_(const OpNode* op) { + return GetRef(op); } -Expr ExprMutator::VisitExpr_(const TupleNode* op, const Expr& e) { +Expr ExprMutator::VisitExpr_(const TupleNode* op) { tvm::Array fields; bool all_fields_unchanged = true; for (auto field : op->fields) { @@ -49,23 +49,23 @@ Expr ExprMutator::VisitExpr_(const TupleNode* op, const Expr& e) { } if (all_fields_unchanged) { - return e; + return GetRef(op); } else { return TupleNode::make(fields); } } -Expr ExprMutator::VisitExpr_(const ParamNode* op, const Expr& e) { +Expr ExprMutator::VisitExpr_(const ParamNode* op) { Var var = Downcast(this->Mutate(op->var)); auto type = this->VisitType(op->type); - if (var == op->var && type == op->type) { - return e; + if (op->var.same_as(var) && op->type.same_as(type)) { + return GetRef(op); } else { return ParamNode::make(var, type); } } -Expr ExprMutator::VisitExpr_(const FunctionNode* op, const Expr& e) { +Expr ExprMutator::VisitExpr_(const FunctionNode* op) { tvm::Array ty_params; bool all_ty_params_changed = true; @@ -86,74 +86,82 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op, const Expr& e) { auto ret_type = this->VisitType(op->ret_type); auto body = this->Mutate(op->body); - if (ty_params.same_as(op->type_params) && params.same_as(op->params) && - ret_type.same_as(op->ret_type) && body.same_as(op->body)) { - return e; + if (ty_params.same_as(op->type_params) && + params.same_as(op->params) && + ret_type.same_as(op->ret_type) && + body.same_as(op->body)) { + return GetRef(op); } else { return FunctionNode::make(params, ret_type, body, ty_params); } } -Expr ExprMutator::VisitExpr_(const CallNode* call_node, const Expr& e) { - auto op = this->Mutate(call_node->op); +Expr ExprMutator::VisitExpr_(const CallNode* call_node) { + auto new_op = this->Mutate(call_node->op); + bool unchanged = call_node->op.same_as(new_op); tvm::Array ty_args; - bool all_ty_args_unchanged = true; for (auto ty_arg : call_node->type_args) { auto new_ty_arg = this->VisitType(ty_arg); ty_args.push_back(new_ty_arg); - all_ty_args_unchanged &= new_ty_arg.same_as(ty_arg); + unchanged &= new_ty_arg.same_as(ty_arg); } tvm::Array call_args; - bool all_args_unchanged = true; for (auto arg : call_node->args) { auto new_arg = this->Mutate(arg); call_args.push_back(new_arg); - all_args_unchanged &= new_arg.same_as(arg); + unchanged &= new_arg.same_as(arg); } - if (all_ty_args_unchanged && all_args_unchanged && - call_node->op.same_as(op)) { - return e; + if (unchanged) { + return GetRef(call_node); } else { - return CallNode::make(op, call_args, call_node->attrs, ty_args); + return CallNode::make(new_op, call_args, call_node->attrs, ty_args); } } -Expr ExprMutator::VisitExpr_(const LetNode* op, const Expr& e) { +Expr ExprMutator::VisitExpr_(const LetNode* op) { Var var = Downcast(this->Mutate(op->var)); auto type = this->VisitType(op->value_type); auto value = this->Mutate(op->value); auto body = this->Mutate(op->body); - if (var.same_as(op->var) && type.same_as(op->value_type) && - value.same_as(op->value) && body.same_as(op->body)) { - return e; + if (var.same_as(op->var) && + type.same_as(op->value_type) && + value.same_as(op->value) && + body.same_as(op->body)) { + return GetRef(op); } else { return LetNode::make(var, value, body, type); } } -Expr ExprMutator::VisitExpr_(const IfNode* op, const Expr& e) { +Expr ExprMutator::VisitExpr_(const IfNode* op) { auto guard = this->Mutate(op->cond); auto true_b = this->Mutate(op->true_branch); auto false_b = this->Mutate(op->false_branch); - if (op->cond == guard && true_b == op->true_branch && - false_b == op->false_branch) { - return e; + if (op->cond.same_as(guard) && + op->true_branch.same_as(true_b) && + op->false_branch.same_as(false_b)) { + return GetRef(op);; } else { return IfNode::make(guard, true_b, false_b); } } -Type ExprMutator::VisitType(const Type& t) { return t; } +Type ExprMutator::VisitType(const Type& t) { + return t; +} -void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) { return; } +void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) { +} -void ExprVisitor::ExprVisitor::VisitExpr_(const GlobalVarNode* op) { return; } +void ExprVisitor::ExprVisitor::VisitExpr_(const GlobalVarNode* op) { +} -void ExprVisitor::ExprVisitor::VisitExpr_(const ConstantNode* op) { return; } +void ExprVisitor::ExprVisitor::VisitExpr_(const ConstantNode* op) { +} void ExprVisitor::ExprVisitor::VisitExpr_(const TupleNode* op) { for (auto field : op->fields) { @@ -202,4 +210,3 @@ void ExprVisitor::VisitType(const Type& t) { return; } } // namespace relay } // namespace tvm - diff --git a/src/relay/pass/type_visitor.h b/src/relay/pass/type_visitor.h index 725e3d9b..c37b536c 100644 --- a/src/relay/pass/type_visitor.h +++ b/src/relay/pass/type_visitor.h @@ -78,7 +78,8 @@ struct TypeMutator : TypeFunctor { Array type_constraints; for (auto type_cs : op->type_constraints) { auto new_type_cs = VisitType(type_cs); - if (const TypeConstraintNode* tin = As(new_type_cs)) { + if (const TypeConstraintNode* tin = + new_type_cs.as_derived()) { type_constraints.push_back(GetRef(tin)); } else { CHECK(false) << new_type_cs << std::endl; diff --git a/tests/cpp/expr_test.cc b/tests/cpp/expr_test.cc index 9cdfef7f..dca76205 100644 --- a/tests/cpp/expr_test.cc +++ b/tests/cpp/expr_test.cc @@ -20,7 +20,7 @@ TEST(ExprNodeRef, Basic) { Var x("x"); Expr z = max(x + 1 + 2, 100); const ir::Max* op = z.as(); - CHECK(op->GetNodeRef().same_as(z)); + CHECK(NodeRef(op->GetNodePtr()).same_as(z)); }