From 7dd96387991479efe4aaf2915007f32333cf17a8 Mon Sep 17 00:00:00 2001 From: Bowen Bao Date: Fri, 24 Aug 2018 13:44:28 -0700 Subject: [PATCH] squash of the following changes: - fix flatten onnx export. - fix unsqueeze onnx export. - add comments on temporarily skipped tests. - adjust the importing of softmax, logsoftmax and hardmax with blockfunction - such that they could be exported as is back to onnx. - update reshape onnx export to pass mobilenet round trip test. --- Source/CNTKv2LibraryDll/Function.cpp | 6 +---- .../proto/onnx/CNTKToONNX.cpp | 27 ++++++++++++++++--- .../proto/onnx/ONNXToCNTK.cpp | 23 +++++++++++----- .../CNTKv2LibraryDll/proto/onnx/Operators.cpp | 12 +++++++++ Tests/EndToEndTests/TestDriver.py | 1 + bindings/python/cntk/tests/onnx_model_test.py | 16 +---------- bindings/python/cntk/tests/onnx_op_test.py | 1 - 7 files changed, 55 insertions(+), 31 deletions(-) diff --git a/Source/CNTKv2LibraryDll/Function.cpp b/Source/CNTKv2LibraryDll/Function.cpp index 6b255e350..f48d1ff4c 100644 --- a/Source/CNTKv2LibraryDll/Function.cpp +++ b/Source/CNTKv2LibraryDll/Function.cpp @@ -1604,7 +1604,6 @@ namespace CNTK FunctionPtr Flatten(const Variable& operand, const Axis& axis, const std::wstring& name) { int cntk_index; - int onnx_axis; // We need to express in onnx axis system to help ONNX conversion. if (axis.IsStaticAxis()) @@ -1618,7 +1617,6 @@ namespace CNTK // for CNTK reshape, cntk_index shall point to the one after 3 (2): cntk_index = axis + 1 // cntk_index (-1) needs to be converted to positive by rank + cntk_index = 3 int cntk_py_index = -axis.StaticAxisIndex() - 1; - onnx_axis = cntk_py_index + 1; cntk_index = axis.StaticAxisIndex() + 1; cntk_index += operand.Shape().Rank(); } @@ -1629,7 +1627,6 @@ namespace CNTK // onnx_axis = 2, points to 3 in [#][[2], [3,4,5]] // cntk_index = 1, points to 3 in [2,3,4,5] int cntk_py_index = axis.StaticAxisIndex(); - onnx_axis = cntk_py_index + 1; cntk_index = axis.StaticAxisIndex(); } } @@ -1637,7 +1634,6 @@ namespace CNTK { // expected result: [[batch],[flatten sample]]([[#][2,3,4,5]]) cntk_index = 0; - onnx_axis = 1; } else { @@ -1670,7 +1666,7 @@ namespace CNTK NDShape newShape({ dim0, dim1 }); auto additionalProperties = Dictionary(); - additionalProperties[PrimitiveFunctionAttribute::AttributeNameAxis] = Axis(onnx_axis); + additionalProperties[PrimitiveFunctionAttribute::AttributeNameAxis] = Axis(cntk_index); auto operandPlaceholder = PlaceholderVariable(); diff --git a/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp b/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp index 7d03cfbe6..3871fc9a8 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp +++ b/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp @@ -2965,11 +2965,25 @@ void CNTKToONNXHelper::ProcessInputs(const FunctionPtr& src, if (cntkOpName == "Reshape" && IsONNX1_2Supported()) { // ONNX1.2 reshape node take shape as input instead of attribute. - const std::vector& shapeVec = src->Output().Shape().Dimensions(); + + // We can construct the shape input for onnx by two ways: 1. cntk node output shape, or 2. cntk node attribute "newShape". + // If there attribute "newShape" is missing, or attributes "beginAxis" and "endAxis" exists, we use cntk node output shape. + // such that we don't need to duplicate the shape inference logic here. + // Otherwise we use the cntk node attribute "newShape". + bool useOutputShape = [&]() { + if (!src->Attributes().Contains(L"newShape") || ((NDShape)src->Attributes()[L"newShape"].Value()).Rank() == 0) + return true; + if (src->Attributes().Contains(L"beginAxis") && ((Axis)src->Attributes()[L"beginAxis"].Value()).StaticAxisIndex() != 0) + return true; + if (src->Attributes().Contains(L"endAxis") && ((Axis)src->Attributes()[L"endAxis"].Value()).StaticAxisIndex() != src->Inputs()[0].Shape().Rank()) + return true; + return false; + }(); + const NDShape shape = useOutputShape ? src->Output().Shape() : (NDShape)src->Attributes()[L"newShape"].Value(); std::vector newShapeVec; size_t numInferredDimensions(0); - for (const auto& axisSize : shapeVec) + for (const auto& axisSize : shape.Dimensions()) { if (axisSize == NDShape::InferredDimension) { @@ -3395,6 +3409,12 @@ void CNTKToONNXHelper::CopyAttributes(const FunctionPtr& src, LotusIR::Node* nod axisIndex += src->Inputs()[0].DynamicAxes().size(); node->AddAttribute(attributesMap[L"axis"], axisIndex); } + else if (src->OpName() == L"Softmax_onnx" || src->OpName() == L"LogSoftmax_onnx" || src->OpName() == L"Hardmax_onnx") + { + Axis axis = (Axis)(src->Attributes()[L"axis"].Value()); + int64_t axisIndex = ConvertAxisToOnnx(axis, src->Inputs()[0]); + node->AddAttribute(attributesMap[L"axis"], axisIndex); + } else if (src->OpName() == L"Times") { size_t outputRank = src->Attributes()[L"outputRank"].Value(); @@ -3484,7 +3504,8 @@ void CNTKToONNXHelper::CopyAttributes(const FunctionPtr& src, LotusIR::Node* nod else if (src->OpName() == L"Unsqueeze") { std::vector axes = AsVector(src->Attributes()[L"axisVec"].Value>()); - std::vector ax = ConvertAxesToOnnx(axes, src->Inputs()[0]); + // Pass in output operand, such that Unsqueeze axes can be converted based on output rank. + std::vector ax = ConvertAxesToOnnx(axes, src->Outputs()[0]); node->AddAttribute("axes", ax); } diff --git a/Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp b/Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp index 459ca0368..ad4efbbb8 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp +++ b/Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp @@ -2480,8 +2480,10 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector } else if (onnxOpName == "Softmax" || onnxOpName == "LogSoftmax" || onnxOpName == "Hardmax") { - Axis axis(ConvertONNXAxisToCNTKCppApi(static_cast(GetNamedAttributeAsInt64(node, "axis", 1)), inputs[0])); - Variable input = Flatten(inputs[0], axis); + auto inputOperand0Placeholder = PlaceholderVariable(inputs[0].Shape(), inputs[0].GetDataType(), L"operand", {}); + + Axis axis(ConvertONNXAxisToCNTKCppApi(static_cast(GetNamedAttributeAsInt64(node, "axis", 1)), inputOperand0Placeholder)); + Variable input = Flatten(inputOperand0Placeholder, axis); FunctionPtr cntkFunction; if (onnxOpName == "Softmax") { @@ -2495,12 +2497,18 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector { cntkFunction = Hardmax(input, ToFixedWStringFromMultiByte(node->Name())); } - NDShape originalShape = inputs[0].Shape(); + NDShape originalShape = inputOperand0Placeholder.Shape(); assert(originalShape.Rank() > 0); // If original shape has free dimension(batch axis), we'll need to have reshape node infer that for us. if (originalShape[originalShape.Rank() - 1] == NDShape::FreeDimension) originalShape[originalShape.Rank() - 1] = NDShape::InferredDimension; - return Reshape(cntkFunction, originalShape); + cntkFunction = Reshape(cntkFunction, originalShape); + + auto additionalProperties = Dictionary(); + additionalProperties[L"axis"] = axis; + + return AsBlock(std::move(cntkFunction), {{inputOperand0Placeholder, inputs[0]}}, std::move(additionalProperties), + ToFixedWStringFromMultiByte(onnxOpName) + L"_onnx", ToFixedWStringFromMultiByte(node->Name())); } else if (onnxOpName == "Softplus") { @@ -2696,8 +2704,6 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector // { L"", "Split) else if (onnxOpName == "Slice") { - std::vector axes = ConvertONNXAxesToCNTKCppApi(GetNamedAttributeAsInt64Vec(node, "axes"), inputs[0]); - std::vector starts64 = GetNamedAttributeAsInt64Vec(node, "starts"); std::vector ends64 = GetNamedAttributeAsInt64Vec(node, "ends"); @@ -2716,10 +2722,13 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector e = 0; } + std::vector axes; + if (HasNamedAttribute(node, "axes")) + axes = ConvertONNXAxesToCNTKCppApi(GetNamedAttributeAsInt64Vec(node, "axes"), inputs[0]); // axes is optional so provide a default if (axes.empty()) { - for (int i = 0; i < starts.size(); i++) + for (int i = starts.size() - 1; i >= 0; i--) { Axis axis(i); axes.push_back(axis); diff --git a/Source/CNTKv2LibraryDll/proto/onnx/Operators.cpp b/Source/CNTKv2LibraryDll/proto/onnx/Operators.cpp index f033ebe63..ba9402104 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/Operators.cpp +++ b/Source/CNTKv2LibraryDll/proto/onnx/Operators.cpp @@ -254,6 +254,18 @@ namespace ONNX { L"LogSoftmax", "LogSoftmax" }, { L"axis", "axis" }, } } }, + { L"Hardmax_onnx",{ { + { L"Hardmax_onnx", "Hardmax" }, + { L"axis", "axis" }, + } } }, + { L"Softmax_onnx",{ { + { L"Softmax_onnx", "Softmax" }, + { L"axis", "axis" }, + } } }, + { L"LogSoftmax_onnx",{ { + { L"LogSoftmax_onnx", "LogSoftmax" }, + { L"axis", "axis" }, + } } }, { L"Softplus",{ { { L"Softplus", "Softplus" }, } } }, diff --git a/Tests/EndToEndTests/TestDriver.py b/Tests/EndToEndTests/TestDriver.py index c2f436218..0a0eca35e 100755 --- a/Tests/EndToEndTests/TestDriver.py +++ b/Tests/EndToEndTests/TestDriver.py @@ -188,6 +188,7 @@ class Test: @staticmethod def discoverAllTests(): for dirName, subdirList, fileList in os.walk(thisDir): + # Temporarily disable these tests on Windows due to an issue introduced by adding onnx to our CI. if os.path.basename(dirName) == 'Keras' and windows: continue if 'testcases.yml' in fileList: diff --git a/bindings/python/cntk/tests/onnx_model_test.py b/bindings/python/cntk/tests/onnx_model_test.py index ecd2936ef..2201101e9 100644 --- a/bindings/python/cntk/tests/onnx_model_test.py +++ b/bindings/python/cntk/tests/onnx_model_test.py @@ -40,22 +40,8 @@ skip_model_names = [ ] skip_round_trip_model_names = [ - # these are skipped due to known issues with gemm and pooling. - 'bvlc_alexnet', - 'bvlc_googlenet', - 'bvlc_reference_caffenet', - 'bvlc_reference_rcnn_ilsvrc13', - 'inception_v1', - 'resnet50', + # Convolution Nan issue on Linux. 'shufflenet', - 'vgg19', - 'zfnet512', - - 'resnet3d', - 'densenet121', - 'inception_v2', - 'mobilenetv2-1.0', - 'squeezenet', ] @pytest.mark.parametrize('model_name, round_trip', diff --git a/bindings/python/cntk/tests/onnx_op_test.py b/bindings/python/cntk/tests/onnx_op_test.py index d1dfb906c..ef584803d 100644 --- a/bindings/python/cntk/tests/onnx_op_test.py +++ b/bindings/python/cntk/tests/onnx_op_test.py @@ -599,7 +599,6 @@ def test_Exp(tmpdir, dtype): #Flatten @pytest.mark.parametrize("dtype", DType_Config) def test_Flatten(tmpdir, dtype): - pytest.skip('Needs to be fixed after removal of batch axis change.') with C.default_options(dtype = dtype): shape = (2, 3, 4, 5) data = np.reshape(np.arange(np.prod(shape), dtype = dtype), shape)