[NODE][RELAY] Move most of the reference related code to node (#1747)
This commit is contained in:
Родитель
1c2b0b656b
Коммит
ec0d497c69
|
@ -102,10 +102,10 @@ class TVM_DLL Node : public NodeBase {
|
|||
template<typename T>
|
||||
inline bool is_type() const;
|
||||
/*!
|
||||
* \brief Get a NodeRef that holds reference to this Node.
|
||||
* \return the NodeRef
|
||||
* \brief Get a NodePtr that holds reference to this Node.
|
||||
* \return the NodePtr
|
||||
*/
|
||||
inline NodeRef GetNodeRef() const;
|
||||
inline NodePtr<Node> GetNodePtr() const;
|
||||
// node ref can see this
|
||||
friend class NodeRef;
|
||||
static constexpr const char* _type_key = "Node";
|
||||
|
@ -176,6 +176,32 @@ class NodeRef {
|
|||
NodePtr<Node> node_;
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Get a reference type from a Node ptr type
|
||||
*
|
||||
* It is always important to get a reference type
|
||||
* if we want to return a value as reference or keep
|
||||
* the node alive beyond the scope of the function.
|
||||
*
|
||||
* \param ptr The node pointer
|
||||
* \tparam RefType The reference type
|
||||
* \tparam NodeType The node type
|
||||
* \return The corresponding RefType
|
||||
*/
|
||||
template <typename RefType, typename NodeType>
|
||||
inline RefType GetRef(const NodeType* ptr);
|
||||
|
||||
/*!
|
||||
* \brief Downcast a base reference type to a more specific type.
|
||||
*
|
||||
* \param ref The inptut reference
|
||||
* \return The corresponding SubRef.
|
||||
* \tparam SubRef The target specific reference type.
|
||||
* \tparam BaseRef the current reference type.
|
||||
*/
|
||||
template <typename SubRef, typename BaseRef>
|
||||
inline SubRef Downcast(BaseRef ref);
|
||||
|
||||
/*!
|
||||
* \brief helper macro to declare type information in a base node.
|
||||
*/
|
||||
|
@ -218,8 +244,24 @@ inline bool Node::derived_from() const {
|
|||
return this->_DerivedFrom(type_id);
|
||||
}
|
||||
|
||||
inline NodeRef Node::GetNodeRef() const {
|
||||
return NodeRef(NodePtr<Node>(const_cast<Node*>(this)));
|
||||
inline NodePtr<Node> Node::GetNodePtr() const {
|
||||
return NodePtr<Node>(const_cast<Node*>(this));
|
||||
}
|
||||
|
||||
template <typename RefType, typename NodeType>
|
||||
inline RefType GetRef(const NodeType* ptr) {
|
||||
static_assert(std::is_base_of<typename RefType::ContainerType, NodeType>::value,
|
||||
"Can only cast to the ref of same container type");
|
||||
return RefType(ptr->GetNodePtr());
|
||||
}
|
||||
|
||||
template <typename SubRef, typename BaseRef>
|
||||
inline SubRef Downcast(BaseRef ref) {
|
||||
CHECK(ref->template is_type<typename SubRef::ContainerType>() ||
|
||||
ref->template derived_from<typename SubRef::ContainerType>())
|
||||
<< "Downcast from " << ref->type_key() << " to "
|
||||
<< SubRef::ContainerType::_type_key << " failed.";
|
||||
return SubRef(std::move(ref.node_));
|
||||
}
|
||||
|
||||
inline const Node* NodeRef::get() const {
|
||||
|
|
|
@ -158,43 +158,6 @@ class RelayNode : public Node {
|
|||
TVM_DECLARE_BASE_NODE_INFO(RelayNode, Node);
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Get a reference type from a Node ptr type
|
||||
*
|
||||
* It is always important to get a reference type
|
||||
* if we want to return a value as reference or keep
|
||||
* the node alive beyond the scope of the function.
|
||||
*
|
||||
* \param ptr The node pointer
|
||||
* \tparam RefType The reference type
|
||||
* \tparam NodeType The node type
|
||||
* \return The corresponding RefType
|
||||
*/
|
||||
template <typename RefType, typename NodeType>
|
||||
RefType GetRef(const NodeType* ptr) {
|
||||
static_assert(std::is_same<typename RefType::ContainerType, NodeType>::value,
|
||||
"Can only cast to the ref of same container type");
|
||||
return RefType(std::move(ptr->GetNodeRef().node_));
|
||||
}
|
||||
|
||||
// TODO(@tqchen, @jroesch): can we move these semantics to HalideIR
|
||||
template <typename T>
|
||||
inline const T* As(const NodeRef& node) {
|
||||
const Node* ptr = static_cast<const Node*>(node.get());
|
||||
if (ptr && (ptr->is_type<T>() || ptr->derived_from<T>())) {
|
||||
return static_cast<const T*>(ptr);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <typename SubRef, typename BaseRef>
|
||||
SubRef Downcast(BaseRef ref) {
|
||||
CHECK(ref->template is_type<typename SubRef::ContainerType>())
|
||||
<< "Downcast from " << ref->type_key() << " to "
|
||||
<< SubRef::ContainerType::_type_key << " failed.";
|
||||
return SubRef(ref.node_);
|
||||
}
|
||||
|
||||
} // namespace relay
|
||||
} // namespace tvm
|
||||
|
||||
|
|
|
@ -65,7 +65,9 @@ class ConstantNode : public ExprNode {
|
|||
TensorType tensor_type() const;
|
||||
|
||||
/*! \return Whether it is scalar(rank-0 tensor) */
|
||||
bool is_scalar() const { return data->ndim == 0; }
|
||||
bool is_scalar() const {
|
||||
return data->ndim == 0;
|
||||
}
|
||||
|
||||
void VisitAttrs(tvm::AttrVisitor* v) final {
|
||||
v->Visit("data", &data);
|
||||
|
@ -341,7 +343,7 @@ RELAY_DEFINE_NODE_REF(Let, LetNode, Expr);
|
|||
*
|
||||
* let x = if (true) { 1 } else { 0 }; // x is 1
|
||||
* let y = if (false) { 1 } else { 0 }; // y is 0
|
||||
*
|
||||
*
|
||||
* \note This is similar to C's ternary operator.
|
||||
*/
|
||||
class If;
|
||||
|
|
|
@ -139,19 +139,19 @@ class ExprVisitor : public ::tvm::relay::ExprFunctor<void(const Expr& n)> {
|
|||
* the cost of using functional updates.
|
||||
*/
|
||||
class ExprMutator
|
||||
: public ::tvm::relay::ExprFunctor<Expr(const Expr&, const Expr&)> {
|
||||
: public ::tvm::relay::ExprFunctor<Expr(const Expr&)> {
|
||||
public:
|
||||
Expr Mutate(const Expr& expr);
|
||||
Expr VisitExpr_(const VarNode* op, const Expr& e) override;
|
||||
Expr VisitExpr_(const ConstantNode* op, const Expr& e) override;
|
||||
Expr VisitExpr_(const GlobalVarNode* op, const Expr& e) override;
|
||||
Expr VisitExpr_(const OpNode* op, const Expr& expr) override;
|
||||
Expr VisitExpr_(const TupleNode* op, const Expr& e) override;
|
||||
Expr VisitExpr_(const ParamNode* op, const Expr& e) override;
|
||||
Expr VisitExpr_(const FunctionNode* op, const Expr& e) override;
|
||||
Expr VisitExpr_(const CallNode* call_node, const Expr& e) override;
|
||||
Expr VisitExpr_(const LetNode* op, const Expr& e) override;
|
||||
Expr VisitExpr_(const IfNode* op, const Expr& e) override;
|
||||
Expr VisitExpr_(const VarNode* op) override;
|
||||
Expr VisitExpr_(const ConstantNode* op) override;
|
||||
Expr VisitExpr_(const GlobalVarNode* op) override;
|
||||
Expr VisitExpr_(const OpNode* op) override;
|
||||
Expr VisitExpr_(const TupleNode* op) override;
|
||||
Expr VisitExpr_(const ParamNode* op) override;
|
||||
Expr VisitExpr_(const FunctionNode* op) override;
|
||||
Expr VisitExpr_(const CallNode* call_node) override;
|
||||
Expr VisitExpr_(const LetNode* op) override;
|
||||
Expr VisitExpr_(const IfNode* op) override;
|
||||
/*! \brief Used to visit the types inside of expressions.
|
||||
*
|
||||
* Can be overloaded to transform the types in arbitrary
|
||||
|
@ -162,7 +162,7 @@ class ExprMutator
|
|||
|
||||
private:
|
||||
/*! \brief Internal map used for memoization. */
|
||||
tvm::Map<Expr, Expr> memo_;
|
||||
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo_;
|
||||
};
|
||||
|
||||
} // namespace relay
|
||||
|
|
|
@ -41,12 +41,12 @@ void EnvironmentNode::Add(const GlobalVar &var,
|
|||
const Function &func,
|
||||
bool update) {
|
||||
// Type check the item before we add it to the environment.
|
||||
auto env = relay::GetRef<Environment>(this);
|
||||
auto env = GetRef<Environment>(this);
|
||||
|
||||
Expr checked_expr = InferType(env, var, func);
|
||||
|
||||
if (const FunctionNode *func_node = checked_expr.as<FunctionNode>()) {
|
||||
auto checked_func = relay::GetRef<Function>(func_node);
|
||||
auto checked_func = GetRef<Function>(func_node);
|
||||
auto type = checked_func->checked_type();
|
||||
|
||||
CHECK(IsFullyResolved(type));
|
||||
|
|
|
@ -13,33 +13,33 @@ namespace tvm {
|
|||
namespace relay {
|
||||
|
||||
Expr ExprMutator::Mutate(const Expr& expr) {
|
||||
auto cached_expr = this->memo_.find(expr);
|
||||
if (cached_expr != this->memo_.end()) {
|
||||
return (*cached_expr).second;
|
||||
auto it = this->memo_.find(expr);
|
||||
if (it != this->memo_.end()) {
|
||||
return it->second;
|
||||
} else {
|
||||
auto new_expr = this->ExprMutator::VisitExpr(expr, expr);
|
||||
this->memo_.Set(expr, new_expr);
|
||||
Expr new_expr = ExprMutator::VisitExpr(expr);
|
||||
memo_[expr] = new_expr;
|
||||
return new_expr;
|
||||
}
|
||||
}
|
||||
|
||||
Expr ExprMutator::VisitExpr_(const VarNode* op, const Expr& expr) {
|
||||
return expr;
|
||||
Expr ExprMutator::VisitExpr_(const VarNode* op) {
|
||||
return GetRef<Expr>(op);
|
||||
}
|
||||
|
||||
Expr ExprMutator::VisitExpr_(const ConstantNode* op, const Expr& expr) {
|
||||
return expr;
|
||||
Expr ExprMutator::VisitExpr_(const ConstantNode* op) {
|
||||
return GetRef<Expr>(op);
|
||||
}
|
||||
|
||||
Expr ExprMutator::VisitExpr_(const GlobalVarNode* op, const Expr& expr) {
|
||||
return expr;
|
||||
Expr ExprMutator::VisitExpr_(const GlobalVarNode* op) {
|
||||
return GetRef<Expr>(op);
|
||||
}
|
||||
|
||||
Expr ExprMutator::VisitExpr_(const OpNode* op, const Expr& expr) {
|
||||
return expr;
|
||||
Expr ExprMutator::VisitExpr_(const OpNode* op) {
|
||||
return GetRef<Expr>(op);
|
||||
}
|
||||
|
||||
Expr ExprMutator::VisitExpr_(const TupleNode* op, const Expr& e) {
|
||||
Expr ExprMutator::VisitExpr_(const TupleNode* op) {
|
||||
tvm::Array<Expr> fields;
|
||||
bool all_fields_unchanged = true;
|
||||
for (auto field : op->fields) {
|
||||
|
@ -49,23 +49,23 @@ Expr ExprMutator::VisitExpr_(const TupleNode* op, const Expr& e) {
|
|||
}
|
||||
|
||||
if (all_fields_unchanged) {
|
||||
return e;
|
||||
return GetRef<Expr>(op);
|
||||
} else {
|
||||
return TupleNode::make(fields);
|
||||
}
|
||||
}
|
||||
|
||||
Expr ExprMutator::VisitExpr_(const ParamNode* op, const Expr& e) {
|
||||
Expr ExprMutator::VisitExpr_(const ParamNode* op) {
|
||||
Var var = Downcast<Var>(this->Mutate(op->var));
|
||||
auto type = this->VisitType(op->type);
|
||||
if (var == op->var && type == op->type) {
|
||||
return e;
|
||||
if (op->var.same_as(var) && op->type.same_as(type)) {
|
||||
return GetRef<Expr>(op);
|
||||
} else {
|
||||
return ParamNode::make(var, type);
|
||||
}
|
||||
}
|
||||
|
||||
Expr ExprMutator::VisitExpr_(const FunctionNode* op, const Expr& e) {
|
||||
Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
|
||||
tvm::Array<TypeParam> ty_params;
|
||||
bool all_ty_params_changed = true;
|
||||
|
||||
|
@ -86,74 +86,82 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op, const Expr& e) {
|
|||
auto ret_type = this->VisitType(op->ret_type);
|
||||
auto body = this->Mutate(op->body);
|
||||
|
||||
if (ty_params.same_as(op->type_params) && params.same_as(op->params) &&
|
||||
ret_type.same_as(op->ret_type) && body.same_as(op->body)) {
|
||||
return e;
|
||||
if (ty_params.same_as(op->type_params) &&
|
||||
params.same_as(op->params) &&
|
||||
ret_type.same_as(op->ret_type) &&
|
||||
body.same_as(op->body)) {
|
||||
return GetRef<Expr>(op);
|
||||
} else {
|
||||
return FunctionNode::make(params, ret_type, body, ty_params);
|
||||
}
|
||||
}
|
||||
|
||||
Expr ExprMutator::VisitExpr_(const CallNode* call_node, const Expr& e) {
|
||||
auto op = this->Mutate(call_node->op);
|
||||
Expr ExprMutator::VisitExpr_(const CallNode* call_node) {
|
||||
auto new_op = this->Mutate(call_node->op);
|
||||
bool unchanged = call_node->op.same_as(new_op);
|
||||
|
||||
tvm::Array<Type> ty_args;
|
||||
bool all_ty_args_unchanged = true;
|
||||
for (auto ty_arg : call_node->type_args) {
|
||||
auto new_ty_arg = this->VisitType(ty_arg);
|
||||
ty_args.push_back(new_ty_arg);
|
||||
all_ty_args_unchanged &= new_ty_arg.same_as(ty_arg);
|
||||
unchanged &= new_ty_arg.same_as(ty_arg);
|
||||
}
|
||||
|
||||
tvm::Array<Expr> call_args;
|
||||
bool all_args_unchanged = true;
|
||||
for (auto arg : call_node->args) {
|
||||
auto new_arg = this->Mutate(arg);
|
||||
call_args.push_back(new_arg);
|
||||
all_args_unchanged &= new_arg.same_as(arg);
|
||||
unchanged &= new_arg.same_as(arg);
|
||||
}
|
||||
|
||||
if (all_ty_args_unchanged && all_args_unchanged &&
|
||||
call_node->op.same_as(op)) {
|
||||
return e;
|
||||
if (unchanged) {
|
||||
return GetRef<Expr>(call_node);
|
||||
} else {
|
||||
return CallNode::make(op, call_args, call_node->attrs, ty_args);
|
||||
return CallNode::make(new_op, call_args, call_node->attrs, ty_args);
|
||||
}
|
||||
}
|
||||
|
||||
Expr ExprMutator::VisitExpr_(const LetNode* op, const Expr& e) {
|
||||
Expr ExprMutator::VisitExpr_(const LetNode* op) {
|
||||
Var var = Downcast<Var>(this->Mutate(op->var));
|
||||
auto type = this->VisitType(op->value_type);
|
||||
auto value = this->Mutate(op->value);
|
||||
auto body = this->Mutate(op->body);
|
||||
|
||||
if (var.same_as(op->var) && type.same_as(op->value_type) &&
|
||||
value.same_as(op->value) && body.same_as(op->body)) {
|
||||
return e;
|
||||
if (var.same_as(op->var) &&
|
||||
type.same_as(op->value_type) &&
|
||||
value.same_as(op->value) &&
|
||||
body.same_as(op->body)) {
|
||||
return GetRef<Expr>(op);
|
||||
} else {
|
||||
return LetNode::make(var, value, body, type);
|
||||
}
|
||||
}
|
||||
|
||||
Expr ExprMutator::VisitExpr_(const IfNode* op, const Expr& e) {
|
||||
Expr ExprMutator::VisitExpr_(const IfNode* op) {
|
||||
auto guard = this->Mutate(op->cond);
|
||||
auto true_b = this->Mutate(op->true_branch);
|
||||
auto false_b = this->Mutate(op->false_branch);
|
||||
if (op->cond == guard && true_b == op->true_branch &&
|
||||
false_b == op->false_branch) {
|
||||
return e;
|
||||
if (op->cond.same_as(guard) &&
|
||||
op->true_branch.same_as(true_b) &&
|
||||
op->false_branch.same_as(false_b)) {
|
||||
return GetRef<Expr>(op);;
|
||||
} else {
|
||||
return IfNode::make(guard, true_b, false_b);
|
||||
}
|
||||
}
|
||||
|
||||
Type ExprMutator::VisitType(const Type& t) { return t; }
|
||||
Type ExprMutator::VisitType(const Type& t) {
|
||||
return t;
|
||||
}
|
||||
|
||||
void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) { return; }
|
||||
void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) {
|
||||
}
|
||||
|
||||
void ExprVisitor::ExprVisitor::VisitExpr_(const GlobalVarNode* op) { return; }
|
||||
void ExprVisitor::ExprVisitor::VisitExpr_(const GlobalVarNode* op) {
|
||||
}
|
||||
|
||||
void ExprVisitor::ExprVisitor::VisitExpr_(const ConstantNode* op) { return; }
|
||||
void ExprVisitor::ExprVisitor::VisitExpr_(const ConstantNode* op) {
|
||||
}
|
||||
|
||||
void ExprVisitor::ExprVisitor::VisitExpr_(const TupleNode* op) {
|
||||
for (auto field : op->fields) {
|
||||
|
@ -202,4 +210,3 @@ void ExprVisitor::VisitType(const Type& t) { return; }
|
|||
|
||||
} // namespace relay
|
||||
} // namespace tvm
|
||||
|
||||
|
|
|
@ -78,7 +78,8 @@ struct TypeMutator : TypeFunctor<Type(const Type& n)> {
|
|||
Array<TypeConstraint> type_constraints;
|
||||
for (auto type_cs : op->type_constraints) {
|
||||
auto new_type_cs = VisitType(type_cs);
|
||||
if (const TypeConstraintNode* tin = As<TypeConstraintNode>(new_type_cs)) {
|
||||
if (const TypeConstraintNode* tin =
|
||||
new_type_cs.as_derived<TypeConstraintNode>()) {
|
||||
type_constraints.push_back(GetRef<TypeConstraint>(tin));
|
||||
} else {
|
||||
CHECK(false) << new_type_cs << std::endl;
|
||||
|
|
|
@ -20,7 +20,7 @@ TEST(ExprNodeRef, Basic) {
|
|||
Var x("x");
|
||||
Expr z = max(x + 1 + 2, 100);
|
||||
const ir::Max* op = z.as<ir::Max>();
|
||||
CHECK(op->GetNodeRef().same_as(z));
|
||||
CHECK(NodeRef(op->GetNodePtr()).same_as(z));
|
||||
}
|
||||
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче