diff --git a/nnvm/python/nnvm/frontend/mxnet.py b/nnvm/python/nnvm/frontend/mxnet.py index 2f190ab7..1fc67311 100644 --- a/nnvm/python/nnvm/frontend/mxnet.py +++ b/nnvm/python/nnvm/frontend/mxnet.py @@ -241,6 +241,12 @@ def _elemwise_sum(inputs, _): return _get_nnvm_op('elemwise_sum')(*inputs, **new_attrs) +def _expand_dims(inputs, attrs): + op_name, new_attrs = "expand_dims", {} + new_attrs['axis'] = _required_attr(attrs, 'axis') + return _get_nnvm_op(op_name)(*inputs, **new_attrs) + + _identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__', '__div_symbol__', '__mul_scalar__', '__mul_symbol__', '__pow_scalar__', '__rdiv_scalar__', '__rpow_scalar__', @@ -288,7 +294,8 @@ _convert_map = { 'reshape' : _reshape, 'sum_axis' : _rename('sum'), 'UpSampling' : _upsampling, - 'clip' : _clip + 'clip' : _clip, + 'expand_dims' : _expand_dims } def _convert_symbol(op_name, inputs, attrs, diff --git a/nnvm/tests/python/frontend/mxnet/test_forward.py b/nnvm/tests/python/frontend/mxnet/test_forward.py index e6b6dffa..cfb4e553 100644 --- a/nnvm/tests/python/frontend/mxnet/test_forward.py +++ b/nnvm/tests/python/frontend/mxnet/test_forward.py @@ -136,6 +136,11 @@ def test_forward_split_squeeze(): mx_sym = mx.sym.split(data, axis=1, num_outputs=4, squeeze_axis=True) verify_mxnet_frontend_impl(mx_sym, (1, 4, 2, 1), (1, 2, 1)) +def test_forward_expand_dims(): + data = mx.sym.var('data') + mx_sym = mx.sym.expand_dims(data, axis=1) + verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 1, 3, 4)) + if __name__ == '__main__': test_forward_mlp() test_forward_vgg() @@ -148,3 +153,4 @@ if __name__ == '__main__': test_forward_clip() test_forward_split() test_forward_split_squeeze() + test_forward_expand_dims()