From 4a6238d9798905dc01927d3b5a665acbbca727c2 Mon Sep 17 00:00:00 2001 From: Peyman Manikashani Date: Thu, 23 Aug 2018 11:05:56 -0700 Subject: [PATCH] reduction all axes export fix --- .../CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp b/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp index 8cb824ba2..0739a324c 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp +++ b/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp @@ -3464,11 +3464,19 @@ void CNTKToONNXHelper::CopyAttributes(const FunctionPtr& src, LotusIR::Node* nod Axis axis = (Axis)(src->Attributes()[L"axis"].Value()); // Reduction on batch axis in CNTK removes the batch axis, even if keepdims is true. // For ONNX export we need to make sure we export keepdims as 0 (false). - if (axis == Axis::DefaultBatchAxis()) + // The same applies for All axes + if (axis == Axis::DefaultBatchAxis() || axis == Axis::AllAxes() || axis == Axis::AllStaticAxes()) keepReducedDimensions = 0; - int64_t ax = ConvertAxisToOnnx(axis, src->Inputs()[0]); - - node->AddAttribute("axis", ax); + if (node->OpType() != "ArgMax" && node->OpType() != "ArgMin") + { + std::vector axes = ConvertAxesToOnnx(std::vector({ axis }), src->Inputs()[0]); + node->AddAttribute("axes", axes); + } + else + { + int64_t ax = ConvertAxisToOnnx(axis, src->Inputs()[0]); + node->AddAttribute("axis", ax); + } } node->AddAttribute("keepdims", keepReducedDimensions);