From 22e869ec42df45412751bfd50cf31eb1caf3cf99 Mon Sep 17 00:00:00 2001 From: Spandan Tiwari Date: Wed, 5 Dec 2018 13:51:43 -0800 Subject: [PATCH] Add ONNX support for zeros_like, ones_like, and eye_like. --- Source/CNTKv2LibraryDll/API/CNTKLibrary.h | 7 ++++- Source/CNTKv2LibraryDll/Function.cpp | 14 +++++----- .../proto/onnx/CNTKToONNX.cpp | 15 +++++++++-- .../proto/onnx/ONNXToCNTK.cpp | 15 +++++++++++ .../CNTKv2LibraryDll/proto/onnx/Operators.cpp | 6 +++++ bindings/python/cntk/tests/onnx_op_test.py | 26 ++++++++++++++++++- .../python/cntk/tests/onnx_verify_helper.py | 1 + 7 files changed, 74 insertions(+), 10 deletions(-) diff --git a/Source/CNTKv2LibraryDll/API/CNTKLibrary.h b/Source/CNTKv2LibraryDll/API/CNTKLibrary.h index 0f8374106..6b6bb1bcb 100644 --- a/Source/CNTKv2LibraryDll/API/CNTKLibrary.h +++ b/Source/CNTKv2LibraryDll/API/CNTKLibrary.h @@ -4161,7 +4161,12 @@ namespace CNTK /// Create an instance of the expand dims operation on specified tensor input operand, for the specified axis /// CNTK_API FunctionPtr ExpandDims(const Variable& operand, const Axis& axis, const std::wstring& name = L""); - + + // + /// Create an instance of a constant-like operation. This produces a tensor with given constant value with the shape and dynamic axes specified by the operand. + /// + CNTK_API FunctionPtr ConstantLike(const Variable& operand, const double value, const std::wstring& name = L""); + /// /// Create an instance of a zeros-like operation. This produces zeros with the shape and dynamic axes specified by the operand. /// diff --git a/Source/CNTKv2LibraryDll/Function.cpp b/Source/CNTKv2LibraryDll/Function.cpp index db135b113..26f7d0b92 100644 --- a/Source/CNTKv2LibraryDll/Function.cpp +++ b/Source/CNTKv2LibraryDll/Function.cpp @@ -1714,20 +1714,22 @@ namespace CNTK return AsBlock(std::move(result), { { operandPlaceholder, operand }}, L"ExpandDims", name); } - FunctionPtr ZerosLike(const Variable& operand, const std::wstring& name) + FunctionPtr ConstantLike(const Variable& operand, const double value, const std::wstring& name) { auto additionalProperties = Dictionary(); - additionalProperties[PrimitiveFunctionAttribute::AttributeNameFillValue] = 0.0; + additionalProperties[PrimitiveFunctionAttribute::AttributeNameFillValue] = value; return UnaryOp(PrimitiveOpType::ConstantOp, operand, std::move(additionalProperties), name); } + FunctionPtr ZerosLike(const Variable& operand, const std::wstring& name) + { + return ConstantLike(operand, 0.0, name); + } + FunctionPtr OnesLike(const Variable& operand, const std::wstring& name) { - auto additionalProperties = Dictionary(); - additionalProperties[PrimitiveFunctionAttribute::AttributeNameFillValue] = 1.0; - - return UnaryOp(PrimitiveOpType::ConstantOp, operand, std::move(additionalProperties), name); + return ConstantLike(operand, 1.0, name); } FunctionPtr CustomProxyOp(const std::vector& operands, const std::wstring& customOp, const NDShape& outputShape, DataType outputType, const std::wstring& name) diff --git a/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp b/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp index 2b451c1e9..de84c2e1b 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp +++ b/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp @@ -5743,6 +5743,19 @@ void CNTKToONNXHelper::CopyAttributes(const FunctionPtr& src, onnxruntime::Node* size_t k = src->Attributes()[L"numItems"].Value(); node->AddAttribute(attributesMap[L"numItems"], static_cast(k)); } + else if (src->OpName() == L"EyeLikeOp") + { + bool isOutputSparse = src->Attributes().Contains(L"OutputSparse") ? (bool)src->Attributes()[L"OutputSparse"].Value() : false; + if(isOutputSparse) + LogicError("Node '%S': 'OutputSparse' is True. Sparse format export not supported.", src->AsString().c_str()); + } + else if (src->OpName() == L"ConstantOp") + { + if(!src->Attributes().Contains(L"fillValue")) + LogicError("Node '%S': 'fillValue' not present. Cannot export op.", src->AsString().c_str()); + auto fillValue = static_cast(src->Attributes()[L"fillValue"].Value()); + node->AddAttribute("value", fillValue); + } else if (src->OpName() == L"Crop") { const NDShape& inputShape = src->Inputs()[0].Shape(); @@ -6155,8 +6168,6 @@ onnxruntime::Node* CNTKToONNXHelper::AddNode(const FunctionPtr& src, onnxruntime // if (src->OpName() == L"Times") { - if (src->Uid() == L"Times4771") - std::cout << ""; size_t py_api_output_rank_argument = src->Attributes()[L"outputRank"].Value(); auto input1Shape = orderedInputs[0]->Shape(); auto input2Shape = orderedInputs[1]->Shape(); diff --git a/Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp b/Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp index 9fd9dab35..10769e047 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp +++ b/Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp @@ -2929,6 +2929,21 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector FunctionPtr cntkFunction = TopK(inputs[0], k, axis, ToFixedWStringFromMultiByte(node->Name())); return cntkFunction; } + else if (onnxOpName == "EyeLike") + { + // Only limited import support is provided. + FunctionPtr cntkFunction = EyeLike(inputs[0], false, ToFixedWStringFromMultiByte(node->Name())); + return cntkFunction; + } + else if (onnxOpName == "ConstantLike") + { + // Limited import support implemented. 'shape' attribute + // node syntax not supported. Only syntax with input tensor + // for shape and 'value' attribute for value is supported. + float value = GetNamedAttributeAsFloat(node, "value", 0.0f); + FunctionPtr cntkFunction = ConstantLike(inputOperand0, static_cast(value), ToFixedWStringFromMultiByte(node->Name())); + return cntkFunction; + } else if (onnxOpName == "Crop") { // inputShape: [W, H, C] x [N] diff --git a/Source/CNTKv2LibraryDll/proto/onnx/Operators.cpp b/Source/CNTKv2LibraryDll/proto/onnx/Operators.cpp index 23095e820..7c16bc0f3 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/Operators.cpp +++ b/Source/CNTKv2LibraryDll/proto/onnx/Operators.cpp @@ -477,6 +477,12 @@ namespace ONNX { L"OneHotOp", { { { L"OneHotOp", "OneHotEncoder"}, } } }, + { L"EyeLikeOp",{ { + { L"EyeLikeOp", "EyeLike" }, + } } }, + { L"ConstantOp",{ { + { L"ConstantOp", "ConstantLike" }, + } } }, }; // given a cntkOpName and cntk attribute OpName which is saved in CNTK::Function's attribute, diff --git a/bindings/python/cntk/tests/onnx_op_test.py b/bindings/python/cntk/tests/onnx_op_test.py index cda31f3de..ad27fc046 100644 --- a/bindings/python/cntk/tests/onnx_op_test.py +++ b/bindings/python/cntk/tests/onnx_op_test.py @@ -2062,4 +2062,28 @@ def test_Crop_Manual(tmpdir, dtype): y = C.constant(np.ones((1,2,1), dtype=np.float32)) model = C.crop_manual(x, y, 1, 2, name='crop_manual') data = np.asarray(range(4*4), dtype=np.float32).reshape((1,4,4)) - verify_one_input(model, data, tmpdir, "Crop_Manual_0") \ No newline at end of file + verify_one_input(model, data, tmpdir, "Crop_Manual_0") + +# eye_like +@pytest.mark.parametrize("dtype", DType_Config) +def test_Eye_Like(tmpdir, dtype): + x = C.input_variable((4, 4), dynamic_axes=[], dtype=dtype, name='feature') + model = C.eye_like(x, sparse_output=False) + data = np.asarray(range(4*4), dtype=dtype).reshape((4,4)) + verify_one_input(model, data, tmpdir, "Eye_Like_0") + +# zeros_like +@pytest.mark.parametrize("dtype", DType_Config) +def test_Zeros_Like(tmpdir, dtype): + x = C.input_variable((3, 4), dynamic_axes=[], dtype=dtype, name='feature') + model = C.zeros_like(x, name='zeros_like_op') + data = np.asarray(range(3*4), dtype=dtype).reshape((3,4)) + verify_one_input(model, data, tmpdir, "Zeros_Like_0") + +# ones_like +@pytest.mark.parametrize("dtype", DType_Config) +def test_Ones_Like(tmpdir, dtype): + x = C.input_variable((3, 4), dynamic_axes=[], dtype=dtype, name='feature') + model = C.ones_like(x, name='ones_like_op') + data = np.asarray(range(3*4), dtype=dtype).reshape((3,4)) + verify_one_input(model, data, tmpdir, "Ones_Like_0") \ No newline at end of file diff --git a/bindings/python/cntk/tests/onnx_verify_helper.py b/bindings/python/cntk/tests/onnx_verify_helper.py index 7bf8d4fbc..1f5c8d21a 100644 --- a/bindings/python/cntk/tests/onnx_verify_helper.py +++ b/bindings/python/cntk/tests/onnx_verify_helper.py @@ -30,6 +30,7 @@ known_issues = [ 'MVN_1', 'MVN_2', 'MVN_3', + 'Eye_Like_0', ] def parse_single_result_case(case_str):