This commit is contained in:
Tianqi Chen 2017-09-08 19:00:24 -07:00 коммит произвёл GitHub
Родитель 0c9adc5b1e
Коммит a961b29c5c
2 изменённых файлов: 7 добавлений и 2 удалений

Просмотреть файл

@ -21,6 +21,7 @@ def softmax(x):
m, n = x.shape
k = tvm.reduce_axis((0, n), name='k')
max_elem = tvm.compute((m, ), lambda i: tvm.max(x[i, k], axis=k))
k = tvm.reduce_axis((0, n), name='k')
expsum = tvm.compute(
(m, ), lambda i: tvm.sum(tvm.exp(x[i, k] - max_elem[i]), axis=k))
return tvm.compute(

Просмотреть файл

@ -6,9 +6,12 @@ import topi
from topi.util import get_const_tuple
def verify_softmax(m, n):
A = tvm.placeholder((m, n), name='A')
B = topi.nn.softmax(A)
# confirm lower works
s = tvm.create_schedule([B.op])
tvm.lower(s, [A, B], simple_mode=True)
s = topi.cuda.schedule_softmax(B)
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
@ -26,10 +29,11 @@ def verify_softmax(m, n):
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal']:
check_device(device)
check_device(device)
def test_softmax():
verify_softmax(32, 10)
verify_softmax(3, 4)
if __name__ == "__main__":