From a961b29c5c363edca18d3d30291b28a259d35f2c Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Fri, 8 Sep 2017 19:00:24 -0700 Subject: [PATCH] [TOPI] Fix softmax bug (#437) --- topi/python/topi/nn/softmax.py | 1 + topi/tests/python/test_topi_softmax.py | 8 ++++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/topi/python/topi/nn/softmax.py b/topi/python/topi/nn/softmax.py index 4c39a9f1..e3b19cff 100644 --- a/topi/python/topi/nn/softmax.py +++ b/topi/python/topi/nn/softmax.py @@ -21,6 +21,7 @@ def softmax(x): m, n = x.shape k = tvm.reduce_axis((0, n), name='k') max_elem = tvm.compute((m, ), lambda i: tvm.max(x[i, k], axis=k)) + k = tvm.reduce_axis((0, n), name='k') expsum = tvm.compute( (m, ), lambda i: tvm.sum(tvm.exp(x[i, k] - max_elem[i]), axis=k)) return tvm.compute( diff --git a/topi/tests/python/test_topi_softmax.py b/topi/tests/python/test_topi_softmax.py index 0afd7460..b5eb9363 100644 --- a/topi/tests/python/test_topi_softmax.py +++ b/topi/tests/python/test_topi_softmax.py @@ -6,9 +6,12 @@ import topi from topi.util import get_const_tuple def verify_softmax(m, n): - A = tvm.placeholder((m, n), name='A') B = topi.nn.softmax(A) + # confirm lower works + s = tvm.create_schedule([B.op]) + tvm.lower(s, [A, B], simple_mode=True) + s = topi.cuda.schedule_softmax(B) a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) @@ -26,10 +29,11 @@ def verify_softmax(m, n): np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) for device in ['cuda', 'opencl', 'metal']: - check_device(device) + check_device(device) def test_softmax(): verify_softmax(32, 10) + verify_softmax(3, 4) if __name__ == "__main__":