Add more cases to keras _convert_reshape (#3846)
This commit is contained in:
Родитель
ec7790e355
Коммит
2ebf1bd14e
|
@ -490,11 +490,26 @@ def _convert_concat(inexpr, keras_layer, _):
|
|||
|
||||
def _convert_reshape(inexpr, keras_layer, _):
|
||||
_check_data_format(keras_layer)
|
||||
ch = keras_layer.input_shape[-1]
|
||||
assert ch == keras_layer.target_shape[-1], \
|
||||
inshape = keras_layer.input_shape # includes batch
|
||||
tshape = keras_layer.target_shape # no batch
|
||||
if len(inshape) == 3 and len(tshape) == 1:
|
||||
# (?, a, b) -> (-1, ab)
|
||||
shape = (-1, tshape[0])
|
||||
elif len(inshape) in [2, 3] and len(tshape) == 2:
|
||||
# (?, cc) -> (-1, c, c)
|
||||
# (?, a, b) -> (-1, c, c)
|
||||
assert tshape[0] == tshape[1], \
|
||||
"Only supports square target shapes, but got {}".format(tshape)
|
||||
shape = (-1, ) + tshape
|
||||
else:
|
||||
# (?, h, w, c) -> (-1, c, H, W)
|
||||
# (?, h, w, c) -> (-1, c, hw)
|
||||
# (?, hw, c) -> (-1, c, h, w)
|
||||
ch = inshape[-1]
|
||||
assert ch == tshape[-1], \
|
||||
"Only supports last dimension in target shape being equal to " \
|
||||
"the channel number of input tensor."
|
||||
shape = (-1, ch) + keras_layer.target_shape[:-1]
|
||||
shape = (-1, ch) + tshape[:-1]
|
||||
return _op.reshape(inexpr, newshape=shape)
|
||||
|
||||
|
||||
|
|
|
@ -193,10 +193,36 @@ def test_forward_upsample(interpolation='nearest'):
|
|||
|
||||
|
||||
def test_forward_reshape():
|
||||
# input_shape len is 3, target_shape len is 3
|
||||
data = keras.layers.Input(shape=(32, 32, 3))
|
||||
x = keras.layers.Reshape(target_shape=(32, 32, 3))(data)
|
||||
x = keras.layers.Reshape(target_shape=(16, 64, 3))(data)
|
||||
keras_model = keras.models.Model(data, x)
|
||||
verify_keras_frontend(keras_model)
|
||||
# input_shape len is 3, target_shape len is 2
|
||||
data = keras.layers.Input(shape=(32, 8, 3))
|
||||
x = keras.layers.Reshape(target_shape=(256, 3))(data)
|
||||
keras_model = keras.models.Model(data, x)
|
||||
verify_keras_frontend(keras_model)
|
||||
# input_shape len is 2, target_shape len is 3
|
||||
data = keras.layers.Input(shape=(256, 3))
|
||||
x = keras.layers.Reshape(target_shape=(8, 32, 3))(data)
|
||||
keras_model = keras.models.Model(data, x)
|
||||
verify_keras_frontend(keras_model)
|
||||
# input_shape len is 2, target_shape len is 1
|
||||
data = keras.layers.Input(shape=(2, 8))
|
||||
x = keras.layers.Reshape(target_shape=(16,))(data)
|
||||
keras_model = keras.models.Model(data, x)
|
||||
verify_keras_frontend(keras_model, need_transpose=False)
|
||||
# input_shape len is 1, target_shape len is 2
|
||||
data = keras.layers.Input(shape=(16,))
|
||||
x = keras.layers.Reshape(target_shape=(4, 4))(data)
|
||||
keras_model = keras.models.Model(data, x)
|
||||
verify_keras_frontend(keras_model, need_transpose=False)
|
||||
# input_shape len is 2, target_shape len is 2
|
||||
data = keras.layers.Input(shape=(2, 8))
|
||||
x = keras.layers.Reshape(target_shape=(4, 4))(data)
|
||||
keras_model = keras.models.Model(data, x)
|
||||
verify_keras_frontend(keras_model, need_transpose=False)
|
||||
|
||||
|
||||
def test_forward_crop():
|
||||
|
|
Загрузка…
Ссылка в новой задаче