fix copro_sync.cc errors of ctx (#1274)
This commit is contained in:
Родитель
fed6298b7b
Коммит
c55865b46a
|
@ -385,7 +385,7 @@ class CoProcInstDepDetector : public IRVisitor {
|
|||
&(curr_state_.exit_push),
|
||||
&(curr_state_.enter_pop));
|
||||
curr_state_.enter_ctx = first_state_.enter_ctx;
|
||||
curr_state_.exit_ctx = last_state_.enter_ctx;
|
||||
curr_state_.exit_ctx = last_state_.exit_ctx;
|
||||
}
|
||||
std::swap(first_state_, temp_first);
|
||||
std::swap(last_state_, temp_last);
|
||||
|
|
|
@ -78,7 +78,44 @@ def test_coproc_sync2():
|
|||
stmt = ib.get()
|
||||
stmt = tvm.ir_pass.CoProcSync(stmt)
|
||||
|
||||
def test_coproc_sync3():
|
||||
def __check_list(tvm_array, py_list):
|
||||
for ti, li in zip(tvm_array, py_list):
|
||||
if ti.value != li:
|
||||
return False
|
||||
return True
|
||||
|
||||
ib = tvm.ir_builder.create()
|
||||
n = tvm.var("n")
|
||||
cp = tvm.thread_axis((0, 1), "cop")
|
||||
A = ib.allocate("float32", 128, name="A", scope="global.cache")
|
||||
with ib.for_range(0, n, name="i") as i:
|
||||
with ib.for_range(0, n, name="i") as j:
|
||||
with ib.new_scope():
|
||||
ib.scope_attr(cp, "coproc_scope", 1)
|
||||
A[i] = 1.0
|
||||
with ib.new_scope():
|
||||
ib.scope_attr(cp, "coproc_scope", 2)
|
||||
A[i] = 1.0
|
||||
with ib.new_scope():
|
||||
ib.scope_attr(cp, "coproc_scope", 3)
|
||||
A[0] = 0.0
|
||||
|
||||
stmt = ib.get()
|
||||
stmt = tvm.ir_pass.CoProcSync(stmt)
|
||||
slist = tvm.make.stmt_list(stmt.first.body.body)
|
||||
push_st = slist[2]
|
||||
slist = tvm.make.stmt_list(slist[-1])
|
||||
pop_st = slist[0].body.first
|
||||
|
||||
assert(push_st.value.name == "cop.coproc_dep_push")
|
||||
assert(__check_list(push_st.value.args, [2,3]))
|
||||
assert(pop_st.value.name == "cop.coproc_dep_pop")
|
||||
assert(__check_list(pop_st.value.args, [2,3]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_coproc_sync()
|
||||
test_storage_sync()
|
||||
test_coproc_sync2()
|
||||
test_coproc_sync3()
|
||||
|
|
Загрузка…
Ссылка в новой задаче