save (#3033)
save save save upstream lint remove bad changes fix build save save please the ci god Update src/relay/pass/partial_eval.cc Co-Authored-By: Wei Chen <ipondering.weic@gmail.com> save fix test ci is ANGRY fix rebase problem fix rebase add test save save comment
This commit is contained in:
Родитель
50dd03ca86
Коммит
df88c411f5
|
@ -296,13 +296,15 @@ TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const Module& mod);
|
||||||
* For example, this pass should turn `let a = 1 in 2` into `2`,
|
* For example, this pass should turn `let a = 1 in 2` into `2`,
|
||||||
* as the value of the expression does not depend on a.
|
* as the value of the expression does not depend on a.
|
||||||
*
|
*
|
||||||
* As another example, `let a = 1 in a` will be optimized into 1.
|
* As another example, `let a = 1 in a` will be optimized into 1,
|
||||||
|
* if the flag is turned on.
|
||||||
*
|
*
|
||||||
* \param e the expression to optimize.
|
* \param e the expression to optimize.
|
||||||
|
* \param inline_once whether or not to inline binding used one.
|
||||||
*
|
*
|
||||||
* \return the optimized expression.
|
* \return the optimized expression.
|
||||||
*/
|
*/
|
||||||
TVM_DLL Expr DeadCodeElimination(const Expr& e);
|
TVM_DLL Expr DeadCodeElimination(const Expr& e, bool inline_once = false);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Fold constant expressions.
|
* \brief Fold constant expressions.
|
||||||
|
@ -435,11 +437,12 @@ TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const Module& mod);
|
||||||
* It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
|
* It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
|
||||||
* As a side effect, code size will explode.
|
* As a side effect, code size will explode.
|
||||||
*
|
*
|
||||||
* \param e the expression,
|
* \param e the expression
|
||||||
|
* \param mod the module
|
||||||
*
|
*
|
||||||
* \return the optimized expression.
|
* \return the optimized expression.
|
||||||
*/
|
*/
|
||||||
TVM_DLL Expr PartialEval(const Expr& e);
|
TVM_DLL Expr PartialEval(const Expr& e, const Module& mod);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Bind the free variables to a Relay expression.
|
* \brief Bind the free variables to a Relay expression.
|
||||||
|
|
|
@ -356,9 +356,11 @@ TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc<
|
||||||
*
|
*
|
||||||
* As another example, `let a = 1 in a` will be optimized into 1.
|
* As another example, `let a = 1 in a` will be optimized into 1.
|
||||||
*
|
*
|
||||||
|
* \param inline_once whether or not to inline binding used one.
|
||||||
|
*
|
||||||
* \return the pass.
|
* \return the pass.
|
||||||
*/
|
*/
|
||||||
TVM_DLL Pass DeadCodeElimination();
|
TVM_DLL Pass DeadCodeElimination(bool inline_once = false);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Fold constant expressions.
|
* \brief Fold constant expressions.
|
||||||
|
|
|
@ -129,7 +129,7 @@ def well_formed(expr):
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
expr: tvm.relay.Expr
|
expr : tvm.relay.Expr
|
||||||
The input expression
|
The input expression
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
|
@ -175,7 +175,7 @@ def free_vars(expr):
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
expr: tvm.relay.Expr
|
expr : tvm.relay.Expr
|
||||||
The input expression
|
The input expression
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
|
@ -197,7 +197,7 @@ def bound_vars(expr):
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
expr: tvm.relay.Expr
|
expr : tvm.relay.Expr
|
||||||
The input expression
|
The input expression
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
|
@ -213,7 +213,7 @@ def all_vars(expr):
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
expr: tvm.relay.Expr
|
expr : tvm.relay.Expr
|
||||||
The input expression
|
The input expression
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
|
@ -229,9 +229,10 @@ def free_type_vars(expr, mod=None):
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
expr: Union[tvm.relay.Expr,tvm.relay.Type]
|
expr : Union[tvm.relay.Expr,tvm.relay.Type]
|
||||||
The input expression/type
|
The input expression/type
|
||||||
mod: tvm.relay.Module, optional
|
|
||||||
|
mod : Optional[tvm.relay.Module]
|
||||||
The global module
|
The global module
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
|
@ -248,9 +249,10 @@ def bound_type_vars(expr, mod=None):
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
expr: Union[tvm.relay.Expr,tvm.relay.Type]
|
expr : Union[tvm.relay.Expr,tvm.relay.Type]
|
||||||
The input expression/type
|
The input expression/type
|
||||||
mod: tvm.relay.Module, optional
|
|
||||||
|
mod : Optional[tvm.relay.Module]
|
||||||
The global module
|
The global module
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
|
@ -267,9 +269,9 @@ def all_type_vars(expr, mod=None):
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
expr: Union[tvm.relay.Expr,tvm.relay.Type]
|
expr : Union[tvm.relay.Expr,tvm.relay.Type]
|
||||||
The input expression/type
|
The input expression/type
|
||||||
mod: tvm.relay.Module, optional
|
mod : Optional[tvm.relay.Module]
|
||||||
The global module
|
The global module
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
|
@ -286,12 +288,12 @@ def simplify_inference(expr):
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
e: tvm.relay.Expr
|
expr : tvm.relay.Expr
|
||||||
The input Expression
|
The input Expression
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
result: tvm.relay.Expr
|
result : tvm.relay.Expr
|
||||||
An expression which is semantically equal to the input expression,
|
An expression which is semantically equal to the input expression,
|
||||||
but with some simplification
|
but with some simplification
|
||||||
"""
|
"""
|
||||||
|
@ -304,32 +306,34 @@ def canonicalize_ops(expr):
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
e: tvm.relay.Expr
|
expr : tvm.relay.Expr
|
||||||
The input Expression
|
The input Expression
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
result: tvm.relay.Expr
|
result : tvm.relay.Expr
|
||||||
An expression without bias_add
|
An expression without bias_add
|
||||||
"""
|
"""
|
||||||
return _ir_pass.canonicalize_ops(expr)
|
return _ir_pass.canonicalize_ops(expr)
|
||||||
|
|
||||||
|
|
||||||
def dead_code_elimination(expr):
|
def dead_code_elimination(expr, inline_once=False):
|
||||||
""" Remove expressions which does not effect the program result (dead code).
|
""" Remove expressions which does not effect the program result (dead code).
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
e: tvm.relay.Expr
|
expr : tvm.relay.Expr
|
||||||
The input Expression
|
The input Expression
|
||||||
|
|
||||||
|
inline_once : Optional[Bool]
|
||||||
|
Whether to inline binding that occur only once.
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
result: tvm.relay.Expr
|
result : tvm.relay.Expr
|
||||||
An expression which is semantically equal to the input expression,
|
An expression which is semantically equal to the input expression,
|
||||||
but with dead code removed.
|
but with dead code removed.
|
||||||
"""
|
"""
|
||||||
return _ir_pass.dead_code_elimination(expr)
|
return _ir_pass.dead_code_elimination(expr, inline_once)
|
||||||
|
|
||||||
|
|
||||||
def alpha_equal(lhs, rhs):
|
def alpha_equal(lhs, rhs):
|
||||||
|
@ -337,15 +341,15 @@ def alpha_equal(lhs, rhs):
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
lhs: tvm.relay.Expr
|
lhs : tvm.relay.Expr
|
||||||
One of the input Expression.
|
One of the input Expression.
|
||||||
|
|
||||||
rhs: tvm.relay.Expr
|
rhs : tvm.relay.Expr
|
||||||
One of the input Expression.
|
One of the input Expression.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
result: bool
|
result : bool
|
||||||
True iff lhs is alpha equal to rhs.
|
True iff lhs is alpha equal to rhs.
|
||||||
"""
|
"""
|
||||||
return bool(_make._alpha_equal(lhs, rhs))
|
return bool(_make._alpha_equal(lhs, rhs))
|
||||||
|
@ -359,15 +363,15 @@ def graph_equal(lhs, rhs):
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
lhs: tvm.relay.Expr
|
lhs : tvm.relay.Expr
|
||||||
One of the input Expression.
|
One of the input Expression.
|
||||||
|
|
||||||
rhs: tvm.relay.Expr
|
rhs : tvm.relay.Expr
|
||||||
One of the input Expression.
|
One of the input Expression.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
result: bool
|
result : bool
|
||||||
True iff lhs is data-flow equivalent to rhs.
|
True iff lhs is data-flow equivalent to rhs.
|
||||||
"""
|
"""
|
||||||
return bool(_make._graph_equal(lhs, rhs))
|
return bool(_make._graph_equal(lhs, rhs))
|
||||||
|
@ -378,12 +382,12 @@ def structural_hash(value):
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
expr: tvm.relay.Expr or tvm.relay.Type
|
expr : Union[tvm.relay.Expr, tvm.relay.Type]
|
||||||
The expression to hash.
|
The expression to hash.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
result: int
|
result : int
|
||||||
The hash value
|
The hash value
|
||||||
"""
|
"""
|
||||||
if isinstance(value, Expr):
|
if isinstance(value, Expr):
|
||||||
|
@ -544,12 +548,12 @@ def to_a_normal_form(expr, mod=None):
|
||||||
expr : tvm.relay.Expr
|
expr : tvm.relay.Expr
|
||||||
The input expression.
|
The input expression.
|
||||||
|
|
||||||
mod: Optional[tvm.relay.Module]
|
mod : Optional[tvm.relay.Module]
|
||||||
The global module.
|
The global module.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
expr: tvm.relay.Expr
|
result : tvm.relay.Expr
|
||||||
The output expression.
|
The output expression.
|
||||||
"""
|
"""
|
||||||
return _ir_pass.to_a_normal_form(expr, mod)
|
return _ir_pass.to_a_normal_form(expr, mod)
|
||||||
|
@ -563,7 +567,7 @@ def to_graph_normal_form(expr):
|
||||||
The input expression
|
The input expression
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
expr : tvm.relay.Expr
|
result : tvm.relay.Expr
|
||||||
The output expression
|
The output expression
|
||||||
"""
|
"""
|
||||||
return _ir_pass.to_graph_normal_form(expr)
|
return _ir_pass.to_graph_normal_form(expr)
|
||||||
|
@ -612,7 +616,7 @@ def get_total_mac_number(expr):
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
ret : int64
|
result : int64
|
||||||
The number of MACs (multiply-accumulate) of a model
|
The number of MACs (multiply-accumulate) of a model
|
||||||
"""
|
"""
|
||||||
return _ir_pass.GetTotalMacNumber(expr)
|
return _ir_pass.GetTotalMacNumber(expr)
|
||||||
|
@ -627,17 +631,17 @@ def eliminate_common_subexpr(expr, fskip=None):
|
||||||
expr : tvm.relay.Expr
|
expr : tvm.relay.Expr
|
||||||
The input expression.
|
The input expression.
|
||||||
|
|
||||||
fskip: function
|
fskip : function
|
||||||
The callback function that decides whether an expression should be skipped.
|
The callback function that decides whether an expression should be skipped.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
expr : tvm.relay.Expr
|
result : tvm.relay.Expr
|
||||||
The output expression.
|
The output expression.
|
||||||
"""
|
"""
|
||||||
return _ir_pass.eliminate_common_subexpr(expr, fskip)
|
return _ir_pass.eliminate_common_subexpr(expr, fskip)
|
||||||
|
|
||||||
def partial_evaluate(expr):
|
def partial_evaluate(expr, mod=None):
|
||||||
"""
|
"""
|
||||||
Evaluate the static fragment of the code.
|
Evaluate the static fragment of the code.
|
||||||
|
|
||||||
|
@ -646,12 +650,15 @@ def partial_evaluate(expr):
|
||||||
expr : tvm.relay.Expr
|
expr : tvm.relay.Expr
|
||||||
The input expression.
|
The input expression.
|
||||||
|
|
||||||
|
mod : Optional[tvm.relay.Module]
|
||||||
|
The global module
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
expr : tvm.relay.Expr
|
result : tvm.relay.Expr
|
||||||
The output expression.
|
The output expression.
|
||||||
"""
|
"""
|
||||||
return _ir_pass.partial_evaluate(expr)
|
return _ir_pass.partial_evaluate(expr, mod)
|
||||||
|
|
||||||
def unmatched_cases(match, mod=None):
|
def unmatched_cases(match, mod=None):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -220,8 +220,8 @@ TVM_REGISTER_API("relay._make.Call")
|
||||||
|
|
||||||
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
|
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
|
||||||
.set_dispatch<CallNode>([](const CallNode* node, tvm::IRPrinter* p) {
|
.set_dispatch<CallNode>([](const CallNode* node, tvm::IRPrinter* p) {
|
||||||
p->stream << "CallNode(" << node->op << ", " << node->args << ", "
|
p->stream << "CallNode(" << node->op << ", " << node->args << ", "
|
||||||
<< node->attrs << ", " << node->type_args << ")";
|
<< node->attrs << ", " << node->type_args << ")";
|
||||||
});
|
});
|
||||||
|
|
||||||
Let LetNode::make(Var var, Expr value, Expr body) {
|
Let LetNode::make(Var var, Expr value, Expr body) {
|
||||||
|
@ -324,7 +324,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
|
||||||
|
|
||||||
TVM_REGISTER_API("relay._expr.TempExprRealize")
|
TVM_REGISTER_API("relay._expr.TempExprRealize")
|
||||||
.set_body_typed<Expr(TempExpr)>([](TempExpr temp) {
|
.set_body_typed<Expr(TempExpr)>([](TempExpr temp) {
|
||||||
return temp->Realize();
|
return temp->Realize();
|
||||||
});
|
});
|
||||||
|
|
||||||
} // namespace relay
|
} // namespace relay
|
||||||
|
|
|
@ -38,10 +38,10 @@ namespace relay {
|
||||||
// calculate the dependency graph from expression
|
// calculate the dependency graph from expression
|
||||||
class CalcDep : private ExprVisitor {
|
class CalcDep : private ExprVisitor {
|
||||||
public:
|
public:
|
||||||
static Expr Eliminate(const Expr& e) {
|
static Expr Eliminate(const Expr& e, bool inline_once) {
|
||||||
CalcDep cd;
|
CalcDep cd;
|
||||||
cd.Calculate(e);
|
cd.Calculate(e);
|
||||||
Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_);
|
Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_, inline_once);
|
||||||
return el(e);
|
return el(e);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -117,15 +117,23 @@ class CalcDep : private ExprVisitor {
|
||||||
VarMap<Expr> expr_map_;
|
VarMap<Expr> expr_map_;
|
||||||
VarMap<size_t> use_map_;
|
VarMap<size_t> use_map_;
|
||||||
VarSet letrec_set_;
|
VarSet letrec_set_;
|
||||||
|
bool inline_once_;
|
||||||
explicit Eliminator(const VarMap<Expr>& expr_map,
|
explicit Eliminator(const VarMap<Expr>& expr_map,
|
||||||
const VarMap<size_t>& use_map,
|
const VarMap<size_t>& use_map,
|
||||||
const VarSet& letrec_set) :
|
const VarSet& letrec_set,
|
||||||
expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set) { }
|
bool inline_once) :
|
||||||
|
expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set), inline_once_(inline_once) { }
|
||||||
friend CalcDep;
|
friend CalcDep;
|
||||||
|
|
||||||
bool HasLet(const Var& v) {
|
bool HasLet(const Var& v) {
|
||||||
// TODO(@jroesch): MK fix me
|
switch (use_map_[v]) {
|
||||||
return (use_map_[v] > 0 || (use_map_[v] != 0 && letrec_set_.count(v) != 0));
|
case 0:
|
||||||
|
return false;
|
||||||
|
case 1:
|
||||||
|
return letrec_set_.count(v) > 0 || !inline_once_;
|
||||||
|
default:
|
||||||
|
return true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Expr VisitExpr_(const VarNode* op) final {
|
Expr VisitExpr_(const VarNode* op) final {
|
||||||
|
@ -144,8 +152,8 @@ class CalcDep : private ExprVisitor {
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
Expr DeadCodeElimination(const Expr& e) {
|
Expr DeadCodeElimination(const Expr& e, bool inline_once) {
|
||||||
return CalcDep::Eliminate(e);
|
return CalcDep::Eliminate(e, inline_once);
|
||||||
}
|
}
|
||||||
|
|
||||||
TVM_REGISTER_API("relay._ir_pass.dead_code_elimination")
|
TVM_REGISTER_API("relay._ir_pass.dead_code_elimination")
|
||||||
|
@ -153,10 +161,10 @@ TVM_REGISTER_API("relay._ir_pass.dead_code_elimination")
|
||||||
|
|
||||||
namespace transform {
|
namespace transform {
|
||||||
|
|
||||||
Pass DeadCodeElimination() {
|
Pass DeadCodeElimination(bool inline_once) {
|
||||||
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
|
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
|
||||||
[=](Function f, Module m, PassContext pc) {
|
[=](Function f, Module m, PassContext pc) {
|
||||||
return Downcast<Function>(DeadCodeElimination(f));
|
return Downcast<Function>(DeadCodeElimination(f, inline_once));
|
||||||
};
|
};
|
||||||
return CreateFunctionPass(pass_func, 1, "DeadCodeElimination", {});
|
return CreateFunctionPass(pass_func, 1, "DeadCodeElimination", {});
|
||||||
}
|
}
|
||||||
|
|
|
@ -74,28 +74,19 @@
|
||||||
*
|
*
|
||||||
* The partial evaluator makes several assumptions, so there is room for improvement:
|
* The partial evaluator makes several assumptions, so there is room for improvement:
|
||||||
*
|
*
|
||||||
* 0: The partial evaluator treats global variables as opaque.
|
* 0: Every time an unknown effect happened, we clear the whole store.
|
||||||
* Doing PartialEval on a module level will solve this.
|
|
||||||
*
|
|
||||||
* 1: The partial evaluator assume all functions as terminating.
|
|
||||||
* We need to has a max_expand parameter that shrink on every compile time evaluation,
|
|
||||||
* to make sure PE does not infinite loop.
|
|
||||||
* Additionally, we might add a termination analysis pass that lift this requirement
|
|
||||||
* for function that analysis found terminating.
|
|
||||||
*
|
|
||||||
* 2: Every time an unknown effect happened, we clear the whole store.
|
|
||||||
* It is too conservative: if a local reference is created (and do not get passed outside),
|
* It is too conservative: if a local reference is created (and do not get passed outside),
|
||||||
* An unknown global function call/global reference write can not modify it.
|
* An unknown global function call/global reference write can not modify it.
|
||||||
* We can pair PE with escape analysis/alias analysis.
|
* We can pair PE with escape analysis/alias analysis.
|
||||||
*
|
*
|
||||||
* 3: We assume all unknown code has effect. Doing effect analysis can make the store more precise.
|
* 1: We assume all unknown code has effect. Doing effect analysis can make the store more precise.
|
||||||
*
|
*
|
||||||
* 4: When doing pattern matching, we can simplify the match even for dynamic case.
|
* 2: When doing pattern matching, we can simplify the match even for dynamic case.
|
||||||
* Right now it is all or nothing: either a complete match, or the original dynamic code.
|
* Right now it is all or nothing: either a complete match, or the original dynamic code.
|
||||||
* Instead, we can get a match tree, pair it with the data and evaluate it to a normal form.
|
* Instead, we can get a match tree, pair it with the data and evaluate it to a normal form.
|
||||||
* We then can reify the result.
|
* We then can reify the result.
|
||||||
*
|
*
|
||||||
* 5: Every time a function is called, it's code will get expanded and partially evaluated.
|
* 3: Every time a function is called, its code will get expanded and partially evaluated.
|
||||||
* We can do a binding time analysis to cache the result and avoid re-partial evaluation.
|
* We can do a binding time analysis to cache the result and avoid re-partial evaluation.
|
||||||
*
|
*
|
||||||
* These assumptions do not affect the correctness of the algorithm, however.
|
* These assumptions do not affect the correctness of the algorithm, however.
|
||||||
|
@ -104,6 +95,7 @@
|
||||||
#include <tvm/relay/expr_functor.h>
|
#include <tvm/relay/expr_functor.h>
|
||||||
#include <tvm/relay/pattern_functor.h>
|
#include <tvm/relay/pattern_functor.h>
|
||||||
#include <tvm/relay/interpreter.h>
|
#include <tvm/relay/interpreter.h>
|
||||||
|
#include "../ir/type_functor.h"
|
||||||
#include "pass_util.h"
|
#include "pass_util.h"
|
||||||
#include "let_list.h"
|
#include "let_list.h"
|
||||||
|
|
||||||
|
@ -132,6 +124,8 @@ struct VarEqual {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Expr PostProcess(const Expr&);
|
||||||
|
|
||||||
/*! \brief The base container type of Relay values. */
|
/*! \brief The base container type of Relay values. */
|
||||||
class StaticNode : public RelayNode {
|
class StaticNode : public RelayNode {
|
||||||
public:
|
public:
|
||||||
|
@ -150,10 +144,20 @@ class Static : public NodeRef {
|
||||||
using ContainerType = StaticNode;
|
using ContainerType = StaticNode;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
using Time = size_t;
|
||||||
|
|
||||||
struct PStaticNode : Node {
|
struct PStaticNode : Node {
|
||||||
|
static Time time() {
|
||||||
|
static Time time_ = 0;
|
||||||
|
Time ret = time_;
|
||||||
|
time_++;
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
Static pstatic; // may be null
|
Static pstatic; // may be null
|
||||||
Expr dynamic;
|
Expr dynamic;
|
||||||
PStaticNode(const Static& pstatic, const Expr& dynamic) : pstatic(pstatic), dynamic(dynamic) { }
|
Time created_time;
|
||||||
|
PStaticNode(const Static& pstatic, const Expr& dynamic) :
|
||||||
|
pstatic(pstatic), dynamic(dynamic), created_time(time()) { }
|
||||||
explicit PStaticNode(const Expr& dynamic) : PStaticNode(Static(), dynamic) { }
|
explicit PStaticNode(const Expr& dynamic) : PStaticNode(Static(), dynamic) { }
|
||||||
TVM_DECLARE_NODE_TYPE_INFO(PStaticNode, Node);
|
TVM_DECLARE_NODE_TYPE_INFO(PStaticNode, Node);
|
||||||
};
|
};
|
||||||
|
@ -341,6 +345,7 @@ class Store {
|
||||||
};
|
};
|
||||||
|
|
||||||
PStatic HasStatic(const Static& stat, const Expr& dynamic) {
|
PStatic HasStatic(const Static& stat, const Expr& dynamic) {
|
||||||
|
CHECK(stat.defined());
|
||||||
return PStatic(make_node<PStaticNode>(stat, dynamic));
|
return PStatic(make_node<PStaticNode>(stat, dynamic));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -383,15 +388,78 @@ FInterpreter CPUInterpreter() {
|
||||||
return CreateInterpreter(Module(nullptr), CPUContext(), target);
|
return CreateInterpreter(Module(nullptr), CPUContext(), target);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsAtomic(const Expr& e) {
|
||||||
|
return e.as<VarNode>() || e.as<OpNode>() || e.as<ConstructorNode>() || e.as<GlobalVarNode>();
|
||||||
|
}
|
||||||
|
|
||||||
|
using FuncId = int;
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Annotate a function with a FuncId.
|
||||||
|
*/
|
||||||
|
struct WithFuncIdAttrs : public tvm::AttrsNode<WithFuncIdAttrs> {
|
||||||
|
FuncId fid;
|
||||||
|
|
||||||
|
TVM_DECLARE_ATTRS(WithFuncIdAttrs, "relay.attrs.WithFuncIdAttrs") {
|
||||||
|
TVM_ATTR_FIELD(fid)
|
||||||
|
.describe("The FuncId that an function is annotated with.")
|
||||||
|
.set_default(-1);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TVM_REGISTER_NODE_TYPE(WithFuncIdAttrs);
|
||||||
|
|
||||||
|
Op WithFuncIdOp() {
|
||||||
|
static const Op& op = Op::Get("annotation.with_funcid");
|
||||||
|
return op;
|
||||||
|
}
|
||||||
|
|
||||||
|
Expr MkWithFuncId(const Expr& expr, FuncId fid) {
|
||||||
|
auto attrs = make_node<WithFuncIdAttrs>();
|
||||||
|
attrs->fid = fid;
|
||||||
|
return CallNode::make(WithFuncIdOp(), {expr}, Attrs(attrs), {});
|
||||||
|
}
|
||||||
|
|
||||||
|
RELAY_REGISTER_OP("annotation.with_funcid")
|
||||||
|
.describe(R"code(Annotate a function with a funcid.)code"
|
||||||
|
TVM_ADD_FILELINE)
|
||||||
|
.set_num_inputs(1)
|
||||||
|
.add_argument("func", "Function", "The input data.");
|
||||||
|
|
||||||
|
Expr StripWithFuncId(const Expr& e);
|
||||||
|
|
||||||
|
Expr DeDup(const Expr& e);
|
||||||
|
|
||||||
|
Function AsFunc(const Expr& e) {
|
||||||
|
if (e.as<FunctionNode>()) {
|
||||||
|
return Downcast<Function>(e);
|
||||||
|
} else if (const CallNode* c = e.as<CallNode>()) {
|
||||||
|
CHECK(c->op.same_as(WithFuncIdOp()));
|
||||||
|
CHECK_EQ(c->args.size(), 1);
|
||||||
|
return AsFunc(c->args[0]);
|
||||||
|
} else {
|
||||||
|
LOG(FATAL) << "Unknown case";
|
||||||
|
throw;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>,
|
class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>,
|
||||||
public PatternFunctor<MatchStatus(const Pattern&, const PStatic&)> {
|
public PatternFunctor<MatchStatus(const Pattern&, const PStatic&)> {
|
||||||
public:
|
public:
|
||||||
PartialEvaluator(const tvm::Array<Var>& free_vars) {
|
PartialEvaluator(const tvm::Array<Var>& free_vars,
|
||||||
|
const Module& mod) :
|
||||||
|
mod_(mod) {
|
||||||
for (const Var& v : free_vars) {
|
for (const Var& v : free_vars) {
|
||||||
env_.Insert(v, NoStatic(v));
|
env_.Insert(v, NoStatic(v));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PStatic VisitExpr(const Expr& e, LetList* ll) final {
|
||||||
|
PStatic ret = ExprFunctor<PStatic(const Expr&, LetList*)>::VisitExpr(e, ll);
|
||||||
|
CHECK(IsAtomic(ret->dynamic)) << ret->dynamic;
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
PStatic VisitExpr_(const ConstantNode* op, LetList* ll) final {
|
PStatic VisitExpr_(const ConstantNode* op, LetList* ll) final {
|
||||||
return HasStatic(MkSTensor(op->data.CopyTo(context_)), ll->Push(GetRef<Expr>(op)));
|
return HasStatic(MkSTensor(op->data.CopyTo(context_)), ll->Push(GetRef<Expr>(op)));
|
||||||
}
|
}
|
||||||
|
@ -421,7 +489,20 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
|
||||||
}
|
}
|
||||||
|
|
||||||
PStatic VisitExpr_(const GlobalVarNode* op, LetList* ll) final {
|
PStatic VisitExpr_(const GlobalVarNode* op, LetList* ll) final {
|
||||||
return NoStatic(GetRef<Expr>(op));
|
GlobalVar gv = GetRef<GlobalVar>(op);
|
||||||
|
if (gv_map_.count(gv) == 0) {
|
||||||
|
if (mod_.defined()) {
|
||||||
|
Function func = mod_->Lookup(gv);
|
||||||
|
InitializeFuncId(func);
|
||||||
|
Func f = VisitFuncStatic(func, gv);
|
||||||
|
gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)});
|
||||||
|
func = AsFunc(PostProcess(VisitFuncDynamic(func, f)));
|
||||||
|
mod_->Update(gv, func);
|
||||||
|
} else {
|
||||||
|
gv_map_.insert({gv, NoStatic(gv)});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return gv_map_.at(gv);
|
||||||
}
|
}
|
||||||
|
|
||||||
PStatic VisitExpr_(const LetNode* op, LetList* ll) final {
|
PStatic VisitExpr_(const LetNode* op, LetList* ll) final {
|
||||||
|
@ -485,6 +566,10 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
|
||||||
}
|
}
|
||||||
|
|
||||||
PStatic VisitExpr_(const CallNode* op, LetList* ll) final {
|
PStatic VisitExpr_(const CallNode* op, LetList* ll) final {
|
||||||
|
if (op->op.same_as(WithFuncIdOp())) {
|
||||||
|
CHECK_EQ(op->args.size(), 1);
|
||||||
|
return VisitExpr(op->args[0], ll);
|
||||||
|
}
|
||||||
PStatic f = VisitExpr(op->op, ll);
|
PStatic f = VisitExpr(op->op, ll);
|
||||||
std::vector<PStatic> x;
|
std::vector<PStatic> x;
|
||||||
tvm::Array<Expr> x_dyn;
|
tvm::Array<Expr> x_dyn;
|
||||||
|
@ -501,19 +586,40 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
PStatic VisitExpr_(const FunctionNode* op, LetList* ll) final {
|
struct TimeFrame {
|
||||||
Function func = GetRef<Function>(op);
|
PartialEvaluator* pe_;
|
||||||
|
FuncId fid_;
|
||||||
|
std::vector<Time> old_time;
|
||||||
|
bool has_old_time;
|
||||||
|
TimeFrame(PartialEvaluator* pe,
|
||||||
|
FuncId fid,
|
||||||
|
const std::vector<Time>& args_time) : pe_(pe), fid_(fid) {
|
||||||
|
has_old_time = pe_->time_map_.count(fid_) > 0;
|
||||||
|
old_time = pe_->time_map_[fid_];
|
||||||
|
pe_->time_map_[fid_] = args_time;
|
||||||
|
}
|
||||||
|
~TimeFrame() {
|
||||||
|
if (has_old_time) {
|
||||||
|
pe_->time_map_[fid_] = old_time;
|
||||||
|
} else {
|
||||||
|
pe_->time_map_.erase(fid_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Func VisitFuncStatic(const Function& func, const Expr& var) {
|
||||||
|
CHECK(IsAtomic(var));
|
||||||
if (func->IsPrimitive()) {
|
if (func->IsPrimitive()) {
|
||||||
return HasStatic(MkSFunc(ConstEvaluateFunc(func, ll)), func);
|
return ConstEvaluateFunc(func);
|
||||||
}
|
}
|
||||||
std::vector<std::pair<Var, PStatic> > free_vars;
|
std::vector<std::pair<Var, PStatic> > free_vars;
|
||||||
for (const auto& v : FreeVars(GetRef<Expr>(op))) {
|
for (const auto& v : FreeVars(func)) {
|
||||||
free_vars.push_back(std::pair<Var, PStatic>(v, env_.Lookup(v)));
|
free_vars.push_back(std::pair<Var, PStatic>(v, env_.Lookup(v)));
|
||||||
}
|
}
|
||||||
Func f = [=](const std::vector<PStatic>& pv,
|
return [=](const std::vector<PStatic>& pv,
|
||||||
const Attrs& attrs,
|
const Attrs& attrs,
|
||||||
const tvm::Array<Type>& type_args,
|
const tvm::Array<Type>& type_args,
|
||||||
LetList* ll) {
|
LetList* ll) {
|
||||||
return env_.Extend<PStatic>([&]() {
|
return env_.Extend<PStatic>([&]() {
|
||||||
CHECK_EQ(pv.size(), func->params.size());
|
CHECK_EQ(pv.size(), func->params.size());
|
||||||
for (size_t i = 0; i < pv.size(); ++i) {
|
for (size_t i = 0; i < pv.size(); ++i) {
|
||||||
|
@ -529,10 +635,50 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
|
||||||
for (size_t i = type_args.size(); i < func->type_params.size(); ++i) {
|
for (size_t i = type_args.size(); i < func->type_params.size(); ++i) {
|
||||||
subst.Set(func->type_params[i], Type());
|
subst.Set(func->type_params[i], Type());
|
||||||
}
|
}
|
||||||
return VisitExpr(TypeSubst(func->body, subst), ll);
|
std::vector<Time> args_time;
|
||||||
|
for (const auto& v : pv) {
|
||||||
|
args_time.push_back(v->created_time);
|
||||||
|
}
|
||||||
|
CHECK_GT(func_map_.count(func), 0);
|
||||||
|
FuncId fid = func_map_.at(func);
|
||||||
|
auto recurse = [&]() {
|
||||||
|
TimeFrame tf(this, fid, args_time);
|
||||||
|
return VisitExpr(RegisterFuncId(TypeSubst(AnnotateFuncId(func->body), subst)), ll);
|
||||||
|
};
|
||||||
|
if (time_map_.count(fid) == 0) {
|
||||||
|
return recurse();
|
||||||
|
} else {
|
||||||
|
/* We check to see that at least one argument decrease
|
||||||
|
* with respect to all previous invocation.
|
||||||
|
* The depth of the recursion is bounded by
|
||||||
|
* the sum of the time of all argument at the first call.
|
||||||
|
*/
|
||||||
|
bool can_recurse = false;
|
||||||
|
std::vector<Time>& min_time = time_map_.at(fid);
|
||||||
|
CHECK_EQ(args_time.size(), min_time.size());
|
||||||
|
for (size_t i = 0; i < args_time.size(); ++i) {
|
||||||
|
if (args_time[i] < min_time[i]) {
|
||||||
|
can_recurse = true;
|
||||||
|
}
|
||||||
|
args_time[i] = std::min(args_time[i], min_time[i]);
|
||||||
|
}
|
||||||
|
if (can_recurse) {
|
||||||
|
return recurse();
|
||||||
|
} else {
|
||||||
|
std::vector<Expr> dyn;
|
||||||
|
for (const auto& v : pv) {
|
||||||
|
dyn.push_back(v->dynamic);
|
||||||
|
}
|
||||||
|
return NoStatic(ll->Push(CallNode::make(var, dyn, attrs, type_args)));
|
||||||
|
}
|
||||||
|
}
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
Expr dyn = store_.Extend<Expr>([&]() {
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Expr VisitFuncDynamic(const Function& func, const Func& f) {
|
||||||
|
return store_.Extend<Expr>([&]() {
|
||||||
store_.Invalidate();
|
store_.Invalidate();
|
||||||
return FunctionNode::make(func->params, LetList::With([&](LetList* ll) {
|
return FunctionNode::make(func->params, LetList::With([&](LetList* ll) {
|
||||||
std::vector<PStatic> pv;
|
std::vector<PStatic> pv;
|
||||||
|
@ -546,7 +692,20 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
|
||||||
return f(pv, Attrs(), type_args, ll)->dynamic;
|
return f(pv, Attrs(), type_args, ll)->dynamic;
|
||||||
}), func->ret_type, func->type_params, func->attrs);
|
}), func->ret_type, func->type_params, func->attrs);
|
||||||
});
|
});
|
||||||
return HasStatic(MkSFunc(f), ll->Push(dyn));
|
}
|
||||||
|
|
||||||
|
PStatic VisitFunc(const Function& func, LetList* ll) {
|
||||||
|
Var v = VarNode::make("x", Type());
|
||||||
|
Func f = VisitFuncStatic(func, v);
|
||||||
|
Function u_func = AsFunc(RegisterFuncId(DeDup(AnnotateFuncId(func))));
|
||||||
|
// TODO(@M.K.): we seems to reduce landin knot into letrec.
|
||||||
|
// restore letrec support across whole relay.
|
||||||
|
return HasStatic(MkSFunc(f),
|
||||||
|
ll->Push(v, VisitFuncDynamic(u_func, f)));
|
||||||
|
}
|
||||||
|
|
||||||
|
PStatic VisitExpr_(const FunctionNode* op, LetList* ll) final {
|
||||||
|
return VisitFunc(GetRef<Function>(op), ll);
|
||||||
}
|
}
|
||||||
|
|
||||||
Expr Reflect(const PStatic& st) {
|
Expr Reflect(const PStatic& st) {
|
||||||
|
@ -590,7 +749,8 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
|
||||||
return Reify(executor_(fused_infered), ll);
|
return Reify(executor_(fused_infered), ll);
|
||||||
}
|
}
|
||||||
|
|
||||||
Func ConstEvaluateFunc(const Expr& expr, LetList* ll) {
|
Func ConstEvaluateFunc(const Expr& expr) {
|
||||||
|
CHECK_EQ(FreeVars(expr).size(), 0);
|
||||||
return [=](const std::vector<PStatic>& pv,
|
return [=](const std::vector<PStatic>& pv,
|
||||||
const Attrs& attrs,
|
const Attrs& attrs,
|
||||||
const tvm::Array<Type>& type_args,
|
const tvm::Array<Type>& type_args,
|
||||||
|
@ -599,7 +759,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
|
||||||
for (const PStatic& ps : pv) {
|
for (const PStatic& ps : pv) {
|
||||||
ns_args.push_back(ps->dynamic);
|
ns_args.push_back(ps->dynamic);
|
||||||
}
|
}
|
||||||
PStatic ns = NoStatic(CallNode::make(expr, ns_args, attrs, type_args));
|
PStatic ns = NoStatic(ll->Push(CallNode::make(expr, ns_args, attrs, type_args)));
|
||||||
if (StatefulOp(expr)) {
|
if (StatefulOp(expr)) {
|
||||||
return ns;
|
return ns;
|
||||||
}
|
}
|
||||||
|
@ -616,7 +776,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
|
||||||
}
|
}
|
||||||
|
|
||||||
PStatic VisitExpr_(const OpNode* op, LetList* ll) final {
|
PStatic VisitExpr_(const OpNode* op, LetList* ll) final {
|
||||||
return HasStatic(MkSFunc(ConstEvaluateFunc(GetRef<Expr>(op), ll)), GetRef<Expr>(op));
|
return HasStatic(MkSFunc(ConstEvaluateFunc(GetRef<Expr>(op))), GetRef<Expr>(op));
|
||||||
}
|
}
|
||||||
|
|
||||||
PStatic VisitExpr_(const ConstructorNode* op, LetList* ll) final {
|
PStatic VisitExpr_(const ConstructorNode* op, LetList* ll) final {
|
||||||
|
@ -680,7 +840,6 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
|
||||||
CHECK_NE(op->constructor->tag, -1);
|
CHECK_NE(op->constructor->tag, -1);
|
||||||
CHECK_NE(scn->constructor->tag, -1);
|
CHECK_NE(scn->constructor->tag, -1);
|
||||||
if (op->constructor->tag == scn->constructor->tag) {
|
if (op->constructor->tag == scn->constructor->tag) {
|
||||||
// todo(M.K.): should use ptr equality but it is broken
|
|
||||||
CHECK_EQ(op->patterns.size(), scn->fields.size());
|
CHECK_EQ(op->patterns.size(), scn->fields.size());
|
||||||
MatchStatus current_match_status = MatchStatus::Match;
|
MatchStatus current_match_status = MatchStatus::Match;
|
||||||
for (size_t i = 0; i < op->patterns.size(); ++i) {
|
for (size_t i = 0; i < op->patterns.size(); ++i) {
|
||||||
|
@ -702,27 +861,119 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void InitializeFuncId(const Expr& e) {
|
||||||
|
struct InitializeFuncIdVisitor : ExprVisitor, PatternVisitor {
|
||||||
|
PartialEvaluator* pe;
|
||||||
|
explicit InitializeFuncIdVisitor(PartialEvaluator* pe) : pe(pe) { }
|
||||||
|
|
||||||
|
void VisitExpr_(const FunctionNode* op) final {
|
||||||
|
Function f = GetRef<Function>(op);
|
||||||
|
CHECK_EQ(pe->func_map_.count(f), 0);
|
||||||
|
pe->func_map_.insert({f, pe->func_map_.size()});
|
||||||
|
VisitExpr(f->body);
|
||||||
|
}
|
||||||
|
|
||||||
|
void VisitPattern(const Pattern& p) final {
|
||||||
|
PatternVisitor::VisitPattern(p);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
InitializeFuncIdVisitor(this).VisitExpr(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
Expr RegisterFuncId(const Expr& e) {
|
||||||
|
struct RegisterFuncIdVisitor : ExprVisitor, PatternVisitor {
|
||||||
|
PartialEvaluator* pe;
|
||||||
|
explicit RegisterFuncIdVisitor(PartialEvaluator* pe) : pe(pe) { }
|
||||||
|
|
||||||
|
void VisitExpr_(const CallNode* op) final {
|
||||||
|
if (op->op.same_as(WithFuncIdOp())) {
|
||||||
|
CHECK_EQ(op->args.size(), 1);
|
||||||
|
CHECK(op->attrs.defined());
|
||||||
|
CHECK(op->attrs.as<WithFuncIdAttrs>());
|
||||||
|
Function f = AsFunc(op->args[0]);
|
||||||
|
FuncId fid = op->attrs.as<WithFuncIdAttrs>()->fid;
|
||||||
|
if (pe->func_map_.count(f) != 0) {
|
||||||
|
CHECK_EQ(pe->func_map_.at(f), fid);
|
||||||
|
}
|
||||||
|
pe->func_map_.insert({f, fid});
|
||||||
|
}
|
||||||
|
ExprVisitor::VisitExpr_(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
void VisitExpr_(const FunctionNode* op) final {
|
||||||
|
Function f = GetRef<Function>(op);
|
||||||
|
CHECK_GT(pe->func_map_.count(f), 0);
|
||||||
|
ExprVisitor::VisitExpr_(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
void VisitPattern(const Pattern& p) final {
|
||||||
|
PatternVisitor::VisitPattern(p);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
RegisterFuncIdVisitor(this).VisitExpr(e);
|
||||||
|
return e;
|
||||||
|
}
|
||||||
|
|
||||||
|
Expr AnnotateFuncId(const Expr& e) {
|
||||||
|
struct AnnotateFuncIdMutator : ExprMutator, PatternMutator {
|
||||||
|
PartialEvaluator* pe;
|
||||||
|
explicit AnnotateFuncIdMutator(PartialEvaluator* pe) : pe(pe) { }
|
||||||
|
|
||||||
|
Expr VisitExpr_(const FunctionNode* op) final {
|
||||||
|
Function f = GetRef<Function>(op);
|
||||||
|
CHECK_GT(pe->func_map_.count(f), 0);
|
||||||
|
return MkWithFuncId(ExprMutator::VisitExpr_(op), pe->func_map_.at(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
Pattern VisitPattern(const Pattern& p) final {
|
||||||
|
return PatternMutator::VisitPattern(p);
|
||||||
|
}
|
||||||
|
|
||||||
|
Var VisitVar(const Var& v) final {
|
||||||
|
return v;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
return AnnotateFuncIdMutator(this).VisitExpr(e);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Environment env_;
|
Environment env_;
|
||||||
|
Module mod_;
|
||||||
|
std::unordered_map<GlobalVar, PStatic, NodeHash, NodeEqual> gv_map_;
|
||||||
|
/*! Termination checking is done as follows:
|
||||||
|
* We have finitely many FunctionIds.
|
||||||
|
* Each FunctionId maps to a class of semantically equivalent function (ignoring type),
|
||||||
|
* as both TypeSubst and DeDup create semantically equivalent function.
|
||||||
|
* We partially map each FunctionId to a std::vector<Time>,
|
||||||
|
* denoting the minimal TimeFrame of each argument of the function.
|
||||||
|
* Every time we try to inline a Function,
|
||||||
|
* we make sure it either does not have a vector<Time>, which means this is the initial call,
|
||||||
|
* or some argument has a lesser time, which means some earlier argument is passed in.
|
||||||
|
* In any case, we remap the mapping to a minimal vector<Time> across all previous invocations
|
||||||
|
* when we PE inside the Function body.
|
||||||
|
* Termination is guaranteed because the creation time of at least one argument will decrease every call.
|
||||||
|
*/
|
||||||
|
std::unordered_map<Function, FuncId, NodeHash, NodeEqual> func_map_;
|
||||||
|
std::unordered_map<FuncId, std::vector<Time> > time_map_;
|
||||||
Store store_;
|
Store store_;
|
||||||
DLContext context_ = CPUContext();
|
DLContext context_ = CPUContext();
|
||||||
FInterpreter executor_ = CPUInterpreter();
|
FInterpreter executor_ = CPUInterpreter();
|
||||||
};
|
};
|
||||||
|
|
||||||
Var DeDupVar(const Var& v) {
|
|
||||||
return VarNode::make(v->name_hint(), v->type_annotation);
|
|
||||||
}
|
|
||||||
|
|
||||||
TypeVar DeDupTypeVar(const TypeVar& tv) {
|
|
||||||
return TypeVarNode::make(tv->var->name_hint, tv->kind);
|
|
||||||
}
|
|
||||||
|
|
||||||
/*! \brief Use a fresh Id for every Var to make the result well-formed. */
|
/*! \brief Use a fresh Id for every Var to make the result well-formed. */
|
||||||
Expr DeDup(const Expr& e) {
|
Expr DeDup(const Expr& e) {
|
||||||
class DeDupMutator : public ExprMutator, public PatternMutator {
|
class DeDupMutator : public TypeMutator,
|
||||||
|
public ExprMutator,
|
||||||
|
public PatternMutator {
|
||||||
public:
|
public:
|
||||||
|
TypeVar Fresh(const TypeVar& tv) {
|
||||||
|
TypeVar ret = TypeVarNode::make(tv->var->name_hint, tv->kind);
|
||||||
|
type_rename_[tv] = ret;
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
Var Fresh(const Var& v) {
|
Var Fresh(const Var& v) {
|
||||||
Var ret = DeDupVar(v);
|
Var ret = VarNode::make(v->name_hint(), VisitType(v->type_annotation));
|
||||||
rename_[v] = ret;
|
rename_[v] = ret;
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
@ -737,18 +988,27 @@ Expr DeDup(const Expr& e) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Expr VisitExpr_(const LetNode* op) final {
|
Expr VisitExpr_(const LetNode* op) final {
|
||||||
return LetNode::make(Fresh(op->var), VisitExpr(op->value), VisitExpr(op->body));
|
Var v = Fresh(op->var);
|
||||||
|
return LetNode::make(v, VisitExpr(op->value), VisitExpr(op->body));
|
||||||
|
}
|
||||||
|
|
||||||
|
Type VisitType(const Type& t) final {
|
||||||
|
return t.defined() ? TypeMutator::VisitType(t) : t;
|
||||||
}
|
}
|
||||||
|
|
||||||
Expr VisitExpr_(const FunctionNode* op) final {
|
Expr VisitExpr_(const FunctionNode* op) final {
|
||||||
|
tvm::Array<TypeVar> type_params;
|
||||||
|
for (const TypeVar& type_param : op->type_params) {
|
||||||
|
type_params.push_back(Fresh(type_param));
|
||||||
|
}
|
||||||
tvm::Array<Var> params;
|
tvm::Array<Var> params;
|
||||||
for (const Var& param : op->params) {
|
for (const Var& param : op->params) {
|
||||||
params.push_back(Fresh(param));
|
params.push_back(Fresh(param));
|
||||||
}
|
}
|
||||||
return FunctionNode::make(params,
|
return FunctionNode::make(params,
|
||||||
VisitExpr(op->body),
|
VisitExpr(op->body),
|
||||||
op->ret_type,
|
VisitType(op->ret_type),
|
||||||
op->type_params,
|
type_params,
|
||||||
op->attrs);
|
op->attrs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -756,14 +1016,28 @@ Expr DeDup(const Expr& e) {
|
||||||
return PatternMutator::VisitPattern(p);
|
return PatternMutator::VisitPattern(p);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Clause VisitClause(const Clause& c) final {
|
||||||
|
Pattern pat = VisitPattern(c->lhs);
|
||||||
|
return ClauseNode::make(pat, VisitExpr(c->rhs));
|
||||||
|
}
|
||||||
|
|
||||||
|
Type VisitType_(const TypeVarNode* op) final {
|
||||||
|
TypeVar v = GetRef<TypeVar>(op);
|
||||||
|
return type_rename_.count(v) != 0 ? type_rename_.at(v) : v;
|
||||||
|
}
|
||||||
|
|
||||||
Var VisitVar(const Var& v) final {
|
Var VisitVar(const Var& v) final {
|
||||||
return Fresh(v);
|
return Fresh(v);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unordered_map<Var, Var, NodeHash, NodeEqual> rename_;
|
std::unordered_map<Var, Var, NodeHash, NodeEqual> rename_;
|
||||||
|
std::unordered_map<TypeVar, TypeVar, NodeHash, NodeEqual> type_rename_;
|
||||||
};
|
};
|
||||||
return DeDupMutator().VisitExpr(e);
|
|
||||||
|
Expr ret = DeDupMutator().VisitExpr(e);
|
||||||
|
CHECK_EQ(FreeVars(ret).size(), FreeVars(e).size());
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*! \brief Remap multiple Var sharing the same Id into the same Var. */
|
/*! \brief Remap multiple Var sharing the same Id into the same Var. */
|
||||||
|
@ -787,11 +1061,38 @@ Expr Remap(const Expr& e) {
|
||||||
return RemapMutator().VisitExpr(e);
|
return RemapMutator().VisitExpr(e);
|
||||||
}
|
}
|
||||||
|
|
||||||
Expr PartialEval(const Expr& e) {
|
Expr StripWithFuncId(const Expr& e) {
|
||||||
|
struct StripWithFuncIdMutator : ExprMutator, PatternMutator {
|
||||||
|
Expr VisitExpr_(const CallNode* op) final {
|
||||||
|
if (op->op.same_as(WithFuncIdOp())) {
|
||||||
|
CHECK_EQ(op->args.size(), 1);
|
||||||
|
return VisitExpr(op->args[0]);
|
||||||
|
} else {
|
||||||
|
return ExprMutator::VisitExpr_(op);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Pattern VisitPattern(const Pattern& p) final {
|
||||||
|
return PatternMutator::VisitPattern(p);
|
||||||
|
}
|
||||||
|
|
||||||
|
Var VisitVar(const Var& v) final {
|
||||||
|
return v;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
return StripWithFuncIdMutator().VisitExpr(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
Expr PostProcess(const Expr& e) {
|
||||||
|
return StripWithFuncId(DeDup(Remap(e)));
|
||||||
|
}
|
||||||
|
|
||||||
|
Expr PartialEval(const Expr& e, const Module& m) {
|
||||||
return TransformF([&](const Expr& e) {
|
return TransformF([&](const Expr& e) {
|
||||||
return LetList::With([&](LetList* ll) {
|
return LetList::With([&](LetList* ll) {
|
||||||
PartialEvaluator pe(FreeVars(e));
|
PartialEvaluator pe(FreeVars(e), m);
|
||||||
return Remap(DeDup(pe.VisitExpr(e, ll)->dynamic));
|
pe.InitializeFuncId(e);
|
||||||
|
return PostProcess(pe.VisitExpr(e, ll)->dynamic);
|
||||||
});
|
});
|
||||||
}, e);
|
}, e);
|
||||||
}
|
}
|
||||||
|
@ -804,7 +1105,7 @@ namespace transform {
|
||||||
Pass PartialEval() {
|
Pass PartialEval() {
|
||||||
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
|
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
|
||||||
[=](Function f, Module m, PassContext pc) {
|
[=](Function f, Module m, PassContext pc) {
|
||||||
return Downcast<Function>(PartialEval(f));
|
return Downcast<Function>(PartialEval(f, m));
|
||||||
};
|
};
|
||||||
return CreateFunctionPass(pass_func, 1, "PartialEvaluate", {});
|
return CreateFunctionPass(pass_func, 1, "PartialEvaluate", {});
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,14 +18,17 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tvm
|
import tvm
|
||||||
from tvm import relay
|
from tvm import relay
|
||||||
from tvm.relay.ir_pass import partial_evaluate, dead_code_elimination
|
from tvm.relay.ir_pass import partial_evaluate, alpha_equal, infer_type, dead_code_elimination
|
||||||
from tvm.relay.ir_pass import gradient, alpha_equal, infer_type
|
from tvm.relay.ir_pass import gradient
|
||||||
from tvm.relay import op, create_executor
|
from tvm.relay import op, create_executor
|
||||||
from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue
|
from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue
|
||||||
from tvm.relay.prelude import Prelude
|
from tvm.relay.prelude import Prelude
|
||||||
from tvm.relay import create_executor
|
from tvm.relay import create_executor
|
||||||
|
|
||||||
from nose.tools import nottest
|
from nose.tools import nottest
|
||||||
|
from tvm.relay import Var, TypeVar, TupleGetItem, Let, Function, const, RefRead, RefWrite, RefCreate
|
||||||
|
from tvm.relay import TensorType, Tuple, If, Module, Clause, PatternConstructor, PatternVar, Match
|
||||||
|
from tvm.relay import GlobalVar, Call, Type
|
||||||
|
from tvm.relay.testing import add_nat_definitions, count, make_nat_value, make_nat_expr
|
||||||
|
|
||||||
def check_eval(expr, expected_result, mod=None, rtol=1e-07):
|
def check_eval(expr, expected_result, mod=None, rtol=1e-07):
|
||||||
ctx = tvm.context("llvm", 0)
|
ctx = tvm.context("llvm", 0)
|
||||||
|
@ -35,24 +38,22 @@ def check_eval(expr, expected_result, mod=None, rtol=1e-07):
|
||||||
np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol)
|
np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol)
|
||||||
|
|
||||||
|
|
||||||
def dcpe(expr):
|
def dcpe(expr, mod=None):
|
||||||
return dead_code_elimination(partial_evaluate(expr))
|
return dead_code_elimination(partial_evaluate(expr, mod=mod), inline_once=True)
|
||||||
|
|
||||||
|
|
||||||
def test_tuple():
|
def test_tuple():
|
||||||
t = relay.TypeVar("t")
|
t = TypeVar("t")
|
||||||
x = relay.Var("x", t)
|
x = Var("x", t)
|
||||||
body = relay.TupleGetItem(relay.Tuple([relay.const(4.0), x]), 1)
|
body = TupleGetItem(relay.Tuple([relay.const(4.0), x]), 1)
|
||||||
f = relay.Function([x], body, None, [t])
|
f = Function([x], body, None, [t])
|
||||||
assert alpha_equal(dcpe(f), relay.Function([x], x, None, [t]))
|
assert alpha_equal(dcpe(f), relay.Function([x], x, None, [t]))
|
||||||
|
|
||||||
@nottest
|
|
||||||
def test_const_inline():
|
def test_const_inline():
|
||||||
# TODO(MK): fix me
|
d = Var("d")
|
||||||
d = relay.Var("d")
|
double = Function([d], d + d)
|
||||||
double = relay.Function([d], d + d)
|
orig = double(const(4.0))
|
||||||
orig = double(relay.const(4.0))
|
assert alpha_equal(dcpe(orig), const(8.0))
|
||||||
assert alpha_equal(dcpe(double(relay.const(4.0))), relay.const(8.0))
|
|
||||||
|
|
||||||
|
|
||||||
def test_ref():
|
def test_ref():
|
||||||
|
@ -60,44 +61,57 @@ def test_ref():
|
||||||
r = relay.Var("r")
|
r = relay.Var("r")
|
||||||
x = relay.Var("x")
|
x = relay.Var("x")
|
||||||
body = relay.RefRead(r)
|
body = relay.RefRead(r)
|
||||||
body = relay.Let(x, relay.RefWrite(r, relay.RefRead(r) * relay.RefRead(r)), body)
|
body = Let(x, RefWrite(r, RefRead(r) * RefRead(r)), body)
|
||||||
body = relay.Let(r, relay.RefCreate(d), body)
|
body = Let(r, RefCreate(d), body)
|
||||||
square = relay.Function([d], body)
|
square = Function([d], body)
|
||||||
assert alpha_equal(dcpe(square), relay.Function([d], d * d))
|
assert alpha_equal(dcpe(square), Function([d], d * d))
|
||||||
|
|
||||||
@nottest
|
|
||||||
def test_ad():
|
def test_empty_ad():
|
||||||
# TODO(MK): fix me
|
|
||||||
shape = (10, 10)
|
shape = (10, 10)
|
||||||
dtype = "float32"
|
dtype = "float32"
|
||||||
t = relay.TensorType(shape, dtype)
|
t = TensorType(shape, dtype)
|
||||||
d = relay.Var("d", t)
|
d = Var("d", t)
|
||||||
f = relay.Function([d], d * d)
|
f = Function([d], d)
|
||||||
|
g = dcpe(gradient(f))
|
||||||
|
expected = Function([d], Tuple([d, Tuple([op.ones_like(d)])]))
|
||||||
|
assert alpha_equal(g, expected)
|
||||||
|
|
||||||
|
def test_ad():
|
||||||
|
shape = (10, 10)
|
||||||
|
dtype = "float32"
|
||||||
|
t = TensorType(shape, dtype)
|
||||||
|
d = Var("d", t)
|
||||||
|
f = Function([d], d * d)
|
||||||
g = dcpe(gradient(f))
|
g = dcpe(gradient(f))
|
||||||
m = d * d
|
m = d * d
|
||||||
o = relay.op.ones_like(m)
|
x = relay.Var("x")
|
||||||
grad = relay.op.zeros_like(d) + relay.op.collapse_sum_like(o * d, d) + relay.op.collapse_sum_like(o * d, d)
|
o = op.ones_like(x)
|
||||||
expected = relay.Function([d], relay.Tuple([m, relay.Tuple([grad])]))
|
x1 = relay.Var("x1")
|
||||||
|
grad = op.zeros_like(d) + op.collapse_sum_like(x1 * d, d) + op.collapse_sum_like(x1 * d, d)
|
||||||
|
body = Tuple([x, Tuple([grad])])
|
||||||
|
body = relay.Let(x1, o, body)
|
||||||
|
expected = Function([d], relay.Let(x, m, body))
|
||||||
assert alpha_equal(g, expected)
|
assert alpha_equal(g, expected)
|
||||||
|
|
||||||
|
|
||||||
def test_if_ref():
|
def test_if_ref():
|
||||||
shape = ()
|
shape = ()
|
||||||
dtype = "bool"
|
dtype = "bool"
|
||||||
t = relay.TensorType(shape, dtype)
|
t = TensorType(shape, dtype)
|
||||||
d = relay.Var("d", t)
|
d = Var("d", t)
|
||||||
r = relay.Var("r")
|
r = Var("r")
|
||||||
update = relay.Function([], relay.RefWrite(r, relay.RefRead(r) + relay.RefRead(r)))
|
update = Function([], RefWrite(r, RefRead(r) + RefRead(r)))
|
||||||
u = relay.Var("u")
|
u = Var("u")
|
||||||
body = relay.If(d, u(), u())
|
body = If(d, u(), u())
|
||||||
eff = relay.Var("eff")
|
eff = Var("eff")
|
||||||
body = relay.Let(eff, body, relay.RefRead(r))
|
body = Let(eff, body, RefRead(r))
|
||||||
f = relay.Function([d], relay.Let(r, relay.RefCreate(relay.const(1)), relay.Let(u, update, body)))
|
f = Function([d], Let(r, RefCreate(const(1)), Let(u, update, body)))
|
||||||
f = infer_type(f)
|
f = infer_type(f)
|
||||||
pe_f = infer_type(partial_evaluate(f))
|
pe_f = infer_type(partial_evaluate(f))
|
||||||
ex = create_executor()
|
ex = create_executor()
|
||||||
f_res = ex.evaluate(f)(relay.const(True))
|
f_res = ex.evaluate(f)(const(True))
|
||||||
pe_f_res = ex.evaluate(pe_f)(relay.const(True))
|
pe_f_res = ex.evaluate(pe_f)(const(True))
|
||||||
np.testing.assert_allclose(f_res.asnumpy(), 2 * np.ones_like(f_res.asnumpy()))
|
np.testing.assert_allclose(f_res.asnumpy(), 2 * np.ones_like(f_res.asnumpy()))
|
||||||
np.testing.assert_allclose(pe_f_res.asnumpy(), 2 * np.ones_like(pe_f_res.asnumpy()))
|
np.testing.assert_allclose(pe_f_res.asnumpy(), 2 * np.ones_like(pe_f_res.asnumpy()))
|
||||||
|
|
||||||
|
@ -105,52 +119,162 @@ def test_if_ref():
|
||||||
def test_function_invalidate():
|
def test_function_invalidate():
|
||||||
shape = ()
|
shape = ()
|
||||||
dtype = "bool"
|
dtype = "bool"
|
||||||
t = relay.TensorType(shape, dtype)
|
t = TensorType(shape, dtype)
|
||||||
d = relay.Var("d", t)
|
d = Var("d", t)
|
||||||
r = relay.Var("r")
|
r = Var("r")
|
||||||
fetch = relay.Function([], relay.RefRead(r))
|
fetch = Function([], RefRead(r))
|
||||||
fet = relay.Var("fetch")
|
fet = Var("fetch")
|
||||||
fet_obscured = relay.Var("fetch_obscured")
|
fet_obscured = Var("fetch_obscured")
|
||||||
u = relay.Var("u")
|
u = Var("u")
|
||||||
body = relay.If(d, fet_obscured(), fet_obscured())
|
body = If(d, fet_obscured(), fet_obscured())
|
||||||
body = relay.Let(u, relay.RefWrite(r, relay.const(1)), body)
|
body = Let(u, RefWrite(r, const(1)), body)
|
||||||
body = relay.Let(fet_obscured, relay.If(d, fet, fet), body)
|
body = Let(fet_obscured, If(d, fet, fet), body)
|
||||||
body = relay.Let(fet, fetch, body)
|
body = Let(fet, fetch, body)
|
||||||
body = relay.Let(r, relay.RefCreate(relay.const(0)), body)
|
body = Let(r, RefCreate(const(0)), body)
|
||||||
f = relay.Function([d], body)
|
f = Function([d], body)
|
||||||
f = infer_type(f)
|
f = infer_type(f)
|
||||||
pe_f = infer_type(partial_evaluate(f))
|
pe_f = infer_type(partial_evaluate(f))
|
||||||
ex = create_executor()
|
ex = create_executor()
|
||||||
f_res = ex.evaluate(f)(relay.const(True))
|
f_res = ex.evaluate(f)(const(True))
|
||||||
pe_f_res = ex.evaluate(pe_f)(relay.const(True))
|
pe_f_res = ex.evaluate(pe_f)(const(True))
|
||||||
np.testing.assert_allclose(f_res.asnumpy(), np.ones_like(f_res.asnumpy()))
|
np.testing.assert_allclose(f_res.asnumpy(), np.ones_like(f_res.asnumpy()))
|
||||||
np.testing.assert_allclose(pe_f_res.asnumpy(), np.ones_like(pe_f_res.asnumpy()))
|
np.testing.assert_allclose(pe_f_res.asnumpy(), np.ones_like(pe_f_res.asnumpy()))
|
||||||
|
|
||||||
|
|
||||||
def test_head_cons():
|
def test_head_cons():
|
||||||
mod = relay.Module()
|
mod = Module()
|
||||||
p = Prelude(mod)
|
p = Prelude(mod)
|
||||||
def hd_impl():
|
def hd_impl():
|
||||||
a = relay.TypeVar("a")
|
a = TypeVar("a")
|
||||||
x = relay.Var("x", p.l(a))
|
x = Var("x", p.l(a))
|
||||||
y = relay.Var("y")
|
y = Var("y")
|
||||||
z = relay.Var("z")
|
z = Var("z")
|
||||||
cons_case = relay.Clause(relay.PatternConstructor(p.cons,
|
cons_case = Clause(PatternConstructor(p.cons,
|
||||||
[relay.PatternVar(y),
|
[PatternVar(y),
|
||||||
relay.PatternVar(z)]),
|
PatternVar(z)]),
|
||||||
y)
|
y)
|
||||||
return relay.Function([x], relay.Match(x, [cons_case]), a, [a])
|
y = Var("y")
|
||||||
t = relay.TypeVar("t")
|
z = Var("z")
|
||||||
x = relay.Var("x", t)
|
return Function([x], Match(x, [cons_case]), a, [a])
|
||||||
hd = relay.Var("hd")
|
t = TypeVar("t")
|
||||||
body = relay.Let(hd, hd_impl(), hd(p.cons(x, p.nil())))
|
x = Var("x", t)
|
||||||
f = relay.Function([x], body, None, [t])
|
hd = Var("hd")
|
||||||
|
body = Let(hd, hd_impl(), hd(p.cons(x, p.nil())))
|
||||||
|
f = Function([x], body, None, [t])
|
||||||
f = infer_type(f, mod=mod)
|
f = infer_type(f, mod=mod)
|
||||||
res = dcpe(f)
|
res = dcpe(f)
|
||||||
assert alpha_equal(res, relay.Function([x], x, t, [t]))
|
assert alpha_equal(res, Function([x], x, t, [t]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_map():
|
||||||
|
mod = Module()
|
||||||
|
p = Prelude(mod)
|
||||||
|
f = Var("f")
|
||||||
|
orig = p.map(f, p.cons(const(1), p.cons(const(2), p.cons(const(3), p.nil()))))
|
||||||
|
expected = p.cons(f(const(1)), p.cons(f(const(2)), p.cons(f(const(3)), p.nil())))
|
||||||
|
assert alpha_equal(dcpe(orig, mod=mod), expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_loop():
|
||||||
|
mod = Module()
|
||||||
|
t = TypeVar("t")
|
||||||
|
x = Var("x", t)
|
||||||
|
loop = GlobalVar("loop")
|
||||||
|
mod[loop] = Function([x], loop(x), t, [t])
|
||||||
|
res = dcpe(loop(const(1)), mod=mod)
|
||||||
|
expected = Call(loop, [const(1)], None, [None])
|
||||||
|
assert alpha_equal(res, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_swap_loop():
|
||||||
|
mod = Module()
|
||||||
|
p = Prelude(mod)
|
||||||
|
add_nat_definitions(p)
|
||||||
|
nat = p.nat()
|
||||||
|
x = Var("x", nat)
|
||||||
|
y = Var("y", nat)
|
||||||
|
loop = GlobalVar("loop")
|
||||||
|
mod[loop] = Function([x, y], loop(y, x), nat)
|
||||||
|
prog = loop(make_nat_expr(p, 1), make_nat_expr(p, 2))
|
||||||
|
res = dcpe(prog, mod=mod)
|
||||||
|
assert alpha_equal(prog, res)
|
||||||
|
|
||||||
|
|
||||||
|
def test_abs_diff():
|
||||||
|
# TODO(@M.K.): refactor using tuple pattern (not yet implemented)
|
||||||
|
mod = Module()
|
||||||
|
p = Prelude(mod)
|
||||||
|
add_nat_definitions(p)
|
||||||
|
nat = p.nat()
|
||||||
|
x = Var("x", nat)
|
||||||
|
y = Var("y", nat)
|
||||||
|
xp = Var("x'", nat)
|
||||||
|
yp = Var("y'", nat)
|
||||||
|
diff = GlobalVar("diff")
|
||||||
|
y_z_case = Clause(PatternConstructor(p.z, []), x)
|
||||||
|
y_s_case = Clause(PatternConstructor(p.s, [PatternVar(yp)]), diff(yp, xp))
|
||||||
|
x_z_case = Clause(PatternConstructor(p.z, []), y)
|
||||||
|
x_s_case = Clause(PatternConstructor(p.s, [PatternVar(xp)]), Match(y, [y_z_case, y_s_case]))
|
||||||
|
mod[diff] = Function([x, y], Match(x, [x_z_case, x_s_case]))
|
||||||
|
orig = diff(make_nat_expr(p, 7), make_nat_expr(p, 3))
|
||||||
|
res = dcpe(orig, mod=mod)
|
||||||
|
assert alpha_equal(res, make_nat_expr(p, 4))
|
||||||
|
|
||||||
|
|
||||||
|
def test_match_nat_id():
|
||||||
|
mod = Module()
|
||||||
|
p = Prelude(mod)
|
||||||
|
add_nat_definitions(p)
|
||||||
|
nat = p.nat()
|
||||||
|
x = Var("x", nat)
|
||||||
|
y = Var("y", nat)
|
||||||
|
nat_id = GlobalVar("nat_id")
|
||||||
|
z_case = Clause(PatternConstructor(p.z, []), p.z())
|
||||||
|
s_case = Clause(PatternConstructor(p.s, [PatternVar(y)]), p.s(y))
|
||||||
|
mod[nat_id] = Function([x], Match(x, [z_case, s_case]))
|
||||||
|
orig = nat_id(make_nat_expr(p, 3))
|
||||||
|
res = dcpe(orig, mod=mod)
|
||||||
|
assert alpha_equal(res, make_nat_expr(p, 3))
|
||||||
|
|
||||||
|
|
||||||
|
def test_nat_id():
|
||||||
|
mod = Module()
|
||||||
|
p = Prelude(mod)
|
||||||
|
add_nat_definitions(p)
|
||||||
|
nat = p.nat()
|
||||||
|
x = Var("x", nat)
|
||||||
|
y = Var("y", nat)
|
||||||
|
nat_id = GlobalVar("nat_id")
|
||||||
|
mod[nat_id] = Function([x], x)
|
||||||
|
orig = nat_id(make_nat_expr(p, 3))
|
||||||
|
res = dcpe(orig, mod=mod)
|
||||||
|
assert alpha_equal(res, make_nat_expr(p, 3))
|
||||||
|
|
||||||
|
|
||||||
|
def test_global_match_nat_id():
|
||||||
|
mod = Module()
|
||||||
|
p = Prelude(mod)
|
||||||
|
add_nat_definitions(p)
|
||||||
|
nat = p.nat()
|
||||||
|
x = Var("x", nat)
|
||||||
|
z_case = Clause(PatternConstructor(p.z, []), p.z())
|
||||||
|
s_case = Clause(PatternConstructor(p.s, [PatternVar(x)]), p.s(x))
|
||||||
|
orig = Match(make_nat_expr(p, 3), [z_case, s_case])
|
||||||
|
res = dcpe(orig, mod=mod)
|
||||||
|
assert alpha_equal(res, make_nat_expr(p, 3))
|
||||||
|
|
||||||
|
|
||||||
|
def test_double():
|
||||||
|
mod = Module()
|
||||||
|
p = Prelude(mod)
|
||||||
|
add_nat_definitions(p)
|
||||||
|
orig = p.double(make_nat_expr(p, 3))
|
||||||
|
res = dcpe(orig, mod=mod)
|
||||||
|
assert alpha_equal(res, make_nat_expr(p, 6))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
test_empty_ad()
|
||||||
test_tuple()
|
test_tuple()
|
||||||
test_const_inline()
|
test_const_inline()
|
||||||
test_ref()
|
test_ref()
|
||||||
|
@ -158,3 +282,11 @@ if __name__ == '__main__':
|
||||||
test_if_ref()
|
test_if_ref()
|
||||||
test_function_invalidate()
|
test_function_invalidate()
|
||||||
test_head_cons()
|
test_head_cons()
|
||||||
|
test_map()
|
||||||
|
test_loop()
|
||||||
|
test_swap_loop()
|
||||||
|
test_abs_diff()
|
||||||
|
test_double()
|
||||||
|
test_nat_id()
|
||||||
|
test_global_match_nat_id()
|
||||||
|
test_match_nat_id()
|
||||||
|
|
Загрузка…
Ссылка в новой задаче