Fix inceptionv3 (#1446)
This commit is contained in:
Родитель
00c87b376d
Коммит
8dbe779466
|
@ -56,6 +56,8 @@ def _pooling(inputs, attrs):
|
|||
new_attrs['strides'] = attrs.get('stride', (1, 1))
|
||||
new_attrs['padding'] = attrs.get('pad', (0, 0))
|
||||
new_attrs['ceil_mode'] = (attrs.get('pooling_convention', 'valid') == 'full')
|
||||
if pool_type == 'avg':
|
||||
new_attrs['count_include_pad'] = attrs.get('count_include_pad', True)
|
||||
return _get_nnvm_op(op_name)(*inputs, **new_attrs)
|
||||
|
||||
def _batch_norm(inputs, attrs):
|
||||
|
|
|
@ -10,6 +10,20 @@ from nnvm.testing.config import ctx_list
|
|||
|
||||
|
||||
def test_conv2d():
|
||||
def run_test_conv2d(sym, dtype, dshape, kshape, oshape, shape_dict, padding):
|
||||
for target, ctx in ctx_list():
|
||||
graph, lib, _ = nnvm.compiler.build(sym, target, shape_dict)
|
||||
m = graph_runtime.create(graph, lib, ctx)
|
||||
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
|
||||
kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype))
|
||||
bias = tvm.nd.array(np.random.uniform(size=kshape[0]).astype(dtype))
|
||||
m.run(x=data, y_weight=kernel, y_bias=bias)
|
||||
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
|
||||
c_np = topi.testing.conv2d_nchw_python(
|
||||
data.asnumpy(), kernel.asnumpy(), 1, padding)
|
||||
c_np = c_np + bias.asnumpy().reshape(kshape[0], 1, 1)
|
||||
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
|
||||
|
||||
x = sym.Variable("x")
|
||||
y = sym.conv2d(x, channels=10, kernel_size=(3,3),
|
||||
name="y", padding=(1,1))
|
||||
|
@ -18,18 +32,17 @@ def test_conv2d():
|
|||
kshape = (10, 3, 3, 3)
|
||||
oshape = (1, 10, 18, 18)
|
||||
shape_dict = {"x": dshape}
|
||||
for target, ctx in ctx_list():
|
||||
graph, lib, _ = nnvm.compiler.build(y, target, shape_dict)
|
||||
m = graph_runtime.create(graph, lib, ctx)
|
||||
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
|
||||
kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype))
|
||||
bias = tvm.nd.array(np.random.uniform(size=kshape[0]).astype(dtype))
|
||||
m.run(x=data, y_weight=kernel, y_bias=bias)
|
||||
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
|
||||
c_np = topi.testing.conv2d_nchw_python(
|
||||
data.asnumpy(), kernel.asnumpy(), 1, 1)
|
||||
c_np = c_np + bias.asnumpy().reshape(kshape[0], 1, 1)
|
||||
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
|
||||
run_test_conv2d(y, dtype, dshape, kshape, oshape, shape_dict, (1,1))
|
||||
|
||||
x = sym.Variable("x")
|
||||
y = sym.conv2d(x, channels=10, kernel_size=(1,3),
|
||||
name="y", padding=(0,1))
|
||||
dtype = "float32"
|
||||
dshape = (1, 3, 224, 224)
|
||||
kshape = (10, 3, 1, 3)
|
||||
oshape = (1, 10, 224, 224)
|
||||
shape_dict = {"x": dshape}
|
||||
run_test_conv2d(y, dtype, dshape, kshape, oshape, shape_dict, (0,1))
|
||||
|
||||
|
||||
def test_mixed_precision():
|
||||
|
|
|
@ -141,6 +141,14 @@ def test_forward_expand_dims():
|
|||
mx_sym = mx.sym.expand_dims(data, axis=1)
|
||||
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 1, 3, 4))
|
||||
|
||||
def test_forward_pooling():
|
||||
data = mx.sym.var('data')
|
||||
mx_sym = mx.sym.Pooling(data, kernel=(3, 3), pad=(1, 1), pool_type='avg')
|
||||
verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8), (1, 20, 8, 8))
|
||||
|
||||
mx_sym = mx.sym.Pooling(data, kernel=(3, 3), pad=(1, 1), pool_type='max')
|
||||
verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8), (1, 20, 8, 8))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_forward_mlp()
|
||||
test_forward_vgg()
|
||||
|
@ -154,3 +162,4 @@ if __name__ == '__main__':
|
|||
test_forward_split()
|
||||
test_forward_split_squeeze()
|
||||
test_forward_expand_dims()
|
||||
test_forward_pooling()
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
|
||||
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals, too-many-branches
|
||||
"""Convolution in python"""
|
||||
import numpy as np
|
||||
import scipy.signal
|
||||
|
@ -18,8 +18,8 @@ def conv2d_nchw_python(a_np, w_np, stride, padding):
|
|||
stride : int or a list/tuple of two ints
|
||||
Stride size, or [stride_height, stride_width]
|
||||
|
||||
padding : int or str
|
||||
Padding size, or ['VALID', 'SAME']
|
||||
padding : int or str or a list/tuple of two ints
|
||||
Padding size, or ['VALID', 'SAME'], or [pad_height, pad_width]
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
@ -34,12 +34,11 @@ def conv2d_nchw_python(a_np, w_np, stride, padding):
|
|||
stride_h, stride_w = stride
|
||||
if isinstance(padding, int):
|
||||
pad_h = pad_w = padding * 2
|
||||
elif padding == 'VALID':
|
||||
pad_h = 0
|
||||
pad_w = 0
|
||||
else: # 'SAME'
|
||||
pad_h = kernel_h - 1
|
||||
pad_w = kernel_w - 1
|
||||
elif isinstance(padding, (list, tuple)):
|
||||
pad_h, pad_w = padding[0] * 2, padding[1] * 2
|
||||
else:
|
||||
pad_h = 0 if padding == 'VALID' else kernel_h - 1
|
||||
pad_w = 0 if padding == 'VALID' else kernel_w - 1
|
||||
pad_top = int(np.ceil(float(pad_h) / 2))
|
||||
pad_bottom = pad_h - pad_top
|
||||
pad_left = int(np.ceil(float(pad_w) / 2))
|
||||
|
@ -53,9 +52,14 @@ def conv2d_nchw_python(a_np, w_np, stride, padding):
|
|||
for n in range(batch):
|
||||
for f in range(out_channel):
|
||||
for c in range(in_channel):
|
||||
if pad_h > 0:
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
apad = np.zeros((in_height + pad_h, in_width + pad_w))
|
||||
apad[pad_top:-pad_bottom, pad_left:-pad_right] = a_np[n, c]
|
||||
if pad_h == 0:
|
||||
apad[:, pad_left:-pad_right] = a_np[n, c]
|
||||
elif pad_w == 0:
|
||||
apad[pad_top:-pad_bottom, :] = a_np[n, c]
|
||||
else:
|
||||
apad[pad_top:-pad_bottom, pad_left:-pad_right] = a_np[n, c]
|
||||
else:
|
||||
apad = a_np[n, c]
|
||||
out = scipy.signal.convolve2d(
|
||||
|
|
|
@ -56,7 +56,7 @@ def _declaration_conv(data, kernel, stride, padding, layout, out_dtype):
|
|||
out_height = (in_height + 2 * HPAD - kernel_height) // HSTR + 1
|
||||
out_width = (in_width + 2 * WPAD - kernel_width) // WSTR + 1
|
||||
|
||||
DOPAD = (HPAD != 0 and WPAD != 0)
|
||||
DOPAD = (HPAD != 0 or WPAD != 0)
|
||||
if DOPAD:
|
||||
data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad")
|
||||
else:
|
||||
|
@ -95,7 +95,7 @@ def _schedule_conv(s, data, data_pad, data_vec, kernel, kernel_vec, conv_out, ou
|
|||
sch = _get_schedule(wkl)
|
||||
|
||||
HPAD, WPAD = wkl.hpad, wkl.wpad
|
||||
DOPAD = (HPAD != 0 and WPAD != 0)
|
||||
DOPAD = (HPAD != 0 or WPAD != 0)
|
||||
|
||||
A, W = data, kernel_vec
|
||||
A0, A1 = data_pad, data_vec
|
||||
|
@ -163,7 +163,7 @@ def _declaration_conv_NCHWc(wkl, sch, data, kernel):
|
|||
out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1
|
||||
out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
|
||||
|
||||
DOPAD = (HPAD != 0 and WPAD != 0)
|
||||
DOPAD = (HPAD != 0 or WPAD != 0)
|
||||
if DOPAD:
|
||||
data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad")
|
||||
else:
|
||||
|
|
|
@ -58,7 +58,7 @@ def _declaration_conv(data, kernel, stride, padding, layout, out_dtype):
|
|||
out_width = (in_width + 2 * WPAD - kernel_width) // WSTR + 1
|
||||
|
||||
# pack data
|
||||
DOPAD = (HPAD != 0 and WPAD != 0)
|
||||
DOPAD = (HPAD != 0 or WPAD != 0)
|
||||
if DOPAD:
|
||||
data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad")
|
||||
else:
|
||||
|
@ -108,7 +108,7 @@ def _schedule_conv(s, data, data_pad, data_vec, kernel, kernel_vec, conv_out, ou
|
|||
sch = _get_schedule(wkl)
|
||||
|
||||
HPAD, WPAD = wkl.hpad, wkl.wpad
|
||||
DOPAD = (HPAD != 0 and WPAD != 0)
|
||||
DOPAD = (HPAD != 0 or WPAD != 0)
|
||||
|
||||
A, W = data, kernel_vec
|
||||
A0, A1 = data_pad, data_vec
|
||||
|
@ -181,7 +181,7 @@ def _declaration_conv_NCHWc(wkl, sch, data, kernel):
|
|||
out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
|
||||
|
||||
# pack data
|
||||
DOPAD = (HPAD != 0 and WPAD != 0)
|
||||
DOPAD = (HPAD != 0 or WPAD != 0)
|
||||
if DOPAD:
|
||||
data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad")
|
||||
else:
|
||||
|
|
Загрузка…
Ссылка в новой задаче