[PASS][FIX] Fix LiftAttrScope with if (#309)
* [PASS][FIX] Fix LiftAttrScope with if * [PASS] Fix on proc sync * fix
This commit is contained in:
Родитель
19381b51ff
Коммит
e4b500b608
|
@ -95,7 +95,7 @@ class AttrScopeLifter : public IRMutator {
|
|||
}
|
||||
|
||||
Stmt Mutate_(const IfThenElse* op, const Stmt& s) final {
|
||||
if (!op->then_case.defined()) {
|
||||
if (!op->else_case.defined()) {
|
||||
return IRMutator::Mutate_(op, s);
|
||||
}
|
||||
Stmt then_case = this->Mutate(op->then_case);
|
||||
|
|
|
@ -312,7 +312,7 @@ class CoProcTouchedBuffer : public IRVisitor {
|
|||
IRVisitor::Visit_(op);
|
||||
}
|
||||
void Visit_(const Call* op) final {
|
||||
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
|
||||
if (op->is_intrinsic(intrinsic::tvm_access_ptr) && in_scope_) {
|
||||
const Variable* buffer = op->args[1].as<Variable>();
|
||||
touched_.insert(buffer);
|
||||
}
|
||||
|
@ -321,6 +321,8 @@ class CoProcTouchedBuffer : public IRVisitor {
|
|||
void Visit_(const AttrStmt* op) final {
|
||||
if (op->attr_key == attr::coproc_scope && !in_scope_) {
|
||||
in_scope_ = true;
|
||||
IterVar iv(op->node.node_);
|
||||
coproc_.insert(iv);
|
||||
IRVisitor::Visit_(op);
|
||||
in_scope_ = false;
|
||||
} else {
|
||||
|
@ -329,6 +331,7 @@ class CoProcTouchedBuffer : public IRVisitor {
|
|||
}
|
||||
|
||||
std::unordered_set<const Variable*> touched_;
|
||||
std::unordered_set<IterVar> coproc_;
|
||||
|
||||
private:
|
||||
bool in_scope_{false};
|
||||
|
@ -344,6 +347,11 @@ class CoProcSyncPlanner : public StorageAccessVisitor {
|
|||
if (!touched_.empty()) {
|
||||
this->Visit(stmt);
|
||||
PlanWriteSync(scope_.back(), nullptr, true);
|
||||
CHECK_EQ(visitor.coproc_.size(), 1U);
|
||||
if (write_sync_.size() == 0) {
|
||||
write_sync_[stmt.get()] = GetWriteSync(
|
||||
(*visitor.coproc_.begin())->var->name_hint + ".coproc_sync");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -438,7 +446,10 @@ class CoProcSyncPlanner : public StorageAccessVisitor {
|
|||
// Does not consider memory coherence, need runtime.
|
||||
CHECK_NE(co_access.size(), 0U);
|
||||
CHECK_EQ(co_access[0].threads.size(), 1U);
|
||||
std::string sync_name = co_access[0].threads[0]->var->name_hint + ".coproc_sync";
|
||||
return GetWriteSync(co_access[0].threads[0]->var->name_hint + ".coproc_sync");
|
||||
}
|
||||
|
||||
std::vector<Stmt> GetWriteSync(std::string sync_name) {
|
||||
std::vector<Stmt> stmts;
|
||||
stmts.emplace_back(
|
||||
Evaluate::make(Call::make(
|
||||
|
@ -447,6 +458,7 @@ class CoProcSyncPlanner : public StorageAccessVisitor {
|
|||
{}, Call::Intrinsic)));
|
||||
return stmts;
|
||||
}
|
||||
|
||||
std::unordered_set<const Variable*> touched_;
|
||||
StorageScope global_scope_ = StorageScope::make("global");
|
||||
};
|
||||
|
|
|
@ -11,9 +11,10 @@ def test_coproc_lift():
|
|||
with ib.for_range(0, 10, name="j") as j:
|
||||
ib.scope_attr(cp, "coproc_uop_scope", value)
|
||||
A[i] = A[i] + 1
|
||||
with ib.for_range(0, 10, name="j") as j:
|
||||
ib.scope_attr(cp, "coproc_uop_scope", value)
|
||||
A[j] = A[j] + 2
|
||||
with ib.if_scope(i.equal(0)):
|
||||
with ib.for_range(0, 10, name="j") as j:
|
||||
ib.scope_attr(cp, "coproc_uop_scope", value)
|
||||
A[j] = A[j] + 2
|
||||
body = ib.get()
|
||||
body = tvm.ir_pass.LiftAttrScope(body, "coproc_uop_scope")
|
||||
assert body.body.body.node == cp
|
||||
|
|
Загрузка…
Ссылка в новой задаче