[PASS][FIX] Fix LiftAttrScope with if (#309)

* [PASS][FIX] Fix LiftAttrScope with if

* [PASS] Fix on proc sync

* fix
This commit is contained in:
Tianqi Chen 2017-08-10 18:38:09 -07:00 коммит произвёл GitHub
Родитель 19381b51ff
Коммит e4b500b608
3 изменённых файлов: 19 добавлений и 6 удалений

Просмотреть файл

@ -95,7 +95,7 @@ class AttrScopeLifter : public IRMutator {
}
Stmt Mutate_(const IfThenElse* op, const Stmt& s) final {
if (!op->then_case.defined()) {
if (!op->else_case.defined()) {
return IRMutator::Mutate_(op, s);
}
Stmt then_case = this->Mutate(op->then_case);

Просмотреть файл

@ -312,7 +312,7 @@ class CoProcTouchedBuffer : public IRVisitor {
IRVisitor::Visit_(op);
}
void Visit_(const Call* op) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
if (op->is_intrinsic(intrinsic::tvm_access_ptr) && in_scope_) {
const Variable* buffer = op->args[1].as<Variable>();
touched_.insert(buffer);
}
@ -321,6 +321,8 @@ class CoProcTouchedBuffer : public IRVisitor {
void Visit_(const AttrStmt* op) final {
if (op->attr_key == attr::coproc_scope && !in_scope_) {
in_scope_ = true;
IterVar iv(op->node.node_);
coproc_.insert(iv);
IRVisitor::Visit_(op);
in_scope_ = false;
} else {
@ -329,6 +331,7 @@ class CoProcTouchedBuffer : public IRVisitor {
}
std::unordered_set<const Variable*> touched_;
std::unordered_set<IterVar> coproc_;
private:
bool in_scope_{false};
@ -344,6 +347,11 @@ class CoProcSyncPlanner : public StorageAccessVisitor {
if (!touched_.empty()) {
this->Visit(stmt);
PlanWriteSync(scope_.back(), nullptr, true);
CHECK_EQ(visitor.coproc_.size(), 1U);
if (write_sync_.size() == 0) {
write_sync_[stmt.get()] = GetWriteSync(
(*visitor.coproc_.begin())->var->name_hint + ".coproc_sync");
}
}
}
@ -438,7 +446,10 @@ class CoProcSyncPlanner : public StorageAccessVisitor {
// Does not consider memory coherence, need runtime.
CHECK_NE(co_access.size(), 0U);
CHECK_EQ(co_access[0].threads.size(), 1U);
std::string sync_name = co_access[0].threads[0]->var->name_hint + ".coproc_sync";
return GetWriteSync(co_access[0].threads[0]->var->name_hint + ".coproc_sync");
}
std::vector<Stmt> GetWriteSync(std::string sync_name) {
std::vector<Stmt> stmts;
stmts.emplace_back(
Evaluate::make(Call::make(
@ -447,6 +458,7 @@ class CoProcSyncPlanner : public StorageAccessVisitor {
{}, Call::Intrinsic)));
return stmts;
}
std::unordered_set<const Variable*> touched_;
StorageScope global_scope_ = StorageScope::make("global");
};

Просмотреть файл

@ -11,9 +11,10 @@ def test_coproc_lift():
with ib.for_range(0, 10, name="j") as j:
ib.scope_attr(cp, "coproc_uop_scope", value)
A[i] = A[i] + 1
with ib.for_range(0, 10, name="j") as j:
ib.scope_attr(cp, "coproc_uop_scope", value)
A[j] = A[j] + 2
with ib.if_scope(i.equal(0)):
with ib.for_range(0, 10, name="j") as j:
ib.scope_attr(cp, "coproc_uop_scope", value)
A[j] = A[j] + 2
body = ib.get()
body = tvm.ir_pass.LiftAttrScope(body, "coproc_uop_scope")
assert body.body.body.node == cp