Adding ONNX format support to CNTK.
This commit is contained in:
Родитель
0b384cba0d
Коммит
d57ad1d673
|
@ -234,6 +234,10 @@ bindings/python/doc/_build
|
|||
Source/CNTKv2LibraryDll/proto/CNTK.pb.cc
|
||||
Source/CNTKv2LibraryDll/proto/CNTK.pb.h
|
||||
|
||||
# Auto-generated sources from ONNX proto
|
||||
Source/CNTKv2LibraryDll/proto/onnx/protobuf/graph.pb.cc
|
||||
Source/CNTKv2LibraryDll/proto/onnx/protobuf/graph.pb.h
|
||||
|
||||
bindings/python/cntk/c_plus_c.mod
|
||||
bindings/python/cntk/i_plus_c_0.mod
|
||||
bindings/python/cntk/i_plus_i_0.mod
|
||||
|
|
21
Makefile
21
Makefile
|
@ -493,6 +493,27 @@ CNTKLIBRARY_COMMON_SRC =\
|
|||
$(SOURCEDIR)/CNTKv2LibraryDll/tensorboard/tensorboard.pb.cc \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/tensorboard/TensorBoardFileWriter.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/tensorboard/TensorBoardUtils.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/protobuf/graph.pb.cc \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/defs/experiments/defs.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/defs/generator/defs.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/defs/logical/defs.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/defs/math/defs.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/defs/nn/defs.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/defs/reduction/defs.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/defs/rnn/defs.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/defs/tensor/defs.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/core/constants.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/core/status.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/core/utils.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/core/opsignature.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/core/op.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/core/shape_inference.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/core/graph.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/core/model.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/Operators.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/ONNX.cpp \
|
||||
|
||||
CNTKLIBRARY_SRC =\
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/ComputeInputStatistics.cpp \
|
||||
|
|
23
README.md
23
README.md
|
@ -2,6 +2,29 @@
|
|||
|
||||
## Latest news
|
||||
|
||||
***2017-10-10.*** Preview: CNTK ONNX Format Support
|
||||
Update CNTK to support load and save ONNX format from https://github.com/onnx/onnx, please try it and provide feedback. We only support ONNX OPs. This is a preview, and we expect a breaking change in the future.
|
||||
|
||||
* Support loading a model saved in ONNX format.
|
||||
* Support saving a model in ONNX format, not all CNTK models are currently supported. Only a subset of CNTK models are supported and no RNN. We will add more in the future.
|
||||
|
||||
To load an ONNX model, simply specify the format parameter for the load function.
|
||||
```
|
||||
import cntk as C
|
||||
|
||||
C.Function.load(<path of your ONNX model>, format=C.ModelFormat.ONNX)
|
||||
```
|
||||
|
||||
To save a CNTK graph as ONNX model, simply specify the format in the save function.
|
||||
|
||||
```
|
||||
import cntk as C
|
||||
|
||||
x = C.input_variable(<input shape>)
|
||||
z = create_model(x)
|
||||
z.save(<path of where to save your ONNX model>, format=C.ModelFormat.ONNX)
|
||||
```
|
||||
|
||||
***2017-09-25.*** CNTK September interation plan posted [here](https://github.com/Microsoft/CNTK/issues/2410).
|
||||
|
||||
***2017-09-24.*** CNTK R-binding now available [here](https://github.com/Microsoft/CNTK-R).
|
||||
|
|
|
@ -1797,6 +1797,7 @@ namespace CNTK
|
|||
friend class Trainer;
|
||||
friend class PrimitiveFunction;
|
||||
friend class Utils;
|
||||
friend class CNTKToONNXHelper;
|
||||
|
||||
template <typename T>
|
||||
friend struct std::hash;
|
||||
|
@ -1874,6 +1875,14 @@ namespace CNTK
|
|||
///
|
||||
bool IsPlaceholder() const { return Kind() == VariableKind::Placeholder; }
|
||||
|
||||
///
|
||||
/// Returns a boolean value indicating if 'this' variable has a batch axis or not.
|
||||
///
|
||||
bool HasBatchAxis() const {
|
||||
return std::any_of(DynamicAxes().begin(), DynamicAxes().end(),
|
||||
[](const Axis& axis) { return (axis == Axis::DefaultBatchAxis()); });
|
||||
}
|
||||
|
||||
///
|
||||
/// Returns the name of 'this' variable
|
||||
///
|
||||
|
@ -2995,6 +3004,24 @@ namespace CNTK
|
|||
};
|
||||
typedef std::shared_ptr<BackPropState> BackPropStatePtr;
|
||||
|
||||
///
|
||||
/// List of supported disk formats for CNTK model.
|
||||
///
|
||||
enum class ModelFormat
|
||||
{
|
||||
///
|
||||
/// Default CNTK version 2 format, support all CNTK features.
|
||||
///
|
||||
CNTKv2,
|
||||
|
||||
///
|
||||
/// Open Neural Network Exchange format from https://github.com/onnx/onnx
|
||||
/// ONNX support limited subset of CNTK.
|
||||
///
|
||||
ONNX,
|
||||
};
|
||||
|
||||
|
||||
///
|
||||
/// How are Parameters handled when cloning a Function
|
||||
///
|
||||
|
@ -3406,7 +3433,7 @@ namespace CNTK
|
|||
///
|
||||
/// Save this Function graph into a model file.
|
||||
///
|
||||
CNTK_API void Save(const std::wstring& filepath);
|
||||
CNTK_API void Save(const std::wstring& filepath, ModelFormat format = ModelFormat::CNTKv2);
|
||||
|
||||
///
|
||||
/// Restore the models parameters (in-place) from a model file
|
||||
|
@ -3417,7 +3444,8 @@ namespace CNTK
|
|||
/// Load a Function from a model file
|
||||
///
|
||||
CNTK_API static FunctionPtr Load(const std::wstring& filepath,
|
||||
const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
|
||||
const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice(),
|
||||
ModelFormat format = ModelFormat::CNTKv2);
|
||||
|
||||
///
|
||||
/// Load a Function from a memory buffer
|
||||
|
@ -3894,6 +3922,17 @@ namespace CNTK
|
|||
///
|
||||
CNTK_API FunctionPtr ElementDivide(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Compute the element wise maximum operation between the given operands.
|
||||
///
|
||||
CNTK_API FunctionPtr ElementMax(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name);
|
||||
|
||||
|
||||
///
|
||||
/// Compute the element wise minimum operation between the given operands.
|
||||
///
|
||||
CNTK_API FunctionPtr ElementMin(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name);
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in elementwise equality comparison operation on specified tensor input operands.
|
||||
///
|
||||
|
@ -4270,7 +4309,6 @@ namespace CNTK
|
|||
///
|
||||
/// TODO:
|
||||
///
|
||||
// TODO: Do we need a separate "spatial" parameter or can it be inferred from the tensor dimensions
|
||||
CNTK_API FunctionPtr BatchNormalization(const Variable& operand,
|
||||
const Variable& scale,
|
||||
const Variable& bias,
|
||||
|
@ -4284,6 +4322,17 @@ namespace CNTK
|
|||
bool useCuDNNEngine = true,
|
||||
const std::wstring& name = L"");
|
||||
|
||||
//
|
||||
// Local response normalization as described in http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks
|
||||
//
|
||||
CNTK_API FunctionPtr LocalResponseNormalization(const Variable& operand,
|
||||
size_t depthRadius,
|
||||
double bias,
|
||||
double alpha,
|
||||
double beta,
|
||||
const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in OptimizedRNNStack operation on specified input operands
|
||||
///
|
||||
CNTK_API FunctionPtr OptimizedRNNStack(const Variable& operand, const Variable& weights, size_t hiddenSize, size_t numLayers, bool bidirectional = false, const std::wstring& recurrentOp = L"lstm", const std::wstring& name = L"");
|
||||
|
@ -4322,6 +4371,12 @@ namespace CNTK
|
|||
///
|
||||
CNTK_API FunctionPtr AsBlock(FunctionPtr&& composite, const std::vector<std::pair<Variable, Variable>>& argumentsMap, const std::wstring& blockOpName, const std::wstring& blockName = L"");
|
||||
|
||||
///
|
||||
/// Creates a Block Function that encapsulates a composite to create an opaque Function object that
|
||||
/// appears as any other primitive Function
|
||||
///
|
||||
CNTK_API FunctionPtr AsBlock(FunctionPtr&& composite, const std::vector<std::pair<Variable, Variable>>& argumentsMap, Dictionary&& attributes, const std::wstring& blockOpName, const std::wstring& blockName);
|
||||
|
||||
///
|
||||
/// Creates a new Function instance which output its input as it is and previent any gradient contribution from its output.
|
||||
///
|
||||
|
@ -4352,7 +4407,7 @@ namespace CNTK
|
|||
///
|
||||
/// Create an instance of the CNTK built-in elementwise scaled exponential linear unit operation with the specified input operand.
|
||||
///
|
||||
CNTK_API FunctionPtr SELU(const Variable& operand, double scale = 1.0507009873554804934193349852946, double alpha = 1.6732632423543772848170429916717, const std::wstring& name = L"");
|
||||
CNTK_API FunctionPtr SELU(const Variable& operand, double gamma = 1.0507009873554804934193349852946, double alpha = 1.6732632423543772848170429916717, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in elementwise leaky linear rectifier operation with the specified input operand.
|
||||
|
|
|
@ -87,7 +87,7 @@
|
|||
<PrecompiledHeader>NotUsing</PrecompiledHeader>
|
||||
<WarningLevel>Level4</WarningLevel>
|
||||
<Optimization>Disabled</Optimization>
|
||||
<PreprocessorDefinitions>CNTKV2LIBRARYDLL;WIN32;_DEBUG;_WINDOWS;_USRDLL;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<PreprocessorDefinitions>ONNX_V1_OPSCHEMA_COMPAT;CNTKV2LIBRARYDLL;WIN32;_DEBUG;_WINDOWS;_USRDLL;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<SDLCheck>true</SDLCheck>
|
||||
<TreatWarningAsError>true</TreatWarningAsError>
|
||||
<AdditionalOptions>/bigobj %(AdditionalOptions)</AdditionalOptions>
|
||||
|
@ -103,7 +103,7 @@
|
|||
<ClCompile>
|
||||
<WarningLevel>Level4</WarningLevel>
|
||||
<PrecompiledHeader>NotUsing</PrecompiledHeader>
|
||||
<PreprocessorDefinitions>CNTKV2LIBRARYDLL;WIN32;NDEBUG;_WINDOWS;_USRDLL;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<PreprocessorDefinitions>ONNX_V1_OPSCHEMA_COMPAT;CNTKV2LIBRARYDLL;WIN32;NDEBUG;_WINDOWS;_USRDLL;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<SDLCheck>true</SDLCheck>
|
||||
<AdditionalOptions>/d2Zi+ /bigobj %(AdditionalOptions)</AdditionalOptions>
|
||||
<RuntimeLibrary>MultiThreadedDLL</RuntimeLibrary>
|
||||
|
@ -119,7 +119,9 @@
|
|||
</ItemDefinitionGroup>
|
||||
<ItemDefinitionGroup Condition="$(CpuOnlyBuild)">
|
||||
<ClCompile>
|
||||
<PreprocessorDefinitions>CPUONLY;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<PreprocessorDefinitions>ONNX_V1_OPSCHEMA_COMPAT;CPUONLY;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<ObjectFileName Condition="'$(Configuration)|$(Platform)'=='Release_CpuOnly|x64'">$(IntDir)\$(ProjectName)\$(ConfigurationName)\%(RelativeDir)</ObjectFileName>
|
||||
<ObjectFileName Condition="'$(Configuration)|$(Platform)'=='Debug_CpuOnly|x64'">$(IntDir)\$(ProjectName)\$(ConfigurationName)\%(RelativeDir)</ObjectFileName>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<DelayLoadDLLs>Cntk.Math$(OutputSuffix)-$(CntkComponentVersion).dll; msmpi.dll;</DelayLoadDLLs>
|
||||
|
@ -130,6 +132,9 @@
|
|||
<ItemDefinitionGroup Condition="$(GpuBuild)">
|
||||
<ClCompile>
|
||||
<AdditionalIncludeDirectories>%(AdditionalIncludeDirectories);$(CudaInclude)</AdditionalIncludeDirectories>
|
||||
<ObjectFileName Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">$(IntDir)\$(ProjectName)\$(ConfigurationName)\%(RelativeDir)</ObjectFileName>
|
||||
<ObjectFileName Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">$(IntDir)\$(ProjectName)\$(ConfigurationName)\%(RelativeDir)</ObjectFileName>
|
||||
<ObjectFileName Condition="'$(Configuration)|$(Platform)'=='Release|x64'">$(IntDir)\$(ProjectName)\$(ConfigurationName)\%(RelativeDir)</ObjectFileName>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<AdditionalLibraryDirectories>%(AdditionalLibraryDirectories);$(CudaLibPath)</AdditionalLibraryDirectories>
|
||||
|
@ -146,6 +151,12 @@
|
|||
xcopy /D /I /Y "$(TargetPath)" "$(TargetDir).."
|
||||
</Command>
|
||||
</PostBuildEvent>
|
||||
<ClCompile>
|
||||
<ObjectFileName Condition="'$(Configuration)|$(Platform)'=='Release_UWP|x64'">$(IntDir)\$(ProjectName)\$(ConfigurationName)\%(RelativeDir)</ObjectFileName>
|
||||
</ClCompile>
|
||||
<ClCompile>
|
||||
<ObjectFileName Condition="'$(Configuration)|$(Platform)'=='Debug_UWP|x64'">$(IntDir)\$(ProjectName)\$(ConfigurationName)\%(RelativeDir)</ObjectFileName>
|
||||
</ClCompile>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="API\CNTKLibrary.h" />
|
||||
|
@ -161,6 +172,18 @@
|
|||
<ClInclude Include="MinibatchSource.h" />
|
||||
<ClInclude Include="PrimitiveFunction.h" />
|
||||
<ClInclude Include="PrimitiveOpType.h" />
|
||||
<ClInclude Include="proto\onnx\CNTKToONNX.h" />
|
||||
<ClInclude Include="proto\onnx\core\constants.h" />
|
||||
<ClInclude Include="proto\onnx\core\graph.h" />
|
||||
<ClInclude Include="proto\onnx\core\model.h" />
|
||||
<ClInclude Include="proto\onnx\core\op.h" />
|
||||
<ClInclude Include="proto\onnx\core\opsignature.h" />
|
||||
<ClInclude Include="proto\onnx\core\shape_inference.h" />
|
||||
<ClInclude Include="proto\onnx\core\status.h" />
|
||||
<ClInclude Include="proto\onnx\core\utils.h" />
|
||||
<ClInclude Include="proto\onnx\ONNX.h" />
|
||||
<ClInclude Include="proto\onnx\ONNXToCNTK.h" />
|
||||
<ClInclude Include="proto\onnx\Operators.h" />
|
||||
<ClInclude Include="Serialization.h" />
|
||||
<ClInclude Include="tensorboard\TensorBoardUtils.h" />
|
||||
<ClInclude Include="UserDefinedFunction.h" />
|
||||
|
@ -192,6 +215,27 @@
|
|||
<ClCompile Include="NDMask.cpp" />
|
||||
<ClCompile Include="PrimitiveFunction.cpp" />
|
||||
<ClCompile Include="proto\CNTK.pb.cc.VS_wrapper.cpp" />
|
||||
<ClCompile Include="proto\onnx\CNTKToONNX.cpp" />
|
||||
<ClCompile Include="proto\onnx\core\constants.cpp" />
|
||||
<ClCompile Include="proto\onnx\core\graph.cpp" />
|
||||
<ClCompile Include="proto\onnx\core\model.cpp" />
|
||||
<ClCompile Include="proto\onnx\core\op.cpp" />
|
||||
<ClCompile Include="proto\onnx\core\opsignature.cpp" />
|
||||
<ClCompile Include="proto\onnx\core\shape_inference.cpp" />
|
||||
<ClCompile Include="proto\onnx\core\status.cpp" />
|
||||
<ClCompile Include="proto\onnx\core\utils.cpp" />
|
||||
<ClCompile Include="proto\onnx\defs\experiments\defs.cpp" />
|
||||
<ClCompile Include="proto\onnx\defs\generator\defs.cpp" />
|
||||
<ClCompile Include="proto\onnx\defs\logical\defs.cpp" />
|
||||
<ClCompile Include="proto\onnx\defs\math\defs.cpp" />
|
||||
<ClCompile Include="proto\onnx\defs\nn\defs.cpp" />
|
||||
<ClCompile Include="proto\onnx\defs\reduction\defs.cpp" />
|
||||
<ClCompile Include="proto\onnx\defs\rnn\defs.cpp" />
|
||||
<ClCompile Include="proto\onnx\defs\tensor\defs.cpp" />
|
||||
<ClCompile Include="proto\onnx\ONNX.cpp" />
|
||||
<ClCompile Include="proto\onnx\ONNXToCNTK.cpp" />
|
||||
<ClCompile Include="proto\onnx\Operators.cpp" />
|
||||
<ClCompile Include="proto\onnx\protobuf\graph.pb.cc.VS_wrapper.cpp" />
|
||||
<ClCompile Include="Serialization.cpp" />
|
||||
<ClCompile Include="stdafx.cpp">
|
||||
<PrecompiledHeader>Create</PrecompiledHeader>
|
||||
|
@ -209,6 +253,7 @@
|
|||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<Proto Include="proto\CNTK.proto" />
|
||||
<Proto Include="proto\onnx\protobuf\graph.proto" />
|
||||
<Proto Include="tensorboard\tensorboard.proto" />
|
||||
</ItemGroup>
|
||||
<Target Name="ProtoGen" Inputs="@(Proto)" Outputs="@(Proto->'%(RelativeDir)%(Filename).pb.cc')">
|
||||
|
|
|
@ -37,6 +37,69 @@
|
|||
<ClCompile Include="ProgressWriter.cpp" />
|
||||
<ClCompile Include="Evaluator.cpp" />
|
||||
<ClCompile Include="UserDefinedFunction.cpp" />
|
||||
<ClCompile Include="proto\onnx\protobuf\graph.pb.cc.VS_wrapper.cpp">
|
||||
<Filter>proto\onnx\protobuf</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\CNTKToONNX.cpp">
|
||||
<Filter>proto\onnx</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\ONNX.cpp">
|
||||
<Filter>proto\onnx</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\ONNXToCNTK.cpp">
|
||||
<Filter>proto\onnx</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\Operators.cpp">
|
||||
<Filter>proto\onnx</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\core\constants.cpp">
|
||||
<Filter>proto\onnx\core</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\core\graph.cpp">
|
||||
<Filter>proto\onnx\core</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\core\model.cpp">
|
||||
<Filter>proto\onnx\core</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\core\op.cpp">
|
||||
<Filter>proto\onnx\core</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\core\opsignature.cpp">
|
||||
<Filter>proto\onnx\core</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\core\shape_inference.cpp">
|
||||
<Filter>proto\onnx\core</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\core\status.cpp">
|
||||
<Filter>proto\onnx\core</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\core\utils.cpp">
|
||||
<Filter>proto\onnx\core</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\defs\experiments\defs.cpp">
|
||||
<Filter>proto\onnx\defs\experiments</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\defs\generator\defs.cpp">
|
||||
<Filter>proto\onnx\defs\generator</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\defs\logical\defs.cpp">
|
||||
<Filter>proto\onnx\defs\logical</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\defs\math\defs.cpp">
|
||||
<Filter>proto\onnx\defs\math</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\defs\nn\defs.cpp">
|
||||
<Filter>proto\onnx\defs\nn</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\defs\reduction\defs.cpp">
|
||||
<Filter>proto\onnx\defs\reduction</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\defs\rnn\defs.cpp">
|
||||
<Filter>proto\onnx\defs\rnn</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\defs\tensor\defs.cpp">
|
||||
<Filter>proto\onnx\defs\tensor</Filter>
|
||||
</ClCompile>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="stdafx.h" />
|
||||
|
@ -69,6 +132,42 @@
|
|||
<ClInclude Include="Variable.h" />
|
||||
<ClInclude Include="UserFunctionFactory.h" />
|
||||
<ClInclude Include="UserDefinedFunction.h" />
|
||||
<ClInclude Include="proto\onnx\CNTKToONNX.h">
|
||||
<Filter>proto\onnx</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="proto\onnx\ONNX.h">
|
||||
<Filter>proto\onnx</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="proto\onnx\ONNXToCNTK.h">
|
||||
<Filter>proto\onnx</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="proto\onnx\Operators.h">
|
||||
<Filter>proto\onnx</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="proto\onnx\core\constants.h">
|
||||
<Filter>proto\onnx\core</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="proto\onnx\core\graph.h">
|
||||
<Filter>proto\onnx\core</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="proto\onnx\core\model.h">
|
||||
<Filter>proto\onnx\core</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="proto\onnx\core\op.h">
|
||||
<Filter>proto\onnx\core</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="proto\onnx\core\opsignature.h">
|
||||
<Filter>proto\onnx\core</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="proto\onnx\core\shape_inference.h">
|
||||
<Filter>proto\onnx\core</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="proto\onnx\core\status.h">
|
||||
<Filter>proto\onnx\core</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="proto\onnx\core\utils.h">
|
||||
<Filter>proto\onnx\core</Filter>
|
||||
</ClInclude>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<Filter Include="API">
|
||||
|
@ -80,6 +179,42 @@
|
|||
<Filter Include="tensorboard">
|
||||
<UniqueIdentifier>{4242f3a9-0e06-4bf5-b2c2-85b292fe0b43}</UniqueIdentifier>
|
||||
</Filter>
|
||||
<Filter Include="proto\onnx">
|
||||
<UniqueIdentifier>{ca68761d-44d4-41a9-b055-4b192402ed0b}</UniqueIdentifier>
|
||||
</Filter>
|
||||
<Filter Include="proto\onnx\protobuf">
|
||||
<UniqueIdentifier>{77b30c82-208d-4c0d-825c-9b0eda61caa6}</UniqueIdentifier>
|
||||
</Filter>
|
||||
<Filter Include="proto\onnx\core">
|
||||
<UniqueIdentifier>{ac45f7f4-5f65-40d4-9163-46580266ae16}</UniqueIdentifier>
|
||||
</Filter>
|
||||
<Filter Include="proto\onnx\defs">
|
||||
<UniqueIdentifier>{cb1e39c1-bd5e-4d7f-8c83-39de28e70307}</UniqueIdentifier>
|
||||
</Filter>
|
||||
<Filter Include="proto\onnx\defs\experiments">
|
||||
<UniqueIdentifier>{6168d648-8e32-4ad7-b16d-78d6a2e7a461}</UniqueIdentifier>
|
||||
</Filter>
|
||||
<Filter Include="proto\onnx\defs\generator">
|
||||
<UniqueIdentifier>{e52f27a1-b8c4-4d67-874f-51aa43b04815}</UniqueIdentifier>
|
||||
</Filter>
|
||||
<Filter Include="proto\onnx\defs\logical">
|
||||
<UniqueIdentifier>{e8f5863d-409a-4b02-814b-a93241e4e56d}</UniqueIdentifier>
|
||||
</Filter>
|
||||
<Filter Include="proto\onnx\defs\math">
|
||||
<UniqueIdentifier>{6a99c5c8-2686-4007-83ed-988b1e1ae852}</UniqueIdentifier>
|
||||
</Filter>
|
||||
<Filter Include="proto\onnx\defs\nn">
|
||||
<UniqueIdentifier>{7a757445-6a05-4bd2-b37d-6d95da3112fc}</UniqueIdentifier>
|
||||
</Filter>
|
||||
<Filter Include="proto\onnx\defs\reduction">
|
||||
<UniqueIdentifier>{cd1c3786-8e2a-44bf-a7fc-28114019cf2d}</UniqueIdentifier>
|
||||
</Filter>
|
||||
<Filter Include="proto\onnx\defs\rnn">
|
||||
<UniqueIdentifier>{4bc38915-7be4-42e2-951b-9445abf8038c}</UniqueIdentifier>
|
||||
</Filter>
|
||||
<Filter Include="proto\onnx\defs\tensor">
|
||||
<UniqueIdentifier>{deb73515-13a1-4926-b5fc-e8c6e97f7784}</UniqueIdentifier>
|
||||
</Filter>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<Proto Include="proto\CNTK.proto">
|
||||
|
@ -88,5 +223,8 @@
|
|||
<Proto Include="tensorboard\tensorboard.proto">
|
||||
<Filter>tensorboard</Filter>
|
||||
</Proto>
|
||||
<Proto Include="proto\onnx\protobuf\graph.proto">
|
||||
<Filter>proto\onnx\protobuf</Filter>
|
||||
</Proto>
|
||||
</ItemGroup>
|
||||
</Project>
|
|
@ -11,6 +11,7 @@
|
|||
#include "Utils.h"
|
||||
#include "UserFunctionFactory.h"
|
||||
#include "TrainingNodes.h"
|
||||
#include "proto/onnx/ONNX.h"
|
||||
|
||||
using namespace Microsoft::MSR::CNTK;
|
||||
|
||||
|
@ -467,15 +468,32 @@ namespace CNTK
|
|||
vectorBuf.assign(s.begin(), s.end());
|
||||
}
|
||||
|
||||
void Function::Save(const std::wstring& filepath)
|
||||
void Function::Save(const std::wstring& filepath, ModelFormat format)
|
||||
{
|
||||
switch (format)
|
||||
{
|
||||
case ModelFormat::CNTKv2:
|
||||
{
|
||||
Dictionary model = Serialize();
|
||||
auto stream = GetFstream(filepath, false);
|
||||
*stream << model;
|
||||
stream->flush();
|
||||
break;
|
||||
}
|
||||
|
||||
/*static*/ FunctionPtr Function::Load(const std::wstring& filepath, const DeviceDescriptor& computeDevice)
|
||||
case ModelFormat::ONNX:
|
||||
{
|
||||
ONNXFormat::Save(RootFunction(), filepath);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*static*/ FunctionPtr Function::Load(const std::wstring& filepath, const DeviceDescriptor& computeDevice, ModelFormat format)
|
||||
{
|
||||
switch (format)
|
||||
{
|
||||
case ModelFormat::CNTKv2:
|
||||
{
|
||||
auto stream = GetFstream(filepath, true);
|
||||
if (!Internal::IsLegacyModel(*stream))
|
||||
|
@ -488,6 +506,15 @@ namespace CNTK
|
|||
{
|
||||
return Internal::LoadLegacyModel(filepath, computeDevice); // throw an exception if deserializer != nullptr?
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case ModelFormat::ONNX:
|
||||
return ONNXFormat::Load(filepath, computeDevice);
|
||||
break;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/*static*/ FunctionPtr Function::Load(const char *buffer, size_t length, const DeviceDescriptor& computeDevice)
|
||||
|
@ -1145,10 +1172,13 @@ namespace CNTK
|
|||
if (!axis.IsStaticAxis() && (axis != Axis::AllStaticAxes()))
|
||||
LogicError("Softmax: support only static axes.");
|
||||
|
||||
auto additionalProperties = Dictionary();
|
||||
additionalProperties[PrimitiveFunction::AttributeNameAxis] = axis;
|
||||
|
||||
if (((operand.Shape().Rank() == 1) && (axis.StaticAxisIndex() == 0)) ||
|
||||
(axis == Axis::AllStaticAxes()))
|
||||
{
|
||||
return UnaryOp(PrimitiveOpType::Softmax, operand, Dictionary(), name);
|
||||
return UnaryOp(PrimitiveOpType::Softmax, operand, std::move(additionalProperties), name);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -1157,7 +1187,7 @@ namespace CNTK
|
|||
auto expOperandDelta = Exp(operandDelta);
|
||||
auto result = ElementDivide(expOperandDelta, ReduceSum(expOperandDelta, axis));
|
||||
|
||||
return AsBlock(std::move(result), { { operandPlaceholder, operand } }, L"Softmax", name);
|
||||
return AsBlock(std::move(result), { { operandPlaceholder, operand } }, std::move(additionalProperties), L"Softmax", name);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1416,6 +1446,28 @@ namespace CNTK
|
|||
return ElementTimes(leftOperand, Reciprocal(rightOperand), name);
|
||||
}
|
||||
|
||||
FunctionPtr ElementMax(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name)
|
||||
{
|
||||
auto leftOperandPlaceholder = PlaceholderVariable();
|
||||
auto rightOperandPlaceholder = PlaceholderVariable();
|
||||
|
||||
auto result = ElementSelect(Greater(leftOperandPlaceholder, rightOperandPlaceholder),
|
||||
leftOperandPlaceholder,
|
||||
rightOperandPlaceholder);
|
||||
return AsBlock(std::move(result), { { leftOperandPlaceholder, leftOperand },{ rightOperandPlaceholder, rightOperand } }, L"ElementMax", name);
|
||||
}
|
||||
|
||||
FunctionPtr ElementMin(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name)
|
||||
{
|
||||
auto leftOperandPlaceholder = PlaceholderVariable();
|
||||
auto rightOperandPlaceholder = PlaceholderVariable();
|
||||
|
||||
auto result = ElementSelect(Less(leftOperandPlaceholder, rightOperandPlaceholder),
|
||||
leftOperandPlaceholder,
|
||||
rightOperandPlaceholder);
|
||||
return AsBlock(std::move(result), { { leftOperandPlaceholder, leftOperand },{ rightOperandPlaceholder, rightOperand } }, L"ElementMin", name);
|
||||
}
|
||||
|
||||
FunctionPtr Equal(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name)
|
||||
{
|
||||
return BinaryOp(PrimitiveOpType::Equal, leftOperand, rightOperand, Dictionary(), name);
|
||||
|
@ -1966,6 +2018,26 @@ namespace CNTK
|
|||
name);
|
||||
}
|
||||
|
||||
FunctionPtr LocalResponseNormalization(const Variable& operand, size_t depthRadius, double bias, double alpha, double beta, const std::wstring& name)
|
||||
{
|
||||
auto additionalProperties = Dictionary();
|
||||
additionalProperties[PrimitiveFunction::AttributeNameDepthRadius] = depthRadius;
|
||||
additionalProperties[PrimitiveFunction::AttributeNameBias] = bias;
|
||||
additionalProperties[PrimitiveFunction::AttributeNameAlpha] = alpha;
|
||||
additionalProperties[PrimitiveFunction::AttributeNameBeta] = beta;
|
||||
|
||||
auto operandPlaceholder = PlaceholderVariable();
|
||||
auto operandSquare = Square(operandPlaceholder);
|
||||
operandSquare = Reshape(operandSquare, { NDShape::InferredDimension, 1 }, Axis(2), Axis(3));
|
||||
auto weights = Constant({ 1, 1, 2 * depthRadius + 1, 1 }, operand.GetDataType(), alpha / (2 * depthRadius + 1));
|
||||
auto convResult = Convolution(weights, operandSquare);
|
||||
convResult = Reshape(convResult, { NDShape::InferredDimension }, Axis(2), Axis(4));
|
||||
auto denom = Exp(ElementTimes(Constant::Scalar(operand.GetDataType(), beta), Log(Plus(Constant::Scalar(operand.GetDataType(), bias), convResult))));
|
||||
|
||||
auto result = ElementDivide(operandPlaceholder, denom);
|
||||
return AsBlock(std::move(result), { { operandPlaceholder, operand } }, std::move(additionalProperties), L"LocalResponseNormalization", name);
|
||||
}
|
||||
|
||||
FunctionPtr Clip(const Variable& operand, const Variable& min, const Variable& max, const std::wstring& name)
|
||||
{
|
||||
std::vector<Variable> operands = { operand, min, max };
|
||||
|
@ -1999,11 +2071,16 @@ namespace CNTK
|
|||
}
|
||||
|
||||
FunctionPtr AsBlock(FunctionPtr&& composite, const std::vector<std::pair<Variable, Variable>>& argumentsMap, const std::wstring& blockOpName, const std::wstring& blockName)
|
||||
{
|
||||
return AsBlock(std::move(composite), argumentsMap, Dictionary(), blockOpName, blockName);
|
||||
}
|
||||
|
||||
FunctionPtr AsBlock(FunctionPtr&& composite, const std::vector<std::pair<Variable, Variable>>& argumentsMap, Dictionary&& attributes, const std::wstring& blockOpName, const std::wstring& blockName)
|
||||
{
|
||||
if (!composite->IsComposite())
|
||||
InvalidArgument("Composite argument '%S' to AsBlock is not a composite Function.", composite->AsString().c_str());
|
||||
|
||||
return AsComposite(MakeSharedObject<BlockFunction>(std::move(composite), argumentsMap, blockOpName, Dictionary(), blockName), blockName);
|
||||
return AsComposite(MakeSharedObject<BlockFunction>(std::move(composite), argumentsMap, blockOpName, std::move(attributes), blockName), blockName);
|
||||
}
|
||||
|
||||
FunctionPtr AsComposite(const FunctionPtr& rootFunction, const std::wstring& name)
|
||||
|
@ -2027,26 +2104,33 @@ namespace CNTK
|
|||
return UnaryOp(PrimitiveOpType::ELU, operand, Dictionary(), name);
|
||||
}
|
||||
|
||||
FunctionPtr SELU(const Variable& operand, double scale, double alpha, const std::wstring& name)
|
||||
FunctionPtr SELU(const Variable& operand, double gamma, double alpha, const std::wstring& name)
|
||||
{
|
||||
auto additionalProperties = Dictionary();
|
||||
additionalProperties[PrimitiveFunction::AttributeNameGamma] = gamma;
|
||||
additionalProperties[PrimitiveFunction::AttributeNameAlpha] = alpha;
|
||||
|
||||
auto operandPlaceholder = PlaceholderVariable();
|
||||
auto lessThanZero = Less(operandPlaceholder, Constant::Scalar(operand.GetDataType(), 0.0));
|
||||
auto result = ElementSelect(lessThanZero,
|
||||
ElementTimes(Constant::Scalar(operand.GetDataType(), alpha), ELU(operandPlaceholder)),
|
||||
operandPlaceholder);
|
||||
result = ElementTimes(Constant::Scalar(operand.GetDataType(), scale), result);
|
||||
return AsBlock(std::move(result), { { operandPlaceholder, operand } }, L"SELU", name);
|
||||
result = ElementTimes(Constant::Scalar(operand.GetDataType(), gamma), result);
|
||||
return AsBlock(std::move(result), { { operandPlaceholder, operand } }, std::move(additionalProperties), L"SELU", name);
|
||||
}
|
||||
|
||||
FunctionPtr LeakyReLU(const Variable& operand, const std::wstring& name)
|
||||
{
|
||||
auto additionalProperties = Dictionary();
|
||||
additionalProperties[PrimitiveFunction::AttributeNameAlpha] = 0.01;
|
||||
|
||||
auto operandPlaceholder = PlaceholderVariable();
|
||||
auto lessThanZero = Less(operandPlaceholder, Constant::Scalar(operand.GetDataType(), 0.0));
|
||||
auto result = ElementSelect(lessThanZero,
|
||||
ElementTimes(Constant::Scalar(operand.GetDataType(), 0.01), operandPlaceholder),
|
||||
operandPlaceholder);
|
||||
|
||||
return AsBlock(std::move(result), { { operandPlaceholder, operand } }, L"LeakyReLU", name);
|
||||
return AsBlock(std::move(result), { { operandPlaceholder, operand } }, std::move(additionalProperties), L"LeakyReLU", name);
|
||||
}
|
||||
|
||||
FunctionPtr PReLU(const Variable& alpha, const Variable& operand, const std::wstring& name)
|
||||
|
@ -2530,6 +2614,7 @@ namespace CNTK
|
|||
additionalProperties[PrimitiveFunction::AttributeNameUpperPad] = NDShape({0});
|
||||
additionalProperties[PrimitiveFunction::AttributeNameTranspose] = transpose;
|
||||
additionalProperties[PrimitiveFunction::AttributeNameOutputShape] = outputShape;
|
||||
additionalProperties[PrimitiveFunction::AttributeNameKernelShape] = NDShape({0});
|
||||
additionalProperties[PrimitiveFunction::AttributeNameMaxTempMemSizeInSamples] = maxTempMemSizeInSamples;
|
||||
|
||||
return BinaryOp(PrimitiveOpType::Convolution, convolutionMap, operand, std::move(additionalProperties), name);
|
||||
|
|
|
@ -111,6 +111,13 @@ namespace CNTK
|
|||
/*static*/ const std::wstring PrimitiveFunction::AttributeNamePaddingFoot = L"paddingFoot";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNamePaddingMode = L"paddingMode";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNamePaddingConstantValue = L"paddingConstantValue";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameAlpha = L"alpha";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameBeta = L"beta";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameGamma = L"gamma";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameKernelShape = L"kernelShape";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameBias = L"bias";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameDepthRadius = L"depthRadius";
|
||||
|
||||
|
||||
/*static*/ DataType PrimitiveFunction::GetOutputDataType(PrimitiveOpType op, std::vector<Variable>& inputs, bool inferDimensions)
|
||||
{
|
||||
|
@ -876,6 +883,7 @@ namespace CNTK
|
|||
m_attributes[PrimitiveFunction::AttributeNameSharing] = AsDictionaryValueVector(sharing);
|
||||
m_attributes[PrimitiveFunction::AttributeNameAutoPadding] = AsDictionaryValueVector(autoPadding);
|
||||
m_attributes[PrimitiveFunction::AttributeNameDilation] = dilation;
|
||||
m_attributes[PrimitiveFunction::AttributeNameKernelShape] = kernelShape;
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::CrossEntropyWithSoftmax:
|
||||
|
|
|
@ -281,6 +281,12 @@ namespace CNTK
|
|||
static const std::wstring AttributeNamePaddingFoot;
|
||||
static const std::wstring AttributeNamePaddingMode;
|
||||
static const std::wstring AttributeNamePaddingConstantValue;
|
||||
static const std::wstring AttributeNameAlpha;
|
||||
static const std::wstring AttributeNameBeta;
|
||||
static const std::wstring AttributeNameGamma;
|
||||
static const std::wstring AttributeNameKernelShape;
|
||||
static const std::wstring AttributeNameBias;
|
||||
static const std::wstring AttributeNameDepthRadius;
|
||||
|
||||
protected:
|
||||
PrimitiveFunction(PrimitiveOpType op, const std::vector<Variable>& inputs, Dictionary&& functionConfig, const std::wstring& functionName, const std::wstring& uid)
|
||||
|
|
|
@ -0,0 +1,840 @@
|
|||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
|
||||
#include "CNTKToONNX.h"
|
||||
#include "proto/onnx/core/model.h"
|
||||
#include "proto/onnx/core/graph.h"
|
||||
|
||||
#include "Utils.h"
|
||||
#include "Operators.h"
|
||||
#include "BlockFunction.h"
|
||||
#include <vector>
|
||||
|
||||
using namespace CNTK::ONNX;
|
||||
using namespace CNTK;
|
||||
|
||||
//
|
||||
// A helper function, to reverse any iterable container and return a copy
|
||||
// of the reversed container.
|
||||
//
|
||||
template<typename ItrType>
|
||||
ItrType reverse(ItrType v)
|
||||
{
|
||||
std::reverse(std::begin(v), std::end(v));
|
||||
return v;
|
||||
}
|
||||
|
||||
//
|
||||
// Helper function to reduce the rank of a shape.
|
||||
//
|
||||
ONNXIR::TypeProto ReduceRank(const ONNXIR::TypeProto::TensorShapeProto* inputShape, int reductionRank, bool rightReduction)
|
||||
{
|
||||
assert(inputShape != nullptr);
|
||||
|
||||
int inputRank = inputShape->dim_size();
|
||||
assert(inputRank > reductionRank);
|
||||
|
||||
ONNXIR::TypeProto newShape;
|
||||
int64_t reduceDim = 1;
|
||||
|
||||
if (rightReduction)
|
||||
{
|
||||
for (int index = 0; index < (inputRank - reductionRank); index++)
|
||||
newShape.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(inputShape->dim(index).dim_value());
|
||||
|
||||
for (int index = (inputRank - reductionRank); index < inputRank; index++)
|
||||
reduceDim *= inputShape->dim(index).dim_value();
|
||||
|
||||
newShape.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(reduceDim);
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int index = 0; index < reductionRank; index++)
|
||||
reduceDim *= inputShape->dim(index).dim_value();
|
||||
|
||||
newShape.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(reduceDim);
|
||||
|
||||
for (int index = reductionRank; index < inputRank; index++)
|
||||
newShape.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(inputShape->dim(index).dim_value());
|
||||
}
|
||||
|
||||
return newShape;
|
||||
}
|
||||
|
||||
namespace CNTK
|
||||
{
|
||||
|
||||
class CNTKToONNXHelper
|
||||
{
|
||||
public:
|
||||
//
|
||||
// Copy the entire CNTK graph to ONNX graph.
|
||||
//
|
||||
static void Copy(const FunctionPtr& src, ONNXIR::Graph* dst);
|
||||
|
||||
private:
|
||||
//
|
||||
// Recursively create ONNX nodes corresponding to each CNTK node.
|
||||
//
|
||||
static ONNXIR::Node* CreateNode(const FunctionPtr& src,
|
||||
ONNXIR::Graph* graph,
|
||||
std::unordered_map<FunctionPtr, ONNXIR::Node*>& functionNodes,
|
||||
std::unordered_map<Variable, ONNXIR::Node*>& variableNodes,
|
||||
const std::unordered_map<Variable, Variable>& compositeOutputsMap);
|
||||
|
||||
//
|
||||
// Traverse the entire graph and collect variable mapping between graph inside and outside the block.
|
||||
//
|
||||
static void TraverseGraph(const FunctionPtr& src,
|
||||
std::set<FunctionPtr>& visited,
|
||||
std::unordered_map<Variable, Variable>& compositeOutputsMap);
|
||||
|
||||
//
|
||||
// Copy the content of NDArrayView to TensorProto, and do the needed
|
||||
// convergence.
|
||||
//
|
||||
static void CopyTensor(const NDArrayViewPtr src, ONNXIR::TensorProto& dst);
|
||||
|
||||
//
|
||||
// Copy supported attributes from CNTK node to corresponding ONNX node.
|
||||
//
|
||||
static void CopyAttributes(const FunctionPtr& src, ONNXIR::Node* node);
|
||||
|
||||
//
|
||||
// Convert Axis object to actual tensor index.
|
||||
//
|
||||
static int ToIndex(const Axis& axis);
|
||||
|
||||
//
|
||||
// Convert NDShape and various std::vector types to TensorShape
|
||||
//
|
||||
static ONNXIR::TypeProto ToTypeProto(const NDShape& shape, bool hasBatchAxis = false);
|
||||
static ONNXIR::TypeProto ToTypeProto(const std::vector<bool>& shape);
|
||||
static ONNXIR::TypeProto ToTypeProto(const std::vector<int>& shape);
|
||||
static ONNXIR::TypeProto ToTypeProto(const std::vector<Axis>& axes);
|
||||
|
||||
//
|
||||
// Convert TypeProto, NDShape and various std::vector types to std::vector
|
||||
//
|
||||
static std::vector<int64_t> ToINTS(const ONNXIR::TypeProto& shape);
|
||||
static std::vector<int64_t> ToINTS(const NDShape& shape, bool hasBatchAxis = false);
|
||||
static std::vector<int64_t> ToINTS(const std::vector<bool>& shape);
|
||||
static std::vector<int64_t> ToINTS(const std::vector<int>& shape);
|
||||
static std::vector<int64_t> ToINTS(const std::vector<Axis>& axes);
|
||||
|
||||
//
|
||||
// Convert data types from CNTK to ONNX.
|
||||
//
|
||||
static void UpdateONNXType(DataType dataType, ONNXIR::TypeProto& type);
|
||||
|
||||
//
|
||||
// Map CNTK OP names to ONNX OP Names.
|
||||
//
|
||||
static std::string ToOPName(const FunctionPtr& src);
|
||||
|
||||
//
|
||||
// Check that the CNTK variable is compatible with ONNX.
|
||||
//
|
||||
static void ValidateVariable(const Variable& v);
|
||||
|
||||
//
|
||||
// Which input to ignore during converting a CNTK block to a primitive OP in ONNX.
|
||||
//
|
||||
static bool FilterInput(const FunctionPtr& src, const CNTK::Variable& input, size_t inputIndex);
|
||||
|
||||
//
|
||||
// Argument orders between CNTK and ONNX aren't always the same.
|
||||
//
|
||||
static std::vector<ONNXIR::NodeArg> MapInputsOrderToONNX(const FunctionPtr& src, const std::vector<ONNXIR::NodeArg>& inputs);
|
||||
|
||||
//
|
||||
// Add current CNTK node to ONNX graph.
|
||||
//
|
||||
static ONNXIR::Node* AddNode(const FunctionPtr& src, ONNXIR::Graph* graph, const std::vector<ONNXIR::NodeArg>& inputs, const std::vector<ONNXIR::NodeArg>& outputs);
|
||||
};
|
||||
}
|
||||
|
||||
std::unique_ptr<ONNXIR::Model> CNTKToONNX::CreateModel(const FunctionPtr& src)
|
||||
{
|
||||
std::unique_ptr<ONNXIR::Model> model(new ONNXIR::Model("CNTKGraph", true));
|
||||
auto dstGraph = model->MainGraph();
|
||||
CNTKToONNXHelper::Copy(src, dstGraph);
|
||||
ONNXIR::Status status = dstGraph->Resolve();
|
||||
if (!status.Ok())
|
||||
LogicError("%s", status.ErrorMsg().c_str());
|
||||
return model;
|
||||
}
|
||||
|
||||
void CNTKToONNXHelper::Copy(const FunctionPtr& src, ONNXIR::Graph* dst)
|
||||
{
|
||||
std::set<FunctionPtr> visited;
|
||||
std::unordered_map<Variable, Variable> compositeOutputsMap;
|
||||
std::unordered_map<FunctionPtr, ONNXIR::Node*> functionNodes;
|
||||
std::unordered_map<Variable, ONNXIR::Node*> variableNodes;
|
||||
|
||||
//
|
||||
// Traverse the graph and collect some information.
|
||||
//
|
||||
TraverseGraph(src, visited, compositeOutputsMap);
|
||||
|
||||
//
|
||||
// Iterate through each node in CNTK graph and create an equivalent node
|
||||
// in ONNX graph.
|
||||
//
|
||||
CreateNode(src, dst, functionNodes, variableNodes, compositeOutputsMap);
|
||||
}
|
||||
|
||||
void CNTKToONNXHelper::CopyTensor(const NDArrayViewPtr src, ONNXIR::TensorProto& dst)
|
||||
{
|
||||
auto dataType = src->GetDataType();
|
||||
auto srcTemp = src->DeepClone();
|
||||
auto srcShape = srcTemp->Shape();
|
||||
auto totalSize = srcShape.TotalSize();
|
||||
|
||||
// This is our own copy so move it to the CPU.
|
||||
srcTemp->ChangeDevice(DeviceDescriptor::CPUDevice());
|
||||
|
||||
switch (dataType)
|
||||
{
|
||||
case DataType::Float:
|
||||
{
|
||||
dst.set_data_type(ONNXIR::TensorProto_DataType_FLOAT);
|
||||
auto data = srcTemp->DataBuffer<float>();
|
||||
for (size_t index = 0; index < totalSize; index++)
|
||||
*(dst.mutable_float_data()->Add()) = data[index];
|
||||
|
||||
break;
|
||||
}
|
||||
case DataType::Double:
|
||||
{
|
||||
dst.set_data_type(ONNXIR::TensorProto_DataType_DOUBLE);
|
||||
auto data = srcTemp->DataBuffer<double>();
|
||||
for (size_t index = 0; index < totalSize; index++)
|
||||
*(dst.mutable_double_data()->Add()) = data[index];
|
||||
|
||||
break;
|
||||
}
|
||||
default:
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
auto dimensions = reverse(srcShape.Dimensions());
|
||||
for (auto dim : dimensions)
|
||||
*(dst.mutable_dims()->Add()) = dim;
|
||||
}
|
||||
|
||||
int CNTKToONNXHelper::ToIndex(const Axis& axis)
|
||||
{
|
||||
if ((axis == Axis::AllAxes()) || (axis == Axis::AllStaticAxes()))
|
||||
LogicError("AllAxes and AllStaticAxes are currently not supported.");
|
||||
|
||||
if (axis.IsSequenceAxis())
|
||||
LogicError("Sequence axis are currently not supported.");
|
||||
|
||||
if (axis.IsBatchAxis())
|
||||
return 0;
|
||||
|
||||
return axis.StaticAxisIndex() + 1;
|
||||
}
|
||||
|
||||
ONNXIR::TypeProto CNTKToONNXHelper::ToTypeProto(const NDShape& shape, bool hasBatchAxis)
|
||||
{
|
||||
ONNXIR::TypeProto newShape;
|
||||
if (shape.HasUnboundDimension())
|
||||
LogicError("Inferred and FreeDimension aren't currently supported.");
|
||||
|
||||
if (hasBatchAxis)
|
||||
newShape.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
|
||||
|
||||
auto dimensions = reverse(shape.Dimensions());
|
||||
for (auto dimension : dimensions)
|
||||
{
|
||||
newShape.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dimension);
|
||||
}
|
||||
|
||||
return newShape;
|
||||
}
|
||||
|
||||
ONNXIR::TypeProto CNTKToONNXHelper::ToTypeProto(const std::vector<bool>& shape)
|
||||
{
|
||||
ONNXIR::TypeProto newShape;
|
||||
auto dimensions = reverse(shape);
|
||||
for (auto dimension : dimensions)
|
||||
newShape.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dimension ? 1:0);
|
||||
|
||||
return newShape;
|
||||
}
|
||||
|
||||
ONNXIR::TypeProto CNTKToONNXHelper::ToTypeProto(const std::vector<int>& shape)
|
||||
{
|
||||
ONNXIR::TypeProto newShape;
|
||||
auto dimensions = reverse(shape);
|
||||
for (auto dimension : dimensions)
|
||||
newShape.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dimension);
|
||||
|
||||
return newShape;
|
||||
}
|
||||
|
||||
ONNXIR::TypeProto CNTKToONNXHelper::ToTypeProto(const std::vector<Axis>& axes)
|
||||
{
|
||||
std::vector<int> axesValue;
|
||||
for (auto axis : axes)
|
||||
{
|
||||
axesValue.push_back(ToIndex(axis));
|
||||
}
|
||||
std::sort(axesValue.begin(), axesValue.end());
|
||||
|
||||
ONNXIR::TypeProto newShape;
|
||||
for (auto dimension : axesValue)
|
||||
newShape.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dimension);
|
||||
|
||||
return newShape;
|
||||
}
|
||||
|
||||
std::vector<int64_t> CNTKToONNXHelper::ToINTS(const ONNXIR::TypeProto& shape)
|
||||
{
|
||||
std::vector<int64_t> newShape;
|
||||
|
||||
for (int i = 0; i < shape.tensor_type().shape().dim_size(); i++)
|
||||
newShape.push_back((int64_t)shape.tensor_type().shape().dim(i).dim_value());
|
||||
|
||||
return newShape;
|
||||
}
|
||||
|
||||
std::vector<int64_t> CNTKToONNXHelper::ToINTS(const NDShape& shape, bool hasBatchAxis)
|
||||
{
|
||||
return ToINTS(ToTypeProto(shape, hasBatchAxis));
|
||||
}
|
||||
|
||||
std::vector<int64_t> CNTKToONNXHelper::ToINTS(const std::vector<bool>& shape)
|
||||
{
|
||||
return ToINTS(ToTypeProto(shape));
|
||||
}
|
||||
|
||||
std::vector<int64_t> CNTKToONNXHelper::ToINTS(const std::vector<int>& shape)
|
||||
{
|
||||
return ToINTS(ToTypeProto(shape));
|
||||
}
|
||||
|
||||
std::vector<int64_t> CNTKToONNXHelper::ToINTS(const std::vector<Axis>& axes)
|
||||
{
|
||||
return ToINTS(ToTypeProto(axes));
|
||||
}
|
||||
|
||||
void CNTKToONNXHelper::UpdateONNXType(DataType dataType, ONNXIR::TypeProto& type)
|
||||
{
|
||||
switch (dataType)
|
||||
{
|
||||
case DataType::Float:
|
||||
type.mutable_tensor_type()->set_elem_type(ONNXIR::TensorProto_DataType_FLOAT);
|
||||
break;
|
||||
case DataType::Double:
|
||||
type.mutable_tensor_type()->set_elem_type(ONNXIR::TensorProto_DataType_DOUBLE);
|
||||
break;
|
||||
default:
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
}
|
||||
|
||||
std::string CNTKToONNXHelper::ToOPName(const FunctionPtr& src)
|
||||
{
|
||||
auto lookup = Operators::CntkToONNXLookup();
|
||||
assert(lookup.count(src->OpName()) != 0);
|
||||
|
||||
std::string opName = ToString(src->OpName());
|
||||
if (lookup.count(src->OpName()) == 1)
|
||||
{
|
||||
auto attributesMap = lookup.find(src->OpName())->second.map;
|
||||
opName = attributesMap[src->OpName()];
|
||||
}
|
||||
else
|
||||
{
|
||||
// Some nodes map one to many.
|
||||
if (src->OpName() == L"Convolution")
|
||||
{
|
||||
auto transpose = (bool)src->Attributes()[L"transpose"].Value<bool>();
|
||||
if (transpose)
|
||||
opName = "ConvTranspose";
|
||||
else
|
||||
opName = "Conv";
|
||||
}
|
||||
else if (src->OpName() == L"Pooling")
|
||||
{
|
||||
PoolingType poolingType = (PoolingType)src->Attributes()[L"poolingType"].Value<size_t>();
|
||||
if (poolingType == PoolingType::Max)
|
||||
opName = "MaxPool";
|
||||
else
|
||||
opName = "AveragePool";
|
||||
}
|
||||
}
|
||||
|
||||
return opName;
|
||||
}
|
||||
|
||||
void CNTKToONNXHelper::ValidateVariable(const Variable& v)
|
||||
{
|
||||
if ((v.HasBatchAxis() && (v.DynamicAxes().size() > 1)) ||
|
||||
(!v.HasBatchAxis() && (v.DynamicAxes().size() > 0)))
|
||||
{
|
||||
LogicError("Sequence and user defined dynamic axis are currently not supported.");
|
||||
}
|
||||
}
|
||||
|
||||
bool CNTKToONNXHelper::FilterInput(const FunctionPtr& src, const CNTK::Variable& input, size_t inputIndex)
|
||||
{
|
||||
// In CNTK block functions, they expose all constants inside the block. For block functions that
|
||||
// map directly to ONNX OP, we don't care about constanst inside the block.
|
||||
if (input.IsConstant())
|
||||
return !Operators::IsValidInputs(src->OpName(), inputIndex);
|
||||
return false;
|
||||
}
|
||||
|
||||
//
|
||||
// This is the main horsepower, it navigate CNTK graph recursivley while keep track of all visited nodes and variables,
|
||||
// and create the corresponding ONNX graph.
|
||||
//
|
||||
ONNXIR::Node* CNTKToONNXHelper::CreateNode(const FunctionPtr& src,
|
||||
ONNXIR::Graph* graph,
|
||||
std::unordered_map<FunctionPtr, ONNXIR::Node*>& functionNodes,
|
||||
std::unordered_map<Variable, ONNXIR::Node*>& variableNodes,
|
||||
const std::unordered_map<Variable, Variable>& compositeOutputsMap)
|
||||
{
|
||||
auto iter = functionNodes.find(src);
|
||||
if (iter != functionNodes.end())
|
||||
return iter->second;
|
||||
|
||||
ONNXIR::Node* functionNode = nullptr;
|
||||
std::string opName = ToString(src->OpName());
|
||||
|
||||
//
|
||||
// If this block node equivalent to a primitive ONNX OP, then treated as such.
|
||||
// And just maps its argument to ONNX node.
|
||||
//
|
||||
if (src->IsBlock() &&
|
||||
(!Operators::IsSupportedCNTKOP(src->OpName()) || Operators::IsLayerCNTKOP(src->OpName())))
|
||||
{
|
||||
functionNode = CreateNode(src->BlockRoot(), graph, functionNodes, variableNodes, compositeOutputsMap);
|
||||
}
|
||||
//
|
||||
// For compatibility of other framework that support ONNX, we will limit the list of OPs to the one
|
||||
// supported by ONNX https://github.com/onnx/onnx/tree/master/onnx/defs.
|
||||
//
|
||||
else if (Operators::IsSupportedCNTKOP(src->OpName()))
|
||||
{
|
||||
std::vector<ONNXIR::NodeArg> inputs;
|
||||
std::vector<ONNXIR::NodeArg> outputs;
|
||||
|
||||
for (const auto& output : src->Outputs())
|
||||
{
|
||||
ValidateVariable(output);
|
||||
|
||||
auto outputArgType = ToTypeProto(output.Shape(), output.HasBatchAxis());
|
||||
UpdateONNXType(output.GetDataType(), outputArgType);
|
||||
|
||||
ONNXIR::NodeArg outputArg(ToString(output.Uid()), &outputArgType);
|
||||
outputs.push_back(outputArg);
|
||||
}
|
||||
|
||||
for (size_t inputIndex = 0; inputIndex < src->Inputs().size(); ++inputIndex)
|
||||
{
|
||||
auto input = src->Inputs()[inputIndex];
|
||||
|
||||
if (input.IsPlaceholder())
|
||||
{
|
||||
input = input.BlockFunctionVariableMapping();
|
||||
if (input.IsPlaceholder())
|
||||
LogicError("Node '%S': Placeholder isn't supported currently.", src->AsString().c_str());
|
||||
}
|
||||
ValidateVariable(input);
|
||||
|
||||
if (src->IsBlock() && FilterInput(src, input, inputIndex))
|
||||
continue;
|
||||
|
||||
//
|
||||
// Use user defined name if available otherwise use our internel unique name ID.
|
||||
//
|
||||
std::string inputName = ToString(input.Uid());
|
||||
auto inputItr = compositeOutputsMap.find(input);
|
||||
if (inputItr != compositeOutputsMap.end())
|
||||
inputName = ToString(inputItr->second.Uid());
|
||||
|
||||
auto inputArgType = ToTypeProto(input.Shape(), input.HasBatchAxis());
|
||||
UpdateONNXType(input.GetDataType(), inputArgType);
|
||||
ONNXIR::NodeArg inputArg(inputName, &inputArgType);
|
||||
|
||||
inputs.push_back(inputArg);
|
||||
|
||||
//
|
||||
// Leaf nodes are data entry to the graph and need their own node with only output arg.
|
||||
//
|
||||
if ((input.IsParameter() || input.IsConstant()) &&
|
||||
!Operators::IgnoreConstantAndParameter(src->OpName(), inputIndex))
|
||||
{
|
||||
if (variableNodes.find(input) == variableNodes.end())
|
||||
{
|
||||
std::vector<ONNXIR::NodeArg> varInputs;
|
||||
std::vector<ONNXIR::NodeArg> varOutputs;
|
||||
|
||||
varOutputs.push_back({ inputArg });
|
||||
ONNXIR::Node* variableNode = nullptr;
|
||||
if (input.IsParameter() || input.IsConstant())
|
||||
{
|
||||
variableNode = graph->AddNode(inputName, "Constant", "", varInputs, varOutputs);
|
||||
auto srcTensor = input.IsParameter() ? Parameter(input).Value() : Constant(input).Value();
|
||||
|
||||
ONNXIR::TensorProto dstTensor;
|
||||
CopyTensor(srcTensor, dstTensor);
|
||||
|
||||
variableNode->AddAttribute("value", dstTensor);
|
||||
variableNodes.emplace(input, variableNode);
|
||||
}
|
||||
}
|
||||
}
|
||||
//
|
||||
// If this input is output, then it is the ouput of an up stream node. Recursively add all upstream nodes.
|
||||
// Pretty much, we are doing DFS.
|
||||
//
|
||||
else if (input.IsOutput())
|
||||
CreateNode(input.Owner(), graph, functionNodes, variableNodes, compositeOutputsMap);
|
||||
}
|
||||
|
||||
//
|
||||
// Finally add a new node to ONNX graph.
|
||||
//
|
||||
functionNode = AddNode(src, graph, inputs, outputs);
|
||||
}
|
||||
else
|
||||
LogicError("Node '%S': Unsupported node.", src->AsString().c_str());
|
||||
|
||||
functionNodes.emplace(src, functionNode);
|
||||
return functionNode;
|
||||
}
|
||||
|
||||
void CNTKToONNXHelper::TraverseGraph(const FunctionPtr& src,
|
||||
std::set<FunctionPtr>& visited,
|
||||
std::unordered_map<Variable, Variable>& compositeOutputsMap)
|
||||
{
|
||||
auto iter = visited.find(src);
|
||||
if (iter != visited.end())
|
||||
return;
|
||||
|
||||
std::string opName = ToString(src->OpName());
|
||||
if (src->IsBlock())
|
||||
{
|
||||
if (!Operators::IsSupportedCNTKOP(src->OpName()) || Operators::IsLayerCNTKOP(src->OpName()))
|
||||
{
|
||||
auto blockSrc = dynamic_cast<BlockFunction*>(src.get());
|
||||
for (auto map : blockSrc->CompositeOutputsMap())
|
||||
compositeOutputsMap.insert(map);
|
||||
TraverseGraph(src->BlockRoot(), visited, compositeOutputsMap);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (auto input : src->Inputs())
|
||||
{
|
||||
if (input.IsPlaceholder())
|
||||
{
|
||||
input = input.BlockFunctionVariableMapping();
|
||||
if (input.IsPlaceholder())
|
||||
LogicError("Node '%S': Placeholder isn't supported currently.", src->AsString().c_str());
|
||||
}
|
||||
|
||||
if (input.IsOutput())
|
||||
TraverseGraph(input.Owner(), visited, compositeOutputsMap);
|
||||
}
|
||||
}
|
||||
|
||||
visited.emplace(src);
|
||||
}
|
||||
|
||||
void CNTKToONNXHelper::CopyAttributes(const FunctionPtr& src, ONNXIR::Node* node)
|
||||
{
|
||||
auto lookup = Operators::CntkToONNXLookup();
|
||||
assert(lookup.count(src->OpName()) != 0);
|
||||
|
||||
std::string opName = ToString(src->OpName());
|
||||
if (lookup.count(src->OpName()) == 1)
|
||||
{
|
||||
auto attributesMap = lookup.find(src->OpName())->second.map;
|
||||
opName = attributesMap[src->OpName()];
|
||||
|
||||
if (src->OpName() == L"BatchNormalization")
|
||||
{
|
||||
auto spatial = (int64_t)((bool)src->Attributes()[L"spatial"].Value<bool>() ? 1 : 0);
|
||||
auto normalizationTimeConstant = (float)src->Attributes()[L"normalizationTimeConstant"].Value<double>();
|
||||
// auto blendTimeConstant = (float)src->Attributes()[L"blendTimeConstant"].Value<double>();
|
||||
auto epsilon = (float)src->Attributes()[L"epsilon"].Value<double>();
|
||||
|
||||
//
|
||||
// onnx: running_mean = running_mean * momentum + mean * (1 - momentum)
|
||||
// cntk: expAvgFactor * MB stats + (1-expAvgFactor) * prev running stats
|
||||
//
|
||||
auto momentum = 0.0f;
|
||||
if (!isfinite(normalizationTimeConstant))
|
||||
momentum = 1.0f;
|
||||
else if (normalizationTimeConstant > 0)
|
||||
momentum = 1.0f + expm1(-48.0f / normalizationTimeConstant);
|
||||
|
||||
node->AddAttribute(attributesMap[L"spatial"], spatial);
|
||||
node->AddAttribute("is_test", (int64_t)1);
|
||||
node->AddAttribute(attributesMap[L"epsilon"], epsilon);
|
||||
node->AddAttribute("momentum", momentum);
|
||||
}
|
||||
else if (src->OpName() == L"LocalResponseNormalization")
|
||||
{
|
||||
auto depthRadius = (int64_t)src->Attributes()[L"depthRadius"].Value<size_t>();
|
||||
auto bias = (float)src->Attributes()[L"bias"].Value<double>();
|
||||
auto alpha = (float)src->Attributes()[L"alpha"].Value<double>();
|
||||
auto beta = (float)src->Attributes()[L"beta"].Value<double>();
|
||||
|
||||
node->AddAttribute(attributesMap[L"depthRadius"], depthRadius);
|
||||
node->AddAttribute(attributesMap[L"bias"], bias);
|
||||
node->AddAttribute(attributesMap[L"alpha"], alpha);
|
||||
node->AddAttribute(attributesMap[L"beta"], beta);
|
||||
}
|
||||
else if ((src->OpName() == L"LeakyReLU") || (src->OpName() == L"ELU"))
|
||||
{
|
||||
auto alpha = 0.01f;
|
||||
if (src->Attributes().Contains(L"alpha"))
|
||||
alpha = (float)src->Attributes()[L"alpha"].Value<double>();
|
||||
node->AddAttribute("alpha", 0.01f);
|
||||
}
|
||||
else if (src->OpName() == L"SELU")
|
||||
{
|
||||
auto alpha = 1.6732f;
|
||||
if (src->Attributes().Contains(L"alpha"))
|
||||
alpha = (float)src->Attributes()[L"alpha"].Value<double>();
|
||||
|
||||
auto gamma = 1.0507f;
|
||||
if (src->Attributes().Contains(L"gamma"))
|
||||
gamma = (float)src->Attributes()[L"gamma"].Value<double>();
|
||||
|
||||
node->AddAttribute("alpha", alpha);
|
||||
node->AddAttribute("gamma", gamma);
|
||||
}
|
||||
else if (src->OpName() == L"Dropout")
|
||||
{
|
||||
auto dropoutRate = (float)src->Attributes()[L"dropoutRate"].Value<double>();
|
||||
node->AddAttribute(attributesMap[L"dropoutRate"], dropoutRate);
|
||||
node->AddAttribute("is_test", (int64_t)1);
|
||||
}
|
||||
else if ((src->OpName() == L"UniformRandom") || (src->OpName() == L"NormalRandom") ||
|
||||
(src->OpName() == L"UniformRandomLike") || (src->OpName() == L"NormalRandomLike"))
|
||||
{
|
||||
auto randomArgs = AsVector<double>(src->Attributes()[L"randomDistributionArgs"].Value<std::vector<DictionaryValue>>());
|
||||
auto seed = (int64_t)src->Attributes()[L"rngSeed"].Value<int>();
|
||||
|
||||
if ((src->OpName() == L"UniformRandom") || (src->OpName() == L"UniformRandomLike"))
|
||||
{
|
||||
node->AddAttribute("low", (float)randomArgs[0]);
|
||||
node->AddAttribute("high", (float)randomArgs[1]);
|
||||
}
|
||||
else
|
||||
{
|
||||
node->AddAttribute("mean", (float)randomArgs[0]);
|
||||
node->AddAttribute("scale", (float)randomArgs[1]);
|
||||
}
|
||||
|
||||
node->AddAttribute(attributesMap[L"rngSeed"], seed);
|
||||
if ((src->OpName() == L"UniformRandom") || (src->OpName() == L"NormalRandom"))
|
||||
{
|
||||
auto shape = (NDShape)src->Attributes()[L"newShape"].Value<NDShape>();
|
||||
node->AddAttribute(attributesMap[L"newShape"], ToINTS(shape));
|
||||
}
|
||||
}
|
||||
else if ((src->OpName() == L"ReduceMax") || (src->OpName() == L"ReduceMin") ||
|
||||
(src->OpName() == L"ReduceSum") || (src->OpName() == L"ReduceMean") ||
|
||||
(src->OpName() == L"ReduceProd") || (src->OpName() == L"ReduceLogSum") ||
|
||||
(src->OpName() == L"Argmax") || (src->OpName() == L"Argmin"))
|
||||
{
|
||||
auto keepReducedDimensions = (int64_t)((bool)src->Attributes()[L"reductionKeepDimensions"].Value<bool>() ? 1 : 0);
|
||||
std::vector<Axis> reductionAxes;
|
||||
if (src->Attributes().Contains(L"axisVec"))
|
||||
reductionAxes = AsVector<Axis>(src->Attributes()[L"axisVec"].Value<std::vector<DictionaryValue>>());
|
||||
else if (src->Attributes().Contains(L"axis"))
|
||||
reductionAxes.push_back((Axis)(src->Attributes()[L"axis"].Value<Axis>()));
|
||||
|
||||
node->AddAttribute(attributesMap[L"reductionKeepDimensions"], keepReducedDimensions);
|
||||
node->AddAttribute("axes", ToINTS(reductionAxes));
|
||||
}
|
||||
else if (src->OpName() == L"Transpose")
|
||||
{
|
||||
std::vector<Axis> perm = AsVector<Axis>(src->Attributes()[L"axisVec"].Value<std::vector<DictionaryValue>>());
|
||||
node->AddAttribute(attributesMap[L"axisVec"], ToINTS(perm));
|
||||
}
|
||||
else if (src->OpName() == L"Reshape")
|
||||
{
|
||||
auto shape = (NDShape)src->Attributes()[L"newShape"].Value<NDShape>();
|
||||
node->AddAttribute(attributesMap[L"newShape"], ToINTS(shape, true));
|
||||
}
|
||||
else if (src->OpName() == L"Splice")
|
||||
{
|
||||
Axis axis = (Axis)(src->Attributes()[L"axis"].Value<Axis>());
|
||||
node->AddAttribute(attributesMap[L"axis"], (int64_t)ToIndex(axis));
|
||||
}
|
||||
else if (src->OpName() == L"Slice")
|
||||
{
|
||||
std::vector<Axis> sliceAxes;
|
||||
std::vector<int> beginIndex;
|
||||
std::vector<int> endIndex;
|
||||
|
||||
if (src->Attributes().Contains(L"axisVec"))
|
||||
{
|
||||
sliceAxes = AsVector<Axis>(src->Attributes()[L"axisVec"].Value<std::vector<DictionaryValue>>());
|
||||
beginIndex = AsVector<int>(src->Attributes()[L"beginIndexVec"].Value<std::vector<DictionaryValue>>());
|
||||
endIndex = AsVector<int>(src->Attributes()[L"endIndexVec"].Value<std::vector<DictionaryValue>>());
|
||||
}
|
||||
else if (src->Attributes().Contains(L"axis"))
|
||||
{
|
||||
sliceAxes.push_back((Axis)(src->Attributes()[L"axis"].Value<Axis>()));
|
||||
beginIndex.push_back((int)(src->Attributes()[L"beginIndex"].Value<int>()));
|
||||
endIndex.push_back((int)(src->Attributes()[L"endIndex"].Value<int>()));
|
||||
}
|
||||
|
||||
node->AddAttribute(attributesMap[L"axes"], ToINTS(sliceAxes));
|
||||
node->AddAttribute(attributesMap[L"starts"], ToINTS(beginIndex));
|
||||
node->AddAttribute(attributesMap[L"ends"], ToINTS(endIndex));
|
||||
}
|
||||
else if (src->OpName() == L"Softmax")
|
||||
{
|
||||
Axis axis = Axis(0);
|
||||
if (src->Attributes().Contains(L"axis"))
|
||||
axis = (Axis)(src->Attributes()[L"axis"].Value<Axis>());
|
||||
node->AddAttribute(attributesMap[L"axis"], (int64_t)ToIndex(axis));
|
||||
}
|
||||
else if ((src->OpName() == L"Plus") || (src->OpName() == L"Minus") ||
|
||||
(src->OpName() == L"ElementTimes") || (src->OpName() == L"ElementDivide"))
|
||||
{
|
||||
node->AddAttribute("broadcast", (int64_t)1);
|
||||
// node->AddAttribute("axis", (int64_t)1);
|
||||
}
|
||||
else if (src->OpName() == L"Times")
|
||||
{
|
||||
size_t outputRank = src->Attributes()[L"outputRank"].Value<size_t>();
|
||||
if (outputRank > 1)
|
||||
LogicError("Output rank other than 1 is not supported.");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Some nodes map one to many.
|
||||
if (src->OpName() == L"Convolution")
|
||||
{
|
||||
auto kernelShape = (NDShape)src->Attributes()[L"kernelShape"].Value<NDShape>();
|
||||
auto strides = (NDShape)src->Attributes()[L"strides"].Value<NDShape>();
|
||||
auto autoPadding = AsVector<bool>(src->Attributes()[L"autoPadding"].Value<std::vector<DictionaryValue>>());
|
||||
auto dilations = (NDShape)src->Attributes()[L"dilation"].Value<NDShape>();
|
||||
auto transpose = (bool)src->Attributes()[L"transpose"].Value<bool>();
|
||||
|
||||
//
|
||||
// Remove the channel part for ONNX.
|
||||
//
|
||||
kernelShape = kernelShape.SubShape(0, kernelShape.Rank() - 1);
|
||||
strides = strides.SubShape(0, strides.Rank() - 1);
|
||||
autoPadding.pop_back();
|
||||
dilations = dilations.SubShape(0, dilations.Rank() - 1);
|
||||
|
||||
node->AddAttribute("kernel_shape", ToINTS(kernelShape));
|
||||
node->AddAttribute("strides", ToINTS(strides));
|
||||
node->AddAttribute("pads", ToINTS(autoPadding));
|
||||
node->AddAttribute("dilations", ToINTS(dilations));
|
||||
node->AddAttribute("group", (int64_t)1);
|
||||
|
||||
if (transpose)
|
||||
{
|
||||
auto outputShape = (NDShape)src->Attributes()[L"outputShape"].Value<NDShape>();
|
||||
node->AddAttribute("output_shape", ToINTS(outputShape, src->Inputs()[1].HasBatchAxis()));
|
||||
}
|
||||
}
|
||||
else if (src->OpName() == L"Pooling")
|
||||
{
|
||||
auto kernelShape = (NDShape)src->Attributes()[L"poolingWindowShape"].Value<NDShape>();
|
||||
auto strides = (NDShape)src->Attributes()[L"strides"].Value<NDShape>();
|
||||
auto autoPadding = AsVector<bool>(src->Attributes()[L"autoPadding"].Value<std::vector<DictionaryValue>>());
|
||||
|
||||
node->AddAttribute("kernel_shape", ToINTS(kernelShape));
|
||||
node->AddAttribute("strides", ToINTS(strides));
|
||||
node->AddAttribute("pads", ToINTS(autoPadding));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ONNXIR::NodeArg> CNTKToONNXHelper::MapInputsOrderToONNX(const FunctionPtr& src, const std::vector<ONNXIR::NodeArg>& inputs)
|
||||
{
|
||||
if (Operators::HasInputIndexMap(src->OpName()))
|
||||
{
|
||||
std::vector<ONNXIR::NodeArg> orderedInputs;
|
||||
std::map<int, ONNXIR::NodeArg> orderedInputsMap;
|
||||
auto map = Operators::ToONNXInputIndexMap(src->OpName());
|
||||
|
||||
for (size_t inputIndex = 0; inputIndex < inputs.size(); ++inputIndex)
|
||||
{
|
||||
if (map[inputIndex] >= 0)
|
||||
orderedInputsMap.insert(std::pair<int, ONNXIR::NodeArg>(map[inputIndex], inputs[inputIndex]));
|
||||
}
|
||||
|
||||
for (const auto& item : orderedInputsMap)
|
||||
orderedInputs.push_back(item.second);
|
||||
|
||||
return orderedInputs;
|
||||
}
|
||||
|
||||
return inputs;
|
||||
}
|
||||
|
||||
ONNXIR::Node* CNTKToONNXHelper::AddNode(const FunctionPtr& src, ONNXIR::Graph* graph, const std::vector<ONNXIR::NodeArg>& inputs, const std::vector<ONNXIR::NodeArg>& outputs)
|
||||
{
|
||||
ONNXIR::Node* node = nullptr;
|
||||
auto orderedInputs = MapInputsOrderToONNX(src, inputs);
|
||||
auto nodeName = src->Name().empty() ? ToString(src->Uid()) : ToString(src->Name());
|
||||
|
||||
//
|
||||
// CNTK Times OP is way more flexible for ONNX, so depend on the inputs and output shape,
|
||||
// we will need to insert some reshapes.
|
||||
//
|
||||
if (src->OpName() == L"Times")
|
||||
{
|
||||
auto input1Shape = orderedInputs[0].Shape();
|
||||
auto input2Shape = orderedInputs[1].Shape();
|
||||
auto outputShape = outputs[0].Shape();
|
||||
|
||||
int input1Rank = input1Shape->dim_size();
|
||||
int input2Rank = input2Shape->dim_size();
|
||||
int outputRank = outputShape->dim_size();
|
||||
int reductionRank = (input1Rank + input2Rank - outputRank) / 2;
|
||||
|
||||
if (reductionRank > 1) // We need to insert reshape.
|
||||
{
|
||||
auto input1Reshape = ReduceRank(input1Shape, reductionRank, true);
|
||||
auto input2Reshape = ReduceRank(input2Shape, reductionRank, false);
|
||||
|
||||
UpdateONNXType(src->Inputs()[1].GetDataType(), input1Reshape);
|
||||
UpdateONNXType(src->Inputs()[0].GetDataType(), input2Reshape);
|
||||
|
||||
ONNXIR::NodeArg inputOutput1Arg(orderedInputs[0].Name() + string("_reshape0"), &input1Reshape);
|
||||
ONNXIR::NodeArg inputOutput2Arg(orderedInputs[1].Name() + string("_reshape1"), &input2Reshape);
|
||||
|
||||
auto reshapeNode1 = graph->AddNode(nodeName + string("_reshape0"), "Reshape", "", { orderedInputs[0] }, { inputOutput1Arg });
|
||||
auto reshapeNode2 = graph->AddNode(nodeName + string("_reshape1"), "Reshape", "", { orderedInputs[1] }, { inputOutput2Arg });
|
||||
|
||||
reshapeNode1->AddAttribute("shape", ToINTS(input1Reshape));
|
||||
reshapeNode2->AddAttribute("shape", ToINTS(input2Reshape));
|
||||
|
||||
node = graph->AddNode(nodeName, ToOPName(src), "", { inputOutput1Arg , inputOutput2Arg }, outputs);
|
||||
}
|
||||
else
|
||||
node = graph->AddNode(nodeName, ToOPName(src), "", orderedInputs, outputs);
|
||||
}
|
||||
else
|
||||
node = graph->AddNode(nodeName, ToOPName(src), "", orderedInputs, outputs);
|
||||
|
||||
//
|
||||
// Copy and validate attributes.
|
||||
//
|
||||
CopyAttributes(src, node);
|
||||
|
||||
return node;
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "stdafx.h"
|
||||
#include "CNTKLibrary.h"
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
class Model;
|
||||
}
|
||||
|
||||
namespace CNTK
|
||||
{
|
||||
class CNTKToONNX
|
||||
{
|
||||
public:
|
||||
static std::unique_ptr<ONNXIR::Model> CreateModel(const FunctionPtr& src);
|
||||
};
|
||||
}
|
|
@ -0,0 +1,75 @@
|
|||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
|
||||
#include "ONNX.h"
|
||||
#include "CNTKToONNX.h"
|
||||
#include "proto/onnx/core/model.h"
|
||||
#include "proto/onnx/core/graph.h"
|
||||
#include "Utils.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "ONNXToCNTK.h"
|
||||
|
||||
using namespace CNTK;
|
||||
|
||||
namespace CNTK
|
||||
{
|
||||
static void PrintGraph(FunctionPtr function, int spaces, bool useName = false)
|
||||
{
|
||||
if (function->Inputs().size() == 0)
|
||||
{
|
||||
cout << string(spaces, '.') + "(" + ToString(useName ? function->Name() : function->Uid()) + ")" + ToString(function->AsString()) << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
for (auto input : function->Inputs())
|
||||
{
|
||||
cout << string(spaces, '.') + "(" + ToString(useName ? function->Name() : function->Uid()) + ")" + "->" +
|
||||
"(" + ToString(useName ? input.Name() : input.Uid()) + ")" + ToString(input.AsString()) << std::endl;
|
||||
}
|
||||
|
||||
for (auto input : function->Inputs())
|
||||
{
|
||||
if (input.Owner() != NULL)
|
||||
{
|
||||
FunctionPtr f = input.Owner();
|
||||
PrintGraph(f, spaces + 4, useName);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ONNXFormat::Save(const FunctionPtr& src, const std::wstring& filepath)
|
||||
{
|
||||
auto model = CNTKToONNX::CreateModel(src);
|
||||
#ifdef _WIN32
|
||||
ONNXIR::Model::Save(*model, filepath);
|
||||
#else
|
||||
ONNXIR::Model::Save(*model, ToString(filepath));
|
||||
#endif
|
||||
}
|
||||
|
||||
FunctionPtr ONNXFormat::Load(const std::wstring& filepath, const DeviceDescriptor& computeDevice)
|
||||
{
|
||||
ONNXIR::ModelProto modelProto;
|
||||
|
||||
#ifdef _WIN32
|
||||
bool loadStatus = ONNXIR::Model::Load(filepath, &modelProto);
|
||||
#else
|
||||
bool loadStatus = ONNXIR::Model::Load(ToString(filepath), &modelProto);
|
||||
#endif
|
||||
loadStatus;
|
||||
//if (!loadStatus)
|
||||
// LogicError("Failed to load the model.");
|
||||
|
||||
ONNXIR::Model model(modelProto);
|
||||
auto status = model.MainGraph()->Resolve();
|
||||
if (!status.Ok())
|
||||
LogicError("%s", status.ErrorMsg().c_str());
|
||||
|
||||
FunctionPtr cntkFunction = ONNXToCNTK::CreateGraph(model.MainGraph(), computeDevice);
|
||||
return cntkFunction;
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "stdafx.h"
|
||||
#include "CNTKLibrary.h"
|
||||
|
||||
|
||||
namespace CNTK
|
||||
{
|
||||
class ONNXFormat
|
||||
{
|
||||
public:
|
||||
static void Save(const FunctionPtr& src, const std::wstring& filepath);
|
||||
static FunctionPtr Load(const std::wstring& filepath, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
|
||||
};
|
||||
}
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,27 @@
|
|||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "stdafx.h"
|
||||
#include "CNTKLibrary.h"
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
class Graph;
|
||||
}
|
||||
|
||||
namespace CNTK
|
||||
{
|
||||
class ONNXToCNTK
|
||||
{
|
||||
public:
|
||||
//
|
||||
// Create a CNTK graph (Function) given an ONNX graph. The function is created to use the
|
||||
// specified computing device.
|
||||
//
|
||||
static FunctionPtr CreateGraph(ONNXIR::Graph* src, const DeviceDescriptor& computeDevice);
|
||||
};
|
||||
}
|
|
@ -0,0 +1,297 @@
|
|||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
|
||||
#include "Operators.h"
|
||||
#include "proto/onnx/core/graph.h"
|
||||
#include "Utils.h"
|
||||
|
||||
namespace CNTK
|
||||
{
|
||||
namespace ONNX
|
||||
{
|
||||
//
|
||||
// Support ONNX OPs from https://github.com/onnx/onnx/tree/master/onnx/defs
|
||||
//
|
||||
// The format of the below structure is simply a key which is CNTK OpName and a corresponding
|
||||
// lookup table, the corrsponding lookup table map the OpName and all its attributes from CNTK
|
||||
// to ONNX.
|
||||
//
|
||||
// Eventually, it would be good to change CNTK OpName to match ONNX in order to avoid the need
|
||||
// of the below table.
|
||||
//
|
||||
std::unordered_multimap<std::wstring, AttributesMapping> Operators::_cntkToONNXOpName = {
|
||||
// From nn
|
||||
{ L"Pooling", { {
|
||||
{ L"Pooling", "AveragePool" },
|
||||
{ L"poolingWindowShape", "kernel_shape" },
|
||||
{ L"strides", "strides" },
|
||||
{ L"autoPadding", "pads" },
|
||||
} } },
|
||||
{ L"Pooling", { {
|
||||
{ L"Pooling", "MaxPool" },
|
||||
{ L"poolingWindowShape", "kernel_shape" },
|
||||
{ L"strides", "strides" },
|
||||
{ L"autoPadding", "pads" },
|
||||
} } },
|
||||
{ L"Convolution", { {
|
||||
{ L"Convolution", "Conv" },
|
||||
{ L"kernelShape", "kernel_shape" },
|
||||
{ L"strides", "strides" },
|
||||
{ L"autoPadding", "pads" },
|
||||
{ L"dilation", "dilations" },
|
||||
// { L"", "group" },
|
||||
} } },
|
||||
{ L"Convolution", { {
|
||||
{ L"ConvolutionTranspose", "ConvTranspose" },
|
||||
{ L"kernelShape", "kernel_shape" },
|
||||
{ L"strides", "strides" },
|
||||
{ L"autoPadding", "pads" },
|
||||
{ L"dilation", "dilations" },
|
||||
{ L"outputShape", "output_shape" },
|
||||
} } },
|
||||
{ L"GlobalMaxPooling", { {
|
||||
{ L"GlobalMaxPooling", "GlobalMaxPool" },
|
||||
} } },
|
||||
{ L"GlobalAveragePooling", { {
|
||||
{ L"GlobalAveragePooling", "GlobalAveragePool" },
|
||||
} } },
|
||||
{ L"BatchNormalization", { {
|
||||
{ L"BatchNormalization", "BatchNormalization" },
|
||||
{ L"spatial", "spatial" },
|
||||
// { L"", "is_test" },
|
||||
{ L"epsilon", "epsilon" },
|
||||
// { L"", "momentum" },
|
||||
} } },
|
||||
// from ONNX experiament, added to test Caffe models
|
||||
// TODO: set key as BatchNormalization instead of BatchNormalizationCaffe
|
||||
{ L"BatchNormalizationCaffe",{ {
|
||||
{ L"BatchNormalization", "SpatialBN" },
|
||||
{ L"spatial", "spatial" },
|
||||
// { L"", "is_test" },
|
||||
{ L"epsilon", "epsilon" },
|
||||
// { L"", "momentum" },
|
||||
} } },
|
||||
{ L"LocalResponseNormalization",{ {
|
||||
{ L"LocalResponseNormalization", "LRN" },
|
||||
{ L"depthRadius", "size" },
|
||||
{ L"bias", "bias" },
|
||||
{ L"alpha", "alpha" },
|
||||
{ L"beta", "beta" },
|
||||
} } },
|
||||
{ L"Dropout", { {
|
||||
{ L"Dropout", "Dropout" },
|
||||
{ L"dropoutRate", "ratio" },
|
||||
// { L"", "is_test" },
|
||||
} } },
|
||||
{ L"Flatten",{ {
|
||||
{ L"Flatten", "Flatten" },
|
||||
} } },
|
||||
|
||||
// From Generator
|
||||
{ L"UniformRandom", { {
|
||||
{ L"UniformRandom", "RandomUniform" },
|
||||
// { L"", "low" },
|
||||
// { L"", "high" },
|
||||
{ L"rngSeed", "seed" },
|
||||
{ L"newShape", "shape" },
|
||||
} } },
|
||||
{ L"NormalRandom", { {
|
||||
{ L"NormalRandom", "RandomNormal" },
|
||||
// { L"", "mean" },
|
||||
// { L"", "scale" },
|
||||
{ L"rngSeed", "seed" },
|
||||
{ L"newShape", "shape" },
|
||||
} } },
|
||||
{ L"UniformRandomLike", { {
|
||||
{ L"UniformRandomLike", "RandomUniformLike" },
|
||||
// { L"", "low" },
|
||||
// { L"", "high" },
|
||||
{ L"rngSeed", "seed" },
|
||||
} } },
|
||||
{ L"NormalRandomLike", { {
|
||||
{ L"NormalRandomLike", "RandomNormalLike" },
|
||||
// { L"", "mean" },
|
||||
// { L"", "scale" },
|
||||
{ L"rngSeed", "seed" },
|
||||
} } },
|
||||
|
||||
// From Math
|
||||
{ L"Plus", { {
|
||||
{ L"Plus", "Add" },
|
||||
} } },
|
||||
{ L"Minus", { {
|
||||
{ L"Minus", "Sub" },
|
||||
} } },
|
||||
{ L"ElementTimes", { {
|
||||
{ L"ElementTimes", "Mul" },
|
||||
} } },
|
||||
{ L"ElementDivide", { {
|
||||
{ L"ElementDivide", "Div" },
|
||||
} } },
|
||||
{ L"Negate", { {
|
||||
{ L"Negate", "Neg" },
|
||||
} } },
|
||||
{ L"Abs", { {
|
||||
{ L"Abs", "Abs" },
|
||||
} } },
|
||||
{ L"Reciprocal", { {
|
||||
{ L"Reciprocal", "Reciprocal" },
|
||||
} } },
|
||||
{ L"Floor", { {
|
||||
{ L"Floor", "Floor" },
|
||||
} } },
|
||||
{ L"Ceil", { {
|
||||
{ L"Ceil", "Ceil" },
|
||||
} } },
|
||||
{ L"Sqrt", { {
|
||||
{ L"Sqrt", "Sqrt" },
|
||||
} } },
|
||||
{ L"ReLU", { {
|
||||
{ L"ReLU", "Relu" },
|
||||
} } },
|
||||
{ L"LeakyReLU", { {
|
||||
{ L"LeakyReLU", "LeakyRelu" },
|
||||
{ L"alpha", "alpha" },
|
||||
} } },
|
||||
{ L"SELU", { {
|
||||
{ L"SELU", "Selu" },
|
||||
{ L"alpha", "alpha" },
|
||||
{ L"gamma", "gamma" },
|
||||
} } },
|
||||
{ L"ELU", { {
|
||||
{ L"ELU", "Elu" },
|
||||
// { L"", "alpha" },
|
||||
} } },
|
||||
{ L"Exp", { {
|
||||
{ L"Exp", "Exp" },
|
||||
} } },
|
||||
{ L"Log", { {
|
||||
{ L"Log", "Log" },
|
||||
} } },
|
||||
{ L"Tanh", { {
|
||||
{ L"Tanh", "Tanh" },
|
||||
} } },
|
||||
{ L"Pow", { {
|
||||
{ L"Pow", "Pow" },
|
||||
// { L"", "exponent" },
|
||||
} } },
|
||||
{ L"Times", { {
|
||||
{ L"Times", "Dot" },
|
||||
} } },
|
||||
{ L"PReLU", { {
|
||||
{ L"PReLU", "PRelu" },
|
||||
} } },
|
||||
{ L"StableSigmoid", { {
|
||||
{ L"StableSigmoid", "Sigmoid" },
|
||||
} } },
|
||||
{ L"ElementMax", { {
|
||||
{ L"ElementMax", "Max" },
|
||||
} } },
|
||||
{ L"ElementMax", { {
|
||||
{ L"ElementMax", "Min" },
|
||||
} } },
|
||||
// { L"", "Sum" },
|
||||
{ L"Softmax", { {
|
||||
{ L"Softmax", "Softmax" },
|
||||
{ L"axis", "axis" },
|
||||
} } },
|
||||
|
||||
// From reduction
|
||||
{ L"ReduceMax", { {
|
||||
{ L"ReduceMax", "ReduceMax" },
|
||||
{ L"axisVec", "axes" },
|
||||
{ L"reductionKeepDimensions", "keepdims" },
|
||||
} } },
|
||||
{ L"ReduceMin", { {
|
||||
{ L"ReduceMin", "ReduceMin" },
|
||||
{ L"axisVec", "axes" },
|
||||
{ L"reductionKeepDimensions", "keepdims" },
|
||||
} } },
|
||||
{ L"ReduceSum", { {
|
||||
{ L"ReduceSum", "ReduceSum" },
|
||||
{ L"axisVec", "axes" },
|
||||
{ L"reductionKeepDimensions", "keepdims" },
|
||||
} } },
|
||||
{ L"ReduceMean", { {
|
||||
{ L"ReduceMean", "ReduceMean" },
|
||||
{ L"axisVec", "axes" },
|
||||
{ L"reductionKeepDimensions", "keepdims" },
|
||||
} } },
|
||||
{ L"ReduceProd", { {
|
||||
{ L"ReduceProd", "ReduceProd" },
|
||||
{ L"axisVec", "axes" },
|
||||
{ L"reductionKeepDimensions", "keepdims" },
|
||||
} } },
|
||||
{ L"ReduceLogSum", { {
|
||||
{ L"ReduceLogSum", "ReduceLogSumExp" },
|
||||
{ L"axisVec", "axes" },
|
||||
{ L"reductionKeepDimensions", "keepdims" },
|
||||
} } },
|
||||
{ L"Argmax", { {
|
||||
{ L"Argmax", "ArgMax" },
|
||||
{ L"axis", "axes" },
|
||||
// { L"", "keepdims" },
|
||||
} } },
|
||||
{ L"Argmin", { {
|
||||
{ L"Argmin", "ArgMin" },
|
||||
{ L"axis", "axes" },
|
||||
// { L"", "keepdims" },
|
||||
} } },
|
||||
|
||||
// From tensor
|
||||
// { L"", "Cast" },
|
||||
{ L"Reshape", { {
|
||||
{ L"Reshape", "Reshape" },
|
||||
{ L"newShape", "shape" },
|
||||
} } },
|
||||
{ L"Splice", { {
|
||||
{ L"Splice", "Concat" },
|
||||
{ L"axis", "axis" },
|
||||
} } },
|
||||
// { L"", "Split" },
|
||||
{ L"Slice", { {
|
||||
{ L"Slice", "Slice" },
|
||||
{ L"beginIndexVec", "starts" },
|
||||
{ L"endIndexVec", "ends" },
|
||||
} } },
|
||||
{ L"Transpose", { {
|
||||
{ L"Transpose", "Transpose" },
|
||||
{ L"axisVec", "perm" },
|
||||
} } },
|
||||
{ L"GatherOp", { {
|
||||
{ L"GatherOp", "Gather" },
|
||||
} } },
|
||||
// { L"", "Squeeze" },
|
||||
};
|
||||
|
||||
std::unordered_map<std::wstring, std::set<size_t>> Operators::_cntkBlockOPInvalidIndices = {
|
||||
{ L"LeakyReLU", { 0, 1 } },
|
||||
{ L"SELU", { 0, 1, 2 } },
|
||||
{ L"PReLU", { 0 } },
|
||||
{ L"ElementMax", {} },
|
||||
{ L"ElementMin", {} },
|
||||
{ L"Softmax", {} },
|
||||
{ L"LocalResponseNormalization", { 0, 1, 2 } }
|
||||
};
|
||||
|
||||
std::unordered_map<std::wstring, std::vector<int>> Operators::_cntkToONNXInputIndices = {
|
||||
{ L"Convolution", { 1, 0 } },
|
||||
{ L"ConvolutionTranspose", { 1, 0 } },
|
||||
{ L"BatchNormalization", { 0, 1, 2, 3, 4, -1 } },
|
||||
{ L"Times", { 1, 0 } },
|
||||
};
|
||||
|
||||
//
|
||||
// CNTK Layer API needs to be treated specially.
|
||||
//
|
||||
std::set<std::wstring> Operators::_cntkLayerOPName = {
|
||||
{ L"Convolution" },
|
||||
{ L"ConvolutionTranspose" },
|
||||
{ L"BatchNormalization" },
|
||||
{ L"Dropout" },
|
||||
};
|
||||
|
||||
}
|
||||
}
|
|
@ -0,0 +1,111 @@
|
|||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "stdafx.h"
|
||||
#include "CNTKLibrary.h"
|
||||
|
||||
#include <set>
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
class Graph;
|
||||
}
|
||||
|
||||
namespace CNTK
|
||||
{
|
||||
namespace ONNX
|
||||
{
|
||||
|
||||
struct AttributesMapping
|
||||
{
|
||||
std::unordered_map<std::wstring, std::string> map;
|
||||
};
|
||||
|
||||
class Operators
|
||||
{
|
||||
public:
|
||||
//
|
||||
// Check if opName is one of the supported ONNX OP.
|
||||
//
|
||||
static inline bool IsSupportedCNTKOP(const std::wstring& opName)
|
||||
{
|
||||
return _cntkToONNXOpName.find(opName) != _cntkToONNXOpName.end();
|
||||
}
|
||||
|
||||
//
|
||||
// Layer APIs use block function as a wrapper, so we need to handle them with care.
|
||||
//
|
||||
static inline bool IsLayerCNTKOP(const std::wstring& opName)
|
||||
{
|
||||
return _cntkLayerOPName.find(opName) != _cntkLayerOPName.end();
|
||||
}
|
||||
|
||||
//
|
||||
// Return a lookup table which is keyed on CNTK OP, and the value is another table
|
||||
// that contain name mapping from CNTK to ONNX.
|
||||
//
|
||||
static inline const std::unordered_multimap<std::wstring, AttributesMapping>& CntkToONNXLookup()
|
||||
{
|
||||
return _cntkToONNXOpName;
|
||||
}
|
||||
|
||||
//
|
||||
// Because in CNTK block, we can't filtered out the external inputs to the block.
|
||||
// We need a way to filter out leaf input from its subgraph.
|
||||
//
|
||||
static inline bool IsValidInputs(const std::wstring& opName, size_t index)
|
||||
{
|
||||
assert(_cntkBlockOPInvalidIndices.find(opName) != _cntkBlockOPInvalidIndices.end());
|
||||
|
||||
auto invalidIndices = _cntkBlockOPInvalidIndices[opName];
|
||||
return invalidIndices.find(index) == invalidIndices.end();
|
||||
}
|
||||
|
||||
//
|
||||
// The positional of the argument between CNTK and ONNX aren't the same.
|
||||
// The below function return true, if we need a remap.
|
||||
//
|
||||
static inline bool HasInputIndexMap(const std::wstring& opName)
|
||||
{
|
||||
return _cntkToONNXInputIndices.find(opName) != _cntkToONNXInputIndices.end();
|
||||
}
|
||||
|
||||
//
|
||||
// If we need a remap, the below function return a remapping map.
|
||||
//
|
||||
static inline const std::vector<int>& ToONNXInputIndexMap(const std::wstring& opName)
|
||||
{
|
||||
assert(_cntkToONNXInputIndices.find(opName) != _cntkToONNXInputIndices.end());
|
||||
return _cntkToONNXInputIndices[opName];
|
||||
}
|
||||
|
||||
//
|
||||
// For block function with internal constant or parameter, we don't want to create
|
||||
// the corresponding ONNX tensor for some of the parameters.
|
||||
//
|
||||
static inline bool IgnoreConstantAndParameter(const std::wstring& opName, size_t index)
|
||||
{
|
||||
if (_cntkToONNXInputIndices.find(opName) != _cntkToONNXInputIndices.end())
|
||||
{
|
||||
auto indexMap = _cntkToONNXInputIndices[opName];
|
||||
assert (index < indexMap.size());
|
||||
|
||||
return (indexMap[index] < 0);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
static std::unordered_multimap<std::wstring, AttributesMapping> _cntkToONNXOpName;
|
||||
static std::unordered_map<std::wstring, std::set<size_t>> _cntkBlockOPInvalidIndices;
|
||||
static std::unordered_map<std::wstring, std::vector<int>> _cntkToONNXInputIndices;
|
||||
static std::set<std::wstring> _cntkLayerOPName;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
#include "constants.h"
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
|
||||
TypesWrapper& TypesWrapper::GetTypesWrapper()
|
||||
{
|
||||
static TypesWrapper* types = new TypesWrapper();
|
||||
return *types;
|
||||
}
|
||||
|
||||
std::unordered_set<std::string>& TypesWrapper::GetAllowedDataTypes()
|
||||
{
|
||||
static std::unordered_set<std::string>* allowedDataTypes =
|
||||
new std::unordered_set<std::string>({
|
||||
c_float16, c_float, c_double,
|
||||
c_int8, c_int16, c_int32, c_int64,
|
||||
c_uint8, c_uint16, c_uint32, c_uint64,
|
||||
c_complex64, c_complex128,
|
||||
c_string, c_bool });
|
||||
return *allowedDataTypes;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
#include <string>
|
||||
#include <unordered_set>
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
static const std::string c_noOp = "NoOp";
|
||||
static const std::string c_constantOp = "Constant";
|
||||
static const std::string c_constantValue = "value";
|
||||
|
||||
// Singleton wrapper around allowed data types.
|
||||
// This implements construct on first use which is needed to ensure
|
||||
// static objects are initialized before use. Ops registration does not work
|
||||
// properly without this.
|
||||
class TypesWrapper
|
||||
{
|
||||
public:
|
||||
static TypesWrapper& GetTypesWrapper();
|
||||
|
||||
// DataType strings. These should match the DataTypes defined in Data.proto
|
||||
const std::string c_float16 = "float16";
|
||||
const std::string c_float = "float";
|
||||
const std::string c_double = "double";
|
||||
const std::string c_int8 = "int8";
|
||||
const std::string c_int16 = "int16";
|
||||
const std::string c_int32 = "int32";
|
||||
const std::string c_int64 = "int64";
|
||||
const std::string c_uint8 = "uint8";
|
||||
const std::string c_uint16 = "uint16";
|
||||
const std::string c_uint32 = "uint32";
|
||||
const std::string c_uint64 = "uint64";
|
||||
const std::string c_complex64 = "complex64";
|
||||
const std::string c_complex128 = "complex128";
|
||||
const std::string c_string = "string";
|
||||
const std::string c_bool = "bool";
|
||||
std::unordered_set<std::string>& GetAllowedDataTypes();
|
||||
~TypesWrapper() = default;
|
||||
TypesWrapper(const TypesWrapper&) = delete;
|
||||
void operator=(const TypesWrapper&) = delete;
|
||||
private:
|
||||
TypesWrapper() = default;
|
||||
};
|
||||
}
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,604 @@
|
|||
#pragma warning(push)
|
||||
#pragma warning(disable : 4800 4610 4512 4510 4267 4127 4125 4100 4456)
|
||||
|
||||
#ifndef CORE_GRAPH_GRAPH_H
|
||||
#define CORE_GRAPH_GRAPH_H
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "constants.h"
|
||||
#include "proto/onnx/protobuf/graph.pb.h"
|
||||
#include "status.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
typedef size_t NODEINDEX;
|
||||
typedef int64_t VERSION;
|
||||
typedef std::unordered_map<std::string, AttributeProto> NodeAttributes;
|
||||
typedef ValueInfoProto NodeArgInfo;
|
||||
typedef std::unordered_map<std::string, TensorProto> InitialTensorSet;
|
||||
typedef std::unordered_map<std::string, TypeProto> ArgNameToTypeMap;
|
||||
|
||||
class Graph;
|
||||
class Node;
|
||||
class OpSignature;
|
||||
|
||||
// Node argument definition, for both input and output,
|
||||
// including arg name, arg type (contains both type and shape).
|
||||
//
|
||||
// Design Question: in my (Ke's) opinion, shape should not be part of type.
|
||||
// We may align the protobuf design with our operator registry interface,
|
||||
// which has type specified for each operator, but no shape. Well, shape
|
||||
// should be inferred with a separate shape inference function given
|
||||
// input shapes, or input tensor data sometimes.
|
||||
// With shape as part of type (current protobuf design),
|
||||
// 1) we'll have to split the "TypeProto" into type and shape in this internal
|
||||
// representation interface so that it could be easily used when doing type
|
||||
// inference and matching with operator registry.
|
||||
// 2) SetType should be always called before SetShape, otherwise, SetShape()
|
||||
// will fail. Because shape is located in a TypeProto.
|
||||
// Thoughts?
|
||||
//
|
||||
class NodeArg
|
||||
{
|
||||
public:
|
||||
|
||||
// Constructor by specifying node arg name and type&shape which is
|
||||
// optional. This is called when loading a <Graph> from <GraphProto>
|
||||
// normally.
|
||||
NodeArg(const std::string& p_name,
|
||||
const TypeProto* p_argType);
|
||||
|
||||
// Get node arg name.
|
||||
const std::string& Name() const;
|
||||
|
||||
// Get node arg type.
|
||||
const PTYPE& Type() const;
|
||||
|
||||
// Get node arg shape.
|
||||
// Return null pointer if there's no shape specified.
|
||||
const TypeProto::TensorShapeProto* Shape() const;
|
||||
|
||||
// Set node arg shape.
|
||||
// Shape could only be set after setting type since shape information
|
||||
// now is part of TypeProto.
|
||||
void SetShape(const TypeProto::TensorShapeProto& p_shape);
|
||||
|
||||
// Get node arg info proto.
|
||||
const NodeArgInfo& ToProto() const;
|
||||
|
||||
private:
|
||||
|
||||
friend class Node;
|
||||
friend class Graph;
|
||||
|
||||
void SetType(PTYPE p_type);
|
||||
void SetType(const TypeProto& p_typeProto);
|
||||
|
||||
// Node arg PType.
|
||||
PTYPE m_type;
|
||||
|
||||
// Node arg name, type and shape.
|
||||
NodeArgInfo m_nodeArgInfo;
|
||||
};
|
||||
|
||||
// Function representation.
|
||||
// It could present two cases of functions.
|
||||
// 1. Function without instantiation (No Node* sent to constructor). This
|
||||
// may be used in pure function optimization.
|
||||
// Function body (subgraph) should not be able to executed since no real
|
||||
// tensor binded with inputs/outputs of the function.
|
||||
// 2. Function with instantiation (A non-empty Node* sent to constructor).
|
||||
// Function body (subgraph) should be able to be executed since all
|
||||
// input/output names among nodes inside are refering real tensor names.
|
||||
// 2_a. Function with template type parameter.
|
||||
// 2_b. Function without template type parameter.
|
||||
// Function definition (FunctionDefProto) will be synced only when its body
|
||||
// is changed. Meanwhile, in 2_a case above, the function definition name
|
||||
// will be appended with real type string.
|
||||
class Function
|
||||
{
|
||||
public:
|
||||
|
||||
// Get function body - a subgraph.
|
||||
// Returned pointer owned by <*this> Function.
|
||||
Graph* Body();
|
||||
|
||||
// Get function name.
|
||||
// A function's name could be either its function definition name
|
||||
// m_functionDefProto.name(), or m_functionDefProto.name() + template
|
||||
// argument value.
|
||||
const std::string& Name();
|
||||
|
||||
// Get the protobuf representation of <*this> function.
|
||||
const FunctionDefProto& ToProto();
|
||||
|
||||
private:
|
||||
|
||||
friend class Graph;
|
||||
|
||||
Function() = delete;
|
||||
|
||||
// Constructor.
|
||||
// <p_node> specifies the node that refers to <*this> function. It's
|
||||
// used to instantiate <p_funcProto> if <p_funcProto> is a function
|
||||
// template.
|
||||
// <p_funcProto> specifies a function definition that a node refers to.
|
||||
Function(Node* p_node,
|
||||
const FunctionDefProto& p_funcProto);
|
||||
|
||||
// Function body which is a SubGraph.
|
||||
std::unique_ptr<Graph> m_body;
|
||||
};
|
||||
|
||||
// A node representation class.
|
||||
class Node {
|
||||
|
||||
public:
|
||||
|
||||
// An edge end. It could be input or output edge end of a node.
|
||||
// For node's input edge end, it's the source end, as the destination
|
||||
// end is the node itself.
|
||||
// For node's ouput edge end, it's the destination end, as the source
|
||||
// end is the node itself.
|
||||
class EdgeEnd
|
||||
{
|
||||
public:
|
||||
|
||||
// Constructor.
|
||||
// An EdgeEnd contains a Node pointer, a NodeArg pointer.
|
||||
// NOTE: it does not own the Node pointer and NodeArg pointer.
|
||||
EdgeEnd(const Node& p_node, const NodeArg& p_nodeArg);
|
||||
|
||||
// Get the <Node*> that this edge end refers to.
|
||||
const Node* GetNode() const;
|
||||
|
||||
// Get the <NodeArg*> that this edge end refers to.
|
||||
const NodeArg* GetNodeArg() const;
|
||||
|
||||
private:
|
||||
|
||||
const Node* m_node;
|
||||
|
||||
const NodeArg* m_nodeArg;
|
||||
};
|
||||
|
||||
// An iterator helper class for iterating a Node's neighbour nodes.
|
||||
class NodeConstIterator
|
||||
{
|
||||
public:
|
||||
|
||||
NodeConstIterator(std::set<const Node*>::const_iterator p_iter);
|
||||
|
||||
bool operator==(const NodeConstIterator& p_other) const;
|
||||
|
||||
bool operator!=(const NodeConstIterator& p_other) const;
|
||||
|
||||
void operator++();
|
||||
|
||||
const Node* operator*();
|
||||
|
||||
private:
|
||||
|
||||
std::set<const Node*>::const_iterator m_iter;
|
||||
};
|
||||
|
||||
// Get node index.
|
||||
NODEINDEX Index() const;
|
||||
|
||||
// Get node name.
|
||||
const std::string& Name() const;
|
||||
|
||||
// Get node operator type.
|
||||
const std::string& OpType() const;
|
||||
|
||||
// Get node description.
|
||||
const std::string& Description() const;
|
||||
|
||||
// Read/Write <*this> node's input args' definition, including name,
|
||||
// type and shape.
|
||||
const std::vector<NodeArg>& InputDefs() const;
|
||||
std::vector<NodeArg>& Mutable_InputDefs();
|
||||
|
||||
const std::vector<int>& InputArgCount() const;
|
||||
std::vector<int>& Mutable_InputArgCount();
|
||||
|
||||
// Read/Write <*this> node's output args' definition, including name,
|
||||
// type and shape.
|
||||
const std::vector<NodeArg>& OutputDefs() const;
|
||||
std::vector<NodeArg>& Mutable_OutputDefs();
|
||||
|
||||
// Functions defined to traverse a Graph as below.
|
||||
// Read all input nodes of <*this>.
|
||||
Node::NodeConstIterator InputNodes_begin() const;
|
||||
Node::NodeConstIterator InputNodes_end() const;
|
||||
// Read all output nodes of <*this>.
|
||||
Node::NodeConstIterator OutputNodes_begin() const;
|
||||
Node::NodeConstIterator OutputNodes_end() const;
|
||||
// Given input arg, get the source end of an input edge.
|
||||
bool InputEdgeSrcEnd(NodeArg* p_inputArg,
|
||||
/*out*/const EdgeEnd** p_inputEdgeSrcEnd) const;
|
||||
|
||||
// Add a node attribute with specified attribute name and value.
|
||||
bool AddAttribute(const std::string& p_attrName, const AttributeProto& p_value);
|
||||
|
||||
#define ADD_ATTR_INTERFACES(TypeName) \
|
||||
bool AddAttribute(const std::string& p_attrName, \
|
||||
const TypeName& p_value); \
|
||||
bool AddAttribute(const std::string& p_attrName, \
|
||||
const std::vector<TypeName>& p_values); \
|
||||
|
||||
ADD_ATTR_INTERFACES(int64_t)
|
||||
ADD_ATTR_INTERFACES(float)
|
||||
ADD_ATTR_INTERFACES(std::string)
|
||||
ADD_ATTR_INTERFACES(TensorProto)
|
||||
ADD_ATTR_INTERFACES(GraphProto)
|
||||
ADD_ATTR_INTERFACES(TypeProto)
|
||||
ADD_ATTR_INTERFACES(TypeProto::TensorShapeProto)
|
||||
|
||||
// Clear specified node attribute.
|
||||
bool ClearAttribute(const std::string& p_attrName);
|
||||
|
||||
// Get node attributes.
|
||||
const NodeAttributes& GetAttributes() const;
|
||||
|
||||
// Indicates on which we will run this node in runtime.
|
||||
// Executor will decide which device that this node will run against
|
||||
// and set it properly.
|
||||
// TODO: may change the return value type to be an ENUM.
|
||||
const std::string& Device() const;
|
||||
void SetDevice(const std::string& p_device);
|
||||
|
||||
// Get the corresponding <NodeProto>.
|
||||
void ToProto(NodeProto& p_proto) const;
|
||||
|
||||
private:
|
||||
|
||||
friend class Graph;
|
||||
|
||||
// Node could ONLY be constructed and owned by a <Graph>.
|
||||
Node() {}
|
||||
Node(NODEINDEX p_index, Graph* p_graph)
|
||||
: m_index(p_index),
|
||||
m_graph(p_graph) {}
|
||||
Node(const Node& p_other);
|
||||
|
||||
// Init node per <NodeProto>.
|
||||
// <p_nameToValueInfoMap> specifies the node's inputs'/outputs' value information,
|
||||
// including name, type and shape.
|
||||
void Init(const NodeProto& p_nodeProto,
|
||||
const ArgNameToTypeMap& p_nameToType);
|
||||
void Init(const std::string& p_name,
|
||||
const std::string& p_opType,
|
||||
const std::string& p_description,
|
||||
const std::vector<NodeArg>& p_inputArgs,
|
||||
const std::vector<NodeArg>& p_outputArgs);
|
||||
void Init(const std::string& p_name,
|
||||
const std::string& p_opType,
|
||||
const std::string& p_description,
|
||||
const std::vector<NodeArg>& p_inputArgs,
|
||||
const std::vector<int>& p_inputArgCount,
|
||||
const std::vector<NodeArg>& p_outputArgs);
|
||||
void Init(const std::string& p_name,
|
||||
const std::string& p_opType,
|
||||
const std::string& p_description,
|
||||
const std::vector<NodeArg>& p_outputArgs);
|
||||
|
||||
// Node index.
|
||||
NODEINDEX m_index;
|
||||
|
||||
// Node name.
|
||||
std::string m_name;
|
||||
|
||||
// Node operator type.
|
||||
std::string m_opType;
|
||||
|
||||
// Node doc string.
|
||||
std::string m_description;
|
||||
|
||||
// Node inputs' definition.
|
||||
std::vector<NodeArg> m_inputDefs;
|
||||
// The number of inputs for each argument of the operator or function which
|
||||
// this node refers.
|
||||
// For example, <m_inputDefs> has 10 elements (inputs), and <m_inputArgCount>
|
||||
// is {4, 6}. This means that 4 elements (inputs) of <m_inputDefs> map to the
|
||||
// first argument of the operator or function, and the other 6 map to the
|
||||
// second argument.
|
||||
std::vector<int> m_inputArgCount;
|
||||
|
||||
// Node outputs' definition.
|
||||
std::vector<NodeArg> m_outputDefs;
|
||||
|
||||
// Node inputs' instantiation.
|
||||
std::unordered_map<const NodeArg*, EdgeEnd> m_inputs;
|
||||
// Node input nodes, besides input nodes mentioned in <m_inputs> above,
|
||||
// it also contains all control input nodes;
|
||||
std::set<const Node*> m_inputNodes;
|
||||
// Control input nodes' names.
|
||||
std::set<std::string> m_controlInputs;
|
||||
// Node's output nodes.
|
||||
std::set<const Node*> m_outputNodes;
|
||||
|
||||
// Device.
|
||||
std::string m_device;
|
||||
|
||||
// Map from attribute name to attribute.
|
||||
// This allows attribute adding and removing.
|
||||
NodeAttributes m_attributes;
|
||||
|
||||
Graph* m_graph;
|
||||
};
|
||||
|
||||
// A graph representation class.
|
||||
class Graph
|
||||
{
|
||||
public:
|
||||
|
||||
// An iterator helper to access graph nodes without copy.
|
||||
// The iterator itself does not own any data.
|
||||
class NodeIterator
|
||||
{
|
||||
public:
|
||||
|
||||
// Constructor.
|
||||
NodeIterator(NODEINDEX p_currentNodeIndex, Graph* p_graph)
|
||||
: m_graph(p_graph),
|
||||
m_currentNodeIndex(p_currentNodeIndex)
|
||||
{
|
||||
}
|
||||
|
||||
bool operator==(const NodeIterator& p_other) const;
|
||||
|
||||
bool operator!=(const NodeIterator& p_other) const;
|
||||
|
||||
void operator++();
|
||||
|
||||
Node* operator*();
|
||||
|
||||
private:
|
||||
|
||||
Graph* m_graph;
|
||||
|
||||
// it's the Node Index in <m_nodes> of the <m_graph>.
|
||||
NODEINDEX m_currentNodeIndex;
|
||||
};
|
||||
|
||||
// Constructor from scratch.
|
||||
// <p_isONNX> is a special flag to indicate whether it's
|
||||
// going to construct a ONNX graph. With ONNX graph, strict
|
||||
// type checking will be skiped.
|
||||
Graph(const std::string& p_name, bool p_isONNX = false);
|
||||
Graph(const std::string& p_name, const std::string& p_docString);
|
||||
|
||||
// Constructor: Given a <GraphProto> loaded from model file, construct
|
||||
// a <Graph> object.
|
||||
Graph(const GraphProto& p_graphProto);
|
||||
|
||||
// Constructor: Given a function definition and a node which refers to
|
||||
// the function, construct a <Graph> object.
|
||||
// Normally the <p_name> could be the parent node name and the
|
||||
// <p_version> could be the parent graph's version.
|
||||
// Question: will a node defined in a function refers another function
|
||||
// please? I (Ke) am assuming we don't allow such case here for now.
|
||||
Graph(Node* p_node,
|
||||
const FunctionDefProto& p_functionProto);
|
||||
|
||||
// Resolve <*this> graph to ensure it's in a good shape with full
|
||||
// functionality.
|
||||
// 1. Run through all validation rules.
|
||||
// a. Node name and node output's names should be unique.
|
||||
// b. Attribute match between node and op definition.
|
||||
// c. Input/Output match between node and op definition.
|
||||
// d. Graph is acyclic.
|
||||
// 2. Check & Setup inner nodes' dependency.
|
||||
// 3. Cleanup function definition lists.
|
||||
// Returns resolving status.
|
||||
Status Resolve();
|
||||
|
||||
// Getter and Setter for graph name.
|
||||
const std::string& Name() const;
|
||||
void SetName(const std::string& p_name);
|
||||
|
||||
// Add/Remove/Get initial tensors for some graph inputs.
|
||||
void AddInitialTensor(const TensorProto& p_tensor);
|
||||
void RemoveInitialTensor(const std::string& p_tensorName);
|
||||
bool GetInitialTensor(const std::string& p_tensorName,
|
||||
TensorProto& p_value) const;
|
||||
const InitialTensorSet& GetAllInitialTensors() const;
|
||||
|
||||
// Add or Remove a function definition.
|
||||
bool AddFunctionDef(const FunctionDefProto& p_function);
|
||||
void RemoveFunctionDef(const std::string& p_functionName);
|
||||
|
||||
// Get node given specific node index.
|
||||
Node* GetNode(NODEINDEX p_nodeIndex);
|
||||
|
||||
// Get node iterator to access all effective nodes in the graph.
|
||||
Graph::NodeIterator Nodes_begin();
|
||||
Graph::NodeIterator Nodes_end();
|
||||
|
||||
// Max Node Index.
|
||||
NODEINDEX MaxNodeIndex() const;
|
||||
|
||||
// Number of nodes in the <Graph>.
|
||||
// This is smaller than MaxNodeIndex(), since there may be nodes
|
||||
// removed during optimization.
|
||||
int NumberOfNodes() const;
|
||||
|
||||
// Add, remove node from <*this> graph.
|
||||
Node* AddNode(const std::string& p_name,
|
||||
const std::string& p_opType,
|
||||
const std::string& p_description,
|
||||
const std::vector<NodeArg>& p_inputArgs,
|
||||
const std::vector<NodeArg>& p_outputArgs);
|
||||
Node* AddNode(const std::string& p_name,
|
||||
const std::string& p_opType,
|
||||
const std::string& p_description,
|
||||
const std::vector<NodeArg>& p_inputArgs,
|
||||
const std::vector<int>& p_inputArgCount,
|
||||
const std::vector<NodeArg>& p_outputArgs);
|
||||
Node* AddNode(const std::string& p_name,
|
||||
const std::string& p_opType,
|
||||
const std::string& p_description,
|
||||
const std::vector<NodeArg>& p_outputArgs);
|
||||
Node* AddNode(const Node& p_other);
|
||||
bool RemoveNode(NODEINDEX p_nodeIndex);
|
||||
|
||||
// Convenience method for adding a constant op
|
||||
Node* AddConstantNode(const std::string& p_name,
|
||||
const std::string& p_description,
|
||||
const std::vector<NodeArg>& p_outputArgs,
|
||||
const TensorProto& p_tensor);
|
||||
|
||||
// Add control edge into <*this> graph.
|
||||
// The <p_dstNodeIndex> node does not consume any data output by
|
||||
// <p_srcNodeIndex>, but it's designed to be executed behind.
|
||||
bool AddControlEdge(NODEINDEX p_srcNodeIndex, NODEINDEX p_dstNodeIndex);
|
||||
|
||||
// Try to get function with specified <p_nodeIndex>. Return true if the
|
||||
// specified node refers to a function, and <p_function> will be the
|
||||
// function; false otherwise, and <p_function> will be unchanged.
|
||||
bool TryGetFunction(NODEINDEX p_nodeIndex,
|
||||
/*out*/ Function** p_function);
|
||||
|
||||
// Serialize the <Graph> into <GraphProto>.
|
||||
const GraphProto& ToGraphProto();
|
||||
|
||||
// Serialize the <Graph> into <FunctionDefProto>.
|
||||
// This is used when the graph is a subgraph of a main graph.
|
||||
const FunctionDefProto& ToFuncProto();
|
||||
|
||||
// Inline all function in <*this> and construct <p_graph>
|
||||
// without any functions. <p_graph> owned by caller.
|
||||
bool InlineAllFunctions(/*out*/Graph* p_graph) const;
|
||||
|
||||
bool IsSourceNode(NODEINDEX p_index) const;
|
||||
bool IsSinkNode(NODEINDEX p_index) const;
|
||||
|
||||
const Node* SourceNode() const;
|
||||
const Node* SinkNode() const;
|
||||
|
||||
Status GetNodesInTopologicalOrder(std::vector<NODEINDEX>** nodes);
|
||||
|
||||
private:
|
||||
|
||||
enum Type
|
||||
{
|
||||
// A main graph.
|
||||
Main = 1,
|
||||
// A sub graph (function).
|
||||
Sub = 2,
|
||||
// A graph with strict type checking.
|
||||
Strict = 4,
|
||||
};
|
||||
|
||||
friend class Node;
|
||||
|
||||
Node* AllocateNode();
|
||||
void ReleaseNode(NODEINDEX p_nodeIndex);
|
||||
|
||||
// Add node with specified <p_nodeProto>.
|
||||
Node* AddNode(const NodeProto& p_nodeProto,
|
||||
const ArgNameToTypeMap& p_nameToType);
|
||||
|
||||
Status VerifyNoDuplicateName(
|
||||
/*out*/ std::unordered_map<std::string, Node::EdgeEnd>& p_outputArgs,
|
||||
/*out*/ std::unordered_map<std::string, NODEINDEX>& p_nodeNameToIndex);
|
||||
|
||||
// Build and verify node connection (edges).
|
||||
// Verify NodeArg name/type/shape matching correctly.
|
||||
Status BuildConnections(
|
||||
const std::unordered_map<std::string, Node::EdgeEnd>& p_outputArgs,
|
||||
const std::unordered_map<std::string, NODEINDEX>& p_nodeNameToIndex);
|
||||
|
||||
// Check whether <*this> graph is acyclic.
|
||||
// Depth-first going thru the graph and check whether there's any back
|
||||
// edge.
|
||||
// <p_nodesInToplogicalOrder> returns nodes' indexes in toplogical
|
||||
// order if <Status> returned is "OK", otherwise it's undefined.
|
||||
Status CheckIsAcyclic(
|
||||
/*out*/std::vector<NODEINDEX>& p_nodesInToplogicalOrder);
|
||||
|
||||
// Depth-first graph access.
|
||||
// <p_ancestors> specifies all ancestor nodes of <p_current> node.
|
||||
// <p_current> specifies current node being accessed.
|
||||
// <p_visitedNodes> specifies nodes already visited.
|
||||
// <p_nodesInToplogicalOrder> returns nodes' indexes in toplogical
|
||||
// order if the graph is acyclic.
|
||||
Status DepthFirstAccess(std::unordered_set<NODEINDEX> p_ancestors,
|
||||
NODEINDEX p_current,
|
||||
/*in | out*/std::unordered_set<NODEINDEX>& p_visitedNodes,
|
||||
/*out*/std::vector<NODEINDEX>& p_nodesInToplogicalOrder);
|
||||
|
||||
// Given nodes in toplogical order, infer and set type information
|
||||
// across <*this> graph if needed, and verify type/attribute
|
||||
// information match between node and op.
|
||||
Status VerifyNodeAndOpMatch(
|
||||
const std::vector<NODEINDEX>& p_nodesInToplogicalOrder,
|
||||
std::unordered_map<std::string, Node::EdgeEnd>& p_outputArgs,
|
||||
/*out*/ std::set<std::string>& p_funcDefNames);
|
||||
|
||||
Status InferAndVerifyTypeMatch(Node* p_node,
|
||||
const OpSignature* p_op,
|
||||
const std::unordered_map<std::string, Node::EdgeEnd>& p_outputArgs);
|
||||
|
||||
// Clean function definition map.
|
||||
// Remove function definitions not refered by any node.
|
||||
void CleanFunctionDefMap(const std::set<std::string>& p_funcDefNames);
|
||||
|
||||
// Add source/sink nodes to <*this> graph.
|
||||
void AddSourceSinkNodes();
|
||||
|
||||
// Set graph inputs/outputs when serializing to proto.
|
||||
void SetGraphInputsOutputs();
|
||||
|
||||
// Graph nodes.
|
||||
// Element in <m_nodes> may be nullptr due to graph optimization.
|
||||
std::vector<std::unique_ptr<Node>> m_nodes;
|
||||
|
||||
// Number of nodes.
|
||||
// Normally this is smaller than the size of <m_nodes>, as some
|
||||
// elements in <m_nodes> may be removed when doing graph optimization,
|
||||
// or some elements may be merged, etc.
|
||||
int m_numOfNodes;
|
||||
|
||||
NODEINDEX m_sourceNodeIndex;
|
||||
NODEINDEX m_sinkNodeIndex;
|
||||
|
||||
// GraphProto to store name, version, initializer.
|
||||
// When serilizing <*this> Graph to a GraphProto, the nodes and
|
||||
// functions in <Graph> will also be fed into <m_graphProto> so that
|
||||
// it's consistent with <*this> graph.
|
||||
GraphProto m_graphProto;
|
||||
FunctionDefProto m_funcDefProto;
|
||||
|
||||
// The node which refers to <*this> graph (Function).
|
||||
Node* m_node;
|
||||
|
||||
// Graph function instantiations.
|
||||
std::unordered_map<std::string,
|
||||
std::unique_ptr<Function>> m_functionMap;
|
||||
|
||||
// Graph function definitions.
|
||||
std::unordered_map<std::string, FunctionDefProto> m_funcDefMap;
|
||||
|
||||
InitialTensorSet m_nameToInitialTensor;
|
||||
|
||||
// A flag indicates whether <*this> graph needs to be resolved.
|
||||
bool m_graphResolveNeeded;
|
||||
|
||||
bool m_graphProtoSyncNeeded;
|
||||
|
||||
int m_graphType = 0;
|
||||
|
||||
// the topologic order of node index
|
||||
std::vector<NODEINDEX> m_nodesInTopologicalOrder;
|
||||
};
|
||||
}
|
||||
|
||||
#endif // CORE_GRAPH_GRAPH_H
|
||||
|
||||
#pragma warning(pop)
|
|
@ -0,0 +1,288 @@
|
|||
#pragma warning(push)
|
||||
#pragma warning(disable : 4800 4610 4512 4510 4267 4127 4125 4100 4456 4189 4996)
|
||||
|
||||
#include <fcntl.h>
|
||||
#include <fstream>
|
||||
#include <google/protobuf/io/coded_stream.h>
|
||||
#include <google/protobuf/io/zero_copy_stream_impl.h>
|
||||
#ifdef _WIN32
|
||||
#include <io.h>
|
||||
#else
|
||||
#include <sys/io.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
#include "model.h"
|
||||
|
||||
namespace
|
||||
{
|
||||
#ifdef _WIN32
|
||||
inline int FileOpenRd(const std::wstring& p_path)
|
||||
{
|
||||
int fd = -1;
|
||||
bool err = _wsopen_s(&fd, p_path.c_str(), _O_RDONLY | _O_SEQUENTIAL | _O_BINARY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
|
||||
return fd;
|
||||
}
|
||||
|
||||
inline int FileOpenWr(const std::wstring& p_path)
|
||||
{
|
||||
int fd = -1;
|
||||
_wsopen_s(&fd, p_path.c_str(), _O_CREAT | _O_SEQUENTIAL | _O_BINARY | _O_WRONLY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
|
||||
return fd;
|
||||
}
|
||||
#endif
|
||||
|
||||
inline int FileOpenRd(const std::string& p_path)
|
||||
{
|
||||
#ifdef _WIN32
|
||||
int fd = -1;
|
||||
_sopen_s(&fd, p_path.c_str(), _O_RDONLY | _O_SEQUENTIAL | _O_BINARY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
|
||||
return fd;
|
||||
#else
|
||||
return open(p_path.c_str(), O_RDONLY);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline int FileOpenWr(const std::string& p_path)
|
||||
{
|
||||
#ifdef _WIN32
|
||||
int fd = -1;
|
||||
_sopen_s(&fd, p_path.c_str(), _O_CREAT | _O_SEQUENTIAL | _O_BINARY | _O_WRONLY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
|
||||
return fd;
|
||||
#else
|
||||
return open(p_path.c_str(), O_WRONLY | O_CREAT | O_TRUNC, 0644);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
Model::Model(const std::string& p_graphName, bool p_isONNX)
|
||||
{
|
||||
m_graph.reset(new Graph(p_graphName, p_isONNX));
|
||||
}
|
||||
|
||||
Model::Model(const std::string& p_graphName,
|
||||
const std::string& p_graphDocString)
|
||||
{
|
||||
m_graph.reset(new Graph(p_graphName, p_graphDocString));
|
||||
}
|
||||
|
||||
Model::Model(const std::string& p_graphName,
|
||||
const std::string& p_graphDocString,
|
||||
VERSION p_irVersion,
|
||||
const std::string& p_producerName,
|
||||
const std::string& p_producerVersion,
|
||||
const std::string& p_domain,
|
||||
VERSION p_modelVersion,
|
||||
const std::string& p_docString,
|
||||
const std::string& p_modelAuthor,
|
||||
const std::string& p_modelLicense)
|
||||
{
|
||||
m_graph.reset(new Graph(p_graphName, p_graphDocString));
|
||||
m_modelProto.set_ir_version(p_irVersion);
|
||||
m_modelProto.set_producer_name(p_producerName);
|
||||
m_modelProto.set_producer_version(p_producerVersion);
|
||||
m_modelProto.set_domain(p_domain);
|
||||
m_modelProto.set_model_version(p_modelVersion);
|
||||
m_modelProto.set_doc_string(p_docString);
|
||||
m_modelProto.set_model_author(p_modelAuthor);
|
||||
m_modelProto.set_model_license(p_modelLicense);
|
||||
}
|
||||
|
||||
Model::Model(const ModelProto& p_modelProto)
|
||||
{
|
||||
m_modelProto = p_modelProto;
|
||||
if (m_modelProto.has_graph())
|
||||
{
|
||||
m_graph.reset(new Graph(m_modelProto.graph()));
|
||||
}
|
||||
}
|
||||
|
||||
VERSION Model::IrVersion() const
|
||||
{
|
||||
if (m_modelProto.has_ir_version())
|
||||
{
|
||||
return m_modelProto.ir_version();
|
||||
}
|
||||
return c_noVersion;
|
||||
}
|
||||
|
||||
void Model::SetIrVersion(VERSION p_irVersion)
|
||||
{
|
||||
m_modelProto.set_ir_version(p_irVersion);
|
||||
}
|
||||
|
||||
const std::string& Model::ProducerName() const
|
||||
{
|
||||
return m_modelProto.producer_name();
|
||||
}
|
||||
|
||||
void Model::SetProducerName(const std::string& p_producerName)
|
||||
{
|
||||
m_modelProto.set_producer_name(p_producerName);
|
||||
}
|
||||
|
||||
const std::string& Model::ProducerVersion() const
|
||||
{
|
||||
return m_modelProto.producer_version();
|
||||
}
|
||||
|
||||
void Model::SetProducerVersion(const std::string& p_producerVersion)
|
||||
{
|
||||
m_modelProto.set_producer_version(p_producerVersion);
|
||||
}
|
||||
|
||||
const std::string& Model::Domain() const
|
||||
{
|
||||
return m_modelProto.domain();
|
||||
}
|
||||
|
||||
void Model::SetDomain(const std::string& p_domain)
|
||||
{
|
||||
m_modelProto.set_domain(p_domain);
|
||||
}
|
||||
|
||||
VERSION Model::ModelVersion() const
|
||||
{
|
||||
if (m_modelProto.has_model_version())
|
||||
{
|
||||
return m_modelProto.model_version();
|
||||
}
|
||||
return c_noVersion;
|
||||
}
|
||||
|
||||
void Model::SetModelversion(VERSION p_modelVersion)
|
||||
{
|
||||
m_modelProto.set_model_version(p_modelVersion);
|
||||
}
|
||||
|
||||
const std::string& Model::DocString() const
|
||||
{
|
||||
return m_modelProto.doc_string();
|
||||
}
|
||||
|
||||
void Model::SetDocString(const std::string& p_docString)
|
||||
{
|
||||
m_modelProto.set_doc_string(p_docString);
|
||||
}
|
||||
|
||||
const std::string& Model::ModelAuthor() const
|
||||
{
|
||||
return m_modelProto.model_author();
|
||||
}
|
||||
|
||||
void Model::SetModelAuthor(const std::string& p_modelAuthor)
|
||||
{
|
||||
m_modelProto.set_model_author(p_modelAuthor);
|
||||
}
|
||||
|
||||
const std::string& Model::ModelLicense() const
|
||||
{
|
||||
return m_modelProto.model_license();
|
||||
}
|
||||
|
||||
void Model::SetModelLicense(const std::string& p_modelLicense)
|
||||
{
|
||||
m_modelProto.set_model_license(p_modelLicense);
|
||||
}
|
||||
|
||||
Graph* Model::MainGraph()
|
||||
{
|
||||
return m_graph.get();
|
||||
}
|
||||
|
||||
const ModelProto& Model::ToProto()
|
||||
{
|
||||
*(m_modelProto.mutable_graph()) = m_graph->ToGraphProto();
|
||||
return m_modelProto;
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
bool Model::Load(const std::wstring& p_filePath, /*out*/ ModelProto* p_modelProto)
|
||||
{
|
||||
return Load(FileOpenRd(p_filePath), p_modelProto);
|
||||
}
|
||||
std::shared_ptr<Model> Model::Load(const std::wstring& p_filePath)
|
||||
{
|
||||
return Load(FileOpenRd(p_filePath));
|
||||
}
|
||||
bool Model::Save(Model& p_model, const std::wstring& p_filePath)
|
||||
{
|
||||
return Save(p_model.ToProto(), FileOpenWr(p_filePath));
|
||||
}
|
||||
bool Model::Save(const ModelProto& p_modelProto, const std::wstring& p_filePath)
|
||||
{
|
||||
return Save(p_modelProto, FileOpenWr(p_filePath));
|
||||
}
|
||||
#endif
|
||||
|
||||
bool Model::Load(const std::string& p_filePath, /*out*/ ModelProto* p_modelProto)
|
||||
{
|
||||
return Load(FileOpenRd(p_filePath), p_modelProto);
|
||||
}
|
||||
std::shared_ptr<Model> Model::Load(const std::string& p_filePath)
|
||||
{
|
||||
return Load(FileOpenRd(p_filePath));
|
||||
}
|
||||
bool Model::Save(Model& p_model, const std::string& p_filePath)
|
||||
{
|
||||
return Save(p_model.ToProto(), FileOpenWr(p_filePath));
|
||||
}
|
||||
bool Model::Save(const ModelProto& p_modelProto, const std::string& p_filePath)
|
||||
{
|
||||
return Save(p_modelProto, FileOpenWr(p_filePath));
|
||||
}
|
||||
|
||||
using ::google::protobuf::io::ZeroCopyInputStream;
|
||||
using ::google::protobuf::io::FileInputStream;
|
||||
using ::google::protobuf::io::CodedInputStream;
|
||||
bool Model::Load(int p_fd, /*out*/ ModelProto* p_modelProto)
|
||||
{
|
||||
if (nullptr == p_modelProto || p_fd < 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
std::unique_ptr<ZeroCopyInputStream> raw_input(new FileInputStream(p_fd));
|
||||
std::unique_ptr<CodedInputStream> coded_input(
|
||||
new CodedInputStream(raw_input.get()));
|
||||
// Allows protobuf library versions < 3.2.0 to parse messages greater than 64MB.
|
||||
coded_input->SetTotalBytesLimit(INT_MAX, INT_MAX);
|
||||
bool result = p_modelProto->ParseFromCodedStream(coded_input.get());
|
||||
coded_input.reset();
|
||||
raw_input.reset();
|
||||
close(p_fd);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::shared_ptr<Model> Model::Load(int p_fd)
|
||||
{
|
||||
ModelProto modelProto;
|
||||
bool result = Load(p_fd, &modelProto);
|
||||
if (!result || p_fd < 0)
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
auto model = std::shared_ptr<Model>(new Model(modelProto));
|
||||
auto status = model->MainGraph()->Resolve();
|
||||
|
||||
close(p_fd);
|
||||
if (status.Ok())
|
||||
{
|
||||
return model;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool Model::Save(const ModelProto& p_modelProto, int p_fd)
|
||||
{
|
||||
if (p_fd < 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
bool result = p_modelProto.SerializeToFileDescriptor(p_fd);
|
||||
close(p_fd);
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma warning(pop)
|
|
@ -0,0 +1,121 @@
|
|||
#ifndef CORE_GRAPH_MODEL_H
|
||||
#define CORE_GRAPH_MODEL_H
|
||||
|
||||
#include "graph.h"
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
// A machine learning model representation class.
|
||||
// Besides a main <Graph>, it also holds basic information, say,
|
||||
// version, domain, model author, license etc.
|
||||
class Model
|
||||
{
|
||||
public:
|
||||
|
||||
const VERSION c_noVersion = INT64_MAX;
|
||||
|
||||
// <p_isONNX> is a special flag to indicate whether it's
|
||||
// going to construct a ONNX graph. With ONNX graph, strict
|
||||
// type checking will be skiped.
|
||||
Model(const std::string& p_graphName, bool p_isONNX = false);
|
||||
|
||||
Model(const std::string& p_graphName,
|
||||
const std::string& p_graphDocString);
|
||||
|
||||
Model(const std::string& p_graphName,
|
||||
const std::string& p_graphDocString,
|
||||
VERSION p_irVersion,
|
||||
const std::string& p_producerName,
|
||||
const std::string& p_producerVersion,
|
||||
const std::string& p_domain,
|
||||
VERSION p_modelVersion,
|
||||
const std::string& p_modelDocString,
|
||||
const std::string& p_modelAuthor,
|
||||
const std::string& p_modelLicense);
|
||||
|
||||
Model(const ModelProto& p_modelProto);
|
||||
|
||||
// Get model's IR version.
|
||||
// Return <c_noVersion> if not specified.
|
||||
VERSION IrVersion() const;
|
||||
// Set model's IR version.
|
||||
void SetIrVersion(VERSION p_irVersion);
|
||||
|
||||
// Get model's producer name.
|
||||
// Return null pointer if not specified.
|
||||
const std::string& ProducerName() const;
|
||||
// Set model's producer name.
|
||||
void SetProducerName(const std::string& p_producerName);
|
||||
|
||||
// Get model's producer version.
|
||||
// Return null pointer if not specified.
|
||||
const std::string& ProducerVersion() const;
|
||||
// Set model's producer version.
|
||||
void SetProducerVersion(const std::string& p_producerVersion);
|
||||
|
||||
// Get model's domain.
|
||||
// Return null pointer if not specified.
|
||||
const std::string& Domain() const;
|
||||
// Set models' damain.
|
||||
void SetDomain(const std::string& p_domain);
|
||||
|
||||
// Get model's version.
|
||||
// Return null pointer if not specified.
|
||||
VERSION ModelVersion() const;
|
||||
// Set models' version.
|
||||
void SetModelversion(VERSION p_modelVersion);
|
||||
|
||||
// Get model's doc string.
|
||||
// Return null pointer if not specified.
|
||||
const std::string& DocString() const;
|
||||
// Set models' doc string.
|
||||
void SetDocString(const std::string& p_docString);
|
||||
|
||||
// Get model's author.
|
||||
// Return null pointer if not specified.
|
||||
const std::string& ModelAuthor() const;
|
||||
// Set models' author.
|
||||
void SetModelAuthor(const std::string& p_modelAuthor);
|
||||
|
||||
// Get model's license.
|
||||
// Return null pointer if not specified.
|
||||
const std::string& ModelLicense() const;
|
||||
// Set models' license.
|
||||
void SetModelLicense(const std::string& p_modelLicense);
|
||||
|
||||
// Get model's main graph.
|
||||
// The return pointer is owned by <*this> model.
|
||||
Graph* MainGraph();
|
||||
|
||||
// Get model's serlization proto data.
|
||||
const ModelProto& ToProto();
|
||||
|
||||
#ifdef _WIN32
|
||||
// wstring versions for Windows only.
|
||||
static bool Save(const ModelProto& p_modelProto, const std::wstring& p_filePath);
|
||||
static bool Save(Model& p_model, const std::wstring& p_filePath);
|
||||
// Load a ModelProto from a file.
|
||||
static bool Load(const std::wstring& p_filePath, /*out*/ ModelProto* p_modelProto);
|
||||
static std::shared_ptr<Model> Load(const std::wstring& p_filePath);
|
||||
#endif
|
||||
// Save a ModelProto to a file.
|
||||
static bool Save(const ModelProto& p_modelProto, const std::string& p_filePath);
|
||||
static bool Save(Model& p_model, const std::string& p_filePath);
|
||||
static bool Save(const ModelProto& p_modelProto, int p_fd);
|
||||
// Load a ModelProto from a file.
|
||||
static bool Load(const std::string& p_filePath, /*out*/ ModelProto* p_modelProto);
|
||||
static std::shared_ptr<Model> Load(const std::string& p_filePath);
|
||||
static bool Load(int p_fd, /*out*/ ModelProto* p_modelProto);
|
||||
static std::shared_ptr<Model> Load(int p_fd);
|
||||
|
||||
private:
|
||||
|
||||
// Model data.
|
||||
ModelProto m_modelProto;
|
||||
|
||||
// Main graph of the model.
|
||||
std::unique_ptr<Graph> m_graph;
|
||||
};
|
||||
}
|
||||
|
||||
#endif
|
|
@ -0,0 +1,354 @@
|
|||
#pragma warning(push)
|
||||
#pragma warning(disable : 4800 4610 4512 4510 4267 4127 4125 4100 4456)
|
||||
|
||||
#include "op.h"
|
||||
#include "opsignature.h"
|
||||
#include "utils.h"
|
||||
#include <cstring>
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
const std::string& OperatorSchema::GetName() const
|
||||
{
|
||||
return m_opSignature.GetName();
|
||||
}
|
||||
|
||||
const OpSignature& OperatorSchema::GetOpSignature() const
|
||||
{
|
||||
return m_opSignature;
|
||||
}
|
||||
|
||||
ShapeInferenceFunc OperatorSchema::GetShapeInferenceFn() const
|
||||
{
|
||||
return m_shapeInferenceFunc;
|
||||
}
|
||||
|
||||
AttributeParser OperatorSchema::GetAttributeParser() const
|
||||
{
|
||||
return m_attrParser;
|
||||
}
|
||||
|
||||
OperatorSchemaSetter&
|
||||
OperatorSchemaSetter::Name(const std::string& p_opName)
|
||||
{
|
||||
m_opSchema.m_opSignature.m_name = p_opName;
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
|
||||
OperatorSchemaSetter&
|
||||
OperatorSchemaSetter::Description(const std::string& p_description)
|
||||
{
|
||||
m_opSchema.m_opSignature.m_description = p_description;
|
||||
return *this;
|
||||
}
|
||||
|
||||
OperatorSchemaSetter&
|
||||
OperatorSchemaSetter::Input(const std::string& p_inputName,
|
||||
const std::string& p_description,
|
||||
const std::string& p_type)
|
||||
{
|
||||
m_inputs.push_back(std::make_tuple(p_inputName, p_description, p_type));
|
||||
return *this;
|
||||
}
|
||||
|
||||
OperatorSchemaSetter&
|
||||
OperatorSchemaSetter::Output(const std::string& p_outputName,
|
||||
const std::string& p_description,
|
||||
const std::string& p_type)
|
||||
{
|
||||
m_outputs.push_back(std::make_tuple(p_outputName, p_description, p_type));
|
||||
return *this;
|
||||
}
|
||||
|
||||
OperatorSchemaSetter&
|
||||
OperatorSchemaSetter::Attr(const std::string& p_attrName,
|
||||
const std::string& p_description,
|
||||
AttrType p_attrType, bool required)
|
||||
{
|
||||
m_opSchema.m_opSignature.m_attributes.push_back(
|
||||
OpSignature::Attribute(p_attrName, p_attrType, p_description));
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
#define ATTR_SETTER_BASIC_IMPL(type, field) \
|
||||
OperatorSchemaSetter& \
|
||||
OperatorSchemaSetter::Attr(const std::string& p_attrName, \
|
||||
const std::string& p_description, \
|
||||
AttrType p_attrType, \
|
||||
const type& p_defaultValue) \
|
||||
{ \
|
||||
AttributeProto a; \
|
||||
a.set_name(p_attrName); \
|
||||
a.set_##field(p_defaultValue); \
|
||||
\
|
||||
m_opSchema.m_opSignature.m_attributes.push_back( \
|
||||
OpSignature::Attribute(p_attrName, \
|
||||
p_attrType, \
|
||||
p_description, \
|
||||
a)); \
|
||||
\
|
||||
return *this; \
|
||||
} \
|
||||
|
||||
#define ATTR_SETTER_LIST_IMPL(type, field) \
|
||||
OperatorSchemaSetter& \
|
||||
OperatorSchemaSetter::Attr(const std::string& p_attrName, \
|
||||
const std::string& p_description, \
|
||||
AttrType p_attrType, \
|
||||
const std::vector<type>& p_defaultValue) \
|
||||
{ \
|
||||
AttributeProto a; \
|
||||
a.set_name(p_attrName); \
|
||||
for (const auto& v : p_defaultValue) \
|
||||
{ \
|
||||
a.add_##field(v); \
|
||||
} \
|
||||
\
|
||||
m_opSchema.m_opSignature.m_attributes.push_back( \
|
||||
OpSignature::Attribute(p_attrName, \
|
||||
p_attrType, \
|
||||
p_description, \
|
||||
a)); \
|
||||
return *this; \
|
||||
} \
|
||||
|
||||
ATTR_SETTER_BASIC_IMPL(int64_t, i)
|
||||
ATTR_SETTER_BASIC_IMPL(float, f)
|
||||
ATTR_SETTER_BASIC_IMPL(std::string, s)
|
||||
ATTR_SETTER_LIST_IMPL(int64_t, ints)
|
||||
ATTR_SETTER_LIST_IMPL(float, floats)
|
||||
ATTR_SETTER_LIST_IMPL(std::string, strings)
|
||||
|
||||
OperatorSchemaSetter&
|
||||
OperatorSchemaSetter::TypeConstraint(const std::string& p_typeName,
|
||||
const std::vector<std::string>& p_constraints,
|
||||
const std::string& p_description)
|
||||
{
|
||||
m_constraints.push_back(std::make_tuple(p_typeName, p_constraints, p_description));
|
||||
return *this;
|
||||
}
|
||||
|
||||
OperatorSchemaSetter&
|
||||
OperatorSchemaSetter::SetShapeInferenceFunc(
|
||||
ShapeInferenceFunc p_shapeInferFunc)
|
||||
{
|
||||
m_opSchema.m_shapeInferenceFunc = p_shapeInferFunc;
|
||||
return *this;
|
||||
}
|
||||
|
||||
OperatorSchemaSetter&
|
||||
OperatorSchemaSetter::SetAttributeParser(
|
||||
AttributeParser p_attrParser)
|
||||
{
|
||||
m_opSchema.m_attrParser = p_attrParser;
|
||||
return *this;
|
||||
}
|
||||
|
||||
OperatorSchemaRegistry::RegisterOnce::RegisterOnce(
|
||||
OperatorSchemaSetter& p_opSchemaSetter)
|
||||
{
|
||||
auto& opSchema = p_opSchemaSetter.m_opSchema;
|
||||
// Process type constraints.
|
||||
for (const auto& constraint : p_opSchemaSetter.m_constraints)
|
||||
{
|
||||
std::string name;
|
||||
std::vector<std::string> types;
|
||||
std::string desc;
|
||||
std::tie(name, types, desc) = constraint;
|
||||
|
||||
auto it = opSchema.m_opSignature.m_typeConstraintMap.find(name);
|
||||
if (it == opSchema.m_opSignature.m_typeConstraintMap.end())
|
||||
{
|
||||
DataTypeSet d;
|
||||
for (const auto& t : types)
|
||||
{
|
||||
d.insert(Utils::OpUtils::ToType(t));
|
||||
}
|
||||
opSchema.m_opSignature.m_typeConstraintMap.insert(std::make_pair(name, std::make_pair(d, desc)));
|
||||
}
|
||||
else
|
||||
{
|
||||
// already a constraint with the same name. error.
|
||||
}
|
||||
}
|
||||
|
||||
opSchema.m_opSignature.m_inputs.reserve(p_opSchemaSetter.m_inputs.size());
|
||||
for (const auto& input : p_opSchemaSetter.m_inputs)
|
||||
{
|
||||
std::string name;
|
||||
std::string type;
|
||||
std::string desc;
|
||||
std::tie(name, desc, type) = input;
|
||||
opSchema.m_opSignature.m_inputs.push_back(
|
||||
OpSignature::FormalParameter(name, type, desc, opSchema.m_opSignature.m_typeConstraintMap));
|
||||
}
|
||||
|
||||
opSchema.m_opSignature.m_outputs.reserve(p_opSchemaSetter.m_outputs.size());
|
||||
for (const auto& output : p_opSchemaSetter.m_outputs)
|
||||
{
|
||||
std::string name;
|
||||
std::string type;
|
||||
std::string desc;
|
||||
std::tie(name, desc, type) = output;
|
||||
opSchema.m_opSignature.m_outputs.push_back(
|
||||
OpSignature::FormalParameter(name, type, desc,
|
||||
opSchema.m_opSignature.m_typeConstraintMap));
|
||||
}
|
||||
|
||||
auto& opSignature = p_opSchemaSetter.m_opSchema.m_opSignature;
|
||||
if (0 == opSignature.m_inputs.size())
|
||||
{
|
||||
for (int i = 0; i < opSignature.m_onnxMinInput; ++i)
|
||||
{
|
||||
std::string name = "p" + std::to_string(i);
|
||||
std::string desc = "Input Parameter " + std::to_string(i);
|
||||
opSignature.m_inputs.push_back(
|
||||
OpSignature::FormalParameter(name, "", desc, opSignature.m_typeConstraintMap));
|
||||
}
|
||||
}
|
||||
|
||||
if (0 == opSignature.m_outputs.size())
|
||||
{
|
||||
for (int i = 0; i < opSignature.m_onnxMinOutput; ++i)
|
||||
{
|
||||
std::string name = "p" + std::to_string(i);
|
||||
std::string desc = "Output Result " + std::to_string(i);
|
||||
opSignature.m_outputs.push_back(
|
||||
OpSignature::FormalParameter(name, "", desc, opSignature.m_typeConstraintMap));
|
||||
}
|
||||
}
|
||||
OperatorSchemaRegistry::Get()->Register(p_opSchemaSetter.m_opSchema);
|
||||
}
|
||||
|
||||
bool OperatorSchemaRegistry::TryGetOp(const std::string& p_name,
|
||||
const OperatorSchema** p_opSchema) const
|
||||
{
|
||||
if (nullptr == p_opSchema)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
auto iter = m_opNameToOpSchemaMap.find(p_name);
|
||||
if (m_opNameToOpSchemaMap.end() == iter)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
*p_opSchema = &(iter->second);
|
||||
return true;
|
||||
}
|
||||
|
||||
Status OperatorSchemaRegistry::Register(
|
||||
const OperatorSchema& p_opSchema)
|
||||
{
|
||||
auto iter = m_opNameToOpSchemaMap.find(p_opSchema.GetName());
|
||||
if (m_opNameToOpSchemaMap.end() != iter)
|
||||
{
|
||||
Status status(false,
|
||||
"Error: operator schema with same name ("
|
||||
+ p_opSchema.GetName() + ") exists.");
|
||||
return status;
|
||||
}
|
||||
else
|
||||
{
|
||||
m_opNameToOpSchemaMap[p_opSchema.GetName()] = p_opSchema;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
OperatorSchemaRegistry* OperatorSchemaRegistry::Get()
|
||||
{
|
||||
static OperatorSchemaRegistry* s_registry
|
||||
= new OperatorSchemaRegistry();
|
||||
return s_registry;
|
||||
}
|
||||
|
||||
Status TypeUtils::GetType(const AttributeProto& p_attr, AttrType& p_type)
|
||||
{
|
||||
if (!OpSignature::IsValidAttribute(p_attr))
|
||||
{
|
||||
return Status(false, "Invalid AttributeProto.");
|
||||
}
|
||||
|
||||
if (p_attr.has_f())
|
||||
{
|
||||
p_type = AttrType::FLOAT;
|
||||
}
|
||||
else if (p_attr.has_i())
|
||||
{
|
||||
p_type = AttrType::INT;
|
||||
}
|
||||
else if (p_attr.has_s())
|
||||
{
|
||||
p_type = AttrType::STRING;
|
||||
}
|
||||
else if (p_attr.has_t())
|
||||
{
|
||||
p_type = AttrType::TENSOR;
|
||||
}
|
||||
else if (p_attr.has_g())
|
||||
{
|
||||
p_type = AttrType::GRAPH;
|
||||
}
|
||||
else if (p_attr.floats_size())
|
||||
{
|
||||
p_type = AttrType::FLOATS;
|
||||
}
|
||||
else if (p_attr.ints_size())
|
||||
{
|
||||
p_type = AttrType::INTS;
|
||||
}
|
||||
else if (p_attr.strings_size())
|
||||
{
|
||||
p_type = AttrType::STRINGS;
|
||||
}
|
||||
else if (p_attr.tensors_size())
|
||||
{
|
||||
p_type = AttrType::TENSORS;
|
||||
}
|
||||
else if (p_attr.graphs_size())
|
||||
{
|
||||
p_type = AttrType::GRAPHS;
|
||||
}
|
||||
else if (p_attr.has_type())
|
||||
{
|
||||
p_type = AttrType::TYPE;
|
||||
}
|
||||
else if (p_attr.types_size())
|
||||
{
|
||||
p_type = AttrType::TYPES;
|
||||
}
|
||||
else if (p_attr.has_shape())
|
||||
{
|
||||
p_type = AttrType::SHAPE;
|
||||
}
|
||||
else if (p_attr.has_shape())
|
||||
{
|
||||
p_type = AttrType::SHAPES;
|
||||
}
|
||||
else
|
||||
{
|
||||
p_type = AttrType::NONE;
|
||||
return Status(false, "Invalid AttributeProto.");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
size_t ReplaceAll(std::string& s, const char* from, const char* to)
|
||||
{
|
||||
size_t numReplaced = 0;
|
||||
std::string::size_type lenFrom = std::strlen(from);
|
||||
std::string::size_type lenTo = std::strlen(to);
|
||||
for (std::string::size_type pos = s.find(from); pos != std::string::npos;
|
||||
pos = s.find(from, pos + lenTo)) {
|
||||
s.replace(pos, lenFrom, to);
|
||||
numReplaced++;
|
||||
}
|
||||
return numReplaced;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma warning(pop)
|
|
@ -0,0 +1,264 @@
|
|||
#ifndef CORE_GRAPH_OP_H
|
||||
#define CORE_GRAPH_OP_H
|
||||
|
||||
#include <functional>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "opsignature.h"
|
||||
#include "shape_inference.h"
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
class OpSignature;
|
||||
class OperatorSchemaSetter;
|
||||
typedef OperatorSchemaSetter OpSchema;
|
||||
|
||||
class TypeUtils
|
||||
{
|
||||
public:
|
||||
|
||||
// Get attribute type given attribute proto data.
|
||||
static Status GetType(const AttributeProto& p_attr, AttrType& p_type);
|
||||
|
||||
};
|
||||
|
||||
// An attribute parser - it's specified when registering an operator.
|
||||
// The parser is designed and used in two ways.
|
||||
// 1) It will be used to verify whether a Node's attributes match the
|
||||
// operator's definition.
|
||||
// 2) It will be used to parse a Node's attributes into a <T> object,
|
||||
// which makes it be easier to access node attributes.
|
||||
// TODO: to implement the 2nd point above, NodeAttributes should be changed
|
||||
// to contain a <T> field, which is structured attributes.
|
||||
typedef std::function<Status(const NodeAttributes&)> AttributeParser;
|
||||
|
||||
class OperatorSchema
|
||||
{
|
||||
public:
|
||||
|
||||
const std::string& GetName() const;
|
||||
const OpSignature& GetOpSignature() const;
|
||||
ShapeInferenceFunc GetShapeInferenceFn() const;
|
||||
AttributeParser GetAttributeParser() const;
|
||||
|
||||
private:
|
||||
|
||||
friend class OperatorSchemaSetter;
|
||||
friend class OperatorSchemaRegistry;
|
||||
|
||||
OpSignature m_opSignature;
|
||||
ShapeInferenceFunc m_shapeInferenceFunc;
|
||||
AttributeParser m_attrParser;
|
||||
};
|
||||
|
||||
typedef std::tuple<std::string, std::string, std::string> InputOutputParam;
|
||||
typedef std::tuple<std::string, std::string, AttrType, AttributeProto> AttrParam;
|
||||
typedef std::tuple<std::string, std::vector<std::string>, std::string> TypeConstraintParam;
|
||||
|
||||
#define ATTR_SETTER_INTERFACE(TypeName) \
|
||||
OperatorSchemaSetter& Attr(const std::string& p_attrName, \
|
||||
const std::string& p_description, \
|
||||
AttrType p_attrType, \
|
||||
const TypeName& p_defaultValue); \
|
||||
OperatorSchemaSetter& Attr(const std::string& p_attrName, \
|
||||
const std::string& p_description, \
|
||||
AttrType p_attrType, \
|
||||
const std::vector<TypeName>& p_defaultValues); \
|
||||
|
||||
// Operator registry setter helper.
|
||||
// This is used in "OPERATOR_DEFINITION" macro, to separate setters from getters
|
||||
// in OpSignature.
|
||||
class OperatorSchemaSetter
|
||||
{
|
||||
public:
|
||||
|
||||
OperatorSchemaSetter() = default;
|
||||
|
||||
OperatorSchemaSetter& Name(const std::string& p_opName);
|
||||
|
||||
OperatorSchemaSetter& Description(const std::string& p_description);
|
||||
|
||||
OperatorSchemaSetter& Input(const std::string& p_inputName,
|
||||
const std::string& p_description,
|
||||
const std::string& p_type = "");
|
||||
|
||||
OperatorSchemaSetter& Output(const std::string& p_outputName,
|
||||
const std::string& p_description,
|
||||
const std::string& p_type = "");
|
||||
|
||||
OperatorSchemaSetter& Attr(const std::string& p_attrName,
|
||||
const std::string& p_description,
|
||||
AttrType p_attrType, bool required = false);
|
||||
|
||||
ATTR_SETTER_INTERFACE(int64_t)
|
||||
ATTR_SETTER_INTERFACE(float)
|
||||
ATTR_SETTER_INTERFACE(std::string)
|
||||
ATTR_SETTER_INTERFACE(TensorProto)
|
||||
ATTR_SETTER_INTERFACE(GraphProto)
|
||||
ATTR_SETTER_INTERFACE(TypeProto)
|
||||
ATTR_SETTER_INTERFACE(TypeProto::TensorShapeProto)
|
||||
|
||||
OperatorSchemaSetter& TypeConstraint(const std::string& p_typeName,
|
||||
const std::vector<std::string>& p_constraints,
|
||||
const std::string& p_description);
|
||||
|
||||
// Shape inference function will be used to infer outputs' shape with
|
||||
// inputs' shape.
|
||||
OperatorSchemaSetter& SetShapeInferenceFunc(
|
||||
ShapeInferenceFunc p_shapeInferFunc);
|
||||
|
||||
// Attribute parser will be used to parse Node's attributes to see
|
||||
// whether Node attributes are matching operator attributes definition.
|
||||
OperatorSchemaSetter& SetAttributeParser(
|
||||
AttributeParser p_attrParser);
|
||||
|
||||
enum class SupportType {
|
||||
COMMON,
|
||||
EXPERIMENTAL,
|
||||
};
|
||||
// Methods added for compatibility with ONNX OpSchema registration API
|
||||
OpSchema& NumInputs(int n)
|
||||
{
|
||||
return NumInputs(n, n);
|
||||
}
|
||||
OpSchema& NumInputs(int min, int max)
|
||||
{
|
||||
m_opSchema.m_opSignature.m_onnxMinInput = min;
|
||||
m_opSchema.m_opSignature.m_onnxMaxInput = max;
|
||||
return *this;
|
||||
}
|
||||
OpSchema& NumInputs(std::set<int> allowed_input_nums)
|
||||
{
|
||||
return NumInputs([allowed_input_nums](int n)-> bool {
|
||||
return allowed_input_nums.count(n) > 0;
|
||||
});
|
||||
}
|
||||
OpSchema& NumInputs(std::function<bool(int)> func)
|
||||
{
|
||||
m_opSchema.m_opSignature.m_onnxNumInputsAllowed = func;
|
||||
return *this;
|
||||
}
|
||||
OpSchema& NumOutputs(int n) {
|
||||
return NumOutputs(n, n);
|
||||
}
|
||||
OpSchema& NumOutputs(int min, int max)
|
||||
{
|
||||
m_opSchema.m_opSignature.m_onnxMinOutput = min;
|
||||
m_opSchema.m_opSignature.m_onnxMaxOutput = max;
|
||||
return *this;
|
||||
}
|
||||
OpSchema& NumOutputs(std::set<int> allowed_output_nums)
|
||||
{
|
||||
return NumOutputs([allowed_output_nums](int n)-> bool {
|
||||
return allowed_output_nums.count(n) > 0;
|
||||
});
|
||||
}
|
||||
OpSchema& NumOutputs(std::function<bool(int)> func)
|
||||
{
|
||||
m_opSchema.m_opSignature.m_onnxNumOutputsAllowed = func;
|
||||
return *this;
|
||||
}
|
||||
OpSchema& NumInputsOutputs(std::function<bool(int, int)> func)
|
||||
{
|
||||
m_opSchema.m_opSignature.m_onnxNumInputsOutputsAllowed = func;
|
||||
return *this;
|
||||
}
|
||||
OpSchema& OutputCalculator(std::function<int(int)> calc) { return *this; }
|
||||
OpSchema& SameNumberOfOutput() { return *this; }
|
||||
OpSchema& AllowConsumed(std::function<std::pair<bool, int>(int)> inplace) { return *this; }
|
||||
OpSchema& AllowConsumed(std::unordered_map<int, int> inplace) { return *this; }
|
||||
OpSchema& AllowOneToOneConsumed() { return *this; }
|
||||
OpSchema& EnforceConsumed(std::function<std::pair<bool, int>(int)> inplace) { return *this; }
|
||||
OpSchema& EnforceConsumed(std::unordered_map<int, int> inplace) { return *this; }
|
||||
OpSchema& EnforceOneToOneConsumed() { return *this; }
|
||||
OpSchema& SetSupportLevel(SupportType) { return *this; }
|
||||
OpSchema& AllowUncheckedAttributes() { return *this; }
|
||||
OpSchema& FillUsing(std::function<void(OpSchema&)> populator)
|
||||
{
|
||||
if (populator)
|
||||
{
|
||||
populator(*this);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
OpSchema& Input(const int, const char* name, const char* description)
|
||||
{
|
||||
return Input(name, description);
|
||||
}
|
||||
OpSchema& Output(const int, const char* name, const char* description)
|
||||
{
|
||||
return Output(name, description);
|
||||
}
|
||||
OpSchema& SetDoc(const std::string& doc)
|
||||
{
|
||||
return Description(doc);
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
//friend class OpSignature;
|
||||
friend class OperatorSchemaRegistry;
|
||||
|
||||
OperatorSchema m_opSchema;
|
||||
|
||||
// Operator input formal parameters.
|
||||
std::vector<InputOutputParam> m_inputs;
|
||||
|
||||
// Operator output formal parameters.
|
||||
std::vector<InputOutputParam> m_outputs;
|
||||
|
||||
// Operator type constraints.
|
||||
std::vector<TypeConstraintParam> m_constraints;
|
||||
};
|
||||
|
||||
// Operator schema registry. A singleton registry to manage all operator
|
||||
// schemas.
|
||||
class OperatorSchemaRegistry
|
||||
{
|
||||
public:
|
||||
|
||||
// Helper function providing a way to call
|
||||
// OpSignatureFactory::Register().
|
||||
class RegisterOnce
|
||||
{
|
||||
public:
|
||||
|
||||
RegisterOnce(OperatorSchemaSetter& p_opRegistry);
|
||||
};
|
||||
|
||||
// Try to get operator with specified operator name.
|
||||
bool TryGetOp(const std::string& p_name,
|
||||
const OperatorSchema** p_opRegistry) const;
|
||||
|
||||
// Register an operator.
|
||||
Status Register(const OperatorSchema& p_opSchema);
|
||||
|
||||
// Get the global operator registry factory instance.
|
||||
static OperatorSchemaRegistry* Get();
|
||||
|
||||
private:
|
||||
|
||||
OperatorSchemaRegistry() = default;
|
||||
|
||||
// An operator name to operator definition data map.
|
||||
std::unordered_map<std::string, OperatorSchema> m_opNameToOpSchemaMap;
|
||||
};
|
||||
|
||||
// utility function used by ONNX v1 op registration defs.
|
||||
size_t ReplaceAll(std::string& s, const char* from, const char* to);
|
||||
|
||||
#define REGISTER_OPERATOR_SCHEMA(OpName) OPERATOR_SCHEMA_UNIQ_HELPER(__COUNTER__, OpName)
|
||||
#define OPERATOR_SCHEMA_UNIQ_HELPER(Counter, OpName) OPERATOR_SCHEMA_UNIQ(Counter, OpName)
|
||||
#define OPERATOR_SCHEMA_UNIQ(Counter, OpName) \
|
||||
static OperatorSchemaRegistry::RegisterOnce op_##Counter \
|
||||
= OperatorSchemaSetter().Name(#OpName)
|
||||
|
||||
// Operator registration example.
|
||||
// OPERATOR_DEFINITION(Add).Description("An operator to sum two float numbers.")
|
||||
// .Input("input_1", "docstr for input_1.", "T")
|
||||
// .Input("input_2", "docstr for input_2.", "T")
|
||||
// .Output("output_1", "docstr for output_1.", "T")
|
||||
// .TypeConstraint("T", { "float16", "float32", "float64" }, "Constrain input and output types to floats.");
|
||||
}
|
||||
|
||||
#endif
|
|
@ -0,0 +1,151 @@
|
|||
#include "opsignature.h"
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
OpSignature::FormalParameter::FormalParameter(
|
||||
const std::string& p_name, const std::string& p_type,
|
||||
const std::string& p_description,
|
||||
const TypeConstraintMap& p_constraintMap)
|
||||
: m_name(p_name), m_typeStr(p_type), m_description(p_description)
|
||||
{
|
||||
auto it = p_constraintMap.find(p_type);
|
||||
if (it != p_constraintMap.end())
|
||||
{
|
||||
m_types = it->second.first;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (!p_type.empty())
|
||||
{
|
||||
m_types.emplace(Utils::OpUtils::ToType(m_typeStr));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const std::string& OpSignature::FormalParameter::GetName() const
|
||||
{
|
||||
return m_name;
|
||||
}
|
||||
|
||||
const DataTypeSet& OpSignature::FormalParameter::GetTypes() const
|
||||
{
|
||||
return m_types;
|
||||
}
|
||||
|
||||
const std::string& OpSignature::FormalParameter::GetTypeStr() const
|
||||
{
|
||||
return m_typeStr;
|
||||
}
|
||||
|
||||
const std::string& OpSignature::FormalParameter::GetDescription() const
|
||||
{
|
||||
return m_description;
|
||||
}
|
||||
|
||||
OpSignature::Attribute::Attribute(
|
||||
const std::string& p_attrName,
|
||||
AttrType p_type,
|
||||
const std::string& p_description,
|
||||
const AttributeProto& p_defaultVal)
|
||||
: m_name(p_attrName), m_type(p_type), m_description(p_description),
|
||||
m_hasDefaultValue(true)
|
||||
{
|
||||
m_allowedValues.push_back(p_defaultVal);
|
||||
}
|
||||
|
||||
OpSignature::Attribute::Attribute(
|
||||
const std::string& p_attrName,
|
||||
AttrType p_type,
|
||||
const std::string& p_description)
|
||||
: m_name(p_attrName), m_type(p_type), m_description(p_description),
|
||||
m_hasDefaultValue(false)
|
||||
{
|
||||
}
|
||||
|
||||
const std::string& OpSignature::Attribute::GetName() const
|
||||
{
|
||||
return m_name;
|
||||
}
|
||||
|
||||
AttrType OpSignature::Attribute::GetType() const
|
||||
{
|
||||
return m_type;
|
||||
}
|
||||
|
||||
bool OpSignature::Attribute::HasDefaultValue(
|
||||
const AttributeProto** p_value) const
|
||||
{
|
||||
if (m_hasDefaultValue
|
||||
&& nullptr != p_value)
|
||||
{
|
||||
*p_value = &(m_allowedValues[0]);
|
||||
}
|
||||
|
||||
return m_hasDefaultValue;
|
||||
}
|
||||
|
||||
|
||||
|
||||
const std::string& OpSignature::GetName() const
|
||||
{
|
||||
return m_name;
|
||||
}
|
||||
|
||||
const std::string& OpSignature::GetDescription() const
|
||||
{
|
||||
return m_description;
|
||||
}
|
||||
|
||||
const std::vector<OpSignature::FormalParameter>&
|
||||
OpSignature::GetInputs() const
|
||||
{
|
||||
return m_inputs;
|
||||
}
|
||||
|
||||
const std::vector<OpSignature::FormalParameter>&
|
||||
OpSignature::GetOutputs() const
|
||||
{
|
||||
return m_outputs;
|
||||
}
|
||||
|
||||
const std::vector<OpSignature::Attribute>&
|
||||
OpSignature::GetAttributes() const
|
||||
{
|
||||
return m_attributes;
|
||||
}
|
||||
|
||||
const TypeConstraintMap& OpSignature::GetTypeConstraintMap() const
|
||||
{
|
||||
return m_typeConstraintMap;
|
||||
}
|
||||
|
||||
bool OpSignature::IsValidAttribute(const AttributeProto& p_attr)
|
||||
{
|
||||
if (p_attr.name().empty())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
int num_fields =
|
||||
p_attr.has_f() +
|
||||
p_attr.has_i() +
|
||||
p_attr.has_s() +
|
||||
p_attr.has_t() +
|
||||
p_attr.has_g() +
|
||||
(p_attr.floats_size() > 0) +
|
||||
(p_attr.ints_size() > 0) +
|
||||
(p_attr.strings_size() > 0) +
|
||||
(p_attr.tensors_size() > 0) +
|
||||
(p_attr.graphs_size() > 0) +
|
||||
p_attr.has_type() +
|
||||
(p_attr.types_size() > 0) +
|
||||
p_attr.has_shape() +
|
||||
(p_attr.shapes_size() > 0);
|
||||
|
||||
if (num_fields == 1)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,239 @@
|
|||
#pragma warning(push)
|
||||
#pragma warning(disable : 4800 4610 4512 4510 4267 4127 4125 4100 4456)
|
||||
|
||||
#ifndef CORE_GRAPH_OPSCHEMA_H
|
||||
#define CORE_GRAPH_OPSCHEMA_H
|
||||
|
||||
#include <functional>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "proto/onnx/protobuf/graph.pb.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
enum class AttrType {
|
||||
NONE,
|
||||
FLOAT,
|
||||
INT,
|
||||
STRING,
|
||||
GRAPH,
|
||||
TENSOR,
|
||||
TYPE,
|
||||
SHAPE,
|
||||
FLOATS,
|
||||
INTS,
|
||||
STRINGS,
|
||||
GRAPHS,
|
||||
TENSORS,
|
||||
TYPES,
|
||||
SHAPES
|
||||
};
|
||||
|
||||
// This string array should exactly match the AttrType defined above.
|
||||
static const std::string c_attrTypeStr[14] =
|
||||
{
|
||||
"FLOAT",
|
||||
"INT",
|
||||
"STRING",
|
||||
"GRAPH",
|
||||
"TENSOR",
|
||||
"TYPE",
|
||||
"SHAPE",
|
||||
"FLOATS",
|
||||
"INTS",
|
||||
"STRINGS",
|
||||
"GRAPHS",
|
||||
"TENSORS",
|
||||
"TYPES",
|
||||
"SHAPES"
|
||||
};
|
||||
|
||||
typedef std::unordered_set<PTYPE> DataTypeSet;
|
||||
typedef std::unordered_map<std::string, std::pair<DataTypeSet, std::string>> TypeConstraintMap;
|
||||
|
||||
// Operator signature declaration.
|
||||
// It defines input formal parameter, output formal parameters and
|
||||
// attributes.
|
||||
// Once an operator signature created, it's "Read-Only".
|
||||
class OpSignature
|
||||
{
|
||||
public:
|
||||
|
||||
// Formal parameter represenation, including parameter name, type.
|
||||
class FormalParameter
|
||||
{
|
||||
public:
|
||||
|
||||
// Constructor.
|
||||
explicit FormalParameter(const std::string& p_name,
|
||||
const std::string& p_type,
|
||||
const std::string& p_description,
|
||||
const TypeConstraintMap& p_constraintMap = TypeConstraintMap());
|
||||
|
||||
// Get formal parameter name.
|
||||
const std::string& GetName() const;
|
||||
|
||||
// Get supported data types.
|
||||
const DataTypeSet& GetTypes() const;
|
||||
|
||||
// Get formal parameter type string.
|
||||
const std::string& GetTypeStr() const;
|
||||
|
||||
// Get formal parameter description.
|
||||
const std::string& GetDescription() const;
|
||||
|
||||
private:
|
||||
|
||||
FormalParameter() {}
|
||||
|
||||
// Formal parameter name.
|
||||
std::string m_name;
|
||||
|
||||
// A set of data types supported for <*this> formal parameter.
|
||||
// It should contain at least one element if this formal parameter
|
||||
// is good.
|
||||
DataTypeSet m_types;
|
||||
|
||||
// The <parameter type> string specified when registring an op.
|
||||
// It could be a supported data type or a type constraint key, which
|
||||
// maps to a set of supported data types.
|
||||
std::string m_typeStr;
|
||||
|
||||
// Formal parameter description
|
||||
std::string m_description;
|
||||
|
||||
};
|
||||
|
||||
// Attribute representation, including name, type, and allowed values.
|
||||
// The first element of allowed values (if specified) is the default
|
||||
// value.
|
||||
class Attribute
|
||||
{
|
||||
public:
|
||||
|
||||
// Constructor.
|
||||
explicit Attribute(const std::string& p_attrName,
|
||||
AttrType p_type,
|
||||
const std::string& p_description);
|
||||
|
||||
// Constructor with default value.
|
||||
explicit Attribute(const std::string& p_attrName,
|
||||
AttrType p_type,
|
||||
const std::string& p_description,
|
||||
const AttributeProto& p_defaultVal);
|
||||
|
||||
// Get attribute name.
|
||||
const std::string& GetName() const;
|
||||
|
||||
// Get attribute type.
|
||||
AttrType GetType() const;
|
||||
|
||||
// Get to know whether this attribute has default value,
|
||||
// if yes, <p_value> will be assigned to be the default value.
|
||||
bool HasDefaultValue(const AttributeProto** p_value) const;
|
||||
|
||||
private:
|
||||
|
||||
Attribute() {}
|
||||
|
||||
// Attribute name.
|
||||
std::string m_name;
|
||||
|
||||
// Attribute type.
|
||||
AttrType m_type;
|
||||
|
||||
// Attribute description.
|
||||
std::string m_description;
|
||||
|
||||
// Flag indicates whether a default value specified.
|
||||
// It it's true, the first element of <m_allowedValues> is the
|
||||
// default value.
|
||||
bool m_hasDefaultValue;
|
||||
|
||||
// Allowed attribute values.
|
||||
std::vector<AttributeProto> m_allowedValues;
|
||||
};
|
||||
|
||||
static bool IsValidAttribute(const AttributeProto& p_attribute);
|
||||
|
||||
// Constructor.
|
||||
OpSignature() = default;
|
||||
|
||||
// Get operator name.
|
||||
const std::string& GetName() const;
|
||||
|
||||
// Get operator description.
|
||||
const std::string& GetDescription() const;
|
||||
|
||||
// Get input formal parameters.
|
||||
const std::vector<FormalParameter>& GetInputs() const;
|
||||
|
||||
// Get output formal parameters.
|
||||
const std::vector<FormalParameter>& GetOutputs() const;
|
||||
|
||||
// Get attributes.
|
||||
const std::vector<Attribute>& GetAttributes() const;
|
||||
|
||||
// Get type constraint map.
|
||||
const TypeConstraintMap& GetTypeConstraintMap() const;
|
||||
|
||||
// To support ONNX variable input/output compatibility.
|
||||
// Min and Max num arguments of last input/output.
|
||||
int GetOnnxMinInput() const { return m_onnxMinInput; }
|
||||
int GetOnnxMaxInput() const { return m_onnxMaxInput; }
|
||||
int GetOnnxMinOutput() const { return m_onnxMinOutput; }
|
||||
int GetOnnxMaxOutput() const { return m_onnxMaxOutput; }
|
||||
std::function<bool(int)> GetOnnxNumInputsAllowedFunc() const
|
||||
{
|
||||
return m_onnxNumInputsAllowed;
|
||||
}
|
||||
std::function<bool(int)> GetOnnxNumOutputsAllowedFunc() const
|
||||
{
|
||||
return m_onnxNumOutputsAllowed;
|
||||
}
|
||||
std::function<bool(int, int)> GetOnnxNumInputsOutputsAllowedFunc() const
|
||||
{
|
||||
return m_onnxNumInputsOutputsAllowed;
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
friend class OperatorSchemaSetter;
|
||||
friend class OperatorSchemaRegistry;
|
||||
|
||||
// Operator name.
|
||||
std::string m_name;
|
||||
|
||||
// Operator description.
|
||||
std::string m_description;
|
||||
|
||||
// Operator input formal parameters.
|
||||
std::vector<FormalParameter> m_inputs;
|
||||
|
||||
// Operator output formal parameters.
|
||||
std::vector<FormalParameter> m_outputs;
|
||||
|
||||
// Operator attributes' definitions.
|
||||
std::vector<Attribute> m_attributes;
|
||||
|
||||
// Map from constraint name to DataTypeSet
|
||||
TypeConstraintMap m_typeConstraintMap;
|
||||
|
||||
// To support ONNX variable input/output compatibility.
|
||||
// Min and Max num arguments of last input/output.
|
||||
int m_onnxMinInput = 0;
|
||||
int m_onnxMaxInput = std::numeric_limits<int>::max();
|
||||
int m_onnxMinOutput = 0;
|
||||
int m_onnxMaxOutput = std::numeric_limits<int>::max();
|
||||
std::function<bool(int)> m_onnxNumInputsAllowed =
|
||||
[](int) { return true; };
|
||||
std::function<bool(int)> m_onnxNumOutputsAllowed =
|
||||
[](int) { return true; };
|
||||
std::function<bool(int, int)> m_onnxNumInputsOutputsAllowed =
|
||||
[](int, int) { return true; };
|
||||
};
|
||||
}
|
||||
#endif
|
||||
|
||||
#pragma warning(pop)
|
|
@ -0,0 +1,39 @@
|
|||
#include "shape_inference.h"
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
InferenceContext::InferenceContext(Node* p_node,
|
||||
const OpSignature* p_opSchema)
|
||||
: m_node(p_node),
|
||||
m_opSignature(p_opSchema)
|
||||
{
|
||||
}
|
||||
|
||||
const Node* InferenceContext::GetNode() const
|
||||
{
|
||||
return m_node;
|
||||
}
|
||||
|
||||
const OpSignature* InferenceContext::GetOp() const
|
||||
{
|
||||
return m_opSignature;
|
||||
}
|
||||
|
||||
const std::vector<NodeArg>* InferenceContext::GetInputs() const
|
||||
{
|
||||
if (nullptr == m_node)
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
return &(m_node->InputDefs());
|
||||
}
|
||||
|
||||
std::vector<NodeArg>* InferenceContext::Mutable_Outputs()
|
||||
{
|
||||
if (nullptr == m_node)
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
return &(m_node->Mutable_OutputDefs());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
#ifndef CORE_GRAPH_SHAPEINFERENCE_H
|
||||
#define CORE_GRAPH_SHAPEINFERENCE_H
|
||||
|
||||
#include "graph.h"
|
||||
#include "opsignature.h"
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
|
||||
// A context to contain information for shape inference function.
|
||||
// It includes the operator registry, input arguments definition,
|
||||
// and mutable output arguments, whose shapes needs to be filled.
|
||||
class InferenceContext
|
||||
{
|
||||
public:
|
||||
|
||||
// TODO: Add input tensors into constructor.
|
||||
// TODO: An abstract tensor interface will be needed.
|
||||
// In some cases, node evaluation will be needed to get output shapes.
|
||||
InferenceContext(Node* p_node,
|
||||
const OpSignature* p_opSchema);
|
||||
|
||||
const Node* GetNode() const;
|
||||
|
||||
const OpSignature* GetOp() const;
|
||||
|
||||
const std::vector<NodeArg>* GetInputs() const;
|
||||
|
||||
std::vector<NodeArg>* Mutable_Outputs();
|
||||
|
||||
private:
|
||||
|
||||
Node* m_node;
|
||||
|
||||
const OpSignature* m_opSignature;
|
||||
};
|
||||
|
||||
// Shape inference function define.
|
||||
typedef std::function<Status(InferenceContext&)> ShapeInferenceFunc;
|
||||
}
|
||||
#endif
|
|
@ -0,0 +1,32 @@
|
|||
#include "status.h"
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
Status::Status(bool p_ok, const std::string& p_errMsg)
|
||||
{
|
||||
m_ok = p_ok;
|
||||
m_errMsg = p_errMsg;
|
||||
}
|
||||
|
||||
Status::Status(const Status& p_other)
|
||||
{
|
||||
m_ok = p_other.m_ok;
|
||||
m_errMsg = p_other.m_errMsg;
|
||||
}
|
||||
|
||||
bool Status::Ok() const
|
||||
{
|
||||
return m_ok;
|
||||
}
|
||||
|
||||
const std::string& Status::ErrorMsg() const
|
||||
{
|
||||
return m_errMsg;
|
||||
}
|
||||
|
||||
Status Status::OK()
|
||||
{
|
||||
static Status ok(true, "");
|
||||
return ok;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
#ifndef CORE_GRAPH_STATUS_H
|
||||
#define CORE_GRAPH_STATUS_H
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
|
||||
#define RETURN_IF_ERROR(expr) \
|
||||
do { \
|
||||
auto status = (expr); \
|
||||
if ((!status.Ok())) return status; \
|
||||
} while (0)
|
||||
|
||||
class Status
|
||||
{
|
||||
public:
|
||||
Status() = delete;
|
||||
|
||||
// Constructor.
|
||||
Status(bool p_ok, const std::string& p_errMsg);
|
||||
|
||||
// Copy constructor.
|
||||
Status(const Status& p_other);
|
||||
|
||||
// Getter of <m_ok>.
|
||||
bool Ok() const;
|
||||
|
||||
// Getter of <m_errMsg>.
|
||||
const std::string& ErrorMsg() const;
|
||||
|
||||
static Status OK();
|
||||
|
||||
private:
|
||||
|
||||
bool m_ok;
|
||||
std::string m_errMsg;
|
||||
};
|
||||
}
|
||||
|
||||
#endif // !CORE_GRAPH_STATUS_H
|
|
@ -0,0 +1,422 @@
|
|||
#pragma warning(push)
|
||||
#pragma warning(disable : 4800 4610 4512 4510 4267 4127 4125 4100 4456)
|
||||
|
||||
#include <cctype>
|
||||
#include <iterator>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "constants.h"
|
||||
#include "proto/onnx/protobuf/graph.pb.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
namespace Utils
|
||||
{
|
||||
std::unordered_map<std::string, TypeProto>& OpUtils::GetTypeStrToProtoMap()
|
||||
{
|
||||
static std::unordered_map<std::string, TypeProto>* typeStrToProtoMap =
|
||||
new std::unordered_map<std::string, TypeProto>();
|
||||
return *typeStrToProtoMap;
|
||||
}
|
||||
|
||||
PTYPE OpUtils::ToType(const TypeProto& p_type)
|
||||
{
|
||||
auto typeStr = ToString(p_type);
|
||||
if (GetTypeStrToProtoMap().find(typeStr) == GetTypeStrToProtoMap().end())
|
||||
{
|
||||
GetTypeStrToProtoMap()[typeStr] = p_type;
|
||||
}
|
||||
return &(GetTypeStrToProtoMap().find(typeStr)->first);
|
||||
}
|
||||
|
||||
PTYPE OpUtils::ToType(const std::string& p_type)
|
||||
{
|
||||
TypeProto type;
|
||||
FromString(p_type, type);
|
||||
return ToType(type);
|
||||
}
|
||||
|
||||
const TypeProto& OpUtils::ToTypeProto(const PTYPE& p_type)
|
||||
{
|
||||
auto it = GetTypeStrToProtoMap().find(*p_type);
|
||||
if (it != GetTypeStrToProtoMap().end())
|
||||
{
|
||||
return it->second;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::invalid_argument("PTYPE not found: " + *p_type);
|
||||
}
|
||||
}
|
||||
|
||||
std::string OpUtils::ToString(const TypeProto& p_type)
|
||||
{
|
||||
switch (p_type.value_case())
|
||||
{
|
||||
case TypeProto::ValueCase::kTensorType:
|
||||
return ToString(p_type.tensor_type().elem_type());
|
||||
case TypeProto::ValueCase::kSparseTensorType:
|
||||
return "sparse(" + ToString(p_type.sparse_tensor_type().elem_type()) + ")";
|
||||
case TypeProto::ValueCase::kSeqType:
|
||||
return "seq(" + ToString(p_type.seq_type().elem_type()) + ")";
|
||||
case TypeProto::ValueCase::kTupleType:
|
||||
{
|
||||
int size = p_type.tuple_type().elem_type_size();
|
||||
std::string tuple_str("tuple(");
|
||||
for (int i = 0; i < size - 1; i++)
|
||||
{
|
||||
tuple_str = tuple_str + ToString(p_type.tuple_type().elem_type(i)) + ",";
|
||||
}
|
||||
tuple_str += ToString(p_type.tuple_type().elem_type(size - 1));
|
||||
tuple_str += ")";
|
||||
return tuple_str;
|
||||
}
|
||||
case TypeProto::ValueCase::kMapType:
|
||||
{
|
||||
std::string map_str("map(");
|
||||
map_str = map_str + ToString(p_type.map_type().key_type()) + ","
|
||||
+ ToString(p_type.map_type().value_type()) + ")";
|
||||
return map_str;
|
||||
}
|
||||
case TypeProto::ValueCase::kHandleType:
|
||||
return "handle";
|
||||
default:
|
||||
throw std::invalid_argument("Unknown TypeProto");
|
||||
}
|
||||
}
|
||||
|
||||
std::string OpUtils::ToString(const TensorProto::DataType& p_type)
|
||||
{
|
||||
TypesWrapper& t = TypesWrapper::GetTypesWrapper();
|
||||
switch (p_type)
|
||||
{
|
||||
case TensorProto::DataType::TensorProto_DataType_BOOL:
|
||||
return t.c_bool;
|
||||
case TensorProto::DataType::TensorProto_DataType_STRING:
|
||||
return t.c_string;
|
||||
case TensorProto::DataType::TensorProto_DataType_FLOAT16:
|
||||
return t.c_float16;
|
||||
case TensorProto::DataType::TensorProto_DataType_FLOAT:
|
||||
return t.c_float;
|
||||
case TensorProto::DataType::TensorProto_DataType_DOUBLE:
|
||||
return t.c_double;
|
||||
case TensorProto::DataType::TensorProto_DataType_INT8:
|
||||
return t.c_int8;
|
||||
case TensorProto::DataType::TensorProto_DataType_INT16:
|
||||
return t.c_int16;
|
||||
case TensorProto::DataType::TensorProto_DataType_INT32:
|
||||
return t.c_int32;
|
||||
case TensorProto::DataType::TensorProto_DataType_INT64:
|
||||
return t.c_int64;
|
||||
case TensorProto::DataType::TensorProto_DataType_UINT8:
|
||||
return t.c_uint8;
|
||||
case TensorProto::DataType::TensorProto_DataType_UINT16:
|
||||
return t.c_uint16;
|
||||
case TensorProto::DataType::TensorProto_DataType_UINT32:
|
||||
return t.c_uint32;
|
||||
case TensorProto::DataType::TensorProto_DataType_UINT64:
|
||||
return t.c_uint64;
|
||||
case TensorProto::DataType::TensorProto_DataType_COMPLEX64:
|
||||
return t.c_complex64;
|
||||
case TensorProto::DataType::TensorProto_DataType_COMPLEX128:
|
||||
return t.c_complex128;
|
||||
}
|
||||
|
||||
throw std::invalid_argument("Unknown DataType");
|
||||
}
|
||||
|
||||
|
||||
void OpUtils::FromString(const std::string& p_src, TypeProto& p_type)
|
||||
{
|
||||
StringRange s(p_src);
|
||||
s.LAndRStrip();
|
||||
p_type.Clear();
|
||||
|
||||
if (s.LStrip("seq("))
|
||||
{
|
||||
s.RStrip(")");
|
||||
FromString(std::string(s.Data(), s.Size()), *p_type.mutable_seq_type()->mutable_elem_type());
|
||||
}
|
||||
else if (s.LStrip("tuple("))
|
||||
{
|
||||
s.RStrip(")");
|
||||
std::istringstream types(std::string(s.Data(), s.Size()));
|
||||
std::string type;
|
||||
while (std::getline(types, type, ','))
|
||||
{
|
||||
FromString(type, *p_type.mutable_tuple_type()->mutable_elem_type()->Add());
|
||||
}
|
||||
}
|
||||
else if (s.LStrip("map("))
|
||||
{
|
||||
size_t key_size = s.Find(',');
|
||||
StringRange k(s.Data(), key_size);
|
||||
std::string key = std::string(k.Data(), k.Size());
|
||||
s.LStrip(key_size);
|
||||
s.LStrip(",");
|
||||
size_t val_size = s.Find(')');
|
||||
StringRange v(s.Data(), val_size);
|
||||
std::string val = std::string(v.Data(), v.Size());
|
||||
|
||||
TensorProto::DataType key_type;
|
||||
FromString(key, key_type);
|
||||
TensorProto::DataType val_type;
|
||||
FromString(val, val_type);
|
||||
p_type.mutable_map_type()->set_key_type(key_type);
|
||||
p_type.mutable_map_type()->set_value_type(val_type);
|
||||
}
|
||||
else if (s.LStrip("handle"))
|
||||
{
|
||||
p_type.mutable_handle_type();
|
||||
}
|
||||
else if (s.LStrip("sparse("))
|
||||
{
|
||||
s.RStrip(")");
|
||||
TensorProto::DataType e;
|
||||
FromString(std::string(s.Data(), s.Size()), e);
|
||||
p_type.mutable_sparse_tensor_type()->set_elem_type(e);
|
||||
}
|
||||
else
|
||||
{
|
||||
// dense tensor
|
||||
TensorProto::DataType e;
|
||||
FromString(std::string(s.Data(), s.Size()), e);
|
||||
p_type.mutable_tensor_type()->set_elem_type(e);
|
||||
}
|
||||
}
|
||||
|
||||
bool OpUtils::IsValidDataTypeString(const std::string& p_dataType)
|
||||
{
|
||||
TypesWrapper& t = TypesWrapper::GetTypesWrapper();
|
||||
return (t.GetAllowedDataTypes().find(p_dataType) != t.GetAllowedDataTypes().end());
|
||||
}
|
||||
|
||||
void OpUtils::FromString(const std::string& p_typeStr, TensorProto::DataType& p_type)
|
||||
{
|
||||
if (!IsValidDataTypeString(p_typeStr))
|
||||
{
|
||||
throw std::invalid_argument("Unknown DataType: " + p_typeStr);
|
||||
}
|
||||
|
||||
TypesWrapper& t = TypesWrapper::GetTypesWrapper();
|
||||
if (p_typeStr == t.c_bool)
|
||||
{
|
||||
p_type = TensorProto::DataType::TensorProto_DataType_BOOL;
|
||||
}
|
||||
else if (p_typeStr == t.c_float)
|
||||
{
|
||||
p_type = TensorProto::DataType::TensorProto_DataType_FLOAT;
|
||||
}
|
||||
else if (p_typeStr == t.c_float16)
|
||||
{
|
||||
p_type = TensorProto::DataType::TensorProto_DataType_FLOAT16;
|
||||
}
|
||||
else if (p_typeStr == t.c_double)
|
||||
{
|
||||
p_type = TensorProto::DataType::TensorProto_DataType_DOUBLE;
|
||||
}
|
||||
else if (p_typeStr == t.c_int8)
|
||||
{
|
||||
p_type = TensorProto::DataType::TensorProto_DataType_INT8;
|
||||
}
|
||||
else if (p_typeStr == t.c_int16)
|
||||
{
|
||||
p_type = TensorProto::DataType::TensorProto_DataType_INT16;
|
||||
}
|
||||
else if (p_typeStr == t.c_int32)
|
||||
{
|
||||
p_type = TensorProto::DataType::TensorProto_DataType_INT32;
|
||||
}
|
||||
else if (p_typeStr == t.c_int64)
|
||||
{
|
||||
p_type = TensorProto::DataType::TensorProto_DataType_INT64;
|
||||
}
|
||||
else if (p_typeStr == t.c_string)
|
||||
{
|
||||
p_type = TensorProto::DataType::TensorProto_DataType_STRING;
|
||||
}
|
||||
else if (p_typeStr == t.c_uint8)
|
||||
{
|
||||
p_type = TensorProto::DataType::TensorProto_DataType_UINT8;
|
||||
}
|
||||
else if (p_typeStr == t.c_uint16)
|
||||
{
|
||||
p_type = TensorProto::DataType::TensorProto_DataType_UINT16;
|
||||
}
|
||||
else if (p_typeStr == t.c_uint32)
|
||||
{
|
||||
p_type = TensorProto::DataType::TensorProto_DataType_UINT32;
|
||||
}
|
||||
else if (p_typeStr == t.c_uint64)
|
||||
{
|
||||
p_type = TensorProto::DataType::TensorProto_DataType_UINT64;
|
||||
}
|
||||
else if (p_typeStr == t.c_complex64)
|
||||
{
|
||||
p_type = TensorProto::DataType::TensorProto_DataType_COMPLEX64;
|
||||
}
|
||||
else if (p_typeStr == t.c_complex128)
|
||||
{
|
||||
p_type = TensorProto::DataType::TensorProto_DataType_COMPLEX128;
|
||||
}
|
||||
else
|
||||
{
|
||||
p_type = TensorProto::DataType::TensorProto_DataType_UNDEFINED;
|
||||
}
|
||||
}
|
||||
|
||||
StringRange::StringRange()
|
||||
: m_data(""), m_size(0)
|
||||
{}
|
||||
|
||||
StringRange::StringRange(const char* p_data, size_t p_size)
|
||||
: m_data(p_data), m_size(p_size)
|
||||
{}
|
||||
|
||||
StringRange::StringRange(const std::string& p_str)
|
||||
: m_data(p_str.data()), m_size(p_str.size())
|
||||
{}
|
||||
|
||||
StringRange::StringRange(const char* p_data)
|
||||
: m_data(p_data), m_size(strlen(p_data))
|
||||
{}
|
||||
|
||||
const char* StringRange::Data() const
|
||||
{
|
||||
return m_data;
|
||||
}
|
||||
|
||||
size_t StringRange::Size() const
|
||||
{
|
||||
return m_size;
|
||||
}
|
||||
|
||||
bool StringRange::Empty() const
|
||||
{
|
||||
return m_size == 0;
|
||||
}
|
||||
|
||||
char StringRange::operator[](size_t p_idx) const
|
||||
{
|
||||
return m_data[p_idx];
|
||||
}
|
||||
|
||||
void StringRange::Reset()
|
||||
{
|
||||
m_data = "";
|
||||
m_size = 0;
|
||||
}
|
||||
|
||||
void StringRange::Reset(const char* p_data, size_t p_size)
|
||||
{
|
||||
m_data = p_data;
|
||||
m_size = p_size;
|
||||
}
|
||||
|
||||
void StringRange::Reset(const std::string& p_str)
|
||||
{
|
||||
m_data = p_str.data();
|
||||
m_size = p_str.size();
|
||||
}
|
||||
|
||||
bool StringRange::StartsWith(const StringRange& p_str) const
|
||||
{
|
||||
return ((m_size >= p_str.m_size) && (memcmp(m_data, p_str.m_data, p_str.m_size) == 0));
|
||||
}
|
||||
|
||||
bool StringRange::EndsWith(const StringRange& p_str) const
|
||||
{
|
||||
return ((m_size >= p_str.m_size) &&
|
||||
(memcmp(m_data + (m_size - p_str.m_size), p_str.m_data, p_str.m_size) == 0));
|
||||
}
|
||||
|
||||
bool StringRange::LStrip() {
|
||||
size_t count = 0;
|
||||
const char* ptr = m_data;
|
||||
while (count < m_size && isspace(*ptr)) {
|
||||
count++;
|
||||
ptr++;
|
||||
}
|
||||
|
||||
if (count > 0)
|
||||
{
|
||||
return LStrip(count);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
bool StringRange::LStrip(size_t p_size)
|
||||
{
|
||||
if (p_size <= m_size)
|
||||
{
|
||||
m_data += p_size;
|
||||
m_size -= p_size;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool StringRange::LStrip(StringRange p_str)
|
||||
{
|
||||
if (StartsWith(p_str)) {
|
||||
return LStrip(p_str.m_size);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool StringRange::RStrip() {
|
||||
size_t count = 0;
|
||||
const char* ptr = m_data + m_size - 1;
|
||||
while (count < m_size && isspace(*ptr)) {
|
||||
++count;
|
||||
--ptr;
|
||||
}
|
||||
|
||||
if (count > 0)
|
||||
{
|
||||
return RStrip(count);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool StringRange::RStrip(size_t p_size)
|
||||
{
|
||||
if (m_size >= p_size)
|
||||
{
|
||||
m_size -= p_size;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool StringRange::RStrip(StringRange p_str)
|
||||
{
|
||||
if (EndsWith(p_str)) {
|
||||
return RStrip(p_str.m_size);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool StringRange::LAndRStrip()
|
||||
{
|
||||
return LStrip() || RStrip();
|
||||
}
|
||||
|
||||
size_t StringRange::Find(const char p_ch) const
|
||||
{
|
||||
size_t idx = 0;
|
||||
while (idx < m_size)
|
||||
{
|
||||
if (m_data[idx] == p_ch)
|
||||
{
|
||||
return idx;
|
||||
}
|
||||
idx++;
|
||||
}
|
||||
return std::string::npos;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#pragma warning(pop)
|
|
@ -0,0 +1,64 @@
|
|||
#ifndef ONNXIR_UTILS_H
|
||||
#define ONNXIR_UTILS_H
|
||||
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <string>
|
||||
|
||||
class TensorProto;
|
||||
class TypeProto;
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
typedef const std::string* PTYPE;
|
||||
|
||||
namespace Utils
|
||||
{
|
||||
class OpUtils
|
||||
{
|
||||
public:
|
||||
static PTYPE ToType(const TypeProto& p_type);
|
||||
static PTYPE ToType(const std::string& p_type);
|
||||
static const TypeProto& ToTypeProto(const PTYPE& p_type);
|
||||
static std::string ToString(const TypeProto& p_type);
|
||||
static std::string ToString(const TensorProto::DataType& p_type);
|
||||
static void FromString(const std::string& p_src, TypeProto& p_type);
|
||||
static void FromString(const std::string& p_src, TensorProto::DataType& p_type);
|
||||
static bool IsValidDataTypeString(const std::string &p_dataType);
|
||||
private:
|
||||
static std::unordered_map<std::string, TypeProto>& GetTypeStrToProtoMap();
|
||||
};
|
||||
|
||||
class StringRange
|
||||
{
|
||||
public:
|
||||
StringRange();
|
||||
StringRange(const char* p_data, size_t p_size);
|
||||
StringRange(const std::string& p_str);
|
||||
StringRange(const char* p_data);
|
||||
const char* Data() const;
|
||||
size_t Size() const;
|
||||
bool Empty() const;
|
||||
char operator[](size_t p_idx) const;
|
||||
void Reset();
|
||||
void Reset(const char* p_data, size_t p_size);
|
||||
void Reset(const std::string& p_str);
|
||||
bool StartsWith(const StringRange& p_str) const;
|
||||
bool EndsWith(const StringRange& p_str) const;
|
||||
bool LStrip();
|
||||
bool LStrip(size_t p_size);
|
||||
bool LStrip(StringRange p_str);
|
||||
bool RStrip();
|
||||
bool RStrip(size_t p_size);
|
||||
bool RStrip(StringRange p_str);
|
||||
bool LAndRStrip();
|
||||
size_t Find(const char p_ch) const;
|
||||
|
||||
private:
|
||||
const char* m_data;
|
||||
size_t m_size;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
#endif // ! ONNXIR_UTILS_H
|
|
@ -0,0 +1,329 @@
|
|||
// Copyright (c) Facebook Inc. and Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "proto/onnx/core/op.h"
|
||||
|
||||
using SupportType = ONNXIR::OpSchema::SupportType;
|
||||
namespace ONNXIR {
|
||||
REGISTER_OPERATOR_SCHEMA(ConstantFill)
|
||||
.SetSupportLevel(SupportType::EXPERIMENTAL)
|
||||
.NumInputs(0, 1)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
The operator fills the elements of the output tensor with a constant value
|
||||
specified by the 'value' argument.
|
||||
|
||||
The data type is specified by the 'dtype' argument. The 'dtype' argument must
|
||||
be one of the data types specified in the 'DataType' enum field in the
|
||||
TensorProto message. If the 'dtype' argument is not provided, the data type of
|
||||
'value' is used.
|
||||
|
||||
The output tensor shape is specified by the 'shape' argument. If the number of
|
||||
input is 1, the shape will be identical to that of the input at run time with
|
||||
optional additional dimensions appended at the end as specified by 'extra_shape'
|
||||
argument. In that case the 'shape' argument should not be set.
|
||||
|
||||
If input_as_shape is set to true, then the input should be a 1D tensor
|
||||
containing the desired output shape (the dimensions specified in extra_shape
|
||||
will also be appended)
|
||||
|
||||
NOTE: Currently, it supports data type of float, int32, int64, and bool.
|
||||
)DOC")
|
||||
.Attr("value",
|
||||
"The value for the elements of the output tensor.",
|
||||
AttrType::FLOAT)
|
||||
.Attr(
|
||||
"dtype",
|
||||
"The data type for the elements of the output tensor."
|
||||
"Strictly must be one of the types from DataType enum in TensorProto.",
|
||||
AttrType::INT)
|
||||
.Attr(
|
||||
"shape",
|
||||
"The shape of the output tensor."
|
||||
"Cannot set the shape argument and pass in an input at the same time.",
|
||||
AttrType::INTS)
|
||||
.Attr(
|
||||
"extra_shape",
|
||||
"The additional dimensions appended at the end of the shape indicated"
|
||||
"by the input blob."
|
||||
"Cannot set the extra_shape argument when there is no input blob.",
|
||||
AttrType::INTS)
|
||||
.Attr(
|
||||
"input_as_shape",
|
||||
"1D tensor containing the desired output shape. First input must be in "
|
||||
"CPU context.",
|
||||
AttrType::INT)
|
||||
.Input(0, "input", "Input tensor (optional) to provide shape information.")
|
||||
.Output(
|
||||
0,
|
||||
"output",
|
||||
"Output tensor of constant values specified by 'value'"
|
||||
"argument and its type is specified by the 'dtype' argument");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Caffe2ConvTranspose)
|
||||
.SetSupportLevel(SupportType::EXPERIMENTAL)
|
||||
.NumInputs(3)
|
||||
.NumOutputs(1)
|
||||
.SetDoc(R"DOC(
|
||||
The transposed convolution consumes an input vector, the filter blob, and
|
||||
the bias blob, and computes the output. Note that other parameters, such as
|
||||
the stride and kernel size, or the pads' sizes in each direction are not
|
||||
necessary for input because they are provided by the
|
||||
ConvTransposeUnpoolOpBase operator. Various dimension checks are done
|
||||
implicitly, and the sizes are specified in the Input docs for this operator.
|
||||
As is expected, the filter is deconvolved with a subset of the
|
||||
image and the bias is added; this is done throughout the image data and the
|
||||
output is computed. As a side note on the implementation layout:
|
||||
conv_transpose_op_impl.h is the templated implementation of the
|
||||
conv_transpose_op.h file, which is why they are separate files.
|
||||
)DOC")
|
||||
.Input(
|
||||
0,
|
||||
"X",
|
||||
"Input data blob from previous layer; has size "
|
||||
"(N x C x H x W), where N is the batch size, C is the number of channels, and"
|
||||
" H and W are the height and width. Note that this is for the NCHW usage. On "
|
||||
"the other hand, the NHWC Op has a different set of dimension constraints.")
|
||||
.Input(
|
||||
1,
|
||||
"filter",
|
||||
"The filter blob that will be used in the transposed "
|
||||
"convolution; has size (M x C x kH x kW), where C is the number of channels,"
|
||||
" and kH and kW are the height and width of the kernel.")
|
||||
.Input(
|
||||
2,
|
||||
"bias",
|
||||
"The 1D bias blob that is added through the convolution;"
|
||||
"has size (C)")
|
||||
.Output(
|
||||
0,
|
||||
"Y",
|
||||
"Output data blob that contains the result of the "
|
||||
"transposed convolution. The output dimensions are functions of the kernel"
|
||||
" size, stride size, and pad lengths.")
|
||||
.Attr("pads", "", AttrType::INTS)
|
||||
.Attr("kernel_shape", "", AttrType::INTS)
|
||||
.Attr("dilations", "", AttrType::INTS)
|
||||
.Attr("group", "", AttrType::INT)
|
||||
.Attr("strides", "", AttrType::INTS);
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(SpatialBN)
|
||||
.SetSupportLevel(SupportType::EXPERIMENTAL)
|
||||
.NumInputs(5)
|
||||
.NumOutputs({ 1, 5 })
|
||||
.EnforceConsumed({ {3, 1}, {4, 2} })
|
||||
.SetDoc(R"DOC(
|
||||
Carries out batch normalization as described in the paper
|
||||
https://arxiv.org/abs/1502.03167. Depending on the mode it is being run,
|
||||
there are multiple cases for the number of outputs, which we list below:
|
||||
|
||||
Output case #1: Y, mean, var, saved_mean, saved_var (training mode)
|
||||
Output case #2: Y (test mode)
|
||||
)DOC")
|
||||
.Attr("is_test",
|
||||
"If set to nonzero, run spatial batch normalization in test mode.",
|
||||
AttrType::INT)
|
||||
.Attr("epsilon",
|
||||
"The epsilon value to use to avoid division by zero.",
|
||||
AttrType::FLOAT)
|
||||
.Attr("momentum",
|
||||
"Factor used in computing the running mean and variance."
|
||||
"e.g., running_mean = running_mean * momentum + mean * (1 - momentum)",
|
||||
AttrType::FLOAT)
|
||||
.Input(0,
|
||||
"X",
|
||||
"The input 4-dimensional tensor of shape NCHW.")
|
||||
.Input(1,
|
||||
"scale",
|
||||
"The scale as a 1-dimensional tensor of size C to be applied to the "
|
||||
"output.")
|
||||
.Input(2,
|
||||
"bias",
|
||||
"The bias as a 1-dimensional tensor of size C to be applied to the "
|
||||
"output.")
|
||||
.Input(3,
|
||||
"mean",
|
||||
"The running mean (training) or the estimated mean (testing) "
|
||||
"as a 1-dimensional tensor of size C.")
|
||||
.Input(4,
|
||||
"var",
|
||||
"The running variance (training) or the estimated "
|
||||
"variance (testing) as a 1-dimensional tensor of size C.")
|
||||
.Output(0, "Y", "The output 4-dimensional tensor of the same shape as X.")
|
||||
.Output(1,
|
||||
"mean",
|
||||
"The running mean after the spatial BN operator. Must be in-place "
|
||||
"with the input mean. Should not be used for testing.")
|
||||
.Output(2,
|
||||
"var",
|
||||
"The running variance after the spatial BN operator. Must be "
|
||||
"in-place with the input var. Should not be used for testing.")
|
||||
.Output(3,
|
||||
"saved_mean",
|
||||
"Saved mean used during training to speed up gradient "
|
||||
"computation. Should not be used for testing.")
|
||||
.Output(4,
|
||||
"saved_var",
|
||||
"Saved variance used during training to speed up "
|
||||
"gradient computation. Should not be used for testing.");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(LRN)
|
||||
.SetSupportLevel(SupportType::EXPERIMENTAL)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1, 2)
|
||||
.Attr("size", "", AttrType::INT)
|
||||
.Attr("alpha", "", AttrType::FLOAT)
|
||||
.Attr("beta", "", AttrType::FLOAT)
|
||||
.Attr("bias", "", AttrType::FLOAT);
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(GivenTensorFill)
|
||||
.SetSupportLevel(SupportType::EXPERIMENTAL)
|
||||
.NumInputs(0, 1)
|
||||
.NumOutputs(1)
|
||||
.Input(0, "shape", "The shape of filled tensor")
|
||||
.Output(0, "X", "The filled tensor")
|
||||
.Attr("values", "", AttrType::FLOATS)
|
||||
.Attr("shape", "", AttrType::INTS)
|
||||
.Attr("input_as_shape", "", AttrType::INT)
|
||||
.Attr("extra_shape", "", AttrType::INTS)
|
||||
.AllowConsumed({ {0, 0} });
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(FC)
|
||||
.SetSupportLevel(SupportType::EXPERIMENTAL)
|
||||
.NumInputs(3)
|
||||
.NumOutputs(1)
|
||||
.SetDoc(R"DOC(
|
||||
Computes the result of passing an input vector X into a fully
|
||||
connected layer with 2D weight matrix W and 1D bias vector b. That is,
|
||||
the layer computes Y = X * W^T + b, where X has size (M x K),
|
||||
W has size (N x K), b has size (N), and Y has size (M x N),
|
||||
where M is often the batch size.
|
||||
NOTE: X does not need to explicitly be a 2D vector; rather, it will be
|
||||
coerced into one. For an arbitrary n-dimensional tensor
|
||||
X \in [a_0, a_1, ...,a_{k-1}, a_k, ..., a_{n-1}] where a_i \in N+ and k is
|
||||
the axis provided, then X will be coerced into a 2-dimensional tensor with
|
||||
dimensions [a_0 * ... * a_{k-1}, a_k * ... * a_{n-1}]. For the default
|
||||
case where axis=1, this means the X tensor will be coerced into a 2D tensor
|
||||
of dimensions [a_0, a_1 * ... * a_{n-1}], where a_0 is often the batch size.
|
||||
In this situation, we must have a_0 = M and a_1 * ... * a_{n-1} = K.
|
||||
Lastly, even though b is a 1D vector of size N, it is copied/resized to
|
||||
be size (M x N) implicitly and added to each vector in the batch.
|
||||
Each of these dimensions must be matched correctly, or else the operator
|
||||
will throw errors.
|
||||
)DOC")
|
||||
.Attr(
|
||||
"axis",
|
||||
"(int32_t) default to 1; describes the axis of the inputs; "
|
||||
"defaults to one because the 0th axis most likely describes "
|
||||
"the batch_size",
|
||||
AttrType::INT)
|
||||
.Attr(
|
||||
"axis_w",
|
||||
"(int32_t) default to 1; describes the axis of the weights; "
|
||||
"defaults to one because the 0th axis most likely describes "
|
||||
"the batch_size",
|
||||
AttrType::INT)
|
||||
.Input(
|
||||
0,
|
||||
"X",
|
||||
"input tensor that's coerced into a 2D matrix of size (MxK) "
|
||||
"as described above")
|
||||
.Input(
|
||||
1,
|
||||
"W",
|
||||
"2D blob of size (KxN) containing fully connected weight "
|
||||
"matrix")
|
||||
.Input(2, "b", "1D blob containing bias vector")
|
||||
.Output(0, "Y", "2D output tensor");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Normalize)
|
||||
.SetSupportLevel(SupportType::EXPERIMENTAL)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.SetDoc(R"DOC(
|
||||
Given a matrix, apply L2-normalization along the last dimension.
|
||||
)DOC");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Scale)
|
||||
.SetSupportLevel(SupportType::EXPERIMENTAL)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
Scale takes one input data (Tensor<float>) and produces one output data
|
||||
(Tensor<float>) whose value is the input data tensor scaled element-wise.
|
||||
)DOC")
|
||||
.Attr("scale",
|
||||
"(float, default 1.0) the scale to apply.",
|
||||
AttrType::FLOAT);
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(ChannelShuffle)
|
||||
.SetSupportLevel(SupportType::EXPERIMENTAL)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.Attr("kernel_shape",
|
||||
"The size of the kernel along each axis",
|
||||
AttrType::INTS)
|
||||
.Attr("group",
|
||||
"Number of channel groups",
|
||||
AttrType::INT);
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(RecurrentNetwork)
|
||||
.SetSupportLevel(SupportType::EXPERIMENTAL)
|
||||
.NumInputs(1, INT_MAX)
|
||||
.NumOutputs(2, INT_MAX)
|
||||
.SetDoc(R"DOC(
|
||||
Run the input network in a recurrent fashion. This can be used to
|
||||
implement fairly general recurrent neural networks (RNNs).
|
||||
The operator proceeds as follows.
|
||||
- First, initialized the states from the input recurrent states
|
||||
- For each timestep T, apply the links (that map offsets from input/output
|
||||
tensors into the inputs/outputs for the `step` network)
|
||||
- Finally, alias the recurrent states to the specified output blobs.
|
||||
This is a fairly special-case meta-operator, and so the implementation
|
||||
is somewhat complex. It trades of generality (and frankly usability)
|
||||
against performance and control (compared to e.g. TF
|
||||
dynamic_rnn, Theano scan, etc).
|
||||
See the usage examples for a flavor of how to use it.
|
||||
)DOC");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(GRUUnit)
|
||||
.SetSupportLevel(SupportType::EXPERIMENTAL)
|
||||
.NumInputs(4)
|
||||
.NumOutputs(1)
|
||||
.SetDoc(R"DOC(
|
||||
GRUUnit computes the activations of a standard GRU,
|
||||
in a sequence-length aware fashion.
|
||||
Concretely, given the (fused) inputs X (TxNxD), the previous hidden
|
||||
state (NxD), and the sequence lengths (N), computes the GRU
|
||||
activations, avoiding computation if the input is invalid (as in, the
|
||||
value at X[t][n] >= seqLengths[n].
|
||||
)DOC")
|
||||
.Attr(
|
||||
"drop_states",
|
||||
"Bool to determine if hidden state is zeroes or passed "
|
||||
"along for timesteps past the given sequence_length.",
|
||||
AttrType::INT)
|
||||
.Input(0, "hidden_prev", "The previous GRU hidden state.")
|
||||
.Input(
|
||||
1,
|
||||
"gates",
|
||||
"Unactivated gate outputs from forget, update, "
|
||||
"and output gates, pre-activation.")
|
||||
.Input(
|
||||
2,
|
||||
"seq_lengths",
|
||||
"Array of sequence lengths. "
|
||||
"len(seq_lengths) should equal batch size N.")
|
||||
.Input(3, "t", "The timestep for this operation.")
|
||||
.Output(0, "hidden", "The new GRU hidden state calculated by this op.");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(ATen)
|
||||
.SetSupportLevel(SupportType::EXPERIMENTAL)
|
||||
.AllowUncheckedAttributes()
|
||||
.SetDoc(R"DOC(
|
||||
Experimental allowing ATen operations to be accessed directly from Caffe2
|
||||
to allow for quick prototyping when ONNX is missing standard versions of
|
||||
and op)DOC");
|
||||
}
|
|
@ -0,0 +1,168 @@
|
|||
// Copyright (c) Facebook Inc. and Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "proto/onnx/core/op.h"
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
REGISTER_OPERATOR_SCHEMA(Constant)
|
||||
.NumInputs(0)
|
||||
.NumOutputs(1)
|
||||
.SetDoc(R"DOC(A constant tensor.)DOC")
|
||||
.Attr("value",
|
||||
"The value for the elements of the output tensor.",
|
||||
AttrType::TENSOR)
|
||||
.Output(
|
||||
0,
|
||||
"output",
|
||||
"Output tensor containing the same value of the provided tensor.");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(RandomUniform)
|
||||
.NumInputs(0)
|
||||
.NumOutputs(1)
|
||||
.SetDoc(R"DOC(
|
||||
Generate a tensor with random values drawn from a uniform distribution. The shape
|
||||
of the tensor is specified by the `shape` argument and the range by `low` and `high`.
|
||||
|
||||
The data type is specified by the 'dtype' argument. The 'dtype' argument must
|
||||
be one of the data types specified in the 'DataType' enum field in the
|
||||
TensorProto message.
|
||||
)DOC")
|
||||
.Attr(
|
||||
"low",
|
||||
"Lower boundary of the output values.",
|
||||
AttrType::FLOAT)
|
||||
.Attr(
|
||||
"high",
|
||||
"Upper boundary of the output values.",
|
||||
AttrType::FLOAT)
|
||||
.Attr(
|
||||
"seed",
|
||||
"(Optional) Seed to the random generator, if not specified we will auto generate one.",
|
||||
AttrType::FLOAT)
|
||||
.Attr(
|
||||
"dtype",
|
||||
"The data type for the elements of the output tensor.",
|
||||
AttrType::INT)
|
||||
.Attr(
|
||||
"shape",
|
||||
"The shape of the output tensor.",
|
||||
AttrType::INTS)
|
||||
.Output(
|
||||
0,
|
||||
"output",
|
||||
"Output tensor of random values drawn from uniform distribution");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(RandomNormal)
|
||||
.NumInputs(0)
|
||||
.NumOutputs(1)
|
||||
.SetDoc(R"DOC(
|
||||
Generate a tensor with random values drawn from a normal distribution. The shape
|
||||
of the tensor is specified by the `shape` argument and the parameter of the normal distribution
|
||||
specified by `mean` and `scale`.
|
||||
|
||||
The data type is specified by the 'dtype' argument. The 'dtype' argument must
|
||||
be one of the data types specified in the 'DataType' enum field in the
|
||||
TensorProto message.
|
||||
)DOC")
|
||||
.Attr(
|
||||
"mean",
|
||||
"The mean of the normal distribution.",
|
||||
AttrType::FLOAT)
|
||||
.Attr(
|
||||
"scale",
|
||||
"The standard deviation of the normal distribution.",
|
||||
AttrType::FLOAT)
|
||||
.Attr(
|
||||
"seed",
|
||||
"(Optional) Seed to the random generator, if not specified we will auto generate one.",
|
||||
AttrType::FLOAT)
|
||||
.Attr(
|
||||
"dtype",
|
||||
"The data type for the elements of the output tensor.",
|
||||
AttrType::INT)
|
||||
.Attr(
|
||||
"shape",
|
||||
"The shape of the output tensor.",
|
||||
AttrType::INTS)
|
||||
.Output(
|
||||
0,
|
||||
"output",
|
||||
"Output tensor of random values drawn from normal distribution");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(RandomUniformLike)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.SetDoc(R"DOC(
|
||||
Generate a tensor with random values drawn from a uniform distribution. The shape
|
||||
of the tensor is computed from the input argument and the range by `low` and `high`.
|
||||
|
||||
The data type is specified by the 'dtype' argument. The 'dtype' argument must
|
||||
be one of the data types specified in the 'DataType' enum field in the
|
||||
TensorProto message.
|
||||
)DOC")
|
||||
.Attr(
|
||||
"low",
|
||||
"Lower boundary of the output values.",
|
||||
AttrType::FLOAT)
|
||||
.Attr(
|
||||
"high",
|
||||
"Upper boundary of the output values.",
|
||||
AttrType::FLOAT)
|
||||
.Attr(
|
||||
"seed",
|
||||
"(Optional) Seed to the random generator, if not specified we will auto generate one.",
|
||||
AttrType::FLOAT)
|
||||
.Attr(
|
||||
"dtype",
|
||||
"(Optional) The data type for the elements of the output tensor, if not specified, we will use"
|
||||
"the data type of the input tensor.",
|
||||
AttrType::INT)
|
||||
.Input(
|
||||
0,
|
||||
"input",
|
||||
"Input tensor to provide shape information.")
|
||||
.Output(
|
||||
0,
|
||||
"output",
|
||||
"Output tensor of random values drawn from uniform distribution");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(RandomNormalLike)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.SetDoc(R"DOC(
|
||||
Generate a tensor with random values drawn from a normal distribution. The shape
|
||||
of the tensor is computed from the input argument and the parameter of the normal distribution
|
||||
specified by `mean` and `scale`.
|
||||
|
||||
The data type is specified by the 'dtype' argument. The 'dtype' argument must
|
||||
be one of the data types specified in the 'DataType' enum field in the
|
||||
TensorProto message.
|
||||
)DOC")
|
||||
.Attr(
|
||||
"mean",
|
||||
"The mean of the normal distribution.",
|
||||
AttrType::FLOAT)
|
||||
.Attr(
|
||||
"scale",
|
||||
"The standard deviation of the normal distribution.",
|
||||
AttrType::FLOAT)
|
||||
.Attr(
|
||||
"seed",
|
||||
"(Optional) Seed to the random generator, if not specified we will auto generate one.",
|
||||
AttrType::FLOAT)
|
||||
.Attr(
|
||||
"dtype",
|
||||
"(Optional) The data type for the elements of the output tensor, if not specified, we will use"
|
||||
"the data type of the input tensor.",
|
||||
AttrType::INT)
|
||||
.Input(
|
||||
0,
|
||||
"input",
|
||||
"Input tensor to provide shape information.")
|
||||
.Output(
|
||||
0,
|
||||
"output",
|
||||
"Output tensor of random values drawn from normal distribution");
|
||||
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
// Copyright (c) Facebook Inc. and Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "proto/onnx/core/op.h"
|
||||
|
||||
namespace ONNXIR {
|
||||
|
||||
std::function<void(OpSchema&)> BinaryLogicDocGenerator(const char* name) {
|
||||
return [=](OpSchema& schema) {
|
||||
std::string doc = R"DOC(
|
||||
Computes the `{name} than` elementwise logical operation between `left` and `right` input tensor.
|
||||
The result is a tensor of type integer in which `0` mean false and `1` mean true.)DOC";
|
||||
ReplaceAll(doc, "{name}", name);
|
||||
schema.NumInputs(2);
|
||||
schema.NumOutputs(1);
|
||||
schema.SetDoc(doc);
|
||||
schema.Input(0, "left", "Left input tensor for the logical operator.");
|
||||
schema.Input(1, "right", "Right input tensor for the logical operator.");
|
||||
schema.Output(0, "output", "Result tensor of type `int`, 0 mean False and 1 mean True.");
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace ONNXIR
|
|
@ -0,0 +1,385 @@
|
|||
// Copyright (c) Facebook Inc. and Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "proto/onnx/core/op.h"
|
||||
#include <functional>
|
||||
|
||||
using AttrType = ONNXIR::AttrType;
|
||||
|
||||
namespace ONNXIR {
|
||||
|
||||
const char* kBroadcastDoc = R"DOC(
|
||||
If necessary the right-hand-side argument will be broadcasted to match the
|
||||
shape of left-hand-side argument. When broadcasting is specified, the second
|
||||
tensor can either be of size 1 (a scalar value), or having its shape as a
|
||||
contiguous subset of the first tensor's shape. The starting of the mutually
|
||||
equal shape is specified by the argument "axis", and if it is not set, suffix
|
||||
matching is assumed. 1-dim expansion doesn't work yet.
|
||||
|
||||
For example, the following tensor shapes are supported (with broadcast=1):
|
||||
|
||||
shape(A) = (2, 3, 4, 5), shape(B) = (,), i.e. B is a scalar
|
||||
shape(A) = (2, 3, 4, 5), shape(B) = (5,)
|
||||
shape(A) = (2, 3, 4, 5), shape(B) = (4, 5)
|
||||
shape(A) = (2, 3, 4, 5), shape(B) = (3, 4), with axis=1
|
||||
shape(A) = (2, 3, 4, 5), shape(B) = (2), with axis=0
|
||||
|
||||
Attribute `broadcast=1` needs to be passed to enable broadcasting.
|
||||
)DOC";
|
||||
|
||||
std::function<void(OpSchema&)> MathDocGenerator(const char* name) {
|
||||
return [=](OpSchema& schema) {
|
||||
std::string doc = R"DOC(
|
||||
Performs element-wise binary {name} (with limited broadcast support).
|
||||
{broadcast_doc})DOC";
|
||||
ReplaceAll(doc, "{name}", name);
|
||||
ReplaceAll(doc, "{broadcast_doc}", kBroadcastDoc);
|
||||
schema.SetDoc(doc);
|
||||
schema.Attr("broadcast",
|
||||
"Pass 1 to enable broadcasting",
|
||||
AttrType::INT);
|
||||
schema.Attr("axis",
|
||||
"If set, defines the broadcast dimensions. See doc for details.",
|
||||
AttrType::INT);
|
||||
schema.Input(
|
||||
0,
|
||||
"A",
|
||||
"First operand, should share the type with the second operand.");
|
||||
schema.Input(
|
||||
1,
|
||||
"B",
|
||||
"Second operand. With broadcasting can be of smaller size than A. "
|
||||
"If broadcasting is disabled it should be of the same size.");
|
||||
schema.Output(0, "C", "Result, has same dimensions and type as A");
|
||||
};
|
||||
}
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Add)
|
||||
.NumInputs(2)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0}, {1, 0} })
|
||||
.FillUsing(MathDocGenerator("addition"));
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Sub)
|
||||
.NumInputs(2)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0}, {1, 0} })
|
||||
.FillUsing(MathDocGenerator("subtraction"));
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Mul)
|
||||
.NumInputs(2)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0}, {1, 0} })
|
||||
.FillUsing(MathDocGenerator("multiplication"));
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Div)
|
||||
.NumInputs(2)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0}, {1, 0} })
|
||||
.FillUsing(MathDocGenerator("division"));
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Neg)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
Neg takes one input data (Tensor<T>) and produces one output data
|
||||
(Tensor<T>) where each element flipped sign, y = -x, is applied to
|
||||
the tensor elementwise.
|
||||
)DOC")
|
||||
.Input(0, "X", "Input tensor")
|
||||
.Output(0, "Y", "Output tensor");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Abs)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
Absolute takes one input data (Tensor<T>) and produces one output data
|
||||
(Tensor<T>) where the absolute is, y = abs(x), is applied to
|
||||
the tensor elementwise.
|
||||
)DOC")
|
||||
.Input(0, "X", "Input tensor")
|
||||
.Output(0, "Y", "Output tensor");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Reciprocal)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
Reciprocal takes one input data (Tensor<T>) and produces one output data
|
||||
(Tensor<T>) where the reciprocal is, y = 1/x, is applied to
|
||||
the tensor elementwise.
|
||||
)DOC")
|
||||
.Input(0, "X", "Input tensor")
|
||||
.Output(0, "Y", "Output tensor");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Floor)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
Floor takes one input data (Tensor<T>) and produces one output data
|
||||
(Tensor<T>) where the floor is, y = floor(x), is applied to
|
||||
the tensor elementwise.
|
||||
)DOC")
|
||||
.Input(0, "X", "Input tensor")
|
||||
.Output(0, "Y", "Output tensor");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Ceil)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
Ceil takes one input data (Tensor<T>) and produces one output data
|
||||
(Tensor<T>) where the ceil is, y = ceil(x), is applied to
|
||||
the tensor elementwise.
|
||||
)DOC")
|
||||
.Input(0, "X", "Input tensor")
|
||||
.Output(0, "Y", "Output tensor");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Sqrt)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
Square root takes one input data (Tensor<T>) and produces one output data
|
||||
(Tensor<T>) where the square root is, y = x^0.5, is applied to
|
||||
the tensor elementwise. If x is negative, then it will return NaN.
|
||||
)DOC")
|
||||
.Input(0, "X", "Input tensor")
|
||||
.Output(0, "Y", "Output tensor");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Relu)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
Relu takes one input data (Tensor<T>) and produces one output data
|
||||
(Tensor<T>) where the rectified linear function, y = max(0, x), is applied to
|
||||
the tensor elementwise.
|
||||
)DOC")
|
||||
.Input(0, "X", "Input tensor")
|
||||
.Output(0, "Y", "Output tensor");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(LeakyRelu)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.Attr("alpha",
|
||||
"Coefficient of leakage",
|
||||
AttrType::FLOAT)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
LeakyRelu takes input data (Tensor<T>) and an argument alpha, and produces one
|
||||
output data (Tensor<T>) where the function `f(x) = alpha * x for x < 0`,
|
||||
`f(x) = x for x >= 0`, is applied to the data tensor elementwise.
|
||||
)DOC")
|
||||
.Input(0, "X", "Input tensor")
|
||||
.Output(0, "Y", "Output tensor");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Selu)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.Attr("alpha",
|
||||
"Coefficient of SELU default to 1.6732.",
|
||||
AttrType::FLOAT)
|
||||
.Attr("gamma",
|
||||
"Coefficient of SELU default to 1.0507.",
|
||||
AttrType::FLOAT)
|
||||
.SetDoc(R"DOC(
|
||||
Selu takes one input data (Tensor<T>) and produces one output data
|
||||
(Tensor<T>) where the scaled exponential linear unit function,
|
||||
`y = gamma * (alpha * e^x - alpha) for x <= 0`, `f(x) = gamma * x for x > 0`,
|
||||
is applied to the tensor elementwise.
|
||||
)DOC")
|
||||
.Input(0, "X", "Input tensor")
|
||||
.Output(0, "Y", "Output tensor");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Elu)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
|
||||
Elu takes one input data (Tensor<T>) and produces one output data
|
||||
(Tensor<T>) where the function `f(x) = alpha * (exp(x) - 1.) for x <
|
||||
0`, `f(x) = x for x >= 0`., is applied to the tensor elementwise.
|
||||
|
||||
)DOC")
|
||||
.Input(0, "X", "1D input tensor")
|
||||
.Output(0, "Y", "1D input tensor");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Exp)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
Calculates the exponential of the given input tensor, element-wise. This
|
||||
operation can be done in an in-place fashion too, by providing the same input
|
||||
and output blobs.
|
||||
)DOC")
|
||||
.Input(0, "input", "Input tensor")
|
||||
.Output(
|
||||
0,
|
||||
"output",
|
||||
"The exponential of the input tensor computed "
|
||||
"element-wise");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Log)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
Calculates the natural log of the given input tensor, element-wise. This
|
||||
operation can be done in an in-place fashion too, by providing the same input
|
||||
and output blobs.
|
||||
)DOC")
|
||||
.Input(0, "input", "Input tensor")
|
||||
.Output(
|
||||
0,
|
||||
"output",
|
||||
"The natural log of the input tensor computed "
|
||||
"element-wise");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Tanh)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
Calculates the hyperbolic tangent of the given input tensor element-wise. This
|
||||
operation can be done in an in-place fashion too, by providing the same input
|
||||
and output blobs.
|
||||
)DOC")
|
||||
.Input(0, "input", "1-D input tensor")
|
||||
.Output(0, "output", "The hyperbolic tangent values of the input tensor "
|
||||
"computed element-wise");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Pow)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.Attr("exponent",
|
||||
"The exponent of the power function.",
|
||||
AttrType::FLOAT)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
Pow takes input data (Tensor<T>) and an argument exponent, and
|
||||
produces one output data (Tensor<T>) where the function `f(x) = x^exponent`,
|
||||
is applied to the data tensor elementwise.
|
||||
)DOC")
|
||||
.Input(0, "X", "Input tensor of any shape")
|
||||
.Output(0, "Y", "Output tensor (same size as X)");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Dot)
|
||||
.NumInputs(2)
|
||||
.NumOutputs(1)
|
||||
.SetDoc(R"DOC(
|
||||
Apply dot product between 2 tensors. Similar to numpy implementation:
|
||||
https://docs.scipy.org/doc/numpy/reference/generated/numpy.dot.html
|
||||
)DOC")
|
||||
.Input(0, "X", "Input tensor of any shape")
|
||||
.Input(1, "Y", "Input tensor of any shape")
|
||||
.Output(0, "Z", "Output tensor the dot product between X and Y.");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(PRelu)
|
||||
.NumInputs(2)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
|
||||
PRelu takes input data (Tensor<T>) and slope tensor as input, and produces one
|
||||
output data (Tensor<T>) where the function `f(x) = slope * x for x < 0`,
|
||||
`f(x) = x for x >= 0`., is applied to the data tensor elementwise.
|
||||
|
||||
)DOC")
|
||||
.Input(0, "X", "Input tensor")
|
||||
.Input(
|
||||
1,
|
||||
"Slope",
|
||||
"Slope tensor. If `Slope` is of size 1, the value is shared"
|
||||
"across different channels")
|
||||
.Output(0, "Y", "Input tensor");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Sigmoid)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
Sigmoid takes one input data (Tensor<T>) and produces one output data
|
||||
(Tensor<T>) where the sigmoid function, y = 1 / (1 + exp(-x)), is applied to the
|
||||
tensor elementwise.
|
||||
)DOC")
|
||||
.Input(0, "X", "Input tensor")
|
||||
.Output(0, "Y", "Output tensor");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Max)
|
||||
.NumInputs(1, INT_MAX)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
Element-wise max of each of the input tensors. The first input tensor can be
|
||||
used in-place as the output tensor, in which case the max will be done in
|
||||
place and results will be accumulated in input0. All inputs and outputs must
|
||||
have the same shape and data type.
|
||||
)DOC")
|
||||
.Input(0, "data_0", "First of the input tensors. Can be inplace.")
|
||||
.Output(0, "max", "Output tensor. Same dimension as inputs.");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Min)
|
||||
.NumInputs(1, INT_MAX)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
Element-wise min of each of the input tensors. The first input tensor can be
|
||||
used in-place as the output tensor, in which case the max will be done in
|
||||
place and results will be accumulated in input0. All inputs and outputs must
|
||||
have the same shape and data type.
|
||||
)DOC")
|
||||
.Input(0, "data_0", "First of the input tensors. Can be inplace.")
|
||||
.Output(0, "max", "Output tensor. Same dimension as inputs.");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Sum)
|
||||
.NumInputs(1, INT_MAX)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
Element-wise sum of each of the input tensors. The first input tensor can be
|
||||
used in-place as the output tensor, in which case the sum will be done in
|
||||
place and results will be accumulated in input0. All inputs and outputs must
|
||||
have the same shape and data type.
|
||||
)DOC")
|
||||
.Input(0, "data_0", "First of the input tensors. Can be inplace.")
|
||||
.Output(0, "sum", "Output tensor. Same dimension as inputs.");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Softmax)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.SetDoc(R"DOC(
|
||||
The operator computes the softmax normalized values for each layer in the batch
|
||||
of the given input. The input is a 2-D tensor (Tensor<float>) of size
|
||||
(batch_size x input_feature_dimensions). The output tensor has the same shape
|
||||
and contains the softmax normalized values of the corresponding input.
|
||||
|
||||
X does not need to explicitly be a 2D vector; rather, it will be
|
||||
coerced into one. For an arbitrary n-dimensional tensor
|
||||
X \in [a_0, a_1, ..., a_{k-1}, a_k, ..., a_{n-1}] and k is
|
||||
the axis provided, then X will be coerced into a 2-dimensional tensor with
|
||||
dimensions [a_0 * ... * a_{k-1}, a_k * ... * a_{n-1}]. For the default
|
||||
case where axis=1, this means the X tensor will be coerced into a 2D tensor
|
||||
of dimensions [a_0, a_1 * ... * a_{n-1}], where a_0 is often the batch size.
|
||||
In this situation, we must have a_0 = N and a_1 * ... * a_{n-1} = D.
|
||||
Each of these dimensions must be matched correctly, or else the operator
|
||||
will throw errors.
|
||||
)DOC")
|
||||
.Attr("axis",
|
||||
"(int) default to 1; describes the axis of the inputs when coerced "
|
||||
"to 2D; defaults to one because the 0th axis most likely describes "
|
||||
"the batch_size",
|
||||
AttrType::INT)
|
||||
.Input(0, "input",
|
||||
"The input tensor that's coerced into a 2D matrix of size (NxD) "
|
||||
"as described above.")
|
||||
.Output(0, "output", "The softmax normalized output values with the same "
|
||||
"shape as input tensor.");
|
||||
|
||||
}
|
|
@ -0,0 +1,309 @@
|
|||
// Copyright (c) Facebook Inc. and Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "proto/onnx/core/op.h"
|
||||
|
||||
namespace ONNXIR {
|
||||
std::function<void(OpSchema&)> AveragePoolOpSchemaGenerator(const char* name) {
|
||||
return [=](OpSchema& schema) {
|
||||
std::string doc = R"DOC(
|
||||
{name} consumes an input tensor X and applies average pooling across the
|
||||
the tensor according to kernel sizes, stride sizes, and pad lengths.
|
||||
Average pooling consisting of averaging all values of a subset of the
|
||||
input tensor according to the kernel size and downsampling the
|
||||
data into the output tensor Y for further processing.)DOC";
|
||||
ReplaceAll(doc, "{name}", name);
|
||||
schema.SetDoc(doc);
|
||||
schema.NumInputs(1);
|
||||
schema.NumOutputs(1);
|
||||
schema.Attr("kernel_shape",
|
||||
"The size of the kernel along each axis.",
|
||||
AttrType::INTS);
|
||||
schema.Attr("strides",
|
||||
"Stride along each axis.",
|
||||
AttrType::INTS);
|
||||
schema.Attr("pads",
|
||||
"Padding along each axis, can take the value 0 (False) or non 0 (True)",
|
||||
AttrType::INTS);
|
||||
schema.Input(0,
|
||||
"X",
|
||||
"Input data tensor from the previous operator; dimensions for image case "
|
||||
"are (N x C x H x W), where N is the batch size, C is the number of channels, "
|
||||
"and H and W are the height and the width of the data. For non image case, the "
|
||||
"dimension are in the form of (N x D1 x D2 ... Dn), where N is the batch size.");
|
||||
schema.Output(0,
|
||||
"Y",
|
||||
"Output data tensor from average pooling across the input "
|
||||
"tensor. Dimensions will vary based on various kernel, stride, and pad "
|
||||
"sizes.");
|
||||
};
|
||||
}
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(AveragePool)
|
||||
.FillUsing(AveragePoolOpSchemaGenerator("AveragePool"));
|
||||
|
||||
std::function<void(OpSchema&)> MaxPoolOpSchemaGenerator(const char* name) {
|
||||
return [=](OpSchema& schema) {
|
||||
std::string doc = R"DOC(
|
||||
{name} consumes an input tensor X and applies max pooling across the
|
||||
the tensor according to kernel sizes, stride sizes, and pad lengths.
|
||||
Average pooling consisting of averaging all values of a subset of the
|
||||
input tensor according to the kernel size and downsampling the
|
||||
data into the output tensor Y for further processing.)DOC";
|
||||
ReplaceAll(doc, "{name}", name);
|
||||
schema.SetDoc(doc);
|
||||
schema.NumInputs(1);
|
||||
schema.NumOutputs(1);
|
||||
schema.Attr("kernel_shape",
|
||||
"The size of the kernel along each axis.",
|
||||
AttrType::INTS);
|
||||
schema.Attr("strides",
|
||||
"Stride along each axis.",
|
||||
AttrType::INTS);
|
||||
schema.Attr("pads",
|
||||
"Padding along each axis, can take the value 0 (False) or non 0 (True)",
|
||||
AttrType::INTS);
|
||||
schema.Attr("dilations",
|
||||
"Dilaton along each axis, 1 mean no dilation.",
|
||||
AttrType::INTS);
|
||||
schema.Input(0,
|
||||
"X",
|
||||
"Input data tensor from the previous operator; dimensions for image case "
|
||||
"are (N x C x H x W), where N is the batch size, C is the number of channels, "
|
||||
"and H and W are the height and the width of the data. For non image case, the "
|
||||
"dimension are in the form of (N x D1 x D2 ... Dn), where N is the batch size.");
|
||||
schema.Output(0,
|
||||
"Y",
|
||||
"Output data tensor from max pooling across the input "
|
||||
"tensor. Dimensions will vary based on various kernel, stride, and pad "
|
||||
"sizes.");
|
||||
};
|
||||
}
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(MaxPool)
|
||||
.FillUsing(MaxPoolOpSchemaGenerator("MaxPool"));
|
||||
|
||||
std::function<void(OpSchema&)> ConvOpSchemaGenerator(const char* filter_desc) {
|
||||
return [=](OpSchema& schema) {
|
||||
std::string doc = R"DOC(
|
||||
The convolution operator consumes an input tensor and {filter_desc}, and
|
||||
computes the output.)DOC";
|
||||
ReplaceAll(doc, "{filter_desc}", filter_desc);
|
||||
schema.SetDoc(doc);
|
||||
schema.NumInputs(2, 3);
|
||||
schema.NumOutputs(1);
|
||||
schema.Input(0,
|
||||
"X",
|
||||
"Input data tensor from previous layer; has size (N x C x H x W)"
|
||||
", where N is the batch size, C is the number of channels, and"
|
||||
" H and W are the height and width. Note that this is for the 2D image."
|
||||
"Otherwise the size is (N x D1 x D2 ... x Dn)");
|
||||
schema.Input(1,
|
||||
"filter",
|
||||
"The filter blob that will be used in the convolutions; "
|
||||
"has size (M x C x kH x kW), where C is the number of channels, "
|
||||
"and kH and kW are the height and width of the kernel.");
|
||||
schema.Output(0,
|
||||
"Y",
|
||||
"Output data tensor that contains the result of the convolution. The "
|
||||
"output dimensions are functions of the kernel size, stride size, "
|
||||
"and pad lengths.");
|
||||
schema.Attr("kernel_shape",
|
||||
"The shape of the convolution kernel.",
|
||||
AttrType::INTS);
|
||||
schema.Attr("dilations",
|
||||
"dilation value along each axis of the filter.",
|
||||
AttrType::INTS);
|
||||
schema.Attr("strides",
|
||||
"stride along each axis.",
|
||||
AttrType::INTS);
|
||||
schema.Attr("pads",
|
||||
"Padding along each axis, can take the value 0 (False) or non 0 (True)",
|
||||
AttrType::INTS);
|
||||
schema.Attr("group",
|
||||
"number of groups input channels and output channels are divided into",
|
||||
AttrType::INT);
|
||||
};
|
||||
}
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Conv)
|
||||
.FillUsing(ConvOpSchemaGenerator("a filter"));
|
||||
|
||||
|
||||
std::function<void(OpSchema&)> ConvTransposeOpSchemaGenerator(const char* filter_desc) {
|
||||
return [=](OpSchema& schema) {
|
||||
std::string doc = R"DOC(
|
||||
The convolution transpose operator consumes an input tensor and {filter_desc},
|
||||
and computes the output.)DOC";
|
||||
ReplaceAll(doc, "{filter_desc}", filter_desc);
|
||||
schema.SetDoc(doc);
|
||||
schema.NumInputs(2);
|
||||
schema.NumOutputs(1);
|
||||
schema.Input(0,
|
||||
"X",
|
||||
"Input data tensor from previous layer; has size (N x C x H x W)"
|
||||
", where N is the batch size, C is the number of channels, and"
|
||||
" H and W are the height and width. Note that this is for the 2D image."
|
||||
"Otherwise the size is (N x D1 x D2 ... x Dn)");
|
||||
schema.Input(1,
|
||||
"filter",
|
||||
"The filter blob that will be used in the convolutions; "
|
||||
"has size (M x C x kH x kW), where C is the number of channels, "
|
||||
"and kH and kW are the height and width of the kernel.");
|
||||
schema.Output(0,
|
||||
"Y",
|
||||
"Output data tensor that contains the result of the convolution. The "
|
||||
"output dimensions are functions of the kernel size, stride size, "
|
||||
"and pad lengths.");
|
||||
schema.Attr("kernel_shape",
|
||||
"The shape of the convolution kernel.",
|
||||
AttrType::INTS);
|
||||
schema.Attr("output_shape",
|
||||
"The shape of the output.",
|
||||
AttrType::INTS);
|
||||
schema.Attr("dilations",
|
||||
"dilation value along each axis of the filter.",
|
||||
AttrType::INTS);
|
||||
schema.Attr("strides",
|
||||
"stride along each axis.",
|
||||
AttrType::INTS);
|
||||
schema.Attr("pads",
|
||||
"Padding along each axis, can take the value 0 (False) or non 0 (True)",
|
||||
AttrType::INTS);
|
||||
};
|
||||
}
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(ConvTranspose)
|
||||
.FillUsing(ConvTransposeOpSchemaGenerator("a filter"));
|
||||
|
||||
|
||||
std::function<void(OpSchema&)> GlobalPoolingOpSchemaGenerator(const char* op_type, const char* op) {
|
||||
return [=](OpSchema& schema) {
|
||||
std::string doc = R"DOC(
|
||||
Global{op_type} consumes an input tensor X and applies {op} pooling across the
|
||||
the values in the same channel. This is equivalent to {op_type} with kernel size
|
||||
equal to the spatial dimension of input tensor.)DOC";
|
||||
ReplaceAll(doc, "{op_type}", op_type);
|
||||
ReplaceAll(doc, "{op}", op);
|
||||
schema.SetDoc(doc);
|
||||
schema.NumInputs(1);
|
||||
schema.NumOutputs(1);
|
||||
schema.Input(0,
|
||||
"X",
|
||||
"Input data tensor from the previous operator; dimensions for image case "
|
||||
"are (N x C x H x W), where N is the batch size, C is the number of channels, "
|
||||
"and H and W are the height and the width of the data. For non image case, the "
|
||||
"dimension are in the form of (N x D1 x D2 ... Dn), where N is the batch size.");
|
||||
schema.Output(0,
|
||||
"Y",
|
||||
"Output data tensor from pooling across the input "
|
||||
"tensor. Dimensions will be N x C x 1 x 1");
|
||||
schema.SetDoc(doc);
|
||||
};
|
||||
}
|
||||
REGISTER_OPERATOR_SCHEMA(GlobalAveragePool)
|
||||
.FillUsing(GlobalPoolingOpSchemaGenerator("AveragePool", "average"));
|
||||
REGISTER_OPERATOR_SCHEMA(GlobalMaxPool)
|
||||
.FillUsing(GlobalPoolingOpSchemaGenerator("MaxPool", "max"));
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(BatchNormalization)
|
||||
.NumInputs(5)
|
||||
.NumOutputs({ 1, 5 })
|
||||
.EnforceConsumed({ {3, 1}, {4, 2} })
|
||||
.SetDoc(R"DOC(
|
||||
Carries out batch normalization as described in the paper
|
||||
https://arxiv.org/abs/1502.03167. Depending on the mode it is being run,
|
||||
there are multiple cases for the number of outputs, which we list below:
|
||||
|
||||
Output case #1: Y, mean, var, saved_mean, saved_var (training mode)
|
||||
Output case #2: Y (test mode)
|
||||
)DOC")
|
||||
.Attr("spatial",
|
||||
"Compute the mean and variance across all spatial elements or per feature.",
|
||||
AttrType::INT)
|
||||
.Attr("is_test",
|
||||
"If set to nonzero, run spatial batch normalization in test mode.",
|
||||
AttrType::INT)
|
||||
.Attr("epsilon",
|
||||
"The epsilon value to use to avoid division by zero.",
|
||||
AttrType::FLOAT)
|
||||
.Attr("momentum",
|
||||
"Factor used in computing the running mean and variance."
|
||||
"e.g., running_mean = running_mean * momentum + mean * (1 - momentum)",
|
||||
AttrType::FLOAT)
|
||||
.Input(0,
|
||||
"X",
|
||||
"The input 4-dimensional tensor of shape NCHW or NHWC depending "
|
||||
"on the order parameter.")
|
||||
.Input(1,
|
||||
"scale",
|
||||
"The scale as a 1-dimensional tensor of size C to be applied to the "
|
||||
"output.")
|
||||
.Input(2,
|
||||
"bias",
|
||||
"The bias as a 1-dimensional tensor of size C to be applied to the "
|
||||
"output.")
|
||||
.Input(3,
|
||||
"mean",
|
||||
"The running mean (training) or the estimated mean (testing) "
|
||||
"as a 1-dimensional tensor of size C.")
|
||||
.Input(4,
|
||||
"var",
|
||||
"The running variance (training) or the estimated "
|
||||
"variance (testing) as a 1-dimensional tensor of size C.")
|
||||
.Output(0, "Y", "The output 4-dimensional tensor of the same shape as X.")
|
||||
.Output(1,
|
||||
"mean",
|
||||
"The running mean after the BatchNormalization operator. Must be in-place "
|
||||
"with the input mean. Should not be used for testing.")
|
||||
.Output(2,
|
||||
"var",
|
||||
"The running variance after the BatchNormalization operator. Must be "
|
||||
"in-place with the input var. Should not be used for testing.")
|
||||
.Output(3,
|
||||
"saved_mean",
|
||||
"Saved mean used during training to speed up gradient "
|
||||
"computation. Should not be used for testing.")
|
||||
.Output(4,
|
||||
"saved_var",
|
||||
"Saved variance used during training to speed up "
|
||||
"gradient computation. Should not be used for testing.");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Dropout)
|
||||
.NumInputs(1)
|
||||
.NumOutputs({ 1,2 })
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
Dropout takes one input data (Tensor<float>) and produces two Tensor outputs,
|
||||
output (Tensor<float>) and mask (Tensor<bool>). Depending on whether it is in
|
||||
test mode or not, the output Y will either be a random dropout, or a simple
|
||||
copy of the input. Note that our implementation of Dropout does scaling in
|
||||
the training phase, so during testing nothing needs to be done.
|
||||
)DOC")
|
||||
.Attr("ratio",
|
||||
"(float, default 0.5) the ratio of random dropout",
|
||||
AttrType::FLOAT)
|
||||
.Attr("is_test",
|
||||
"(int, default 0) if nonzero, run dropout in test mode where "
|
||||
"the output is simply Y = X.",
|
||||
AttrType::INT)
|
||||
.Input(0, "data", "The input data as Tensor.")
|
||||
.Output(0, "output", "The output.")
|
||||
.Output(1, "mask",
|
||||
"The output mask. If is_test is nonzero, this output is not filled.");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Flatten)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.SetDoc(R"DOC(
|
||||
Flattens the input tensor into a 2D matrix, keeping the first dimension
|
||||
unchanged.
|
||||
)DOC")
|
||||
.Input(0, "input", "A tensor of rank >= 2.")
|
||||
.Output(
|
||||
0,
|
||||
"output",
|
||||
"A tensor of rank 2 with the contents of the input tensor, "
|
||||
"with first dimension equal first dimension of input, and remaining "
|
||||
"input dimensions flatenned into the inner dimension of the output.");
|
||||
}
|
|
@ -0,0 +1,91 @@
|
|||
// Copyright (c) Facebook Inc. and Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "proto/onnx/core/op.h"
|
||||
#include <functional>
|
||||
|
||||
namespace ONNXIR {
|
||||
|
||||
std::function<void(OpSchema&)> ReduceDocGenerator(const char* name) {
|
||||
return [=](OpSchema& schema) {
|
||||
std::string doc = R"DOC(
|
||||
Computes the {name} of the input tensor's element along the provided axes. The resulted
|
||||
tensor has the same shape as the input if keepdims equal 1. If keepdims equal 0, then
|
||||
the resulted tensor have the reduced dimension pruned.
|
||||
|
||||
The above behavior is similar to numpy, with the exception that numpy default keepdims to
|
||||
False instead of True.)DOC";
|
||||
ReplaceAll(doc, "{name}", name);
|
||||
schema.SetDoc(doc);
|
||||
schema.Attr("axes",
|
||||
"A list of integers, along which to reduce max.",
|
||||
AttrType::INTS);
|
||||
schema.Attr("keepdims",
|
||||
"Keep the reduced dimension or not, default 1 mean keep reduced dimension.",
|
||||
AttrType::INT);
|
||||
schema.Input(0, "data", "An input tensor.");
|
||||
schema.Output(0, "reduced", "Reduced output tensor.");
|
||||
};
|
||||
}
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(ReduceMax)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.FillUsing(ReduceDocGenerator("max"));
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(ReduceMin)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.FillUsing(ReduceDocGenerator("min"));
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(ReduceSum)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.FillUsing(ReduceDocGenerator("sum"));
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(ReduceMean)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.FillUsing(ReduceDocGenerator("mean"));
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(ReduceProd)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.FillUsing(ReduceDocGenerator("product"));
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(ReduceLogSumExp)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.FillUsing(ReduceDocGenerator("log sum exponent"));
|
||||
|
||||
std::function<void(OpSchema&)> ArgReduceDocGenerator(const char* name) {
|
||||
return [=](OpSchema& schema) {
|
||||
std::string doc = R"DOC(
|
||||
Computes the indices of the {name} elements of the input tensor's element along the
|
||||
provided axes. The resulted tensor has the same shape as the input if keepdims equal 1.
|
||||
If keepdims equal 0, then the resulted tensor have the reduced dimension pruned.
|
||||
The type of the output tensor is integer.)DOC";
|
||||
ReplaceAll(doc, "{name}", name);
|
||||
schema.SetDoc(doc);
|
||||
schema.Attr("axes",
|
||||
"A list of integers, along which to reduce max.",
|
||||
AttrType::INTS);
|
||||
schema.Attr("keepdims",
|
||||
"Keep the reduced dimension or not, default 1 mean keep reduced dimension.",
|
||||
AttrType::INT);
|
||||
schema.Input(0, "data", "An input tensor.");
|
||||
schema.Output(0, "reduced", "Reduced output tensor with integer data type.");
|
||||
};
|
||||
}
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(ArgMax)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.FillUsing(ArgReduceDocGenerator("max"));
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(ArgMin)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.FillUsing(ArgReduceDocGenerator("min"));
|
||||
|
||||
} // namespace ONNXIR
|
|
@ -0,0 +1,6 @@
|
|||
// Copyright (c) Facebook Inc. and Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "proto/onnx/core/op.h"
|
||||
|
||||
|
|
@ -0,0 +1,162 @@
|
|||
// Copyright (c) Facebook Inc. and Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "proto/onnx/core/op.h"
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
REGISTER_OPERATOR_SCHEMA(Cast)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.SetDoc(R"DOC(
|
||||
The operator casts the elements of a given input tensor to a data type
|
||||
specified by the 'to' argument and returns an output tensor of the same size in
|
||||
the converted type. The 'to' argument must be one of the data types specified
|
||||
in the 'DataType' enum field in the TensorProto message. If the 'to' argument
|
||||
is not provided or is not one of the enumerated types in DataType, Caffe2
|
||||
throws an Enforce error.
|
||||
|
||||
NOTE: Casting to and from strings is not supported yet.
|
||||
)DOC")
|
||||
.Attr(
|
||||
"to",
|
||||
"The data type to which the elements of the input tensor are cast."
|
||||
"Strictly must be one of the types from DataType enum in TensorProto",
|
||||
AttrType::STRING)
|
||||
.Input(0, "input", "Input tensor to be cast.")
|
||||
.Output(
|
||||
0,
|
||||
"output",
|
||||
"Output tensor with the same shape as input with type "
|
||||
"specified by the 'to' argument");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Reshape)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
Reshape the input tensor similar to numpy.reshape.
|
||||
|
||||
It takes a tensor as input and an argument `shape`. It outputs the reshaped tensor.
|
||||
|
||||
At most one dimension of the new shape can be -1. In this case, the value is
|
||||
inferred from the size of the tensor and the remaining dimensions. A dimension
|
||||
could also be 0, in which case the actual dimension value is going to be copied
|
||||
from the shape argument.)DOC")
|
||||
.Attr("shape", "New shape", AttrType::INTS)
|
||||
.Input(0, "data", "An input tensor.")
|
||||
.Output(0, "reshaped", "Reshaped data.");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Concat)
|
||||
.NumInputs(1, INT_MAX)
|
||||
.NumOutputs(2)
|
||||
.Attr("axis",
|
||||
"Which axis to concat on",
|
||||
AttrType::INT)
|
||||
.SetDoc("Concatenate a list of tensors into a single tensor")
|
||||
.Output(0, "concat_result", "Concatenated tensor");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Split)
|
||||
.NumInputs(1, 2)
|
||||
.NumOutputs(1, INT_MAX)
|
||||
.Input(0, "input", "The tensor to split")
|
||||
.Input(1, "split", "Optional list of output lengths (see also arg 'split')")
|
||||
.Attr("axis",
|
||||
"Which axis to split on",
|
||||
AttrType::INT)
|
||||
.Attr("split",
|
||||
"length of each output",
|
||||
AttrType::INTS)
|
||||
.SetDoc(R"DOC(Split a tensor into a list of tensors, along the specified
|
||||
'axis'. The lengths of the split can be specified using argument 'axis' or
|
||||
optional second input blob to the operator. Otherwise, the tensor is split
|
||||
to equal sized parts.
|
||||
)DOC");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Slice)
|
||||
.NumInputs(1, 3)
|
||||
.NumOutputs(1)
|
||||
.SetDoc(R"DOC(
|
||||
Produces a slice of the input tensor along multiple axes. Similar to numpy:
|
||||
https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html
|
||||
|
||||
Slices are passed as two keyword argument lists with starting and end indices
|
||||
for each dimension of the input `data` tensor. If a negative value is passed
|
||||
for any of the start or end indices, it represent number of elements before
|
||||
the end of that dimension.
|
||||
|
||||
`strides` is the step sizes when applying slicing, negative value means in
|
||||
reverse order.
|
||||
)DOC")
|
||||
.Input(0, "data", "Tensor of data to extract slices from.")
|
||||
.Attr("starts",
|
||||
"List of starting indices",
|
||||
AttrType::INTS)
|
||||
.Attr("ends",
|
||||
"List of ending indices",
|
||||
AttrType::INTS)
|
||||
.Output(0, "output", "Sliced data tensor.");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Transpose)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.SetDoc(R"DOC(
|
||||
Transpose the input tensor similar to numpy.transpose. For example, when
|
||||
axes=(1, 0, 2), given an input tensor of shape (1, 2, 3), the output shape
|
||||
will be (2, 1, 3).
|
||||
)DOC")
|
||||
.Attr("perm",
|
||||
"A list of integers. By default, reverse the dimensions, "
|
||||
"otherwise permute the axes according to the values given.",
|
||||
AttrType::INTS)
|
||||
.Input(0, "data", "An input tensor.")
|
||||
.Output(0, "transposed", "Transposed output.");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Gather)
|
||||
.NumInputs(2)
|
||||
.NumOutputs(1)
|
||||
.SetDoc(R"DOC(
|
||||
Given DATA tensor of rank r >= 1, and INDICES tensor of rank q, gather
|
||||
entries of the outer-most dimension of DATA indexed by INDICES, and concatenate
|
||||
them in an output tensor of rank q + (r - 1).
|
||||
|
||||
Example:
|
||||
DATA = [
|
||||
[1.0, 1.2],
|
||||
[2.3, 3.4],
|
||||
[4.5, 5.7],
|
||||
]
|
||||
INDICES = [
|
||||
[0, 1],
|
||||
[1, 2],
|
||||
]
|
||||
OUTPUT = [
|
||||
[
|
||||
[1.0, 1.2],
|
||||
[2.3, 3.4],
|
||||
],
|
||||
[
|
||||
[2.3, 3.4],
|
||||
[4.5, 5.7],
|
||||
],
|
||||
]
|
||||
)DOC")
|
||||
.Input(0, "DATA", "Tensor of rank r >= 1.")
|
||||
.Input(1, "INDICES", "Tensor of int32/int64 indices, of any rank q.")
|
||||
.Output(0, "OUTPUT", "Tensor of rank q + (r - 1).");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Squeeze)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.Attr("axes",
|
||||
"List of positive integers, indicate the dimensions to squeeze.",
|
||||
AttrType::INTS,
|
||||
true)
|
||||
.SetDoc(R"DOC(
|
||||
Remove single-dimensional entries from the shape of a tensor.
|
||||
Takes a parameter `axes` with a list of axes to squeeze.
|
||||
)DOC")
|
||||
.Input(0, "data", "Tensors with at least max(dims) dimensions.")
|
||||
.Output(0, "squeezed", "Reshaped tensor with same data as input.");
|
||||
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable : 4800 4610 4512 4510 4267 4127 4125 4100 4456)
|
||||
#include "graph.pb.cc"
|
||||
#pragma warning(pop)
|
|
@ -0,0 +1,613 @@
|
|||
syntax = "proto2";
|
||||
|
||||
package ONNXIR;
|
||||
|
||||
// Note [Protobuf compatibility]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// Based on experience working with downstream vendors, we generally can't
|
||||
// assume recent versions of protobufs. This means that we do not use any
|
||||
// protobuf features that are only available in proto3.
|
||||
//
|
||||
// Here are the most notable contortions we have to carry out to work around
|
||||
// these limitations:
|
||||
//
|
||||
// - No 'map' (added protobuf 3.0). We instead represent mappings as lists
|
||||
// of key-value pairs, where order does not matter and duplicates
|
||||
// are not allowed.
|
||||
|
||||
// Note [Namespaces]
|
||||
// ~~~~~~~~~~~~~~~~~
|
||||
// LotusIR gives explicit names to graphs, intermediate values and
|
||||
// serialized tensors. To make it easier to generate names, we organize
|
||||
// these into separate namespaces (so, e.g., a graph can have the same
|
||||
// name as a serialized tensor.) The namespaces are as follows:
|
||||
//
|
||||
// - Node: These names identify specific nodes in the graph (but not, necessarily
|
||||
// any particular input or output of the node.
|
||||
// - Graph: These names identify graphs in the protobuf.
|
||||
// - Attribute: These names identify attribute names for extra attributes that
|
||||
// are passed to operators.
|
||||
// - OperatorOrFunction: These names identify particular operators and
|
||||
// functions.
|
||||
// - Value: These names identify intermediate values (typically tensors) flowing through
|
||||
// the computation of a graph.
|
||||
// - Shape: These names represent parameters for unknown shape dimensions.
|
||||
//
|
||||
// We specify the namespace of a name in LotusIR as comments in the form
|
||||
// of "namespace {Node,Graph,OperatorOrFunction,Attribute,Value,Shape}". Framework is
|
||||
// responsible for supporting the namespaces.
|
||||
|
||||
// To be compatible with both proto2 and proto3, we will use a version number
|
||||
// that is not defined by the default value but an explicit enum number.
|
||||
enum Version {
|
||||
// The version field is always serialized and we will use it to store the
|
||||
// version that the graph is generated from. This helps us set up version
|
||||
// control. We should use version as
|
||||
// xx(major) - xx(minor) - xxxx(bugfix)
|
||||
// and we are starting with 00000001.
|
||||
IR_VERSION = 00000001;
|
||||
}
|
||||
|
||||
// A named attribute containing either singular float, integer, string
|
||||
// and tensor values, or repeated float, integer, string and tensor values.
|
||||
// An AttributeProto MUST contain the name field, and *only one* of the
|
||||
// following content fields, effectively enforcing a C/C++ union equivalent.
|
||||
message AttributeProto {
|
||||
// The name field MUST be present for this version of the IR.
|
||||
optional string name = 1; // namespace Attribute
|
||||
optional float f = 2; // float
|
||||
optional int64 i = 3; // int
|
||||
optional bytes s = 4; // UTF-8 string
|
||||
optional TensorProto t = 5; // tensor value
|
||||
optional GraphProto g = 6; // graph
|
||||
|
||||
repeated float floats = 7; // list of floats
|
||||
repeated int64 ints = 8; // list of ints
|
||||
repeated bytes strings = 9; // list of UTF-8 strings
|
||||
repeated TensorProto tensors = 10; // list of tensors
|
||||
repeated GraphProto graphs = 11; // list of graph
|
||||
|
||||
optional TypeProto type = 51;
|
||||
repeated TypeProto types = 52;
|
||||
//ISSUE:13807134,dbox: Do we ever see shape showing up as an attribute value?
|
||||
// If so, won't it always be accompanied by a TypeProto?
|
||||
optional TypeProto.TensorShapeProto shape = 53;
|
||||
repeated TypeProto.TensorShapeProto shapes = 54;
|
||||
}
|
||||
|
||||
// Defines information on value, including the name, the type, and
|
||||
// the shape of the value.
|
||||
message ValueInfoProto {
|
||||
// This field MUST be present in this version of the IR.
|
||||
optional string name = 1; // namespace Value
|
||||
// This field MUST be present in this version of the IR.
|
||||
optional TypeProto type = 2;
|
||||
}
|
||||
|
||||
// Defines a node in a computation graph. Each graph node is either an
|
||||
// operator or a function call. A node that is similar to the notion of "layer"
|
||||
// or "operator" in many deep learning frameworks. For example, it can be a
|
||||
// node of type "Conv" that takes in an image, a filter tensor and a bias
|
||||
// tensor, and produces the convolved output.
|
||||
//
|
||||
// NOTE: Control flow is defined by two built-in operators:
|
||||
//
|
||||
// Cond(p, true_input, false_input) takes three inputs, where p is a
|
||||
// boolean scalar tensor, true_input is the list of inputs to the true
|
||||
// branch of cond, and false_input is the list of inputs to the false
|
||||
// branch of cond. The true and false branches are defined as
|
||||
// functions that takes true_input and false_input as inputs respectively.
|
||||
// The two functions must have the same number of outputs, and each
|
||||
// corresponding output must have the same types, and have compatible
|
||||
// shapes.
|
||||
//
|
||||
// While(vars, consts) takes two inputs, where vars are the initial
|
||||
// values of the loop variables and consts are the values of constants
|
||||
// used inside the loop. The loop condition and loop body are defined
|
||||
// as functions. The functions take both vars and consts as inputs.
|
||||
// The loop condition function returns a boolean scalar tensor. The
|
||||
// loop body function has the form: body(vars, consts) = new_vars,
|
||||
// where new_vars are the new values of the loop variables after one
|
||||
// iteration so must match vars in terms of types and shapes.
|
||||
message NodeProto {
|
||||
// The named inputs of the node.
|
||||
repeated string input = 1; // namespace Value
|
||||
|
||||
// The named outputs of the node.
|
||||
repeated string output = 2; // namespace Value
|
||||
|
||||
// The name of this node.
|
||||
// This field is optional and used to uniquely identify nodes in the graph.
|
||||
optional string name = 3; // namespace Node
|
||||
|
||||
// The name of the operator/function called by the node.
|
||||
// This field MUST be present for this version of the IR.
|
||||
optional string op_type = 4; // namespace OperatorOrFunction
|
||||
|
||||
// Additional named attributes.
|
||||
repeated AttributeProto attribute = 5;
|
||||
|
||||
// An optional human-readable documentation for this node in the graph.
|
||||
// This text MAY contain Markdown markup that conforms to http://commonmark.org/.
|
||||
optional string doc_string = 6;
|
||||
|
||||
// The number of inputs for each argument of the operator/function.
|
||||
// A formal parameter of the op may take a variable number of inputs
|
||||
// that is only known when this node is constructed.
|
||||
//BUG:13806939,dbox: I'm assuming that this field is like input_arg_info in that
|
||||
// a zero element/missing array implies that one needs to crawl
|
||||
// the graph to figure out the input counts, yes? Confirm and I'll
|
||||
// make clear. Otherwise, we need to require it to be present
|
||||
// and accurate.
|
||||
repeated int32 input_arg_count = 50;
|
||||
|
||||
// Specify a list of named nodes that must be executed before this node.
|
||||
// Framework may use this to give users the ability to impose additional
|
||||
// execution orders for the operations.
|
||||
repeated string control_input = 51;
|
||||
}
|
||||
|
||||
// ModelProto is a top-level file/container format for bundling a ML model.
|
||||
// The semantics of the model are described by the GraphProto that represents
|
||||
// a parameterized computation graph against a set of named operators that are
|
||||
// defined independently from the graph.
|
||||
message ModelProto {
|
||||
// The version of the IR this model targets. See Version enum above.
|
||||
// This field MUST be present.
|
||||
optional int64 ir_version = 1;
|
||||
|
||||
// The name of the framework or tool used to generate this model.
|
||||
// This field SHOULD be present to indicate which implementation/tool/framework
|
||||
// emitted the model.
|
||||
optional string producer_name = 2;
|
||||
|
||||
// The version of the framework or tool used to generate this model.
|
||||
// This field SHOULD be present to indicate which implementation/tool/framework
|
||||
// emitted the model.
|
||||
optional string producer_version = 3;
|
||||
|
||||
// Domain name of the model.
|
||||
// We use reverse domain names as name space indicators. For example:
|
||||
// `com.facebook.fair` or `com.microsoft.cognitiveservices`
|
||||
//
|
||||
// Together with `model_version` and GraphProto.name, this forms the unique identity of
|
||||
// the graph.
|
||||
optional string domain = 4;
|
||||
|
||||
// The version of the graph encoded. See Version enum below.
|
||||
optional int64 model_version = 5;
|
||||
|
||||
// A human-readable documentation for this model. Markdown is allowed.
|
||||
optional string doc_string = 6;
|
||||
|
||||
// The parameterized graph that is evaluated to execute the model.
|
||||
optional GraphProto graph = 7;
|
||||
|
||||
// NOTE: ids between 8 and 49 are reserved for more ONNX fields.
|
||||
|
||||
// The optional name of the author who created the graph.
|
||||
optional string model_author = 50;
|
||||
|
||||
// Optional licensing information concerning use or origination of the graph.
|
||||
// This text MAY contain Markdown markup that conforms to http://commonmark.org/.
|
||||
optional string model_license = 51;
|
||||
};
|
||||
|
||||
// GraphProto defines a parameterized series of nodes to form a directed acyclic graph.
|
||||
// This is the equivalent of the "network" and "graph" in many deep learning
|
||||
// frameworks.
|
||||
// All the input/output tensors are explicitly named so a framework can
|
||||
// run any subgraph of the graph by feeding and fetching the named tensors.
|
||||
message GraphProto {
|
||||
// The nodes in the graph.
|
||||
repeated NodeProto node = 1;
|
||||
|
||||
// The name of the graph.
|
||||
optional string name = 2; // namespace Graph
|
||||
|
||||
// A list of named tensor values (constants), used to specify default
|
||||
// values for some of the inputs of the graph.
|
||||
// Each TensorProto entry must have a distinct name (within the list) that
|
||||
// also appears in the input list.
|
||||
// In an evaluation, the default value specified here is used if and only if
|
||||
// user specifies no value for the corresponding input parameter.
|
||||
// May be used to pass serialized parameters for networks.
|
||||
repeated TensorProto initializer = 5;
|
||||
|
||||
// A human-readable documentation for this graph. Markdown is allowed.
|
||||
optional string doc_string = 10;
|
||||
|
||||
// The inputs and outputs of the graph.
|
||||
repeated ValueInfoProto input = 11;
|
||||
repeated ValueInfoProto output = 12;
|
||||
|
||||
// Information for the values in the graph. The ValueInfoProto.name's
|
||||
// must be distinct. It is optional for a value to appear in value_info list.
|
||||
repeated ValueInfoProto value_info = 13;
|
||||
|
||||
// DO NOT USE the following fields which were deprecated.
|
||||
// repeated string input = 3;
|
||||
// repeated string output = 4;
|
||||
// optional int64 ir_version = 6;
|
||||
// optional int64 producer_version = 7;
|
||||
// optional string producer_tag = 8;
|
||||
// optional string domain = 9;
|
||||
|
||||
// The function definitions of the graph. They can only only be used
|
||||
// (i.e., called) in this graph.
|
||||
// Each FunctionDefProto in function MUST have a unique name.
|
||||
repeated FunctionDefProto function = 50;
|
||||
|
||||
// The externally defined operators declared by this graph.
|
||||
repeated OperatorDeclProto operator = 51;
|
||||
|
||||
// TODO: When the map type is added, provide for the "model_information"
|
||||
// field which holds name/value pairs of strings with additional devops
|
||||
// metadata, such as an identifier for which training set this instance
|
||||
// of a graph was trained with.
|
||||
|
||||
// Imported libraries are referenced as a collection of strings in the form of absolute
|
||||
// URIs or relative paths. Where such relative paths are rooted is defined by tools and
|
||||
// runtime implementations.
|
||||
repeated string imported_libraries = 52;
|
||||
|
||||
reserved 100 to 200; // for future extensions.
|
||||
}
|
||||
|
||||
// A message defined to store a tensor in its serialized format.
|
||||
message TensorProto {
|
||||
enum DataType {
|
||||
UNDEFINED = 0;
|
||||
// Basic types.
|
||||
FLOAT = 1; // float
|
||||
UINT8 = 2; // uint8_t
|
||||
INT8 = 3; // int8_t
|
||||
UINT16 = 4; // uint16_t
|
||||
INT16 = 5; // int16_t
|
||||
INT32 = 6; // int32_t
|
||||
INT64 = 7; // int64_t
|
||||
STRING = 8; // string
|
||||
BOOL = 9; // bool
|
||||
|
||||
// Advanced types
|
||||
FLOAT16 = 10;
|
||||
DOUBLE = 11;
|
||||
UINT32 = 12;
|
||||
UINT64 = 13;
|
||||
COMPLEX64 = 14; // complex with float32 real and imaginary components
|
||||
COMPLEX128 = 15; // complex with float64 real and imaginary components
|
||||
// Future extensions go here.
|
||||
}
|
||||
|
||||
// The shape of the tensor.
|
||||
repeated int64 dims = 1;
|
||||
|
||||
// The data type of the tensor.
|
||||
optional DataType data_type = 2;
|
||||
|
||||
// For very large tensors, we may want to store them in chunks, in which
|
||||
// case the following fields will specify the segment that is stored in
|
||||
// the current TensorProto.
|
||||
message Segment {
|
||||
optional int64 begin = 1;
|
||||
optional int64 end = 2;
|
||||
}
|
||||
optional Segment segment = 3;
|
||||
|
||||
// Tensor content must be in the row major order.
|
||||
//
|
||||
// Depending on the data_type field, exactly one of the fields below with
|
||||
// name ending in _data is used to store the elements of the tensor.
|
||||
|
||||
// For float and complex64 values
|
||||
// Complex64 tensors are encoded as a single array of floats,
|
||||
// with the real components appearing in odd numbered positions,
|
||||
// and the corresponding imaginary component apparing in the
|
||||
// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
|
||||
// is encoded as [1.0, 2.0 ,3.0 ,4.0]
|
||||
// When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
|
||||
repeated float float_data = 4 [packed = true];
|
||||
|
||||
// For int32, uint8, int8, uint16, int16, bool, and float16 values
|
||||
// float16 values must be bit-wise converted to an uint16_t prior
|
||||
// to writing to the buffer.
|
||||
// When this field is present, the data_type field MUST be
|
||||
// INT32, INT16, INT8, UINT16, INT8, BOOL, or FLOAT32
|
||||
repeated int32 int32_data = 5 [packed = true];
|
||||
|
||||
// For strings.
|
||||
// Each element of string_data is a UTF-8 encoded Unicode
|
||||
// string. No trailing null, no leading BOM. The protobuf "string"
|
||||
// scalar type is not used to match ML community conventions.
|
||||
// When this field is present, the data_type field MUST be STRING
|
||||
repeated bytes string_data = 6;
|
||||
|
||||
// For int64.
|
||||
// When this field is present, the data_type field MUST be INT64
|
||||
repeated int64 int64_data = 7 [packed = true];
|
||||
|
||||
// Optionally, a name for the tensor.
|
||||
optional string name = 8; // namespace Value
|
||||
|
||||
// Serializations can either use one of the fields above, or use this
|
||||
// raw bytes field. The only exception is the string case, where one is
|
||||
// required to store the content in the repeated bytes string_data field.
|
||||
//
|
||||
// When this raw_data field is used to store tensor value, elements MUST
|
||||
// be stored in as fixed-width, little-endian order.
|
||||
// Floating-point data types MUST be stored in IEEE 754 format.
|
||||
// Complex64 elements must be written as two consecutive FLOAT values, real component first.
|
||||
// Complex128 elements must be written as two consecutive DOUBLE values, real component first.
|
||||
// Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
|
||||
//
|
||||
// Note: the advantage of specific field rather than the raw_data field is
|
||||
// that in some cases (e.g. int data), protobuf does a better packing via
|
||||
// variable length storage, and may lead to smaller binary footprint.
|
||||
// When this field is present, the data_type field MUST NOT be STRING or UNDEFINED
|
||||
optional bytes raw_data = 9;
|
||||
|
||||
// For double
|
||||
// Complex64 tensors are encoded as a single array of doubles,
|
||||
// with the real components appearing in odd numbered positions,
|
||||
// and the corresponding imaginary component apparing in the
|
||||
// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
|
||||
// is encoded as [1.0, 2.0 ,3.0 ,4.0]
|
||||
// When this field is present, the data_type field MUST be DOUBLE or COMPLEX128
|
||||
repeated double double_data = 10 [packed = true];
|
||||
|
||||
// For uint64 and uint32 values
|
||||
// When this field is present, the data_type field MUST be
|
||||
// UINT32 or UINT64
|
||||
repeated uint64 uint64_data = 11 [packed = true];
|
||||
}
|
||||
|
||||
// A sparse tensor must be stored as three dense tensors:
|
||||
// 1. dims: The shape of the original dense tensor.
|
||||
// 2. indices: A 2-D tensor specifying the indices of the nonzero elements.
|
||||
// 3. values: A 1-D tensor containing the values of the nonzero elements.
|
||||
message SparseTensorProto {
|
||||
// The dimensions in the tensor.
|
||||
repeated int64 dims = 1;
|
||||
// This field MUST be present this version of the IR.
|
||||
optional TensorProto indices = 2;
|
||||
// This field MUST be present this version of the IR.
|
||||
optional TensorProto values = 3;
|
||||
}
|
||||
|
||||
// Define the types.
|
||||
message TypeProto {
|
||||
// Defines a tensor shape. A dimension can be either an integer value
|
||||
// or a symbolic variable. A symbolic variable represents an unknown
|
||||
// dimension.
|
||||
message TensorShapeProto {
|
||||
message Dimension {
|
||||
oneof value {
|
||||
int64 dim_value = 1;
|
||||
string dim_param = 2; // namespace Shape
|
||||
}
|
||||
}
|
||||
repeated Dimension dim = 1;
|
||||
}
|
||||
|
||||
message TensorTypeProto {
|
||||
// This field MUST NOT have the value of UNDEFINED
|
||||
// This field MUST be present for this version of the IR.
|
||||
optional TensorProto.DataType elem_type = 1;
|
||||
optional TensorShapeProto shape = 2;
|
||||
}
|
||||
|
||||
message SparseTensorTypeProto {
|
||||
// This field MUST NOT have the value of UNDEFINED
|
||||
// This field MUST be present for this version of the IR.
|
||||
optional TensorProto.DataType elem_type = 1;
|
||||
optional TensorShapeProto shape = 2;
|
||||
}
|
||||
|
||||
message HandleTypeProto {
|
||||
}
|
||||
|
||||
message TupleTypeProto {
|
||||
repeated TypeProto elem_type = 1;
|
||||
}
|
||||
|
||||
message SeqTypeProto {
|
||||
// This field MUST be present for this version of the IR.
|
||||
optional TypeProto elem_type = 1;
|
||||
}
|
||||
|
||||
message MapTypeProto {
|
||||
// This field MUST be present for this version of the IR.
|
||||
// This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING
|
||||
optional TensorProto.DataType key_type = 1;
|
||||
// This field MUST be present for this version of the IR.
|
||||
// This field MUST NOT refer to UNDEFINED
|
||||
optional TensorProto.DataType value_type = 2;
|
||||
}
|
||||
|
||||
oneof value {
|
||||
// The type of a tensor.
|
||||
TensorTypeProto tensor_type = 1;
|
||||
|
||||
// The type of a sparse tensor.
|
||||
SparseTensorTypeProto sparse_tensor_type = 2;
|
||||
|
||||
// The type of an opaque handle. A handle is used to represent a
|
||||
// reference to a resource managed by the framework runtime.
|
||||
HandleTypeProto handle_type = 3;
|
||||
|
||||
// The type of a tuple.
|
||||
TupleTypeProto tuple_type = 4;
|
||||
|
||||
// The type of a sequence.
|
||||
SeqTypeProto seq_type = 5;
|
||||
|
||||
// The type of a map.
|
||||
MapTypeProto map_type = 6;
|
||||
}
|
||||
}
|
||||
|
||||
message ValueProto {
|
||||
// Defines a handle in its serialized format.
|
||||
message HandleProto {
|
||||
// This field MUST be present this version of the IR.
|
||||
optional int64 uid = 1;
|
||||
|
||||
// More information to be added. We need to specify the device
|
||||
// that the resource managed by the handle is on.
|
||||
}
|
||||
|
||||
// Defines a tuple in its serialized format.
|
||||
message TupleProto {
|
||||
repeated ValueProto elems = 1;
|
||||
}
|
||||
|
||||
// Defines a sequence in its serialized format.
|
||||
message SequenceProto {
|
||||
repeated ValueProto elems = 1;
|
||||
}
|
||||
|
||||
// Defines a map in its serialized format.
|
||||
// Maps are serialized as two single-dimensional tensors
|
||||
// for storage efficiency. The dimensions of each tensor MUST be identical
|
||||
// and the key at position N corresponds to the value at position N.
|
||||
// Keys SHOULD be unique. When a given key appears multiple times,
|
||||
// the value that corresponds last occurance of the key is the value.
|
||||
// This is consistent with protobuf3 encoding rules for map.
|
||||
message MapProto {
|
||||
// This field MUST be present for this version of the IR.
|
||||
// The data type of the tensor MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING
|
||||
optional TensorProto keys = 1;
|
||||
|
||||
// This field MUST be present for this version of the IR.
|
||||
optional TensorProto values = 2;
|
||||
}
|
||||
|
||||
oneof value {
|
||||
// A dense tensor.
|
||||
TensorProto dense_tensor = 1;
|
||||
|
||||
// A sparse tensor.
|
||||
SparseTensorProto sparse_tensor = 2;
|
||||
|
||||
// A handle.
|
||||
HandleProto handle = 3;
|
||||
|
||||
// A tuple.
|
||||
TupleProto tuple = 4;
|
||||
|
||||
// A sequence.
|
||||
SequenceProto seq = 5;
|
||||
|
||||
// A map.
|
||||
MapProto map = 6;
|
||||
}
|
||||
}
|
||||
|
||||
message ParameterDeclProto {
|
||||
optional string name = 1;
|
||||
optional TypeProto type = 2;
|
||||
// An optional human-readable documentation for this parameter.
|
||||
optional string doc_string = 3;
|
||||
}
|
||||
|
||||
// Defines a function.
|
||||
message FunctionDefProto {
|
||||
// The name of the function.
|
||||
// This field MUST be present for this version of the IR.
|
||||
optional string name = 1;
|
||||
|
||||
// The input parameters of the function.
|
||||
repeated ParameterDeclProto input_params = 2;
|
||||
|
||||
// The output parameters of the function.
|
||||
repeated ParameterDeclProto output_params = 3;
|
||||
|
||||
// The body of the function.
|
||||
repeated NodeProto node = 4;
|
||||
|
||||
// The named attributes of the function.
|
||||
repeated AttributeProto attr = 5;
|
||||
}
|
||||
|
||||
message SignatureDeclProto {
|
||||
// The formal input parameters to the operation or function
|
||||
repeated ParameterDeclProto input_params = 1;
|
||||
// The formal output parameters to the operation or function
|
||||
repeated ParameterDeclProto output_params = 2;
|
||||
// The declaration of expected attributes to the operation or function
|
||||
repeated ParameterDeclProto input_attributes = 3;
|
||||
// An optional human-readable documentation for this signature.
|
||||
optional string doc_string = 4;
|
||||
}
|
||||
|
||||
message OperatorDeclProto {
|
||||
// This field MUST be present for this version of the IR.
|
||||
optional string name = 1;
|
||||
|
||||
// This field MUST contain at least one SignatureDeclProto.
|
||||
// This field MAY contain multiple SignatureDeclProtos, one
|
||||
// per type signature supported by this operator.
|
||||
repeated SignatureDeclProto signature = 2;
|
||||
|
||||
// An optional human-readable documentation for this operator.
|
||||
optional string doc_string = 3;
|
||||
}
|
||||
|
||||
// A library is a top-level format that contains the declaration
|
||||
// of operators and the definition of functions.
|
||||
message LibraryProto {
|
||||
// The version of the IR this graph targets. See Version enum below.
|
||||
// This field MUST be present this version of the IR.
|
||||
optional int64 ir_version = 1;
|
||||
|
||||
// The optional version of the framework runtime that generates this graph.
|
||||
// This producer_version has the same format as ir_version.
|
||||
optional int64 producer_version = 2;
|
||||
|
||||
// The optional name of the framework used to generate this graph in the form
|
||||
// "framework_name[-tag]". Tag is optional and provides additional
|
||||
// information such as `alpha` or `beta` or `rc3`.
|
||||
optional string producer_tag = 3;
|
||||
|
||||
// An optional version identifier used to track evolution of this library.
|
||||
// This model_version has the same format as ir_version.
|
||||
optional int64 model_version = 4;
|
||||
|
||||
// The optional name of the author who created the library.
|
||||
optional string model_author = 5;
|
||||
|
||||
// Optional licensing information concerning use or origination of the library.
|
||||
optional string model_license = 6;
|
||||
|
||||
// The name of the library.
|
||||
optional string name = 7; // namespace Library
|
||||
|
||||
// Domain of the graph.
|
||||
// We use reverse domain names as name space indicators. For example:
|
||||
// `com.facebook.fair` or `com.microsoft.cognitiveservices`
|
||||
//
|
||||
// Together with `name` and `model_version`, this forms the unique identity of
|
||||
// the library.
|
||||
optional string domain = 8;
|
||||
|
||||
// An optional human-readable documentation for this graph.
|
||||
optional string doc_string = 9;
|
||||
|
||||
// The operators declared by this library.
|
||||
repeated OperatorDeclProto operator = 10;
|
||||
|
||||
// The function definitions of the library.
|
||||
repeated FunctionDefProto function = 11;
|
||||
|
||||
// A given name may appear at most once in either operator or function (but not both).
|
||||
// When refering to an operator or function from outside of this library, the op_type
|
||||
// field must equal:
|
||||
// LibraryProto.domain + "." + LibraryProto.name + "." OperatorDeclProto.name
|
||||
// or
|
||||
// LibraryProto.domain + "." + LibraryProto.name + "." FunctionDefProto.name
|
||||
|
||||
// Imported libraries are referenced as a collection of strings in the form of absolute
|
||||
// URIs or relative paths. Where such relative paths are rooted is defined by tools and
|
||||
// runtime implementations.
|
||||
repeated string imported_libraries = 12;
|
||||
}
|
|
@ -542,9 +542,9 @@ void fgetText(FILE* f, T& v)
|
|||
{
|
||||
int rc = ftrygetText(f, v);
|
||||
if (rc == 0)
|
||||
RuntimeError("error reading value from file (invalid format)");
|
||||
Microsoft::MSR::CNTK::RuntimeError("error reading value from file (invalid format)");
|
||||
else if (rc == EOF)
|
||||
RuntimeError("error reading from file: %s", strerror(errno));
|
||||
Microsoft::MSR::CNTK::RuntimeError("error reading from file: %s", strerror(errno));
|
||||
assert(rc == 1);
|
||||
}
|
||||
|
||||
|
@ -575,9 +575,9 @@ void fputText(FILE* f, T v)
|
|||
const wchar_t* formatString = GetFormatString(v);
|
||||
int rc = fwprintf(f, formatString, v);
|
||||
if (rc == 0)
|
||||
RuntimeError("error writing value to file, no values written");
|
||||
Microsoft::MSR::CNTK::RuntimeError("error writing value to file, no values written");
|
||||
else if (rc < 0)
|
||||
RuntimeError("error writing to file: %s", strerror(errno));
|
||||
Microsoft::MSR::CNTK::RuntimeError("error writing to file: %s", strerror(errno));
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
@ -723,7 +723,7 @@ class auto_file_ptr
|
|||
#pragma warning(disable : 4996)
|
||||
void openfailed(const std::string& path)
|
||||
{
|
||||
RuntimeError("auto_file_ptr: error opening file '%s': %s", path.c_str(), strerror(errno));
|
||||
Microsoft::MSR::CNTK::RuntimeError("auto_file_ptr: error opening file '%s': %s", path.c_str(), strerror(errno));
|
||||
}
|
||||
#pragma warning(pop)
|
||||
protected:
|
||||
|
|
|
@ -88,6 +88,7 @@
|
|||
<Compile Include="SwigProxyClasses\MinibatchInfo.cs" />
|
||||
<Compile Include="SwigProxyClasses\MinibatchSource.cs" />
|
||||
<Compile Include="SwigProxyClasses\MinibatchSourceConfig.cs" />
|
||||
<Compile Include="SwigProxyClasses\ModelFormat.cs" />
|
||||
<Compile Include="SwigProxyClasses\PaddingMode.cs" />
|
||||
<Compile Include="SwigProxyClasses\PairDoubleDouble.cs" />
|
||||
<Compile Include="SwigProxyClasses\PairFloatFloat.cs" />
|
||||
|
|
|
@ -105,6 +105,7 @@
|
|||
<None Include="cntk\pytest.ini" />
|
||||
<None Include="cntk\sample_installer.py" />
|
||||
<None Include="cntk\tensor.py" />
|
||||
<None Include="cntk\tests\onnx_format_test.py" />
|
||||
<None Include="cntk\tests\user_learner.py" />
|
||||
<None Include="cntk\tests\variables_test.py" />
|
||||
<None Include="cntk\train\distributed.py" />
|
||||
|
|
|
@ -361,6 +361,9 @@
|
|||
<None Include="cntk\ops\tests\sequence_test.py">
|
||||
<Filter>cntk\ops\tests</Filter>
|
||||
</None>
|
||||
<None Include="cntk\tests\onnx_format_test.py">
|
||||
<Filter>cntk\tests</Filter>
|
||||
</None>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="cntk\cntk_py_wrap.h">
|
||||
|
|
|
@ -12,7 +12,7 @@ from __future__ import print_function
|
|||
import numpy as np
|
||||
import numbers
|
||||
from . import sequence
|
||||
from .functions import CloneMethod, Function, BlockFunction, load_model, register_native_user_function, native_user_function
|
||||
from .functions import ModelFormat, CloneMethod, Function, BlockFunction, load_model, register_native_user_function, native_user_function
|
||||
from cntk.internal import sanitize_input, sanitize_shape, sanitize_axis, sanitize_dynamic_axes, sanitize_axis_list, sanitize_multi_axis_reduction_list, typemap, sanitize_pooling_args, sanitize_convolution_args, sanitize_permutation
|
||||
from cntk.internal.utils import get_data_type
|
||||
from ..axis import Axis
|
||||
|
@ -517,6 +517,25 @@ def batch_normalization(operand, scale, bias, running_mean, running_inv_std, spa
|
|||
normalization_time_constant, blend_time_constant,
|
||||
epsilon, use_cudnn_engine, name)
|
||||
|
||||
@typemap
|
||||
def local_response_normalization(operand, depth_radius, bias, alpha, beta, name=''):
|
||||
'''
|
||||
Local Response Normalization layer.
|
||||
|
||||
Args:
|
||||
operand: input of the Local Response Normalization.
|
||||
depth_radius (int): the radius on the channel dimension to apply the normalization.
|
||||
bias (double): a bias term to avoid divide by zero.
|
||||
alpha (double): the alpha term of the above equation.
|
||||
beta (double): the beta term of the above equation.
|
||||
name (str, optional): the name of the Function instance in the network.
|
||||
Returns:
|
||||
:class:`~cntk.ops.functions.Function`
|
||||
'''
|
||||
from cntk.cntk_py import local_response_normalization
|
||||
operand = sanitize_input(operand)
|
||||
return local_response_normalization(operand, depth_radius, bias, alpha, beta, name)
|
||||
|
||||
##########################################################################
|
||||
# comparison ops
|
||||
##########################################################################
|
||||
|
@ -852,8 +871,6 @@ def element_times(left, right, name=''):
|
|||
right = sanitize_input(right, dtype)
|
||||
return cntk_py_element_times(left, right, name)
|
||||
|
||||
|
||||
# TODO: move element_max/min to C++
|
||||
@associative_multi_arg
|
||||
@typemap
|
||||
def element_max(left, right, name=''):
|
||||
|
@ -869,10 +886,11 @@ def element_max(left, right, name=''):
|
|||
Returns:
|
||||
:class:`~cntk.ops.functions.Function`
|
||||
'''
|
||||
gt = greater(left, right)
|
||||
# TODO: use as_block()
|
||||
return element_select(gt, left, right, name)
|
||||
|
||||
from cntk.cntk_py import element_max as cntk_py_element_max
|
||||
dtype = get_data_type(left, right)
|
||||
left = sanitize_input(left, dtype)
|
||||
right = sanitize_input(right, dtype)
|
||||
return cntk_py_element_max(left, right, name)
|
||||
|
||||
@associative_multi_arg
|
||||
@typemap
|
||||
|
@ -889,10 +907,11 @@ def element_min(left, right, name=''):
|
|||
Returns:
|
||||
:class:`~cntk.ops.functions.Function`
|
||||
'''
|
||||
lt = less(left, right)
|
||||
# TODO: use as_block()
|
||||
return element_select(lt, left, right, name)
|
||||
|
||||
from cntk.cntk_py import element_min as cntk_py_element_min
|
||||
dtype = get_data_type(left, right)
|
||||
left = sanitize_input(left, dtype)
|
||||
right = sanitize_input(right, dtype)
|
||||
return cntk_py_element_min(left, right, name)
|
||||
|
||||
@typemap
|
||||
def element_divide(left, right, name=''):
|
||||
|
|
|
@ -26,6 +26,24 @@ from cntk.internal.sanitize import is_byte_buffer
|
|||
from ..variables import Record, Variable
|
||||
|
||||
|
||||
|
||||
@unique
|
||||
class ModelFormat(Enum):
|
||||
'''
|
||||
Describes the supported disk format for CNTK model.
|
||||
'''
|
||||
|
||||
CNTKv2 = cntk_py.ModelFormat_CNTKv2
|
||||
'''
|
||||
Default CNTK version 2 format, it supports all CNTK functionalities.
|
||||
'''
|
||||
|
||||
ONNX = cntk_py.ModelFormat_ONNX
|
||||
'''
|
||||
Open Neural Network Exchange format from https://github.com/onnx/onnx, ONNX currently support
|
||||
subset of CNTK functionalities.
|
||||
'''
|
||||
|
||||
@unique
|
||||
class CloneMethod(Enum):
|
||||
'''
|
||||
|
@ -1435,10 +1453,9 @@ class Function(cntk_py.Function):
|
|||
return collector.test_summaries[-1]
|
||||
|
||||
@typemap
|
||||
def save(self, filename):
|
||||
def save(self, filename, format=ModelFormat.CNTKv2):
|
||||
'''
|
||||
Save this function graph into a model file using protobuf-based
|
||||
serialization.
|
||||
Save this function graph into a model file using the specified format.
|
||||
|
||||
Use distributed.Communicator.is_main() to gate your call to save()
|
||||
in distributed environment.
|
||||
|
@ -1446,7 +1463,7 @@ class Function(cntk_py.Function):
|
|||
Args:
|
||||
filename (str): model path
|
||||
'''
|
||||
return super(Function, self).save(filename)
|
||||
return super(Function, self).save(filename, format.value)
|
||||
|
||||
@typemap
|
||||
def restore(self, filename):
|
||||
|
@ -1489,7 +1506,7 @@ class Function(cntk_py.Function):
|
|||
|
||||
@staticmethod
|
||||
@typemap
|
||||
def load(model, device=None):
|
||||
def load(model, device=None, format=ModelFormat.CNTKv2):
|
||||
'''
|
||||
Load the ``model``, that has been saved using :func:`~cntk.ops.functions.Function.save`.
|
||||
|
||||
|
@ -1498,6 +1515,8 @@ class Function(cntk_py.Function):
|
|||
containing the binary representation of a model.
|
||||
device (:class:`~cntk.device.DeviceDescriptor`, defaults to the current globally default device):
|
||||
specifies the device to allocate the model on.
|
||||
format (:class:`~cntk.ModelFormat`, defaults to CNTKv2 format): specifies the format of the file to load.
|
||||
if the specified format is ONNX, then model must be a filename.
|
||||
|
||||
Returns:
|
||||
root node
|
||||
|
@ -1515,10 +1534,12 @@ class Function(cntk_py.Function):
|
|||
pass
|
||||
|
||||
if is_buffer:
|
||||
if format != ModelFormat.CNTKv2:
|
||||
raise ValueError('Loading from buffer only supported for CNTKv2 format.')
|
||||
return cntk_py.Function.load_from_buffer(model, device)
|
||||
|
||||
if is_file:
|
||||
return cntk_py.Function.load(str(model), device)
|
||||
return cntk_py.Function.load(str(model), device, format.value)
|
||||
|
||||
raise ValueError('Cannot load the model {} that is neither a file nor a byte buffer.'.format(model))
|
||||
|
||||
|
@ -1600,11 +1621,11 @@ def native_user_function(op_id, operands, attributes=None, user_function_instanc
|
|||
return cntk_py.Function_native_user_function(op_id, operands, attributes, user_function_instance_name)
|
||||
|
||||
@typemap
|
||||
def load_model(model, device=None):
|
||||
def load_model(model, device=None, format=ModelFormat.CNTKv2):
|
||||
'''
|
||||
Alias for :func:`~cntk.ops.functions.Function.load`.
|
||||
'''
|
||||
return Function.load(model, device)
|
||||
return Function.load(model, device, format)
|
||||
|
||||
class UserFunction(Function):
|
||||
'''
|
||||
|
|
|
@ -613,6 +613,36 @@ def test_op_batch_normalization(use_cudnn, sample, device_id, precision):
|
|||
|
||||
unittest_helper(op_node, forward_input, expected_forward, expected_backward=None, device_id=device_id, precision=precision)
|
||||
|
||||
def test_local_response_normalization(device_id, precision):
|
||||
dtype = PRECISION_TO_TYPE[precision]
|
||||
dev = cntk_device(device_id)
|
||||
|
||||
def lrn(x, depth_radius, bias, alpha, beta, name=''):
|
||||
x2 = C.square(x)
|
||||
# reshape to insert a fake singleton reduction dimension after the 3th axis (channel axis). Note Python axis order and BrainScript are reversed.
|
||||
x2s = C.reshape(x2, (1, C.InferredDimension), 0, 1)
|
||||
W = C.constant(alpha/(2*depth_radius+1), shape=(1,2*depth_radius+1,1,1), dtype=dtype, name='W')
|
||||
# 3D convolution with a filter that has a non 1-size only in the 3rd axis, and does not reduce since the reduction dimension is fake and 1
|
||||
y = C.convolution (W, x2s)
|
||||
# reshape back to remove the fake singleton reduction dimension
|
||||
b = C.reshape(y, C.InferredDimension, 0, 2)
|
||||
den = C.exp(beta * C.log(bias + b))
|
||||
return C.element_divide(x, den)
|
||||
|
||||
from cntk import local_response_normalization
|
||||
|
||||
img_shape = (64, 32, 32)
|
||||
img = np.asarray(np.random.uniform(-1, 1, img_shape), dtype=dtype)
|
||||
x_gt = C.input_variable(shape=img_shape, dtype=dtype)
|
||||
x_r = C.input_variable(shape=img_shape, dtype=dtype)
|
||||
|
||||
gt = lrn(x_gt, 2, 1.0, 0.0001, 0.75)
|
||||
r = local_response_normalization(x_r, 2, 1.0, 0.0001, 0.75)
|
||||
ss = gt.eval({x_gt:img})
|
||||
sa = r.eval({x_r:img})
|
||||
|
||||
assert np.allclose(r.eval({x_r:img}), gt.eval({x_gt:img}))
|
||||
|
||||
TENSOR_PAIRS = [
|
||||
([0.3], [0.1]),
|
||||
([[0.1]], [[0.3]]),
|
||||
|
|
|
@ -27,7 +27,8 @@ def test_convolution_attributes():
|
|||
'upperPad': (0, 0, 0),
|
||||
'lowerPad': (0, 0, 0),
|
||||
'transpose': False,
|
||||
'outputShape': (0,)
|
||||
'outputShape': (0,),
|
||||
'kernelShape': (1, 2, 2)
|
||||
}
|
||||
_check(expected, d)
|
||||
|
||||
|
@ -41,7 +42,8 @@ def test_convolution_attributes():
|
|||
'upperPad': (0, 0, 0),
|
||||
'lowerPad': (0, 0, 0),
|
||||
'transpose': False,
|
||||
'outputShape': (0,)
|
||||
'outputShape': (0,),
|
||||
'kernelShape': (1, 2, 2)
|
||||
}
|
||||
_check(expected, d)
|
||||
|
||||
|
@ -59,7 +61,8 @@ def test_convolution_transpose_attributes():
|
|||
'upperPad': (0, 0, 0),
|
||||
'lowerPad': (0, 0, 0),
|
||||
'transpose': True,
|
||||
'outputShape': (0,)
|
||||
'outputShape': (0,),
|
||||
'kernelShape': (1, 2, 2)
|
||||
}
|
||||
_check(expected, d)
|
||||
|
||||
|
|
|
@ -0,0 +1,280 @@
|
|||
# Copyright (c) Microsoft. All rights reserved.
|
||||
# Licensed under the MIT license. See LICENSE.md file in the project root
|
||||
# for full license information.
|
||||
# ==============================================================================
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import cntk as C
|
||||
|
||||
def test_load_save_constant(tmpdir):
|
||||
# import pdb;pdb.set_trace()
|
||||
c = C.constant(value=[1,3])
|
||||
root_node = c * 5
|
||||
|
||||
result = root_node.eval()
|
||||
expected = [[[[5,15]]]]
|
||||
assert np.allclose(result, expected)
|
||||
|
||||
filename = os.path.join(str(tmpdir), R'c_plus_c.onnx')
|
||||
root_node.save(filename, format=C.ModelFormat.ONNX)
|
||||
|
||||
loaded_node = C.Function.load(filename, format=C.ModelFormat.ONNX)
|
||||
assert root_node.shape == loaded_node.shape
|
||||
|
||||
loaded_result = loaded_node.eval()
|
||||
assert np.allclose(loaded_result, expected)
|
||||
|
||||
def test_dense_layer(tmpdir):
|
||||
img_shape = (1, 5, 5)
|
||||
img = np.asarray(np.random.uniform(-1, 1, img_shape), dtype=np.float32)
|
||||
|
||||
x = C.input_variable(img.shape)
|
||||
root_node = C.layers.Dense(5, activation=C.softmax)(x)
|
||||
|
||||
filename = os.path.join(str(tmpdir), R'dense_layer.onnx')
|
||||
root_node.save(filename, format=C.ModelFormat.ONNX)
|
||||
|
||||
loaded_node = C.Function.load(filename, format=C.ModelFormat.ONNX)
|
||||
assert root_node.shape == loaded_node.shape
|
||||
|
||||
x_ = loaded_node.arguments[0]
|
||||
assert np.allclose(loaded_node.eval({x_:img}), root_node.eval({x:img}))
|
||||
|
||||
def test_convolution(tmpdir):
|
||||
img_shape = (1, 5, 5)
|
||||
img = np.asarray(np.random.uniform(-1, 1, img_shape), dtype=np.float32)
|
||||
|
||||
x = C.input_variable(img.shape)
|
||||
filter = np.reshape(np.array([2, -1, -1, 2], dtype = np.float32), (1, 2, 2))
|
||||
kernel = C.constant(value = filter)
|
||||
root_node = C.convolution(kernel, x, auto_padding=[False])
|
||||
|
||||
filename = os.path.join(str(tmpdir), R'conv.onnx')
|
||||
root_node.save(filename, format=C.ModelFormat.ONNX)
|
||||
|
||||
loaded_node = C.Function.load(filename, format=C.ModelFormat.ONNX)
|
||||
assert root_node.shape == loaded_node.shape
|
||||
|
||||
x_ = loaded_node.arguments[0]
|
||||
assert np.allclose(loaded_node.eval({x_:[img]}), root_node.eval({x:[img]}))
|
||||
|
||||
def test_convolution_transpose(tmpdir):
|
||||
img_shape = (1, 3, 3)
|
||||
img = np.asarray(np.random.uniform(-1, 1, img_shape), dtype=np.float32)
|
||||
|
||||
x = C.input_variable(img.shape)
|
||||
filter = np.reshape(np.array([2, -1, -1, 2], dtype = np.float32), (1, 2, 2))
|
||||
kernel = C.constant(value = filter)
|
||||
root_node = C.convolution_transpose(kernel, x, auto_padding=[False], output_shape=(1, 4, 4))
|
||||
|
||||
filename = os.path.join(str(tmpdir), R'conv_transpose.onnx')
|
||||
root_node.save(filename, format=C.ModelFormat.ONNX)
|
||||
|
||||
loaded_node = C.Function.load(filename, format=C.ModelFormat.ONNX)
|
||||
assert root_node.shape == loaded_node.shape
|
||||
|
||||
x_ = loaded_node.arguments[0]
|
||||
assert np.allclose(loaded_node.eval({x_:[img]}), root_node.eval({x:[img]}))
|
||||
|
||||
def test_conv_model(tmpdir):
|
||||
def create_model(input):
|
||||
with C.layers.default_options(init=C.glorot_uniform(), activation=C.relu):
|
||||
model = C.layers.Sequential([
|
||||
C.layers.For(range(3), lambda i: [
|
||||
C.layers.Convolution((5,5), [32,32,64][i], pad=True),
|
||||
C.layers.MaxPooling((3,3), strides=(2,2))
|
||||
]),
|
||||
C.layers.Dense(64),
|
||||
C.layers.Dense(10, activation=None)
|
||||
])
|
||||
|
||||
return model(input)
|
||||
|
||||
img_shape = (3, 32, 32)
|
||||
img = np.asarray(np.random.uniform(-1, 1, img_shape), dtype=np.float32)
|
||||
|
||||
x = C.input_variable(img.shape)
|
||||
root_node = create_model(x)
|
||||
|
||||
filename = os.path.join(str(tmpdir), R'conv_model.onnx')
|
||||
root_node.save(filename, format=C.ModelFormat.ONNX)
|
||||
|
||||
loaded_node = C.Function.load(filename, format=C.ModelFormat.ONNX)
|
||||
assert root_node.shape == loaded_node.shape
|
||||
|
||||
x_ = loaded_node.arguments[0]
|
||||
assert np.allclose(loaded_node.eval({x_:img}), root_node.eval({x:img}))
|
||||
|
||||
def test_batch_norm_model(tmpdir):
|
||||
image_height = 32
|
||||
image_width = 32
|
||||
num_channels = 3
|
||||
num_classes = 10
|
||||
|
||||
input_var = C.input_variable((num_channels, image_height, image_width))
|
||||
label_var = C.input_variable((num_classes))
|
||||
def create_basic_model_with_batch_normalization(input, out_dims):
|
||||
with C.layers.default_options(activation=C.relu, init=C.glorot_uniform()):
|
||||
model = C.layers.Sequential([
|
||||
C.layers.For(range(3), lambda i: [
|
||||
C.layers.Convolution((5,5), [image_width,image_height,64][i], pad=True),
|
||||
C.layers.BatchNormalization(map_rank=1),
|
||||
C.layers.MaxPooling((3,3), strides=(2,2))
|
||||
]),
|
||||
C.layers.Dense(64),
|
||||
C.layers.BatchNormalization(map_rank=1),
|
||||
C.layers.Dense(out_dims, activation=None)
|
||||
])
|
||||
|
||||
return model(input)
|
||||
|
||||
feature_scale = 1.0 / 256.0
|
||||
input_var_norm = C.element_times(feature_scale, input_var)
|
||||
|
||||
# apply model to input
|
||||
z = create_basic_model_with_batch_normalization(input_var_norm, out_dims=10)
|
||||
|
||||
filename = os.path.join(str(tmpdir), R'bn_model.onnx')
|
||||
z.save(filename, format=C.ModelFormat.ONNX)
|
||||
|
||||
loaded_node = C.Function.load(filename, format=C.ModelFormat.ONNX)
|
||||
assert z.shape == loaded_node.shape
|
||||
|
||||
img_shape = (num_channels, image_width, image_height)
|
||||
img = np.asarray(np.random.uniform(-1, 1, img_shape), dtype=np.float32)
|
||||
|
||||
x = z.arguments[0];
|
||||
x_ = loaded_node.arguments[0]
|
||||
assert np.allclose(loaded_node.eval({x_:img}), z.eval({x:img}))
|
||||
|
||||
def test_vgg9_model(tmpdir):
|
||||
def create_model(input):
|
||||
with C.layers.default_options(activation=C.relu, init=C.glorot_uniform()):
|
||||
model = C.layers.Sequential([
|
||||
C.layers.For(range(3), lambda i: [
|
||||
C.layers.Convolution((3,3), [64,96,128][i], pad=True),
|
||||
C.layers.Convolution((3,3), [64,96,128][i], pad=True),
|
||||
C.layers.MaxPooling((3,3), strides=(2,2))
|
||||
]),
|
||||
C.layers.For(range(2), lambda : [
|
||||
C.layers.Dense(1024)
|
||||
]),
|
||||
C.layers.Dense(10, activation=None)
|
||||
])
|
||||
|
||||
return model(input)
|
||||
|
||||
img_shape = (3, 32, 32)
|
||||
img = np.asarray(np.random.uniform(-1, 1, img_shape), dtype=np.float32)
|
||||
|
||||
x = C.input_variable(img.shape)
|
||||
root_node = create_model(x)
|
||||
|
||||
filename = os.path.join(str(tmpdir), R'vgg9_model.onnx')
|
||||
root_node.save(filename, format=C.ModelFormat.ONNX)
|
||||
|
||||
loaded_node = C.Function.load(filename, format=C.ModelFormat.ONNX)
|
||||
assert root_node.shape == loaded_node.shape
|
||||
|
||||
x_ = loaded_node.arguments[0]
|
||||
assert np.allclose(loaded_node.eval({x_:img}), root_node.eval({x:img}))
|
||||
|
||||
def test_conv3d_model(tmpdir):
|
||||
def create_model(input):
|
||||
with C.default_options (activation=C.relu):
|
||||
model = C.layers.Sequential([
|
||||
C.layers.Convolution3D((3,3,3), 64, pad=True),
|
||||
C.layers.MaxPooling((1,2,2), (1,2,2)),
|
||||
C.layers.For(range(3), lambda i: [
|
||||
C.layers.Convolution3D((3,3,3), [96, 128, 128][i], pad=True),
|
||||
C.layers.Convolution3D((3,3,3), [96, 128, 128][i], pad=True),
|
||||
C.layers.MaxPooling((2,2,2), (2,2,2))
|
||||
]),
|
||||
C.layers.For(range(2), lambda : [
|
||||
C.layers.Dense(1024),
|
||||
C.layers.Dropout(0.5)
|
||||
]),
|
||||
C.layers.Dense(100, activation=None)
|
||||
])
|
||||
|
||||
return model(input)
|
||||
|
||||
video_shape = (3, 20, 32, 32)
|
||||
video = np.asarray(np.random.uniform(-1, 1, video_shape), dtype=np.float32)
|
||||
|
||||
x = C.input_variable(video.shape)
|
||||
root_node = create_model(x)
|
||||
|
||||
filename = os.path.join(str(tmpdir), R'conv3d_model.onnx')
|
||||
root_node.save(filename, format=C.ModelFormat.ONNX)
|
||||
|
||||
loaded_node = C.Function.load(filename, format=C.ModelFormat.ONNX)
|
||||
assert root_node.shape == loaded_node.shape
|
||||
|
||||
x_ = loaded_node.arguments[0]
|
||||
assert np.allclose(loaded_node.eval({x_:video}), root_node.eval({x:video}))
|
||||
|
||||
def test_resnet_model(tmpdir):
|
||||
def convolution_bn(input, filter_size, num_filters, strides=(1,1), init=C.normal(0.01), activation=C.relu):
|
||||
r = C.layers.Convolution(filter_size,
|
||||
num_filters,
|
||||
strides=strides,
|
||||
init=init,
|
||||
activation=None,
|
||||
pad=True, bias=False)(input)
|
||||
r = C.layers.BatchNormalization(map_rank=1)(r)
|
||||
r = r if activation is None else activation(r)
|
||||
return r
|
||||
|
||||
def resnet_basic(input, num_filters):
|
||||
c1 = convolution_bn(input, (3,3), num_filters)
|
||||
c2 = convolution_bn(c1, (3,3), num_filters, activation=None)
|
||||
p = c2 + input
|
||||
return C.relu(p)
|
||||
|
||||
def resnet_basic_inc(input, num_filters):
|
||||
c1 = convolution_bn(input, (3,3), num_filters, strides=(2,2))
|
||||
c2 = convolution_bn(c1, (3,3), num_filters, activation=None)
|
||||
|
||||
s = convolution_bn(input, (1,1), num_filters, strides=(2,2), activation=None)
|
||||
|
||||
p = c2 + s
|
||||
return C.relu(p)
|
||||
|
||||
def resnet_basic_stack(input, num_filters, num_stack):
|
||||
assert (num_stack > 0)
|
||||
|
||||
r = input
|
||||
for _ in range(num_stack):
|
||||
r = resnet_basic(r, num_filters)
|
||||
return r
|
||||
|
||||
def create_model(input):
|
||||
conv = convolution_bn(input, (3,3), 16)
|
||||
r1_1 = resnet_basic_stack(conv, 16, 3)
|
||||
|
||||
r2_1 = resnet_basic_inc(r1_1, 32)
|
||||
r2_2 = resnet_basic_stack(r2_1, 32, 2)
|
||||
|
||||
r3_1 = resnet_basic_inc(r2_2, 64)
|
||||
r3_2 = resnet_basic_stack(r3_1, 64, 2)
|
||||
|
||||
# Global average pooling
|
||||
pool = C.layers.AveragePooling(filter_shape=(8,8), strides=(1,1))(r3_2)
|
||||
return C.layers.Dense(10, init=C.normal(0.01), activation=None)(pool)
|
||||
|
||||
img_shape = (3, 32, 32)
|
||||
img = np.asarray(np.random.uniform(0, 1, img_shape), dtype=np.float32)
|
||||
|
||||
x = C.input_variable(img.shape)
|
||||
root_node = create_model(x)
|
||||
|
||||
filename = os.path.join(str(tmpdir), R'resnet_model.onnx')
|
||||
root_node.save(filename, format=C.ModelFormat.ONNX)
|
||||
|
||||
loaded_node = C.Function.load(filename, format=C.ModelFormat.ONNX)
|
||||
assert root_node.shape == loaded_node.shape
|
||||
|
||||
x_ = loaded_node.arguments[0]
|
||||
assert np.allclose(loaded_node.eval({x_:img}), root_node.eval({x:img}))
|
Загрузка…
Ссылка в новой задаче