[CODEGEN] Multiple parallel in one launch (#399)
This commit is contained in:
Родитель
ad8733ea14
Коммит
b03c324304
|
@ -404,7 +404,7 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) {
|
|||
std::swap(var_map_, new_vmap);
|
||||
std::swap(parallel_env_, par_env);
|
||||
std::swap(function_, f);
|
||||
CHECK(par_env.hit_parallel_loop)
|
||||
CHECK_NE(par_env.parallel_loop_count, 0)
|
||||
<< "Cannot find parallel loop within parallel launch";
|
||||
builder_->SetInsertPoint(par_launch_end);
|
||||
}
|
||||
|
@ -679,7 +679,7 @@ void CodeGenCPU::VisitStmt_(const AttrStmt* op) {
|
|||
} else if (pname == "parallel_barrier_when_finish") {
|
||||
CHECK(parallel_env_.penv != nullptr)
|
||||
<< "Cannot run barrier without parallel environment";
|
||||
CHECK(!parallel_env_.hit_parallel_loop)
|
||||
CHECK(!parallel_env_.in_parallel_loop)
|
||||
<< "Cannot not place within parallel loop as the workload may differ, "
|
||||
<< " place it between parallel and parallel_launch_point";
|
||||
this->VisitStmt(op->body);
|
||||
|
@ -713,9 +713,9 @@ void CodeGenCPU::VisitStmt_(const For* op) {
|
|||
Type t = op->extent.type();
|
||||
Expr num_task = cast(t, parallel_env_.num_task);
|
||||
Expr task_id = cast(t, parallel_env_.task_id);
|
||||
CHECK(!parallel_env_.hit_parallel_loop)
|
||||
CHECK(!parallel_env_.in_parallel_loop)
|
||||
<< "Nested parallel loop is not supported by threadpool, try fuse them instead";
|
||||
parallel_env_.hit_parallel_loop = true;
|
||||
parallel_env_.in_parallel_loop = true;
|
||||
if (parallel_env_.stride_pattern) {
|
||||
CreateSerialFor(MakeValue(task_id),
|
||||
MakeValue(op->extent),
|
||||
|
@ -732,6 +732,8 @@ void CodeGenCPU::VisitStmt_(const For* op) {
|
|||
op->loop_var,
|
||||
op->body);
|
||||
}
|
||||
parallel_env_.in_parallel_loop = false;
|
||||
++parallel_env_.parallel_loop_count;
|
||||
}
|
||||
} else {
|
||||
LOG(FATAL) << "cannot handle for type " << op->for_type;
|
||||
|
|
|
@ -24,7 +24,6 @@ class CodeGenCPU : public CodeGenLLVM {
|
|||
bool dynamic_lookup) override;
|
||||
void AddFunction(const LoweredFunc& f) override;
|
||||
void AddMainFunction(const std::string& entry_func_name) override;
|
||||
|
||||
void VisitStmt_(const AssertStmt* op) override;
|
||||
void VisitStmt_(const AttrStmt* op) override;
|
||||
void VisitStmt_(const For* op) override;
|
||||
|
@ -60,7 +59,8 @@ class CodeGenCPU : public CodeGenLLVM {
|
|||
VarExpr task_id;
|
||||
VarExpr num_task;
|
||||
bool stride_pattern{false};
|
||||
bool hit_parallel_loop{false};
|
||||
bool in_parallel_loop{false};
|
||||
int parallel_loop_count{0};
|
||||
llvm::Value* penv{nullptr};
|
||||
};
|
||||
// Get runtime functions
|
||||
|
|
|
@ -138,9 +138,6 @@ class LinearAccessPatternFinder final : public IRVisitor {
|
|||
in_thread_env_ = true;
|
||||
VisitNewScope(op);
|
||||
in_thread_env_ = false;
|
||||
} else if (op->attr_key == attr::pragma_scope &&
|
||||
op->value.as<StringImm>()->value == "parallel_launch_point") {
|
||||
VisitNewScope(op);
|
||||
} else if (op->attr_key == attr::storage_scope) {
|
||||
const Variable* buf = op->node.as<Variable>();
|
||||
storage_scope_[buf] =
|
||||
|
|
|
@ -61,6 +61,36 @@ def test_llvm_add_pipeline():
|
|||
check_llvm()
|
||||
|
||||
|
||||
def test_llvm_persist_parallel():
|
||||
n = 128
|
||||
A = tvm.placeholder((n,), name='A')
|
||||
B = tvm.compute(A.shape, lambda *i: A(*i) + 1, name='B')
|
||||
C = tvm.compute(A.shape, lambda *i: B(*i) + 2, name='C')
|
||||
s = tvm.create_schedule(C.op)
|
||||
xo, xi = s[C].split(C.op.axis[0], factor=8)
|
||||
xo1, xo2 = s[C].split(xo, nparts=1)
|
||||
s[B].compute_at(s[C], xo1)
|
||||
s[B].parallel(s[B].op.axis[0])
|
||||
s[B].pragma(s[B].op.axis[0], "parallel_barrier_when_finish")
|
||||
s[C].parallel(xi)
|
||||
s[C].pragma(xo1, "parallel_launch_point")
|
||||
s[C].pragma(xi, "parallel_stride_pattern")
|
||||
|
||||
def check_llvm():
|
||||
if not tvm.module.enabled("llvm"):
|
||||
return
|
||||
# BUILD and invoke the kernel.
|
||||
f = tvm.build(s, [A, C], "llvm")
|
||||
ctx = tvm.cpu(0)
|
||||
# launch the kernel.
|
||||
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
|
||||
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
|
||||
f(a, c)
|
||||
np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + 3)
|
||||
|
||||
check_llvm()
|
||||
|
||||
|
||||
def test_llvm_flip_pipeline():
|
||||
def check_llvm(nn, base):
|
||||
if not tvm.module.enabled("llvm"):
|
||||
|
@ -222,6 +252,7 @@ def test_llvm_select():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_llvm_persist_parallel()
|
||||
test_llvm_select()
|
||||
test_llvm_vadd_pipeline()
|
||||
test_llvm_add_pipeline()
|
||||
|
|
|
@ -121,7 +121,7 @@ def test_parallel_alloc():
|
|||
A[j] = A[j] + 2
|
||||
body = ib.get()
|
||||
body = tvm.ir_pass.StorageRewrite(body)
|
||||
assert(isinstance(body.body.body.body, tvm.stmt.Allocate))
|
||||
assert(isinstance(body.body.body.body.body, tvm.stmt.Allocate))
|
||||
|
||||
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче