diff --git a/Source/CNTKv2LibraryDll/PrimitiveFunction.cpp b/Source/CNTKv2LibraryDll/PrimitiveFunction.cpp index b891717ef..e74f4cd0f 100644 --- a/Source/CNTKv2LibraryDll/PrimitiveFunction.cpp +++ b/Source/CNTKv2LibraryDll/PrimitiveFunction.cpp @@ -600,24 +600,35 @@ namespace CNTK VerifyStaticAxis(ax, m_inputs[0].Shape()); size_t sliceAxisDim = m_inputs[0].Shape()[ax.StaticAxisIndex()]; - int realBeginIndex = (beginIndex[i] >= 0) ? beginIndex[i] : beginIndex[i] + sliceAxisDim; - int realEndIndex = (endIndex[i] > 0) ? endIndex[i] : endIndex[i] + sliceAxisDim; - if ((sliceAxisDim < realEndIndex) || (realEndIndex < realBeginIndex) || (realBeginIndex < 0)) - RuntimeError("Function '%S': Slice operation index range [%d,%d), interpreted as [%d,%d), is invalid for input '%S' shape '%S'.", - AsString().c_str(), - beginIndex[i], - endIndex[i], - realBeginIndex, - realEndIndex, - m_inputs[0].AsString().c_str(), - m_inputs[0].Shape().AsString().c_str()); - // propagate as much as we can - // Note: If the sliceAxisDim is a free dimension and the slice size is relative to the sliceAxisDim then the - // corresponding outputDim is also a free dimension - if ((((sliceAxisDim != NDShape::FreeDimension) && (sliceAxisDim != NDShape::InferredDimension)) || (((beginIndex[i] >= 0) && (endIndex[i] > 0)) || ((beginIndex[i] < 0) && (endIndex[i] <= 0)))) && - ((ax.StaticAxisIndex() < (int)outputTensorShape.GetRank()) && (0 <= realBeginIndex) && (realBeginIndex <= realEndIndex) && (realEndIndex <= sliceAxisDim))) + if (sliceAxisDim == NDShape::FreeDimension && (beginIndex[i] < 0 || endIndex[i] <= 0)) { - outputTensorShape.NarrowTo(ax.StaticAxisIndex(), realBeginIndex, realEndIndex, strides[i]); + // not able to calculate real indices. do not narrow either. + // note that endIndex[i] = 0 means to (and include) the last. + // One case for this condition is to export and import, in ONNX format, a CNTK Sequence.Slice op. + // In this case, if batch size is larger than 1 and input data are a zigged array (i.e. sequences of various lengths), + // model evaludation will not march the original CNTK model. + } + else + { + int realBeginIndex = (beginIndex[i] >= 0) ? beginIndex[i] : beginIndex[i] + sliceAxisDim; + int realEndIndex = (endIndex[i] > 0) ? endIndex[i] : endIndex[i] + sliceAxisDim; + if ((sliceAxisDim < realEndIndex) || (realEndIndex < realBeginIndex) || (realBeginIndex < 0)) + RuntimeError("Function '%S': Slice operation index range [%d,%d), interpreted as [%d,%d), is invalid for input '%S' shape '%S'.", + AsString().c_str(), + beginIndex[i], + endIndex[i], + realBeginIndex, + realEndIndex, + m_inputs[0].AsString().c_str(), + m_inputs[0].Shape().AsString().c_str()); + // propagate as much as we can + // Note: If the sliceAxisDim is a free dimension and the slice size is relative to the sliceAxisDim then the + // corresponding outputDim is also a free dimension + if ((((sliceAxisDim != NDShape::FreeDimension) && (sliceAxisDim != NDShape::InferredDimension)) || (((beginIndex[i] >= 0) && (endIndex[i] > 0)) || ((beginIndex[i] < 0) && (endIndex[i] <= 0)))) && + ((ax.StaticAxisIndex() < (int)outputTensorShape.GetRank()) && (0 <= realBeginIndex) && (realBeginIndex <= realEndIndex) && (realEndIndex <= sliceAxisDim))) + { + outputTensorShape.NarrowTo(ax.StaticAxisIndex(), realBeginIndex, realEndIndex, strides[i]); + } } } outputShape = AsNDShape(outputTensorShape, /*allowNonFlattenableTensorShapes = */ true); diff --git a/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp b/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp index 2bff97463..11d42689c 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp +++ b/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp @@ -83,7 +83,7 @@ private: std::vector& outputs, Graph *graph); static LotusIR::Node *AddReshapeNode(LotusIR::NodeArg &nodeArg, const std::vector &newShape, const std::string &outArgName, - LotusIR::Graph* graph, int dynamicAxisCount); + LotusIR::Graph* graph); static LotusIR::Node *AddMatMulNode(LotusIR::NodeArg &nodeArg1, LotusIR::NodeArg &nodeArg2, LotusIR::Graph* graph, const std::string &out_arg_name); static LotusIR::Node *AddArgMaxNode(LotusIR::NodeArg &nodeArg, LotusIR::Graph* graph, int axis); @@ -1181,10 +1181,13 @@ bool IsUnSupportedLayerNormalization(const FunctionPtr src) FunctionPtr SkipBatchAndSequenceAxisOp(const FunctionPtr src) { if ((src->OpName() == L"ToSequenceOp" && src->Inputs()[0].Owner() && - src->Inputs()[0].Owner()->OpName() == L"ToBatchAxis") || + src->Inputs()[0].Owner()->OpName() == L"ToBatchAxis") || (src->OpName() == L"UnpackBatchAxis" && src->Inputs()[0].Owner() && src->Inputs()[0].Owner()->OpName() == L"UnpackSequenceOp")) return src->Inputs()[0].Owner()->Inputs()[0].Owner(); + else if (src->OpName() == L"UnpackBatchAxis" && src->Inputs()[0].Owner() && + src->Inputs()[0].Owner()->OpName() == L"Sequence::Slice") + return src->Inputs()[0].Owner(); else return src; } @@ -1342,11 +1345,12 @@ bool CNTKToONNXHelper::FilterInput(const FunctionPtr& src, const CNTK::Variable& } /* -CNTK python static axis is zero based. Free/Inferred axis is not static. -ONNX batch axis, if exists, is 0. in this case static axes start from 1. -CNTK cpp get static axis in a dis-normalized form (e.g. -axis - 1) -In general CNTK node attribute contains axis in this dis-normalized form. -This function converts dis-normalized form to ONNX form. +CNTK python static axis is zero based. Batch and Sequence axis is not static axis. +CNTK cpp get static axis in a sanitized form (e.g. -axis - 1 by sanitize_axis) +In general CNTK node attribute contains axis +in a dis-normalized form (e.g. index from the last dimension). +This function converts axis to ONNX form +(e.g. index from the first dimension of the shape including both static and dynamic axes). */ int64_t CNTKToONNXHelper::ConvertAxisToOnnx(const Axis &axis, const Variable &operand) { @@ -2315,9 +2319,9 @@ LotusIR::Node *CNTKToONNXHelper::AddReshapeNodeAccordingToONNXVersion(Graph *gra LotusIR::Node *CNTKToONNXHelper::AddReshapeNode(LotusIR::NodeArg &nodeArg, const std::vector &newShape, const std::string &outArgName, - LotusIR::Graph *graph, int dynamicAxisCount) + LotusIR::Graph *graph) { - onnx::TypeProto typeProto = ToTypeProto(newShape, dynamicAxisCount); + onnx::TypeProto typeProto = ToTypeProto(newShape, false); UpdateONNXType(CNTK::DataType::Float, typeProto); LotusIR::NodeArg &outputArg = graph->GetOrCreateNodeArg(outArgName, &typeProto); @@ -2386,6 +2390,144 @@ LotusIR::Node *CNTKToONNXHelper::InsertReshapeNodeToCNTKFunction(const FunctionP return reshapeNode; } +// To parse Sequence.Slice node graph to collect axis/begin index/end index +// and to build an ONNX Slice op. +// IMPORTANT NOTE: +// This function convert a CNTK Sequence::Slice op to ONNX Slice op. +// CNTK Sequence::Slice has ability to handle input of zigged arrays (i.e. sequences of various lengths). +// ONNX Slice does not support zigged arrays data format. +// Therefore in case of batch size larger than 1 and input data are a zigged arrays, +// we do not expect model evaludation to generate marching numbers between CNTK and ONNX. +// with this following CNTK example: +// model = C.sequence.slice(C.sequence.input_variable((1)), -2, -1) +// model.eval([[0, 1, 2], [0, 1, 2, 3, 4]]) +// CNTK output is: +// array([[1.], [3.]], dtype = float32) +// output from exported ONNX model will be: +// array([[padding_value], [3.]], dtype = float32) +LotusIR::Node* CNTKToONNXHelper::CreateSequenceSliceNode(const FunctionPtr& src, + LotusIR::Graph* graph, + std::unordered_map& functionNodes, + std::unordered_map& variableNodes, + const std::unordered_map& compositeOutputsMap) +{ + auto f = src->BlockRoot(); + int64_t beginIndex = 0, endIndex = 0; + + auto packedIndex = f->Inputs()[1].Owner(); + auto whereFunc = packedIndex->Inputs()[1].Owner(); + auto inputToWhere = whereFunc->Inputs()[0].Owner(); + // input to Where node can be: + // ElementTimes - both indices are non-zero, beginIndex/endIndex are from First/Second inputs + // FutureValue - beginIndex is negative, endIndex is zero + // PastValue - endIndex is positive, beginIndex is zero + // 1 Minus FutureValue - endIndex is negative, beginIndex is zero + // 1 Minus PastValue - beginIndex is positive, endIndex is zero + auto reportLogicError = [&src]() + { + LogicError("Failed to parse Sequence.Slice node %s(%s).", ToLegacyString(ToUTF8(src->Name())).c_str(), ToLegacyString(ToUTF8(src->Uid())).c_str()); + }; + if (inputToWhere->OpName() == L"ElementTimes") + { + { + auto beginToWhere = inputToWhere->Inputs()[0].Owner(); + if (beginToWhere->OpName() == L"Minus") + { + auto beginToMinusMustBeAPastValueOp = beginToWhere->Inputs()[1].Owner(); + if (beginToMinusMustBeAPastValueOp->OpName() == L"PastValue") + beginIndex = static_cast(beginToMinusMustBeAPastValueOp->Attributes()[PrimitiveFunction::AttributeNameOffset].Value()); + else + reportLogicError(); + } + else if (beginToWhere->OpName() == L"FutureValue") + { + beginIndex = -static_cast(beginToWhere->Attributes()[PrimitiveFunction::AttributeNameOffset].Value()); + } + else + reportLogicError(); + } + { + auto endToWhere = inputToWhere->Inputs()[1].Owner(); + if (endToWhere->OpName() == L"Minus") + { + auto endToMinusMustBeAFutureValueOp = endToWhere->Inputs()[1].Owner(); + if (endToMinusMustBeAFutureValueOp->OpName() == L"FutureValue") + endIndex = -static_cast(endToMinusMustBeAFutureValueOp->Attributes()[PrimitiveFunction::AttributeNameOffset].Value()); + else + reportLogicError(); + } + else if (endToWhere->OpName() == L"PastValue") + { + endIndex = static_cast(endToWhere->Attributes()[PrimitiveFunction::AttributeNameOffset].Value()); + } + else + reportLogicError(); + } + } + else if (inputToWhere->OpName() == L"FutureValue") + { + beginIndex = -static_cast(inputToWhere->Attributes()[PrimitiveFunction::AttributeNameOffset].Value()); + } + else if (inputToWhere->OpName() == L"PastValue") + { + endIndex = static_cast(inputToWhere->Attributes()[PrimitiveFunction::AttributeNameOffset].Value()); + } + else if (inputToWhere->OpName() == L"Minus") + { + auto inputToMinus = inputToWhere->Inputs()[1].Owner(); + if (inputToMinus->OpName() == L"FutureValue") + { + endIndex = -static_cast(inputToMinus->Attributes()[PrimitiveFunction::AttributeNameOffset].Value()); + } + else if (inputToMinus->OpName() == L"PastValue") + { + beginIndex = static_cast(inputToMinus->Attributes()[PrimitiveFunction::AttributeNameOffset].Value()); + } + } + + if (endIndex == 0) + // this is where CNTK and numpy disagree. numpy will output an empty matrix + // where CNTK outputs from beginIndex to (and include) the last. + endIndex = INT_MAX; + + std::vector inputs; + ProcessInputs(src, graph, functionNodes, variableNodes, compositeOutputsMap, inputs); + + //std::vector outputs; + //ProcessOutputs(src, outputs, graph); + auto outputArgType = ToTypeProto(src->Output().Shape(), src->Output().HasBatchAxis(), src->Output().HasSequenceAxis()); + UpdateONNXType(src->Output().GetDataType(), outputArgType); + + std::string outputName = ToLegacyString(ToUTF8(src->BlockRoot()->Output().Uid())); + std::string sliceOutputName = outputName; + bool seq_dim_is_1 = endIndex - beginIndex == 1 || (endIndex == INT_MAX && beginIndex == -1); + if (seq_dim_is_1) + { + // it appears that sequence.slice squeezes sequence axis out if slice length is 1 + sliceOutputName += "_PreReshape"; + } + + LotusIR::NodeArg &outputNodeArg = graph->GetOrCreateNodeArg(sliceOutputName, &outputArgType); + + const std::string & nodeName = ToLegacyString(ToUTF8(src->Name())); + LotusIR::Node *sequenceSliceNode = graph->AddNode(nodeName, "Slice", "", { inputs[inputs.size() - 1] }, { &outputNodeArg }); + sequenceSliceNode->AddAttribute("axes", std::vector({ int64_t(0) })); + sequenceSliceNode->AddAttribute("ends", std::vector({ endIndex })); + sequenceSliceNode->AddAttribute("starts", std::vector({ beginIndex })); + if (seq_dim_is_1) + { + // CNTK Sequence.Slice op squeezes the sequence axis if it is of dimension 1. + // insert reshape to remove sequence axis + std::vector newShape(reverse(Cast(src->Output().Shape().Dimensions()))); + // add batch size at end + newShape.insert(newShape.begin(), 1); + const std::string outArgName = sliceOutputName; + return AddReshapeNode(outputNodeArg, newShape, outputName, graph); + } + else + return sequenceSliceNode; +} + // // This is the main horsepower, it navigate CNTK graph recursivley while keep track of all visited nodes and variables, // and create the corresponding ONNX graph. @@ -2417,7 +2559,15 @@ LotusIR::Node* CNTKToONNXHelper::CreateNode(const FunctionPtr& initialSrc, // return CreateLSTMNode(src, graph, functionNodes, variableNodes, compositeOutputsMap); //} //else - if (cntkOpName == "RNNStep") + if (cntkOpName == "Sequence::Slice") + { + return CreateSequenceSliceNode(src, + graph, + functionNodes, + variableNodes, + compositeOutputsMap); + } + else if (cntkOpName == "RNNStep") { return CreateRNNNode(src, graph, functionNodes, variableNodes, compositeOutputsMap); } diff --git a/Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp b/Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp index e891599d8..39781fc99 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp +++ b/Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp @@ -2678,7 +2678,6 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector // { L"", "Split) else if (onnxOpName == "Slice") { - // axes is optional so provide a default std::vector axes = ConvertONNXAxesToCNTKCppApi(GetNamedAttributeAsInt64Vec(node, "axes"), inputs[0]); std::vector starts64 = GetNamedAttributeAsInt64Vec(node, "starts"); @@ -2692,7 +2691,14 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector std::vector starts = VecInt64ToVecInt(starts64); std::vector ends = VecInt64ToVecInt(ends64); + for (auto &e : ends) + { + // CNTK treats endIndex of 0 as to (and include) the last. + if (e == INT_MAX) + e = 0; + } + // axes is optional so provide a default if (axes.empty()) { for (int i = 0; i < starts.size(); i++) @@ -2702,13 +2708,6 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector } } - bool workaroundONNXRT = false; - if (workaroundONNXRT) - { - axes.erase(axes.begin()); - starts.erase(starts.begin()); - ends.erase(ends.begin()); - } FunctionPtr cntkFunction = Slice(inputs[0], axes, starts, ends, ToFixedWStringFromMultiByte(node->Name())); return cntkFunction; } diff --git a/Source/CNTKv2LibraryDll/proto/onnx/RNNHelper.cpp b/Source/CNTKv2LibraryDll/proto/onnx/RNNHelper.cpp index fef8314db..f05dd47a8 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/RNNHelper.cpp +++ b/Source/CNTKv2LibraryDll/proto/onnx/RNNHelper.cpp @@ -463,11 +463,21 @@ Variable ToBatchAndSequence(Variable input) return operandWithBatchAndSequenceAxis; } -FunctionPtr UnpackBatchAndSequence(FunctionPtr rnnFunction) +FunctionPtr UnpackBatchAndSequence(FunctionPtr rnnFunction, bool doTranspose) { FunctionPtr cntkFunctionWithoutSequenceAxis = Sequence::Unpack(rnnFunction, 0, L""); FunctionPtr cntkFunctionWithoutDynamicAxis = UnpackBatch(cntkFunctionWithoutSequenceAxis, L""); - return cntkFunctionWithoutDynamicAxis; + if (doTranspose) + { + FunctionPtr transpose = TransposeAxes(cntkFunctionWithoutDynamicAxis, + Axis(cntkFunctionWithoutDynamicAxis->Output().Shape().Rank() - 2), + Axis(cntkFunctionWithoutDynamicAxis->Output().Shape().Rank() - 1), L""); + return transpose; + } + else + // in case of RNN ops, transpose is inserted after the op so we do not do transpose again + // TODO: do not transpose after RNN ops so we have one code path here. + return cntkFunctionWithoutDynamicAxis; } FunctionPtr CreateLSTM(const LotusIR::Node *node, const std::vector &inputs, const std::string &direction, @@ -557,7 +567,7 @@ FunctionPtr CreateLSTM(const LotusIR::Node *node, const std::vector &i rnnFunction = Splice(operands, Axis(0), ToFixedWStringFromMultiByte(node->Name())); } - FunctionPtr unpackedRnnFunction = UnpackBatchAndSequence(rnnFunction); + FunctionPtr unpackedRnnFunction = UnpackBatchAndSequence(rnnFunction, false); return unpackedRnnFunction; } @@ -622,7 +632,7 @@ FunctionPtr CreateGRU(const LotusIR::Node *node, const std::vector &in rnnFunction = Splice(operands, Axis(0), ToFixedWStringFromMultiByte(node->Name())); } - FunctionPtr unpackedRnnFunction = UnpackBatchAndSequence(rnnFunction); + FunctionPtr unpackedRnnFunction = UnpackBatchAndSequence(rnnFunction, false); return unpackedRnnFunction; } @@ -678,7 +688,7 @@ FunctionPtr CreateRNN(const LotusIR::Node *node, const std::vector &in rnnFunction = Splice(operands, Axis(0), ToFixedWStringFromMultiByte(node->Name())); } - FunctionPtr unpackedRnnFunction = UnpackBatchAndSequence(rnnFunction); + FunctionPtr unpackedRnnFunction = UnpackBatchAndSequence(rnnFunction, false); return unpackedRnnFunction; } diff --git a/Source/CNTKv2LibraryDll/proto/onnx/RNNHelper.h b/Source/CNTKv2LibraryDll/proto/onnx/RNNHelper.h index 0c789ab4d..c6274364d 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/RNNHelper.h +++ b/Source/CNTKv2LibraryDll/proto/onnx/RNNHelper.h @@ -184,4 +184,4 @@ std::vector GetRNNBlocksFromSingleOrBidirectionalRNN(const CN CNTK::Variable ToBatchAndSequence(CNTK::Variable input); -CNTK::FunctionPtr UnpackBatchAndSequence(CNTK::FunctionPtr rnnFunction); \ No newline at end of file +CNTK::FunctionPtr UnpackBatchAndSequence(CNTK::FunctionPtr rnnFunction, bool doTranspose = true); \ No newline at end of file diff --git a/bindings/python/cntk/tests/onnx_op_test.py b/bindings/python/cntk/tests/onnx_op_test.py index f55f2e998..c3b7b641c 100644 --- a/bindings/python/cntk/tests/onnx_op_test.py +++ b/bindings/python/cntk/tests/onnx_op_test.py @@ -93,20 +93,17 @@ def verify_sequence_model(model, data, tmpdir, name, device=None, loaded_model=N perm[0], perm[1] = perm[1], perm[0] return np.transpose(data, perm) - assert model.output.has_sequence_axis() and model.output.has_batch_axis() - # data here is reference to the outside data object. create deepcopy to avoid changing the outside data since it might get reused. data = deepcopy(data) opname = model.owner.op_name loaded_model = try_save_load_resave_onnx_model(model, tmpdir, name, loaded_model) - model_shape = model.shape - # When both batch and sequence axes exist, model input will have batch and sequence axes - # swapped to match the onnx model. In this case, input and output data shall be adjusted accordingly. - model_shape = (CNTK_FREEDIM_AXIS_DENOTATION, CNTK_FREEDIM_AXIS_DENOTATION, ) + model_shape - assert model_shape == loaded_model.shape + # in cases like with RNN models where models have both batch and sequence axis + # as dynamic axis, imported models will have the dynamic axes as free_dimensions in static shapes. + if model.output.dynamic_axes == (C.Axis('defaultBatchAxis'), C.Axis('defaultDynamicAxis')): + assert (CNTK_FREEDIM_AXIS_DENOTATION, CNTK_FREEDIM_AXIS_DENOTATION, ) + model.shape == loaded_model.shape dataOnnx = TranposeDynamicAxis(data) if device: @@ -116,13 +113,13 @@ def verify_sequence_model(model, data, tmpdir, name, device=None, loaded_model=N o0 = model.eval({model.arguments[0]:data}) o1 = loaded_model.eval({loaded_model.arguments[0]:dataOnnx}) - if (type(o0) is list): - o0 = o0[0] - if (type(o1) is list): - o1 = o1[0] + o0 = np.array(o0) + o1 = np.array(o1) - # squeeze the batch axis (=1) to match original model output - o1 = np.squeeze(o1, axis=1) + # if there is a sequence axis in the output, it must be swapped with batch axis + # to match the original CNTK model's output + if model.outputs[0].has_sequence_axis(): + o1 = TranposeDynamicAxis(o1) assert np.allclose(o0, o1) return loaded_model @@ -1352,6 +1349,21 @@ def test_Slice(tmpdir, dtype): model = C.slice(x1, [0,1], [1,0], [2,1]); verify_one_input(model, data, tmpdir, 'Slice2_1') +#Sequence.Slice +@pytest.mark.parametrize("beginIndex, endIndex", ( + (-2, -1), (0, -1), (1, -1), (-1, 0), (1, 0), (-4, 2), (0, 1), (1, 2))) +@pytest.mark.parametrize("dtype", DType_Config) +def test_SequenceSlice(tmpdir, dtype, beginIndex, endIndex): + batch_size = 1 + sequence_length = 5 + feature_shape = (3,) + shape = (batch_size, sequence_length, *feature_shape) + data = np.reshape(range(0, np.prod(shape)), shape).astype(dtype) + testName = "test_sequence_slice_{0}.{1}".format(beginIndex, endIndex) + print(testName) + model = C.sequence.slice(C.sequence.input_variable((feature_shape)), beginIndex, endIndex) + verify_sequence_model(model, data, tmpdir, testName) + #Softmax @pytest.mark.parametrize("dtype", DType_Config) def test_Softmax(tmpdir, dtype):