Expose clip to frontend mxnet (#512)
This commit is contained in:
Родитель
90c1157b5f
Коммит
31edf3f7eb
|
@ -205,6 +205,12 @@ def _upsampling(inputs, attrs):
|
||||||
new_attrs = {'scale':int(scale)}
|
new_attrs = {'scale':int(scale)}
|
||||||
return _get_nnvm_op('upsampling')(inputs[0], **new_attrs)
|
return _get_nnvm_op('upsampling')(inputs[0], **new_attrs)
|
||||||
|
|
||||||
|
def _clip(inputs, attrs):
|
||||||
|
op_name, new_attrs = "clip", {}
|
||||||
|
new_attrs['a_min'] = _required_attr(attrs, 'a_min')
|
||||||
|
new_attrs['a_max'] = _required_attr(attrs, 'a_max')
|
||||||
|
return _get_nnvm_op(op_name)(*inputs, **new_attrs)
|
||||||
|
|
||||||
|
|
||||||
_identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
|
_identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
|
||||||
'__div_symbol__', '__mul_scalar__', '__mul_symbol__',
|
'__div_symbol__', '__mul_scalar__', '__mul_symbol__',
|
||||||
|
@ -248,6 +254,7 @@ _convert_map = {
|
||||||
'reshape' : _reshape,
|
'reshape' : _reshape,
|
||||||
'sum_axis' : _rename('sum'),
|
'sum_axis' : _rename('sum'),
|
||||||
'UpSampling' : _upsampling,
|
'UpSampling' : _upsampling,
|
||||||
|
'clip' : _clip
|
||||||
}
|
}
|
||||||
|
|
||||||
def _convert_symbol(op_name, inputs, attrs,
|
def _convert_symbol(op_name, inputs, attrs,
|
||||||
|
|
|
@ -71,7 +71,7 @@ def get_symbol(num_classes, num_layers=11, batch_norm=False, dtype='float32', **
|
||||||
13: ([2, 2, 2, 2, 2], [64, 128, 256, 512, 512]),
|
13: ([2, 2, 2, 2, 2], [64, 128, 256, 512, 512]),
|
||||||
16: ([2, 2, 3, 3, 3], [64, 128, 256, 512, 512]),
|
16: ([2, 2, 3, 3, 3], [64, 128, 256, 512, 512]),
|
||||||
19: ([2, 2, 4, 4, 4], [64, 128, 256, 512, 512])}
|
19: ([2, 2, 4, 4, 4], [64, 128, 256, 512, 512])}
|
||||||
if not vgg_spec.has_key(num_layers):
|
if num_layers not in vgg_spec:
|
||||||
raise ValueError("Invalide num_layers {}. Possible choices are 11,13,16,19.".format(num_layers))
|
raise ValueError("Invalide num_layers {}. Possible choices are 11,13,16,19.".format(num_layers))
|
||||||
layers, filters = vgg_spec[num_layers]
|
layers, filters = vgg_spec[num_layers]
|
||||||
data = mx.sym.Variable(name="data")
|
data = mx.sym.Variable(name="data")
|
||||||
|
|
|
@ -8,24 +8,41 @@ import nnvm.compiler
|
||||||
from nnvm.testing.config import ctx_list
|
from nnvm.testing.config import ctx_list
|
||||||
from nnvm import frontend
|
from nnvm import frontend
|
||||||
import mxnet as mx
|
import mxnet as mx
|
||||||
|
from mxnet import gluon
|
||||||
|
from mxnet.gluon.model_zoo import vision
|
||||||
import model_zoo
|
import model_zoo
|
||||||
|
|
||||||
|
|
||||||
def verify_mxnet_frontend_impl(mx_symbol, data_shape=(1, 3, 224, 224), out_shape=(1, 1000)):
|
def verify_mxnet_frontend_impl(mx_symbol, data_shape=(1, 3, 224, 224), out_shape=(1, 1000),
|
||||||
|
gluon_impl=False, name=None):
|
||||||
"""Use name different from test to avoid let nose pick it up"""
|
"""Use name different from test to avoid let nose pick it up"""
|
||||||
def get_mxnet_output(symbol, x, dtype='float32'):
|
if gluon_impl:
|
||||||
from collections import namedtuple
|
def get_gluon_output(name, x):
|
||||||
Batch = namedtuple('Batch', ['data'])
|
net = vision.get_model(name)
|
||||||
mod = mx.mod.Module(symbol, label_names=None)
|
net.collect_params().initialize(mx.init.Xavier())
|
||||||
mod.bind(data_shapes=[('data', x.shape)], for_training=False)
|
net_sym = gluon.nn.SymbolBlock(outputs=net(mx.sym.var('data')),
|
||||||
mod.init_params()
|
inputs=mx.sym.var('data'),
|
||||||
mod.forward(Batch([mx.nd.array(x.astype(dtype))]))
|
params=net.collect_params())
|
||||||
out = mod.get_outputs()[0].asnumpy()
|
out = net_sym(mx.nd.array(x.astype(dtype))).asnumpy()
|
||||||
args, auxs = mod.get_params()
|
return out, net_sym
|
||||||
return out, args, auxs
|
else:
|
||||||
|
def get_mxnet_output(symbol, x, dtype='float32'):
|
||||||
|
from collections import namedtuple
|
||||||
|
Batch = namedtuple('Batch', ['data'])
|
||||||
|
mod = mx.mod.Module(symbol, label_names=None)
|
||||||
|
mod.bind(data_shapes=[('data', x.shape)], for_training=False)
|
||||||
|
mod.init_params()
|
||||||
|
mod.forward(Batch([mx.nd.array(x.astype(dtype))]))
|
||||||
|
out = mod.get_outputs()[0].asnumpy()
|
||||||
|
args, auxs = mod.get_params()
|
||||||
|
return out, args, auxs
|
||||||
|
|
||||||
def get_tvm_output(symbol, x, args, auxs, target, ctx, dtype='float32'):
|
def get_tvm_output(symbol, x, args, auxs, target, ctx, dtype='float32'):
|
||||||
new_sym, params = frontend.from_mxnet(symbol, args, auxs)
|
if gluon_impl:
|
||||||
|
new_sym, params = frontend.from_mxnet(symbol)
|
||||||
|
else:
|
||||||
|
new_sym, params = frontend.from_mxnet(symbol, args, auxs)
|
||||||
|
|
||||||
dshape = x.shape
|
dshape = x.shape
|
||||||
shape_dict = {'data': dshape}
|
shape_dict = {'data': dshape}
|
||||||
with nnvm.compiler.build_config(opt_level=3):
|
with nnvm.compiler.build_config(opt_level=3):
|
||||||
|
@ -42,11 +59,17 @@ def verify_mxnet_frontend_impl(mx_symbol, data_shape=(1, 3, 224, 224), out_shape
|
||||||
# random input
|
# random input
|
||||||
dtype = 'float32'
|
dtype = 'float32'
|
||||||
x = np.random.uniform(size=data_shape)
|
x = np.random.uniform(size=data_shape)
|
||||||
mx_out, args, auxs = get_mxnet_output(mx_symbol, x, dtype)
|
if gluon_impl:
|
||||||
assert "data" not in args
|
gluon_out, gluon_sym = get_gluon_output(name, x)
|
||||||
for target, ctx in ctx_list():
|
for target, ctx in ctx_list():
|
||||||
tvm_out = get_tvm_output(mx_symbol, x, args, auxs, target, ctx, dtype)
|
tvm_out = get_tvm_output(gluon_sym, x, None, None, target, ctx, dtype)
|
||||||
np.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5)
|
np.testing.assert_allclose(gluon_out, tvm_out, rtol=1e-5, atol=1e-5)
|
||||||
|
else:
|
||||||
|
mx_out, args, auxs = get_mxnet_output(mx_symbol, x, dtype)
|
||||||
|
assert "data" not in args
|
||||||
|
for target, ctx in ctx_list():
|
||||||
|
tvm_out = get_tvm_output(mx_symbol, x, args, auxs, target, ctx, dtype)
|
||||||
|
np.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5)
|
||||||
|
|
||||||
def test_forward_mlp():
|
def test_forward_mlp():
|
||||||
mlp = model_zoo.mx_mlp
|
mlp = model_zoo.mx_mlp
|
||||||
|
@ -91,6 +114,12 @@ def test_forward_fc_flatten():
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_forward_clip():
|
||||||
|
data = mx.sym.var('data')
|
||||||
|
data = mx.sym.concat(data, -data, dim=1) # negative part explicity
|
||||||
|
mx_sym = mx.sym.clip(data, a_min=0, a_max=1)
|
||||||
|
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_forward_mlp()
|
test_forward_mlp()
|
||||||
test_forward_vgg()
|
test_forward_vgg()
|
||||||
|
@ -99,3 +128,4 @@ if __name__ == '__main__':
|
||||||
test_forward_rrelu()
|
test_forward_rrelu()
|
||||||
test_forward_softrelu()
|
test_forward_softrelu()
|
||||||
test_forward_fc_flatten()
|
test_forward_fc_flatten()
|
||||||
|
test_forward_clip()
|
||||||
|
|
Загрузка…
Ссылка в новой задаче