diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 977bb679..fff630f5 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -296,13 +296,15 @@ TVM_DLL tvm::Array AllTypeVars(const Type& t, const Module& mod); * For example, this pass should turn `let a = 1 in 2` into `2`, * as the value of the expression does not depend on a. * - * As another example, `let a = 1 in a` will be optimized into 1. + * As another example, `let a = 1 in a` will be optimized into 1, + * if the flag is turned on. * * \param e the expression to optimize. + * \param inline_once whether or not to inline binding used one. * * \return the optimized expression. */ -TVM_DLL Expr DeadCodeElimination(const Expr& e); +TVM_DLL Expr DeadCodeElimination(const Expr& e, bool inline_once = false); /*! * \brief Fold constant expressions. @@ -435,11 +437,12 @@ TVM_DLL Array UnmatchedCases(const Match& match, const Module& mod); * It has two benefit: remove runtime overhead, and allow more optimization (typically fusion). * As a side effect, code size will explode. * - * \param e the expression, + * \param e the expression + * \param mod the module * * \return the optimized expression. */ -TVM_DLL Expr PartialEval(const Expr& e); +TVM_DLL Expr PartialEval(const Expr& e, const Module& mod); /*! * \brief Bind the free variables to a Relay expression. diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index f579f1c7..fb8ebbf0 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -356,9 +356,11 @@ TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc< * * As another example, `let a = 1 in a` will be optimized into 1. * + * \param inline_once whether or not to inline binding used one. + * * \return the pass. */ -TVM_DLL Pass DeadCodeElimination(); +TVM_DLL Pass DeadCodeElimination(bool inline_once = false); /*! * \brief Fold constant expressions. diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 8f1ceded..dd0f54c6 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -129,7 +129,7 @@ def well_formed(expr): Parameters ---------- - expr: tvm.relay.Expr + expr : tvm.relay.Expr The input expression Returns @@ -175,7 +175,7 @@ def free_vars(expr): Parameters ---------- - expr: tvm.relay.Expr + expr : tvm.relay.Expr The input expression Returns @@ -197,7 +197,7 @@ def bound_vars(expr): Parameters ---------- - expr: tvm.relay.Expr + expr : tvm.relay.Expr The input expression Returns @@ -213,7 +213,7 @@ def all_vars(expr): Parameters ---------- - expr: tvm.relay.Expr + expr : tvm.relay.Expr The input expression Returns @@ -229,9 +229,10 @@ def free_type_vars(expr, mod=None): Parameters ---------- - expr: Union[tvm.relay.Expr,tvm.relay.Type] + expr : Union[tvm.relay.Expr,tvm.relay.Type] The input expression/type - mod: tvm.relay.Module, optional + + mod : Optional[tvm.relay.Module] The global module Returns @@ -248,9 +249,10 @@ def bound_type_vars(expr, mod=None): Parameters ---------- - expr: Union[tvm.relay.Expr,tvm.relay.Type] + expr : Union[tvm.relay.Expr,tvm.relay.Type] The input expression/type - mod: tvm.relay.Module, optional + + mod : Optional[tvm.relay.Module] The global module Returns @@ -267,9 +269,9 @@ def all_type_vars(expr, mod=None): Parameters ---------- - expr: Union[tvm.relay.Expr,tvm.relay.Type] + expr : Union[tvm.relay.Expr,tvm.relay.Type] The input expression/type - mod: tvm.relay.Module, optional + mod : Optional[tvm.relay.Module] The global module Returns @@ -286,12 +288,12 @@ def simplify_inference(expr): Parameters ---------- - e: tvm.relay.Expr + expr : tvm.relay.Expr The input Expression Returns ------- - result: tvm.relay.Expr + result : tvm.relay.Expr An expression which is semantically equal to the input expression, but with some simplification """ @@ -304,32 +306,34 @@ def canonicalize_ops(expr): Parameters ---------- - e: tvm.relay.Expr + expr : tvm.relay.Expr The input Expression Returns ------- - result: tvm.relay.Expr + result : tvm.relay.Expr An expression without bias_add """ return _ir_pass.canonicalize_ops(expr) -def dead_code_elimination(expr): +def dead_code_elimination(expr, inline_once=False): """ Remove expressions which does not effect the program result (dead code). Parameters ---------- - e: tvm.relay.Expr + expr : tvm.relay.Expr The input Expression + inline_once : Optional[Bool] + Whether to inline binding that occur only once. Returns ------- - result: tvm.relay.Expr + result : tvm.relay.Expr An expression which is semantically equal to the input expression, but with dead code removed. """ - return _ir_pass.dead_code_elimination(expr) + return _ir_pass.dead_code_elimination(expr, inline_once) def alpha_equal(lhs, rhs): @@ -337,15 +341,15 @@ def alpha_equal(lhs, rhs): Parameters ---------- - lhs: tvm.relay.Expr + lhs : tvm.relay.Expr One of the input Expression. - rhs: tvm.relay.Expr + rhs : tvm.relay.Expr One of the input Expression. Returns ------- - result: bool + result : bool True iff lhs is alpha equal to rhs. """ return bool(_make._alpha_equal(lhs, rhs)) @@ -359,15 +363,15 @@ def graph_equal(lhs, rhs): Parameters ---------- - lhs: tvm.relay.Expr + lhs : tvm.relay.Expr One of the input Expression. - rhs: tvm.relay.Expr + rhs : tvm.relay.Expr One of the input Expression. Returns ------- - result: bool + result : bool True iff lhs is data-flow equivalent to rhs. """ return bool(_make._graph_equal(lhs, rhs)) @@ -378,12 +382,12 @@ def structural_hash(value): Parameters ---------- - expr: tvm.relay.Expr or tvm.relay.Type + expr : Union[tvm.relay.Expr, tvm.relay.Type] The expression to hash. Returns ------- - result: int + result : int The hash value """ if isinstance(value, Expr): @@ -544,12 +548,12 @@ def to_a_normal_form(expr, mod=None): expr : tvm.relay.Expr The input expression. - mod: Optional[tvm.relay.Module] + mod : Optional[tvm.relay.Module] The global module. Returns ------- - expr: tvm.relay.Expr + result : tvm.relay.Expr The output expression. """ return _ir_pass.to_a_normal_form(expr, mod) @@ -563,7 +567,7 @@ def to_graph_normal_form(expr): The input expression Returns ------- - expr : tvm.relay.Expr + result : tvm.relay.Expr The output expression """ return _ir_pass.to_graph_normal_form(expr) @@ -612,7 +616,7 @@ def get_total_mac_number(expr): Returns ------- - ret : int64 + result : int64 The number of MACs (multiply-accumulate) of a model """ return _ir_pass.GetTotalMacNumber(expr) @@ -627,17 +631,17 @@ def eliminate_common_subexpr(expr, fskip=None): expr : tvm.relay.Expr The input expression. - fskip: function + fskip : function The callback function that decides whether an expression should be skipped. Returns ------- - expr : tvm.relay.Expr + result : tvm.relay.Expr The output expression. """ return _ir_pass.eliminate_common_subexpr(expr, fskip) -def partial_evaluate(expr): +def partial_evaluate(expr, mod=None): """ Evaluate the static fragment of the code. @@ -646,12 +650,15 @@ def partial_evaluate(expr): expr : tvm.relay.Expr The input expression. + mod : Optional[tvm.relay.Module] + The global module + Returns ------- - expr : tvm.relay.Expr + result : tvm.relay.Expr The output expression. """ - return _ir_pass.partial_evaluate(expr) + return _ir_pass.partial_evaluate(expr, mod) def unmatched_cases(match, mod=None): """ diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 64706933..e0ec10a8 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -220,8 +220,8 @@ TVM_REGISTER_API("relay._make.Call") TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const CallNode* node, tvm::IRPrinter* p) { - p->stream << "CallNode(" << node->op << ", " << node->args << ", " - << node->attrs << ", " << node->type_args << ")"; + p->stream << "CallNode(" << node->op << ", " << node->args << ", " + << node->attrs << ", " << node->type_args << ")"; }); Let LetNode::make(Var var, Expr value, Expr body) { @@ -324,7 +324,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_REGISTER_API("relay._expr.TempExprRealize") .set_body_typed([](TempExpr temp) { - return temp->Realize(); + return temp->Realize(); }); } // namespace relay diff --git a/src/relay/pass/dead_code.cc b/src/relay/pass/dead_code.cc index be677456..7e186f80 100644 --- a/src/relay/pass/dead_code.cc +++ b/src/relay/pass/dead_code.cc @@ -38,10 +38,10 @@ namespace relay { // calculate the dependency graph from expression class CalcDep : private ExprVisitor { public: - static Expr Eliminate(const Expr& e) { + static Expr Eliminate(const Expr& e, bool inline_once) { CalcDep cd; cd.Calculate(e); - Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_); + Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_, inline_once); return el(e); } @@ -117,15 +117,23 @@ class CalcDep : private ExprVisitor { VarMap expr_map_; VarMap use_map_; VarSet letrec_set_; + bool inline_once_; explicit Eliminator(const VarMap& expr_map, const VarMap& use_map, - const VarSet& letrec_set) : - expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set) { } + const VarSet& letrec_set, + bool inline_once) : + expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set), inline_once_(inline_once) { } friend CalcDep; bool HasLet(const Var& v) { - // TODO(@jroesch): MK fix me - return (use_map_[v] > 0 || (use_map_[v] != 0 && letrec_set_.count(v) != 0)); + switch (use_map_[v]) { + case 0: + return false; + case 1: + return letrec_set_.count(v) > 0 || !inline_once_; + default: + return true; + } } Expr VisitExpr_(const VarNode* op) final { @@ -144,8 +152,8 @@ class CalcDep : private ExprVisitor { }; }; -Expr DeadCodeElimination(const Expr& e) { - return CalcDep::Eliminate(e); +Expr DeadCodeElimination(const Expr& e, bool inline_once) { + return CalcDep::Eliminate(e, inline_once); } TVM_REGISTER_API("relay._ir_pass.dead_code_elimination") @@ -153,10 +161,10 @@ TVM_REGISTER_API("relay._ir_pass.dead_code_elimination") namespace transform { -Pass DeadCodeElimination() { +Pass DeadCodeElimination(bool inline_once) { runtime::TypedPackedFunc pass_func = [=](Function f, Module m, PassContext pc) { - return Downcast(DeadCodeElimination(f)); + return Downcast(DeadCodeElimination(f, inline_once)); }; return CreateFunctionPass(pass_func, 1, "DeadCodeElimination", {}); } diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index 71ba7cd1..07ec1b07 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -74,28 +74,19 @@ * * The partial evaluator makes several assumptions, so there is room for improvement: * - * 0: The partial evaluator treats global variables as opaque. - * Doing PartialEval on a module level will solve this. - * - * 1: The partial evaluator assume all functions as terminating. - * We need to has a max_expand parameter that shrink on every compile time evaluation, - * to make sure PE does not infinite loop. - * Additionally, we might add a termination analysis pass that lift this requirement - * for function that analysis found terminating. - * - * 2: Every time an unknown effect happened, we clear the whole store. + * 0: Every time an unknown effect happened, we clear the whole store. * It is too conservative: if a local reference is created (and do not get passed outside), * An unknown global function call/global reference write can not modify it. * We can pair PE with escape analysis/alias analysis. * - * 3: We assume all unknown code has effect. Doing effect analysis can make the store more precise. + * 1: We assume all unknown code has effect. Doing effect analysis can make the store more precise. * - * 4: When doing pattern matching, we can simplify the match even for dynamic case. + * 2: When doing pattern matching, we can simplify the match even for dynamic case. * Right now it is all or nothing: either a complete match, or the original dynamic code. * Instead, we can get a match tree, pair it with the data and evaluate it to a normal form. * We then can reify the result. * - * 5: Every time a function is called, it's code will get expanded and partially evaluated. + * 3: Every time a function is called, its code will get expanded and partially evaluated. * We can do a binding time analysis to cache the result and avoid re-partial evaluation. * * These assumptions do not affect the correctness of the algorithm, however. @@ -104,6 +95,7 @@ #include #include #include +#include "../ir/type_functor.h" #include "pass_util.h" #include "let_list.h" @@ -132,6 +124,8 @@ struct VarEqual { } }; +Expr PostProcess(const Expr&); + /*! \brief The base container type of Relay values. */ class StaticNode : public RelayNode { public: @@ -150,10 +144,20 @@ class Static : public NodeRef { using ContainerType = StaticNode; }; +using Time = size_t; + struct PStaticNode : Node { + static Time time() { + static Time time_ = 0; + Time ret = time_; + time_++; + return ret; + } Static pstatic; // may be null Expr dynamic; - PStaticNode(const Static& pstatic, const Expr& dynamic) : pstatic(pstatic), dynamic(dynamic) { } + Time created_time; + PStaticNode(const Static& pstatic, const Expr& dynamic) : + pstatic(pstatic), dynamic(dynamic), created_time(time()) { } explicit PStaticNode(const Expr& dynamic) : PStaticNode(Static(), dynamic) { } TVM_DECLARE_NODE_TYPE_INFO(PStaticNode, Node); }; @@ -341,6 +345,7 @@ class Store { }; PStatic HasStatic(const Static& stat, const Expr& dynamic) { + CHECK(stat.defined()); return PStatic(make_node(stat, dynamic)); } @@ -383,15 +388,78 @@ FInterpreter CPUInterpreter() { return CreateInterpreter(Module(nullptr), CPUContext(), target); } +bool IsAtomic(const Expr& e) { + return e.as() || e.as() || e.as() || e.as(); +} + +using FuncId = int; + +/*! + * \brief Annotate a function with a FuncId. + */ +struct WithFuncIdAttrs : public tvm::AttrsNode { + FuncId fid; + + TVM_DECLARE_ATTRS(WithFuncIdAttrs, "relay.attrs.WithFuncIdAttrs") { + TVM_ATTR_FIELD(fid) + .describe("The FuncId that an function is annotated with.") + .set_default(-1); + } +}; + +TVM_REGISTER_NODE_TYPE(WithFuncIdAttrs); + +Op WithFuncIdOp() { + static const Op& op = Op::Get("annotation.with_funcid"); + return op; +} + +Expr MkWithFuncId(const Expr& expr, FuncId fid) { + auto attrs = make_node(); + attrs->fid = fid; + return CallNode::make(WithFuncIdOp(), {expr}, Attrs(attrs), {}); +} + +RELAY_REGISTER_OP("annotation.with_funcid") +.describe(R"code(Annotate a function with a funcid.)code" +TVM_ADD_FILELINE) +.set_num_inputs(1) +.add_argument("func", "Function", "The input data."); + +Expr StripWithFuncId(const Expr& e); + +Expr DeDup(const Expr& e); + +Function AsFunc(const Expr& e) { + if (e.as()) { + return Downcast(e); + } else if (const CallNode* c = e.as()) { + CHECK(c->op.same_as(WithFuncIdOp())); + CHECK_EQ(c->args.size(), 1); + return AsFunc(c->args[0]); + } else { + LOG(FATAL) << "Unknown case"; + throw; + } +} + class PartialEvaluator : public ExprFunctor, public PatternFunctor { public: - PartialEvaluator(const tvm::Array& free_vars) { + PartialEvaluator(const tvm::Array& free_vars, + const Module& mod) : + mod_(mod) { for (const Var& v : free_vars) { env_.Insert(v, NoStatic(v)); } } + PStatic VisitExpr(const Expr& e, LetList* ll) final { + PStatic ret = ExprFunctor::VisitExpr(e, ll); + CHECK(IsAtomic(ret->dynamic)) << ret->dynamic; + return ret; + } + PStatic VisitExpr_(const ConstantNode* op, LetList* ll) final { return HasStatic(MkSTensor(op->data.CopyTo(context_)), ll->Push(GetRef(op))); } @@ -421,7 +489,20 @@ class PartialEvaluator : public ExprFunctor } PStatic VisitExpr_(const GlobalVarNode* op, LetList* ll) final { - return NoStatic(GetRef(op)); + GlobalVar gv = GetRef(op); + if (gv_map_.count(gv) == 0) { + if (mod_.defined()) { + Function func = mod_->Lookup(gv); + InitializeFuncId(func); + Func f = VisitFuncStatic(func, gv); + gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)}); + func = AsFunc(PostProcess(VisitFuncDynamic(func, f))); + mod_->Update(gv, func); + } else { + gv_map_.insert({gv, NoStatic(gv)}); + } + } + return gv_map_.at(gv); } PStatic VisitExpr_(const LetNode* op, LetList* ll) final { @@ -485,6 +566,10 @@ class PartialEvaluator : public ExprFunctor } PStatic VisitExpr_(const CallNode* op, LetList* ll) final { + if (op->op.same_as(WithFuncIdOp())) { + CHECK_EQ(op->args.size(), 1); + return VisitExpr(op->args[0], ll); + } PStatic f = VisitExpr(op->op, ll); std::vector x; tvm::Array x_dyn; @@ -501,19 +586,40 @@ class PartialEvaluator : public ExprFunctor } } - PStatic VisitExpr_(const FunctionNode* op, LetList* ll) final { - Function func = GetRef(op); + struct TimeFrame { + PartialEvaluator* pe_; + FuncId fid_; + std::vector