diff --git a/Makefile b/Makefile
index 60322204b..60e184b19 100644
--- a/Makefile
+++ b/Makefile
@@ -544,9 +544,14 @@ CNTKLIBRARY_COMMON_SRC =\
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/core/platform/posix/stacktrace.cc \
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/checker.cpp \
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/common/assertions.cc \
+ $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/common/model_helpers.cc \
+ $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/common/status.cc \
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/defs/controlflow/defs.cc \
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/defs/experiments/defs.cc \
+ $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/defs/experiments/experiments_functions.cc \
+ $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/defs/function.cc \
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/defs/generator/defs.cc \
+ $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/defs/generator/old.cc \
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/defs/logical/defs.cc \
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/defs/logical/old.cc \
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/defs/math/defs.cc \
@@ -561,6 +566,7 @@ CNTKLIBRARY_COMMON_SRC =\
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/defs/traditionalml/defs.cc \
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/defs/data_type_utils.cc \
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/defs/schema.cc \
+ $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/shape_inference/implementation.cc \
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/onnx-ml.pb.cc \
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/onnx-operators-ml.pb.cc \
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/Operators.cpp \
diff --git a/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj b/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj
index 0c78f1e21..039a987a7 100644
--- a/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj
+++ b/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj
@@ -224,8 +224,11 @@
+
+
+
@@ -292,10 +295,15 @@
+
+
+
+
+
@@ -309,6 +317,7 @@
+
diff --git a/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj.filters b/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj.filters
index a19372b0e..98a9567c0 100644
--- a/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj.filters
+++ b/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj.filters
@@ -169,6 +169,24 @@
proto\onnx\core\platform\windows
+
+ proto\onnx\onnx_repo\onnx\defs
+
+
+ proto\onnx\onnx_repo\onnx\common
+
+
+ proto\onnx\onnx_repo\onnx\defs\experiments
+
+
+ proto\onnx\onnx_repo\onnx\defs\generator
+
+
+ proto\onnx\onnx_repo\onnx\common
+
+
+ proto\onnx\onnx_repo\onnx\shape_inference
+
@@ -394,6 +412,15 @@
proto\onnx
+
+ proto\onnx\onnx_repo\onnx\defs
+
+
+ proto\onnx\onnx_repo\onnx\common
+
+
+ proto\onnx\onnx_repo\onnx\common
+
@@ -504,6 +531,9 @@
{938a6293-26e8-4aad-9aa3-200d9b96102b}
+
+ {b8ebfd65-98ba-44fb-b10d-ac1e7e8e5246}
+
diff --git a/Source/CNTKv2LibraryDll/Logger.h b/Source/CNTKv2LibraryDll/Logger.h
index 7ae8997d8..640297807 100644
--- a/Source/CNTKv2LibraryDll/Logger.h
+++ b/Source/CNTKv2LibraryDll/Logger.h
@@ -9,17 +9,15 @@
#include "core/common/logging/isink.h"
namespace CNTK {
-class CNTKClogSink : public onnxruntime::Logging::ISink {
+class CNTKClogSink : public onnxruntime::logging::ISink {
public:
CNTKClogSink()
: stream_{&(std::clog)}, flush_{true}
{}
- void SendImpl(const onnxruntime::Logging::Timestamp ×tamp,
- const std::string &logger_id, const onnxruntime::Logging::Capture &message) override
+ void SendImpl(const onnxruntime::logging::Timestamp ×tamp,
+ const std::string &logger_id, const onnxruntime::logging::Capture &message) override
{
- UNUSED_PARAMETER(timestamp);
-
std::ostringstream msg;
msg << " [" << message.SeverityPrefix() << ":" << message.Category() << ":" << logger_id << ", "
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp b/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp
index ce6ccfaf1..b28721350 100644
--- a/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp
+++ b/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp
@@ -241,6 +241,13 @@ private:
std::unordered_map& variableNodes,
const std::unordered_map& compositeOutputsMap,
std::vector &scanLoops, int createLoopIndex);
+
+ static onnxruntime::Node* CreatePastFutureValueNode(const FunctionPtr& src,
+ onnxruntime::Graph* graph,
+ std::unordered_map& functionNodes,
+ std::unordered_map& variableNodes,
+ const std::unordered_map& compositeOutputsMap,
+ std::vector &scanLoops, int createLoopIndex);
static onnxruntime::Node* CreateSequenceIsFirstOrLastNode(const FunctionPtr& src,
onnxruntime::Graph* graph,
std::unordered_map& functionNodes,
@@ -248,6 +255,14 @@ private:
const std::unordered_map& compositeOutputsMap,
std::vector &scanLoops, int createLoopIndex,
bool isFirst);
+
+ static onnxruntime::Node* CreateNodeWithGatherPacked(const FunctionPtr& src,
+ onnxruntime::Graph* graph,
+ std::unordered_map& functionNodes,
+ std::unordered_map& variableNodes,
+ const std::unordered_map& compositeOutputsMap,
+ std::vector &scanLoops, int createLoopIndex);
+
static onnxruntime::Node* CreateSequenceSliceNode(const FunctionPtr& src,
onnxruntime::Graph* graph,
std::unordered_map& functionNodes,
@@ -298,7 +313,8 @@ private:
const std::string &nodeArgName);
static onnxruntime::Node *AddReshapeNodeImpl(Graph *graph, const string &nodeName, NodeArg *input, NodeArg *output, const std::vector& newShape);
- static NodeArg* GetInputAdjustmentForBroadcast(onnxruntime::Graph* graph, const FunctionPtr src, const Variable &input, int inputIndex, onnx::TypeProto &inputArgType);
+ static NodeArg* GetInputAdjustmentForBroadcast(onnxruntime::Graph* graph, const FunctionPtr src, const Variable &input, int inputIndex,
+ onnx::TypeProto &inputArgType, const std::unordered_map& compositeOutputsMap);
// process loops to produce Scan ops.
// return true to continue process the src, otherwise the node has been process.
@@ -331,7 +347,8 @@ private:
static void ProcessOutputsForBatchAxisOp(const FunctionPtr& rootNode,
std::vector& outputs, Graph *graph);
- static onnxruntime::NodeArg &CreateNodeArg(const Variable &input, onnxruntime::Graph* graph, const std::string &replace_name = "");
+ static onnxruntime::NodeArg &CreateNodeArg(const Variable &variable, onnxruntime::Graph* graph,
+ bool isInput, const std::string &replace_name = "");
static onnxruntime::Node *AddSliceNode(onnxruntime::NodeArg &inputArg, const std::vector &axes,
const std::vector &sliceStarts, const std::vector &sliceEnds,
const std::string &outArgName, onnxruntime::Graph* graph);
@@ -351,7 +368,8 @@ private:
static onnxruntime::Node *AddArgMaxNode(onnxruntime::NodeArg &nodeArg, onnxruntime::Graph* graph, int axis);
static onnxruntime::Node *AddCastNode(onnxruntime::NodeArg &nodeArg, onnxruntime::Graph* graph, onnx::TensorProto_DataType toType,
const std::string &outputNodeArgName);
- static NodeArg& AddTransposeBatchSequenceAxesNode(onnxruntime::NodeArg &nodeArg, bool isInput, onnxruntime::Graph* graph);
+ static NodeArg& AddTransposeBatchSequenceAxesNode(onnxruntime::NodeArg &nodeArg, bool isInput,
+ onnxruntime::Graph* graph, const std::string& scanNodeName);
static onnxruntime::Node *AddTransposeNode(onnxruntime::NodeArg &nodeArg, onnxruntime::Graph* graph, const std::vector &perm,
onnx::TypeProto& transposeOutputArgType, const std::string &outputNodeArgName);
@@ -927,7 +945,7 @@ void CNTKToONNXHelper::HandleRootCombineOp(const FunctionPtr& src, onnxruntime::
for (auto input : src->Inputs())
{
std::string nodeArgName = UniqueNodeNameStorage::GetUniqueInputNodeName(input);
- const NodeArg* nodeArg = dst->FindNodeArg(nodeArgName);
+ const NodeArg* nodeArg = dst->GetNodeArg(nodeArgName);
if (!nodeArg)
continue;
@@ -1942,6 +1960,14 @@ std::string CNTKToONNXHelper::ToOPName(const FunctionPtr& src)
const AttributesMapping& attributeMap = Operators::FindAttributeMap(src->OpName(), cntkAttributeOpName);
+ opName = attributeMap.map.at(cntkAttributeOpName);
+ }
+ else if (src->OpName() == L"RandomDistribution")
+ {
+ wstring cntkAttributeOpName = (wstring)src->Attributes()[PrimitiveFunctionAttribute::AttributeNameRandomDistributionType].Value();
+
+ const AttributesMapping& attributeMap = Operators::FindAttributeMap(src->OpName(), cntkAttributeOpName);
+
opName = attributeMap.map.at(cntkAttributeOpName);
}
}
@@ -1963,10 +1989,16 @@ bool CNTKToONNXHelper::OpInputsHasBatchAxis(const FunctionPtr& src)
bool CNTKToONNXHelper::FilterInput(const FunctionPtr& src, const CNTK::Variable& input, size_t inputIndex)
{
- // In CNTK block functions, they expose all constants inside the block. For block functions that
+ // 1. In CNTK block functions, they expose all constants inside the block. For block functions that
// map directly to ONNX OP, we don't care about constanst inside the block.
- if (input.IsConstant())
+ // 2. For some CNTK ops, we want to only process selected inputs.
+ // For example, in v1 model Sequence::Gather op is decomposed into a subgraph of GatherPacked, PackedIndex, and Where.
+ // inputs to the composed Sequence::Gather op is GatherPacked's inputs[0] and Where's inputs[0]. These are the
+ // inputs we need to process. Other inputs to ops in the subgraph are treated as invalid so they are not processed.
+ if (input.IsConstant() ||
+ src->OpName() == L"GatherPacked")
return !Operators::IsValidInputs(src->OpName(), inputIndex);
+
return false;
}
@@ -2744,17 +2776,21 @@ onnxruntime::Node *CNTKToONNXHelper::CreateGRUNode(const FunctionPtr &src,
onnxruntime::NodeArg &outputArg = graph->GetOrCreateNodeArg(nodeOutputNameBeforeReshape, &outputArgType);
nodeOutputs.push_back(&outputArg);
- {
- Variable Yh = Yhs[0];
- std::string nodeName = ToLegacyString(ToUTF8(Yh.Uid())) + "_h";
- // TODO: batchSize is fixed to one. Needs to find out how to handle bacth axis as a free dimension.
- const int batchSize = 1;
- const bool doReverseVec = false;
- auto outputArgType = ToTypeProto(std::vector({ (int64_t)Yhs.size(), batchSize, (int)Yh.Shape()[0] }), doReverseVec);
- UpdateONNXType(Yh.GetDataType(), outputArgType);
- onnxruntime::NodeArg &outputArg = graph->GetOrCreateNodeArg(nodeName, &outputArgType);
- nodeOutputs.push_back(&outputArg);
- }
+ // TODO: to be consistant with RNN and LSTM where Yhs is the only output.
+ // It is true that either C.layers.Recurrence(C.layers.GRU... or
+ // C.layers.Sequential([C.layers.Recurrence(C.layers.LSTM
+ // both has a single output.
+ //{
+ // Variable Yh = Yhs[0];
+ // std::string nodeName = ToLegacyString(ToUTF8(Yh.Uid())) + "_h";
+ // // TODO: batchSize is fixed to one. Needs to find out how to handle bacth axis as a free dimension.
+ // const int batchSize = 1;
+ // const bool doReverseVec = false;
+ // auto outputArgType = ToTypeProto(std::vector({ (int64_t)Yhs.size(), batchSize, (int)Yh.Shape()[0] }), doReverseVec);
+ // UpdateONNXType(Yh.GetDataType(), outputArgType);
+ // onnxruntime::NodeArg &outputArg = graph->GetOrCreateNodeArg(nodeName, &outputArgType);
+ // nodeOutputs.push_back(&outputArg);
+ //}
}
// TODO: Except X, all other inputs to GRU are treated as constant.
@@ -3055,13 +3091,18 @@ onnxruntime::Node *CNTKToONNXHelper::AddReshapeNodeImpl(Graph *graph, const stri
}
// create a NodeArg for an input variable.
-onnxruntime::NodeArg &CNTKToONNXHelper::CreateNodeArg(const Variable &input, onnxruntime::Graph* graph, const std::string &replace_name)
+onnxruntime::NodeArg &CNTKToONNXHelper::CreateNodeArg(const Variable &variable, onnxruntime::Graph* graph, bool isInput, const std::string &replace_name)
{
- onnx::TypeProto inputTypeProto = ToTypeProto(input.Shape(), input.HasBatchAxis(), input.HasSequenceAxis());
- onnx::TensorProto_DataType elemType = ConvertDataTypeCNTKToTensorProto(input.GetDataType());
- inputTypeProto.mutable_tensor_type()->set_elem_type(elemType);
- onnxruntime::NodeArg &inputArg = graph->GetOrCreateNodeArg(
- replace_name.empty() ? UniqueNodeNameStorage::GetUniqueInputNodeName(input) : replace_name, &inputTypeProto);
+ onnx::TypeProto typeProto = ToTypeProto(variable.Shape(), variable.HasBatchAxis(), variable.HasSequenceAxis());
+ onnx::TensorProto_DataType elemType = ConvertDataTypeCNTKToTensorProto(variable.GetDataType());
+ typeProto.mutable_tensor_type()->set_elem_type(elemType);
+
+ std::string nodeArgName = replace_name;
+ if (nodeArgName.empty())
+ nodeArgName = isInput ?
+ UniqueNodeNameStorage::GetUniqueInputNodeName(variable) :
+ UniqueNodeNameStorage::GetUniqueOutputNodeName(variable);
+ onnxruntime::NodeArg &inputArg = graph->GetOrCreateNodeArg(nodeArgName, &typeProto);
return inputArg;
}
@@ -3139,8 +3180,8 @@ onnxruntime::Node *CNTKToONNXHelper::AddSqueezeNode(onnxruntime::NodeArg &inputA
}
// add an expand node
-onnxruntime::Node *CNTKToONNXHelper::AddExpandNode(onnxruntime::NodeArg &inputArg, const std::vector &newShape, const std::string &outArgName,
- onnxruntime::Graph* graph)
+onnxruntime::Node *CNTKToONNXHelper::AddExpandNode(onnxruntime::NodeArg &inputArg, const std::vector &newShape,
+ const std::string &outArgName, onnxruntime::Graph* graph)
{
onnxruntime::NodeArg &shapeNodeArg = CreateAddShapeNodeArg(graph, newShape, outArgName + "_expand_shape");
@@ -3187,7 +3228,10 @@ onnxruntime::Node *CNTKToONNXHelper::AddAddNode(onnxruntime::NodeArg &nodeArg1,
onnxruntime::Node *CNTKToONNXHelper::AddIdentityOp(onnxruntime::NodeArg &nodeArg, onnxruntime::Graph* graph, const std::string &out_arg_name)
{
- onnxruntime::NodeArg &outputArg = graph->GetOrCreateNodeArg(out_arg_name, nullptr);
+ onnx::TypeProto outputTypeProto(*nodeArg.TypeAsProto());
+ outputTypeProto.mutable_tensor_type()->set_elem_type(nodeArg.TypeAsProto()->tensor_type().elem_type());
+
+ onnxruntime::NodeArg &outputArg = graph->GetOrCreateNodeArg(out_arg_name, &outputTypeProto);
onnxruntime::Node* identityNode = graph->AddNode(
nodeArg.Name() + string("_identity"), "Identity", "", { &nodeArg}, { &outputArg });
return identityNode;
@@ -3205,8 +3249,10 @@ onnxruntime::Node *CNTKToONNXHelper::AddArgMaxNode(onnxruntime::NodeArg &nodeArg
onnxruntime::Node *CNTKToONNXHelper::AddCastNode(onnxruntime::NodeArg &nodeArg, onnxruntime::Graph* graph,
onnx::TensorProto_DataType toType, const std::string &outputNodeArgName)
{
- // onnxruntime::NodeArg inputArg(nodeArg.Name(), nullptr);
- onnxruntime::NodeArg &outputArg = graph->GetOrCreateNodeArg(nodeArg.Name() + "_cast_out_" + outputNodeArgName, nullptr);
+ TypeProto outputTypeProto(*nodeArg.TypeAsProto());
+ outputTypeProto.mutable_tensor_type()->set_elem_type(toType);
+
+ onnxruntime::NodeArg &outputArg = graph->GetOrCreateNodeArg(nodeArg.Name() + "_cast_out_" + outputNodeArgName, &outputTypeProto);
onnxruntime::Node* castNode = graph->AddNode(nodeArg.Name() + string("_cast_") + outputNodeArgName,
"Cast", "", { &nodeArg }, { &outputArg });
castNode->AddAttribute("to", (int64_t)toType);
@@ -3217,7 +3263,7 @@ onnxruntime::Node *CNTKToONNXHelper::AddCastNode(onnxruntime::NodeArg &nodeArg,
// This is different from the convention of CNTK exporter and ONNX RNN ops where sequence is the first dimension.
// to conpensate this difference, call this method before and after a Scan op to swap batch and sequence axes.
NodeArg& CNTKToONNXHelper::AddTransposeBatchSequenceAxesNode(onnxruntime::NodeArg &nodeArg,
- bool isInput, onnxruntime::Graph* graph)
+ bool isInput, onnxruntime::Graph* graph, const std::string& scanNodeName)
{
const TypeProto& typeProto = *nodeArg.TypeAsProto();
int rank = typeProto.tensor_type().shape().dim_size();
@@ -3236,8 +3282,10 @@ NodeArg& CNTKToONNXHelper::AddTransposeBatchSequenceAxesNode(onnxruntime::NodeAr
*newdim = typeProto.tensor_type().shape().dim((int)i);
}
- std::string otherNodeArgName = nodeArg.Name() + (isInput ? "transposed_to_batch_sequence_output" : "transposed_to_sequence_batch_input");
- std::string nodeName = nodeArg.Name() + (isInput ? "transposed_to_batch_sequence" : "transposed_to_sequence_batch");
+ std::string otherNodeArgName = nodeArg.Name() +
+ (isInput ? "_transposed_to_batch_sequence_output_" : "_transposed_to_sequence_batch_input_") + scanNodeName;
+ std::string nodeName = nodeArg.Name() +
+ (isInput ? "_transposed_to_batch_sequence_" : "_transposed_to_sequence_batch_") + scanNodeName;
onnxruntime::NodeArg &otherArg = graph->GetOrCreateNodeArg(otherNodeArgName, &otherTypeProto);
std::vector perm(rank);
std::generate(perm.begin(), perm.end(), [axis = 0]() mutable { return axis++; });
@@ -3252,7 +3300,7 @@ NodeArg& CNTKToONNXHelper::AddTransposeBatchSequenceAxesNode(onnxruntime::NodeAr
onnxruntime::Node *CNTKToONNXHelper::AddTransposeNode(onnxruntime::NodeArg &nodeArg, onnxruntime::Graph* graph,
const std::vector &perm, onnx::TypeProto& transposeOutputArgType, const std::string &outputNodeArgName)
{
- onnxruntime::NodeArg &outputArg = graph->GetOrCreateNodeArg(nodeArg.Name() + "transpose_out", &transposeOutputArgType);
+ onnxruntime::NodeArg &outputArg = graph->GetOrCreateNodeArg(outputNodeArgName, &transposeOutputArgType);
onnx::TensorProto_DataType elementType = nodeArg.TypeAsProto()->tensor_type().elem_type();
const_cast(outputArg.TypeAsProto())->mutable_tensor_type()->set_elem_type(elementType);
onnxruntime::Node* transposeNode = graph->AddNode(nodeArg.Name() + string("_transpose"), "Transpose", "", { &nodeArg }, { &outputArg });
@@ -3341,6 +3389,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateSoftmaxLikeNode(const FunctionPtr& sr
std::vector inputs;
ProcessInputs(src, graph, functionNodes, variableNodes, compositeOutputsMap, inputs, scanLoops, createLoopIndex);
+
std::vector outputs;
ProcessOutputs(src, outputs, graph);
@@ -3405,6 +3454,72 @@ onnxruntime::Node* CNTKToONNXHelper::CreateSoftmaxLikeNode(const FunctionPtr& sr
return softmaxLikeNode;
}
+onnxruntime::Node* CNTKToONNXHelper::CreatePastFutureValueNode(const FunctionPtr& src,
+ onnxruntime::Graph* graph,
+ std::unordered_map& functionNodes,
+ std::unordered_map& variableNodes,
+ const std::unordered_map& compositeOutputsMap,
+ std::vector &scanLoops, int createLoopIndex)
+{
+
+ bool past = src->OpName() == L"PastValue";
+ std::vector inputs, outputs;
+ ProcessInputs(src, graph, functionNodes, variableNodes, compositeOutputsMap, inputs, scanLoops, createLoopIndex);
+
+ ProcessOutputs(src, outputs, graph);
+
+ // 1. slice off first or last timeframe from input[0] -> input_sliced_node
+ // 2. expand initial value input[1] to the shape of input[0] without sequence axis (the first axis) -> init_value_expanded
+ // 3. concat input_sliced_node with init_value_expanded or other way around -> Past(Future)Value node
+
+ // 1. slice input
+ int64_t sliceAxis = 0, sliceStart, sliceEnd;
+ if (past)
+ {
+ sliceStart = 0;
+ sliceEnd = -1;
+ }
+ else
+ {
+ sliceStart = 1;
+ sliceEnd = std::numeric_limits::max();
+ }
+
+ const std::string sliceOutputArgName = UniqueNodeNameStorage::GetUniqueInputNodeName(src->Inputs()[0]) +
+ "_slice_" + UniqueNodeNameStorage::GetUniqueNodeName(src);
+ Node *sliceNode = AddSliceNode(*inputs[0], { sliceAxis }, { sliceStart }, { sliceEnd }, sliceOutputArgName, graph);
+
+ // 2. expand init_value
+ std::vector expandShape = ToINTS(*inputs[0]->TypeAsProto());
+ // sequence dimension is one for init_value
+ expandShape[0] = 1;
+ const std::string expandOutputName = UniqueNodeNameStorage::GetUniqueInputNodeName(src->Inputs()[1]) + "_expand_" +
+ UniqueNodeNameStorage::GetUniqueNodeName(src);
+ Node *initValueExpand = AddExpandNode(*inputs[1], expandShape, expandOutputName, graph);
+
+ // 3. concat
+ std::string outputNodeArgName = UniqueNodeNameStorage::GetUniqueOutputNodeName(src->Outputs()[0]);
+
+ Node * concatNode;
+ if (past)
+ {
+ concatNode = graph->AddNode(UniqueNodeNameStorage::GetUniqueNodeName(src), "Concat", "",
+ { const_cast(initValueExpand->OutputDefs()[0]), const_cast(sliceNode->OutputDefs()[0]) },
+ outputs);
+ }
+ else
+ {
+ concatNode = graph->AddNode(UniqueNodeNameStorage::GetUniqueNodeName(src), "Concat", "",
+ { const_cast(sliceNode->OutputDefs()[0]), const_cast(initValueExpand->OutputDefs()[0]) },
+ outputs);
+ }
+
+ // concat on sequence axis
+ concatNode->AddAttribute("axis", (int64_t)0);
+ functionNodes.emplace(src, concatNode);
+ return concatNode;
+}
+
// the idea is to create an EyeLike node and slice the first slice for IsFirst, the last slice for IsLast op.
onnxruntime::Node* CNTKToONNXHelper::CreateSequenceIsFirstOrLastNode(const FunctionPtr& src,
onnxruntime::Graph* graph,
@@ -3466,8 +3581,15 @@ onnxruntime::Node* CNTKToONNXHelper::CreateTupleNode(const FunctionPtr& src,
const std::unordered_map& compositeOutputsMap,
std::vector &scanLoops, int createLoopIndex)
{
- std::vector inputs;
+ std::vector inputs, outputs;
ProcessInputs(src, graph, functionNodes, variableNodes, compositeOutputsMap, inputs, scanLoops, createLoopIndex);
+
+ ProcessOutputs(src, outputs, graph);
+
+ assert(inputs.size() == outputs.size());
+ for (int i = 0; i < inputs.size(); i++)
+ graph->AddNode(UniqueNodeNameStorage::GetUniqueNodeName(src) + std::to_string(i), "Identity", "", { inputs[i] }, { outputs[i] });
+
return nullptr;
}
@@ -3508,7 +3630,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateSequenceBroadcastAsNode(const Functio
// [#][d0, d1]
std::vector newShape = ToINTS(ToTypeProto(input.Shape(), (int)input.DynamicAxes().size()));
- onnxruntime::NodeArg &inputNodeArg = CreateNodeArg(input, graph);
+ onnxruntime::NodeArg &inputNodeArg = CreateNodeArg(input, graph, true);
if (input.DynamicAxes().size() == 0)
{
newShape.insert(newShape.begin(), (int64_t)FreeBatchSize);
@@ -3531,8 +3653,40 @@ onnxruntime::Node* CNTKToONNXHelper::CreateSequenceGatherNode(const FunctionPtr&
if (CNTKToONNXHelper::isProcessingScan)
LogicError("SequenceGather cannot be in a scan loop");
- // waiting ONNX to have Compress or Where op
- NOT_IMPLEMENTED;
+ std::vector inputs;
+ ProcessInputs(src, graph, functionNodes, variableNodes, compositeOutputsMap, inputs, scanLoops, createLoopIndex);
+
+ // TODO: cannot call ProcessOutputs because we want the final output to have the expected ArgNode name
+ // to maintain graph connection.
+ //std::vector outputs;
+ //ProcessOutputs(src, outputs, graph);
+
+ // Cast inputs[1] from tensor to tensor
+ const std::string outputNodeArgName = inputs[1]->Name() + "_cast_to_bool";
+ Node *castNode = AddCastNode(*inputs[1], graph,
+ TensorProto_DataType::TensorProto_DataType_BOOL, outputNodeArgName);
+
+ // We want create a 1D boolean tensor as the condition input to the ONNX Compress.
+ // CNTK condition input has sequence and batch axes, and possibly additional static axes.
+ // all dimentions of static axes must be one.
+ // TODO: how to handle cases where batch_size is not 1?
+ std::vector squeezeAxes(inputs[1]->Shape()->dim_size() - 1);
+ std::generate(squeezeAxes.begin(), squeezeAxes.end(), [axis = 1]() mutable { return axis++; });
+
+ Node *castScreezeNode = AddSqueezeNode(const_cast(*castNode->OutputDefs()[0]),
+ squeezeAxes, castNode->Name() + "_squeezed", graph);
+ inputs[1] = const_cast(castScreezeNode->OutputDefs()[0]);
+
+ NodeArg& compressOutputNodeArg = CreateNodeArg(src->Outputs()[0], graph, false);
+
+ std::string nodeName = UniqueNodeNameStorage::GetUniqueNodeName(src);
+ std::string onnxOpName = "Compress";
+ Node *compressNode = graph->AddNode(nodeName, onnxOpName, "", inputs, { &compressOutputNodeArg });
+
+ int64_t sequenceAxis = 0;
+ compressNode->AddAttribute("axis", sequenceAxis);
+ functionNodes.emplace(src, compressNode);
+ return compressNode;
}
onnxruntime::Node* CNTKToONNXHelper::CreateSequenceReduceElementsNode(const FunctionPtr& src,
@@ -3562,6 +3716,56 @@ onnxruntime::Node* CNTKToONNXHelper::CreateSequenceReduceElementsNode(const Func
return node;
}
+onnxruntime::Node* CNTKToONNXHelper::CreateNodeWithGatherPacked(const FunctionPtr& src,
+ onnxruntime::Graph* graph,
+ std::unordered_map& functionNodes,
+ std::unordered_map& variableNodes,
+ const std::unordered_map& compositeOutputsMap,
+ std::vector &scanLoops, int createLoopIndex)
+{
+ assert(src->OpName() == L"GatherPacked");
+
+ auto packedIndex = src->Inputs()[1].Owner();
+ if (packedIndex->OpName() != L"PackedIndex")
+ LogicError("GatherPacked not from Sequence.Gather cannot be handled.");
+
+ auto whereFunc = packedIndex->Inputs()[1].Owner();
+ if (whereFunc->OpName() != L"Where")
+ LogicError("GatherPacked not from Sequence.Gather cannot be handled.");
+
+ // _cntkBlockOPInvalidIndices is set for "GatherPacked" to only have second input processed
+ std::vector gatherPackedInputs;
+ ProcessInputs(src, graph, functionNodes, variableNodes, compositeOutputsMap,
+ gatherPackedInputs, scanLoops, createLoopIndex);
+ assert(gatherPackedInputs.size() == 1);
+
+ std::vector whereInputs;
+ ProcessInputs(whereFunc, graph, functionNodes, variableNodes, compositeOutputsMap,
+ whereInputs, scanLoops, createLoopIndex);
+
+ // Cast from tensor to tensor
+ const std::string outputNodeArgName = whereInputs[0]->Name() + "_cast_to_bool";
+ Node *castNode = AddCastNode(*whereInputs[0], graph,
+ TensorProto_DataType::TensorProto_DataType_BOOL, outputNodeArgName);
+
+ // Squeeze to 1 dimension (sequence axis = 0) condition
+ std::vector squeezeAxes(castNode->OutputDefs()[0]->Shape()->dim_size() - 1);
+ std::generate(squeezeAxes.begin(), squeezeAxes.end(), [axis = 1]() mutable { return axis++; });
+
+ Node *castScreezeNode = AddSqueezeNode(const_cast(*castNode->OutputDefs()[0]),
+ squeezeAxes, castNode->Name() + "_squeezed", graph);
+
+ std::vector outputs;
+ ProcessOutputs(src, outputs, graph);
+
+ Node *compressNode = graph->AddNode(UniqueNodeNameStorage::GetUniqueNodeName(src), "Compress", "",
+ { gatherPackedInputs[0], const_cast(castScreezeNode->OutputDefs()[0]) }, outputs);
+ int64_t sequenceAxis = 0;
+ compressNode->AddAttribute("axis", sequenceAxis);
+ functionNodes.emplace(src, compressNode);
+ return compressNode;
+}
+
// To parse Sequence.Slice node graph to collect axis/begin index/end index
// and to build an ONNX Slice op.
// IMPORTANT NOTE:
@@ -3750,19 +3954,21 @@ void ResolveGraphAndSaveModel(onnxruntime::Model *model)
// use this method to attach an identity op so that state inputs/outputs of the subgraph are in the same order as the scan op
// extendedNodeArgOfSubgraph -> nodeArg -> Scan
// Scan -> nodeArg -> extendedNodeArgOfSubgraph
-void AttachNodeArg(onnxruntime::Graph* scanGraph, const std::string &subgraphNodeArgName, bool isInput)
+void AttachNodeArg(onnxruntime::Graph* scanGraph, const std::string &subgraphNodeArgName, bool isInput, bool isState)
{
NodeArg& nodeArgOfSubgraph = scanGraph->GetOrCreateNodeArg(subgraphNodeArgName, nullptr);
- NodeArg& extendedNodeArgOfSubgraph = scanGraph->GetOrCreateNodeArg(subgraphNodeArgName + "_extended", nodeArgOfSubgraph.TypeAsProto());
+ std::string extendedNodeAndNodeArgName = isState ? "state_" : "scan_";
+ extendedNodeAndNodeArgName += subgraphNodeArgName;
+ extendedNodeAndNodeArgName += isInput ? "_extended_to_" : "_extended_from_";
+
+ NodeArg& extendedNodeArgOfSubgraph = scanGraph->GetOrCreateNodeArg(extendedNodeAndNodeArgName, nodeArgOfSubgraph.TypeAsProto());
if (isInput)
{
- scanGraph->AddNode(subgraphNodeArgName + "_extended_to_", "Identity", "",
- { &extendedNodeArgOfSubgraph }, { &nodeArgOfSubgraph });
+ scanGraph->AddNode(extendedNodeAndNodeArgName, "Identity", "", { &extendedNodeArgOfSubgraph }, { &nodeArgOfSubgraph });
}
else
{
- scanGraph->AddNode(subgraphNodeArgName + "_extended_from_", "Identity", "",
- { &nodeArgOfSubgraph }, { &extendedNodeArgOfSubgraph });
+ scanGraph->AddNode(extendedNodeAndNodeArgName, "Identity", "", { &nodeArgOfSubgraph }, { &extendedNodeArgOfSubgraph });
}
}
@@ -3788,7 +3994,7 @@ bool CNTKToONNXHelper::ProcessLoopsAndCheckCNTKNodeContinueCreate(const Function
CNTKToONNXHelper::isProcessingScan = true;
// we are creating the createLoopIndex_th loop body, skip all ops that are not part of the loop body.
ScanLoop ¤tLoop = scanLoops[createLoopIndex];
- if (std::find(currentLoop.m_body.begin(), currentLoop.m_body.end(), src) == currentLoop.m_body.end())
+ if (!currentLoop.IsInBody(src))
{
return false;
}
@@ -3840,6 +4046,7 @@ bool CNTKToONNXHelper::ProcessLoopsAndCheckCNTKNodeContinueCreate(const Function
// create a subgraph
CreateNode(src, &scanGraph, functionNodes, variableNodes, compositeOutputsMap, scanLoops, loopIndex);
+ std::string scanNodeName = UniqueNodeNameStorage::GetUniqueNodeName(src);
// continue to create the global graph
CNTKToONNXHelper::isProcessingScan = false;
for (auto & loopBodyInput : scanLoops[loopIndex].m_inputs)
@@ -3879,7 +4086,7 @@ bool CNTKToONNXHelper::ProcessLoopsAndCheckCNTKNodeContinueCreate(const Function
subGraphInitialStateNodeArg.Name(), &scanInitialStateTypeProto);
input_args.push_back(&scanInitialStateNodeArg);
- AttachNodeArg(&scanGraph, subGraphInitialStateNodeArg.Name(), true);
+ AttachNodeArg(&scanGraph, subGraphInitialStateNodeArg.Name(), true, true);
{
// as with initial state, output state does have batch axis but not sequence axis.
@@ -3887,11 +4094,17 @@ bool CNTKToONNXHelper::ProcessLoopsAndCheckCNTKNodeContinueCreate(const Function
true, false);
scanFinalStateTypeProto.mutable_tensor_type()->set_elem_type(
ConvertDataTypeCNTKToTensorProto(scanLoopState.m_stateOutput.GetDataType()));
- onnxruntime::NodeArg &scanFinalStateNodeArg = graph->GetOrCreateNodeArg(
- ToLegacyString(ToUTF8(scanLoopState.m_stateOutput.Uid())), &scanFinalStateTypeProto);
+
+ // TODO: UniqueNodeNameStorage is causing model validation failure.
+ std::string stateOutputName = ToLegacyString(ToUTF8(scanLoopState.m_stateOutput.Uid()));
+ // std::string stateOutputName = UniqueNodeNameStorage::GetUniqueInputNodeName(scanLoopState.m_stateOutput);
+
+ onnxruntime::NodeArg &scanFinalStateNodeArg =
+ graph->GetOrCreateNodeArg(stateOutputName, &scanFinalStateTypeProto);
output_args.push_back(&scanFinalStateNodeArg);
- AttachNodeArg(&scanGraph, scanFinalStateNodeArg.Name(), false);
+
+ AttachNodeArg(&scanGraph, stateOutputName, false, true);
}
if (scanLoopState.m_hasInitializer)
@@ -3901,12 +4114,17 @@ bool CNTKToONNXHelper::ProcessLoopsAndCheckCNTKNodeContinueCreate(const Function
for (auto &scanInput : scanLoops[loopIndex].m_scanInputs)
{
- NodeArg& scanInputNodeArg = CreateNodeArg(scanInput, graph);
+ std::string subgraphNodeArgName;
+ auto inputItr = compositeOutputsMap.find(scanInput);
+ if (inputItr != compositeOutputsMap.end())
+ subgraphNodeArgName = UniqueNodeNameStorage::GetUniqueInputNodeName(inputItr->second);
+ else
+ subgraphNodeArgName = UniqueNodeNameStorage::GetUniqueInputNodeName(scanInput);
- std::string subgraphNodeArgName = UniqueNodeNameStorage::GetUniqueInputNodeName(scanInput);
+ NodeArg& scanInputNodeArg = CreateNodeArg(scanInput, graph, true, subgraphNodeArgName);
- AttachNodeArg(&scanGraph, subgraphNodeArgName, true);
- NodeArg& transposedScanInputNodeArg = AddTransposeBatchSequenceAxesNode(scanInputNodeArg, true, graph);
+ AttachNodeArg(&scanGraph, subgraphNodeArgName, true, false);
+ NodeArg& transposedScanInputNodeArg = AddTransposeBatchSequenceAxesNode(scanInputNodeArg, true, graph, scanNodeName);
input_args.push_back(&transposedScanInputNodeArg);
// IMPORTANT: can only support single direction for now.
@@ -3925,12 +4143,20 @@ bool CNTKToONNXHelper::ProcessLoopsAndCheckCNTKNodeContinueCreate(const Function
for (auto &scanOutput : scanLoops[loopIndex].m_scanOutputs)
{
// add the NodeArg to the main graph because it may not get a chance to get created if two scans are connected back-to-back
- NodeArg& scanOutputNodeArg = CreateNodeArg(scanOutput, graph);
- AttachNodeArg(&scanGraph, scanOutputNodeArg.Name(), false);
- NodeArg& transposedScanOutputNodeArg = AddTransposeBatchSequenceAxesNode(scanOutputNodeArg, false, graph);
+ // if scan output is alos the final state, rename the scan output to avoid output name collision.
+
+ NodeArg* scanOutputNodeArg;
+ if (IsStepFunction(scanOutput.Owner()))
+ scanOutputNodeArg = &CreateNodeArg(scanOutput, graph, false,
+ UniqueNodeNameStorage::GetUniqueOutputNodeName(scanOutput) + "_finalstate_as_scanoutput");
+ else
+ scanOutputNodeArg = &CreateNodeArg(scanOutput, graph, false);
+
+ AttachNodeArg(&scanGraph, UniqueNodeNameStorage::GetUniqueOutputNodeName(scanOutput), false, false);
+ NodeArg& transposedScanOutputNodeArg = AddTransposeBatchSequenceAxesNode(*scanOutputNodeArg, false, graph, scanNodeName);
output_args.push_back(&transposedScanOutputNodeArg);
}
- Node *scanNode = graph->AddNode(ToLegacyString(ToUTF8(src->Uid())), "Scan", "", input_args, output_args);
+ Node *scanNode = graph->AddNode(scanNodeName, "Scan", "", input_args, output_args);
ResolveGraphAndSaveModel(scanSubModel.get());
@@ -3957,12 +4183,12 @@ onnxruntime::Node* CNTKToONNXHelper::CreateNode(const FunctionPtr& src,
const std::unordered_map& compositeOutputsMap,
std::vector &scanLoops, int createLoopIndex)
{
- if (!ProcessLoopsAndCheckCNTKNodeContinueCreate(src, graph, functionNodes, variableNodes, compositeOutputsMap, scanLoops, createLoopIndex))
- return nullptr;
-
auto iter = functionNodes.find(src);
if (iter != functionNodes.end())
return iter->second;
+
+ if (!ProcessLoopsAndCheckCNTKNodeContinueCreate(src, graph, functionNodes, variableNodes, compositeOutputsMap, scanLoops, createLoopIndex))
+ return nullptr;
onnxruntime::Node* functionNode = nullptr;
std::string cntkOpName = ToLegacyString(ToUTF8(src->OpName()));
@@ -3978,7 +4204,16 @@ onnxruntime::Node* CNTKToONNXHelper::CreateNode(const FunctionPtr& src,
// return CreateLSTMNode(src, graph, functionNodes, variableNodes, compositeOutputsMap);
//}
//else
- if (cntkOpName == "Sequence::Slice")
+ if (cntkOpName == "GatherPacked")
+ {
+ return CreateNodeWithGatherPacked(src,
+ graph,
+ functionNodes,
+ variableNodes,
+ compositeOutputsMap,
+ scanLoops, createLoopIndex);
+ }
+ else if (cntkOpName == "Sequence::Slice")
{
return CreateSequenceSliceNode(src,
graph,
@@ -4041,6 +4276,15 @@ onnxruntime::Node* CNTKToONNXHelper::CreateNode(const FunctionPtr& src,
compositeOutputsMap,
scanLoops, createLoopIndex, false);
}
+ else if (cntkOpName == "PastValue" || cntkOpName == "FutureValue")
+ {
+ if (createLoopIndex != -1)
+ // ProcessLoopsAndCheckCNTKNodeContinueCreate shall have already handled
+ // PastValue or FutureValue ops in a loop.
+ LogicError("PastValue or FutureValue ops inside a loop shall not reach here.");
+ return CreatePastFutureValueNode(src, graph, functionNodes, variableNodes, compositeOutputsMap,
+ scanLoops, createLoopIndex);
+ }
else if (cntkOpName == "Sequence::Gather")
{
return CreateSequenceGatherNode(src,
@@ -4054,20 +4298,34 @@ onnxruntime::Node* CNTKToONNXHelper::CreateNode(const FunctionPtr& src,
{
return CreateSoftmaxLikeNode(src, graph, functionNodes, variableNodes, compositeOutputsMap, scanLoops, createLoopIndex);
}
+ // in the following RNN cases, we need to unblock the RNN block
+ // it is in a loop.
else if (cntkOpName == "RNNStep")
{
- return CreateRNNNode(src, graph, functionNodes, variableNodes, compositeOutputsMap,
- scanLoops, createLoopIndex);
+ if (createLoopIndex == -1)
+ return CreateRNNNode(src, graph, functionNodes, variableNodes, compositeOutputsMap,
+ scanLoops, createLoopIndex);
+ else
+ functionNode = CreateNode(src->BlockRoot(), graph, functionNodes, variableNodes, compositeOutputsMap,
+ scanLoops, createLoopIndex);
}
else if (cntkOpName == "GRU")
{
- return CreateGRUNode(src, graph, functionNodes, variableNodes, compositeOutputsMap,
- scanLoops, createLoopIndex);
+ if (createLoopIndex == -1)
+ return CreateGRUNode(src, graph, functionNodes, variableNodes, compositeOutputsMap,
+ scanLoops, createLoopIndex);
+ else
+ functionNode = CreateNode(src->BlockRoot(), graph, functionNodes, variableNodes, compositeOutputsMap,
+ scanLoops, createLoopIndex);
}
else if (cntkOpName == "LSTM")
{
- return CreateLSTMNode(src, graph, functionNodes, variableNodes, compositeOutputsMap,
- scanLoops, createLoopIndex);
+ if (createLoopIndex == -1)
+ return CreateLSTMNode(src, graph, functionNodes, variableNodes, compositeOutputsMap,
+ scanLoops, createLoopIndex);
+ else
+ functionNode = CreateNode(src->BlockRoot(), graph, functionNodes, variableNodes, compositeOutputsMap,
+ scanLoops, createLoopIndex);
}
else if (cntkOpName == "Combine")
{
@@ -4192,7 +4450,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateNode(const FunctionPtr& src,
Variable SkipDynamicAxisPackUnpack(Variable input, bool &dynamicAxisPackUnpackSkipped)
{
dynamicAxisPackUnpackSkipped = false;
- std::set ops({ L"UnpackBatchAxis" , L"ToBatchAxis" , L"UnpackSequenceOp" , L"UnpackBatchAxis" });
+ std::set ops({ L"UnpackBatchAxis" , L"ToBatchAxis" , L"UnpackSequenceOp", L"ToSequenceOp" });
while (input.Owner() && ops.find(input.Owner()->OpName()) != ops.end())
{
input = input.Owner()->Inputs()[0];
@@ -4204,7 +4462,7 @@ Variable SkipDynamicAxisPackUnpack(Variable input, bool &dynamicAxisPackUnpackSk
bool TryMatchNodeArgType(onnx::TypeProto &argType, onnxruntime::Graph* graph, const std::string &nodeArgName)
{
- const NodeArg* inputNodeArg = graph->FindNodeArg(nodeArgName);
+ const NodeArg* inputNodeArg = graph->GetNodeArg(nodeArgName);
if (inputNodeArg)
{
onnx::TensorProto_DataType inputType = inputNodeArg->TypeAsProto()->tensor_type().elem_type();
@@ -4240,7 +4498,7 @@ bool TryMatchNodeArgType(onnx::TypeProto &argType, onnxruntime::Graph* graph, co
//
// input is not necessarily an input to src. It may be obtained via skipping of batch/sequence pack/unpack wrappers.
NodeArg* CNTKToONNXHelper::GetInputAdjustmentForBroadcast(onnxruntime::Graph* graph, const FunctionPtr src,
- const Variable &input, int inputIndex, onnx::TypeProto &inputArgType)
+ const Variable &input, int inputIndex, onnx::TypeProto &inputArgType, const std::unordered_map& compositeOutputsMap)
{
// TODO: do we need to get blockroot if it is a block function?
if (!Operators::SupportBroadcast(src->OpName()))
@@ -4290,7 +4548,13 @@ NodeArg* CNTKToONNXHelper::GetInputAdjustmentForBroadcast(onnxruntime::Graph* gr
//auto inputArgType = ToTypeProto(input.Shape(), input.HasBatchAxis(), input.HasSequenceAxis());
//inputArgType.mutable_tensor_type()->set_elem_type(inputArgType.tensor_type().elem_type());
//UpdateONNXType(input.GetDataType(), inputArgType);
- std::string inputNodeArgName = UniqueNodeNameStorage::GetUniqueInputNodeName(input);
+ std::string inputNodeArgName;
+ auto inputItr = compositeOutputsMap.find(input);
+ if (inputItr != compositeOutputsMap.end())
+ inputNodeArgName = UniqueNodeNameStorage::GetUniqueInputNodeName(inputItr->second);
+ else
+ inputNodeArgName = UniqueNodeNameStorage::GetUniqueInputNodeName(input);
+
std::string outputArgName = UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(inputNodeArgName + "_reshaped_for_broadcast");
onnxruntime::NodeArg &nodeArg = graph->GetOrCreateNodeArg(inputNodeArgName, &inputArgType);
Node *reshapeNode = AddReshapeNode(nodeArg, newShape, outputArgName, graph);
@@ -4356,7 +4620,12 @@ void CNTKToONNXHelper::ProcessInputs(const FunctionPtr& src,
continue;
}
- if (FilterInput(src, input, inputIndex))
+ if (src->OpName() == L"Sequence::Slice" && inputIndex != src->Inputs().size() - 1)
+ {
+ // for Sequence::Slice, only the last input is the real valid input.
+ continue;
+ }
+ else if (FilterInput(src, input, inputIndex))
continue;
//
@@ -4397,6 +4666,22 @@ void CNTKToONNXHelper::ProcessInputs(const FunctionPtr& src,
inputArgType = ToTypeProto(input.Shape(), input.HasBatchAxis(), input.HasSequenceAxis());
}
}
+ else if (input.Owner() && ONNX::Operators::IsRNNOp(ToLegacyString(ToUTF8(input.Owner()->OpName()))) &&
+ createLoopIndex >= 0 && createLoopIndex < scanLoops.size())
+ {
+ // we are processing subgraph and hit LSTM block.
+ // Because LSTM is constructed as a whole compositeOutputsMap does not have map for LSTM block.
+ // Now LSTM is in the loop. The LSTM block is decomposed in scan loop.
+ // So we need to use its internal names (instead of block names).
+ BlockFunction* block = dynamic_cast(input.Owner().get());
+
+ // from block to underlying
+ std::unordered_map bm = block->CompositeOutputsMap();
+ if (bm.find(input) == bm.end())
+ LogicError("cannot map PastValue/Future's input to LSTM underlying output");
+
+ inputName = UniqueNodeNameStorage::GetUniqueInputNodeName(bm[input]);
+ }
//
// If this input is output, then it is the ouput of an up stream node. Recursively add all upstream nodes.
@@ -4492,6 +4777,11 @@ void CNTKToONNXHelper::ProcessInputs(const FunctionPtr& src,
// we already completed preparation of this input and can proceed to the next input.
continue;
}
+ else if (createLoopIndex >= 0 && createLoopIndex < scanLoops.size())
+ {
+ //
+ UpdateONNXType(input.GetDataType(), inputArgType);
+ }
}
}
else
@@ -4538,7 +4828,8 @@ void CNTKToONNXHelper::ProcessInputs(const FunctionPtr& src,
}
}
- onnxruntime::NodeArg *adjusted = GetInputAdjustmentForBroadcast(graph, src, input, inputIndex, inputArgType);
+ onnxruntime::NodeArg *adjusted = GetInputAdjustmentForBroadcast(graph, src, input, inputIndex, inputArgType,
+ compositeOutputsMap);
onnxruntime::NodeArg &inputArg = adjusted == nullptr ? graph->GetOrCreateNodeArg(inputName, &inputArgType) : *adjusted;
@@ -4613,6 +4904,42 @@ void CNTKToONNXHelper::ProcessOutputs(const FunctionPtr& src,
}
else if (OpNeedONNXTypeMap(onnxOpName))
{
+ TensorProto_DataType onnx_type = MapAndUpdateONNXType(onnxOpName, false, outputIndex, output.GetDataType(), nullptr);
+ TensorProto_DataType cntk_type = ConvertDataTypeCNTKToTensorProto(output.GetDataType());
+ // TODO: handle all cases
+ if (((onnxOpName == "TopK" && outputIndex == 1) ||
+ onnxOpName == "ArgMax" || onnxOpName == "ArgMin" ||
+ onnxOpName == "Greater" || onnxOpName == "Equal" || onnxOpName == "Less" ||
+ onnxOpName == "Not" || onnxOpName == "Or" || onnxOpName == "Xor") &&
+ cntk_type != onnx_type)
+ {
+ // output NodeArg has not been created yet.
+ // a Cast op needs to be inserted to get the desired type in ONNX.
+
+ // cast ONNX op output type (onnx_type) to CNTK output type (output.GetDataType()).
+ // element type of the input to the Cast op is onnx_type.
+ // element type of the output (outputArgType) of the Cast op is CNTK output.GetDataType()
+ // input and output of the cast op have the same shape.
+ UpdateONNXType(output.GetDataType(), outputArgType);
+
+ auto castInputArgType = ToTypeProto(output.Shape(), output.HasBatchAxis(), output.HasSequenceAxis());
+ castInputArgType.mutable_tensor_type()->set_elem_type(onnx_type);
+
+ std::string outputArgNodeName = UniqueNodeNameStorage::GetUniqueOutputNodeName(output);
+ // std::string outputArgNodeName = ToLegacyString(ToUTF8(output.Uid()));
+ onnxruntime::NodeArg &castInputArg = graph->GetOrCreateNodeArg(
+ outputArgNodeName + "_post_cast_input", &castInputArgType);
+ onnxruntime::NodeArg &castOutputArg = graph->GetOrCreateNodeArg(outputArgNodeName, &outputArgType);
+
+ onnxruntime::Node* castNode = graph->AddNode(castInputArg.Name() + string("_cast_") + outputArgNodeName,
+ "Cast", "", { &castInputArg }, { &castOutputArg });
+ castNode->AddAttribute("to", (int64_t)cntk_type);
+
+ outputs.push_back(&castInputArg);
+
+ // we already completed preparation of this input and can proceed to the next input.
+ continue;
+ }
MapAndUpdateONNXType(onnxOpName, false, outputIndex, output.GetDataType(), &outputArgType);
}
else
@@ -4640,9 +4967,9 @@ void CNTKToONNXHelper::TraverseGraph(const FunctionPtr& src,
return;
}
- if (!Operators::IsRNNOp(opName) && !Operators::IsSequenceBlockOp(opName) &&
+ if (!Operators::IsRNNOp(opName) && !Operators::IsSequenceBlockOp(opName) && opName != "Tuple" &&
src->IsBlock() &&
- (!Operators::IsSupportedCNTKOP(src->OpName()) || Operators::IsLayerCNTKOP(src->OpName())) ||
+ (!Operators::IsSupportedCNTKOP(src->OpName()) || Operators::IsLayerCNTKOP(src->OpName())) ||
IsUnSupportedLayerNormalization(src))
{
auto blockSrc = dynamic_cast(src.get());
@@ -4786,7 +5113,7 @@ void CNTKToONNXHelper::CopyAttributes(const FunctionPtr& src, onnxruntime::Node*
node->AddAttribute(attributesMap[L"newShape"], ToINTS(shape));
}
}
- if (src->OpName() == L"ReduceL1" || src->OpName() == L"ReduceL2" || src->OpName() == L"ReduceSumSquare")
+ else if (src->OpName() == L"ReduceL1" || src->OpName() == L"ReduceL2" || src->OpName() == L"ReduceSumSquare")
{
SetReduceElementsAttributes(src, node);
}
@@ -4853,6 +5180,8 @@ void CNTKToONNXHelper::CopyAttributes(const FunctionPtr& src, onnxruntime::Node*
beginIndex.push_back((int)(src->Attributes()[L"beginIndex"].Value()));
endIndex.push_back((int)(src->Attributes()[L"endIndex"].Value()));
+ if (*beginIndex.rbegin() == -1 && *endIndex.rbegin() == 0)
+ *endIndex.rbegin() = std::numeric_limits::max();
}
std::vector beginIndex64 = Cast(beginIndex);
@@ -5158,6 +5487,35 @@ void CNTKToONNXHelper::CopyAttributes(const FunctionPtr& src, onnxruntime::Node*
{
SetReduceElementsAttributes(src, node);
}
+ else if ((src->OpName() == L"RandomDistribution") ||
+ (src->OpName() == L"UniformRandom") || (src->OpName() == L"NormalRandom") ||
+ (src->OpName() == L"UniformRandomLike") || (src->OpName() == L"NormalRandomLike"))
+ {
+ std::string onnxOp = node->OpType();
+ auto randomArgs = AsVector(src->Attributes()[L"randomDistributionArgs"].Value>());
+ auto seed = (int64_t)src->Attributes()[L"rngSeed"].Value();
+
+ if ((onnxOp == "RandomNormal") || (onnxOp == "RandomNormalLike"))
+ {
+ node->AddAttribute("mean", (float)randomArgs[0]);
+ node->AddAttribute("scale", (float)randomArgs[1]);
+ }
+ else
+ {
+ node->AddAttribute("low", (float)randomArgs[0]);
+ node->AddAttribute("high", (float)randomArgs[1]);
+ }
+
+ node->AddAttribute("seed", (float)seed);
+ if ((onnxOp == "RandomUniform") || (onnxOp == "RandomNormal"))
+ {
+ auto shape = (NDShape)src->Attributes()[L"newShape"].Value();
+ node->AddAttribute("shape", ToINTS(shape));
+
+ DataType dataType = (DataType)src->Attributes()[L"newDataType"].Value();
+ node->AddAttribute("dtype", (int64_t)ConvertDataTypeCNTKToTensorProto(dataType));
+ }
+ }
}
}
@@ -5169,7 +5527,10 @@ void CNTKToONNXHelper::SetReduceElementsAttributes(const FunctionPtr src, Node *
reductionOpName = src->Attributes()[L"reductionOpName"].Value();
}
- auto keepReducedDimensions = (int64_t)((bool)src->Attributes()[L"reductionKeepDimensions"].Value() ? 1 : 0);
+ //
+ int64_t keepReducedDimensions = 1;
+ if (src->Attributes().Contains(L"reductionKeepDimensions"))
+ keepReducedDimensions = (int64_t)((bool)src->Attributes()[L"reductionKeepDimensions"].Value() ? 1 : 0);
bool forceKeepReducedDimensions = false;
std::vector reductionAxes;
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/ControlFlowHelper.h b/Source/CNTKv2LibraryDll/proto/onnx/ControlFlowHelper.h
index 8e6f6b094..7e48cd5f6 100644
--- a/Source/CNTKv2LibraryDll/proto/onnx/ControlFlowHelper.h
+++ b/Source/CNTKv2LibraryDll/proto/onnx/ControlFlowHelper.h
@@ -2,6 +2,10 @@
#include "CNTKLibrary.h"
#include "Internals/ComputationGraphAlgorithms.h"
#include "core/graph/graph.h"
+#include "Operators.h"
+#include
+
+using namespace Microsoft::MSR::CNTK;
namespace CNTK
{
@@ -35,10 +39,62 @@ namespace CNTK
m_scanOutputs(scanOutputs),
m_body(body),
m_scanOpCreated(false)
- {}
+ {
+ // collect nodes in RNN ops as part of the body
+ for (auto &f : m_body)
+ {
+ if (ONNX::Operators::IsRNNOp(ToLegacyString(ToUTF8(f->OpName()))))
+ {
+ std::vector rnnInternalBody;
+ CollectInternalNodes(f->BlockRoot(), rnnInternalBody);
+ m_rnnInternalBodies.insert(std::make_pair(f, rnnInternalBody));
+ }
+ }
+
+ // if RNN is in the loop, we want to map scan outputs that are from LSTM
+ // to LSTM block underlying variable
+ for (auto &rnn : this->m_rnnInternalBodies)
+ {
+ FunctionPtr rnnF = rnn.first;
+ BlockFunction* block = dynamic_cast(rnnF.get());
+ std::unordered_map bm = block->CompositeOutputsMap();
+ for (auto &blockOutput : rnnF->Outputs())
+ {
+ for (int i = 0; i < m_scanOutputs.size(); i++)
+ {
+ if (m_scanOutputs[i] == blockOutput)
+ {
+ if (bm.find(blockOutput) == bm.end())
+ LogicError("cannot map PastValue/Future's input to LSTM underlying output");
+ m_scanOutputs[i] = bm[blockOutput];
+ }
+ }
+ }
+ }
+ }
+
+ bool IsInBody(const FunctionPtr src)
+ {
+ if (std::find(this->m_body.begin(), this->m_body.end(), src) != this->m_body.end())
+ return true;
+ for (auto &rnn : this->m_rnnInternalBodies)
+ {
+ if (std::find(rnn.second.begin(), rnn.second.end(), src) != rnn.second.end())
+ return true;
+ }
+ return false;
+ }
+
+ static void CollectInternalNodes(FunctionPtr src, std::vector &rnnInternalBody)
+ {
+ src->PreorderTraverse([&rnnInternalBody](const FunctionPtr& function) {
+ rnnInternalBody.push_back(function);
+ }, false);
+ }
std::vector m_inputs, m_outputs, m_scanInputs, m_scanOutputs;
std::vector m_body;
+ std::unordered_map> m_rnnInternalBodies;
std::vector scanLoopStates;
std::vector m_visited;
bool m_scanOpCreated;
@@ -74,6 +130,16 @@ namespace CNTK
return L"( " + f->Name() + L": " + f->Uid() + L")";
}
+ bool IsStepFunction(FunctionPtr f)
+ {
+ return f->OpName() == L"PastValue" || f->OpName() == L"FutureValue";
+ }
+
+ void AddScanOutputVariable(std::vector& scanoutput, Variable output)
+ {
+ scanoutput.push_back(output);
+ }
+
void BuildLoops(const std::vector& roots,
std::vector &scanLoops)
{
@@ -160,7 +226,7 @@ namespace CNTK
{
outputs.push_back(input);
if (input.DynamicAxes().size() == 2)
- scanoutputs[l].push_back(input);
+ AddScanOutputVariable(scanoutputs[l], input);
}
}
}
@@ -175,7 +241,9 @@ namespace CNTK
if (std::find(loop.Nodes().begin(), loop.Nodes().end(), root) != loop.Nodes().end())
for (auto output : root->Outputs())
if (std::find(scanoutputs[l].begin(), scanoutputs[l].end(), output) == scanoutputs[l].end())
- scanoutputs[l].push_back(output);
+ {
+ AddScanOutputVariable(scanoutputs[l], output);
+ }
}
}
@@ -189,7 +257,7 @@ namespace CNTK
const std::vector &nodes = loop.Nodes();
for (auto &f : nodes)
{
- if (f->OpName() == L"PastValue" || f->OpName() == L"FutureValue")
+ if (IsStepFunction(f))
loopstepfunctions[l].push_back(f);
else if (f->OpName() != L"LSTM" && f->OpName() != L"GRU" && f->OpName() != L"RNNStep")
filterOutBlockRNNs[l] = true;
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/ONNX.cpp b/Source/CNTKv2LibraryDll/proto/onnx/ONNX.cpp
index 4c8bac090..9e2cf0aea 100644
--- a/Source/CNTKv2LibraryDll/proto/onnx/ONNX.cpp
+++ b/Source/CNTKv2LibraryDll/proto/onnx/ONNX.cpp
@@ -23,28 +23,28 @@ namespace CNTK
{
std::once_flag ONNXFormat::op_schema_initializer_flag_;
static std::string defaultLoggerId{"Default"};
- static onnxruntime::Logging::LoggingManager default_logging_manager_{
- std::unique_ptr{new CNTKClogSink{}},
+ static onnxruntime::logging::LoggingManager default_logging_manager_{
+ std::unique_ptr{new CNTKClogSink{}},
[](){
- onnxruntime::Logging::Severity severity;
+ onnxruntime::logging::Severity severity;
switch (GetTraceLevel())
{
case TraceLevel::Error:
- severity = onnxruntime::Logging::Severity::kERROR;
+ severity = onnxruntime::logging::Severity::kERROR;
break;
case TraceLevel::Warning:
- severity = onnxruntime::Logging::Severity::kWARNING;
+ severity = onnxruntime::logging::Severity::kWARNING;
break;
case TraceLevel::Info:
- severity = onnxruntime::Logging::Severity::kINFO;
+ severity = onnxruntime::logging::Severity::kINFO;
break;
default:
- severity = onnxruntime::Logging::Severity::kFATAL;
+ severity = onnxruntime::logging::Severity::kFATAL;
}
return severity;
}(),
false,
- onnxruntime::Logging::LoggingManager::InstanceType::Default,
+ onnxruntime::logging::LoggingManager::InstanceType::Default,
&defaultLoggerId };
static void PrintGraph(FunctionPtr function, int spaces, bool useName = false)
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp b/Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp
index 938da4b3c..89ff07e73 100644
--- a/Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp
+++ b/Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp
@@ -461,7 +461,7 @@ Constant ONNXToCNTKHelper::CreateConstant(const onnx::TensorProto &valueProto, c
{
// It does not work using vector because resulted memory layout is not what we expect.
bool *srcData = new bool[shape.TotalSize()];
- onnxruntime::Utils::TensorUtils::UnpackTensor(valueProto, srcData, shape.TotalSize());
+ onnxruntime::utils::TensorUtils::UnpackTensor(valueProto, srcData, shape.TotalSize());
// CNTK does not support bool. We need to convert to float.
std::vector srcFloatData(shape.TotalSize());
@@ -476,7 +476,7 @@ Constant ONNXToCNTKHelper::CreateConstant(const onnx::TensorProto &valueProto, c
case TensorProto_DataType_INT32:
{
std::vector srcData(shape.TotalSize());
- onnxruntime::Utils::TensorUtils::UnpackTensor(valueProto, &srcData[0], shape.TotalSize());
+ onnxruntime::utils::TensorUtils::UnpackTensor(valueProto, &srcData[0], shape.TotalSize());
// CNTK does not support int. We need to convert to float.
std::vector srcFloatData(shape.TotalSize());
@@ -490,7 +490,7 @@ Constant ONNXToCNTKHelper::CreateConstant(const onnx::TensorProto &valueProto, c
case TensorProto_DataType_INT64:
{
std::vector srcData(shape.TotalSize());
- onnxruntime::Utils::TensorUtils::UnpackTensor(valueProto, &srcData[0], shape.TotalSize());
+ onnxruntime::utils::TensorUtils::UnpackTensor(valueProto, &srcData[0], shape.TotalSize());
// CNTK does not support int64_t. We need to convert to float.
std::vector srcFloatData(shape.TotalSize());
@@ -1235,7 +1235,7 @@ std::vector CreateRNNConstantOp(const Graph *graph, const Node *nod
const DeviceDescriptor &computeDevice)
{
const onnx::TensorProto *valueProto;
- if (!graph->GetInitializedTensor(node->Name(), &valueProto))
+ if (!graph->GetInitializedTensor(node->Name(), valueProto))
{
NodeAttributes::const_iterator itValue = node->GetAttributes().find("value");
if (itValue == node->GetAttributes().cend())
@@ -1260,7 +1260,7 @@ std::vector ONNXToCNTKHelper::CreateRNNLeafVariableOrConstant(const No
string parentONNXOpName = parentNode->OpType();
std::string nodeName = nodeArg->Name();
const onnx::TensorProto *valueProto;
- if (graph->GetInitializedTensor(nodeName, &valueProto))
+ if (graph->GetInitializedTensor(nodeName, valueProto))
{
int index = CalculateNodeArgInputIndex(nodeArg, parentNode);
return CreateRNNConstant(parentNode, index, nodeName, *valueProto, computeDevice);
@@ -1379,7 +1379,7 @@ Variable ONNXToCNTKHelper::CreateLeafVariableOrConstant(const NodeArg *nodeArg,
std::string nodeName = nodeArg->Name();
const onnx::TensorProto *valueProto;
- if (graph->GetInitializedTensor(nodeName, &valueProto))
+ if (graph->GetInitializedTensor(nodeName, valueProto))
{
return CreateConstant(*valueProto, nodeName, computeDevice); // There is no batch axis added on here.
}
@@ -1438,14 +1438,14 @@ ConvAutoPadType ONNXToCNTKHelper::ConvertStrToConvAutoPadType(const string &str)
std::vector ONNXToCNTKHelper::GetShapeFromInput(const NodeArg *shapeInput, const Graph *graph)
{
const onnx::TensorProto *valueProto;
- if (!graph->GetInitializedTensor(shapeInput->Name(), &valueProto))
+ if (!graph->GetInitializedTensor(shapeInput->Name(), valueProto))
{
LogicError("Non-constant shape input for Reshape is not implemented.");
};
auto shapeSize = valueProto->dims(0);
std::vector dimData(shapeSize);
- onnxruntime::Utils::TensorUtils::UnpackTensor(*valueProto, &dimData[0], shapeSize);
+ onnxruntime::utils::TensorUtils::UnpackTensor(*valueProto, &dimData[0], shapeSize);
return dimData;
}
@@ -1922,7 +1922,7 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
)
{
string onnxOpName = node->OpType();
- Variable inputOperand0 = (inputPlaceholder.IsInitialized()) ? inputPlaceholder : inputs[0];
+ Variable inputOperand0 = (inputPlaceholder.IsInitialized() || inputs.empty()) ? inputPlaceholder : inputs[0];
if (onnxOpName == "LSTM")
{
@@ -2200,8 +2200,9 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
{
const NDShape &shape = GetNamedAttributeAsShape(node, "shape", false);
- // ONNX only has float type for random generators
- CNTK::DataType dataType = CNTK::DataType::Float;
+ TensorProto_DataType onnxDataType = static_cast(GetNamedAttributeAsInt64(
+ node, "dtype", TensorProto_DataType::TensorProto_DataType_FLOAT));
+ CNTK::DataType dataType = ConvertDataTypeTensorProtoToCNTK(onnxDataType);
double low = GetNamedAttributeAsFloat(node, "low");
double high = GetNamedAttributeAsFloat(node, "high");
@@ -2212,7 +2213,11 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
else if (onnxOpName == "RandomNormal")
{
const NDShape &shape = GetNamedAttributeAsShape(node, "shape", false);
- CNTK::DataType dataType = CNTK::DataType::Float;
+
+ TensorProto_DataType onnxDataType = static_cast(GetNamedAttributeAsInt64(
+ node, "dtype", TensorProto_DataType::TensorProto_DataType_FLOAT));
+ CNTK::DataType dataType = ConvertDataTypeTensorProtoToCNTK(onnxDataType);
+
double mean = GetNamedAttributeAsFloat(node, "mean");
double scale = GetNamedAttributeAsFloat(node, "scale");
unsigned long seed = GetNamedAttributeAsInt64(node, "seed");
@@ -2684,8 +2689,6 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
}
else if (onnxOpName == "Concat")
{
- if (node->Name() == "Splice3547")
- std::cout << std::endl;
// We allow the 'axis' attribute to be optional, and not required (as
// given in Concat's ONNX spec), to be consistent with other frameworks.
// 'axis' can be enforced as a required attribute, if needed.
@@ -2962,6 +2965,15 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
FunctionPtr cntkFunction = Crop(inputOperand0, referent, leftBorder, topBorder, ToFixedWStringFromMultiByte(node->Name()));
return cntkFunction;
}
+ else if (onnxOpName == "OneHotEncoder")
+ {
+ // TODO: this only works in this specific case.
+ std::vector cats = GetNamedAttributeAsInt64Vec(node, "cats_int64s");
+ int numClass = cats.size();
+ Axis axis = ConvertONNXAxisToCNTKCppApi(2, inputs[0]);
+ FunctionPtr cntkFunction = OneHotOp(inputs[0], numClass, false, axis);
+ return cntkFunction;
+ }
else
{
LogicError("ONNX (%s) is not supported in CNTK", onnxOpName.c_str());
@@ -3488,6 +3500,39 @@ FunctionPtr ONNXToCNTKHelper::CreateCNTKFCNode(const std::wstring &nodeName, con
return cntkFunction;
}
+// onnx graph library treats output NodeArgs as outputs.
+// when creating a CNTK model, we build a map from Nodes to FunctionPtrs.
+// To figure out the outputs of a CNTK model, we need to filter out
+// output variables of output Functions that are not in the graph outputs.
+void FilterGraphOutputs(std::vector &outputVariables)
+{
+ std::set visited;
+ std::vector sinkedVariables;
+ for (auto v : outputVariables)
+ {
+ if (v.Owner())
+ {
+ v.Owner()->PreorderTraverse([&visited, &sinkedVariables](const FunctionPtr& function) {
+ if (visited.find(function) != visited.end())
+ return;
+ visited.insert(function);
+ for (auto inputVariable : function->Inputs())
+ if (std::find(sinkedVariables.begin(), sinkedVariables.end(), inputVariable) == sinkedVariables.end())
+ sinkedVariables.push_back(inputVariable);
+
+ }, false);
+ }
+ }
+
+ for (std::vector::iterator it = outputVariables.begin(); it != outputVariables.end();)
+ {
+ if (std::find(sinkedVariables.begin(), sinkedVariables.end(), *it) != sinkedVariables.end())
+ it = outputVariables.erase(it);
+ else
+ ++it;
+ }
+}
+
FunctionPtr ONNXToCNTK::CreateGraph(onnxruntime::Graph *src, const DeviceDescriptor &computeDevice)
{
FunctionPtr cntkModel;
@@ -3510,20 +3555,31 @@ FunctionPtr ONNXToCNTK::CreateGraph(onnxruntime::Graph *src, const DeviceDescrip
}
}
- // ONNX puts all outputs in an graph as input to the "_Graph_Sink" node.
- ONNXToCNTKMap::iterator itNodeFn = std::find_if(constructedFunctions.begin(), constructedFunctions.end(),
- [](ONNXToCNTKMap::value_type nodeFn) { return nodeFn.first->Name() == "_Graph_Sink"; });
- if (itNodeFn == constructedFunctions.end())
+ std::vector functions;
+ const std::vector& graphOutputs = src->GetOutputs();
+ // collect output Nodes based on output NodeArgs
+ std::set outputNodes;
+ for (int i = 0; i < graphOutputs.size(); i++)
{
- return nullptr;
+ const NodeArg* nodeArg = graphOutputs[i];
+ for (auto &node : src->Nodes())
+ {
+ if (std::find(outputNodes.begin(), outputNodes.end(), &node) == outputNodes.end())
+ {
+ for (auto nodeOutput : node.OutputDefs())
+ if (nodeOutput == nodeArg)
+ {
+ outputNodes.insert(&node);
+ break;
+ }
+ }
+ }
}
- std::vector functions;
- for (Node::NodeConstIterator it = itNodeFn->first->InputNodesBegin(); it != itNodeFn->first->InputNodesEnd(); ++it)
+ // collect output FunctionPtrs from output Nodes
+ for (auto &node : outputNodes)
{
- // TODO: consulting onnxruntime to see how to do this solidly.
- // https://msasg.visualstudio.com/DefaultCollection/Shared%20Data/AIToolkits-CNTK/_queries?id=1134732&_a=edit&triage=true
- std::vector &constructedFuncts = constructedFunctions[*it];
+ std::vector &constructedFuncts = constructedFunctions[node];
for (int index = 0; index < constructedFuncts.size(); index++)
{
FunctionPtr &constructedFunct = constructedFuncts[index];
@@ -3543,7 +3599,17 @@ FunctionPtr ONNXToCNTK::CreateGraph(onnxruntime::Graph *src, const DeviceDescrip
else
{
// in case multiple outputs are in a graph, combine them into one CNTK graph.
- return Combine(std::vector(functions.begin(), functions.end()));
+ std::vector outputVariables;
+ for (auto f : functions)
+ {
+ for (auto v : f->Outputs())
+ {
+ outputVariables.push_back(v);
+ }
+ }
+ if (outputVariables.size() > graphOutputs.size())
+ FilterGraphOutputs(outputVariables);
+ return Combine(outputVariables);
}
}
@@ -3643,7 +3709,7 @@ std::pair> ONNXToCNTKHelper::CheckNodeBelongsToOp
if (firstParentNode != nullptr)
{
it = firstParentNode->OutputNodesBegin();
- if (it != node->OutputNodesEnd())
+ if (it != firstParentNode->OutputNodesEnd())
{
grandParentNode = *it;
}
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/Operators.cpp b/Source/CNTKv2LibraryDll/proto/onnx/Operators.cpp
index 34d8b1adc..4aeeb767b 100644
--- a/Source/CNTKv2LibraryDll/proto/onnx/Operators.cpp
+++ b/Source/CNTKv2LibraryDll/proto/onnx/Operators.cpp
@@ -116,6 +116,7 @@ namespace ONNX
// From Generator
{ L"RandomDistribution", { {
{ L"UniformRandom", "RandomUniform" },
+ { L"uniform", "RandomUniform" },
// { L"", "low" },
// { L"", "high" },
{ L"rngSeed", "seed" },
@@ -123,6 +124,7 @@ namespace ONNX
} } },
{ L"RandomDistribution", { {
{ L"NormalRandom", "RandomNormal" },
+ { L"normal", "RandomNormal" },
// { L"", "mean" },
// { L"", "scale" },
{ L"rngSeed", "seed" },
@@ -528,7 +530,8 @@ namespace ONNX
bool Operators::IsSequenceBlockOp(const std::string &opName)
{
- return opName == "Sequence::ReduceElements" || opName == "Sequence::BroadcastAs";
+ return opName == "Sequence::ReduceElements" || opName == "Sequence::BroadcastAs" ||
+ opName == "Sequence::Gather" || opName == "Sequence::Softmax";
}
std::unordered_map> Operators::_cntkBlockOPInvalidIndices = {
@@ -550,7 +553,8 @@ namespace ONNX
{ L"Softsign",{ 0 } },
{ L"ImageScaler",{ 0, 1, 2, 3 } },
{ L"MeanVarianceNormalization",{ 0 } },
- { L"Sequence::Slice",{ 0, 1 } },
+ { L"Sequence::Slice",{ 0, 1, 2, 3, 4 } },
+ { L"GatherPacked",{ 1 } },
};
std::unordered_map> Operators::_cntkToONNXInputIndices = {
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/capture.cc b/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/capture.cc
index 85704f3e3..93dbe9f8a 100644
--- a/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/capture.cc
+++ b/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/capture.cc
@@ -7,7 +7,7 @@
#include "gsl/gsl_util"
namespace onnxruntime {
-namespace Logging {
+namespace logging {
void Capture::CapturePrintf(msvc_printf_check const char* format, ...) {
va_list arglist;
@@ -47,5 +47,5 @@ Capture::~Capture() {
logger_->Log(*this);
}
}
-} // namespace Logging
+} // namespace logging
} // namespace onnxruntime
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/logging.cc b/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/logging.cc
index 1f912fbd1..bc6521c07 100644
--- a/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/logging.cc
+++ b/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/logging.cc
@@ -3,6 +3,7 @@
#include
#include
+#include
#include "core/common/exceptions.h"
#include "core/common/logging/isink.h"
@@ -16,7 +17,7 @@
#endif
namespace onnxruntime {
-namespace Logging {
+namespace logging {
const char* Category::onnxruntime = "onnxruntime";
const char* Category::System = "System";
@@ -133,7 +134,7 @@ void LoggingManager::CreateDefaultLogger(const std::string& logger_id) {
}
std::unique_ptr LoggingManager::CreateLogger(std::string logger_id) {
- return CreateLogger(logger_id, default_min_severity_, default_filter_user_data_, default_max_vlog_level_);
+ return CreateLogger(std::move(logger_id), default_min_severity_, default_filter_user_data_, default_max_vlog_level_);
}
std::unique_ptr LoggingManager::CreateLogger(std::string logger_id,
@@ -179,8 +180,8 @@ std::exception LoggingManager::LogFatalAndCreateException(const char* category,
// create Capture in separate scope so it gets destructed (leading to log output) before we throw.
{
- ::onnxruntime::Logging::Capture c{::onnxruntime::Logging::LoggingManager::DefaultLogger(),
- ::onnxruntime::Logging::Severity::kFATAL, category, ::onnxruntime::Logging::DataType::SYSTEM, location};
+ ::onnxruntime::logging::Capture c{::onnxruntime::logging::LoggingManager::DefaultLogger(),
+ ::onnxruntime::logging::Severity::kFATAL, category, ::onnxruntime::logging::DataType::SYSTEM, location};
va_list args;
va_start(args, format_str);
@@ -190,7 +191,7 @@ std::exception LoggingManager::LogFatalAndCreateException(const char* category,
exception_msg = c.Message();
}
- return LotusException(location, exception_msg);
+ return OnnxRuntimeException(location, exception_msg);
}
unsigned int GetThreadId() {
@@ -212,5 +213,5 @@ unsigned int GetProcessId() {
#endif
}
-} // namespace Logging
+} // namespace logging
} // namespace onnxruntime
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/sinks/cerr_sink.h b/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/sinks/cerr_sink.h
new file mode 100644
index 000000000..42577ba26
--- /dev/null
+++ b/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/sinks/cerr_sink.h
@@ -0,0 +1,21 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include
+#include "core/common/logging/sinks/ostream_sink.h"
+
+namespace onnxruntime {
+namespace logging {
+///
+/// A std::cerr based ISink
+///
+///
+class CErrSink : public OStreamSink {
+ public:
+ CErrSink() : OStreamSink(std::cerr, /*flush*/ false) { // std::cerr isn't buffered so no flush required
+ }
+};
+} // namespace logging
+} // namespace onnxruntime
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/include/core/common/logging/sinks/clog_sink.h b/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/sinks/clog_sink.h
similarity index 90%
rename from Source/CNTKv2LibraryDll/proto/onnx/core/include/core/common/logging/sinks/clog_sink.h
rename to Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/sinks/clog_sink.h
index d52dfe774..9b0adf92f 100644
--- a/Source/CNTKv2LibraryDll/proto/onnx/core/include/core/common/logging/sinks/clog_sink.h
+++ b/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/sinks/clog_sink.h
@@ -7,7 +7,7 @@
#include "core/common/logging/sinks/ostream_sink.h"
namespace onnxruntime {
-namespace Logging {
+namespace logging {
///
/// A std::clog based ISink
///
@@ -17,5 +17,5 @@ class CLogSink : public OStreamSink {
CLogSink() : OStreamSink(std::clog, /*flush*/ true) {
}
};
-} // namespace Logging
+} // namespace logging
} // namespace onnxruntime
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/sinks/composite_sink.h b/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/sinks/composite_sink.h
new file mode 100644
index 000000000..f27abb9e6
--- /dev/null
+++ b/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/sinks/composite_sink.h
@@ -0,0 +1,46 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include
+#include
+
+#include "core/common/logging/isink.h"
+#include "core/common/logging/logging.h"
+
+namespace onnxruntime {
+namespace logging {
+///
+/// Class that abstracts multiple ISink instances being written to.
+///
+///
+class CompositeSink : public ISink {
+ public:
+ ///
+ /// Initializes a new instance of the class.
+ /// Use AddSink to add sinks.
+ ///
+ CompositeSink() {}
+
+ ///
+ /// Adds a sink. Takes ownership of the sink (so pass unique_ptr by value).
+ ///
+ /// The sink.
+ /// This instance to allow chaining.
+ CompositeSink& AddSink(std::unique_ptr sink) {
+ sinks_.push_back(std::move(sink));
+ return *this;
+ }
+
+ private:
+ void SendImpl(const Timestamp& timestamp, const std::string& logger_id, const Capture& message) override {
+ for (auto& sink : sinks_) {
+ sink->Send(timestamp, logger_id, message);
+ }
+ }
+
+ std::vector> sinks_;
+};
+} // namespace logging
+} // namespace onnxruntime
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/sinks/file_sink.h b/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/sinks/file_sink.h
new file mode 100644
index 000000000..ba3ff3e0b
--- /dev/null
+++ b/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/sinks/file_sink.h
@@ -0,0 +1,51 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include
+#include "core/common/logging/sinks/ostream_sink.h"
+
+namespace onnxruntime {
+namespace logging {
+///
+/// ISink that writes to a file.
+///
+///
+class FileSink : public OStreamSink {
+ public:
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The filename to write to.
+ /// If set to true [append to file]. Otherwise truncate.
+ /// If set to true [removes user data].
+ /// Filtering of user data can alternatively be done at the level.
+ FileSink(std::unique_ptr file, bool filter_user_data)
+ : OStreamSink(*file, /*flush*/ true), file_(std::move(file)), filter_user_data_{filter_user_data} {
+ }
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The filename to write to.
+ /// If set to true [append to file]. Otherwise truncate.
+ /// If set to true [removes user data].
+ /// Filtering of user data can alternatively be done at the level.
+ FileSink(const std::string& filename, bool append, bool filter_user_data)
+ : FileSink{std::make_unique(filename, std::ios::out | (append ? std::ios::app : std::ios::trunc)),
+ filter_user_data} {
+ }
+
+ private:
+ void SendImpl(const Timestamp& timestamp, const std::string& logger_id, const Capture& message) override {
+ if (!filter_user_data_ || message.DataType() != DataType::USER) {
+ OStreamSink::SendImpl(timestamp, logger_id, message);
+ }
+ }
+
+ std::unique_ptr file_;
+ bool filter_user_data_;
+};
+} // namespace logging
+} // namespace onnxruntime
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/include/core/common/logging/sinks/ostream_sink.h b/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/sinks/ostream_sink.h
similarity index 94%
rename from Source/CNTKv2LibraryDll/proto/onnx/core/include/core/common/logging/sinks/ostream_sink.h
rename to Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/sinks/ostream_sink.h
index 7d17bf14a..bf5cec174 100644
--- a/Source/CNTKv2LibraryDll/proto/onnx/core/include/core/common/logging/sinks/ostream_sink.h
+++ b/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/sinks/ostream_sink.h
@@ -11,7 +11,7 @@
#include "core/common/logging/isink.h"
namespace onnxruntime {
-namespace Logging {
+namespace logging {
///
/// A std::ostream based ISink
///
@@ -29,5 +29,5 @@ class OStreamSink : public ISink {
std::ostream* stream_;
const bool flush_;
};
-} // namespace Logging
+} // namespace logging
} // namespace onnxruntime
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/common/profiler.cc b/Source/CNTKv2LibraryDll/proto/onnx/core/common/profiler.cc
index ea675bb2d..146671994 100644
--- a/Source/CNTKv2LibraryDll/proto/onnx/core/common/profiler.cc
+++ b/Source/CNTKv2LibraryDll/proto/onnx/core/common/profiler.cc
@@ -4,15 +4,15 @@
#include "profiler.h"
namespace onnxruntime {
-namespace Profiling {
+namespace profiling {
using namespace std::chrono;
-::onnxruntime::TimePoint Profiling::Profiler::StartTime() const {
+::onnxruntime::TimePoint profiling::Profiler::StartTime() const {
return std::chrono::high_resolution_clock::now();
}
-void Profiler::StartProfiling(const Logging::Logger* session_logger, const std::string& file_name) {
- LOTUS_ENFORCE(session_logger != nullptr);
+void Profiler::StartProfiling(const logging::Logger* session_logger, const std::string& file_name) {
+ ONNXRUNTIME_ENFORCE(session_logger != nullptr);
session_logger_ = session_logger;
enabled_ = true;
profile_stream_ = std::ofstream(file_name, std::ios::out | std::ios::trunc);
@@ -32,8 +32,8 @@ void Profiler::EndTimeAndRecordEvent(EventCategory category,
if (events_.size() < max_num_events_) {
long long dur = TimeDiffMicroSeconds(start_time);
long long ts = TimeDiffMicroSeconds(profiling_start_time_, start_time);
- events_.emplace_back(category, Logging::GetProcessId(),
- Logging::GetThreadId(), event_name, ts, dur, std::move(event_args));
+ events_.emplace_back(category, logging::GetProcessId(),
+ logging::GetThreadId(), event_name, ts, dur, std::move(event_args));
} else {
if (session_logger_ && !max_events_reached) {
LOGS(*session_logger_, ERROR)
@@ -80,8 +80,8 @@ std::string Profiler::WriteProfileData() {
// Conditionally sync the GPU if the syncGPU flag is set.
//
void ProfilerSyncGpu() {
- LOTUS_NOT_IMPLEMENTED("Needs to implement only for gpus");
+ ONNXRUNTIME_NOT_IMPLEMENTED("Needs to implement only for gpus");
}
-} // namespace Profiling
+} // namespace profiling
} // namespace onnxruntime
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/common/profiler.h b/Source/CNTKv2LibraryDll/proto/onnx/core/common/profiler.h
index 29bffa02f..3470677f3 100644
--- a/Source/CNTKv2LibraryDll/proto/onnx/core/common/profiler.h
+++ b/Source/CNTKv2LibraryDll/proto/onnx/core/common/profiler.h
@@ -8,7 +8,7 @@
namespace onnxruntime {
-namespace Profiling {
+namespace profiling {
enum EventCategory {
SESSION_EVENT = 0,
@@ -60,7 +60,7 @@ class Profiler {
/*
Start profiler and record beginning time.
*/
- void StartProfiling(const Logging::Logger* session_logger, const std::string& file_name);
+ void StartProfiling(const logging::Logger* session_logger, const std::string& file_name);
/*
Produce current time point for any profiling action.
@@ -84,19 +84,19 @@ class Profiler {
std::string WriteProfileData();
private:
- LOTUS_DISALLOW_COPY_ASSIGN_AND_MOVE(Profiler);
+ ONNXRUNTIME_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Profiler);
// Mutex controlling access to profiler data
std::mutex mutex_;
bool enabled_{false};
std::ofstream profile_stream_;
std::string profile_stream_file_;
- const Logging::Logger* session_logger_{nullptr};
+ const logging::Logger* session_logger_{nullptr};
TimePoint profiling_start_time_;
std::vector events_;
bool max_events_reached{false};
static constexpr size_t max_num_events_ = 1000000;
};
-} // namespace Profiling
+} // namespace profiling
} // namespace onnxruntime
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/common/status.cc b/Source/CNTKv2LibraryDll/proto/onnx/core/common/status.cc
index 8c9d80ac2..85b486b81 100644
--- a/Source/CNTKv2LibraryDll/proto/onnx/core/common/status.cc
+++ b/Source/CNTKv2LibraryDll/proto/onnx/core/common/status.cc
@@ -8,7 +8,7 @@ namespace onnxruntime {
namespace common {
Status::Status(StatusCategory category, int code, const std::string& msg) {
// state_ will be allocated here causing the status to be treated as a failure
- LOTUS_ENFORCE(code != static_cast(MLStatus::OK));
+ ONNXRUNTIME_ENFORCE(code != static_cast(MLStatus::OK));
state_ = std::make_unique(category, code, msg);
}
@@ -44,7 +44,7 @@ std::string Status::ToString() const {
result += "SystemError";
result += " : ";
result += std::to_string(errno);
- } else if (common::LOTUS == state_->category) {
+ } else if (common::ONNXRUNTIME == state_->category) {
result += "[LotusError]";
result += " : ";
result += std::to_string(Code());
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/common/task_thread_pool.h b/Source/CNTKv2LibraryDll/proto/onnx/core/common/task_thread_pool.h
index 75dad9f49..217c65189 100644
--- a/Source/CNTKv2LibraryDll/proto/onnx/core/common/task_thread_pool.h
+++ b/Source/CNTKv2LibraryDll/proto/onnx/core/common/task_thread_pool.h
@@ -145,7 +145,7 @@ class TaskThreadPool {
}
private:
- LOTUS_DISALLOW_COPY_ASSIGN_AND_MOVE(TaskThreadPool);
+ ONNXRUNTIME_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TaskThreadPool);
/// @brief Entry point for pool threads.
void MainLoop(std::size_t index) {
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/framework/tensorutils.cc b/Source/CNTKv2LibraryDll/proto/onnx/core/framework/tensorutils.cc
index ea88f574d..3889a1024 100644
--- a/Source/CNTKv2LibraryDll/proto/onnx/core/framework/tensorutils.cc
+++ b/Source/CNTKv2LibraryDll/proto/onnx/core/framework/tensorutils.cc
@@ -7,6 +7,7 @@
#pragma warning(disable : 4244)
#endif
#include "core/framework/tensorutils.h"
+#include "core/framework/allocator.h"
#include
@@ -50,34 +51,34 @@ static void UnpackTensorWithRawData(const ONNX_NAMESPACE::TensorProto& tensor, /
}
namespace onnxruntime {
-namespace Utils {
-#define DEFINE_UNPACK_TENSOR(T, Type, field_name, field_size) \
- template <> \
+namespace utils {
+#define DEFINE_UNPACK_TENSOR(T, Type, field_name, field_size) \
+ template <> \
Status TensorUtils::UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, /*out*/ T* p_data, int64_t expected_size) { \
- if (nullptr == p_data) { \
- const size_t size = tensor.has_raw_data() ? tensor.raw_data().size() : tensor.field_size(); \
- if (size == 0) \
- return Status::OK(); \
- else \
- return Status(common::LOTUS, common::INVALID_ARGUMENT); \
- } \
- if (nullptr == p_data || Type != tensor.data_type()) { \
- return Status(common::LOTUS, common::INVALID_ARGUMENT); \
- } \
- if (tensor.has_raw_data()) { \
- if (tensor.raw_data().size() != ((expected_size) * sizeof(T))) \
- return Status(common::LOTUS, common::FAIL, \
- "UnpackTensor: the pre-allocated size does not match the raw data size"); \
- UnpackTensorWithRawData(tensor, p_data); \
- return Status::OK(); \
- } \
- if (tensor.field_size() != expected_size) \
- return Status(common::LOTUS, common::FAIL, \
- "UnpackTensor: the pre-allocated size does not match the size in proto"); \
- const auto span = gsl::make_span(p_data, expected_size); \
- auto& data = tensor.field_name(); \
- std::copy(data.cbegin(), data.cend(), span.begin()); \
- return Status::OK(); \
+ if (nullptr == p_data) { \
+ const size_t size = tensor.has_raw_data() ? tensor.raw_data().size() : tensor.field_size(); \
+ if (size == 0) \
+ return Status::OK(); \
+ else \
+ return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); \
+ } \
+ if (nullptr == p_data || Type != tensor.data_type()) { \
+ return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); \
+ } \
+ if (tensor.has_raw_data()) { \
+ if (tensor.raw_data().size() != ((expected_size) * sizeof(T))) \
+ return Status(common::ONNXRUNTIME, common::FAIL, \
+ "UnpackTensor: the pre-allocated size does not match the raw data size"); \
+ UnpackTensorWithRawData(tensor, p_data); \
+ return Status::OK(); \
+ } \
+ if (tensor.field_size() != expected_size) \
+ return Status(common::ONNXRUNTIME, common::FAIL, \
+ "UnpackTensor: the pre-allocated size does not match the size in proto"); \
+ const auto span = gsl::make_span(p_data, expected_size); \
+ auto& data = tensor.field_name(); \
+ std::copy(data.cbegin(), data.cend(), span.begin()); \
+ return Status::OK(); \
}
//TODO: uint32 uint64 complex64 complex128
@@ -101,14 +102,14 @@ Status TensorUtils::UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor,
if (tensor.string_data_size() == 0)
return Status::OK();
else
- return Status(common::LOTUS, common::INVALID_ARGUMENT);
+ return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT);
}
if (ONNX_NAMESPACE::TensorProto_DataType_STRING != tensor.data_type()) {
- return Status(common::LOTUS, common::INVALID_ARGUMENT);
+ return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT);
}
if (tensor.string_data_size() != expected_size)
- return Status(common::LOTUS, common::FAIL,
+ return Status(common::ONNXRUNTIME, common::FAIL,
"UnpackTensor: the pre-allocate size does not match the size in proto");
const auto data = gsl::make_span(p_data, expected_size);
@@ -127,15 +128,15 @@ Status TensorUtils::UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor,
if (size == 0)
return Status::OK();
else
- return Status(common::LOTUS, common::INVALID_ARGUMENT);
+ return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT);
}
if (ONNX_NAMESPACE::TensorProto_DataType_BOOL != tensor.data_type()) {
- return Status(common::LOTUS, common::INVALID_ARGUMENT);
+ return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT);
}
if (tensor.has_raw_data()) {
if (tensor.raw_data().size() != (expected_size) * sizeof(bool))
- return Status(common::LOTUS, common::FAIL,
+ return Status(common::ONNXRUNTIME, common::FAIL,
"UnpackTensor: the pre-allocate size does not match the raw data size");
UnpackTensorWithRawData(tensor, p_data);
@@ -143,7 +144,7 @@ Status TensorUtils::UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor,
}
if (tensor.int32_data_size() != expected_size)
- return Status(common::LOTUS, common::FAIL,
+ return Status(common::ONNXRUNTIME, common::FAIL,
"UnpackTensor: the pre-allocate size does not match the size in proto");
const auto data = gsl::make_span(p_data, expected_size);
@@ -160,15 +161,15 @@ Status TensorUtils::UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor,
if (size == 0)
return Status::OK();
else
- return Status(common::LOTUS, common::INVALID_ARGUMENT);
+ return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT);
}
if (ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 != tensor.data_type()) {
- return Status(common::LOTUS, common::INVALID_ARGUMENT);
+ return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT);
}
if (tensor.has_raw_data()) {
if (tensor.raw_data().size() != (expected_size) * sizeof(uint16_t))
- return Status(common::LOTUS, common::FAIL,
+ return Status(common::ONNXRUNTIME, common::FAIL,
"UnpackTensor: the pre-allocate size does not match the raw data size");
UnpackTensorWithRawData(tensor, p_data);
@@ -176,7 +177,7 @@ Status TensorUtils::UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor,
}
if (tensor.int32_data_size() != expected_size)
- return Status(common::LOTUS, common::FAIL,
+ return Status(common::ONNXRUNTIME, common::FAIL,
"UnpackTensor: the pre-allocate size does not match the size in proto");
const auto data = gsl::make_span(p_data, expected_size);
@@ -186,44 +187,45 @@ Status TensorUtils::UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor,
return Status::OK();
}
-#define LOTUS_CASE_PROTO_TRACE(X, Y) \
- case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##X: \
- size *= sizeof(Y); \
+#define CASE_PROTO_TRACE(X, Y) \
+ case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##X: \
+ if (!IAllocator::CalcMemSizeForArrayWithAlignment(size, sizeof(Y), out)) { \
+ return common::Status(common::ONNXRUNTIME, common::FAIL, "Invalid TensorProto"); \
+ } \
break;
+template
common::Status GetSizeInBytesFromTensorProto(const ONNX_NAMESPACE::TensorProto& tensor_proto, size_t* out) {
const auto& dims = tensor_proto.dims();
- int64_t size = 1;
+ size_t size = 1;
for (int i = 0; i < dims.size(); ++i) {
if (dims[i] < 0) {
- size = -1;
- break;
+ return common::Status(common::ONNXRUNTIME, common::FAIL, "Invalid TensorProto");
+ }
+ if (!IAllocator::CalcMemSizeForArray(size, static_cast(dims[i]), &size)) {
+ return common::Status(common::ONNXRUNTIME, common::FAIL, "Invalid TensorProto");
}
- size *= dims[i];
}
- //If 'size' is too big, size*sizeof(T) could overflow. Then Allocator may allocate less memory than needed
- //Here max(sizeof(T)) is 8. 63 - 8 = 55.
- if (size < 0 || size >= (1LL << 55)) return common::Status(common::LOTUS, common::FAIL, "Invalid TensorProto");
switch (tensor_proto.data_type()) {
- LOTUS_CASE_PROTO_TRACE(FLOAT, float);
- LOTUS_CASE_PROTO_TRACE(DOUBLE, double);
- LOTUS_CASE_PROTO_TRACE(BOOL, bool);
- LOTUS_CASE_PROTO_TRACE(INT8, int8_t);
- LOTUS_CASE_PROTO_TRACE(INT16, int16_t);
- LOTUS_CASE_PROTO_TRACE(INT32, int32_t);
- LOTUS_CASE_PROTO_TRACE(INT64, int64_t);
- LOTUS_CASE_PROTO_TRACE(UINT8, uint8_t);
- LOTUS_CASE_PROTO_TRACE(UINT16, uint16_t);
- LOTUS_CASE_PROTO_TRACE(UINT32, uint32_t);
- LOTUS_CASE_PROTO_TRACE(UINT64, uint64_t);
- LOTUS_CASE_PROTO_TRACE(FLOAT16, MLFloat16);
+ CASE_PROTO_TRACE(FLOAT, float);
+ CASE_PROTO_TRACE(DOUBLE, double);
+ CASE_PROTO_TRACE(BOOL, bool);
+ CASE_PROTO_TRACE(INT8, int8_t);
+ CASE_PROTO_TRACE(INT16, int16_t);
+ CASE_PROTO_TRACE(INT32, int32_t);
+ CASE_PROTO_TRACE(INT64, int64_t);
+ CASE_PROTO_TRACE(UINT8, uint8_t);
+ CASE_PROTO_TRACE(UINT16, uint16_t);
+ CASE_PROTO_TRACE(UINT32, uint32_t);
+ CASE_PROTO_TRACE(UINT64, uint64_t);
+ CASE_PROTO_TRACE(FLOAT16, MLFloat16);
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_STRING:
default:
- return common::Status(common::LOTUS, common::NOT_IMPLEMENTED);
+ return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED);
}
- *out = size;
return Status::OK();
}
-} // namespace Utils
+template common::Status GetSizeInBytesFromTensorProto<256>(const ONNX_NAMESPACE::TensorProto& tensor_proto, size_t* out);
+} // namespace utils
} // namespace onnxruntime
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/framework/tensorutils.h b/Source/CNTKv2LibraryDll/proto/onnx/core/framework/tensorutils.h
index 218c97b6e..a38e3c282 100644
--- a/Source/CNTKv2LibraryDll/proto/onnx/core/framework/tensorutils.h
+++ b/Source/CNTKv2LibraryDll/proto/onnx/core/framework/tensorutils.h
@@ -13,10 +13,11 @@ namespace ONNX_NAMESPACE {
class TensorProto;
}
namespace onnxruntime {
-namespace Utils {
+namespace utils {
//How much memory it will need for putting the content of this tensor into a plain array
//string/complex64/complex128 tensors are not supported.
//The output value could be zero or -1.
+template
common::Status GetSizeInBytesFromTensorProto(const ONNX_NAMESPACE::TensorProto& tensor_proto, size_t* out);
class TensorUtils {
public:
@@ -26,5 +27,5 @@ class TensorUtils {
int64_t expected_size);
}; // namespace Utils
-} // namespace Utils
+} // namespace utils
} // namespace onnxruntime
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/function.cc b/Source/CNTKv2LibraryDll/proto/onnx/core/graph/function.cc
index 4d41c3ff6..036f3e11e 100644
--- a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/function.cc
+++ b/Source/CNTKv2LibraryDll/proto/onnx/core/graph/function.cc
@@ -4,12 +4,75 @@
#include "core/graph/function_impl.h"
#include "core/graph/graph.h"
#include "core/graph/function_container.h"
+#include "onnx/shape_inference/implementation.h"
namespace onnxruntime {
+void TypeConstraintHelper(const ONNX_NAMESPACE::FunctionProto* onnx_func_proto_,
+ std::unique_ptr& op_schema_,
+ /*out*/
+ std::unordered_map& input_name_idx_map,
+ std::unordered_map& output_name_idx_map) {
+ std::vector> input_types_list(onnx_func_proto_->input_size());
+ std::vector> output_types_list(onnx_func_proto_->output_size());
+ std::unordered_map> type_constraint_map;
+ for (int i = 0; i < onnx_func_proto_->input_size(); ++i) {
+ input_name_idx_map[onnx_func_proto_->input().Get(i)] = i;
+ }
+ for (int i = 0; i < onnx_func_proto_->output_size(); ++i) {
+ output_name_idx_map[onnx_func_proto_->output().Get(i)] = i;
+ }
+ auto schema_registry = ONNX_NAMESPACE::OpSchemaRegistry::Instance();
+ for (auto& node : onnx_func_proto_->node()) {
+ const auto node_op_schema = schema_registry->GetSchema(node.op_type(), (int)onnx_func_proto_->since_version(), node.domain());
+ for (int i = 0; i < node.input_size(); ++i) {
+ auto& in_name = node.input().Get(i);
+ if (input_name_idx_map.count(in_name)) {
+ int idx = input_name_idx_map[in_name];
+ const auto& p = node_op_schema->inputs().at(i);
+ std::string type_str = p.GetTypeStr() + "in" + std::to_string(i);
+ input_types_list[idx] = std::make_pair(in_name, type_str);
+ if (!type_constraint_map.count(type_str)) {
+ for (auto s : p.GetTypes()) {
+ type_constraint_map[type_str].emplace_back(*s);
+ }
+ }
+ }
+ }
+ for (int i = 0; i < node.output_size(); ++i) {
+ auto& out_name = node.output().Get(i);
+ if (output_name_idx_map.count(out_name)) {
+ int idx = output_name_idx_map[out_name];
+ const auto& p = node_op_schema->outputs().at(i);
+ std::string type_str = p.GetTypeStr() + "out" + std::to_string(i);
+ output_types_list[idx] = std::make_pair(out_name, type_str);
+ if (!type_constraint_map.count(type_str)) {
+ for (auto s : p.GetTypes()) {
+ type_constraint_map[type_str].emplace_back(*s);
+ }
+ }
+ }
+ }
+ }
+
+ int i = 0;
+ for (auto& input : input_types_list) {
+ op_schema_->Input(i, input.first, "", input.second);
+ ++i;
+ }
+ i = 0;
+ for (auto& output : output_types_list) {
+ op_schema_->Output(i, output.first, "", output.second);
+ ++i;
+ }
+
+ for (auto& tc : type_constraint_map) {
+ op_schema_->TypeConstraint(tc.first, tc.second, "");
+ }
+}
FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,
- std::unique_ptr customized_func) {
- parent_graph_ = &graph;
+ std::unique_ptr customized_func)
+ : parent_graph_(&graph) {
customized_func_body_ = std::move(customized_func);
auto meta_def = customized_func_body_->GetMetaDef();
op_schema_ = std::make_unique();
@@ -31,11 +94,8 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,
}
op_schema_->Finalize();
//construct body
- std::unordered_map domain_to_version;
- //TODO: set correct domain and version
- domain_to_version[onnxruntime::kOnnxDomain] = 7;
body_ = std::make_unique("fused_function_subgraph", false, onnxruntime::ModelMetaData(),
- /*TODO: get custom schema*/ nullptr, domain_to_version);
+ IOnnxRuntimeOpSchemaRegistryList({graph.GetSchemaRegistry()}), graph.DomainToVersionMap());
auto& sub_graph = body_->MainGraph();
//Add node and node args
@@ -55,7 +115,80 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,
sub_graph.AddNode(node->Name(), node->OpType(), node->Description(), inputs, outputs, &node->GetAttributes(), node->Domain());
}
//TODO: if we reuse the nodes in parent graph, maybe we don't need to resolve it.
- LOTUS_ENFORCE(sub_graph.Resolve().IsOK());
+ ONNXRUNTIME_ENFORCE(sub_graph.Resolve().IsOK());
+}
+
+FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,
+ const onnxruntime::NodeIndex& node_index,
+ const ONNX_NAMESPACE::FunctionProto* onnx_func_proto)
+ : parent_graph_(&graph) {
+ onnx_func_proto_ = onnx_func_proto;
+ auto node_in_parent_graph = parent_graph_->GetNode(node_index);
+ op_schema_ = std::make_unique();
+ op_schema_->SetName(onnx_func_proto_->name());
+ op_schema_->SetDomain(onnx_func_proto_->node().Get(0).domain());
+ op_schema_->SetDoc(onnx_func_proto_->doc_string());
+ op_schema_->SinceVersion((ONNX_NAMESPACE::OperatorSetVersion)onnx_func_proto_->since_version());
+ std::unordered_map input_name_idx_map;
+ std::unordered_map output_name_idx_map;
+ TypeConstraintHelper(onnx_func_proto_, this->op_schema_, input_name_idx_map, output_name_idx_map);
+
+ op_schema_->TypeAndShapeInferenceFunction(
+ [this](ONNX_NAMESPACE::InferenceContext& ctx) {
+ auto schema_registry = ONNX_NAMESPACE::OpSchemaRegistry::Instance();
+ const ONNX_NAMESPACE::FunctionProto* func_ptr = this->GetFuncProto();
+ if (nullptr != func_ptr) {
+ ONNX_NAMESPACE::shape_inference::InferShapeForFunctionNode(*func_ptr, schema_registry, ctx);
+ }
+ });
+
+ op_schema_->Finalize();
+ //construct body
+ std::unordered_map domain_to_version;
+ //TODO: set correct domain and version
+ domain_to_version[onnxruntime::kOnnxDomain] = (int)onnx_func_proto_->since_version();
+ body_ = std::make_unique(onnx_func_proto_->name(), false, onnxruntime::ModelMetaData(),
+ IOnnxRuntimeOpSchemaRegistryList(), domain_to_version);
+ auto& sub_graph = body_->MainGraph();
+ //Add node and node args into subgraph
+ auto attr_map = node_in_parent_graph->GetAttributes();
+ for (auto& node : onnx_func_proto_->node()) {
+ std::vector inputs, outputs;
+ for (int idx = 0; idx < node.input_size(); ++idx) {
+ std::string tensor_name = node.input().Get(idx);
+ if (input_name_idx_map.count(tensor_name)) {
+ ONNX_NAMESPACE::NodeProto temp_node_proto;
+ node_in_parent_graph->ToProto(temp_node_proto);
+ const onnxruntime::NodeArg* node_arg = parent_graph_->GetNodeArg(temp_node_proto.input().Get(input_name_idx_map[tensor_name]));
+ auto& n_input = sub_graph.GetOrCreateNodeArg(
+ tensor_name, node_arg->TypeAsProto());
+ inputs.push_back(&n_input);
+ } else {
+ auto& n_input = sub_graph.GetOrCreateNodeArg(
+ tensor_name, nullptr);
+ inputs.push_back(&n_input);
+ }
+ }
+ for (int idx = 0; idx < node.output_size(); ++idx) {
+ std::string tensor_name = node.output().Get(idx);
+ auto& n_output = sub_graph.GetOrCreateNodeArg(tensor_name, nullptr);
+ outputs.push_back(&n_output);
+ }
+
+ onnxruntime::NodeAttributes new_attr_map;
+ for (auto& attr : node.attribute()) {
+ if (attr.has_ref_attr_name()) {
+ if (attr_map.count(attr.ref_attr_name())) {
+ new_attr_map[attr.name()] = attr_map[attr.ref_attr_name()];
+ }
+ } else {
+ new_attr_map[attr.name()] = attr;
+ }
+ }
+ sub_graph.AddNode(node.name(), node.op_type(), node.doc_string(), inputs, outputs, &new_attr_map, node.domain());
+ }
+ auto status = sub_graph.Resolve();
+ ONNXRUNTIME_ENFORCE(status.IsOK());
}
const ONNX_NAMESPACE::OpSchema& FunctionImpl::OpSchema() const {
@@ -70,6 +203,10 @@ const IndexedSubGraph& FunctionImpl::GetIndexedSubGraph() const {
return *customized_func_body_;
}
+const ONNX_NAMESPACE::FunctionProto* FunctionImpl::GetFuncProto() const {
+ return onnx_func_proto_;
+}
+
std::unique_ptr MakeFunction(const onnxruntime::Graph& graph,
std::unique_ptr customized_func) {
return std::make_unique(graph, std::move(customized_func));
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/function_impl.h b/Source/CNTKv2LibraryDll/proto/onnx/core/graph/function_impl.h
index 4e38dc762..a3a1e0c75 100644
--- a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/function_impl.h
+++ b/Source/CNTKv2LibraryDll/proto/onnx/core/graph/function_impl.h
@@ -14,22 +14,29 @@ class Node;
namespace onnxruntime {
// Function representation class.
-class FunctionImpl : public Function {
+class FunctionImpl final : public Function {
public:
FunctionImpl(const onnxruntime::Graph& graph,
- std::unique_ptr customized_func);
+ std::unique_ptr customized_func);
- virtual const ONNX_NAMESPACE::OpSchema& OpSchema() const override;
+ FunctionImpl(const onnxruntime::Graph& graph,
+ const onnxruntime::NodeIndex& node_index,
+ const ONNX_NAMESPACE::FunctionProto* onnx_func);
- virtual const onnxruntime::GraphBase& Body() const override;
+ const ONNX_NAMESPACE::OpSchema& OpSchema() const override;
- virtual const IndexedSubGraph& GetIndexedSubGraph() const override;
+ const onnxruntime::GraphBase& Body() const override;
+
+ const IndexedSubGraph& GetIndexedSubGraph() const override;
+
+ const ONNX_NAMESPACE::FunctionProto* GetFuncProto() const;
private:
- const onnxruntime::Graph* parent_graph_;
+ const onnxruntime::Graph* const parent_graph_;
std::unique_ptr customized_func_body_;
std::unique_ptr op_schema_;
std::unique_ptr body_;
+ const ONNX_NAMESPACE::FunctionProto* onnx_func_proto_;
};
} // namespace onnxruntime
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/graph.cc b/Source/CNTKv2LibraryDll/proto/onnx/core/graph/graph.cc
index 87bf3bc36..a9247c6b2 100644
--- a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/graph.cc
+++ b/Source/CNTKv2LibraryDll/proto/onnx/core/graph/graph.cc
@@ -13,6 +13,7 @@
#include "gsl/pointers"
#include "core/graph/function.h"
+#include "core/graph/function_impl.h"
#include "core/graph/graph.h"
#include "core/graph/indexed_sub_graph.h"
#include "core/graph/op.h"
@@ -194,7 +195,7 @@ void Node::ToProto(NodeProto& proto) const {
// Set attributes.
proto.clear_attribute();
for (auto attribute : attributes_) {
- const gsl::not_null attr = proto.add_attribute();
+ const gsl::not_null attr{proto.add_attribute()};
*attr = attribute.second;
}
@@ -318,8 +319,9 @@ Status Node::UpdateInputArgCount() {
definitions_.input_arg_count.cend(), 0);
if (total_arg_count < 0 || static_cast(total_arg_count) != definitions_.input_defs.size()) {
- return LOTUS_MAKE_STATUS(LOTUS, FAIL,
- "The sum of input arg count is not equal to size of input defs in node (", name_, ")");
+ return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL,
+ "The sum of input arg count is not equal to size of input defs in node (",
+ name_, ")");
}
// op_ is always valid when this is called
@@ -370,12 +372,12 @@ const NodeAttributes& Node::GetAttributes() const noexcept {
}
void Node::ForEachDef(std::function func) const {
- for (const gsl::not_null arg : InputDefs()) {
+ for (const auto* arg : InputDefs()) {
if (!arg->Exists())
continue;
func(&*arg, true);
}
- for (const gsl::not_null arg : OutputDefs()) {
+ for (const auto* arg : OutputDefs()) {
if (!arg->Exists())
continue;
func(&*arg, false);
@@ -383,7 +385,7 @@ void Node::ForEachDef(std::function func) const {
- for (const gsl::not_null arg : InputDefs()) {
+ for (const auto* arg : InputDefs()) {
if (!arg->Exists())
continue;
func(&*arg);
@@ -391,7 +393,7 @@ void Node::ForEachInputDef(std::function func
};
void Node::ForEachOutputDef(std::function func) const {
- for (const gsl::not_null arg : OutputDefs()) {
+ for (const auto* arg : OutputDefs()) {
if (!arg->Exists())
continue;
func(&*arg);
@@ -402,7 +404,7 @@ void Node::ReplaceDefs(const std::map*> all_defs = {&definitions_.input_defs, &definitions_.output_defs};
for (auto pair : replacements)
- for (const gsl::not_null*> defs : all_defs)
+ for (auto* defs : all_defs)
for (auto& def : *defs)
if (def == pair.first)
def = pair.second;
@@ -428,14 +430,14 @@ using google::protobuf::RepeatedPtrField;
Graph::Graph(GraphProto* graph_proto,
const std::unordered_map& domain_to_version,
Version ir_version,
- ILotusOpSchemaCollectionPtr schema_registry)
+ IOnnxRuntimeOpSchemaCollectionPtr schema_registry)
: GraphBase(/* resolve needed */ true, /* proto sync needed */ false, domain_to_version, ir_version),
graph_proto_{graph_proto},
graph_type_{Type::Main},
schema_registry_(schema_registry),
function_container_(std::make_unique()) {
- LOTUS_ENFORCE(graph_proto != nullptr, "graph_proto cannot be null");
+ ONNXRUNTIME_ENFORCE(graph_proto != nullptr, "graph_proto cannot be null");
ArgNameToTypeMap name_to_type_map;
// these are all empty unless we received a graph_proto as input
@@ -443,14 +445,14 @@ Graph::Graph(GraphProto* graph_proto,
// Copy constant nodes _value to name_to_initial_tensor_
for (auto& node : graph_proto_->node()) {
if (node.op_type() == kConstant) {
- const gsl::not_null tensor = graph_proto_->add_initializer();
+ const gsl::not_null tensor{graph_proto_->add_initializer()};
*tensor = node.attribute(0).t();
*(tensor->mutable_name()) = node.output(0);
}
}
// remove constant nodes
- const gsl::not_null*> graph_mutable_nodes = graph_proto_->mutable_node();
+ const gsl::not_null*> graph_mutable_nodes{graph_proto_->mutable_node()};
graph_mutable_nodes->erase(
std::remove_if(graph_mutable_nodes->begin(), graph_mutable_nodes->end(),
[](NodeProto& p) {
@@ -469,6 +471,8 @@ Graph::Graph(GraphProto* graph_proto,
for (auto& graph_input : graph_proto_->input()) {
if (graph_input.has_name() && graph_input.has_type()) {
name_to_type_map[graph_input.name()] = graph_input.type();
+ // always create a NodeArg for graph input in case its from an initializer
+ GetOrCreateNodeArg(graph_input.name(), &graph_input.type());
}
}
@@ -486,13 +490,10 @@ Graph::Graph(GraphProto* graph_proto,
name_to_type_map[node_arg.name()] = node_arg.type();
}
}
- }
- // Add nodes.
- AddSourceSinkNodes();
-
- for (auto node_proto : graph_proto_->node()) {
- AddNode(node_proto, name_to_type_map);
+ for (auto node_proto : graph_proto_->node()) {
+ AddNode(node_proto, name_to_type_map);
+ }
}
}
@@ -514,7 +515,7 @@ Status GraphBase::VerifyNoDuplicateName(/*in*/ const std::unordered_set output_def : node.OutputDefs()) {
+ for (const auto* output_def : node.OutputDefs()) {
if (output_def->Exists()) {
auto& output_arg_name = output_def->Name();
if (inputs_and_initializers.count(output_arg_name)) {
- Status status(LOTUS, FAIL,
+ Status status(ONNXRUNTIME, FAIL,
"Error: Duplicate definition of name (" + output_arg_name + ").");
return status;
}
auto result = output_args.insert({output_arg_name, &node});
if (!result.second) {
// Two outputs with same name, so that insertion fails.
- Status status(LOTUS, FAIL,
+ Status status(ONNXRUNTIME, FAIL,
"Error: Duplicate definition of name (" + output_arg_name + ").");
return status;
}
@@ -549,14 +550,10 @@ Status GraphBase::BuildConnections(const std::unordered_map&
std::unordered_set inner_nodes;
for (auto& node : Nodes()) {
- if (IsSourceNode(node) || IsSinkNode(node)) {
- continue;
- }
-
for (auto& control_input : node.ControlInputs()) {
auto name_to_index_iter = node_name_to_index.find(control_input);
if (node_name_to_index.end() == name_to_index_iter) {
- Status status(LOTUS, FAIL,
+ Status status(ONNXRUNTIME, FAIL,
"The control input (" + control_input + ") of Node (" +
node.Name() + ") does not exist in the graph.");
return status;
@@ -566,7 +563,7 @@ Status GraphBase::BuildConnections(const std::unordered_map&
const NodeIndex dst_node_index = node.Index();
auto dst = GetNode(dst_node_index);
auto src = GetNode(src_node_index);
- LOTUS_ENFORCE(dst && src, "ControlInputs should not have invalid nodes. dst=", dst, " src=", src);
+ ONNXRUNTIME_ENFORCE(dst && src, "ControlInputs should not have invalid nodes. dst=", dst, " src=", src);
src->MutableRelationships().output_nodes.insert(dst);
dst->MutableRelationships().input_nodes.insert(src);
}
@@ -575,7 +572,7 @@ Status GraphBase::BuildConnections(const std::unordered_map&
if (input_args.size() > 0) {
// This node needs inputs.
- for (const gsl::not_null input_arg : input_args) {
+ for (const auto* input_arg : input_args) {
if (!input_arg->Exists()) {
// This input could be optional and it does not exist in this case.
continue;
@@ -585,52 +582,28 @@ Status GraphBase::BuildConnections(const std::unordered_map&
if (output_args.end() == output_arg_iter) {
// No such output_arg matching this input_arg.
// This input arg should be fed when running evaluation.
-
- // Add a control edge between node and this node.
- NO_CHANGE_ON_SYNC_FLAG(AddControlEdge(source_node_index_, node.Index()));
continue;
}
// Setup input/output relationship between <*node_iter>
// and .
Node& output_node = *output_arg_iter->second;
-
node.MutableRelationships().input_nodes.insert(&output_node);
-
auto new_edge = std::make_unique(output_node, *input_arg);
node.MutableRelationships().input_edges.insert(new_edge.get());
owned_edges_.push_back(std::move(new_edge));
output_node.MutableRelationships().output_nodes.insert(&node);
-
new_edge = std::make_unique(node, *input_arg);
output_node.MutableRelationships().output_edges.insert(new_edge.get());
owned_edges_.push_back(std::move(new_edge));
inner_nodes.insert(&output_node);
}
- } else {
- if (node.OutputDefs().size() <= 0) {
- // This is a useless node.
- // It has no input/output.
- RemoveNode(node.Index());
- }
-
- // This is a starting node.
- // Add a control edge between node and this node.
- NO_CHANGE_ON_SYNC_FLAG(AddControlEdge(source_node_index_, node.Index()));
- }
- }
-
- for (auto& node : Nodes()) {
- if (IsSourceNode(node) || IsSinkNode(node)) {
- continue;
- }
-
- if (inner_nodes.empty() || inner_nodes.end() == inner_nodes.find(&node)) {
- // This is an ending node.
- // Add a control edge from this node to sink node.
- NO_CHANGE_ON_SYNC_FLAG(AddControlEdge(node.Index(), sink_node_index_));
+ } else if (node.OutputDefs().size() <= 0) {
+ // This is a useless node.
+ // It has no input/output.
+ RemoveNode(node.Index());
}
}
@@ -684,7 +657,7 @@ void GraphBase::ReverseDFSFrom(const std::vector& from,
sorted_nodes.push_back((*iter));
}
std::sort(sorted_nodes.begin(), sorted_nodes.end(), comp);
- for (gsl::not_null in : sorted_nodes) {
+ for (const auto* in : sorted_nodes) {
const NodeIndex idx = in->Index();
if (!visited[idx]) {
stack.emplace_back(in, false);
@@ -702,89 +675,101 @@ void GraphBase::ReverseDFSFrom(const std::vector& from,
}
GSL_SUPPRESS(es .84) // noisy warning about ignoring return value from insert(...)
-Status GraphBase::CheckIsAcyclic(std::vector& nodes_in_topological_order) const {
- nodes_in_topological_order.clear();
+Status GraphBase::PerformTopologicalSortAndCheckIsAcyclic() {
+ std::vector& nodes_in_topological_order{MutableNodesInTopologicalOrder()};
+ nodes_in_topological_order.clear();
- // nodes that have been processed and added to nodes_in_topological_order.
- std::unordered_set processed_nodes;
- std::unordered_set output_nodes;
- std::unordered_set nodes_added_for_processing;
- std::stack stack;
+ // nodes that have been processed and added to nodes_in_topological_order.
+ std::unordered_set processed_nodes;
+ std::unordered_set output_nodes;
+ std::unordered_set nodes_added_for_processing;
+ std::stack stack;
- // push the top level nodes into nodes_in_topological_order in the order they were added
- // to ensure that is consistent.
- auto& nodes_in_original_order = Nodes();
- for (GraphNodes::ConstNodeIterator it = nodes_in_original_order.cbegin(); it != nodes_in_original_order.cend(); ++it)
- {
- const Node& node = *it;
- auto index = node.Index();
+ // push the top level nodes into nodes_in_topological_order in the order they were added
+ // to ensure that is consistent.
+ auto& nodes_in_original_order = Nodes();
+ std::for_each(nodes_in_original_order.cbegin(), nodes_in_original_order.cend(),
+ [&](const Node& node) {
+ auto index = node.Index();
- // find the top level nodes in the graph
- if (node.GetRelationships().input_edges.size() == 0 && index != sink_node_index_) {
- // add to the topological list, and ensure we skip these nodes when walking the graph
- nodes_in_topological_order.push_back(index);
- processed_nodes.insert(index);
+ // find the top level nodes in the graph.
+ // need to also consider nodes that only have Constants as inputs as top level nodes,
+ // as the constant will get replaced by an initializer.
+ auto input_edges = node.GetRelationships().input_edges;
+ auto has_inputs = std::any_of(input_edges.cbegin(), input_edges.cend(), [](Node::EdgeEnd* edge) {
+ return edge->GetNode().OpType() != kConstant;
+ });
- // mark this as added as we've fully processed it and don't need to do it again later
- nodes_added_for_processing.insert(index);
- }
+ if (!has_inputs) {
+ // add to the topological list, and ensure we skip these nodes when walking the graph
+ nodes_in_topological_order.push_back(index);
+ processed_nodes.insert(index);
+
+ // mark this as added as we've fully processed it and don't need to do it again later
+ nodes_added_for_processing.insert(index);
+ }
+ });
+
+ // start at the bottom and work our way up the graph
+ for (auto iter = Nodes().begin(); iter != Nodes().end(); ++iter) {
+ if (0 == iter->relationships_.output_edges.size()) {
+ // This is a leaf node.
+ stack.push(iter->Index());
+ }
+ }
+
+ while (!stack.empty()) {
+ const NodeIndex current = stack.top();
+ stack.pop();
+
+ if (processed_nodes.find(current) != processed_nodes.end()) {
+ continue;
}
- // start at the bottom and work our way up the graph
- stack.push(sink_node_index_);
-
- while (!stack.empty()) {
- const NodeIndex current = stack.top();
- stack.pop();
-
- if (processed_nodes.find(current) != processed_nodes.end()) {
- continue;
- }
-
- if (nodes_added_for_processing.find(current) != nodes_added_for_processing.end()) {
- // we popped the stack and are back to a node that was added previously,
- // so we know all the upstream nodes from it have been fully processed,
- nodes_in_topological_order.push_back(current);
- processed_nodes.insert(current);
- output_nodes.erase(current);
- continue;
- }
-
- const Node* node = GetNode(current);
- if (!node) {
- continue;
- }
-
- stack.push(current);
- output_nodes.insert(current);
-
- // push the node's inputs onto the stack in reverse order so that when we finish processing each one
- // and pop them from the stack they get added to nodes_in_topological_order in their original order
- for (auto iter = std::make_reverse_iterator(node->InputNodesEnd()),
- end = std::make_reverse_iterator(node->InputNodesBegin());
- iter != end; ++iter) {
- const NodeIndex idx = (*iter)->Index();
- if (output_nodes.find(idx) != output_nodes.end()) {
- Status status(LOTUS, FAIL, "Error: the graph is not acyclic.");
- return status;
- }
-
- // avoid re-processing nodes
- if (nodes_added_for_processing.find(idx) == nodes_added_for_processing.end()) {
- stack.push(idx);
- }
- }
-
- nodes_added_for_processing.insert(current);
+ if (nodes_added_for_processing.find(current) != nodes_added_for_processing.end()) {
+ // we popped the stack and are back to a node that was added previously,
+ // so we know all the upstream nodes from it have been fully processed,
+ nodes_in_topological_order.push_back(current);
+ processed_nodes.insert(current);
+ output_nodes.erase(current);
+ continue;
}
- if (num_of_nodes_ >= 0 && static_cast(num_of_nodes_) == nodes_in_topological_order.size()) {
- return Status::OK();
+ const Node* node = GetNode(current);
+ if (!node) {
+ continue;
}
- else {
- return Status(LOTUS, FAIL, "Error: the graph is not acyclic.");
+
+ stack.push(current);
+ output_nodes.insert(current);
+
+ // push the node's inputs onto the stack in reverse order so that when we finish processing each one
+ // and pop them from the stack they get added to nodes_in_topological_order in their original order
+ for (auto iter = std::make_reverse_iterator(node->InputNodesEnd()),
+ end = std::make_reverse_iterator(node->InputNodesBegin());
+ iter != end; ++iter) {
+ const NodeIndex idx = (*iter)->Index();
+ if (output_nodes.find(idx) != output_nodes.end()) {
+ Status status(ONNXRUNTIME, FAIL, "Error: the graph is not acyclic.");
+ return status;
+ }
+
+ // avoid re-processing nodes
+ if (nodes_added_for_processing.find(idx) == nodes_added_for_processing.end()) {
+ stack.push(idx);
+ }
}
+
+ nodes_added_for_processing.insert(current);
+ }
+
+ if (num_of_nodes_ >= 0 && static_cast(num_of_nodes_) == nodes_in_topological_order.size()) {
+ return Status::OK();
+ } else {
+ return Status(ONNXRUNTIME, FAIL, "Error: the graph is not acyclic.");
+ }
}
+
bool FullyDefinedType(const TypeProto& type_proto) {
switch (type_proto.value_case()) {
case TypeProto::kTensorType: {
@@ -900,7 +885,7 @@ Status Graph::InferAndVerifyTypeMatch(Node& node,
if (input_def->Type() == nullptr) {
// Logic error: This should not happen if we properly checked that every use has
// a corresponding def, for which type-inference already produced a valid type
- Status status(LOTUS, FAIL,
+ Status status(ONNXRUNTIME, FAIL,
"Node (" + nodeName + ") input arg (" +
input_def->Name() + ") does not have type information set by parent node.");
return status;
@@ -914,7 +899,7 @@ Status Graph::InferAndVerifyTypeMatch(Node& node,
if (input_type == nullptr) input_type = &null_pointer;
// Type error in input model/graph.
- Status status(LOTUS, INVALID_GRAPH,
+ Status status(ONNXRUNTIME, INVALID_GRAPH,
"Type Error: Type '" + *input_type + "' of input parameter (" + input_def->Name() +
") of operator (" + op.Name() + ") in node (" + nodeName + ") is invalid.");
return status;
@@ -934,7 +919,7 @@ Status Graph::InferAndVerifyTypeMatch(Node& node,
// However, this will need to be extended to handle the If-Then-Else and Loop
// constructs in future which will have variadic inputs and outputs of different types.
- Status status(LOTUS, FAIL,
+ Status status(ONNXRUNTIME, FAIL,
"Type Error: Type parameter (" + op_formal_parameter.GetTypeStr() +
") bound to different types (" + *(param_to_type_iter->second) +
" and " + *(input_def->Type()) +
@@ -947,9 +932,9 @@ Status Graph::InferAndVerifyTypeMatch(Node& node,
// Apply ONNX's shape/type inference to node
std::vector onnx_inferred_types;
try {
- LOTUS_RETURN_IF_ERROR(InferOutputTypesAndShapes(node, onnx_inferred_types));
+ ONNXRUNTIME_RETURN_IF_ERROR(InferOutputTypesAndShapes(node, onnx_inferred_types));
} catch (const std::exception& ex) {
- return Status(LOTUS, FAIL, ex.what());
+ return Status(ONNXRUNTIME, FAIL, ex.what());
}
// Infer and verify node output arg type information.
@@ -985,14 +970,14 @@ Status Graph::InferAndVerifyTypeMatch(Node& node,
inferred_type = existing_type;
} else {
// This should not happen: indicates incompleteness in ONNX inference.
- Status status(LOTUS, FAIL,
+ Status status(ONNXRUNTIME, FAIL,
"Node (" + nodeName + ") output arg (" + output_def->Name() + ") type inference failed");
return status;
}
if ((existing_type != inferred_type) && (existing_type != nullptr)) {
// A type exists for this output but does not match the inferred type.
- return Status(LOTUS, FAIL,
+ return Status(ONNXRUNTIME, FAIL,
"Type Error: Type (" + *existing_type + ") of output arg (" +
output_def->Name() + ") of node (" + nodeName +
") does not match expected type (" + *inferred_type + ").");
@@ -1020,7 +1005,7 @@ Status Graph::InferAndVerifyTypeMatch(Node& node,
// Check that the type of every input is specified:
for (auto* graph_input : GetInputs()) {
if (nullptr == graph_input->Type()) {
- Status status(LOTUS, FAIL, "Model input (" + graph_input->Name() + ") does not have type information.");
+ Status status(ONNXRUNTIME, FAIL, "Model input (" + graph_input->Name() + ") does not have type information.");
return status;
}
}
@@ -1031,7 +1016,7 @@ Status Graph::InferAndVerifyTypeMatch(Node& node,
// Infer/check type and shape for all initializers from their values
for (auto& initializer_pair : name_to_initial_tensor_) {
const std::string& name = initializer_pair.first;
- auto* node_arg = FindNodeArg(name);
+ auto* node_arg = GetNodeArg(name);
// If node_arg is null, we ignore this as a potentially unused initializer here
if (nullptr != node_arg) {
const TensorProto* tensor_proto = initializer_pair.second;
@@ -1042,7 +1027,7 @@ Status Graph::InferAndVerifyTypeMatch(Node& node,
if (nullptr == existing_type)
node_arg->SetType(inferred_type);
else if (inferred_type != existing_type) {
- return Status(LOTUS, FAIL,
+ return Status(ONNXRUNTIME, FAIL,
"Type Error: Value of initializer " + name + " does not match its type.");
}
@@ -1056,12 +1041,12 @@ Status Graph::InferAndVerifyTypeMatch(Node& node,
node_arg->SetShape(inferred_shape);
else {
if (p_existing_shape->dim_size() != tensor_proto->dims_size())
- return Status(LOTUS, FAIL,
+ return Status(ONNXRUNTIME, FAIL,
"Type Error: Shape of initializer " + name + " does not match its type.");
for (int i = 0; i < p_existing_shape->dim_size(); ++i) {
auto& d = p_existing_shape->dim(i);
if (d.has_dim_value() && (d.dim_value() != tensor_proto->dims(i)))
- return Status(LOTUS, FAIL,
+ return Status(ONNXRUNTIME, FAIL,
"Type Error: Shape of initializer " + initializer_pair.first + " does not match its type.");
}
}
@@ -1070,25 +1055,20 @@ Status Graph::InferAndVerifyTypeMatch(Node& node,
return Status::OK();
}
-Status Graph::VerifyNodeAndOpMatch(const std::vector& nodes_in_topological_order,
- const std::unordered_map& output_args) {
- LOTUS_RETURN_IF_ERROR(TypeCheckInputsAndInitializers());
-
- for (auto nodeIndex : nodes_in_topological_order) {
- if (IsSourceNode(nodeIndex) || IsSinkNode(nodeIndex)) {
- continue;
- }
+Status Graph::VerifyNodeAndOpMatch(const std::unordered_set& inputs_and_initializers) {
+ ONNXRUNTIME_RETURN_IF_ERROR(TypeCheckInputsAndInitializers());
+ // Initialize list of topologically avaliable tensor names
+ LexicalScopeContext lsc;
+ for (auto& initializer_input_name : inputs_and_initializers) {
+ lsc.output_names.insert(initializer_input_name);
+ }
+ for (auto& nodeIndex : NodesInTopologicalOrder()) {
// Node verification.
auto& node = *GetNode(nodeIndex);
CheckerContext ctx;
ctx.set_ir_version(gsl::narrow_cast(IrVersion()));
ctx.set_opset_imports(DomainToVersionMap());
ctx.set_schema_registry(schema_registry_.get());
- LexicalScopeContext lsc;
- for (auto& kv : output_args) {
- GSL_SUPPRESS(es .84)
- lsc.output_names.insert(kv.first);
- }
NodeProto node_proto;
node.ToProto(node_proto);
auto& node_name = node.Name();
@@ -1097,18 +1077,33 @@ Status Graph::VerifyNodeAndOpMatch(const std::vector& nodes_in_topolo
if (!node.Op()) {
try {
checker::check_node(node_proto, ctx, lsc);
+ // Accumulate output names of the iterated tensor
+ for (auto& output_name : node_proto.output()) {
+ lsc.output_names.insert(output_name);
+ }
} catch (const std::exception& ex) {
- return Status(LOTUS, FAIL, ex.what());
+ return Status(ONNXRUNTIME, FAIL, ex.what());
}
auto maxInclusiveVersion = DomainToVersionMap().find(domain)->second;
node.op_ = schema_registry_->GetSchema(node.OpType(), maxInclusiveVersion, node.Domain());
+ if (!node.op_) {
+ ONNX_NAMESPACE::FunctionBuilderRegistry& function_registry =
+ FunctionBuilderRegistry::OnnxInstance();
+ auto onnx_function_proto = function_registry.GetFunction(node.OpType(), maxInclusiveVersion, ONNX_DOMAIN);
+ if (!onnx_function_proto) {
+ return Status(ONNXRUNTIME, FAIL, "Fatal error: " + node.OpType() + " is not a registered function/op");
+ }
+ auto func_ptr = std::make_unique(*this, node.Index(), onnx_function_proto);
+ function_container_->functions_.push_back(std::move(func_ptr));
+ node.SetFunctionBody(*function_container_->functions_.back());
+ }
}
- LOTUS_RETURN_IF_ERROR(node.UpdateInputArgCount());
+ ONNXRUNTIME_RETURN_IF_ERROR(node.UpdateInputArgCount());
- // currently an Op is required by ValidateVersion, so we use gsl::not_null.
+ // currently an Op is required by ValidateVersion, so we use gsl::not_null to validate that.
// This may change in the future to allow a null Op
- const gsl::not_null p_op = node.Op();
+ const gsl::not_null p_op{node.Op()};
// Attribute verification and fill node attribute with
// default value defined in operator definition if needed.
@@ -1125,7 +1120,7 @@ Status Graph::VerifyNodeAndOpMatch(const std::vector& nodes_in_topolo
}
// TODO: Handle optional attribute but no default value specified in op definition.
} else {
- Status status(LOTUS, FAIL,
+ Status status(ONNXRUNTIME, FAIL,
"Node (" + node_name + ") attribute (" + attr_def.first +
") is required but not specified.");
return status;
@@ -1133,7 +1128,7 @@ Status Graph::VerifyNodeAndOpMatch(const std::vector& nodes_in_topolo
}
}
- NO_CHANGE_ON_SYNC_FLAG(LOTUS_RETURN_IF_ERROR(InferAndVerifyTypeMatch(node, *p_op)));
+ NO_CHANGE_ON_SYNC_FLAG(ONNXRUNTIME_RETURN_IF_ERROR(InferAndVerifyTypeMatch(node, *p_op)));
}
return Status::OK();
@@ -1146,7 +1141,7 @@ Status Graph::VerifyInputAndInitializerNames(/*OUT*/ std::unordered_setName());
if (!result.second) {
- Status status(LOTUS, FAIL,
+ Status status(ONNXRUNTIME, FAIL,
"Error: Duplicate definition-site for (" + input->Name() + ").");
return status;
}
@@ -1168,19 +1163,16 @@ Status Graph::Resolve(bool no_proto_sync_required) {
for (auto& node : Nodes()) {
node.MutableRelationships().Clear();
}
- //add control edge for source and sink
- //otherwise, if the graph only contain initializers, CheckIsAcyclic will fail as the graph is not connected.
- NO_CHANGE_ON_SYNC_FLAG(AddControlEdge(source_node_index_, sink_node_index_));
+ ONNXRUNTIME_RETURN_IF_ERROR(SetGraphInputsOutputs());
std::unordered_map output_args;
std::unordered_set inputs_and_initializers;
std::unordered_map node_name_to_index;
- LOTUS_RETURN_IF_ERROR(VerifyInputAndInitializerNames(inputs_and_initializers));
- LOTUS_RETURN_IF_ERROR(VerifyNoDuplicateName(inputs_and_initializers, output_args, node_name_to_index));
- LOTUS_RETURN_IF_ERROR(BuildConnections(output_args, node_name_to_index));
- LOTUS_RETURN_IF_ERROR(CheckIsAcyclic(NodesInTopologicalOrder()));
- LOTUS_RETURN_IF_ERROR(VerifyNodeAndOpMatch(NodesInTopologicalOrder(), output_args));
- LOTUS_RETURN_IF_ERROR(SetGraphInputsOutputs());
+ ONNXRUNTIME_RETURN_IF_ERROR(VerifyInputAndInitializerNames(inputs_and_initializers));
+ ONNXRUNTIME_RETURN_IF_ERROR(VerifyNoDuplicateName(inputs_and_initializers, output_args, node_name_to_index));
+ ONNXRUNTIME_RETURN_IF_ERROR(BuildConnections(output_args, node_name_to_index));
+ ONNXRUNTIME_RETURN_IF_ERROR(PerformTopologicalSortAndCheckIsAcyclic());
+ ONNXRUNTIME_RETURN_IF_ERROR(VerifyNodeAndOpMatch(inputs_and_initializers));
CleanUnusedInitializers();
@@ -1195,30 +1187,16 @@ Status Graph::Resolve(bool no_proto_sync_required) {
return Status::OK();
}
-Status GraphBase::GetNodesInTopologicalOrder(gsl::not_null**> pp_nodes) const {
+Status GraphBase::GetNodesInTopologicalOrder(const std::vector*& pp_nodes) const {
if (graph_resolve_needed_) {
- return Status(::onnxruntime::common::LOTUS, ::onnxruntime::common::FAIL,
+ return Status(::onnxruntime::common::ONNXRUNTIME, ::onnxruntime::common::FAIL,
"Resolve() must be called before using the graph as modifications have been made to it.");
}
- *pp_nodes = &nodes_in_topological_order_;
+ pp_nodes = &nodes_in_topological_order_;
return Status::OK();
}
-void GraphBase::AddSourceSinkNodes() {
- std::vector empty_args;
-
- source_node_index_ = AddNode("_Graph_Source", kNoOp,
- "Source node internally in a graph.", empty_args, empty_args)
- ->Index();
-
- sink_node_index_ = AddNode("_Graph_Sink", kNoOp,
- "Sink node internally in a graph.", empty_args, empty_args)
- ->Index();
-
- NO_CHANGE_ON_SYNC_FLAG(AddControlEdge(source_node_index_, sink_node_index_));
-}
-
const std::string& Graph::Name() const noexcept {
return graph_proto_->name();
}
@@ -1240,9 +1218,8 @@ void Graph::AddInitializedTensor(const TensorProto& tensor) {
return;
}
- const gsl::not_null tensorAdded = graph_proto_->add_initializer();
+ const gsl::not_null tensorAdded{graph_proto_->add_initializer()};
*(tensorAdded) = tensor;
- name_to_initial_tensorIndex_[tensor.name()] = graph_proto_->initializer_size() - 1;
name_to_initial_tensor_[tensor.name()] = tensorAdded;
SetGraphProtoSyncNeeded();
@@ -1250,27 +1227,25 @@ void Graph::AddInitializedTensor(const TensorProto& tensor) {
}
void Graph::RemoveInitializedTensor(const std::string& tensor_name) {
- auto iter = name_to_initial_tensorIndex_.find(tensor_name);
- if (name_to_initial_tensorIndex_.end() != iter) {
- removed_initializer_indexes_.push_back(iter->second);
- name_to_initial_tensorIndex_.erase(tensor_name);
+ auto iter = name_to_initial_tensor_.find(tensor_name);
+ if (name_to_initial_tensor_.end() != iter) {
name_to_initial_tensor_.erase(tensor_name);
SetGraphProtoSyncNeeded();
SetGraphResolveNeeded();
}
}
-bool Graph::GetInitializedTensor(const std::string& tensor_name, gsl::not_null value) const {
+bool Graph::GetInitializedTensor(const std::string& tensor_name, const TensorProto*& value) const {
auto iter = name_to_initial_tensor_.find(tensor_name);
if (name_to_initial_tensor_.end() == iter) {
+ value = nullptr;
return false;
}
- *value = iter->second;
+ value = iter->second;
return true;
}
void Graph::CleanAllInitializedTensors() noexcept {
- name_to_initial_tensorIndex_.clear();
name_to_initial_tensor_.clear();
removed_initializer_indexes_.clear();
@@ -1292,30 +1267,8 @@ const std::vector& Graph::GetValueInfo() const noexcept {
return value_info_;
}
-// Ensure the NodeArgs in the input are created and in this Graph's node arg map
-static void AddNodeArgs(const std::vector& input_args,
- std::unordered_map& node_arg_map) {
- for (const gsl::not_null input_arg : input_args) {
- if (!input_arg->Exists()) continue;
- auto& key = input_arg->Name();
- auto existing_entry = node_arg_map.find(key);
-
- NodeArg* node_arg = existing_entry == node_arg_map.end() ? nullptr : existing_entry->second;
-
- if (node_arg == nullptr) {
- node_arg_map[key] = input_arg;
- } else {
- // check that if an existing entry was found, it was for the same instance
- LOTUS_ENFORCE(node_arg == input_arg,
- "Existing entry in NodeArg map for ", key, " != input definition.");
- }
- }
-}
-
-static std::vector