diff --git a/.gitattributes b/.gitattributes
index c3e6bdd25..40ec94114 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -163,5 +163,6 @@ Examples/Extensibility/BinaryConvolution/BinaryConvolutionLib/halide/halide_conv
Tests/EndToEndTests/Speech/Data/mlf2.bin binary
external/gsl text
Source/CNTKv2LibraryDll/proto/onnx/onnx_repo text
+Source/CNTKv2LibraryDll/proto/onnx/onnxruntime text
#certificates
*.pfx binary
diff --git a/.gitmodules b/.gitmodules
index dfd1fdd2b..dbe1c5ca5 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -8,3 +8,6 @@
[submodule "Source/CNTKv2LibraryDll/proto/onnx/onnx_repo"]
path = Source/CNTKv2LibraryDll/proto/onnx/onnx_repo
url = https://github.com/onnx/onnx.git
+[submodule "Source/CNTKv2LibraryDll/proto/onnx/onnxruntime"]
+ path = Source/CNTKv2LibraryDll/proto/onnx/onnxruntime
+ url = https://github.com/Microsoft/onnxruntime.git
diff --git a/Documentation/current_iteration.md b/Documentation/current_iteration.md
index 60147c9a6..460b90af9 100644
--- a/Documentation/current_iteration.md
+++ b/Documentation/current_iteration.md
@@ -12,7 +12,7 @@ To setup build and runtime environment on Windows:
* Install [Visual Studio 2017](https://www.visualstudio.com/downloads/). Note: going forward for CUDA 10 and beyond, it is no longer required to install and run with the specific VC Tools version 14.11.
* Install [Nvidia CUDA 10](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64)
* From PowerShell, run:
- [DevInstall.ps1](./Tools/devInstall/Windows/DevInstall.ps1)
+ [DevInstall.ps1](../Tools/devInstall/Windows/DevInstall.ps1)
* Start Visual Studio 2017 and open [CNTK.sln](./CNTK.sln).
To setup build and runtime environment on Linux using docker, please build Unbuntu 16.04 docker image using Dockerfiles [here](./Tools/docker). For other Linux systems, please refer to the Dockerfiles to setup dependent libraries for CNTK.
\ No newline at end of file
diff --git a/Makefile b/Makefile
index 8910814be..1c184d73e 100644
--- a/Makefile
+++ b/Makefile
@@ -97,14 +97,15 @@ GSL_PATH:=$(SOURCEDIR)/../external/gsl
ONNX_PATH:=$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx
ONNX_REPO_PATH:=$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo
ONNX_REPO_PATH+=$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx
-ONNX_REPO_PATH+=$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/core/include
+ONNX_REPO_PATH+=$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnxruntime/onnxruntime
+ONNX_REPO_PATH+=$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnxruntime/include/onnxruntime
INCLUDEPATH:= $(addprefix $(SOURCEDIR)/, Common/Include CNTKv2LibraryDll CNTKv2LibraryDll/API CNTKv2LibraryDll/API/Internals CNTKv2LibraryDll/Generated/Linux CNTKv2LibraryDll/proto ../Examples/Extensibility/CPP Math CNTK ActionsLib ComputationNetworkLib SGDLib SequenceTrainingLib CNTK/BrainScript Readers/ReaderLib PerformanceProfilerDll)
INCLUDEPATH+=$(PROTOBUF_PATH)/include
INCLUDEPATH+=$(GSL_PATH)/include
INCLUDEPATH+=$(ONNX_PATH)
INCLUDEPATH+=$(ONNX_REPO_PATH)
# COMMON_FLAGS include settings that are passed both to NVCC and C++ compilers.
-COMMON_FLAGS:= $(COMMON_FLAGS) -DONNX_NAMESPACE=onnx -DONNX_ML=1 -DHAS_MPI=$(HAS_MPI) -D_POSIX_SOURCE -D_XOPEN_SOURCE=600 -D__USE_XOPEN2K -std=c++14 -DCUDA_NO_HALF -D__CUDA_NO_HALF_OPERATORS__
+COMMON_FLAGS:= $(COMMON_FLAGS) -DONNX_NAMESPACE=onnx -DONNX_ML=1 -DHAS_MPI=$(HAS_MPI) -D_POSIX_SOURCE -D_XOPEN_SOURCE=600 -D__USE_XOPEN2K -std=c++14 -DCUDA_NO_HALF -D__CUDA_NO_HALF_OPERATORS__ -DPLATFORM_POSIX
CPPFLAGS:=
CXXFLAGS:= $(SSE_FLAGS) $(CXXFLAGS) -fopenmp -fpermissive -fPIC -Werror -fcheck-new
LIBPATH:=
@@ -526,28 +527,29 @@ CNTKLIBRARY_COMMON_SRC =\
$(SOURCEDIR)/CNTKv2LibraryDll/tensorboard/tensorboard.pb.cc \
$(SOURCEDIR)/CNTKv2LibraryDll/tensorboard/TensorBoardFileWriter.cpp \
$(SOURCEDIR)/CNTKv2LibraryDll/tensorboard/TensorBoardUtils.cpp \
- $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/core/common/logging/capture.cc \
- $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/core/common/logging/logging.cc \
- $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/core/common/profiler.cc \
- $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/core/common/status.cc \
- $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/core/framework/tensorutils.cc \
- $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/core/graph/function.cc \
- $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/core/graph/graph.cc \
- $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/core/graph/graph_transformer_mgr.cc \
- $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/core/graph/graph_viewer.cc \
- $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/core/graph/model.cc \
- $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/core/graph/op.cc \
- $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/core/graph/schema_registry.cc \
- $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/core/platform/env.cc \
- $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/core/platform/env_time.cc \
- $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/core/platform/posix/env.cc \
- $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/core/platform/posix/env_time.cc \
- $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/core/platform/posix/stacktrace.cc \
+ $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnxruntime/onnxruntime/core/common/logging/capture.cc \
+ $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnxruntime/onnxruntime/core/common/logging/logging.cc \
+ $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnxruntime/onnxruntime/core/common/profiler.cc \
+ $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnxruntime/onnxruntime/core/common/status.cc \
+ $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnxruntime/onnxruntime/core/framework/tensorutils.cc \
+ $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnxruntime/onnxruntime/core/graph/function.cc \
+ $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnxruntime/onnxruntime/core/graph/graph.cc \
+ $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnxruntime/onnxruntime/core/graph/graph_transformer_mgr.cc \
+ $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnxruntime/onnxruntime/core/graph/graph_viewer.cc \
+ $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnxruntime/onnxruntime/core/graph/model.cc \
+ $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnxruntime/onnxruntime/core/graph/op.cc \
+ $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnxruntime/onnxruntime/core/graph/schema_registry.cc \
+ $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnxruntime/onnxruntime/core/platform/env.cc \
+ $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnxruntime/onnxruntime/core/platform/env_time.cc \
+ $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnxruntime/onnxruntime/core/platform/posix/env.cc \
+ $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnxruntime/onnxruntime/core/platform/posix/env_time.cc \
+ $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnxruntime/onnxruntime/core/platform/posix/stacktrace.cc \
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/checker.cpp \
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/common/assertions.cc \
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/common/model_helpers.cc \
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/common/status.cc \
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/defs/controlflow/defs.cc \
+ $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/defs/controlflow/old.cc \
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/defs/experiments/defs.cc \
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/defs/experiments/experiments_functions.cc \
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/defs/function.cc \
@@ -564,7 +566,8 @@ CNTKLIBRARY_COMMON_SRC =\
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/defs/rnn/old.cc \
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/defs/tensor/defs.cc \
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/defs/tensor/old.cc \
- $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/defs/traditionalml/defs.cc \
+ $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/defs/traditionalml/defs.cc \
+ $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/defs/traditionalml/old.cc \
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/defs/data_type_utils.cc \
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/defs/schema.cc \
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/shape_inference/implementation.cc \
@@ -572,7 +575,7 @@ CNTKLIBRARY_COMMON_SRC =\
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/onnx-operators-ml.pb.cc \
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/Operators.cpp \
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/RNNHelper.cpp \
- $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp \
+ $(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp \
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp \
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/ONNX.cpp \
@@ -1304,7 +1307,7 @@ $(UNITTEST_EVAL) : $(UNITTEST_EVAL_OBJ) | $(EVAL_LIB) $(READER_LIBS)
@echo $(SEPARATOR)
@mkdir -p $(dir $@)
@echo building $@ for $(ARCH) with build type $(BUILDTYPE)
- $(CXX) $(LDFLAGS) $(patsubst %,-L%, $(LIBDIR) $(LIBPATH) $(GDK_NVML_LIB_PATH) $(BOOSTLIB_PATH)) $(patsubst %, $(RPATH)%, $(ORIGINLIBDIR) $(LIBPATH) $(BOOSTLIB_PATH)) -o $@ $^ $(BOOSTLIBS) $(LIBS) -l$(EVAL) $(L_READER_LIBS) $(lMULTIVERSO)
+ $(CXX) $(LDFLAGS) $(patsubst %,-L%, $(LIBDIR) $(LIBPATH) $(GDK_NVML_LIB_PATH) $(BOOSTLIB_PATH)) $(patsubst %, $(RPATH)%, $(ORIGINLIBDIR) $(LIBPATH) $(BOOSTLIB_PATH)) -o $@ $^ $(BOOSTLIBS) $(LIBS) -l$(EVAL) $(L_READER_LIBS) $(lMULTIVERSO) -ldl
#TODO: create project specific makefile or rules to avoid adding project specific path to the global path
INCLUDEPATH += $(SOURCEDIR)/Readers/CNTKTextFormatReader
@@ -1699,17 +1702,18 @@ DEP := $(patsubst %.o, %.d, $(OBJ))
BUILD_CONFIGURATION := Makefile $(BUILD_TOP)/Config.make
+ONNXRUNTIME_PROTO_PATH=$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnxruntime/onnxruntime/core/protobuf
%onnx-ml.pb.cc : %onnx-ml.proto $(BUILD_CONFIGURATION)
@echo $(SEPARATOR)
- @echo compiling protobuf $<
+ @echo compiling protobuf from $(ONNXRUNTIME_PROTO_PATH)
# protoc is confused if --proto_path is not set to an absolute path in below usage
- $(PROTOC) --proto_path=$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/ --cpp_out=$(dir $<).. $<
+ $(PROTOC) --proto_path=$(ONNXRUNTIME_PROTO_PATH)/ --cpp_out=$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/ $(ONNXRUNTIME_PROTO_PATH)/onnx-ml.proto
%onnx-operators-ml.pb.cc : %onnx-operators-ml.proto $(BUILD_CONFIGURATION)
@echo $(SEPARATOR)
- @echo compiling protobuf $<
+ @echo compiling protobuf from $(ONNXRUNTIME_PROTO_PATH)
# protoc is confused if --proto_path is not set to an absolute path in below usage
- $(PROTOC) --proto_path=$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/ --cpp_out=$(dir $<).. $<
+ $(PROTOC) --proto_path=$(ONNXRUNTIME_PROTO_PATH)/ --cpp_out=$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/onnx_repo/onnx/ $(ONNXRUNTIME_PROTO_PATH)/onnx-operators-ml.proto
%.pb.cc : %.proto $(BUILD_CONFIGURATION)
@echo $(SEPARATOR)
diff --git a/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj b/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj
index 11ea7f7f7..c913e15b6 100644
--- a/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj
+++ b/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj
@@ -66,7 +66,7 @@
- .\proto\onnx;.\proto\onnx\core\include;.\proto\onnx\onnx_repo;.\proto\onnx\onnx_repo\onnx;.\API;.\API\Internals;.\proto;$(BOOST_INCLUDE_PATH);$(SolutionDir)\Source\CNTKv2LibraryDll;$(SolutionDir)Source\SGDLib;$(SolutionDir)Source\Readers\ReaderLib;$(SolutionDir)Source\ComputationNetworkLib;$(SolutionDir)Source\SequenceTrainingLib;$(SolutionDir)Source\Math;$(SolutionDir)Source\Common\Include;$(SolutionDir)Source\CNTK\BrainScript;$(SolutionDir)Source\ActionsLib;$(MSMPI_INC);$(NvmlInclude);$(ProtobufInclude);$(SolutionDir)Source\PerformanceProfilerDll;..\..\external\gsl\include;$(ProjectDir)Generated\Windows
+ .\proto\onnx;.\proto\onnx\onnxruntime\onnxruntime;.\proto\onnx\onnxruntime\include\onnxruntime;.\proto\onnx\onnx_repo;.\proto\onnx\onnx_repo\onnx;.\API;.\API\Internals;.\proto;$(BOOST_INCLUDE_PATH);$(SolutionDir)\Source\CNTKv2LibraryDll;$(SolutionDir)Source\SGDLib;$(SolutionDir)Source\Readers\ReaderLib;$(SolutionDir)Source\ComputationNetworkLib;$(SolutionDir)Source\SequenceTrainingLib;$(SolutionDir)Source\Math;$(SolutionDir)Source\Common\Include;$(SolutionDir)Source\CNTK\BrainScript;$(SolutionDir)Source\ActionsLib;$(MSMPI_INC);$(NvmlInclude);$(ProtobufInclude);$(SolutionDir)Source\PerformanceProfilerDll;..\..\external\gsl\include;$(ProjectDir)Generated\Windows
$(SolutionDir)Source\1BitSGD;$(ProjectDir)Generated\Windows;%(AdditionalIncludeDirectories)
CNTK_PARALLEL_TRAINING_SUPPORT;%(PreprocessorDefinitions)
true
@@ -84,7 +84,7 @@
NotUsing
Level4
Disabled
- ONNX_NAMESPACE=onnx;ONNX_ML=1;ONNX_V1_OPSCHEMA_COMPAT;CNTKV2LIBRARYDLL;WIN32;_DEBUG;_WINDOWS;_USRDLL;%(PreprocessorDefinitions)
+ ONNX_NAMESPACE=onnx;ONNX_ML=1;ONNX_V1_OPSCHEMA_COMPAT;CNTKV2LIBRARYDLL;WIN32;_DEBUG;_WINDOWS;_USRDLL;PLATFORM_WINDOWS;%(PreprocessorDefinitions)
true
true
4800;4610;4512;4510;4267;4127;4125;4100;4456;4189;4996;4503;4146
@@ -101,7 +101,7 @@
Level4
NotUsing
- ONNX_NAMESPACE=onnx;ONNX_ML=1;ONNX_V1_OPSCHEMA_COMPAT;CNTKV2LIBRARYDLL;WIN32;NDEBUG;_WINDOWS;_USRDLL;%(PreprocessorDefinitions)
+ ONNX_NAMESPACE=onnx;ONNX_ML=1;ONNX_V1_OPSCHEMA_COMPAT;CNTKV2LIBRARYDLL;WIN32;NDEBUG;_WINDOWS;_USRDLL;PLATFORM_WINDOWS;%(PreprocessorDefinitions)
true
/d2Zi+ /bigobj %(AdditionalOptions)
MultiThreadedDLL
@@ -118,7 +118,7 @@
- ONNX_NAMESPACE=onnx;ONNX_ML=1;ONNX_V1_OPSCHEMA_COMPAT;CPUONLY;%(PreprocessorDefinitions)
+ ONNX_NAMESPACE=onnx;ONNX_ML=1;ONNX_V1_OPSCHEMA_COMPAT;CPUONLY;PLATFORM_WINDOWS;%(PreprocessorDefinitions)
$(IntDir)\$(ProjectName)\$(ConfigurationName)\%(RelativeDir)
$(IntDir)\$(ProjectName)\$(ConfigurationName)\%(RelativeDir)
@@ -151,9 +151,11 @@
+ IsUWP;%(PreprocessorDefinitions)
$(IntDir)\$(ProjectName)\$(ConfigurationName)\%(RelativeDir)
+ IsUWP;%(PreprocessorDefinitions)
$(IntDir)\$(ProjectName)\$(ConfigurationName)\%(RelativeDir)
@@ -175,51 +177,46 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
@@ -272,24 +269,23 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
@@ -299,6 +295,7 @@
+
@@ -318,6 +315,7 @@
+
@@ -345,12 +343,12 @@
-
-
+
+
-
+
diff --git a/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj.filters b/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj.filters
index 0a638d797..623451e4f 100644
--- a/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj.filters
+++ b/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj.filters
@@ -35,35 +35,8 @@
-
- proto\onnx
-
-
- proto\onnx
-
-
- proto\onnx
-
-
- proto\onnx
-
-
- proto
-
-
- proto\onnx
-
-
- proto\onnx\core\common\logging
-
-
- proto\onnx\core\common
-
-
- proto\onnx\core\common\logging
-
proto\onnx\onnx_repo\onnx\defs\controlflow
@@ -94,15 +67,6 @@
proto\onnx\onnx_repo\onnx\defs\traditionalml
-
- proto\onnx\core\graph
-
-
- proto\onnx\core\graph
-
-
- proto\onnx\core\graph
-
proto\onnx\onnx_repo\onnx\defs\logical
@@ -130,45 +94,6 @@
proto\onnx\onnx_repo\onnx\defs
-
- proto\onnx
-
-
- proto\onnx\core\graph
-
-
- proto\onnx\core\graph
-
-
- proto\onnx\core\graph
-
-
- proto\onnx\core\common
-
-
- proto\onnx
-
-
- proto\onnx\core\framework
-
-
- proto\onnx\core\platform
-
-
- proto\onnx\core\platform
-
-
- proto\onnx\core\platform\windows
-
-
- proto\onnx\core\platform\windows
-
-
- proto\onnx\core\platform\windows
-
-
- proto\onnx\core\platform\windows
-
proto\onnx\onnx_repo\onnx\defs
@@ -187,8 +112,80 @@
proto\onnx\onnx_repo\onnx\shape_inference
-
- proto\onnx\core\graph
+
+ proto\onnx\onnxruntime\common\logging
+
+
+ proto\onnx\onnxruntime\common\logging
+
+
+ proto\onnx\onnxruntime\common
+
+
+ proto\onnx\onnxruntime\common
+
+
+ proto\onnx\onnxruntime\graph
+
+
+ proto\onnx\onnxruntime\graph
+
+
+ proto\onnx\onnxruntime\graph
+
+
+ proto\onnx\onnxruntime\graph
+
+
+ proto\onnx\onnxruntime\graph
+
+
+ proto\onnx\onnxruntime\graph
+
+
+ proto\onnx\onnxruntime\graph
+
+
+ proto\onnx\onnxruntime\framework
+
+
+ proto\onnx
+
+
+ proto\onnx
+
+
+ proto\onnx
+
+
+ proto\onnx
+
+
+ proto\onnx
+
+
+ proto\onnx
+
+
+ proto\onnx
+
+
+ proto\onnx
+
+
+ proto\onnx\onnxruntime\platform
+
+
+ proto\onnx\onnxruntime\platform
+
+
+
+
+
+ proto\onnx\onnx_repo\onnx\defs\controlflow
+
+
+ proto\onnx\onnx_repo\onnx\defs\traditionalml
@@ -235,30 +232,12 @@
-
- proto\onnx
-
-
- proto\onnx
-
-
- proto\onnx
-
-
- proto\onnx
-
API
API
-
- proto\onnx
-
-
- proto\onnx\core\inc
-
proto\onnx\onnx_repo\onnx
@@ -286,135 +265,9 @@
proto\onnx\onnx_repo\onnx\common
-
- proto\onnx\core\graph
-
-
- proto\onnx\core\graph
-
-
- proto\onnx\core\graph
-
-
- proto\onnx\core\graph
-
proto\onnx\onnx_repo\onnx\common
-
- proto\onnx
-
-
- proto\onnx\core\graph
-
-
- proto\onnx\core\graph
-
-
- proto\onnx\core\graph
-
-
- proto\onnx\core\graph
-
-
- proto\onnx\core\common
-
-
- proto\onnx\core\common
-
-
- proto\onnx\core\include\core\common
-
-
- proto\onnx\core\include\core\common
-
-
- proto\onnx\core\include\core\common
-
-
- proto\onnx\core\include\core\common
-
-
- proto\onnx\core\include\core\common
-
-
- proto\onnx\core\include\core\common
-
-
- proto\onnx\core\include\core\common\logging
-
-
- proto\onnx\core\include\core\common\logging
-
-
- proto\onnx\core\include\core\common\logging
-
-
- proto\onnx\core\include\core\common\logging
-
-
- proto\onnx\core\include\core\common\logging
-
-
- proto\onnx\core\include\core\common\logging\sinks
-
-
- proto\onnx\core\include\core\common\logging\sinks
-
-
- proto\onnx\core\include\core\graph
-
-
- proto\onnx\core\include\core\graph
-
-
- proto\onnx\core\include\core\graph
-
-
- proto\onnx\core\include\core\graph
-
-
- proto\onnx\core\include\core\graph
-
-
- proto\onnx\core\include\core\graph
-
-
- proto\onnx\core\include\core\graph
-
-
- proto\onnx\core\include\core\graph
-
-
- proto\onnx\core\include\core\graph
-
-
- proto\onnx\core\include\core\graph
-
-
- proto\onnx\core\include\core\inc
-
-
- proto\onnx\core\framework
-
-
- proto\onnx\core\include\core\platform
-
-
- proto\onnx\core\include\core\platform
-
-
- proto\onnx\core\platform
-
-
- proto\onnx\core\platform
-
-
- proto\onnx\core\platform\windows
-
-
- proto\onnx
-
proto\onnx\onnx_repo\onnx\defs
@@ -424,6 +277,135 @@
proto\onnx\onnx_repo\onnx\common
+
+ proto\onnx\onnxruntime\include\common
+
+
+ proto\onnx\onnxruntime\include\common
+
+
+ proto\onnx\onnxruntime\include\common
+
+
+ proto\onnx\onnxruntime\include\common
+
+
+ proto\onnx\onnxruntime\include\common
+
+
+ proto\onnx\onnxruntime\include\common
+
+
+ proto\onnx\onnxruntime\include\common\logging
+
+
+ proto\onnx\onnxruntime\include\common\logging
+
+
+ proto\onnx\onnxruntime\include\common\logging
+
+
+ proto\onnx\onnxruntime\include\common\logging
+
+
+ proto\onnx\onnxruntime\include\common\logging
+
+
+ proto\onnx\onnxruntime\include\graph
+
+
+ proto\onnx\onnxruntime\include\graph
+
+
+ proto\onnx\onnxruntime\include\graph
+
+
+ proto\onnx\onnxruntime\include\graph
+
+
+ proto\onnx\onnxruntime\include\graph
+
+
+ proto\onnx\onnxruntime\include\graph
+
+
+ proto\onnx\onnxruntime\include\graph
+
+
+ proto\onnx\onnxruntime\include\graph
+
+
+ proto\onnx\onnxruntime\include\graph
+
+
+ proto\onnx\onnxruntime\include\graph
+
+
+ proto\onnx\onnxruntime\include\inc
+
+
+ proto\onnx\onnxruntime\platform
+
+
+ proto\onnx\onnxruntime\platform
+
+
+ proto\onnx\onnxruntime\graph
+
+
+ proto\onnx\onnxruntime\graph
+
+
+ proto\onnx\onnxruntime\graph
+
+
+ proto\onnx\onnxruntime\graph
+
+
+ proto\onnx\onnxruntime\graph
+
+
+ proto\onnx\onnxruntime\graph
+
+
+ proto\onnx\onnxruntime\graph
+
+
+ proto\onnx\onnxruntime\graph
+
+
+ proto\onnx\onnxruntime\common
+
+
+ proto\onnx\onnxruntime\common
+
+
+ proto\onnx\onnxruntime\framework
+
+
+ proto\onnx\onnxruntime\inc
+
+
+ proto\onnx
+
+
+ proto\onnx
+
+
+ proto\onnx
+
+
+ proto\onnx
+
+
+ proto\onnx
+
+
+ proto\onnx\onnxruntime\platform
+
+
+ proto\onnx\onnxruntime\platform
+
@@ -441,21 +423,6 @@
{ca68761d-44d4-41a9-b055-4b192402ed0b}
-
- {ac45f7f4-5f65-40d4-9163-46580266ae16}
-
-
- {3a706847-68f2-45a2-91bf-66deeac9a67b}
-
-
- {0bdf50b3-73a2-455b-9271-6f749b3cbb98}
-
-
- {c6e7230c-950a-4ecd-92da-0db3843d795c}
-
-
- {c18a3bd0-c2dc-4a3d-8820-7c9972f65a5f}
-
{9541e056-faf3-446e-a1cd-821fc16284fa}
@@ -498,50 +465,53 @@
{bc2e7e0d-8620-40a5-8e1f-1cdda8880dd3}
-
- {172ea174-5c72-4e82-baae-fc80eda6e3a0}
-
-
- {d462f397-47df-4cbe-ae8f-751825a70365}
-
-
- {ad17fa77-1bdb-4130-9363-cfb2fe08b3c5}
-
-
- {f594af27-d007-4a79-9616-c589227821d6}
-
-
- {8da0dc26-2ae2-4f78-8a5c-dd497e176e95}
-
-
- {8fcfe046-8edd-4a67-b494-aa2e968e25e0}
-
-
- {106e1174-345f-43bf-a124-4b5656ac3e33}
-
-
- {a468acb3-5520-4433-8ad1-1241a2e13e7c}
-
-
- {9b0d609a-31b4-4b5d-a47b-1d09ffc8459e}
-
-
- {122b6879-351d-4719-974c-1c1db04a8cff}
-
-
- {26599ed1-92ab-42f3-b835-3057768a502a}
-
-
- {938a6293-26e8-4aad-9aa3-200d9b96102b}
-
{b8ebfd65-98ba-44fb-b10d-ac1e7e8e5246}
+
+ {769cf5e4-cef4-47f0-9b29-f190e3731f26}
+
+
+ {45e51e13-29c8-48e4-b765-3dad6f25f52d}
+
+
+ {6666e70d-16b9-4d52-b305-abe70ab144b1}
+
+
+ {d1ad1f5d-18c6-4980-97a4-fe1819672029}
+
+
+ {3f8fc63d-dbcb-4e4d-96e8-b49da7b7d5e7}
+
+
+ {556e9414-303c-45a8-8ed3-f035458d3351}
+
+
+ {babbff64-1577-4c83-a81d-9ea90ec4b931}
+
+
+ {8ac97d45-37a9-4494-a728-8041e35d20dc}
+
+
+ {24483f0a-fe67-44dd-b1df-f5abb91dcc8d}
+
+
+ {955eafd1-4d93-455f-a1a7-137b6eed969d}
+
+
+ {32268a6a-3039-4568-92b4-9a9388e324d0}
+
+
+ {98847797-f8ba-4847-b382-b58e7986336d}
+
+
+ {90661e60-2fcf-4398-a8fc-62cd11bb6418}
+
+
+ {681310a9-13d1-4e99-87ea-4b342d35901e}
+
-
- proto
-
tensorboard
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp b/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp
index f6a437a07..0b772567d 100644
--- a/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp
+++ b/Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp
@@ -2654,7 +2654,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateLSTMNode(const FunctionPtr &src,
CreateNode(input.Owner(), graph, functionNodes, variableNodes, compositeOutputsMap, scanLoops, createLoopIndex);
auto nodeName = ToLegacyString(ToUTF8(src->Uid()));
- onnxruntime::Node *lstmNode = graph->AddNode(nodeName, "LSTM", "", nodeInputs, nodeOutputs);
+ onnxruntime::Node *lstmNode = &graph->AddNode(nodeName, "LSTM", "", nodeInputs, nodeOutputs);
lstmNode->AddAttribute("activations", activations);
lstmNode->AddAttribute("direction", direction);
@@ -2931,7 +2931,7 @@ onnxruntime::Node *CNTKToONNXHelper::CreateGRUNode(const FunctionPtr &src,
CreateNode(input.Owner(), graph, functionNodes, variableNodes, compositeOutputsMap, scanLoops, createLoopIndex);
auto nodeName = ToLegacyString(ToUTF8(src->Uid()));
- onnxruntime::Node *gruNode = graph->AddNode(nodeName, "GRU", "", nodeInputs, nodeOutputs);
+ onnxruntime::Node *gruNode = &graph->AddNode(nodeName, "GRU", "", nodeInputs, nodeOutputs);
gruNode->AddAttribute("activations", activations);
gruNode->AddAttribute("direction", direction);
@@ -3119,7 +3119,7 @@ onnxruntime::Node *CNTKToONNXHelper::CreateRNNNode(const FunctionPtr &src,
CreateNode(input.Owner(), graph, functionNodes, variableNodes, compositeOutputsMap, scanLoops, createLoopIndex);
auto nodeName = UniqueNodeNameStorage::GetUniqueNodeName(src);
- onnxruntime::Node *rnnNode = graph->AddNode(nodeName, "RNN", "", nodeInputs, nodeOutputs);
+ onnxruntime::Node *rnnNode = &graph->AddNode(nodeName, "RNN", "", nodeInputs, nodeOutputs);
rnnNode->AddAttribute("activations", activations);
rnnNode->AddAttribute("direction", direction);
@@ -3217,7 +3217,7 @@ onnxruntime::NodeArg &CNTKToONNXHelper::CreateAddShapeNodeArg(Graph *graph, cons
onnxruntime::Node *CNTKToONNXHelper::AddReshapeNodeImpl(Graph *graph, const string &nodeName, NodeArg *input, NodeArg *output, const std::vector &newShape)
{
onnxruntime::NodeArg &shapeInputArg = CreateAddShapeNodeArg(graph, newShape, output->Name() + "_shape");
- auto reshapeNode1 = graph->AddNode(nodeName, "Reshape", "", { input, &shapeInputArg }, { output });
+ auto reshapeNode1 = &graph->AddNode(nodeName, "Reshape", "", { input, &shapeInputArg }, { output });
return reshapeNode1;
}
@@ -3248,7 +3248,7 @@ onnxruntime::Node *CNTKToONNXHelper::AddSliceNode(onnxruntime::NodeArg &inputArg
const std::string &outArgName, onnxruntime::Graph* graph)
{
const TypeProto &inputTypeProto = *inputArg.TypeAsProto();
- onnx::TensorProto_DataType elemType = inputTypeProto.tensor_type().elem_type();
+ google::protobuf::int32 elemType = inputTypeProto.tensor_type().elem_type();
onnx::TypeProto outputTypeProto = MakeTypeProtoWithShape();
outputTypeProto.mutable_tensor_type()->set_elem_type(elemType);
for (int i = 0, j = 0; i < inputTypeProto.tensor_type().shape().dim_size(); ++i) {
@@ -3274,7 +3274,7 @@ onnxruntime::Node *CNTKToONNXHelper::AddSliceNode(onnxruntime::NodeArg &inputArg
}
onnxruntime::NodeArg &outputNodeArg = graph->GetOrCreateNodeArg(outArgName, &outputTypeProto);
- onnxruntime::Node* sliceNode = graph->AddNode(
+ onnxruntime::Node* sliceNode = &graph->AddNode(
outArgName + string("_slice"), "Slice", "", { &inputArg }, { &outputNodeArg });
sliceNode->AddAttribute("axes", axes);
sliceNode->AddAttribute("starts", starts);
@@ -3289,7 +3289,7 @@ onnxruntime::Node *CNTKToONNXHelper::AddEyeLikeNode(onnxruntime::NodeArg &inputA
const TypeProto *inputTypeProto = inputArg.TypeAsProto();
onnx::TypeProto outputTypeProto(*inputTypeProto);
onnxruntime::NodeArg &outputNodeArg = graph->GetOrCreateNodeArg(outArgName, &outputTypeProto);
- onnxruntime::Node* eyeLikeNode = graph->AddNode(
+ onnxruntime::Node* eyeLikeNode = &graph->AddNode(
outArgName + string("_eye_like"), "EyeLike", "", { &inputArg }, { &outputNodeArg });
return eyeLikeNode;
}
@@ -3302,7 +3302,7 @@ onnxruntime::Node* CNTKToONNXHelper::AddConstantLikeNode(onnxruntime::NodeArg& i
onnx::TypeProto outputTypeProto(*inputTypeProto);
onnxruntime::NodeArg& outputNodeArg = graph->GetOrCreateNodeArg(outArgName, &outputTypeProto);
- onnxruntime::Node* constantLikeNode = graph->AddNode(
+ onnxruntime::Node* constantLikeNode = &graph->AddNode(
outArgName + string("_constant_like"), "ConstantLike", "", {&inputArg}, {&outputNodeArg});
constantLikeNode->AddAttribute("value", value);
return constantLikeNode;
@@ -3315,7 +3315,7 @@ onnxruntime::Node* CNTKToONNXHelper::AddPadNode(onnxruntime::NodeArg& inputArg,
const TypeProto* inputTypeProto = inputArg.TypeAsProto();
onnxruntime::NodeArg& outputNodeArg = graph->GetOrCreateNodeArg(outArgName, &outputType);
- onnxruntime::Node* padNode = graph->AddNode(
+ onnxruntime::Node* padNode = &graph->AddNode(
outArgName + string("_pad"), "Pad", "", {&inputArg}, {&outputNodeArg});
padNode->AddAttribute("mode", mode);
@@ -3329,7 +3329,7 @@ onnxruntime::Node* CNTKToONNXHelper::AddSqueezeNode(onnxruntime::NodeArg& inputA
const std::string& outArgName, onnxruntime::Graph* graph)
{
const TypeProto* inputTypeProto = inputArg.TypeAsProto();
- onnx::TensorProto_DataType elemType = inputTypeProto->tensor_type().elem_type();
+ google::protobuf::int32 elemType = inputTypeProto->tensor_type().elem_type();
onnx::TypeProto outputTypeProto = MakeTypeProtoWithShape();
outputTypeProto.mutable_tensor_type()->set_elem_type(elemType);
@@ -3345,7 +3345,7 @@ onnxruntime::Node* CNTKToONNXHelper::AddSqueezeNode(onnxruntime::NodeArg& inputA
}
onnxruntime::NodeArg& outputNodeArg = graph->GetOrCreateNodeArg(outArgName, &outputTypeProto);
- onnxruntime::Node* squeezeNode = graph->AddNode(
+ onnxruntime::Node* squeezeNode = &graph->AddNode(
outArgName + string("_squeeze"), "Squeeze", "", {&inputArg}, {&outputNodeArg});
squeezeNode->AddAttribute("axes", axes);
return squeezeNode;
@@ -3357,12 +3357,12 @@ onnxruntime::Node *CNTKToONNXHelper::AddExpandNode(onnxruntime::NodeArg &inputAr
{
onnxruntime::NodeArg &shapeNodeArg = CreateAddShapeNodeArg(graph, newShape, outArgName + "_expand_shape");
- onnx::TensorProto_DataType elemType = inputArg.TypeAsProto()->tensor_type().elem_type();
+ google::protobuf::int32 elemType = inputArg.TypeAsProto()->tensor_type().elem_type();
onnx::TypeProto outputTypeProto = ToTypeProto(newShape, false);
outputTypeProto.mutable_tensor_type()->set_elem_type(elemType);
onnxruntime::NodeArg &outputNodeArg = graph->GetOrCreateNodeArg(outArgName, &outputTypeProto);
- onnxruntime::Node* expandNode = graph->AddNode(
+ onnxruntime::Node* expandNode = &graph->AddNode(
outArgName + string("_expand"), "Expand", "", { &inputArg, &shapeNodeArg }, { &outputNodeArg });
return expandNode;
}
@@ -3371,7 +3371,7 @@ onnxruntime::Node *CNTKToONNXHelper::AddReshapeNode(onnxruntime::NodeArg &nodeAr
onnxruntime::Graph *graph)
{
onnx::TypeProto typeProto = ToTypeProto(newShape, false);
- onnx::TensorProto_DataType elemType = nodeArg.TypeAsProto()->tensor_type().elem_type();
+ google::protobuf::int32 elemType = nodeArg.TypeAsProto()->tensor_type().elem_type();
typeProto.mutable_tensor_type()->set_elem_type(elemType);
onnxruntime::NodeArg &outputArg = graph->GetOrCreateNodeArg(outArgName, &typeProto);
@@ -3384,7 +3384,7 @@ onnxruntime::Node *CNTKToONNXHelper::AddMatMulNode(onnxruntime::NodeArg &nodeArg
const std::string &out_arg_name)
{
onnxruntime::NodeArg &outputArg = graph->GetOrCreateNodeArg(out_arg_name, nullptr);
- onnxruntime::Node* argMatMulNode = graph->AddNode(
+ onnxruntime::Node* argMatMulNode = &graph->AddNode(
nodeArg1.Name() + string("_matmul"), "MatMul", "", { &nodeArg1, &nodeArg2 }, { &outputArg });
return argMatMulNode;
}
@@ -3393,7 +3393,7 @@ onnxruntime::Node *CNTKToONNXHelper::AddAddNode(onnxruntime::NodeArg &nodeArg1,
const std::string &out_arg_name)
{
onnxruntime::NodeArg &outputArg = graph->GetOrCreateNodeArg(out_arg_name, nullptr);
- onnxruntime::Node* argMatMulNode = graph->AddNode(
+ onnxruntime::Node* argMatMulNode = &graph->AddNode(
nodeArg1.Name() + string("_add"), "Add", "", { &nodeArg1, &nodeArg2 }, { &outputArg });
return argMatMulNode;
}
@@ -3404,7 +3404,7 @@ onnxruntime::Node *CNTKToONNXHelper::AddIdentityOp(onnxruntime::NodeArg &nodeArg
outputTypeProto.mutable_tensor_type()->set_elem_type(nodeArg.TypeAsProto()->tensor_type().elem_type());
onnxruntime::NodeArg &outputArg = graph->GetOrCreateNodeArg(out_arg_name, &outputTypeProto);
- onnxruntime::Node* identityNode = graph->AddNode(
+ onnxruntime::Node* identityNode = &graph->AddNode(
nodeArg.Name() + string("_identity"), "Identity", "", { &nodeArg}, { &outputArg });
return identityNode;
}
@@ -3413,7 +3413,7 @@ onnxruntime::Node *CNTKToONNXHelper::AddArgMaxNode(onnxruntime::NodeArg &nodeArg
{
// onnxruntime::NodeArg inputArg(nodeArg.Name(), nullptr);
onnxruntime::NodeArg &outputArg = graph->GetOrCreateNodeArg(nodeArg.Name() + "argmax_out", nullptr);
- onnxruntime::Node* argMaxNode = graph->AddNode(nodeArg.Name() + string("_argmax"), "ArgMax", "", { &nodeArg }, { &outputArg });
+ onnxruntime::Node* argMaxNode = &graph->AddNode(nodeArg.Name() + string("_argmax"), "ArgMax", "", { &nodeArg }, { &outputArg });
argMaxNode->AddAttribute("axis", (int64_t)axis);
return argMaxNode;
}
@@ -3425,7 +3425,7 @@ onnxruntime::Node *CNTKToONNXHelper::AddCastNode(onnxruntime::NodeArg &nodeArg,
outputTypeProto.mutable_tensor_type()->set_elem_type(toType);
onnxruntime::NodeArg &outputArg = graph->GetOrCreateNodeArg(nodeArg.Name() + "_cast_out_" + outputNodeArgName, &outputTypeProto);
- onnxruntime::Node* castNode = graph->AddNode(nodeArg.Name() + string("_cast_") + outputNodeArgName,
+ onnxruntime::Node* castNode = &graph->AddNode(nodeArg.Name() + string("_cast_") + outputNodeArgName,
"Cast", "", { &nodeArg }, { &outputArg });
castNode->AddAttribute("to", (int64_t)toType);
return castNode;
@@ -3463,8 +3463,8 @@ NodeArg& CNTKToONNXHelper::AddTransposeBatchSequenceAxesNode(onnxruntime::NodeAr
std::generate(perm.begin(), perm.end(), [axis = 0]() mutable { return axis++; });
std::swap(perm[0], perm[1]);
onnxruntime::Node* transposeNode = isInput ?
- graph->AddNode(nodeName, "Transpose", "", { &nodeArg }, { &otherArg }) :
- graph->AddNode(nodeName, "Transpose", "", { &otherArg }, { &nodeArg });
+ &graph->AddNode(nodeName, "Transpose", "", { &nodeArg }, { &otherArg }) :
+ &graph->AddNode(nodeName, "Transpose", "", { &otherArg }, { &nodeArg });
transposeNode->AddAttribute("perm", perm);
return otherArg;
}
@@ -3473,9 +3473,9 @@ onnxruntime::Node *CNTKToONNXHelper::AddTransposeNode(onnxruntime::NodeArg &node
const std::vector &perm, onnx::TypeProto& transposeOutputArgType, const std::string &outputNodeArgName)
{
onnxruntime::NodeArg &outputArg = graph->GetOrCreateNodeArg(outputNodeArgName, &transposeOutputArgType);
- onnx::TensorProto_DataType elementType = nodeArg.TypeAsProto()->tensor_type().elem_type();
+ google::protobuf::int32 elementType = nodeArg.TypeAsProto()->tensor_type().elem_type();
const_cast(outputArg.TypeAsProto())->mutable_tensor_type()->set_elem_type(elementType);
- onnxruntime::Node* transposeNode = graph->AddNode(nodeArg.Name() + string("_transpose"), "Transpose", "", { &nodeArg }, { &outputArg });
+ onnxruntime::Node* transposeNode = &graph->AddNode(nodeArg.Name() + string("_transpose"), "Transpose", "", { &nodeArg }, { &outputArg });
transposeNode->AddAttribute("perm", perm);
return transposeNode;
}
@@ -3605,7 +3605,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateSoftmaxLikeNode(const FunctionPtr& sr
UpdateONNXType(src->Output().GetDataType(), softmaxLikeOutputArgType);
onnxruntime::NodeArg &innerSoftmaxLikeOutputArg = graph->GetOrCreateNodeArg(innerSoftmaxOutputNodeArgName, &softmaxLikeOutputArgType);
- onnxruntime::Node* softmaxLikeNode = graph->AddNode(nodeName, onnxOpName, "", { inputToInnerSoftmaxArgNode }, { &innerSoftmaxLikeOutputArg });
+ onnxruntime::Node* softmaxLikeNode = &graph->AddNode(nodeName, onnxOpName, "", { inputToInnerSoftmaxArgNode }, { &innerSoftmaxLikeOutputArg });
// always softmax on the last axes
softmaxLikeNode->AddAttribute("axis", (int64_t)onnxRank - 1);
@@ -3676,13 +3676,13 @@ onnxruntime::Node* CNTKToONNXHelper::CreatePastFutureValueNode(const FunctionPtr
Node * concatNode;
if (past)
{
- concatNode = graph->AddNode(UniqueNodeNameStorage::GetUniqueNodeName(src), "Concat", "",
+ concatNode = &graph->AddNode(UniqueNodeNameStorage::GetUniqueNodeName(src), "Concat", "",
{ const_cast(initValueExpand->OutputDefs()[0]), const_cast(sliceNode->OutputDefs()[0]) },
outputs);
}
else
{
- concatNode = graph->AddNode(UniqueNodeNameStorage::GetUniqueNodeName(src), "Concat", "",
+ concatNode = &graph->AddNode(UniqueNodeNameStorage::GetUniqueNodeName(src), "Concat", "",
{ const_cast(sliceNode->OutputDefs()[0]), const_cast(initValueExpand->OutputDefs()[0]) },
outputs);
}
@@ -3832,7 +3832,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateReconcileDynamicAxisNode(const Functi
inputNodeArg = inputs[0];
}
- onnxruntime::Node* elementWiseNode = graph->AddNode(nodeName + "_add", "Add", "",
+ onnxruntime::Node* elementWiseNode = &graph->AddNode(nodeName + "_add", "Add", "",
{ inputNodeArg, broadcastNodeArg }, outputs);
functionNodes.emplace(src, elementWiseNode);
@@ -3889,7 +3889,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateSequenceBroadcastAsNode(const Functio
inputNodeArg = inputs[0];
}
- onnxruntime::Node* elementWiseNode = graph->AddNode(nodeName + "_add", "Add", "",
+ onnxruntime::Node* elementWiseNode = &graph->AddNode(nodeName + "_add", "Add", "",
{ inputNodeArg, broadcastNodeArg }, outputs);
functionNodes.emplace(src, elementWiseNode);
@@ -3935,7 +3935,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateSequenceGatherNode(const FunctionPtr&
std::string nodeName = UniqueNodeNameStorage::GetUniqueNodeName(src);
std::string onnxOpName = "Compress";
- Node *compressNode = graph->AddNode(nodeName, onnxOpName, "", inputs, { &compressOutputNodeArg });
+ Node *compressNode = &graph->AddNode(nodeName, onnxOpName, "", inputs, { &compressOutputNodeArg });
int64_t sequenceAxis = 0;
compressNode->AddAttribute("axis", sequenceAxis);
@@ -3965,7 +3965,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateSequenceReduceElementsNode(const Func
std::vector outputs;
ProcessOutputs(src, inputs, outputs, graph);
- Node *node = graph->AddNode(nodeName, onnxOpName, "", inputs, outputs);
+ Node *node = &graph->AddNode(nodeName, onnxOpName, "", inputs, outputs);
SetReduceElementsAttributes(br, node, true);
return node;
}
@@ -4014,7 +4014,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateNodeWithGatherPacked(const FunctionPt
std::vector({ gatherPackedInputs[0], const_cast(castScreezeNode->OutputDefs()[0]) }),
outputs, graph);
- Node *compressNode = graph->AddNode(UniqueNodeNameStorage::GetUniqueNodeName(src), "Compress", "",
+ Node *compressNode = &graph->AddNode(UniqueNodeNameStorage::GetUniqueNodeName(src), "Compress", "",
{ gatherPackedInputs[0], const_cast(castScreezeNode->OutputDefs()[0]) }, outputs);
int64_t sequenceAxis = 0;
compressNode->AddAttribute("axis", sequenceAxis);
@@ -4060,7 +4060,7 @@ onnxruntime::Node* CNTKToONNXHelper::ExtractShapeWithDynamicAxes(onnxruntime::Gr
// prepare output NodeArg with shape of [sequence, batch]
onnx::TypeProto typeProto = MakeTypeProtoWithShape();
- onnx::TensorProto_DataType elemType = inputNodeArg->TypeAsProto()->tensor_type().elem_type();
+ google::protobuf::int32 elemType = inputNodeArg->TypeAsProto()->tensor_type().elem_type();
typeProto.mutable_tensor_type()->set_elem_type(elemType);
onnxruntime::NodeArg &outputNodeArg = graph->GetOrCreateNodeArg(nodeName + "_output", &typeProto);
@@ -4071,7 +4071,7 @@ onnxruntime::Node* CNTKToONNXHelper::ExtractShapeWithDynamicAxes(onnxruntime::Gr
shapeProto.add_dim()->set_dim_value(BatchSizeProcessor::FreeBatchSize());
outputNodeArg.SetShape(shapeProto);
- Node *constantNode = graph->AddNode(nodeName + "_constant_like", "ConstantLike", "",
+ Node *constantNode = &graph->AddNode(nodeName + "_constant_like", "ConstantLike", "",
{ const_cast(squeezeNode->OutputDefs().at(0)) }, { &outputNodeArg });
constantNode->AddAttribute("value", (float)0);
return constantNode;
@@ -4136,7 +4136,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateDynamicAxisPackUnpackNode(const Funct
std::generate(perm.begin(), perm.end(), [axis = 0]() mutable { return axis++; });
// transpose sequence and batch axes
std::swap(perm[0], perm[1]);
- Node* transposeNode = graph->AddNode(transposeNodeName, "Transpose", "", inputs, { outputs[0] });
+ Node* transposeNode = &graph->AddNode(transposeNodeName, "Transpose", "", inputs, { outputs[0] });
transposeNode->AddAttribute("perm", perm);
functionNodes.emplace(src, transposeNode);
@@ -4175,7 +4175,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateDynamicAxisPackUnpackNode(const Funct
outputs[1]->SetShape(shapeProto);
}
- Node *constantNode = graph->AddNode(transposeNodeName + "_constant_like", "ConstantLike", "",
+ Node *constantNode = &graph->AddNode(transposeNodeName + "_constant_like", "ConstantLike", "",
{ const_cast(squeezeNode->OutputDefs().at(0)) }, { outputs[1] });
constantNode->AddAttribute("value", (float)1);
}
@@ -4185,7 +4185,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateDynamicAxisPackUnpackNode(const Funct
{
std::string identityNodeName = UniqueNodeNameStorage::GetUniqueNodeName(src);
// this is the original output of CNTK UnpackSequence op which is just an identity in ONNX.
- Node *identityNode = graph->AddNode(identityNodeName, "Identity", "", inputs, { outputs[0] });
+ Node *identityNode = &graph->AddNode(identityNodeName, "Identity", "", inputs, { outputs[0] });
functionNodes.emplace(src, identityNode);
return identityNode;
}
@@ -4325,7 +4325,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateSequenceSliceNode(const FunctionPtr&
onnxruntime::NodeArg &outputNodeArg = graph->GetOrCreateNodeArg(sliceOutputName, &outputArgType);
const std::string & nodeName = ToLegacyString(ToUTF8(src->Name()));
- onnxruntime::Node *sequenceSliceNode = graph->AddNode(nodeName, "Slice", "", { inputs[inputs.size() - 1] }, { &outputNodeArg });
+ onnxruntime::Node *sequenceSliceNode = &graph->AddNode(nodeName, "Slice", "", { inputs[inputs.size() - 1] }, { &outputNodeArg });
sequenceSliceNode->AddAttribute("axes", std::vector({ int64_t(0) }));
sequenceSliceNode->AddAttribute("ends", std::vector({ endIndex }));
sequenceSliceNode->AddAttribute("starts", std::vector({ beginIndex }));
@@ -4740,7 +4740,7 @@ bool CNTKToONNXHelper::ProcessLoopsAndCheckCNTKNodeContinueCreate(const Function
scanGraph.SetInputOrder(scanSubgraphOrderedInputs);
scanGraph.SetOutputOrder(scanSubgraphOrderedOutputs);
- Node *scanNode = graph->AddNode(scanNodeName, "Scan", "", input_args, output_args);
+ Node *scanNode = &graph->AddNode(scanNodeName, "Scan", "", input_args, output_args);
ResolveGraphAndSaveModel(scanSubModel.get());
@@ -5029,7 +5029,7 @@ bool TryMatchNodeArgType(onnx::TypeProto &argType, onnxruntime::Graph* graph, co
const NodeArg* inputNodeArg = graph->GetNodeArg(nodeArgName);
if (inputNodeArg)
{
- onnx::TensorProto_DataType inputType = inputNodeArg->TypeAsProto()->tensor_type().elem_type();
+ google::protobuf::int32 inputType = inputNodeArg->TypeAsProto()->tensor_type().elem_type();
argType.mutable_tensor_type()->set_elem_type(inputType);
return true;
@@ -5743,7 +5743,7 @@ void CNTKToONNXHelper::ProcessOutputs(const FunctionPtr& src,
outputArgNodeName + "_post_cast_input", &castInputArgType);
onnxruntime::NodeArg &castOutputArg = graph->GetOrCreateNodeArg(outputArgNodeName, &outputArgType);
- onnxruntime::Node* castNode = graph->AddNode(castInputArg.Name() + string("_cast_") + outputArgNodeName,
+ onnxruntime::Node* castNode = &graph->AddNode(castInputArg.Name() + string("_cast_") + outputArgNodeName,
"Cast", "", { &castInputArg }, { &castOutputArg });
castNode->AddAttribute("to", (int64_t)cntk_type);
@@ -5878,13 +5878,13 @@ void CNTKToONNXHelper::PostProcessGraph(onnxruntime::Graph* graph)
std::vector inputs;
std::vector outputs;
- maxPoolNode->ForEachDef([&inputs, &outputs](const NodeArg* def, bool isInput) {
- if (isInput) inputs.push_back(const_cast(def));
- else outputs.push_back(const_cast(def));
+ maxPoolNode->ForEachDef([&inputs, &outputs](const NodeArg& def, bool isInput) {
+ if (isInput) inputs.push_back(const_cast(&def));
+ else outputs.push_back(const_cast(&def));
});
outputs.push_back(&indicesOutputNodeArg);
- onnxruntime::Node* newMaxPoolNode = graph->AddNode(maxPoolNode->Name(), maxPoolNode->OpType(), maxPoolNode->Description(),
+ onnxruntime::Node* newMaxPoolNode = &graph->AddNode(maxPoolNode->Name(), maxPoolNode->OpType(), maxPoolNode->Description(),
inputs, outputs, &(maxPoolNode->GetAttributes()));
graph->RemoveNode(maxPoolNode->Index());
maxPoolNode = newMaxPoolNode;
@@ -6796,11 +6796,11 @@ onnxruntime::Node* CNTKToONNXHelper::AddNode(const FunctionPtr& src, onnxruntime
node = AddReshapeNodeImpl(graph, nodeName + "_output_reshape", &matMulOutputNodeArg, outputs[0], finalOutputShape);
}
else
- node = graph->AddNode(nodeName, ToOPName(src), "", { &inputOutput1Arg, &inputOutput2Arg }, outputs);
+ node = &graph->AddNode(nodeName, ToOPName(src), "", { &inputOutput1Arg, &inputOutput2Arg }, outputs);
}
else
{
- node = graph->AddNode(nodeName, ToOPName(src), "", orderedInputs, outputs);
+ node = &graph->AddNode(nodeName, ToOPName(src), "", orderedInputs, outputs);
}
}
else if (src->OpName() == L"LayerNormalization")
@@ -6819,26 +6819,26 @@ onnxruntime::Node* CNTKToONNXHelper::AddNode(const FunctionPtr& src, onnxruntime
onnx::TypeProto input0ArgType = ToTypeProto(src->Inputs()[operandIndexInCntkInputs].Shape(), src->Inputs()[operandIndexInCntkInputs].HasBatchAxis());
UpdateONNXType(src->Inputs()[operandIndexInCntkInputs].GetDataType(), input0ArgType);
onnxruntime::NodeArg &mvnTensorOutputArg = graph->GetOrCreateNodeArg(nodeName + string("_mvn_output0"), &input0ArgType);
- onnxruntime::Node* mvnNode = graph->AddNode(nodeName + string("_MVN"), "MeanVarianceNormalization",
+ onnxruntime::Node* mvnNode = &graph->AddNode(nodeName + string("_MVN"), "MeanVarianceNormalization",
"", { input0 }, { &mvnTensorOutputArg });
mvnNode->AddAttribute("across_channels", static_cast(1));
mvnNode->AddAttribute("normalize_variance", static_cast(1));
auto input1 = inputs[scaleIndexInOnnxInputs];
onnxruntime::NodeArg &mulTensorOutputArg = graph->GetOrCreateNodeArg(nodeName + string("_mul_output0"), &input0ArgType);
- onnxruntime::Node* mulNode = graph->AddNode(nodeName + string("_mul"), "Mul",
+ onnxruntime::Node* mulNode = &graph->AddNode(nodeName + string("_mul"), "Mul",
"", { &mvnTensorOutputArg, input1 }, { &mulTensorOutputArg });
auto input2 = inputs[biasIndexInOnnxInputs];
onnxruntime::NodeArg &addTensorOutputArg = graph->GetOrCreateNodeArg(nodeName + string("_Output_0"), &input0ArgType);
- node = graph->AddNode(nodeName + string("_add"), "Add",
+ node = &graph->AddNode(nodeName + string("_add"), "Add",
"", { &mulTensorOutputArg, input2 }, { &addTensorOutputArg });
}
else if (src->OpName() == L"LogPlus")
{
// CNTK LogPlus is the equivalent to numpy.logaddexp
// ONNX has a different but similar op: ReduceLogSumExp
- onnx::TensorProto_DataType tensorType = orderedInputs[0]->TypeAsProto()->tensor_type().elem_type();
+ google::protobuf::int32 tensorType = orderedInputs[0]->TypeAsProto()->tensor_type().elem_type();
std::vector broadcastShape = BroadcastInputs(orderedInputs, /*ignoreAxes=*/{}, src, graph);
// Now both inputs should have the same shape.
// Add another axis in front. This will be the axis to be reduced over later.
@@ -6852,7 +6852,7 @@ onnxruntime::Node* CNTKToONNXHelper::AddNode(const FunctionPtr& src, onnxruntime
onnx::TypeProto outputArgType = ToTypeProto(unsqueezeOutputShape, doReverseVec);
outputArgType.mutable_tensor_type()->set_elem_type(tensorType);
onnxruntime::NodeArg &unsqueezeTensorOutputArg = graph->GetOrCreateNodeArg(nodeName + string("_unsqueeze" + std::to_string(inputIndex) + "_output0"), &outputArgType);
- onnxruntime::Node* unsqueezeNode = graph->AddNode(nodeName + string("_Unsqueeze") + std::to_string(inputIndex), "Unsqueeze", "", { orderedInputs[inputIndex] }, { &unsqueezeTensorOutputArg });
+ onnxruntime::Node* unsqueezeNode = &graph->AddNode(nodeName + string("_Unsqueeze") + std::to_string(inputIndex), "Unsqueeze", "", { orderedInputs[inputIndex] }, { &unsqueezeTensorOutputArg });
unsqueezeNode->AddAttribute("axes", std::vector(1, 0));
return unsqueezeTensorOutputArg;
};
@@ -6863,14 +6863,14 @@ onnxruntime::Node* CNTKToONNXHelper::AddNode(const FunctionPtr& src, onnxruntime
onnx::TypeProto concatOutputArgType = ToTypeProto(concatOutputShape, false);
concatOutputArgType.mutable_tensor_type()->set_elem_type(tensorType);
onnxruntime::NodeArg &concatTensorOutputArg = graph->GetOrCreateNodeArg(nodeName + string("_concat_output0"), &concatOutputArgType);
- onnxruntime::Node* concatNode = graph->AddNode(nodeName + string("_Concat"), "Concat", "", { &unsqueezeTensorOutputArg0, &unsqueezeTensorOutputArg1 },
+ onnxruntime::Node* concatNode = &graph->AddNode(nodeName + string("_Concat"), "Concat", "", { &unsqueezeTensorOutputArg0, &unsqueezeTensorOutputArg1 },
{ &concatTensorOutputArg });
concatNode->AddAttribute("axis", static_cast(0));
onnx::TypeProto outputArgType = ToTypeProto(broadcastShape, false);
outputArgType.mutable_tensor_type()->set_elem_type(tensorType);
onnxruntime::NodeArg &reduceLogSumExpTensorOutputArg = graph->GetOrCreateNodeArg(nodeName + string("_Output_0"), &outputArgType);
- node = graph->AddNode(nodeName + string("_reduce_log_sum_exp"), "ReduceLogSumExp", "", { &concatTensorOutputArg }, { &reduceLogSumExpTensorOutputArg });
+ node = &graph->AddNode(nodeName + string("_reduce_log_sum_exp"), "ReduceLogSumExp", "", { &concatTensorOutputArg }, { &reduceLogSumExpTensorOutputArg });
// reduce over the first axis.
node->AddAttribute("axes", std::vector(1, 0));
node->AddAttribute("keepdims", static_cast(0));
@@ -6881,11 +6881,11 @@ onnxruntime::Node* CNTKToONNXHelper::AddNode(const FunctionPtr& src, onnxruntime
std::vector outputShape = ToINTS(*orderedInputs[1]->TypeAsProto());
onnxruntime::NodeArg &shapeInputArg = CreateAddShapeNodeArg(graph, outputShape, orderedInputs[1]->Name() + "_shape");
orderedInputs.push_back(&shapeInputArg);
- node = graph->AddNode(nodeName, ToOPName(src), "", orderedInputs, outputs);
+ node = &graph->AddNode(nodeName, ToOPName(src), "", orderedInputs, outputs);
}
else
{
- node = graph->AddNode(nodeName, ToOPName(src), "", orderedInputs, outputs);
+ node = &graph->AddNode(nodeName, ToOPName(src), "", orderedInputs, outputs);
}
}
@@ -7177,7 +7177,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateONNXNodesForOneHotOp(const FunctionPt
bool needsTransposeNode = !(onehotAxis == -1 || onehotAxis == static_cast(inputRank));
onnxruntime::NodeArg* oneHotOutputArg = needsTransposeNode ? &graph->GetOrCreateNodeArg(UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(nodeName + "_onehot_out"),
nullptr) : outputs[0];
- onnxruntime::Node* oneHotNode = graph->AddNode(nodeName, ToOPName(src), "", { inputs[0] }, { oneHotOutputArg }, nullptr, "ai.onnx.ml");
+ onnxruntime::Node* oneHotNode = &graph->AddNode(nodeName, ToOPName(src), "", { inputs[0] }, { oneHotOutputArg }, nullptr, "ai.onnx.ml");
std::vector catsVector(numClass);
std::iota(catsVector.begin(), catsVector.end(), 0);
@@ -7200,7 +7200,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateONNXNodesForOneHotOp(const FunctionPt
std::iota(permVector.begin(), permVector.end(), 0);
permVector.insert(permVector.begin() + onnxAxis, onnxOutputRank - 1);
onnxruntime::NodeArg &transposeOutputArg = graph->GetOrCreateNodeArg(UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(nodeName + "_transpose_out"), nullptr);
- outputNode = graph->AddNode(UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(nodeName + "_Transpose"), "Transpose", "", { oneHotOutputArg }, { outputs[0] });
+ outputNode = &graph->AddNode(UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(nodeName + "_Transpose"), "Transpose", "", { oneHotOutputArg }, { outputs[0] });
outputNode->AddAttribute("perm", permVector);
}
else
@@ -7233,11 +7233,11 @@ onnxruntime::Node* CNTKToONNXHelper::CreateONNXNodesForStraightThrough(const Fun
onnxruntime::NodeArg& scalarZeroOutputArg = CreateScalarNode(graph, UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(nodeName + "_zero_out"),
src->Inputs()[0].GetDataType(), 0.0);
onnxruntime::NodeArg& greaterOutputArg = graph->GetOrCreateNodeArg(UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(nodeName + "_greater_out"), nullptr);
- onnxruntime::Node* greaterNode = graph->AddNode(UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(nodeName + "_greater"),
+ onnxruntime::Node* greaterNode = &graph->AddNode(UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(nodeName + "_greater"),
"Greater", "", { inputs[0], &scalarZeroOutputArg }, { &greaterOutputArg });
onnxruntime::NodeArg& castOutputArg = graph->GetOrCreateNodeArg(UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(nodeName + "_cast_out"), nullptr);
- onnxruntime::Node* castNode = graph->AddNode(UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(nodeName + "_cat"),
+ onnxruntime::Node* castNode = &graph->AddNode(UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(nodeName + "_cat"),
"Cast", "", { &greaterOutputArg }, { &castOutputArg });
castNode->AddAttribute("to", static_cast(ConvertDataTypeCNTKToTensorProto(src->Inputs()[0].GetDataType())));
@@ -7245,13 +7245,13 @@ onnxruntime::Node* CNTKToONNXHelper::CreateONNXNodesForStraightThrough(const Fun
src->Inputs()[0].GetDataType(), 2.0);
onnxruntime::NodeArg& mulOutputArg = graph->GetOrCreateNodeArg(UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(nodeName + "_mul_out"), nullptr);
- onnxruntime::Node* mulNode = graph->AddNode(UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(nodeName + "_mul"),
+ onnxruntime::Node* mulNode = &graph->AddNode(UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(nodeName + "_mul"),
"Mul", "", { &castOutputArg, &scalarTwoOutputArg }, { &mulOutputArg });
onnxruntime::NodeArg& scalarOneOutputArg = CreateScalarNode(graph, UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(nodeName + "_one_out"),
src->Inputs()[0].GetDataType(), 1.0);
onnxruntime::NodeArg& subOutputArg = graph->GetOrCreateNodeArg(UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(nodeName + "_sub_out"), nullptr);
- onnxruntime::Node* subNode = graph->AddNode(UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(nodeName + "_sub"),
+ onnxruntime::Node* subNode = &graph->AddNode(UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(nodeName + "_sub"),
"Sub", "", { &mulOutputArg, &scalarOneOutputArg }, { outputs[0] });
functionNodes.emplace(src, subNode);
@@ -7367,7 +7367,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateONNXNodesForOptimizedRNNStack(const F
// ==== Step 6. Add ONNX LSTM node ====
auto rnnOpNameLookup = Operators::OptimizedRnnToOnnxOpLookup();
auto rnnNodeName = UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(ToLegacyString(ToUTF8(src->Uid())) + std::to_string(i));
- functionNode = graph->AddNode(rnnNodeName, rnnOpNameLookup[recurrentOp], "", inputs, outputs);
+ functionNode = &graph->AddNode(rnnNodeName, rnnOpNameLookup[recurrentOp], "", inputs, outputs);
std::vector singleDirectionActivation;
if (recurrentOp == L"lstm")
@@ -7645,7 +7645,7 @@ onnxruntime::NodeArg* CNTKToONNXHelper::LSTMOutputShapeAdapter(onnxruntime::Node
}
UpdateONNXType(outputType, transposeOutputArgType);
onnxruntime::NodeArg &transposeOutputArg = graph->GetOrCreateNodeArg(adapterBasename + "_Transpose_Output", &transposeOutputArgType);
- auto transposeNode = graph->AddNode(adapterBasename + "_Transpose", "Transpose", "", { &inputArg }, { &transposeOutputArg });
+ auto transposeNode = &graph->AddNode(adapterBasename + "_Transpose", "Transpose", "", { &inputArg }, { &transposeOutputArg });
transposeNode->AddAttribute("perm", x);
// Reshape to combine last two axes, i.e. [S, B, numDirections, hiddenSize] --> [S, B, numDirections*hiddenSize]
@@ -7695,7 +7695,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateBatchNormalization(const FunctionPtr
if (spatial)
{
// input and output are in correct shape.
- node = graph->AddNode(nodeName, "BatchNormalization", "", inputs, outputs);
+ node = &graph->AddNode(nodeName, "BatchNormalization", "", inputs, outputs);
}
else
{
@@ -7719,7 +7719,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateBatchNormalization(const FunctionPtr
src->Inputs()[0].Shape().TotalSize());
NodeArg &xFlattenOutput = graph->GetOrCreateNodeArg(inputs[0]->Name() + "_flatten_output", &xFlattenOutputTypeProto);
- Node *xFlattenNode = graph->AddNode(inputs[0]->Name() + "_flatten", "Flatten", "", { inputs[0] }, { &xFlattenOutput });
+ Node *xFlattenNode = &graph->AddNode(inputs[0]->Name() + "_flatten", "Flatten", "", { inputs[0] }, { &xFlattenOutput });
int64_t flattenAxis = src->Inputs()[0].DynamicAxes().size();
xFlattenNode->AddAttribute("axis", flattenAxis);
inputs[0] = const_cast(xFlattenNode->OutputDefs()[0]);
@@ -7740,7 +7740,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateBatchNormalization(const FunctionPtr
// TypeProto of BN's output is the same as its first input
onnxruntime::NodeArg *bnOutput = &graph->GetOrCreateNodeArg(outputs[0]->Name() + "_BN_output",
inputs[0]->TypeAsProto());
- node = graph->AddNode(nodeName, "BatchNormalization", "", inputs, { bnOutput });
+ node = &graph->AddNode(nodeName, "BatchNormalization", "", inputs, { bnOutput });
// output shape and name are the same
std::vector finalOutputShape = ToINTS(*outputs[0]->TypeAsProto());
Node *postBNReshapeNode = AddReshapeNode(const_cast(*node->OutputDefs()[0]),
@@ -7749,7 +7749,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateBatchNormalization(const FunctionPtr
else
{
// input x is not flattened.
- node = graph->AddNode(nodeName, "BatchNormalization", "", inputs, outputs);
+ node = &graph->AddNode(nodeName, "BatchNormalization", "", inputs, outputs);
}
}
@@ -7793,12 +7793,12 @@ onnxruntime::Node* CNTKToONNXHelper::CreateONNXNodesForTimesTranspose(const Func
int rightInputRank = inputs[0]->Shape()->dim_size() - 1;
onnxruntime::NodeArg &transposeOutputArg = graph->GetOrCreateNodeArg(UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(nodeName + "_transpose_out"), nullptr);
- onnxruntime::Node* transposeNode = graph->AddNode(UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(nodeName + "_transpose"),
+ onnxruntime::Node* transposeNode = &graph->AddNode(UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(nodeName + "_transpose"),
"Transpose", "", { inputs[0] }, { &transposeOutputArg });
transposeNode->AddAttribute("perm", ToINTS(rightInputRank == 2 ? vector({ 1, 2, 0 }) : vector({ 0, 1 })));
onnxruntime::NodeArg &matmulOutputArg = graph->GetOrCreateNodeArg(outputs[0]->Name(), nullptr);
- onnxruntime::Node* matmulNode = graph->AddNode(UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(nodeName + "_matmul"),
+ onnxruntime::Node* matmulNode = &graph->AddNode(UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(nodeName + "_matmul"),
"MatMul", "", { inputs[1], &transposeOutputArg }, { &matmulOutputArg });
functionNodes.emplace(src, matmulNode);
@@ -7843,7 +7843,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateONNXNodesForFlatten(const FunctionPtr
onnxruntime::Node* postReshapeNode = AddReshapeNodeImpl(graph, UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(nodeName + "_post_reshape"),
&postReshapeInputArg, outputs[0], ToINTS(outputReshapeOut));
- onnxruntime::Node* flattenNode = graph->AddNode(nodeName, ToOPName(src), "", { &preReshapeOutputArg }, { &postReshapeInputArg });
+ onnxruntime::Node* flattenNode = &graph->AddNode(nodeName, ToOPName(src), "", { &preReshapeOutputArg }, { &postReshapeInputArg });
CopyAttributes(src, flattenNode);
@@ -7874,7 +7874,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateSpliceNode(const FunctionPtr &src,
int64_t axisIndex = ConvertAxisToOnnxForSpliceWithWithBroadcast(axis, src);
BroadcastInputs(inputs, { axisIndex }, src, graph);
- Node *node = graph->AddNode(UniqueNodeNameStorage::GetUniqueNodeName(src), ToOPName(src), "", inputs, outputs);
+ Node *node = &graph->AddNode(UniqueNodeNameStorage::GetUniqueNodeName(src), ToOPName(src), "", inputs, outputs);
node->AddAttribute("axis", axisIndex);
@@ -7908,7 +7908,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateONNXNodesForSelect(const FunctionPtr
// Add a Clip node equivalent to min(abs(flag), 1).
onnxruntime::NodeArg &clipOutputArg = graph->GetOrCreateNodeArg(UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(nodeName + "_clip_out"), nullptr);
- onnxruntime::Node* clipNode = graph->AddNode(UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(nodeName + "_clip"),
+ onnxruntime::Node* clipNode = &graph->AddNode(UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(nodeName + "_clip"),
"Clip", "", { &absOutputArg }, { &clipOutputArg });
clipNode->AddAttribute("min", 0.0f); // Should be unnecesary for ONNX, but currently required by CNTK.
clipNode->AddAttribute("max", 1.0f);
@@ -7921,7 +7921,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateONNXNodesForSelect(const FunctionPtr
graph->AddNode(UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(nodeName + "_mul_true"), "Mul", "", { &ceilOutputArg, inputs[1] }, { &mulTrueOutputArg });
onnxruntime::NodeArg &oneOutputArg = graph->GetOrCreateNodeArg(UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(nodeName + "_one_out"), nullptr);
- onnxruntime::Node* oneNode = graph->AddNode(UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(nodeName + "_one"), "Constant", "", {}, { &oneOutputArg });
+ onnxruntime::Node* oneNode = &graph->AddNode(UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(nodeName + "_one"), "Constant", "", {}, { &oneOutputArg });
onnx::TensorProto oneTensor;
oneTensor.set_data_type(onnx::TensorProto::FLOAT);
oneTensor.add_float_data(1.0f);
@@ -7933,7 +7933,7 @@ onnxruntime::Node* CNTKToONNXHelper::CreateONNXNodesForSelect(const FunctionPtr
onnxruntime::NodeArg &mulFalseOutputArg = graph->GetOrCreateNodeArg(UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(nodeName + "_mul_false_out"), nullptr);
graph->AddNode(UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(nodeName + "_mul_false"), "Mul", "", { &oneSubOutputArg, inputs[2] }, { &mulFalseOutputArg });
- onnxruntime::Node* sumNode = graph->AddNode(UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(nodeName + "_sum"), "Sum", "", { &mulTrueOutputArg, &mulFalseOutputArg }, { outputs[0] });
+ onnxruntime::Node* sumNode = &graph->AddNode(UniqueNodeNameStorage::GetUniqueNodeNameWithoutUid(nodeName + "_sum"), "Sum", "", { &mulTrueOutputArg, &mulFalseOutputArg }, { outputs[0] });
functionNodes.emplace(src, sumNode);
return sumNode;
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp b/Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp
index 10769e047..49fb3c3e0 100644
--- a/Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp
+++ b/Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp
@@ -46,7 +46,7 @@ private:
static Constant CreateConstant(const onnx::TensorProto &valueProto, const std::string &nodeName,
const DeviceDescriptor &computeDevice);
template
- static const CNTK::Constant CreateConstantWithTensorData(CNTK::NDShape &shape, onnx::TensorProto_DataType tensorProtoDataType,
+ static const CNTK::Constant CreateConstantWithTensorData(CNTK::NDShape &shape, google::protobuf::int32 tensorProtoDataType,
CNTK::DataType cntkDataType, const TSrc *srcData, CNTK::NDShape &reversedShape,
const CNTK::DeviceDescriptor &computeDevice, const std::string &nodeName);
@@ -576,7 +576,7 @@ void CopyFromProto(const onnx::TensorProto &src, T &dst, int srcIndex, int dstIn
}
template
-const CNTK::Constant CNTK::ONNXToCNTKHelper::CreateConstantWithTensorData(CNTK::NDShape &shape, onnx::TensorProto_DataType tensorProtoDataType,
+const CNTK::Constant CNTK::ONNXToCNTKHelper::CreateConstantWithTensorData(CNTK::NDShape &shape, google::protobuf::int32 tensorProtoDataType,
CNTK::DataType cntkDataType, const TSrc *srcData, CNTK::NDShape &reversedShape, const CNTK::DeviceDescriptor &computeDevice, const std::string &nodeName)
{
auto totalSize = shape.TotalSize();
@@ -633,7 +633,7 @@ const Node *ONNXToCNTKHelper::GetChildNode(const Node *parentNode, const NodeArg
Node::NodeConstIterator itChildNode = parentNode->InputNodesBegin();
for (; itChildNode != parentNode->InputNodesEnd(); ++itChildNode)
{
- const Node *childNode = *itChildNode;
+ const Node *childNode = &(*itChildNode);
const ConstPointerContainer> &childOutputDefs = childNode->OutputDefs();
nodeArgIndex = 0;
for (ConstPointerContainer>::ConstIterator itChildOutput = childOutputDefs.begin();
@@ -3003,7 +3003,7 @@ std::pair FindParentAndChildIndex(const Node *node)
Node::NodeConstIterator it = node->OutputNodesBegin();
if (it != node->OutputNodesEnd())
{
- const Node *parent = *it;
+ const Node *parent = &(*it);
int index = 0;
for (auto nodeArg : parent->InputDefs())
{
@@ -3768,14 +3768,14 @@ std::pair> ONNXToCNTKHelper::CheckNodeBelongsToOp
Node::NodeConstIterator it = node->OutputNodesBegin();
if (it != node->OutputNodesEnd())
{
- firstParentNode = *it;
+ firstParentNode = &(*it);
}
if (firstParentNode != nullptr)
{
it = firstParentNode->OutputNodesBegin();
if (it != firstParentNode->OutputNodesEnd())
{
- grandParentNode = *it;
+ grandParentNode = &(*it);
}
}
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/RNNHelper.cpp b/Source/CNTKv2LibraryDll/proto/onnx/RNNHelper.cpp
index 81ff6b795..42d235cd0 100644
--- a/Source/CNTKv2LibraryDll/proto/onnx/RNNHelper.cpp
+++ b/Source/CNTKv2LibraryDll/proto/onnx/RNNHelper.cpp
@@ -3,7 +3,7 @@
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
-#include "proto/onnx/core/graph/model.h"
+#include "proto/onnx/onnxruntime/onnxruntime/core/graph/model.h"
#include "RNNHelper.h"
#include "Operators.h"
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/capture.cc b/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/capture.cc
deleted file mode 100644
index 93dbe9f8a..000000000
--- a/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/capture.cc
+++ /dev/null
@@ -1,51 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#include "core/common/logging/capture.h"
-#include "core/common/logging/logging.h"
-#include "gsl/span"
-#include "gsl/gsl_util"
-
-namespace onnxruntime {
-namespace logging {
-
-void Capture::CapturePrintf(msvc_printf_check const char* format, ...) {
- va_list arglist;
- va_start(arglist, format);
-
- ProcessPrintf(format, arglist);
-
- va_end(arglist);
-}
-
-// from https://github.com/KjellKod/g3log/blob/master/src/logcapture.cpp LogCapture::capturef
-// License: https://github.com/KjellKod/g3log/blob/master/LICENSE
-void Capture::ProcessPrintf(msvc_printf_check const char* format, va_list args) {
- static constexpr auto kTruncatedWarningText = "[...truncated...]";
- static const int kMaxMessageSize = 2048;
- char message_buffer[kMaxMessageSize];
- const auto message = gsl::make_span(message_buffer);
-
-#if (defined(WIN32) || defined(_WIN32) || defined(__WIN32__) && !defined(__GNUC__))
- const int nbrcharacters = vsnprintf_s(message.data(), message.size(), _TRUNCATE, format, args);
-#else
- const int nbrcharacters = vsnprintf(message.data(), message.size(), format, args);
-#endif
-
- if (nbrcharacters <= 0) {
- stream_ << "\n\tERROR LOG MSG NOTIFICATION: Failure to successfully parse the message";
- stream_ << '"' << format << '"' << std::endl;
- } else if (nbrcharacters > message.size()) {
- stream_ << message.data() << kTruncatedWarningText;
- } else {
- stream_ << message.data();
- }
-}
-
-Capture::~Capture() {
- if (logger_ != nullptr) {
- logger_->Log(*this);
- }
-}
-} // namespace logging
-} // namespace onnxruntime
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/logging.cc b/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/logging.cc
deleted file mode 100644
index bc6521c07..000000000
--- a/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/logging.cc
+++ /dev/null
@@ -1,217 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#include
-#include
-#include
-
-#include "core/common/exceptions.h"
-#include "core/common/logging/isink.h"
-#include "core/common/logging/logging.h"
-
-#ifdef _WIN32
-#include
-#else
-#include
-#include
-#endif
-
-namespace onnxruntime {
-namespace logging {
-const char* Category::onnxruntime = "onnxruntime";
-const char* Category::System = "System";
-
-using namespace std::chrono;
-
-/*
-As LoggingManager can be a static, we need to wrap the default instance and mutex in functions
-to ensure they're initialized before use in LoggingManager::LoggingManager. If we don't, and
-a static LoggingManager is created at startup, the file scope statics here may not have been
-initialized.
-*/
-
-static std::atomic& DefaultLoggerManagerInstance() noexcept {
- // this atomic is to protect against attempts to log being made after the default LoggingManager is destroyed.
- // Theoretically this can happen if a Logger instance is still alive and calls Log via its internal
- // pointer to the LoggingManager.
- // As the first thing LoggingManager::Log does is check the static DefaultLoggerManagerInstance() is not null,
- // any further damage should be prevented (in theory).
- static std::atomic default_instance;
- return default_instance;
-}
-
-// GSL_SUPRESS(i.22) is broken. Ignore the warnings for the static local variables that are trivial
-// and should not have any destruction order issues via pragmas instead.
-// https://developercommunity.visualstudio.com/content/problem/249706/gslsuppress-does-not-work-for-i22-c-core-guideline.html
-#ifdef _MSC_VER
-#pragma warning(push)
-#pragma warning(disable : 26426)
-#endif
-
-static std::mutex& DefaultLoggerMutex() noexcept {
- static std::mutex mutex;
- return mutex;
-}
-
-std::unique_ptr& LoggingManager::GetDefaultLogger() noexcept {
- static std::unique_ptr default_logger;
- return default_logger;
-}
-
-#ifdef _MSC_VER
-#pragma warning(pop)
-#endif
-
-static minutes InitLocaltimeOffset(const time_point& epoch) noexcept;
-
-const LoggingManager::Epochs& LoggingManager::GetEpochs() noexcept {
- // we save the value from system clock (which we can convert to a timestamp) as well as the high_resolution_clock.
- // from then on, we use the delta from the high_resolution_clock and apply that to the
- // system clock value.
- static Epochs epochs{high_resolution_clock::now(),
- system_clock::now(),
- InitLocaltimeOffset(system_clock::now())};
- return epochs;
-}
-
-LoggingManager::LoggingManager(std::unique_ptr sink, Severity default_min_severity, bool filter_user_data,
- const InstanceType instance_type, const std::string* default_logger_id,
- int default_max_vlog_level)
- : sink_{std::move(sink)},
- default_min_severity_{default_min_severity},
- default_filter_user_data_{filter_user_data},
- default_max_vlog_level_{default_max_vlog_level},
- owns_default_logger_{false} {
- if (!sink_) {
- throw std::logic_error("ISink must be provided.");
- }
-
- if (instance_type == InstanceType::Default) {
- if (default_logger_id == nullptr) {
- throw std::logic_error("default_logger_id must be provided if instance_type is InstanceType::Default");
- }
-
- // lock mutex to create instance, and enable logging
- // this matches the mutex usage in Shutdown
- std::lock_guard guard(DefaultLoggerMutex());
-
- if (DefaultLoggerManagerInstance().load() != nullptr) {
- throw std::logic_error("Only one instance of LoggingManager created with InstanceType::Default can exist at any point in time.");
- }
-
- // This assertion passes, so using the atomic to validate calls to Log should
- // be reasonably economical.
- // assert(DefaultLoggerManagerInstance().is_lock_free());
- DefaultLoggerManagerInstance().store(this);
-
- CreateDefaultLogger(*default_logger_id);
-
- owns_default_logger_ = true;
- }
-}
-
-LoggingManager::~LoggingManager() {
- if (owns_default_logger_) {
- // lock mutex to reset DefaultLoggerManagerInstance() and free default logger from this instance.
- std::lock_guard guard(DefaultLoggerMutex());
-
- DefaultLoggerManagerInstance().store(nullptr, std::memory_order::memory_order_release);
-
- GetDefaultLogger().reset();
- }
-}
-
-void LoggingManager::CreateDefaultLogger(const std::string& logger_id) {
- // this method is only called from ctor in scope where DefaultLoggerMutex() is already locked
-
- std::unique_ptr& default_logger{GetDefaultLogger()};
-
- if (default_logger != nullptr) {
- throw std::logic_error("Default logger already set. ");
- }
-
- default_logger = CreateLogger(logger_id);
-}
-
-std::unique_ptr LoggingManager::CreateLogger(std::string logger_id) {
- return CreateLogger(std::move(logger_id), default_min_severity_, default_filter_user_data_, default_max_vlog_level_);
-}
-
-std::unique_ptr LoggingManager::CreateLogger(std::string logger_id,
- const Severity severity,
- bool filter_user_data,
- int vlog_level) {
- auto logger = std::make_unique(*this, logger_id, severity, filter_user_data, vlog_level);
- return logger;
-}
-
-void LoggingManager::Log(const std::string& logger_id, const Capture& message) const {
- sink_->Send(GetTimestamp(), logger_id, message);
-}
-
-static minutes InitLocaltimeOffset(const time_point& epoch) noexcept {
- // convert the system_clock time_point (UTC) to localtime and gmtime to calculate the difference.
- // we do this once, and apply that difference in GetTimestamp().
- // NOTE: If we happened to be running over a period where the time changed (e.g. daylight saving started)
- // we won't pickup the change. Not worth the extra cost to be 100% accurate 100% of the time.
-
- const time_t system_time_t = system_clock::to_time_t(epoch);
- tm local_tm;
- tm utc_tm;
-
-#ifdef _WIN32
- localtime_s(&local_tm, &system_time_t);
- gmtime_s(&utc_tm, &system_time_t);
-#else
- localtime_r(&system_time_t, &local_tm);
- gmtime_r(&system_time_t, &utc_tm);
-#endif
-
- const double seconds = difftime(mktime(&local_tm), mktime(&utc_tm));
-
- // minutes should be accurate enough for timezone conversion
- return minutes{static_cast(seconds / 60)};
-}
-
-std::exception LoggingManager::LogFatalAndCreateException(const char* category,
- const CodeLocation& location,
- const char* format_str, ...) {
- std::string exception_msg;
-
- // create Capture in separate scope so it gets destructed (leading to log output) before we throw.
- {
- ::onnxruntime::logging::Capture c{::onnxruntime::logging::LoggingManager::DefaultLogger(),
- ::onnxruntime::logging::Severity::kFATAL, category, ::onnxruntime::logging::DataType::SYSTEM, location};
- va_list args;
- va_start(args, format_str);
-
- c.ProcessPrintf(format_str, args);
- va_end(args);
-
- exception_msg = c.Message();
- }
-
- return OnnxRuntimeException(location, exception_msg);
-}
-
-unsigned int GetThreadId() {
-#ifdef _WIN32
- return static_cast(GetCurrentThreadId());
-#else
- return static_cast(syscall(SYS_gettid));
-#endif
-}
-
-//
-// Get current process id
-//
-unsigned int GetProcessId() {
-#ifdef _WIN32
- return static_cast(GetCurrentProcessId());
-#else
- return static_cast(syscall(SYS_getpid));
-#endif
-}
-
-} // namespace logging
-} // namespace onnxruntime
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/sinks/cerr_sink.h b/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/sinks/cerr_sink.h
deleted file mode 100644
index 42577ba26..000000000
--- a/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/sinks/cerr_sink.h
+++ /dev/null
@@ -1,21 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#pragma once
-
-#include
-#include "core/common/logging/sinks/ostream_sink.h"
-
-namespace onnxruntime {
-namespace logging {
-///
-/// A std::cerr based ISink
-///
-///
-class CErrSink : public OStreamSink {
- public:
- CErrSink() : OStreamSink(std::cerr, /*flush*/ false) { // std::cerr isn't buffered so no flush required
- }
-};
-} // namespace logging
-} // namespace onnxruntime
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/sinks/clog_sink.h b/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/sinks/clog_sink.h
deleted file mode 100644
index 9b0adf92f..000000000
--- a/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/sinks/clog_sink.h
+++ /dev/null
@@ -1,21 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#pragma once
-
-#include
-#include "core/common/logging/sinks/ostream_sink.h"
-
-namespace onnxruntime {
-namespace logging {
-///
-/// A std::clog based ISink
-///
-///
-class CLogSink : public OStreamSink {
- public:
- CLogSink() : OStreamSink(std::clog, /*flush*/ true) {
- }
-};
-} // namespace logging
-} // namespace onnxruntime
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/sinks/composite_sink.h b/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/sinks/composite_sink.h
deleted file mode 100644
index f27abb9e6..000000000
--- a/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/sinks/composite_sink.h
+++ /dev/null
@@ -1,46 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#pragma once
-
-#include
-#include
-
-#include "core/common/logging/isink.h"
-#include "core/common/logging/logging.h"
-
-namespace onnxruntime {
-namespace logging {
-///
-/// Class that abstracts multiple ISink instances being written to.
-///
-///
-class CompositeSink : public ISink {
- public:
- ///
- /// Initializes a new instance of the class.
- /// Use AddSink to add sinks.
- ///
- CompositeSink() {}
-
- ///
- /// Adds a sink. Takes ownership of the sink (so pass unique_ptr by value).
- ///
- /// The sink.
- /// This instance to allow chaining.
- CompositeSink& AddSink(std::unique_ptr sink) {
- sinks_.push_back(std::move(sink));
- return *this;
- }
-
- private:
- void SendImpl(const Timestamp& timestamp, const std::string& logger_id, const Capture& message) override {
- for (auto& sink : sinks_) {
- sink->Send(timestamp, logger_id, message);
- }
- }
-
- std::vector> sinks_;
-};
-} // namespace logging
-} // namespace onnxruntime
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/sinks/file_sink.h b/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/sinks/file_sink.h
deleted file mode 100644
index ba3ff3e0b..000000000
--- a/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/sinks/file_sink.h
+++ /dev/null
@@ -1,51 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#pragma once
-
-#include
-#include "core/common/logging/sinks/ostream_sink.h"
-
-namespace onnxruntime {
-namespace logging {
-///
-/// ISink that writes to a file.
-///
-///
-class FileSink : public OStreamSink {
- public:
- ///
- /// Initializes a new instance of the class.
- ///
- /// The filename to write to.
- /// If set to true [append to file]. Otherwise truncate.
- /// If set to true [removes user data].
- /// Filtering of user data can alternatively be done at the level.
- FileSink(std::unique_ptr file, bool filter_user_data)
- : OStreamSink(*file, /*flush*/ true), file_(std::move(file)), filter_user_data_{filter_user_data} {
- }
-
- ///
- /// Initializes a new instance of the class.
- ///
- /// The filename to write to.
- /// If set to true [append to file]. Otherwise truncate.
- /// If set to true [removes user data].
- /// Filtering of user data can alternatively be done at the level.
- FileSink(const std::string& filename, bool append, bool filter_user_data)
- : FileSink{std::make_unique(filename, std::ios::out | (append ? std::ios::app : std::ios::trunc)),
- filter_user_data} {
- }
-
- private:
- void SendImpl(const Timestamp& timestamp, const std::string& logger_id, const Capture& message) override {
- if (!filter_user_data_ || message.DataType() != DataType::USER) {
- OStreamSink::SendImpl(timestamp, logger_id, message);
- }
- }
-
- std::unique_ptr file_;
- bool filter_user_data_;
-};
-} // namespace logging
-} // namespace onnxruntime
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/sinks/ostream_sink.h b/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/sinks/ostream_sink.h
deleted file mode 100644
index bf5cec174..000000000
--- a/Source/CNTKv2LibraryDll/proto/onnx/core/common/logging/sinks/ostream_sink.h
+++ /dev/null
@@ -1,33 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#pragma once
-
-#include
-#include
-#include
-
-#include "core/common/logging/capture.h"
-#include "core/common/logging/isink.h"
-
-namespace onnxruntime {
-namespace logging {
-///
-/// A std::ostream based ISink
-///
-///
-class OStreamSink : public ISink {
- protected:
- OStreamSink(std::ostream& stream, bool flush)
- : stream_{&stream}, flush_{flush} {
- }
-
- public:
- void SendImpl(const Timestamp& timestamp, const std::string& logger_id, const Capture& message) override;
-
- private:
- std::ostream* stream_;
- const bool flush_;
-};
-} // namespace logging
-} // namespace onnxruntime
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/common/profiler.cc b/Source/CNTKv2LibraryDll/proto/onnx/core/common/profiler.cc
deleted file mode 100644
index 146671994..000000000
--- a/Source/CNTKv2LibraryDll/proto/onnx/core/common/profiler.cc
+++ /dev/null
@@ -1,87 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#include "profiler.h"
-
-namespace onnxruntime {
-namespace profiling {
-using namespace std::chrono;
-
-::onnxruntime::TimePoint profiling::Profiler::StartTime() const {
- return std::chrono::high_resolution_clock::now();
-}
-
-void Profiler::StartProfiling(const logging::Logger* session_logger, const std::string& file_name) {
- ONNXRUNTIME_ENFORCE(session_logger != nullptr);
- session_logger_ = session_logger;
- enabled_ = true;
- profile_stream_ = std::ofstream(file_name, std::ios::out | std::ios::trunc);
- profile_stream_file_ = file_name;
- profiling_start_time_ = StartTime();
-}
-
-void Profiler::EndTimeAndRecordEvent(EventCategory category,
- const std::string& event_name,
- TimePoint& start_time,
- std::unordered_map&& event_args,
- bool /*sync_gpu*/) {
- if (!enabled_)
- return;
- //TODO: sync_gpu if needed.
- std::lock_guard lock(mutex_);
- if (events_.size() < max_num_events_) {
- long long dur = TimeDiffMicroSeconds(start_time);
- long long ts = TimeDiffMicroSeconds(profiling_start_time_, start_time);
- events_.emplace_back(category, logging::GetProcessId(),
- logging::GetThreadId(), event_name, ts, dur, std::move(event_args));
- } else {
- if (session_logger_ && !max_events_reached) {
- LOGS(*session_logger_, ERROR)
- << "Maximum number of events reached, could not record profile event.";
- max_events_reached = true;
- }
- }
-}
-
-std::string Profiler::WriteProfileData() {
- std::lock_guard lock(mutex_);
- profile_stream_ << "[\n";
-
- for (size_t i = 0; i < events_.size(); ++i) {
- auto& rec = events_[i];
- profile_stream_ << R"({"cat" : ")" << event_categor_names_[rec.cat] << "\",";
- profile_stream_ << "\"pid\" :" << rec.pid << ",";
- profile_stream_ << "\"tid\" :" << rec.tid << ",";
- profile_stream_ << "\"dur\" :" << rec.dur << ",";
- profile_stream_ << "\"ts\" :" << rec.ts << ",";
- profile_stream_ << R"("ph" : "X",)";
- profile_stream_ << R"("name" :")" << rec.name << "\",";
- profile_stream_ << "\"args\" : {";
- bool is_first_arg = true;
- for (std::pair event_arg : rec.args) {
- if (!is_first_arg) profile_stream_ << ",";
- profile_stream_ << "\"" << event_arg.first << "\" : \"" << event_arg.second << "\"";
- is_first_arg = false;
- }
- profile_stream_ << "}";
- if (i == events_.size() - 1) {
- profile_stream_ << "}\n";
- } else {
- profile_stream_ << "},\n";
- }
- }
- profile_stream_ << "]\n";
- profile_stream_.close();
- enabled_ = false; // will not collect profile after writing.
- return profile_stream_file_;
-}
-
-//
-// Conditionally sync the GPU if the syncGPU flag is set.
-//
-void ProfilerSyncGpu() {
- ONNXRUNTIME_NOT_IMPLEMENTED("Needs to implement only for gpus");
-}
-
-} // namespace profiling
-} // namespace onnxruntime
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/common/profiler.h b/Source/CNTKv2LibraryDll/proto/onnx/core/common/profiler.h
deleted file mode 100644
index 3470677f3..000000000
--- a/Source/CNTKv2LibraryDll/proto/onnx/core/common/profiler.h
+++ /dev/null
@@ -1,102 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#pragma once
-#include
-#include
-#include "core/common/logging/logging.h"
-
-namespace onnxruntime {
-
-namespace profiling {
-
-enum EventCategory {
- SESSION_EVENT = 0,
- NODE_EVENT,
- EVENT_CATEGORY_MAX
-};
-
-/*
-Event descriptions for the above session events.
-*/
-static constexpr const char* event_categor_names_[EVENT_CATEGORY_MAX] = {
- "Session",
- "Node"};
-
-/*
-Timing record for all events.
-*/
-struct EventRecord {
- EventRecord(EventCategory category,
- int process_id,
- int thread_id,
- std::string event_name,
- long long time_stamp,
- long long duration,
- std::unordered_map&& event_args) : cat(category),
- pid(process_id),
- tid(thread_id),
- name(std::move(event_name)),
- ts(time_stamp),
- dur(duration),
- args(event_args) {}
- EventCategory cat;
- int pid;
- int tid;
- std::string name;
- long long ts;
- long long dur;
- std::unordered_map args;
-};
-
-/*
-Main class for profiling. It continues to accumulate events and produce
-a corresponding "complete event (X)" in "chrome tracing" format.
-*/
-class Profiler {
- public:
- Profiler() noexcept {}; // turned off by default.
-
- /*
- Start profiler and record beginning time.
- */
- void StartProfiling(const logging::Logger* session_logger, const std::string& file_name);
-
- /*
- Produce current time point for any profiling action.
- */
- TimePoint StartTime() const;
-
- /*
- Record a single event. Time is measured till the call of this function from
- the start_time.
- */
- void EndTimeAndRecordEvent(EventCategory category,
- const std::string& event_name,
- TimePoint& start_time,
- std::unordered_map&& event_args = std::unordered_map(),
- bool sync_gpu = false);
-
- /*
- Write profile data to the given stream in chrome format defined below.
- https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview#
- */
- std::string WriteProfileData();
-
- private:
- ONNXRUNTIME_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Profiler);
-
- // Mutex controlling access to profiler data
- std::mutex mutex_;
- bool enabled_{false};
- std::ofstream profile_stream_;
- std::string profile_stream_file_;
- const logging::Logger* session_logger_{nullptr};
- TimePoint profiling_start_time_;
- std::vector events_;
- bool max_events_reached{false};
- static constexpr size_t max_num_events_ = 1000000;
-};
-
-} // namespace profiling
-} // namespace onnxruntime
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/common/status.cc b/Source/CNTKv2LibraryDll/proto/onnx/core/common/status.cc
deleted file mode 100644
index 85b486b81..000000000
--- a/Source/CNTKv2LibraryDll/proto/onnx/core/common/status.cc
+++ /dev/null
@@ -1,84 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#include "core/common/status.h"
-#include "core/common/common.h"
-
-namespace onnxruntime {
-namespace common {
-Status::Status(StatusCategory category, int code, const std::string& msg) {
- // state_ will be allocated here causing the status to be treated as a failure
- ONNXRUNTIME_ENFORCE(code != static_cast(MLStatus::OK));
-
- state_ = std::make_unique(category, code, msg);
-}
-
-Status::Status(StatusCategory category, int code)
- : Status(category, code, EmptyString()) {
-}
-
-bool Status::IsOK() const noexcept {
- return (state_ == nullptr);
-}
-
-StatusCategory Status::Category() const noexcept {
- return IsOK() ? common::NONE : state_->category;
-}
-
-int Status::Code() const noexcept {
- return IsOK() ? static_cast(common::OK) : state_->code;
-}
-
-const std::string& Status::ErrorMessage() const noexcept {
- return IsOK() ? EmptyString() : state_->msg;
-}
-
-std::string Status::ToString() const {
- if (state_ == nullptr) {
- return std::string("OK");
- }
-
- std::string result;
-
- if (common::SYSTEM == state_->category) {
- result += "SystemError";
- result += " : ";
- result += std::to_string(errno);
- } else if (common::ONNXRUNTIME == state_->category) {
- result += "[LotusError]";
- result += " : ";
- result += std::to_string(Code());
- std::string msg;
-
- result += " : ";
- result += MLStatusToString(static_cast(Code()));
- result += " : ";
- result += state_->msg;
- }
-
- return result;
-}
-
-// GSL_SUPRESS(i.22) is broken. Ignore the warnings for the static local variables that are trivial
-// and should not have any destruction order issues via pragmas instead.
-// https://developercommunity.visualstudio.com/content/problem/249706/gslsuppress-does-not-work-for-i22-c-core-guideline.html
-#ifdef _MSC_VER
-#pragma warning(push)
-#pragma warning(disable : 26426)
-#endif
-const Status& Status::OK() noexcept {
- static Status s_ok;
- return s_ok;
-}
-
-const std::string& Status::EmptyString() noexcept {
- static std::string s_empty;
- return s_empty;
-}
-
-#ifdef _MSC_VER
-#pragma warning(pop)
-#endif
-
-} // namespace common
-} // namespace onnxruntime
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/common/task_thread_pool.h b/Source/CNTKv2LibraryDll/proto/onnx/core/common/task_thread_pool.h
deleted file mode 100644
index 217c65189..000000000
--- a/Source/CNTKv2LibraryDll/proto/onnx/core/common/task_thread_pool.h
+++ /dev/null
@@ -1,203 +0,0 @@
-/**
- * Copyright (c) 2016-present, Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/*
-Changed to use std::packaged_task instead of std::function so exceptions can be propagated.
-
-This also allows the task threadpool to be shared across multiple operators as the caller
-can keep a container of the packaged_task futures to check when they have completed. Calling
-WaitWorkComplete in that use case is invalid as there may be other concurrent usage of the
-threadpool.
-
-Example of that usage:
-
- std::vector> task_results{};
-
- for (...) {
- std::packaged_task task{std::bind(lambda, i)};
- task_results.push_back(task.get_future());
- task_thread_pool.RunTask(std::move(task));
- }
-
- try {
- // wait for all and propagate any exceptions
- for (auto& future : task_results)
- future.get();
- } catch (const std::exception& ex) {
- ...
- throw;
- }
-
-*/
-
-#pragma once
-
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-
-#include "core/common/common.h"
-#include "core/common/logging/logging.h"
-
-namespace onnxruntime {
-
-class TaskThreadPool {
- private:
- struct task_element_t {
- bool run_with_id;
- std::packaged_task no_id;
- std::packaged_task with_id;
-
- task_element_t(task_element_t&& other) {
- run_with_id = other.run_with_id;
- no_id = std::move(other.no_id);
- with_id = std::move(other.with_id);
- }
-
- explicit task_element_t(std::packaged_task&& f)
- : run_with_id(false), no_id(std::move(f)) {}
-
- explicit task_element_t(std::packaged_task&& f)
- : run_with_id(true), with_id(std::move(f)) {}
- };
-
- std::queue tasks_;
- std::vector threads_;
- std::mutex mutex_;
- std::condition_variable condition_;
- std::condition_variable completed_;
- bool running_;
- bool complete_;
- std::size_t available_;
- std::size_t total_;
-
- public:
- /// @brief Constructor.
- explicit TaskThreadPool(std::size_t pool_size)
- : threads_(pool_size), running_(true), complete_(true), available_(pool_size), total_(pool_size) {
- for (std::size_t i = 0; i < pool_size; ++i) {
- threads_[i] = std::thread(std::bind(&TaskThreadPool::MainLoop, this, i));
- }
- }
-
- /// @brief Destructor.
- ~TaskThreadPool() {
- // Set running flag to false then notify all threads.
- {
- std::unique_lock lock(mutex_);
- running_ = false;
- condition_.notify_all();
- }
-
- try {
- for (auto& t : threads_) {
- t.join();
- }
- }
- // Suppress all exceptions.
- catch (const std::exception& ex) {
- LOGS_DEFAULT(ERROR) << "Exception joining threads in TaskThreadPool: " << ex.what();
- }
- }
-
- void RunTask(std::packaged_task&& task) {
- std::unique_lock lock(mutex_);
-
- // Set task and signal condition variable so that a worker thread will
- // wake up and use the task.
- tasks_.push(task_element_t(std::move(task)));
- complete_ = false;
- condition_.notify_one();
- }
-
- void RunTaskWithID(std::packaged_task&& task) {
- std::unique_lock lock(mutex_);
-
- // Set task and signal condition variable so that a worker thread will
- // wake up and use the task.
- tasks_.push(task_element_t(std::move(task)));
- complete_ = false;
- condition_.notify_one();
- }
-
- /// @brief Wait for queue to be empty
- void WaitWorkComplete() {
- std::unique_lock lock(mutex_);
- while (!complete_)
- completed_.wait(lock);
- }
-
- private:
- ONNXRUNTIME_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TaskThreadPool);
-
- /// @brief Entry point for pool threads.
- void MainLoop(std::size_t index) {
- while (running_) {
- // Wait on condition variable while the task is empty and
- // the pool is still running.
- std::unique_lock lock(mutex_);
- while (tasks_.empty() && running_) {
- condition_.wait(lock);
- }
-
- // If pool is no longer running, break out of loop.
- if (!running_) break;
-
- // Copy task locally and remove from the queue. This is
- // done within its own scope so that the task object is
- // destructed immediately after running the task. This is
- // useful in the event that the function contains
- // shared_ptr arguments bound via bind.
- {
- auto task = std::move(tasks_.front());
- tasks_.pop();
- // Decrement count, indicating thread is no longer available.
- --available_;
-
- lock.unlock();
-
- // Run the task.
- try {
- if (task.run_with_id) {
- task.with_id(index);
- } else {
- task.no_id();
- }
- } catch (const std::exception& /*ex*/) {
- // LOGS_DEFAULT(ERROR) << "Exception running TaskThreadPool task: " << ex.what();
- throw;
- }
-
- // Update status of empty, maybe
- // Need to recover the lock first
- lock.lock();
-
- // Increment count, indicating thread is available.
- ++available_;
- if (tasks_.empty() && available_ == total_) {
- complete_ = true;
- completed_.notify_one();
- }
- }
- } // while running_
- }
-};
-
-} // namespace onnxruntime
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/framework/tensorutils.cc b/Source/CNTKv2LibraryDll/proto/onnx/core/framework/tensorutils.cc
deleted file mode 100644
index 3889a1024..000000000
--- a/Source/CNTKv2LibraryDll/proto/onnx/core/framework/tensorutils.cc
+++ /dev/null
@@ -1,231 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#ifdef _WIN32
-//std::copy only works for the same type(input/output must have the same type)
-//TODO(@chasun): remove std::copy from DEFINE_UNPACK_TENSOR
-#pragma warning(disable : 4244)
-#endif
-#include "core/framework/tensorutils.h"
-#include "core/framework/allocator.h"
-
-#include
-
-#include "core/graph/onnx_protobuf.h"
-
-#include "gsl/pointers"
-#include "gsl/span"
-
-#include "core/inc/op_kernel_author.h"
-
-GSL_SUPPRESS(type .1) // allow use of reinterpret_cast for this special case
-inline bool IsLittleEndianOrder() noexcept {
- static int n = 1;
- return (*reinterpret_cast(&n) == 1);
-}
-
-template
-static void UnpackTensorWithRawData(const ONNX_NAMESPACE::TensorProto& tensor, /*out*/ T* p_data) {
- // allow this low level routine to be somewhat unsafe. assuming it's thoroughly tested and valid
- GSL_SUPPRESS(type) // type.1 reinterpret-cast; type.4 C-style casts; type.5 'T result;' is uninitialized;
- GSL_SUPPRESS(bounds .1) // pointer arithmetic
- GSL_SUPPRESS(f .23) // buff and temp_bytes never tested for nullness and could be gsl::not_null
- {
- auto& raw_data = tensor.raw_data();
- auto buff = raw_data.c_str();
- const size_t type_size = sizeof(T);
-
- if (IsLittleEndianOrder()) {
- memcpy((void*)p_data, (void*)buff, raw_data.size() * sizeof(char));
- } else {
- for (size_t i = 0; i < raw_data.size(); i += type_size, buff += type_size) {
- T result;
- const char* temp_bytes = reinterpret_cast(&result);
- for (size_t j = 0; j < type_size; ++j) {
- memcpy((void*)&temp_bytes[j], (void*)&buff[type_size - 1 - i], sizeof(char));
- }
- p_data[i] = result;
- }
- }
- }
-}
-
-namespace onnxruntime {
-namespace utils {
-#define DEFINE_UNPACK_TENSOR(T, Type, field_name, field_size) \
- template <> \
- Status TensorUtils::UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, /*out*/ T* p_data, int64_t expected_size) { \
- if (nullptr == p_data) { \
- const size_t size = tensor.has_raw_data() ? tensor.raw_data().size() : tensor.field_size(); \
- if (size == 0) \
- return Status::OK(); \
- else \
- return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); \
- } \
- if (nullptr == p_data || Type != tensor.data_type()) { \
- return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); \
- } \
- if (tensor.has_raw_data()) { \
- if (tensor.raw_data().size() != ((expected_size) * sizeof(T))) \
- return Status(common::ONNXRUNTIME, common::FAIL, \
- "UnpackTensor: the pre-allocated size does not match the raw data size"); \
- UnpackTensorWithRawData(tensor, p_data); \
- return Status::OK(); \
- } \
- if (tensor.field_size() != expected_size) \
- return Status(common::ONNXRUNTIME, common::FAIL, \
- "UnpackTensor: the pre-allocated size does not match the size in proto"); \
- const auto span = gsl::make_span(p_data, expected_size); \
- auto& data = tensor.field_name(); \
- std::copy(data.cbegin(), data.cend(), span.begin()); \
- return Status::OK(); \
- }
-
-//TODO: uint32 uint64 complex64 complex128
-//TODO: int16_t/uint16_t/float16 is confusing right now
-DEFINE_UNPACK_TENSOR(float, ONNX_NAMESPACE::TensorProto_DataType_FLOAT, float_data, float_data_size)
-DEFINE_UNPACK_TENSOR(double, ONNX_NAMESPACE::TensorProto_DataType_DOUBLE, double_data, double_data_size);
-DEFINE_UNPACK_TENSOR(uint8_t, ONNX_NAMESPACE::TensorProto_DataType_UINT8, int32_data, int32_data_size)
-DEFINE_UNPACK_TENSOR(int8_t, ONNX_NAMESPACE::TensorProto_DataType_INT8, int32_data, int32_data_size)
-DEFINE_UNPACK_TENSOR(int16_t, ONNX_NAMESPACE::TensorProto_DataType_INT16, int32_data, int32_data_size)
-DEFINE_UNPACK_TENSOR(uint16_t, ONNX_NAMESPACE::TensorProto_DataType_UINT16, int32_data, int32_data_size)
-DEFINE_UNPACK_TENSOR(int32_t, ONNX_NAMESPACE::TensorProto_DataType_INT32, int32_data, int32_data_size)
-DEFINE_UNPACK_TENSOR(int64_t, ONNX_NAMESPACE::TensorProto_DataType_INT64, int64_data, int64_data_size)
-DEFINE_UNPACK_TENSOR(uint64_t, ONNX_NAMESPACE::TensorProto_DataType_UINT64, uint64_data, uint64_data_size)
-DEFINE_UNPACK_TENSOR(uint32_t, ONNX_NAMESPACE::TensorProto_DataType_UINT32, uint64_data, uint64_data_size)
-
-template <>
-Status TensorUtils::UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor,
- /*out*/ std::string* p_data,
- int64_t expected_size) {
- if (nullptr == p_data) {
- if (tensor.string_data_size() == 0)
- return Status::OK();
- else
- return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT);
- }
- if (ONNX_NAMESPACE::TensorProto_DataType_STRING != tensor.data_type()) {
- return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT);
- }
-
- if (tensor.string_data_size() != expected_size)
- return Status(common::ONNXRUNTIME, common::FAIL,
- "UnpackTensor: the pre-allocate size does not match the size in proto");
-
- const auto data = gsl::make_span(p_data, expected_size);
-
- auto& string_data = tensor.string_data();
- std::copy(string_data.cbegin(), string_data.cend(), data.begin());
-
- return Status::OK();
-}
-template <>
-Status TensorUtils::UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor,
- /*out*/ bool* p_data,
- int64_t expected_size) {
- if (nullptr == p_data) {
- const size_t size = tensor.has_raw_data() ? tensor.raw_data().size() : tensor.int32_data_size();
- if (size == 0)
- return Status::OK();
- else
- return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT);
- }
- if (ONNX_NAMESPACE::TensorProto_DataType_BOOL != tensor.data_type()) {
- return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT);
- }
-
- if (tensor.has_raw_data()) {
- if (tensor.raw_data().size() != (expected_size) * sizeof(bool))
- return Status(common::ONNXRUNTIME, common::FAIL,
- "UnpackTensor: the pre-allocate size does not match the raw data size");
-
- UnpackTensorWithRawData(tensor, p_data);
- return Status::OK();
- }
-
- if (tensor.int32_data_size() != expected_size)
- return Status(common::ONNXRUNTIME, common::FAIL,
- "UnpackTensor: the pre-allocate size does not match the size in proto");
-
- const auto data = gsl::make_span(p_data, expected_size);
- std::copy(tensor.int32_data().cbegin(), tensor.int32_data().cend(), data.begin());
-
- return Status::OK();
-}
-template <>
-Status TensorUtils::UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor,
- /*out*/ MLFloat16* p_data,
- int64_t expected_size) {
- if (nullptr == p_data) {
- const size_t size = tensor.has_raw_data() ? tensor.raw_data().size() : tensor.int32_data_size();
- if (size == 0)
- return Status::OK();
- else
- return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT);
- }
- if (ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 != tensor.data_type()) {
- return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT);
- }
-
- if (tensor.has_raw_data()) {
- if (tensor.raw_data().size() != (expected_size) * sizeof(uint16_t))
- return Status(common::ONNXRUNTIME, common::FAIL,
- "UnpackTensor: the pre-allocate size does not match the raw data size");
-
- UnpackTensorWithRawData(tensor, p_data);
- return Status::OK();
- }
-
- if (tensor.int32_data_size() != expected_size)
- return Status(common::ONNXRUNTIME, common::FAIL,
- "UnpackTensor: the pre-allocate size does not match the size in proto");
-
- const auto data = gsl::make_span(p_data, expected_size);
- for (int i = 0; i < expected_size; i++)
- data[i] = MLFloat16(gsl::narrow(tensor.int32_data()[i]));
-
- return Status::OK();
-}
-
-#define CASE_PROTO_TRACE(X, Y) \
- case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##X: \
- if (!IAllocator::CalcMemSizeForArrayWithAlignment(size, sizeof(Y), out)) { \
- return common::Status(common::ONNXRUNTIME, common::FAIL, "Invalid TensorProto"); \
- } \
- break;
-
-template
-common::Status GetSizeInBytesFromTensorProto(const ONNX_NAMESPACE::TensorProto& tensor_proto, size_t* out) {
- const auto& dims = tensor_proto.dims();
- size_t size = 1;
- for (int i = 0; i < dims.size(); ++i) {
- if (dims[i] < 0) {
- return common::Status(common::ONNXRUNTIME, common::FAIL, "Invalid TensorProto");
- }
- if (!IAllocator::CalcMemSizeForArray(size, static_cast(dims[i]), &size)) {
- return common::Status(common::ONNXRUNTIME, common::FAIL, "Invalid TensorProto");
- }
- }
- switch (tensor_proto.data_type()) {
- CASE_PROTO_TRACE(FLOAT, float);
- CASE_PROTO_TRACE(DOUBLE, double);
- CASE_PROTO_TRACE(BOOL, bool);
- CASE_PROTO_TRACE(INT8, int8_t);
- CASE_PROTO_TRACE(INT16, int16_t);
- CASE_PROTO_TRACE(INT32, int32_t);
- CASE_PROTO_TRACE(INT64, int64_t);
- CASE_PROTO_TRACE(UINT8, uint8_t);
- CASE_PROTO_TRACE(UINT16, uint16_t);
- CASE_PROTO_TRACE(UINT32, uint32_t);
- CASE_PROTO_TRACE(UINT64, uint64_t);
- CASE_PROTO_TRACE(FLOAT16, MLFloat16);
- case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_STRING:
- default:
- return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED);
- }
- return Status::OK();
-}
-
-template common::Status GetSizeInBytesFromTensorProto<256>(const ONNX_NAMESPACE::TensorProto& tensor_proto, size_t* out);
-} // namespace utils
-} // namespace onnxruntime
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/framework/tensorutils.h b/Source/CNTKv2LibraryDll/proto/onnx/core/framework/tensorutils.h
deleted file mode 100644
index a38e3c282..000000000
--- a/Source/CNTKv2LibraryDll/proto/onnx/core/framework/tensorutils.h
+++ /dev/null
@@ -1,31 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#pragma once
-
-#include
-#include
-
-#include "core/common/common.h"
-#include "core/common/status.h"
-
-namespace ONNX_NAMESPACE {
-class TensorProto;
-}
-namespace onnxruntime {
-namespace utils {
-//How much memory it will need for putting the content of this tensor into a plain array
-//string/complex64/complex128 tensors are not supported.
-//The output value could be zero or -1.
-template
-common::Status GetSizeInBytesFromTensorProto(const ONNX_NAMESPACE::TensorProto& tensor_proto, size_t* out);
-class TensorUtils {
- public:
- template
- static Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor,
- /*out*/ T* p_data,
- int64_t expected_size);
-
-}; // namespace Utils
-} // namespace utils
-} // namespace onnxruntime
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/function.cc b/Source/CNTKv2LibraryDll/proto/onnx/core/graph/function.cc
deleted file mode 100644
index 263c5c380..000000000
--- a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/function.cc
+++ /dev/null
@@ -1,214 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#include "core/graph/function_impl.h"
-#include "core/graph/graph.h"
-#include "core/graph/function_container.h"
-#include "onnx/shape_inference/implementation.h"
-
-namespace onnxruntime {
-void TypeConstraintHelper(const ONNX_NAMESPACE::FunctionProto* onnx_func_proto_,
- std::unique_ptr& op_schema_,
- /*out*/
- std::unordered_map& input_name_idx_map,
- std::unordered_map& output_name_idx_map) {
- std::vector> input_types_list(onnx_func_proto_->input_size());
- std::vector> output_types_list(onnx_func_proto_->output_size());
- std::unordered_map> type_constraint_map;
- for (int i = 0; i < onnx_func_proto_->input_size(); ++i) {
- input_name_idx_map[onnx_func_proto_->input().Get(i)] = i;
- }
- for (int i = 0; i < onnx_func_proto_->output_size(); ++i) {
- output_name_idx_map[onnx_func_proto_->output().Get(i)] = i;
- }
- auto schema_registry = ONNX_NAMESPACE::OpSchemaRegistry::Instance();
- for (auto& node : onnx_func_proto_->node()) {
- const auto node_op_schema = schema_registry->GetSchema(node.op_type(), (int)onnx_func_proto_->since_version(), node.domain());
- for (int i = 0; i < node.input_size(); ++i) {
- auto& in_name = node.input().Get(i);
- if (input_name_idx_map.count(in_name)) {
- int idx = input_name_idx_map[in_name];
- const auto& p = node_op_schema->inputs().at(i);
- std::string type_str = p.GetTypeStr() + "in" + std::to_string(i);
- input_types_list[idx] = std::make_pair(in_name, type_str);
- if (!type_constraint_map.count(type_str)) {
- for (auto s : p.GetTypes()) {
- type_constraint_map[type_str].emplace_back(*s);
- }
- }
- }
- }
- for (int i = 0; i < node.output_size(); ++i) {
- auto& out_name = node.output().Get(i);
- if (output_name_idx_map.count(out_name)) {
- int idx = output_name_idx_map[out_name];
- const auto& p = node_op_schema->outputs().at(i);
- std::string type_str = p.GetTypeStr() + "out" + std::to_string(i);
- output_types_list[idx] = std::make_pair(out_name, type_str);
- if (!type_constraint_map.count(type_str)) {
- for (auto s : p.GetTypes()) {
- type_constraint_map[type_str].emplace_back(*s);
- }
- }
- }
- }
- }
-
- int i = 0;
- for (auto& input : input_types_list) {
- op_schema_->Input(i, input.first, "", input.second);
- ++i;
- }
- i = 0;
- for (auto& output : output_types_list) {
- op_schema_->Output(i, output.first, "", output.second);
- ++i;
- }
-
- for (auto& tc : type_constraint_map) {
- op_schema_->TypeConstraint(tc.first, tc.second, "");
- }
-}
-
-FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,
- std::unique_ptr customized_func)
- : parent_graph_(&graph) {
- customized_func_body_ = std::move(customized_func);
- auto meta_def = customized_func_body_->GetMetaDef();
- op_schema_ = std::make_unique();
- op_schema_->SetName(meta_def->name);
- op_schema_->SetDomain(meta_def->domain);
- op_schema_->SetDoc(meta_def->doc_string);
- op_schema_->SinceVersion(meta_def->since_version);
- int i = 0;
- for (auto& input : meta_def->inputs) {
- auto input_type = parent_graph_->GetNodeArg(input)->Type();
- op_schema_->Input(i, input, "", *input_type);
- ++i;
- }
- i = 0;
- for (auto& output : meta_def->outputs) {
- auto output_type = parent_graph_->GetNodeArg(output)->Type();
- op_schema_->Output(i, output, "", *output_type);
- ++i;
- }
- op_schema_->Finalize();
- //construct body
- body_ = std::make_unique("fused_function_subgraph", false, onnxruntime::ModelMetaData(),
- IOnnxRuntimeOpSchemaRegistryList({graph.GetSchemaRegistry()}), graph.DomainToVersionMap());
-
- auto& sub_graph = body_->MainGraph();
- //Add node and node args
- //TODO: for better performance, we could try to transfer the nodes in parent graph to sub-graph directly,
- //instead of create new nodes.
- for (auto& node_index : customized_func_body_->nodes) {
- auto node = parent_graph_->GetNode(node_index);
- std::vector inputs, outputs;
- for (auto input : node->InputDefs()) {
- auto& n_input = sub_graph.GetOrCreateNodeArg(input->Name(), input->TypeAsProto());
- inputs.push_back(&n_input);
- }
- for (auto output : node->OutputDefs()) {
- auto& n_output = sub_graph.GetOrCreateNodeArg(output->Name(), output->TypeAsProto());
- outputs.push_back(&n_output);
- }
- sub_graph.AddNode(node->Name(), node->OpType(), node->Description(), inputs, outputs, &node->GetAttributes(), node->Domain());
- }
- //TODO: if we reuse the nodes in parent graph, maybe we don't need to resolve it.
- ONNXRUNTIME_ENFORCE(sub_graph.Resolve().IsOK());
-}
-
-FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,
- const onnxruntime::NodeIndex& node_index,
- const ONNX_NAMESPACE::FunctionProto* onnx_func_proto)
- : parent_graph_(&graph) {
- onnx_func_proto_ = onnx_func_proto;
- auto node_in_parent_graph = parent_graph_->GetNode(node_index);
- op_schema_ = std::make_unique();
- op_schema_->SetName(onnx_func_proto_->name());
- op_schema_->SetDomain(onnx_func_proto_->node().Get(0).domain());
- op_schema_->SetDoc(onnx_func_proto_->doc_string());
- op_schema_->SinceVersion((ONNX_NAMESPACE::OperatorSetVersion)onnx_func_proto_->since_version());
- std::unordered_map input_name_idx_map;
- std::unordered_map output_name_idx_map;
- TypeConstraintHelper(onnx_func_proto_, this->op_schema_, input_name_idx_map, output_name_idx_map);
-
- op_schema_->TypeAndShapeInferenceFunction(
- [this](ONNX_NAMESPACE::InferenceContext& ctx) {
- auto schema_registry = ONNX_NAMESPACE::OpSchemaRegistry::Instance();
- const ONNX_NAMESPACE::FunctionProto* func_ptr = this->GetFuncProto();
- if (nullptr != func_ptr) {
- ONNX_NAMESPACE::shape_inference::InferShapeForFunctionNode(*func_ptr, schema_registry, ctx);
- }
- });
-
- op_schema_->Finalize();
- //construct body
- std::unordered_map domain_to_version;
- //TODO: set correct domain and version
- domain_to_version[onnxruntime::kOnnxDomain] = (int)onnx_func_proto_->since_version();
- body_ = std::make_unique(onnx_func_proto_->name(), false, onnxruntime::ModelMetaData(),
- IOnnxRuntimeOpSchemaRegistryList(), domain_to_version);
- auto& sub_graph = body_->MainGraph();
- //Add node and node args into subgraph
- auto attr_map = node_in_parent_graph->GetAttributes();
- for (auto& node : onnx_func_proto_->node()) {
- std::vector inputs, outputs;
- for (int idx = 0; idx < node.input_size(); ++idx) {
- std::string tensor_name = node.input().Get(idx);
- if (input_name_idx_map.count(tensor_name)) {
- ONNX_NAMESPACE::NodeProto temp_node_proto;
- node_in_parent_graph->ToProto(temp_node_proto);
- const onnxruntime::NodeArg* node_arg = parent_graph_->GetNodeArg(temp_node_proto.input().Get(input_name_idx_map[tensor_name]));
- auto& n_input = sub_graph.GetOrCreateNodeArg(
- tensor_name, node_arg->TypeAsProto());
- inputs.push_back(&n_input);
- } else {
- auto& n_input = sub_graph.GetOrCreateNodeArg(
- tensor_name, nullptr);
- inputs.push_back(&n_input);
- }
- }
- for (int idx = 0; idx < node.output_size(); ++idx) {
- std::string tensor_name = node.output().Get(idx);
- auto& n_output = sub_graph.GetOrCreateNodeArg(tensor_name, nullptr);
- outputs.push_back(&n_output);
- }
-
- onnxruntime::NodeAttributes new_attr_map;
- for (auto& attr : node.attribute()) {
- if (attr.has_ref_attr_name()) {
- if (attr_map.count(attr.ref_attr_name())) {
- new_attr_map[attr.name()] = attr_map[attr.ref_attr_name()];
- }
- } else {
- new_attr_map[attr.name()] = attr;
- }
- }
- sub_graph.AddNode(node.name(), node.op_type(), node.doc_string(), inputs, outputs, &new_attr_map, node.domain());
- }
- auto status = sub_graph.Resolve();
- ONNXRUNTIME_ENFORCE(status.IsOK());
-}
-
-const ONNX_NAMESPACE::OpSchema& FunctionImpl::OpSchema() const {
- return *op_schema_;
-}
-
-const onnxruntime::Graph& FunctionImpl::Body() const {
- return body_->MainGraph();
-}
-
-const IndexedSubGraph& FunctionImpl::GetIndexedSubGraph() const {
- return *customized_func_body_;
-}
-
-const ONNX_NAMESPACE::FunctionProto* FunctionImpl::GetFuncProto() const {
- return onnx_func_proto_;
-}
-
-std::unique_ptr MakeFunction(const onnxruntime::Graph& graph,
- std::unique_ptr customized_func) {
- return std::make_unique(graph, std::move(customized_func));
-}
-} // namespace onnxruntime
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/function.h b/Source/CNTKv2LibraryDll/proto/onnx/core/graph/function.h
deleted file mode 100644
index b3161c2d3..000000000
--- a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/function.h
+++ /dev/null
@@ -1,29 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#pragma once
-
-#include "core/common/common.h"
-#include "core/graph/indexed_sub_graph.h"
-
-namespace onnxruntime {
-class Graph;
-class Node;
-} // namespace onnxruntime
-
-namespace onnxruntime {
-
-// Function representation class.
-class Function {
- public:
- virtual ~Function() {}
- virtual const ONNX_NAMESPACE::OpSchema& OpSchema() const = 0;
-
- virtual const onnxruntime::Graph& Body() const = 0;
-
- virtual const IndexedSubGraph& GetIndexedSubGraph() const = 0;
-};
-
-std::unique_ptr MakeFunction(const onnxruntime::Graph& graph,
- std::unique_ptr customized_func);
-} // namespace onnxruntime
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/function_container.h b/Source/CNTKv2LibraryDll/proto/onnx/core/graph/function_container.h
deleted file mode 100644
index f29b67a30..000000000
--- a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/function_container.h
+++ /dev/null
@@ -1,14 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#pragma once
-#include
-#include
-#include "core/graph/function.h"
-//TODO: we need to make it a stand-alone header because both graph.cc and model.cc need to implement create instance of the graph object.
-//Right now only functions_ has issue because it use vector of unique-ptr, maybe we should extend this to GraphImpl later.
-namespace onnxruntime {
-struct FunctionContainer {
- std::vector> functions_;
-};
-} // namespace onnxruntime
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/function_impl.h b/Source/CNTKv2LibraryDll/proto/onnx/core/graph/function_impl.h
deleted file mode 100644
index 0465bcaee..000000000
--- a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/function_impl.h
+++ /dev/null
@@ -1,42 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#pragma once
-#include "core/graph/function.h"
-#include "core/graph/graph_base.h"
-#include "core/graph/model.h"
-
-namespace onnxruntime {
-class Graph;
-class Node;
-} // namespace onnxruntime
-
-namespace onnxruntime {
-
-// Function representation class.
-class FunctionImpl final : public Function {
- public:
- FunctionImpl(const onnxruntime::Graph& graph,
- std::unique_ptr customized_func);
-
- FunctionImpl(const onnxruntime::Graph& graph,
- const onnxruntime::NodeIndex& node_index,
- const ONNX_NAMESPACE::FunctionProto* onnx_func);
-
- const ONNX_NAMESPACE::OpSchema& OpSchema() const override;
-
- const onnxruntime::Graph& Body() const override;
-
- const IndexedSubGraph& GetIndexedSubGraph() const override;
-
- const ONNX_NAMESPACE::FunctionProto* GetFuncProto() const;
-
- private:
- const onnxruntime::Graph* const parent_graph_;
- std::unique_ptr customized_func_body_;
- std::unique_ptr op_schema_;
- std::unique_ptr body_;
- const ONNX_NAMESPACE::FunctionProto* onnx_func_proto_;
-};
-
-} // namespace onnxruntime
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/function_inliner.h b/Source/CNTKv2LibraryDll/proto/onnx/core/graph/function_inliner.h
deleted file mode 100644
index 1d277836c..000000000
--- a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/function_inliner.h
+++ /dev/null
@@ -1,26 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#pragma once
-#include "core/common/common.h"
-#include "core/graph/function.h"
-#include "core/graph/rewrite_rule.h"
-
-namespace onnxruntime {
-class Node;
-} // namespace onnxruntime
-
-namespace onnxruntime {
-
-// A function-inlining rewrite-rule.
-class FunctionInliner : public onnxruntime::RewriteRule {
- public:
- FunctionInliner(const std::string& name, const std::string& desc)
- : RewriteRule(name, desc) {}
-
- Status Apply(onnxruntime::GraphEditor /*graph_editor*/, onnxruntime::Node* /*node*/, bool* /*modified*/) override {
- return Status::OK();
- }
-};
-
-} // namespace onnxruntime
diff --git a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/graph.cc b/Source/CNTKv2LibraryDll/proto/onnx/core/graph/graph.cc
deleted file mode 100644
index ba8f896fa..000000000
--- a/Source/CNTKv2LibraryDll/proto/onnx/core/graph/graph.cc
+++ /dev/null
@@ -1,2471 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#ifdef _WIN32
-// disable some warnings from protobuf to pass Windows build
-#pragma warning(disable : 4244)
-#endif
-
-#include
-#include
-#include
-#include
-
-#include "gsl/pointers"
-#include "core/graph/function.h"
-#include "core/graph/function_impl.h"
-#include "core/graph/graph.h"
-#include "core/graph/indexed_sub_graph.h"
-#include "core/graph/op.h"
-#include "core/common/logging/logging.h"
-#include "onnx/checker.h"
-#include "core/graph/schema_registry.h"
-#include "core/graph/function_container.h"
-using namespace ONNX_NAMESPACE;
-using namespace ONNX_NAMESPACE::Utils;
-using namespace ONNX_NAMESPACE::checker;
-using namespace ::onnxruntime::common;
-
-namespace onnxruntime {
-
-#define NO_CHANGE_ON_SYNC_FLAG(...) \
- do { \
- const bool sync_needed = GraphProtoSyncNeeded(); \
- { __VA_ARGS__; } \
- GraphProtoSyncNeeded(sync_needed); \
- } while (0)
-
-static Status MergeShapeInfo(const std::string& output_name,
- const TypeProto_Tensor& source, TypeProto_Tensor& target) {
- try {
- ONNX_NAMESPACE::mergeInShapeInfo(source, target);
- } catch (const ONNX_NAMESPACE::InferenceError& ex) {
- return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "Output:", output_name, " ", ex.what());
- }
-
- return Status::OK();
-}
-
-static bool GraphLoadedFromModelFile(const GraphProto* graph_proto) {
- return graph_proto && (graph_proto->input_size() != 0 ||
- graph_proto->output_size() != 0 ||
- graph_proto->value_info_size() != 0);
-}
-
-NodeArg::NodeArg(const std::string& name,
- const TypeProto* p_node_arg_type) {
- node_arg_info_.set_name(name);
- // If the name is empty, it means the arg does not exist.
- exists_ = !(name.empty());
- if (nullptr != p_node_arg_type) {
- (*node_arg_info_.mutable_type()) = *p_node_arg_type;
- type_ = DataTypeUtils::ToType(node_arg_info_.type());
- } else {
- type_ = nullptr;
- }
-}
-
-const std::string& NodeArg::Name() const noexcept {
- return node_arg_info_.name();
-}
-
-DataType NodeArg::Type() const noexcept {
- return type_;
-}
-
-const TypeProto* NodeArg::TypeAsProto() const noexcept {
- if (node_arg_info_.has_type())
- return &node_arg_info_.type();
- else
- return nullptr;
-}
-
-const TensorShapeProto* NodeArg::Shape() const {
- if (!node_arg_info_.has_type()) {
- return nullptr;
- }
-
- const auto typeCase = node_arg_info_.type().value_case();
- switch (typeCase) {
- case TypeProto::kTensorType: {
- if (node_arg_info_.type().tensor_type().has_shape()) {
- return &(node_arg_info_.type().tensor_type().shape());
- } else {
- return nullptr;
- }
- }
- case TypeProto::kSparseTensorType: {
- if (node_arg_info_.type().sparse_tensor_type().has_shape()) {
- return &(node_arg_info_.type().sparse_tensor_type().shape());
- } else {
- return nullptr;
- }
- }
- case TypeProto::kSequenceType:
- case TypeProto::kMapType:
- case TypeProto::kOpaqueType:
- case TypeProto::VALUE_NOT_SET:
- default:
- return nullptr;
- }
-}
-
-void NodeArg::SetShape(const TensorShapeProto& shape) {
- if (!node_arg_info_.has_type()) {
- return;
- }
-
- const auto type_case = node_arg_info_.type().value_case();
- switch (type_case) {
- case TypeProto::kTensorType:
- *(node_arg_info_.mutable_type()->mutable_tensor_type()->mutable_shape()) = shape;
- break;
- case TypeProto::kSparseTensorType:
- *(node_arg_info_.mutable_type()->mutable_sparse_tensor_type()->mutable_shape()) = shape;
- break;
- case TypeProto::kSequenceType:
- case TypeProto::kMapType:
- case TypeProto::kOpaqueType:
- case TypeProto::VALUE_NOT_SET:
- default:
- return;
- }
-}
-
-common::Status NodeArg::UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& input_type) {
- if (!node_arg_info_.has_type()) {
- *node_arg_info_.mutable_type() = input_type;
- type_ = DataTypeUtils::ToType(node_arg_info_.type());
- return Status::OK();
- }
-
- auto& current_type = *node_arg_info_.mutable_type();
- const auto current_type_case = current_type.value_case();
- const auto input_type_case = input_type.value_case();
-
- if (current_type_case != input_type_case)
- return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type mismatch. Current=",
- current_type_case, " Input=", input_type_case);
-
- switch (input_type_case) {
- case TypeProto::kTensorType: {
- const auto& input_tensor_type = input_type.tensor_type();
- const auto& input_tensor_elem_type = input_tensor_type.elem_type();
- const auto& current_tensor_elem_type = current_type.tensor_type().elem_type();
-
- if (input_tensor_elem_type != current_tensor_elem_type)
- return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "Tensor element type mismatch. ",
- TensorProto_DataType_Name(input_tensor_elem_type), " != ",
- TensorProto_DataType_Name(current_tensor_elem_type));
-
- if (input_tensor_type.has_shape()) {
- auto& current_tensor_type = *current_type.mutable_tensor_type();
- if (current_tensor_type.has_shape()) {
- ONNXRUNTIME_RETURN_IF_ERROR(MergeShapeInfo(Name(), input_tensor_type, current_tensor_type));
- } else {
- current_tensor_type = input_tensor_type;
- }
- }
-
- break;
- }
- case TypeProto::kSparseTensorType: {
- const auto& input_tensor_type = input_type.sparse_tensor_type();
- const auto input_tensor_elem_type = input_tensor_type.elem_type();
- const auto current_tensor_elem_type = current_type.sparse_tensor_type().elem_type();
- if (input_tensor_elem_type != current_tensor_elem_type) {
- return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "SparseTensor element type mismatch. ",
- TensorProto_DataType_Name(input_tensor_elem_type), " != ",
- TensorProto_DataType_Name(current_tensor_elem_type));
- }
- if (input_tensor_type.has_shape()) {
- auto& current_tensor_type = *current_type.mutable_sparse_tensor_type();
- if (current_tensor_type.has_shape()) {
- // TODO: Check if we need to merge shape here
- // if so we'd need to provide merging routine ONNX
- // mergeInShapeInfo(input_tensor_type, current_tensor_type);
- } else {
- current_tensor_type = input_tensor_type;
- }
- }
- } break;
- case TypeProto::kSequenceType:
- case TypeProto::kMapType:
- case TypeProto::kOpaqueType:
- case TypeProto::VALUE_NOT_SET:
- break;
- }
-
- return Status::OK();
-}
-
-common::Status NodeArg::UpdateTypeAndShape(const NodeArg& node_arg) {
- auto status = Status::OK();
-
- if (node_arg.node_arg_info_.has_type())
- status = UpdateTypeAndShape(node_arg.node_arg_info_.type());
-
- return status;
-}
-
-void NodeArg::SetType(DataType p_type) {
- if (nullptr == p_type) {
- return;
- }
-
- type_ = p_type;
- *(node_arg_info_.mutable_type()) = DataTypeUtils::ToTypeProto(p_type);
-}
-
-void NodeArg::SetType(const TypeProto& type_proto) {
- type_ = DataTypeUtils::ToType(type_proto);
- *(node_arg_info_.mutable_type()) = type_proto;
-}
-
-bool NodeArg::Exists() const noexcept {
- return exists_;
-}
-
-Node::EdgeEnd::EdgeEnd(const Node& node, const NodeArg& node_arg) noexcept
- : node_(&node), node_arg_(&node_arg) {
-}
-
-Node::EdgeEnd::EdgeEnd(const Node& node) noexcept
- : node_(&node), node_arg_(nullptr) {
-}
-
-const Node& Node::EdgeEnd::GetNode() const noexcept {
- return *node_;
-}
-
-const NodeArg* Node::EdgeEnd::GetNodeArg() const noexcept {
- return node_arg_;
-}
-
-Node::NodeConstIterator::NodeConstIterator(EdgeConstIterator p_iter) {
- m_iter = p_iter;
-}
-
-bool Node::NodeConstIterator::operator==(const NodeConstIterator& p_other) const {
- return m_iter == p_other.m_iter;
-}
-
-bool Node::NodeConstIterator::operator!=(const NodeConstIterator& p_other) const {
- return m_iter != p_other.m_iter;
-}
-
-void Node::NodeConstIterator::operator++() {
- ++m_iter;
-}
-
-void Node::NodeConstIterator::operator--() {
- --m_iter;
-}
-
-const Node* Node::NodeConstIterator::operator*() {
- return &((*m_iter).GetNode());
-}
-
-NodeIndex Node::Index() const noexcept {
- return index_;
-}
-
-const std::string& Node::Name() const noexcept {
- return name_;
-}
-
-const std::string& Node::OpType() const noexcept {
- return op_type_;
-}
-
-const std::string& Node::Description() const noexcept {
- return description_;
-}
-
-const std::string& Node::Domain() const noexcept {
- return domain_;
-}
-
-const OpSchema* Node::Op() const noexcept {
- return op_;
-}
-
-Node::Type Node::NodeType() const noexcept {
- return node_type_;
-}
-
-void Node::SetNodeType(Node::Type node_type) noexcept {
- node_type_ = node_type;
-}
-
-const ::onnxruntime::Function* Node::GetFunctionBody() const noexcept {
- return func_body_;
-}
-
-void Node::SetFunctionBody(const ::onnxruntime::Function& func) {
- func_body_ = &func;
- op_ = &func.OpSchema();
-}
-
-const std::string& Node::GetExecutionProviderType() const noexcept {
- return execution_provider_type_;
-}
-
-void Node::SetExecutionProviderType(onnxruntime::ProviderType execution_provider_type) {
- execution_provider_type_ = execution_provider_type;
-}
-
-void Node::ToProto(NodeProto& proto) const {
- // Set name.
- proto.set_name(name_);
- // Set op type.
- proto.set_op_type(op_type_);
- // Set op domain;
- proto.set_domain(domain_);
- // Set doc string.
- proto.set_doc_string(description_);
-
- // Set attributes.
- proto.clear_attribute();
- for (auto attribute : attributes_) {
- const gsl::not_null attr{proto.add_attribute()};
- *attr = attribute.second;
- }
-
- // Set inputs' definitions.
- proto.clear_input();
- for (auto& input_def : definitions_.input_defs) {
- proto.add_input(input_def->Name());
- }
-
- // Set outputs' definitions.
- proto.clear_output();
- for (auto& output_def : definitions_.output_defs) {
- proto.add_output(output_def->Name());
- }
-}
-
-void Node::Init(const std::string& name,
- const std::string& op_type,
- const std::string& description,
- const std::vector& input_args,
- const std::vector& output_args,
- const NodeAttributes* attributes,
- const std::string& domain) {
- name_ = name;
- op_type_ = op_type;
- description_ = description;
- definitions_.input_defs = input_args;
- definitions_.output_defs = output_args;
- domain_ = domain;
- if (kOnnxDomainAlias == domain_) {
- domain_ = kOnnxDomain;
- }
-
- // Set each arg count as 1 by default.
- // It could be adjusted when resolving the node with its operator
- // information.
- definitions_.input_arg_count.assign(input_args.size(), 1);
-
- if (attributes) {
- attributes_ = *attributes;
- }
-}
-
-Node::Definitions& Node::MutableDefinitions() noexcept {
- // someone fetching these is going to change something
- graph_->SetGraphResolveNeeded();
- graph_->SetGraphProtoSyncNeeded();
- return definitions_;
-}
-
-Node::Relationships& Node::MutableRelationships() noexcept {
- // someone fetching these is going to change something
- graph_->SetGraphResolveNeeded();
- graph_->SetGraphProtoSyncNeeded();
- return relationships_;
-}
-
-void Node::AddAttribute(const std::string& attr_name, const AttributeProto& value) {
- graph_->SetGraphResolveNeeded();
- graph_->SetGraphProtoSyncNeeded();
- attributes_[attr_name] = value;
-}
-
-#define ADD_BASIC_ATTR_IMPL(type, enumType, field) \
- void Node::AddAttribute(const std::string& attr_name, const type& value) { \
- graph_->SetGraphResolveNeeded(); \
- graph_->SetGraphProtoSyncNeeded(); \
- AttributeProto a; \
- a.set_name(attr_name); \
- a.set_type(enumType); \
- a.set_##field(value); \
- attributes_[attr_name] = a; \
- };
-
-#define ADD_ATTR_IMPL(type, enumType, field) \
- void Node::AddAttribute(const std::string& attr_name, const type& value) { \
- graph_->SetGraphResolveNeeded(); \
- graph_->SetGraphProtoSyncNeeded(); \
- AttributeProto a; \
- a.set_name(attr_name); \
- a.set_type(enumType); \
- *(a.mutable_##field()) = value; \
- attributes_[attr_name] = a; \
- };
-
-#define ADD_LIST_ATTR_IMPL(type, enumType, field) \
- void Node::AddAttribute(const std::string& attr_name, \
- const std::vector& values) { \
- graph_->SetGraphResolveNeeded(); \
- graph_->SetGraphProtoSyncNeeded(); \
- AttributeProto a; \
- a.set_name(attr_name); \
- a.set_type(enumType); \
- for (const auto& val : values) { \
- *(a.mutable_##field()->Add()) = val; \
- } \
- attributes_[attr_name] = a; \
- };
-
-ADD_BASIC_ATTR_IMPL(float, AttributeProto_AttributeType::AttributeProto_AttributeType_FLOAT, f)
-ADD_BASIC_ATTR_IMPL(int64_t, AttributeProto_AttributeType::AttributeProto_AttributeType_INT, i)
-ADD_BASIC_ATTR_IMPL(std::string, AttributeProto_AttributeType::AttributeProto_AttributeType_STRING, s)
-ADD_ATTR_IMPL(TensorProto, AttributeProto_AttributeType::AttributeProto_AttributeType_TENSOR, t)
-ADD_ATTR_IMPL(GraphProto, AttributeProto_AttributeType::AttributeProto_AttributeType_GRAPH, g)
-ADD_LIST_ATTR_IMPL(float, AttributeProto_AttributeType::AttributeProto_AttributeType_FLOATS, floats)
-ADD_LIST_ATTR_IMPL(int64_t, AttributeProto_AttributeType::AttributeProto_AttributeType_INTS, ints)
-ADD_LIST_ATTR_IMPL(std::string, AttributeProto_AttributeType::AttributeProto_AttributeType_STRINGS, strings)
-ADD_LIST_ATTR_IMPL(TensorProto, AttributeProto_AttributeType::AttributeProto_AttributeType_TENSORS, tensors)
-ADD_LIST_ATTR_IMPL(GraphProto, AttributeProto_AttributeType::AttributeProto_AttributeType_GRAPHS, graphs)
-
-bool Node::ClearAttribute(const std::string& attr_name) {
- graph_->SetGraphResolveNeeded();
- graph_->SetGraphProtoSyncNeeded();
- return attributes_.erase(attr_name) > 0;
-}
-
-Status Node::UpdateInputArgCount() {
- // The node refers to a primitive operator.
- // Infer and verify node input arg type information.
- int total_arg_count = std::accumulate(definitions_.input_arg_count.cbegin(),
- definitions_.input_arg_count.cend(), 0);
-
- if (total_arg_count < 0 || static_cast(total_arg_count) != definitions_.input_defs.size()) {
- return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL,
- "The sum of input arg count is not equal to size of input defs in node (",
- name_, ")");
- }
-
- // op_ is always valid when this is called
- const ONNX_NAMESPACE::OpSchema& op = *Op();
-
- // Verify size of node arg count is same as input number in
- // operator definition.
- if (op.inputs().size() != definitions_.input_arg_count.size()) {
- // Adjust input arg count array with op definition
- // The adjustment will work as below,
- // In total, there're inputs, which
- // will be split as <1, 1, 1, 1, ... 1, x> or
- // <1, 1, 1, 1, ...1, 0, 0, ...0>. The final input
- // arg count array's element number will be the same
- // as op definition, and the sum of all elements will
- // be equal to .
- auto& input_arg_count = definitions_.input_arg_count;
- input_arg_count.clear();
- size_t m = 0;
- auto arg_count_left = total_arg_count;
-
- if (!op.inputs().empty()) {
- for (; m < op.inputs().size() - 1; ++m) {
- if (arg_count_left > 0) {
- input_arg_count.push_back(1);
- arg_count_left--;
- } else {
- input_arg_count.push_back(0);
- }
- }
- }
-
- // Set the arg count for the last input formal parameter.
- // NOTE: in the case that there's no .input(...) defined
- // in op schema, all input args will be fed as one input
- // of the operator.
- input_arg_count.push_back(arg_count_left);
-
- graph_->SetGraphResolveNeeded();
- graph_->SetGraphProtoSyncNeeded();
- }
-
- return Status::OK();
-}
-
-const NodeAttributes& Node::GetAttributes() const noexcept {
- return attributes_;
-}
-
-void Node::ForEachDef(std::function func) const {
- for (const auto* arg : InputDefs()) {
- if (arg->Exists())
- func(&*arg, true);
- }
-
- for (const auto* arg : ImplicitInputDefs()) {
- if (arg->Exists())
- func(&*arg, true);
- }
-
- for (const auto* arg : OutputDefs()) {
- if (arg->Exists())
- func(&*arg, false);
- }
-};
-
-void Node::ForEachInputDef(std::function func) const {
- for (const auto* arg : InputDefs()) {
- if (!arg->Exists())
- continue;
- func(&*arg);
- }
-};
-
-void Node::ForEachOutputDef(std::function func) const {
- for (const auto* arg : OutputDefs()) {
- if (!arg->Exists())
- continue;
- func(&*arg);
- }
-};
-
-void Node::ReplaceDefs(const std::map& replacements) {
- std::vector*> all_defs = {&definitions_.input_defs, &definitions_.output_defs};
-
- for (auto pair : replacements)
- for (auto* defs : all_defs)
- for (auto& def : *defs)
- if (def == pair.first)
- def = pair.second;
-}
-
-// Constructor: Given a loaded from model file, construct
-// a object and Resolve() it.
-//Status Graph::LoadGraph(const GraphProto& graph_proto,
-// const std::unordered_map& domain_to_version,
-// Version ir_version,
-// std::unique_ptr& new_graph) {
-// // create instance. need to call private ctor so can't use make_unique
-// GSL_SUPPRESS(r .11)
-// new_graph.reset(new Graph(nullptr, &graph_proto, domain_to_version, ir_version));
-//
-// // as we just loaded from file we want to fully initialize/Resolve, but not let that change
-// // the proto sync flag
-// auto status = new_graph->Resolve(/* no_proto_sync_required */ true);
-// return status;
-//}
-using google::protobuf::RepeatedPtrField;
-
-Graph::Graph(GraphProto* graph_proto,
- const std::unordered_map& domain_to_version,
- Version ir_version,
- IOnnxRuntimeOpSchemaCollectionPtr schema_registry)
- : Graph(graph_proto, domain_to_version, ir_version, schema_registry, nullptr) {}
-
-Graph::Graph(GraphProto* graph_proto,
- const std::unordered_map& domain_to_version,
- Version ir_version,
- IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
- Graph* parent_graph)
- : graph_proto_{graph_proto},
- graph_type_{Type::Main},
- schema_registry_(schema_registry),
- function_container_(std::make_unique()),
- graph_resolve_needed_(true),
- graph_proto_sync_needed_(false),
- domain_to_version_(domain_to_version),
- ir_version_(ir_version),
- parent_graph_{parent_graph} {
- ONNXRUNTIME_ENFORCE(graph_proto != nullptr, "graph_proto cannot be null");
- ArgNameToTypeMap name_to_type_map;
-
- // these are all empty unless we received a graph_proto as input
- if (graph_proto != nullptr) {
- // Copy constant nodes _value to name_to_initial_tensor_
- for (auto& node : graph_proto_->node()) {
- if (node.op_type() == kConstant) {
- const gsl::not_null tensor{graph_proto_->add_initializer()};
- *tensor = node.attribute(0).t();
- *(tensor->mutable_name()) = node.output(0);
-
- // we remove the node and add it as an initializer, but still need it to appear in the
- // graph inputs to make the ONNX checker happy. add a new input due to that.
- auto graph_inputs = graph_proto_->mutable_input();
-
- ValueInfoProto* value_info = graph_inputs->Add();
- value_info->set_name(node.output(0));
- value_info->set_doc_string("Input to represent replaced Constant node");
-
- TypeProto t;
- t.mutable_tensor_type()->set_elem_type(tensor->data_type());
- auto shape = t.mutable_tensor_type()->mutable_shape();
- for (auto dim : tensor->dims())
- shape->add_dim()->set_dim_value(dim);
-
- (*value_info->mutable_type()) = t;
- }
- }
-
- // remove constant nodes
- const gsl::not_null*> graph_mutable_nodes{graph_proto_->mutable_node()};
- graph_mutable_nodes->erase(
- std::remove_if(graph_mutable_nodes->begin(), graph_mutable_nodes->end(),
- [](NodeProto& p) {
- return (p.op_type() == kConstant);
- }),
- graph_mutable_nodes->end());
-
- // Copy initial tensors to a map.
- for (auto& tensor : graph_proto_->initializer()) {
- name_to_initial_tensor_[tensor.name()] = &tensor;
- }
-
- // Collect all node arg name, type, shape information in the graph.
- // type/shape information will be assigned to each node arg when going
- // thru all nodes later.
- for (auto& graph_input : graph_proto_->input()) {
- if (graph_input.has_name() && graph_input.has_type()) {
- name_to_type_map[graph_input.name()] = graph_input.type();
- // always create a NodeArg for graph input in case its from an initializer
- GetOrCreateNodeArg(graph_input.name(), &graph_input.type());
- }
- }
-
- for (auto& graph_output : graph_proto_->output()) {
- if (graph_output.has_name() && graph_output.has_type()) {
- auto& name = graph_output.name();
- name_to_type_map[name] = graph_output.type();
- // always create NodeArg for graph output, in case it's from initializer
- GetOrCreateNodeArg(name, &graph_output.type());
- }
- }
-
- for (auto& node_arg : graph_proto_->value_info()) {
- if (node_arg.has_name() && node_arg.has_type()) {
- name_to_type_map[node_arg.name()] = node_arg.type();
- }
- }
-
- for (auto node_proto : graph_proto_->node()) {
- AddNode(node_proto, name_to_type_map);
- }
- }
-}
-
-Graph::Graph(Graph& parent_graph, ONNX_NAMESPACE::GraphProto& subgraph_proto)
- : Graph(&subgraph_proto,
- parent_graph.DomainToVersionMap(), parent_graph.IrVersion(), parent_graph.schema_registry_,
- &parent_graph) {
-}
-
-Status Graph::VerifyNoDuplicateName() {
- const std::unordered_set& inputs_and_initializers = resolve_context_.inputs_and_initializers;
- std::unordered_map& output_args = resolve_context_.output_args;
- std::unordered_map& node_name_to_index = resolve_context_.node_name_to_index;
-
- output_args.clear();
- node_name_to_index.clear();
- // inputs_and_initializers: this is passed in as a parameter, since functions don't have initializers
- // but graphs have them.
-
- for (auto& node : Nodes()) {
- // Verify node name should be unique.
- auto& node_name = node.Name();
-
- if (!node_name.empty() && node_name_to_index.end() != node_name_to_index.find(node_name)) {
- // The node has name and its name was used by another node.
- Status status(ONNXRUNTIME, FAIL,
- "Error: two nodes with same node name (" + node_name + ").");
- return status;
- }
-
- node_name_to_index[node_name] = node.Index();
-
- // Verify node outputs' name should be unique.
- for (const auto* output_def : node.OutputDefs()) {
- if (output_def->Exists()) {
- auto& output_arg_name = output_def->Name();
- if (inputs_and_initializers.count(output_arg_name)) {
- Status status(ONNXRUNTIME, FAIL,
- "Error: Duplicate definition of name (" + output_arg_name + ").");
- return status;
- }
- auto result = output_args.insert({output_arg_name, &node});
- if (!result.second) {
- // Two outputs with same name, so that insertion fails.
- Status status(ONNXRUNTIME, FAIL,
- "Error: Duplicate definition of name (" + output_arg_name + ").");
- return status;
- }
- }
- }
- }
- return Status::OK();
-}
-
-// Recurse into any subgraphs to update the list of NodeArg values in outer scope.
-// This information is needed to resolve any dependencies on outer scope values.
-common::Status Graph::SetOuterScopeNodeArgs(const std::unordered_set& outer_scope_node_args) {
- resolve_context_.outer_scope_node_args = outer_scope_node_args;
-
- if (!resolve_context_.node_to_subgraphs_map.empty()) {
- // Build the list of NodeArg's that are valid for a subgraph of this GraphBase instance:
- // - outer scope for this graph
- // - any inputs/initializers from this graph
- // - any outputs from nodes in this graph
- //
- // NOTE: We must add the most outer most NodeArgs first, and then local NodeArgs, as the local should override
- // an outer scope value if they have the same name.
- //
- // We provide outputs from all nodes in this graph at this stage.
- // BuildConnections will link the node with the subgraph to any outer scope Node/NodeArgs it consumes.
- // PerformTopologicalSortAndCheckIsAcyclic will validate these links.
- std::unordered_set node_args_in_scope_for_subgraph = outer_scope_node_args;
-
- node_args_in_scope_for_subgraph.insert(resolve_context_.inputs_and_initializers.cbegin(),
- resolve_context_.inputs_and_initializers.cend());
-
- std::transform(resolve_context_.output_args.cbegin(), resolve_context_.output_args.cend(),
- std::inserter(node_args_in_scope_for_subgraph, node_args_in_scope_for_subgraph.end()),
- [](const std::pair& entry) { return entry.first; });
-
- for (auto node_subgraphs : resolve_context_.node_to_subgraphs_map) {
- for (auto* subgraph : node_subgraphs.second) {
- auto status = subgraph->SetOuterScopeNodeArgs(node_args_in_scope_for_subgraph);
- ONNXRUNTIME_RETURN_IF_ERROR(status);
- }
- }
- }
-
- return Status::OK();
-}
-
-const NodeArg* Graph::GetNodeArgIncludingParentGraphs(const std::string& node_arg_name) const {
- const NodeArg* node_arg = GetNodeArg(node_arg_name);
-
- if (!node_arg && parent_graph_) {
- node_arg = parent_graph_->GetNodeArgIncludingParentGraphs(node_arg_name);
- }
-
- return node_arg;
-}
-
-void Graph::AddEdge(NodeIndex src_node_index, NodeIndex dst_node_index, const NodeArg& node_arg) {
- if (nodes_.size() <= src_node_index ||
- nodes_.size() <= dst_node_index ||
- nullptr == nodes_[src_node_index] ||
- nullptr == nodes_[dst_node_index]) {
- // Invalid node indexes specified.
- ONNXRUNTIME_THROW("Invalid node indexes specified when adding edge.");
- }
- // Verify whether the node_arg is input of dst and output of src firstly.
- bool valid = false;
- for (auto arg : nodes_[src_node_index]->OutputDefs()) {
- if (arg == &node_arg) {
- valid = true;
- break;
- }
- }
- ONNXRUNTIME_ENFORCE(valid);
- valid = false;
- for (auto arg : nodes_[dst_node_index]->InputDefs()) {
- if (arg == &node_arg) {
- valid = true;
- break;
- }
- }
- for (auto arg : nodes_[dst_node_index]->ImplicitInputDefs()) {
- if (arg == &node_arg) {
- valid = true;
- break;
- }
- }
- ONNXRUNTIME_ENFORCE(valid);
- nodes_[dst_node_index]->MutableRelationships().input_edges.insert(Node::EdgeEnd(*nodes_[src_node_index], node_arg));
- nodes_[src_node_index]->MutableRelationships().output_edges.insert(Node::EdgeEnd(*nodes_[dst_node_index], node_arg));
-}
-
-void Graph::RemoveEdge(NodeIndex src_node_index, NodeIndex dst_node_index, const NodeArg& node_arg) {
- if (nodes_.size() <= src_node_index ||
- nodes_.size() <= dst_node_index ||
- nullptr == nodes_[src_node_index] ||
- nullptr == nodes_[dst_node_index]) {
- // Invalid node indexes specified.
- ONNXRUNTIME_THROW("Invalid node indexes specified when removing edge.");
- }
- // Verify whether the node_arg is input of dst and output of src firstly.
- bool valid = false;
- for (auto arg : nodes_[src_node_index]->OutputDefs()) {
- if (arg == &node_arg) {
- valid = true;
- break;
- }
- }
- ONNXRUNTIME_ENFORCE(valid);
- valid = false;
- for (auto arg : nodes_[dst_node_index]->InputDefs()) {
- if (arg == &node_arg) {
- valid = true;
- break;
- }
- }
- for (auto arg : nodes_[dst_node_index]->ImplicitInputDefs()) {
- if (arg == &node_arg) {
- valid = true;
- break;
- }
- }
- ONNXRUNTIME_ENFORCE(valid);
- nodes_[dst_node_index]->MutableRelationships().input_edges.erase(Node::EdgeEnd(*nodes_[src_node_index], node_arg));
- nodes_[src_node_index]->MutableRelationships().output_edges.erase(Node::EdgeEnd(*nodes_[dst_node_index], node_arg));
-}
-
-GSL_SUPPRESS(es .84) // ignoring return value from unordered_map::insert causes noisy complaint
-Status Graph::BuildConnections(std::vector& outer_scope_node_args_consumed) {
- const std::unordered_set& outer_scope_node_args = resolve_context_.outer_scope_node_args;
- std::unordered_set inner_nodes;
-
- // recurse into subgraphs first so we can update any nodes in this graph that are used by those subgraphs
- if (!resolve_context_.node_to_subgraphs_map.empty()) {
- for (auto nodeid_to_subgraphs : resolve_context_.node_to_subgraphs_map) {
- for (auto* subgraph : nodeid_to_subgraphs.second) {
- std::vector node_args_consumed;
- subgraph->BuildConnections(node_args_consumed);
-
- for (auto& node_arg_name : node_args_consumed) {
- const auto* node_arg = GetNodeArg(node_arg_name);
-
- if (node_arg == nullptr) {
- // it's a node arg from outside this graph's scope, so add that to the list we return
- // so that we can add the dependency at the next level up
- if (node_arg_name == "ElementTimes1147_Output_0")
- std::cout << "";
- outer_scope_node_args_consumed.push_back(node_arg_name);
-
- if (!parent_graph_) {
- return ONNXRUNTIME_MAKE_STATUS(
- ONNXRUNTIME, INVALID_GRAPH,
- "At top level graph without matching NodeArg that subgraph consumes. Name=",
- node_arg_name,
- " Graph may not conform to the ONNX spec and contain initializers that are not graph inputs.");
- }
-
- node_arg = parent_graph_->GetNodeArgIncludingParentGraphs(node_arg_name);
-
- if (!node_arg) {
- return ONNXRUNTIME_MAKE_STATUS(
- ONNXRUNTIME, INVALID_GRAPH,
- "Failed to find NodeArg in all parent graphs. Name=", node_arg_name,
- " Graph may not conform to the ONNX spec and contain initializers that are not graph inputs.");
- }
- }
-
- // add it to the Node's list of implicit inputs
- auto& node = *GetNode(nodeid_to_subgraphs.first);
-
- if (node_arg->Name() == "ElementTimes1147_Output_0")
- std::cout << "";
- node.MutableDefinitions().implicit_input_defs.push_back(node_arg);
-
- if (resolve_context_.inputs_and_initializers.find(node_arg_name) !=
- resolve_context_.inputs_and_initializers.cend()) {
- // no connection required
- } else {
- // if it's an output nodearg in this graph we need to create a link to the node the output is coming from
- auto entry = resolve_context_.output_args.find(node_arg_name);
- ONNXRUNTIME_ENFORCE(entry != resolve_context_.output_args.end());
-
- // Create relationship between this node (node), and the node providing the output (output_node).
- Node& output_node = *entry->second;
- AddEdge(output_node.Index(), node.Index(), *node_arg);
-
- inner_nodes.insert(&output_node);
- }
- }
- }
- }
- }
-
- // now build connections within this Graph instance
- for (auto& node : Nodes()) {
- // Need mutable input defs to be able to set any outer scope NodeArg implicit inputs
- auto& input_args = node.MutableInputDefs();
-
- if (input_args.size() > 0) {
- // This node needs inputs.
-
- for (const auto* input_arg : input_args) {
- if (!input_arg->Exists()) {
- // This input could be optional and it does not exist in this case.
- continue;
- }
-
- auto output_arg_iter = resolve_context_.output_args.find(input_arg->Name());
- if (resolve_context_.output_args.end() == output_arg_iter) {
- // No such output_arg matching this input_arg.
- // This input arg should be fed when running evaluation.
- // See if it's present in the outer scope. If so it will be 'fed' by the execution frame
- // providing access to the MLValue from the outer scope. Pass the name back up so nodes can
- // be linked correctly at that level.
- if (outer_scope_node_args.find(input_arg->Name()) != outer_scope_node_args.cend()) {
- if (input_arg->Name() == "ElementTimes1147_Output_0")
- std::cout << "";
-
- outer_scope_node_args_consumed.push_back(input_arg->Name());
- }
-
- continue;
- }
-
- // Create relationship between this node (node), and the node providing the output (output_node).
- Node& output_node = *output_arg_iter->second;
- AddEdge(output_node.Index(), node.Index(), *input_arg);
-
- inner_nodes.insert(&output_node);
- }
- } else if (node.OutputDefs().size() <= 0) {
- // This is a useless node.
- // It has no input/output.
- RemoveNode(node.Index());
- }
- }
-
- return Status::OK();
-}
-
-void Graph::ReverseDFSFrom(const std::vector& from,
- const std::function& enter,
- const std::function& leave,
- const std::function& comp) const {
- std::vector node_vec;
- for (auto i : from) {
- node_vec.push_back(GetNode(i));
- }
-
- ReverseDFSFrom(node_vec, enter, leave, comp);
-}
-
-void Graph::ReverseDFSFrom(const std::vector& from,
- const std::function& enter,
- const std::function& leave,
- const std::function& comp) const {
- using WorkEntry = std::pair; // bool represents leave or not
- std::vector stack(from.size());
- for (size_t i = 0; i < from.size(); i++) {
- stack[i] = WorkEntry(from[i], false);
- }
-
- std::vector visited(MaxNodeIndex(), false);
- while (!stack.empty()) {
- const WorkEntry last_entry = stack.back();
- stack.pop_back();
- const Node& n = *last_entry.first;
- if (last_entry.second) {
- // leave node
- leave(&n);
- continue;
- }
-
- if (visited[n.Index()]) continue;
-
- visited[n.Index()] = true;
-
- if (enter) enter(&n);
-
- if (leave) stack.emplace_back(&n, true);
-
- if (comp) {
- std::vector sorted_nodes;
- for (auto iter = n.InputNodesBegin(); iter != n.InputNodesEnd(); ++iter) {
- sorted_nodes.push_back((*iter));
- }
- std::sort(sorted_nodes.begin(), sorted_nodes.end(), comp);
- for (const auto* in : sorted_nodes) {
- const NodeIndex idx = in->Index();
- if (!visited[idx]) {
- stack.emplace_back(in, false);
- }
- }
- } else {
- for (auto iter = n.InputNodesBegin(); iter != n.InputNodesEnd(); ++iter) {
- const NodeIndex idx = (*iter)->Index();
- if (!visited[idx]) {
- stack.emplace_back(GetNode(idx), false);
- }
- }
- }
- }
-}
-
-GSL_SUPPRESS(es .84) // noisy warning about ignoring return value from insert(...)
-Status Graph::PerformTopologicalSortAndCheckIsAcyclic() {
- nodes_in_topological_order_.clear();
- // nodes that have been processed and added to nodes_in_topological_order.
- std::unordered_set processed_nodes;
- std::unordered_set output_nodes;
- std::unordered_set nodes_added_for_processing;
- std::stack stack;
-
- // push the top level nodes into nodes_in_topological_order in the order they were added
- // to ensure that is consistent.
- auto& nodes_in_original_order = Nodes();
- std::for_each(nodes_in_original_order.cbegin(), nodes_in_original_order.cend(),
- [&](const Node& node) {
- auto index = node.Index();
-
- // find the top level nodes in the graph.
- // need to also consider nodes that only have Constants as inputs as top level nodes,
- // as the constant will get replaced by an initializer.
- auto input_edges = node.GetRelationships().input_edges;
- auto has_inputs = std::any_of(input_edges.cbegin(), input_edges.cend(), [](const Node::EdgeEnd& edge) {
- return edge.GetNode().OpType() != kConstant;
- });
-
- if (!has_inputs) {
- // add to the topological list, and ensure we skip these nodes when walking the graph
- nodes_in_topological_order_.push_back(index);
- processed_nodes.insert(index);
-
- // mark this as added as we've fully processed it and don't need to do it again later
- nodes_added_for_processing.insert(index);
- }
- });
-
- // start at the bottom and work our way up the graph
- for (auto iter = Nodes().begin(); iter != Nodes().end(); ++iter) {
- if (0 == iter->relationships_.output_edges.size()) {
- // This is a leaf node.
- stack.push(iter->Index());
- }
- }
-
- while (!stack.empty()) {
- const NodeIndex current = stack.top();
- stack.pop();
-
- if (processed_nodes.find(current) != processed_nodes.end()) {
- continue;
- }
-
- if (nodes_added_for_processing.find(current) != nodes_added_for_processing.end()) {
- // we popped the stack and are back to a node that was added previously,
- // so we know all the upstream nodes from it have been fully processed,
- nodes_in_topological_order_.push_back(current);
- processed_nodes.insert(current);
- output_nodes.erase(current);
- continue;
- }
-
- const Node* node = GetNode(current);
- if (!node) {
- continue;
- }
-
- stack.push(current);
- output_nodes.insert(current);
-
- for (auto iter = node->InputNodesBegin(); iter != node->InputNodesEnd(); ++iter) {
- const NodeIndex idx = (*iter)->Index();
- if (output_nodes.find(idx) != output_nodes.end()) {
- Status status(ONNXRUNTIME, FAIL, "Error: the graph is not acyclic.");
- return status;
- }
-
- // avoid re-processing nodes
- if (nodes_added_for_processing.find(idx) == nodes_added_for_processing.end()) {
- stack.push(idx);
- }
- }
-
- nodes_added_for_processing.insert(current);
- }
-
- if (num_of_nodes_ >= 0 && static_cast(num_of_nodes_) == nodes_in_topological_order_.size()) {
- return Status::OK();
- } else {
- return Status(ONNXRUNTIME, FAIL, "Error: the graph is not acyclic.");
- }
-}
-
-bool FullyDefinedType(const TypeProto& type_proto) {
- switch (type_proto.value_case()) {
- case TypeProto::kTensorType: {
- auto& tensor_type = type_proto.tensor_type();
- return tensor_type.has_elem_type() && (tensor_type.elem_type() != TensorProto::UNDEFINED);
- }
- case TypeProto::kSparseTensorType: {
- auto& tensor_type = type_proto.sparse_tensor_type();
- return tensor_type.has_elem_type() && (tensor_type.elem_type() != TensorProto::UNDEFINED);
- }
- case TypeProto::kSequenceType: {
- auto& seq_type = type_proto.sequence_type();
- return seq_type.has_elem_type() && FullyDefinedType(seq_type.elem_type());
- }
- case TypeProto::kMapType: {
- auto& map_type = type_proto.map_type();
- return map_type.has_key_type() &&
- (map_type.key_type() != TensorProto::UNDEFINED) &&
- map_type.has_value_type() &&
- FullyDefinedType(map_type.value_type());
- }
- case TypeProto::kOpaqueType:
- return true;
- case TypeProto::VALUE_NOT_SET:
- default:
- return false;
- }
-}
-
-// function to handle type/shape inferencing of a subgraph.
-// parameters are the Graph instance for the subgraph, the input types from the control flow node that contains
-// the subgraph, and the vector to write the output from the inferencing.
-using SubgraphInferencingFunc =
- std::function&, std::vector&)>;
-
-class GraphInferencerImpl : public ONNX_NAMESPACE::GraphInferencer {
- public:
- GraphInferencerImpl(const Node& node, Graph& graph, SubgraphInferencingFunc& inferencing_func)
- : node_{node}, graph_{graph}, inferencing_func_{inferencing_func} {
- }
-
- // Perform inferencing on the graph contained in GraphInferencer.
- // Returns the graph output types post-inferencing.
- // We ignore input_data currently. Re-consider if InferenceContextImpl::getInputData gets implemented
- std::vector doInferencing(const std::vector& input_types,
- const std::vector& /*input_data*/) override {
- std::vector output_types;
-
- auto status = inferencing_func_(node_, graph_, input_types, output_types);
-
- if (status != Status::OK()) {
- fail_type_inference("Graph attribute inferencing failed: ", status.ErrorMessage());
- }
-
- return output_types;
- }
-
- private:
- const Node& node_;
- Graph& graph_;
- SubgraphInferencingFunc& inferencing_func_;
-};
-
-// An implementation of the InferenceContext interface required by operator-specific
-// shape inference for onnxruntime graphs.
-class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext {
- using AttributeGraphMap = std::unordered_map;
-
- public:
- InferenceContextImpl(Node& node,
- const AttributeGraphMap* subgraphs = nullptr,
- SubgraphInferencingFunc* subgraph_inferencing_func = nullptr) noexcept
- : node_(node),
- attr_to_subgraph_map_{subgraphs},
- subgraph_inferencing_func_{subgraph_inferencing_func} {
- node_output_types_.resize(node.OutputDefs().size());
- }
-
- void RunInferencing() {
- auto schema = node_.Op();
- if (nullptr != schema) {
- schema->GetTypeAndShapeInferenceFunction()(*this);
- }
- }
-
- const std::vector InferredOutputTypes() const { return node_output_types_; }
-
- const AttributeProto* getAttribute(const std::string& name) const override {
- auto& attribute_value_map = node_.GetAttributes();
- auto iter = attribute_value_map.find(name);
- if (iter == attribute_value_map.end()) {
- return nullptr;
- } else {
- return &iter->second;
- }
- }
-
- size_t getNumInputs() const noexcept override {
- return node_.InputDefs().size();
- }
-
- const TypeProto* getInputType(size_t index) const override {
- auto p_node_arg = node_.InputDefs().at(index);
- if ((nullptr != p_node_arg) && p_node_arg->Exists()) {
- return p_node_arg->TypeAsProto();
- // auto p_type_proto = p_node_arg->TypeAsProto();
- //if ((p_type_proto != nullptr) && p_type_proto->has_tensor_type()) {
- // return &p_type_proto->tensor_type();
- //}
- }
- return nullptr;
- }
-
- size_t getNumOutputs() const noexcept override {
- return node_output_types_.size();
- }
-
- TypeProto* getOutputType(size_t index) override {
- return &node_output_types_[index];
- }
-
- const TensorProto* getInputData(size_t) const override {
- // TODO: this interface should be implemented with initializers
- // so that more accurate shape inference could be done.
- return nullptr;
- }
-
- GraphInferencer* getGraphAttributeInferencer(const std::string& attribute_name) override {
- GraphInferencer* graph_inferencer = nullptr;
-
- if (attr_to_subgraph_map_ && subgraph_inferencing_func_) {
- auto attr_to_subgraph = attr_to_subgraph_map_->find(attribute_name);
- if (attr_to_subgraph != attr_to_subgraph_map_->cend()) {
- auto inferencer = std::make_unique(node_, *attr_to_subgraph->second,
- *subgraph_inferencing_func_);
- graph_inferencer = inferencer.get();
- graph_inferencers_.push_back(std::move(inferencer));
- } else {
- fail_type_inference("No Graph instance was found for attribute ",
- attribute_name, " in node ", node_.Name());
- }
- }
-
- return graph_inferencer;
- }
-
- private:
- Node& node_;
- // node_output_types_ will be populated by the operator-specific shape inference.
- std::vector node_output_types_;
- const AttributeGraphMap* attr_to_subgraph_map_;
- SubgraphInferencingFunc* subgraph_inferencing_func_;
- std::vector> graph_inferencers_;
-};
-
-Status Graph::InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph,
- const std::vector& input_types,
- std::vector& output_types) {
- auto status = Status::OK();
-
- output_types.clear();
-
- auto& subgraph_inputs = subgraph.GetInputs();
- auto num_subgraph_inputs = subgraph_inputs.size();
-
- if (num_subgraph_inputs != input_types.size()) {
- return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "Size mismatch validating subgraph inputs. Got ",
- input_types.size(), " inputs but subgraph requires ", subgraph_inputs.size());
- }
-
- // apply type/shape info to the subgraph's inputs
- for (size_t i = 0; i < num_subgraph_inputs; ++i) {
- const auto& input_type = *input_types[i];
- const auto& subgraph_input = *subgraph_inputs[i];
-
- NodeArg* mutable_nodearg = subgraph.GetNodeArg(subgraph_input.Name());
- status = mutable_nodearg->UpdateTypeAndShape(input_type);
- if (!status.IsOK()) {
- return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node:", node.Name(), " ", status.ErrorMessage());
- }
- }
-
- // Apply any current input type/shape information to the Nodes in the subgraph that are implicitly
- // consuming NodeArg's from this scope or higher.
- // The NodeArg's that implicit_input_defs point to would have any type/shape inferencing applied to them
- // by now. As the subgraph is referring to the outer scope NodeArg, we simply replace any information in
- // the subgraph with the details from the outer scope NodeArg.
- auto implicit_input_defs = node.GetDefinitions().implicit_input_defs;
- for (const auto* implicit_node_arg : implicit_input_defs) {
- auto subgraph_nodearg = subgraph.GetNodeArg(implicit_node_arg->Name());
-
- // the implicit input defs may be for a nested subgraph, so it won't necessarily match here.
- // if that is the case, we will update the type/shape information when we descend into the
- // nested subgraph later.
- if (!subgraph_nodearg)
- continue;
-
- status = subgraph_nodearg->UpdateTypeAndShape(*implicit_node_arg);
- if (!status.IsOK()) {
- return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node:", node.Name(), " ", status.ErrorMessage());
- }
-
- // all values above us should have a type by now due to ONNX requirements.
- if (subgraph_nodearg->Type() == nullptr)
- return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "Subgraph input missing type.");
- }
-
- // now that we have handled the input types, do the type/shape inferencing for the subgraph
- // to flow the type/shape info through it
- status = subgraph.PerformTypeAndShapeInferencing();
- ONNXRUNTIME_RETURN_IF_ERROR(status);
-
- auto& subgraph_outputs = subgraph.GetOutputs();
- for (const auto* output : subgraph_outputs) {
- output_types.push_back(output->TypeAsProto());
- }
-
- return Status::OK();
-}
-
-// Implementation of type-inference and type-checking for a single node
-GSL_SUPPRESS(f .23) // spurious warning about inferred_type never being checked for null
-Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op) {
- auto& node_name = node.Name();
-
- // if we're building a graph we permit outer scope node args to have no type
- // as the 'real' Resolve at runtime will have type inferencing
- auto is_outer_scope_nodearg = [this](const std::string& name) {
- return outer_scope_node_arg_names_.find(name) != outer_scope_node_arg_names_.cend();
- };
-
- // index used to navigate node->InputDefs().
- int k = 0;
- std::unordered_map type_parameter_to_type_map;
-
- for (size_t i = 0; i < node.InputArgCount().size(); ++i) {
- // Number of inputs corresponding to the i-th argument.
- const int arg_count = node.InputArgCount()[i];
- // The i-th formal parameter definition.
- auto op_formal_parameter = op.inputs()[i];
-
- // Check all actual parameters (corresponding to the k-th input)
- // match the formal parameter definition (i-th argument).
- for (int j = 0; j < arg_count; ++j, ++k) {
- auto& input_def = node.MutableDefinitions().input_defs[k];
- if (!input_def->Exists())
- continue;
-
- if (input_def->Type() == nullptr) {
- // if we are building a subgraph that uses outer scope values,
- // allow an empty type as it will be copied from the outer scope graph at runtime
- if (is_outer_scope_nodearg(input_def->Name()))
- continue;
-
- // Logic error: This should not happen if we properly checked that every use has
- // a corresponding def, for which type-inference already produced a valid type
- Status status(ONNXRUNTIME, FAIL,
- "Node (" + node_name + ") input arg (" +
- input_def->Name() + ") does not have type information set by parent node.");
- return status;
- }
-
- // Verify that the actual parameter's type is one of permitted types of the formal parameter
- DataType input_type = input_def->Type();
- auto& permitted_types = op_formal_parameter.GetTypes();
- if (0 == permitted_types.count(input_type)) {
- std::string null_pointer("(null)");
- if (input_type == nullptr) input_type = &null_pointer;
- // Type error in input model/graph.
-
- Status status(ONNXRUNTIME, INVALID_GRAPH,
- "Type Error: Type '" + *input_type + "' of input parameter (" + input_def->Name() +
- ") of operator (" + op.Name() + ") in node (" + node_name + ") is invalid.");
- return status;
- }
-
- // Check that type-parameters are bound to the same value:
- auto param_to_type_iter = type_parameter_to_type_map.find(op_formal_parameter.GetTypeStr());
- if (type_parameter_to_type_map.end() == param_to_type_iter) {
- // Bind the corresponding type-parameter's value to the actual type:
- type_parameter_to_type_map[op_formal_parameter.GetTypeStr()] = input_type;
- } else if (param_to_type_iter->second != input_type) {
- // Type error in input model/graph:
- // The type-parameter T is bound to different values for different inputs.
- // E.g., Add(A,B) where A is of type "tensor(int32)" and B is of type "tensor(float)".
- // NOTE: for variadic arguments, this verification rule is currently applicable:
- // e.g., Concat/Max/Mean/Min/Sum all require all input tensors to be of same type.
- // However, this will need to be extended to handle the If-Then-Else and Loop
- // constructs in future which will have variadic inputs and outputs of different types.
-
- Status status(ONNXRUNTIME, FAIL,
- "Type Error: Type parameter (" + op_formal_parameter.GetTypeStr() +
- ") bound to different types (" + *(param_to_type_iter->second) +
- " and " + *(input_def->Type()) +
- " in node (" + node_name + ").");
- return status;
- }
- }
- }
-
- // Apply ONNX's type/shape inference to this node.
- // This will call InferAndVerifySubgraphTypes if the ONNX level type/shape inferencing for the Node attempts
- // to do subgraph type/shape inferencing (Scan/If/Loop nodes).
- // InferAndVerifySubgraphTypes will call PerformTypeAndShapeInferencing for the subgraph, which will recursively
- // handle type/shape inferencing for it.
- // Once that completes, the outputs from the node containing the subgraph will be updated, and the final values
- // returned here.
- SubgraphInferencingFunc func(Graph::InferAndVerifySubgraphTypes);
- auto node_subgraphs = subgraph_map_.find(node.Index());
- auto* subgraphs = node_subgraphs != subgraph_map_.cend() ? &node_subgraphs->second : nullptr;
- InferenceContextImpl context(node, subgraphs, &func);
-
- try {
- context.RunInferencing();
- } catch (const std::exception& ex) {
- return Status(ONNXRUNTIME, FAIL, ex.what());
- }
-
- const auto& onnx_inferred_types{context.InferredOutputTypes()};
-
- // Infer and verify node output arg type information.
- int i = -1;
- for (auto& output_def : node.MutableDefinitions().output_defs) {
- ++i;
- if (!output_def->Exists()) continue;
-
- // if the number of actual parameters exceeds the number of formal parameters,
- // then the op has variadic outputs and the trailing extra actual parameters
- // correspond to the last formal parameter. (The ONNX schema verification check
- // would have checked that the corresponding formal parameter is variadic.)
-
- const int num_formal_params = gsl::narrow_cast(op.outputs().size());
- auto operand_index = std::min(i, num_formal_params - 1);
- auto op_formal_parameter = op.outputs().at(operand_index);
-
- const TypeProto& onnx_inferred_type = onnx_inferred_types[i];
- DataType existing_type = output_def->Type();
- DataType inferred_type = nullptr;
-
- // Infer output arg type if it is constrained to be of the same type as some input:
- // For example, the output of "Abs" is of the same type as its input.
- auto input_types_iter = type_parameter_to_type_map.find(op_formal_parameter.GetTypeStr());
- if (type_parameter_to_type_map.end() != input_types_iter) {
- inferred_type = input_types_iter->second;
- } else if (1 == op_formal_parameter.GetTypes().size()) {
- // Infer output arg type if operator definition specifies unique output type:
- inferred_type = *(op_formal_parameter.GetTypes().begin());
- } else if (FullyDefinedType(onnx_inferred_type)) {
- // Use output type inferred by ONNX inference
- inferred_type = DataTypeUtils::ToType(onnx_inferred_type);
- } else if (existing_type != nullptr) {
- inferred_type = existing_type;
- } else {
- // This should not happen: indicates incompleteness in ONNX inference.
- Status status(ONNXRUNTIME, FAIL,
- "Node (" + node_name + ") output arg (" + output_def->Name() + ") type inference failed");
- return status;
- }
-
- if ((existing_type != inferred_type) && (existing_type != nullptr)) {
- // A type exists for this output but does not match the inferred type.
- return Status(ONNXRUNTIME, FAIL,
- "Type Error: Type (" + *existing_type + ") of output arg (" +
- output_def->Name() + ") of node (" + node_name +
- ") does not match expected type (" + *inferred_type + ").");
- }
-
- if (existing_type == nullptr)
- output_def->SetType(inferred_type);
-
- // Update output-shape if it was inferred:
- if (onnx_inferred_type.has_tensor_type()) {
- auto& tensor_type = onnx_inferred_type.tensor_type();
- if (tensor_type.has_shape()) {
- if (output_def->Shape() == nullptr) {
- output_def->SetShape(tensor_type.shape());
- } else {
- // we need to merge the shapes as a subgraph may have placeholder dimensions to represent the rank
- // that have no values.
- TypeProto_Tensor merge_target;
- (*merge_target.mutable_shape()) = *output_def->Shape();
- auto status = MergeShapeInfo(output_def->Name(), tensor_type, merge_target);
- if (!status.IsOK()) {
- return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node:", node_name, " ", status.ErrorMessage());
- }
-
- output_def->SetShape(merge_target.shape());
- }
- }
- }
- }
-
- return Status::OK();
-}
-
-// Apply type-inference and type-checking to all inputs and initializers:
-common::Status Graph::TypeCheckInputsAndInitializers() {
- // Check that the type of every input is specified:
- for (auto* graph_input : GetInputs()) {
- if (nullptr == graph_input->Type()) {
- Status status(ONNXRUNTIME, FAIL, "Model input (" + graph_input->Name() + ") does not have type information.");
- return status;
- }
- }
-
- // Note: The ONNX spec requires every initializer to be included in the graph input,
- // but onnxruntime relaxes this requirement for various reasons.
-
- // Infer/check type and shape for all initializers from their values
- for (auto& initializer_pair : name_to_initial_tensor_) {
- const std::string& name = initializer_pair.first;
- auto* node_arg = GetNodeArg(name);
- // If node_arg is null, we ignore this as a potentially unused initializer here
- if (nullptr != node_arg) {
- const TensorProto* tensor_proto = initializer_pair.second;
- TypeProto tensor_type;
- tensor_type.mutable_tensor_type()->set_elem_type(tensor_proto->data_type());
- auto inferred_type = DataTypeUtils::ToType(tensor_type);
- auto existing_type = node_arg->Type();
- if (nullptr == existing_type)
- node_arg->SetType(inferred_type);
- else if (inferred_type != existing_type) {
- return Status(ONNXRUNTIME, FAIL,
- "Type Error: Value of initializer " + name + " does not match its type.");
- }
-
- // Set shape accordingly.
- TensorShapeProto inferred_shape;
- for (auto dim : tensor_proto->dims()) {
- inferred_shape.add_dim()->set_dim_value(dim);
- }
- const TensorShapeProto* p_existing_shape = node_arg->Shape();
- if (nullptr == p_existing_shape)
- node_arg->SetShape(inferred_shape);
- else {
- if (p_existing_shape->dim_size() != tensor_proto->dims_size())
- return Status(ONNXRUNTIME, FAIL,
- "Type Error: Shape of initializer " + name + " does not match its type.");
- for (int i = 0; i < p_existing_shape->dim_size(); ++i) {
- auto& d = p_existing_shape->dim(i);
- if (d.has_dim_value() && (d.dim_value() != tensor_proto->dims(i)))
- return Status(ONNXRUNTIME, FAIL,
- "Type Error: Shape of initializer " + initializer_pair.first + " does not match its type.");
- }
- }
- }
- }
- return Status::OK();
-}
-
-Status Graph::VerifyNodeAndOpMatch() {
- CheckerContext ctx;
- ctx.set_ir_version(gsl::narrow_cast(IrVersion()));
- ctx.set_opset_imports(DomainToVersionMap());
- ctx.set_schema_registry(schema_registry_.get());
-
- LexicalScopeContext lsc{resolve_context_.inputs_and_initializers};
-
- // technically we could add values from Node.GetDefinitions().implicit_input_defs on a per-node basis inside
- // the below loop so that we only check against the specific outer dependencies of the node.
- // doing that requires lots of copies of LexicalScopeContext.output_names to clear out the per-Node values
- // after each loop. instead add all the outer scope values upfront so we can just accumulate new inner scope values
- // during each loop iteration.
- lsc.output_names.insert(resolve_context_.outer_scope_node_args.cbegin(),
- resolve_context_.outer_scope_node_args.cend());
-
- for (auto node_index : nodes_in_topological_order_) {
- // Node verification.
- auto& node = *GetNode(node_index);
-
- NodeProto node_proto;
- node.ToProto(node_proto);
- auto& node_name = node.Name();
- auto& domain = node.Domain();
-
- if (!node.Op()) {
- try {
- checker::check_node(node_proto, ctx, lsc);
- } catch (const std::exception& ex) {
- return Status(ONNXRUNTIME, INVALID_GRAPH, ex.what());
- }
-
- auto maxInclusiveVersion = DomainToVersionMap().find(domain)->second;
- node.op_ = schema_registry_->GetSchema(node.OpType(), maxInclusiveVersion, node.Domain());
-
- if (node.op_ && node.op_->Deprecated()) {
- node.op_ = nullptr;
- }
-
- if (!node.op_) {
- ONNX_NAMESPACE::FunctionBuilderRegistry& function_registry =
- FunctionBuilderRegistry::OnnxInstance();
- auto onnx_function_proto = function_registry.GetFunction(node.OpType(), maxInclusiveVersion, ONNX_DOMAIN);
- if (!onnx_function_proto) {
- return Status(ONNXRUNTIME, FAIL, "Fatal error: " + node.OpType() + " is not a registered function/op");
- }
- auto func_ptr = std::make_unique(*this, node.Index(), onnx_function_proto);
- function_container_->functions_.push_back(std::move(func_ptr));
- node.SetFunctionBody(*function_container_->functions_.back());
- }
- }
-
- ONNXRUNTIME_RETURN_IF_ERROR(node.UpdateInputArgCount());
-
- // currently an Op is required by ValidateVersion, so we use gsl::not_null to validate that.
- // This may change in the future to allow a null Op
- const gsl::not_null p_op{node.Op()};
-
- // Attribute verification and fill node attribute with
- // default value defined in operator definition if needed.
- // Fill node attribute with default value specified in operator definition if any.
- auto node_attributes = node.GetAttributes();
- for (auto attr_def : p_op->attributes()) {
- auto node_attr_iter = node_attributes.find(attr_def.first);
- if (node_attributes.end() == node_attr_iter) {
- // The attribute was not specified in the node.
- if (!attr_def.second.required) {
- if (attr_def.second.default_value.has_name()) {
- // Set default value to the node attributes.
- node.AddAttribute(attr_def.first, attr_def.second.default_value);
- }
- // TODO: Handle optional attribute but no default value specified in op definition.
- } else {
- Status status(ONNXRUNTIME, FAIL,
- "Node (" + node_name + ") attribute (" + attr_def.first +
- ") is required but not specified.");
- return status;
- }
- }
- }
-
- NO_CHANGE_ON_SYNC_FLAG(ONNXRUNTIME_RETURN_IF_ERROR(InferAndVerifyTypeMatch(node, *p_op)));
-
- // Accumulate output names of the iterated Node
- for (auto& output_name : node_proto.output()) {
- lsc.output_names.insert(output_name);
- }
- }
-
- return Status::OK();
-}
-
-Graph* Graph::GetMutableSubgraph(const NodeIndex node_index, const std::string& attribute_name) {
- const Graph* subgraph = GetSubgraph(node_index, attribute_name);
- return const_cast(subgraph);
-}
-
-const Graph* Graph::GetSubgraph(const NodeIndex node_index, const std::string& attribute_name) const {
- Graph* subgraph = nullptr;
-
- auto entry = subgraph_map_.find(node_index);
-
- if (entry != subgraph_map_.cend()) {
- auto& name_to_subgraph_map = entry->second;
- auto subgraph_iter = name_to_subgraph_map.find(attribute_name);
- if (subgraph_iter != name_to_subgraph_map.cend()) {
- subgraph = subgraph_iter->second;
- }
- }
-
- return subgraph;
-}
-
-Status Graph::CreateSubgraphs() {
- Status status = Status::OK();
-
- // don't use NodesInTopologicalOrder as we want CreateSubgraphs to recursively create subgraphs with no
- // dependency on PerformTopologicalSortAndCheckIsAcyclic having been called previously
- // to populate NodesInTopologicalOrder
- for (auto& node : Nodes()) {
- auto node_index = node.Index();
- if (subgraph_map_.find(node_index) != subgraph_map_.cend()) {
- // if we have an existing entry we have processed this node previously.
- // as the subgraph is loaded from a static GraphProto we assume nothing in
- // it could have changed and there's no point re-creating it.
- continue;
- }
-
- // check attributes of all nodes looking for GraphProto attributes, and create
- // the Graph instance for the subgraph contained in the GraphProto.
- for (auto& attr : node.attributes_) {
- bool has_subgraph = attr.second.has_g();
- if (has_subgraph) {
- auto& attr_name = attr.first;
- auto entry = subgraph_map_.find(node_index);
-
- // make sure this is new. internal logic error if it is not so using ONNXRUNTIME_ENFORCE.
- if (entry != subgraph_map_.cend()) {
- const auto& existing_entries = entry->second;
- ONNXRUNTIME_ENFORCE(existing_entries.find(attr_name) == existing_entries.cend(),
- "Entry exists in node ", node_index, " for attribute ", attr_name);
- }
-
- auto& graph_proto = *attr.second.mutable_g();
-
- // create instance. need to call private ctor so can't use make_unique
- GSL_SUPPRESS(r .11)
- std::unique_ptr subgraph{new Graph(*this, graph_proto)};
-
- // Recursively create any further subgraphs
- status = subgraph->CreateSubgraphs();
- ONNXRUNTIME_RETURN_IF_ERROR(status);
-
- subgraph_map_[node_index][attr_name] = subgraph.get();
- subgraphs_.push_back(std::move(subgraph));
- }
- }
- }
-
- return Status::OK();
-}
-
-Status Graph::VerifyInputAndInitializerNames() {
- std::unordered_set& inputs_and_initializers = resolve_context_.inputs_and_initializers;
-
- for (auto* input : GetInputs()) {
- auto result = inputs_and_initializers.insert(input->Name());
- if (!result.second) {
- Status status(ONNXRUNTIME, FAIL,
- "Error: Duplicate definition-site for (" + input->Name() + ").");
- return status;
- }
- }
-
- for (auto& initializer_pair : name_to_initial_tensor_) {
- GSL_SUPPRESS(es .84)
- inputs_and_initializers.insert(initializer_pair.first);
- // Initializers are expected to be included in inputs (according to ONNX spec).
- // onnxruntime relaxes this constraint. No duplicate-name check here.
- }
-
- return Status::OK();
-}
-
-Status Graph::InitInputsInitializersOutputs() {
- resolve_context_.Clear();
-
- // clear the previous relationships, as we re-create them when resolving.
- // same applies to the implicit input defs as they are built from any subgraphs within this graph.
- for (auto& node : Nodes()) {
- node.MutableRelationships().Clear();
- node.MutableDefinitions().implicit_input_defs.clear();
- }
-
- // add the subgraph pointers to the resolve context.
- for (auto& nodeid_to_subgraphs : subgraph_map_) {
- resolve_context_.node_to_subgraphs_map[nodeid_to_subgraphs.first] = {};
-
- for (auto& attr_name_to_subgraph : nodeid_to_subgraphs.second) {
- resolve_context_.node_to_subgraphs_map[nodeid_to_subgraphs.first].push_back(attr_name_to_subgraph.second);
- }
- }
-
- ONNXRUNTIME_RETURN_IF_ERROR(SetGraphInputsOutputs());
- ONNXRUNTIME_RETURN_IF_ERROR(VerifyInputAndInitializerNames());
- ONNXRUNTIME_RETURN_IF_ERROR(VerifyNoDuplicateName());
-
- return Status::OK();
-}
-
-Status Graph::PerformTypeAndShapeInferencing() {
- ONNXRUNTIME_RETURN_IF_ERROR(TypeCheckInputsAndInitializers());
-
- // type/shape inferencing on the nodes is done recursively as we need subgraph outputs
- // to be applied to Node outputs for the node containing the subgraph.
- // Call path is
- // VerifyNodeAndOpMatch
- // Iterates Nodes
- // Runs ONNX type/shape inferencing for each Node
- // - If it hits a node with a subgraph, InferenceContext::getGraphAttributeInferencer is called
- // by the ONNX level type/shape inferencing, which updates the subgraph inputs using GraphInferencerImpl
- // - GraphInferencerImpl::doInferencing calls PerformTypeShapeInferencing to execute type/shape inferencing
- // for all nodes in the subgraph. This leads to recursively handling all subgraphs contained in the node.
- // - once we finish processing the subgraph/s we apply resultant type/shape information to the outputs
- // of the node that contains the subgraph.
- ONNXRUNTIME_RETURN_IF_ERROR(VerifyNodeAndOpMatch());
-
- return Status::OK();
-}
-
-Status Graph::ForThisAndAllSubgraphs(std::function func) {
- auto status = func(*this);
- ONNXRUNTIME_RETURN_IF_ERROR(status);
-
- for (auto& subgraph : subgraphs_) {
- status = func(*subgraph);
- ONNXRUNTIME_RETURN_IF_ERROR(status);
- }
- return status;
-}
-
-Status Graph::Resolve() {
- const NodeArg *n = GetNodeArg("ReLU39_Output_0");
- Status s = Resolve(false);
- const NodeArg *n2 = GetNodeArg("ReLU39_Output_0");
- return s;
-}
-
-Status Graph::Resolve(bool no_proto_sync_required) {
- if (parent_graph_) {
- // Resolve must start at the top level graph in-order to handle outer scope
- // connections correctly, so recurse up to that level to start
- auto status = parent_graph_->Resolve(no_proto_sync_required);
- return status;
- }
-
- bool subgraphs_need_resolve = std::any_of(subgraphs_.cbegin(), subgraphs_.cend(),
- [](const std::unique_ptr& graph) {
- return graph->GraphResolveNeeded();
- });
-
- if (!GraphResolveNeeded() && !subgraphs_need_resolve) {
- return Status::OK();
- }
-
- // Create the Graph instances for the subgraph/s in any nodes containing GraphProto attributes (Scan/If/Loop).
- // Do this upfront so we can recurse into them when building connections and doing type/shape inferencing.
- // Recursively creates any nested subgraphs.
- ONNXRUNTIME_RETURN_IF_ERROR(CreateSubgraphs());
-
- // init all graph/subgraphs. non-recursive.
- auto init_func = [](Graph& graph) { return graph.InitInputsInitializersOutputs(); };
- ONNXRUNTIME_RETURN_IF_ERROR(ForThisAndAllSubgraphs(init_func));
-
- // recursively set the outer scope node args.
- ONNXRUNTIME_RETURN_IF_ERROR(SetOuterScopeNodeArgs(resolve_context_.outer_scope_node_args));
-
- std::vector outer_scope_node_args_consumed;
-
- // recursively build connections between nodes in this graph and all subgraphs
- ONNXRUNTIME_RETURN_IF_ERROR(BuildConnections(outer_scope_node_args_consumed));
- ONNXRUNTIME_ENFORCE(outer_scope_node_args_consumed.empty(),
- "Shouldn't be possible to have NodeArgs that haven't been handled already.");
-
- // topological sort of this and any subgraphs is non-recursive
- auto topo_sort_func = [](Graph& graph) { return graph.PerformTopologicalSortAndCheckIsAcyclic(); };
- ONNXRUNTIME_RETURN_IF_ERROR(ForThisAndAllSubgraphs(topo_sort_func));
-
- // type/shape validation and inferencing on this and any subgraphs
- // recurses into subgraphs via the ONNX checker, which descends into the GraphProto in node attributes
- // which define a subgraph.
- ONNXRUNTIME_RETURN_IF_ERROR(PerformTypeAndShapeInferencing());
-
- // perform the final steps for this graph and all subgraphs
- auto finalize_func = [&no_proto_sync_required](Graph& graph) {
- graph.CleanUnusedInitializers();
- graph.GraphResolveNeeded(false);
-
- // if we are resolving immediately after loading from a GraphProto, we don't need to
- // do a proto sync
- if (no_proto_sync_required) {
- graph.GraphProtoSyncNeeded(false);
- }
-
- return Status::OK(); };
-
- ONNXRUNTIME_RETURN_IF_ERROR(ForThisAndAllSubgraphs(finalize_func));
-
- return Status::OK();
-}
-
-const std::string& Graph::Name() const noexcept {
- return graph_proto_->name();
-}
-
-void Graph::SetName(const std::string& name) {
- graph_proto_->set_name(name);
-}
-
-const std::string& Graph::Description() const noexcept {
- return graph_proto_->doc_string();
-}
-
-void Graph::SetDescription(const std::string& description) {
- graph_proto_->set_doc_string(description);
-}
-
-void Graph::AddInitializedTensor(const TensorProto& tensor) {
- if (name_to_initial_tensor_.end() != name_to_initial_tensor_.find(tensor.name())) {
- return;
- }
-
- const gsl::not_null tensor_added{graph_proto_->add_initializer()};
- *(tensor_added) = tensor;
- name_to_initial_tensor_[tensor.name()] = tensor_added;
-
- if (!GraphLoadedFromModelFile(graph_proto_)) {
- // make sure there is a NodeArg for the initializer as SetGraphInputsOutputs will add it to the graph inputs
- TypeProto t;
- t.mutable_tensor_type()->set_elem_type(tensor.data_type());
- auto shape = t.mutable_tensor_type()->mutable_shape();
- for (auto dim : tensor.dims())
- shape->add_dim()->set_dim_value(dim);
-
- ONNXRUNTIME_IGNORE_RETURN_VALUE(GetOrCreateNodeArg(tensor.name(), &t));
- }
-
- SetGraphProtoSyncNeeded();
- SetGraphResolveNeeded();
-}
-
-void Graph::RemoveInitializedTensor(const std::string& tensor_name) {
- auto iter = name_to_initial_tensor_.find(tensor_name);
- if (name_to_initial_tensor_.end() != iter) {
- name_to_initial_tensor_.erase(tensor_name);
- SetGraphProtoSyncNeeded();
- SetGraphResolveNeeded();
- }
-}
-
-bool Graph::GetInitializedTensor(const std::string& tensor_name, const TensorProto*& value) const {
- auto iter = name_to_initial_tensor_.find(tensor_name);
- if (name_to_initial_tensor_.end() == iter) {
- value = nullptr;
- return false;
- }
- value = iter->second;
- return true;
-}
-
-void Graph::CleanAllInitializedTensors() noexcept {
- name_to_initial_tensor_.clear();
- removed_initializer_indexes_.clear();
-
- // Clearing RepeatedPtrFields does not free objects' memory. The memory is retained
- // and can be reused. Need to explicitly release the cleared objects and free the
- // memory.
- graph_proto_->mutable_initializer()->Clear();
- const int num_cleared = graph_proto_->initializer().ClearedCount();
- for (int i = 0; i < num_cleared; i++) {
- delete graph_proto_->mutable_initializer()->ReleaseCleared();
- }
-}
-
-const InitializedTensorSet& Graph::GetAllInitializedTensors() const noexcept {
- return name_to_initial_tensor_;
-}
-
-const std::vector& Graph::GetValueInfo() const noexcept {
- return value_info_;
-}
-
-std::vector Graph::CreateNodeArgs(const google::protobuf::RepeatedPtrField& names,
- const ArgNameToTypeMap& name_to_type_map) {
- const auto name_to_type_map_end = name_to_type_map.end();
- std::vector results;
- results.reserve(names.size());
-
- for (auto& name : names) {
- const TypeProto* type = nullptr;
-
- auto name_to_type_iter = name_to_type_map.find(name);
- if (name_to_type_iter != name_to_type_map_end) {
- // This node input arg type/shape does exist in graph proto.
- // Assign type/shape information to node input arg.
- type = &(name_to_type_iter->second);
- }
-
- auto node_arg = &GetOrCreateNodeArg(name, type);
- results.push_back(node_arg);
- }
-
- return results;
-}
-
-Node* Graph::AddNode(const Node& other) {
- const auto& definitions = other.GetDefinitions();
-
- auto new_node = AddNode(other.Name(), other.OpType(), other.Description(),
- definitions.input_defs,
- definitions.output_defs,
- &other.GetAttributes(),
- other.Domain());
-
- return new_node;
-}
-
-Node* Graph::AddNode(const NodeProto& node_proto,
- const ArgNameToTypeMap& name_to_type_map) {
- auto input_defs = CreateNodeArgs(node_proto.input(), name_to_type_map);
- auto output_defs = CreateNodeArgs(node_proto.output(), name_to_type_map);
-
- const int num_attributes = node_proto.attribute_size();
- NodeAttributes attributes;
- attributes.reserve(num_attributes);
-
- for (int i = 0; i < num_attributes; ++i) {
- auto& attr = node_proto.attribute(i);
- attributes[attr.name()] = attr;
- }
-
- return AddNode(node_proto.name(),
- node_proto.op_type(),
- node_proto.doc_string(),
- input_defs,
- output_defs,
- &attributes,
- node_proto.domain());
-}
-
-std::string Graph::GenerateNodeArgName(const std::string& base_name) {
- std::string new_name;
- do {
- std::ostringstream str;
- str << base_name << "_" << name_generator_++;
- new_name = str.str();
- } while (node_args_.find(new_name) != node_args_.end());
- return new_name;
-}
-
-std::string Graph::GenerateNodeName(const std::string& base_name) {
- std::string new_name;
- bool keep_going = true;
-
- do {
- std::ostringstream str;
- str << base_name << "_" << name_generator_++;
- new_name = str.str();
-
- keep_going = std::find_if(nodes_.cbegin(), nodes_.cend(), [&new_name](const std::unique_ptr& n) {
- return (n != nullptr) && (n->Name() == new_name);
- }) != nodes_.end();
- } while (keep_going);
-
- return new_name;
-}
-
-Node* Graph::AddNode(const std::string& name,
- const std::string& op_type,
- const std::string& description,
- const std::vector& input_args,
- const std::vector& output_args,
- const NodeAttributes* attributes,
- const std::string& domain) {
- std::vector inputs, outputs;
- inputs.resize(input_args.size());
- outputs.resize(output_args.size());
- int i = 0;
- for (auto input_arg : input_args) {
- inputs[i++] = &GetOrCreateNodeArg(input_arg->Name(), input_arg->TypeAsProto());
- }
- i = 0;
- for (auto output_arg : output_args) {
- outputs[i++] = &GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto());
- }
-
- const gsl::not_null node = AllocateNode();
- node->Init(name, op_type, description, inputs, outputs, attributes, domain);
- if (0 != op_type.compare(kNoOp)) {
- graph_proto_sync_needed_ = true;
- }
-
- return node;
-}
-
-bool Graph::RemoveNode(NodeIndex p_index) {
- return ReleaseNode(p_index);
-}
-
-bool Graph::AddControlEdge(NodeIndex src_node_index, NodeIndex dst_node_index) {
- if (nodes_.size() <= src_node_index ||
- nodes_.size() <= dst_node_index ||
- nullptr == nodes_[src_node_index] ||
- nullptr == nodes_[dst_node_index]) {
- // Invalid node indexes specified.
- return false;
- }
-
- GSL_SUPPRESS(es .84) { // ignoring return from insert()
- nodes_[src_node_index]->MutableRelationships().output_edges.insert(Node::EdgeEnd(*nodes_[dst_node_index]));
- nodes_[dst_node_index]->MutableRelationships().input_edges.insert(Node::EdgeEnd(*nodes_[src_node_index]));
- nodes_[dst_node_index]->MutableRelationships().control_inputs.insert(nodes_[src_node_index]->Name());
- }
-
- return true;
-}
-
-const GraphProto& Graph::ToGraphProto() {
- if (!GraphProtoSyncNeeded()) {
- return *graph_proto_;
- }
-
- // Nodes.
- graph_proto_->clear_node();
- GraphViewer graph_viewer(*this);
- // Nodes must be sorted in Topological Order in the GraphProto per ONNX spec.
- for (auto& node_idx : graph_viewer.GetNodesInTopologicalOrder()) {
- const gsl::not_null node_proto{graph_proto_->add_node()};
- const gsl::not_null p_node{GetNode(node_idx)};
- p_node->ToProto(*node_proto);
- }
-
- if (!removed_initializer_indexes_.empty()) {
- // Move initializers.
- std::sort(removed_initializer_indexes_.begin(), removed_initializer_indexes_.end());
- int lastInUseInitializerIndex = graph_proto_->initializer_size() - 1;
- int start = 0, end = gsl::narrow_cast(removed_initializer_indexes_.size()) - 1;
- int lastRemovedInitializerIndex = removed_initializer_indexes_[end];
-
- for (; start <= end; start++) {
- // Find a lastInUseInitializer.
- while (start <= end && lastInUseInitializerIndex == lastRemovedInitializerIndex) {
- graph_proto_->mutable_initializer()->RemoveLast();
- lastInUseInitializerIndex--;
- end--;
- if (start <= end) {
- lastRemovedInitializerIndex = removed_initializer_indexes_[end];
- }
- }
-
- if (start <= end) {
- // Copy the initializer in use to the