Merge branch 'liqun/NewSequenceSliceStage2Stage'

This commit is contained in:
Liqun Fu 2018-08-15 06:46:32 +00:00
Родитель 6547e2ce7f ae163b33a8
Коммит ef836bc293
6 изменённых файлов: 236 добавлений и 54 удалений

Просмотреть файл

@ -600,24 +600,35 @@ namespace CNTK
VerifyStaticAxis(ax, m_inputs[0].Shape()); VerifyStaticAxis(ax, m_inputs[0].Shape());
size_t sliceAxisDim = m_inputs[0].Shape()[ax.StaticAxisIndex()]; size_t sliceAxisDim = m_inputs[0].Shape()[ax.StaticAxisIndex()];
int realBeginIndex = (beginIndex[i] >= 0) ? beginIndex[i] : beginIndex[i] + sliceAxisDim; if (sliceAxisDim == NDShape::FreeDimension && (beginIndex[i] < 0 || endIndex[i] <= 0))
int realEndIndex = (endIndex[i] > 0) ? endIndex[i] : endIndex[i] + sliceAxisDim;
if ((sliceAxisDim < realEndIndex) || (realEndIndex < realBeginIndex) || (realBeginIndex < 0))
RuntimeError("Function '%S': Slice operation index range [%d,%d), interpreted as [%d,%d), is invalid for input '%S' shape '%S'.",
AsString().c_str(),
beginIndex[i],
endIndex[i],
realBeginIndex,
realEndIndex,
m_inputs[0].AsString().c_str(),
m_inputs[0].Shape().AsString().c_str());
// propagate as much as we can
// Note: If the sliceAxisDim is a free dimension and the slice size is relative to the sliceAxisDim then the
// corresponding outputDim is also a free dimension
if ((((sliceAxisDim != NDShape::FreeDimension) && (sliceAxisDim != NDShape::InferredDimension)) || (((beginIndex[i] >= 0) && (endIndex[i] > 0)) || ((beginIndex[i] < 0) && (endIndex[i] <= 0)))) &&
((ax.StaticAxisIndex() < (int)outputTensorShape.GetRank()) && (0 <= realBeginIndex) && (realBeginIndex <= realEndIndex) && (realEndIndex <= sliceAxisDim)))
{ {
outputTensorShape.NarrowTo(ax.StaticAxisIndex(), realBeginIndex, realEndIndex, strides[i]); // not able to calculate real indices. do not narrow either.
// note that endIndex[i] = 0 means to (and include) the last.
// One case for this condition is to export and import, in ONNX format, a CNTK Sequence.Slice op.
// In this case, if batch size is larger than 1 and input data are a zigged array (i.e. sequences of various lengths),
// model evaludation will not march the original CNTK model.
}
else
{
int realBeginIndex = (beginIndex[i] >= 0) ? beginIndex[i] : beginIndex[i] + sliceAxisDim;
int realEndIndex = (endIndex[i] > 0) ? endIndex[i] : endIndex[i] + sliceAxisDim;
if ((sliceAxisDim < realEndIndex) || (realEndIndex < realBeginIndex) || (realBeginIndex < 0))
RuntimeError("Function '%S': Slice operation index range [%d,%d), interpreted as [%d,%d), is invalid for input '%S' shape '%S'.",
AsString().c_str(),
beginIndex[i],
endIndex[i],
realBeginIndex,
realEndIndex,
m_inputs[0].AsString().c_str(),
m_inputs[0].Shape().AsString().c_str());
// propagate as much as we can
// Note: If the sliceAxisDim is a free dimension and the slice size is relative to the sliceAxisDim then the
// corresponding outputDim is also a free dimension
if ((((sliceAxisDim != NDShape::FreeDimension) && (sliceAxisDim != NDShape::InferredDimension)) || (((beginIndex[i] >= 0) && (endIndex[i] > 0)) || ((beginIndex[i] < 0) && (endIndex[i] <= 0)))) &&
((ax.StaticAxisIndex() < (int)outputTensorShape.GetRank()) && (0 <= realBeginIndex) && (realBeginIndex <= realEndIndex) && (realEndIndex <= sliceAxisDim)))
{
outputTensorShape.NarrowTo(ax.StaticAxisIndex(), realBeginIndex, realEndIndex, strides[i]);
}
} }
} }
outputShape = AsNDShape(outputTensorShape, /*allowNonFlattenableTensorShapes = */ true); outputShape = AsNDShape(outputTensorShape, /*allowNonFlattenableTensorShapes = */ true);

Просмотреть файл

@ -83,7 +83,7 @@ private:
std::vector<LotusIR::NodeArg *>& outputs, Graph *graph); std::vector<LotusIR::NodeArg *>& outputs, Graph *graph);
static LotusIR::Node *AddReshapeNode(LotusIR::NodeArg &nodeArg, const std::vector<int> &newShape, const std::string &outArgName, static LotusIR::Node *AddReshapeNode(LotusIR::NodeArg &nodeArg, const std::vector<int> &newShape, const std::string &outArgName,
LotusIR::Graph* graph, int dynamicAxisCount); LotusIR::Graph* graph);
static LotusIR::Node *AddMatMulNode(LotusIR::NodeArg &nodeArg1, LotusIR::NodeArg &nodeArg2, LotusIR::Graph* graph, static LotusIR::Node *AddMatMulNode(LotusIR::NodeArg &nodeArg1, LotusIR::NodeArg &nodeArg2, LotusIR::Graph* graph,
const std::string &out_arg_name); const std::string &out_arg_name);
static LotusIR::Node *AddArgMaxNode(LotusIR::NodeArg &nodeArg, LotusIR::Graph* graph, int axis); static LotusIR::Node *AddArgMaxNode(LotusIR::NodeArg &nodeArg, LotusIR::Graph* graph, int axis);
@ -1181,10 +1181,13 @@ bool IsUnSupportedLayerNormalization(const FunctionPtr src)
FunctionPtr SkipBatchAndSequenceAxisOp(const FunctionPtr src) FunctionPtr SkipBatchAndSequenceAxisOp(const FunctionPtr src)
{ {
if ((src->OpName() == L"ToSequenceOp" && src->Inputs()[0].Owner() && if ((src->OpName() == L"ToSequenceOp" && src->Inputs()[0].Owner() &&
src->Inputs()[0].Owner()->OpName() == L"ToBatchAxis") || src->Inputs()[0].Owner()->OpName() == L"ToBatchAxis") ||
(src->OpName() == L"UnpackBatchAxis" && src->Inputs()[0].Owner() && (src->OpName() == L"UnpackBatchAxis" && src->Inputs()[0].Owner() &&
src->Inputs()[0].Owner()->OpName() == L"UnpackSequenceOp")) src->Inputs()[0].Owner()->OpName() == L"UnpackSequenceOp"))
return src->Inputs()[0].Owner()->Inputs()[0].Owner(); return src->Inputs()[0].Owner()->Inputs()[0].Owner();
else if (src->OpName() == L"UnpackBatchAxis" && src->Inputs()[0].Owner() &&
src->Inputs()[0].Owner()->OpName() == L"Sequence::Slice")
return src->Inputs()[0].Owner();
else else
return src; return src;
} }
@ -1342,11 +1345,12 @@ bool CNTKToONNXHelper::FilterInput(const FunctionPtr& src, const CNTK::Variable&
} }
/* /*
CNTK python static axis is zero based. Free/Inferred axis is not static. CNTK python static axis is zero based. Batch and Sequence axis is not static axis.
ONNX batch axis, if exists, is 0. in this case static axes start from 1. CNTK cpp get static axis in a sanitized form (e.g. -axis - 1 by sanitize_axis)
CNTK cpp get static axis in a dis-normalized form (e.g. -axis - 1) In general CNTK node attribute contains axis
In general CNTK node attribute contains axis in this dis-normalized form. in a dis-normalized form (e.g. index from the last dimension).
This function converts dis-normalized form to ONNX form. This function converts axis to ONNX form
(e.g. index from the first dimension of the shape including both static and dynamic axes).
*/ */
int64_t CNTKToONNXHelper::ConvertAxisToOnnx(const Axis &axis, const Variable &operand) int64_t CNTKToONNXHelper::ConvertAxisToOnnx(const Axis &axis, const Variable &operand)
{ {
@ -2315,9 +2319,9 @@ LotusIR::Node *CNTKToONNXHelper::AddReshapeNodeAccordingToONNXVersion(Graph *gra
LotusIR::Node *CNTKToONNXHelper::AddReshapeNode(LotusIR::NodeArg &nodeArg, const std::vector<int> &newShape, const std::string &outArgName, LotusIR::Node *CNTKToONNXHelper::AddReshapeNode(LotusIR::NodeArg &nodeArg, const std::vector<int> &newShape, const std::string &outArgName,
LotusIR::Graph *graph, int dynamicAxisCount) LotusIR::Graph *graph)
{ {
onnx::TypeProto typeProto = ToTypeProto(newShape, dynamicAxisCount); onnx::TypeProto typeProto = ToTypeProto(newShape, false);
UpdateONNXType(CNTK::DataType::Float, typeProto); UpdateONNXType(CNTK::DataType::Float, typeProto);
LotusIR::NodeArg &outputArg = graph->GetOrCreateNodeArg(outArgName, &typeProto); LotusIR::NodeArg &outputArg = graph->GetOrCreateNodeArg(outArgName, &typeProto);
@ -2386,6 +2390,144 @@ LotusIR::Node *CNTKToONNXHelper::InsertReshapeNodeToCNTKFunction(const FunctionP
return reshapeNode; return reshapeNode;
} }
// To parse Sequence.Slice node graph to collect axis/begin index/end index
// and to build an ONNX Slice op.
// IMPORTANT NOTE:
// This function convert a CNTK Sequence::Slice op to ONNX Slice op.
// CNTK Sequence::Slice has ability to handle input of zigged arrays (i.e. sequences of various lengths).
// ONNX Slice does not support zigged arrays data format.
// Therefore in case of batch size larger than 1 and input data are a zigged arrays,
// we do not expect model evaludation to generate marching numbers between CNTK and ONNX.
// with this following CNTK example:
// model = C.sequence.slice(C.sequence.input_variable((1)), -2, -1)
// model.eval([[0, 1, 2], [0, 1, 2, 3, 4]])
// CNTK output is:
// array([[1.], [3.]], dtype = float32)
// output from exported ONNX model will be:
// array([[padding_value], [3.]], dtype = float32)
LotusIR::Node* CNTKToONNXHelper::CreateSequenceSliceNode(const FunctionPtr& src,
LotusIR::Graph* graph,
std::unordered_map<FunctionPtr, LotusIR::Node*>& functionNodes,
std::unordered_map<Variable, LotusIR::Node*>& variableNodes,
const std::unordered_map<Variable, Variable>& compositeOutputsMap)
{
auto f = src->BlockRoot();
int64_t beginIndex = 0, endIndex = 0;
auto packedIndex = f->Inputs()[1].Owner();
auto whereFunc = packedIndex->Inputs()[1].Owner();
auto inputToWhere = whereFunc->Inputs()[0].Owner();
// input to Where node can be:
// ElementTimes - both indices are non-zero, beginIndex/endIndex are from First/Second inputs
// FutureValue - beginIndex is negative, endIndex is zero
// PastValue - endIndex is positive, beginIndex is zero
// 1 Minus FutureValue - endIndex is negative, beginIndex is zero
// 1 Minus PastValue - beginIndex is positive, endIndex is zero
auto reportLogicError = [&src]()
{
LogicError("Failed to parse Sequence.Slice node %s(%s).", ToLegacyString(ToUTF8(src->Name())).c_str(), ToLegacyString(ToUTF8(src->Uid())).c_str());
};
if (inputToWhere->OpName() == L"ElementTimes")
{
{
auto beginToWhere = inputToWhere->Inputs()[0].Owner();
if (beginToWhere->OpName() == L"Minus")
{
auto beginToMinusMustBeAPastValueOp = beginToWhere->Inputs()[1].Owner();
if (beginToMinusMustBeAPastValueOp->OpName() == L"PastValue")
beginIndex = static_cast<int64_t>(beginToMinusMustBeAPastValueOp->Attributes()[PrimitiveFunction::AttributeNameOffset].Value<size_t>());
else
reportLogicError();
}
else if (beginToWhere->OpName() == L"FutureValue")
{
beginIndex = -static_cast<int64_t>(beginToWhere->Attributes()[PrimitiveFunction::AttributeNameOffset].Value<size_t>());
}
else
reportLogicError();
}
{
auto endToWhere = inputToWhere->Inputs()[1].Owner();
if (endToWhere->OpName() == L"Minus")
{
auto endToMinusMustBeAFutureValueOp = endToWhere->Inputs()[1].Owner();
if (endToMinusMustBeAFutureValueOp->OpName() == L"FutureValue")
endIndex = -static_cast<int64_t>(endToMinusMustBeAFutureValueOp->Attributes()[PrimitiveFunction::AttributeNameOffset].Value<size_t>());
else
reportLogicError();
}
else if (endToWhere->OpName() == L"PastValue")
{
endIndex = static_cast<int64_t>(endToWhere->Attributes()[PrimitiveFunction::AttributeNameOffset].Value<size_t>());
}
else
reportLogicError();
}
}
else if (inputToWhere->OpName() == L"FutureValue")
{
beginIndex = -static_cast<int64_t>(inputToWhere->Attributes()[PrimitiveFunction::AttributeNameOffset].Value<size_t>());
}
else if (inputToWhere->OpName() == L"PastValue")
{
endIndex = static_cast<int64_t>(inputToWhere->Attributes()[PrimitiveFunction::AttributeNameOffset].Value<size_t>());
}
else if (inputToWhere->OpName() == L"Minus")
{
auto inputToMinus = inputToWhere->Inputs()[1].Owner();
if (inputToMinus->OpName() == L"FutureValue")
{
endIndex = -static_cast<int64_t>(inputToMinus->Attributes()[PrimitiveFunction::AttributeNameOffset].Value<size_t>());
}
else if (inputToMinus->OpName() == L"PastValue")
{
beginIndex = static_cast<int64_t>(inputToMinus->Attributes()[PrimitiveFunction::AttributeNameOffset].Value<size_t>());
}
}
if (endIndex == 0)
// this is where CNTK and numpy disagree. numpy will output an empty matrix
// where CNTK outputs from beginIndex to (and include) the last.
endIndex = INT_MAX;
std::vector<LotusIR::NodeArg *> inputs;
ProcessInputs(src, graph, functionNodes, variableNodes, compositeOutputsMap, inputs);
//std::vector<LotusIR::NodeArg *> outputs;
//ProcessOutputs(src, outputs, graph);
auto outputArgType = ToTypeProto(src->Output().Shape(), src->Output().HasBatchAxis(), src->Output().HasSequenceAxis());
UpdateONNXType(src->Output().GetDataType(), outputArgType);
std::string outputName = ToLegacyString(ToUTF8(src->BlockRoot()->Output().Uid()));
std::string sliceOutputName = outputName;
bool seq_dim_is_1 = endIndex - beginIndex == 1 || (endIndex == INT_MAX && beginIndex == -1);
if (seq_dim_is_1)
{
// it appears that sequence.slice squeezes sequence axis out if slice length is 1
sliceOutputName += "_PreReshape";
}
LotusIR::NodeArg &outputNodeArg = graph->GetOrCreateNodeArg(sliceOutputName, &outputArgType);
const std::string & nodeName = ToLegacyString(ToUTF8(src->Name()));
LotusIR::Node *sequenceSliceNode = graph->AddNode(nodeName, "Slice", "", { inputs[inputs.size() - 1] }, { &outputNodeArg });
sequenceSliceNode->AddAttribute("axes", std::vector<int64_t>({ int64_t(0) }));
sequenceSliceNode->AddAttribute("ends", std::vector<int64_t>({ endIndex }));
sequenceSliceNode->AddAttribute("starts", std::vector<int64_t>({ beginIndex }));
if (seq_dim_is_1)
{
// CNTK Sequence.Slice op squeezes the sequence axis if it is of dimension 1.
// insert reshape to remove sequence axis
std::vector<int> newShape(reverse(Cast<size_t, int>(src->Output().Shape().Dimensions())));
// add batch size at end
newShape.insert(newShape.begin(), 1);
const std::string outArgName = sliceOutputName;
return AddReshapeNode(outputNodeArg, newShape, outputName, graph);
}
else
return sequenceSliceNode;
}
// //
// This is the main horsepower, it navigate CNTK graph recursivley while keep track of all visited nodes and variables, // This is the main horsepower, it navigate CNTK graph recursivley while keep track of all visited nodes and variables,
// and create the corresponding ONNX graph. // and create the corresponding ONNX graph.
@ -2417,7 +2559,15 @@ LotusIR::Node* CNTKToONNXHelper::CreateNode(const FunctionPtr& initialSrc,
// return CreateLSTMNode(src, graph, functionNodes, variableNodes, compositeOutputsMap); // return CreateLSTMNode(src, graph, functionNodes, variableNodes, compositeOutputsMap);
//} //}
//else //else
if (cntkOpName == "RNNStep") if (cntkOpName == "Sequence::Slice")
{
return CreateSequenceSliceNode(src,
graph,
functionNodes,
variableNodes,
compositeOutputsMap);
}
else if (cntkOpName == "RNNStep")
{ {
return CreateRNNNode(src, graph, functionNodes, variableNodes, compositeOutputsMap); return CreateRNNNode(src, graph, functionNodes, variableNodes, compositeOutputsMap);
} }

Просмотреть файл

@ -2678,7 +2678,6 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
// { L"", "Split) // { L"", "Split)
else if (onnxOpName == "Slice") else if (onnxOpName == "Slice")
{ {
// axes is optional so provide a default
std::vector<Axis> axes = ConvertONNXAxesToCNTKCppApi(GetNamedAttributeAsInt64Vec(node, "axes"), inputs[0]); std::vector<Axis> axes = ConvertONNXAxesToCNTKCppApi(GetNamedAttributeAsInt64Vec(node, "axes"), inputs[0]);
std::vector<int64_t> starts64 = GetNamedAttributeAsInt64Vec(node, "starts"); std::vector<int64_t> starts64 = GetNamedAttributeAsInt64Vec(node, "starts");
@ -2692,7 +2691,14 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
std::vector<int> starts = VecInt64ToVecInt(starts64); std::vector<int> starts = VecInt64ToVecInt(starts64);
std::vector<int> ends = VecInt64ToVecInt(ends64); std::vector<int> ends = VecInt64ToVecInt(ends64);
for (auto &e : ends)
{
// CNTK treats endIndex of 0 as to (and include) the last.
if (e == INT_MAX)
e = 0;
}
// axes is optional so provide a default
if (axes.empty()) if (axes.empty())
{ {
for (int i = 0; i < starts.size(); i++) for (int i = 0; i < starts.size(); i++)
@ -2702,13 +2708,6 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
} }
} }
bool workaroundONNXRT = false;
if (workaroundONNXRT)
{
axes.erase(axes.begin());
starts.erase(starts.begin());
ends.erase(ends.begin());
}
FunctionPtr cntkFunction = Slice(inputs[0], axes, starts, ends, ToFixedWStringFromMultiByte(node->Name())); FunctionPtr cntkFunction = Slice(inputs[0], axes, starts, ends, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction; return cntkFunction;
} }

Просмотреть файл

@ -463,11 +463,21 @@ Variable ToBatchAndSequence(Variable input)
return operandWithBatchAndSequenceAxis; return operandWithBatchAndSequenceAxis;
} }
FunctionPtr UnpackBatchAndSequence(FunctionPtr rnnFunction) FunctionPtr UnpackBatchAndSequence(FunctionPtr rnnFunction, bool doTranspose)
{ {
FunctionPtr cntkFunctionWithoutSequenceAxis = Sequence::Unpack(rnnFunction, 0, L""); FunctionPtr cntkFunctionWithoutSequenceAxis = Sequence::Unpack(rnnFunction, 0, L"");
FunctionPtr cntkFunctionWithoutDynamicAxis = UnpackBatch(cntkFunctionWithoutSequenceAxis, L""); FunctionPtr cntkFunctionWithoutDynamicAxis = UnpackBatch(cntkFunctionWithoutSequenceAxis, L"");
return cntkFunctionWithoutDynamicAxis; if (doTranspose)
{
FunctionPtr transpose = TransposeAxes(cntkFunctionWithoutDynamicAxis,
Axis(cntkFunctionWithoutDynamicAxis->Output().Shape().Rank() - 2),
Axis(cntkFunctionWithoutDynamicAxis->Output().Shape().Rank() - 1), L"");
return transpose;
}
else
// in case of RNN ops, transpose is inserted after the op so we do not do transpose again
// TODO: do not transpose after RNN ops so we have one code path here.
return cntkFunctionWithoutDynamicAxis;
} }
FunctionPtr CreateLSTM(const LotusIR::Node *node, const std::vector<Variable> &inputs, const std::string &direction, FunctionPtr CreateLSTM(const LotusIR::Node *node, const std::vector<Variable> &inputs, const std::string &direction,
@ -557,7 +567,7 @@ FunctionPtr CreateLSTM(const LotusIR::Node *node, const std::vector<Variable> &i
rnnFunction = Splice(operands, Axis(0), ToFixedWStringFromMultiByte(node->Name())); rnnFunction = Splice(operands, Axis(0), ToFixedWStringFromMultiByte(node->Name()));
} }
FunctionPtr unpackedRnnFunction = UnpackBatchAndSequence(rnnFunction); FunctionPtr unpackedRnnFunction = UnpackBatchAndSequence(rnnFunction, false);
return unpackedRnnFunction; return unpackedRnnFunction;
} }
@ -622,7 +632,7 @@ FunctionPtr CreateGRU(const LotusIR::Node *node, const std::vector<Variable> &in
rnnFunction = Splice(operands, Axis(0), ToFixedWStringFromMultiByte(node->Name())); rnnFunction = Splice(operands, Axis(0), ToFixedWStringFromMultiByte(node->Name()));
} }
FunctionPtr unpackedRnnFunction = UnpackBatchAndSequence(rnnFunction); FunctionPtr unpackedRnnFunction = UnpackBatchAndSequence(rnnFunction, false);
return unpackedRnnFunction; return unpackedRnnFunction;
} }
@ -678,7 +688,7 @@ FunctionPtr CreateRNN(const LotusIR::Node *node, const std::vector<Variable> &in
rnnFunction = Splice(operands, Axis(0), ToFixedWStringFromMultiByte(node->Name())); rnnFunction = Splice(operands, Axis(0), ToFixedWStringFromMultiByte(node->Name()));
} }
FunctionPtr unpackedRnnFunction = UnpackBatchAndSequence(rnnFunction); FunctionPtr unpackedRnnFunction = UnpackBatchAndSequence(rnnFunction, false);
return unpackedRnnFunction; return unpackedRnnFunction;
} }

Просмотреть файл

@ -184,4 +184,4 @@ std::vector<CNTK::FunctionPtr> GetRNNBlocksFromSingleOrBidirectionalRNN(const CN
CNTK::Variable ToBatchAndSequence(CNTK::Variable input); CNTK::Variable ToBatchAndSequence(CNTK::Variable input);
CNTK::FunctionPtr UnpackBatchAndSequence(CNTK::FunctionPtr rnnFunction); CNTK::FunctionPtr UnpackBatchAndSequence(CNTK::FunctionPtr rnnFunction, bool doTranspose = true);

Просмотреть файл

@ -93,20 +93,17 @@ def verify_sequence_model(model, data, tmpdir, name, device=None, loaded_model=N
perm[0], perm[1] = perm[1], perm[0] perm[0], perm[1] = perm[1], perm[0]
return np.transpose(data, perm) return np.transpose(data, perm)
assert model.output.has_sequence_axis() and model.output.has_batch_axis()
# data here is reference to the outside data object. create deepcopy to avoid changing the outside data since it might get reused. # data here is reference to the outside data object. create deepcopy to avoid changing the outside data since it might get reused.
data = deepcopy(data) data = deepcopy(data)
opname = model.owner.op_name opname = model.owner.op_name
loaded_model = try_save_load_resave_onnx_model(model, tmpdir, name, loaded_model) loaded_model = try_save_load_resave_onnx_model(model, tmpdir, name, loaded_model)
model_shape = model.shape
# When both batch and sequence axes exist, model input will have batch and sequence axes # in cases like with RNN models where models have both batch and sequence axis
# swapped to match the onnx model. In this case, input and output data shall be adjusted accordingly. # as dynamic axis, imported models will have the dynamic axes as free_dimensions in static shapes.
model_shape = (CNTK_FREEDIM_AXIS_DENOTATION, CNTK_FREEDIM_AXIS_DENOTATION, ) + model_shape if model.output.dynamic_axes == (C.Axis('defaultBatchAxis'), C.Axis('defaultDynamicAxis')):
assert model_shape == loaded_model.shape assert (CNTK_FREEDIM_AXIS_DENOTATION, CNTK_FREEDIM_AXIS_DENOTATION, ) + model.shape == loaded_model.shape
dataOnnx = TranposeDynamicAxis(data) dataOnnx = TranposeDynamicAxis(data)
if device: if device:
@ -116,13 +113,13 @@ def verify_sequence_model(model, data, tmpdir, name, device=None, loaded_model=N
o0 = model.eval({model.arguments[0]:data}) o0 = model.eval({model.arguments[0]:data})
o1 = loaded_model.eval({loaded_model.arguments[0]:dataOnnx}) o1 = loaded_model.eval({loaded_model.arguments[0]:dataOnnx})
if (type(o0) is list): o0 = np.array(o0)
o0 = o0[0] o1 = np.array(o1)
if (type(o1) is list):
o1 = o1[0]
# squeeze the batch axis (=1) to match original model output # if there is a sequence axis in the output, it must be swapped with batch axis
o1 = np.squeeze(o1, axis=1) # to match the original CNTK model's output
if model.outputs[0].has_sequence_axis():
o1 = TranposeDynamicAxis(o1)
assert np.allclose(o0, o1) assert np.allclose(o0, o1)
return loaded_model return loaded_model
@ -1352,6 +1349,21 @@ def test_Slice(tmpdir, dtype):
model = C.slice(x1, [0,1], [1,0], [2,1]); model = C.slice(x1, [0,1], [1,0], [2,1]);
verify_one_input(model, data, tmpdir, 'Slice2_1') verify_one_input(model, data, tmpdir, 'Slice2_1')
#Sequence.Slice
@pytest.mark.parametrize("beginIndex, endIndex", (
(-2, -1), (0, -1), (1, -1), (-1, 0), (1, 0), (-4, 2), (0, 1), (1, 2)))
@pytest.mark.parametrize("dtype", DType_Config)
def test_SequenceSlice(tmpdir, dtype, beginIndex, endIndex):
batch_size = 1
sequence_length = 5
feature_shape = (3,)
shape = (batch_size, sequence_length, *feature_shape)
data = np.reshape(range(0, np.prod(shape)), shape).astype(dtype)
testName = "test_sequence_slice_{0}.{1}".format(beginIndex, endIndex)
print(testName)
model = C.sequence.slice(C.sequence.input_variable((feature_shape)), beginIndex, endIndex)
verify_sequence_model(model, data, tmpdir, testName)
#Softmax #Softmax
@pytest.mark.parametrize("dtype", DType_Config) @pytest.mark.parametrize("dtype", DType_Config)
def test_Softmax(tmpdir, dtype): def test_Softmax(tmpdir, dtype):