[Relay][Frontend][Keras] Fix ReLU in Keras Converter missed the case (#3917)

* [Relay][Frontend][Keras] Fix ReLU in Keras Converter missed the case

* [Relay][Frontend][Keras] Add test case for ReLU in Keras Converter missed the case

* [Relay][Frontend][Keras] Add test case for ReLU in Keras Converter missed the case
This commit is contained in:
Neo Chien 2019-09-11 01:41:16 +08:00 коммит произвёл Yao Wang
Родитель 42195a48e0
Коммит 5bff6ccede
2 изменённых файлов: 17 добавлений и 3 удалений

Просмотреть файл

@ -128,8 +128,15 @@ def _convert_advanced_activation(inexpr, keras_layer, etab):
axis = axis + 1 if axis < dims - 1 else 1
return _op.nn.softmax(inexpr, axis=axis)
if act_type == 'ReLU':
if keras_layer.max_value:
threshold = _expr.const(keras_layer.threshold, dtype='float32')
if keras_layer.max_value and float(keras_layer.threshold) == 0:
# f(x) = max_value, for x >= max_value
# f(x) = x, for threshold <= x < max_value
return _op.clip(inexpr, a_min=0., a_max=float(keras_layer.max_value))
elif keras_layer.max_value and _op.greater(threshold, inexpr).astype('float32'):
# f(x) = negative_slope * (inexpr - threshold)
negative_slope = _expr.const(keras_layer.negative_slope, dtype='float32')
return _op.multiply(negative_slope, _op.subtract(inexpr, threshold))
return _op.nn.relu(inexpr)
if act_type == 'LeakyReLU':
return _op.nn.leaky_relu(inexpr, alpha=float(keras_layer.alpha))
@ -162,11 +169,11 @@ def _convert_merge(inexpr, keras_layer, _):
axes = [keras_layer.axes, keras_layer.axes]
if isinstance(axes, list):
if len(axes) != 2:
raise tvm.error.OpAttributeUnimplemented(
raise tvm.error.OpAttributeUnImplemented(
'Dot with axes {} is not supported.'.format(keras_layer.axes))
for i, axis in enumerate(axes):
if axis not in [1, 2]:
raise tvm.error.OpAttributeUnimplemented(
raise tvm.error.OpAttributeUnImplemented(
'Dot with axes {} is not supported.'.format(keras_layer.axes))
if axes[i] == 2:
inexpr[i] = _op.transpose(inexpr[i], axes=[0, 2, 1])
@ -191,9 +198,11 @@ def _convert_merge(inexpr, keras_layer, _):
'Operator {} is not supported in frontend Keras.'.format(merge_type))
return ret
def _convert_permute(inexpr, keras_layer, _):
return _op.transpose(inexpr, axes=(0,) + keras_layer.dims)
def _convert_dense(inexpr, keras_layer, etab):
weightList = keras_layer.get_weights()
weight = etab.new_const(weightList[0].transpose([1, 0]))

Просмотреть файл

@ -123,6 +123,11 @@ def test_forward_activations():
keras.layers.Activation('selu'),
keras.layers.ReLU(),
keras.layers.ReLU(max_value=6.),
keras.layers.ReLU(max_value=6., threshold=0.),
keras.layers.ReLU(max_value=6., threshold=1.),
keras.layers.ReLU(max_value=6., threshold=1., negative_slope=0.),
keras.layers.ReLU(max_value=6., threshold=1., negative_slope=0.5),
keras.layers.ReLU(max_value=6., threshold=1., negative_slope=1.),
keras.layers.LeakyReLU(alpha=0.3),
keras.layers.PReLU(weights=np.random.rand(1, 32, 32, 3)),
keras.layers.ELU(alpha=0.5),