From ceaec5636f9fe3d4b0441895d3549569f813bd87 Mon Sep 17 00:00:00 2001 From: Bowen Bao Date: Mon, 13 Aug 2018 17:15:36 -0700 Subject: [PATCH] fix squeezenet onnx_backend_test of issues with Softmax, hardmax, logsoftmax and flatten with freedimension on batchaxis. --- Source/CNTKv2LibraryDll/Function.cpp | 18 ++++++++++++++++-- .../CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp | 7 ++++++- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/Source/CNTKv2LibraryDll/Function.cpp b/Source/CNTKv2LibraryDll/Function.cpp index b80bc533e..cca84aa1d 100644 --- a/Source/CNTKv2LibraryDll/Function.cpp +++ b/Source/CNTKv2LibraryDll/Function.cpp @@ -1649,8 +1649,22 @@ namespace CNTK (int)operand.Shape().Rank(), ToLegacyString(ToUTF8(axis.AsString())).c_str()); } - size_t dim0 = cntk_index == 0 ? 1 : operand.Shape().SubShape(0, cntk_index).TotalSize(); - size_t dim1 = cntk_index == operand.Shape().Rank() ? 1 : operand.Shape().SubShape(cntk_index).TotalSize(); + auto getFlattenedDim = [&](bool isDim0) -> size_t { + const NDShape& operandSubShape = isDim0 ? operand.Shape().SubShape(0, cntk_index) : operand.Shape().SubShape(cntk_index); + // If subshape contains free or inferred dimension, we have reshape node try to infer what should the flattened dimension be. + if (operandSubShape.HasFreeDimension() || operandSubShape.HasInferredDimension()) + { + return NDShape::InferredDimension; + } + + if ((isDim0 && cntk_index == 0) || (!isDim0 && cntk_index == operand.Shape().Rank())) + return 1u; + else + return operandSubShape.TotalSize(); + }; + + size_t dim0 = getFlattenedDim(true); + size_t dim1 = getFlattenedDim(false); NDShape newShape({ dim0, dim1 }); diff --git a/Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp b/Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp index 983f3ca0d..ad9f69ee2 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp +++ b/Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp @@ -2442,7 +2442,12 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector { cntkFunction = Hardmax(input, ToFixedWStringFromMultiByte(node->Name())); } - return Reshape(cntkFunction, inputs[0].Shape()); + NDShape originalShape = inputs[0].Shape(); + assert(originalShape.Rank() > 0); + // If original shape has free dimension(batch axis), we'll need to have reshape node infer that for us. + if (originalShape[originalShape.Rank() - 1] == NDShape::FreeDimension) + originalShape[originalShape.Rank() - 1] = NDShape::InferredDimension; + return Reshape(cntkFunction, originalShape); } else if (onnxOpName == "Softplus") {