[PASS] Update coproc sync (#634)
This commit is contained in:
Родитель
32b0fff2ea
Коммит
f1aabedc9e
|
@ -201,7 +201,8 @@ def lower(sch,
|
|||
add_lower_pass = cfg.add_lower_pass if cfg.add_lower_pass else []
|
||||
lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0]
|
||||
lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1]
|
||||
lower_phase2 = [x[1] for x in add_lower_pass if x[0] > 1]
|
||||
lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2]
|
||||
lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2]
|
||||
# normalize schedule first
|
||||
sch = sch.normalize()
|
||||
# Phase 0
|
||||
|
@ -213,6 +214,9 @@ def lower(sch,
|
|||
# Phase 1
|
||||
stmt = ir_pass.StorageFlatten(stmt, binds, 64)
|
||||
stmt = ir_pass.CanonicalSimplify(stmt)
|
||||
for f in lower_phase1:
|
||||
stmt = f(stmt)
|
||||
# Phase 2
|
||||
if not simple_mode:
|
||||
stmt = ir_pass.LoopPartition(stmt)
|
||||
stmt = ir_pass.VectorizeLoop(stmt)
|
||||
|
@ -224,14 +228,14 @@ def lower(sch,
|
|||
cfg.auto_unroll_max_step,
|
||||
cfg.auto_unroll_max_depth,
|
||||
cfg.unroll_explicit)
|
||||
for f in lower_phase1:
|
||||
for f in lower_phase2:
|
||||
stmt = f(stmt)
|
||||
# Phase 2
|
||||
stmt = ir_pass.Simplify(stmt)
|
||||
stmt = ir_pass.LowerStorageAccessInfo(stmt)
|
||||
stmt = ir_pass.RemoveNoOp(stmt)
|
||||
stmt = ir_pass.RewriteUnsafeSelect(stmt)
|
||||
for f in lower_phase2:
|
||||
for f in lower_phase3:
|
||||
stmt = f(stmt)
|
||||
if simple_mode:
|
||||
return stmt
|
||||
|
|
|
@ -338,6 +338,256 @@ class CoProcBarrierDetector : public StorageAccessVisitor {
|
|||
};
|
||||
|
||||
|
||||
class CoProcInstDepDetector : public IRVisitor {
|
||||
public:
|
||||
explicit CoProcInstDepDetector(
|
||||
const IterVar& coproc_axis,
|
||||
const std::string& coproc_name)
|
||||
: coproc_axis_(coproc_axis) {
|
||||
sync_push_name_ = coproc_name + ".coproc_dep_push";
|
||||
sync_pop_name_ = coproc_name + ".coproc_dep_pop";
|
||||
}
|
||||
|
||||
void Plan(Stmt stmt) {
|
||||
this->Visit(stmt);
|
||||
if (last_state_.node != nullptr) {
|
||||
MatchFixEnterPop(first_state_);
|
||||
MatchFixExitPush(last_state_);
|
||||
}
|
||||
}
|
||||
|
||||
void Visit_(const AttrStmt* op) final {
|
||||
if (op->attr_key == attr::coproc_scope &&
|
||||
op->node.same_as(coproc_axis_)) {
|
||||
const IntImm* ctx_id = op->value.as<IntImm>();
|
||||
CHECK(ctx_id != nullptr);
|
||||
curr_state_.clear();
|
||||
curr_state_.node = op->body.get();
|
||||
curr_state_.enter_ctx.insert(ctx_id->value);
|
||||
curr_state_.exit_ctx.insert(ctx_id->value);
|
||||
UpdateState();
|
||||
} else {
|
||||
IRVisitor::Visit_(op);
|
||||
}
|
||||
}
|
||||
|
||||
void Visit_(const For* op) final {
|
||||
SyncState temp_first, temp_last;
|
||||
std::swap(first_state_, temp_first);
|
||||
std::swap(last_state_, temp_last);
|
||||
this->Visit(op->body);
|
||||
curr_state_.clear();
|
||||
if (last_state_.node != nullptr) {
|
||||
curr_state_.node = op;
|
||||
CHECK(first_state_.node != nullptr);
|
||||
// loop carry dependency
|
||||
InjectSync(last_state_, first_state_,
|
||||
&(curr_state_.exit_push),
|
||||
&(curr_state_.enter_pop));
|
||||
curr_state_.enter_ctx = first_state_.enter_ctx;
|
||||
curr_state_.exit_ctx = last_state_.enter_ctx;
|
||||
}
|
||||
std::swap(first_state_, temp_first);
|
||||
std::swap(last_state_, temp_last);
|
||||
if (curr_state_.node != nullptr) {
|
||||
UpdateState();
|
||||
}
|
||||
}
|
||||
|
||||
void Visit_(const IfThenElse* op) final {
|
||||
SyncState temp_first, temp_last, curr_state;
|
||||
std::swap(first_state_, temp_first);
|
||||
std::swap(last_state_, temp_last);
|
||||
{
|
||||
// then stmt
|
||||
this->Visit(op->then_case);
|
||||
if (last_state_.node != nullptr) {
|
||||
curr_state.node = op;
|
||||
MatchFixEnterPop(first_state_);
|
||||
MatchFixExitPush(last_state_);
|
||||
curr_state.enter_ctx.insert(
|
||||
first_state_.enter_ctx.begin(),
|
||||
first_state_.enter_ctx.end());
|
||||
curr_state.exit_ctx.insert(
|
||||
last_state_.exit_ctx.begin(),
|
||||
last_state_.exit_ctx.end());
|
||||
}
|
||||
first_state_.clear();
|
||||
last_state_.clear();
|
||||
}
|
||||
if (op->else_case.defined()) {
|
||||
this->Visit(op->else_case);
|
||||
if (last_state_.node != nullptr) {
|
||||
curr_state.node = op;
|
||||
MatchFixEnterPop(first_state_);
|
||||
MatchFixExitPush(last_state_);
|
||||
curr_state.enter_ctx.insert(
|
||||
first_state_.enter_ctx.begin(),
|
||||
first_state_.enter_ctx.end());
|
||||
curr_state.exit_ctx.insert(
|
||||
last_state_.exit_ctx.begin(),
|
||||
last_state_.exit_ctx.end());
|
||||
}
|
||||
}
|
||||
// update in the trace.
|
||||
std::swap(first_state_, temp_first);
|
||||
std::swap(last_state_, temp_last);
|
||||
std::swap(curr_state_, curr_state);
|
||||
if (curr_state_.node != nullptr) {
|
||||
UpdateState();
|
||||
}
|
||||
}
|
||||
|
||||
// insert before is stored in reverse order
|
||||
// the first element is closest to the node.
|
||||
std::unordered_map<const Node*, std::vector<Stmt> > insert_before_;
|
||||
std::unordered_map<const Node*, std::vector<Stmt> > insert_after_;
|
||||
|
||||
private:
|
||||
// state in the sync entry
|
||||
struct SyncState {
|
||||
// The statement of the state.
|
||||
const Node* node{nullptr};
|
||||
// Set of all possible contexts in the entering moment.
|
||||
std::unordered_set<int> enter_ctx;
|
||||
// Set of all possible contexts in the exit moment.
|
||||
std::unordered_set<int> exit_ctx;
|
||||
// existing pop performed at enter
|
||||
std::vector<std::pair<int, int> > enter_pop;
|
||||
// existing push peformed at exit
|
||||
std::vector<std::pair<int, int> > exit_push;
|
||||
// clear the state
|
||||
void clear() {
|
||||
node = nullptr;
|
||||
enter_ctx.clear();
|
||||
exit_ctx.clear();
|
||||
enter_pop.clear();
|
||||
exit_push.clear();
|
||||
}
|
||||
};
|
||||
// inject proper sync into the pair
|
||||
// record the push/pop sequence that could be possibly un-matched.
|
||||
// return the push/pop message at enter/exit of the Block
|
||||
// after considering the existing unmatcheded events and added events
|
||||
void InjectSync(const SyncState& prev,
|
||||
const SyncState& next,
|
||||
std::vector<std::pair<int, int> >* prev_exit_push,
|
||||
std::vector<std::pair<int, int> >* next_enter_pop) {
|
||||
prev_exit_push->clear();
|
||||
next_enter_pop->clear();
|
||||
// quick path
|
||||
if (prev.exit_push.size() == 0 && next.enter_pop.size() == 0 &&
|
||||
prev.exit_ctx.size() == 1 && next.enter_ctx.size() == 1) {
|
||||
int from = *prev.exit_ctx.begin();
|
||||
int to = *next.enter_ctx.begin();
|
||||
if (from != to) {
|
||||
insert_after_[prev.node].emplace_back(MakePush(from, to));
|
||||
insert_before_[next.node].emplace_back(MakePop(from, to));
|
||||
prev_exit_push->emplace_back(std::make_pair(from, to));
|
||||
next_enter_pop->emplace_back(std::make_pair(from, to));
|
||||
}
|
||||
return;
|
||||
}
|
||||
// complicate path.
|
||||
std::vector<std::pair<int, int> > vpush = prev.exit_push;
|
||||
std::vector<std::pair<int, int> > vpop = next.enter_pop;
|
||||
std::vector<std::pair<int, int> > pending;
|
||||
for (int from : prev.exit_ctx) {
|
||||
for (int to : next.enter_ctx) {
|
||||
if (from != to) {
|
||||
pending.emplace_back(std::make_pair(from, to));
|
||||
}
|
||||
}
|
||||
}
|
||||
// policy 1
|
||||
std::vector<Stmt> prev_after, next_before;
|
||||
for (const std::pair<int, int>& p : pending) {
|
||||
if (std::find(prev.exit_push.begin(),
|
||||
prev.exit_push.end(), p) ==
|
||||
prev.exit_push.end()) {
|
||||
vpush.push_back(p);
|
||||
prev_after.emplace_back(MakePush(p.first, p.second));
|
||||
}
|
||||
if (std::find(next.enter_pop.begin(),
|
||||
next.enter_pop.end(), p) ==
|
||||
next.enter_pop.end()) {
|
||||
vpop.push_back(p);
|
||||
next_before.emplace_back(MakePop(p.first, p.second));
|
||||
}
|
||||
}
|
||||
// fix pending
|
||||
for (const std::pair<int, int>& p : vpush) {
|
||||
if (std::find(vpop.begin(), vpop.end(), p) == vpop.end()) {
|
||||
prev_after.emplace_back(MakePop(p.first, p.second));
|
||||
} else {
|
||||
prev_exit_push->push_back(p);
|
||||
}
|
||||
}
|
||||
for (const std::pair<int, int>& p : vpop) {
|
||||
if (std::find(vpush.begin(), vpush.end(), p) == vpush.end()) {
|
||||
next_before.emplace_back(MakePush(p.first, p.second));
|
||||
} else {
|
||||
next_enter_pop->push_back(p);
|
||||
}
|
||||
}
|
||||
if (prev_after.size() != 0) {
|
||||
auto &v1 = insert_after_[prev.node];
|
||||
v1.insert(v1.end(), prev_after.begin(), prev_after.end());
|
||||
}
|
||||
if (next_before.size() != 0) {
|
||||
auto &v2 = insert_before_[next.node];
|
||||
v2.insert(v2.end(), next_before.begin(), next_before.end());
|
||||
}
|
||||
}
|
||||
|
||||
void MatchFixEnterPop(const SyncState& state) {
|
||||
if (state.enter_pop.size() == 0) return;
|
||||
auto &vec = insert_before_[state.node];
|
||||
for (const std::pair<int, int>& p : state.enter_pop) {
|
||||
vec.push_back(MakePush(p.first, p.second));
|
||||
}
|
||||
}
|
||||
|
||||
void MatchFixExitPush(const SyncState& state) {
|
||||
if (state.exit_push.size() == 0) return;
|
||||
auto &vec = insert_after_[state.node];
|
||||
for (const std::pair<int, int>& p : state.exit_push) {
|
||||
vec.push_back(MakePop(p.first, p.second));
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateState() {
|
||||
if (last_state_.node != nullptr) {
|
||||
std::vector<std::pair<int, int> > t1, t2;
|
||||
InjectSync(last_state_, curr_state_, &t1, &t2);
|
||||
std::swap(last_state_, curr_state_);
|
||||
} else {
|
||||
CHECK(first_state_.node == nullptr);
|
||||
first_state_ = curr_state_;
|
||||
last_state_ = curr_state_;
|
||||
}
|
||||
}
|
||||
|
||||
Stmt MakePush(int from, int to) {
|
||||
return Evaluate::make(Call::make(
|
||||
Int(32), sync_push_name_,
|
||||
{make_const(Int(32), from), make_const(Int(32), to)},
|
||||
Call::Intrinsic));
|
||||
}
|
||||
Stmt MakePop(int from, int to) {
|
||||
return Evaluate::make(Call::make(
|
||||
Int(32), sync_pop_name_,
|
||||
{make_const(Int(32), from), make_const(Int(32), to)},
|
||||
Call::Intrinsic));
|
||||
}
|
||||
// sync states.
|
||||
SyncState first_state_, last_state_, curr_state_;
|
||||
// Variables
|
||||
IterVar coproc_axis_;
|
||||
std::string sync_push_name_, sync_pop_name_;
|
||||
};
|
||||
|
||||
|
||||
class CoProcSyncInserter : public IRMutator {
|
||||
public:
|
||||
Stmt Insert(Stmt stmt) {
|
||||
|
@ -372,6 +622,18 @@ class CoProcSyncInserter : public IRMutator {
|
|||
auto& vec = insert_after_[kv.first];
|
||||
vec.insert(vec.end(), kv.second.begin(), kv.second.end());
|
||||
}
|
||||
// Detect barrier
|
||||
CoProcInstDepDetector sync_detector(
|
||||
*visitor.coproc_.begin(), coproc_name);
|
||||
sync_detector.Plan(stmt);
|
||||
for (const auto& kv : sync_detector.insert_before_) {
|
||||
auto& vec = insert_before_[kv.first];
|
||||
vec.insert(vec.end(), kv.second.begin(), kv.second.end());
|
||||
}
|
||||
for (const auto& kv : sync_detector.insert_after_) {
|
||||
auto& vec = insert_after_[kv.first];
|
||||
vec.insert(vec.end(), kv.second.begin(), kv.second.end());
|
||||
}
|
||||
return Mutate(stmt);
|
||||
}
|
||||
|
||||
|
@ -379,7 +641,8 @@ class CoProcSyncInserter : public IRMutator {
|
|||
Stmt before, after;
|
||||
auto it = insert_before_.find(stmt.get());
|
||||
if (it != insert_before_.end()) {
|
||||
before = MergeSeq(it->second);
|
||||
before = MergeSeq(std::vector<Stmt>(
|
||||
it->second.rbegin(), it->second.rend()));
|
||||
}
|
||||
it = insert_after_.find(stmt.get());
|
||||
if (it != insert_after_.end()) {
|
||||
|
@ -396,10 +659,13 @@ class CoProcSyncInserter : public IRMutator {
|
|||
}
|
||||
|
||||
private:
|
||||
// insert before is stored in reverse order
|
||||
// the first element is closest to the node.
|
||||
std::unordered_map<const Node*, std::vector<Stmt> > insert_before_;
|
||||
std::unordered_map<const Node*, std::vector<Stmt> > insert_after_;
|
||||
};
|
||||
|
||||
|
||||
Stmt CoProcSync(Stmt stmt) {
|
||||
return CoProcSyncInserter().Insert(stmt);
|
||||
}
|
||||
|
|
|
@ -189,7 +189,7 @@ class StoragePlanRewriter : public IRMutator {
|
|||
if (attach_map_.count(nullptr)) {
|
||||
std::vector<Stmt> nest;
|
||||
for (StorageEntry* e : attach_map_.at(nullptr)) {
|
||||
CHECK_EQ(e->scope.rank, 0);
|
||||
// CHECK_EQ(e->scope.rank, 0);
|
||||
if (e->new_alloc.defined()) {
|
||||
nest.emplace_back(AttrStmt::make(
|
||||
e->alloc_var, attr::storage_scope,
|
||||
|
@ -395,6 +395,12 @@ class StoragePlanRewriter : public IRMutator {
|
|||
e->new_alloc = Allocate::make(
|
||||
e->alloc_var, alloc_type, e->allocs[0]->extents,
|
||||
e->allocs[0]->condition, Evaluate::make(0));
|
||||
if (e->scope.tag.length() != 0) {
|
||||
MemoryInfo info = GetMemoryInfo(e->scope.to_string());
|
||||
uint64_t total_elem = e->const_nbits / e->elem_type.bits();
|
||||
CHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits)
|
||||
<< "Allocation exceed bound of memory tag " << e->scope.to_string();
|
||||
}
|
||||
} else {
|
||||
// Build a merged allocation
|
||||
Expr combo_size;
|
||||
|
|
|
@ -71,7 +71,7 @@ struct ThreadScope {
|
|||
*/
|
||||
static ThreadScope make(const std::string& s) {
|
||||
ThreadScope r;
|
||||
if (s == "vthread") {
|
||||
if (s == "vthread" || s == "cthread") {
|
||||
// virtual thread at the same level as local
|
||||
r.rank = 1;
|
||||
r.dim_index = -1;
|
||||
|
|
|
@ -58,6 +58,27 @@ def test_coproc_sync():
|
|||
assert(blist[-1].value.args[3].value == 10)
|
||||
|
||||
|
||||
def test_coproc_sync2():
|
||||
ib = tvm.ir_builder.create()
|
||||
n = tvm.var("n")
|
||||
cp = tvm.thread_axis((0, 1), "cop")
|
||||
ty = tvm.thread_axis("cthread")
|
||||
A = ib.allocate("float32", 128, name="A")
|
||||
ib.scope_attr(ty, "virtual_thread", 2)
|
||||
with ib.new_scope():
|
||||
ib.scope_attr(cp, "coproc_scope", 2)
|
||||
A[ty] = 0.0
|
||||
with ib.for_range(0, n, name="i") as i:
|
||||
with ib.new_scope():
|
||||
ib.scope_attr(cp, "coproc_scope", 1)
|
||||
A[ty] = 1.0
|
||||
with ib.new_scope():
|
||||
ib.scope_attr(cp, "coproc_scope", 2)
|
||||
A[ty] = 1.0
|
||||
stmt = ib.get()
|
||||
stmt = tvm.ir_pass.CoProcSync(stmt)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_coproc_sync()
|
||||
test_storage_sync()
|
||||
test_coproc_sync2()
|
||||
|
|
Загрузка…
Ссылка в новой задаче