Fixed a cloning bug in placehoder shape information

This commit is contained in:
Yuqing Tang 2018-06-03 12:12:09 -07:00
Родитель 0642d734c3
Коммит 199bc5c30b
2 изменённых файлов: 23 добавлений и 1 удалений

Просмотреть файл

@ -805,7 +805,9 @@ namespace CNTK
if (existingPlaceholderReplacement == placeholderReplacements.end())
{
clonedInput = PlaceholderVariable();
//we need to carry the shape information to the new placeholder otherwise, deep chained recurrence with reshaping ops will fail (e.g. expand_dims);
//however, we can not carry over the dynamic axis, as the placeholder might be replaced with different dynamic axes
clonedInput = PlaceholderVariable(cloneeInput.Shape());
placeholderReplacements[clonedInput] = cloneeInput;
}
else

Просмотреть файл

@ -558,3 +558,23 @@ def test_clone_with_different_dynamic_axes():
rnn = C.layers.Recurrence(C.layers.LSTM(5))(question_input)
rnn_cloned = rnn.clone(C.CloneMethod.share, {question_input:answer_input})
def test_clone_with_deep_rnn_chaining():
def seq_op_func(seqinp):
l = seqinp
r = C.sequence.future_value(l)
r = C.expand_dims(r, -len(seqinp.shape) - 1)
res = l + r
return res
def rnn_seq(features):
step_func = C.layers.GRU(1)
seq = C.layers.Recurrence(step_func)(features)
return seq
feat = C.sequence.input_variable((40,), name='sequence_inp')
c1 = rnn_seq(feat)
seq_op_res = seq_op_func(c1)
net = rnn_seq(seq_op_res)
cloned = net.clone('freeze')