[TOPI] Fix softmax bug (#437)
This commit is contained in:
Родитель
0c9adc5b1e
Коммит
a961b29c5c
|
@ -21,6 +21,7 @@ def softmax(x):
|
|||
m, n = x.shape
|
||||
k = tvm.reduce_axis((0, n), name='k')
|
||||
max_elem = tvm.compute((m, ), lambda i: tvm.max(x[i, k], axis=k))
|
||||
k = tvm.reduce_axis((0, n), name='k')
|
||||
expsum = tvm.compute(
|
||||
(m, ), lambda i: tvm.sum(tvm.exp(x[i, k] - max_elem[i]), axis=k))
|
||||
return tvm.compute(
|
||||
|
|
|
@ -6,9 +6,12 @@ import topi
|
|||
from topi.util import get_const_tuple
|
||||
|
||||
def verify_softmax(m, n):
|
||||
|
||||
A = tvm.placeholder((m, n), name='A')
|
||||
B = topi.nn.softmax(A)
|
||||
# confirm lower works
|
||||
s = tvm.create_schedule([B.op])
|
||||
tvm.lower(s, [A, B], simple_mode=True)
|
||||
|
||||
s = topi.cuda.schedule_softmax(B)
|
||||
|
||||
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
|
||||
|
@ -26,10 +29,11 @@ def verify_softmax(m, n):
|
|||
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
|
||||
|
||||
for device in ['cuda', 'opencl', 'metal']:
|
||||
check_device(device)
|
||||
check_device(device)
|
||||
|
||||
def test_softmax():
|
||||
verify_softmax(32, 10)
|
||||
verify_softmax(3, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Загрузка…
Ссылка в новой задаче