onnx export: fix cases when scan input needs to be broadcasted

This commit is contained in:
liqfu 2019-04-29 11:58:00 -07:00
Родитель 0e172db667
Коммит 22a86bf3ce
1 изменённых файлов: 25 добавлений и 9 удалений

Просмотреть файл

@ -388,7 +388,7 @@ private:
static onnxruntime::Node *AddReshapeNodeImpl(Graph *graph, const string &nodeName, NodeArg *input, NodeArg *output, const std::vector<int64_t>& newShape);
static NodeArg* GetInputAdjustmentForBroadcast(onnxruntime::Graph* graph, const FunctionPtr src, const Variable &input, int inputIndex,
onnx::TypeProto &inputArgType);
onnx::TypeProto &inputArgType, const std::string& scanInputName = "");
static int BatchSizeOverride(const FunctionPtr src, const std::vector<onnxruntime::NodeArg*>& inputs,
onnx::TypeProto& outputArgType);
@ -5596,8 +5596,12 @@ bool TryMatchNodeArgType(onnx::TypeProto &argType, onnxruntime::Graph* graph, co
// z = C.sequence.input_variable((2,)) + C.input_variable((3,2))
//
// input is not necessarily an input to src. It may be obtained via skipping of batch/sequence pack/unpack wrappers.
NodeArg* CNTKToONNXHelper::GetInputAdjustmentForBroadcast(onnxruntime::Graph* graph, const FunctionPtr src,
const Variable &input, int inputIndex, onnx::TypeProto &inputArgType)
//
// Special note for scan input cases:
// when input is also an input to a scan op, we have to keep its name so that
// the main graph and subgraph is well connected.
NodeArg* CNTKToONNXHelper::GetInputAdjustmentForBroadcast(onnxruntime::Graph* graph, const FunctionPtr src,
const Variable &input, int inputIndex, onnx::TypeProto &inputArgType, const std::string& scanInputName)
{
// TODO: do we need to get blockroot if it is a block function?
if (!Operators::SupportBroadcast(src->OpName()))
@ -5653,11 +5657,16 @@ NodeArg* CNTKToONNXHelper::GetInputAdjustmentForBroadcast(onnxruntime::Graph* gr
//inputArgType.mutable_tensor_type()->set_elem_type(inputArgType.tensor_type().elem_type());
//UpdateONNXType(input.GetDataType(), inputArgType);
std::string inputNodeArgName;
auto inputItr = compositeOutputsMap.find(input);
if (inputItr != compositeOutputsMap.end())
inputNodeArgName = UniqueNodeNameStorage::GetUniqueInputNodeName(inputItr->second);
if (scanInputName != "")
inputNodeArgName = scanInputName;
else
inputNodeArgName = UniqueNodeNameStorage::GetUniqueInputNodeName(input);
{
auto inputItr = compositeOutputsMap.find(input);
if (inputItr != compositeOutputsMap.end())
inputNodeArgName = UniqueNodeNameStorage::GetUniqueInputNodeName(inputItr->second);
else
inputNodeArgName = UniqueNodeNameStorage::GetUniqueInputNodeName(input);
}
std::string outputArgName = UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(inputNodeArgName + "_reshaped_for_broadcast");
onnxruntime::NodeArg &nodeArg = graph->GetOrCreateNodeArg(inputNodeArgName, &inputArgType);
@ -6095,11 +6104,18 @@ void CNTKToONNXHelper::ProcessInputs(const FunctionPtr& src,
}
}
onnxruntime::NodeArg *adjusted = GetInputAdjustmentForBroadcast(graph, src, input, inputIndex, inputArgType);
onnxruntime::NodeArg *adjusted = nullptr;
if ((isOutputOfStepFunction && isInSubGraph) || isScanInputInSubgraph)
{
inputName = MakeScanInputOutputNodeArgName(inputName);
// in case of broadcast, we want the input name unchanged.
// The inserted reshape op is treated as being inside of the scan subgraph.
adjusted = GetInputAdjustmentForBroadcast(graph, src, input, inputIndex, inputArgType, inputName);
}
else
{
adjusted = GetInputAdjustmentForBroadcast(graph, src, input, inputIndex, inputArgType);
}
onnxruntime::NodeArg &inputArg = adjusted == nullptr ? graph->GetOrCreateNodeArg(inputName, &inputArgType) : *adjusted;