save

save

save

upstream

lint

remove bad changes

fix build

save

save

please the ci god

Update src/relay/pass/partial_eval.cc

Co-Authored-By: Wei Chen <ipondering.weic@gmail.com>

save

fix test

ci is ANGRY

fix rebase problem

fix rebase

add test

save

save

comment
This commit is contained in:
雾雨魔理沙 2019-06-15 15:08:46 -07:00 коммит произвёл Jared Roesch
Родитель 50dd03ca86
Коммит df88c411f5
7 изменённых файлов: 625 добавлений и 172 удалений

Просмотреть файл

@ -296,13 +296,15 @@ TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const Module& mod);
* For example, this pass should turn `let a = 1 in 2` into `2`,
* as the value of the expression does not depend on a.
*
* As another example, `let a = 1 in a` will be optimized into 1.
* As another example, `let a = 1 in a` will be optimized into 1,
* if the flag is turned on.
*
* \param e the expression to optimize.
* \param inline_once whether or not to inline binding used one.
*
* \return the optimized expression.
*/
TVM_DLL Expr DeadCodeElimination(const Expr& e);
TVM_DLL Expr DeadCodeElimination(const Expr& e, bool inline_once = false);
/*!
* \brief Fold constant expressions.
@ -435,11 +437,12 @@ TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const Module& mod);
* It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
* As a side effect, code size will explode.
*
* \param e the expression,
* \param e the expression
* \param mod the module
*
* \return the optimized expression.
*/
TVM_DLL Expr PartialEval(const Expr& e);
TVM_DLL Expr PartialEval(const Expr& e, const Module& mod);
/*!
* \brief Bind the free variables to a Relay expression.

Просмотреть файл

@ -356,9 +356,11 @@ TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc<
*
* As another example, `let a = 1 in a` will be optimized into 1.
*
* \param inline_once whether or not to inline binding used one.
*
* \return the pass.
*/
TVM_DLL Pass DeadCodeElimination();
TVM_DLL Pass DeadCodeElimination(bool inline_once = false);
/*!
* \brief Fold constant expressions.

Просмотреть файл

@ -129,7 +129,7 @@ def well_formed(expr):
Parameters
----------
expr: tvm.relay.Expr
expr : tvm.relay.Expr
The input expression
Returns
@ -175,7 +175,7 @@ def free_vars(expr):
Parameters
----------
expr: tvm.relay.Expr
expr : tvm.relay.Expr
The input expression
Returns
@ -197,7 +197,7 @@ def bound_vars(expr):
Parameters
----------
expr: tvm.relay.Expr
expr : tvm.relay.Expr
The input expression
Returns
@ -213,7 +213,7 @@ def all_vars(expr):
Parameters
----------
expr: tvm.relay.Expr
expr : tvm.relay.Expr
The input expression
Returns
@ -229,9 +229,10 @@ def free_type_vars(expr, mod=None):
Parameters
----------
expr: Union[tvm.relay.Expr,tvm.relay.Type]
expr : Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
mod: tvm.relay.Module, optional
mod : Optional[tvm.relay.Module]
The global module
Returns
@ -248,9 +249,10 @@ def bound_type_vars(expr, mod=None):
Parameters
----------
expr: Union[tvm.relay.Expr,tvm.relay.Type]
expr : Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
mod: tvm.relay.Module, optional
mod : Optional[tvm.relay.Module]
The global module
Returns
@ -267,9 +269,9 @@ def all_type_vars(expr, mod=None):
Parameters
----------
expr: Union[tvm.relay.Expr,tvm.relay.Type]
expr : Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
mod: tvm.relay.Module, optional
mod : Optional[tvm.relay.Module]
The global module
Returns
@ -286,12 +288,12 @@ def simplify_inference(expr):
Parameters
----------
e: tvm.relay.Expr
expr : tvm.relay.Expr
The input Expression
Returns
-------
result: tvm.relay.Expr
result : tvm.relay.Expr
An expression which is semantically equal to the input expression,
but with some simplification
"""
@ -304,32 +306,34 @@ def canonicalize_ops(expr):
Parameters
----------
e: tvm.relay.Expr
expr : tvm.relay.Expr
The input Expression
Returns
-------
result: tvm.relay.Expr
result : tvm.relay.Expr
An expression without bias_add
"""
return _ir_pass.canonicalize_ops(expr)
def dead_code_elimination(expr):
def dead_code_elimination(expr, inline_once=False):
""" Remove expressions which does not effect the program result (dead code).
Parameters
----------
e: tvm.relay.Expr
expr : tvm.relay.Expr
The input Expression
inline_once : Optional[Bool]
Whether to inline binding that occur only once.
Returns
-------
result: tvm.relay.Expr
result : tvm.relay.Expr
An expression which is semantically equal to the input expression,
but with dead code removed.
"""
return _ir_pass.dead_code_elimination(expr)
return _ir_pass.dead_code_elimination(expr, inline_once)
def alpha_equal(lhs, rhs):
@ -337,15 +341,15 @@ def alpha_equal(lhs, rhs):
Parameters
----------
lhs: tvm.relay.Expr
lhs : tvm.relay.Expr
One of the input Expression.
rhs: tvm.relay.Expr
rhs : tvm.relay.Expr
One of the input Expression.
Returns
-------
result: bool
result : bool
True iff lhs is alpha equal to rhs.
"""
return bool(_make._alpha_equal(lhs, rhs))
@ -359,15 +363,15 @@ def graph_equal(lhs, rhs):
Parameters
----------
lhs: tvm.relay.Expr
lhs : tvm.relay.Expr
One of the input Expression.
rhs: tvm.relay.Expr
rhs : tvm.relay.Expr
One of the input Expression.
Returns
-------
result: bool
result : bool
True iff lhs is data-flow equivalent to rhs.
"""
return bool(_make._graph_equal(lhs, rhs))
@ -378,12 +382,12 @@ def structural_hash(value):
Parameters
----------
expr: tvm.relay.Expr or tvm.relay.Type
expr : Union[tvm.relay.Expr, tvm.relay.Type]
The expression to hash.
Returns
-------
result: int
result : int
The hash value
"""
if isinstance(value, Expr):
@ -544,12 +548,12 @@ def to_a_normal_form(expr, mod=None):
expr : tvm.relay.Expr
The input expression.
mod: Optional[tvm.relay.Module]
mod : Optional[tvm.relay.Module]
The global module.
Returns
-------
expr: tvm.relay.Expr
result : tvm.relay.Expr
The output expression.
"""
return _ir_pass.to_a_normal_form(expr, mod)
@ -563,7 +567,7 @@ def to_graph_normal_form(expr):
The input expression
Returns
-------
expr : tvm.relay.Expr
result : tvm.relay.Expr
The output expression
"""
return _ir_pass.to_graph_normal_form(expr)
@ -612,7 +616,7 @@ def get_total_mac_number(expr):
Returns
-------
ret : int64
result : int64
The number of MACs (multiply-accumulate) of a model
"""
return _ir_pass.GetTotalMacNumber(expr)
@ -627,17 +631,17 @@ def eliminate_common_subexpr(expr, fskip=None):
expr : tvm.relay.Expr
The input expression.
fskip: function
fskip : function
The callback function that decides whether an expression should be skipped.
Returns
-------
expr : tvm.relay.Expr
result : tvm.relay.Expr
The output expression.
"""
return _ir_pass.eliminate_common_subexpr(expr, fskip)
def partial_evaluate(expr):
def partial_evaluate(expr, mod=None):
"""
Evaluate the static fragment of the code.
@ -646,12 +650,15 @@ def partial_evaluate(expr):
expr : tvm.relay.Expr
The input expression.
mod : Optional[tvm.relay.Module]
The global module
Returns
-------
expr : tvm.relay.Expr
result : tvm.relay.Expr
The output expression.
"""
return _ir_pass.partial_evaluate(expr)
return _ir_pass.partial_evaluate(expr, mod)
def unmatched_cases(match, mod=None):
"""

Просмотреть файл

@ -220,8 +220,8 @@ TVM_REGISTER_API("relay._make.Call")
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<CallNode>([](const CallNode* node, tvm::IRPrinter* p) {
p->stream << "CallNode(" << node->op << ", " << node->args << ", "
<< node->attrs << ", " << node->type_args << ")";
p->stream << "CallNode(" << node->op << ", " << node->args << ", "
<< node->attrs << ", " << node->type_args << ")";
});
Let LetNode::make(Var var, Expr value, Expr body) {
@ -324,7 +324,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
TVM_REGISTER_API("relay._expr.TempExprRealize")
.set_body_typed<Expr(TempExpr)>([](TempExpr temp) {
return temp->Realize();
return temp->Realize();
});
} // namespace relay

Просмотреть файл

@ -38,10 +38,10 @@ namespace relay {
// calculate the dependency graph from expression
class CalcDep : private ExprVisitor {
public:
static Expr Eliminate(const Expr& e) {
static Expr Eliminate(const Expr& e, bool inline_once) {
CalcDep cd;
cd.Calculate(e);
Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_);
Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_, inline_once);
return el(e);
}
@ -117,15 +117,23 @@ class CalcDep : private ExprVisitor {
VarMap<Expr> expr_map_;
VarMap<size_t> use_map_;
VarSet letrec_set_;
bool inline_once_;
explicit Eliminator(const VarMap<Expr>& expr_map,
const VarMap<size_t>& use_map,
const VarSet& letrec_set) :
expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set) { }
const VarSet& letrec_set,
bool inline_once) :
expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set), inline_once_(inline_once) { }
friend CalcDep;
bool HasLet(const Var& v) {
// TODO(@jroesch): MK fix me
return (use_map_[v] > 0 || (use_map_[v] != 0 && letrec_set_.count(v) != 0));
switch (use_map_[v]) {
case 0:
return false;
case 1:
return letrec_set_.count(v) > 0 || !inline_once_;
default:
return true;
}
}
Expr VisitExpr_(const VarNode* op) final {
@ -144,8 +152,8 @@ class CalcDep : private ExprVisitor {
};
};
Expr DeadCodeElimination(const Expr& e) {
return CalcDep::Eliminate(e);
Expr DeadCodeElimination(const Expr& e, bool inline_once) {
return CalcDep::Eliminate(e, inline_once);
}
TVM_REGISTER_API("relay._ir_pass.dead_code_elimination")
@ -153,10 +161,10 @@ TVM_REGISTER_API("relay._ir_pass.dead_code_elimination")
namespace transform {
Pass DeadCodeElimination() {
Pass DeadCodeElimination(bool inline_once) {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(DeadCodeElimination(f));
return Downcast<Function>(DeadCodeElimination(f, inline_once));
};
return CreateFunctionPass(pass_func, 1, "DeadCodeElimination", {});
}

Просмотреть файл

@ -74,28 +74,19 @@
*
* The partial evaluator makes several assumptions, so there is room for improvement:
*
* 0: The partial evaluator treats global variables as opaque.
* Doing PartialEval on a module level will solve this.
*
* 1: The partial evaluator assume all functions as terminating.
* We need to has a max_expand parameter that shrink on every compile time evaluation,
* to make sure PE does not infinite loop.
* Additionally, we might add a termination analysis pass that lift this requirement
* for function that analysis found terminating.
*
* 2: Every time an unknown effect happened, we clear the whole store.
* 0: Every time an unknown effect happened, we clear the whole store.
* It is too conservative: if a local reference is created (and do not get passed outside),
* An unknown global function call/global reference write can not modify it.
* We can pair PE with escape analysis/alias analysis.
*
* 3: We assume all unknown code has effect. Doing effect analysis can make the store more precise.
* 1: We assume all unknown code has effect. Doing effect analysis can make the store more precise.
*
* 4: When doing pattern matching, we can simplify the match even for dynamic case.
* 2: When doing pattern matching, we can simplify the match even for dynamic case.
* Right now it is all or nothing: either a complete match, or the original dynamic code.
* Instead, we can get a match tree, pair it with the data and evaluate it to a normal form.
* We then can reify the result.
*
* 5: Every time a function is called, it's code will get expanded and partially evaluated.
* 3: Every time a function is called, its code will get expanded and partially evaluated.
* We can do a binding time analysis to cache the result and avoid re-partial evaluation.
*
* These assumptions do not affect the correctness of the algorithm, however.
@ -104,6 +95,7 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/relay/interpreter.h>
#include "../ir/type_functor.h"
#include "pass_util.h"
#include "let_list.h"
@ -132,6 +124,8 @@ struct VarEqual {
}
};
Expr PostProcess(const Expr&);
/*! \brief The base container type of Relay values. */
class StaticNode : public RelayNode {
public:
@ -150,10 +144,20 @@ class Static : public NodeRef {
using ContainerType = StaticNode;
};
using Time = size_t;
struct PStaticNode : Node {
static Time time() {
static Time time_ = 0;
Time ret = time_;
time_++;
return ret;
}
Static pstatic; // may be null
Expr dynamic;
PStaticNode(const Static& pstatic, const Expr& dynamic) : pstatic(pstatic), dynamic(dynamic) { }
Time created_time;
PStaticNode(const Static& pstatic, const Expr& dynamic) :
pstatic(pstatic), dynamic(dynamic), created_time(time()) { }
explicit PStaticNode(const Expr& dynamic) : PStaticNode(Static(), dynamic) { }
TVM_DECLARE_NODE_TYPE_INFO(PStaticNode, Node);
};
@ -341,6 +345,7 @@ class Store {
};
PStatic HasStatic(const Static& stat, const Expr& dynamic) {
CHECK(stat.defined());
return PStatic(make_node<PStaticNode>(stat, dynamic));
}
@ -383,15 +388,78 @@ FInterpreter CPUInterpreter() {
return CreateInterpreter(Module(nullptr), CPUContext(), target);
}
bool IsAtomic(const Expr& e) {
return e.as<VarNode>() || e.as<OpNode>() || e.as<ConstructorNode>() || e.as<GlobalVarNode>();
}
using FuncId = int;
/*!
* \brief Annotate a function with a FuncId.
*/
struct WithFuncIdAttrs : public tvm::AttrsNode<WithFuncIdAttrs> {
FuncId fid;
TVM_DECLARE_ATTRS(WithFuncIdAttrs, "relay.attrs.WithFuncIdAttrs") {
TVM_ATTR_FIELD(fid)
.describe("The FuncId that an function is annotated with.")
.set_default(-1);
}
};
TVM_REGISTER_NODE_TYPE(WithFuncIdAttrs);
Op WithFuncIdOp() {
static const Op& op = Op::Get("annotation.with_funcid");
return op;
}
Expr MkWithFuncId(const Expr& expr, FuncId fid) {
auto attrs = make_node<WithFuncIdAttrs>();
attrs->fid = fid;
return CallNode::make(WithFuncIdOp(), {expr}, Attrs(attrs), {});
}
RELAY_REGISTER_OP("annotation.with_funcid")
.describe(R"code(Annotate a function with a funcid.)code"
TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("func", "Function", "The input data.");
Expr StripWithFuncId(const Expr& e);
Expr DeDup(const Expr& e);
Function AsFunc(const Expr& e) {
if (e.as<FunctionNode>()) {
return Downcast<Function>(e);
} else if (const CallNode* c = e.as<CallNode>()) {
CHECK(c->op.same_as(WithFuncIdOp()));
CHECK_EQ(c->args.size(), 1);
return AsFunc(c->args[0]);
} else {
LOG(FATAL) << "Unknown case";
throw;
}
}
class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>,
public PatternFunctor<MatchStatus(const Pattern&, const PStatic&)> {
public:
PartialEvaluator(const tvm::Array<Var>& free_vars) {
PartialEvaluator(const tvm::Array<Var>& free_vars,
const Module& mod) :
mod_(mod) {
for (const Var& v : free_vars) {
env_.Insert(v, NoStatic(v));
}
}
PStatic VisitExpr(const Expr& e, LetList* ll) final {
PStatic ret = ExprFunctor<PStatic(const Expr&, LetList*)>::VisitExpr(e, ll);
CHECK(IsAtomic(ret->dynamic)) << ret->dynamic;
return ret;
}
PStatic VisitExpr_(const ConstantNode* op, LetList* ll) final {
return HasStatic(MkSTensor(op->data.CopyTo(context_)), ll->Push(GetRef<Expr>(op)));
}
@ -421,7 +489,20 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
}
PStatic VisitExpr_(const GlobalVarNode* op, LetList* ll) final {
return NoStatic(GetRef<Expr>(op));
GlobalVar gv = GetRef<GlobalVar>(op);
if (gv_map_.count(gv) == 0) {
if (mod_.defined()) {
Function func = mod_->Lookup(gv);
InitializeFuncId(func);
Func f = VisitFuncStatic(func, gv);
gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)});
func = AsFunc(PostProcess(VisitFuncDynamic(func, f)));
mod_->Update(gv, func);
} else {
gv_map_.insert({gv, NoStatic(gv)});
}
}
return gv_map_.at(gv);
}
PStatic VisitExpr_(const LetNode* op, LetList* ll) final {
@ -485,6 +566,10 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
}
PStatic VisitExpr_(const CallNode* op, LetList* ll) final {
if (op->op.same_as(WithFuncIdOp())) {
CHECK_EQ(op->args.size(), 1);
return VisitExpr(op->args[0], ll);
}
PStatic f = VisitExpr(op->op, ll);
std::vector<PStatic> x;
tvm::Array<Expr> x_dyn;
@ -501,19 +586,40 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
}
}
PStatic VisitExpr_(const FunctionNode* op, LetList* ll) final {
Function func = GetRef<Function>(op);
struct TimeFrame {
PartialEvaluator* pe_;
FuncId fid_;
std::vector<Time> old_time;
bool has_old_time;
TimeFrame(PartialEvaluator* pe,
FuncId fid,
const std::vector<Time>& args_time) : pe_(pe), fid_(fid) {
has_old_time = pe_->time_map_.count(fid_) > 0;
old_time = pe_->time_map_[fid_];
pe_->time_map_[fid_] = args_time;
}
~TimeFrame() {
if (has_old_time) {
pe_->time_map_[fid_] = old_time;
} else {
pe_->time_map_.erase(fid_);
}
}
};
Func VisitFuncStatic(const Function& func, const Expr& var) {
CHECK(IsAtomic(var));
if (func->IsPrimitive()) {
return HasStatic(MkSFunc(ConstEvaluateFunc(func, ll)), func);
return ConstEvaluateFunc(func);
}
std::vector<std::pair<Var, PStatic> > free_vars;
for (const auto& v : FreeVars(GetRef<Expr>(op))) {
for (const auto& v : FreeVars(func)) {
free_vars.push_back(std::pair<Var, PStatic>(v, env_.Lookup(v)));
}
Func f = [=](const std::vector<PStatic>& pv,
const Attrs& attrs,
const tvm::Array<Type>& type_args,
LetList* ll) {
return [=](const std::vector<PStatic>& pv,
const Attrs& attrs,
const tvm::Array<Type>& type_args,
LetList* ll) {
return env_.Extend<PStatic>([&]() {
CHECK_EQ(pv.size(), func->params.size());
for (size_t i = 0; i < pv.size(); ++i) {
@ -529,10 +635,50 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
for (size_t i = type_args.size(); i < func->type_params.size(); ++i) {
subst.Set(func->type_params[i], Type());
}
return VisitExpr(TypeSubst(func->body, subst), ll);
std::vector<Time> args_time;
for (const auto& v : pv) {
args_time.push_back(v->created_time);
}
CHECK_GT(func_map_.count(func), 0);
FuncId fid = func_map_.at(func);
auto recurse = [&]() {
TimeFrame tf(this, fid, args_time);
return VisitExpr(RegisterFuncId(TypeSubst(AnnotateFuncId(func->body), subst)), ll);
};
if (time_map_.count(fid) == 0) {
return recurse();
} else {
/* We check to see that at least one argument decrease
* with respect to all previous invocation.
* The depth of the recursion is bounded by
* the sum of the time of all argument at the first call.
*/
bool can_recurse = false;
std::vector<Time>& min_time = time_map_.at(fid);
CHECK_EQ(args_time.size(), min_time.size());
for (size_t i = 0; i < args_time.size(); ++i) {
if (args_time[i] < min_time[i]) {
can_recurse = true;
}
args_time[i] = std::min(args_time[i], min_time[i]);
}
if (can_recurse) {
return recurse();
} else {
std::vector<Expr> dyn;
for (const auto& v : pv) {
dyn.push_back(v->dynamic);
}
return NoStatic(ll->Push(CallNode::make(var, dyn, attrs, type_args)));
}
}
});
};
Expr dyn = store_.Extend<Expr>([&]() {
}
Expr VisitFuncDynamic(const Function& func, const Func& f) {
return store_.Extend<Expr>([&]() {
store_.Invalidate();
return FunctionNode::make(func->params, LetList::With([&](LetList* ll) {
std::vector<PStatic> pv;
@ -546,7 +692,20 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
return f(pv, Attrs(), type_args, ll)->dynamic;
}), func->ret_type, func->type_params, func->attrs);
});
return HasStatic(MkSFunc(f), ll->Push(dyn));
}
PStatic VisitFunc(const Function& func, LetList* ll) {
Var v = VarNode::make("x", Type());
Func f = VisitFuncStatic(func, v);
Function u_func = AsFunc(RegisterFuncId(DeDup(AnnotateFuncId(func))));
// TODO(@M.K.): we seems to reduce landin knot into letrec.
// restore letrec support across whole relay.
return HasStatic(MkSFunc(f),
ll->Push(v, VisitFuncDynamic(u_func, f)));
}
PStatic VisitExpr_(const FunctionNode* op, LetList* ll) final {
return VisitFunc(GetRef<Function>(op), ll);
}
Expr Reflect(const PStatic& st) {
@ -590,7 +749,8 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
return Reify(executor_(fused_infered), ll);
}
Func ConstEvaluateFunc(const Expr& expr, LetList* ll) {
Func ConstEvaluateFunc(const Expr& expr) {
CHECK_EQ(FreeVars(expr).size(), 0);
return [=](const std::vector<PStatic>& pv,
const Attrs& attrs,
const tvm::Array<Type>& type_args,
@ -599,7 +759,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
for (const PStatic& ps : pv) {
ns_args.push_back(ps->dynamic);
}
PStatic ns = NoStatic(CallNode::make(expr, ns_args, attrs, type_args));
PStatic ns = NoStatic(ll->Push(CallNode::make(expr, ns_args, attrs, type_args)));
if (StatefulOp(expr)) {
return ns;
}
@ -616,7 +776,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
}
PStatic VisitExpr_(const OpNode* op, LetList* ll) final {
return HasStatic(MkSFunc(ConstEvaluateFunc(GetRef<Expr>(op), ll)), GetRef<Expr>(op));
return HasStatic(MkSFunc(ConstEvaluateFunc(GetRef<Expr>(op))), GetRef<Expr>(op));
}
PStatic VisitExpr_(const ConstructorNode* op, LetList* ll) final {
@ -680,7 +840,6 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
CHECK_NE(op->constructor->tag, -1);
CHECK_NE(scn->constructor->tag, -1);
if (op->constructor->tag == scn->constructor->tag) {
// todo(M.K.): should use ptr equality but it is broken
CHECK_EQ(op->patterns.size(), scn->fields.size());
MatchStatus current_match_status = MatchStatus::Match;
for (size_t i = 0; i < op->patterns.size(); ++i) {
@ -702,27 +861,119 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
}
}
void InitializeFuncId(const Expr& e) {
struct InitializeFuncIdVisitor : ExprVisitor, PatternVisitor {
PartialEvaluator* pe;
explicit InitializeFuncIdVisitor(PartialEvaluator* pe) : pe(pe) { }
void VisitExpr_(const FunctionNode* op) final {
Function f = GetRef<Function>(op);
CHECK_EQ(pe->func_map_.count(f), 0);
pe->func_map_.insert({f, pe->func_map_.size()});
VisitExpr(f->body);
}
void VisitPattern(const Pattern& p) final {
PatternVisitor::VisitPattern(p);
}
};
InitializeFuncIdVisitor(this).VisitExpr(e);
}
Expr RegisterFuncId(const Expr& e) {
struct RegisterFuncIdVisitor : ExprVisitor, PatternVisitor {
PartialEvaluator* pe;
explicit RegisterFuncIdVisitor(PartialEvaluator* pe) : pe(pe) { }
void VisitExpr_(const CallNode* op) final {
if (op->op.same_as(WithFuncIdOp())) {
CHECK_EQ(op->args.size(), 1);
CHECK(op->attrs.defined());
CHECK(op->attrs.as<WithFuncIdAttrs>());
Function f = AsFunc(op->args[0]);
FuncId fid = op->attrs.as<WithFuncIdAttrs>()->fid;
if (pe->func_map_.count(f) != 0) {
CHECK_EQ(pe->func_map_.at(f), fid);
}
pe->func_map_.insert({f, fid});
}
ExprVisitor::VisitExpr_(op);
}
void VisitExpr_(const FunctionNode* op) final {
Function f = GetRef<Function>(op);
CHECK_GT(pe->func_map_.count(f), 0);
ExprVisitor::VisitExpr_(op);
}
void VisitPattern(const Pattern& p) final {
PatternVisitor::VisitPattern(p);
}
};
RegisterFuncIdVisitor(this).VisitExpr(e);
return e;
}
Expr AnnotateFuncId(const Expr& e) {
struct AnnotateFuncIdMutator : ExprMutator, PatternMutator {
PartialEvaluator* pe;
explicit AnnotateFuncIdMutator(PartialEvaluator* pe) : pe(pe) { }
Expr VisitExpr_(const FunctionNode* op) final {
Function f = GetRef<Function>(op);
CHECK_GT(pe->func_map_.count(f), 0);
return MkWithFuncId(ExprMutator::VisitExpr_(op), pe->func_map_.at(f));
}
Pattern VisitPattern(const Pattern& p) final {
return PatternMutator::VisitPattern(p);
}
Var VisitVar(const Var& v) final {
return v;
}
};
return AnnotateFuncIdMutator(this).VisitExpr(e);
}
private:
Environment env_;
Module mod_;
std::unordered_map<GlobalVar, PStatic, NodeHash, NodeEqual> gv_map_;
/*! Termination checking is done as follows:
* We have finitely many FunctionIds.
* Each FunctionId maps to a class of semantically equivalent function (ignoring type),
* as both TypeSubst and DeDup create semantically equivalent function.
* We partially map each FunctionId to a std::vector<Time>,
* denoting the minimal TimeFrame of each argument of the function.
* Every time we try to inline a Function,
* we make sure it either does not have a vector<Time>, which means this is the initial call,
* or some argument has a lesser time, which means some earlier argument is passed in.
* In any case, we remap the mapping to a minimal vector<Time> across all previous invocations
* when we PE inside the Function body.
* Termination is guaranteed because the creation time of at least one argument will decrease every call.
*/
std::unordered_map<Function, FuncId, NodeHash, NodeEqual> func_map_;
std::unordered_map<FuncId, std::vector<Time> > time_map_;
Store store_;
DLContext context_ = CPUContext();
FInterpreter executor_ = CPUInterpreter();
};
Var DeDupVar(const Var& v) {
return VarNode::make(v->name_hint(), v->type_annotation);
}
TypeVar DeDupTypeVar(const TypeVar& tv) {
return TypeVarNode::make(tv->var->name_hint, tv->kind);
}
/*! \brief Use a fresh Id for every Var to make the result well-formed. */
Expr DeDup(const Expr& e) {
class DeDupMutator : public ExprMutator, public PatternMutator {
class DeDupMutator : public TypeMutator,
public ExprMutator,
public PatternMutator {
public:
TypeVar Fresh(const TypeVar& tv) {
TypeVar ret = TypeVarNode::make(tv->var->name_hint, tv->kind);
type_rename_[tv] = ret;
return ret;
}
Var Fresh(const Var& v) {
Var ret = DeDupVar(v);
Var ret = VarNode::make(v->name_hint(), VisitType(v->type_annotation));
rename_[v] = ret;
return ret;
}
@ -737,18 +988,27 @@ Expr DeDup(const Expr& e) {
}
Expr VisitExpr_(const LetNode* op) final {
return LetNode::make(Fresh(op->var), VisitExpr(op->value), VisitExpr(op->body));
Var v = Fresh(op->var);
return LetNode::make(v, VisitExpr(op->value), VisitExpr(op->body));
}
Type VisitType(const Type& t) final {
return t.defined() ? TypeMutator::VisitType(t) : t;
}
Expr VisitExpr_(const FunctionNode* op) final {
tvm::Array<TypeVar> type_params;
for (const TypeVar& type_param : op->type_params) {
type_params.push_back(Fresh(type_param));
}
tvm::Array<Var> params;
for (const Var& param : op->params) {
params.push_back(Fresh(param));
}
return FunctionNode::make(params,
VisitExpr(op->body),
op->ret_type,
op->type_params,
VisitType(op->ret_type),
type_params,
op->attrs);
}
@ -756,14 +1016,28 @@ Expr DeDup(const Expr& e) {
return PatternMutator::VisitPattern(p);
}
Clause VisitClause(const Clause& c) final {
Pattern pat = VisitPattern(c->lhs);
return ClauseNode::make(pat, VisitExpr(c->rhs));
}
Type VisitType_(const TypeVarNode* op) final {
TypeVar v = GetRef<TypeVar>(op);
return type_rename_.count(v) != 0 ? type_rename_.at(v) : v;
}
Var VisitVar(const Var& v) final {
return Fresh(v);
}
private:
std::unordered_map<Var, Var, NodeHash, NodeEqual> rename_;
std::unordered_map<TypeVar, TypeVar, NodeHash, NodeEqual> type_rename_;
};
return DeDupMutator().VisitExpr(e);
Expr ret = DeDupMutator().VisitExpr(e);
CHECK_EQ(FreeVars(ret).size(), FreeVars(e).size());
return ret;
}
/*! \brief Remap multiple Var sharing the same Id into the same Var. */
@ -787,11 +1061,38 @@ Expr Remap(const Expr& e) {
return RemapMutator().VisitExpr(e);
}
Expr PartialEval(const Expr& e) {
Expr StripWithFuncId(const Expr& e) {
struct StripWithFuncIdMutator : ExprMutator, PatternMutator {
Expr VisitExpr_(const CallNode* op) final {
if (op->op.same_as(WithFuncIdOp())) {
CHECK_EQ(op->args.size(), 1);
return VisitExpr(op->args[0]);
} else {
return ExprMutator::VisitExpr_(op);
}
}
Pattern VisitPattern(const Pattern& p) final {
return PatternMutator::VisitPattern(p);
}
Var VisitVar(const Var& v) final {
return v;
}
};
return StripWithFuncIdMutator().VisitExpr(e);
}
Expr PostProcess(const Expr& e) {
return StripWithFuncId(DeDup(Remap(e)));
}
Expr PartialEval(const Expr& e, const Module& m) {
return TransformF([&](const Expr& e) {
return LetList::With([&](LetList* ll) {
PartialEvaluator pe(FreeVars(e));
return Remap(DeDup(pe.VisitExpr(e, ll)->dynamic));
PartialEvaluator pe(FreeVars(e), m);
pe.InitializeFuncId(e);
return PostProcess(pe.VisitExpr(e, ll)->dynamic);
});
}, e);
}
@ -804,7 +1105,7 @@ namespace transform {
Pass PartialEval() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(PartialEval(f));
return Downcast<Function>(PartialEval(f, m));
};
return CreateFunctionPass(pass_func, 1, "PartialEvaluate", {});
}

Просмотреть файл

@ -18,14 +18,17 @@
import numpy as np
import tvm
from tvm import relay
from tvm.relay.ir_pass import partial_evaluate, dead_code_elimination
from tvm.relay.ir_pass import gradient, alpha_equal, infer_type
from tvm.relay.ir_pass import partial_evaluate, alpha_equal, infer_type, dead_code_elimination
from tvm.relay.ir_pass import gradient
from tvm.relay import op, create_executor
from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue
from tvm.relay.prelude import Prelude
from tvm.relay import create_executor
from nose.tools import nottest
from tvm.relay import Var, TypeVar, TupleGetItem, Let, Function, const, RefRead, RefWrite, RefCreate
from tvm.relay import TensorType, Tuple, If, Module, Clause, PatternConstructor, PatternVar, Match
from tvm.relay import GlobalVar, Call, Type
from tvm.relay.testing import add_nat_definitions, count, make_nat_value, make_nat_expr
def check_eval(expr, expected_result, mod=None, rtol=1e-07):
ctx = tvm.context("llvm", 0)
@ -35,24 +38,22 @@ def check_eval(expr, expected_result, mod=None, rtol=1e-07):
np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol)
def dcpe(expr):
return dead_code_elimination(partial_evaluate(expr))
def dcpe(expr, mod=None):
return dead_code_elimination(partial_evaluate(expr, mod=mod), inline_once=True)
def test_tuple():
t = relay.TypeVar("t")
x = relay.Var("x", t)
body = relay.TupleGetItem(relay.Tuple([relay.const(4.0), x]), 1)
f = relay.Function([x], body, None, [t])
t = TypeVar("t")
x = Var("x", t)
body = TupleGetItem(relay.Tuple([relay.const(4.0), x]), 1)
f = Function([x], body, None, [t])
assert alpha_equal(dcpe(f), relay.Function([x], x, None, [t]))
@nottest
def test_const_inline():
# TODO(MK): fix me
d = relay.Var("d")
double = relay.Function([d], d + d)
orig = double(relay.const(4.0))
assert alpha_equal(dcpe(double(relay.const(4.0))), relay.const(8.0))
d = Var("d")
double = Function([d], d + d)
orig = double(const(4.0))
assert alpha_equal(dcpe(orig), const(8.0))
def test_ref():
@ -60,44 +61,57 @@ def test_ref():
r = relay.Var("r")
x = relay.Var("x")
body = relay.RefRead(r)
body = relay.Let(x, relay.RefWrite(r, relay.RefRead(r) * relay.RefRead(r)), body)
body = relay.Let(r, relay.RefCreate(d), body)
square = relay.Function([d], body)
assert alpha_equal(dcpe(square), relay.Function([d], d * d))
body = Let(x, RefWrite(r, RefRead(r) * RefRead(r)), body)
body = Let(r, RefCreate(d), body)
square = Function([d], body)
assert alpha_equal(dcpe(square), Function([d], d * d))
@nottest
def test_ad():
# TODO(MK): fix me
def test_empty_ad():
shape = (10, 10)
dtype = "float32"
t = relay.TensorType(shape, dtype)
d = relay.Var("d", t)
f = relay.Function([d], d * d)
t = TensorType(shape, dtype)
d = Var("d", t)
f = Function([d], d)
g = dcpe(gradient(f))
expected = Function([d], Tuple([d, Tuple([op.ones_like(d)])]))
assert alpha_equal(g, expected)
def test_ad():
shape = (10, 10)
dtype = "float32"
t = TensorType(shape, dtype)
d = Var("d", t)
f = Function([d], d * d)
g = dcpe(gradient(f))
m = d * d
o = relay.op.ones_like(m)
grad = relay.op.zeros_like(d) + relay.op.collapse_sum_like(o * d, d) + relay.op.collapse_sum_like(o * d, d)
expected = relay.Function([d], relay.Tuple([m, relay.Tuple([grad])]))
x = relay.Var("x")
o = op.ones_like(x)
x1 = relay.Var("x1")
grad = op.zeros_like(d) + op.collapse_sum_like(x1 * d, d) + op.collapse_sum_like(x1 * d, d)
body = Tuple([x, Tuple([grad])])
body = relay.Let(x1, o, body)
expected = Function([d], relay.Let(x, m, body))
assert alpha_equal(g, expected)
def test_if_ref():
shape = ()
dtype = "bool"
t = relay.TensorType(shape, dtype)
d = relay.Var("d", t)
r = relay.Var("r")
update = relay.Function([], relay.RefWrite(r, relay.RefRead(r) + relay.RefRead(r)))
u = relay.Var("u")
body = relay.If(d, u(), u())
eff = relay.Var("eff")
body = relay.Let(eff, body, relay.RefRead(r))
f = relay.Function([d], relay.Let(r, relay.RefCreate(relay.const(1)), relay.Let(u, update, body)))
t = TensorType(shape, dtype)
d = Var("d", t)
r = Var("r")
update = Function([], RefWrite(r, RefRead(r) + RefRead(r)))
u = Var("u")
body = If(d, u(), u())
eff = Var("eff")
body = Let(eff, body, RefRead(r))
f = Function([d], Let(r, RefCreate(const(1)), Let(u, update, body)))
f = infer_type(f)
pe_f = infer_type(partial_evaluate(f))
ex = create_executor()
f_res = ex.evaluate(f)(relay.const(True))
pe_f_res = ex.evaluate(pe_f)(relay.const(True))
f_res = ex.evaluate(f)(const(True))
pe_f_res = ex.evaluate(pe_f)(const(True))
np.testing.assert_allclose(f_res.asnumpy(), 2 * np.ones_like(f_res.asnumpy()))
np.testing.assert_allclose(pe_f_res.asnumpy(), 2 * np.ones_like(pe_f_res.asnumpy()))
@ -105,52 +119,162 @@ def test_if_ref():
def test_function_invalidate():
shape = ()
dtype = "bool"
t = relay.TensorType(shape, dtype)
d = relay.Var("d", t)
r = relay.Var("r")
fetch = relay.Function([], relay.RefRead(r))
fet = relay.Var("fetch")
fet_obscured = relay.Var("fetch_obscured")
u = relay.Var("u")
body = relay.If(d, fet_obscured(), fet_obscured())
body = relay.Let(u, relay.RefWrite(r, relay.const(1)), body)
body = relay.Let(fet_obscured, relay.If(d, fet, fet), body)
body = relay.Let(fet, fetch, body)
body = relay.Let(r, relay.RefCreate(relay.const(0)), body)
f = relay.Function([d], body)
t = TensorType(shape, dtype)
d = Var("d", t)
r = Var("r")
fetch = Function([], RefRead(r))
fet = Var("fetch")
fet_obscured = Var("fetch_obscured")
u = Var("u")
body = If(d, fet_obscured(), fet_obscured())
body = Let(u, RefWrite(r, const(1)), body)
body = Let(fet_obscured, If(d, fet, fet), body)
body = Let(fet, fetch, body)
body = Let(r, RefCreate(const(0)), body)
f = Function([d], body)
f = infer_type(f)
pe_f = infer_type(partial_evaluate(f))
ex = create_executor()
f_res = ex.evaluate(f)(relay.const(True))
pe_f_res = ex.evaluate(pe_f)(relay.const(True))
f_res = ex.evaluate(f)(const(True))
pe_f_res = ex.evaluate(pe_f)(const(True))
np.testing.assert_allclose(f_res.asnumpy(), np.ones_like(f_res.asnumpy()))
np.testing.assert_allclose(pe_f_res.asnumpy(), np.ones_like(pe_f_res.asnumpy()))
def test_head_cons():
mod = relay.Module()
mod = Module()
p = Prelude(mod)
def hd_impl():
a = relay.TypeVar("a")
x = relay.Var("x", p.l(a))
y = relay.Var("y")
z = relay.Var("z")
cons_case = relay.Clause(relay.PatternConstructor(p.cons,
[relay.PatternVar(y),
relay.PatternVar(z)]),
y)
return relay.Function([x], relay.Match(x, [cons_case]), a, [a])
t = relay.TypeVar("t")
x = relay.Var("x", t)
hd = relay.Var("hd")
body = relay.Let(hd, hd_impl(), hd(p.cons(x, p.nil())))
f = relay.Function([x], body, None, [t])
a = TypeVar("a")
x = Var("x", p.l(a))
y = Var("y")
z = Var("z")
cons_case = Clause(PatternConstructor(p.cons,
[PatternVar(y),
PatternVar(z)]),
y)
y = Var("y")
z = Var("z")
return Function([x], Match(x, [cons_case]), a, [a])
t = TypeVar("t")
x = Var("x", t)
hd = Var("hd")
body = Let(hd, hd_impl(), hd(p.cons(x, p.nil())))
f = Function([x], body, None, [t])
f = infer_type(f, mod=mod)
res = dcpe(f)
assert alpha_equal(res, relay.Function([x], x, t, [t]))
assert alpha_equal(res, Function([x], x, t, [t]))
def test_map():
mod = Module()
p = Prelude(mod)
f = Var("f")
orig = p.map(f, p.cons(const(1), p.cons(const(2), p.cons(const(3), p.nil()))))
expected = p.cons(f(const(1)), p.cons(f(const(2)), p.cons(f(const(3)), p.nil())))
assert alpha_equal(dcpe(orig, mod=mod), expected)
def test_loop():
mod = Module()
t = TypeVar("t")
x = Var("x", t)
loop = GlobalVar("loop")
mod[loop] = Function([x], loop(x), t, [t])
res = dcpe(loop(const(1)), mod=mod)
expected = Call(loop, [const(1)], None, [None])
assert alpha_equal(res, expected)
def test_swap_loop():
mod = Module()
p = Prelude(mod)
add_nat_definitions(p)
nat = p.nat()
x = Var("x", nat)
y = Var("y", nat)
loop = GlobalVar("loop")
mod[loop] = Function([x, y], loop(y, x), nat)
prog = loop(make_nat_expr(p, 1), make_nat_expr(p, 2))
res = dcpe(prog, mod=mod)
assert alpha_equal(prog, res)
def test_abs_diff():
# TODO(@M.K.): refactor using tuple pattern (not yet implemented)
mod = Module()
p = Prelude(mod)
add_nat_definitions(p)
nat = p.nat()
x = Var("x", nat)
y = Var("y", nat)
xp = Var("x'", nat)
yp = Var("y'", nat)
diff = GlobalVar("diff")
y_z_case = Clause(PatternConstructor(p.z, []), x)
y_s_case = Clause(PatternConstructor(p.s, [PatternVar(yp)]), diff(yp, xp))
x_z_case = Clause(PatternConstructor(p.z, []), y)
x_s_case = Clause(PatternConstructor(p.s, [PatternVar(xp)]), Match(y, [y_z_case, y_s_case]))
mod[diff] = Function([x, y], Match(x, [x_z_case, x_s_case]))
orig = diff(make_nat_expr(p, 7), make_nat_expr(p, 3))
res = dcpe(orig, mod=mod)
assert alpha_equal(res, make_nat_expr(p, 4))
def test_match_nat_id():
mod = Module()
p = Prelude(mod)
add_nat_definitions(p)
nat = p.nat()
x = Var("x", nat)
y = Var("y", nat)
nat_id = GlobalVar("nat_id")
z_case = Clause(PatternConstructor(p.z, []), p.z())
s_case = Clause(PatternConstructor(p.s, [PatternVar(y)]), p.s(y))
mod[nat_id] = Function([x], Match(x, [z_case, s_case]))
orig = nat_id(make_nat_expr(p, 3))
res = dcpe(orig, mod=mod)
assert alpha_equal(res, make_nat_expr(p, 3))
def test_nat_id():
mod = Module()
p = Prelude(mod)
add_nat_definitions(p)
nat = p.nat()
x = Var("x", nat)
y = Var("y", nat)
nat_id = GlobalVar("nat_id")
mod[nat_id] = Function([x], x)
orig = nat_id(make_nat_expr(p, 3))
res = dcpe(orig, mod=mod)
assert alpha_equal(res, make_nat_expr(p, 3))
def test_global_match_nat_id():
mod = Module()
p = Prelude(mod)
add_nat_definitions(p)
nat = p.nat()
x = Var("x", nat)
z_case = Clause(PatternConstructor(p.z, []), p.z())
s_case = Clause(PatternConstructor(p.s, [PatternVar(x)]), p.s(x))
orig = Match(make_nat_expr(p, 3), [z_case, s_case])
res = dcpe(orig, mod=mod)
assert alpha_equal(res, make_nat_expr(p, 3))
def test_double():
mod = Module()
p = Prelude(mod)
add_nat_definitions(p)
orig = p.double(make_nat_expr(p, 3))
res = dcpe(orig, mod=mod)
assert alpha_equal(res, make_nat_expr(p, 6))
if __name__ == '__main__':
test_empty_ad()
test_tuple()
test_const_inline()
test_ref()
@ -158,3 +282,11 @@ if __name__ == '__main__':
test_if_ref()
test_function_invalidate()
test_head_cons()
test_map()
test_loop()
test_swap_loop()
test_abs_diff()
test_double()
test_nat_id()
test_global_match_nat_id()
test_match_nat_id()