This reverts commit 47345c6c54.
This commit is contained in:
liqfu 2019-01-03 15:45:28 -08:00
Родитель 47345c6c54
Коммит 6c56fdd2f9
1 изменённых файлов: 240 добавлений и 183 удалений

Просмотреть файл

@ -291,13 +291,6 @@ private:
std::vector<ScanLoop> &scanLoops, int createLoopIndex,
bool isFirst);
static onnxruntime::Node* CreateWhereNode(const FunctionPtr& src,
onnxruntime::Graph* graph,
std::unordered_map<FunctionPtr, onnxruntime::Node*>& functionNodes,
std::unordered_map<Variable, onnxruntime::Node*>& variableNodes,
const std::unordered_map<Variable, Variable>& compositeOutputsMap,
std::vector<ScanLoop> &scanLoops, int createLoopIndex);
static onnxruntime::Node* CreateNodeWithGatherPacked(const FunctionPtr& src,
onnxruntime::Graph* graph,
std::unordered_map<FunctionPtr, onnxruntime::Node*>& functionNodes,
@ -3785,8 +3778,65 @@ onnxruntime::Node* CNTKToONNXHelper::CreateReconcileDynamicAxisNode(const Functi
const std::unordered_map<Variable, Variable>& compositeOutputsMap,
std::vector<ScanLoop> &scanLoops, int createLoopIndex)
{
return CreateSequenceBroadcastAsNode(src, graph, functionNodes, variableNodes, compositeOutputsMap,
scanLoops, createLoopIndex);
bool inLoop = createLoopIndex != -1;
if (!inLoop)
// TODO: sequence.broadcast_as may still be in a loop.
// Investigate whether both sequence.broadcast_as and reconcile_dyanmic_axis is fully equivalent.
return CreateSequenceBroadcastAsNode(src, graph, functionNodes, variableNodes, compositeOutputsMap,
scanLoops, createLoopIndex);
std::vector<onnxruntime::NodeArg *> inputs, outputs;
ProcessInputs(src, graph, functionNodes, variableNodes, compositeOutputsMap, inputs, scanLoops, createLoopIndex);
ProcessOutputs(src, inputs, outputs, graph);
Variable input = src->Inputs()[0];
Variable broadcastAs = src->Inputs()[1];
Variable output = src->Outputs()[0];
std::vector<int64_t> newShape = ToINTS(*outputs[0]->TypeAsProto());
std::string nodeName = UniqueNodeNameStorage::GetUniqueNodeName(src);
std::string squeezedBroadcastNodeArgName = nodeName + "_squeezed";
onnxruntime::Node* broadcastAsSqueezed = ExtractShapeWithDynamicAxes(
graph, src->Inputs()[1], inputs[1], squeezedBroadcastNodeArgName, inLoop);
NodeArg *broadcastNodeArg = const_cast<onnxruntime::NodeArg *>(broadcastAsSqueezed->OutputDefs()[0]);
//broadcastAsSqueezed has shape [sequence, batch], append ONEs to get to the same rank as the input before broadcasting
if (input.Shape().Rank() != 0)
{
std::vector<int64_t> newShape = ToINTS(*broadcastAsSqueezed->OutputDefs()[0]->TypeAsProto());
for (int i = 0; i < input.Shape().Rank(); i++)
newShape.push_back(1);
std::string broadcastReshapedNodeName = nodeName + "_squeezed_reshaped";
onnxruntime::Node* broadcastReshaped = AddReshapeNode(
*broadcastNodeArg, newShape, broadcastReshapedNodeName, graph);
broadcastNodeArg = const_cast<onnxruntime::NodeArg *>(broadcastReshaped->OutputDefs()[0]);
}
NodeArg *inputNodeArg = nullptr;
if (input.DynamicAxes().size() == 1)
{
// input does not have sequence axis, insert with size 1 so it is broadcasted
std::vector<int64_t> newShape = ToINTS(*inputs[0]->TypeAsProto());
newShape.insert(newShape.begin(), 1);
Node* inputReshapeNode = AddReshapeNode(*inputs[0], newShape, nodeName + "_input_reshape", graph);
inputNodeArg = const_cast<onnxruntime::NodeArg *>(inputReshapeNode->OutputDefs()[0]);
}
else if(input.DynamicAxes().size() == 0)
{
inputNodeArg = inputs[0];
}
else if (input.DynamicAxes().size() == 2)
{
// This is the case with unfold op. It has been verified with onnx runtime.
inputNodeArg = inputs[0];
}
onnxruntime::Node* elementWiseNode = &graph->AddNode(nodeName + "_add", "Add", "",
{ inputNodeArg, broadcastNodeArg }, outputs);
functionNodes.emplace(src, elementWiseNode);
return elementWiseNode;
}
//
@ -3816,51 +3866,23 @@ onnxruntime::Node* CNTKToONNXHelper::CreateSequenceBroadcastAsNode(const Functio
//broadcastAsSqueezed has shape [sequence, batch], append 1s to get to the same rank as the input before broadcasting
if (input.Shape().Rank() != 0)
{
TypeProto typeProto = *broadcastAsSqueezed->OutputDefs()[0]->TypeAsProto();
TensorShapeProto* tensorShapeProto = typeProto.mutable_tensor_type()->mutable_shape();
std::vector<int64_t> unsqueezeAxis;
std::vector<int64_t> newShape = ToINTS(*broadcastAsSqueezed->OutputDefs()[0]->TypeAsProto());
for (int i = 0; i < input.Shape().Rank(); i++)
{
tensorShapeProto->add_dim()->set_dim_value(1);
unsqueezeAxis.push_back(i + broadcastNodeArg->Shape()->dim_size());
}
std::string broadcastUnsqueezeNodeName = nodeName + "_broadcast_as_unsqueeze";
std::string broadcastUnsqueezeOutputNodeArgName = broadcastUnsqueezeNodeName + "_output";
NodeArg *broadcastUnsqueezeOutputNodeArg = &graph->GetOrCreateNodeArg(broadcastUnsqueezeOutputNodeArgName, &typeProto);
Node *unsqueezeNode = &graph->AddNode(broadcastUnsqueezeNodeName, "Unsqueeze", "",
{ broadcastNodeArg }, { broadcastUnsqueezeOutputNodeArg });
unsqueezeNode->AddAttribute("axes", unsqueezeAxis);
broadcastNodeArg = const_cast<onnxruntime::NodeArg *>(unsqueezeNode->OutputDefs()[0]);
newShape.push_back(1);
std::string broadcastReshapedNodeName = nodeName + "_squeezed_reshaped";
onnxruntime::Node* broadcastReshaped = AddReshapeNode(
*broadcastNodeArg, newShape, broadcastReshapedNodeName, graph);
broadcastNodeArg = const_cast<onnxruntime::NodeArg *>(broadcastReshaped->OutputDefs()[0]);
}
NodeArg *inputNodeArg = nullptr;
if (input.DynamicAxes().size() == 1 && createLoopIndex < 0)
if (input.DynamicAxes().size() == 1)
{
// not in a loop so there shall be a sequence axis.
// input does not have sequence axis, prepend dim with dim_value = 1 so it is broadcasted
onnx::TypeProto unsqueezeNodeOutputTYpeProto = MakeTypeProtoWithShape();
onnx::TensorProto_DataType elemType = ConvertDataTypeCNTKToTensorProto(input.GetDataType());
unsqueezeNodeOutputTYpeProto.mutable_tensor_type()->set_elem_type(elemType);
ONNX_NAMESPACE::TensorShapeProto *shapeProto = unsqueezeNodeOutputTYpeProto.mutable_tensor_type()->mutable_shape();
shapeProto->add_dim()->set_dim_value(1);
for (int dim = 0; dim < inputs[0]->Shape()->dim_size(); dim++)
{
if (inputs[0]->Shape()->dim(dim).has_dim_value())
shapeProto->add_dim()->set_dim_value(inputs[0]->Shape()->dim(dim).dim_value());
else
shapeProto->add_dim()->set_dim_param(inputs[0]->Shape()->dim(dim).dim_param());
}
std::string unsqueezeNodeName = nodeName + "_input_unsqueeze";
NodeArg *unsqueezeOutputNodeArg = &graph->GetOrCreateNodeArg(unsqueezeNodeName + "_output", &unsqueezeNodeOutputTYpeProto);
Node* unsqueezeNode = &graph->AddNode(unsqueezeNodeName, "Unsqueeze", "",
{ inputs[0] }, { unsqueezeOutputNodeArg });
unsqueezeNode->AddAttribute("axes", std::vector<int64_t>({ 0 }));
inputNodeArg = const_cast<onnxruntime::NodeArg *>(unsqueezeNode->OutputDefs()[0]);
// input does not have sequence axis, insert with size 1 so it is broadcasted
std::vector<int64_t> newShape = ToINTS(*inputs[0]->TypeAsProto());
newShape.insert(newShape.begin(), 1);
Node* inputReshapeNode = AddReshapeNode(*inputs[0], newShape, nodeName + "_input_reshape", graph);
inputNodeArg = const_cast<onnxruntime::NodeArg *>(inputReshapeNode->OutputDefs()[0]);
}
else
{
@ -3885,9 +3907,6 @@ onnxruntime::Node* CNTKToONNXHelper::CreateSequenceGatherNode(const FunctionPtr&
if (CNTKToONNXHelper::isProcessingScan)
LogicError("SequenceGather cannot be in a scan loop");
const std::string nodeName = UniqueNodeNameStorage::GetUniqueNodeName(src);
const std::string onnxOpName = "Compress";
std::vector<onnxruntime::NodeArg *> inputs;
ProcessInputs(src, graph, functionNodes, variableNodes, compositeOutputsMap, inputs, scanLoops, createLoopIndex);
@ -3897,7 +3916,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateSequenceGatherNode(const FunctionPtr&
//ProcessOutputs(src, outputs, graph);
// Cast inputs[1] from tensor<float> to tensor<bool>
const std::string outputNodeArgName = inputs[1]->Name() + "_cast_to_bool_" + nodeName;
const std::string outputNodeArgName = inputs[1]->Name() + "_cast_to_bool";
Node *castNode = AddCastNode(*inputs[1], graph,
TensorProto_DataType::TensorProto_DataType_BOOL, outputNodeArgName);
@ -3914,6 +3933,8 @@ onnxruntime::Node* CNTKToONNXHelper::CreateSequenceGatherNode(const FunctionPtr&
NodeArg& compressOutputNodeArg = CreateNodeArg(src->Outputs()[0], graph, false);
std::string nodeName = UniqueNodeNameStorage::GetUniqueNodeName(src);
std::string onnxOpName = "Compress";
Node *compressNode = &graph->AddNode(nodeName, onnxOpName, "", inputs, { &compressOutputNodeArg });
int64_t sequenceAxis = 0;
@ -3946,34 +3967,9 @@ onnxruntime::Node* CNTKToONNXHelper::CreateSequenceReduceElementsNode(const Func
Node *node = &graph->AddNode(nodeName, onnxOpName, "", inputs, outputs);
SetReduceElementsAttributes(br, node, true);
functionNodes.emplace(src, node);
return node;
}
onnxruntime::Node* CNTKToONNXHelper::CreateWhereNode(const FunctionPtr& src,
onnxruntime::Graph* graph,
std::unordered_map<FunctionPtr, onnxruntime::Node*>& functionNodes,
std::unordered_map<Variable, onnxruntime::Node*>& variableNodes,
const std::unordered_map<Variable, Variable>& compositeOutputsMap,
std::vector<ScanLoop> &scanLoops, int createLoopIndex)
{
assert(src->OpName() == L"Where");
std::vector<onnxruntime::NodeArg *> inputs;
ProcessInputs(src, graph, functionNodes, variableNodes, compositeOutputsMap, inputs, scanLoops, createLoopIndex);
std::vector<onnxruntime::NodeArg *> outputs;
ProcessOutputs(src, inputs, outputs, graph);
// TODO: implement Where op.
// This is just to unblock Seq2Seq test.
Node* whereNode = AddIdentityOp(*inputs[0], graph, outputs[0]->Name());
functionNodes.emplace(src, whereNode);
return whereNode;
}
onnxruntime::Node* CNTKToONNXHelper::CreateNodeWithGatherPacked(const FunctionPtr& src,
onnxruntime::Graph* graph,
std::unordered_map<FunctionPtr, onnxruntime::Node*>& functionNodes,
@ -4038,15 +4034,13 @@ bool IsDynamicAxisPackUnpack(const std::string &cntkOpName)
onnxruntime::Node* CNTKToONNXHelper::ExtractShapeWithDynamicAxes(onnxruntime::Graph* graph,
Variable input, NodeArg* inputNodeArg, const std::string &nodeName, bool inLoop)
{
NodeArg *constantLikeInput = nullptr;
if (input.Shape().Rank() == 0)
{
// already in the shape of [*, #]
constantLikeInput = inputNodeArg;
// TODO:
return nullptr;
}
else
{
// need to slice off static axes and squeeze to get [*, #]
std::vector<int64_t> sliceStarts, sliceEnds;
std::vector<Axis> axes;
for (int i = 0; i < input.Shape().Rank(); i++)
@ -4063,8 +4057,6 @@ onnxruntime::Node* CNTKToONNXHelper::ExtractShapeWithDynamicAxes(onnxruntime::Gr
Node *squeezeNode = AddSqueezeNode(const_cast<NodeArg &>(*sliceNode->OutputDefs().at(0)), sliceAxes,
nodeName + "_squeeze", graph);
constantLikeInput = const_cast<NodeArg *>(squeezeNode->OutputDefs().at(0));
}
// prepare output NodeArg with shape of [sequence, batch]
onnx::TypeProto typeProto = MakeTypeProtoWithShape();
@ -4072,17 +4064,18 @@ onnxruntime::Node* CNTKToONNXHelper::ExtractShapeWithDynamicAxes(onnxruntime::Gr
typeProto.mutable_tensor_type()->set_elem_type(elemType);
onnxruntime::NodeArg &outputNodeArg = graph->GetOrCreateNodeArg(nodeName + "_output", &typeProto);
ONNX_NAMESPACE::TensorShapeProto shapeProto;
if (!inLoop)
shapeProto.add_dim()->set_dim_param(FreeSequenceDimParam);
if (!inLoop || !ScanWithoutBatchAxis)
shapeProto.add_dim()->set_dim_value(BatchSizeProcessor::FreeBatchSize());
outputNodeArg.SetShape(shapeProto);
ONNX_NAMESPACE::TensorShapeProto shapeProto;
if (!inLoop)
shapeProto.add_dim()->set_dim_param(FreeSequenceDimParam);
if (!inLoop || !ScanWithoutBatchAxis)
shapeProto.add_dim()->set_dim_value(BatchSizeProcessor::FreeBatchSize());
outputNodeArg.SetShape(shapeProto);
Node *constantNode = &graph->AddNode(nodeName + "_constant_like", "ConstantLike", "",
{ constantLikeInput }, { &outputNodeArg });
constantNode->AddAttribute("value", (float)0);
return constantNode;
Node *constantNode = &graph->AddNode(nodeName + "_constant_like", "ConstantLike", "",
{ const_cast<NodeArg *>(squeezeNode->OutputDefs().at(0)) }, { &outputNodeArg });
constantNode->AddAttribute("value", (float)0);
return constantNode;
}
}
// To handle "UnpackBatchAxis" , "ToBatchAxis" , "UnpackSequenceOp", "ToSequenceOp" ops.
@ -4149,6 +4142,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateDynamicAxisPackUnpackNode(const Funct
if (src->OpName() == L"UnpackSequenceOp" && src->Outputs().size() == 2)
{
std::string unpackMaskNodeName = transposeNodeName + "_mask";
// 1. slice from [#,*][d1, d2...] to [#,*][1, 1...]
// 2. squeeze to [#,*][1]
// 3. ConstantLike with value = 1
@ -4161,11 +4155,26 @@ onnxruntime::Node* CNTKToONNXHelper::CreateDynamicAxisPackUnpackNode(const Funct
}
std::string sliceOutputArgName = transposeNodeName + "_slice";
Node *sliceNode = AddSliceNode(*outputs[0], sliceAxes, sliceStarts, sliceEnds, sliceOutputArgName, graph);
Node *sliceNode = AddSliceNode(*inputs[0], sliceAxes, sliceStarts, sliceEnds, sliceOutputArgName, graph);
Node *squeezeNode = AddSqueezeNode(const_cast<NodeArg &>(*sliceNode->OutputDefs().at(0)), sliceAxes,
transposeNodeName + "_squeeze", graph);
if (src->Outputs()[1].DynamicAxes().size() == 1 &&
src->Outputs()[1].Shape().Rank() > 0 &&
src->Outputs()[1].Shape()[src->Outputs()[1].Shape().Rank() - 1] == NDShape::FreeDimension)
{
// the second output of UnpackSequenceOp has a shape of [batch][sequence]
// ToTypeProto does not handle this case (it only swap if there are 2 dynamic axes [batch, sequence][d1...])
// make output onnx shape [sequence, batch] (was [batch,sequence])
ONNX_NAMESPACE::TensorShapeProto shapeProto = *outputs[1]->Shape();
const ::onnx::TensorShapeProto_Dimension dim0 = shapeProto.dim(0);
*shapeProto.mutable_dim(0) = shapeProto.dim(1);
*shapeProto.mutable_dim(1) = dim0;
outputs[1]->SetShape(shapeProto);
}
Node *constantNode = &graph->AddNode(transposeNodeName + "_constant_like", "ConstantLike", "",
{ const_cast<NodeArg *>(squeezeNode->OutputDefs().at(0)) }, { outputs[1] });
constantNode->AddAttribute("value", (float)1);
@ -4175,6 +4184,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateDynamicAxisPackUnpackNode(const Funct
else
{
std::string identityNodeName = UniqueNodeNameStorage::GetUniqueNodeName(src);
// this is the original output of CNTK UnpackSequence op which is just an identity in ONNX.
Node *identityNode = &graph->AddNode(identityNodeName, "Identity", "", inputs, { outputs[0] });
functionNodes.emplace(src, identityNode);
return identityNode;
@ -4365,8 +4375,8 @@ void ResolveGraphAndSaveModel(onnxruntime::Model *model)
model->SetProducerName(CNTK_ONNX_PRODUCER_NAME);
// Uncomment below code for debugging and trouble shooting.
std::string savePath = "E:/LiqunWA/CNTK/ONNX/TestOps";
onnxruntime::Model::Save(*model, savePath + "/" + dstGraph.GetOutputs()[0]->Name() + "_subgraph.onnx");
//std::string savePath = "C:/Temp";
//onnxruntime::Model::Save(*model, savePath + "/" + dstGraph.GetOutputs()[0]->Name() + "_subgraph.onnx");
}
// use this method to attach an identity op so that state inputs/outputs of the subgraph are in the same order as the scan op
@ -4468,14 +4478,25 @@ bool CNTKToONNXHelper::ProcessLoopsAndCheckCNTKNodeContinueCreate(const Function
{
std::vector<onnxruntime::NodeArg *> inputs;
std::vector<FunctionPtr> inputOps = ProcessLoopStepInputs(src, graph, functionNodes, variableNodes, compositeOutputsMap,
inputs, scanLoops, createLoopIndex);
for (auto f : inputOps)
CreateNode(f, graph, functionNodes, variableNodes, compositeOutputsMap, scanLoops, createLoopIndex);
bool useProcessLoopStepInputs = false;
if (useProcessLoopStepInputs)
{
// TODO: remove loop specific code in ProcessInputs to make it more readable.
// use ProcessLoopStepInputs to handle loop specific cases.
std::vector<FunctionPtr> inputOps = ProcessLoopStepInputs(src, graph, functionNodes, variableNodes, compositeOutputsMap,
inputs, scanLoops, createLoopIndex);
for (auto f : inputOps)
CreateNode(f, graph, functionNodes, variableNodes, compositeOutputsMap, scanLoops, createLoopIndex);
}
else
{
ProcessInputs(src, graph, functionNodes, variableNodes, compositeOutputsMap,
inputs, scanLoops, createLoopIndex);
}
// ProcessOutputs(src, outputs, graph);
// This is for the final state output - ONNX requires graph output being a pure output that
// do not server as an input to any other node. Thus we have to add an identity node.
// do not server as an input to any other node. Thus we have to add an identity node.
AddIdentityOp(*inputs[0], graph, UniqueNodeNameStorage::GetUniqueOutputNodeName(src->Outputs()[0]));
// do not create node from step ops.
@ -4533,7 +4554,7 @@ bool CNTKToONNXHelper::ProcessLoopsAndCheckCNTKNodeContinueCreate(const Function
std::vector<NodeArg*> output_args;
// sequence_lens
// AddEmptyInput(graph, input_args);
AddEmptyInput(graph, input_args);
int numStates = scanLoops[loopIndex].scanLoopStates.size();
std::vector<int64_t> directions;
@ -4725,7 +4746,7 @@ bool CNTKToONNXHelper::ProcessLoopsAndCheckCNTKNodeContinueCreate(const Function
GraphProto graphProto(scanGraph.ToGraphProto());
scanNode->AddAttribute("body", graphProto);
scanNode->AddAttribute("scan_output_directions", directions);
scanNode->AddAttribute("directions", directions);
scanNode->AddAttribute("num_scan_inputs", (int64_t)(scanLoops[loopIndex].m_scanInputs.size()));
return false;
@ -4767,16 +4788,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateNode(const FunctionPtr& src,
// return CreateLSTMNode(src, graph, functionNodes, variableNodes, compositeOutputsMap);
//}
//else
if (cntkOpName == "Where")
{
return CreateWhereNode(src,
graph,
functionNodes,
variableNodes,
compositeOutputsMap,
scanLoops, createLoopIndex);
}
else if (cntkOpName == "GatherPacked")
if (cntkOpName == "GatherPacked")
{
return CreateNodeWithGatherPacked(src,
graph,
@ -5187,13 +5199,8 @@ std::vector<FunctionPtr> CNTKToONNXHelper::ProcessLoopStepInputs(const FunctionP
if (cntkOpName != "FutureValue" && cntkOpName != "PastValue")
LogicError("ProcessLoopStepInputs is called with wrong op.");
Variable stateInput = src->Inputs()[0];
if (createLoopIndex >= 0 && Operators::IsRNNOp(ToLegacyString(ToUTF8(src->Inputs()[0].Owner()->OpName()))))
{
// RNN op in a loop. shall use the underlying output as input to the step function.
stateInput = stateInput.BlockFunctionVariableMapping();
}
Variable initialStateInput = src->Inputs()[1];
{
// process input
std::string stateInputName = UniqueNodeNameStorage::GetUniqueInputNodeName(stateInput);
@ -5205,70 +5212,73 @@ std::vector<FunctionPtr> CNTKToONNXHelper::ProcessLoopStepInputs(const FunctionP
bool inSubgraph = createLoopIndex >= 0;
{
// process initial state.
// we need to make sure each scanLoopStates has a uniques NodeArg.
// process initial state. we need to make sure each of scanLoopStates has a uniques NodeArg.
// This NodeArg is for scan iteration to loop back at each scan iteration.
// It cannot be shared between too state. MakeInitialStateNodeArgName combines initial state name
// with the step op name to ensure the uniqueness.
// initial state initializer is input to a scan op. it belongs to the parent graph.
// initial state nodearg applies to both subgraph and the parent graph.
// e.g.: initialzer name: Constant221FutureValue222
// init state node arg for main graph: Constant221FutureValue222
// init state node arg for subgraph: Constant221FutureValue222_subgraph
std::string initialStateInitializerName = MakeInitialStateNodeArgName(src->Output());
std::string initialStateInputName = inSubgraph ?
MakeScanInputOutputNodeArgName(initialStateInitializerName) : initialStateInitializerName;
// It cannot be shared between too state.
// inputName = inputName + ToLegacyString(ToUTF8(src->Uid()));
// build initial state with shape that matches the state input. In case of scalar values, data tensor is constructed with the needed shape.
onnx::TypeProto initialStateInputArgType = ToTypeProto(stateInput.Shape(), initialStateInput.HasBatchAxis(), initialStateInput.HasSequenceAxis());
std::string initialStateInputName = MakeInitialStateNodeArgName(src->Output());
if (inSubgraph)
initialStateInputName = MakeScanInputOutputNodeArgName(initialStateInputName);
onnx::TypeProto initialStateInputArgType = ToTypeProto(initialStateInput.Shape(), initialStateInput.HasBatchAxis(), initialStateInput.HasSequenceAxis());
UpdateONNXType(initialStateInput.GetDataType(), initialStateInputArgType);
onnxruntime::NodeArg &initialStateInputArg = graph->GetOrCreateNodeArg(initialStateInputName, &initialStateInputArgType);
inputs.push_back(&initialStateInputArg);
if (inSubgraph)
{
// create initial state constant and final state nodeArg
// define a state output so executors know this is the place
// to run body function in loops to get the next t + 1 state.
Variable stateOutput = src->Outputs()[0];
scanLoops[createLoopIndex].scanLoopStates.push_back(ScanLoopState(initialStateInput, nullptr, stateOutput,
src->OpName() == L"PastValue" ? 1 : -1));
scanLoops[createLoopIndex].scanLoopStates.rbegin()->m_initialStateNodeArg = &initialStateInputArg;
bool isInitialStateConstant = initialStateInput.IsParameter() || initialStateInput.IsConstant();
if (isInitialStateConstant)
{
auto srcTensor = initialStateInput.IsParameter() ?
Parameter(initialStateInput).Value() : Constant(initialStateInput).Value();
onnx::TensorProto dstTensor;
dstTensor.set_name(initialStateInitializerName);
// in case initial state being a scalar, we needs to expand it to the shape of state.
std::vector<int64_t> initialStateShape = ToINTS(initialStateInputArgType);
// Initial state as an input to the scan op shall have batch axis.
// initial state is a input to a Scan node. As an input, it needs to have batch axis.
std::vector<int64_t> initialStateShape =
ToINTS(ToTypeProto(src->Inputs()[0].Shape(), (int)(src->Inputs()[0].DynamicAxes().size())));
if (ScanWithoutBatchAxis)
initialStateShape.insert(initialStateShape.begin(), BatchSizeProcessor::FreeBatchSize());
if (!srcTensor->Shape().IsScalar())
{
// initialStateInputArgType can be off by a batch dimension of size 1 if ScanWithoutBatchAxis is true
onnx::TypeProto initialStateInputArgTypeWithBatch = ToTypeProto(initialStateShape, false);
UpdateONNXType(initialStateInput.GetDataType(), initialStateInputArgTypeWithBatch);
CopyTensor(srcTensor, dstTensor, &initialStateInputArgTypeWithBatch);
}
else
{
FillTensorWithScalarFromSingleSource(srcTensor, dstTensor, initialStateShape);
}
onnx::TensorProto dstTensor;
dstTensor.set_name(initialStateInputName);
auto srcTensor = initialStateInput.IsParameter() ?
Parameter(initialStateInput).Value() : Constant(initialStateInput).Value();
// FillTensorWithScalar takes vector of srcs and assumes initialStateShape a collection of shape
// for each src
FillTensorWithScalarFromSingleSource(srcTensor, dstTensor, initialStateShape);
// initial state is input to a scan op. it belongs to the parent graph
scanLoops[createLoopIndex].scanLoopStates.rbegin()->m_initialStateTensor = dstTensor;
scanLoops[createLoopIndex].scanLoopStates.rbegin()->m_hasInitializer = true;
graph->AddOuterScopeNodeArg(initialStateInputName);
}
else
{
//if (!scanLoopState.m_initialState.HasBatchAxis())
//{
// // expand with batch dimension - initial state may or may not has batch dimension.
// // 1. in case initial state is a trained constant, there is no batch dimension.
// // ProcessLoopStepInputs (and ProcessInpits) has this case handled by adding an FreeDimension
// // 2. in case where initial state is from another computation node, it is likely
// // it already has a batch dimension.
//}
}
}
}
if (inSubgraph)
{
// create initial state constant and final state nodeArg
// define a state output so executors know this is the place
// to run body function in loops to get the next t + 1 state.
Variable stateOutput = src->Outputs()[0];
scanLoops[createLoopIndex].scanLoopStates.push_back(ScanLoopState(initialStateInput, nullptr, stateOutput,
src->OpName() == L"PastValue" ? 1 : -1));
scanLoops[createLoopIndex].scanLoopStates.rbegin()->m_initialStateNodeArg = inputs[1];
}
std::vector<FunctionPtr> inputOps;
if (stateInput.IsOutput())
inputOps.push_back(stateInput.Owner());
@ -5368,6 +5378,19 @@ void CNTKToONNXHelper::ProcessInputs(const FunctionPtr& src,
bool isInSubGraph = createLoopIndex >= 0 && createLoopIndex < scanLoops.size();
bool isInitialStateOfSubGraph = false;
if ((createLoopIndex >= 0 && createLoopIndex < scanLoops.size()) && inputIndex == 1)
{
for (auto &f : scanLoops[createLoopIndex].m_loopstepfunctions)
{
if (f->Inputs().size() == 2 && f->Inputs()[inputIndex].Uid() == input.Uid())
{
isInitialStateOfSubGraph = true;
break;
}
}
}
bool isScanInputInSubgraph = createLoopIndex != -1 &&
std::find_if(scanLoops[createLoopIndex].m_scanInputs.begin(), scanLoops[createLoopIndex].m_scanInputs.end(),
[inputName](Variable v) {return inputName == UniqueNodeNameStorage::GetUniqueInputNodeName(v); })
@ -5386,6 +5409,7 @@ void CNTKToONNXHelper::ProcessInputs(const FunctionPtr& src,
// if initial state is a scalar, it will be created with correct shape later in this method.
ScanLoop &scanLoop = scanLoops[createLoopIndex];
// to match "else if (isInitialStateOfSubGraph)" case
// one intial state may map to multiple final states.
// to make one to one mapping from initial to final states,
// we have to split the inital state.
@ -5448,6 +5472,14 @@ void CNTKToONNXHelper::ProcessInputs(const FunctionPtr& src,
}
inputArgType = ToTypeProto(inputShape, input.HasBatchAxis(), input.HasSequenceAxis());
}
else if (isInitialStateOfSubGraph)
{
// for initial state, we need to make sure each of scanLoopStates has a uniques NodeArg.
// This NodeArg is for scan iteration to loop back at each scan iteration.
// It cannot be shared between too state.
inputName = MakeInitialStateNodeArgName(src->Output());
inputArgType = ToTypeProto(src->Inputs()[0].Shape(), src->Inputs()[0].HasBatchAxis(), src->Inputs()[0].HasSequenceAxis());
}
else
{
inputArgType = ToTypeProto(input.Shape(), input.HasBatchAxis(), input.HasSequenceAxis());
@ -5507,6 +5539,19 @@ void CNTKToONNXHelper::ProcessInputs(const FunctionPtr& src,
UpdateONNXType(input.GetDataType(), inputArgType);
}
if (isInitialStateOfSubGraph)
{
Variable initialState = src->Inputs()[1];
Variable stateInput = src->Inputs()[0];
Variable stateOutput = src->Outputs()[0];
// create initial state constant and final state nodeArg
// define a state output so executors know this is the place
// to run body function in loops to get the next t + 1 state.
scanLoops[createLoopIndex].scanLoopStates.push_back(ScanLoopState(initialState, nullptr, stateOutput,
src->OpName() == L"PastValue" ? 1 : -1));
}
bool addedInitializer = false;
//
// Leaf nodes are data entry to the graph and need their own node with only output arg.
@ -5522,20 +5567,39 @@ void CNTKToONNXHelper::ProcessInputs(const FunctionPtr& src,
onnx::TensorProto dstTensor;
dstTensor.set_name(inputName);
CopyTensor(srcTensor, dstTensor, &inputArgType);
if (CNTKToONNXHelper::globalGraph && createLoopIndex != -1)
if (isInitialStateOfSubGraph)
{
scanLoops[createLoopIndex].initializerAsInput.push_back(inputName);
// in case initial state being a scalar, we needs to expand it to the shape of state.
// initial state is a input to a Scan node. As an input, it needs to have batch axis.
std::vector<int64_t> initialStateShape =
ToINTS(ToTypeProto(src->Inputs()[0].Shape(), (int)(src->Inputs()[0].DynamicAxes().size())));
if (ScanWithoutBatchAxis)
initialStateShape.insert(initialStateShape.begin(), BatchSizeProcessor::FreeBatchSize());
// With Bing.Malta50.proto1_128_gru_normv3_ep3_z.model, I can only got ONNX runtime
// to produce matching results by putting initializers in the subgraphs
// (calling graph->AddInitializedTensor instead).
CNTKToONNXHelper::globalGraph->AddInitializedTensor(dstTensor);
// graph->AddInitializedTensor(dstTensor);
addedInitializer = true;
// FillTensorWithScalar takes vector of srcs and assumes initialStateShape a collection of shape
// for each src
FillTensorWithScalarFromSingleSource(srcTensor, dstTensor, initialStateShape);
// initial state is input to a scan op. it belongs to the parent graph
scanLoops[createLoopIndex].scanLoopStates.rbegin()->m_initialStateTensor = dstTensor;
scanLoops[createLoopIndex].scanLoopStates.rbegin()->m_hasInitializer = true;
}
else
graph->AddInitializedTensor(dstTensor);
{
CopyTensor(srcTensor, dstTensor, &inputArgType);
if (CNTKToONNXHelper::globalGraph && createLoopIndex != -1)
{
scanLoops[createLoopIndex].initializerAsInput.push_back(inputName);
// With Bing.Malta50.proto1_128_gru_normv3_ep3_z.model, I can only got ONNX runtime
// to produce matching results by putting initializers in the subgraphs
// (calling graph->AddInitializedTensor instead).
CNTKToONNXHelper::globalGraph->AddInitializedTensor(dstTensor);
// graph->AddInitializedTensor(dstTensor);
addedInitializer = true;
}
else
graph->AddInitializedTensor(dstTensor);
}
}
}
}
@ -5543,8 +5607,7 @@ void CNTKToONNXHelper::ProcessInputs(const FunctionPtr& src,
onnxruntime::NodeArg *adjusted = GetInputAdjustmentForBroadcast(graph, src, input, inputIndex, inputArgType,
compositeOutputsMap);
if ((isOutputOfStepFunction && isInSubGraph) || isScanInputInSubgraph)
if (isInitialStateOfSubGraph || (isOutputOfStepFunction && isInSubGraph) || isScanInputInSubgraph)
{
inputName = MakeScanInputOutputNodeArgName(inputName);
}
@ -5557,6 +5620,9 @@ void CNTKToONNXHelper::ProcessInputs(const FunctionPtr& src,
inputs.push_back(&inputArg);
if (isInitialStateOfSubGraph)
scanLoops[createLoopIndex].scanLoopStates.rbegin()->m_initialStateNodeArg = inputs[1];
if (cntkOpName == "Reshape")
{
// ONNX1.2 reshape node take shape as input instead of attribute.
@ -5750,15 +5816,8 @@ void CNTKToONNXHelper::TraverseGraph(const FunctionPtr& src,
if (input.IsPlaceholder())
{
input = input.BlockFunctionVariableMapping();
if (!Operators::IsRNNOp(opName) && input.IsPlaceholder())
{
// this could be the case that a block take an input from another block.
// So try map one more time.
input = input.BlockFunctionVariableMapping();
if (!input.IsInitialized() || input.IsPlaceholder())
LogicError("Node '%S': Placeholder isn't supported currently.", src->AsString().c_str());
}
LogicError("Node '%S': Placeholder isn't supported currently.", src->AsString().c_str());
}
if (input.IsInitialized() && input.IsOutput())
@ -6931,8 +6990,6 @@ std::pair<std::vector<int>, std::vector<int>> CNTKToONNXHelper::GetONNXPadsAttri
void CNTKToONNXHelper::FillTensorWithScalarFromSingleSource(const NDArrayViewPtr &src,
onnx::TensorProto& dst, const std::vector<int64_t> dstShape)
{
if (!src->Shape().IsScalar())
LogicError("FillTensorWithScalarFromSingleSource can only work with a scalar.");
auto dataType = src->GetDataType();
SetTensorType(dst, dataType);
int64_t eachSrcSize = std::accumulate(dstShape.begin(), dstShape.end(), (int64_t)1, std::multiplies<int64_t>());