[RELAY][IR] Move type_annotation to Var, remove Param (#1900)
This commit is contained in:
Родитель
53428606c8
Коммит
0b4cc05050
|
@ -118,17 +118,27 @@ class Var;
|
|||
/*! \brief Container for Var */
|
||||
class VarNode : public ExprNode {
|
||||
public:
|
||||
/*! \brief The name of the variable, this only acts as a hint to the user,
|
||||
* and is not used for equality.
|
||||
/*!
|
||||
* \brief The name of the variable,
|
||||
* this only acts as a hint to the user,
|
||||
* and is not used for equality.
|
||||
*/
|
||||
std::string name_hint;
|
||||
/*!
|
||||
* \brief type annotaion of the variable.
|
||||
* This field records user provided type annotation of the Var.
|
||||
* This field is optional and can be None.
|
||||
*/
|
||||
Type type_annotation;
|
||||
|
||||
void VisitAttrs(tvm::AttrVisitor* v) final {
|
||||
v->Visit("name_hint", &name_hint);
|
||||
v->Visit("type_annotation", &type_annotation);
|
||||
v->Visit("_checked_type_", &checked_type_);
|
||||
}
|
||||
|
||||
TVM_DLL static Var make(std::string name_hint);
|
||||
TVM_DLL static Var make(std::string name_hint,
|
||||
Type type_annotation);
|
||||
|
||||
static constexpr const char* _type_key = "relay.Var";
|
||||
TVM_DECLARE_NODE_TYPE_INFO(VarNode, ExprNode);
|
||||
|
@ -162,32 +172,6 @@ class GlobalVarNode : public ExprNode {
|
|||
|
||||
RELAY_DEFINE_NODE_REF(GlobalVar, GlobalVarNode, Expr);
|
||||
|
||||
/*!
|
||||
* \brief Function parameter declaration.
|
||||
*/
|
||||
class Param;
|
||||
/*! \brief A parameter. */
|
||||
class ParamNode : public ExprNode {
|
||||
public:
|
||||
/*! \brief The variable */
|
||||
Var var;
|
||||
/*! \brief The type of the parameter */
|
||||
Type type;
|
||||
|
||||
void VisitAttrs(tvm::AttrVisitor* v) final {
|
||||
v->Visit("var", &var);
|
||||
v->Visit("type", &type);
|
||||
v->Visit("span", &span);
|
||||
}
|
||||
|
||||
TVM_DLL static Param make(Var var, Type type);
|
||||
|
||||
static constexpr const char* _type_key = "relay.Param";
|
||||
TVM_DECLARE_NODE_TYPE_INFO(ParamNode, ExprNode);
|
||||
};
|
||||
|
||||
RELAY_DEFINE_NODE_REF(Param, ParamNode, Expr);
|
||||
|
||||
/*!
|
||||
* \brief Function (subgraph in computational graph)
|
||||
*/
|
||||
|
@ -196,7 +180,7 @@ class Function;
|
|||
class FunctionNode : public ExprNode {
|
||||
public:
|
||||
/*! \brief Function parameters */
|
||||
tvm::Array<Param> params;
|
||||
tvm::Array<Var> params;
|
||||
/*! \brief User annotated return type of the function. */
|
||||
Type ret_type;
|
||||
/*!
|
||||
|
@ -224,10 +208,18 @@ class FunctionNode : public ExprNode {
|
|||
v->Visit("_checked_type_", &checked_type_);
|
||||
}
|
||||
|
||||
Type fn_type() const;
|
||||
/*!
|
||||
* \brief Return the derived function annotation of this expression.
|
||||
*
|
||||
* \return The function type annotation.
|
||||
* \note The function type annotation can contain IncompleteType.
|
||||
*/
|
||||
TVM_DLL FuncType func_type_annotation() const;
|
||||
|
||||
TVM_DLL static Function make(tvm::Array<Param> params, Type ret_type,
|
||||
Expr body, tvm::Array<TypeParam> ty_params);
|
||||
TVM_DLL static Function make(tvm::Array<Var> params,
|
||||
Type ret_type,
|
||||
Expr body,
|
||||
tvm::Array<TypeParam> ty_params);
|
||||
|
||||
static constexpr const char* _type_key = "relay.Function";
|
||||
TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode);
|
||||
|
@ -289,7 +281,7 @@ class CallNode : public ExprNode {
|
|||
TVM_DLL static Call make(Expr op,
|
||||
Array<Expr> args,
|
||||
Attrs attrs = Attrs(),
|
||||
Array<Type> ty_args = Array<Type>());
|
||||
Array<Type> type_args = Array<Type>());
|
||||
|
||||
static constexpr const char* _type_key = "relay.Call";
|
||||
TVM_DECLARE_NODE_TYPE_INFO(CallNode, ExprNode);
|
||||
|
@ -318,19 +310,16 @@ class LetNode : public ExprNode {
|
|||
Expr value;
|
||||
/*! \brief The body of the let binding */
|
||||
Expr body;
|
||||
/*! \brief Type annotation of value, this can be null */
|
||||
Type value_type;
|
||||
|
||||
void VisitAttrs(tvm::AttrVisitor* v) final {
|
||||
v->Visit("var", &var);
|
||||
v->Visit("value", &value);
|
||||
v->Visit("body", &body);
|
||||
v->Visit("value_type", &value_type);
|
||||
v->Visit("span", &span);
|
||||
v->Visit("_checked_type_", &checked_type_);
|
||||
}
|
||||
|
||||
TVM_DLL static Let make(Var var, Expr value, Expr body, Type value_type);
|
||||
TVM_DLL static Let make(Var var, Expr value, Expr body);
|
||||
|
||||
static constexpr const char* _type_key = "relay.Let";
|
||||
TVM_DECLARE_NODE_TYPE_INFO(LetNode, ExprNode);
|
||||
|
@ -376,11 +365,11 @@ class IfNode : public ExprNode {
|
|||
|
||||
RELAY_DEFINE_NODE_REF(If, IfNode, Expr);
|
||||
|
||||
/*! \brief Get a field out of a tuple. */
|
||||
/*! \brief Get index-th field out of a tuple. */
|
||||
class TupleGetItem;
|
||||
class TupleGetItemNode : public ExprNode {
|
||||
public:
|
||||
/*! \brief The tuple */
|
||||
/*! \brief The tuple Expression */
|
||||
Expr tuple;
|
||||
/*! \brief which value to get */
|
||||
int index;
|
||||
|
|
|
@ -80,7 +80,6 @@ class ExprFunctor<R(const Expr& n, Args...)> {
|
|||
Args... args) EXPR_FUNCTOR_DEFAULT;
|
||||
virtual R VisitExpr_(const GlobalVarNode* op,
|
||||
Args... args) EXPR_FUNCTOR_DEFAULT;
|
||||
virtual R VisitExpr_(const ParamNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
|
||||
virtual R VisitExpr_(const FunctionNode* op,
|
||||
Args... args) EXPR_FUNCTOR_DEFAULT;
|
||||
virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
|
||||
|
@ -103,7 +102,6 @@ class ExprFunctor<R(const Expr& n, Args...)> {
|
|||
RELAY_EXPR_FUNCTOR_DISPATCH(TupleNode);
|
||||
RELAY_EXPR_FUNCTOR_DISPATCH(VarNode);
|
||||
RELAY_EXPR_FUNCTOR_DISPATCH(GlobalVarNode);
|
||||
RELAY_EXPR_FUNCTOR_DISPATCH(ParamNode);
|
||||
RELAY_EXPR_FUNCTOR_DISPATCH(FunctionNode);
|
||||
RELAY_EXPR_FUNCTOR_DISPATCH(CallNode);
|
||||
RELAY_EXPR_FUNCTOR_DISPATCH(LetNode);
|
||||
|
@ -127,7 +125,6 @@ class ExprVisitor : public ::tvm::relay::ExprFunctor<void(const Expr& n)> {
|
|||
void VisitExpr_(const GlobalVarNode* op) override;
|
||||
void VisitExpr_(const ConstantNode* op) override;
|
||||
void VisitExpr_(const TupleNode* op) override;
|
||||
void VisitExpr_(const ParamNode* op) override;
|
||||
void VisitExpr_(const FunctionNode* op) override;
|
||||
void VisitExpr_(const CallNode* op) override;
|
||||
void VisitExpr_(const LetNode* op) override;
|
||||
|
@ -151,7 +148,6 @@ class ExprMutator
|
|||
Expr VisitExpr_(const GlobalVarNode* op) override;
|
||||
Expr VisitExpr_(const OpNode* op) override;
|
||||
Expr VisitExpr_(const TupleNode* op) override;
|
||||
Expr VisitExpr_(const ParamNode* op) override;
|
||||
Expr VisitExpr_(const FunctionNode* op) override;
|
||||
Expr VisitExpr_(const CallNode* call_node) override;
|
||||
Expr VisitExpr_(const LetNode* op) override;
|
||||
|
|
|
@ -34,7 +34,6 @@ Constant = expr.Constant
|
|||
Tuple = expr.Tuple
|
||||
Var = expr.Var
|
||||
GlobalVar = expr.GlobalVar
|
||||
Param = expr.Param
|
||||
Function = expr.Function
|
||||
Call = expr.Call
|
||||
Let = expr.Let
|
||||
|
|
|
@ -11,11 +11,11 @@ class Expr(NodeBase):
|
|||
"""The base type for all Relay expressions."""
|
||||
@property
|
||||
def checked_type(self):
|
||||
"""Get the checked type of relay.
|
||||
"""Get the checked type of tvm.relay.Expr.
|
||||
|
||||
Returns
|
||||
-------
|
||||
checked_type : relay.Type
|
||||
checked_type : tvm.relay.Type
|
||||
The checked type.
|
||||
"""
|
||||
ret = self._checked_type_
|
||||
|
@ -25,70 +25,97 @@ class Expr(NodeBase):
|
|||
return ret
|
||||
|
||||
def __call__(self, *args):
|
||||
converted_args = []
|
||||
for arg in args:
|
||||
if isinstance(arg, Param):
|
||||
converted_args.append(arg.var)
|
||||
else:
|
||||
converted_args.append(arg)
|
||||
|
||||
return Call(self, args, None, None)
|
||||
|
||||
|
||||
@register_relay_node
|
||||
class Constant(Expr):
|
||||
"""A constant tensor in Relay, see tvm/relay/type.h for more details.
|
||||
"""
|
||||
"""A constant expression in Relay.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : tvm.nd.NDArray
|
||||
The data content of the constant expression.
|
||||
"""
|
||||
def __init__(self, data):
|
||||
self.__init_handle_by_constructor__(_make.Constant, data)
|
||||
|
||||
|
||||
@register_relay_node
|
||||
class Tuple(Expr):
|
||||
"""A hetereogenous sequence of values.
|
||||
see tvm/relay/type.h for more details.
|
||||
"""
|
||||
"""Tuple expression that groups several fields together.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fields : List[tvm.relay.Expr]
|
||||
The fields in the tuple.
|
||||
"""
|
||||
def __init__(self, fields):
|
||||
self.__init_handle_by_constructor__(_make.Tuple, fields)
|
||||
|
||||
|
||||
@register_relay_node
|
||||
class Var(Expr):
|
||||
"""A local variable in Relay."""
|
||||
"""A local variable in Tvm.Relay.
|
||||
|
||||
def __init__(self, name_hint):
|
||||
self.__init_handle_by_constructor__(_make.Var, name_hint)
|
||||
Local variable can be used to declare input
|
||||
arguments to a function, or intermediate variables.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name_hint: str
|
||||
The name of the variable.
|
||||
This name only acts as a hint, and is not used
|
||||
for equality.
|
||||
|
||||
type_annotation: tvm.relay.Type, optional
|
||||
The type annotation on the variable.
|
||||
"""
|
||||
def __init__(self, name_hint, type_annotation=None):
|
||||
self.__init_handle_by_constructor__(
|
||||
_make.Var, name_hint, type_annotation)
|
||||
|
||||
|
||||
@register_relay_node
|
||||
class GlobalVar(Expr):
|
||||
"""A global variable in Relay."""
|
||||
"""A global variable in Tvm.Relay.
|
||||
|
||||
GlobalVar is used to refer to the global functions
|
||||
stored in the environment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name_hint: str
|
||||
The name of the variable.
|
||||
"""
|
||||
def __init__(self, name_hint):
|
||||
self.__init_handle_by_constructor__(_make.GlobalVar, name_hint)
|
||||
|
||||
|
||||
@register_relay_node
|
||||
class Param(Expr):
|
||||
"""A function type in Relay, see tvm/relay/type.h for more details.
|
||||
"""
|
||||
|
||||
def __init__(self, var, ty):
|
||||
self.__init_handle_by_constructor__(_make.Param, var, ty)
|
||||
|
||||
|
||||
@register_relay_node
|
||||
class Function(Expr):
|
||||
"""A function in Relay, see tvm/relay/expr.h for more details."""
|
||||
"""A function declaration expression.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
params: List[tvm.relay.Var]
|
||||
List of input parameters to the function.
|
||||
|
||||
ret_type: tvm.relay.Type
|
||||
The return type annotation of the function.
|
||||
|
||||
body: tvm.relay.Expr
|
||||
The body of the function.
|
||||
|
||||
type_params: Optional[List[tvm.relay.TypeParam]]
|
||||
The additional type parameters, this is only
|
||||
used in advanced usecase of template functions.
|
||||
"""
|
||||
def __init__(self,
|
||||
params,
|
||||
ret_type,
|
||||
body,
|
||||
type_params=None
|
||||
):
|
||||
type_params=None):
|
||||
if type_params is None:
|
||||
type_params = convert([])
|
||||
|
||||
|
@ -98,39 +125,87 @@ class Function(Expr):
|
|||
|
||||
@register_relay_node
|
||||
class Call(Expr):
|
||||
"""A function call in Relay, see tvm/relay/expr.h for more details."""
|
||||
"""Function call node in Relay.
|
||||
|
||||
def __init__(self, op, args, attrs, ty_args=None):
|
||||
if not ty_args:
|
||||
ty_args = []
|
||||
Call node corresponds the operator application node
|
||||
in computational graph terminology.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
op: tvm.relay.Op or any tvm.relay.Expr with function type.
|
||||
The operation to be called.
|
||||
|
||||
args: List[tvm.relay.Expr]
|
||||
The arguments to the call.
|
||||
|
||||
attrs: Optional[tvm.Attrs]
|
||||
Attributes to the call, can be None
|
||||
|
||||
type_args: Optional[List[tvm.relay.Type]]
|
||||
The additional type arguments, this is only
|
||||
used in advanced usecase of template functions.
|
||||
"""
|
||||
def __init__(self, op, args, attrs=None, type_args=None):
|
||||
if not type_args:
|
||||
type_args = []
|
||||
self.__init_handle_by_constructor__(
|
||||
_make.Call, op, args, attrs, ty_args)
|
||||
_make.Call, op, args, attrs, type_args)
|
||||
|
||||
|
||||
@register_relay_node
|
||||
class Let(Expr):
|
||||
"""A variable bindings in Relay, see tvm/relay/expr.h for more details."""
|
||||
"""Let variable binding expression.
|
||||
|
||||
def __init__(self, var, value, body, value_type=None):
|
||||
Parameters
|
||||
----------
|
||||
var: tvm.relay.Var
|
||||
The local variable to be bound.
|
||||
|
||||
value: tvm.relay.Expr
|
||||
The value to be bound.
|
||||
|
||||
body: tvm.relay.Expr
|
||||
The body of the let binding.
|
||||
"""
|
||||
def __init__(self, var, value, body):
|
||||
self.__init_handle_by_constructor__(
|
||||
_make.Let, var, value, body, value_type)
|
||||
_make.Let, var, value, body)
|
||||
|
||||
|
||||
@register_relay_node
|
||||
class If(Expr):
|
||||
"""A conditional expression in Relay, see tvm/relay/expr.h for more details."""
|
||||
"""A conditional expression in Relay.
|
||||
|
||||
def __init__(self, cond, true_value, false_value):
|
||||
Parameters
|
||||
----------
|
||||
cond: tvm.relay.Expr
|
||||
The condition.
|
||||
|
||||
true_branch: tvm.relay.Expr
|
||||
The expression evaluated when condition is true.
|
||||
|
||||
false_branch: tvm.relay.Expr
|
||||
The expression evaluated when condition is false.
|
||||
"""
|
||||
def __init__(self, cond, true_branch, false_branch):
|
||||
self.__init_handle_by_constructor__(
|
||||
_make.If, cond, true_value, false_value)
|
||||
_make.If, cond, true_branch, false_branch)
|
||||
|
||||
|
||||
@register_relay_node
|
||||
class TupleGetItem(Expr):
|
||||
"""An expression that get field from tuple in Relay, see tvm/relay/expr.h for more details."""
|
||||
"""Get index-th item from a tuple.
|
||||
|
||||
def __init__(self, tuple_, index):
|
||||
Parameters
|
||||
----------
|
||||
tuple_value: tvm.relay.Expr
|
||||
The input tuple expression.
|
||||
|
||||
index: int
|
||||
The index.
|
||||
"""
|
||||
def __init__(self, tuple_value, index):
|
||||
self.__init_handle_by_constructor__(
|
||||
_make.TupleGetItem, tuple_, index)
|
||||
_make.TupleGetItem, tuple_value, index)
|
||||
|
||||
debug_print = _expr._debug_print
|
||||
|
|
|
@ -7,7 +7,7 @@ from collections import OrderedDict
|
|||
import numpy as np
|
||||
import tvm
|
||||
from .ty import Type, FuncType, TensorType
|
||||
from .expr import Expr, Constant, Let, Var, Param, Function, If
|
||||
from .expr import Expr, Constant, Let, Var, Function, If
|
||||
from .env import Environment
|
||||
|
||||
|
||||
|
@ -98,7 +98,7 @@ class PartialFunc(object):
|
|||
self.type_params = type_params
|
||||
|
||||
def param_ids(self):
|
||||
return [p.var for p in self.params]
|
||||
return [p for p in self.params]
|
||||
|
||||
def to_func(self):
|
||||
"""Converts a PartialFunc into a :py:class:`~relay.Function`."""
|
||||
|
@ -113,9 +113,8 @@ class PartialFunc(object):
|
|||
|
||||
def _mk_let(bindings, ret_value):
|
||||
let_expr = ret_value
|
||||
for var, (value, ty) in reversed(list(bindings.items())):
|
||||
let_expr = Let(var, value, let_expr, ty)
|
||||
|
||||
for var, value in reversed(list(bindings.items())):
|
||||
let_expr = Let(var, value, let_expr)
|
||||
return let_expr
|
||||
|
||||
|
||||
|
@ -168,15 +167,12 @@ class IRBuilder(object):
|
|||
|
||||
#pylint: disable=invalid-name
|
||||
def bind(self, name, value, ty):
|
||||
lv = Var(name)
|
||||
lv = Var(name, ty)
|
||||
self.scopes[-1][name] = lv
|
||||
self.bindings[-1][lv] = (value, ty)
|
||||
self.bindings[-1][lv] = value
|
||||
return lv
|
||||
|
||||
def let(self, name, value, value_type=None):
|
||||
if isinstance(value, Param):
|
||||
value = value.var
|
||||
|
||||
if not isinstance(value, Expr):
|
||||
value = convert(value)
|
||||
|
||||
|
@ -185,23 +181,18 @@ class IRBuilder(object):
|
|||
def _convert_params(self, raw_params):
|
||||
relay_params = []
|
||||
for raw_param in raw_params:
|
||||
if isinstance(raw_param, Param):
|
||||
var = raw_param.var
|
||||
if isinstance(raw_param, Var):
|
||||
param = raw_param
|
||||
elif isinstance(raw_param, tuple):
|
||||
var, ty = raw_param
|
||||
if isinstance(var, str):
|
||||
var = Var(var)
|
||||
ty = _convert_type(ty)
|
||||
param = Param(var, ty)
|
||||
elif isinstance(param, str):
|
||||
var = Var(raw_param)
|
||||
ty = None
|
||||
param = Param(var, ty)
|
||||
param = Var(var, ty)
|
||||
elif isinstance(raw_param, str):
|
||||
param = Var(raw_param, None)
|
||||
else:
|
||||
raise Exception("unknown parameter type")
|
||||
|
||||
self.scopes[-1][var.name_hint] = var
|
||||
self.scopes[-1][param.name_hint] = param
|
||||
relay_params.append(param)
|
||||
|
||||
return relay_params
|
||||
|
@ -265,7 +256,7 @@ class IRBuilder(object):
|
|||
else:
|
||||
ty = _convert_type(ty)
|
||||
|
||||
return Param(Var(name), ty)
|
||||
return Var(name, ty)
|
||||
|
||||
def global_var(self, name):
|
||||
# type: (str) -> GlobalVar
|
||||
|
|
|
@ -96,7 +96,9 @@ class TypeDocifier : private TypeFunctor<Doc(const Type& n)> {
|
|||
}
|
||||
|
||||
std::vector<Doc> DocifyTypeParam(const tvm::Array<TypeParam>& arr) {
|
||||
return MapDocify<TypeParam>(arr, [=](const TypeParam& tp) { return Docify(tp); });
|
||||
return MapDocify<TypeParam>(arr, [=](const TypeParam& tp) {
|
||||
return Docify(tp);
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<Doc> DocifyTypeConstraint(const tvm::Array<TypeConstraint>& arr) {
|
||||
|
@ -188,10 +190,11 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
|
|||
return vec;
|
||||
}
|
||||
|
||||
std::vector<Doc> DocifyParamArray(const tvm::Array<Param>& arr) {
|
||||
std::vector<Doc> DocifyParamArray(const tvm::Array<Var>& arr) {
|
||||
std::vector<Doc> vec;
|
||||
for (size_t i = 0; i < arr.size(); ++i) {
|
||||
vec.push_back(Docify(arr[i]));
|
||||
for (Var param : arr) {
|
||||
vec.emplace_back(TypeAnnotation(DocOfStr(VarName(param)),
|
||||
param->type_annotation));
|
||||
}
|
||||
return vec;
|
||||
}
|
||||
|
@ -212,10 +215,6 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
|
|||
return DocOfStr(g->name_hint);
|
||||
}
|
||||
|
||||
Doc VisitExpr_(const ParamNode* p) final {
|
||||
return TypeAnnotation(Docify(p->var), p->type);
|
||||
}
|
||||
|
||||
Doc VisitExpr_(const FunctionNode* f) final {
|
||||
return Group(TypeAnnotation(Seq("(", DocifyParamArray(f->params), ")"), f->ret_type) + Sep() +
|
||||
DocOfStr("=>") + Sep() +
|
||||
|
@ -227,7 +226,8 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
|
|||
}
|
||||
|
||||
Doc VisitExpr_(const LetNode* l) final {
|
||||
return Group(DocOfStr("let") + Sep() + TypeAnnotation(Docify(l->var), l->value_type) + Sep() +
|
||||
return Group(DocOfStr("let") + Sep() +
|
||||
TypeAnnotation(Docify(l->var), l->var->type_annotation) + Sep() +
|
||||
DocOfStr("=") + Sep() + Docify(l->value) + DocOfStr(";") + Endl() +
|
||||
Docify(l->body));
|
||||
}
|
||||
|
|
|
@ -54,20 +54,26 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
|
|||
p->stream << "Tuple(" << node->fields << ")";
|
||||
});
|
||||
|
||||
Var VarNode::make(std::string name_hint) {
|
||||
Var VarNode::make(std::string name_hint, Type type_annotation) {
|
||||
NodePtr<VarNode> n = make_node<VarNode>();
|
||||
n->name_hint = std::move(name_hint);
|
||||
n->type_annotation = std::move(type_annotation);
|
||||
return Var(n);
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay._make.Var")
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
*ret = VarNode::make(args[0]);
|
||||
*ret = VarNode::make(args[0], args[1]);
|
||||
});
|
||||
|
||||
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
|
||||
.set_dispatch<VarNode>([](const VarNode *node, tvm::IRPrinter *p) {
|
||||
p->stream << "Var(" << node->name_hint << ")";
|
||||
p->stream << "Var(" << node->name_hint;
|
||||
if (node->type_annotation.defined()) {
|
||||
p->stream << ", ty=";
|
||||
p->print(node->type_annotation);
|
||||
}
|
||||
p->stream << ")";
|
||||
});
|
||||
|
||||
GlobalVar GlobalVarNode::make(std::string name_hint) {
|
||||
|
@ -86,24 +92,10 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
|
|||
p->stream << "GlobalVar(" << node->name_hint << ")";
|
||||
});
|
||||
|
||||
Param ParamNode::make(Var var, Type type) {
|
||||
NodePtr<ParamNode> n = make_node<ParamNode>();
|
||||
n->var = std::move(var);
|
||||
n->type = std::move(type);
|
||||
return Param(n);
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay._make.Param")
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
*ret = ParamNode::make(args[0], args[1]);
|
||||
});
|
||||
|
||||
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
|
||||
.set_dispatch<ParamNode>([](const ParamNode *node, tvm::IRPrinter *p) {
|
||||
p->stream << "Param(" << node->var << ", " << node->type << ")";
|
||||
});
|
||||
|
||||
Function FunctionNode::make(tvm::Array<Param> params, Type ret_type, Expr body,
|
||||
Function FunctionNode::make(tvm::Array<Var> params,
|
||||
Type ret_type,
|
||||
Expr body,
|
||||
tvm::Array<TypeParam> type_params) {
|
||||
NodePtr<FunctionNode> n = make_node<FunctionNode>();
|
||||
n->params = std::move(params);
|
||||
|
@ -113,12 +105,11 @@ Function FunctionNode::make(tvm::Array<Param> params, Type ret_type, Expr body,
|
|||
return Function(n);
|
||||
}
|
||||
|
||||
Type FunctionNode::fn_type() const {
|
||||
FuncType FunctionNode::func_type_annotation() const {
|
||||
Array<Type> param_types;
|
||||
for (auto param : this->params) {
|
||||
param_types.push_back(param->type);
|
||||
param_types.push_back(param->type_annotation);
|
||||
}
|
||||
|
||||
return FuncTypeNode::make(param_types, this->ret_type, this->type_params, {});
|
||||
}
|
||||
|
||||
|
@ -155,24 +146,23 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
|
|||
<< node->attrs << ", " << node->type_args << ")";
|
||||
});
|
||||
|
||||
Let LetNode::make(Var var, Expr value, Expr body, Type value_type) {
|
||||
Let LetNode::make(Var var, Expr value, Expr body) {
|
||||
NodePtr<LetNode> n = make_node<LetNode>();
|
||||
n->var = std::move(var);
|
||||
n->value = std::move(value);
|
||||
n->body = std::move(body);
|
||||
n->value_type = std::move(value_type);
|
||||
return Let(n);
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay._make.Let")
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
*ret = LetNode::make(args[0], args[1], args[2], args[3]);
|
||||
});
|
||||
*ret = LetNode::make(args[0], args[1], args[2]);
|
||||
});
|
||||
|
||||
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
|
||||
.set_dispatch<LetNode>([](const LetNode *node, tvm::IRPrinter *p) {
|
||||
p->stream << "LetNode(" << node->var << ", " << node->value
|
||||
<< ", " << node->body << ", " << node->value_type << ")";
|
||||
<< ", " << node->body << ")";
|
||||
});
|
||||
|
||||
If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) {
|
||||
|
|
|
@ -24,6 +24,16 @@ Expr ExprMutator::Mutate(const Expr& expr) {
|
|||
}
|
||||
|
||||
Expr ExprMutator::VisitExpr_(const VarNode* op) {
|
||||
// NOTE: var will only be mutated once
|
||||
// Thanks to the memo and reused during rewriting if necessary.
|
||||
// It is safe to assume that the
|
||||
if (op->type_annotation.defined()) {
|
||||
auto type = this->VisitType(op->type_annotation);
|
||||
if (!op->type_annotation.same_as(type)) {
|
||||
return VarNode::make(op->name_hint, type);
|
||||
}
|
||||
}
|
||||
// default case return self.
|
||||
return GetRef<Expr>(op);
|
||||
}
|
||||
|
||||
|
@ -55,16 +65,6 @@ Expr ExprMutator::VisitExpr_(const TupleNode* op) {
|
|||
}
|
||||
}
|
||||
|
||||
Expr ExprMutator::VisitExpr_(const ParamNode* op) {
|
||||
Var var = Downcast<Var>(this->Mutate(op->var));
|
||||
auto type = this->VisitType(op->type);
|
||||
if (op->var.same_as(var) && op->type.same_as(type)) {
|
||||
return GetRef<Expr>(op);
|
||||
} else {
|
||||
return ParamNode::make(var, type);
|
||||
}
|
||||
}
|
||||
|
||||
Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
|
||||
tvm::Array<TypeParam> ty_params;
|
||||
bool all_ty_params_changed = true;
|
||||
|
@ -75,10 +75,10 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
|
|||
all_ty_params_changed &= new_ty_param.same_as(ty_param);
|
||||
}
|
||||
|
||||
tvm::Array<Param> params;
|
||||
tvm::Array<Var> params;
|
||||
bool all_params_changed = true;
|
||||
for (auto param : op->params) {
|
||||
Param new_param = Downcast<Param>(this->Mutate(param));
|
||||
Var new_param = Downcast<Var>(this->Mutate(param));
|
||||
params.push_back(new_param);
|
||||
all_params_changed &= param.same_as(new_param);
|
||||
}
|
||||
|
@ -123,17 +123,15 @@ Expr ExprMutator::VisitExpr_(const CallNode* call_node) {
|
|||
|
||||
Expr ExprMutator::VisitExpr_(const LetNode* op) {
|
||||
Var var = Downcast<Var>(this->Mutate(op->var));
|
||||
auto type = this->VisitType(op->value_type);
|
||||
auto value = this->Mutate(op->value);
|
||||
auto body = this->Mutate(op->body);
|
||||
|
||||
if (var.same_as(op->var) &&
|
||||
type.same_as(op->value_type) &&
|
||||
value.same_as(op->value) &&
|
||||
body.same_as(op->body)) {
|
||||
return GetRef<Expr>(op);
|
||||
} else {
|
||||
return LetNode::make(var, value, body, type);
|
||||
return LetNode::make(var, value, body);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -162,6 +160,9 @@ Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
|
|||
Type ExprMutator::VisitType(const Type& t) { return t; }
|
||||
|
||||
void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) {
|
||||
if (op->type_annotation.defined()) {
|
||||
this->VisitType(op->type_annotation);
|
||||
}
|
||||
}
|
||||
|
||||
void ExprVisitor::ExprVisitor::VisitExpr_(const GlobalVarNode* op) {
|
||||
|
@ -176,10 +177,6 @@ void ExprVisitor::ExprVisitor::VisitExpr_(const TupleNode* op) {
|
|||
}
|
||||
}
|
||||
|
||||
void ExprVisitor::ExprVisitor::VisitExpr_(const ParamNode* op) {
|
||||
this->VisitExpr(op->var);
|
||||
}
|
||||
|
||||
void ExprVisitor::ExprVisitor::VisitExpr_(const FunctionNode* op) {
|
||||
for (auto param : op->params) {
|
||||
this->VisitExpr(param);
|
||||
|
|
|
@ -252,15 +252,6 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
|
|||
}
|
||||
}
|
||||
|
||||
void VisitExpr_(const ParamNode* p1, const Expr& e2) final {
|
||||
if (const ParamNode* p2 = e2.as<ParamNode>()) {
|
||||
eq_map.Set(p1->var, p2->var);
|
||||
equal = equal && AlphaEqual(p1->type, p2->type);
|
||||
} else {
|
||||
equal = false;
|
||||
}
|
||||
}
|
||||
|
||||
void VisitExpr_(const FunctionNode* func1, const Expr& e2) final {
|
||||
if (const FunctionNode* func2 = e2.as<FunctionNode>()) {
|
||||
if (func1->params.size() != func2->params.size()) {
|
||||
|
@ -273,9 +264,10 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
|
|||
return;
|
||||
}
|
||||
|
||||
for (size_t i = 0U; i < func1->params.size(); i++) {
|
||||
this->VisitExpr(func1->params[i], func2->params[i]);
|
||||
for (size_t i = 0; i < func1->params.size(); ++i) {
|
||||
MergeVarDecl(func1->params[i], func2->params[i]);
|
||||
}
|
||||
if (!equal) return;
|
||||
|
||||
for (size_t i = 0U; i < func1->type_params.size(); i++) {
|
||||
equal = equal && AlphaEqual(func1->type_params[i], func2->type_params[i]);
|
||||
|
@ -332,19 +324,9 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
|
|||
|
||||
void VisitExpr_(const LetNode* op, const Expr& e2) final {
|
||||
if (const LetNode* let = e2.as<LetNode>()) {
|
||||
eq_map.Set(op->var, let->var);
|
||||
MergeVarDecl(op->var, let->var);
|
||||
this->VisitExpr(op->value, let->value);
|
||||
this->VisitExpr(op->body, let->body);
|
||||
|
||||
// value_type should match as well (including nulls)
|
||||
if (op->value_type.defined() != let->value_type.defined()) {
|
||||
equal = false;
|
||||
return;
|
||||
}
|
||||
|
||||
if (op->value_type.defined()) {
|
||||
equal = equal && AlphaEqual(op->value_type, let->value_type);
|
||||
}
|
||||
} else {
|
||||
equal = false;
|
||||
}
|
||||
|
@ -388,6 +370,20 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
|
|||
equal = false;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void MergeVarDecl(const Var& var1, const Var& var2) {
|
||||
if (var1->type_annotation.defined() != var2->type_annotation.defined()) {
|
||||
equal = false;
|
||||
return;
|
||||
}
|
||||
if (var1->type_annotation.defined() &&
|
||||
!AlphaEqual(var1->type_annotation, var2->type_annotation)) {
|
||||
equal = false;
|
||||
return;
|
||||
}
|
||||
eq_map.Set(var1, var2);
|
||||
}
|
||||
};
|
||||
|
||||
bool AlphaEqual(const Expr& e1, const Expr& e2) {
|
||||
|
|
|
@ -54,12 +54,7 @@ class CalcDep : private ExprMutator {
|
|||
}
|
||||
|
||||
private:
|
||||
struct Binder {
|
||||
Type t;
|
||||
Expr e;
|
||||
Binder(const Type& t, const Expr& e) : t(t), e(e) { }
|
||||
};
|
||||
using VarMap = std::unordered_map<Var, Binder, NodeHash, NodeEqual>;
|
||||
using VarMap = std::unordered_map<Var, Expr, NodeHash, NodeEqual>;
|
||||
VarMap var_map_;
|
||||
|
||||
Expr VisitExpr_(const IfNode* i) final {
|
||||
|
@ -74,9 +69,7 @@ class CalcDep : private ExprMutator {
|
|||
}
|
||||
|
||||
Expr VisitExpr_(const LetNode* l) final {
|
||||
var_map_.insert(std::pair<Var, Binder>(l->var,
|
||||
Binder(l->value_type,
|
||||
Eliminate(l->value))));
|
||||
var_map_[l->var] = Eliminate(l->value);
|
||||
return VisitExpr(l->body);
|
||||
}
|
||||
|
||||
|
@ -92,15 +85,16 @@ class CalcDep : private ExprMutator {
|
|||
explicit GenLet(const VarMap& var_map) : var_map_(var_map) { }
|
||||
friend CalcDep;
|
||||
|
||||
void VisitExpr_(const VarNode* vn) final {
|
||||
Var v = GetRef<Var>(vn);
|
||||
if (var_map_.count(v) != 0) {
|
||||
auto val = var_map_.at(v);
|
||||
var_map_.erase(v);
|
||||
void VisitExpr_(const VarNode* vnode) final {
|
||||
Var v = GetRef<Var>(vnode);
|
||||
auto it = var_map_.find(v);
|
||||
if (it != var_map_.end()) {
|
||||
Expr expr = it->second;
|
||||
var_map_.erase(it);
|
||||
// erase before visit to handle letrec
|
||||
VisitExpr(val.e);
|
||||
VisitExpr(expr);
|
||||
// visit before push back so the dependency of dependency is before the dependency
|
||||
lets_.Push(v, val.t, val.e);
|
||||
lets_.Push(v, expr);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
|
@ -26,57 +26,46 @@ namespace relay {
|
|||
*/
|
||||
class LetList {
|
||||
public:
|
||||
/*! \brief insert a binding.
|
||||
/*!
|
||||
* \brief insert a binding.
|
||||
*
|
||||
* \param pv the var of the binding.
|
||||
* \param pv the var of the binding.
|
||||
*
|
||||
* \param ty the type of the binding.
|
||||
* \param expr the value of the binding.
|
||||
*
|
||||
* \param expr the value of the binding.
|
||||
*
|
||||
* \return a Var that hold the inserted expr.
|
||||
* \return a Var that hold the inserted expr.
|
||||
*/
|
||||
Var Push(const Var& pv, const Type& ty, const Expr& expr) {
|
||||
std::tuple<Var, Type, Expr> tuple(pv, ty, expr);
|
||||
lets_.push_back(tuple);
|
||||
Var Push(Var pv, Expr expr) {
|
||||
lets_.emplace_back(std::make_pair(pv, expr));
|
||||
return pv;
|
||||
}
|
||||
|
||||
/*! \brief insert a binding.
|
||||
/*!
|
||||
* \brief insert a binding.
|
||||
*
|
||||
* \param ty the type of the binding.
|
||||
* \param ty the type of the binding.
|
||||
*
|
||||
* \param expr the value of the binding.
|
||||
* \param expr the value of the binding.
|
||||
*
|
||||
* \return a Var that hold the inserted expr.
|
||||
* \return a Var that hold the inserted expr.
|
||||
*/
|
||||
Var Push(const Type& ty, const Expr& expr) {
|
||||
return Push(VarNode::make("x"), ty, expr);
|
||||
Var Push(Type ty, Expr expr) {
|
||||
return Push(VarNode::make("x", ty), expr);
|
||||
}
|
||||
|
||||
/*! \brief insert a binding.
|
||||
*
|
||||
* \param pv the var of the binding.
|
||||
/*!
|
||||
* \brief insert a binding.
|
||||
*
|
||||
* \param expr the value of the binding.
|
||||
*
|
||||
* \return a Var that hold the inserted expr.
|
||||
*/
|
||||
Var Push(const Var& pv, const Expr& expr) {
|
||||
return Push(pv, IncompleteTypeNode::make(TypeParamNode::kType), expr);
|
||||
}
|
||||
|
||||
/*! \brief insert a binding.
|
||||
*
|
||||
* \param expr the value of the binding.
|
||||
*
|
||||
* \return a Var that hold the inserted expr.
|
||||
*/
|
||||
Var Push(const Expr& expr) {
|
||||
Var Push(Expr expr) {
|
||||
return Push(IncompleteTypeNode::make(TypeParamNode::kType), expr);
|
||||
}
|
||||
|
||||
/*! \brief wrap an expr around the LetList.
|
||||
/*!
|
||||
* \brief wrap an expr around the LetList.
|
||||
*
|
||||
* \param body the Expression to be wrapped around.
|
||||
*
|
||||
|
@ -85,7 +74,7 @@ class LetList {
|
|||
Expr Get(const Expr& body) const {
|
||||
Expr ret = body;
|
||||
for (auto rit = lets_.rbegin(); rit != lets_.rend(); ++rit) {
|
||||
ret = LetNode::make(std::get<0>(*rit), std::get<2>(*rit), ret, std::get<1>(*rit));
|
||||
ret = LetNode::make(std::get<0>(*rit), std::get<1>(*rit), ret);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
@ -118,7 +107,7 @@ class LetList {
|
|||
}
|
||||
|
||||
private:
|
||||
std::vector<std::tuple<Var, Type, Expr> > lets_;
|
||||
std::vector<std::pair<Var, Expr> > lets_;
|
||||
};
|
||||
|
||||
} // namespace relay
|
||||
|
|
|
@ -87,15 +87,11 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
|
|||
|
||||
// Visitor logics
|
||||
Type VisitExpr_(const VarNode* op) final {
|
||||
// The type of Var can already been lookedup in type_map_;
|
||||
LOG(FATAL) << "Cannot find binding for var " << GetRef<Var>(op);
|
||||
return Type();
|
||||
}
|
||||
|
||||
Type VisitExpr_(const ParamNode* op) final {
|
||||
// directly handled by Funtion
|
||||
LOG(FATAL) << "not reached";
|
||||
return Type();
|
||||
if (op->type_annotation.defined()) {
|
||||
return op->type_annotation;
|
||||
} else {
|
||||
return IncompleteTypeNode::make(TypeParamNode::kType);
|
||||
}
|
||||
}
|
||||
|
||||
Type VisitExpr_(const GlobalVarNode* op) final {
|
||||
|
@ -139,11 +135,11 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
|
|||
|
||||
Type VisitExpr_(const LetNode* op) final {
|
||||
Type vtype = GetType(op->value);
|
||||
if (op->value_type.defined()) {
|
||||
vtype = Unify(vtype, op->value_type, op->span);
|
||||
if (op->var->type_annotation.defined()) {
|
||||
vtype = Unify(vtype, op->var->type_annotation, op->span);
|
||||
}
|
||||
CHECK(!type_map_.count(op->var));
|
||||
// NOTE: no scoping is necessary becase var are unique in program
|
||||
// NOTE: no scoping is necessary because var are unique in program
|
||||
type_map_[op->var] = vtype;
|
||||
return GetType(op->body);
|
||||
}
|
||||
|
@ -256,8 +252,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
|
|||
|
||||
Type VisitExpr_(const FunctionNode* f) final {
|
||||
for (auto param : f->params) {
|
||||
type_map_[param->var] = param->type;
|
||||
type_map_[param] = param->type;
|
||||
GetType(param);
|
||||
}
|
||||
Type rtype = GetType(f->body);
|
||||
// Run solver using the currently known information
|
||||
|
@ -265,8 +260,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
|
|||
// Trying to resolve
|
||||
Array<Type> arg_types;
|
||||
for (size_t i = 0; i < f->params.size(); ++i) {
|
||||
Param param = f->params[i];
|
||||
Type atype = solver_.Resolve(param->type);
|
||||
Type atype = solver_.Resolve(GetType(f->params[i]));
|
||||
CHECK(atype.as<IncompleteTypeNode>() == nullptr)
|
||||
<< "Cannot resolve type of " << i
|
||||
<< "-th parameter of function at" << f->span;
|
||||
|
@ -311,9 +305,6 @@ class TypeInferencer::Resolver : public ExprMutator {
|
|||
return AttachCheckedType(op);
|
||||
}
|
||||
|
||||
Expr VisitExpr_(const ParamNode* op) final {
|
||||
return ExprMutator::VisitExpr_(op);
|
||||
}
|
||||
|
||||
Expr VisitExpr_(const FunctionNode* op) final {
|
||||
return AttachCheckedType(op);
|
||||
|
@ -380,7 +371,7 @@ Expr InferType(const Environment& env,
|
|||
const GlobalVar& var,
|
||||
const Function& func) {
|
||||
Function func_copy = Function(make_node<FunctionNode>(*func.operator->()));
|
||||
func_copy->checked_type_ = func_copy->fn_type();
|
||||
func_copy->checked_type_ = func_copy->func_type_annotation();
|
||||
env->functions.Set(var, func_copy);
|
||||
Expr func_ret = TypeInferencer(env).Infer(func_copy);
|
||||
auto map_node = env->functions.CopyOnWrite();
|
||||
|
|
|
@ -50,14 +50,17 @@ class FreeVar : public ExprVisitor {
|
|||
if (bound_vars.count(var) == 0) {
|
||||
free_vars.insert(var);
|
||||
}
|
||||
if (v->type_annotation.defined()) {
|
||||
VisitType(v->type_annotation);
|
||||
}
|
||||
}
|
||||
|
||||
void VisitExpr_(const FunctionNode *f) final {
|
||||
for (const auto& tp : f->type_params) {
|
||||
bound_types.insert(tp);
|
||||
}
|
||||
for (const auto& p : f->params) {
|
||||
bound_vars.insert(p->var);
|
||||
for (const auto& param : f->params) {
|
||||
bound_vars.insert(param);
|
||||
}
|
||||
VisitExpr(f->body);
|
||||
VisitType(f->ret_type);
|
||||
|
@ -67,7 +70,6 @@ class FreeVar : public ExprVisitor {
|
|||
bound_vars.insert(l->var);
|
||||
VisitExpr(l->value);
|
||||
VisitExpr(l->body);
|
||||
VisitType(l->value_type);
|
||||
}
|
||||
|
||||
public:
|
||||
|
|
|
@ -34,8 +34,8 @@ class WellFormedChecker : private ExprVisitor {
|
|||
}
|
||||
|
||||
void VisitExpr_(const FunctionNode * f) final {
|
||||
for (const Param & p : f->params) {
|
||||
Check(p->var);
|
||||
for (const Var & param : f->params) {
|
||||
Check(param);
|
||||
}
|
||||
CheckWellFormed(f->body);
|
||||
}
|
||||
|
|
|
@ -14,7 +14,6 @@ def test_let():
|
|||
assert var == prog.body
|
||||
assert isinstance(value, Constant)
|
||||
assert value.data.asnumpy() == np.array(1)
|
||||
assert prog.value_type == None
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_let()
|
||||
|
|
|
@ -49,18 +49,11 @@ def test_global_var():
|
|||
show(gv)
|
||||
|
||||
|
||||
def test_param():
|
||||
lv = relay.Var('x')
|
||||
ty = None
|
||||
param = relay.Param(lv, ty)
|
||||
show(lv)
|
||||
|
||||
|
||||
def test_function():
|
||||
param_names = ['a', 'b', 'c', 'd']
|
||||
params = tvm.convert([relay.Param(relay.Var(n), None) for n in param_names])
|
||||
params = tvm.convert([relay.Var(n) for n in param_names])
|
||||
ret_type = None
|
||||
body = params[0].var
|
||||
body = params[0]
|
||||
type_params = tvm.convert([])
|
||||
fn = relay.Function(params, ret_type, body, type_params)
|
||||
show(fn)
|
||||
|
@ -76,11 +69,11 @@ def test_call():
|
|||
|
||||
|
||||
def test_let():
|
||||
lv = relay.Var('x')
|
||||
ty = relay.ty.TensorType((10, 20), 'float32')
|
||||
lv = relay.Var('x', ty)
|
||||
arr = tvm.nd.array(10)
|
||||
value = relay.Constant(arr)
|
||||
let = relay.Let(lv, value, lv, ty)
|
||||
let = relay.Let(lv, value, lv)
|
||||
show(let)
|
||||
|
||||
|
||||
|
|
|
@ -99,10 +99,16 @@ def test_tuple():
|
|||
def test_local_var():
|
||||
name_hint = 's'
|
||||
lv = relay.Var(name_hint)
|
||||
lv.name_hint == name_hint
|
||||
assert lv.name_hint == name_hint
|
||||
assert lv.type_annotation is None
|
||||
# assert lv.span == None todo(@jroesch): what do we do about spans
|
||||
str(lv)
|
||||
|
||||
t1 = relay.ty.TensorType((), "float")
|
||||
lv = relay.Var(name_hint, t1)
|
||||
assert lv.name_hint == name_hint
|
||||
assert lv.type_annotation == t1
|
||||
|
||||
|
||||
def test_global_var():
|
||||
name_hint = 'g'
|
||||
|
@ -112,19 +118,9 @@ def test_global_var():
|
|||
str(gv)
|
||||
|
||||
|
||||
def test_param():
|
||||
lv = relay.Var('x')
|
||||
ty = None
|
||||
param = relay.Param(lv, ty)
|
||||
assert param.var == lv
|
||||
assert param.type == ty
|
||||
assert param.span == None
|
||||
str(param)
|
||||
|
||||
|
||||
def test_function():
|
||||
param_names = ['a', 'b', 'c', 'd']
|
||||
params = tvm.convert([relay.Param(relay.Var(n), None) for n in param_names])
|
||||
params = tvm.convert([relay.Var(n) for n in param_names])
|
||||
ret_type = None
|
||||
body = None
|
||||
type_params = tvm.convert([])
|
||||
|
@ -154,10 +150,9 @@ def test_let():
|
|||
value = relay.Constant(arr)
|
||||
# I would prefer that the order of arguments
|
||||
# matches syntax let x: t = v in b
|
||||
let = relay.Let(lv, value, lv, ty)
|
||||
let = relay.Let(lv, value, lv)
|
||||
assert let.var == lv
|
||||
assert let.value == value
|
||||
assert let.value_type == ty
|
||||
assert let.body == lv
|
||||
assert let.span == None
|
||||
str(let)
|
||||
|
@ -194,7 +189,6 @@ if __name__ == "__main__":
|
|||
test_tuple()
|
||||
test_local_var()
|
||||
test_global_var()
|
||||
test_param()
|
||||
test_function()
|
||||
test_call()
|
||||
test_let()
|
||||
|
|
|
@ -7,23 +7,22 @@ def test_well_formed():
|
|||
assert well_formed(x)
|
||||
v = relay.Constant(tvm.nd.array(10))
|
||||
ty = None
|
||||
let = relay.Let(x, v, x, ty)
|
||||
let = relay.Let(x, v, x)
|
||||
assert well_formed(let)
|
||||
assert not well_formed(relay.Let(x, v, let, ty))
|
||||
f = relay.Function([relay.Param(x, ty)], ty, x)
|
||||
assert not well_formed(relay.Let(x, v, let))
|
||||
f = relay.Function([x], ty, x)
|
||||
assert well_formed(f)
|
||||
# this test should pass in case of weak uniqueness (only test for shadowing)
|
||||
# but we want all binder to be distinct from each other.
|
||||
assert not well_formed(relay.Let(relay.Var("y"), f,
|
||||
relay.Let(relay.Var("z"), f, v, ty), ty))
|
||||
relay.Let(relay.Var("z"), f, v)))
|
||||
|
||||
|
||||
def test_tuple():
|
||||
x = relay.Var('x')
|
||||
assert well_formed(x)
|
||||
v = relay.Constant(tvm.nd.array(10))
|
||||
ty = None
|
||||
let = relay.Let(x, v, x, ty)
|
||||
let = relay.Let(x, v, x)
|
||||
assert well_formed(let)
|
||||
assert well_formed(relay.Tuple([v, v]))
|
||||
assert not well_formed(relay.Tuple([let, let]))
|
||||
|
|
|
@ -27,6 +27,8 @@ def test_single_op():
|
|||
tvm.relay.sigmoid, tvm.relay.tanh]:
|
||||
check_single_op(opfunc)
|
||||
|
||||
|
||||
|
||||
def test_expand_dims_infer_type():
|
||||
ib = relay.ir_builder.IRBuilder()
|
||||
n, t, d = tvm.var("n"), tvm.var("t"), 100
|
||||
|
@ -75,12 +77,13 @@ def test_unary_op():
|
|||
ib = relay.ir_builder.IRBuilder()
|
||||
x = ib.param("x", relay.TensorType((10, 4), "int32"))
|
||||
with ib.function(x) as func:
|
||||
ib.ret(op(x.var))
|
||||
ib.ret(op(x))
|
||||
ib.ret(func)
|
||||
func = relay.ir_pass.infer_type(ib.env, func.to_func())
|
||||
ftype = func.checked_type
|
||||
assert ftype.ret_type == relay.TensorType((10, 4), "int32")
|
||||
|
||||
|
||||
def test_binary_op():
|
||||
def check_binary_op(opfunc):
|
||||
"""
|
||||
|
@ -94,7 +97,7 @@ def test_binary_op():
|
|||
x = b.param('x', tensor_type(5, 5, 5))
|
||||
y = b.param('y', tensor_type(5, 5, 5))
|
||||
with b.function(x, y) as func:
|
||||
b.ret(opfunc(x.var, y.var))
|
||||
b.ret(opfunc(x, y))
|
||||
b.ret(func)
|
||||
prog, env = b.get()
|
||||
ttype = tensor_type(5, 5, 5)
|
||||
|
@ -118,7 +121,7 @@ def test_binary_broadcast_op():
|
|||
x = b.param('x', tensor_type(10, 4))
|
||||
y = b.param('y', tensor_type(5, 10, 1))
|
||||
with b.function(x, y) as func:
|
||||
b.ret(opfunc(x.var, y.var))
|
||||
b.ret(opfunc(x, y))
|
||||
b.ret(func)
|
||||
prog, env = b.get()
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ def test_conv2d_infer_type():
|
|||
w = ib.param("w", relay.ty.IncompleteType())
|
||||
|
||||
with ib.function(x, w) as func:
|
||||
ib.ret(relay.nn.conv2d(x.var, w.var,
|
||||
ib.ret(relay.nn.conv2d(x, w,
|
||||
kernel_size=(3, 3),
|
||||
padding=(1, 1),
|
||||
channels=2))
|
||||
|
@ -29,7 +29,7 @@ def test_conv2d_infer_type():
|
|||
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "int8"))
|
||||
w = ib.param("w", relay.ty.TensorType((2, 10, 3, 3), "int8"))
|
||||
with ib.function(x, w) as func:
|
||||
ib.ret(relay.nn.conv2d(x.var, w.var, out_dtype="int32"))
|
||||
ib.ret(relay.nn.conv2d(x, w, out_dtype="int32"))
|
||||
ib.ret(func)
|
||||
func = relay.ir_pass.infer_type(ib.env, func.to_func())
|
||||
ftype = func.checked_type
|
||||
|
@ -42,7 +42,7 @@ def test_conv2d_infer_type():
|
|||
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "int8"))
|
||||
w = ib.param("w", relay.ty.IncompleteType())
|
||||
with ib.function(x, w) as func:
|
||||
ib.ret(relay.nn.conv2d(x.var, w.var,
|
||||
ib.ret(relay.nn.conv2d(x, w,
|
||||
kernel_size=(3, 3),
|
||||
padding=(1, 1),
|
||||
channels=16,
|
||||
|
@ -65,7 +65,7 @@ def test_conv2d_transpose_infer_type():
|
|||
w = ib.param("w", relay.ty.IncompleteType())
|
||||
|
||||
with ib.function(x, w) as func:
|
||||
ib.ret(relay.nn.conv2d_transpose(x.var, w.var,
|
||||
ib.ret(relay.nn.conv2d_transpose(x, w,
|
||||
kernel_size=(3, 3),
|
||||
padding=(1, 1),
|
||||
channels=15))
|
||||
|
@ -83,7 +83,7 @@ def test_conv2d_transpose_infer_type():
|
|||
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
|
||||
w = ib.param("w", relay.ty.TensorType((12, 11, 5, 5), "float32"))
|
||||
with ib.function(x, w) as func:
|
||||
ib.ret(relay.nn.conv2d_transpose(x.var, w.var,
|
||||
ib.ret(relay.nn.conv2d_transpose(x, w,
|
||||
output_padding=(1, 1),
|
||||
channels=11,
|
||||
data_layout="NHWC"))
|
||||
|
@ -98,7 +98,7 @@ def test_upsampling_infer_type():
|
|||
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
|
||||
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
|
||||
with ib.function(x) as func:
|
||||
ib.ret(relay.nn.upsampling(x.var, scale=2, layout="NCHW", method="BILINEAR"))
|
||||
ib.ret(relay.nn.upsampling(x, scale=2, layout="NCHW", method="BILINEAR"))
|
||||
ib.ret(func)
|
||||
func = relay.ir_pass.infer_type(ib.env, func.to_func())
|
||||
ftype = func.checked_type
|
||||
|
@ -108,7 +108,7 @@ def test_upsampling_infer_type():
|
|||
n, c = tvm.var("n"), tvm.var("c")
|
||||
x = ib.param("x", relay.ty.TensorType((n, c, 100, 200), "float32"))
|
||||
with ib.function(x) as func:
|
||||
ib.ret(relay.nn.upsampling(x.var, scale=2, layout="NCHW", method="BILINEAR"))
|
||||
ib.ret(relay.nn.upsampling(x, scale=2, layout="NCHW", method="BILINEAR"))
|
||||
ib.ret(func)
|
||||
func = relay.ir_pass.infer_type(ib.env, func.to_func())
|
||||
ftype = func.checked_type
|
||||
|
@ -119,7 +119,7 @@ def _test_pool2d_infer_type(opfunc):
|
|||
n, c, h, w = tvm.var("n"), 10, 224, 224
|
||||
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
|
||||
with ib.function(x) as func:
|
||||
ib.ret(opfunc(x.var, pool_size=(1, 1)))
|
||||
ib.ret(opfunc(x, pool_size=(1, 1)))
|
||||
ib.ret(func)
|
||||
func = relay.ir_pass.infer_type(ib.env, func.to_func())
|
||||
ftype = func.checked_type
|
||||
|
@ -132,7 +132,7 @@ def _test_pool2d_infer_type(opfunc):
|
|||
n, c, h, w = tvm.var("n"), 10, 224, 224
|
||||
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
|
||||
with ib.function(x) as func:
|
||||
ib.ret(opfunc(x.var, pool_size=(ph, pw), strides=(sh, sw)))
|
||||
ib.ret(opfunc(x, pool_size=(ph, pw), strides=(sh, sw)))
|
||||
ib.ret(func)
|
||||
func = relay.ir_pass.infer_type(ib.env, func.to_func())
|
||||
ftype = func.checked_type
|
||||
|
@ -144,7 +144,7 @@ def _test_global_pool2d_infer_type(opfunc):
|
|||
n, c, h, w = tvm.var("n"), tvm.var("c"), 224, 224
|
||||
x = ib.param("x", relay.ty.TensorType((n, h, w, c), "float32"))
|
||||
with ib.function(x) as func:
|
||||
ib.ret(opfunc(x.var, layout="NHWC"))
|
||||
ib.ret(opfunc(x, layout="NHWC"))
|
||||
ib.ret(func)
|
||||
func = relay.ir_pass.infer_type(ib.env, func.to_func())
|
||||
ftype = func.checked_type
|
||||
|
@ -154,7 +154,7 @@ def _test_global_pool2d_infer_type(opfunc):
|
|||
n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
|
||||
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
|
||||
with ib.function(x) as func:
|
||||
ib.ret(opfunc(x.var))
|
||||
ib.ret(opfunc(x))
|
||||
ib.ret(func)
|
||||
func = relay.ir_pass.infer_type(ib.env, func.to_func())
|
||||
ftype = func.checked_type
|
||||
|
@ -172,7 +172,7 @@ def test_flatten_infer_type():
|
|||
x = ib.param("x", relay.ty.TensorType((d1, d2, d3, d4), "float32"))
|
||||
|
||||
with ib.function(x) as func:
|
||||
ib.ret(relay.nn.batch_flatten(x.var))
|
||||
ib.ret(relay.nn.batch_flatten(x))
|
||||
ib.ret(func)
|
||||
func = relay.ir_pass.infer_type(ib.env, func.to_func())
|
||||
ftype = func.checked_type
|
||||
|
@ -181,7 +181,7 @@ def test_flatten_infer_type():
|
|||
ib = relay.ir_builder.IRBuilder()
|
||||
x = ib.param("x", relay.ty.TensorType((3, 2, 4, 3), "float32"))
|
||||
with ib.function(x) as func:
|
||||
ib.ret(relay.nn.batch_flatten(x.var))
|
||||
ib.ret(relay.nn.batch_flatten(x))
|
||||
ib.ret(func)
|
||||
func = relay.ir_pass.infer_type(ib.env, func.to_func())
|
||||
ftype = func.checked_type
|
||||
|
@ -190,7 +190,7 @@ def test_flatten_infer_type():
|
|||
ib = relay.ir_builder.IRBuilder()
|
||||
x = ib.param("x", relay.ty.TensorType((d1, 2, d3, 3), "float32"))
|
||||
with ib.function(x) as func:
|
||||
ib.ret(relay.nn.batch_flatten(x.var))
|
||||
ib.ret(relay.nn.batch_flatten(x))
|
||||
ib.ret(func)
|
||||
func = relay.ir_pass.infer_type(ib.env, func.to_func())
|
||||
ftype = func.checked_type
|
||||
|
@ -202,7 +202,7 @@ def test_pad_infer_type():
|
|||
n, c, h, w = 1, 2, 3, 4
|
||||
t = ib.param("t", relay.TensorType((n, c, h, w), "float32"))
|
||||
with ib.function(t) as func:
|
||||
ib.ret(relay.nn.pad(t.var, ((1, 1), (2, 2), (3, 3), (4, 4))))
|
||||
ib.ret(relay.nn.pad(t, ((1, 1), (2, 2), (3, 3), (4, 4))))
|
||||
ib.ret(func)
|
||||
func = relay.ir_pass.infer_type(ib.env, func.to_func())
|
||||
ftype = func.checked_type
|
||||
|
@ -213,7 +213,7 @@ def test_pad_infer_type():
|
|||
n, c, h, w = tvm.var("n"), 2, 3, tvm.var("w")
|
||||
t = ib.param("t", relay.TensorType((n, c, h, w), "float32"))
|
||||
with ib.function(t) as func:
|
||||
ib.ret(relay.nn.pad(t.var, ((1, 1), (2, 2), (3, 3), (4, 4))))
|
||||
ib.ret(relay.nn.pad(t, ((1, 1), (2, 2), (3, 3), (4, 4))))
|
||||
ib.ret(func)
|
||||
func = relay.ir_pass.infer_type(ib.env, func.to_func())
|
||||
ftype = func.checked_type
|
||||
|
@ -227,4 +227,3 @@ if __name__ == "__main__":
|
|||
test_flatten_infer_type()
|
||||
test_pad_infer_type()
|
||||
test_conv2d_transpose_infer_type()
|
||||
|
||||
|
|
|
@ -17,12 +17,13 @@ def test_zeros_ones():
|
|||
ftype = func.checked_type
|
||||
assert ftype.ret_type == relay.TensorType((124, 50), "float64")
|
||||
|
||||
|
||||
def test_unary_identity():
|
||||
for op in [relay.zeros_like, relay.ones_like]:
|
||||
ib = relay.ir_builder.IRBuilder()
|
||||
x = ib.param("x", relay.TensorType((8, 9, 4), "int32"))
|
||||
with ib.function(x) as func:
|
||||
ib.ret(op(x.var))
|
||||
ib.ret(op(x))
|
||||
ib.ret(func)
|
||||
func = relay.ir_pass.infer_type(ib.env, func.to_func())
|
||||
ftype = func.checked_type
|
||||
|
@ -33,7 +34,7 @@ def test_clip_type():
|
|||
ib = relay.ir_builder.IRBuilder()
|
||||
a = ib.param("a", relay.TensorType((10, 4), "float32"))
|
||||
with ib.function(a) as func:
|
||||
ib.ret(relay.clip(a.var, 1., 4.))
|
||||
ib.ret(relay.clip(a, 1., 4.))
|
||||
ib.ret(func)
|
||||
func = relay.ir_pass.infer_type(ib.env, func.to_func())
|
||||
ftype = func.checked_type
|
||||
|
@ -106,7 +107,7 @@ def test_take_infer_type():
|
|||
x = ib.param("x", relay.ty.TensorType(dshape, "float32"))
|
||||
indices = ib.param("indices", relay.ty.TensorType(indices_shape, "int32"))
|
||||
with ib.function(x, indices) as func:
|
||||
ib.ret(relay.take(x.var, indices.var, axis=axis))
|
||||
ib.ret(relay.take(x, indices, axis=axis))
|
||||
ib.ret(func)
|
||||
func = relay.ir_pass.infer_type(ib.env, func.to_func())
|
||||
ftype = func.checked_type
|
||||
|
@ -127,7 +128,7 @@ def test_full():
|
|||
ib = relay.ir_builder.IRBuilder()
|
||||
x = ib.param("x", relay.TensorType((), "int8"))
|
||||
with ib.function(x) as func:
|
||||
ib.ret(relay.full(x.var, ()))
|
||||
ib.ret(relay.full(x, ()))
|
||||
ib.ret(func)
|
||||
func = relay.ir_pass.infer_type(ib.env, func.to_func())
|
||||
ftype = func.checked_type
|
||||
|
@ -137,7 +138,7 @@ def test_full():
|
|||
ib = relay.ir_builder.IRBuilder()
|
||||
x = ib.param("x", relay.TensorType((), "float32"))
|
||||
with ib.function(x) as func:
|
||||
ib.ret(relay.full(x.var, (1, 2), "int8"))
|
||||
ib.ret(relay.full(x, (1, 2), "int8"))
|
||||
ib.ret(func)
|
||||
func = relay.ir_pass.infer_type(ib.env, func.to_func())
|
||||
ftype = func.checked_type
|
||||
|
@ -150,7 +151,7 @@ def test_full_like():
|
|||
base = ib.param("base", relay.TensorType((1, 2, 3), "float32"))
|
||||
fill = ib.param("fill", relay.TensorType((), "float32"))
|
||||
with ib.function(base, fill) as func:
|
||||
ib.ret(relay.full_like(base.var, fill.var))
|
||||
ib.ret(relay.full_like(base, fill))
|
||||
ib.ret(func)
|
||||
func = relay.ir_pass.infer_type(ib.env, func.to_func())
|
||||
ftype = func.checked_type
|
||||
|
@ -162,7 +163,7 @@ def test_full_like():
|
|||
base = ib.param("base", relay.TensorType((n, c, h, w), "float32"))
|
||||
fill = ib.param("fill", relay.TensorType((), "float32"))
|
||||
with ib.function(base, fill) as func:
|
||||
ib.ret(relay.full_like(base.var, fill.var))
|
||||
ib.ret(relay.full_like(base, fill))
|
||||
ib.ret(func)
|
||||
func = relay.ir_pass.infer_type(ib.env, func.to_func())
|
||||
ftype = func.checked_type
|
||||
|
|
|
@ -24,7 +24,7 @@ def test_cmp_type():
|
|||
x = ib.param("x", relay.TensorType((10, 4), "float32"))
|
||||
y = ib.param("y", relay.TensorType((5, 10, 1), "float32"))
|
||||
with ib.function(x, y) as func:
|
||||
ib.ret(op(x.var, y.var))
|
||||
ib.ret(op(x, y))
|
||||
ib.ret(func)
|
||||
func = relay.ir_pass.infer_type(ib.env, func.to_func())
|
||||
ftype = func.checked_type
|
||||
|
@ -39,7 +39,7 @@ def test_binary_broadcast():
|
|||
x = ib.param("x", relay.TensorType((10, 4), "int32"))
|
||||
y = ib.param("y", relay.TensorType((5, 10, 1), "int32"))
|
||||
with ib.function(x, y) as func:
|
||||
ib.ret(op(x.var, y.var))
|
||||
ib.ret(op(x, y))
|
||||
ib.ret(func)
|
||||
func = relay.ir_pass.infer_type(ib.env, func.to_func())
|
||||
ftype = func.checked_type
|
||||
|
@ -58,7 +58,7 @@ def test_binary_op():
|
|||
x = b.param('x', tensor_type(5, 5, 5))
|
||||
y = b.param('y', tensor_type(5, 5, 5))
|
||||
with b.function(x, y) as func:
|
||||
b.ret(opfunc(x.var, y.var))
|
||||
b.ret(opfunc(x, y))
|
||||
b.ret(func)
|
||||
prog, env = b.get()
|
||||
ttype = tensor_type(5, 5, 5)
|
||||
|
@ -81,7 +81,7 @@ def test_binary_broadcast_op():
|
|||
x = b.param('x', tensor_type(10, 4))
|
||||
y = b.param('y', tensor_type(5, 10, 1))
|
||||
with b.function(x, y) as func:
|
||||
b.ret(opfunc(x.var, y.var))
|
||||
b.ret(opfunc(x, y))
|
||||
b.ret(func)
|
||||
prog, env = b.get()
|
||||
|
||||
|
@ -103,7 +103,7 @@ def test_cmp_type():
|
|||
x = ib.param("x", relay.TensorType((10, 4), "float32"))
|
||||
y = ib.param("y", relay.TensorType((5, 10, 1), "float32"))
|
||||
with ib.function(x, y) as func:
|
||||
ib.ret(op(x.var, y.var))
|
||||
ib.ret(op(x, y))
|
||||
ib.ret(func)
|
||||
func = relay.ir_pass.infer_type(ib.env, func.to_func())
|
||||
ftype = func.checked_type
|
||||
|
@ -118,7 +118,7 @@ def test_binary_broadcast():
|
|||
x = ib.param("x", relay.TensorType((10, 4), "int32"))
|
||||
y = ib.param("y", relay.TensorType((5, 10, 1), "int32"))
|
||||
with ib.function(x, y) as func:
|
||||
ib.ret(op(x.var, y.var))
|
||||
ib.ret(op(x, y))
|
||||
ib.ret(func)
|
||||
func = relay.ir_pass.infer_type(ib.env, func.to_func())
|
||||
ftype = func.checked_type
|
||||
|
@ -131,7 +131,7 @@ def test_where():
|
|||
x = ib.param("x", relay.TensorType((3, 4), "float32"))
|
||||
y = ib.param("y", relay.TensorType((3, 4), "float32"))
|
||||
with ib.function(cond, x, y) as func:
|
||||
ib.ret(relay.where(cond.var, x.var, y.var))
|
||||
ib.ret(relay.where(cond, x, y))
|
||||
ib.ret(func)
|
||||
func = relay.ir_pass.infer_type(ib.env, func.to_func())
|
||||
ftype = func.checked_type
|
||||
|
|
|
@ -10,7 +10,7 @@ def test_resize_infer_type():
|
|||
th, tw = tvm.var("th"), tvm.var("tw")
|
||||
|
||||
with ib.function(x) as func:
|
||||
ib.ret(relay.image.resize(x.var, (th, tw)))
|
||||
ib.ret(relay.image.resize(x, (th, tw)))
|
||||
ib.ret(func)
|
||||
func = relay.ir_pass.infer_type(ib.env, func.to_func())
|
||||
ftype = func.checked_type
|
||||
|
@ -19,7 +19,7 @@ def test_resize_infer_type():
|
|||
ib = relay.ir_builder.IRBuilder()
|
||||
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "int8"))
|
||||
with ib.function(x) as func:
|
||||
ib.ret(relay.image.resize(x.var, (100, 200), "NCHW", "BILINEAR", False))
|
||||
ib.ret(relay.image.resize(x, (100, 200), "NCHW", "BILINEAR", False))
|
||||
ib.ret(func)
|
||||
func = relay.ir_pass.infer_type(ib.env, func.to_func())
|
||||
ftype = func.checked_type
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import tvm
|
||||
import numpy as np
|
||||
from tvm import relay
|
||||
from tvm.relay.ir_pass import alpha_equal
|
||||
from tvm.relay.ir_builder import convert
|
||||
|
@ -179,9 +180,9 @@ def test_var_alpha_equal():
|
|||
assert not alpha_equal(v1, v2)
|
||||
|
||||
# let node allows for setting the eq_map
|
||||
l1 = relay.Let(v1, convert(1), v1, None)
|
||||
l2 = relay.Let(v2, convert(1), v2, None)
|
||||
l3 = relay.Let(v1, convert(1), v2, None)
|
||||
l1 = relay.Let(v1, convert(1), v1)
|
||||
l2 = relay.Let(v2, convert(1), v2)
|
||||
l3 = relay.Let(v1, convert(1), v2)
|
||||
|
||||
assert alpha_equal(l1, l2)
|
||||
assert not alpha_equal(l1, l3)
|
||||
|
@ -209,10 +210,10 @@ def test_tuple_alpha_equal():
|
|||
assert alpha_equal(tup, same)
|
||||
|
||||
# use the eq_map
|
||||
let_tup = relay.Let(v1, tup, v1, None)
|
||||
let_tup = relay.Let(v1, tup, v1)
|
||||
let_mapped = relay.Let(v2, relay.Tuple([v2, convert(2), convert(3),
|
||||
relay.Tuple([convert(4)])]),
|
||||
v2, None)
|
||||
v2)
|
||||
assert alpha_equal(let_tup, let_mapped)
|
||||
|
||||
more_fields = relay.Tuple([v1, convert(2), convert(3), relay.Tuple([convert(4)]), v2])
|
||||
|
@ -242,61 +243,44 @@ def test_tuple_get_item_alpha_equal():
|
|||
assert alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 1))
|
||||
|
||||
|
||||
def test_param_alpha_equal():
|
||||
# only checks equality of the types
|
||||
v1 = relay.Var("v1")
|
||||
v2 = relay.Var("v2")
|
||||
|
||||
p1 = relay.Param(v1, relay.TensorType((1, 2, 3), "float32"))
|
||||
p2 = relay.Param(v2, relay.TensorType((1, 2, 3), "float32"))
|
||||
assert alpha_equal(p1, p2)
|
||||
|
||||
p3 = relay.Param(v1, relay.TensorType((4, 5, 6), "int8"))
|
||||
assert not alpha_equal(p1, p3)
|
||||
|
||||
p4 = relay.Param(v1, relay.TupleType([relay.TensorType((1, 2, 3),
|
||||
"float32")]))
|
||||
assert not alpha_equal(p1, p4)
|
||||
|
||||
|
||||
def test_function_alpha_equal():
|
||||
v1 = relay.Var("v1")
|
||||
v2 = relay.Var("v2")
|
||||
v3 = relay.Var("v3")
|
||||
v4 = relay.Var("v4")
|
||||
|
||||
tt1 = relay.TensorType((1, 2, 3), "float32")
|
||||
tt2 = relay.TensorType((4, 5, 6), "int8")
|
||||
tt3 = relay.TupleType([tt1, tt2])
|
||||
|
||||
v1 = relay.Var("v1", tt1)
|
||||
v2 = relay.Var("v2", tt2)
|
||||
v3 = relay.Var("v3", tt3)
|
||||
v4 = relay.Var("v4", tt2)
|
||||
vret = relay.Constant(tvm.nd.array(np.ones(1)))
|
||||
|
||||
tp1 = relay.TypeParam("tp1", relay.Kind.Type)
|
||||
tp2 = relay.TypeParam("tp2", relay.Kind.Type)
|
||||
tp3 = relay.TypeParam("tp3", relay.Kind.Shape)
|
||||
tp4 = relay.TypeParam("tp4", relay.Kind.Shape)
|
||||
|
||||
basic_args = [relay.Param(v3, tt1), relay.Param(v4, tt2)]
|
||||
basic_args = [relay.Var("v3", tt1), relay.Var("v4", tt2)]
|
||||
basic_tps = [tp1, tp2]
|
||||
|
||||
func = relay.Function([relay.Param(v1, tt1), relay.Param(v2, tt2)],
|
||||
tt2, v2, basic_tps)
|
||||
mapped = relay.Function(basic_args, tt2, v4, basic_tps)
|
||||
func = relay.Function([v1, v2],
|
||||
tt2, v1, basic_tps)
|
||||
mapped = relay.Function(basic_args, tt2, basic_args[0], basic_tps)
|
||||
assert alpha_equal(func, mapped)
|
||||
|
||||
fewer_params = relay.Function([relay.Param(v4, tt2)], tt2, v4, basic_tps)
|
||||
fewer_params = relay.Function([relay.Var("v4", tt2)], tt2, v4, basic_tps)
|
||||
assert not alpha_equal(func, fewer_params)
|
||||
|
||||
more_params = relay.Function([relay.Param(v3, tt1), relay.Param(v4, tt2),
|
||||
relay.Param(v2, tt2)], tt2, v4, basic_tps)
|
||||
more_params = relay.Function([relay.Var("v3", tt1),
|
||||
relay.Var("v4", tt2),
|
||||
relay.Var("v2", tt2)], tt2, v4, basic_tps)
|
||||
assert not alpha_equal(func, more_params)
|
||||
|
||||
params_unordered = relay.Function([relay.Param(v3, tt2),
|
||||
relay.Param(v4, tt1)],
|
||||
tt1, v3, basic_tps)
|
||||
params_unordered = relay.Function([v2, v1],
|
||||
tt2, v1, basic_tps)
|
||||
assert not alpha_equal(func, params_unordered)
|
||||
|
||||
params_mismatch = relay.Function([relay.Param(v3, tt3),
|
||||
relay.Param(v4, tt2)],
|
||||
tt2, v4, basic_tps)
|
||||
params_mismatch = relay.Function([v1, v3],
|
||||
tt2, v1, basic_tps)
|
||||
assert not alpha_equal(func, params_mismatch)
|
||||
|
||||
# also would not typecheck
|
||||
|
@ -376,7 +360,10 @@ def test_call_alpha_equal():
|
|||
|
||||
|
||||
def test_let_alpha_equal():
|
||||
tt1 = relay.TensorType((), "float32")
|
||||
tt2 = relay.TensorType((), "int8")
|
||||
v1 = relay.Var("v1")
|
||||
v1_wtype = relay.Var("v1", tt1)
|
||||
v2 = relay.Var("v2")
|
||||
v3 = relay.Var("v3")
|
||||
|
||||
|
@ -394,14 +381,13 @@ def test_let_alpha_equal():
|
|||
assert not alpha_equal(let, different_body)
|
||||
|
||||
# specified types must match
|
||||
tt1 = relay.TensorType((), "float32")
|
||||
tt2 = relay.TensorType((), "int8")
|
||||
let_with_type = relay.Let(v1, convert(2), v1, tt1)
|
||||
same_type = relay.Let(v1, convert(2), v1, tt1)
|
||||
|
||||
let_with_type = relay.Let(v1_wtype, convert(2), v1_wtype)
|
||||
same_type = relay.Let(v1_wtype, convert(2), v1_wtype)
|
||||
assert alpha_equal(let_with_type, same_type)
|
||||
assert not alpha_equal(let, let_with_type)
|
||||
|
||||
different_type = relay.Let(v1, convert(2), v1, tt2)
|
||||
v2 = relay.Var("v1", tt2)
|
||||
different_type = relay.Let(v2, convert(2), v2)
|
||||
assert not alpha_equal(let_with_type, different_type)
|
||||
|
||||
|
||||
|
@ -437,16 +423,13 @@ if __name__ == "__main__":
|
|||
test_tensor_type_alpha_equal()
|
||||
test_incomplete_type_alpha_equal()
|
||||
test_constant_alpha_equal()
|
||||
test_type_param_alpha_equal()
|
||||
test_func_type_alpha_equal()
|
||||
test_tuple_type_alpha_equal()
|
||||
test_type_relation_alpha_equal()
|
||||
test_constant_alpha_equal()
|
||||
test_var_alpha_equal()
|
||||
test_global_var_alpha_equal()
|
||||
test_tuple_alpha_equal()
|
||||
test_tuple_get_item_alpha_equal()
|
||||
test_param_alpha_equal()
|
||||
test_function_alpha_equal()
|
||||
test_call_alpha_equal()
|
||||
test_let_alpha_equal()
|
||||
|
|
|
@ -28,17 +28,17 @@ e = env()
|
|||
|
||||
|
||||
def test_let():
|
||||
orig = relay.Let(e.x, e.y, e.z, e.tt)
|
||||
orig = relay.Let(e.x, e.y, e.z)
|
||||
assert alpha_equal(dead_code_elimination(orig), e.z)
|
||||
|
||||
|
||||
def test_used_let():
|
||||
orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c, e.tt), e.tt)
|
||||
assert alpha_equal(dead_code_elimination(orig), relay.Let(e.c, e.d, e.c, e.tt))
|
||||
orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c))
|
||||
assert alpha_equal(dead_code_elimination(orig), relay.Let(e.c, e.d, e.c))
|
||||
|
||||
|
||||
def test_chain_unused_let():
|
||||
orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e, e.tt), e.tt)
|
||||
orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e))
|
||||
assert alpha_equal(dead_code_elimination(orig), e.e)
|
||||
|
||||
|
||||
|
@ -56,19 +56,17 @@ def test_recursion():
|
|||
f(2, 10000);
|
||||
"""
|
||||
f = relay.Var("f")
|
||||
n = relay.Var("n")
|
||||
np = relay.Param(n, e.int32)
|
||||
data = relay.Var("data")
|
||||
datap = relay.Param(data, e.float32)
|
||||
n = relay.Var("n", e.int32)
|
||||
data = relay.Var("data", e.float32)
|
||||
funcbody = relay.If(equal(n, convert(0)), data, f(subtract(n, convert(1.0)), log(data)))
|
||||
value = relay.Function([np, datap], e.float32, funcbody, [])
|
||||
orig = relay.Let(f, funcbody, f(convert(2.0), convert(10000.0)), e.float32)
|
||||
value = relay.Function([n, data], e.float32, funcbody, [])
|
||||
orig = relay.Let(f, funcbody, f(convert(2.0), convert(10000.0)))
|
||||
assert alpha_equal(dead_code_elimination(orig), orig)
|
||||
assert alpha_equal(dead_code_elimination(relay.Let(f, funcbody, e.three, e.float32)), e.three)
|
||||
assert alpha_equal(dead_code_elimination(relay.Let(f, funcbody, e.three)), e.three)
|
||||
|
||||
|
||||
def test_op_let():
|
||||
assert alpha_equal(dead_code_elimination(add(relay.Let(e.a, e.one, e.three, e.float32), e.two)), add(e.three, e.two))
|
||||
assert alpha_equal(dead_code_elimination(add(relay.Let(e.a, e.one, e.three), e.two)), add(e.three, e.two))
|
||||
|
||||
|
||||
def test_if():
|
||||
|
@ -80,7 +78,7 @@ def test_tuple_get_item():
|
|||
t = relay.Var('t')
|
||||
g = relay.TupleGetItem(t, 0)
|
||||
assert alpha_equal(dead_code_elimination(g), g)
|
||||
assert alpha_equal(dead_code_elimination(relay.TupleGetItem(relay.Let(e.a, e.one, t, e.float32), 0)), g)
|
||||
assert alpha_equal(dead_code_elimination(relay.TupleGetItem(relay.Let(e.a, e.one, t), 0)), g)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -3,16 +3,17 @@ from tvm import relay
|
|||
from tvm.relay.ir_pass import free_vars, free_type_vars
|
||||
|
||||
def test_free_vars():
|
||||
x = relay.Var("x")
|
||||
ty = relay.TensorType([], "int32")
|
||||
x = relay.Var("x", ty)
|
||||
fvx = free_vars(x)
|
||||
assert len(fvx) == 1
|
||||
assert fvx[0] == x
|
||||
v = relay.Constant(tvm.nd.array(10))
|
||||
ty = relay.TensorType([], "int32")
|
||||
let = relay.Let(x, v, x, ty)
|
||||
|
||||
let = relay.Let(x, v, x)
|
||||
fvx = free_vars(let)
|
||||
assert len(free_vars(let)) == 0
|
||||
f = relay.Function([relay.Param(x, ty)], ty, x)
|
||||
f = relay.Function([x], ty, x)
|
||||
assert len(free_vars(f)) == 0
|
||||
|
||||
|
||||
|
@ -29,9 +30,9 @@ def test_tuple():
|
|||
def test_free_type_vars():
|
||||
tp = relay.TypeParam("")
|
||||
ty = relay.TupleType([tp, relay.TensorType([], "int32")])
|
||||
x = relay.Var("x")
|
||||
x = relay.Var("x", ty)
|
||||
y = relay.Var("y")
|
||||
let = relay.Let(x, y, x, ty)
|
||||
let = relay.Let(x, y, x)
|
||||
fvl = free_vars(let)
|
||||
assert len(fvl) == 1
|
||||
assert fvl[0] == y
|
||||
|
|
Загрузка…
Ссылка в новой задаче