Родитель
47345c6c54
Коммит
6c56fdd2f9
|
@ -291,13 +291,6 @@ private:
|
|||
std::vector<ScanLoop> &scanLoops, int createLoopIndex,
|
||||
bool isFirst);
|
||||
|
||||
static onnxruntime::Node* CreateWhereNode(const FunctionPtr& src,
|
||||
onnxruntime::Graph* graph,
|
||||
std::unordered_map<FunctionPtr, onnxruntime::Node*>& functionNodes,
|
||||
std::unordered_map<Variable, onnxruntime::Node*>& variableNodes,
|
||||
const std::unordered_map<Variable, Variable>& compositeOutputsMap,
|
||||
std::vector<ScanLoop> &scanLoops, int createLoopIndex);
|
||||
|
||||
static onnxruntime::Node* CreateNodeWithGatherPacked(const FunctionPtr& src,
|
||||
onnxruntime::Graph* graph,
|
||||
std::unordered_map<FunctionPtr, onnxruntime::Node*>& functionNodes,
|
||||
|
@ -3785,8 +3778,65 @@ onnxruntime::Node* CNTKToONNXHelper::CreateReconcileDynamicAxisNode(const Functi
|
|||
const std::unordered_map<Variable, Variable>& compositeOutputsMap,
|
||||
std::vector<ScanLoop> &scanLoops, int createLoopIndex)
|
||||
{
|
||||
return CreateSequenceBroadcastAsNode(src, graph, functionNodes, variableNodes, compositeOutputsMap,
|
||||
scanLoops, createLoopIndex);
|
||||
bool inLoop = createLoopIndex != -1;
|
||||
if (!inLoop)
|
||||
// TODO: sequence.broadcast_as may still be in a loop.
|
||||
// Investigate whether both sequence.broadcast_as and reconcile_dyanmic_axis is fully equivalent.
|
||||
return CreateSequenceBroadcastAsNode(src, graph, functionNodes, variableNodes, compositeOutputsMap,
|
||||
scanLoops, createLoopIndex);
|
||||
|
||||
std::vector<onnxruntime::NodeArg *> inputs, outputs;
|
||||
ProcessInputs(src, graph, functionNodes, variableNodes, compositeOutputsMap, inputs, scanLoops, createLoopIndex);
|
||||
ProcessOutputs(src, inputs, outputs, graph);
|
||||
|
||||
Variable input = src->Inputs()[0];
|
||||
Variable broadcastAs = src->Inputs()[1];
|
||||
Variable output = src->Outputs()[0];
|
||||
|
||||
std::vector<int64_t> newShape = ToINTS(*outputs[0]->TypeAsProto());
|
||||
|
||||
std::string nodeName = UniqueNodeNameStorage::GetUniqueNodeName(src);
|
||||
std::string squeezedBroadcastNodeArgName = nodeName + "_squeezed";
|
||||
onnxruntime::Node* broadcastAsSqueezed = ExtractShapeWithDynamicAxes(
|
||||
graph, src->Inputs()[1], inputs[1], squeezedBroadcastNodeArgName, inLoop);
|
||||
|
||||
NodeArg *broadcastNodeArg = const_cast<onnxruntime::NodeArg *>(broadcastAsSqueezed->OutputDefs()[0]);
|
||||
//broadcastAsSqueezed has shape [sequence, batch], append ONEs to get to the same rank as the input before broadcasting
|
||||
if (input.Shape().Rank() != 0)
|
||||
{
|
||||
std::vector<int64_t> newShape = ToINTS(*broadcastAsSqueezed->OutputDefs()[0]->TypeAsProto());
|
||||
for (int i = 0; i < input.Shape().Rank(); i++)
|
||||
newShape.push_back(1);
|
||||
std::string broadcastReshapedNodeName = nodeName + "_squeezed_reshaped";
|
||||
onnxruntime::Node* broadcastReshaped = AddReshapeNode(
|
||||
*broadcastNodeArg, newShape, broadcastReshapedNodeName, graph);
|
||||
broadcastNodeArg = const_cast<onnxruntime::NodeArg *>(broadcastReshaped->OutputDefs()[0]);
|
||||
}
|
||||
|
||||
NodeArg *inputNodeArg = nullptr;
|
||||
if (input.DynamicAxes().size() == 1)
|
||||
{
|
||||
// input does not have sequence axis, insert with size 1 so it is broadcasted
|
||||
std::vector<int64_t> newShape = ToINTS(*inputs[0]->TypeAsProto());
|
||||
newShape.insert(newShape.begin(), 1);
|
||||
Node* inputReshapeNode = AddReshapeNode(*inputs[0], newShape, nodeName + "_input_reshape", graph);
|
||||
inputNodeArg = const_cast<onnxruntime::NodeArg *>(inputReshapeNode->OutputDefs()[0]);
|
||||
}
|
||||
else if(input.DynamicAxes().size() == 0)
|
||||
{
|
||||
inputNodeArg = inputs[0];
|
||||
}
|
||||
else if (input.DynamicAxes().size() == 2)
|
||||
{
|
||||
// This is the case with unfold op. It has been verified with onnx runtime.
|
||||
inputNodeArg = inputs[0];
|
||||
}
|
||||
|
||||
onnxruntime::Node* elementWiseNode = &graph->AddNode(nodeName + "_add", "Add", "",
|
||||
{ inputNodeArg, broadcastNodeArg }, outputs);
|
||||
|
||||
functionNodes.emplace(src, elementWiseNode);
|
||||
return elementWiseNode;
|
||||
}
|
||||
|
||||
//
|
||||
|
@ -3816,51 +3866,23 @@ onnxruntime::Node* CNTKToONNXHelper::CreateSequenceBroadcastAsNode(const Functio
|
|||
//broadcastAsSqueezed has shape [sequence, batch], append 1s to get to the same rank as the input before broadcasting
|
||||
if (input.Shape().Rank() != 0)
|
||||
{
|
||||
TypeProto typeProto = *broadcastAsSqueezed->OutputDefs()[0]->TypeAsProto();
|
||||
TensorShapeProto* tensorShapeProto = typeProto.mutable_tensor_type()->mutable_shape();
|
||||
std::vector<int64_t> unsqueezeAxis;
|
||||
std::vector<int64_t> newShape = ToINTS(*broadcastAsSqueezed->OutputDefs()[0]->TypeAsProto());
|
||||
for (int i = 0; i < input.Shape().Rank(); i++)
|
||||
{
|
||||
tensorShapeProto->add_dim()->set_dim_value(1);
|
||||
unsqueezeAxis.push_back(i + broadcastNodeArg->Shape()->dim_size());
|
||||
}
|
||||
|
||||
std::string broadcastUnsqueezeNodeName = nodeName + "_broadcast_as_unsqueeze";
|
||||
std::string broadcastUnsqueezeOutputNodeArgName = broadcastUnsqueezeNodeName + "_output";
|
||||
NodeArg *broadcastUnsqueezeOutputNodeArg = &graph->GetOrCreateNodeArg(broadcastUnsqueezeOutputNodeArgName, &typeProto);
|
||||
|
||||
Node *unsqueezeNode = &graph->AddNode(broadcastUnsqueezeNodeName, "Unsqueeze", "",
|
||||
{ broadcastNodeArg }, { broadcastUnsqueezeOutputNodeArg });
|
||||
unsqueezeNode->AddAttribute("axes", unsqueezeAxis);
|
||||
broadcastNodeArg = const_cast<onnxruntime::NodeArg *>(unsqueezeNode->OutputDefs()[0]);
|
||||
newShape.push_back(1);
|
||||
std::string broadcastReshapedNodeName = nodeName + "_squeezed_reshaped";
|
||||
onnxruntime::Node* broadcastReshaped = AddReshapeNode(
|
||||
*broadcastNodeArg, newShape, broadcastReshapedNodeName, graph);
|
||||
broadcastNodeArg = const_cast<onnxruntime::NodeArg *>(broadcastReshaped->OutputDefs()[0]);
|
||||
}
|
||||
|
||||
NodeArg *inputNodeArg = nullptr;
|
||||
if (input.DynamicAxes().size() == 1 && createLoopIndex < 0)
|
||||
if (input.DynamicAxes().size() == 1)
|
||||
{
|
||||
// not in a loop so there shall be a sequence axis.
|
||||
// input does not have sequence axis, prepend dim with dim_value = 1 so it is broadcasted
|
||||
onnx::TypeProto unsqueezeNodeOutputTYpeProto = MakeTypeProtoWithShape();
|
||||
onnx::TensorProto_DataType elemType = ConvertDataTypeCNTKToTensorProto(input.GetDataType());
|
||||
unsqueezeNodeOutputTYpeProto.mutable_tensor_type()->set_elem_type(elemType);
|
||||
|
||||
ONNX_NAMESPACE::TensorShapeProto *shapeProto = unsqueezeNodeOutputTYpeProto.mutable_tensor_type()->mutable_shape();
|
||||
shapeProto->add_dim()->set_dim_value(1);
|
||||
for (int dim = 0; dim < inputs[0]->Shape()->dim_size(); dim++)
|
||||
{
|
||||
if (inputs[0]->Shape()->dim(dim).has_dim_value())
|
||||
shapeProto->add_dim()->set_dim_value(inputs[0]->Shape()->dim(dim).dim_value());
|
||||
else
|
||||
shapeProto->add_dim()->set_dim_param(inputs[0]->Shape()->dim(dim).dim_param());
|
||||
}
|
||||
|
||||
std::string unsqueezeNodeName = nodeName + "_input_unsqueeze";
|
||||
NodeArg *unsqueezeOutputNodeArg = &graph->GetOrCreateNodeArg(unsqueezeNodeName + "_output", &unsqueezeNodeOutputTYpeProto);
|
||||
|
||||
Node* unsqueezeNode = &graph->AddNode(unsqueezeNodeName, "Unsqueeze", "",
|
||||
{ inputs[0] }, { unsqueezeOutputNodeArg });
|
||||
unsqueezeNode->AddAttribute("axes", std::vector<int64_t>({ 0 }));
|
||||
inputNodeArg = const_cast<onnxruntime::NodeArg *>(unsqueezeNode->OutputDefs()[0]);
|
||||
// input does not have sequence axis, insert with size 1 so it is broadcasted
|
||||
std::vector<int64_t> newShape = ToINTS(*inputs[0]->TypeAsProto());
|
||||
newShape.insert(newShape.begin(), 1);
|
||||
Node* inputReshapeNode = AddReshapeNode(*inputs[0], newShape, nodeName + "_input_reshape", graph);
|
||||
inputNodeArg = const_cast<onnxruntime::NodeArg *>(inputReshapeNode->OutputDefs()[0]);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -3885,9 +3907,6 @@ onnxruntime::Node* CNTKToONNXHelper::CreateSequenceGatherNode(const FunctionPtr&
|
|||
if (CNTKToONNXHelper::isProcessingScan)
|
||||
LogicError("SequenceGather cannot be in a scan loop");
|
||||
|
||||
const std::string nodeName = UniqueNodeNameStorage::GetUniqueNodeName(src);
|
||||
const std::string onnxOpName = "Compress";
|
||||
|
||||
std::vector<onnxruntime::NodeArg *> inputs;
|
||||
ProcessInputs(src, graph, functionNodes, variableNodes, compositeOutputsMap, inputs, scanLoops, createLoopIndex);
|
||||
|
||||
|
@ -3897,7 +3916,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateSequenceGatherNode(const FunctionPtr&
|
|||
//ProcessOutputs(src, outputs, graph);
|
||||
|
||||
// Cast inputs[1] from tensor<float> to tensor<bool>
|
||||
const std::string outputNodeArgName = inputs[1]->Name() + "_cast_to_bool_" + nodeName;
|
||||
const std::string outputNodeArgName = inputs[1]->Name() + "_cast_to_bool";
|
||||
Node *castNode = AddCastNode(*inputs[1], graph,
|
||||
TensorProto_DataType::TensorProto_DataType_BOOL, outputNodeArgName);
|
||||
|
||||
|
@ -3914,6 +3933,8 @@ onnxruntime::Node* CNTKToONNXHelper::CreateSequenceGatherNode(const FunctionPtr&
|
|||
|
||||
NodeArg& compressOutputNodeArg = CreateNodeArg(src->Outputs()[0], graph, false);
|
||||
|
||||
std::string nodeName = UniqueNodeNameStorage::GetUniqueNodeName(src);
|
||||
std::string onnxOpName = "Compress";
|
||||
Node *compressNode = &graph->AddNode(nodeName, onnxOpName, "", inputs, { &compressOutputNodeArg });
|
||||
|
||||
int64_t sequenceAxis = 0;
|
||||
|
@ -3946,34 +3967,9 @@ onnxruntime::Node* CNTKToONNXHelper::CreateSequenceReduceElementsNode(const Func
|
|||
|
||||
Node *node = &graph->AddNode(nodeName, onnxOpName, "", inputs, outputs);
|
||||
SetReduceElementsAttributes(br, node, true);
|
||||
|
||||
functionNodes.emplace(src, node);
|
||||
|
||||
return node;
|
||||
}
|
||||
|
||||
|
||||
onnxruntime::Node* CNTKToONNXHelper::CreateWhereNode(const FunctionPtr& src,
|
||||
onnxruntime::Graph* graph,
|
||||
std::unordered_map<FunctionPtr, onnxruntime::Node*>& functionNodes,
|
||||
std::unordered_map<Variable, onnxruntime::Node*>& variableNodes,
|
||||
const std::unordered_map<Variable, Variable>& compositeOutputsMap,
|
||||
std::vector<ScanLoop> &scanLoops, int createLoopIndex)
|
||||
{
|
||||
assert(src->OpName() == L"Where");
|
||||
std::vector<onnxruntime::NodeArg *> inputs;
|
||||
ProcessInputs(src, graph, functionNodes, variableNodes, compositeOutputsMap, inputs, scanLoops, createLoopIndex);
|
||||
std::vector<onnxruntime::NodeArg *> outputs;
|
||||
ProcessOutputs(src, inputs, outputs, graph);
|
||||
|
||||
// TODO: implement Where op.
|
||||
// This is just to unblock Seq2Seq test.
|
||||
Node* whereNode = AddIdentityOp(*inputs[0], graph, outputs[0]->Name());
|
||||
|
||||
functionNodes.emplace(src, whereNode);
|
||||
return whereNode;
|
||||
}
|
||||
|
||||
onnxruntime::Node* CNTKToONNXHelper::CreateNodeWithGatherPacked(const FunctionPtr& src,
|
||||
onnxruntime::Graph* graph,
|
||||
std::unordered_map<FunctionPtr, onnxruntime::Node*>& functionNodes,
|
||||
|
@ -4038,15 +4034,13 @@ bool IsDynamicAxisPackUnpack(const std::string &cntkOpName)
|
|||
onnxruntime::Node* CNTKToONNXHelper::ExtractShapeWithDynamicAxes(onnxruntime::Graph* graph,
|
||||
Variable input, NodeArg* inputNodeArg, const std::string &nodeName, bool inLoop)
|
||||
{
|
||||
NodeArg *constantLikeInput = nullptr;
|
||||
if (input.Shape().Rank() == 0)
|
||||
{
|
||||
// already in the shape of [*, #]
|
||||
constantLikeInput = inputNodeArg;
|
||||
// TODO:
|
||||
return nullptr;
|
||||
}
|
||||
else
|
||||
{
|
||||
// need to slice off static axes and squeeze to get [*, #]
|
||||
std::vector<int64_t> sliceStarts, sliceEnds;
|
||||
std::vector<Axis> axes;
|
||||
for (int i = 0; i < input.Shape().Rank(); i++)
|
||||
|
@ -4063,8 +4057,6 @@ onnxruntime::Node* CNTKToONNXHelper::ExtractShapeWithDynamicAxes(onnxruntime::Gr
|
|||
|
||||
Node *squeezeNode = AddSqueezeNode(const_cast<NodeArg &>(*sliceNode->OutputDefs().at(0)), sliceAxes,
|
||||
nodeName + "_squeeze", graph);
|
||||
constantLikeInput = const_cast<NodeArg *>(squeezeNode->OutputDefs().at(0));
|
||||
}
|
||||
|
||||
// prepare output NodeArg with shape of [sequence, batch]
|
||||
onnx::TypeProto typeProto = MakeTypeProtoWithShape();
|
||||
|
@ -4072,17 +4064,18 @@ onnxruntime::Node* CNTKToONNXHelper::ExtractShapeWithDynamicAxes(onnxruntime::Gr
|
|||
typeProto.mutable_tensor_type()->set_elem_type(elemType);
|
||||
onnxruntime::NodeArg &outputNodeArg = graph->GetOrCreateNodeArg(nodeName + "_output", &typeProto);
|
||||
|
||||
ONNX_NAMESPACE::TensorShapeProto shapeProto;
|
||||
if (!inLoop)
|
||||
shapeProto.add_dim()->set_dim_param(FreeSequenceDimParam);
|
||||
if (!inLoop || !ScanWithoutBatchAxis)
|
||||
shapeProto.add_dim()->set_dim_value(BatchSizeProcessor::FreeBatchSize());
|
||||
outputNodeArg.SetShape(shapeProto);
|
||||
ONNX_NAMESPACE::TensorShapeProto shapeProto;
|
||||
if (!inLoop)
|
||||
shapeProto.add_dim()->set_dim_param(FreeSequenceDimParam);
|
||||
if (!inLoop || !ScanWithoutBatchAxis)
|
||||
shapeProto.add_dim()->set_dim_value(BatchSizeProcessor::FreeBatchSize());
|
||||
outputNodeArg.SetShape(shapeProto);
|
||||
|
||||
Node *constantNode = &graph->AddNode(nodeName + "_constant_like", "ConstantLike", "",
|
||||
{ constantLikeInput }, { &outputNodeArg });
|
||||
constantNode->AddAttribute("value", (float)0);
|
||||
return constantNode;
|
||||
Node *constantNode = &graph->AddNode(nodeName + "_constant_like", "ConstantLike", "",
|
||||
{ const_cast<NodeArg *>(squeezeNode->OutputDefs().at(0)) }, { &outputNodeArg });
|
||||
constantNode->AddAttribute("value", (float)0);
|
||||
return constantNode;
|
||||
}
|
||||
}
|
||||
|
||||
// To handle "UnpackBatchAxis" , "ToBatchAxis" , "UnpackSequenceOp", "ToSequenceOp" ops.
|
||||
|
@ -4149,6 +4142,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateDynamicAxisPackUnpackNode(const Funct
|
|||
|
||||
if (src->OpName() == L"UnpackSequenceOp" && src->Outputs().size() == 2)
|
||||
{
|
||||
std::string unpackMaskNodeName = transposeNodeName + "_mask";
|
||||
// 1. slice from [#,*][d1, d2...] to [#,*][1, 1...]
|
||||
// 2. squeeze to [#,*][1]
|
||||
// 3. ConstantLike with value = 1
|
||||
|
@ -4161,11 +4155,26 @@ onnxruntime::Node* CNTKToONNXHelper::CreateDynamicAxisPackUnpackNode(const Funct
|
|||
}
|
||||
|
||||
std::string sliceOutputArgName = transposeNodeName + "_slice";
|
||||
Node *sliceNode = AddSliceNode(*outputs[0], sliceAxes, sliceStarts, sliceEnds, sliceOutputArgName, graph);
|
||||
Node *sliceNode = AddSliceNode(*inputs[0], sliceAxes, sliceStarts, sliceEnds, sliceOutputArgName, graph);
|
||||
|
||||
Node *squeezeNode = AddSqueezeNode(const_cast<NodeArg &>(*sliceNode->OutputDefs().at(0)), sliceAxes,
|
||||
transposeNodeName + "_squeeze", graph);
|
||||
|
||||
|
||||
if (src->Outputs()[1].DynamicAxes().size() == 1 &&
|
||||
src->Outputs()[1].Shape().Rank() > 0 &&
|
||||
src->Outputs()[1].Shape()[src->Outputs()[1].Shape().Rank() - 1] == NDShape::FreeDimension)
|
||||
{
|
||||
// the second output of UnpackSequenceOp has a shape of [batch][sequence]
|
||||
// ToTypeProto does not handle this case (it only swap if there are 2 dynamic axes [batch, sequence][d1...])
|
||||
// make output onnx shape [sequence, batch] (was [batch,sequence])
|
||||
ONNX_NAMESPACE::TensorShapeProto shapeProto = *outputs[1]->Shape();
|
||||
const ::onnx::TensorShapeProto_Dimension dim0 = shapeProto.dim(0);
|
||||
*shapeProto.mutable_dim(0) = shapeProto.dim(1);
|
||||
*shapeProto.mutable_dim(1) = dim0;
|
||||
outputs[1]->SetShape(shapeProto);
|
||||
}
|
||||
|
||||
Node *constantNode = &graph->AddNode(transposeNodeName + "_constant_like", "ConstantLike", "",
|
||||
{ const_cast<NodeArg *>(squeezeNode->OutputDefs().at(0)) }, { outputs[1] });
|
||||
constantNode->AddAttribute("value", (float)1);
|
||||
|
@ -4175,6 +4184,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateDynamicAxisPackUnpackNode(const Funct
|
|||
else
|
||||
{
|
||||
std::string identityNodeName = UniqueNodeNameStorage::GetUniqueNodeName(src);
|
||||
// this is the original output of CNTK UnpackSequence op which is just an identity in ONNX.
|
||||
Node *identityNode = &graph->AddNode(identityNodeName, "Identity", "", inputs, { outputs[0] });
|
||||
functionNodes.emplace(src, identityNode);
|
||||
return identityNode;
|
||||
|
@ -4365,8 +4375,8 @@ void ResolveGraphAndSaveModel(onnxruntime::Model *model)
|
|||
model->SetProducerName(CNTK_ONNX_PRODUCER_NAME);
|
||||
|
||||
// Uncomment below code for debugging and trouble shooting.
|
||||
std::string savePath = "E:/LiqunWA/CNTK/ONNX/TestOps";
|
||||
onnxruntime::Model::Save(*model, savePath + "/" + dstGraph.GetOutputs()[0]->Name() + "_subgraph.onnx");
|
||||
//std::string savePath = "C:/Temp";
|
||||
//onnxruntime::Model::Save(*model, savePath + "/" + dstGraph.GetOutputs()[0]->Name() + "_subgraph.onnx");
|
||||
}
|
||||
|
||||
// use this method to attach an identity op so that state inputs/outputs of the subgraph are in the same order as the scan op
|
||||
|
@ -4468,14 +4478,25 @@ bool CNTKToONNXHelper::ProcessLoopsAndCheckCNTKNodeContinueCreate(const Function
|
|||
{
|
||||
std::vector<onnxruntime::NodeArg *> inputs;
|
||||
|
||||
std::vector<FunctionPtr> inputOps = ProcessLoopStepInputs(src, graph, functionNodes, variableNodes, compositeOutputsMap,
|
||||
inputs, scanLoops, createLoopIndex);
|
||||
for (auto f : inputOps)
|
||||
CreateNode(f, graph, functionNodes, variableNodes, compositeOutputsMap, scanLoops, createLoopIndex);
|
||||
bool useProcessLoopStepInputs = false;
|
||||
if (useProcessLoopStepInputs)
|
||||
{
|
||||
// TODO: remove loop specific code in ProcessInputs to make it more readable.
|
||||
// use ProcessLoopStepInputs to handle loop specific cases.
|
||||
std::vector<FunctionPtr> inputOps = ProcessLoopStepInputs(src, graph, functionNodes, variableNodes, compositeOutputsMap,
|
||||
inputs, scanLoops, createLoopIndex);
|
||||
for (auto f : inputOps)
|
||||
CreateNode(f, graph, functionNodes, variableNodes, compositeOutputsMap, scanLoops, createLoopIndex);
|
||||
}
|
||||
else
|
||||
{
|
||||
ProcessInputs(src, graph, functionNodes, variableNodes, compositeOutputsMap,
|
||||
inputs, scanLoops, createLoopIndex);
|
||||
}
|
||||
|
||||
// ProcessOutputs(src, outputs, graph);
|
||||
// This is for the final state output - ONNX requires graph output being a pure output that
|
||||
// do not server as an input to any other node. Thus we have to add an identity node.
|
||||
// do not server as an input to any other node. Thus we have to add an identity node.
|
||||
AddIdentityOp(*inputs[0], graph, UniqueNodeNameStorage::GetUniqueOutputNodeName(src->Outputs()[0]));
|
||||
|
||||
// do not create node from step ops.
|
||||
|
@ -4533,7 +4554,7 @@ bool CNTKToONNXHelper::ProcessLoopsAndCheckCNTKNodeContinueCreate(const Function
|
|||
std::vector<NodeArg*> output_args;
|
||||
|
||||
// sequence_lens
|
||||
// AddEmptyInput(graph, input_args);
|
||||
AddEmptyInput(graph, input_args);
|
||||
|
||||
int numStates = scanLoops[loopIndex].scanLoopStates.size();
|
||||
std::vector<int64_t> directions;
|
||||
|
@ -4725,7 +4746,7 @@ bool CNTKToONNXHelper::ProcessLoopsAndCheckCNTKNodeContinueCreate(const Function
|
|||
|
||||
GraphProto graphProto(scanGraph.ToGraphProto());
|
||||
scanNode->AddAttribute("body", graphProto);
|
||||
scanNode->AddAttribute("scan_output_directions", directions);
|
||||
scanNode->AddAttribute("directions", directions);
|
||||
scanNode->AddAttribute("num_scan_inputs", (int64_t)(scanLoops[loopIndex].m_scanInputs.size()));
|
||||
|
||||
return false;
|
||||
|
@ -4767,16 +4788,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateNode(const FunctionPtr& src,
|
|||
// return CreateLSTMNode(src, graph, functionNodes, variableNodes, compositeOutputsMap);
|
||||
//}
|
||||
//else
|
||||
if (cntkOpName == "Where")
|
||||
{
|
||||
return CreateWhereNode(src,
|
||||
graph,
|
||||
functionNodes,
|
||||
variableNodes,
|
||||
compositeOutputsMap,
|
||||
scanLoops, createLoopIndex);
|
||||
}
|
||||
else if (cntkOpName == "GatherPacked")
|
||||
if (cntkOpName == "GatherPacked")
|
||||
{
|
||||
return CreateNodeWithGatherPacked(src,
|
||||
graph,
|
||||
|
@ -5187,13 +5199,8 @@ std::vector<FunctionPtr> CNTKToONNXHelper::ProcessLoopStepInputs(const FunctionP
|
|||
if (cntkOpName != "FutureValue" && cntkOpName != "PastValue")
|
||||
LogicError("ProcessLoopStepInputs is called with wrong op.");
|
||||
Variable stateInput = src->Inputs()[0];
|
||||
if (createLoopIndex >= 0 && Operators::IsRNNOp(ToLegacyString(ToUTF8(src->Inputs()[0].Owner()->OpName()))))
|
||||
{
|
||||
// RNN op in a loop. shall use the underlying output as input to the step function.
|
||||
stateInput = stateInput.BlockFunctionVariableMapping();
|
||||
}
|
||||
|
||||
Variable initialStateInput = src->Inputs()[1];
|
||||
|
||||
{
|
||||
// process input
|
||||
std::string stateInputName = UniqueNodeNameStorage::GetUniqueInputNodeName(stateInput);
|
||||
|
@ -5205,70 +5212,73 @@ std::vector<FunctionPtr> CNTKToONNXHelper::ProcessLoopStepInputs(const FunctionP
|
|||
|
||||
bool inSubgraph = createLoopIndex >= 0;
|
||||
{
|
||||
// process initial state.
|
||||
// we need to make sure each scanLoopStates has a uniques NodeArg.
|
||||
// process initial state. we need to make sure each of scanLoopStates has a uniques NodeArg.
|
||||
// This NodeArg is for scan iteration to loop back at each scan iteration.
|
||||
// It cannot be shared between too state. MakeInitialStateNodeArgName combines initial state name
|
||||
// with the step op name to ensure the uniqueness.
|
||||
// initial state initializer is input to a scan op. it belongs to the parent graph.
|
||||
// initial state nodearg applies to both subgraph and the parent graph.
|
||||
// e.g.: initialzer name: Constant221FutureValue222
|
||||
// init state node arg for main graph: Constant221FutureValue222
|
||||
// init state node arg for subgraph: Constant221FutureValue222_subgraph
|
||||
std::string initialStateInitializerName = MakeInitialStateNodeArgName(src->Output());
|
||||
std::string initialStateInputName = inSubgraph ?
|
||||
MakeScanInputOutputNodeArgName(initialStateInitializerName) : initialStateInitializerName;
|
||||
// It cannot be shared between too state.
|
||||
// inputName = inputName + ToLegacyString(ToUTF8(src->Uid()));
|
||||
|
||||
// build initial state with shape that matches the state input. In case of scalar values, data tensor is constructed with the needed shape.
|
||||
onnx::TypeProto initialStateInputArgType = ToTypeProto(stateInput.Shape(), initialStateInput.HasBatchAxis(), initialStateInput.HasSequenceAxis());
|
||||
std::string initialStateInputName = MakeInitialStateNodeArgName(src->Output());
|
||||
if (inSubgraph)
|
||||
initialStateInputName = MakeScanInputOutputNodeArgName(initialStateInputName);
|
||||
|
||||
onnx::TypeProto initialStateInputArgType = ToTypeProto(initialStateInput.Shape(), initialStateInput.HasBatchAxis(), initialStateInput.HasSequenceAxis());
|
||||
UpdateONNXType(initialStateInput.GetDataType(), initialStateInputArgType);
|
||||
onnxruntime::NodeArg &initialStateInputArg = graph->GetOrCreateNodeArg(initialStateInputName, &initialStateInputArgType);
|
||||
inputs.push_back(&initialStateInputArg);
|
||||
|
||||
if (inSubgraph)
|
||||
{
|
||||
// create initial state constant and final state nodeArg
|
||||
// define a state output so executors know this is the place
|
||||
// to run body function in loops to get the next t + 1 state.
|
||||
Variable stateOutput = src->Outputs()[0];
|
||||
scanLoops[createLoopIndex].scanLoopStates.push_back(ScanLoopState(initialStateInput, nullptr, stateOutput,
|
||||
src->OpName() == L"PastValue" ? 1 : -1));
|
||||
scanLoops[createLoopIndex].scanLoopStates.rbegin()->m_initialStateNodeArg = &initialStateInputArg;
|
||||
|
||||
bool isInitialStateConstant = initialStateInput.IsParameter() || initialStateInput.IsConstant();
|
||||
if (isInitialStateConstant)
|
||||
{
|
||||
auto srcTensor = initialStateInput.IsParameter() ?
|
||||
Parameter(initialStateInput).Value() : Constant(initialStateInput).Value();
|
||||
|
||||
onnx::TensorProto dstTensor;
|
||||
dstTensor.set_name(initialStateInitializerName);
|
||||
|
||||
// in case initial state being a scalar, we needs to expand it to the shape of state.
|
||||
std::vector<int64_t> initialStateShape = ToINTS(initialStateInputArgType);
|
||||
// Initial state as an input to the scan op shall have batch axis.
|
||||
// initial state is a input to a Scan node. As an input, it needs to have batch axis.
|
||||
std::vector<int64_t> initialStateShape =
|
||||
ToINTS(ToTypeProto(src->Inputs()[0].Shape(), (int)(src->Inputs()[0].DynamicAxes().size())));
|
||||
if (ScanWithoutBatchAxis)
|
||||
initialStateShape.insert(initialStateShape.begin(), BatchSizeProcessor::FreeBatchSize());
|
||||
|
||||
if (!srcTensor->Shape().IsScalar())
|
||||
{
|
||||
// initialStateInputArgType can be off by a batch dimension of size 1 if ScanWithoutBatchAxis is true
|
||||
onnx::TypeProto initialStateInputArgTypeWithBatch = ToTypeProto(initialStateShape, false);
|
||||
UpdateONNXType(initialStateInput.GetDataType(), initialStateInputArgTypeWithBatch);
|
||||
CopyTensor(srcTensor, dstTensor, &initialStateInputArgTypeWithBatch);
|
||||
}
|
||||
else
|
||||
{
|
||||
FillTensorWithScalarFromSingleSource(srcTensor, dstTensor, initialStateShape);
|
||||
}
|
||||
onnx::TensorProto dstTensor;
|
||||
dstTensor.set_name(initialStateInputName);
|
||||
|
||||
auto srcTensor = initialStateInput.IsParameter() ?
|
||||
Parameter(initialStateInput).Value() : Constant(initialStateInput).Value();
|
||||
// FillTensorWithScalar takes vector of srcs and assumes initialStateShape a collection of shape
|
||||
// for each src
|
||||
FillTensorWithScalarFromSingleSource(srcTensor, dstTensor, initialStateShape);
|
||||
// initial state is input to a scan op. it belongs to the parent graph
|
||||
scanLoops[createLoopIndex].scanLoopStates.rbegin()->m_initialStateTensor = dstTensor;
|
||||
scanLoops[createLoopIndex].scanLoopStates.rbegin()->m_hasInitializer = true;
|
||||
|
||||
graph->AddOuterScopeNodeArg(initialStateInputName);
|
||||
}
|
||||
else
|
||||
{
|
||||
//if (!scanLoopState.m_initialState.HasBatchAxis())
|
||||
//{
|
||||
// // expand with batch dimension - initial state may or may not has batch dimension.
|
||||
// // 1. in case initial state is a trained constant, there is no batch dimension.
|
||||
// // ProcessLoopStepInputs (and ProcessInpits) has this case handled by adding an FreeDimension
|
||||
// // 2. in case where initial state is from another computation node, it is likely
|
||||
// // it already has a batch dimension.
|
||||
|
||||
//}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (inSubgraph)
|
||||
{
|
||||
// create initial state constant and final state nodeArg
|
||||
// define a state output so executors know this is the place
|
||||
// to run body function in loops to get the next t + 1 state.
|
||||
Variable stateOutput = src->Outputs()[0];
|
||||
scanLoops[createLoopIndex].scanLoopStates.push_back(ScanLoopState(initialStateInput, nullptr, stateOutput,
|
||||
src->OpName() == L"PastValue" ? 1 : -1));
|
||||
scanLoops[createLoopIndex].scanLoopStates.rbegin()->m_initialStateNodeArg = inputs[1];
|
||||
}
|
||||
|
||||
std::vector<FunctionPtr> inputOps;
|
||||
if (stateInput.IsOutput())
|
||||
inputOps.push_back(stateInput.Owner());
|
||||
|
@ -5368,6 +5378,19 @@ void CNTKToONNXHelper::ProcessInputs(const FunctionPtr& src,
|
|||
|
||||
bool isInSubGraph = createLoopIndex >= 0 && createLoopIndex < scanLoops.size();
|
||||
|
||||
bool isInitialStateOfSubGraph = false;
|
||||
if ((createLoopIndex >= 0 && createLoopIndex < scanLoops.size()) && inputIndex == 1)
|
||||
{
|
||||
for (auto &f : scanLoops[createLoopIndex].m_loopstepfunctions)
|
||||
{
|
||||
if (f->Inputs().size() == 2 && f->Inputs()[inputIndex].Uid() == input.Uid())
|
||||
{
|
||||
isInitialStateOfSubGraph = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool isScanInputInSubgraph = createLoopIndex != -1 &&
|
||||
std::find_if(scanLoops[createLoopIndex].m_scanInputs.begin(), scanLoops[createLoopIndex].m_scanInputs.end(),
|
||||
[inputName](Variable v) {return inputName == UniqueNodeNameStorage::GetUniqueInputNodeName(v); })
|
||||
|
@ -5386,6 +5409,7 @@ void CNTKToONNXHelper::ProcessInputs(const FunctionPtr& src,
|
|||
// if initial state is a scalar, it will be created with correct shape later in this method.
|
||||
|
||||
ScanLoop &scanLoop = scanLoops[createLoopIndex];
|
||||
// to match "else if (isInitialStateOfSubGraph)" case
|
||||
// one intial state may map to multiple final states.
|
||||
// to make one to one mapping from initial to final states,
|
||||
// we have to split the inital state.
|
||||
|
@ -5448,6 +5472,14 @@ void CNTKToONNXHelper::ProcessInputs(const FunctionPtr& src,
|
|||
}
|
||||
inputArgType = ToTypeProto(inputShape, input.HasBatchAxis(), input.HasSequenceAxis());
|
||||
}
|
||||
else if (isInitialStateOfSubGraph)
|
||||
{
|
||||
// for initial state, we need to make sure each of scanLoopStates has a uniques NodeArg.
|
||||
// This NodeArg is for scan iteration to loop back at each scan iteration.
|
||||
// It cannot be shared between too state.
|
||||
inputName = MakeInitialStateNodeArgName(src->Output());
|
||||
inputArgType = ToTypeProto(src->Inputs()[0].Shape(), src->Inputs()[0].HasBatchAxis(), src->Inputs()[0].HasSequenceAxis());
|
||||
}
|
||||
else
|
||||
{
|
||||
inputArgType = ToTypeProto(input.Shape(), input.HasBatchAxis(), input.HasSequenceAxis());
|
||||
|
@ -5507,6 +5539,19 @@ void CNTKToONNXHelper::ProcessInputs(const FunctionPtr& src,
|
|||
UpdateONNXType(input.GetDataType(), inputArgType);
|
||||
}
|
||||
|
||||
if (isInitialStateOfSubGraph)
|
||||
{
|
||||
Variable initialState = src->Inputs()[1];
|
||||
Variable stateInput = src->Inputs()[0];
|
||||
Variable stateOutput = src->Outputs()[0];
|
||||
|
||||
// create initial state constant and final state nodeArg
|
||||
// define a state output so executors know this is the place
|
||||
// to run body function in loops to get the next t + 1 state.
|
||||
scanLoops[createLoopIndex].scanLoopStates.push_back(ScanLoopState(initialState, nullptr, stateOutput,
|
||||
src->OpName() == L"PastValue" ? 1 : -1));
|
||||
}
|
||||
|
||||
bool addedInitializer = false;
|
||||
//
|
||||
// Leaf nodes are data entry to the graph and need their own node with only output arg.
|
||||
|
@ -5522,20 +5567,39 @@ void CNTKToONNXHelper::ProcessInputs(const FunctionPtr& src,
|
|||
onnx::TensorProto dstTensor;
|
||||
dstTensor.set_name(inputName);
|
||||
|
||||
CopyTensor(srcTensor, dstTensor, &inputArgType);
|
||||
if (CNTKToONNXHelper::globalGraph && createLoopIndex != -1)
|
||||
if (isInitialStateOfSubGraph)
|
||||
{
|
||||
scanLoops[createLoopIndex].initializerAsInput.push_back(inputName);
|
||||
// in case initial state being a scalar, we needs to expand it to the shape of state.
|
||||
// initial state is a input to a Scan node. As an input, it needs to have batch axis.
|
||||
std::vector<int64_t> initialStateShape =
|
||||
ToINTS(ToTypeProto(src->Inputs()[0].Shape(), (int)(src->Inputs()[0].DynamicAxes().size())));
|
||||
if (ScanWithoutBatchAxis)
|
||||
initialStateShape.insert(initialStateShape.begin(), BatchSizeProcessor::FreeBatchSize());
|
||||
|
||||
// With Bing.Malta50.proto1_128_gru_normv3_ep3_z.model, I can only got ONNX runtime
|
||||
// to produce matching results by putting initializers in the subgraphs
|
||||
// (calling graph->AddInitializedTensor instead).
|
||||
CNTKToONNXHelper::globalGraph->AddInitializedTensor(dstTensor);
|
||||
// graph->AddInitializedTensor(dstTensor);
|
||||
addedInitializer = true;
|
||||
// FillTensorWithScalar takes vector of srcs and assumes initialStateShape a collection of shape
|
||||
// for each src
|
||||
FillTensorWithScalarFromSingleSource(srcTensor, dstTensor, initialStateShape);
|
||||
// initial state is input to a scan op. it belongs to the parent graph
|
||||
scanLoops[createLoopIndex].scanLoopStates.rbegin()->m_initialStateTensor = dstTensor;
|
||||
scanLoops[createLoopIndex].scanLoopStates.rbegin()->m_hasInitializer = true;
|
||||
}
|
||||
else
|
||||
graph->AddInitializedTensor(dstTensor);
|
||||
{
|
||||
CopyTensor(srcTensor, dstTensor, &inputArgType);
|
||||
if (CNTKToONNXHelper::globalGraph && createLoopIndex != -1)
|
||||
{
|
||||
scanLoops[createLoopIndex].initializerAsInput.push_back(inputName);
|
||||
|
||||
// With Bing.Malta50.proto1_128_gru_normv3_ep3_z.model, I can only got ONNX runtime
|
||||
// to produce matching results by putting initializers in the subgraphs
|
||||
// (calling graph->AddInitializedTensor instead).
|
||||
CNTKToONNXHelper::globalGraph->AddInitializedTensor(dstTensor);
|
||||
// graph->AddInitializedTensor(dstTensor);
|
||||
addedInitializer = true;
|
||||
}
|
||||
else
|
||||
graph->AddInitializedTensor(dstTensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -5543,8 +5607,7 @@ void CNTKToONNXHelper::ProcessInputs(const FunctionPtr& src,
|
|||
onnxruntime::NodeArg *adjusted = GetInputAdjustmentForBroadcast(graph, src, input, inputIndex, inputArgType,
|
||||
compositeOutputsMap);
|
||||
|
||||
|
||||
if ((isOutputOfStepFunction && isInSubGraph) || isScanInputInSubgraph)
|
||||
if (isInitialStateOfSubGraph || (isOutputOfStepFunction && isInSubGraph) || isScanInputInSubgraph)
|
||||
{
|
||||
inputName = MakeScanInputOutputNodeArgName(inputName);
|
||||
}
|
||||
|
@ -5557,6 +5620,9 @@ void CNTKToONNXHelper::ProcessInputs(const FunctionPtr& src,
|
|||
|
||||
inputs.push_back(&inputArg);
|
||||
|
||||
if (isInitialStateOfSubGraph)
|
||||
scanLoops[createLoopIndex].scanLoopStates.rbegin()->m_initialStateNodeArg = inputs[1];
|
||||
|
||||
if (cntkOpName == "Reshape")
|
||||
{
|
||||
// ONNX1.2 reshape node take shape as input instead of attribute.
|
||||
|
@ -5750,15 +5816,8 @@ void CNTKToONNXHelper::TraverseGraph(const FunctionPtr& src,
|
|||
if (input.IsPlaceholder())
|
||||
{
|
||||
input = input.BlockFunctionVariableMapping();
|
||||
|
||||
if (!Operators::IsRNNOp(opName) && input.IsPlaceholder())
|
||||
{
|
||||
// this could be the case that a block take an input from another block.
|
||||
// So try map one more time.
|
||||
input = input.BlockFunctionVariableMapping();
|
||||
if (!input.IsInitialized() || input.IsPlaceholder())
|
||||
LogicError("Node '%S': Placeholder isn't supported currently.", src->AsString().c_str());
|
||||
}
|
||||
LogicError("Node '%S': Placeholder isn't supported currently.", src->AsString().c_str());
|
||||
}
|
||||
|
||||
if (input.IsInitialized() && input.IsOutput())
|
||||
|
@ -6931,8 +6990,6 @@ std::pair<std::vector<int>, std::vector<int>> CNTKToONNXHelper::GetONNXPadsAttri
|
|||
void CNTKToONNXHelper::FillTensorWithScalarFromSingleSource(const NDArrayViewPtr &src,
|
||||
onnx::TensorProto& dst, const std::vector<int64_t> dstShape)
|
||||
{
|
||||
if (!src->Shape().IsScalar())
|
||||
LogicError("FillTensorWithScalarFromSingleSource can only work with a scalar.");
|
||||
auto dataType = src->GetDataType();
|
||||
SetTensorType(dst, dataType);
|
||||
int64_t eachSrcSize = std::accumulate(dstShape.begin(), dstShape.end(), (int64_t)1, std::multiplies<int64_t>());
|
||||
|
|
Загрузка…
Ссылка в новой задаче