Removing support for user-provided node names and using uid for ONNX names.
This commit is contained in:
Родитель
e0b26561b8
Коммит
b73670c3bd
|
@ -1810,7 +1810,7 @@ LotusIR::Node* CNTKToONNXHelper::CreateLSTMNode(const FunctionPtr &src,
|
|||
if (Xs[0].Owner().get() != nullptr)
|
||||
CreateNode(Xs[0].Owner(), graph, functionNodes, variableNodes, compositeOutputsMap);
|
||||
|
||||
auto nodeName = src->Name().empty() ? ToLegacyString(ToUTF8(src->Uid())) : ToLegacyString(ToUTF8(src->Name()));
|
||||
auto nodeName = ToLegacyString(ToUTF8(src->Uid()));
|
||||
LotusIR::Node *lstmNode = graph->AddNode(nodeName, "LSTM", "", nodeInputs, nodeOutputs);
|
||||
|
||||
lstmNode->AddAttribute("activations", activations);
|
||||
|
@ -2080,7 +2080,7 @@ LotusIR::Node *CNTKToONNXHelper::CreateGRUNode(const FunctionPtr &src,
|
|||
if (Xs[0].Owner().get() != nullptr)
|
||||
CreateNode(Xs[0].Owner(), graph, functionNodes, variableNodes, compositeOutputsMap);
|
||||
|
||||
auto nodeName = src->Name().empty() ? ToLegacyString(ToUTF8(src->Uid())) : ToLegacyString(ToUTF8(src->Name()));
|
||||
auto nodeName = ToLegacyString(ToUTF8(src->Uid()));
|
||||
LotusIR::Node *gruNode = graph->AddNode(nodeName, "GRU", "", nodeInputs, nodeOutputs);
|
||||
|
||||
gruNode->AddAttribute("activations", activations);
|
||||
|
@ -2265,7 +2265,7 @@ LotusIR::Node *CNTKToONNXHelper::CreateRNNNode(const FunctionPtr &src,
|
|||
if (Xs[0].Owner().get() != nullptr)
|
||||
CreateNode(Xs[0].Owner(), graph, functionNodes, variableNodes, compositeOutputsMap);
|
||||
|
||||
auto nodeName = src->Name().empty() ? ToLegacyString(ToUTF8(src->Uid())) : ToLegacyString(ToUTF8(src->Name()));
|
||||
auto nodeName = ToLegacyString(ToUTF8(src->Uid()));
|
||||
LotusIR::Node *rnnNode = graph->AddNode(nodeName, "RNN", "", nodeInputs, nodeOutputs);
|
||||
|
||||
rnnNode->AddAttribute("activations", activations);
|
||||
|
@ -3230,6 +3230,8 @@ void CNTKToONNXHelper::CopyAttributes(const FunctionPtr& src, LotusIR::Node* nod
|
|||
|
||||
node->AddAttribute("axis", ax);
|
||||
}
|
||||
|
||||
node->AddAttribute("keepdims", keepReducedDimensions);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -3304,7 +3306,7 @@ LotusIR::Node* CNTKToONNXHelper::AddNode(const FunctionPtr& src, LotusIR::Graph*
|
|||
{
|
||||
LotusIR::Node* node = nullptr;
|
||||
std::vector<LotusIR::NodeArg *> orderedInputs = MapInputsOrderToONNX(src, inputs);
|
||||
auto nodeName = src->Name().empty() ? ToLegacyString(ToUTF8(src->Uid())) : ToLegacyString(ToUTF8(src->Name()));
|
||||
auto nodeName = ToLegacyString(ToUTF8(src->Uid()));
|
||||
|
||||
if (L"Embedding" == src->OpName())
|
||||
{
|
||||
|
@ -3597,7 +3599,7 @@ LotusIR::Node* CNTKToONNXHelper::CreateONNXNodesForOptimizedRNNStack(const Funct
|
|||
// Input operand X
|
||||
if (inputNeedsShapeAdapter)
|
||||
{
|
||||
std::string adapterBasename = (src->Name().empty() ? ToLegacyString(ToUTF8(src->Uid())) : ToLegacyString(ToUTF8(src->Name()))) + "_Adapter_" + std::to_string(i);
|
||||
std::string adapterBasename = ToLegacyString(ToUTF8(src->Uid())) + "_Adapter_" + std::to_string(i);
|
||||
LotusIR::NodeArg* shapeAdaptedInputOperandArg = LSTMOutputShapeAdapter(*layerInputOperandArg, ornnOutputArgType, graph,
|
||||
numDirections, hiddenSize, ornnOutput.GetDataType(), adapterBasename);
|
||||
inputs.push_back(shapeAdaptedInputOperandArg);
|
||||
|
@ -3625,7 +3627,7 @@ LotusIR::Node* CNTKToONNXHelper::CreateONNXNodesForOptimizedRNNStack(const Funct
|
|||
|
||||
// ==== Step 6. Add ONNX LSTM node ====
|
||||
auto rnnOpNameLookup = Operators::OptimizedRnnToOnnxOpLookup();
|
||||
auto rnnNodeName = (src->Name().empty() ? ToLegacyString(ToUTF8(src->Uid())) : ToLegacyString(ToUTF8(src->Name()))) + std::to_string(i);
|
||||
auto rnnNodeName = ToLegacyString(ToUTF8(src->Uid())) + std::to_string(i);
|
||||
functionNode = graph->AddNode(rnnNodeName, rnnOpNameLookup[recurrentOp], "", inputs, outputs);
|
||||
|
||||
std::vector<std::string> singleDirectionActivation;
|
||||
|
|
Загрузка…
Ссылка в новой задаче