save (#3033)
save save save upstream lint remove bad changes fix build save save please the ci god Update src/relay/pass/partial_eval.cc Co-Authored-By: Wei Chen <ipondering.weic@gmail.com> save fix test ci is ANGRY fix rebase problem fix rebase add test save save comment
This commit is contained in:
Родитель
50dd03ca86
Коммит
df88c411f5
|
@ -296,13 +296,15 @@ TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const Module& mod);
|
|||
* For example, this pass should turn `let a = 1 in 2` into `2`,
|
||||
* as the value of the expression does not depend on a.
|
||||
*
|
||||
* As another example, `let a = 1 in a` will be optimized into 1.
|
||||
* As another example, `let a = 1 in a` will be optimized into 1,
|
||||
* if the flag is turned on.
|
||||
*
|
||||
* \param e the expression to optimize.
|
||||
* \param inline_once whether or not to inline binding used one.
|
||||
*
|
||||
* \return the optimized expression.
|
||||
*/
|
||||
TVM_DLL Expr DeadCodeElimination(const Expr& e);
|
||||
TVM_DLL Expr DeadCodeElimination(const Expr& e, bool inline_once = false);
|
||||
|
||||
/*!
|
||||
* \brief Fold constant expressions.
|
||||
|
@ -435,11 +437,12 @@ TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const Module& mod);
|
|||
* It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
|
||||
* As a side effect, code size will explode.
|
||||
*
|
||||
* \param e the expression,
|
||||
* \param e the expression
|
||||
* \param mod the module
|
||||
*
|
||||
* \return the optimized expression.
|
||||
*/
|
||||
TVM_DLL Expr PartialEval(const Expr& e);
|
||||
TVM_DLL Expr PartialEval(const Expr& e, const Module& mod);
|
||||
|
||||
/*!
|
||||
* \brief Bind the free variables to a Relay expression.
|
||||
|
|
|
@ -356,9 +356,11 @@ TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc<
|
|||
*
|
||||
* As another example, `let a = 1 in a` will be optimized into 1.
|
||||
*
|
||||
* \param inline_once whether or not to inline binding used one.
|
||||
*
|
||||
* \return the pass.
|
||||
*/
|
||||
TVM_DLL Pass DeadCodeElimination();
|
||||
TVM_DLL Pass DeadCodeElimination(bool inline_once = false);
|
||||
|
||||
/*!
|
||||
* \brief Fold constant expressions.
|
||||
|
|
|
@ -129,7 +129,7 @@ def well_formed(expr):
|
|||
|
||||
Parameters
|
||||
----------
|
||||
expr: tvm.relay.Expr
|
||||
expr : tvm.relay.Expr
|
||||
The input expression
|
||||
|
||||
Returns
|
||||
|
@ -175,7 +175,7 @@ def free_vars(expr):
|
|||
|
||||
Parameters
|
||||
----------
|
||||
expr: tvm.relay.Expr
|
||||
expr : tvm.relay.Expr
|
||||
The input expression
|
||||
|
||||
Returns
|
||||
|
@ -197,7 +197,7 @@ def bound_vars(expr):
|
|||
|
||||
Parameters
|
||||
----------
|
||||
expr: tvm.relay.Expr
|
||||
expr : tvm.relay.Expr
|
||||
The input expression
|
||||
|
||||
Returns
|
||||
|
@ -213,7 +213,7 @@ def all_vars(expr):
|
|||
|
||||
Parameters
|
||||
----------
|
||||
expr: tvm.relay.Expr
|
||||
expr : tvm.relay.Expr
|
||||
The input expression
|
||||
|
||||
Returns
|
||||
|
@ -229,9 +229,10 @@ def free_type_vars(expr, mod=None):
|
|||
|
||||
Parameters
|
||||
----------
|
||||
expr: Union[tvm.relay.Expr,tvm.relay.Type]
|
||||
expr : Union[tvm.relay.Expr,tvm.relay.Type]
|
||||
The input expression/type
|
||||
mod: tvm.relay.Module, optional
|
||||
|
||||
mod : Optional[tvm.relay.Module]
|
||||
The global module
|
||||
|
||||
Returns
|
||||
|
@ -248,9 +249,10 @@ def bound_type_vars(expr, mod=None):
|
|||
|
||||
Parameters
|
||||
----------
|
||||
expr: Union[tvm.relay.Expr,tvm.relay.Type]
|
||||
expr : Union[tvm.relay.Expr,tvm.relay.Type]
|
||||
The input expression/type
|
||||
mod: tvm.relay.Module, optional
|
||||
|
||||
mod : Optional[tvm.relay.Module]
|
||||
The global module
|
||||
|
||||
Returns
|
||||
|
@ -267,9 +269,9 @@ def all_type_vars(expr, mod=None):
|
|||
|
||||
Parameters
|
||||
----------
|
||||
expr: Union[tvm.relay.Expr,tvm.relay.Type]
|
||||
expr : Union[tvm.relay.Expr,tvm.relay.Type]
|
||||
The input expression/type
|
||||
mod: tvm.relay.Module, optional
|
||||
mod : Optional[tvm.relay.Module]
|
||||
The global module
|
||||
|
||||
Returns
|
||||
|
@ -286,12 +288,12 @@ def simplify_inference(expr):
|
|||
|
||||
Parameters
|
||||
----------
|
||||
e: tvm.relay.Expr
|
||||
expr : tvm.relay.Expr
|
||||
The input Expression
|
||||
|
||||
Returns
|
||||
-------
|
||||
result: tvm.relay.Expr
|
||||
result : tvm.relay.Expr
|
||||
An expression which is semantically equal to the input expression,
|
||||
but with some simplification
|
||||
"""
|
||||
|
@ -304,32 +306,34 @@ def canonicalize_ops(expr):
|
|||
|
||||
Parameters
|
||||
----------
|
||||
e: tvm.relay.Expr
|
||||
expr : tvm.relay.Expr
|
||||
The input Expression
|
||||
|
||||
Returns
|
||||
-------
|
||||
result: tvm.relay.Expr
|
||||
result : tvm.relay.Expr
|
||||
An expression without bias_add
|
||||
"""
|
||||
return _ir_pass.canonicalize_ops(expr)
|
||||
|
||||
|
||||
def dead_code_elimination(expr):
|
||||
def dead_code_elimination(expr, inline_once=False):
|
||||
""" Remove expressions which does not effect the program result (dead code).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
e: tvm.relay.Expr
|
||||
expr : tvm.relay.Expr
|
||||
The input Expression
|
||||
|
||||
inline_once : Optional[Bool]
|
||||
Whether to inline binding that occur only once.
|
||||
Returns
|
||||
-------
|
||||
result: tvm.relay.Expr
|
||||
result : tvm.relay.Expr
|
||||
An expression which is semantically equal to the input expression,
|
||||
but with dead code removed.
|
||||
"""
|
||||
return _ir_pass.dead_code_elimination(expr)
|
||||
return _ir_pass.dead_code_elimination(expr, inline_once)
|
||||
|
||||
|
||||
def alpha_equal(lhs, rhs):
|
||||
|
@ -337,15 +341,15 @@ def alpha_equal(lhs, rhs):
|
|||
|
||||
Parameters
|
||||
----------
|
||||
lhs: tvm.relay.Expr
|
||||
lhs : tvm.relay.Expr
|
||||
One of the input Expression.
|
||||
|
||||
rhs: tvm.relay.Expr
|
||||
rhs : tvm.relay.Expr
|
||||
One of the input Expression.
|
||||
|
||||
Returns
|
||||
-------
|
||||
result: bool
|
||||
result : bool
|
||||
True iff lhs is alpha equal to rhs.
|
||||
"""
|
||||
return bool(_make._alpha_equal(lhs, rhs))
|
||||
|
@ -359,15 +363,15 @@ def graph_equal(lhs, rhs):
|
|||
|
||||
Parameters
|
||||
----------
|
||||
lhs: tvm.relay.Expr
|
||||
lhs : tvm.relay.Expr
|
||||
One of the input Expression.
|
||||
|
||||
rhs: tvm.relay.Expr
|
||||
rhs : tvm.relay.Expr
|
||||
One of the input Expression.
|
||||
|
||||
Returns
|
||||
-------
|
||||
result: bool
|
||||
result : bool
|
||||
True iff lhs is data-flow equivalent to rhs.
|
||||
"""
|
||||
return bool(_make._graph_equal(lhs, rhs))
|
||||
|
@ -378,12 +382,12 @@ def structural_hash(value):
|
|||
|
||||
Parameters
|
||||
----------
|
||||
expr: tvm.relay.Expr or tvm.relay.Type
|
||||
expr : Union[tvm.relay.Expr, tvm.relay.Type]
|
||||
The expression to hash.
|
||||
|
||||
Returns
|
||||
-------
|
||||
result: int
|
||||
result : int
|
||||
The hash value
|
||||
"""
|
||||
if isinstance(value, Expr):
|
||||
|
@ -544,12 +548,12 @@ def to_a_normal_form(expr, mod=None):
|
|||
expr : tvm.relay.Expr
|
||||
The input expression.
|
||||
|
||||
mod: Optional[tvm.relay.Module]
|
||||
mod : Optional[tvm.relay.Module]
|
||||
The global module.
|
||||
|
||||
Returns
|
||||
-------
|
||||
expr: tvm.relay.Expr
|
||||
result : tvm.relay.Expr
|
||||
The output expression.
|
||||
"""
|
||||
return _ir_pass.to_a_normal_form(expr, mod)
|
||||
|
@ -563,7 +567,7 @@ def to_graph_normal_form(expr):
|
|||
The input expression
|
||||
Returns
|
||||
-------
|
||||
expr : tvm.relay.Expr
|
||||
result : tvm.relay.Expr
|
||||
The output expression
|
||||
"""
|
||||
return _ir_pass.to_graph_normal_form(expr)
|
||||
|
@ -612,7 +616,7 @@ def get_total_mac_number(expr):
|
|||
|
||||
Returns
|
||||
-------
|
||||
ret : int64
|
||||
result : int64
|
||||
The number of MACs (multiply-accumulate) of a model
|
||||
"""
|
||||
return _ir_pass.GetTotalMacNumber(expr)
|
||||
|
@ -627,17 +631,17 @@ def eliminate_common_subexpr(expr, fskip=None):
|
|||
expr : tvm.relay.Expr
|
||||
The input expression.
|
||||
|
||||
fskip: function
|
||||
fskip : function
|
||||
The callback function that decides whether an expression should be skipped.
|
||||
|
||||
Returns
|
||||
-------
|
||||
expr : tvm.relay.Expr
|
||||
result : tvm.relay.Expr
|
||||
The output expression.
|
||||
"""
|
||||
return _ir_pass.eliminate_common_subexpr(expr, fskip)
|
||||
|
||||
def partial_evaluate(expr):
|
||||
def partial_evaluate(expr, mod=None):
|
||||
"""
|
||||
Evaluate the static fragment of the code.
|
||||
|
||||
|
@ -646,12 +650,15 @@ def partial_evaluate(expr):
|
|||
expr : tvm.relay.Expr
|
||||
The input expression.
|
||||
|
||||
mod : Optional[tvm.relay.Module]
|
||||
The global module
|
||||
|
||||
Returns
|
||||
-------
|
||||
expr : tvm.relay.Expr
|
||||
result : tvm.relay.Expr
|
||||
The output expression.
|
||||
"""
|
||||
return _ir_pass.partial_evaluate(expr)
|
||||
return _ir_pass.partial_evaluate(expr, mod)
|
||||
|
||||
def unmatched_cases(match, mod=None):
|
||||
"""
|
||||
|
|
|
@ -220,8 +220,8 @@ TVM_REGISTER_API("relay._make.Call")
|
|||
|
||||
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
|
||||
.set_dispatch<CallNode>([](const CallNode* node, tvm::IRPrinter* p) {
|
||||
p->stream << "CallNode(" << node->op << ", " << node->args << ", "
|
||||
<< node->attrs << ", " << node->type_args << ")";
|
||||
p->stream << "CallNode(" << node->op << ", " << node->args << ", "
|
||||
<< node->attrs << ", " << node->type_args << ")";
|
||||
});
|
||||
|
||||
Let LetNode::make(Var var, Expr value, Expr body) {
|
||||
|
@ -324,7 +324,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
|
|||
|
||||
TVM_REGISTER_API("relay._expr.TempExprRealize")
|
||||
.set_body_typed<Expr(TempExpr)>([](TempExpr temp) {
|
||||
return temp->Realize();
|
||||
return temp->Realize();
|
||||
});
|
||||
|
||||
} // namespace relay
|
||||
|
|
|
@ -38,10 +38,10 @@ namespace relay {
|
|||
// calculate the dependency graph from expression
|
||||
class CalcDep : private ExprVisitor {
|
||||
public:
|
||||
static Expr Eliminate(const Expr& e) {
|
||||
static Expr Eliminate(const Expr& e, bool inline_once) {
|
||||
CalcDep cd;
|
||||
cd.Calculate(e);
|
||||
Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_);
|
||||
Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_, inline_once);
|
||||
return el(e);
|
||||
}
|
||||
|
||||
|
@ -117,15 +117,23 @@ class CalcDep : private ExprVisitor {
|
|||
VarMap<Expr> expr_map_;
|
||||
VarMap<size_t> use_map_;
|
||||
VarSet letrec_set_;
|
||||
bool inline_once_;
|
||||
explicit Eliminator(const VarMap<Expr>& expr_map,
|
||||
const VarMap<size_t>& use_map,
|
||||
const VarSet& letrec_set) :
|
||||
expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set) { }
|
||||
const VarSet& letrec_set,
|
||||
bool inline_once) :
|
||||
expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set), inline_once_(inline_once) { }
|
||||
friend CalcDep;
|
||||
|
||||
bool HasLet(const Var& v) {
|
||||
// TODO(@jroesch): MK fix me
|
||||
return (use_map_[v] > 0 || (use_map_[v] != 0 && letrec_set_.count(v) != 0));
|
||||
switch (use_map_[v]) {
|
||||
case 0:
|
||||
return false;
|
||||
case 1:
|
||||
return letrec_set_.count(v) > 0 || !inline_once_;
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
Expr VisitExpr_(const VarNode* op) final {
|
||||
|
@ -144,8 +152,8 @@ class CalcDep : private ExprVisitor {
|
|||
};
|
||||
};
|
||||
|
||||
Expr DeadCodeElimination(const Expr& e) {
|
||||
return CalcDep::Eliminate(e);
|
||||
Expr DeadCodeElimination(const Expr& e, bool inline_once) {
|
||||
return CalcDep::Eliminate(e, inline_once);
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay._ir_pass.dead_code_elimination")
|
||||
|
@ -153,10 +161,10 @@ TVM_REGISTER_API("relay._ir_pass.dead_code_elimination")
|
|||
|
||||
namespace transform {
|
||||
|
||||
Pass DeadCodeElimination() {
|
||||
Pass DeadCodeElimination(bool inline_once) {
|
||||
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
|
||||
[=](Function f, Module m, PassContext pc) {
|
||||
return Downcast<Function>(DeadCodeElimination(f));
|
||||
return Downcast<Function>(DeadCodeElimination(f, inline_once));
|
||||
};
|
||||
return CreateFunctionPass(pass_func, 1, "DeadCodeElimination", {});
|
||||
}
|
||||
|
|
|
@ -74,28 +74,19 @@
|
|||
*
|
||||
* The partial evaluator makes several assumptions, so there is room for improvement:
|
||||
*
|
||||
* 0: The partial evaluator treats global variables as opaque.
|
||||
* Doing PartialEval on a module level will solve this.
|
||||
*
|
||||
* 1: The partial evaluator assume all functions as terminating.
|
||||
* We need to has a max_expand parameter that shrink on every compile time evaluation,
|
||||
* to make sure PE does not infinite loop.
|
||||
* Additionally, we might add a termination analysis pass that lift this requirement
|
||||
* for function that analysis found terminating.
|
||||
*
|
||||
* 2: Every time an unknown effect happened, we clear the whole store.
|
||||
* 0: Every time an unknown effect happened, we clear the whole store.
|
||||
* It is too conservative: if a local reference is created (and do not get passed outside),
|
||||
* An unknown global function call/global reference write can not modify it.
|
||||
* We can pair PE with escape analysis/alias analysis.
|
||||
*
|
||||
* 3: We assume all unknown code has effect. Doing effect analysis can make the store more precise.
|
||||
* 1: We assume all unknown code has effect. Doing effect analysis can make the store more precise.
|
||||
*
|
||||
* 4: When doing pattern matching, we can simplify the match even for dynamic case.
|
||||
* 2: When doing pattern matching, we can simplify the match even for dynamic case.
|
||||
* Right now it is all or nothing: either a complete match, or the original dynamic code.
|
||||
* Instead, we can get a match tree, pair it with the data and evaluate it to a normal form.
|
||||
* We then can reify the result.
|
||||
*
|
||||
* 5: Every time a function is called, it's code will get expanded and partially evaluated.
|
||||
* 3: Every time a function is called, its code will get expanded and partially evaluated.
|
||||
* We can do a binding time analysis to cache the result and avoid re-partial evaluation.
|
||||
*
|
||||
* These assumptions do not affect the correctness of the algorithm, however.
|
||||
|
@ -104,6 +95,7 @@
|
|||
#include <tvm/relay/expr_functor.h>
|
||||
#include <tvm/relay/pattern_functor.h>
|
||||
#include <tvm/relay/interpreter.h>
|
||||
#include "../ir/type_functor.h"
|
||||
#include "pass_util.h"
|
||||
#include "let_list.h"
|
||||
|
||||
|
@ -132,6 +124,8 @@ struct VarEqual {
|
|||
}
|
||||
};
|
||||
|
||||
Expr PostProcess(const Expr&);
|
||||
|
||||
/*! \brief The base container type of Relay values. */
|
||||
class StaticNode : public RelayNode {
|
||||
public:
|
||||
|
@ -150,10 +144,20 @@ class Static : public NodeRef {
|
|||
using ContainerType = StaticNode;
|
||||
};
|
||||
|
||||
using Time = size_t;
|
||||
|
||||
struct PStaticNode : Node {
|
||||
static Time time() {
|
||||
static Time time_ = 0;
|
||||
Time ret = time_;
|
||||
time_++;
|
||||
return ret;
|
||||
}
|
||||
Static pstatic; // may be null
|
||||
Expr dynamic;
|
||||
PStaticNode(const Static& pstatic, const Expr& dynamic) : pstatic(pstatic), dynamic(dynamic) { }
|
||||
Time created_time;
|
||||
PStaticNode(const Static& pstatic, const Expr& dynamic) :
|
||||
pstatic(pstatic), dynamic(dynamic), created_time(time()) { }
|
||||
explicit PStaticNode(const Expr& dynamic) : PStaticNode(Static(), dynamic) { }
|
||||
TVM_DECLARE_NODE_TYPE_INFO(PStaticNode, Node);
|
||||
};
|
||||
|
@ -341,6 +345,7 @@ class Store {
|
|||
};
|
||||
|
||||
PStatic HasStatic(const Static& stat, const Expr& dynamic) {
|
||||
CHECK(stat.defined());
|
||||
return PStatic(make_node<PStaticNode>(stat, dynamic));
|
||||
}
|
||||
|
||||
|
@ -383,15 +388,78 @@ FInterpreter CPUInterpreter() {
|
|||
return CreateInterpreter(Module(nullptr), CPUContext(), target);
|
||||
}
|
||||
|
||||
bool IsAtomic(const Expr& e) {
|
||||
return e.as<VarNode>() || e.as<OpNode>() || e.as<ConstructorNode>() || e.as<GlobalVarNode>();
|
||||
}
|
||||
|
||||
using FuncId = int;
|
||||
|
||||
/*!
|
||||
* \brief Annotate a function with a FuncId.
|
||||
*/
|
||||
struct WithFuncIdAttrs : public tvm::AttrsNode<WithFuncIdAttrs> {
|
||||
FuncId fid;
|
||||
|
||||
TVM_DECLARE_ATTRS(WithFuncIdAttrs, "relay.attrs.WithFuncIdAttrs") {
|
||||
TVM_ATTR_FIELD(fid)
|
||||
.describe("The FuncId that an function is annotated with.")
|
||||
.set_default(-1);
|
||||
}
|
||||
};
|
||||
|
||||
TVM_REGISTER_NODE_TYPE(WithFuncIdAttrs);
|
||||
|
||||
Op WithFuncIdOp() {
|
||||
static const Op& op = Op::Get("annotation.with_funcid");
|
||||
return op;
|
||||
}
|
||||
|
||||
Expr MkWithFuncId(const Expr& expr, FuncId fid) {
|
||||
auto attrs = make_node<WithFuncIdAttrs>();
|
||||
attrs->fid = fid;
|
||||
return CallNode::make(WithFuncIdOp(), {expr}, Attrs(attrs), {});
|
||||
}
|
||||
|
||||
RELAY_REGISTER_OP("annotation.with_funcid")
|
||||
.describe(R"code(Annotate a function with a funcid.)code"
|
||||
TVM_ADD_FILELINE)
|
||||
.set_num_inputs(1)
|
||||
.add_argument("func", "Function", "The input data.");
|
||||
|
||||
Expr StripWithFuncId(const Expr& e);
|
||||
|
||||
Expr DeDup(const Expr& e);
|
||||
|
||||
Function AsFunc(const Expr& e) {
|
||||
if (e.as<FunctionNode>()) {
|
||||
return Downcast<Function>(e);
|
||||
} else if (const CallNode* c = e.as<CallNode>()) {
|
||||
CHECK(c->op.same_as(WithFuncIdOp()));
|
||||
CHECK_EQ(c->args.size(), 1);
|
||||
return AsFunc(c->args[0]);
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown case";
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>,
|
||||
public PatternFunctor<MatchStatus(const Pattern&, const PStatic&)> {
|
||||
public:
|
||||
PartialEvaluator(const tvm::Array<Var>& free_vars) {
|
||||
PartialEvaluator(const tvm::Array<Var>& free_vars,
|
||||
const Module& mod) :
|
||||
mod_(mod) {
|
||||
for (const Var& v : free_vars) {
|
||||
env_.Insert(v, NoStatic(v));
|
||||
}
|
||||
}
|
||||
|
||||
PStatic VisitExpr(const Expr& e, LetList* ll) final {
|
||||
PStatic ret = ExprFunctor<PStatic(const Expr&, LetList*)>::VisitExpr(e, ll);
|
||||
CHECK(IsAtomic(ret->dynamic)) << ret->dynamic;
|
||||
return ret;
|
||||
}
|
||||
|
||||
PStatic VisitExpr_(const ConstantNode* op, LetList* ll) final {
|
||||
return HasStatic(MkSTensor(op->data.CopyTo(context_)), ll->Push(GetRef<Expr>(op)));
|
||||
}
|
||||
|
@ -421,7 +489,20 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
|
|||
}
|
||||
|
||||
PStatic VisitExpr_(const GlobalVarNode* op, LetList* ll) final {
|
||||
return NoStatic(GetRef<Expr>(op));
|
||||
GlobalVar gv = GetRef<GlobalVar>(op);
|
||||
if (gv_map_.count(gv) == 0) {
|
||||
if (mod_.defined()) {
|
||||
Function func = mod_->Lookup(gv);
|
||||
InitializeFuncId(func);
|
||||
Func f = VisitFuncStatic(func, gv);
|
||||
gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)});
|
||||
func = AsFunc(PostProcess(VisitFuncDynamic(func, f)));
|
||||
mod_->Update(gv, func);
|
||||
} else {
|
||||
gv_map_.insert({gv, NoStatic(gv)});
|
||||
}
|
||||
}
|
||||
return gv_map_.at(gv);
|
||||
}
|
||||
|
||||
PStatic VisitExpr_(const LetNode* op, LetList* ll) final {
|
||||
|
@ -485,6 +566,10 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
|
|||
}
|
||||
|
||||
PStatic VisitExpr_(const CallNode* op, LetList* ll) final {
|
||||
if (op->op.same_as(WithFuncIdOp())) {
|
||||
CHECK_EQ(op->args.size(), 1);
|
||||
return VisitExpr(op->args[0], ll);
|
||||
}
|
||||
PStatic f = VisitExpr(op->op, ll);
|
||||
std::vector<PStatic> x;
|
||||
tvm::Array<Expr> x_dyn;
|
||||
|
@ -501,19 +586,40 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
|
|||
}
|
||||
}
|
||||
|
||||
PStatic VisitExpr_(const FunctionNode* op, LetList* ll) final {
|
||||
Function func = GetRef<Function>(op);
|
||||
struct TimeFrame {
|
||||
PartialEvaluator* pe_;
|
||||
FuncId fid_;
|
||||
std::vector<Time> old_time;
|
||||
bool has_old_time;
|
||||
TimeFrame(PartialEvaluator* pe,
|
||||
FuncId fid,
|
||||
const std::vector<Time>& args_time) : pe_(pe), fid_(fid) {
|
||||
has_old_time = pe_->time_map_.count(fid_) > 0;
|
||||
old_time = pe_->time_map_[fid_];
|
||||
pe_->time_map_[fid_] = args_time;
|
||||
}
|
||||
~TimeFrame() {
|
||||
if (has_old_time) {
|
||||
pe_->time_map_[fid_] = old_time;
|
||||
} else {
|
||||
pe_->time_map_.erase(fid_);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Func VisitFuncStatic(const Function& func, const Expr& var) {
|
||||
CHECK(IsAtomic(var));
|
||||
if (func->IsPrimitive()) {
|
||||
return HasStatic(MkSFunc(ConstEvaluateFunc(func, ll)), func);
|
||||
return ConstEvaluateFunc(func);
|
||||
}
|
||||
std::vector<std::pair<Var, PStatic> > free_vars;
|
||||
for (const auto& v : FreeVars(GetRef<Expr>(op))) {
|
||||
for (const auto& v : FreeVars(func)) {
|
||||
free_vars.push_back(std::pair<Var, PStatic>(v, env_.Lookup(v)));
|
||||
}
|
||||
Func f = [=](const std::vector<PStatic>& pv,
|
||||
const Attrs& attrs,
|
||||
const tvm::Array<Type>& type_args,
|
||||
LetList* ll) {
|
||||
return [=](const std::vector<PStatic>& pv,
|
||||
const Attrs& attrs,
|
||||
const tvm::Array<Type>& type_args,
|
||||
LetList* ll) {
|
||||
return env_.Extend<PStatic>([&]() {
|
||||
CHECK_EQ(pv.size(), func->params.size());
|
||||
for (size_t i = 0; i < pv.size(); ++i) {
|
||||
|
@ -529,10 +635,50 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
|
|||
for (size_t i = type_args.size(); i < func->type_params.size(); ++i) {
|
||||
subst.Set(func->type_params[i], Type());
|
||||
}
|
||||
return VisitExpr(TypeSubst(func->body, subst), ll);
|
||||
std::vector<Time> args_time;
|
||||
for (const auto& v : pv) {
|
||||
args_time.push_back(v->created_time);
|
||||
}
|
||||
CHECK_GT(func_map_.count(func), 0);
|
||||
FuncId fid = func_map_.at(func);
|
||||
auto recurse = [&]() {
|
||||
TimeFrame tf(this, fid, args_time);
|
||||
return VisitExpr(RegisterFuncId(TypeSubst(AnnotateFuncId(func->body), subst)), ll);
|
||||
};
|
||||
if (time_map_.count(fid) == 0) {
|
||||
return recurse();
|
||||
} else {
|
||||
/* We check to see that at least one argument decrease
|
||||
* with respect to all previous invocation.
|
||||
* The depth of the recursion is bounded by
|
||||
* the sum of the time of all argument at the first call.
|
||||
*/
|
||||
bool can_recurse = false;
|
||||
std::vector<Time>& min_time = time_map_.at(fid);
|
||||
CHECK_EQ(args_time.size(), min_time.size());
|
||||
for (size_t i = 0; i < args_time.size(); ++i) {
|
||||
if (args_time[i] < min_time[i]) {
|
||||
can_recurse = true;
|
||||
}
|
||||
args_time[i] = std::min(args_time[i], min_time[i]);
|
||||
}
|
||||
if (can_recurse) {
|
||||
return recurse();
|
||||
} else {
|
||||
std::vector<Expr> dyn;
|
||||
for (const auto& v : pv) {
|
||||
dyn.push_back(v->dynamic);
|
||||
}
|
||||
return NoStatic(ll->Push(CallNode::make(var, dyn, attrs, type_args)));
|
||||
}
|
||||
}
|
||||
});
|
||||
};
|
||||
Expr dyn = store_.Extend<Expr>([&]() {
|
||||
}
|
||||
|
||||
|
||||
Expr VisitFuncDynamic(const Function& func, const Func& f) {
|
||||
return store_.Extend<Expr>([&]() {
|
||||
store_.Invalidate();
|
||||
return FunctionNode::make(func->params, LetList::With([&](LetList* ll) {
|
||||
std::vector<PStatic> pv;
|
||||
|
@ -546,7 +692,20 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
|
|||
return f(pv, Attrs(), type_args, ll)->dynamic;
|
||||
}), func->ret_type, func->type_params, func->attrs);
|
||||
});
|
||||
return HasStatic(MkSFunc(f), ll->Push(dyn));
|
||||
}
|
||||
|
||||
PStatic VisitFunc(const Function& func, LetList* ll) {
|
||||
Var v = VarNode::make("x", Type());
|
||||
Func f = VisitFuncStatic(func, v);
|
||||
Function u_func = AsFunc(RegisterFuncId(DeDup(AnnotateFuncId(func))));
|
||||
// TODO(@M.K.): we seems to reduce landin knot into letrec.
|
||||
// restore letrec support across whole relay.
|
||||
return HasStatic(MkSFunc(f),
|
||||
ll->Push(v, VisitFuncDynamic(u_func, f)));
|
||||
}
|
||||
|
||||
PStatic VisitExpr_(const FunctionNode* op, LetList* ll) final {
|
||||
return VisitFunc(GetRef<Function>(op), ll);
|
||||
}
|
||||
|
||||
Expr Reflect(const PStatic& st) {
|
||||
|
@ -590,7 +749,8 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
|
|||
return Reify(executor_(fused_infered), ll);
|
||||
}
|
||||
|
||||
Func ConstEvaluateFunc(const Expr& expr, LetList* ll) {
|
||||
Func ConstEvaluateFunc(const Expr& expr) {
|
||||
CHECK_EQ(FreeVars(expr).size(), 0);
|
||||
return [=](const std::vector<PStatic>& pv,
|
||||
const Attrs& attrs,
|
||||
const tvm::Array<Type>& type_args,
|
||||
|
@ -599,7 +759,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
|
|||
for (const PStatic& ps : pv) {
|
||||
ns_args.push_back(ps->dynamic);
|
||||
}
|
||||
PStatic ns = NoStatic(CallNode::make(expr, ns_args, attrs, type_args));
|
||||
PStatic ns = NoStatic(ll->Push(CallNode::make(expr, ns_args, attrs, type_args)));
|
||||
if (StatefulOp(expr)) {
|
||||
return ns;
|
||||
}
|
||||
|
@ -616,7 +776,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
|
|||
}
|
||||
|
||||
PStatic VisitExpr_(const OpNode* op, LetList* ll) final {
|
||||
return HasStatic(MkSFunc(ConstEvaluateFunc(GetRef<Expr>(op), ll)), GetRef<Expr>(op));
|
||||
return HasStatic(MkSFunc(ConstEvaluateFunc(GetRef<Expr>(op))), GetRef<Expr>(op));
|
||||
}
|
||||
|
||||
PStatic VisitExpr_(const ConstructorNode* op, LetList* ll) final {
|
||||
|
@ -680,7 +840,6 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
|
|||
CHECK_NE(op->constructor->tag, -1);
|
||||
CHECK_NE(scn->constructor->tag, -1);
|
||||
if (op->constructor->tag == scn->constructor->tag) {
|
||||
// todo(M.K.): should use ptr equality but it is broken
|
||||
CHECK_EQ(op->patterns.size(), scn->fields.size());
|
||||
MatchStatus current_match_status = MatchStatus::Match;
|
||||
for (size_t i = 0; i < op->patterns.size(); ++i) {
|
||||
|
@ -702,27 +861,119 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
|
|||
}
|
||||
}
|
||||
|
||||
void InitializeFuncId(const Expr& e) {
|
||||
struct InitializeFuncIdVisitor : ExprVisitor, PatternVisitor {
|
||||
PartialEvaluator* pe;
|
||||
explicit InitializeFuncIdVisitor(PartialEvaluator* pe) : pe(pe) { }
|
||||
|
||||
void VisitExpr_(const FunctionNode* op) final {
|
||||
Function f = GetRef<Function>(op);
|
||||
CHECK_EQ(pe->func_map_.count(f), 0);
|
||||
pe->func_map_.insert({f, pe->func_map_.size()});
|
||||
VisitExpr(f->body);
|
||||
}
|
||||
|
||||
void VisitPattern(const Pattern& p) final {
|
||||
PatternVisitor::VisitPattern(p);
|
||||
}
|
||||
};
|
||||
InitializeFuncIdVisitor(this).VisitExpr(e);
|
||||
}
|
||||
|
||||
Expr RegisterFuncId(const Expr& e) {
|
||||
struct RegisterFuncIdVisitor : ExprVisitor, PatternVisitor {
|
||||
PartialEvaluator* pe;
|
||||
explicit RegisterFuncIdVisitor(PartialEvaluator* pe) : pe(pe) { }
|
||||
|
||||
void VisitExpr_(const CallNode* op) final {
|
||||
if (op->op.same_as(WithFuncIdOp())) {
|
||||
CHECK_EQ(op->args.size(), 1);
|
||||
CHECK(op->attrs.defined());
|
||||
CHECK(op->attrs.as<WithFuncIdAttrs>());
|
||||
Function f = AsFunc(op->args[0]);
|
||||
FuncId fid = op->attrs.as<WithFuncIdAttrs>()->fid;
|
||||
if (pe->func_map_.count(f) != 0) {
|
||||
CHECK_EQ(pe->func_map_.at(f), fid);
|
||||
}
|
||||
pe->func_map_.insert({f, fid});
|
||||
}
|
||||
ExprVisitor::VisitExpr_(op);
|
||||
}
|
||||
|
||||
void VisitExpr_(const FunctionNode* op) final {
|
||||
Function f = GetRef<Function>(op);
|
||||
CHECK_GT(pe->func_map_.count(f), 0);
|
||||
ExprVisitor::VisitExpr_(op);
|
||||
}
|
||||
|
||||
void VisitPattern(const Pattern& p) final {
|
||||
PatternVisitor::VisitPattern(p);
|
||||
}
|
||||
};
|
||||
RegisterFuncIdVisitor(this).VisitExpr(e);
|
||||
return e;
|
||||
}
|
||||
|
||||
Expr AnnotateFuncId(const Expr& e) {
|
||||
struct AnnotateFuncIdMutator : ExprMutator, PatternMutator {
|
||||
PartialEvaluator* pe;
|
||||
explicit AnnotateFuncIdMutator(PartialEvaluator* pe) : pe(pe) { }
|
||||
|
||||
Expr VisitExpr_(const FunctionNode* op) final {
|
||||
Function f = GetRef<Function>(op);
|
||||
CHECK_GT(pe->func_map_.count(f), 0);
|
||||
return MkWithFuncId(ExprMutator::VisitExpr_(op), pe->func_map_.at(f));
|
||||
}
|
||||
|
||||
Pattern VisitPattern(const Pattern& p) final {
|
||||
return PatternMutator::VisitPattern(p);
|
||||
}
|
||||
|
||||
Var VisitVar(const Var& v) final {
|
||||
return v;
|
||||
}
|
||||
};
|
||||
return AnnotateFuncIdMutator(this).VisitExpr(e);
|
||||
}
|
||||
|
||||
private:
|
||||
Environment env_;
|
||||
Module mod_;
|
||||
std::unordered_map<GlobalVar, PStatic, NodeHash, NodeEqual> gv_map_;
|
||||
/*! Termination checking is done as follows:
|
||||
* We have finitely many FunctionIds.
|
||||
* Each FunctionId maps to a class of semantically equivalent function (ignoring type),
|
||||
* as both TypeSubst and DeDup create semantically equivalent function.
|
||||
* We partially map each FunctionId to a std::vector<Time>,
|
||||
* denoting the minimal TimeFrame of each argument of the function.
|
||||
* Every time we try to inline a Function,
|
||||
* we make sure it either does not have a vector<Time>, which means this is the initial call,
|
||||
* or some argument has a lesser time, which means some earlier argument is passed in.
|
||||
* In any case, we remap the mapping to a minimal vector<Time> across all previous invocations
|
||||
* when we PE inside the Function body.
|
||||
* Termination is guaranteed because the creation time of at least one argument will decrease every call.
|
||||
*/
|
||||
std::unordered_map<Function, FuncId, NodeHash, NodeEqual> func_map_;
|
||||
std::unordered_map<FuncId, std::vector<Time> > time_map_;
|
||||
Store store_;
|
||||
DLContext context_ = CPUContext();
|
||||
FInterpreter executor_ = CPUInterpreter();
|
||||
};
|
||||
|
||||
Var DeDupVar(const Var& v) {
|
||||
return VarNode::make(v->name_hint(), v->type_annotation);
|
||||
}
|
||||
|
||||
TypeVar DeDupTypeVar(const TypeVar& tv) {
|
||||
return TypeVarNode::make(tv->var->name_hint, tv->kind);
|
||||
}
|
||||
|
||||
/*! \brief Use a fresh Id for every Var to make the result well-formed. */
|
||||
Expr DeDup(const Expr& e) {
|
||||
class DeDupMutator : public ExprMutator, public PatternMutator {
|
||||
class DeDupMutator : public TypeMutator,
|
||||
public ExprMutator,
|
||||
public PatternMutator {
|
||||
public:
|
||||
TypeVar Fresh(const TypeVar& tv) {
|
||||
TypeVar ret = TypeVarNode::make(tv->var->name_hint, tv->kind);
|
||||
type_rename_[tv] = ret;
|
||||
return ret;
|
||||
}
|
||||
|
||||
Var Fresh(const Var& v) {
|
||||
Var ret = DeDupVar(v);
|
||||
Var ret = VarNode::make(v->name_hint(), VisitType(v->type_annotation));
|
||||
rename_[v] = ret;
|
||||
return ret;
|
||||
}
|
||||
|
@ -737,18 +988,27 @@ Expr DeDup(const Expr& e) {
|
|||
}
|
||||
|
||||
Expr VisitExpr_(const LetNode* op) final {
|
||||
return LetNode::make(Fresh(op->var), VisitExpr(op->value), VisitExpr(op->body));
|
||||
Var v = Fresh(op->var);
|
||||
return LetNode::make(v, VisitExpr(op->value), VisitExpr(op->body));
|
||||
}
|
||||
|
||||
Type VisitType(const Type& t) final {
|
||||
return t.defined() ? TypeMutator::VisitType(t) : t;
|
||||
}
|
||||
|
||||
Expr VisitExpr_(const FunctionNode* op) final {
|
||||
tvm::Array<TypeVar> type_params;
|
||||
for (const TypeVar& type_param : op->type_params) {
|
||||
type_params.push_back(Fresh(type_param));
|
||||
}
|
||||
tvm::Array<Var> params;
|
||||
for (const Var& param : op->params) {
|
||||
params.push_back(Fresh(param));
|
||||
}
|
||||
return FunctionNode::make(params,
|
||||
VisitExpr(op->body),
|
||||
op->ret_type,
|
||||
op->type_params,
|
||||
VisitType(op->ret_type),
|
||||
type_params,
|
||||
op->attrs);
|
||||
}
|
||||
|
||||
|
@ -756,14 +1016,28 @@ Expr DeDup(const Expr& e) {
|
|||
return PatternMutator::VisitPattern(p);
|
||||
}
|
||||
|
||||
Clause VisitClause(const Clause& c) final {
|
||||
Pattern pat = VisitPattern(c->lhs);
|
||||
return ClauseNode::make(pat, VisitExpr(c->rhs));
|
||||
}
|
||||
|
||||
Type VisitType_(const TypeVarNode* op) final {
|
||||
TypeVar v = GetRef<TypeVar>(op);
|
||||
return type_rename_.count(v) != 0 ? type_rename_.at(v) : v;
|
||||
}
|
||||
|
||||
Var VisitVar(const Var& v) final {
|
||||
return Fresh(v);
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_map<Var, Var, NodeHash, NodeEqual> rename_;
|
||||
std::unordered_map<TypeVar, TypeVar, NodeHash, NodeEqual> type_rename_;
|
||||
};
|
||||
return DeDupMutator().VisitExpr(e);
|
||||
|
||||
Expr ret = DeDupMutator().VisitExpr(e);
|
||||
CHECK_EQ(FreeVars(ret).size(), FreeVars(e).size());
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*! \brief Remap multiple Var sharing the same Id into the same Var. */
|
||||
|
@ -787,11 +1061,38 @@ Expr Remap(const Expr& e) {
|
|||
return RemapMutator().VisitExpr(e);
|
||||
}
|
||||
|
||||
Expr PartialEval(const Expr& e) {
|
||||
Expr StripWithFuncId(const Expr& e) {
|
||||
struct StripWithFuncIdMutator : ExprMutator, PatternMutator {
|
||||
Expr VisitExpr_(const CallNode* op) final {
|
||||
if (op->op.same_as(WithFuncIdOp())) {
|
||||
CHECK_EQ(op->args.size(), 1);
|
||||
return VisitExpr(op->args[0]);
|
||||
} else {
|
||||
return ExprMutator::VisitExpr_(op);
|
||||
}
|
||||
}
|
||||
|
||||
Pattern VisitPattern(const Pattern& p) final {
|
||||
return PatternMutator::VisitPattern(p);
|
||||
}
|
||||
|
||||
Var VisitVar(const Var& v) final {
|
||||
return v;
|
||||
}
|
||||
};
|
||||
return StripWithFuncIdMutator().VisitExpr(e);
|
||||
}
|
||||
|
||||
Expr PostProcess(const Expr& e) {
|
||||
return StripWithFuncId(DeDup(Remap(e)));
|
||||
}
|
||||
|
||||
Expr PartialEval(const Expr& e, const Module& m) {
|
||||
return TransformF([&](const Expr& e) {
|
||||
return LetList::With([&](LetList* ll) {
|
||||
PartialEvaluator pe(FreeVars(e));
|
||||
return Remap(DeDup(pe.VisitExpr(e, ll)->dynamic));
|
||||
PartialEvaluator pe(FreeVars(e), m);
|
||||
pe.InitializeFuncId(e);
|
||||
return PostProcess(pe.VisitExpr(e, ll)->dynamic);
|
||||
});
|
||||
}, e);
|
||||
}
|
||||
|
@ -804,7 +1105,7 @@ namespace transform {
|
|||
Pass PartialEval() {
|
||||
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
|
||||
[=](Function f, Module m, PassContext pc) {
|
||||
return Downcast<Function>(PartialEval(f));
|
||||
return Downcast<Function>(PartialEval(f, m));
|
||||
};
|
||||
return CreateFunctionPass(pass_func, 1, "PartialEvaluate", {});
|
||||
}
|
||||
|
|
|
@ -18,14 +18,17 @@
|
|||
import numpy as np
|
||||
import tvm
|
||||
from tvm import relay
|
||||
from tvm.relay.ir_pass import partial_evaluate, dead_code_elimination
|
||||
from tvm.relay.ir_pass import gradient, alpha_equal, infer_type
|
||||
from tvm.relay.ir_pass import partial_evaluate, alpha_equal, infer_type, dead_code_elimination
|
||||
from tvm.relay.ir_pass import gradient
|
||||
from tvm.relay import op, create_executor
|
||||
from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue
|
||||
from tvm.relay.prelude import Prelude
|
||||
from tvm.relay import create_executor
|
||||
|
||||
from nose.tools import nottest
|
||||
from tvm.relay import Var, TypeVar, TupleGetItem, Let, Function, const, RefRead, RefWrite, RefCreate
|
||||
from tvm.relay import TensorType, Tuple, If, Module, Clause, PatternConstructor, PatternVar, Match
|
||||
from tvm.relay import GlobalVar, Call, Type
|
||||
from tvm.relay.testing import add_nat_definitions, count, make_nat_value, make_nat_expr
|
||||
|
||||
def check_eval(expr, expected_result, mod=None, rtol=1e-07):
|
||||
ctx = tvm.context("llvm", 0)
|
||||
|
@ -35,24 +38,22 @@ def check_eval(expr, expected_result, mod=None, rtol=1e-07):
|
|||
np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol)
|
||||
|
||||
|
||||
def dcpe(expr):
|
||||
return dead_code_elimination(partial_evaluate(expr))
|
||||
def dcpe(expr, mod=None):
|
||||
return dead_code_elimination(partial_evaluate(expr, mod=mod), inline_once=True)
|
||||
|
||||
|
||||
def test_tuple():
|
||||
t = relay.TypeVar("t")
|
||||
x = relay.Var("x", t)
|
||||
body = relay.TupleGetItem(relay.Tuple([relay.const(4.0), x]), 1)
|
||||
f = relay.Function([x], body, None, [t])
|
||||
t = TypeVar("t")
|
||||
x = Var("x", t)
|
||||
body = TupleGetItem(relay.Tuple([relay.const(4.0), x]), 1)
|
||||
f = Function([x], body, None, [t])
|
||||
assert alpha_equal(dcpe(f), relay.Function([x], x, None, [t]))
|
||||
|
||||
@nottest
|
||||
def test_const_inline():
|
||||
# TODO(MK): fix me
|
||||
d = relay.Var("d")
|
||||
double = relay.Function([d], d + d)
|
||||
orig = double(relay.const(4.0))
|
||||
assert alpha_equal(dcpe(double(relay.const(4.0))), relay.const(8.0))
|
||||
d = Var("d")
|
||||
double = Function([d], d + d)
|
||||
orig = double(const(4.0))
|
||||
assert alpha_equal(dcpe(orig), const(8.0))
|
||||
|
||||
|
||||
def test_ref():
|
||||
|
@ -60,44 +61,57 @@ def test_ref():
|
|||
r = relay.Var("r")
|
||||
x = relay.Var("x")
|
||||
body = relay.RefRead(r)
|
||||
body = relay.Let(x, relay.RefWrite(r, relay.RefRead(r) * relay.RefRead(r)), body)
|
||||
body = relay.Let(r, relay.RefCreate(d), body)
|
||||
square = relay.Function([d], body)
|
||||
assert alpha_equal(dcpe(square), relay.Function([d], d * d))
|
||||
body = Let(x, RefWrite(r, RefRead(r) * RefRead(r)), body)
|
||||
body = Let(r, RefCreate(d), body)
|
||||
square = Function([d], body)
|
||||
assert alpha_equal(dcpe(square), Function([d], d * d))
|
||||
|
||||
@nottest
|
||||
def test_ad():
|
||||
# TODO(MK): fix me
|
||||
|
||||
def test_empty_ad():
|
||||
shape = (10, 10)
|
||||
dtype = "float32"
|
||||
t = relay.TensorType(shape, dtype)
|
||||
d = relay.Var("d", t)
|
||||
f = relay.Function([d], d * d)
|
||||
t = TensorType(shape, dtype)
|
||||
d = Var("d", t)
|
||||
f = Function([d], d)
|
||||
g = dcpe(gradient(f))
|
||||
expected = Function([d], Tuple([d, Tuple([op.ones_like(d)])]))
|
||||
assert alpha_equal(g, expected)
|
||||
|
||||
def test_ad():
|
||||
shape = (10, 10)
|
||||
dtype = "float32"
|
||||
t = TensorType(shape, dtype)
|
||||
d = Var("d", t)
|
||||
f = Function([d], d * d)
|
||||
g = dcpe(gradient(f))
|
||||
m = d * d
|
||||
o = relay.op.ones_like(m)
|
||||
grad = relay.op.zeros_like(d) + relay.op.collapse_sum_like(o * d, d) + relay.op.collapse_sum_like(o * d, d)
|
||||
expected = relay.Function([d], relay.Tuple([m, relay.Tuple([grad])]))
|
||||
x = relay.Var("x")
|
||||
o = op.ones_like(x)
|
||||
x1 = relay.Var("x1")
|
||||
grad = op.zeros_like(d) + op.collapse_sum_like(x1 * d, d) + op.collapse_sum_like(x1 * d, d)
|
||||
body = Tuple([x, Tuple([grad])])
|
||||
body = relay.Let(x1, o, body)
|
||||
expected = Function([d], relay.Let(x, m, body))
|
||||
assert alpha_equal(g, expected)
|
||||
|
||||
|
||||
def test_if_ref():
|
||||
shape = ()
|
||||
dtype = "bool"
|
||||
t = relay.TensorType(shape, dtype)
|
||||
d = relay.Var("d", t)
|
||||
r = relay.Var("r")
|
||||
update = relay.Function([], relay.RefWrite(r, relay.RefRead(r) + relay.RefRead(r)))
|
||||
u = relay.Var("u")
|
||||
body = relay.If(d, u(), u())
|
||||
eff = relay.Var("eff")
|
||||
body = relay.Let(eff, body, relay.RefRead(r))
|
||||
f = relay.Function([d], relay.Let(r, relay.RefCreate(relay.const(1)), relay.Let(u, update, body)))
|
||||
t = TensorType(shape, dtype)
|
||||
d = Var("d", t)
|
||||
r = Var("r")
|
||||
update = Function([], RefWrite(r, RefRead(r) + RefRead(r)))
|
||||
u = Var("u")
|
||||
body = If(d, u(), u())
|
||||
eff = Var("eff")
|
||||
body = Let(eff, body, RefRead(r))
|
||||
f = Function([d], Let(r, RefCreate(const(1)), Let(u, update, body)))
|
||||
f = infer_type(f)
|
||||
pe_f = infer_type(partial_evaluate(f))
|
||||
ex = create_executor()
|
||||
f_res = ex.evaluate(f)(relay.const(True))
|
||||
pe_f_res = ex.evaluate(pe_f)(relay.const(True))
|
||||
f_res = ex.evaluate(f)(const(True))
|
||||
pe_f_res = ex.evaluate(pe_f)(const(True))
|
||||
np.testing.assert_allclose(f_res.asnumpy(), 2 * np.ones_like(f_res.asnumpy()))
|
||||
np.testing.assert_allclose(pe_f_res.asnumpy(), 2 * np.ones_like(pe_f_res.asnumpy()))
|
||||
|
||||
|
@ -105,52 +119,162 @@ def test_if_ref():
|
|||
def test_function_invalidate():
|
||||
shape = ()
|
||||
dtype = "bool"
|
||||
t = relay.TensorType(shape, dtype)
|
||||
d = relay.Var("d", t)
|
||||
r = relay.Var("r")
|
||||
fetch = relay.Function([], relay.RefRead(r))
|
||||
fet = relay.Var("fetch")
|
||||
fet_obscured = relay.Var("fetch_obscured")
|
||||
u = relay.Var("u")
|
||||
body = relay.If(d, fet_obscured(), fet_obscured())
|
||||
body = relay.Let(u, relay.RefWrite(r, relay.const(1)), body)
|
||||
body = relay.Let(fet_obscured, relay.If(d, fet, fet), body)
|
||||
body = relay.Let(fet, fetch, body)
|
||||
body = relay.Let(r, relay.RefCreate(relay.const(0)), body)
|
||||
f = relay.Function([d], body)
|
||||
t = TensorType(shape, dtype)
|
||||
d = Var("d", t)
|
||||
r = Var("r")
|
||||
fetch = Function([], RefRead(r))
|
||||
fet = Var("fetch")
|
||||
fet_obscured = Var("fetch_obscured")
|
||||
u = Var("u")
|
||||
body = If(d, fet_obscured(), fet_obscured())
|
||||
body = Let(u, RefWrite(r, const(1)), body)
|
||||
body = Let(fet_obscured, If(d, fet, fet), body)
|
||||
body = Let(fet, fetch, body)
|
||||
body = Let(r, RefCreate(const(0)), body)
|
||||
f = Function([d], body)
|
||||
f = infer_type(f)
|
||||
pe_f = infer_type(partial_evaluate(f))
|
||||
ex = create_executor()
|
||||
f_res = ex.evaluate(f)(relay.const(True))
|
||||
pe_f_res = ex.evaluate(pe_f)(relay.const(True))
|
||||
f_res = ex.evaluate(f)(const(True))
|
||||
pe_f_res = ex.evaluate(pe_f)(const(True))
|
||||
np.testing.assert_allclose(f_res.asnumpy(), np.ones_like(f_res.asnumpy()))
|
||||
np.testing.assert_allclose(pe_f_res.asnumpy(), np.ones_like(pe_f_res.asnumpy()))
|
||||
|
||||
|
||||
def test_head_cons():
|
||||
mod = relay.Module()
|
||||
mod = Module()
|
||||
p = Prelude(mod)
|
||||
def hd_impl():
|
||||
a = relay.TypeVar("a")
|
||||
x = relay.Var("x", p.l(a))
|
||||
y = relay.Var("y")
|
||||
z = relay.Var("z")
|
||||
cons_case = relay.Clause(relay.PatternConstructor(p.cons,
|
||||
[relay.PatternVar(y),
|
||||
relay.PatternVar(z)]),
|
||||
y)
|
||||
return relay.Function([x], relay.Match(x, [cons_case]), a, [a])
|
||||
t = relay.TypeVar("t")
|
||||
x = relay.Var("x", t)
|
||||
hd = relay.Var("hd")
|
||||
body = relay.Let(hd, hd_impl(), hd(p.cons(x, p.nil())))
|
||||
f = relay.Function([x], body, None, [t])
|
||||
a = TypeVar("a")
|
||||
x = Var("x", p.l(a))
|
||||
y = Var("y")
|
||||
z = Var("z")
|
||||
cons_case = Clause(PatternConstructor(p.cons,
|
||||
[PatternVar(y),
|
||||
PatternVar(z)]),
|
||||
y)
|
||||
y = Var("y")
|
||||
z = Var("z")
|
||||
return Function([x], Match(x, [cons_case]), a, [a])
|
||||
t = TypeVar("t")
|
||||
x = Var("x", t)
|
||||
hd = Var("hd")
|
||||
body = Let(hd, hd_impl(), hd(p.cons(x, p.nil())))
|
||||
f = Function([x], body, None, [t])
|
||||
f = infer_type(f, mod=mod)
|
||||
res = dcpe(f)
|
||||
assert alpha_equal(res, relay.Function([x], x, t, [t]))
|
||||
assert alpha_equal(res, Function([x], x, t, [t]))
|
||||
|
||||
|
||||
def test_map():
|
||||
mod = Module()
|
||||
p = Prelude(mod)
|
||||
f = Var("f")
|
||||
orig = p.map(f, p.cons(const(1), p.cons(const(2), p.cons(const(3), p.nil()))))
|
||||
expected = p.cons(f(const(1)), p.cons(f(const(2)), p.cons(f(const(3)), p.nil())))
|
||||
assert alpha_equal(dcpe(orig, mod=mod), expected)
|
||||
|
||||
|
||||
def test_loop():
|
||||
mod = Module()
|
||||
t = TypeVar("t")
|
||||
x = Var("x", t)
|
||||
loop = GlobalVar("loop")
|
||||
mod[loop] = Function([x], loop(x), t, [t])
|
||||
res = dcpe(loop(const(1)), mod=mod)
|
||||
expected = Call(loop, [const(1)], None, [None])
|
||||
assert alpha_equal(res, expected)
|
||||
|
||||
|
||||
def test_swap_loop():
|
||||
mod = Module()
|
||||
p = Prelude(mod)
|
||||
add_nat_definitions(p)
|
||||
nat = p.nat()
|
||||
x = Var("x", nat)
|
||||
y = Var("y", nat)
|
||||
loop = GlobalVar("loop")
|
||||
mod[loop] = Function([x, y], loop(y, x), nat)
|
||||
prog = loop(make_nat_expr(p, 1), make_nat_expr(p, 2))
|
||||
res = dcpe(prog, mod=mod)
|
||||
assert alpha_equal(prog, res)
|
||||
|
||||
|
||||
def test_abs_diff():
|
||||
# TODO(@M.K.): refactor using tuple pattern (not yet implemented)
|
||||
mod = Module()
|
||||
p = Prelude(mod)
|
||||
add_nat_definitions(p)
|
||||
nat = p.nat()
|
||||
x = Var("x", nat)
|
||||
y = Var("y", nat)
|
||||
xp = Var("x'", nat)
|
||||
yp = Var("y'", nat)
|
||||
diff = GlobalVar("diff")
|
||||
y_z_case = Clause(PatternConstructor(p.z, []), x)
|
||||
y_s_case = Clause(PatternConstructor(p.s, [PatternVar(yp)]), diff(yp, xp))
|
||||
x_z_case = Clause(PatternConstructor(p.z, []), y)
|
||||
x_s_case = Clause(PatternConstructor(p.s, [PatternVar(xp)]), Match(y, [y_z_case, y_s_case]))
|
||||
mod[diff] = Function([x, y], Match(x, [x_z_case, x_s_case]))
|
||||
orig = diff(make_nat_expr(p, 7), make_nat_expr(p, 3))
|
||||
res = dcpe(orig, mod=mod)
|
||||
assert alpha_equal(res, make_nat_expr(p, 4))
|
||||
|
||||
|
||||
def test_match_nat_id():
|
||||
mod = Module()
|
||||
p = Prelude(mod)
|
||||
add_nat_definitions(p)
|
||||
nat = p.nat()
|
||||
x = Var("x", nat)
|
||||
y = Var("y", nat)
|
||||
nat_id = GlobalVar("nat_id")
|
||||
z_case = Clause(PatternConstructor(p.z, []), p.z())
|
||||
s_case = Clause(PatternConstructor(p.s, [PatternVar(y)]), p.s(y))
|
||||
mod[nat_id] = Function([x], Match(x, [z_case, s_case]))
|
||||
orig = nat_id(make_nat_expr(p, 3))
|
||||
res = dcpe(orig, mod=mod)
|
||||
assert alpha_equal(res, make_nat_expr(p, 3))
|
||||
|
||||
|
||||
def test_nat_id():
|
||||
mod = Module()
|
||||
p = Prelude(mod)
|
||||
add_nat_definitions(p)
|
||||
nat = p.nat()
|
||||
x = Var("x", nat)
|
||||
y = Var("y", nat)
|
||||
nat_id = GlobalVar("nat_id")
|
||||
mod[nat_id] = Function([x], x)
|
||||
orig = nat_id(make_nat_expr(p, 3))
|
||||
res = dcpe(orig, mod=mod)
|
||||
assert alpha_equal(res, make_nat_expr(p, 3))
|
||||
|
||||
|
||||
def test_global_match_nat_id():
|
||||
mod = Module()
|
||||
p = Prelude(mod)
|
||||
add_nat_definitions(p)
|
||||
nat = p.nat()
|
||||
x = Var("x", nat)
|
||||
z_case = Clause(PatternConstructor(p.z, []), p.z())
|
||||
s_case = Clause(PatternConstructor(p.s, [PatternVar(x)]), p.s(x))
|
||||
orig = Match(make_nat_expr(p, 3), [z_case, s_case])
|
||||
res = dcpe(orig, mod=mod)
|
||||
assert alpha_equal(res, make_nat_expr(p, 3))
|
||||
|
||||
|
||||
def test_double():
|
||||
mod = Module()
|
||||
p = Prelude(mod)
|
||||
add_nat_definitions(p)
|
||||
orig = p.double(make_nat_expr(p, 3))
|
||||
res = dcpe(orig, mod=mod)
|
||||
assert alpha_equal(res, make_nat_expr(p, 6))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_empty_ad()
|
||||
test_tuple()
|
||||
test_const_inline()
|
||||
test_ref()
|
||||
|
@ -158,3 +282,11 @@ if __name__ == '__main__':
|
|||
test_if_ref()
|
||||
test_function_invalidate()
|
||||
test_head_cons()
|
||||
test_map()
|
||||
test_loop()
|
||||
test_swap_loop()
|
||||
test_abs_diff()
|
||||
test_double()
|
||||
test_nat_id()
|
||||
test_global_match_nat_id()
|
||||
test_match_nat_id()
|
||||
|
|
Загрузка…
Ссылка в новой задаче