From 7ad3c51e61fb0397fb9cf1bcf90a193ff2b4666b Mon Sep 17 00:00:00 2001 From: Yuwei HU Date: Mon, 28 Aug 2017 13:05:22 +0800 Subject: [PATCH] [TOPI] improve elemwise schedule (#393) * [TOPI] improve elemwise schedule * modify fuse --- topi/python/topi/cuda/elemwise.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/topi/python/topi/cuda/elemwise.py b/topi/python/topi/cuda/elemwise.py index 6e1037fe..ed4953c3 100644 --- a/topi/python/topi/cuda/elemwise.py +++ b/topi/python/topi/cuda/elemwise.py @@ -23,17 +23,10 @@ def schedule_elemwise(outs): x = outs[0] num_dim = len(x.shape) - block_factor = tvm.ir_pass.Simplify(x.op.output(0).shape[num_dim-1]).value - if block_factor % 48 == 0: - block_factor = 48 - elif block_factor % 32 == 0: - block_factor = 32 - bx, tx = s[x].split(x.op.axis[num_dim-1], factor=block_factor) - for i in range(num_dim-2, 0, -1): - bx = s[x].fuse(bx, x.op.axis[i]) - + fused = s[x].fuse(*x.op.axis) + num_thread = 64 + bx, tx = s[x].split(fused, factor=num_thread) s[x].bind(bx, tvm.thread_axis("blockIdx.x")) s[x].bind(tx, tvm.thread_axis("threadIdx.x")) - return s