[TOPI] improve elemwise schedule (#393)
* [TOPI] improve elemwise schedule * modify fuse
This commit is contained in:
Родитель
0560e1569e
Коммит
7ad3c51e61
|
@ -23,17 +23,10 @@ def schedule_elemwise(outs):
|
|||
|
||||
x = outs[0]
|
||||
num_dim = len(x.shape)
|
||||
block_factor = tvm.ir_pass.Simplify(x.op.output(0).shape[num_dim-1]).value
|
||||
if block_factor % 48 == 0:
|
||||
block_factor = 48
|
||||
elif block_factor % 32 == 0:
|
||||
block_factor = 32
|
||||
bx, tx = s[x].split(x.op.axis[num_dim-1], factor=block_factor)
|
||||
for i in range(num_dim-2, 0, -1):
|
||||
bx = s[x].fuse(bx, x.op.axis[i])
|
||||
|
||||
fused = s[x].fuse(*x.op.axis)
|
||||
num_thread = 64
|
||||
bx, tx = s[x].split(fused, factor=num_thread)
|
||||
s[x].bind(bx, tvm.thread_axis("blockIdx.x"))
|
||||
s[x].bind(tx, tvm.thread_axis("threadIdx.x"))
|
||||
|
||||
|
||||
return s
|
||||
|
|
Загрузка…
Ссылка в новой задаче