save

save

save

upstream

lint

remove bad changes

fix build

save

save

please the ci god

Update src/relay/pass/partial_eval.cc

Co-Authored-By: Wei Chen <ipondering.weic@gmail.com>

save

fix test

ci is ANGRY

fix rebase problem

fix rebase

add test

save

save

comment
This commit is contained in:
雾雨魔理沙 2019-06-15 15:08:46 -07:00 коммит произвёл Jared Roesch
Родитель 50dd03ca86
Коммит df88c411f5
7 изменённых файлов: 625 добавлений и 172 удалений

Просмотреть файл

@ -296,13 +296,15 @@ TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const Module& mod);
* For example, this pass should turn `let a = 1 in 2` into `2`, * For example, this pass should turn `let a = 1 in 2` into `2`,
* as the value of the expression does not depend on a. * as the value of the expression does not depend on a.
* *
* As another example, `let a = 1 in a` will be optimized into 1. * As another example, `let a = 1 in a` will be optimized into 1,
* if the flag is turned on.
* *
* \param e the expression to optimize. * \param e the expression to optimize.
* \param inline_once whether or not to inline binding used one.
* *
* \return the optimized expression. * \return the optimized expression.
*/ */
TVM_DLL Expr DeadCodeElimination(const Expr& e); TVM_DLL Expr DeadCodeElimination(const Expr& e, bool inline_once = false);
/*! /*!
* \brief Fold constant expressions. * \brief Fold constant expressions.
@ -435,11 +437,12 @@ TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const Module& mod);
* It has two benefit: remove runtime overhead, and allow more optimization (typically fusion). * It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
* As a side effect, code size will explode. * As a side effect, code size will explode.
* *
* \param e the expression, * \param e the expression
* \param mod the module
* *
* \return the optimized expression. * \return the optimized expression.
*/ */
TVM_DLL Expr PartialEval(const Expr& e); TVM_DLL Expr PartialEval(const Expr& e, const Module& mod);
/*! /*!
* \brief Bind the free variables to a Relay expression. * \brief Bind the free variables to a Relay expression.

Просмотреть файл

@ -356,9 +356,11 @@ TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc<
* *
* As another example, `let a = 1 in a` will be optimized into 1. * As another example, `let a = 1 in a` will be optimized into 1.
* *
* \param inline_once whether or not to inline binding used one.
*
* \return the pass. * \return the pass.
*/ */
TVM_DLL Pass DeadCodeElimination(); TVM_DLL Pass DeadCodeElimination(bool inline_once = false);
/*! /*!
* \brief Fold constant expressions. * \brief Fold constant expressions.

Просмотреть файл

@ -129,7 +129,7 @@ def well_formed(expr):
Parameters Parameters
---------- ----------
expr: tvm.relay.Expr expr : tvm.relay.Expr
The input expression The input expression
Returns Returns
@ -175,7 +175,7 @@ def free_vars(expr):
Parameters Parameters
---------- ----------
expr: tvm.relay.Expr expr : tvm.relay.Expr
The input expression The input expression
Returns Returns
@ -197,7 +197,7 @@ def bound_vars(expr):
Parameters Parameters
---------- ----------
expr: tvm.relay.Expr expr : tvm.relay.Expr
The input expression The input expression
Returns Returns
@ -213,7 +213,7 @@ def all_vars(expr):
Parameters Parameters
---------- ----------
expr: tvm.relay.Expr expr : tvm.relay.Expr
The input expression The input expression
Returns Returns
@ -229,9 +229,10 @@ def free_type_vars(expr, mod=None):
Parameters Parameters
---------- ----------
expr: Union[tvm.relay.Expr,tvm.relay.Type] expr : Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type The input expression/type
mod: tvm.relay.Module, optional
mod : Optional[tvm.relay.Module]
The global module The global module
Returns Returns
@ -248,9 +249,10 @@ def bound_type_vars(expr, mod=None):
Parameters Parameters
---------- ----------
expr: Union[tvm.relay.Expr,tvm.relay.Type] expr : Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type The input expression/type
mod: tvm.relay.Module, optional
mod : Optional[tvm.relay.Module]
The global module The global module
Returns Returns
@ -267,9 +269,9 @@ def all_type_vars(expr, mod=None):
Parameters Parameters
---------- ----------
expr: Union[tvm.relay.Expr,tvm.relay.Type] expr : Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type The input expression/type
mod: tvm.relay.Module, optional mod : Optional[tvm.relay.Module]
The global module The global module
Returns Returns
@ -286,12 +288,12 @@ def simplify_inference(expr):
Parameters Parameters
---------- ----------
e: tvm.relay.Expr expr : tvm.relay.Expr
The input Expression The input Expression
Returns Returns
------- -------
result: tvm.relay.Expr result : tvm.relay.Expr
An expression which is semantically equal to the input expression, An expression which is semantically equal to the input expression,
but with some simplification but with some simplification
""" """
@ -304,32 +306,34 @@ def canonicalize_ops(expr):
Parameters Parameters
---------- ----------
e: tvm.relay.Expr expr : tvm.relay.Expr
The input Expression The input Expression
Returns Returns
------- -------
result: tvm.relay.Expr result : tvm.relay.Expr
An expression without bias_add An expression without bias_add
""" """
return _ir_pass.canonicalize_ops(expr) return _ir_pass.canonicalize_ops(expr)
def dead_code_elimination(expr): def dead_code_elimination(expr, inline_once=False):
""" Remove expressions which does not effect the program result (dead code). """ Remove expressions which does not effect the program result (dead code).
Parameters Parameters
---------- ----------
e: tvm.relay.Expr expr : tvm.relay.Expr
The input Expression The input Expression
inline_once : Optional[Bool]
Whether to inline binding that occur only once.
Returns Returns
------- -------
result: tvm.relay.Expr result : tvm.relay.Expr
An expression which is semantically equal to the input expression, An expression which is semantically equal to the input expression,
but with dead code removed. but with dead code removed.
""" """
return _ir_pass.dead_code_elimination(expr) return _ir_pass.dead_code_elimination(expr, inline_once)
def alpha_equal(lhs, rhs): def alpha_equal(lhs, rhs):
@ -337,15 +341,15 @@ def alpha_equal(lhs, rhs):
Parameters Parameters
---------- ----------
lhs: tvm.relay.Expr lhs : tvm.relay.Expr
One of the input Expression. One of the input Expression.
rhs: tvm.relay.Expr rhs : tvm.relay.Expr
One of the input Expression. One of the input Expression.
Returns Returns
------- -------
result: bool result : bool
True iff lhs is alpha equal to rhs. True iff lhs is alpha equal to rhs.
""" """
return bool(_make._alpha_equal(lhs, rhs)) return bool(_make._alpha_equal(lhs, rhs))
@ -359,15 +363,15 @@ def graph_equal(lhs, rhs):
Parameters Parameters
---------- ----------
lhs: tvm.relay.Expr lhs : tvm.relay.Expr
One of the input Expression. One of the input Expression.
rhs: tvm.relay.Expr rhs : tvm.relay.Expr
One of the input Expression. One of the input Expression.
Returns Returns
------- -------
result: bool result : bool
True iff lhs is data-flow equivalent to rhs. True iff lhs is data-flow equivalent to rhs.
""" """
return bool(_make._graph_equal(lhs, rhs)) return bool(_make._graph_equal(lhs, rhs))
@ -378,12 +382,12 @@ def structural_hash(value):
Parameters Parameters
---------- ----------
expr: tvm.relay.Expr or tvm.relay.Type expr : Union[tvm.relay.Expr, tvm.relay.Type]
The expression to hash. The expression to hash.
Returns Returns
------- -------
result: int result : int
The hash value The hash value
""" """
if isinstance(value, Expr): if isinstance(value, Expr):
@ -544,12 +548,12 @@ def to_a_normal_form(expr, mod=None):
expr : tvm.relay.Expr expr : tvm.relay.Expr
The input expression. The input expression.
mod: Optional[tvm.relay.Module] mod : Optional[tvm.relay.Module]
The global module. The global module.
Returns Returns
------- -------
expr: tvm.relay.Expr result : tvm.relay.Expr
The output expression. The output expression.
""" """
return _ir_pass.to_a_normal_form(expr, mod) return _ir_pass.to_a_normal_form(expr, mod)
@ -563,7 +567,7 @@ def to_graph_normal_form(expr):
The input expression The input expression
Returns Returns
------- -------
expr : tvm.relay.Expr result : tvm.relay.Expr
The output expression The output expression
""" """
return _ir_pass.to_graph_normal_form(expr) return _ir_pass.to_graph_normal_form(expr)
@ -612,7 +616,7 @@ def get_total_mac_number(expr):
Returns Returns
------- -------
ret : int64 result : int64
The number of MACs (multiply-accumulate) of a model The number of MACs (multiply-accumulate) of a model
""" """
return _ir_pass.GetTotalMacNumber(expr) return _ir_pass.GetTotalMacNumber(expr)
@ -627,17 +631,17 @@ def eliminate_common_subexpr(expr, fskip=None):
expr : tvm.relay.Expr expr : tvm.relay.Expr
The input expression. The input expression.
fskip: function fskip : function
The callback function that decides whether an expression should be skipped. The callback function that decides whether an expression should be skipped.
Returns Returns
------- -------
expr : tvm.relay.Expr result : tvm.relay.Expr
The output expression. The output expression.
""" """
return _ir_pass.eliminate_common_subexpr(expr, fskip) return _ir_pass.eliminate_common_subexpr(expr, fskip)
def partial_evaluate(expr): def partial_evaluate(expr, mod=None):
""" """
Evaluate the static fragment of the code. Evaluate the static fragment of the code.
@ -646,12 +650,15 @@ def partial_evaluate(expr):
expr : tvm.relay.Expr expr : tvm.relay.Expr
The input expression. The input expression.
mod : Optional[tvm.relay.Module]
The global module
Returns Returns
------- -------
expr : tvm.relay.Expr result : tvm.relay.Expr
The output expression. The output expression.
""" """
return _ir_pass.partial_evaluate(expr) return _ir_pass.partial_evaluate(expr, mod)
def unmatched_cases(match, mod=None): def unmatched_cases(match, mod=None):
""" """

Просмотреть файл

@ -220,8 +220,8 @@ TVM_REGISTER_API("relay._make.Call")
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<CallNode>([](const CallNode* node, tvm::IRPrinter* p) { .set_dispatch<CallNode>([](const CallNode* node, tvm::IRPrinter* p) {
p->stream << "CallNode(" << node->op << ", " << node->args << ", " p->stream << "CallNode(" << node->op << ", " << node->args << ", "
<< node->attrs << ", " << node->type_args << ")"; << node->attrs << ", " << node->type_args << ")";
}); });
Let LetNode::make(Var var, Expr value, Expr body) { Let LetNode::make(Var var, Expr value, Expr body) {
@ -324,7 +324,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
TVM_REGISTER_API("relay._expr.TempExprRealize") TVM_REGISTER_API("relay._expr.TempExprRealize")
.set_body_typed<Expr(TempExpr)>([](TempExpr temp) { .set_body_typed<Expr(TempExpr)>([](TempExpr temp) {
return temp->Realize(); return temp->Realize();
}); });
} // namespace relay } // namespace relay

Просмотреть файл

@ -38,10 +38,10 @@ namespace relay {
// calculate the dependency graph from expression // calculate the dependency graph from expression
class CalcDep : private ExprVisitor { class CalcDep : private ExprVisitor {
public: public:
static Expr Eliminate(const Expr& e) { static Expr Eliminate(const Expr& e, bool inline_once) {
CalcDep cd; CalcDep cd;
cd.Calculate(e); cd.Calculate(e);
Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_); Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_, inline_once);
return el(e); return el(e);
} }
@ -117,15 +117,23 @@ class CalcDep : private ExprVisitor {
VarMap<Expr> expr_map_; VarMap<Expr> expr_map_;
VarMap<size_t> use_map_; VarMap<size_t> use_map_;
VarSet letrec_set_; VarSet letrec_set_;
bool inline_once_;
explicit Eliminator(const VarMap<Expr>& expr_map, explicit Eliminator(const VarMap<Expr>& expr_map,
const VarMap<size_t>& use_map, const VarMap<size_t>& use_map,
const VarSet& letrec_set) : const VarSet& letrec_set,
expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set) { } bool inline_once) :
expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set), inline_once_(inline_once) { }
friend CalcDep; friend CalcDep;
bool HasLet(const Var& v) { bool HasLet(const Var& v) {
// TODO(@jroesch): MK fix me switch (use_map_[v]) {
return (use_map_[v] > 0 || (use_map_[v] != 0 && letrec_set_.count(v) != 0)); case 0:
return false;
case 1:
return letrec_set_.count(v) > 0 || !inline_once_;
default:
return true;
}
} }
Expr VisitExpr_(const VarNode* op) final { Expr VisitExpr_(const VarNode* op) final {
@ -144,8 +152,8 @@ class CalcDep : private ExprVisitor {
}; };
}; };
Expr DeadCodeElimination(const Expr& e) { Expr DeadCodeElimination(const Expr& e, bool inline_once) {
return CalcDep::Eliminate(e); return CalcDep::Eliminate(e, inline_once);
} }
TVM_REGISTER_API("relay._ir_pass.dead_code_elimination") TVM_REGISTER_API("relay._ir_pass.dead_code_elimination")
@ -153,10 +161,10 @@ TVM_REGISTER_API("relay._ir_pass.dead_code_elimination")
namespace transform { namespace transform {
Pass DeadCodeElimination() { Pass DeadCodeElimination(bool inline_once) {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func = runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) { [=](Function f, Module m, PassContext pc) {
return Downcast<Function>(DeadCodeElimination(f)); return Downcast<Function>(DeadCodeElimination(f, inline_once));
}; };
return CreateFunctionPass(pass_func, 1, "DeadCodeElimination", {}); return CreateFunctionPass(pass_func, 1, "DeadCodeElimination", {});
} }

Просмотреть файл

@ -74,28 +74,19 @@
* *
* The partial evaluator makes several assumptions, so there is room for improvement: * The partial evaluator makes several assumptions, so there is room for improvement:
* *
* 0: The partial evaluator treats global variables as opaque. * 0: Every time an unknown effect happened, we clear the whole store.
* Doing PartialEval on a module level will solve this.
*
* 1: The partial evaluator assume all functions as terminating.
* We need to has a max_expand parameter that shrink on every compile time evaluation,
* to make sure PE does not infinite loop.
* Additionally, we might add a termination analysis pass that lift this requirement
* for function that analysis found terminating.
*
* 2: Every time an unknown effect happened, we clear the whole store.
* It is too conservative: if a local reference is created (and do not get passed outside), * It is too conservative: if a local reference is created (and do not get passed outside),
* An unknown global function call/global reference write can not modify it. * An unknown global function call/global reference write can not modify it.
* We can pair PE with escape analysis/alias analysis. * We can pair PE with escape analysis/alias analysis.
* *
* 3: We assume all unknown code has effect. Doing effect analysis can make the store more precise. * 1: We assume all unknown code has effect. Doing effect analysis can make the store more precise.
* *
* 4: When doing pattern matching, we can simplify the match even for dynamic case. * 2: When doing pattern matching, we can simplify the match even for dynamic case.
* Right now it is all or nothing: either a complete match, or the original dynamic code. * Right now it is all or nothing: either a complete match, or the original dynamic code.
* Instead, we can get a match tree, pair it with the data and evaluate it to a normal form. * Instead, we can get a match tree, pair it with the data and evaluate it to a normal form.
* We then can reify the result. * We then can reify the result.
* *
* 5: Every time a function is called, it's code will get expanded and partially evaluated. * 3: Every time a function is called, its code will get expanded and partially evaluated.
* We can do a binding time analysis to cache the result and avoid re-partial evaluation. * We can do a binding time analysis to cache the result and avoid re-partial evaluation.
* *
* These assumptions do not affect the correctness of the algorithm, however. * These assumptions do not affect the correctness of the algorithm, however.
@ -104,6 +95,7 @@
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h> #include <tvm/relay/pattern_functor.h>
#include <tvm/relay/interpreter.h> #include <tvm/relay/interpreter.h>
#include "../ir/type_functor.h"
#include "pass_util.h" #include "pass_util.h"
#include "let_list.h" #include "let_list.h"
@ -132,6 +124,8 @@ struct VarEqual {
} }
}; };
Expr PostProcess(const Expr&);
/*! \brief The base container type of Relay values. */ /*! \brief The base container type of Relay values. */
class StaticNode : public RelayNode { class StaticNode : public RelayNode {
public: public:
@ -150,10 +144,20 @@ class Static : public NodeRef {
using ContainerType = StaticNode; using ContainerType = StaticNode;
}; };
using Time = size_t;
struct PStaticNode : Node { struct PStaticNode : Node {
static Time time() {
static Time time_ = 0;
Time ret = time_;
time_++;
return ret;
}
Static pstatic; // may be null Static pstatic; // may be null
Expr dynamic; Expr dynamic;
PStaticNode(const Static& pstatic, const Expr& dynamic) : pstatic(pstatic), dynamic(dynamic) { } Time created_time;
PStaticNode(const Static& pstatic, const Expr& dynamic) :
pstatic(pstatic), dynamic(dynamic), created_time(time()) { }
explicit PStaticNode(const Expr& dynamic) : PStaticNode(Static(), dynamic) { } explicit PStaticNode(const Expr& dynamic) : PStaticNode(Static(), dynamic) { }
TVM_DECLARE_NODE_TYPE_INFO(PStaticNode, Node); TVM_DECLARE_NODE_TYPE_INFO(PStaticNode, Node);
}; };
@ -341,6 +345,7 @@ class Store {
}; };
PStatic HasStatic(const Static& stat, const Expr& dynamic) { PStatic HasStatic(const Static& stat, const Expr& dynamic) {
CHECK(stat.defined());
return PStatic(make_node<PStaticNode>(stat, dynamic)); return PStatic(make_node<PStaticNode>(stat, dynamic));
} }
@ -383,15 +388,78 @@ FInterpreter CPUInterpreter() {
return CreateInterpreter(Module(nullptr), CPUContext(), target); return CreateInterpreter(Module(nullptr), CPUContext(), target);
} }
bool IsAtomic(const Expr& e) {
return e.as<VarNode>() || e.as<OpNode>() || e.as<ConstructorNode>() || e.as<GlobalVarNode>();
}
using FuncId = int;
/*!
* \brief Annotate a function with a FuncId.
*/
struct WithFuncIdAttrs : public tvm::AttrsNode<WithFuncIdAttrs> {
FuncId fid;
TVM_DECLARE_ATTRS(WithFuncIdAttrs, "relay.attrs.WithFuncIdAttrs") {
TVM_ATTR_FIELD(fid)
.describe("The FuncId that an function is annotated with.")
.set_default(-1);
}
};
TVM_REGISTER_NODE_TYPE(WithFuncIdAttrs);
Op WithFuncIdOp() {
static const Op& op = Op::Get("annotation.with_funcid");
return op;
}
Expr MkWithFuncId(const Expr& expr, FuncId fid) {
auto attrs = make_node<WithFuncIdAttrs>();
attrs->fid = fid;
return CallNode::make(WithFuncIdOp(), {expr}, Attrs(attrs), {});
}
RELAY_REGISTER_OP("annotation.with_funcid")
.describe(R"code(Annotate a function with a funcid.)code"
TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("func", "Function", "The input data.");
Expr StripWithFuncId(const Expr& e);
Expr DeDup(const Expr& e);
Function AsFunc(const Expr& e) {
if (e.as<FunctionNode>()) {
return Downcast<Function>(e);
} else if (const CallNode* c = e.as<CallNode>()) {
CHECK(c->op.same_as(WithFuncIdOp()));
CHECK_EQ(c->args.size(), 1);
return AsFunc(c->args[0]);
} else {
LOG(FATAL) << "Unknown case";
throw;
}
}
class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>, class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>,
public PatternFunctor<MatchStatus(const Pattern&, const PStatic&)> { public PatternFunctor<MatchStatus(const Pattern&, const PStatic&)> {
public: public:
PartialEvaluator(const tvm::Array<Var>& free_vars) { PartialEvaluator(const tvm::Array<Var>& free_vars,
const Module& mod) :
mod_(mod) {
for (const Var& v : free_vars) { for (const Var& v : free_vars) {
env_.Insert(v, NoStatic(v)); env_.Insert(v, NoStatic(v));
} }
} }
PStatic VisitExpr(const Expr& e, LetList* ll) final {
PStatic ret = ExprFunctor<PStatic(const Expr&, LetList*)>::VisitExpr(e, ll);
CHECK(IsAtomic(ret->dynamic)) << ret->dynamic;
return ret;
}
PStatic VisitExpr_(const ConstantNode* op, LetList* ll) final { PStatic VisitExpr_(const ConstantNode* op, LetList* ll) final {
return HasStatic(MkSTensor(op->data.CopyTo(context_)), ll->Push(GetRef<Expr>(op))); return HasStatic(MkSTensor(op->data.CopyTo(context_)), ll->Push(GetRef<Expr>(op)));
} }
@ -421,7 +489,20 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
} }
PStatic VisitExpr_(const GlobalVarNode* op, LetList* ll) final { PStatic VisitExpr_(const GlobalVarNode* op, LetList* ll) final {
return NoStatic(GetRef<Expr>(op)); GlobalVar gv = GetRef<GlobalVar>(op);
if (gv_map_.count(gv) == 0) {
if (mod_.defined()) {
Function func = mod_->Lookup(gv);
InitializeFuncId(func);
Func f = VisitFuncStatic(func, gv);
gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)});
func = AsFunc(PostProcess(VisitFuncDynamic(func, f)));
mod_->Update(gv, func);
} else {
gv_map_.insert({gv, NoStatic(gv)});
}
}
return gv_map_.at(gv);
} }
PStatic VisitExpr_(const LetNode* op, LetList* ll) final { PStatic VisitExpr_(const LetNode* op, LetList* ll) final {
@ -485,6 +566,10 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
} }
PStatic VisitExpr_(const CallNode* op, LetList* ll) final { PStatic VisitExpr_(const CallNode* op, LetList* ll) final {
if (op->op.same_as(WithFuncIdOp())) {
CHECK_EQ(op->args.size(), 1);
return VisitExpr(op->args[0], ll);
}
PStatic f = VisitExpr(op->op, ll); PStatic f = VisitExpr(op->op, ll);
std::vector<PStatic> x; std::vector<PStatic> x;
tvm::Array<Expr> x_dyn; tvm::Array<Expr> x_dyn;
@ -501,19 +586,40 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
} }
} }
PStatic VisitExpr_(const FunctionNode* op, LetList* ll) final { struct TimeFrame {
Function func = GetRef<Function>(op); PartialEvaluator* pe_;
FuncId fid_;
std::vector<Time> old_time;
bool has_old_time;
TimeFrame(PartialEvaluator* pe,
FuncId fid,
const std::vector<Time>& args_time) : pe_(pe), fid_(fid) {
has_old_time = pe_->time_map_.count(fid_) > 0;
old_time = pe_->time_map_[fid_];
pe_->time_map_[fid_] = args_time;
}
~TimeFrame() {
if (has_old_time) {
pe_->time_map_[fid_] = old_time;
} else {
pe_->time_map_.erase(fid_);
}
}
};
Func VisitFuncStatic(const Function& func, const Expr& var) {
CHECK(IsAtomic(var));
if (func->IsPrimitive()) { if (func->IsPrimitive()) {
return HasStatic(MkSFunc(ConstEvaluateFunc(func, ll)), func); return ConstEvaluateFunc(func);
} }
std::vector<std::pair<Var, PStatic> > free_vars; std::vector<std::pair<Var, PStatic> > free_vars;
for (const auto& v : FreeVars(GetRef<Expr>(op))) { for (const auto& v : FreeVars(func)) {
free_vars.push_back(std::pair<Var, PStatic>(v, env_.Lookup(v))); free_vars.push_back(std::pair<Var, PStatic>(v, env_.Lookup(v)));
} }
Func f = [=](const std::vector<PStatic>& pv, return [=](const std::vector<PStatic>& pv,
const Attrs& attrs, const Attrs& attrs,
const tvm::Array<Type>& type_args, const tvm::Array<Type>& type_args,
LetList* ll) { LetList* ll) {
return env_.Extend<PStatic>([&]() { return env_.Extend<PStatic>([&]() {
CHECK_EQ(pv.size(), func->params.size()); CHECK_EQ(pv.size(), func->params.size());
for (size_t i = 0; i < pv.size(); ++i) { for (size_t i = 0; i < pv.size(); ++i) {
@ -529,10 +635,50 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
for (size_t i = type_args.size(); i < func->type_params.size(); ++i) { for (size_t i = type_args.size(); i < func->type_params.size(); ++i) {
subst.Set(func->type_params[i], Type()); subst.Set(func->type_params[i], Type());
} }
return VisitExpr(TypeSubst(func->body, subst), ll); std::vector<Time> args_time;
for (const auto& v : pv) {
args_time.push_back(v->created_time);
}
CHECK_GT(func_map_.count(func), 0);
FuncId fid = func_map_.at(func);
auto recurse = [&]() {
TimeFrame tf(this, fid, args_time);
return VisitExpr(RegisterFuncId(TypeSubst(AnnotateFuncId(func->body), subst)), ll);
};
if (time_map_.count(fid) == 0) {
return recurse();
} else {
/* We check to see that at least one argument decrease
* with respect to all previous invocation.
* The depth of the recursion is bounded by
* the sum of the time of all argument at the first call.
*/
bool can_recurse = false;
std::vector<Time>& min_time = time_map_.at(fid);
CHECK_EQ(args_time.size(), min_time.size());
for (size_t i = 0; i < args_time.size(); ++i) {
if (args_time[i] < min_time[i]) {
can_recurse = true;
}
args_time[i] = std::min(args_time[i], min_time[i]);
}
if (can_recurse) {
return recurse();
} else {
std::vector<Expr> dyn;
for (const auto& v : pv) {
dyn.push_back(v->dynamic);
}
return NoStatic(ll->Push(CallNode::make(var, dyn, attrs, type_args)));
}
}
}); });
}; };
Expr dyn = store_.Extend<Expr>([&]() { }
Expr VisitFuncDynamic(const Function& func, const Func& f) {
return store_.Extend<Expr>([&]() {
store_.Invalidate(); store_.Invalidate();
return FunctionNode::make(func->params, LetList::With([&](LetList* ll) { return FunctionNode::make(func->params, LetList::With([&](LetList* ll) {
std::vector<PStatic> pv; std::vector<PStatic> pv;
@ -546,7 +692,20 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
return f(pv, Attrs(), type_args, ll)->dynamic; return f(pv, Attrs(), type_args, ll)->dynamic;
}), func->ret_type, func->type_params, func->attrs); }), func->ret_type, func->type_params, func->attrs);
}); });
return HasStatic(MkSFunc(f), ll->Push(dyn)); }
PStatic VisitFunc(const Function& func, LetList* ll) {
Var v = VarNode::make("x", Type());
Func f = VisitFuncStatic(func, v);
Function u_func = AsFunc(RegisterFuncId(DeDup(AnnotateFuncId(func))));
// TODO(@M.K.): we seems to reduce landin knot into letrec.
// restore letrec support across whole relay.
return HasStatic(MkSFunc(f),
ll->Push(v, VisitFuncDynamic(u_func, f)));
}
PStatic VisitExpr_(const FunctionNode* op, LetList* ll) final {
return VisitFunc(GetRef<Function>(op), ll);
} }
Expr Reflect(const PStatic& st) { Expr Reflect(const PStatic& st) {
@ -590,7 +749,8 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
return Reify(executor_(fused_infered), ll); return Reify(executor_(fused_infered), ll);
} }
Func ConstEvaluateFunc(const Expr& expr, LetList* ll) { Func ConstEvaluateFunc(const Expr& expr) {
CHECK_EQ(FreeVars(expr).size(), 0);
return [=](const std::vector<PStatic>& pv, return [=](const std::vector<PStatic>& pv,
const Attrs& attrs, const Attrs& attrs,
const tvm::Array<Type>& type_args, const tvm::Array<Type>& type_args,
@ -599,7 +759,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
for (const PStatic& ps : pv) { for (const PStatic& ps : pv) {
ns_args.push_back(ps->dynamic); ns_args.push_back(ps->dynamic);
} }
PStatic ns = NoStatic(CallNode::make(expr, ns_args, attrs, type_args)); PStatic ns = NoStatic(ll->Push(CallNode::make(expr, ns_args, attrs, type_args)));
if (StatefulOp(expr)) { if (StatefulOp(expr)) {
return ns; return ns;
} }
@ -616,7 +776,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
} }
PStatic VisitExpr_(const OpNode* op, LetList* ll) final { PStatic VisitExpr_(const OpNode* op, LetList* ll) final {
return HasStatic(MkSFunc(ConstEvaluateFunc(GetRef<Expr>(op), ll)), GetRef<Expr>(op)); return HasStatic(MkSFunc(ConstEvaluateFunc(GetRef<Expr>(op))), GetRef<Expr>(op));
} }
PStatic VisitExpr_(const ConstructorNode* op, LetList* ll) final { PStatic VisitExpr_(const ConstructorNode* op, LetList* ll) final {
@ -680,7 +840,6 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
CHECK_NE(op->constructor->tag, -1); CHECK_NE(op->constructor->tag, -1);
CHECK_NE(scn->constructor->tag, -1); CHECK_NE(scn->constructor->tag, -1);
if (op->constructor->tag == scn->constructor->tag) { if (op->constructor->tag == scn->constructor->tag) {
// todo(M.K.): should use ptr equality but it is broken
CHECK_EQ(op->patterns.size(), scn->fields.size()); CHECK_EQ(op->patterns.size(), scn->fields.size());
MatchStatus current_match_status = MatchStatus::Match; MatchStatus current_match_status = MatchStatus::Match;
for (size_t i = 0; i < op->patterns.size(); ++i) { for (size_t i = 0; i < op->patterns.size(); ++i) {
@ -702,27 +861,119 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
} }
} }
void InitializeFuncId(const Expr& e) {
struct InitializeFuncIdVisitor : ExprVisitor, PatternVisitor {
PartialEvaluator* pe;
explicit InitializeFuncIdVisitor(PartialEvaluator* pe) : pe(pe) { }
void VisitExpr_(const FunctionNode* op) final {
Function f = GetRef<Function>(op);
CHECK_EQ(pe->func_map_.count(f), 0);
pe->func_map_.insert({f, pe->func_map_.size()});
VisitExpr(f->body);
}
void VisitPattern(const Pattern& p) final {
PatternVisitor::VisitPattern(p);
}
};
InitializeFuncIdVisitor(this).VisitExpr(e);
}
Expr RegisterFuncId(const Expr& e) {
struct RegisterFuncIdVisitor : ExprVisitor, PatternVisitor {
PartialEvaluator* pe;
explicit RegisterFuncIdVisitor(PartialEvaluator* pe) : pe(pe) { }
void VisitExpr_(const CallNode* op) final {
if (op->op.same_as(WithFuncIdOp())) {
CHECK_EQ(op->args.size(), 1);
CHECK(op->attrs.defined());
CHECK(op->attrs.as<WithFuncIdAttrs>());
Function f = AsFunc(op->args[0]);
FuncId fid = op->attrs.as<WithFuncIdAttrs>()->fid;
if (pe->func_map_.count(f) != 0) {
CHECK_EQ(pe->func_map_.at(f), fid);
}
pe->func_map_.insert({f, fid});
}
ExprVisitor::VisitExpr_(op);
}
void VisitExpr_(const FunctionNode* op) final {
Function f = GetRef<Function>(op);
CHECK_GT(pe->func_map_.count(f), 0);
ExprVisitor::VisitExpr_(op);
}
void VisitPattern(const Pattern& p) final {
PatternVisitor::VisitPattern(p);
}
};
RegisterFuncIdVisitor(this).VisitExpr(e);
return e;
}
Expr AnnotateFuncId(const Expr& e) {
struct AnnotateFuncIdMutator : ExprMutator, PatternMutator {
PartialEvaluator* pe;
explicit AnnotateFuncIdMutator(PartialEvaluator* pe) : pe(pe) { }
Expr VisitExpr_(const FunctionNode* op) final {
Function f = GetRef<Function>(op);
CHECK_GT(pe->func_map_.count(f), 0);
return MkWithFuncId(ExprMutator::VisitExpr_(op), pe->func_map_.at(f));
}
Pattern VisitPattern(const Pattern& p) final {
return PatternMutator::VisitPattern(p);
}
Var VisitVar(const Var& v) final {
return v;
}
};
return AnnotateFuncIdMutator(this).VisitExpr(e);
}
private: private:
Environment env_; Environment env_;
Module mod_;
std::unordered_map<GlobalVar, PStatic, NodeHash, NodeEqual> gv_map_;
/*! Termination checking is done as follows:
* We have finitely many FunctionIds.
* Each FunctionId maps to a class of semantically equivalent function (ignoring type),
* as both TypeSubst and DeDup create semantically equivalent function.
* We partially map each FunctionId to a std::vector<Time>,
* denoting the minimal TimeFrame of each argument of the function.
* Every time we try to inline a Function,
* we make sure it either does not have a vector<Time>, which means this is the initial call,
* or some argument has a lesser time, which means some earlier argument is passed in.
* In any case, we remap the mapping to a minimal vector<Time> across all previous invocations
* when we PE inside the Function body.
* Termination is guaranteed because the creation time of at least one argument will decrease every call.
*/
std::unordered_map<Function, FuncId, NodeHash, NodeEqual> func_map_;
std::unordered_map<FuncId, std::vector<Time> > time_map_;
Store store_; Store store_;
DLContext context_ = CPUContext(); DLContext context_ = CPUContext();
FInterpreter executor_ = CPUInterpreter(); FInterpreter executor_ = CPUInterpreter();
}; };
Var DeDupVar(const Var& v) {
return VarNode::make(v->name_hint(), v->type_annotation);
}
TypeVar DeDupTypeVar(const TypeVar& tv) {
return TypeVarNode::make(tv->var->name_hint, tv->kind);
}
/*! \brief Use a fresh Id for every Var to make the result well-formed. */ /*! \brief Use a fresh Id for every Var to make the result well-formed. */
Expr DeDup(const Expr& e) { Expr DeDup(const Expr& e) {
class DeDupMutator : public ExprMutator, public PatternMutator { class DeDupMutator : public TypeMutator,
public ExprMutator,
public PatternMutator {
public: public:
TypeVar Fresh(const TypeVar& tv) {
TypeVar ret = TypeVarNode::make(tv->var->name_hint, tv->kind);
type_rename_[tv] = ret;
return ret;
}
Var Fresh(const Var& v) { Var Fresh(const Var& v) {
Var ret = DeDupVar(v); Var ret = VarNode::make(v->name_hint(), VisitType(v->type_annotation));
rename_[v] = ret; rename_[v] = ret;
return ret; return ret;
} }
@ -737,18 +988,27 @@ Expr DeDup(const Expr& e) {
} }
Expr VisitExpr_(const LetNode* op) final { Expr VisitExpr_(const LetNode* op) final {
return LetNode::make(Fresh(op->var), VisitExpr(op->value), VisitExpr(op->body)); Var v = Fresh(op->var);
return LetNode::make(v, VisitExpr(op->value), VisitExpr(op->body));
}
Type VisitType(const Type& t) final {
return t.defined() ? TypeMutator::VisitType(t) : t;
} }
Expr VisitExpr_(const FunctionNode* op) final { Expr VisitExpr_(const FunctionNode* op) final {
tvm::Array<TypeVar> type_params;
for (const TypeVar& type_param : op->type_params) {
type_params.push_back(Fresh(type_param));
}
tvm::Array<Var> params; tvm::Array<Var> params;
for (const Var& param : op->params) { for (const Var& param : op->params) {
params.push_back(Fresh(param)); params.push_back(Fresh(param));
} }
return FunctionNode::make(params, return FunctionNode::make(params,
VisitExpr(op->body), VisitExpr(op->body),
op->ret_type, VisitType(op->ret_type),
op->type_params, type_params,
op->attrs); op->attrs);
} }
@ -756,14 +1016,28 @@ Expr DeDup(const Expr& e) {
return PatternMutator::VisitPattern(p); return PatternMutator::VisitPattern(p);
} }
Clause VisitClause(const Clause& c) final {
Pattern pat = VisitPattern(c->lhs);
return ClauseNode::make(pat, VisitExpr(c->rhs));
}
Type VisitType_(const TypeVarNode* op) final {
TypeVar v = GetRef<TypeVar>(op);
return type_rename_.count(v) != 0 ? type_rename_.at(v) : v;
}
Var VisitVar(const Var& v) final { Var VisitVar(const Var& v) final {
return Fresh(v); return Fresh(v);
} }
private: private:
std::unordered_map<Var, Var, NodeHash, NodeEqual> rename_; std::unordered_map<Var, Var, NodeHash, NodeEqual> rename_;
std::unordered_map<TypeVar, TypeVar, NodeHash, NodeEqual> type_rename_;
}; };
return DeDupMutator().VisitExpr(e);
Expr ret = DeDupMutator().VisitExpr(e);
CHECK_EQ(FreeVars(ret).size(), FreeVars(e).size());
return ret;
} }
/*! \brief Remap multiple Var sharing the same Id into the same Var. */ /*! \brief Remap multiple Var sharing the same Id into the same Var. */
@ -787,11 +1061,38 @@ Expr Remap(const Expr& e) {
return RemapMutator().VisitExpr(e); return RemapMutator().VisitExpr(e);
} }
Expr PartialEval(const Expr& e) { Expr StripWithFuncId(const Expr& e) {
struct StripWithFuncIdMutator : ExprMutator, PatternMutator {
Expr VisitExpr_(const CallNode* op) final {
if (op->op.same_as(WithFuncIdOp())) {
CHECK_EQ(op->args.size(), 1);
return VisitExpr(op->args[0]);
} else {
return ExprMutator::VisitExpr_(op);
}
}
Pattern VisitPattern(const Pattern& p) final {
return PatternMutator::VisitPattern(p);
}
Var VisitVar(const Var& v) final {
return v;
}
};
return StripWithFuncIdMutator().VisitExpr(e);
}
Expr PostProcess(const Expr& e) {
return StripWithFuncId(DeDup(Remap(e)));
}
Expr PartialEval(const Expr& e, const Module& m) {
return TransformF([&](const Expr& e) { return TransformF([&](const Expr& e) {
return LetList::With([&](LetList* ll) { return LetList::With([&](LetList* ll) {
PartialEvaluator pe(FreeVars(e)); PartialEvaluator pe(FreeVars(e), m);
return Remap(DeDup(pe.VisitExpr(e, ll)->dynamic)); pe.InitializeFuncId(e);
return PostProcess(pe.VisitExpr(e, ll)->dynamic);
}); });
}, e); }, e);
} }
@ -804,7 +1105,7 @@ namespace transform {
Pass PartialEval() { Pass PartialEval() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func = runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) { [=](Function f, Module m, PassContext pc) {
return Downcast<Function>(PartialEval(f)); return Downcast<Function>(PartialEval(f, m));
}; };
return CreateFunctionPass(pass_func, 1, "PartialEvaluate", {}); return CreateFunctionPass(pass_func, 1, "PartialEvaluate", {});
} }

Просмотреть файл

@ -18,14 +18,17 @@
import numpy as np import numpy as np
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.relay.ir_pass import partial_evaluate, dead_code_elimination from tvm.relay.ir_pass import partial_evaluate, alpha_equal, infer_type, dead_code_elimination
from tvm.relay.ir_pass import gradient, alpha_equal, infer_type from tvm.relay.ir_pass import gradient
from tvm.relay import op, create_executor from tvm.relay import op, create_executor
from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue
from tvm.relay.prelude import Prelude from tvm.relay.prelude import Prelude
from tvm.relay import create_executor from tvm.relay import create_executor
from nose.tools import nottest from nose.tools import nottest
from tvm.relay import Var, TypeVar, TupleGetItem, Let, Function, const, RefRead, RefWrite, RefCreate
from tvm.relay import TensorType, Tuple, If, Module, Clause, PatternConstructor, PatternVar, Match
from tvm.relay import GlobalVar, Call, Type
from tvm.relay.testing import add_nat_definitions, count, make_nat_value, make_nat_expr
def check_eval(expr, expected_result, mod=None, rtol=1e-07): def check_eval(expr, expected_result, mod=None, rtol=1e-07):
ctx = tvm.context("llvm", 0) ctx = tvm.context("llvm", 0)
@ -35,24 +38,22 @@ def check_eval(expr, expected_result, mod=None, rtol=1e-07):
np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol) np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol)
def dcpe(expr): def dcpe(expr, mod=None):
return dead_code_elimination(partial_evaluate(expr)) return dead_code_elimination(partial_evaluate(expr, mod=mod), inline_once=True)
def test_tuple(): def test_tuple():
t = relay.TypeVar("t") t = TypeVar("t")
x = relay.Var("x", t) x = Var("x", t)
body = relay.TupleGetItem(relay.Tuple([relay.const(4.0), x]), 1) body = TupleGetItem(relay.Tuple([relay.const(4.0), x]), 1)
f = relay.Function([x], body, None, [t]) f = Function([x], body, None, [t])
assert alpha_equal(dcpe(f), relay.Function([x], x, None, [t])) assert alpha_equal(dcpe(f), relay.Function([x], x, None, [t]))
@nottest
def test_const_inline(): def test_const_inline():
# TODO(MK): fix me d = Var("d")
d = relay.Var("d") double = Function([d], d + d)
double = relay.Function([d], d + d) orig = double(const(4.0))
orig = double(relay.const(4.0)) assert alpha_equal(dcpe(orig), const(8.0))
assert alpha_equal(dcpe(double(relay.const(4.0))), relay.const(8.0))
def test_ref(): def test_ref():
@ -60,44 +61,57 @@ def test_ref():
r = relay.Var("r") r = relay.Var("r")
x = relay.Var("x") x = relay.Var("x")
body = relay.RefRead(r) body = relay.RefRead(r)
body = relay.Let(x, relay.RefWrite(r, relay.RefRead(r) * relay.RefRead(r)), body) body = Let(x, RefWrite(r, RefRead(r) * RefRead(r)), body)
body = relay.Let(r, relay.RefCreate(d), body) body = Let(r, RefCreate(d), body)
square = relay.Function([d], body) square = Function([d], body)
assert alpha_equal(dcpe(square), relay.Function([d], d * d)) assert alpha_equal(dcpe(square), Function([d], d * d))
@nottest
def test_ad(): def test_empty_ad():
# TODO(MK): fix me
shape = (10, 10) shape = (10, 10)
dtype = "float32" dtype = "float32"
t = relay.TensorType(shape, dtype) t = TensorType(shape, dtype)
d = relay.Var("d", t) d = Var("d", t)
f = relay.Function([d], d * d) f = Function([d], d)
g = dcpe(gradient(f))
expected = Function([d], Tuple([d, Tuple([op.ones_like(d)])]))
assert alpha_equal(g, expected)
def test_ad():
shape = (10, 10)
dtype = "float32"
t = TensorType(shape, dtype)
d = Var("d", t)
f = Function([d], d * d)
g = dcpe(gradient(f)) g = dcpe(gradient(f))
m = d * d m = d * d
o = relay.op.ones_like(m) x = relay.Var("x")
grad = relay.op.zeros_like(d) + relay.op.collapse_sum_like(o * d, d) + relay.op.collapse_sum_like(o * d, d) o = op.ones_like(x)
expected = relay.Function([d], relay.Tuple([m, relay.Tuple([grad])])) x1 = relay.Var("x1")
grad = op.zeros_like(d) + op.collapse_sum_like(x1 * d, d) + op.collapse_sum_like(x1 * d, d)
body = Tuple([x, Tuple([grad])])
body = relay.Let(x1, o, body)
expected = Function([d], relay.Let(x, m, body))
assert alpha_equal(g, expected) assert alpha_equal(g, expected)
def test_if_ref(): def test_if_ref():
shape = () shape = ()
dtype = "bool" dtype = "bool"
t = relay.TensorType(shape, dtype) t = TensorType(shape, dtype)
d = relay.Var("d", t) d = Var("d", t)
r = relay.Var("r") r = Var("r")
update = relay.Function([], relay.RefWrite(r, relay.RefRead(r) + relay.RefRead(r))) update = Function([], RefWrite(r, RefRead(r) + RefRead(r)))
u = relay.Var("u") u = Var("u")
body = relay.If(d, u(), u()) body = If(d, u(), u())
eff = relay.Var("eff") eff = Var("eff")
body = relay.Let(eff, body, relay.RefRead(r)) body = Let(eff, body, RefRead(r))
f = relay.Function([d], relay.Let(r, relay.RefCreate(relay.const(1)), relay.Let(u, update, body))) f = Function([d], Let(r, RefCreate(const(1)), Let(u, update, body)))
f = infer_type(f) f = infer_type(f)
pe_f = infer_type(partial_evaluate(f)) pe_f = infer_type(partial_evaluate(f))
ex = create_executor() ex = create_executor()
f_res = ex.evaluate(f)(relay.const(True)) f_res = ex.evaluate(f)(const(True))
pe_f_res = ex.evaluate(pe_f)(relay.const(True)) pe_f_res = ex.evaluate(pe_f)(const(True))
np.testing.assert_allclose(f_res.asnumpy(), 2 * np.ones_like(f_res.asnumpy())) np.testing.assert_allclose(f_res.asnumpy(), 2 * np.ones_like(f_res.asnumpy()))
np.testing.assert_allclose(pe_f_res.asnumpy(), 2 * np.ones_like(pe_f_res.asnumpy())) np.testing.assert_allclose(pe_f_res.asnumpy(), 2 * np.ones_like(pe_f_res.asnumpy()))
@ -105,52 +119,162 @@ def test_if_ref():
def test_function_invalidate(): def test_function_invalidate():
shape = () shape = ()
dtype = "bool" dtype = "bool"
t = relay.TensorType(shape, dtype) t = TensorType(shape, dtype)
d = relay.Var("d", t) d = Var("d", t)
r = relay.Var("r") r = Var("r")
fetch = relay.Function([], relay.RefRead(r)) fetch = Function([], RefRead(r))
fet = relay.Var("fetch") fet = Var("fetch")
fet_obscured = relay.Var("fetch_obscured") fet_obscured = Var("fetch_obscured")
u = relay.Var("u") u = Var("u")
body = relay.If(d, fet_obscured(), fet_obscured()) body = If(d, fet_obscured(), fet_obscured())
body = relay.Let(u, relay.RefWrite(r, relay.const(1)), body) body = Let(u, RefWrite(r, const(1)), body)
body = relay.Let(fet_obscured, relay.If(d, fet, fet), body) body = Let(fet_obscured, If(d, fet, fet), body)
body = relay.Let(fet, fetch, body) body = Let(fet, fetch, body)
body = relay.Let(r, relay.RefCreate(relay.const(0)), body) body = Let(r, RefCreate(const(0)), body)
f = relay.Function([d], body) f = Function([d], body)
f = infer_type(f) f = infer_type(f)
pe_f = infer_type(partial_evaluate(f)) pe_f = infer_type(partial_evaluate(f))
ex = create_executor() ex = create_executor()
f_res = ex.evaluate(f)(relay.const(True)) f_res = ex.evaluate(f)(const(True))
pe_f_res = ex.evaluate(pe_f)(relay.const(True)) pe_f_res = ex.evaluate(pe_f)(const(True))
np.testing.assert_allclose(f_res.asnumpy(), np.ones_like(f_res.asnumpy())) np.testing.assert_allclose(f_res.asnumpy(), np.ones_like(f_res.asnumpy()))
np.testing.assert_allclose(pe_f_res.asnumpy(), np.ones_like(pe_f_res.asnumpy())) np.testing.assert_allclose(pe_f_res.asnumpy(), np.ones_like(pe_f_res.asnumpy()))
def test_head_cons(): def test_head_cons():
mod = relay.Module() mod = Module()
p = Prelude(mod) p = Prelude(mod)
def hd_impl(): def hd_impl():
a = relay.TypeVar("a") a = TypeVar("a")
x = relay.Var("x", p.l(a)) x = Var("x", p.l(a))
y = relay.Var("y") y = Var("y")
z = relay.Var("z") z = Var("z")
cons_case = relay.Clause(relay.PatternConstructor(p.cons, cons_case = Clause(PatternConstructor(p.cons,
[relay.PatternVar(y), [PatternVar(y),
relay.PatternVar(z)]), PatternVar(z)]),
y) y)
return relay.Function([x], relay.Match(x, [cons_case]), a, [a]) y = Var("y")
t = relay.TypeVar("t") z = Var("z")
x = relay.Var("x", t) return Function([x], Match(x, [cons_case]), a, [a])
hd = relay.Var("hd") t = TypeVar("t")
body = relay.Let(hd, hd_impl(), hd(p.cons(x, p.nil()))) x = Var("x", t)
f = relay.Function([x], body, None, [t]) hd = Var("hd")
body = Let(hd, hd_impl(), hd(p.cons(x, p.nil())))
f = Function([x], body, None, [t])
f = infer_type(f, mod=mod) f = infer_type(f, mod=mod)
res = dcpe(f) res = dcpe(f)
assert alpha_equal(res, relay.Function([x], x, t, [t])) assert alpha_equal(res, Function([x], x, t, [t]))
def test_map():
mod = Module()
p = Prelude(mod)
f = Var("f")
orig = p.map(f, p.cons(const(1), p.cons(const(2), p.cons(const(3), p.nil()))))
expected = p.cons(f(const(1)), p.cons(f(const(2)), p.cons(f(const(3)), p.nil())))
assert alpha_equal(dcpe(orig, mod=mod), expected)
def test_loop():
mod = Module()
t = TypeVar("t")
x = Var("x", t)
loop = GlobalVar("loop")
mod[loop] = Function([x], loop(x), t, [t])
res = dcpe(loop(const(1)), mod=mod)
expected = Call(loop, [const(1)], None, [None])
assert alpha_equal(res, expected)
def test_swap_loop():
mod = Module()
p = Prelude(mod)
add_nat_definitions(p)
nat = p.nat()
x = Var("x", nat)
y = Var("y", nat)
loop = GlobalVar("loop")
mod[loop] = Function([x, y], loop(y, x), nat)
prog = loop(make_nat_expr(p, 1), make_nat_expr(p, 2))
res = dcpe(prog, mod=mod)
assert alpha_equal(prog, res)
def test_abs_diff():
# TODO(@M.K.): refactor using tuple pattern (not yet implemented)
mod = Module()
p = Prelude(mod)
add_nat_definitions(p)
nat = p.nat()
x = Var("x", nat)
y = Var("y", nat)
xp = Var("x'", nat)
yp = Var("y'", nat)
diff = GlobalVar("diff")
y_z_case = Clause(PatternConstructor(p.z, []), x)
y_s_case = Clause(PatternConstructor(p.s, [PatternVar(yp)]), diff(yp, xp))
x_z_case = Clause(PatternConstructor(p.z, []), y)
x_s_case = Clause(PatternConstructor(p.s, [PatternVar(xp)]), Match(y, [y_z_case, y_s_case]))
mod[diff] = Function([x, y], Match(x, [x_z_case, x_s_case]))
orig = diff(make_nat_expr(p, 7), make_nat_expr(p, 3))
res = dcpe(orig, mod=mod)
assert alpha_equal(res, make_nat_expr(p, 4))
def test_match_nat_id():
mod = Module()
p = Prelude(mod)
add_nat_definitions(p)
nat = p.nat()
x = Var("x", nat)
y = Var("y", nat)
nat_id = GlobalVar("nat_id")
z_case = Clause(PatternConstructor(p.z, []), p.z())
s_case = Clause(PatternConstructor(p.s, [PatternVar(y)]), p.s(y))
mod[nat_id] = Function([x], Match(x, [z_case, s_case]))
orig = nat_id(make_nat_expr(p, 3))
res = dcpe(orig, mod=mod)
assert alpha_equal(res, make_nat_expr(p, 3))
def test_nat_id():
mod = Module()
p = Prelude(mod)
add_nat_definitions(p)
nat = p.nat()
x = Var("x", nat)
y = Var("y", nat)
nat_id = GlobalVar("nat_id")
mod[nat_id] = Function([x], x)
orig = nat_id(make_nat_expr(p, 3))
res = dcpe(orig, mod=mod)
assert alpha_equal(res, make_nat_expr(p, 3))
def test_global_match_nat_id():
mod = Module()
p = Prelude(mod)
add_nat_definitions(p)
nat = p.nat()
x = Var("x", nat)
z_case = Clause(PatternConstructor(p.z, []), p.z())
s_case = Clause(PatternConstructor(p.s, [PatternVar(x)]), p.s(x))
orig = Match(make_nat_expr(p, 3), [z_case, s_case])
res = dcpe(orig, mod=mod)
assert alpha_equal(res, make_nat_expr(p, 3))
def test_double():
mod = Module()
p = Prelude(mod)
add_nat_definitions(p)
orig = p.double(make_nat_expr(p, 3))
res = dcpe(orig, mod=mod)
assert alpha_equal(res, make_nat_expr(p, 6))
if __name__ == '__main__': if __name__ == '__main__':
test_empty_ad()
test_tuple() test_tuple()
test_const_inline() test_const_inline()
test_ref() test_ref()
@ -158,3 +282,11 @@ if __name__ == '__main__':
test_if_ref() test_if_ref()
test_function_invalidate() test_function_invalidate()
test_head_cons() test_head_cons()
test_map()
test_loop()
test_swap_loop()
test_abs_diff()
test_double()
test_nat_id()
test_global_match_nat_id()
test_match_nat_id()