Small refactors and bug fixes. (#2281)
This commit is contained in:
Родитель
5cb729ec48
Коммит
395804e524
|
@ -248,6 +248,13 @@ class FunctionNode : public ExprNode {
|
|||
*/
|
||||
TVM_DLL FuncType func_type_annotation() const;
|
||||
|
||||
/*!
|
||||
* \brief Check whether the function is a primitive function.
|
||||
*
|
||||
* \return Whether the function is primitive or not.
|
||||
*/
|
||||
bool IsPrimitive() const;
|
||||
|
||||
TVM_DLL static Function make(tvm::Array<Var> params,
|
||||
Expr body,
|
||||
Type ret_type,
|
||||
|
|
|
@ -5,6 +5,7 @@ from ..api import register_func
|
|||
from . import base
|
||||
from . import ty
|
||||
from . import expr
|
||||
from . import expr_functor
|
||||
from . import module
|
||||
from . import ir_pass
|
||||
from .build_module import build, build_config, create_executor
|
||||
|
@ -53,6 +54,10 @@ Let = expr.Let
|
|||
If = expr.If
|
||||
TupleGetItem = expr.TupleGetItem
|
||||
|
||||
# ExprFunctor
|
||||
ExprFunctor = expr_functor.ExprFunctor
|
||||
ExprMutator = expr_functor.ExprMutator
|
||||
|
||||
# helper functions
|
||||
var = expr.var
|
||||
const = expr.const
|
||||
|
|
|
@ -24,7 +24,8 @@ import attr
|
|||
from . import _backend
|
||||
from . import compile_engine
|
||||
from ..op import Op
|
||||
from ..expr import Function, GlobalVar, ExprFunctor
|
||||
from ..expr import Function, GlobalVar
|
||||
from ..expr_functor import ExprFunctor
|
||||
from ..ty import TupleType, TensorType
|
||||
|
||||
|
||||
|
@ -251,6 +252,9 @@ class GraphRuntimeCodegen(ExprFunctor):
|
|||
op_name, inputs, {})
|
||||
return self.add_node(op_node, call)
|
||||
|
||||
def visit_op(self, _):
|
||||
raise Exception("can not compile op in non-eta expanded form")
|
||||
|
||||
def _get_json(self):
|
||||
"""
|
||||
Convert the sequence of nodes stored by the compiler into the
|
||||
|
|
|
@ -222,12 +222,13 @@ class Function(Expr):
|
|||
params,
|
||||
body,
|
||||
ret_type=None,
|
||||
type_params=None):
|
||||
type_params=None,
|
||||
attrs=None):
|
||||
if type_params is None:
|
||||
type_params = convert([])
|
||||
|
||||
self.__init_handle_by_constructor__(
|
||||
_make.Function, params, body, ret_type, type_params)
|
||||
_make.Function, params, body, ret_type, type_params, attrs)
|
||||
|
||||
def __call__(self, *args):
|
||||
"""Invoke the gobal function.
|
||||
|
@ -343,131 +344,6 @@ class TempExpr(Expr):
|
|||
return _expr.TempExprRealize(self)
|
||||
|
||||
|
||||
class ExprFunctor(object):
|
||||
"""
|
||||
An abstract visitor defined over Expr.
|
||||
|
||||
Defines the default dispatch over expressions, and
|
||||
implements memoization.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.memo_map = {}
|
||||
|
||||
# pylint: disable=no-else-return
|
||||
def visit(self, expr):
|
||||
"""Apply the visitor to an expression."""
|
||||
found = self.memo_map.get(expr)
|
||||
if found:
|
||||
return found
|
||||
|
||||
if isinstance(expr, Function):
|
||||
res = self.visit_function(expr)
|
||||
elif isinstance(expr, Call):
|
||||
res = self.visit_call(expr)
|
||||
elif isinstance(expr, Let):
|
||||
res = self.visit_let(expr)
|
||||
elif isinstance(expr, Var):
|
||||
res = self.visit_var(expr)
|
||||
elif isinstance(expr, GlobalVar):
|
||||
res = self.visit_global_var(expr)
|
||||
elif isinstance(expr, If):
|
||||
res = self.visit_if(expr)
|
||||
elif isinstance(expr, Tuple):
|
||||
res = self.visit_tuple(expr)
|
||||
elif isinstance(expr, TupleGetItem):
|
||||
res = self.visit_tuple_getitem(expr)
|
||||
elif isinstance(expr, Constant):
|
||||
res = self.visit_constant(expr)
|
||||
else:
|
||||
raise Exception("warning unhandled case: {0}".format(type(expr)))
|
||||
|
||||
self.memo_map[expr] = res
|
||||
return res
|
||||
|
||||
def visit_function(self, _):
|
||||
raise NotImplementedError()
|
||||
|
||||
def visit_let(self, _):
|
||||
raise NotImplementedError()
|
||||
|
||||
def visit_call(self, _):
|
||||
raise NotImplementedError()
|
||||
|
||||
def visit_var(self, _):
|
||||
raise NotImplementedError()
|
||||
|
||||
def visit_type(self, typ):
|
||||
return typ
|
||||
|
||||
def visit_if(self, _):
|
||||
raise NotImplementedError()
|
||||
|
||||
def visit_tuple(self, _):
|
||||
raise NotImplementedError()
|
||||
|
||||
def visit_tuple_getitem(self, _):
|
||||
raise NotImplementedError()
|
||||
|
||||
def visit_constant(self, _):
|
||||
raise NotImplementedError()
|
||||
|
||||
def visit_global_var(self, _):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class ExprMutator(ExprFunctor):
|
||||
"""
|
||||
A functional visitor over Expr.
|
||||
|
||||
The default behavior recursively traverses the AST
|
||||
and reconstructs the AST.
|
||||
"""
|
||||
def visit_function(self, fn):
|
||||
new_body = self.visit(fn.body)
|
||||
return Function(
|
||||
list(fn.params),
|
||||
fn.ret_type, new_body,
|
||||
fn.type_params)
|
||||
|
||||
def visit_let(self, let):
|
||||
new_var = self.visit(let.var)
|
||||
new_val = self.visit(let.value)
|
||||
new_body = self.visit(let.body)
|
||||
return Let(new_var, new_val, new_body)
|
||||
|
||||
def visit_call(self, call):
|
||||
new_fn = self.visit(call.op)
|
||||
new_args = [self.visit(arg) for arg in call.args]
|
||||
return Call(new_fn, new_args, call.attrs)
|
||||
|
||||
def visit_var(self, rvar):
|
||||
return rvar
|
||||
|
||||
def visit_global_id(self, global_var):
|
||||
return global_var
|
||||
|
||||
def visit_if(self, ite):
|
||||
return If(
|
||||
self.visit(ite.guard),
|
||||
self.visit(ite.true_b),
|
||||
self.visit(ite.false_b))
|
||||
|
||||
def visit_tuple(self, tup):
|
||||
return Tuple([self.visit(field) for field in tup.fields])
|
||||
|
||||
def visit_tuple_getitem(self, op):
|
||||
tuple_value = self.visit(op.tuple_value)
|
||||
if not tuple_value.same_as(op.tuple_value):
|
||||
return TupleGetItem(tuple_value, op.index)
|
||||
return op
|
||||
|
||||
def visit_global_var(self, gvar):
|
||||
return gvar
|
||||
|
||||
def visit_constant(self, rconst):
|
||||
return rconst
|
||||
|
||||
|
||||
class TupleWrapper(object):
|
||||
"""TupleWrapper.
|
||||
|
||||
|
|
|
@ -0,0 +1,155 @@
|
|||
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
|
||||
"""The expression functor of Relay."""
|
||||
|
||||
from .expr import Function, Call, Let, Var, GlobalVar, If, Tuple, TupleGetItem, Constant
|
||||
from .op import Op
|
||||
|
||||
class ExprFunctor:
|
||||
"""
|
||||
An abstract visitor defined over Expr.
|
||||
|
||||
Defines the default dispatch over expressions, and
|
||||
implements memoization.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.memo_map = {}
|
||||
|
||||
# pylint: disable=no-else-return
|
||||
def visit(self, expr):
|
||||
"""Apply the visitor to an expression."""
|
||||
found = self.memo_map.get(expr)
|
||||
if found:
|
||||
return found
|
||||
|
||||
if isinstance(expr, Function):
|
||||
res = self.visit_function(expr)
|
||||
elif isinstance(expr, Call):
|
||||
res = self.visit_call(expr)
|
||||
elif isinstance(expr, Let):
|
||||
res = self.visit_let(expr)
|
||||
elif isinstance(expr, Var):
|
||||
res = self.visit_var(expr)
|
||||
elif isinstance(expr, GlobalVar):
|
||||
res = self.visit_global_var(expr)
|
||||
elif isinstance(expr, If):
|
||||
res = self.visit_if(expr)
|
||||
elif isinstance(expr, Tuple):
|
||||
res = self.visit_tuple(expr)
|
||||
elif isinstance(expr, TupleGetItem):
|
||||
res = self.visit_tuple_getitem(expr)
|
||||
elif isinstance(expr, Constant):
|
||||
res = self.visit_constant(expr)
|
||||
elif isinstance(expr, Op):
|
||||
res = self.visit_op(expr)
|
||||
else:
|
||||
raise Exception("warning unhandled case: {0}".format(type(expr)))
|
||||
|
||||
self.memo_map[expr] = res
|
||||
|
||||
return res
|
||||
|
||||
def visit_function(self, _):
|
||||
raise NotImplementedError()
|
||||
|
||||
def visit_let(self, _):
|
||||
raise NotImplementedError()
|
||||
|
||||
def visit_call(self, _):
|
||||
raise NotImplementedError()
|
||||
|
||||
def visit_var(self, _):
|
||||
raise NotImplementedError()
|
||||
|
||||
def visit_type(self, typ):
|
||||
return typ
|
||||
|
||||
def visit_if(self, _):
|
||||
raise NotImplementedError()
|
||||
|
||||
def visit_tuple(self, _):
|
||||
raise NotImplementedError()
|
||||
|
||||
def visit_tuple_getitem(self, _):
|
||||
raise NotImplementedError()
|
||||
|
||||
def visit_global_var(self, _):
|
||||
raise NotImplementedError()
|
||||
|
||||
def visit_op(self, _):
|
||||
raise NotImplementedError()
|
||||
|
||||
def visit_constant(self, _):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class ExprMutator(ExprFunctor):
|
||||
"""
|
||||
A functional visitor over Expr.
|
||||
|
||||
The default behavior recursively traverses the AST
|
||||
and reconstructs the AST.
|
||||
"""
|
||||
def visit_function(self, fn):
|
||||
new_body = self.visit(fn.body)
|
||||
return Function(
|
||||
list(fn.params),
|
||||
new_body,
|
||||
fn.ret_type,
|
||||
fn.type_params,
|
||||
fn.attrs)
|
||||
|
||||
def visit_let(self, let):
|
||||
new_var = self.visit(let.var)
|
||||
new_val = self.visit(let.value)
|
||||
new_body = self.visit(let.body)
|
||||
return Let(new_var, new_val, new_body)
|
||||
|
||||
def visit_call(self, call):
|
||||
new_fn = self.visit(call.op)
|
||||
new_args = [self.visit(arg) for arg in call.args]
|
||||
return Call(new_fn, new_args, call.attrs)
|
||||
|
||||
def visit_var(self, rvar):
|
||||
return rvar
|
||||
|
||||
def visit_global_id(self, global_var):
|
||||
return global_var
|
||||
|
||||
def visit_if(self, ite):
|
||||
return If(
|
||||
self.visit(ite.guard),
|
||||
self.visit(ite.true_b),
|
||||
self.visit(ite.false_b))
|
||||
|
||||
def visit_tuple(self, tup):
|
||||
return Tuple([self.visit(field) for field in tup.fields])
|
||||
|
||||
def visit_tuple_getitem(self, op):
|
||||
tuple_value = self.visit(op.tuple_value)
|
||||
if not tuple_value.same_as(op.tuple_value):
|
||||
return TupleGetItem(tuple_value, op.index)
|
||||
return op
|
||||
|
||||
def visit_global_var(self, gvar):
|
||||
return gvar
|
||||
|
||||
def visit_op(self, op):
|
||||
return op
|
||||
|
||||
def visit_constant(self, const):
|
||||
return const
|
||||
|
||||
def visit_constructor(self, con):
|
||||
return con
|
||||
|
||||
def visit_match(self, m):
|
||||
return Match(self.visit(m.data), [Clause(c.lhs, self.visit(c.rhs)) for c in m.pattern])
|
||||
|
||||
def visit_ref_new(self, r):
|
||||
return RefNew(self.visit(r.value))
|
||||
|
||||
def visit_ref_write(self, r):
|
||||
return RefWrite(self.visit(r.ref), self.visit(r.value))
|
||||
|
||||
def visit_ref_read(self, r):
|
||||
return RefRead(self.visit(r.ref))
|
|
@ -157,14 +157,14 @@ class ScheduleGetter :
|
|||
|
||||
int op_pattern = fpattern[op];
|
||||
if (op_pattern >= kCommReduce) {
|
||||
CHECK(!master_op_.defined() || master_op_patetrn_ < kCommReduce)
|
||||
CHECK(!master_op_.defined() || master_op_pattern_ < kCommReduce)
|
||||
<< "Two complicated op in a primitive function "
|
||||
<< " master=" << master_op_ << " current=" << op;
|
||||
}
|
||||
if (op_pattern >= master_op_patetrn_) {
|
||||
if (op_pattern >= master_op_pattern_) {
|
||||
master_op_ = op;
|
||||
master_attrs_ = call_node->attrs;
|
||||
master_op_patetrn_ = op_pattern;
|
||||
master_op_pattern_ = op_pattern;
|
||||
}
|
||||
if (outputs.size() != 1) {
|
||||
const auto* tuple_type =
|
||||
|
@ -213,7 +213,7 @@ class ScheduleGetter :
|
|||
tvm::Target target_;
|
||||
Op master_op_;
|
||||
Attrs master_attrs_;
|
||||
int master_op_patetrn_{0};
|
||||
int master_op_pattern_{0};
|
||||
std::ostringstream readable_name_stream_;
|
||||
std::unordered_map<Expr, Array<Tensor>, NodeHash, NodeEqual> memo_;
|
||||
};
|
||||
|
|
|
@ -292,17 +292,10 @@ class Interpreter :
|
|||
}
|
||||
}
|
||||
|
||||
// Check if function is a primitive function.
|
||||
bool IsPrimitive(const Function& func) const {
|
||||
NodeRef res = FunctionGetAttr(func, "Primitive");
|
||||
const ir::IntImm* pval = res.as<ir::IntImm>();
|
||||
return pval && pval->value != 0;
|
||||
}
|
||||
|
||||
// Invoke the closure
|
||||
Value Invoke(const Closure& closure, const tvm::Array<Value>& args) {
|
||||
// Get a reference to the function inside the closure.
|
||||
if (IsPrimitive(closure->func)) {
|
||||
if (closure->func->IsPrimitive()) {
|
||||
return InvokePrimitiveOp(closure->func, args);
|
||||
}
|
||||
auto func = closure->func;
|
||||
|
|
|
@ -135,6 +135,12 @@ FuncType FunctionNode::func_type_annotation() const {
|
|||
return FuncTypeNode::make(param_types, this->ret_type, this->type_params, {});
|
||||
}
|
||||
|
||||
bool FunctionNode::IsPrimitive() const {
|
||||
NodeRef res = FunctionGetAttr(GetRef<Function>(this), "Primitive");
|
||||
const ir::IntImm* pval = res.as<ir::IntImm>();
|
||||
return pval && pval->value != 0;
|
||||
}
|
||||
|
||||
NodeRef FunctionGetAttr(const Function& func, const std::string& key) {
|
||||
if (!func->attrs.defined()) { return NodeRef(); }
|
||||
|
||||
|
@ -172,7 +178,7 @@ TVM_REGISTER_NODE_TYPE(FunctionNode);
|
|||
|
||||
TVM_REGISTER_API("relay._make.Function")
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
*ret = FunctionNode::make(args[0], args[1], args[2], args[3]);
|
||||
*ret = FunctionNode::make(args[0], args[1], args[2], args[3], args[4]);
|
||||
});
|
||||
|
||||
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
|
||||
|
|
|
@ -699,9 +699,7 @@ class FuseMutator : private ExprMutator {
|
|||
std::unordered_map<GraphPartitioner::Group*, GroupInfo> ginfo_;
|
||||
// Skip primitive function.
|
||||
Expr VisitExpr_(const FunctionNode* fn_node) {
|
||||
NodeRef res = FunctionGetAttr(GetRef<Function>(fn_node), "Primitive");
|
||||
const ir::IntImm* pval = res.as<ir::IntImm>();
|
||||
if (pval && pval->value != 0) {
|
||||
if (fn_node->IsPrimitive()) {
|
||||
return GetRef<Expr>(fn_node);
|
||||
} else {
|
||||
return ExprMutator::VisitExpr_(fn_node);
|
||||
|
|
Загрузка…
Ссылка в новой задаче