[TOPI] improve elemwise schedule (#393)
* [TOPI] improve elemwise schedule * modify fuse
This commit is contained in:
Родитель
0560e1569e
Коммит
7ad3c51e61
|
@ -23,17 +23,10 @@ def schedule_elemwise(outs):
|
||||||
|
|
||||||
x = outs[0]
|
x = outs[0]
|
||||||
num_dim = len(x.shape)
|
num_dim = len(x.shape)
|
||||||
block_factor = tvm.ir_pass.Simplify(x.op.output(0).shape[num_dim-1]).value
|
fused = s[x].fuse(*x.op.axis)
|
||||||
if block_factor % 48 == 0:
|
num_thread = 64
|
||||||
block_factor = 48
|
bx, tx = s[x].split(fused, factor=num_thread)
|
||||||
elif block_factor % 32 == 0:
|
|
||||||
block_factor = 32
|
|
||||||
bx, tx = s[x].split(x.op.axis[num_dim-1], factor=block_factor)
|
|
||||||
for i in range(num_dim-2, 0, -1):
|
|
||||||
bx = s[x].fuse(bx, x.op.axis[i])
|
|
||||||
|
|
||||||
s[x].bind(bx, tvm.thread_axis("blockIdx.x"))
|
s[x].bind(bx, tvm.thread_axis("blockIdx.x"))
|
||||||
s[x].bind(tx, tvm.thread_axis("threadIdx.x"))
|
s[x].bind(tx, tvm.thread_axis("threadIdx.x"))
|
||||||
|
|
||||||
|
|
||||||
return s
|
return s
|
||||||
|
|
Загрузка…
Ссылка в новой задаче