fix copro_sync.cc errors of ctx (#1274)

This commit is contained in:
libing4752 2018-06-14 01:53:04 +08:00 коммит произвёл Tianqi Chen
Родитель fed6298b7b
Коммит c55865b46a
2 изменённых файлов: 38 добавлений и 1 удалений

Просмотреть файл

@ -385,7 +385,7 @@ class CoProcInstDepDetector : public IRVisitor {
&(curr_state_.exit_push),
&(curr_state_.enter_pop));
curr_state_.enter_ctx = first_state_.enter_ctx;
curr_state_.exit_ctx = last_state_.enter_ctx;
curr_state_.exit_ctx = last_state_.exit_ctx;
}
std::swap(first_state_, temp_first);
std::swap(last_state_, temp_last);

Просмотреть файл

@ -78,7 +78,44 @@ def test_coproc_sync2():
stmt = ib.get()
stmt = tvm.ir_pass.CoProcSync(stmt)
def test_coproc_sync3():
def __check_list(tvm_array, py_list):
for ti, li in zip(tvm_array, py_list):
if ti.value != li:
return False
return True
ib = tvm.ir_builder.create()
n = tvm.var("n")
cp = tvm.thread_axis((0, 1), "cop")
A = ib.allocate("float32", 128, name="A", scope="global.cache")
with ib.for_range(0, n, name="i") as i:
with ib.for_range(0, n, name="i") as j:
with ib.new_scope():
ib.scope_attr(cp, "coproc_scope", 1)
A[i] = 1.0
with ib.new_scope():
ib.scope_attr(cp, "coproc_scope", 2)
A[i] = 1.0
with ib.new_scope():
ib.scope_attr(cp, "coproc_scope", 3)
A[0] = 0.0
stmt = ib.get()
stmt = tvm.ir_pass.CoProcSync(stmt)
slist = tvm.make.stmt_list(stmt.first.body.body)
push_st = slist[2]
slist = tvm.make.stmt_list(slist[-1])
pop_st = slist[0].body.first
assert(push_st.value.name == "cop.coproc_dep_push")
assert(__check_list(push_st.value.args, [2,3]))
assert(pop_st.value.name == "cop.coproc_dep_pop")
assert(__check_list(pop_st.value.args, [2,3]))
if __name__ == "__main__":
test_coproc_sync()
test_storage_sync()
test_coproc_sync2()
test_coproc_sync3()