[Relay] use unordered_map instead of map in ANF (#3024)

This commit is contained in:
雾雨魔理沙 2019-04-15 12:56:31 -07:00 коммит произвёл Haichen Shen
Родитель 8d3b392da6
Коммит 5293c6bf66
1 изменённых файлов: 13 добавлений и 7 удалений

Просмотреть файл

@ -34,7 +34,9 @@
namespace tvm {
namespace relay {
Expr ToANormalForm(const Expr& e, const Module& m, std::set<GlobalVar>* gv);
Expr ToANormalForm(const Expr& e,
const Module& m,
std::unordered_set<GlobalVar, NodeHash, NodeEqual>* gv);
struct ScopeNode;
using Scope = std::shared_ptr<ScopeNode>;
@ -104,7 +106,7 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
const Module& m,
const DependencyGraph& dg,
std::unordered_map<DependencyGraph::Node*, Scope>* node_scope,
std::set<GlobalVar>* gv) {
std::unordered_set<GlobalVar, NodeHash, NodeEqual>* gv) {
Fill fi(m, dg, node_scope, gv);
return fi.GetScope(e)->ll->Get(fi.VisitExpr(e));
}
@ -113,13 +115,13 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
Module mod_;
const DependencyGraph& dg_;
std::unordered_map<DependencyGraph::Node*, Scope>* node_scope_;
std::set<GlobalVar>* visited_;
std::unordered_set<GlobalVar, NodeHash, NodeEqual>* visited_;
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo;
Fill(Module mod,
const DependencyGraph& dg,
std::unordered_map<DependencyGraph::Node*, Scope>* node_scope,
std::set<GlobalVar>* visited) :
std::unordered_set<GlobalVar, NodeHash, NodeEqual>* visited) :
mod_(mod),
dg_(dg),
node_scope_(node_scope),
@ -273,7 +275,9 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
}
};
Expr ToANormalFormAux(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
Expr ToANormalFormAux(const Expr& e,
const Module& m,
std::unordered_set<GlobalVar, NodeHash, NodeEqual>* gv) {
/* When you lift a lambda, what is inside is also being lift.
*
* So we must determine the scope of the lambda before determining the scope of it's body.
@ -299,12 +303,14 @@ Expr ToANormalFormAux(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
return Fill::ToANormalForm(e, m, dg, &node_scope, gv);
}
Expr ToANormalForm(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
Expr ToANormalForm(const Expr& e,
const Module& m,
std::unordered_set<GlobalVar, NodeHash, NodeEqual>* gv) {
return TransformF([&](const Expr& e) { return ToANormalFormAux(e, m, gv); }, e);
}
Expr ToANormalForm(const Expr& e, const Module& m) {
std::set<GlobalVar> gv;
std::unordered_set<GlobalVar, NodeHash, NodeEqual> gv;
return ToANormalForm(e, m, &gv);
}