[FRONTEND][MXNET] Add expand_dims supoort (#1317)
* [FRONTEND][MXNET] Add expand_dims supoort * fix lint
This commit is contained in:
Родитель
a83e1e1eff
Коммит
f216b25e01
|
@ -241,6 +241,12 @@ def _elemwise_sum(inputs, _):
|
|||
return _get_nnvm_op('elemwise_sum')(*inputs, **new_attrs)
|
||||
|
||||
|
||||
def _expand_dims(inputs, attrs):
|
||||
op_name, new_attrs = "expand_dims", {}
|
||||
new_attrs['axis'] = _required_attr(attrs, 'axis')
|
||||
return _get_nnvm_op(op_name)(*inputs, **new_attrs)
|
||||
|
||||
|
||||
_identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
|
||||
'__div_symbol__', '__mul_scalar__', '__mul_symbol__',
|
||||
'__pow_scalar__', '__rdiv_scalar__', '__rpow_scalar__',
|
||||
|
@ -288,7 +294,8 @@ _convert_map = {
|
|||
'reshape' : _reshape,
|
||||
'sum_axis' : _rename('sum'),
|
||||
'UpSampling' : _upsampling,
|
||||
'clip' : _clip
|
||||
'clip' : _clip,
|
||||
'expand_dims' : _expand_dims
|
||||
}
|
||||
|
||||
def _convert_symbol(op_name, inputs, attrs,
|
||||
|
|
|
@ -136,6 +136,11 @@ def test_forward_split_squeeze():
|
|||
mx_sym = mx.sym.split(data, axis=1, num_outputs=4, squeeze_axis=True)
|
||||
verify_mxnet_frontend_impl(mx_sym, (1, 4, 2, 1), (1, 2, 1))
|
||||
|
||||
def test_forward_expand_dims():
|
||||
data = mx.sym.var('data')
|
||||
mx_sym = mx.sym.expand_dims(data, axis=1)
|
||||
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4), (2, 1, 3, 4))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_forward_mlp()
|
||||
test_forward_vgg()
|
||||
|
@ -148,3 +153,4 @@ if __name__ == '__main__':
|
|||
test_forward_clip()
|
||||
test_forward_split()
|
||||
test_forward_split_squeeze()
|
||||
test_forward_expand_dims()
|
||||
|
|
Загрузка…
Ссылка в новой задаче