Add PReLU support to mxnet frontend (#1249)
This commit is contained in:
Родитель
5ba24773d6
Коммит
5a15664ebe
|
@ -151,9 +151,10 @@ def _dropout(inputs, attrs):
|
|||
|
||||
def _leaky_relu(inputs, attrs):
|
||||
act_type = _required_attr(attrs, 'act_type')
|
||||
if act_type in ['leaky']:
|
||||
op_name, new_attrs = 'leaky_relu', {}
|
||||
new_attrs['alpha'] = attrs.get('slope', 0.25)
|
||||
if act_type in ['leaky', 'prelu']:
|
||||
op_name, new_attrs = act_type, {}
|
||||
if act_type == 'leaky':
|
||||
new_attrs['alpha'] = attrs.get('slope', 0.25)
|
||||
sym = _get_nnvm_op(op_name)(*inputs, **new_attrs)
|
||||
elif act_type == 'elu':
|
||||
slope = attrs.get('slope', 0.25)
|
||||
|
|
|
@ -97,6 +97,13 @@ def test_forward_rrelu():
|
|||
mx_sym = mx.sym.LeakyReLU(data, act_type='rrelu', lower_bound=0.3, upper_bound=0.7)
|
||||
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
|
||||
|
||||
def test_forward_prelu():
|
||||
data = mx.sym.var('data')
|
||||
data = mx.sym.concat(data, -data, dim=1) # negative part explicitly
|
||||
gamma = mx.sym.zeros(shape=(6,))
|
||||
mx_sym = mx.sym.LeakyReLU(data, gamma=gamma, act_type='prelu')
|
||||
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
|
||||
|
||||
def test_forward_softrelu():
|
||||
data = mx.sym.var('data')
|
||||
data = mx.sym.concat(data, -data, dim=1) # negative part explicitly
|
||||
|
@ -126,6 +133,7 @@ if __name__ == '__main__':
|
|||
test_forward_resnet()
|
||||
test_forward_elu()
|
||||
test_forward_rrelu()
|
||||
test_forward_prelu()
|
||||
test_forward_softrelu()
|
||||
test_forward_fc_flatten()
|
||||
test_forward_clip()
|
||||
|
|
Загрузка…
Ссылка в новой задаче