From 22a86bf3ce3e34a59d4d0f574e7f6af13efda895 Mon Sep 17 00:00:00 2001 From: liqfu Date: Mon, 29 Apr 2019 11:58:00 -0700 Subject: [PATCH] onnx export: fix cases when scan input needs to be broadcasted --- .../proto/onnx/CNTKToONNX.cpp | 34 ++++++++++++++----- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp b/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp index 24d83d36e..6a91056d0 100644 --- a/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp +++ b/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp @@ -388,7 +388,7 @@ private: static onnxruntime::Node *AddReshapeNodeImpl(Graph *graph, const string &nodeName, NodeArg *input, NodeArg *output, const std::vector& newShape); static NodeArg* GetInputAdjustmentForBroadcast(onnxruntime::Graph* graph, const FunctionPtr src, const Variable &input, int inputIndex, - onnx::TypeProto &inputArgType); + onnx::TypeProto &inputArgType, const std::string& scanInputName = ""); static int BatchSizeOverride(const FunctionPtr src, const std::vector& inputs, onnx::TypeProto& outputArgType); @@ -5596,8 +5596,12 @@ bool TryMatchNodeArgType(onnx::TypeProto &argType, onnxruntime::Graph* graph, co // z = C.sequence.input_variable((2,)) + C.input_variable((3,2)) // // input is not necessarily an input to src. It may be obtained via skipping of batch/sequence pack/unpack wrappers. -NodeArg* CNTKToONNXHelper::GetInputAdjustmentForBroadcast(onnxruntime::Graph* graph, const FunctionPtr src, - const Variable &input, int inputIndex, onnx::TypeProto &inputArgType) +// +// Special note for scan input cases: +// when input is also an input to a scan op, we have to keep its name so that +// the main graph and subgraph is well connected. +NodeArg* CNTKToONNXHelper::GetInputAdjustmentForBroadcast(onnxruntime::Graph* graph, const FunctionPtr src, + const Variable &input, int inputIndex, onnx::TypeProto &inputArgType, const std::string& scanInputName) { // TODO: do we need to get blockroot if it is a block function? if (!Operators::SupportBroadcast(src->OpName())) @@ -5653,11 +5657,16 @@ NodeArg* CNTKToONNXHelper::GetInputAdjustmentForBroadcast(onnxruntime::Graph* gr //inputArgType.mutable_tensor_type()->set_elem_type(inputArgType.tensor_type().elem_type()); //UpdateONNXType(input.GetDataType(), inputArgType); std::string inputNodeArgName; - auto inputItr = compositeOutputsMap.find(input); - if (inputItr != compositeOutputsMap.end()) - inputNodeArgName = UniqueNodeNameStorage::GetUniqueInputNodeName(inputItr->second); + if (scanInputName != "") + inputNodeArgName = scanInputName; else - inputNodeArgName = UniqueNodeNameStorage::GetUniqueInputNodeName(input); + { + auto inputItr = compositeOutputsMap.find(input); + if (inputItr != compositeOutputsMap.end()) + inputNodeArgName = UniqueNodeNameStorage::GetUniqueInputNodeName(inputItr->second); + else + inputNodeArgName = UniqueNodeNameStorage::GetUniqueInputNodeName(input); + } std::string outputArgName = UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(inputNodeArgName + "_reshaped_for_broadcast"); onnxruntime::NodeArg &nodeArg = graph->GetOrCreateNodeArg(inputNodeArgName, &inputArgType); @@ -6095,11 +6104,18 @@ void CNTKToONNXHelper::ProcessInputs(const FunctionPtr& src, } } - onnxruntime::NodeArg *adjusted = GetInputAdjustmentForBroadcast(graph, src, input, inputIndex, inputArgType); - + onnxruntime::NodeArg *adjusted = nullptr; if ((isOutputOfStepFunction && isInSubGraph) || isScanInputInSubgraph) { inputName = MakeScanInputOutputNodeArgName(inputName); + + // in case of broadcast, we want the input name unchanged. + // The inserted reshape op is treated as being inside of the scan subgraph. + adjusted = GetInputAdjustmentForBroadcast(graph, src, input, inputIndex, inputArgType, inputName); + } + else + { + adjusted = GetInputAdjustmentForBroadcast(graph, src, input, inputIndex, inputArgType); } onnxruntime::NodeArg &inputArg = adjusted == nullptr ? graph->GetOrCreateNodeArg(inputName, &inputArgType) : *adjusted;