diff --git a/nnvm/python/nnvm/frontend/mxnet.py b/nnvm/python/nnvm/frontend/mxnet.py index 3fc3ca85..82b8e555 100644 --- a/nnvm/python/nnvm/frontend/mxnet.py +++ b/nnvm/python/nnvm/frontend/mxnet.py @@ -151,9 +151,10 @@ def _dropout(inputs, attrs): def _leaky_relu(inputs, attrs): act_type = _required_attr(attrs, 'act_type') - if act_type in ['leaky']: - op_name, new_attrs = 'leaky_relu', {} - new_attrs['alpha'] = attrs.get('slope', 0.25) + if act_type in ['leaky', 'prelu']: + op_name, new_attrs = act_type, {} + if act_type == 'leaky': + new_attrs['alpha'] = attrs.get('slope', 0.25) sym = _get_nnvm_op(op_name)(*inputs, **new_attrs) elif act_type == 'elu': slope = attrs.get('slope', 0.25) diff --git a/nnvm/tests/python/frontend/mxnet/test_forward.py b/nnvm/tests/python/frontend/mxnet/test_forward.py index fca19a69..b54a5e42 100644 --- a/nnvm/tests/python/frontend/mxnet/test_forward.py +++ b/nnvm/tests/python/frontend/mxnet/test_forward.py @@ -97,6 +97,13 @@ def test_forward_rrelu(): mx_sym = mx.sym.LeakyReLU(data, act_type='rrelu', lower_bound=0.3, upper_bound=0.7) verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100)) +def test_forward_prelu(): + data = mx.sym.var('data') + data = mx.sym.concat(data, -data, dim=1) # negative part explicitly + gamma = mx.sym.zeros(shape=(6,)) + mx_sym = mx.sym.LeakyReLU(data, gamma=gamma, act_type='prelu') + verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100)) + def test_forward_softrelu(): data = mx.sym.var('data') data = mx.sym.concat(data, -data, dim=1) # negative part explicitly @@ -126,6 +133,7 @@ if __name__ == '__main__': test_forward_resnet() test_forward_elu() test_forward_rrelu() + test_forward_prelu() test_forward_softrelu() test_forward_fc_flatten() test_forward_clip()