Add PReLU support to mxnet frontend (#1249)

This commit is contained in:
Tatsuya Nishiyama 2018-06-08 13:32:52 +09:00 коммит произвёл Tianqi Chen
Родитель 5ba24773d6
Коммит 5a15664ebe
2 изменённых файлов: 12 добавлений и 3 удалений

Просмотреть файл

@ -151,9 +151,10 @@ def _dropout(inputs, attrs):
def _leaky_relu(inputs, attrs):
act_type = _required_attr(attrs, 'act_type')
if act_type in ['leaky']:
op_name, new_attrs = 'leaky_relu', {}
new_attrs['alpha'] = attrs.get('slope', 0.25)
if act_type in ['leaky', 'prelu']:
op_name, new_attrs = act_type, {}
if act_type == 'leaky':
new_attrs['alpha'] = attrs.get('slope', 0.25)
sym = _get_nnvm_op(op_name)(*inputs, **new_attrs)
elif act_type == 'elu':
slope = attrs.get('slope', 0.25)

Просмотреть файл

@ -97,6 +97,13 @@ def test_forward_rrelu():
mx_sym = mx.sym.LeakyReLU(data, act_type='rrelu', lower_bound=0.3, upper_bound=0.7)
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
def test_forward_prelu():
data = mx.sym.var('data')
data = mx.sym.concat(data, -data, dim=1) # negative part explicitly
gamma = mx.sym.zeros(shape=(6,))
mx_sym = mx.sym.LeakyReLU(data, gamma=gamma, act_type='prelu')
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
def test_forward_softrelu():
data = mx.sym.var('data')
data = mx.sym.concat(data, -data, dim=1) # negative part explicitly
@ -126,6 +133,7 @@ if __name__ == '__main__':
test_forward_resnet()
test_forward_elu()
test_forward_rrelu()
test_forward_prelu()
test_forward_softrelu()
test_forward_fc_flatten()
test_forward_clip()