onnx export: fix cases when scan input needs to be broadcasted
This commit is contained in:
Родитель
0e172db667
Коммит
22a86bf3ce
|
@ -388,7 +388,7 @@ private:
|
|||
static onnxruntime::Node *AddReshapeNodeImpl(Graph *graph, const string &nodeName, NodeArg *input, NodeArg *output, const std::vector<int64_t>& newShape);
|
||||
|
||||
static NodeArg* GetInputAdjustmentForBroadcast(onnxruntime::Graph* graph, const FunctionPtr src, const Variable &input, int inputIndex,
|
||||
onnx::TypeProto &inputArgType);
|
||||
onnx::TypeProto &inputArgType, const std::string& scanInputName = "");
|
||||
|
||||
static int BatchSizeOverride(const FunctionPtr src, const std::vector<onnxruntime::NodeArg*>& inputs,
|
||||
onnx::TypeProto& outputArgType);
|
||||
|
@ -5596,8 +5596,12 @@ bool TryMatchNodeArgType(onnx::TypeProto &argType, onnxruntime::Graph* graph, co
|
|||
// z = C.sequence.input_variable((2,)) + C.input_variable((3,2))
|
||||
//
|
||||
// input is not necessarily an input to src. It may be obtained via skipping of batch/sequence pack/unpack wrappers.
|
||||
NodeArg* CNTKToONNXHelper::GetInputAdjustmentForBroadcast(onnxruntime::Graph* graph, const FunctionPtr src,
|
||||
const Variable &input, int inputIndex, onnx::TypeProto &inputArgType)
|
||||
//
|
||||
// Special note for scan input cases:
|
||||
// when input is also an input to a scan op, we have to keep its name so that
|
||||
// the main graph and subgraph is well connected.
|
||||
NodeArg* CNTKToONNXHelper::GetInputAdjustmentForBroadcast(onnxruntime::Graph* graph, const FunctionPtr src,
|
||||
const Variable &input, int inputIndex, onnx::TypeProto &inputArgType, const std::string& scanInputName)
|
||||
{
|
||||
// TODO: do we need to get blockroot if it is a block function?
|
||||
if (!Operators::SupportBroadcast(src->OpName()))
|
||||
|
@ -5653,11 +5657,16 @@ NodeArg* CNTKToONNXHelper::GetInputAdjustmentForBroadcast(onnxruntime::Graph* gr
|
|||
//inputArgType.mutable_tensor_type()->set_elem_type(inputArgType.tensor_type().elem_type());
|
||||
//UpdateONNXType(input.GetDataType(), inputArgType);
|
||||
std::string inputNodeArgName;
|
||||
auto inputItr = compositeOutputsMap.find(input);
|
||||
if (inputItr != compositeOutputsMap.end())
|
||||
inputNodeArgName = UniqueNodeNameStorage::GetUniqueInputNodeName(inputItr->second);
|
||||
if (scanInputName != "")
|
||||
inputNodeArgName = scanInputName;
|
||||
else
|
||||
inputNodeArgName = UniqueNodeNameStorage::GetUniqueInputNodeName(input);
|
||||
{
|
||||
auto inputItr = compositeOutputsMap.find(input);
|
||||
if (inputItr != compositeOutputsMap.end())
|
||||
inputNodeArgName = UniqueNodeNameStorage::GetUniqueInputNodeName(inputItr->second);
|
||||
else
|
||||
inputNodeArgName = UniqueNodeNameStorage::GetUniqueInputNodeName(input);
|
||||
}
|
||||
|
||||
std::string outputArgName = UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(inputNodeArgName + "_reshaped_for_broadcast");
|
||||
onnxruntime::NodeArg &nodeArg = graph->GetOrCreateNodeArg(inputNodeArgName, &inputArgType);
|
||||
|
@ -6095,11 +6104,18 @@ void CNTKToONNXHelper::ProcessInputs(const FunctionPtr& src,
|
|||
}
|
||||
}
|
||||
|
||||
onnxruntime::NodeArg *adjusted = GetInputAdjustmentForBroadcast(graph, src, input, inputIndex, inputArgType);
|
||||
|
||||
onnxruntime::NodeArg *adjusted = nullptr;
|
||||
if ((isOutputOfStepFunction && isInSubGraph) || isScanInputInSubgraph)
|
||||
{
|
||||
inputName = MakeScanInputOutputNodeArgName(inputName);
|
||||
|
||||
// in case of broadcast, we want the input name unchanged.
|
||||
// The inserted reshape op is treated as being inside of the scan subgraph.
|
||||
adjusted = GetInputAdjustmentForBroadcast(graph, src, input, inputIndex, inputArgType, inputName);
|
||||
}
|
||||
else
|
||||
{
|
||||
adjusted = GetInputAdjustmentForBroadcast(graph, src, input, inputIndex, inputArgType);
|
||||
}
|
||||
|
||||
onnxruntime::NodeArg &inputArg = adjusted == nullptr ? graph->GetOrCreateNodeArg(inputName, &inputArgType) : *adjusted;
|
||||
|
|
Загрузка…
Ссылка в новой задаче