[RELAY][IR] Introduce IdNode to preserve var id across rewriting (#2178)
This commit is contained in:
Родитель
246a38a1db
Коммит
7af48f1aa2
|
@ -165,6 +165,34 @@ class RelayNode : public Node {
|
|||
TVM_DECLARE_BASE_NODE_INFO(RelayNode, Node);
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief The unique identifier of variables.
|
||||
*
|
||||
* Id is like name to the variables,
|
||||
* except that id is unique for each Var.
|
||||
*
|
||||
* \note Do not create Id directly, they are created in Var.
|
||||
*/
|
||||
class IdNode : public Node {
|
||||
public:
|
||||
/*!
|
||||
* \brief The name of the variable,
|
||||
* this only acts as a hint to the user,
|
||||
* and is not used for equality.
|
||||
*/
|
||||
std::string name_hint;
|
||||
|
||||
void VisitAttrs(tvm::AttrVisitor* v) final {
|
||||
v->Visit("name_hint", &name_hint);
|
||||
}
|
||||
|
||||
static constexpr const char* _type_key = "relay.Id";
|
||||
TVM_DECLARE_NODE_TYPE_INFO(IdNode, Node);
|
||||
};
|
||||
|
||||
RELAY_DEFINE_NODE_REF(Id, IdNode, NodeRef);
|
||||
|
||||
|
||||
struct Module;
|
||||
|
||||
} // namespace relay
|
||||
|
|
|
@ -124,18 +124,22 @@ RELAY_DEFINE_NODE_REF(Tuple, TupleNode, Expr);
|
|||
* Its semantics are similar to tvm.Var node used in TVM's low level
|
||||
* tensor expression language.
|
||||
*
|
||||
* \note Each Var is bind only once and is immutable/
|
||||
* \note Each Var is bind only once and is immutable.
|
||||
*/
|
||||
class Var;
|
||||
/*! \brief Container for Var */
|
||||
class VarNode : public ExprNode {
|
||||
public:
|
||||
/*!
|
||||
* \brief The name of the variable,
|
||||
* this only acts as a hint to the user,
|
||||
* and is not used for equality.
|
||||
* \brief The unique identifier of the Var.
|
||||
*
|
||||
* vid will be preserved for the same Var during type inference
|
||||
* and other rewritings, while the VarNode might be recreated
|
||||
* to attach additional information.
|
||||
* This property can be used to keep track of parameter Var
|
||||
* information across passes.
|
||||
*/
|
||||
std::string name_hint;
|
||||
Id vid;
|
||||
/*!
|
||||
* \brief type annotaion of the variable.
|
||||
* This field records user provided type annotation of the Var.
|
||||
|
@ -143,8 +147,13 @@ class VarNode : public ExprNode {
|
|||
*/
|
||||
Type type_annotation;
|
||||
|
||||
/*! \return The name hint of the variable */
|
||||
const std::string& name_hint() const {
|
||||
return vid->name_hint;
|
||||
}
|
||||
|
||||
void VisitAttrs(tvm::AttrVisitor* v) final {
|
||||
v->Visit("name_hint", &name_hint);
|
||||
v->Visit("vid", &vid);
|
||||
v->Visit("type_annotation", &type_annotation);
|
||||
v->Visit("span", &span);
|
||||
v->Visit("_checked_type_", &checked_type_);
|
||||
|
@ -153,6 +162,9 @@ class VarNode : public ExprNode {
|
|||
TVM_DLL static Var make(std::string name_hint,
|
||||
Type type_annotation);
|
||||
|
||||
TVM_DLL static Var make(Id vid,
|
||||
Type type_annotation);
|
||||
|
||||
static constexpr const char* _type_key = "relay.Var";
|
||||
TVM_DECLARE_NODE_TYPE_INFO(VarNode, ExprNode);
|
||||
};
|
||||
|
|
|
@ -54,3 +54,10 @@ class RelayNode(NodeBase):
|
|||
class Span(RelayNode):
|
||||
def __init__(self, source, lineno, col_offset):
|
||||
self.__init_handle_by_constructor__(_make.Span, source, lineno, col_offset)
|
||||
|
||||
|
||||
@register_relay_node
|
||||
class Id(NodeBase):
|
||||
"""Unique identifier(name) for Var across type checking."""
|
||||
def __init__(self):
|
||||
raise RuntimeError("Cannot directly construct Id")
|
||||
|
|
|
@ -166,6 +166,12 @@ class Var(Expr):
|
|||
self.__init_handle_by_constructor__(
|
||||
_make.Var, name_hint, type_annotation)
|
||||
|
||||
@property
|
||||
def name_hint(self):
|
||||
"""Get name hint of the current var."""
|
||||
name = self.vid.name_hint
|
||||
return name
|
||||
|
||||
|
||||
@register_relay_node
|
||||
class GlobalVar(Expr):
|
||||
|
|
|
@ -99,7 +99,7 @@ class ScheduleGetter :
|
|||
}
|
||||
|
||||
Array<Tensor> VisitExpr_(const VarNode* op) final {
|
||||
LOG(FATAL) << "Free variable " << op->name_hint;
|
||||
LOG(FATAL) << "Free variable " << op->name_hint();
|
||||
return {};
|
||||
}
|
||||
|
||||
|
|
|
@ -240,8 +240,9 @@ class AlphaEqualHandler:
|
|||
}
|
||||
|
||||
bool VisitExpr_(const VarNode* lhs, const Expr& other) final {
|
||||
// This function will only be triggered if we are matching free variables.
|
||||
if (const VarNode* rhs = other.as<VarNode>()) {
|
||||
if (lhs->name_hint != rhs->name_hint) return false;
|
||||
if (lhs->name_hint() != rhs->name_hint()) return false;
|
||||
if (!TypeEqual(lhs->type_annotation, rhs->type_annotation)) return false;
|
||||
return LeafNodeEqual(GetRef<NodeRef>(lhs), other);
|
||||
} else {
|
||||
|
|
|
@ -64,7 +64,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
|
|||
<< node->col_offset << ")";
|
||||
});
|
||||
|
||||
TVM_REGISTER_NODE_TYPE(IdNode);
|
||||
|
||||
} // namespace relay
|
||||
} // namespace tvm
|
||||
|
||||
|
||||
|
|
|
@ -63,23 +63,30 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
|
|||
p->stream << "Tuple(" << node->fields << ")";
|
||||
});
|
||||
|
||||
Var VarNode::make(std::string name_hint, Type type_annotation) {
|
||||
|
||||
Var VarNode::make(Id vid, Type type_annotation) {
|
||||
NodePtr<VarNode> n = make_node<VarNode>();
|
||||
n->name_hint = std::move(name_hint);
|
||||
n->vid = std::move(vid);
|
||||
n->type_annotation = std::move(type_annotation);
|
||||
return Var(n);
|
||||
}
|
||||
|
||||
Var VarNode::make(std::string name_hint, Type type_annotation) {
|
||||
NodePtr<IdNode> n = make_node<IdNode>();
|
||||
n->name_hint = std::move(name_hint);
|
||||
return VarNode::make(Id(n), type_annotation);
|
||||
}
|
||||
|
||||
TVM_REGISTER_NODE_TYPE(VarNode);
|
||||
|
||||
TVM_REGISTER_API("relay._make.Var")
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
*ret = VarNode::make(args[0], args[1]);
|
||||
*ret = VarNode::make(args[0].operator std::string(), args[1]);
|
||||
});
|
||||
|
||||
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
|
||||
.set_dispatch<VarNode>([](const VarNode* node, tvm::IRPrinter* p) {
|
||||
p->stream << "Var(" << node->name_hint;
|
||||
p->stream << "Var(" << node->name_hint();
|
||||
if (node->type_annotation.defined()) {
|
||||
p->stream << ", ty=";
|
||||
p->print(node->type_annotation);
|
||||
|
|
|
@ -30,7 +30,7 @@ Expr ExprMutator::VisitExpr_(const VarNode* op) {
|
|||
if (op->type_annotation.defined()) {
|
||||
auto type = this->VisitType(op->type_annotation);
|
||||
if (!op->type_annotation.same_as(type)) {
|
||||
return VarNode::make(op->name_hint, type);
|
||||
return VarNode::make(op->vid, type);
|
||||
}
|
||||
}
|
||||
// default case return self.
|
||||
|
|
|
@ -202,7 +202,8 @@ class RelayHashHandler:
|
|||
}
|
||||
|
||||
size_t VisitExpr_(const VarNode* var) final {
|
||||
size_t name_hash = std::hash<std::string>()(var->name_hint);
|
||||
// hash free variable
|
||||
size_t name_hash = std::hash<const Node*>()(var->vid.get());
|
||||
return Combine(name_hash, TypeHash(var->type_annotation));
|
||||
}
|
||||
|
||||
|
|
|
@ -690,7 +690,7 @@ class TextPrinter :
|
|||
* \return The corresponding name.
|
||||
*/
|
||||
TextValue AllocVarName(const Var& var) {
|
||||
std::string name = var->name_hint;
|
||||
std::string name = var->name_hint();
|
||||
// always make sure first name is alpha
|
||||
if (name.length() != 0 && !std::isalpha(name[0])) {
|
||||
name = "%v" + name;
|
||||
|
|
|
@ -141,6 +141,7 @@ def test_free_expr():
|
|||
y = relay.add(x, x)
|
||||
yy = relay.ir_pass.infer_type(y)
|
||||
assert yy.checked_type == relay.scalar_type("float32")
|
||||
assert x.vid.same_as(yy.args[0].vid)
|
||||
|
||||
|
||||
def test_type_args():
|
||||
|
|
Загрузка…
Ссылка в новой задаче