Adding ONNX format support to CNTK.
This commit is contained in:
Родитель
0b384cba0d
Коммит
d57ad1d673
|
@ -234,6 +234,10 @@ bindings/python/doc/_build
|
|||
Source/CNTKv2LibraryDll/proto/CNTK.pb.cc
|
||||
Source/CNTKv2LibraryDll/proto/CNTK.pb.h
|
||||
|
||||
# Auto-generated sources from ONNX proto
|
||||
Source/CNTKv2LibraryDll/proto/onnx/protobuf/graph.pb.cc
|
||||
Source/CNTKv2LibraryDll/proto/onnx/protobuf/graph.pb.h
|
||||
|
||||
bindings/python/cntk/c_plus_c.mod
|
||||
bindings/python/cntk/i_plus_c_0.mod
|
||||
bindings/python/cntk/i_plus_i_0.mod
|
||||
|
|
|
@ -71,7 +71,7 @@ namespace CNTK.CSTrainingExamples
|
|||
var labels = CNTKLib.InputVariable(new int[] { numClasses }, DataType.Float, labelsStreamName);
|
||||
var trainingLoss = CNTKLib.CrossEntropyWithSoftmax(new Variable(classifierOutput), labels, "lossFunction");
|
||||
var prediction = CNTKLib.ClassificationError(new Variable(classifierOutput), labels, "classificationError");
|
||||
|
||||
|
||||
// prepare training data
|
||||
var minibatchSource = MinibatchSource.TextFormatMinibatchSource(
|
||||
Path.Combine(ImageDataFolder, "Train_cntk_text.txt"), streamConfigurations, MinibatchSource.InfinitelyRepeat);
|
||||
|
@ -122,7 +122,7 @@ namespace CNTK.CSTrainingExamples
|
|||
imageDim, numClasses, featureStreamName, labelsStreamName, classifierName, device);
|
||||
}
|
||||
|
||||
private static Function CreateMLPClassifier(DeviceDescriptor device, int numOutputClasses, int hiddenLayerDim,
|
||||
private static Function CreateMLPClassifier(DeviceDescriptor device, int numOutputClasses, int hiddenLayerDim,
|
||||
Function scaledInput, string classifierName)
|
||||
{
|
||||
Function dense1 = TestHelper.Dense(scaledInput, hiddenLayerDim, device, Activation.Sigmoid, "");
|
||||
|
@ -160,8 +160,8 @@ namespace CNTK.CSTrainingExamples
|
|||
return denseLayer;
|
||||
}
|
||||
|
||||
private static Function ConvolutionWithMaxPooling(Variable features, DeviceDescriptor device,
|
||||
int kernelWidth, int kernelHeight, int numInputChannels, int outFeatureMapCount,
|
||||
private static Function ConvolutionWithMaxPooling(Variable features, DeviceDescriptor device,
|
||||
int kernelWidth, int kernelHeight, int numInputChannels, int outFeatureMapCount,
|
||||
int hStride, int vStride, int poolingWindowWidth, int poolingWindowHeight)
|
||||
{
|
||||
// parameter initialization hyper parameter
|
||||
|
|
|
@ -20,7 +20,7 @@ namespace CNTK.CSTrainingExamples
|
|||
}
|
||||
public class TestHelper
|
||||
{
|
||||
public static Function Dense(Variable input, int outputDim, DeviceDescriptor device,
|
||||
public static Function Dense(Variable input, int outputDim, DeviceDescriptor device,
|
||||
Activation activation = Activation.None, string outputName = "")
|
||||
{
|
||||
if (input.Shape.Rank != 1)
|
||||
|
|
21
Makefile
21
Makefile
|
@ -493,6 +493,27 @@ CNTKLIBRARY_COMMON_SRC =\
|
|||
$(SOURCEDIR)/CNTKv2LibraryDll/tensorboard/tensorboard.pb.cc \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/tensorboard/TensorBoardFileWriter.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/tensorboard/TensorBoardUtils.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/protobuf/graph.pb.cc \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/defs/experiments/defs.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/defs/generator/defs.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/defs/logical/defs.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/defs/math/defs.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/defs/nn/defs.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/defs/reduction/defs.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/defs/rnn/defs.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/defs/tensor/defs.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/core/constants.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/core/status.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/core/utils.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/core/opsignature.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/core/op.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/core/shape_inference.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/core/graph.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/core/model.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/Operators.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/proto/onnx/ONNX.cpp \
|
||||
|
||||
CNTKLIBRARY_SRC =\
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/ComputeInputStatistics.cpp \
|
||||
|
|
63
README.md
63
README.md
|
@ -1,26 +1,49 @@
|
|||
[![Join the chat at https://gitter.im/Microsoft/CNTK](https://badges.gitter.im/Microsoft/CNTK.svg)](https://gitter.im/Microsoft/CNTK?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
|
||||
|
||||
## Latest news
|
||||
## Latest news
|
||||
|
||||
***2017-09-25.*** CNTK September interation plan posted [here](https://github.com/Microsoft/CNTK/issues/2410).
|
||||
***2017-10-10.*** Preview: CNTK ONNX Format Support
|
||||
Update CNTK to support load and save ONNX format from https://github.com/onnx/onnx, please try it and provide feedback. We only support ONNX OPs. This is a preview, and we expect a breaking change in the future.
|
||||
|
||||
***2017-09-24.*** CNTK R-binding now available [here](https://github.com/Microsoft/CNTK-R).
|
||||
* Support loading a model saved in ONNX format.
|
||||
* Support saving a model in ONNX format, not all CNTK models are currently supported. Only a subset of CNTK models are supported and no RNN. We will add more in the future.
|
||||
|
||||
To load an ONNX model, simply specify the format parameter for the load function.
|
||||
```
|
||||
import cntk as C
|
||||
|
||||
C.Function.load(<path of your ONNX model>, format=C.ModelFormat.ONNX)
|
||||
```
|
||||
|
||||
To save a CNTK graph as ONNX model, simply specify the format in the save function.
|
||||
|
||||
```
|
||||
import cntk as C
|
||||
|
||||
x = C.input_variable(<input shape>)
|
||||
z = create_model(x)
|
||||
z.save(<path of where to save your ONNX model>, format=C.ModelFormat.ONNX)
|
||||
```
|
||||
|
||||
***2017-09-25.*** CNTK September interation plan posted [here](https://github.com/Microsoft/CNTK/issues/2410).
|
||||
|
||||
***2017-09-24.*** CNTK R-binding now available [here](https://github.com/Microsoft/CNTK-R).
|
||||
|
||||
***2017-09-15.* CNTK 2.2**
|
||||
Release of Cognitive Toolkit v2.2.
|
||||
Release of Cognitive Toolkit v2.2.
|
||||
|
||||
Hightlights:
|
||||
* NCCL 2 support
|
||||
* New learner interface
|
||||
* A C#/.NET API that enables people to build and train networks
|
||||
* New C++ and C# eval examples
|
||||
* New nodes
|
||||
Hightlights:
|
||||
* NCCL 2 support
|
||||
* New learner interface
|
||||
* A C#/.NET API that enables people to build and train networks
|
||||
* New C++ and C# eval examples
|
||||
* New nodes
|
||||
* Tensorboard image support for CNTK
|
||||
|
||||
See more in the [Release Notes](https://docs.microsoft.com/en-us/cognitive-toolkit/ReleaseNotes/CNTK_2_2_Release_Notes).
|
||||
Get the Release from the [CNTK Releases page](https://github.com/Microsoft/CNTK/releases).
|
||||
|
||||
***2017-08-04.*** CNTK August interation plan posted [here](https://github.com/Microsoft/CNTK/issues/2194).
|
||||
***2017-08-04.*** CNTK August interation plan posted [here](https://github.com/Microsoft/CNTK/issues/2194).
|
||||
|
||||
***2017-07-31.* CNTK 2.1**
|
||||
Release of Cognitive Toolkit v.2.1.
|
||||
|
@ -42,23 +65,23 @@ See [all news](https://docs.microsoft.com/en-us/cognitive-toolkit/news)
|
|||
|
||||
The Microsoft Cognitive Toolkit (https://cntk.ai), is a unified deep-learning toolkit that describes neural networks as a series of computational steps via a directed graph. In this directed graph, leaf nodes represent input values or network parameters, while other nodes represent matrix operations upon their inputs. CNTK allows to easily realize and combine popular model types such as feed-forward DNNs, convolutional nets (CNNs), and recurrent networks (RNNs/LSTMs). It implements stochastic gradient descent (SGD, error backpropagation) learning with automatic differentiation and parallelization across multiple GPUs and servers. CNTK has been available under an open-source license since April 2015. It is our hope that the community will take advantage of CNTK to share ideas more quickly through the exchange of open source working code.
|
||||
|
||||
## Installation
|
||||
## Installation
|
||||
|
||||
* [Setup CNTK](https://docs.microsoft.com/en-us/cognitive-toolkit/Setup-CNTK-on-your-machine)
|
||||
* Windows [Python-only](https://docs.microsoft.com/en-us/cognitive-toolkit/setup-windows-python) / [Script-driven](https://docs.microsoft.com/en-us/cognitive-toolkit/setup-windows-binary-script) / [Manual](https://docs.microsoft.com/en-us/cognitive-toolkit/setup-windows-binary-manual)
|
||||
* Windows [Python-only](https://docs.microsoft.com/en-us/cognitive-toolkit/setup-windows-python) / [Script-driven](https://docs.microsoft.com/en-us/cognitive-toolkit/setup-windows-binary-script) / [Manual](https://docs.microsoft.com/en-us/cognitive-toolkit/setup-windows-binary-manual)
|
||||
* Linux [Python-only](https://docs.microsoft.com/en-us/cognitive-toolkit/setup-linux-python) / [Script-driven](https://docs.microsoft.com/en-us/cognitive-toolkit/setup-linux-binary-script) / [Manual](https://docs.microsoft.com/en-us/cognitive-toolkit/setup-linux-binary-manual) / [Docker](https://docs.microsoft.com/en-us/cognitive-toolkit/cntk-docker-containers)
|
||||
* [CNTK backend for Keras](https://docs.microsoft.com/en-us/cognitive-toolkit/using-cntk-with-keras)
|
||||
* [Setup CNTK development environment](https://docs.microsoft.com/en-us/cognitive-toolkit/setup-development-environment)
|
||||
* Windows [Script-driven](https://docs.microsoft.com/en-us/cognitive-toolkit/setup-cntk-with-script-on-windows) / [Manual](https://docs.microsoft.com/en-us/cognitive-toolkit/setup-cntk-on-windows)
|
||||
* [Setup CNTK development environment](https://docs.microsoft.com/en-us/cognitive-toolkit/setup-development-environment)
|
||||
* Windows [Script-driven](https://docs.microsoft.com/en-us/cognitive-toolkit/setup-cntk-with-script-on-windows) / [Manual](https://docs.microsoft.com/en-us/cognitive-toolkit/setup-cntk-on-windows)
|
||||
* Linux [Manual](https://docs.microsoft.com/en-us/cognitive-toolkit/setup-cntk-on-linux)
|
||||
|
||||
## Learning CNTK
|
||||
## Learning CNTK
|
||||
|
||||
You may learn more about CNTK with the following resources:
|
||||
You may learn more about CNTK with the following resources:
|
||||
* [General documentation](https://docs.microsoft.com/en-us/cognitive-toolkit/)
|
||||
* [Python API documentation](https://cntk.ai/pythondocs/)
|
||||
* [BrainScript documentation](https://docs.microsoft.com/en-us/cognitive-toolkit/Using-CNTK-with-BrainScript)
|
||||
* [Evaluation documentation (C++, C#/.NET, Python, Java)](https://docs.microsoft.com/en-us/cognitive-toolkit/CNTK-Evaluation-Overview)
|
||||
* [Evaluation documentation (C++, C#/.NET, Python, Java)](https://docs.microsoft.com/en-us/cognitive-toolkit/CNTK-Evaluation-Overview)
|
||||
* [Manual](https://github.com/Microsoft/CNTK/tree/master/Manual)
|
||||
* [Tutorials](https://docs.microsoft.com/en-us/cognitive-toolkit/tutorials)
|
||||
* [Examples](https://docs.microsoft.com/en-us/cognitive-toolkit/Examples)
|
||||
|
@ -67,9 +90,9 @@ You may learn more about CNTK with the following resources:
|
|||
* [Presentations](https://docs.microsoft.com/en-us/cognitive-toolkit/Presentations)
|
||||
* [License](./LICENSE.md)
|
||||
|
||||
## More information
|
||||
## More information
|
||||
|
||||
* [Reasons to switch from TensorFlow to CNTK](https://docs.microsoft.com/en-us/cognitive-toolkit/reasons-to-switch-from-tensorflow-to-cntk)
|
||||
* [Reasons to switch from TensorFlow to CNTK](https://docs.microsoft.com/en-us/cognitive-toolkit/reasons-to-switch-from-tensorflow-to-cntk)
|
||||
* [Contribute to CNTK](https://docs.microsoft.com/en-us/cognitive-toolkit/Contributing-to-CNTK)
|
||||
* [FAQ](https://docs.microsoft.com/en-us/cognitive-toolkit/CNTK-FAQ)
|
||||
* [Feedback](https://docs.microsoft.com/en-us/cognitive-toolkit/Feedback-Channels)
|
||||
|
|
|
@ -1797,6 +1797,7 @@ namespace CNTK
|
|||
friend class Trainer;
|
||||
friend class PrimitiveFunction;
|
||||
friend class Utils;
|
||||
friend class CNTKToONNXHelper;
|
||||
|
||||
template <typename T>
|
||||
friend struct std::hash;
|
||||
|
@ -1874,6 +1875,14 @@ namespace CNTK
|
|||
///
|
||||
bool IsPlaceholder() const { return Kind() == VariableKind::Placeholder; }
|
||||
|
||||
///
|
||||
/// Returns a boolean value indicating if 'this' variable has a batch axis or not.
|
||||
///
|
||||
bool HasBatchAxis() const {
|
||||
return std::any_of(DynamicAxes().begin(), DynamicAxes().end(),
|
||||
[](const Axis& axis) { return (axis == Axis::DefaultBatchAxis()); });
|
||||
}
|
||||
|
||||
///
|
||||
/// Returns the name of 'this' variable
|
||||
///
|
||||
|
@ -2995,6 +3004,24 @@ namespace CNTK
|
|||
};
|
||||
typedef std::shared_ptr<BackPropState> BackPropStatePtr;
|
||||
|
||||
///
|
||||
/// List of supported disk formats for CNTK model.
|
||||
///
|
||||
enum class ModelFormat
|
||||
{
|
||||
///
|
||||
/// Default CNTK version 2 format, support all CNTK features.
|
||||
///
|
||||
CNTKv2,
|
||||
|
||||
///
|
||||
/// Open Neural Network Exchange format from https://github.com/onnx/onnx
|
||||
/// ONNX support limited subset of CNTK.
|
||||
///
|
||||
ONNX,
|
||||
};
|
||||
|
||||
|
||||
///
|
||||
/// How are Parameters handled when cloning a Function
|
||||
///
|
||||
|
@ -3406,7 +3433,7 @@ namespace CNTK
|
|||
///
|
||||
/// Save this Function graph into a model file.
|
||||
///
|
||||
CNTK_API void Save(const std::wstring& filepath);
|
||||
CNTK_API void Save(const std::wstring& filepath, ModelFormat format = ModelFormat::CNTKv2);
|
||||
|
||||
///
|
||||
/// Restore the models parameters (in-place) from a model file
|
||||
|
@ -3417,7 +3444,8 @@ namespace CNTK
|
|||
/// Load a Function from a model file
|
||||
///
|
||||
CNTK_API static FunctionPtr Load(const std::wstring& filepath,
|
||||
const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
|
||||
const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice(),
|
||||
ModelFormat format = ModelFormat::CNTKv2);
|
||||
|
||||
///
|
||||
/// Load a Function from a memory buffer
|
||||
|
@ -3894,6 +3922,17 @@ namespace CNTK
|
|||
///
|
||||
CNTK_API FunctionPtr ElementDivide(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Compute the element wise maximum operation between the given operands.
|
||||
///
|
||||
CNTK_API FunctionPtr ElementMax(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name);
|
||||
|
||||
|
||||
///
|
||||
/// Compute the element wise minimum operation between the given operands.
|
||||
///
|
||||
CNTK_API FunctionPtr ElementMin(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name);
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in elementwise equality comparison operation on specified tensor input operands.
|
||||
///
|
||||
|
@ -4270,7 +4309,6 @@ namespace CNTK
|
|||
///
|
||||
/// TODO:
|
||||
///
|
||||
// TODO: Do we need a separate "spatial" parameter or can it be inferred from the tensor dimensions
|
||||
CNTK_API FunctionPtr BatchNormalization(const Variable& operand,
|
||||
const Variable& scale,
|
||||
const Variable& bias,
|
||||
|
@ -4284,6 +4322,17 @@ namespace CNTK
|
|||
bool useCuDNNEngine = true,
|
||||
const std::wstring& name = L"");
|
||||
|
||||
//
|
||||
// Local response normalization as described in http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks
|
||||
//
|
||||
CNTK_API FunctionPtr LocalResponseNormalization(const Variable& operand,
|
||||
size_t depthRadius,
|
||||
double bias,
|
||||
double alpha,
|
||||
double beta,
|
||||
const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in OptimizedRNNStack operation on specified input operands
|
||||
///
|
||||
CNTK_API FunctionPtr OptimizedRNNStack(const Variable& operand, const Variable& weights, size_t hiddenSize, size_t numLayers, bool bidirectional = false, const std::wstring& recurrentOp = L"lstm", const std::wstring& name = L"");
|
||||
|
@ -4322,6 +4371,12 @@ namespace CNTK
|
|||
///
|
||||
CNTK_API FunctionPtr AsBlock(FunctionPtr&& composite, const std::vector<std::pair<Variable, Variable>>& argumentsMap, const std::wstring& blockOpName, const std::wstring& blockName = L"");
|
||||
|
||||
///
|
||||
/// Creates a Block Function that encapsulates a composite to create an opaque Function object that
|
||||
/// appears as any other primitive Function
|
||||
///
|
||||
CNTK_API FunctionPtr AsBlock(FunctionPtr&& composite, const std::vector<std::pair<Variable, Variable>>& argumentsMap, Dictionary&& attributes, const std::wstring& blockOpName, const std::wstring& blockName);
|
||||
|
||||
///
|
||||
/// Creates a new Function instance which output its input as it is and previent any gradient contribution from its output.
|
||||
///
|
||||
|
@ -4352,7 +4407,7 @@ namespace CNTK
|
|||
///
|
||||
/// Create an instance of the CNTK built-in elementwise scaled exponential linear unit operation with the specified input operand.
|
||||
///
|
||||
CNTK_API FunctionPtr SELU(const Variable& operand, double scale = 1.0507009873554804934193349852946, double alpha = 1.6732632423543772848170429916717, const std::wstring& name = L"");
|
||||
CNTK_API FunctionPtr SELU(const Variable& operand, double gamma = 1.0507009873554804934193349852946, double alpha = 1.6732632423543772848170429916717, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in elementwise leaky linear rectifier operation with the specified input operand.
|
||||
|
|
|
@ -87,7 +87,7 @@
|
|||
<PrecompiledHeader>NotUsing</PrecompiledHeader>
|
||||
<WarningLevel>Level4</WarningLevel>
|
||||
<Optimization>Disabled</Optimization>
|
||||
<PreprocessorDefinitions>CNTKV2LIBRARYDLL;WIN32;_DEBUG;_WINDOWS;_USRDLL;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<PreprocessorDefinitions>ONNX_V1_OPSCHEMA_COMPAT;CNTKV2LIBRARYDLL;WIN32;_DEBUG;_WINDOWS;_USRDLL;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<SDLCheck>true</SDLCheck>
|
||||
<TreatWarningAsError>true</TreatWarningAsError>
|
||||
<AdditionalOptions>/bigobj %(AdditionalOptions)</AdditionalOptions>
|
||||
|
@ -103,7 +103,7 @@
|
|||
<ClCompile>
|
||||
<WarningLevel>Level4</WarningLevel>
|
||||
<PrecompiledHeader>NotUsing</PrecompiledHeader>
|
||||
<PreprocessorDefinitions>CNTKV2LIBRARYDLL;WIN32;NDEBUG;_WINDOWS;_USRDLL;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<PreprocessorDefinitions>ONNX_V1_OPSCHEMA_COMPAT;CNTKV2LIBRARYDLL;WIN32;NDEBUG;_WINDOWS;_USRDLL;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<SDLCheck>true</SDLCheck>
|
||||
<AdditionalOptions>/d2Zi+ /bigobj %(AdditionalOptions)</AdditionalOptions>
|
||||
<RuntimeLibrary>MultiThreadedDLL</RuntimeLibrary>
|
||||
|
@ -119,7 +119,9 @@
|
|||
</ItemDefinitionGroup>
|
||||
<ItemDefinitionGroup Condition="$(CpuOnlyBuild)">
|
||||
<ClCompile>
|
||||
<PreprocessorDefinitions>CPUONLY;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<PreprocessorDefinitions>ONNX_V1_OPSCHEMA_COMPAT;CPUONLY;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<ObjectFileName Condition="'$(Configuration)|$(Platform)'=='Release_CpuOnly|x64'">$(IntDir)\$(ProjectName)\$(ConfigurationName)\%(RelativeDir)</ObjectFileName>
|
||||
<ObjectFileName Condition="'$(Configuration)|$(Platform)'=='Debug_CpuOnly|x64'">$(IntDir)\$(ProjectName)\$(ConfigurationName)\%(RelativeDir)</ObjectFileName>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<DelayLoadDLLs>Cntk.Math$(OutputSuffix)-$(CntkComponentVersion).dll; msmpi.dll;</DelayLoadDLLs>
|
||||
|
@ -130,6 +132,9 @@
|
|||
<ItemDefinitionGroup Condition="$(GpuBuild)">
|
||||
<ClCompile>
|
||||
<AdditionalIncludeDirectories>%(AdditionalIncludeDirectories);$(CudaInclude)</AdditionalIncludeDirectories>
|
||||
<ObjectFileName Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">$(IntDir)\$(ProjectName)\$(ConfigurationName)\%(RelativeDir)</ObjectFileName>
|
||||
<ObjectFileName Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">$(IntDir)\$(ProjectName)\$(ConfigurationName)\%(RelativeDir)</ObjectFileName>
|
||||
<ObjectFileName Condition="'$(Configuration)|$(Platform)'=='Release|x64'">$(IntDir)\$(ProjectName)\$(ConfigurationName)\%(RelativeDir)</ObjectFileName>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<AdditionalLibraryDirectories>%(AdditionalLibraryDirectories);$(CudaLibPath)</AdditionalLibraryDirectories>
|
||||
|
@ -146,6 +151,12 @@
|
|||
xcopy /D /I /Y "$(TargetPath)" "$(TargetDir).."
|
||||
</Command>
|
||||
</PostBuildEvent>
|
||||
<ClCompile>
|
||||
<ObjectFileName Condition="'$(Configuration)|$(Platform)'=='Release_UWP|x64'">$(IntDir)\$(ProjectName)\$(ConfigurationName)\%(RelativeDir)</ObjectFileName>
|
||||
</ClCompile>
|
||||
<ClCompile>
|
||||
<ObjectFileName Condition="'$(Configuration)|$(Platform)'=='Debug_UWP|x64'">$(IntDir)\$(ProjectName)\$(ConfigurationName)\%(RelativeDir)</ObjectFileName>
|
||||
</ClCompile>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="API\CNTKLibrary.h" />
|
||||
|
@ -161,6 +172,18 @@
|
|||
<ClInclude Include="MinibatchSource.h" />
|
||||
<ClInclude Include="PrimitiveFunction.h" />
|
||||
<ClInclude Include="PrimitiveOpType.h" />
|
||||
<ClInclude Include="proto\onnx\CNTKToONNX.h" />
|
||||
<ClInclude Include="proto\onnx\core\constants.h" />
|
||||
<ClInclude Include="proto\onnx\core\graph.h" />
|
||||
<ClInclude Include="proto\onnx\core\model.h" />
|
||||
<ClInclude Include="proto\onnx\core\op.h" />
|
||||
<ClInclude Include="proto\onnx\core\opsignature.h" />
|
||||
<ClInclude Include="proto\onnx\core\shape_inference.h" />
|
||||
<ClInclude Include="proto\onnx\core\status.h" />
|
||||
<ClInclude Include="proto\onnx\core\utils.h" />
|
||||
<ClInclude Include="proto\onnx\ONNX.h" />
|
||||
<ClInclude Include="proto\onnx\ONNXToCNTK.h" />
|
||||
<ClInclude Include="proto\onnx\Operators.h" />
|
||||
<ClInclude Include="Serialization.h" />
|
||||
<ClInclude Include="tensorboard\TensorBoardUtils.h" />
|
||||
<ClInclude Include="UserDefinedFunction.h" />
|
||||
|
@ -192,6 +215,27 @@
|
|||
<ClCompile Include="NDMask.cpp" />
|
||||
<ClCompile Include="PrimitiveFunction.cpp" />
|
||||
<ClCompile Include="proto\CNTK.pb.cc.VS_wrapper.cpp" />
|
||||
<ClCompile Include="proto\onnx\CNTKToONNX.cpp" />
|
||||
<ClCompile Include="proto\onnx\core\constants.cpp" />
|
||||
<ClCompile Include="proto\onnx\core\graph.cpp" />
|
||||
<ClCompile Include="proto\onnx\core\model.cpp" />
|
||||
<ClCompile Include="proto\onnx\core\op.cpp" />
|
||||
<ClCompile Include="proto\onnx\core\opsignature.cpp" />
|
||||
<ClCompile Include="proto\onnx\core\shape_inference.cpp" />
|
||||
<ClCompile Include="proto\onnx\core\status.cpp" />
|
||||
<ClCompile Include="proto\onnx\core\utils.cpp" />
|
||||
<ClCompile Include="proto\onnx\defs\experiments\defs.cpp" />
|
||||
<ClCompile Include="proto\onnx\defs\generator\defs.cpp" />
|
||||
<ClCompile Include="proto\onnx\defs\logical\defs.cpp" />
|
||||
<ClCompile Include="proto\onnx\defs\math\defs.cpp" />
|
||||
<ClCompile Include="proto\onnx\defs\nn\defs.cpp" />
|
||||
<ClCompile Include="proto\onnx\defs\reduction\defs.cpp" />
|
||||
<ClCompile Include="proto\onnx\defs\rnn\defs.cpp" />
|
||||
<ClCompile Include="proto\onnx\defs\tensor\defs.cpp" />
|
||||
<ClCompile Include="proto\onnx\ONNX.cpp" />
|
||||
<ClCompile Include="proto\onnx\ONNXToCNTK.cpp" />
|
||||
<ClCompile Include="proto\onnx\Operators.cpp" />
|
||||
<ClCompile Include="proto\onnx\protobuf\graph.pb.cc.VS_wrapper.cpp" />
|
||||
<ClCompile Include="Serialization.cpp" />
|
||||
<ClCompile Include="stdafx.cpp">
|
||||
<PrecompiledHeader>Create</PrecompiledHeader>
|
||||
|
@ -209,6 +253,7 @@
|
|||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<Proto Include="proto\CNTK.proto" />
|
||||
<Proto Include="proto\onnx\protobuf\graph.proto" />
|
||||
<Proto Include="tensorboard\tensorboard.proto" />
|
||||
</ItemGroup>
|
||||
<Target Name="ProtoGen" Inputs="@(Proto)" Outputs="@(Proto->'%(RelativeDir)%(Filename).pb.cc')">
|
||||
|
|
|
@ -37,6 +37,69 @@
|
|||
<ClCompile Include="ProgressWriter.cpp" />
|
||||
<ClCompile Include="Evaluator.cpp" />
|
||||
<ClCompile Include="UserDefinedFunction.cpp" />
|
||||
<ClCompile Include="proto\onnx\protobuf\graph.pb.cc.VS_wrapper.cpp">
|
||||
<Filter>proto\onnx\protobuf</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\CNTKToONNX.cpp">
|
||||
<Filter>proto\onnx</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\ONNX.cpp">
|
||||
<Filter>proto\onnx</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\ONNXToCNTK.cpp">
|
||||
<Filter>proto\onnx</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\Operators.cpp">
|
||||
<Filter>proto\onnx</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\core\constants.cpp">
|
||||
<Filter>proto\onnx\core</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\core\graph.cpp">
|
||||
<Filter>proto\onnx\core</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\core\model.cpp">
|
||||
<Filter>proto\onnx\core</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\core\op.cpp">
|
||||
<Filter>proto\onnx\core</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\core\opsignature.cpp">
|
||||
<Filter>proto\onnx\core</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\core\shape_inference.cpp">
|
||||
<Filter>proto\onnx\core</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\core\status.cpp">
|
||||
<Filter>proto\onnx\core</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\core\utils.cpp">
|
||||
<Filter>proto\onnx\core</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\defs\experiments\defs.cpp">
|
||||
<Filter>proto\onnx\defs\experiments</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\defs\generator\defs.cpp">
|
||||
<Filter>proto\onnx\defs\generator</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\defs\logical\defs.cpp">
|
||||
<Filter>proto\onnx\defs\logical</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\defs\math\defs.cpp">
|
||||
<Filter>proto\onnx\defs\math</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\defs\nn\defs.cpp">
|
||||
<Filter>proto\onnx\defs\nn</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\defs\reduction\defs.cpp">
|
||||
<Filter>proto\onnx\defs\reduction</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\defs\rnn\defs.cpp">
|
||||
<Filter>proto\onnx\defs\rnn</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="proto\onnx\defs\tensor\defs.cpp">
|
||||
<Filter>proto\onnx\defs\tensor</Filter>
|
||||
</ClCompile>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="stdafx.h" />
|
||||
|
@ -69,6 +132,42 @@
|
|||
<ClInclude Include="Variable.h" />
|
||||
<ClInclude Include="UserFunctionFactory.h" />
|
||||
<ClInclude Include="UserDefinedFunction.h" />
|
||||
<ClInclude Include="proto\onnx\CNTKToONNX.h">
|
||||
<Filter>proto\onnx</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="proto\onnx\ONNX.h">
|
||||
<Filter>proto\onnx</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="proto\onnx\ONNXToCNTK.h">
|
||||
<Filter>proto\onnx</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="proto\onnx\Operators.h">
|
||||
<Filter>proto\onnx</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="proto\onnx\core\constants.h">
|
||||
<Filter>proto\onnx\core</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="proto\onnx\core\graph.h">
|
||||
<Filter>proto\onnx\core</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="proto\onnx\core\model.h">
|
||||
<Filter>proto\onnx\core</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="proto\onnx\core\op.h">
|
||||
<Filter>proto\onnx\core</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="proto\onnx\core\opsignature.h">
|
||||
<Filter>proto\onnx\core</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="proto\onnx\core\shape_inference.h">
|
||||
<Filter>proto\onnx\core</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="proto\onnx\core\status.h">
|
||||
<Filter>proto\onnx\core</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="proto\onnx\core\utils.h">
|
||||
<Filter>proto\onnx\core</Filter>
|
||||
</ClInclude>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<Filter Include="API">
|
||||
|
@ -80,6 +179,42 @@
|
|||
<Filter Include="tensorboard">
|
||||
<UniqueIdentifier>{4242f3a9-0e06-4bf5-b2c2-85b292fe0b43}</UniqueIdentifier>
|
||||
</Filter>
|
||||
<Filter Include="proto\onnx">
|
||||
<UniqueIdentifier>{ca68761d-44d4-41a9-b055-4b192402ed0b}</UniqueIdentifier>
|
||||
</Filter>
|
||||
<Filter Include="proto\onnx\protobuf">
|
||||
<UniqueIdentifier>{77b30c82-208d-4c0d-825c-9b0eda61caa6}</UniqueIdentifier>
|
||||
</Filter>
|
||||
<Filter Include="proto\onnx\core">
|
||||
<UniqueIdentifier>{ac45f7f4-5f65-40d4-9163-46580266ae16}</UniqueIdentifier>
|
||||
</Filter>
|
||||
<Filter Include="proto\onnx\defs">
|
||||
<UniqueIdentifier>{cb1e39c1-bd5e-4d7f-8c83-39de28e70307}</UniqueIdentifier>
|
||||
</Filter>
|
||||
<Filter Include="proto\onnx\defs\experiments">
|
||||
<UniqueIdentifier>{6168d648-8e32-4ad7-b16d-78d6a2e7a461}</UniqueIdentifier>
|
||||
</Filter>
|
||||
<Filter Include="proto\onnx\defs\generator">
|
||||
<UniqueIdentifier>{e52f27a1-b8c4-4d67-874f-51aa43b04815}</UniqueIdentifier>
|
||||
</Filter>
|
||||
<Filter Include="proto\onnx\defs\logical">
|
||||
<UniqueIdentifier>{e8f5863d-409a-4b02-814b-a93241e4e56d}</UniqueIdentifier>
|
||||
</Filter>
|
||||
<Filter Include="proto\onnx\defs\math">
|
||||
<UniqueIdentifier>{6a99c5c8-2686-4007-83ed-988b1e1ae852}</UniqueIdentifier>
|
||||
</Filter>
|
||||
<Filter Include="proto\onnx\defs\nn">
|
||||
<UniqueIdentifier>{7a757445-6a05-4bd2-b37d-6d95da3112fc}</UniqueIdentifier>
|
||||
</Filter>
|
||||
<Filter Include="proto\onnx\defs\reduction">
|
||||
<UniqueIdentifier>{cd1c3786-8e2a-44bf-a7fc-28114019cf2d}</UniqueIdentifier>
|
||||
</Filter>
|
||||
<Filter Include="proto\onnx\defs\rnn">
|
||||
<UniqueIdentifier>{4bc38915-7be4-42e2-951b-9445abf8038c}</UniqueIdentifier>
|
||||
</Filter>
|
||||
<Filter Include="proto\onnx\defs\tensor">
|
||||
<UniqueIdentifier>{deb73515-13a1-4926-b5fc-e8c6e97f7784}</UniqueIdentifier>
|
||||
</Filter>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<Proto Include="proto\CNTK.proto">
|
||||
|
@ -88,5 +223,8 @@
|
|||
<Proto Include="tensorboard\tensorboard.proto">
|
||||
<Filter>tensorboard</Filter>
|
||||
</Proto>
|
||||
<Proto Include="proto\onnx\protobuf\graph.proto">
|
||||
<Filter>proto\onnx\protobuf</Filter>
|
||||
</Proto>
|
||||
</ItemGroup>
|
||||
</Project>
|
|
@ -11,6 +11,7 @@
|
|||
#include "Utils.h"
|
||||
#include "UserFunctionFactory.h"
|
||||
#include "TrainingNodes.h"
|
||||
#include "proto/onnx/ONNX.h"
|
||||
|
||||
using namespace Microsoft::MSR::CNTK;
|
||||
|
||||
|
@ -467,27 +468,53 @@ namespace CNTK
|
|||
vectorBuf.assign(s.begin(), s.end());
|
||||
}
|
||||
|
||||
void Function::Save(const std::wstring& filepath)
|
||||
void Function::Save(const std::wstring& filepath, ModelFormat format)
|
||||
{
|
||||
Dictionary model = Serialize();
|
||||
auto stream = GetFstream(filepath, false);
|
||||
*stream << model;
|
||||
stream->flush();
|
||||
switch (format)
|
||||
{
|
||||
case ModelFormat::CNTKv2:
|
||||
{
|
||||
Dictionary model = Serialize();
|
||||
auto stream = GetFstream(filepath, false);
|
||||
*stream << model;
|
||||
stream->flush();
|
||||
break;
|
||||
}
|
||||
|
||||
case ModelFormat::ONNX:
|
||||
{
|
||||
ONNXFormat::Save(RootFunction(), filepath);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*static*/ FunctionPtr Function::Load(const std::wstring& filepath, const DeviceDescriptor& computeDevice)
|
||||
/*static*/ FunctionPtr Function::Load(const std::wstring& filepath, const DeviceDescriptor& computeDevice, ModelFormat format)
|
||||
{
|
||||
auto stream = GetFstream(filepath, true);
|
||||
if (!Internal::IsLegacyModel(*stream))
|
||||
switch (format)
|
||||
{
|
||||
Dictionary model;
|
||||
*stream >> model;
|
||||
return Function::Deserialize(model, computeDevice);
|
||||
}
|
||||
else
|
||||
{
|
||||
return Internal::LoadLegacyModel(filepath, computeDevice); // throw an exception if deserializer != nullptr?
|
||||
case ModelFormat::CNTKv2:
|
||||
{
|
||||
auto stream = GetFstream(filepath, true);
|
||||
if (!Internal::IsLegacyModel(*stream))
|
||||
{
|
||||
Dictionary model;
|
||||
*stream >> model;
|
||||
return Function::Deserialize(model, computeDevice);
|
||||
}
|
||||
else
|
||||
{
|
||||
return Internal::LoadLegacyModel(filepath, computeDevice); // throw an exception if deserializer != nullptr?
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case ModelFormat::ONNX:
|
||||
return ONNXFormat::Load(filepath, computeDevice);
|
||||
break;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/*static*/ FunctionPtr Function::Load(const char *buffer, size_t length, const DeviceDescriptor& computeDevice)
|
||||
|
@ -1145,10 +1172,13 @@ namespace CNTK
|
|||
if (!axis.IsStaticAxis() && (axis != Axis::AllStaticAxes()))
|
||||
LogicError("Softmax: support only static axes.");
|
||||
|
||||
auto additionalProperties = Dictionary();
|
||||
additionalProperties[PrimitiveFunction::AttributeNameAxis] = axis;
|
||||
|
||||
if (((operand.Shape().Rank() == 1) && (axis.StaticAxisIndex() == 0)) ||
|
||||
(axis == Axis::AllStaticAxes()))
|
||||
{
|
||||
return UnaryOp(PrimitiveOpType::Softmax, operand, Dictionary(), name);
|
||||
return UnaryOp(PrimitiveOpType::Softmax, operand, std::move(additionalProperties), name);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -1157,7 +1187,7 @@ namespace CNTK
|
|||
auto expOperandDelta = Exp(operandDelta);
|
||||
auto result = ElementDivide(expOperandDelta, ReduceSum(expOperandDelta, axis));
|
||||
|
||||
return AsBlock(std::move(result), { { operandPlaceholder, operand } }, L"Softmax", name);
|
||||
return AsBlock(std::move(result), { { operandPlaceholder, operand } }, std::move(additionalProperties), L"Softmax", name);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1416,6 +1446,28 @@ namespace CNTK
|
|||
return ElementTimes(leftOperand, Reciprocal(rightOperand), name);
|
||||
}
|
||||
|
||||
FunctionPtr ElementMax(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name)
|
||||
{
|
||||
auto leftOperandPlaceholder = PlaceholderVariable();
|
||||
auto rightOperandPlaceholder = PlaceholderVariable();
|
||||
|
||||
auto result = ElementSelect(Greater(leftOperandPlaceholder, rightOperandPlaceholder),
|
||||
leftOperandPlaceholder,
|
||||
rightOperandPlaceholder);
|
||||
return AsBlock(std::move(result), { { leftOperandPlaceholder, leftOperand },{ rightOperandPlaceholder, rightOperand } }, L"ElementMax", name);
|
||||
}
|
||||
|
||||
FunctionPtr ElementMin(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name)
|
||||
{
|
||||
auto leftOperandPlaceholder = PlaceholderVariable();
|
||||
auto rightOperandPlaceholder = PlaceholderVariable();
|
||||
|
||||
auto result = ElementSelect(Less(leftOperandPlaceholder, rightOperandPlaceholder),
|
||||
leftOperandPlaceholder,
|
||||
rightOperandPlaceholder);
|
||||
return AsBlock(std::move(result), { { leftOperandPlaceholder, leftOperand },{ rightOperandPlaceholder, rightOperand } }, L"ElementMin", name);
|
||||
}
|
||||
|
||||
FunctionPtr Equal(const Variable& leftOperand, const Variable& rightOperand, const std::wstring& name)
|
||||
{
|
||||
return BinaryOp(PrimitiveOpType::Equal, leftOperand, rightOperand, Dictionary(), name);
|
||||
|
@ -1966,6 +2018,26 @@ namespace CNTK
|
|||
name);
|
||||
}
|
||||
|
||||
FunctionPtr LocalResponseNormalization(const Variable& operand, size_t depthRadius, double bias, double alpha, double beta, const std::wstring& name)
|
||||
{
|
||||
auto additionalProperties = Dictionary();
|
||||
additionalProperties[PrimitiveFunction::AttributeNameDepthRadius] = depthRadius;
|
||||
additionalProperties[PrimitiveFunction::AttributeNameBias] = bias;
|
||||
additionalProperties[PrimitiveFunction::AttributeNameAlpha] = alpha;
|
||||
additionalProperties[PrimitiveFunction::AttributeNameBeta] = beta;
|
||||
|
||||
auto operandPlaceholder = PlaceholderVariable();
|
||||
auto operandSquare = Square(operandPlaceholder);
|
||||
operandSquare = Reshape(operandSquare, { NDShape::InferredDimension, 1 }, Axis(2), Axis(3));
|
||||
auto weights = Constant({ 1, 1, 2 * depthRadius + 1, 1 }, operand.GetDataType(), alpha / (2 * depthRadius + 1));
|
||||
auto convResult = Convolution(weights, operandSquare);
|
||||
convResult = Reshape(convResult, { NDShape::InferredDimension }, Axis(2), Axis(4));
|
||||
auto denom = Exp(ElementTimes(Constant::Scalar(operand.GetDataType(), beta), Log(Plus(Constant::Scalar(operand.GetDataType(), bias), convResult))));
|
||||
|
||||
auto result = ElementDivide(operandPlaceholder, denom);
|
||||
return AsBlock(std::move(result), { { operandPlaceholder, operand } }, std::move(additionalProperties), L"LocalResponseNormalization", name);
|
||||
}
|
||||
|
||||
FunctionPtr Clip(const Variable& operand, const Variable& min, const Variable& max, const std::wstring& name)
|
||||
{
|
||||
std::vector<Variable> operands = { operand, min, max };
|
||||
|
@ -1999,11 +2071,16 @@ namespace CNTK
|
|||
}
|
||||
|
||||
FunctionPtr AsBlock(FunctionPtr&& composite, const std::vector<std::pair<Variable, Variable>>& argumentsMap, const std::wstring& blockOpName, const std::wstring& blockName)
|
||||
{
|
||||
return AsBlock(std::move(composite), argumentsMap, Dictionary(), blockOpName, blockName);
|
||||
}
|
||||
|
||||
FunctionPtr AsBlock(FunctionPtr&& composite, const std::vector<std::pair<Variable, Variable>>& argumentsMap, Dictionary&& attributes, const std::wstring& blockOpName, const std::wstring& blockName)
|
||||
{
|
||||
if (!composite->IsComposite())
|
||||
InvalidArgument("Composite argument '%S' to AsBlock is not a composite Function.", composite->AsString().c_str());
|
||||
|
||||
return AsComposite(MakeSharedObject<BlockFunction>(std::move(composite), argumentsMap, blockOpName, Dictionary(), blockName), blockName);
|
||||
return AsComposite(MakeSharedObject<BlockFunction>(std::move(composite), argumentsMap, blockOpName, std::move(attributes), blockName), blockName);
|
||||
}
|
||||
|
||||
FunctionPtr AsComposite(const FunctionPtr& rootFunction, const std::wstring& name)
|
||||
|
@ -2027,26 +2104,33 @@ namespace CNTK
|
|||
return UnaryOp(PrimitiveOpType::ELU, operand, Dictionary(), name);
|
||||
}
|
||||
|
||||
FunctionPtr SELU(const Variable& operand, double scale, double alpha, const std::wstring& name)
|
||||
FunctionPtr SELU(const Variable& operand, double gamma, double alpha, const std::wstring& name)
|
||||
{
|
||||
auto additionalProperties = Dictionary();
|
||||
additionalProperties[PrimitiveFunction::AttributeNameGamma] = gamma;
|
||||
additionalProperties[PrimitiveFunction::AttributeNameAlpha] = alpha;
|
||||
|
||||
auto operandPlaceholder = PlaceholderVariable();
|
||||
auto lessThanZero = Less(operandPlaceholder, Constant::Scalar(operand.GetDataType(), 0.0));
|
||||
auto result = ElementSelect(lessThanZero,
|
||||
ElementTimes(Constant::Scalar(operand.GetDataType(), alpha), ELU(operandPlaceholder)),
|
||||
operandPlaceholder);
|
||||
result = ElementTimes(Constant::Scalar(operand.GetDataType(), scale), result);
|
||||
return AsBlock(std::move(result), { { operandPlaceholder, operand } }, L"SELU", name);
|
||||
result = ElementTimes(Constant::Scalar(operand.GetDataType(), gamma), result);
|
||||
return AsBlock(std::move(result), { { operandPlaceholder, operand } }, std::move(additionalProperties), L"SELU", name);
|
||||
}
|
||||
|
||||
FunctionPtr LeakyReLU(const Variable& operand, const std::wstring& name)
|
||||
{
|
||||
auto additionalProperties = Dictionary();
|
||||
additionalProperties[PrimitiveFunction::AttributeNameAlpha] = 0.01;
|
||||
|
||||
auto operandPlaceholder = PlaceholderVariable();
|
||||
auto lessThanZero = Less(operandPlaceholder, Constant::Scalar(operand.GetDataType(), 0.0));
|
||||
auto result = ElementSelect(lessThanZero,
|
||||
ElementTimes(Constant::Scalar(operand.GetDataType(), 0.01), operandPlaceholder),
|
||||
operandPlaceholder);
|
||||
|
||||
return AsBlock(std::move(result), { { operandPlaceholder, operand } }, L"LeakyReLU", name);
|
||||
return AsBlock(std::move(result), { { operandPlaceholder, operand } }, std::move(additionalProperties), L"LeakyReLU", name);
|
||||
}
|
||||
|
||||
FunctionPtr PReLU(const Variable& alpha, const Variable& operand, const std::wstring& name)
|
||||
|
@ -2530,6 +2614,7 @@ namespace CNTK
|
|||
additionalProperties[PrimitiveFunction::AttributeNameUpperPad] = NDShape({0});
|
||||
additionalProperties[PrimitiveFunction::AttributeNameTranspose] = transpose;
|
||||
additionalProperties[PrimitiveFunction::AttributeNameOutputShape] = outputShape;
|
||||
additionalProperties[PrimitiveFunction::AttributeNameKernelShape] = NDShape({0});
|
||||
additionalProperties[PrimitiveFunction::AttributeNameMaxTempMemSizeInSamples] = maxTempMemSizeInSamples;
|
||||
|
||||
return BinaryOp(PrimitiveOpType::Convolution, convolutionMap, operand, std::move(additionalProperties), name);
|
||||
|
|
|
@ -111,6 +111,13 @@ namespace CNTK
|
|||
/*static*/ const std::wstring PrimitiveFunction::AttributeNamePaddingFoot = L"paddingFoot";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNamePaddingMode = L"paddingMode";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNamePaddingConstantValue = L"paddingConstantValue";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameAlpha = L"alpha";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameBeta = L"beta";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameGamma = L"gamma";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameKernelShape = L"kernelShape";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameBias = L"bias";
|
||||
/*static*/ const std::wstring PrimitiveFunction::AttributeNameDepthRadius = L"depthRadius";
|
||||
|
||||
|
||||
/*static*/ DataType PrimitiveFunction::GetOutputDataType(PrimitiveOpType op, std::vector<Variable>& inputs, bool inferDimensions)
|
||||
{
|
||||
|
@ -876,6 +883,7 @@ namespace CNTK
|
|||
m_attributes[PrimitiveFunction::AttributeNameSharing] = AsDictionaryValueVector(sharing);
|
||||
m_attributes[PrimitiveFunction::AttributeNameAutoPadding] = AsDictionaryValueVector(autoPadding);
|
||||
m_attributes[PrimitiveFunction::AttributeNameDilation] = dilation;
|
||||
m_attributes[PrimitiveFunction::AttributeNameKernelShape] = kernelShape;
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::CrossEntropyWithSoftmax:
|
||||
|
|
|
@ -281,6 +281,12 @@ namespace CNTK
|
|||
static const std::wstring AttributeNamePaddingFoot;
|
||||
static const std::wstring AttributeNamePaddingMode;
|
||||
static const std::wstring AttributeNamePaddingConstantValue;
|
||||
static const std::wstring AttributeNameAlpha;
|
||||
static const std::wstring AttributeNameBeta;
|
||||
static const std::wstring AttributeNameGamma;
|
||||
static const std::wstring AttributeNameKernelShape;
|
||||
static const std::wstring AttributeNameBias;
|
||||
static const std::wstring AttributeNameDepthRadius;
|
||||
|
||||
protected:
|
||||
PrimitiveFunction(PrimitiveOpType op, const std::vector<Variable>& inputs, Dictionary&& functionConfig, const std::wstring& functionName, const std::wstring& uid)
|
||||
|
|
|
@ -0,0 +1,840 @@
|
|||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
|
||||
#include "CNTKToONNX.h"
|
||||
#include "proto/onnx/core/model.h"
|
||||
#include "proto/onnx/core/graph.h"
|
||||
|
||||
#include "Utils.h"
|
||||
#include "Operators.h"
|
||||
#include "BlockFunction.h"
|
||||
#include <vector>
|
||||
|
||||
using namespace CNTK::ONNX;
|
||||
using namespace CNTK;
|
||||
|
||||
//
|
||||
// A helper function, to reverse any iterable container and return a copy
|
||||
// of the reversed container.
|
||||
//
|
||||
template<typename ItrType>
|
||||
ItrType reverse(ItrType v)
|
||||
{
|
||||
std::reverse(std::begin(v), std::end(v));
|
||||
return v;
|
||||
}
|
||||
|
||||
//
|
||||
// Helper function to reduce the rank of a shape.
|
||||
//
|
||||
ONNXIR::TypeProto ReduceRank(const ONNXIR::TypeProto::TensorShapeProto* inputShape, int reductionRank, bool rightReduction)
|
||||
{
|
||||
assert(inputShape != nullptr);
|
||||
|
||||
int inputRank = inputShape->dim_size();
|
||||
assert(inputRank > reductionRank);
|
||||
|
||||
ONNXIR::TypeProto newShape;
|
||||
int64_t reduceDim = 1;
|
||||
|
||||
if (rightReduction)
|
||||
{
|
||||
for (int index = 0; index < (inputRank - reductionRank); index++)
|
||||
newShape.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(inputShape->dim(index).dim_value());
|
||||
|
||||
for (int index = (inputRank - reductionRank); index < inputRank; index++)
|
||||
reduceDim *= inputShape->dim(index).dim_value();
|
||||
|
||||
newShape.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(reduceDim);
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int index = 0; index < reductionRank; index++)
|
||||
reduceDim *= inputShape->dim(index).dim_value();
|
||||
|
||||
newShape.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(reduceDim);
|
||||
|
||||
for (int index = reductionRank; index < inputRank; index++)
|
||||
newShape.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(inputShape->dim(index).dim_value());
|
||||
}
|
||||
|
||||
return newShape;
|
||||
}
|
||||
|
||||
namespace CNTK
|
||||
{
|
||||
|
||||
class CNTKToONNXHelper
|
||||
{
|
||||
public:
|
||||
//
|
||||
// Copy the entire CNTK graph to ONNX graph.
|
||||
//
|
||||
static void Copy(const FunctionPtr& src, ONNXIR::Graph* dst);
|
||||
|
||||
private:
|
||||
//
|
||||
// Recursively create ONNX nodes corresponding to each CNTK node.
|
||||
//
|
||||
static ONNXIR::Node* CreateNode(const FunctionPtr& src,
|
||||
ONNXIR::Graph* graph,
|
||||
std::unordered_map<FunctionPtr, ONNXIR::Node*>& functionNodes,
|
||||
std::unordered_map<Variable, ONNXIR::Node*>& variableNodes,
|
||||
const std::unordered_map<Variable, Variable>& compositeOutputsMap);
|
||||
|
||||
//
|
||||
// Traverse the entire graph and collect variable mapping between graph inside and outside the block.
|
||||
//
|
||||
static void TraverseGraph(const FunctionPtr& src,
|
||||
std::set<FunctionPtr>& visited,
|
||||
std::unordered_map<Variable, Variable>& compositeOutputsMap);
|
||||
|
||||
//
|
||||
// Copy the content of NDArrayView to TensorProto, and do the needed
|
||||
// convergence.
|
||||
//
|
||||
static void CopyTensor(const NDArrayViewPtr src, ONNXIR::TensorProto& dst);
|
||||
|
||||
//
|
||||
// Copy supported attributes from CNTK node to corresponding ONNX node.
|
||||
//
|
||||
static void CopyAttributes(const FunctionPtr& src, ONNXIR::Node* node);
|
||||
|
||||
//
|
||||
// Convert Axis object to actual tensor index.
|
||||
//
|
||||
static int ToIndex(const Axis& axis);
|
||||
|
||||
//
|
||||
// Convert NDShape and various std::vector types to TensorShape
|
||||
//
|
||||
static ONNXIR::TypeProto ToTypeProto(const NDShape& shape, bool hasBatchAxis = false);
|
||||
static ONNXIR::TypeProto ToTypeProto(const std::vector<bool>& shape);
|
||||
static ONNXIR::TypeProto ToTypeProto(const std::vector<int>& shape);
|
||||
static ONNXIR::TypeProto ToTypeProto(const std::vector<Axis>& axes);
|
||||
|
||||
//
|
||||
// Convert TypeProto, NDShape and various std::vector types to std::vector
|
||||
//
|
||||
static std::vector<int64_t> ToINTS(const ONNXIR::TypeProto& shape);
|
||||
static std::vector<int64_t> ToINTS(const NDShape& shape, bool hasBatchAxis = false);
|
||||
static std::vector<int64_t> ToINTS(const std::vector<bool>& shape);
|
||||
static std::vector<int64_t> ToINTS(const std::vector<int>& shape);
|
||||
static std::vector<int64_t> ToINTS(const std::vector<Axis>& axes);
|
||||
|
||||
//
|
||||
// Convert data types from CNTK to ONNX.
|
||||
//
|
||||
static void UpdateONNXType(DataType dataType, ONNXIR::TypeProto& type);
|
||||
|
||||
//
|
||||
// Map CNTK OP names to ONNX OP Names.
|
||||
//
|
||||
static std::string ToOPName(const FunctionPtr& src);
|
||||
|
||||
//
|
||||
// Check that the CNTK variable is compatible with ONNX.
|
||||
//
|
||||
static void ValidateVariable(const Variable& v);
|
||||
|
||||
//
|
||||
// Which input to ignore during converting a CNTK block to a primitive OP in ONNX.
|
||||
//
|
||||
static bool FilterInput(const FunctionPtr& src, const CNTK::Variable& input, size_t inputIndex);
|
||||
|
||||
//
|
||||
// Argument orders between CNTK and ONNX aren't always the same.
|
||||
//
|
||||
static std::vector<ONNXIR::NodeArg> MapInputsOrderToONNX(const FunctionPtr& src, const std::vector<ONNXIR::NodeArg>& inputs);
|
||||
|
||||
//
|
||||
// Add current CNTK node to ONNX graph.
|
||||
//
|
||||
static ONNXIR::Node* AddNode(const FunctionPtr& src, ONNXIR::Graph* graph, const std::vector<ONNXIR::NodeArg>& inputs, const std::vector<ONNXIR::NodeArg>& outputs);
|
||||
};
|
||||
}
|
||||
|
||||
std::unique_ptr<ONNXIR::Model> CNTKToONNX::CreateModel(const FunctionPtr& src)
|
||||
{
|
||||
std::unique_ptr<ONNXIR::Model> model(new ONNXIR::Model("CNTKGraph", true));
|
||||
auto dstGraph = model->MainGraph();
|
||||
CNTKToONNXHelper::Copy(src, dstGraph);
|
||||
ONNXIR::Status status = dstGraph->Resolve();
|
||||
if (!status.Ok())
|
||||
LogicError("%s", status.ErrorMsg().c_str());
|
||||
return model;
|
||||
}
|
||||
|
||||
void CNTKToONNXHelper::Copy(const FunctionPtr& src, ONNXIR::Graph* dst)
|
||||
{
|
||||
std::set<FunctionPtr> visited;
|
||||
std::unordered_map<Variable, Variable> compositeOutputsMap;
|
||||
std::unordered_map<FunctionPtr, ONNXIR::Node*> functionNodes;
|
||||
std::unordered_map<Variable, ONNXIR::Node*> variableNodes;
|
||||
|
||||
//
|
||||
// Traverse the graph and collect some information.
|
||||
//
|
||||
TraverseGraph(src, visited, compositeOutputsMap);
|
||||
|
||||
//
|
||||
// Iterate through each node in CNTK graph and create an equivalent node
|
||||
// in ONNX graph.
|
||||
//
|
||||
CreateNode(src, dst, functionNodes, variableNodes, compositeOutputsMap);
|
||||
}
|
||||
|
||||
void CNTKToONNXHelper::CopyTensor(const NDArrayViewPtr src, ONNXIR::TensorProto& dst)
|
||||
{
|
||||
auto dataType = src->GetDataType();
|
||||
auto srcTemp = src->DeepClone();
|
||||
auto srcShape = srcTemp->Shape();
|
||||
auto totalSize = srcShape.TotalSize();
|
||||
|
||||
// This is our own copy so move it to the CPU.
|
||||
srcTemp->ChangeDevice(DeviceDescriptor::CPUDevice());
|
||||
|
||||
switch (dataType)
|
||||
{
|
||||
case DataType::Float:
|
||||
{
|
||||
dst.set_data_type(ONNXIR::TensorProto_DataType_FLOAT);
|
||||
auto data = srcTemp->DataBuffer<float>();
|
||||
for (size_t index = 0; index < totalSize; index++)
|
||||
*(dst.mutable_float_data()->Add()) = data[index];
|
||||
|
||||
break;
|
||||
}
|
||||
case DataType::Double:
|
||||
{
|
||||
dst.set_data_type(ONNXIR::TensorProto_DataType_DOUBLE);
|
||||
auto data = srcTemp->DataBuffer<double>();
|
||||
for (size_t index = 0; index < totalSize; index++)
|
||||
*(dst.mutable_double_data()->Add()) = data[index];
|
||||
|
||||
break;
|
||||
}
|
||||
default:
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
auto dimensions = reverse(srcShape.Dimensions());
|
||||
for (auto dim : dimensions)
|
||||
*(dst.mutable_dims()->Add()) = dim;
|
||||
}
|
||||
|
||||
int CNTKToONNXHelper::ToIndex(const Axis& axis)
|
||||
{
|
||||
if ((axis == Axis::AllAxes()) || (axis == Axis::AllStaticAxes()))
|
||||
LogicError("AllAxes and AllStaticAxes are currently not supported.");
|
||||
|
||||
if (axis.IsSequenceAxis())
|
||||
LogicError("Sequence axis are currently not supported.");
|
||||
|
||||
if (axis.IsBatchAxis())
|
||||
return 0;
|
||||
|
||||
return axis.StaticAxisIndex() + 1;
|
||||
}
|
||||
|
||||
ONNXIR::TypeProto CNTKToONNXHelper::ToTypeProto(const NDShape& shape, bool hasBatchAxis)
|
||||
{
|
||||
ONNXIR::TypeProto newShape;
|
||||
if (shape.HasUnboundDimension())
|
||||
LogicError("Inferred and FreeDimension aren't currently supported.");
|
||||
|
||||
if (hasBatchAxis)
|
||||
newShape.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
|
||||
|
||||
auto dimensions = reverse(shape.Dimensions());
|
||||
for (auto dimension : dimensions)
|
||||
{
|
||||
newShape.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dimension);
|
||||
}
|
||||
|
||||
return newShape;
|
||||
}
|
||||
|
||||
ONNXIR::TypeProto CNTKToONNXHelper::ToTypeProto(const std::vector<bool>& shape)
|
||||
{
|
||||
ONNXIR::TypeProto newShape;
|
||||
auto dimensions = reverse(shape);
|
||||
for (auto dimension : dimensions)
|
||||
newShape.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dimension ? 1:0);
|
||||
|
||||
return newShape;
|
||||
}
|
||||
|
||||
ONNXIR::TypeProto CNTKToONNXHelper::ToTypeProto(const std::vector<int>& shape)
|
||||
{
|
||||
ONNXIR::TypeProto newShape;
|
||||
auto dimensions = reverse(shape);
|
||||
for (auto dimension : dimensions)
|
||||
newShape.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dimension);
|
||||
|
||||
return newShape;
|
||||
}
|
||||
|
||||
ONNXIR::TypeProto CNTKToONNXHelper::ToTypeProto(const std::vector<Axis>& axes)
|
||||
{
|
||||
std::vector<int> axesValue;
|
||||
for (auto axis : axes)
|
||||
{
|
||||
axesValue.push_back(ToIndex(axis));
|
||||
}
|
||||
std::sort(axesValue.begin(), axesValue.end());
|
||||
|
||||
ONNXIR::TypeProto newShape;
|
||||
for (auto dimension : axesValue)
|
||||
newShape.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dimension);
|
||||
|
||||
return newShape;
|
||||
}
|
||||
|
||||
std::vector<int64_t> CNTKToONNXHelper::ToINTS(const ONNXIR::TypeProto& shape)
|
||||
{
|
||||
std::vector<int64_t> newShape;
|
||||
|
||||
for (int i = 0; i < shape.tensor_type().shape().dim_size(); i++)
|
||||
newShape.push_back((int64_t)shape.tensor_type().shape().dim(i).dim_value());
|
||||
|
||||
return newShape;
|
||||
}
|
||||
|
||||
std::vector<int64_t> CNTKToONNXHelper::ToINTS(const NDShape& shape, bool hasBatchAxis)
|
||||
{
|
||||
return ToINTS(ToTypeProto(shape, hasBatchAxis));
|
||||
}
|
||||
|
||||
std::vector<int64_t> CNTKToONNXHelper::ToINTS(const std::vector<bool>& shape)
|
||||
{
|
||||
return ToINTS(ToTypeProto(shape));
|
||||
}
|
||||
|
||||
std::vector<int64_t> CNTKToONNXHelper::ToINTS(const std::vector<int>& shape)
|
||||
{
|
||||
return ToINTS(ToTypeProto(shape));
|
||||
}
|
||||
|
||||
std::vector<int64_t> CNTKToONNXHelper::ToINTS(const std::vector<Axis>& axes)
|
||||
{
|
||||
return ToINTS(ToTypeProto(axes));
|
||||
}
|
||||
|
||||
void CNTKToONNXHelper::UpdateONNXType(DataType dataType, ONNXIR::TypeProto& type)
|
||||
{
|
||||
switch (dataType)
|
||||
{
|
||||
case DataType::Float:
|
||||
type.mutable_tensor_type()->set_elem_type(ONNXIR::TensorProto_DataType_FLOAT);
|
||||
break;
|
||||
case DataType::Double:
|
||||
type.mutable_tensor_type()->set_elem_type(ONNXIR::TensorProto_DataType_DOUBLE);
|
||||
break;
|
||||
default:
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
}
|
||||
|
||||
std::string CNTKToONNXHelper::ToOPName(const FunctionPtr& src)
|
||||
{
|
||||
auto lookup = Operators::CntkToONNXLookup();
|
||||
assert(lookup.count(src->OpName()) != 0);
|
||||
|
||||
std::string opName = ToString(src->OpName());
|
||||
if (lookup.count(src->OpName()) == 1)
|
||||
{
|
||||
auto attributesMap = lookup.find(src->OpName())->second.map;
|
||||
opName = attributesMap[src->OpName()];
|
||||
}
|
||||
else
|
||||
{
|
||||
// Some nodes map one to many.
|
||||
if (src->OpName() == L"Convolution")
|
||||
{
|
||||
auto transpose = (bool)src->Attributes()[L"transpose"].Value<bool>();
|
||||
if (transpose)
|
||||
opName = "ConvTranspose";
|
||||
else
|
||||
opName = "Conv";
|
||||
}
|
||||
else if (src->OpName() == L"Pooling")
|
||||
{
|
||||
PoolingType poolingType = (PoolingType)src->Attributes()[L"poolingType"].Value<size_t>();
|
||||
if (poolingType == PoolingType::Max)
|
||||
opName = "MaxPool";
|
||||
else
|
||||
opName = "AveragePool";
|
||||
}
|
||||
}
|
||||
|
||||
return opName;
|
||||
}
|
||||
|
||||
void CNTKToONNXHelper::ValidateVariable(const Variable& v)
|
||||
{
|
||||
if ((v.HasBatchAxis() && (v.DynamicAxes().size() > 1)) ||
|
||||
(!v.HasBatchAxis() && (v.DynamicAxes().size() > 0)))
|
||||
{
|
||||
LogicError("Sequence and user defined dynamic axis are currently not supported.");
|
||||
}
|
||||
}
|
||||
|
||||
bool CNTKToONNXHelper::FilterInput(const FunctionPtr& src, const CNTK::Variable& input, size_t inputIndex)
|
||||
{
|
||||
// In CNTK block functions, they expose all constants inside the block. For block functions that
|
||||
// map directly to ONNX OP, we don't care about constanst inside the block.
|
||||
if (input.IsConstant())
|
||||
return !Operators::IsValidInputs(src->OpName(), inputIndex);
|
||||
return false;
|
||||
}
|
||||
|
||||
//
|
||||
// This is the main horsepower, it navigate CNTK graph recursivley while keep track of all visited nodes and variables,
|
||||
// and create the corresponding ONNX graph.
|
||||
//
|
||||
ONNXIR::Node* CNTKToONNXHelper::CreateNode(const FunctionPtr& src,
|
||||
ONNXIR::Graph* graph,
|
||||
std::unordered_map<FunctionPtr, ONNXIR::Node*>& functionNodes,
|
||||
std::unordered_map<Variable, ONNXIR::Node*>& variableNodes,
|
||||
const std::unordered_map<Variable, Variable>& compositeOutputsMap)
|
||||
{
|
||||
auto iter = functionNodes.find(src);
|
||||
if (iter != functionNodes.end())
|
||||
return iter->second;
|
||||
|
||||
ONNXIR::Node* functionNode = nullptr;
|
||||
std::string opName = ToString(src->OpName());
|
||||
|
||||
//
|
||||
// If this block node equivalent to a primitive ONNX OP, then treated as such.
|
||||
// And just maps its argument to ONNX node.
|
||||
//
|
||||
if (src->IsBlock() &&
|
||||
(!Operators::IsSupportedCNTKOP(src->OpName()) || Operators::IsLayerCNTKOP(src->OpName())))
|
||||
{
|
||||
functionNode = CreateNode(src->BlockRoot(), graph, functionNodes, variableNodes, compositeOutputsMap);
|
||||
}
|
||||
//
|
||||
// For compatibility of other framework that support ONNX, we will limit the list of OPs to the one
|
||||
// supported by ONNX https://github.com/onnx/onnx/tree/master/onnx/defs.
|
||||
//
|
||||
else if (Operators::IsSupportedCNTKOP(src->OpName()))
|
||||
{
|
||||
std::vector<ONNXIR::NodeArg> inputs;
|
||||
std::vector<ONNXIR::NodeArg> outputs;
|
||||
|
||||
for (const auto& output : src->Outputs())
|
||||
{
|
||||
ValidateVariable(output);
|
||||
|
||||
auto outputArgType = ToTypeProto(output.Shape(), output.HasBatchAxis());
|
||||
UpdateONNXType(output.GetDataType(), outputArgType);
|
||||
|
||||
ONNXIR::NodeArg outputArg(ToString(output.Uid()), &outputArgType);
|
||||
outputs.push_back(outputArg);
|
||||
}
|
||||
|
||||
for (size_t inputIndex = 0; inputIndex < src->Inputs().size(); ++inputIndex)
|
||||
{
|
||||
auto input = src->Inputs()[inputIndex];
|
||||
|
||||
if (input.IsPlaceholder())
|
||||
{
|
||||
input = input.BlockFunctionVariableMapping();
|
||||
if (input.IsPlaceholder())
|
||||
LogicError("Node '%S': Placeholder isn't supported currently.", src->AsString().c_str());
|
||||
}
|
||||
ValidateVariable(input);
|
||||
|
||||
if (src->IsBlock() && FilterInput(src, input, inputIndex))
|
||||
continue;
|
||||
|
||||
//
|
||||
// Use user defined name if available otherwise use our internel unique name ID.
|
||||
//
|
||||
std::string inputName = ToString(input.Uid());
|
||||
auto inputItr = compositeOutputsMap.find(input);
|
||||
if (inputItr != compositeOutputsMap.end())
|
||||
inputName = ToString(inputItr->second.Uid());
|
||||
|
||||
auto inputArgType = ToTypeProto(input.Shape(), input.HasBatchAxis());
|
||||
UpdateONNXType(input.GetDataType(), inputArgType);
|
||||
ONNXIR::NodeArg inputArg(inputName, &inputArgType);
|
||||
|
||||
inputs.push_back(inputArg);
|
||||
|
||||
//
|
||||
// Leaf nodes are data entry to the graph and need their own node with only output arg.
|
||||
//
|
||||
if ((input.IsParameter() || input.IsConstant()) &&
|
||||
!Operators::IgnoreConstantAndParameter(src->OpName(), inputIndex))
|
||||
{
|
||||
if (variableNodes.find(input) == variableNodes.end())
|
||||
{
|
||||
std::vector<ONNXIR::NodeArg> varInputs;
|
||||
std::vector<ONNXIR::NodeArg> varOutputs;
|
||||
|
||||
varOutputs.push_back({ inputArg });
|
||||
ONNXIR::Node* variableNode = nullptr;
|
||||
if (input.IsParameter() || input.IsConstant())
|
||||
{
|
||||
variableNode = graph->AddNode(inputName, "Constant", "", varInputs, varOutputs);
|
||||
auto srcTensor = input.IsParameter() ? Parameter(input).Value() : Constant(input).Value();
|
||||
|
||||
ONNXIR::TensorProto dstTensor;
|
||||
CopyTensor(srcTensor, dstTensor);
|
||||
|
||||
variableNode->AddAttribute("value", dstTensor);
|
||||
variableNodes.emplace(input, variableNode);
|
||||
}
|
||||
}
|
||||
}
|
||||
//
|
||||
// If this input is output, then it is the ouput of an up stream node. Recursively add all upstream nodes.
|
||||
// Pretty much, we are doing DFS.
|
||||
//
|
||||
else if (input.IsOutput())
|
||||
CreateNode(input.Owner(), graph, functionNodes, variableNodes, compositeOutputsMap);
|
||||
}
|
||||
|
||||
//
|
||||
// Finally add a new node to ONNX graph.
|
||||
//
|
||||
functionNode = AddNode(src, graph, inputs, outputs);
|
||||
}
|
||||
else
|
||||
LogicError("Node '%S': Unsupported node.", src->AsString().c_str());
|
||||
|
||||
functionNodes.emplace(src, functionNode);
|
||||
return functionNode;
|
||||
}
|
||||
|
||||
void CNTKToONNXHelper::TraverseGraph(const FunctionPtr& src,
|
||||
std::set<FunctionPtr>& visited,
|
||||
std::unordered_map<Variable, Variable>& compositeOutputsMap)
|
||||
{
|
||||
auto iter = visited.find(src);
|
||||
if (iter != visited.end())
|
||||
return;
|
||||
|
||||
std::string opName = ToString(src->OpName());
|
||||
if (src->IsBlock())
|
||||
{
|
||||
if (!Operators::IsSupportedCNTKOP(src->OpName()) || Operators::IsLayerCNTKOP(src->OpName()))
|
||||
{
|
||||
auto blockSrc = dynamic_cast<BlockFunction*>(src.get());
|
||||
for (auto map : blockSrc->CompositeOutputsMap())
|
||||
compositeOutputsMap.insert(map);
|
||||
TraverseGraph(src->BlockRoot(), visited, compositeOutputsMap);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (auto input : src->Inputs())
|
||||
{
|
||||
if (input.IsPlaceholder())
|
||||
{
|
||||
input = input.BlockFunctionVariableMapping();
|
||||
if (input.IsPlaceholder())
|
||||
LogicError("Node '%S': Placeholder isn't supported currently.", src->AsString().c_str());
|
||||
}
|
||||
|
||||
if (input.IsOutput())
|
||||
TraverseGraph(input.Owner(), visited, compositeOutputsMap);
|
||||
}
|
||||
}
|
||||
|
||||
visited.emplace(src);
|
||||
}
|
||||
|
||||
void CNTKToONNXHelper::CopyAttributes(const FunctionPtr& src, ONNXIR::Node* node)
|
||||
{
|
||||
auto lookup = Operators::CntkToONNXLookup();
|
||||
assert(lookup.count(src->OpName()) != 0);
|
||||
|
||||
std::string opName = ToString(src->OpName());
|
||||
if (lookup.count(src->OpName()) == 1)
|
||||
{
|
||||
auto attributesMap = lookup.find(src->OpName())->second.map;
|
||||
opName = attributesMap[src->OpName()];
|
||||
|
||||
if (src->OpName() == L"BatchNormalization")
|
||||
{
|
||||
auto spatial = (int64_t)((bool)src->Attributes()[L"spatial"].Value<bool>() ? 1 : 0);
|
||||
auto normalizationTimeConstant = (float)src->Attributes()[L"normalizationTimeConstant"].Value<double>();
|
||||
// auto blendTimeConstant = (float)src->Attributes()[L"blendTimeConstant"].Value<double>();
|
||||
auto epsilon = (float)src->Attributes()[L"epsilon"].Value<double>();
|
||||
|
||||
//
|
||||
// onnx: running_mean = running_mean * momentum + mean * (1 - momentum)
|
||||
// cntk: expAvgFactor * MB stats + (1-expAvgFactor) * prev running stats
|
||||
//
|
||||
auto momentum = 0.0f;
|
||||
if (!isfinite(normalizationTimeConstant))
|
||||
momentum = 1.0f;
|
||||
else if (normalizationTimeConstant > 0)
|
||||
momentum = 1.0f + expm1(-48.0f / normalizationTimeConstant);
|
||||
|
||||
node->AddAttribute(attributesMap[L"spatial"], spatial);
|
||||
node->AddAttribute("is_test", (int64_t)1);
|
||||
node->AddAttribute(attributesMap[L"epsilon"], epsilon);
|
||||
node->AddAttribute("momentum", momentum);
|
||||
}
|
||||
else if (src->OpName() == L"LocalResponseNormalization")
|
||||
{
|
||||
auto depthRadius = (int64_t)src->Attributes()[L"depthRadius"].Value<size_t>();
|
||||
auto bias = (float)src->Attributes()[L"bias"].Value<double>();
|
||||
auto alpha = (float)src->Attributes()[L"alpha"].Value<double>();
|
||||
auto beta = (float)src->Attributes()[L"beta"].Value<double>();
|
||||
|
||||
node->AddAttribute(attributesMap[L"depthRadius"], depthRadius);
|
||||
node->AddAttribute(attributesMap[L"bias"], bias);
|
||||
node->AddAttribute(attributesMap[L"alpha"], alpha);
|
||||
node->AddAttribute(attributesMap[L"beta"], beta);
|
||||
}
|
||||
else if ((src->OpName() == L"LeakyReLU") || (src->OpName() == L"ELU"))
|
||||
{
|
||||
auto alpha = 0.01f;
|
||||
if (src->Attributes().Contains(L"alpha"))
|
||||
alpha = (float)src->Attributes()[L"alpha"].Value<double>();
|
||||
node->AddAttribute("alpha", 0.01f);
|
||||
}
|
||||
else if (src->OpName() == L"SELU")
|
||||
{
|
||||
auto alpha = 1.6732f;
|
||||
if (src->Attributes().Contains(L"alpha"))
|
||||
alpha = (float)src->Attributes()[L"alpha"].Value<double>();
|
||||
|
||||
auto gamma = 1.0507f;
|
||||
if (src->Attributes().Contains(L"gamma"))
|
||||
gamma = (float)src->Attributes()[L"gamma"].Value<double>();
|
||||
|
||||
node->AddAttribute("alpha", alpha);
|
||||
node->AddAttribute("gamma", gamma);
|
||||
}
|
||||
else if (src->OpName() == L"Dropout")
|
||||
{
|
||||
auto dropoutRate = (float)src->Attributes()[L"dropoutRate"].Value<double>();
|
||||
node->AddAttribute(attributesMap[L"dropoutRate"], dropoutRate);
|
||||
node->AddAttribute("is_test", (int64_t)1);
|
||||
}
|
||||
else if ((src->OpName() == L"UniformRandom") || (src->OpName() == L"NormalRandom") ||
|
||||
(src->OpName() == L"UniformRandomLike") || (src->OpName() == L"NormalRandomLike"))
|
||||
{
|
||||
auto randomArgs = AsVector<double>(src->Attributes()[L"randomDistributionArgs"].Value<std::vector<DictionaryValue>>());
|
||||
auto seed = (int64_t)src->Attributes()[L"rngSeed"].Value<int>();
|
||||
|
||||
if ((src->OpName() == L"UniformRandom") || (src->OpName() == L"UniformRandomLike"))
|
||||
{
|
||||
node->AddAttribute("low", (float)randomArgs[0]);
|
||||
node->AddAttribute("high", (float)randomArgs[1]);
|
||||
}
|
||||
else
|
||||
{
|
||||
node->AddAttribute("mean", (float)randomArgs[0]);
|
||||
node->AddAttribute("scale", (float)randomArgs[1]);
|
||||
}
|
||||
|
||||
node->AddAttribute(attributesMap[L"rngSeed"], seed);
|
||||
if ((src->OpName() == L"UniformRandom") || (src->OpName() == L"NormalRandom"))
|
||||
{
|
||||
auto shape = (NDShape)src->Attributes()[L"newShape"].Value<NDShape>();
|
||||
node->AddAttribute(attributesMap[L"newShape"], ToINTS(shape));
|
||||
}
|
||||
}
|
||||
else if ((src->OpName() == L"ReduceMax") || (src->OpName() == L"ReduceMin") ||
|
||||
(src->OpName() == L"ReduceSum") || (src->OpName() == L"ReduceMean") ||
|
||||
(src->OpName() == L"ReduceProd") || (src->OpName() == L"ReduceLogSum") ||
|
||||
(src->OpName() == L"Argmax") || (src->OpName() == L"Argmin"))
|
||||
{
|
||||
auto keepReducedDimensions = (int64_t)((bool)src->Attributes()[L"reductionKeepDimensions"].Value<bool>() ? 1 : 0);
|
||||
std::vector<Axis> reductionAxes;
|
||||
if (src->Attributes().Contains(L"axisVec"))
|
||||
reductionAxes = AsVector<Axis>(src->Attributes()[L"axisVec"].Value<std::vector<DictionaryValue>>());
|
||||
else if (src->Attributes().Contains(L"axis"))
|
||||
reductionAxes.push_back((Axis)(src->Attributes()[L"axis"].Value<Axis>()));
|
||||
|
||||
node->AddAttribute(attributesMap[L"reductionKeepDimensions"], keepReducedDimensions);
|
||||
node->AddAttribute("axes", ToINTS(reductionAxes));
|
||||
}
|
||||
else if (src->OpName() == L"Transpose")
|
||||
{
|
||||
std::vector<Axis> perm = AsVector<Axis>(src->Attributes()[L"axisVec"].Value<std::vector<DictionaryValue>>());
|
||||
node->AddAttribute(attributesMap[L"axisVec"], ToINTS(perm));
|
||||
}
|
||||
else if (src->OpName() == L"Reshape")
|
||||
{
|
||||
auto shape = (NDShape)src->Attributes()[L"newShape"].Value<NDShape>();
|
||||
node->AddAttribute(attributesMap[L"newShape"], ToINTS(shape, true));
|
||||
}
|
||||
else if (src->OpName() == L"Splice")
|
||||
{
|
||||
Axis axis = (Axis)(src->Attributes()[L"axis"].Value<Axis>());
|
||||
node->AddAttribute(attributesMap[L"axis"], (int64_t)ToIndex(axis));
|
||||
}
|
||||
else if (src->OpName() == L"Slice")
|
||||
{
|
||||
std::vector<Axis> sliceAxes;
|
||||
std::vector<int> beginIndex;
|
||||
std::vector<int> endIndex;
|
||||
|
||||
if (src->Attributes().Contains(L"axisVec"))
|
||||
{
|
||||
sliceAxes = AsVector<Axis>(src->Attributes()[L"axisVec"].Value<std::vector<DictionaryValue>>());
|
||||
beginIndex = AsVector<int>(src->Attributes()[L"beginIndexVec"].Value<std::vector<DictionaryValue>>());
|
||||
endIndex = AsVector<int>(src->Attributes()[L"endIndexVec"].Value<std::vector<DictionaryValue>>());
|
||||
}
|
||||
else if (src->Attributes().Contains(L"axis"))
|
||||
{
|
||||
sliceAxes.push_back((Axis)(src->Attributes()[L"axis"].Value<Axis>()));
|
||||
beginIndex.push_back((int)(src->Attributes()[L"beginIndex"].Value<int>()));
|
||||
endIndex.push_back((int)(src->Attributes()[L"endIndex"].Value<int>()));
|
||||
}
|
||||
|
||||
node->AddAttribute(attributesMap[L"axes"], ToINTS(sliceAxes));
|
||||
node->AddAttribute(attributesMap[L"starts"], ToINTS(beginIndex));
|
||||
node->AddAttribute(attributesMap[L"ends"], ToINTS(endIndex));
|
||||
}
|
||||
else if (src->OpName() == L"Softmax")
|
||||
{
|
||||
Axis axis = Axis(0);
|
||||
if (src->Attributes().Contains(L"axis"))
|
||||
axis = (Axis)(src->Attributes()[L"axis"].Value<Axis>());
|
||||
node->AddAttribute(attributesMap[L"axis"], (int64_t)ToIndex(axis));
|
||||
}
|
||||
else if ((src->OpName() == L"Plus") || (src->OpName() == L"Minus") ||
|
||||
(src->OpName() == L"ElementTimes") || (src->OpName() == L"ElementDivide"))
|
||||
{
|
||||
node->AddAttribute("broadcast", (int64_t)1);
|
||||
// node->AddAttribute("axis", (int64_t)1);
|
||||
}
|
||||
else if (src->OpName() == L"Times")
|
||||
{
|
||||
size_t outputRank = src->Attributes()[L"outputRank"].Value<size_t>();
|
||||
if (outputRank > 1)
|
||||
LogicError("Output rank other than 1 is not supported.");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Some nodes map one to many.
|
||||
if (src->OpName() == L"Convolution")
|
||||
{
|
||||
auto kernelShape = (NDShape)src->Attributes()[L"kernelShape"].Value<NDShape>();
|
||||
auto strides = (NDShape)src->Attributes()[L"strides"].Value<NDShape>();
|
||||
auto autoPadding = AsVector<bool>(src->Attributes()[L"autoPadding"].Value<std::vector<DictionaryValue>>());
|
||||
auto dilations = (NDShape)src->Attributes()[L"dilation"].Value<NDShape>();
|
||||
auto transpose = (bool)src->Attributes()[L"transpose"].Value<bool>();
|
||||
|
||||
//
|
||||
// Remove the channel part for ONNX.
|
||||
//
|
||||
kernelShape = kernelShape.SubShape(0, kernelShape.Rank() - 1);
|
||||
strides = strides.SubShape(0, strides.Rank() - 1);
|
||||
autoPadding.pop_back();
|
||||
dilations = dilations.SubShape(0, dilations.Rank() - 1);
|
||||
|
||||
node->AddAttribute("kernel_shape", ToINTS(kernelShape));
|
||||
node->AddAttribute("strides", ToINTS(strides));
|
||||
node->AddAttribute("pads", ToINTS(autoPadding));
|
||||
node->AddAttribute("dilations", ToINTS(dilations));
|
||||
node->AddAttribute("group", (int64_t)1);
|
||||
|
||||
if (transpose)
|
||||
{
|
||||
auto outputShape = (NDShape)src->Attributes()[L"outputShape"].Value<NDShape>();
|
||||
node->AddAttribute("output_shape", ToINTS(outputShape, src->Inputs()[1].HasBatchAxis()));
|
||||
}
|
||||
}
|
||||
else if (src->OpName() == L"Pooling")
|
||||
{
|
||||
auto kernelShape = (NDShape)src->Attributes()[L"poolingWindowShape"].Value<NDShape>();
|
||||
auto strides = (NDShape)src->Attributes()[L"strides"].Value<NDShape>();
|
||||
auto autoPadding = AsVector<bool>(src->Attributes()[L"autoPadding"].Value<std::vector<DictionaryValue>>());
|
||||
|
||||
node->AddAttribute("kernel_shape", ToINTS(kernelShape));
|
||||
node->AddAttribute("strides", ToINTS(strides));
|
||||
node->AddAttribute("pads", ToINTS(autoPadding));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ONNXIR::NodeArg> CNTKToONNXHelper::MapInputsOrderToONNX(const FunctionPtr& src, const std::vector<ONNXIR::NodeArg>& inputs)
|
||||
{
|
||||
if (Operators::HasInputIndexMap(src->OpName()))
|
||||
{
|
||||
std::vector<ONNXIR::NodeArg> orderedInputs;
|
||||
std::map<int, ONNXIR::NodeArg> orderedInputsMap;
|
||||
auto map = Operators::ToONNXInputIndexMap(src->OpName());
|
||||
|
||||
for (size_t inputIndex = 0; inputIndex < inputs.size(); ++inputIndex)
|
||||
{
|
||||
if (map[inputIndex] >= 0)
|
||||
orderedInputsMap.insert(std::pair<int, ONNXIR::NodeArg>(map[inputIndex], inputs[inputIndex]));
|
||||
}
|
||||
|
||||
for (const auto& item : orderedInputsMap)
|
||||
orderedInputs.push_back(item.second);
|
||||
|
||||
return orderedInputs;
|
||||
}
|
||||
|
||||
return inputs;
|
||||
}
|
||||
|
||||
ONNXIR::Node* CNTKToONNXHelper::AddNode(const FunctionPtr& src, ONNXIR::Graph* graph, const std::vector<ONNXIR::NodeArg>& inputs, const std::vector<ONNXIR::NodeArg>& outputs)
|
||||
{
|
||||
ONNXIR::Node* node = nullptr;
|
||||
auto orderedInputs = MapInputsOrderToONNX(src, inputs);
|
||||
auto nodeName = src->Name().empty() ? ToString(src->Uid()) : ToString(src->Name());
|
||||
|
||||
//
|
||||
// CNTK Times OP is way more flexible for ONNX, so depend on the inputs and output shape,
|
||||
// we will need to insert some reshapes.
|
||||
//
|
||||
if (src->OpName() == L"Times")
|
||||
{
|
||||
auto input1Shape = orderedInputs[0].Shape();
|
||||
auto input2Shape = orderedInputs[1].Shape();
|
||||
auto outputShape = outputs[0].Shape();
|
||||
|
||||
int input1Rank = input1Shape->dim_size();
|
||||
int input2Rank = input2Shape->dim_size();
|
||||
int outputRank = outputShape->dim_size();
|
||||
int reductionRank = (input1Rank + input2Rank - outputRank) / 2;
|
||||
|
||||
if (reductionRank > 1) // We need to insert reshape.
|
||||
{
|
||||
auto input1Reshape = ReduceRank(input1Shape, reductionRank, true);
|
||||
auto input2Reshape = ReduceRank(input2Shape, reductionRank, false);
|
||||
|
||||
UpdateONNXType(src->Inputs()[1].GetDataType(), input1Reshape);
|
||||
UpdateONNXType(src->Inputs()[0].GetDataType(), input2Reshape);
|
||||
|
||||
ONNXIR::NodeArg inputOutput1Arg(orderedInputs[0].Name() + string("_reshape0"), &input1Reshape);
|
||||
ONNXIR::NodeArg inputOutput2Arg(orderedInputs[1].Name() + string("_reshape1"), &input2Reshape);
|
||||
|
||||
auto reshapeNode1 = graph->AddNode(nodeName + string("_reshape0"), "Reshape", "", { orderedInputs[0] }, { inputOutput1Arg });
|
||||
auto reshapeNode2 = graph->AddNode(nodeName + string("_reshape1"), "Reshape", "", { orderedInputs[1] }, { inputOutput2Arg });
|
||||
|
||||
reshapeNode1->AddAttribute("shape", ToINTS(input1Reshape));
|
||||
reshapeNode2->AddAttribute("shape", ToINTS(input2Reshape));
|
||||
|
||||
node = graph->AddNode(nodeName, ToOPName(src), "", { inputOutput1Arg , inputOutput2Arg }, outputs);
|
||||
}
|
||||
else
|
||||
node = graph->AddNode(nodeName, ToOPName(src), "", orderedInputs, outputs);
|
||||
}
|
||||
else
|
||||
node = graph->AddNode(nodeName, ToOPName(src), "", orderedInputs, outputs);
|
||||
|
||||
//
|
||||
// Copy and validate attributes.
|
||||
//
|
||||
CopyAttributes(src, node);
|
||||
|
||||
return node;
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "stdafx.h"
|
||||
#include "CNTKLibrary.h"
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
class Model;
|
||||
}
|
||||
|
||||
namespace CNTK
|
||||
{
|
||||
class CNTKToONNX
|
||||
{
|
||||
public:
|
||||
static std::unique_ptr<ONNXIR::Model> CreateModel(const FunctionPtr& src);
|
||||
};
|
||||
}
|
|
@ -0,0 +1,75 @@
|
|||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
|
||||
#include "ONNX.h"
|
||||
#include "CNTKToONNX.h"
|
||||
#include "proto/onnx/core/model.h"
|
||||
#include "proto/onnx/core/graph.h"
|
||||
#include "Utils.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "ONNXToCNTK.h"
|
||||
|
||||
using namespace CNTK;
|
||||
|
||||
namespace CNTK
|
||||
{
|
||||
static void PrintGraph(FunctionPtr function, int spaces, bool useName = false)
|
||||
{
|
||||
if (function->Inputs().size() == 0)
|
||||
{
|
||||
cout << string(spaces, '.') + "(" + ToString(useName ? function->Name() : function->Uid()) + ")" + ToString(function->AsString()) << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
for (auto input : function->Inputs())
|
||||
{
|
||||
cout << string(spaces, '.') + "(" + ToString(useName ? function->Name() : function->Uid()) + ")" + "->" +
|
||||
"(" + ToString(useName ? input.Name() : input.Uid()) + ")" + ToString(input.AsString()) << std::endl;
|
||||
}
|
||||
|
||||
for (auto input : function->Inputs())
|
||||
{
|
||||
if (input.Owner() != NULL)
|
||||
{
|
||||
FunctionPtr f = input.Owner();
|
||||
PrintGraph(f, spaces + 4, useName);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ONNXFormat::Save(const FunctionPtr& src, const std::wstring& filepath)
|
||||
{
|
||||
auto model = CNTKToONNX::CreateModel(src);
|
||||
#ifdef _WIN32
|
||||
ONNXIR::Model::Save(*model, filepath);
|
||||
#else
|
||||
ONNXIR::Model::Save(*model, ToString(filepath));
|
||||
#endif
|
||||
}
|
||||
|
||||
FunctionPtr ONNXFormat::Load(const std::wstring& filepath, const DeviceDescriptor& computeDevice)
|
||||
{
|
||||
ONNXIR::ModelProto modelProto;
|
||||
|
||||
#ifdef _WIN32
|
||||
bool loadStatus = ONNXIR::Model::Load(filepath, &modelProto);
|
||||
#else
|
||||
bool loadStatus = ONNXIR::Model::Load(ToString(filepath), &modelProto);
|
||||
#endif
|
||||
loadStatus;
|
||||
//if (!loadStatus)
|
||||
// LogicError("Failed to load the model.");
|
||||
|
||||
ONNXIR::Model model(modelProto);
|
||||
auto status = model.MainGraph()->Resolve();
|
||||
if (!status.Ok())
|
||||
LogicError("%s", status.ErrorMsg().c_str());
|
||||
|
||||
FunctionPtr cntkFunction = ONNXToCNTK::CreateGraph(model.MainGraph(), computeDevice);
|
||||
return cntkFunction;
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "stdafx.h"
|
||||
#include "CNTKLibrary.h"
|
||||
|
||||
|
||||
namespace CNTK
|
||||
{
|
||||
class ONNXFormat
|
||||
{
|
||||
public:
|
||||
static void Save(const FunctionPtr& src, const std::wstring& filepath);
|
||||
static FunctionPtr Load(const std::wstring& filepath, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
|
||||
};
|
||||
}
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,27 @@
|
|||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "stdafx.h"
|
||||
#include "CNTKLibrary.h"
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
class Graph;
|
||||
}
|
||||
|
||||
namespace CNTK
|
||||
{
|
||||
class ONNXToCNTK
|
||||
{
|
||||
public:
|
||||
//
|
||||
// Create a CNTK graph (Function) given an ONNX graph. The function is created to use the
|
||||
// specified computing device.
|
||||
//
|
||||
static FunctionPtr CreateGraph(ONNXIR::Graph* src, const DeviceDescriptor& computeDevice);
|
||||
};
|
||||
}
|
|
@ -0,0 +1,297 @@
|
|||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
|
||||
#include "Operators.h"
|
||||
#include "proto/onnx/core/graph.h"
|
||||
#include "Utils.h"
|
||||
|
||||
namespace CNTK
|
||||
{
|
||||
namespace ONNX
|
||||
{
|
||||
//
|
||||
// Support ONNX OPs from https://github.com/onnx/onnx/tree/master/onnx/defs
|
||||
//
|
||||
// The format of the below structure is simply a key which is CNTK OpName and a corresponding
|
||||
// lookup table, the corrsponding lookup table map the OpName and all its attributes from CNTK
|
||||
// to ONNX.
|
||||
//
|
||||
// Eventually, it would be good to change CNTK OpName to match ONNX in order to avoid the need
|
||||
// of the below table.
|
||||
//
|
||||
std::unordered_multimap<std::wstring, AttributesMapping> Operators::_cntkToONNXOpName = {
|
||||
// From nn
|
||||
{ L"Pooling", { {
|
||||
{ L"Pooling", "AveragePool" },
|
||||
{ L"poolingWindowShape", "kernel_shape" },
|
||||
{ L"strides", "strides" },
|
||||
{ L"autoPadding", "pads" },
|
||||
} } },
|
||||
{ L"Pooling", { {
|
||||
{ L"Pooling", "MaxPool" },
|
||||
{ L"poolingWindowShape", "kernel_shape" },
|
||||
{ L"strides", "strides" },
|
||||
{ L"autoPadding", "pads" },
|
||||
} } },
|
||||
{ L"Convolution", { {
|
||||
{ L"Convolution", "Conv" },
|
||||
{ L"kernelShape", "kernel_shape" },
|
||||
{ L"strides", "strides" },
|
||||
{ L"autoPadding", "pads" },
|
||||
{ L"dilation", "dilations" },
|
||||
// { L"", "group" },
|
||||
} } },
|
||||
{ L"Convolution", { {
|
||||
{ L"ConvolutionTranspose", "ConvTranspose" },
|
||||
{ L"kernelShape", "kernel_shape" },
|
||||
{ L"strides", "strides" },
|
||||
{ L"autoPadding", "pads" },
|
||||
{ L"dilation", "dilations" },
|
||||
{ L"outputShape", "output_shape" },
|
||||
} } },
|
||||
{ L"GlobalMaxPooling", { {
|
||||
{ L"GlobalMaxPooling", "GlobalMaxPool" },
|
||||
} } },
|
||||
{ L"GlobalAveragePooling", { {
|
||||
{ L"GlobalAveragePooling", "GlobalAveragePool" },
|
||||
} } },
|
||||
{ L"BatchNormalization", { {
|
||||
{ L"BatchNormalization", "BatchNormalization" },
|
||||
{ L"spatial", "spatial" },
|
||||
// { L"", "is_test" },
|
||||
{ L"epsilon", "epsilon" },
|
||||
// { L"", "momentum" },
|
||||
} } },
|
||||
// from ONNX experiament, added to test Caffe models
|
||||
// TODO: set key as BatchNormalization instead of BatchNormalizationCaffe
|
||||
{ L"BatchNormalizationCaffe",{ {
|
||||
{ L"BatchNormalization", "SpatialBN" },
|
||||
{ L"spatial", "spatial" },
|
||||
// { L"", "is_test" },
|
||||
{ L"epsilon", "epsilon" },
|
||||
// { L"", "momentum" },
|
||||
} } },
|
||||
{ L"LocalResponseNormalization",{ {
|
||||
{ L"LocalResponseNormalization", "LRN" },
|
||||
{ L"depthRadius", "size" },
|
||||
{ L"bias", "bias" },
|
||||
{ L"alpha", "alpha" },
|
||||
{ L"beta", "beta" },
|
||||
} } },
|
||||
{ L"Dropout", { {
|
||||
{ L"Dropout", "Dropout" },
|
||||
{ L"dropoutRate", "ratio" },
|
||||
// { L"", "is_test" },
|
||||
} } },
|
||||
{ L"Flatten",{ {
|
||||
{ L"Flatten", "Flatten" },
|
||||
} } },
|
||||
|
||||
// From Generator
|
||||
{ L"UniformRandom", { {
|
||||
{ L"UniformRandom", "RandomUniform" },
|
||||
// { L"", "low" },
|
||||
// { L"", "high" },
|
||||
{ L"rngSeed", "seed" },
|
||||
{ L"newShape", "shape" },
|
||||
} } },
|
||||
{ L"NormalRandom", { {
|
||||
{ L"NormalRandom", "RandomNormal" },
|
||||
// { L"", "mean" },
|
||||
// { L"", "scale" },
|
||||
{ L"rngSeed", "seed" },
|
||||
{ L"newShape", "shape" },
|
||||
} } },
|
||||
{ L"UniformRandomLike", { {
|
||||
{ L"UniformRandomLike", "RandomUniformLike" },
|
||||
// { L"", "low" },
|
||||
// { L"", "high" },
|
||||
{ L"rngSeed", "seed" },
|
||||
} } },
|
||||
{ L"NormalRandomLike", { {
|
||||
{ L"NormalRandomLike", "RandomNormalLike" },
|
||||
// { L"", "mean" },
|
||||
// { L"", "scale" },
|
||||
{ L"rngSeed", "seed" },
|
||||
} } },
|
||||
|
||||
// From Math
|
||||
{ L"Plus", { {
|
||||
{ L"Plus", "Add" },
|
||||
} } },
|
||||
{ L"Minus", { {
|
||||
{ L"Minus", "Sub" },
|
||||
} } },
|
||||
{ L"ElementTimes", { {
|
||||
{ L"ElementTimes", "Mul" },
|
||||
} } },
|
||||
{ L"ElementDivide", { {
|
||||
{ L"ElementDivide", "Div" },
|
||||
} } },
|
||||
{ L"Negate", { {
|
||||
{ L"Negate", "Neg" },
|
||||
} } },
|
||||
{ L"Abs", { {
|
||||
{ L"Abs", "Abs" },
|
||||
} } },
|
||||
{ L"Reciprocal", { {
|
||||
{ L"Reciprocal", "Reciprocal" },
|
||||
} } },
|
||||
{ L"Floor", { {
|
||||
{ L"Floor", "Floor" },
|
||||
} } },
|
||||
{ L"Ceil", { {
|
||||
{ L"Ceil", "Ceil" },
|
||||
} } },
|
||||
{ L"Sqrt", { {
|
||||
{ L"Sqrt", "Sqrt" },
|
||||
} } },
|
||||
{ L"ReLU", { {
|
||||
{ L"ReLU", "Relu" },
|
||||
} } },
|
||||
{ L"LeakyReLU", { {
|
||||
{ L"LeakyReLU", "LeakyRelu" },
|
||||
{ L"alpha", "alpha" },
|
||||
} } },
|
||||
{ L"SELU", { {
|
||||
{ L"SELU", "Selu" },
|
||||
{ L"alpha", "alpha" },
|
||||
{ L"gamma", "gamma" },
|
||||
} } },
|
||||
{ L"ELU", { {
|
||||
{ L"ELU", "Elu" },
|
||||
// { L"", "alpha" },
|
||||
} } },
|
||||
{ L"Exp", { {
|
||||
{ L"Exp", "Exp" },
|
||||
} } },
|
||||
{ L"Log", { {
|
||||
{ L"Log", "Log" },
|
||||
} } },
|
||||
{ L"Tanh", { {
|
||||
{ L"Tanh", "Tanh" },
|
||||
} } },
|
||||
{ L"Pow", { {
|
||||
{ L"Pow", "Pow" },
|
||||
// { L"", "exponent" },
|
||||
} } },
|
||||
{ L"Times", { {
|
||||
{ L"Times", "Dot" },
|
||||
} } },
|
||||
{ L"PReLU", { {
|
||||
{ L"PReLU", "PRelu" },
|
||||
} } },
|
||||
{ L"StableSigmoid", { {
|
||||
{ L"StableSigmoid", "Sigmoid" },
|
||||
} } },
|
||||
{ L"ElementMax", { {
|
||||
{ L"ElementMax", "Max" },
|
||||
} } },
|
||||
{ L"ElementMax", { {
|
||||
{ L"ElementMax", "Min" },
|
||||
} } },
|
||||
// { L"", "Sum" },
|
||||
{ L"Softmax", { {
|
||||
{ L"Softmax", "Softmax" },
|
||||
{ L"axis", "axis" },
|
||||
} } },
|
||||
|
||||
// From reduction
|
||||
{ L"ReduceMax", { {
|
||||
{ L"ReduceMax", "ReduceMax" },
|
||||
{ L"axisVec", "axes" },
|
||||
{ L"reductionKeepDimensions", "keepdims" },
|
||||
} } },
|
||||
{ L"ReduceMin", { {
|
||||
{ L"ReduceMin", "ReduceMin" },
|
||||
{ L"axisVec", "axes" },
|
||||
{ L"reductionKeepDimensions", "keepdims" },
|
||||
} } },
|
||||
{ L"ReduceSum", { {
|
||||
{ L"ReduceSum", "ReduceSum" },
|
||||
{ L"axisVec", "axes" },
|
||||
{ L"reductionKeepDimensions", "keepdims" },
|
||||
} } },
|
||||
{ L"ReduceMean", { {
|
||||
{ L"ReduceMean", "ReduceMean" },
|
||||
{ L"axisVec", "axes" },
|
||||
{ L"reductionKeepDimensions", "keepdims" },
|
||||
} } },
|
||||
{ L"ReduceProd", { {
|
||||
{ L"ReduceProd", "ReduceProd" },
|
||||
{ L"axisVec", "axes" },
|
||||
{ L"reductionKeepDimensions", "keepdims" },
|
||||
} } },
|
||||
{ L"ReduceLogSum", { {
|
||||
{ L"ReduceLogSum", "ReduceLogSumExp" },
|
||||
{ L"axisVec", "axes" },
|
||||
{ L"reductionKeepDimensions", "keepdims" },
|
||||
} } },
|
||||
{ L"Argmax", { {
|
||||
{ L"Argmax", "ArgMax" },
|
||||
{ L"axis", "axes" },
|
||||
// { L"", "keepdims" },
|
||||
} } },
|
||||
{ L"Argmin", { {
|
||||
{ L"Argmin", "ArgMin" },
|
||||
{ L"axis", "axes" },
|
||||
// { L"", "keepdims" },
|
||||
} } },
|
||||
|
||||
// From tensor
|
||||
// { L"", "Cast" },
|
||||
{ L"Reshape", { {
|
||||
{ L"Reshape", "Reshape" },
|
||||
{ L"newShape", "shape" },
|
||||
} } },
|
||||
{ L"Splice", { {
|
||||
{ L"Splice", "Concat" },
|
||||
{ L"axis", "axis" },
|
||||
} } },
|
||||
// { L"", "Split" },
|
||||
{ L"Slice", { {
|
||||
{ L"Slice", "Slice" },
|
||||
{ L"beginIndexVec", "starts" },
|
||||
{ L"endIndexVec", "ends" },
|
||||
} } },
|
||||
{ L"Transpose", { {
|
||||
{ L"Transpose", "Transpose" },
|
||||
{ L"axisVec", "perm" },
|
||||
} } },
|
||||
{ L"GatherOp", { {
|
||||
{ L"GatherOp", "Gather" },
|
||||
} } },
|
||||
// { L"", "Squeeze" },
|
||||
};
|
||||
|
||||
std::unordered_map<std::wstring, std::set<size_t>> Operators::_cntkBlockOPInvalidIndices = {
|
||||
{ L"LeakyReLU", { 0, 1 } },
|
||||
{ L"SELU", { 0, 1, 2 } },
|
||||
{ L"PReLU", { 0 } },
|
||||
{ L"ElementMax", {} },
|
||||
{ L"ElementMin", {} },
|
||||
{ L"Softmax", {} },
|
||||
{ L"LocalResponseNormalization", { 0, 1, 2 } }
|
||||
};
|
||||
|
||||
std::unordered_map<std::wstring, std::vector<int>> Operators::_cntkToONNXInputIndices = {
|
||||
{ L"Convolution", { 1, 0 } },
|
||||
{ L"ConvolutionTranspose", { 1, 0 } },
|
||||
{ L"BatchNormalization", { 0, 1, 2, 3, 4, -1 } },
|
||||
{ L"Times", { 1, 0 } },
|
||||
};
|
||||
|
||||
//
|
||||
// CNTK Layer API needs to be treated specially.
|
||||
//
|
||||
std::set<std::wstring> Operators::_cntkLayerOPName = {
|
||||
{ L"Convolution" },
|
||||
{ L"ConvolutionTranspose" },
|
||||
{ L"BatchNormalization" },
|
||||
{ L"Dropout" },
|
||||
};
|
||||
|
||||
}
|
||||
}
|
|
@ -0,0 +1,111 @@
|
|||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "stdafx.h"
|
||||
#include "CNTKLibrary.h"
|
||||
|
||||
#include <set>
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
class Graph;
|
||||
}
|
||||
|
||||
namespace CNTK
|
||||
{
|
||||
namespace ONNX
|
||||
{
|
||||
|
||||
struct AttributesMapping
|
||||
{
|
||||
std::unordered_map<std::wstring, std::string> map;
|
||||
};
|
||||
|
||||
class Operators
|
||||
{
|
||||
public:
|
||||
//
|
||||
// Check if opName is one of the supported ONNX OP.
|
||||
//
|
||||
static inline bool IsSupportedCNTKOP(const std::wstring& opName)
|
||||
{
|
||||
return _cntkToONNXOpName.find(opName) != _cntkToONNXOpName.end();
|
||||
}
|
||||
|
||||
//
|
||||
// Layer APIs use block function as a wrapper, so we need to handle them with care.
|
||||
//
|
||||
static inline bool IsLayerCNTKOP(const std::wstring& opName)
|
||||
{
|
||||
return _cntkLayerOPName.find(opName) != _cntkLayerOPName.end();
|
||||
}
|
||||
|
||||
//
|
||||
// Return a lookup table which is keyed on CNTK OP, and the value is another table
|
||||
// that contain name mapping from CNTK to ONNX.
|
||||
//
|
||||
static inline const std::unordered_multimap<std::wstring, AttributesMapping>& CntkToONNXLookup()
|
||||
{
|
||||
return _cntkToONNXOpName;
|
||||
}
|
||||
|
||||
//
|
||||
// Because in CNTK block, we can't filtered out the external inputs to the block.
|
||||
// We need a way to filter out leaf input from its subgraph.
|
||||
//
|
||||
static inline bool IsValidInputs(const std::wstring& opName, size_t index)
|
||||
{
|
||||
assert(_cntkBlockOPInvalidIndices.find(opName) != _cntkBlockOPInvalidIndices.end());
|
||||
|
||||
auto invalidIndices = _cntkBlockOPInvalidIndices[opName];
|
||||
return invalidIndices.find(index) == invalidIndices.end();
|
||||
}
|
||||
|
||||
//
|
||||
// The positional of the argument between CNTK and ONNX aren't the same.
|
||||
// The below function return true, if we need a remap.
|
||||
//
|
||||
static inline bool HasInputIndexMap(const std::wstring& opName)
|
||||
{
|
||||
return _cntkToONNXInputIndices.find(opName) != _cntkToONNXInputIndices.end();
|
||||
}
|
||||
|
||||
//
|
||||
// If we need a remap, the below function return a remapping map.
|
||||
//
|
||||
static inline const std::vector<int>& ToONNXInputIndexMap(const std::wstring& opName)
|
||||
{
|
||||
assert(_cntkToONNXInputIndices.find(opName) != _cntkToONNXInputIndices.end());
|
||||
return _cntkToONNXInputIndices[opName];
|
||||
}
|
||||
|
||||
//
|
||||
// For block function with internal constant or parameter, we don't want to create
|
||||
// the corresponding ONNX tensor for some of the parameters.
|
||||
//
|
||||
static inline bool IgnoreConstantAndParameter(const std::wstring& opName, size_t index)
|
||||
{
|
||||
if (_cntkToONNXInputIndices.find(opName) != _cntkToONNXInputIndices.end())
|
||||
{
|
||||
auto indexMap = _cntkToONNXInputIndices[opName];
|
||||
assert (index < indexMap.size());
|
||||
|
||||
return (indexMap[index] < 0);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
static std::unordered_multimap<std::wstring, AttributesMapping> _cntkToONNXOpName;
|
||||
static std::unordered_map<std::wstring, std::set<size_t>> _cntkBlockOPInvalidIndices;
|
||||
static std::unordered_map<std::wstring, std::vector<int>> _cntkToONNXInputIndices;
|
||||
static std::set<std::wstring> _cntkLayerOPName;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
#include "constants.h"
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
|
||||
TypesWrapper& TypesWrapper::GetTypesWrapper()
|
||||
{
|
||||
static TypesWrapper* types = new TypesWrapper();
|
||||
return *types;
|
||||
}
|
||||
|
||||
std::unordered_set<std::string>& TypesWrapper::GetAllowedDataTypes()
|
||||
{
|
||||
static std::unordered_set<std::string>* allowedDataTypes =
|
||||
new std::unordered_set<std::string>({
|
||||
c_float16, c_float, c_double,
|
||||
c_int8, c_int16, c_int32, c_int64,
|
||||
c_uint8, c_uint16, c_uint32, c_uint64,
|
||||
c_complex64, c_complex128,
|
||||
c_string, c_bool });
|
||||
return *allowedDataTypes;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
#include <string>
|
||||
#include <unordered_set>
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
static const std::string c_noOp = "NoOp";
|
||||
static const std::string c_constantOp = "Constant";
|
||||
static const std::string c_constantValue = "value";
|
||||
|
||||
// Singleton wrapper around allowed data types.
|
||||
// This implements construct on first use which is needed to ensure
|
||||
// static objects are initialized before use. Ops registration does not work
|
||||
// properly without this.
|
||||
class TypesWrapper
|
||||
{
|
||||
public:
|
||||
static TypesWrapper& GetTypesWrapper();
|
||||
|
||||
// DataType strings. These should match the DataTypes defined in Data.proto
|
||||
const std::string c_float16 = "float16";
|
||||
const std::string c_float = "float";
|
||||
const std::string c_double = "double";
|
||||
const std::string c_int8 = "int8";
|
||||
const std::string c_int16 = "int16";
|
||||
const std::string c_int32 = "int32";
|
||||
const std::string c_int64 = "int64";
|
||||
const std::string c_uint8 = "uint8";
|
||||
const std::string c_uint16 = "uint16";
|
||||
const std::string c_uint32 = "uint32";
|
||||
const std::string c_uint64 = "uint64";
|
||||
const std::string c_complex64 = "complex64";
|
||||
const std::string c_complex128 = "complex128";
|
||||
const std::string c_string = "string";
|
||||
const std::string c_bool = "bool";
|
||||
std::unordered_set<std::string>& GetAllowedDataTypes();
|
||||
~TypesWrapper() = default;
|
||||
TypesWrapper(const TypesWrapper&) = delete;
|
||||
void operator=(const TypesWrapper&) = delete;
|
||||
private:
|
||||
TypesWrapper() = default;
|
||||
};
|
||||
}
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,604 @@
|
|||
#pragma warning(push)
|
||||
#pragma warning(disable : 4800 4610 4512 4510 4267 4127 4125 4100 4456)
|
||||
|
||||
#ifndef CORE_GRAPH_GRAPH_H
|
||||
#define CORE_GRAPH_GRAPH_H
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "constants.h"
|
||||
#include "proto/onnx/protobuf/graph.pb.h"
|
||||
#include "status.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
typedef size_t NODEINDEX;
|
||||
typedef int64_t VERSION;
|
||||
typedef std::unordered_map<std::string, AttributeProto> NodeAttributes;
|
||||
typedef ValueInfoProto NodeArgInfo;
|
||||
typedef std::unordered_map<std::string, TensorProto> InitialTensorSet;
|
||||
typedef std::unordered_map<std::string, TypeProto> ArgNameToTypeMap;
|
||||
|
||||
class Graph;
|
||||
class Node;
|
||||
class OpSignature;
|
||||
|
||||
// Node argument definition, for both input and output,
|
||||
// including arg name, arg type (contains both type and shape).
|
||||
//
|
||||
// Design Question: in my (Ke's) opinion, shape should not be part of type.
|
||||
// We may align the protobuf design with our operator registry interface,
|
||||
// which has type specified for each operator, but no shape. Well, shape
|
||||
// should be inferred with a separate shape inference function given
|
||||
// input shapes, or input tensor data sometimes.
|
||||
// With shape as part of type (current protobuf design),
|
||||
// 1) we'll have to split the "TypeProto" into type and shape in this internal
|
||||
// representation interface so that it could be easily used when doing type
|
||||
// inference and matching with operator registry.
|
||||
// 2) SetType should be always called before SetShape, otherwise, SetShape()
|
||||
// will fail. Because shape is located in a TypeProto.
|
||||
// Thoughts?
|
||||
//
|
||||
class NodeArg
|
||||
{
|
||||
public:
|
||||
|
||||
// Constructor by specifying node arg name and type&shape which is
|
||||
// optional. This is called when loading a <Graph> from <GraphProto>
|
||||
// normally.
|
||||
NodeArg(const std::string& p_name,
|
||||
const TypeProto* p_argType);
|
||||
|
||||
// Get node arg name.
|
||||
const std::string& Name() const;
|
||||
|
||||
// Get node arg type.
|
||||
const PTYPE& Type() const;
|
||||
|
||||
// Get node arg shape.
|
||||
// Return null pointer if there's no shape specified.
|
||||
const TypeProto::TensorShapeProto* Shape() const;
|
||||
|
||||
// Set node arg shape.
|
||||
// Shape could only be set after setting type since shape information
|
||||
// now is part of TypeProto.
|
||||
void SetShape(const TypeProto::TensorShapeProto& p_shape);
|
||||
|
||||
// Get node arg info proto.
|
||||
const NodeArgInfo& ToProto() const;
|
||||
|
||||
private:
|
||||
|
||||
friend class Node;
|
||||
friend class Graph;
|
||||
|
||||
void SetType(PTYPE p_type);
|
||||
void SetType(const TypeProto& p_typeProto);
|
||||
|
||||
// Node arg PType.
|
||||
PTYPE m_type;
|
||||
|
||||
// Node arg name, type and shape.
|
||||
NodeArgInfo m_nodeArgInfo;
|
||||
};
|
||||
|
||||
// Function representation.
|
||||
// It could present two cases of functions.
|
||||
// 1. Function without instantiation (No Node* sent to constructor). This
|
||||
// may be used in pure function optimization.
|
||||
// Function body (subgraph) should not be able to executed since no real
|
||||
// tensor binded with inputs/outputs of the function.
|
||||
// 2. Function with instantiation (A non-empty Node* sent to constructor).
|
||||
// Function body (subgraph) should be able to be executed since all
|
||||
// input/output names among nodes inside are refering real tensor names.
|
||||
// 2_a. Function with template type parameter.
|
||||
// 2_b. Function without template type parameter.
|
||||
// Function definition (FunctionDefProto) will be synced only when its body
|
||||
// is changed. Meanwhile, in 2_a case above, the function definition name
|
||||
// will be appended with real type string.
|
||||
class Function
|
||||
{
|
||||
public:
|
||||
|
||||
// Get function body - a subgraph.
|
||||
// Returned pointer owned by <*this> Function.
|
||||
Graph* Body();
|
||||
|
||||
// Get function name.
|
||||
// A function's name could be either its function definition name
|
||||
// m_functionDefProto.name(), or m_functionDefProto.name() + template
|
||||
// argument value.
|
||||
const std::string& Name();
|
||||
|
||||
// Get the protobuf representation of <*this> function.
|
||||
const FunctionDefProto& ToProto();
|
||||
|
||||
private:
|
||||
|
||||
friend class Graph;
|
||||
|
||||
Function() = delete;
|
||||
|
||||
// Constructor.
|
||||
// <p_node> specifies the node that refers to <*this> function. It's
|
||||
// used to instantiate <p_funcProto> if <p_funcProto> is a function
|
||||
// template.
|
||||
// <p_funcProto> specifies a function definition that a node refers to.
|
||||
Function(Node* p_node,
|
||||
const FunctionDefProto& p_funcProto);
|
||||
|
||||
// Function body which is a SubGraph.
|
||||
std::unique_ptr<Graph> m_body;
|
||||
};
|
||||
|
||||
// A node representation class.
|
||||
class Node {
|
||||
|
||||
public:
|
||||
|
||||
// An edge end. It could be input or output edge end of a node.
|
||||
// For node's input edge end, it's the source end, as the destination
|
||||
// end is the node itself.
|
||||
// For node's ouput edge end, it's the destination end, as the source
|
||||
// end is the node itself.
|
||||
class EdgeEnd
|
||||
{
|
||||
public:
|
||||
|
||||
// Constructor.
|
||||
// An EdgeEnd contains a Node pointer, a NodeArg pointer.
|
||||
// NOTE: it does not own the Node pointer and NodeArg pointer.
|
||||
EdgeEnd(const Node& p_node, const NodeArg& p_nodeArg);
|
||||
|
||||
// Get the <Node*> that this edge end refers to.
|
||||
const Node* GetNode() const;
|
||||
|
||||
// Get the <NodeArg*> that this edge end refers to.
|
||||
const NodeArg* GetNodeArg() const;
|
||||
|
||||
private:
|
||||
|
||||
const Node* m_node;
|
||||
|
||||
const NodeArg* m_nodeArg;
|
||||
};
|
||||
|
||||
// An iterator helper class for iterating a Node's neighbour nodes.
|
||||
class NodeConstIterator
|
||||
{
|
||||
public:
|
||||
|
||||
NodeConstIterator(std::set<const Node*>::const_iterator p_iter);
|
||||
|
||||
bool operator==(const NodeConstIterator& p_other) const;
|
||||
|
||||
bool operator!=(const NodeConstIterator& p_other) const;
|
||||
|
||||
void operator++();
|
||||
|
||||
const Node* operator*();
|
||||
|
||||
private:
|
||||
|
||||
std::set<const Node*>::const_iterator m_iter;
|
||||
};
|
||||
|
||||
// Get node index.
|
||||
NODEINDEX Index() const;
|
||||
|
||||
// Get node name.
|
||||
const std::string& Name() const;
|
||||
|
||||
// Get node operator type.
|
||||
const std::string& OpType() const;
|
||||
|
||||
// Get node description.
|
||||
const std::string& Description() const;
|
||||
|
||||
// Read/Write <*this> node's input args' definition, including name,
|
||||
// type and shape.
|
||||
const std::vector<NodeArg>& InputDefs() const;
|
||||
std::vector<NodeArg>& Mutable_InputDefs();
|
||||
|
||||
const std::vector<int>& InputArgCount() const;
|
||||
std::vector<int>& Mutable_InputArgCount();
|
||||
|
||||
// Read/Write <*this> node's output args' definition, including name,
|
||||
// type and shape.
|
||||
const std::vector<NodeArg>& OutputDefs() const;
|
||||
std::vector<NodeArg>& Mutable_OutputDefs();
|
||||
|
||||
// Functions defined to traverse a Graph as below.
|
||||
// Read all input nodes of <*this>.
|
||||
Node::NodeConstIterator InputNodes_begin() const;
|
||||
Node::NodeConstIterator InputNodes_end() const;
|
||||
// Read all output nodes of <*this>.
|
||||
Node::NodeConstIterator OutputNodes_begin() const;
|
||||
Node::NodeConstIterator OutputNodes_end() const;
|
||||
// Given input arg, get the source end of an input edge.
|
||||
bool InputEdgeSrcEnd(NodeArg* p_inputArg,
|
||||
/*out*/const EdgeEnd** p_inputEdgeSrcEnd) const;
|
||||
|
||||
// Add a node attribute with specified attribute name and value.
|
||||
bool AddAttribute(const std::string& p_attrName, const AttributeProto& p_value);
|
||||
|
||||
#define ADD_ATTR_INTERFACES(TypeName) \
|
||||
bool AddAttribute(const std::string& p_attrName, \
|
||||
const TypeName& p_value); \
|
||||
bool AddAttribute(const std::string& p_attrName, \
|
||||
const std::vector<TypeName>& p_values); \
|
||||
|
||||
ADD_ATTR_INTERFACES(int64_t)
|
||||
ADD_ATTR_INTERFACES(float)
|
||||
ADD_ATTR_INTERFACES(std::string)
|
||||
ADD_ATTR_INTERFACES(TensorProto)
|
||||
ADD_ATTR_INTERFACES(GraphProto)
|
||||
ADD_ATTR_INTERFACES(TypeProto)
|
||||
ADD_ATTR_INTERFACES(TypeProto::TensorShapeProto)
|
||||
|
||||
// Clear specified node attribute.
|
||||
bool ClearAttribute(const std::string& p_attrName);
|
||||
|
||||
// Get node attributes.
|
||||
const NodeAttributes& GetAttributes() const;
|
||||
|
||||
// Indicates on which we will run this node in runtime.
|
||||
// Executor will decide which device that this node will run against
|
||||
// and set it properly.
|
||||
// TODO: may change the return value type to be an ENUM.
|
||||
const std::string& Device() const;
|
||||
void SetDevice(const std::string& p_device);
|
||||
|
||||
// Get the corresponding <NodeProto>.
|
||||
void ToProto(NodeProto& p_proto) const;
|
||||
|
||||
private:
|
||||
|
||||
friend class Graph;
|
||||
|
||||
// Node could ONLY be constructed and owned by a <Graph>.
|
||||
Node() {}
|
||||
Node(NODEINDEX p_index, Graph* p_graph)
|
||||
: m_index(p_index),
|
||||
m_graph(p_graph) {}
|
||||
Node(const Node& p_other);
|
||||
|
||||
// Init node per <NodeProto>.
|
||||
// <p_nameToValueInfoMap> specifies the node's inputs'/outputs' value information,
|
||||
// including name, type and shape.
|
||||
void Init(const NodeProto& p_nodeProto,
|
||||
const ArgNameToTypeMap& p_nameToType);
|
||||
void Init(const std::string& p_name,
|
||||
const std::string& p_opType,
|
||||
const std::string& p_description,
|
||||
const std::vector<NodeArg>& p_inputArgs,
|
||||
const std::vector<NodeArg>& p_outputArgs);
|
||||
void Init(const std::string& p_name,
|
||||
const std::string& p_opType,
|
||||
const std::string& p_description,
|
||||
const std::vector<NodeArg>& p_inputArgs,
|
||||
const std::vector<int>& p_inputArgCount,
|
||||
const std::vector<NodeArg>& p_outputArgs);
|
||||
void Init(const std::string& p_name,
|
||||
const std::string& p_opType,
|
||||
const std::string& p_description,
|
||||
const std::vector<NodeArg>& p_outputArgs);
|
||||
|
||||
// Node index.
|
||||
NODEINDEX m_index;
|
||||
|
||||
// Node name.
|
||||
std::string m_name;
|
||||
|
||||
// Node operator type.
|
||||
std::string m_opType;
|
||||
|
||||
// Node doc string.
|
||||
std::string m_description;
|
||||
|
||||
// Node inputs' definition.
|
||||
std::vector<NodeArg> m_inputDefs;
|
||||
// The number of inputs for each argument of the operator or function which
|
||||
// this node refers.
|
||||
// For example, <m_inputDefs> has 10 elements (inputs), and <m_inputArgCount>
|
||||
// is {4, 6}. This means that 4 elements (inputs) of <m_inputDefs> map to the
|
||||
// first argument of the operator or function, and the other 6 map to the
|
||||
// second argument.
|
||||
std::vector<int> m_inputArgCount;
|
||||
|
||||
// Node outputs' definition.
|
||||
std::vector<NodeArg> m_outputDefs;
|
||||
|
||||
// Node inputs' instantiation.
|
||||
std::unordered_map<const NodeArg*, EdgeEnd> m_inputs;
|
||||
// Node input nodes, besides input nodes mentioned in <m_inputs> above,
|
||||
// it also contains all control input nodes;
|
||||
std::set<const Node*> m_inputNodes;
|
||||
// Control input nodes' names.
|
||||
std::set<std::string> m_controlInputs;
|
||||
// Node's output nodes.
|
||||
std::set<const Node*> m_outputNodes;
|
||||
|
||||
// Device.
|
||||
std::string m_device;
|
||||
|
||||
// Map from attribute name to attribute.
|
||||
// This allows attribute adding and removing.
|
||||
NodeAttributes m_attributes;
|
||||
|
||||
Graph* m_graph;
|
||||
};
|
||||
|
||||
// A graph representation class.
|
||||
class Graph
|
||||
{
|
||||
public:
|
||||
|
||||
// An iterator helper to access graph nodes without copy.
|
||||
// The iterator itself does not own any data.
|
||||
class NodeIterator
|
||||
{
|
||||
public:
|
||||
|
||||
// Constructor.
|
||||
NodeIterator(NODEINDEX p_currentNodeIndex, Graph* p_graph)
|
||||
: m_graph(p_graph),
|
||||
m_currentNodeIndex(p_currentNodeIndex)
|
||||
{
|
||||
}
|
||||
|
||||
bool operator==(const NodeIterator& p_other) const;
|
||||
|
||||
bool operator!=(const NodeIterator& p_other) const;
|
||||
|
||||
void operator++();
|
||||
|
||||
Node* operator*();
|
||||
|
||||
private:
|
||||
|
||||
Graph* m_graph;
|
||||
|
||||
// it's the Node Index in <m_nodes> of the <m_graph>.
|
||||
NODEINDEX m_currentNodeIndex;
|
||||
};
|
||||
|
||||
// Constructor from scratch.
|
||||
// <p_isONNX> is a special flag to indicate whether it's
|
||||
// going to construct a ONNX graph. With ONNX graph, strict
|
||||
// type checking will be skiped.
|
||||
Graph(const std::string& p_name, bool p_isONNX = false);
|
||||
Graph(const std::string& p_name, const std::string& p_docString);
|
||||
|
||||
// Constructor: Given a <GraphProto> loaded from model file, construct
|
||||
// a <Graph> object.
|
||||
Graph(const GraphProto& p_graphProto);
|
||||
|
||||
// Constructor: Given a function definition and a node which refers to
|
||||
// the function, construct a <Graph> object.
|
||||
// Normally the <p_name> could be the parent node name and the
|
||||
// <p_version> could be the parent graph's version.
|
||||
// Question: will a node defined in a function refers another function
|
||||
// please? I (Ke) am assuming we don't allow such case here for now.
|
||||
Graph(Node* p_node,
|
||||
const FunctionDefProto& p_functionProto);
|
||||
|
||||
// Resolve <*this> graph to ensure it's in a good shape with full
|
||||
// functionality.
|
||||
// 1. Run through all validation rules.
|
||||
// a. Node name and node output's names should be unique.
|
||||
// b. Attribute match between node and op definition.
|
||||
// c. Input/Output match between node and op definition.
|
||||
// d. Graph is acyclic.
|
||||
// 2. Check & Setup inner nodes' dependency.
|
||||
// 3. Cleanup function definition lists.
|
||||
// Returns resolving status.
|
||||
Status Resolve();
|
||||
|
||||
// Getter and Setter for graph name.
|
||||
const std::string& Name() const;
|
||||
void SetName(const std::string& p_name);
|
||||
|
||||
// Add/Remove/Get initial tensors for some graph inputs.
|
||||
void AddInitialTensor(const TensorProto& p_tensor);
|
||||
void RemoveInitialTensor(const std::string& p_tensorName);
|
||||
bool GetInitialTensor(const std::string& p_tensorName,
|
||||
TensorProto& p_value) const;
|
||||
const InitialTensorSet& GetAllInitialTensors() const;
|
||||
|
||||
// Add or Remove a function definition.
|
||||
bool AddFunctionDef(const FunctionDefProto& p_function);
|
||||
void RemoveFunctionDef(const std::string& p_functionName);
|
||||
|
||||
// Get node given specific node index.
|
||||
Node* GetNode(NODEINDEX p_nodeIndex);
|
||||
|
||||
// Get node iterator to access all effective nodes in the graph.
|
||||
Graph::NodeIterator Nodes_begin();
|
||||
Graph::NodeIterator Nodes_end();
|
||||
|
||||
// Max Node Index.
|
||||
NODEINDEX MaxNodeIndex() const;
|
||||
|
||||
// Number of nodes in the <Graph>.
|
||||
// This is smaller than MaxNodeIndex(), since there may be nodes
|
||||
// removed during optimization.
|
||||
int NumberOfNodes() const;
|
||||
|
||||
// Add, remove node from <*this> graph.
|
||||
Node* AddNode(const std::string& p_name,
|
||||
const std::string& p_opType,
|
||||
const std::string& p_description,
|
||||
const std::vector<NodeArg>& p_inputArgs,
|
||||
const std::vector<NodeArg>& p_outputArgs);
|
||||
Node* AddNode(const std::string& p_name,
|
||||
const std::string& p_opType,
|
||||
const std::string& p_description,
|
||||
const std::vector<NodeArg>& p_inputArgs,
|
||||
const std::vector<int>& p_inputArgCount,
|
||||
const std::vector<NodeArg>& p_outputArgs);
|
||||
Node* AddNode(const std::string& p_name,
|
||||
const std::string& p_opType,
|
||||
const std::string& p_description,
|
||||
const std::vector<NodeArg>& p_outputArgs);
|
||||
Node* AddNode(const Node& p_other);
|
||||
bool RemoveNode(NODEINDEX p_nodeIndex);
|
||||
|
||||
// Convenience method for adding a constant op
|
||||
Node* AddConstantNode(const std::string& p_name,
|
||||
const std::string& p_description,
|
||||
const std::vector<NodeArg>& p_outputArgs,
|
||||
const TensorProto& p_tensor);
|
||||
|
||||
// Add control edge into <*this> graph.
|
||||
// The <p_dstNodeIndex> node does not consume any data output by
|
||||
// <p_srcNodeIndex>, but it's designed to be executed behind.
|
||||
bool AddControlEdge(NODEINDEX p_srcNodeIndex, NODEINDEX p_dstNodeIndex);
|
||||
|
||||
// Try to get function with specified <p_nodeIndex>. Return true if the
|
||||
// specified node refers to a function, and <p_function> will be the
|
||||
// function; false otherwise, and <p_function> will be unchanged.
|
||||
bool TryGetFunction(NODEINDEX p_nodeIndex,
|
||||
/*out*/ Function** p_function);
|
||||
|
||||
// Serialize the <Graph> into <GraphProto>.
|
||||
const GraphProto& ToGraphProto();
|
||||
|
||||
// Serialize the <Graph> into <FunctionDefProto>.
|
||||
// This is used when the graph is a subgraph of a main graph.
|
||||
const FunctionDefProto& ToFuncProto();
|
||||
|
||||
// Inline all function in <*this> and construct <p_graph>
|
||||
// without any functions. <p_graph> owned by caller.
|
||||
bool InlineAllFunctions(/*out*/Graph* p_graph) const;
|
||||
|
||||
bool IsSourceNode(NODEINDEX p_index) const;
|
||||
bool IsSinkNode(NODEINDEX p_index) const;
|
||||
|
||||
const Node* SourceNode() const;
|
||||
const Node* SinkNode() const;
|
||||
|
||||
Status GetNodesInTopologicalOrder(std::vector<NODEINDEX>** nodes);
|
||||
|
||||
private:
|
||||
|
||||
enum Type
|
||||
{
|
||||
// A main graph.
|
||||
Main = 1,
|
||||
// A sub graph (function).
|
||||
Sub = 2,
|
||||
// A graph with strict type checking.
|
||||
Strict = 4,
|
||||
};
|
||||
|
||||
friend class Node;
|
||||
|
||||
Node* AllocateNode();
|
||||
void ReleaseNode(NODEINDEX p_nodeIndex);
|
||||
|
||||
// Add node with specified <p_nodeProto>.
|
||||
Node* AddNode(const NodeProto& p_nodeProto,
|
||||
const ArgNameToTypeMap& p_nameToType);
|
||||
|
||||
Status VerifyNoDuplicateName(
|
||||
/*out*/ std::unordered_map<std::string, Node::EdgeEnd>& p_outputArgs,
|
||||
/*out*/ std::unordered_map<std::string, NODEINDEX>& p_nodeNameToIndex);
|
||||
|
||||
// Build and verify node connection (edges).
|
||||
// Verify NodeArg name/type/shape matching correctly.
|
||||
Status BuildConnections(
|
||||
const std::unordered_map<std::string, Node::EdgeEnd>& p_outputArgs,
|
||||
const std::unordered_map<std::string, NODEINDEX>& p_nodeNameToIndex);
|
||||
|
||||
// Check whether <*this> graph is acyclic.
|
||||
// Depth-first going thru the graph and check whether there's any back
|
||||
// edge.
|
||||
// <p_nodesInToplogicalOrder> returns nodes' indexes in toplogical
|
||||
// order if <Status> returned is "OK", otherwise it's undefined.
|
||||
Status CheckIsAcyclic(
|
||||
/*out*/std::vector<NODEINDEX>& p_nodesInToplogicalOrder);
|
||||
|
||||
// Depth-first graph access.
|
||||
// <p_ancestors> specifies all ancestor nodes of <p_current> node.
|
||||
// <p_current> specifies current node being accessed.
|
||||
// <p_visitedNodes> specifies nodes already visited.
|
||||
// <p_nodesInToplogicalOrder> returns nodes' indexes in toplogical
|
||||
// order if the graph is acyclic.
|
||||
Status DepthFirstAccess(std::unordered_set<NODEINDEX> p_ancestors,
|
||||
NODEINDEX p_current,
|
||||
/*in | out*/std::unordered_set<NODEINDEX>& p_visitedNodes,
|
||||
/*out*/std::vector<NODEINDEX>& p_nodesInToplogicalOrder);
|
||||
|
||||
// Given nodes in toplogical order, infer and set type information
|
||||
// across <*this> graph if needed, and verify type/attribute
|
||||
// information match between node and op.
|
||||
Status VerifyNodeAndOpMatch(
|
||||
const std::vector<NODEINDEX>& p_nodesInToplogicalOrder,
|
||||
std::unordered_map<std::string, Node::EdgeEnd>& p_outputArgs,
|
||||
/*out*/ std::set<std::string>& p_funcDefNames);
|
||||
|
||||
Status InferAndVerifyTypeMatch(Node* p_node,
|
||||
const OpSignature* p_op,
|
||||
const std::unordered_map<std::string, Node::EdgeEnd>& p_outputArgs);
|
||||
|
||||
// Clean function definition map.
|
||||
// Remove function definitions not refered by any node.
|
||||
void CleanFunctionDefMap(const std::set<std::string>& p_funcDefNames);
|
||||
|
||||
// Add source/sink nodes to <*this> graph.
|
||||
void AddSourceSinkNodes();
|
||||
|
||||
// Set graph inputs/outputs when serializing to proto.
|
||||
void SetGraphInputsOutputs();
|
||||
|
||||
// Graph nodes.
|
||||
// Element in <m_nodes> may be nullptr due to graph optimization.
|
||||
std::vector<std::unique_ptr<Node>> m_nodes;
|
||||
|
||||
// Number of nodes.
|
||||
// Normally this is smaller than the size of <m_nodes>, as some
|
||||
// elements in <m_nodes> may be removed when doing graph optimization,
|
||||
// or some elements may be merged, etc.
|
||||
int m_numOfNodes;
|
||||
|
||||
NODEINDEX m_sourceNodeIndex;
|
||||
NODEINDEX m_sinkNodeIndex;
|
||||
|
||||
// GraphProto to store name, version, initializer.
|
||||
// When serilizing <*this> Graph to a GraphProto, the nodes and
|
||||
// functions in <Graph> will also be fed into <m_graphProto> so that
|
||||
// it's consistent with <*this> graph.
|
||||
GraphProto m_graphProto;
|
||||
FunctionDefProto m_funcDefProto;
|
||||
|
||||
// The node which refers to <*this> graph (Function).
|
||||
Node* m_node;
|
||||
|
||||
// Graph function instantiations.
|
||||
std::unordered_map<std::string,
|
||||
std::unique_ptr<Function>> m_functionMap;
|
||||
|
||||
// Graph function definitions.
|
||||
std::unordered_map<std::string, FunctionDefProto> m_funcDefMap;
|
||||
|
||||
InitialTensorSet m_nameToInitialTensor;
|
||||
|
||||
// A flag indicates whether <*this> graph needs to be resolved.
|
||||
bool m_graphResolveNeeded;
|
||||
|
||||
bool m_graphProtoSyncNeeded;
|
||||
|
||||
int m_graphType = 0;
|
||||
|
||||
// the topologic order of node index
|
||||
std::vector<NODEINDEX> m_nodesInTopologicalOrder;
|
||||
};
|
||||
}
|
||||
|
||||
#endif // CORE_GRAPH_GRAPH_H
|
||||
|
||||
#pragma warning(pop)
|
|
@ -0,0 +1,288 @@
|
|||
#pragma warning(push)
|
||||
#pragma warning(disable : 4800 4610 4512 4510 4267 4127 4125 4100 4456 4189 4996)
|
||||
|
||||
#include <fcntl.h>
|
||||
#include <fstream>
|
||||
#include <google/protobuf/io/coded_stream.h>
|
||||
#include <google/protobuf/io/zero_copy_stream_impl.h>
|
||||
#ifdef _WIN32
|
||||
#include <io.h>
|
||||
#else
|
||||
#include <sys/io.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
#include "model.h"
|
||||
|
||||
namespace
|
||||
{
|
||||
#ifdef _WIN32
|
||||
inline int FileOpenRd(const std::wstring& p_path)
|
||||
{
|
||||
int fd = -1;
|
||||
bool err = _wsopen_s(&fd, p_path.c_str(), _O_RDONLY | _O_SEQUENTIAL | _O_BINARY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
|
||||
return fd;
|
||||
}
|
||||
|
||||
inline int FileOpenWr(const std::wstring& p_path)
|
||||
{
|
||||
int fd = -1;
|
||||
_wsopen_s(&fd, p_path.c_str(), _O_CREAT | _O_SEQUENTIAL | _O_BINARY | _O_WRONLY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
|
||||
return fd;
|
||||
}
|
||||
#endif
|
||||
|
||||
inline int FileOpenRd(const std::string& p_path)
|
||||
{
|
||||
#ifdef _WIN32
|
||||
int fd = -1;
|
||||
_sopen_s(&fd, p_path.c_str(), _O_RDONLY | _O_SEQUENTIAL | _O_BINARY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
|
||||
return fd;
|
||||
#else
|
||||
return open(p_path.c_str(), O_RDONLY);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline int FileOpenWr(const std::string& p_path)
|
||||
{
|
||||
#ifdef _WIN32
|
||||
int fd = -1;
|
||||
_sopen_s(&fd, p_path.c_str(), _O_CREAT | _O_SEQUENTIAL | _O_BINARY | _O_WRONLY, _SH_DENYWR, _S_IREAD | _S_IWRITE);
|
||||
return fd;
|
||||
#else
|
||||
return open(p_path.c_str(), O_WRONLY | O_CREAT | O_TRUNC, 0644);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
Model::Model(const std::string& p_graphName, bool p_isONNX)
|
||||
{
|
||||
m_graph.reset(new Graph(p_graphName, p_isONNX));
|
||||
}
|
||||
|
||||
Model::Model(const std::string& p_graphName,
|
||||
const std::string& p_graphDocString)
|
||||
{
|
||||
m_graph.reset(new Graph(p_graphName, p_graphDocString));
|
||||
}
|
||||
|
||||
Model::Model(const std::string& p_graphName,
|
||||
const std::string& p_graphDocString,
|
||||
VERSION p_irVersion,
|
||||
const std::string& p_producerName,
|
||||
const std::string& p_producerVersion,
|
||||
const std::string& p_domain,
|
||||
VERSION p_modelVersion,
|
||||
const std::string& p_docString,
|
||||
const std::string& p_modelAuthor,
|
||||
const std::string& p_modelLicense)
|
||||
{
|
||||
m_graph.reset(new Graph(p_graphName, p_graphDocString));
|
||||
m_modelProto.set_ir_version(p_irVersion);
|
||||
m_modelProto.set_producer_name(p_producerName);
|
||||
m_modelProto.set_producer_version(p_producerVersion);
|
||||
m_modelProto.set_domain(p_domain);
|
||||
m_modelProto.set_model_version(p_modelVersion);
|
||||
m_modelProto.set_doc_string(p_docString);
|
||||
m_modelProto.set_model_author(p_modelAuthor);
|
||||
m_modelProto.set_model_license(p_modelLicense);
|
||||
}
|
||||
|
||||
Model::Model(const ModelProto& p_modelProto)
|
||||
{
|
||||
m_modelProto = p_modelProto;
|
||||
if (m_modelProto.has_graph())
|
||||
{
|
||||
m_graph.reset(new Graph(m_modelProto.graph()));
|
||||
}
|
||||
}
|
||||
|
||||
VERSION Model::IrVersion() const
|
||||
{
|
||||
if (m_modelProto.has_ir_version())
|
||||
{
|
||||
return m_modelProto.ir_version();
|
||||
}
|
||||
return c_noVersion;
|
||||
}
|
||||
|
||||
void Model::SetIrVersion(VERSION p_irVersion)
|
||||
{
|
||||
m_modelProto.set_ir_version(p_irVersion);
|
||||
}
|
||||
|
||||
const std::string& Model::ProducerName() const
|
||||
{
|
||||
return m_modelProto.producer_name();
|
||||
}
|
||||
|
||||
void Model::SetProducerName(const std::string& p_producerName)
|
||||
{
|
||||
m_modelProto.set_producer_name(p_producerName);
|
||||
}
|
||||
|
||||
const std::string& Model::ProducerVersion() const
|
||||
{
|
||||
return m_modelProto.producer_version();
|
||||
}
|
||||
|
||||
void Model::SetProducerVersion(const std::string& p_producerVersion)
|
||||
{
|
||||
m_modelProto.set_producer_version(p_producerVersion);
|
||||
}
|
||||
|
||||
const std::string& Model::Domain() const
|
||||
{
|
||||
return m_modelProto.domain();
|
||||
}
|
||||
|
||||
void Model::SetDomain(const std::string& p_domain)
|
||||
{
|
||||
m_modelProto.set_domain(p_domain);
|
||||
}
|
||||
|
||||
VERSION Model::ModelVersion() const
|
||||
{
|
||||
if (m_modelProto.has_model_version())
|
||||
{
|
||||
return m_modelProto.model_version();
|
||||
}
|
||||
return c_noVersion;
|
||||
}
|
||||
|
||||
void Model::SetModelversion(VERSION p_modelVersion)
|
||||
{
|
||||
m_modelProto.set_model_version(p_modelVersion);
|
||||
}
|
||||
|
||||
const std::string& Model::DocString() const
|
||||
{
|
||||
return m_modelProto.doc_string();
|
||||
}
|
||||
|
||||
void Model::SetDocString(const std::string& p_docString)
|
||||
{
|
||||
m_modelProto.set_doc_string(p_docString);
|
||||
}
|
||||
|
||||
const std::string& Model::ModelAuthor() const
|
||||
{
|
||||
return m_modelProto.model_author();
|
||||
}
|
||||
|
||||
void Model::SetModelAuthor(const std::string& p_modelAuthor)
|
||||
{
|
||||
m_modelProto.set_model_author(p_modelAuthor);
|
||||
}
|
||||
|
||||
const std::string& Model::ModelLicense() const
|
||||
{
|
||||
return m_modelProto.model_license();
|
||||
}
|
||||
|
||||
void Model::SetModelLicense(const std::string& p_modelLicense)
|
||||
{
|
||||
m_modelProto.set_model_license(p_modelLicense);
|
||||
}
|
||||
|
||||
Graph* Model::MainGraph()
|
||||
{
|
||||
return m_graph.get();
|
||||
}
|
||||
|
||||
const ModelProto& Model::ToProto()
|
||||
{
|
||||
*(m_modelProto.mutable_graph()) = m_graph->ToGraphProto();
|
||||
return m_modelProto;
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
bool Model::Load(const std::wstring& p_filePath, /*out*/ ModelProto* p_modelProto)
|
||||
{
|
||||
return Load(FileOpenRd(p_filePath), p_modelProto);
|
||||
}
|
||||
std::shared_ptr<Model> Model::Load(const std::wstring& p_filePath)
|
||||
{
|
||||
return Load(FileOpenRd(p_filePath));
|
||||
}
|
||||
bool Model::Save(Model& p_model, const std::wstring& p_filePath)
|
||||
{
|
||||
return Save(p_model.ToProto(), FileOpenWr(p_filePath));
|
||||
}
|
||||
bool Model::Save(const ModelProto& p_modelProto, const std::wstring& p_filePath)
|
||||
{
|
||||
return Save(p_modelProto, FileOpenWr(p_filePath));
|
||||
}
|
||||
#endif
|
||||
|
||||
bool Model::Load(const std::string& p_filePath, /*out*/ ModelProto* p_modelProto)
|
||||
{
|
||||
return Load(FileOpenRd(p_filePath), p_modelProto);
|
||||
}
|
||||
std::shared_ptr<Model> Model::Load(const std::string& p_filePath)
|
||||
{
|
||||
return Load(FileOpenRd(p_filePath));
|
||||
}
|
||||
bool Model::Save(Model& p_model, const std::string& p_filePath)
|
||||
{
|
||||
return Save(p_model.ToProto(), FileOpenWr(p_filePath));
|
||||
}
|
||||
bool Model::Save(const ModelProto& p_modelProto, const std::string& p_filePath)
|
||||
{
|
||||
return Save(p_modelProto, FileOpenWr(p_filePath));
|
||||
}
|
||||
|
||||
using ::google::protobuf::io::ZeroCopyInputStream;
|
||||
using ::google::protobuf::io::FileInputStream;
|
||||
using ::google::protobuf::io::CodedInputStream;
|
||||
bool Model::Load(int p_fd, /*out*/ ModelProto* p_modelProto)
|
||||
{
|
||||
if (nullptr == p_modelProto || p_fd < 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
std::unique_ptr<ZeroCopyInputStream> raw_input(new FileInputStream(p_fd));
|
||||
std::unique_ptr<CodedInputStream> coded_input(
|
||||
new CodedInputStream(raw_input.get()));
|
||||
// Allows protobuf library versions < 3.2.0 to parse messages greater than 64MB.
|
||||
coded_input->SetTotalBytesLimit(INT_MAX, INT_MAX);
|
||||
bool result = p_modelProto->ParseFromCodedStream(coded_input.get());
|
||||
coded_input.reset();
|
||||
raw_input.reset();
|
||||
close(p_fd);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::shared_ptr<Model> Model::Load(int p_fd)
|
||||
{
|
||||
ModelProto modelProto;
|
||||
bool result = Load(p_fd, &modelProto);
|
||||
if (!result || p_fd < 0)
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
auto model = std::shared_ptr<Model>(new Model(modelProto));
|
||||
auto status = model->MainGraph()->Resolve();
|
||||
|
||||
close(p_fd);
|
||||
if (status.Ok())
|
||||
{
|
||||
return model;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool Model::Save(const ModelProto& p_modelProto, int p_fd)
|
||||
{
|
||||
if (p_fd < 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
bool result = p_modelProto.SerializeToFileDescriptor(p_fd);
|
||||
close(p_fd);
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma warning(pop)
|
|
@ -0,0 +1,121 @@
|
|||
#ifndef CORE_GRAPH_MODEL_H
|
||||
#define CORE_GRAPH_MODEL_H
|
||||
|
||||
#include "graph.h"
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
// A machine learning model representation class.
|
||||
// Besides a main <Graph>, it also holds basic information, say,
|
||||
// version, domain, model author, license etc.
|
||||
class Model
|
||||
{
|
||||
public:
|
||||
|
||||
const VERSION c_noVersion = INT64_MAX;
|
||||
|
||||
// <p_isONNX> is a special flag to indicate whether it's
|
||||
// going to construct a ONNX graph. With ONNX graph, strict
|
||||
// type checking will be skiped.
|
||||
Model(const std::string& p_graphName, bool p_isONNX = false);
|
||||
|
||||
Model(const std::string& p_graphName,
|
||||
const std::string& p_graphDocString);
|
||||
|
||||
Model(const std::string& p_graphName,
|
||||
const std::string& p_graphDocString,
|
||||
VERSION p_irVersion,
|
||||
const std::string& p_producerName,
|
||||
const std::string& p_producerVersion,
|
||||
const std::string& p_domain,
|
||||
VERSION p_modelVersion,
|
||||
const std::string& p_modelDocString,
|
||||
const std::string& p_modelAuthor,
|
||||
const std::string& p_modelLicense);
|
||||
|
||||
Model(const ModelProto& p_modelProto);
|
||||
|
||||
// Get model's IR version.
|
||||
// Return <c_noVersion> if not specified.
|
||||
VERSION IrVersion() const;
|
||||
// Set model's IR version.
|
||||
void SetIrVersion(VERSION p_irVersion);
|
||||
|
||||
// Get model's producer name.
|
||||
// Return null pointer if not specified.
|
||||
const std::string& ProducerName() const;
|
||||
// Set model's producer name.
|
||||
void SetProducerName(const std::string& p_producerName);
|
||||
|
||||
// Get model's producer version.
|
||||
// Return null pointer if not specified.
|
||||
const std::string& ProducerVersion() const;
|
||||
// Set model's producer version.
|
||||
void SetProducerVersion(const std::string& p_producerVersion);
|
||||
|
||||
// Get model's domain.
|
||||
// Return null pointer if not specified.
|
||||
const std::string& Domain() const;
|
||||
// Set models' damain.
|
||||
void SetDomain(const std::string& p_domain);
|
||||
|
||||
// Get model's version.
|
||||
// Return null pointer if not specified.
|
||||
VERSION ModelVersion() const;
|
||||
// Set models' version.
|
||||
void SetModelversion(VERSION p_modelVersion);
|
||||
|
||||
// Get model's doc string.
|
||||
// Return null pointer if not specified.
|
||||
const std::string& DocString() const;
|
||||
// Set models' doc string.
|
||||
void SetDocString(const std::string& p_docString);
|
||||
|
||||
// Get model's author.
|
||||
// Return null pointer if not specified.
|
||||
const std::string& ModelAuthor() const;
|
||||
// Set models' author.
|
||||
void SetModelAuthor(const std::string& p_modelAuthor);
|
||||
|
||||
// Get model's license.
|
||||
// Return null pointer if not specified.
|
||||
const std::string& ModelLicense() const;
|
||||
// Set models' license.
|
||||
void SetModelLicense(const std::string& p_modelLicense);
|
||||
|
||||
// Get model's main graph.
|
||||
// The return pointer is owned by <*this> model.
|
||||
Graph* MainGraph();
|
||||
|
||||
// Get model's serlization proto data.
|
||||
const ModelProto& ToProto();
|
||||
|
||||
#ifdef _WIN32
|
||||
// wstring versions for Windows only.
|
||||
static bool Save(const ModelProto& p_modelProto, const std::wstring& p_filePath);
|
||||
static bool Save(Model& p_model, const std::wstring& p_filePath);
|
||||
// Load a ModelProto from a file.
|
||||
static bool Load(const std::wstring& p_filePath, /*out*/ ModelProto* p_modelProto);
|
||||
static std::shared_ptr<Model> Load(const std::wstring& p_filePath);
|
||||
#endif
|
||||
// Save a ModelProto to a file.
|
||||
static bool Save(const ModelProto& p_modelProto, const std::string& p_filePath);
|
||||
static bool Save(Model& p_model, const std::string& p_filePath);
|
||||
static bool Save(const ModelProto& p_modelProto, int p_fd);
|
||||
// Load a ModelProto from a file.
|
||||
static bool Load(const std::string& p_filePath, /*out*/ ModelProto* p_modelProto);
|
||||
static std::shared_ptr<Model> Load(const std::string& p_filePath);
|
||||
static bool Load(int p_fd, /*out*/ ModelProto* p_modelProto);
|
||||
static std::shared_ptr<Model> Load(int p_fd);
|
||||
|
||||
private:
|
||||
|
||||
// Model data.
|
||||
ModelProto m_modelProto;
|
||||
|
||||
// Main graph of the model.
|
||||
std::unique_ptr<Graph> m_graph;
|
||||
};
|
||||
}
|
||||
|
||||
#endif
|
|
@ -0,0 +1,354 @@
|
|||
#pragma warning(push)
|
||||
#pragma warning(disable : 4800 4610 4512 4510 4267 4127 4125 4100 4456)
|
||||
|
||||
#include "op.h"
|
||||
#include "opsignature.h"
|
||||
#include "utils.h"
|
||||
#include <cstring>
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
const std::string& OperatorSchema::GetName() const
|
||||
{
|
||||
return m_opSignature.GetName();
|
||||
}
|
||||
|
||||
const OpSignature& OperatorSchema::GetOpSignature() const
|
||||
{
|
||||
return m_opSignature;
|
||||
}
|
||||
|
||||
ShapeInferenceFunc OperatorSchema::GetShapeInferenceFn() const
|
||||
{
|
||||
return m_shapeInferenceFunc;
|
||||
}
|
||||
|
||||
AttributeParser OperatorSchema::GetAttributeParser() const
|
||||
{
|
||||
return m_attrParser;
|
||||
}
|
||||
|
||||
OperatorSchemaSetter&
|
||||
OperatorSchemaSetter::Name(const std::string& p_opName)
|
||||
{
|
||||
m_opSchema.m_opSignature.m_name = p_opName;
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
|
||||
OperatorSchemaSetter&
|
||||
OperatorSchemaSetter::Description(const std::string& p_description)
|
||||
{
|
||||
m_opSchema.m_opSignature.m_description = p_description;
|
||||
return *this;
|
||||
}
|
||||
|
||||
OperatorSchemaSetter&
|
||||
OperatorSchemaSetter::Input(const std::string& p_inputName,
|
||||
const std::string& p_description,
|
||||
const std::string& p_type)
|
||||
{
|
||||
m_inputs.push_back(std::make_tuple(p_inputName, p_description, p_type));
|
||||
return *this;
|
||||
}
|
||||
|
||||
OperatorSchemaSetter&
|
||||
OperatorSchemaSetter::Output(const std::string& p_outputName,
|
||||
const std::string& p_description,
|
||||
const std::string& p_type)
|
||||
{
|
||||
m_outputs.push_back(std::make_tuple(p_outputName, p_description, p_type));
|
||||
return *this;
|
||||
}
|
||||
|
||||
OperatorSchemaSetter&
|
||||
OperatorSchemaSetter::Attr(const std::string& p_attrName,
|
||||
const std::string& p_description,
|
||||
AttrType p_attrType, bool required)
|
||||
{
|
||||
m_opSchema.m_opSignature.m_attributes.push_back(
|
||||
OpSignature::Attribute(p_attrName, p_attrType, p_description));
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
#define ATTR_SETTER_BASIC_IMPL(type, field) \
|
||||
OperatorSchemaSetter& \
|
||||
OperatorSchemaSetter::Attr(const std::string& p_attrName, \
|
||||
const std::string& p_description, \
|
||||
AttrType p_attrType, \
|
||||
const type& p_defaultValue) \
|
||||
{ \
|
||||
AttributeProto a; \
|
||||
a.set_name(p_attrName); \
|
||||
a.set_##field(p_defaultValue); \
|
||||
\
|
||||
m_opSchema.m_opSignature.m_attributes.push_back( \
|
||||
OpSignature::Attribute(p_attrName, \
|
||||
p_attrType, \
|
||||
p_description, \
|
||||
a)); \
|
||||
\
|
||||
return *this; \
|
||||
} \
|
||||
|
||||
#define ATTR_SETTER_LIST_IMPL(type, field) \
|
||||
OperatorSchemaSetter& \
|
||||
OperatorSchemaSetter::Attr(const std::string& p_attrName, \
|
||||
const std::string& p_description, \
|
||||
AttrType p_attrType, \
|
||||
const std::vector<type>& p_defaultValue) \
|
||||
{ \
|
||||
AttributeProto a; \
|
||||
a.set_name(p_attrName); \
|
||||
for (const auto& v : p_defaultValue) \
|
||||
{ \
|
||||
a.add_##field(v); \
|
||||
} \
|
||||
\
|
||||
m_opSchema.m_opSignature.m_attributes.push_back( \
|
||||
OpSignature::Attribute(p_attrName, \
|
||||
p_attrType, \
|
||||
p_description, \
|
||||
a)); \
|
||||
return *this; \
|
||||
} \
|
||||
|
||||
ATTR_SETTER_BASIC_IMPL(int64_t, i)
|
||||
ATTR_SETTER_BASIC_IMPL(float, f)
|
||||
ATTR_SETTER_BASIC_IMPL(std::string, s)
|
||||
ATTR_SETTER_LIST_IMPL(int64_t, ints)
|
||||
ATTR_SETTER_LIST_IMPL(float, floats)
|
||||
ATTR_SETTER_LIST_IMPL(std::string, strings)
|
||||
|
||||
OperatorSchemaSetter&
|
||||
OperatorSchemaSetter::TypeConstraint(const std::string& p_typeName,
|
||||
const std::vector<std::string>& p_constraints,
|
||||
const std::string& p_description)
|
||||
{
|
||||
m_constraints.push_back(std::make_tuple(p_typeName, p_constraints, p_description));
|
||||
return *this;
|
||||
}
|
||||
|
||||
OperatorSchemaSetter&
|
||||
OperatorSchemaSetter::SetShapeInferenceFunc(
|
||||
ShapeInferenceFunc p_shapeInferFunc)
|
||||
{
|
||||
m_opSchema.m_shapeInferenceFunc = p_shapeInferFunc;
|
||||
return *this;
|
||||
}
|
||||
|
||||
OperatorSchemaSetter&
|
||||
OperatorSchemaSetter::SetAttributeParser(
|
||||
AttributeParser p_attrParser)
|
||||
{
|
||||
m_opSchema.m_attrParser = p_attrParser;
|
||||
return *this;
|
||||
}
|
||||
|
||||
OperatorSchemaRegistry::RegisterOnce::RegisterOnce(
|
||||
OperatorSchemaSetter& p_opSchemaSetter)
|
||||
{
|
||||
auto& opSchema = p_opSchemaSetter.m_opSchema;
|
||||
// Process type constraints.
|
||||
for (const auto& constraint : p_opSchemaSetter.m_constraints)
|
||||
{
|
||||
std::string name;
|
||||
std::vector<std::string> types;
|
||||
std::string desc;
|
||||
std::tie(name, types, desc) = constraint;
|
||||
|
||||
auto it = opSchema.m_opSignature.m_typeConstraintMap.find(name);
|
||||
if (it == opSchema.m_opSignature.m_typeConstraintMap.end())
|
||||
{
|
||||
DataTypeSet d;
|
||||
for (const auto& t : types)
|
||||
{
|
||||
d.insert(Utils::OpUtils::ToType(t));
|
||||
}
|
||||
opSchema.m_opSignature.m_typeConstraintMap.insert(std::make_pair(name, std::make_pair(d, desc)));
|
||||
}
|
||||
else
|
||||
{
|
||||
// already a constraint with the same name. error.
|
||||
}
|
||||
}
|
||||
|
||||
opSchema.m_opSignature.m_inputs.reserve(p_opSchemaSetter.m_inputs.size());
|
||||
for (const auto& input : p_opSchemaSetter.m_inputs)
|
||||
{
|
||||
std::string name;
|
||||
std::string type;
|
||||
std::string desc;
|
||||
std::tie(name, desc, type) = input;
|
||||
opSchema.m_opSignature.m_inputs.push_back(
|
||||
OpSignature::FormalParameter(name, type, desc, opSchema.m_opSignature.m_typeConstraintMap));
|
||||
}
|
||||
|
||||
opSchema.m_opSignature.m_outputs.reserve(p_opSchemaSetter.m_outputs.size());
|
||||
for (const auto& output : p_opSchemaSetter.m_outputs)
|
||||
{
|
||||
std::string name;
|
||||
std::string type;
|
||||
std::string desc;
|
||||
std::tie(name, desc, type) = output;
|
||||
opSchema.m_opSignature.m_outputs.push_back(
|
||||
OpSignature::FormalParameter(name, type, desc,
|
||||
opSchema.m_opSignature.m_typeConstraintMap));
|
||||
}
|
||||
|
||||
auto& opSignature = p_opSchemaSetter.m_opSchema.m_opSignature;
|
||||
if (0 == opSignature.m_inputs.size())
|
||||
{
|
||||
for (int i = 0; i < opSignature.m_onnxMinInput; ++i)
|
||||
{
|
||||
std::string name = "p" + std::to_string(i);
|
||||
std::string desc = "Input Parameter " + std::to_string(i);
|
||||
opSignature.m_inputs.push_back(
|
||||
OpSignature::FormalParameter(name, "", desc, opSignature.m_typeConstraintMap));
|
||||
}
|
||||
}
|
||||
|
||||
if (0 == opSignature.m_outputs.size())
|
||||
{
|
||||
for (int i = 0; i < opSignature.m_onnxMinOutput; ++i)
|
||||
{
|
||||
std::string name = "p" + std::to_string(i);
|
||||
std::string desc = "Output Result " + std::to_string(i);
|
||||
opSignature.m_outputs.push_back(
|
||||
OpSignature::FormalParameter(name, "", desc, opSignature.m_typeConstraintMap));
|
||||
}
|
||||
}
|
||||
OperatorSchemaRegistry::Get()->Register(p_opSchemaSetter.m_opSchema);
|
||||
}
|
||||
|
||||
bool OperatorSchemaRegistry::TryGetOp(const std::string& p_name,
|
||||
const OperatorSchema** p_opSchema) const
|
||||
{
|
||||
if (nullptr == p_opSchema)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
auto iter = m_opNameToOpSchemaMap.find(p_name);
|
||||
if (m_opNameToOpSchemaMap.end() == iter)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
*p_opSchema = &(iter->second);
|
||||
return true;
|
||||
}
|
||||
|
||||
Status OperatorSchemaRegistry::Register(
|
||||
const OperatorSchema& p_opSchema)
|
||||
{
|
||||
auto iter = m_opNameToOpSchemaMap.find(p_opSchema.GetName());
|
||||
if (m_opNameToOpSchemaMap.end() != iter)
|
||||
{
|
||||
Status status(false,
|
||||
"Error: operator schema with same name ("
|
||||
+ p_opSchema.GetName() + ") exists.");
|
||||
return status;
|
||||
}
|
||||
else
|
||||
{
|
||||
m_opNameToOpSchemaMap[p_opSchema.GetName()] = p_opSchema;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
OperatorSchemaRegistry* OperatorSchemaRegistry::Get()
|
||||
{
|
||||
static OperatorSchemaRegistry* s_registry
|
||||
= new OperatorSchemaRegistry();
|
||||
return s_registry;
|
||||
}
|
||||
|
||||
Status TypeUtils::GetType(const AttributeProto& p_attr, AttrType& p_type)
|
||||
{
|
||||
if (!OpSignature::IsValidAttribute(p_attr))
|
||||
{
|
||||
return Status(false, "Invalid AttributeProto.");
|
||||
}
|
||||
|
||||
if (p_attr.has_f())
|
||||
{
|
||||
p_type = AttrType::FLOAT;
|
||||
}
|
||||
else if (p_attr.has_i())
|
||||
{
|
||||
p_type = AttrType::INT;
|
||||
}
|
||||
else if (p_attr.has_s())
|
||||
{
|
||||
p_type = AttrType::STRING;
|
||||
}
|
||||
else if (p_attr.has_t())
|
||||
{
|
||||
p_type = AttrType::TENSOR;
|
||||
}
|
||||
else if (p_attr.has_g())
|
||||
{
|
||||
p_type = AttrType::GRAPH;
|
||||
}
|
||||
else if (p_attr.floats_size())
|
||||
{
|
||||
p_type = AttrType::FLOATS;
|
||||
}
|
||||
else if (p_attr.ints_size())
|
||||
{
|
||||
p_type = AttrType::INTS;
|
||||
}
|
||||
else if (p_attr.strings_size())
|
||||
{
|
||||
p_type = AttrType::STRINGS;
|
||||
}
|
||||
else if (p_attr.tensors_size())
|
||||
{
|
||||
p_type = AttrType::TENSORS;
|
||||
}
|
||||
else if (p_attr.graphs_size())
|
||||
{
|
||||
p_type = AttrType::GRAPHS;
|
||||
}
|
||||
else if (p_attr.has_type())
|
||||
{
|
||||
p_type = AttrType::TYPE;
|
||||
}
|
||||
else if (p_attr.types_size())
|
||||
{
|
||||
p_type = AttrType::TYPES;
|
||||
}
|
||||
else if (p_attr.has_shape())
|
||||
{
|
||||
p_type = AttrType::SHAPE;
|
||||
}
|
||||
else if (p_attr.has_shape())
|
||||
{
|
||||
p_type = AttrType::SHAPES;
|
||||
}
|
||||
else
|
||||
{
|
||||
p_type = AttrType::NONE;
|
||||
return Status(false, "Invalid AttributeProto.");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
size_t ReplaceAll(std::string& s, const char* from, const char* to)
|
||||
{
|
||||
size_t numReplaced = 0;
|
||||
std::string::size_type lenFrom = std::strlen(from);
|
||||
std::string::size_type lenTo = std::strlen(to);
|
||||
for (std::string::size_type pos = s.find(from); pos != std::string::npos;
|
||||
pos = s.find(from, pos + lenTo)) {
|
||||
s.replace(pos, lenFrom, to);
|
||||
numReplaced++;
|
||||
}
|
||||
return numReplaced;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma warning(pop)
|
|
@ -0,0 +1,264 @@
|
|||
#ifndef CORE_GRAPH_OP_H
|
||||
#define CORE_GRAPH_OP_H
|
||||
|
||||
#include <functional>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "opsignature.h"
|
||||
#include "shape_inference.h"
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
class OpSignature;
|
||||
class OperatorSchemaSetter;
|
||||
typedef OperatorSchemaSetter OpSchema;
|
||||
|
||||
class TypeUtils
|
||||
{
|
||||
public:
|
||||
|
||||
// Get attribute type given attribute proto data.
|
||||
static Status GetType(const AttributeProto& p_attr, AttrType& p_type);
|
||||
|
||||
};
|
||||
|
||||
// An attribute parser - it's specified when registering an operator.
|
||||
// The parser is designed and used in two ways.
|
||||
// 1) It will be used to verify whether a Node's attributes match the
|
||||
// operator's definition.
|
||||
// 2) It will be used to parse a Node's attributes into a <T> object,
|
||||
// which makes it be easier to access node attributes.
|
||||
// TODO: to implement the 2nd point above, NodeAttributes should be changed
|
||||
// to contain a <T> field, which is structured attributes.
|
||||
typedef std::function<Status(const NodeAttributes&)> AttributeParser;
|
||||
|
||||
class OperatorSchema
|
||||
{
|
||||
public:
|
||||
|
||||
const std::string& GetName() const;
|
||||
const OpSignature& GetOpSignature() const;
|
||||
ShapeInferenceFunc GetShapeInferenceFn() const;
|
||||
AttributeParser GetAttributeParser() const;
|
||||
|
||||
private:
|
||||
|
||||
friend class OperatorSchemaSetter;
|
||||
friend class OperatorSchemaRegistry;
|
||||
|
||||
OpSignature m_opSignature;
|
||||
ShapeInferenceFunc m_shapeInferenceFunc;
|
||||
AttributeParser m_attrParser;
|
||||
};
|
||||
|
||||
typedef std::tuple<std::string, std::string, std::string> InputOutputParam;
|
||||
typedef std::tuple<std::string, std::string, AttrType, AttributeProto> AttrParam;
|
||||
typedef std::tuple<std::string, std::vector<std::string>, std::string> TypeConstraintParam;
|
||||
|
||||
#define ATTR_SETTER_INTERFACE(TypeName) \
|
||||
OperatorSchemaSetter& Attr(const std::string& p_attrName, \
|
||||
const std::string& p_description, \
|
||||
AttrType p_attrType, \
|
||||
const TypeName& p_defaultValue); \
|
||||
OperatorSchemaSetter& Attr(const std::string& p_attrName, \
|
||||
const std::string& p_description, \
|
||||
AttrType p_attrType, \
|
||||
const std::vector<TypeName>& p_defaultValues); \
|
||||
|
||||
// Operator registry setter helper.
|
||||
// This is used in "OPERATOR_DEFINITION" macro, to separate setters from getters
|
||||
// in OpSignature.
|
||||
class OperatorSchemaSetter
|
||||
{
|
||||
public:
|
||||
|
||||
OperatorSchemaSetter() = default;
|
||||
|
||||
OperatorSchemaSetter& Name(const std::string& p_opName);
|
||||
|
||||
OperatorSchemaSetter& Description(const std::string& p_description);
|
||||
|
||||
OperatorSchemaSetter& Input(const std::string& p_inputName,
|
||||
const std::string& p_description,
|
||||
const std::string& p_type = "");
|
||||
|
||||
OperatorSchemaSetter& Output(const std::string& p_outputName,
|
||||
const std::string& p_description,
|
||||
const std::string& p_type = "");
|
||||
|
||||
OperatorSchemaSetter& Attr(const std::string& p_attrName,
|
||||
const std::string& p_description,
|
||||
AttrType p_attrType, bool required = false);
|
||||
|
||||
ATTR_SETTER_INTERFACE(int64_t)
|
||||
ATTR_SETTER_INTERFACE(float)
|
||||
ATTR_SETTER_INTERFACE(std::string)
|
||||
ATTR_SETTER_INTERFACE(TensorProto)
|
||||
ATTR_SETTER_INTERFACE(GraphProto)
|
||||
ATTR_SETTER_INTERFACE(TypeProto)
|
||||
ATTR_SETTER_INTERFACE(TypeProto::TensorShapeProto)
|
||||
|
||||
OperatorSchemaSetter& TypeConstraint(const std::string& p_typeName,
|
||||
const std::vector<std::string>& p_constraints,
|
||||
const std::string& p_description);
|
||||
|
||||
// Shape inference function will be used to infer outputs' shape with
|
||||
// inputs' shape.
|
||||
OperatorSchemaSetter& SetShapeInferenceFunc(
|
||||
ShapeInferenceFunc p_shapeInferFunc);
|
||||
|
||||
// Attribute parser will be used to parse Node's attributes to see
|
||||
// whether Node attributes are matching operator attributes definition.
|
||||
OperatorSchemaSetter& SetAttributeParser(
|
||||
AttributeParser p_attrParser);
|
||||
|
||||
enum class SupportType {
|
||||
COMMON,
|
||||
EXPERIMENTAL,
|
||||
};
|
||||
// Methods added for compatibility with ONNX OpSchema registration API
|
||||
OpSchema& NumInputs(int n)
|
||||
{
|
||||
return NumInputs(n, n);
|
||||
}
|
||||
OpSchema& NumInputs(int min, int max)
|
||||
{
|
||||
m_opSchema.m_opSignature.m_onnxMinInput = min;
|
||||
m_opSchema.m_opSignature.m_onnxMaxInput = max;
|
||||
return *this;
|
||||
}
|
||||
OpSchema& NumInputs(std::set<int> allowed_input_nums)
|
||||
{
|
||||
return NumInputs([allowed_input_nums](int n)-> bool {
|
||||
return allowed_input_nums.count(n) > 0;
|
||||
});
|
||||
}
|
||||
OpSchema& NumInputs(std::function<bool(int)> func)
|
||||
{
|
||||
m_opSchema.m_opSignature.m_onnxNumInputsAllowed = func;
|
||||
return *this;
|
||||
}
|
||||
OpSchema& NumOutputs(int n) {
|
||||
return NumOutputs(n, n);
|
||||
}
|
||||
OpSchema& NumOutputs(int min, int max)
|
||||
{
|
||||
m_opSchema.m_opSignature.m_onnxMinOutput = min;
|
||||
m_opSchema.m_opSignature.m_onnxMaxOutput = max;
|
||||
return *this;
|
||||
}
|
||||
OpSchema& NumOutputs(std::set<int> allowed_output_nums)
|
||||
{
|
||||
return NumOutputs([allowed_output_nums](int n)-> bool {
|
||||
return allowed_output_nums.count(n) > 0;
|
||||
});
|
||||
}
|
||||
OpSchema& NumOutputs(std::function<bool(int)> func)
|
||||
{
|
||||
m_opSchema.m_opSignature.m_onnxNumOutputsAllowed = func;
|
||||
return *this;
|
||||
}
|
||||
OpSchema& NumInputsOutputs(std::function<bool(int, int)> func)
|
||||
{
|
||||
m_opSchema.m_opSignature.m_onnxNumInputsOutputsAllowed = func;
|
||||
return *this;
|
||||
}
|
||||
OpSchema& OutputCalculator(std::function<int(int)> calc) { return *this; }
|
||||
OpSchema& SameNumberOfOutput() { return *this; }
|
||||
OpSchema& AllowConsumed(std::function<std::pair<bool, int>(int)> inplace) { return *this; }
|
||||
OpSchema& AllowConsumed(std::unordered_map<int, int> inplace) { return *this; }
|
||||
OpSchema& AllowOneToOneConsumed() { return *this; }
|
||||
OpSchema& EnforceConsumed(std::function<std::pair<bool, int>(int)> inplace) { return *this; }
|
||||
OpSchema& EnforceConsumed(std::unordered_map<int, int> inplace) { return *this; }
|
||||
OpSchema& EnforceOneToOneConsumed() { return *this; }
|
||||
OpSchema& SetSupportLevel(SupportType) { return *this; }
|
||||
OpSchema& AllowUncheckedAttributes() { return *this; }
|
||||
OpSchema& FillUsing(std::function<void(OpSchema&)> populator)
|
||||
{
|
||||
if (populator)
|
||||
{
|
||||
populator(*this);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
OpSchema& Input(const int, const char* name, const char* description)
|
||||
{
|
||||
return Input(name, description);
|
||||
}
|
||||
OpSchema& Output(const int, const char* name, const char* description)
|
||||
{
|
||||
return Output(name, description);
|
||||
}
|
||||
OpSchema& SetDoc(const std::string& doc)
|
||||
{
|
||||
return Description(doc);
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
//friend class OpSignature;
|
||||
friend class OperatorSchemaRegistry;
|
||||
|
||||
OperatorSchema m_opSchema;
|
||||
|
||||
// Operator input formal parameters.
|
||||
std::vector<InputOutputParam> m_inputs;
|
||||
|
||||
// Operator output formal parameters.
|
||||
std::vector<InputOutputParam> m_outputs;
|
||||
|
||||
// Operator type constraints.
|
||||
std::vector<TypeConstraintParam> m_constraints;
|
||||
};
|
||||
|
||||
// Operator schema registry. A singleton registry to manage all operator
|
||||
// schemas.
|
||||
class OperatorSchemaRegistry
|
||||
{
|
||||
public:
|
||||
|
||||
// Helper function providing a way to call
|
||||
// OpSignatureFactory::Register().
|
||||
class RegisterOnce
|
||||
{
|
||||
public:
|
||||
|
||||
RegisterOnce(OperatorSchemaSetter& p_opRegistry);
|
||||
};
|
||||
|
||||
// Try to get operator with specified operator name.
|
||||
bool TryGetOp(const std::string& p_name,
|
||||
const OperatorSchema** p_opRegistry) const;
|
||||
|
||||
// Register an operator.
|
||||
Status Register(const OperatorSchema& p_opSchema);
|
||||
|
||||
// Get the global operator registry factory instance.
|
||||
static OperatorSchemaRegistry* Get();
|
||||
|
||||
private:
|
||||
|
||||
OperatorSchemaRegistry() = default;
|
||||
|
||||
// An operator name to operator definition data map.
|
||||
std::unordered_map<std::string, OperatorSchema> m_opNameToOpSchemaMap;
|
||||
};
|
||||
|
||||
// utility function used by ONNX v1 op registration defs.
|
||||
size_t ReplaceAll(std::string& s, const char* from, const char* to);
|
||||
|
||||
#define REGISTER_OPERATOR_SCHEMA(OpName) OPERATOR_SCHEMA_UNIQ_HELPER(__COUNTER__, OpName)
|
||||
#define OPERATOR_SCHEMA_UNIQ_HELPER(Counter, OpName) OPERATOR_SCHEMA_UNIQ(Counter, OpName)
|
||||
#define OPERATOR_SCHEMA_UNIQ(Counter, OpName) \
|
||||
static OperatorSchemaRegistry::RegisterOnce op_##Counter \
|
||||
= OperatorSchemaSetter().Name(#OpName)
|
||||
|
||||
// Operator registration example.
|
||||
// OPERATOR_DEFINITION(Add).Description("An operator to sum two float numbers.")
|
||||
// .Input("input_1", "docstr for input_1.", "T")
|
||||
// .Input("input_2", "docstr for input_2.", "T")
|
||||
// .Output("output_1", "docstr for output_1.", "T")
|
||||
// .TypeConstraint("T", { "float16", "float32", "float64" }, "Constrain input and output types to floats.");
|
||||
}
|
||||
|
||||
#endif
|
|
@ -0,0 +1,151 @@
|
|||
#include "opsignature.h"
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
OpSignature::FormalParameter::FormalParameter(
|
||||
const std::string& p_name, const std::string& p_type,
|
||||
const std::string& p_description,
|
||||
const TypeConstraintMap& p_constraintMap)
|
||||
: m_name(p_name), m_typeStr(p_type), m_description(p_description)
|
||||
{
|
||||
auto it = p_constraintMap.find(p_type);
|
||||
if (it != p_constraintMap.end())
|
||||
{
|
||||
m_types = it->second.first;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (!p_type.empty())
|
||||
{
|
||||
m_types.emplace(Utils::OpUtils::ToType(m_typeStr));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const std::string& OpSignature::FormalParameter::GetName() const
|
||||
{
|
||||
return m_name;
|
||||
}
|
||||
|
||||
const DataTypeSet& OpSignature::FormalParameter::GetTypes() const
|
||||
{
|
||||
return m_types;
|
||||
}
|
||||
|
||||
const std::string& OpSignature::FormalParameter::GetTypeStr() const
|
||||
{
|
||||
return m_typeStr;
|
||||
}
|
||||
|
||||
const std::string& OpSignature::FormalParameter::GetDescription() const
|
||||
{
|
||||
return m_description;
|
||||
}
|
||||
|
||||
OpSignature::Attribute::Attribute(
|
||||
const std::string& p_attrName,
|
||||
AttrType p_type,
|
||||
const std::string& p_description,
|
||||
const AttributeProto& p_defaultVal)
|
||||
: m_name(p_attrName), m_type(p_type), m_description(p_description),
|
||||
m_hasDefaultValue(true)
|
||||
{
|
||||
m_allowedValues.push_back(p_defaultVal);
|
||||
}
|
||||
|
||||
OpSignature::Attribute::Attribute(
|
||||
const std::string& p_attrName,
|
||||
AttrType p_type,
|
||||
const std::string& p_description)
|
||||
: m_name(p_attrName), m_type(p_type), m_description(p_description),
|
||||
m_hasDefaultValue(false)
|
||||
{
|
||||
}
|
||||
|
||||
const std::string& OpSignature::Attribute::GetName() const
|
||||
{
|
||||
return m_name;
|
||||
}
|
||||
|
||||
AttrType OpSignature::Attribute::GetType() const
|
||||
{
|
||||
return m_type;
|
||||
}
|
||||
|
||||
bool OpSignature::Attribute::HasDefaultValue(
|
||||
const AttributeProto** p_value) const
|
||||
{
|
||||
if (m_hasDefaultValue
|
||||
&& nullptr != p_value)
|
||||
{
|
||||
*p_value = &(m_allowedValues[0]);
|
||||
}
|
||||
|
||||
return m_hasDefaultValue;
|
||||
}
|
||||
|
||||
|
||||
|
||||
const std::string& OpSignature::GetName() const
|
||||
{
|
||||
return m_name;
|
||||
}
|
||||
|
||||
const std::string& OpSignature::GetDescription() const
|
||||
{
|
||||
return m_description;
|
||||
}
|
||||
|
||||
const std::vector<OpSignature::FormalParameter>&
|
||||
OpSignature::GetInputs() const
|
||||
{
|
||||
return m_inputs;
|
||||
}
|
||||
|
||||
const std::vector<OpSignature::FormalParameter>&
|
||||
OpSignature::GetOutputs() const
|
||||
{
|
||||
return m_outputs;
|
||||
}
|
||||
|
||||
const std::vector<OpSignature::Attribute>&
|
||||
OpSignature::GetAttributes() const
|
||||
{
|
||||
return m_attributes;
|
||||
}
|
||||
|
||||
const TypeConstraintMap& OpSignature::GetTypeConstraintMap() const
|
||||
{
|
||||
return m_typeConstraintMap;
|
||||
}
|
||||
|
||||
bool OpSignature::IsValidAttribute(const AttributeProto& p_attr)
|
||||
{
|
||||
if (p_attr.name().empty())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
int num_fields =
|
||||
p_attr.has_f() +
|
||||
p_attr.has_i() +
|
||||
p_attr.has_s() +
|
||||
p_attr.has_t() +
|
||||
p_attr.has_g() +
|
||||
(p_attr.floats_size() > 0) +
|
||||
(p_attr.ints_size() > 0) +
|
||||
(p_attr.strings_size() > 0) +
|
||||
(p_attr.tensors_size() > 0) +
|
||||
(p_attr.graphs_size() > 0) +
|
||||
p_attr.has_type() +
|
||||
(p_attr.types_size() > 0) +
|
||||
p_attr.has_shape() +
|
||||
(p_attr.shapes_size() > 0);
|
||||
|
||||
if (num_fields == 1)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,239 @@
|
|||
#pragma warning(push)
|
||||
#pragma warning(disable : 4800 4610 4512 4510 4267 4127 4125 4100 4456)
|
||||
|
||||
#ifndef CORE_GRAPH_OPSCHEMA_H
|
||||
#define CORE_GRAPH_OPSCHEMA_H
|
||||
|
||||
#include <functional>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "proto/onnx/protobuf/graph.pb.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
enum class AttrType {
|
||||
NONE,
|
||||
FLOAT,
|
||||
INT,
|
||||
STRING,
|
||||
GRAPH,
|
||||
TENSOR,
|
||||
TYPE,
|
||||
SHAPE,
|
||||
FLOATS,
|
||||
INTS,
|
||||
STRINGS,
|
||||
GRAPHS,
|
||||
TENSORS,
|
||||
TYPES,
|
||||
SHAPES
|
||||
};
|
||||
|
||||
// This string array should exactly match the AttrType defined above.
|
||||
static const std::string c_attrTypeStr[14] =
|
||||
{
|
||||
"FLOAT",
|
||||
"INT",
|
||||
"STRING",
|
||||
"GRAPH",
|
||||
"TENSOR",
|
||||
"TYPE",
|
||||
"SHAPE",
|
||||
"FLOATS",
|
||||
"INTS",
|
||||
"STRINGS",
|
||||
"GRAPHS",
|
||||
"TENSORS",
|
||||
"TYPES",
|
||||
"SHAPES"
|
||||
};
|
||||
|
||||
typedef std::unordered_set<PTYPE> DataTypeSet;
|
||||
typedef std::unordered_map<std::string, std::pair<DataTypeSet, std::string>> TypeConstraintMap;
|
||||
|
||||
// Operator signature declaration.
|
||||
// It defines input formal parameter, output formal parameters and
|
||||
// attributes.
|
||||
// Once an operator signature created, it's "Read-Only".
|
||||
class OpSignature
|
||||
{
|
||||
public:
|
||||
|
||||
// Formal parameter represenation, including parameter name, type.
|
||||
class FormalParameter
|
||||
{
|
||||
public:
|
||||
|
||||
// Constructor.
|
||||
explicit FormalParameter(const std::string& p_name,
|
||||
const std::string& p_type,
|
||||
const std::string& p_description,
|
||||
const TypeConstraintMap& p_constraintMap = TypeConstraintMap());
|
||||
|
||||
// Get formal parameter name.
|
||||
const std::string& GetName() const;
|
||||
|
||||
// Get supported data types.
|
||||
const DataTypeSet& GetTypes() const;
|
||||
|
||||
// Get formal parameter type string.
|
||||
const std::string& GetTypeStr() const;
|
||||
|
||||
// Get formal parameter description.
|
||||
const std::string& GetDescription() const;
|
||||
|
||||
private:
|
||||
|
||||
FormalParameter() {}
|
||||
|
||||
// Formal parameter name.
|
||||
std::string m_name;
|
||||
|
||||
// A set of data types supported for <*this> formal parameter.
|
||||
// It should contain at least one element if this formal parameter
|
||||
// is good.
|
||||
DataTypeSet m_types;
|
||||
|
||||
// The <parameter type> string specified when registring an op.
|
||||
// It could be a supported data type or a type constraint key, which
|
||||
// maps to a set of supported data types.
|
||||
std::string m_typeStr;
|
||||
|
||||
// Formal parameter description
|
||||
std::string m_description;
|
||||
|
||||
};
|
||||
|
||||
// Attribute representation, including name, type, and allowed values.
|
||||
// The first element of allowed values (if specified) is the default
|
||||
// value.
|
||||
class Attribute
|
||||
{
|
||||
public:
|
||||
|
||||
// Constructor.
|
||||
explicit Attribute(const std::string& p_attrName,
|
||||
AttrType p_type,
|
||||
const std::string& p_description);
|
||||
|
||||
// Constructor with default value.
|
||||
explicit Attribute(const std::string& p_attrName,
|
||||
AttrType p_type,
|
||||
const std::string& p_description,
|
||||
const AttributeProto& p_defaultVal);
|
||||
|
||||
// Get attribute name.
|
||||
const std::string& GetName() const;
|
||||
|
||||
// Get attribute type.
|
||||
AttrType GetType() const;
|
||||
|
||||
// Get to know whether this attribute has default value,
|
||||
// if yes, <p_value> will be assigned to be the default value.
|
||||
bool HasDefaultValue(const AttributeProto** p_value) const;
|
||||
|
||||
private:
|
||||
|
||||
Attribute() {}
|
||||
|
||||
// Attribute name.
|
||||
std::string m_name;
|
||||
|
||||
// Attribute type.
|
||||
AttrType m_type;
|
||||
|
||||
// Attribute description.
|
||||
std::string m_description;
|
||||
|
||||
// Flag indicates whether a default value specified.
|
||||
// It it's true, the first element of <m_allowedValues> is the
|
||||
// default value.
|
||||
bool m_hasDefaultValue;
|
||||
|
||||
// Allowed attribute values.
|
||||
std::vector<AttributeProto> m_allowedValues;
|
||||
};
|
||||
|
||||
static bool IsValidAttribute(const AttributeProto& p_attribute);
|
||||
|
||||
// Constructor.
|
||||
OpSignature() = default;
|
||||
|
||||
// Get operator name.
|
||||
const std::string& GetName() const;
|
||||
|
||||
// Get operator description.
|
||||
const std::string& GetDescription() const;
|
||||
|
||||
// Get input formal parameters.
|
||||
const std::vector<FormalParameter>& GetInputs() const;
|
||||
|
||||
// Get output formal parameters.
|
||||
const std::vector<FormalParameter>& GetOutputs() const;
|
||||
|
||||
// Get attributes.
|
||||
const std::vector<Attribute>& GetAttributes() const;
|
||||
|
||||
// Get type constraint map.
|
||||
const TypeConstraintMap& GetTypeConstraintMap() const;
|
||||
|
||||
// To support ONNX variable input/output compatibility.
|
||||
// Min and Max num arguments of last input/output.
|
||||
int GetOnnxMinInput() const { return m_onnxMinInput; }
|
||||
int GetOnnxMaxInput() const { return m_onnxMaxInput; }
|
||||
int GetOnnxMinOutput() const { return m_onnxMinOutput; }
|
||||
int GetOnnxMaxOutput() const { return m_onnxMaxOutput; }
|
||||
std::function<bool(int)> GetOnnxNumInputsAllowedFunc() const
|
||||
{
|
||||
return m_onnxNumInputsAllowed;
|
||||
}
|
||||
std::function<bool(int)> GetOnnxNumOutputsAllowedFunc() const
|
||||
{
|
||||
return m_onnxNumOutputsAllowed;
|
||||
}
|
||||
std::function<bool(int, int)> GetOnnxNumInputsOutputsAllowedFunc() const
|
||||
{
|
||||
return m_onnxNumInputsOutputsAllowed;
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
friend class OperatorSchemaSetter;
|
||||
friend class OperatorSchemaRegistry;
|
||||
|
||||
// Operator name.
|
||||
std::string m_name;
|
||||
|
||||
// Operator description.
|
||||
std::string m_description;
|
||||
|
||||
// Operator input formal parameters.
|
||||
std::vector<FormalParameter> m_inputs;
|
||||
|
||||
// Operator output formal parameters.
|
||||
std::vector<FormalParameter> m_outputs;
|
||||
|
||||
// Operator attributes' definitions.
|
||||
std::vector<Attribute> m_attributes;
|
||||
|
||||
// Map from constraint name to DataTypeSet
|
||||
TypeConstraintMap m_typeConstraintMap;
|
||||
|
||||
// To support ONNX variable input/output compatibility.
|
||||
// Min and Max num arguments of last input/output.
|
||||
int m_onnxMinInput = 0;
|
||||
int m_onnxMaxInput = std::numeric_limits<int>::max();
|
||||
int m_onnxMinOutput = 0;
|
||||
int m_onnxMaxOutput = std::numeric_limits<int>::max();
|
||||
std::function<bool(int)> m_onnxNumInputsAllowed =
|
||||
[](int) { return true; };
|
||||
std::function<bool(int)> m_onnxNumOutputsAllowed =
|
||||
[](int) { return true; };
|
||||
std::function<bool(int, int)> m_onnxNumInputsOutputsAllowed =
|
||||
[](int, int) { return true; };
|
||||
};
|
||||
}
|
||||
#endif
|
||||
|
||||
#pragma warning(pop)
|
|
@ -0,0 +1,39 @@
|
|||
#include "shape_inference.h"
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
InferenceContext::InferenceContext(Node* p_node,
|
||||
const OpSignature* p_opSchema)
|
||||
: m_node(p_node),
|
||||
m_opSignature(p_opSchema)
|
||||
{
|
||||
}
|
||||
|
||||
const Node* InferenceContext::GetNode() const
|
||||
{
|
||||
return m_node;
|
||||
}
|
||||
|
||||
const OpSignature* InferenceContext::GetOp() const
|
||||
{
|
||||
return m_opSignature;
|
||||
}
|
||||
|
||||
const std::vector<NodeArg>* InferenceContext::GetInputs() const
|
||||
{
|
||||
if (nullptr == m_node)
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
return &(m_node->InputDefs());
|
||||
}
|
||||
|
||||
std::vector<NodeArg>* InferenceContext::Mutable_Outputs()
|
||||
{
|
||||
if (nullptr == m_node)
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
return &(m_node->Mutable_OutputDefs());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
#ifndef CORE_GRAPH_SHAPEINFERENCE_H
|
||||
#define CORE_GRAPH_SHAPEINFERENCE_H
|
||||
|
||||
#include "graph.h"
|
||||
#include "opsignature.h"
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
|
||||
// A context to contain information for shape inference function.
|
||||
// It includes the operator registry, input arguments definition,
|
||||
// and mutable output arguments, whose shapes needs to be filled.
|
||||
class InferenceContext
|
||||
{
|
||||
public:
|
||||
|
||||
// TODO: Add input tensors into constructor.
|
||||
// TODO: An abstract tensor interface will be needed.
|
||||
// In some cases, node evaluation will be needed to get output shapes.
|
||||
InferenceContext(Node* p_node,
|
||||
const OpSignature* p_opSchema);
|
||||
|
||||
const Node* GetNode() const;
|
||||
|
||||
const OpSignature* GetOp() const;
|
||||
|
||||
const std::vector<NodeArg>* GetInputs() const;
|
||||
|
||||
std::vector<NodeArg>* Mutable_Outputs();
|
||||
|
||||
private:
|
||||
|
||||
Node* m_node;
|
||||
|
||||
const OpSignature* m_opSignature;
|
||||
};
|
||||
|
||||
// Shape inference function define.
|
||||
typedef std::function<Status(InferenceContext&)> ShapeInferenceFunc;
|
||||
}
|
||||
#endif
|
|
@ -0,0 +1,32 @@
|
|||
#include "status.h"
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
Status::Status(bool p_ok, const std::string& p_errMsg)
|
||||
{
|
||||
m_ok = p_ok;
|
||||
m_errMsg = p_errMsg;
|
||||
}
|
||||
|
||||
Status::Status(const Status& p_other)
|
||||
{
|
||||
m_ok = p_other.m_ok;
|
||||
m_errMsg = p_other.m_errMsg;
|
||||
}
|
||||
|
||||
bool Status::Ok() const
|
||||
{
|
||||
return m_ok;
|
||||
}
|
||||
|
||||
const std::string& Status::ErrorMsg() const
|
||||
{
|
||||
return m_errMsg;
|
||||
}
|
||||
|
||||
Status Status::OK()
|
||||
{
|
||||
static Status ok(true, "");
|
||||
return ok;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
#ifndef CORE_GRAPH_STATUS_H
|
||||
#define CORE_GRAPH_STATUS_H
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
|
||||
#define RETURN_IF_ERROR(expr) \
|
||||
do { \
|
||||
auto status = (expr); \
|
||||
if ((!status.Ok())) return status; \
|
||||
} while (0)
|
||||
|
||||
class Status
|
||||
{
|
||||
public:
|
||||
Status() = delete;
|
||||
|
||||
// Constructor.
|
||||
Status(bool p_ok, const std::string& p_errMsg);
|
||||
|
||||
// Copy constructor.
|
||||
Status(const Status& p_other);
|
||||
|
||||
// Getter of <m_ok>.
|
||||
bool Ok() const;
|
||||
|
||||
// Getter of <m_errMsg>.
|
||||
const std::string& ErrorMsg() const;
|
||||
|
||||
static Status OK();
|
||||
|
||||
private:
|
||||
|
||||
bool m_ok;
|
||||
std::string m_errMsg;
|
||||
};
|
||||
}
|
||||
|
||||
#endif // !CORE_GRAPH_STATUS_H
|
|
@ -0,0 +1,422 @@
|
|||
#pragma warning(push)
|
||||
#pragma warning(disable : 4800 4610 4512 4510 4267 4127 4125 4100 4456)
|
||||
|
||||
#include <cctype>
|
||||
#include <iterator>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "constants.h"
|
||||
#include "proto/onnx/protobuf/graph.pb.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
namespace Utils
|
||||
{
|
||||
std::unordered_map<std::string, TypeProto>& OpUtils::GetTypeStrToProtoMap()
|
||||
{
|
||||
static std::unordered_map<std::string, TypeProto>* typeStrToProtoMap =
|
||||
new std::unordered_map<std::string, TypeProto>();
|
||||
return *typeStrToProtoMap;
|
||||
}
|
||||
|
||||
PTYPE OpUtils::ToType(const TypeProto& p_type)
|
||||
{
|
||||
auto typeStr = ToString(p_type);
|
||||
if (GetTypeStrToProtoMap().find(typeStr) == GetTypeStrToProtoMap().end())
|
||||
{
|
||||
GetTypeStrToProtoMap()[typeStr] = p_type;
|
||||
}
|
||||
return &(GetTypeStrToProtoMap().find(typeStr)->first);
|
||||
}
|
||||
|
||||
PTYPE OpUtils::ToType(const std::string& p_type)
|
||||
{
|
||||
TypeProto type;
|
||||
FromString(p_type, type);
|
||||
return ToType(type);
|
||||
}
|
||||
|
||||
const TypeProto& OpUtils::ToTypeProto(const PTYPE& p_type)
|
||||
{
|
||||
auto it = GetTypeStrToProtoMap().find(*p_type);
|
||||
if (it != GetTypeStrToProtoMap().end())
|
||||
{
|
||||
return it->second;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::invalid_argument("PTYPE not found: " + *p_type);
|
||||
}
|
||||
}
|
||||
|
||||
std::string OpUtils::ToString(const TypeProto& p_type)
|
||||
{
|
||||
switch (p_type.value_case())
|
||||
{
|
||||
case TypeProto::ValueCase::kTensorType:
|
||||
return ToString(p_type.tensor_type().elem_type());
|
||||
case TypeProto::ValueCase::kSparseTensorType:
|
||||
return "sparse(" + ToString(p_type.sparse_tensor_type().elem_type()) + ")";
|
||||
case TypeProto::ValueCase::kSeqType:
|
||||
return "seq(" + ToString(p_type.seq_type().elem_type()) + ")";
|
||||
case TypeProto::ValueCase::kTupleType:
|
||||
{
|
||||
int size = p_type.tuple_type().elem_type_size();
|
||||
std::string tuple_str("tuple(");
|
||||
for (int i = 0; i < size - 1; i++)
|
||||
{
|
||||
tuple_str = tuple_str + ToString(p_type.tuple_type().elem_type(i)) + ",";
|
||||
}
|
||||
tuple_str += ToString(p_type.tuple_type().elem_type(size - 1));
|
||||
tuple_str += ")";
|
||||
return tuple_str;
|
||||
}
|
||||
case TypeProto::ValueCase::kMapType:
|
||||
{
|
||||
std::string map_str("map(");
|
||||
map_str = map_str + ToString(p_type.map_type().key_type()) + ","
|
||||
+ ToString(p_type.map_type().value_type()) + ")";
|
||||
return map_str;
|
||||
}
|
||||
case TypeProto::ValueCase::kHandleType:
|
||||
return "handle";
|
||||
default:
|
||||
throw std::invalid_argument("Unknown TypeProto");
|
||||
}
|
||||
}
|
||||
|
||||
std::string OpUtils::ToString(const TensorProto::DataType& p_type)
|
||||
{
|
||||
TypesWrapper& t = TypesWrapper::GetTypesWrapper();
|
||||
switch (p_type)
|
||||
{
|
||||
case TensorProto::DataType::TensorProto_DataType_BOOL:
|
||||
return t.c_bool;
|
||||
case TensorProto::DataType::TensorProto_DataType_STRING:
|
||||
return t.c_string;
|
||||
case TensorProto::DataType::TensorProto_DataType_FLOAT16:
|
||||
return t.c_float16;
|
||||
case TensorProto::DataType::TensorProto_DataType_FLOAT:
|
||||
return t.c_float;
|
||||
case TensorProto::DataType::TensorProto_DataType_DOUBLE:
|
||||
return t.c_double;
|
||||
case TensorProto::DataType::TensorProto_DataType_INT8:
|
||||
return t.c_int8;
|
||||
case TensorProto::DataType::TensorProto_DataType_INT16:
|
||||
return t.c_int16;
|
||||
case TensorProto::DataType::TensorProto_DataType_INT32:
|
||||
return t.c_int32;
|
||||
case TensorProto::DataType::TensorProto_DataType_INT64:
|
||||
return t.c_int64;
|
||||
case TensorProto::DataType::TensorProto_DataType_UINT8:
|
||||
return t.c_uint8;
|
||||
case TensorProto::DataType::TensorProto_DataType_UINT16:
|
||||
return t.c_uint16;
|
||||
case TensorProto::DataType::TensorProto_DataType_UINT32:
|
||||
return t.c_uint32;
|
||||
case TensorProto::DataType::TensorProto_DataType_UINT64:
|
||||
return t.c_uint64;
|
||||
case TensorProto::DataType::TensorProto_DataType_COMPLEX64:
|
||||
return t.c_complex64;
|
||||
case TensorProto::DataType::TensorProto_DataType_COMPLEX128:
|
||||
return t.c_complex128;
|
||||
}
|
||||
|
||||
throw std::invalid_argument("Unknown DataType");
|
||||
}
|
||||
|
||||
|
||||
void OpUtils::FromString(const std::string& p_src, TypeProto& p_type)
|
||||
{
|
||||
StringRange s(p_src);
|
||||
s.LAndRStrip();
|
||||
p_type.Clear();
|
||||
|
||||
if (s.LStrip("seq("))
|
||||
{
|
||||
s.RStrip(")");
|
||||
FromString(std::string(s.Data(), s.Size()), *p_type.mutable_seq_type()->mutable_elem_type());
|
||||
}
|
||||
else if (s.LStrip("tuple("))
|
||||
{
|
||||
s.RStrip(")");
|
||||
std::istringstream types(std::string(s.Data(), s.Size()));
|
||||
std::string type;
|
||||
while (std::getline(types, type, ','))
|
||||
{
|
||||
FromString(type, *p_type.mutable_tuple_type()->mutable_elem_type()->Add());
|
||||
}
|
||||
}
|
||||
else if (s.LStrip("map("))
|
||||
{
|
||||
size_t key_size = s.Find(',');
|
||||
StringRange k(s.Data(), key_size);
|
||||
std::string key = std::string(k.Data(), k.Size());
|
||||
s.LStrip(key_size);
|
||||
s.LStrip(",");
|
||||
size_t val_size = s.Find(')');
|
||||
StringRange v(s.Data(), val_size);
|
||||
std::string val = std::string(v.Data(), v.Size());
|
||||
|
||||
TensorProto::DataType key_type;
|
||||
FromString(key, key_type);
|
||||
TensorProto::DataType val_type;
|
||||
FromString(val, val_type);
|
||||
p_type.mutable_map_type()->set_key_type(key_type);
|
||||
p_type.mutable_map_type()->set_value_type(val_type);
|
||||
}
|
||||
else if (s.LStrip("handle"))
|
||||
{
|
||||
p_type.mutable_handle_type();
|
||||
}
|
||||
else if (s.LStrip("sparse("))
|
||||
{
|
||||
s.RStrip(")");
|
||||
TensorProto::DataType e;
|
||||
FromString(std::string(s.Data(), s.Size()), e);
|
||||
p_type.mutable_sparse_tensor_type()->set_elem_type(e);
|
||||
}
|
||||
else
|
||||
{
|
||||
// dense tensor
|
||||
TensorProto::DataType e;
|
||||
FromString(std::string(s.Data(), s.Size()), e);
|
||||
p_type.mutable_tensor_type()->set_elem_type(e);
|
||||
}
|
||||
}
|
||||
|
||||
bool OpUtils::IsValidDataTypeString(const std::string& p_dataType)
|
||||
{
|
||||
TypesWrapper& t = TypesWrapper::GetTypesWrapper();
|
||||
return (t.GetAllowedDataTypes().find(p_dataType) != t.GetAllowedDataTypes().end());
|
||||
}
|
||||
|
||||
void OpUtils::FromString(const std::string& p_typeStr, TensorProto::DataType& p_type)
|
||||
{
|
||||
if (!IsValidDataTypeString(p_typeStr))
|
||||
{
|
||||
throw std::invalid_argument("Unknown DataType: " + p_typeStr);
|
||||
}
|
||||
|
||||
TypesWrapper& t = TypesWrapper::GetTypesWrapper();
|
||||
if (p_typeStr == t.c_bool)
|
||||
{
|
||||
p_type = TensorProto::DataType::TensorProto_DataType_BOOL;
|
||||
}
|
||||
else if (p_typeStr == t.c_float)
|
||||
{
|
||||
p_type = TensorProto::DataType::TensorProto_DataType_FLOAT;
|
||||
}
|
||||
else if (p_typeStr == t.c_float16)
|
||||
{
|
||||
p_type = TensorProto::DataType::TensorProto_DataType_FLOAT16;
|
||||
}
|
||||
else if (p_typeStr == t.c_double)
|
||||
{
|
||||
p_type = TensorProto::DataType::TensorProto_DataType_DOUBLE;
|
||||
}
|
||||
else if (p_typeStr == t.c_int8)
|
||||
{
|
||||
p_type = TensorProto::DataType::TensorProto_DataType_INT8;
|
||||
}
|
||||
else if (p_typeStr == t.c_int16)
|
||||
{
|
||||
p_type = TensorProto::DataType::TensorProto_DataType_INT16;
|
||||
}
|
||||
else if (p_typeStr == t.c_int32)
|
||||
{
|
||||
p_type = TensorProto::DataType::TensorProto_DataType_INT32;
|
||||
}
|
||||
else if (p_typeStr == t.c_int64)
|
||||
{
|
||||
p_type = TensorProto::DataType::TensorProto_DataType_INT64;
|
||||
}
|
||||
else if (p_typeStr == t.c_string)
|
||||
{
|
||||
p_type = TensorProto::DataType::TensorProto_DataType_STRING;
|
||||
}
|
||||
else if (p_typeStr == t.c_uint8)
|
||||
{
|
||||
p_type = TensorProto::DataType::TensorProto_DataType_UINT8;
|
||||
}
|
||||
else if (p_typeStr == t.c_uint16)
|
||||
{
|
||||
p_type = TensorProto::DataType::TensorProto_DataType_UINT16;
|
||||
}
|
||||
else if (p_typeStr == t.c_uint32)
|
||||
{
|
||||
p_type = TensorProto::DataType::TensorProto_DataType_UINT32;
|
||||
}
|
||||
else if (p_typeStr == t.c_uint64)
|
||||
{
|
||||
p_type = TensorProto::DataType::TensorProto_DataType_UINT64;
|
||||
}
|
||||
else if (p_typeStr == t.c_complex64)
|
||||
{
|
||||
p_type = TensorProto::DataType::TensorProto_DataType_COMPLEX64;
|
||||
}
|
||||
else if (p_typeStr == t.c_complex128)
|
||||
{
|
||||
p_type = TensorProto::DataType::TensorProto_DataType_COMPLEX128;
|
||||
}
|
||||
else
|
||||
{
|
||||
p_type = TensorProto::DataType::TensorProto_DataType_UNDEFINED;
|
||||
}
|
||||
}
|
||||
|
||||
StringRange::StringRange()
|
||||
: m_data(""), m_size(0)
|
||||
{}
|
||||
|
||||
StringRange::StringRange(const char* p_data, size_t p_size)
|
||||
: m_data(p_data), m_size(p_size)
|
||||
{}
|
||||
|
||||
StringRange::StringRange(const std::string& p_str)
|
||||
: m_data(p_str.data()), m_size(p_str.size())
|
||||
{}
|
||||
|
||||
StringRange::StringRange(const char* p_data)
|
||||
: m_data(p_data), m_size(strlen(p_data))
|
||||
{}
|
||||
|
||||
const char* StringRange::Data() const
|
||||
{
|
||||
return m_data;
|
||||
}
|
||||
|
||||
size_t StringRange::Size() const
|
||||
{
|
||||
return m_size;
|
||||
}
|
||||
|
||||
bool StringRange::Empty() const
|
||||
{
|
||||
return m_size == 0;
|
||||
}
|
||||
|
||||
char StringRange::operator[](size_t p_idx) const
|
||||
{
|
||||
return m_data[p_idx];
|
||||
}
|
||||
|
||||
void StringRange::Reset()
|
||||
{
|
||||
m_data = "";
|
||||
m_size = 0;
|
||||
}
|
||||
|
||||
void StringRange::Reset(const char* p_data, size_t p_size)
|
||||
{
|
||||
m_data = p_data;
|
||||
m_size = p_size;
|
||||
}
|
||||
|
||||
void StringRange::Reset(const std::string& p_str)
|
||||
{
|
||||
m_data = p_str.data();
|
||||
m_size = p_str.size();
|
||||
}
|
||||
|
||||
bool StringRange::StartsWith(const StringRange& p_str) const
|
||||
{
|
||||
return ((m_size >= p_str.m_size) && (memcmp(m_data, p_str.m_data, p_str.m_size) == 0));
|
||||
}
|
||||
|
||||
bool StringRange::EndsWith(const StringRange& p_str) const
|
||||
{
|
||||
return ((m_size >= p_str.m_size) &&
|
||||
(memcmp(m_data + (m_size - p_str.m_size), p_str.m_data, p_str.m_size) == 0));
|
||||
}
|
||||
|
||||
bool StringRange::LStrip() {
|
||||
size_t count = 0;
|
||||
const char* ptr = m_data;
|
||||
while (count < m_size && isspace(*ptr)) {
|
||||
count++;
|
||||
ptr++;
|
||||
}
|
||||
|
||||
if (count > 0)
|
||||
{
|
||||
return LStrip(count);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
bool StringRange::LStrip(size_t p_size)
|
||||
{
|
||||
if (p_size <= m_size)
|
||||
{
|
||||
m_data += p_size;
|
||||
m_size -= p_size;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool StringRange::LStrip(StringRange p_str)
|
||||
{
|
||||
if (StartsWith(p_str)) {
|
||||
return LStrip(p_str.m_size);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool StringRange::RStrip() {
|
||||
size_t count = 0;
|
||||
const char* ptr = m_data + m_size - 1;
|
||||
while (count < m_size && isspace(*ptr)) {
|
||||
++count;
|
||||
--ptr;
|
||||
}
|
||||
|
||||
if (count > 0)
|
||||
{
|
||||
return RStrip(count);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool StringRange::RStrip(size_t p_size)
|
||||
{
|
||||
if (m_size >= p_size)
|
||||
{
|
||||
m_size -= p_size;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool StringRange::RStrip(StringRange p_str)
|
||||
{
|
||||
if (EndsWith(p_str)) {
|
||||
return RStrip(p_str.m_size);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool StringRange::LAndRStrip()
|
||||
{
|
||||
return LStrip() || RStrip();
|
||||
}
|
||||
|
||||
size_t StringRange::Find(const char p_ch) const
|
||||
{
|
||||
size_t idx = 0;
|
||||
while (idx < m_size)
|
||||
{
|
||||
if (m_data[idx] == p_ch)
|
||||
{
|
||||
return idx;
|
||||
}
|
||||
idx++;
|
||||
}
|
||||
return std::string::npos;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#pragma warning(pop)
|
|
@ -0,0 +1,64 @@
|
|||
#ifndef ONNXIR_UTILS_H
|
||||
#define ONNXIR_UTILS_H
|
||||
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <string>
|
||||
|
||||
class TensorProto;
|
||||
class TypeProto;
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
typedef const std::string* PTYPE;
|
||||
|
||||
namespace Utils
|
||||
{
|
||||
class OpUtils
|
||||
{
|
||||
public:
|
||||
static PTYPE ToType(const TypeProto& p_type);
|
||||
static PTYPE ToType(const std::string& p_type);
|
||||
static const TypeProto& ToTypeProto(const PTYPE& p_type);
|
||||
static std::string ToString(const TypeProto& p_type);
|
||||
static std::string ToString(const TensorProto::DataType& p_type);
|
||||
static void FromString(const std::string& p_src, TypeProto& p_type);
|
||||
static void FromString(const std::string& p_src, TensorProto::DataType& p_type);
|
||||
static bool IsValidDataTypeString(const std::string &p_dataType);
|
||||
private:
|
||||
static std::unordered_map<std::string, TypeProto>& GetTypeStrToProtoMap();
|
||||
};
|
||||
|
||||
class StringRange
|
||||
{
|
||||
public:
|
||||
StringRange();
|
||||
StringRange(const char* p_data, size_t p_size);
|
||||
StringRange(const std::string& p_str);
|
||||
StringRange(const char* p_data);
|
||||
const char* Data() const;
|
||||
size_t Size() const;
|
||||
bool Empty() const;
|
||||
char operator[](size_t p_idx) const;
|
||||
void Reset();
|
||||
void Reset(const char* p_data, size_t p_size);
|
||||
void Reset(const std::string& p_str);
|
||||
bool StartsWith(const StringRange& p_str) const;
|
||||
bool EndsWith(const StringRange& p_str) const;
|
||||
bool LStrip();
|
||||
bool LStrip(size_t p_size);
|
||||
bool LStrip(StringRange p_str);
|
||||
bool RStrip();
|
||||
bool RStrip(size_t p_size);
|
||||
bool RStrip(StringRange p_str);
|
||||
bool LAndRStrip();
|
||||
size_t Find(const char p_ch) const;
|
||||
|
||||
private:
|
||||
const char* m_data;
|
||||
size_t m_size;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
#endif // ! ONNXIR_UTILS_H
|
|
@ -0,0 +1,329 @@
|
|||
// Copyright (c) Facebook Inc. and Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "proto/onnx/core/op.h"
|
||||
|
||||
using SupportType = ONNXIR::OpSchema::SupportType;
|
||||
namespace ONNXIR {
|
||||
REGISTER_OPERATOR_SCHEMA(ConstantFill)
|
||||
.SetSupportLevel(SupportType::EXPERIMENTAL)
|
||||
.NumInputs(0, 1)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
The operator fills the elements of the output tensor with a constant value
|
||||
specified by the 'value' argument.
|
||||
|
||||
The data type is specified by the 'dtype' argument. The 'dtype' argument must
|
||||
be one of the data types specified in the 'DataType' enum field in the
|
||||
TensorProto message. If the 'dtype' argument is not provided, the data type of
|
||||
'value' is used.
|
||||
|
||||
The output tensor shape is specified by the 'shape' argument. If the number of
|
||||
input is 1, the shape will be identical to that of the input at run time with
|
||||
optional additional dimensions appended at the end as specified by 'extra_shape'
|
||||
argument. In that case the 'shape' argument should not be set.
|
||||
|
||||
If input_as_shape is set to true, then the input should be a 1D tensor
|
||||
containing the desired output shape (the dimensions specified in extra_shape
|
||||
will also be appended)
|
||||
|
||||
NOTE: Currently, it supports data type of float, int32, int64, and bool.
|
||||
)DOC")
|
||||
.Attr("value",
|
||||
"The value for the elements of the output tensor.",
|
||||
AttrType::FLOAT)
|
||||
.Attr(
|
||||
"dtype",
|
||||
"The data type for the elements of the output tensor."
|
||||
"Strictly must be one of the types from DataType enum in TensorProto.",
|
||||
AttrType::INT)
|
||||
.Attr(
|
||||
"shape",
|
||||
"The shape of the output tensor."
|
||||
"Cannot set the shape argument and pass in an input at the same time.",
|
||||
AttrType::INTS)
|
||||
.Attr(
|
||||
"extra_shape",
|
||||
"The additional dimensions appended at the end of the shape indicated"
|
||||
"by the input blob."
|
||||
"Cannot set the extra_shape argument when there is no input blob.",
|
||||
AttrType::INTS)
|
||||
.Attr(
|
||||
"input_as_shape",
|
||||
"1D tensor containing the desired output shape. First input must be in "
|
||||
"CPU context.",
|
||||
AttrType::INT)
|
||||
.Input(0, "input", "Input tensor (optional) to provide shape information.")
|
||||
.Output(
|
||||
0,
|
||||
"output",
|
||||
"Output tensor of constant values specified by 'value'"
|
||||
"argument and its type is specified by the 'dtype' argument");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Caffe2ConvTranspose)
|
||||
.SetSupportLevel(SupportType::EXPERIMENTAL)
|
||||
.NumInputs(3)
|
||||
.NumOutputs(1)
|
||||
.SetDoc(R"DOC(
|
||||
The transposed convolution consumes an input vector, the filter blob, and
|
||||
the bias blob, and computes the output. Note that other parameters, such as
|
||||
the stride and kernel size, or the pads' sizes in each direction are not
|
||||
necessary for input because they are provided by the
|
||||
ConvTransposeUnpoolOpBase operator. Various dimension checks are done
|
||||
implicitly, and the sizes are specified in the Input docs for this operator.
|
||||
As is expected, the filter is deconvolved with a subset of the
|
||||
image and the bias is added; this is done throughout the image data and the
|
||||
output is computed. As a side note on the implementation layout:
|
||||
conv_transpose_op_impl.h is the templated implementation of the
|
||||
conv_transpose_op.h file, which is why they are separate files.
|
||||
)DOC")
|
||||
.Input(
|
||||
0,
|
||||
"X",
|
||||
"Input data blob from previous layer; has size "
|
||||
"(N x C x H x W), where N is the batch size, C is the number of channels, and"
|
||||
" H and W are the height and width. Note that this is for the NCHW usage. On "
|
||||
"the other hand, the NHWC Op has a different set of dimension constraints.")
|
||||
.Input(
|
||||
1,
|
||||
"filter",
|
||||
"The filter blob that will be used in the transposed "
|
||||
"convolution; has size (M x C x kH x kW), where C is the number of channels,"
|
||||
" and kH and kW are the height and width of the kernel.")
|
||||
.Input(
|
||||
2,
|
||||
"bias",
|
||||
"The 1D bias blob that is added through the convolution;"
|
||||
"has size (C)")
|
||||
.Output(
|
||||
0,
|
||||
"Y",
|
||||
"Output data blob that contains the result of the "
|
||||
"transposed convolution. The output dimensions are functions of the kernel"
|
||||
" size, stride size, and pad lengths.")
|
||||
.Attr("pads", "", AttrType::INTS)
|
||||
.Attr("kernel_shape", "", AttrType::INTS)
|
||||
.Attr("dilations", "", AttrType::INTS)
|
||||
.Attr("group", "", AttrType::INT)
|
||||
.Attr("strides", "", AttrType::INTS);
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(SpatialBN)
|
||||
.SetSupportLevel(SupportType::EXPERIMENTAL)
|
||||
.NumInputs(5)
|
||||
.NumOutputs({ 1, 5 })
|
||||
.EnforceConsumed({ {3, 1}, {4, 2} })
|
||||
.SetDoc(R"DOC(
|
||||
Carries out batch normalization as described in the paper
|
||||
https://arxiv.org/abs/1502.03167. Depending on the mode it is being run,
|
||||
there are multiple cases for the number of outputs, which we list below:
|
||||
|
||||
Output case #1: Y, mean, var, saved_mean, saved_var (training mode)
|
||||
Output case #2: Y (test mode)
|
||||
)DOC")
|
||||
.Attr("is_test",
|
||||
"If set to nonzero, run spatial batch normalization in test mode.",
|
||||
AttrType::INT)
|
||||
.Attr("epsilon",
|
||||
"The epsilon value to use to avoid division by zero.",
|
||||
AttrType::FLOAT)
|
||||
.Attr("momentum",
|
||||
"Factor used in computing the running mean and variance."
|
||||
"e.g., running_mean = running_mean * momentum + mean * (1 - momentum)",
|
||||
AttrType::FLOAT)
|
||||
.Input(0,
|
||||
"X",
|
||||
"The input 4-dimensional tensor of shape NCHW.")
|
||||
.Input(1,
|
||||
"scale",
|
||||
"The scale as a 1-dimensional tensor of size C to be applied to the "
|
||||
"output.")
|
||||
.Input(2,
|
||||
"bias",
|
||||
"The bias as a 1-dimensional tensor of size C to be applied to the "
|
||||
"output.")
|
||||
.Input(3,
|
||||
"mean",
|
||||
"The running mean (training) or the estimated mean (testing) "
|
||||
"as a 1-dimensional tensor of size C.")
|
||||
.Input(4,
|
||||
"var",
|
||||
"The running variance (training) or the estimated "
|
||||
"variance (testing) as a 1-dimensional tensor of size C.")
|
||||
.Output(0, "Y", "The output 4-dimensional tensor of the same shape as X.")
|
||||
.Output(1,
|
||||
"mean",
|
||||
"The running mean after the spatial BN operator. Must be in-place "
|
||||
"with the input mean. Should not be used for testing.")
|
||||
.Output(2,
|
||||
"var",
|
||||
"The running variance after the spatial BN operator. Must be "
|
||||
"in-place with the input var. Should not be used for testing.")
|
||||
.Output(3,
|
||||
"saved_mean",
|
||||
"Saved mean used during training to speed up gradient "
|
||||
"computation. Should not be used for testing.")
|
||||
.Output(4,
|
||||
"saved_var",
|
||||
"Saved variance used during training to speed up "
|
||||
"gradient computation. Should not be used for testing.");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(LRN)
|
||||
.SetSupportLevel(SupportType::EXPERIMENTAL)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1, 2)
|
||||
.Attr("size", "", AttrType::INT)
|
||||
.Attr("alpha", "", AttrType::FLOAT)
|
||||
.Attr("beta", "", AttrType::FLOAT)
|
||||
.Attr("bias", "", AttrType::FLOAT);
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(GivenTensorFill)
|
||||
.SetSupportLevel(SupportType::EXPERIMENTAL)
|
||||
.NumInputs(0, 1)
|
||||
.NumOutputs(1)
|
||||
.Input(0, "shape", "The shape of filled tensor")
|
||||
.Output(0, "X", "The filled tensor")
|
||||
.Attr("values", "", AttrType::FLOATS)
|
||||
.Attr("shape", "", AttrType::INTS)
|
||||
.Attr("input_as_shape", "", AttrType::INT)
|
||||
.Attr("extra_shape", "", AttrType::INTS)
|
||||
.AllowConsumed({ {0, 0} });
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(FC)
|
||||
.SetSupportLevel(SupportType::EXPERIMENTAL)
|
||||
.NumInputs(3)
|
||||
.NumOutputs(1)
|
||||
.SetDoc(R"DOC(
|
||||
Computes the result of passing an input vector X into a fully
|
||||
connected layer with 2D weight matrix W and 1D bias vector b. That is,
|
||||
the layer computes Y = X * W^T + b, where X has size (M x K),
|
||||
W has size (N x K), b has size (N), and Y has size (M x N),
|
||||
where M is often the batch size.
|
||||
NOTE: X does not need to explicitly be a 2D vector; rather, it will be
|
||||
coerced into one. For an arbitrary n-dimensional tensor
|
||||
X \in [a_0, a_1, ...,a_{k-1}, a_k, ..., a_{n-1}] where a_i \in N+ and k is
|
||||
the axis provided, then X will be coerced into a 2-dimensional tensor with
|
||||
dimensions [a_0 * ... * a_{k-1}, a_k * ... * a_{n-1}]. For the default
|
||||
case where axis=1, this means the X tensor will be coerced into a 2D tensor
|
||||
of dimensions [a_0, a_1 * ... * a_{n-1}], where a_0 is often the batch size.
|
||||
In this situation, we must have a_0 = M and a_1 * ... * a_{n-1} = K.
|
||||
Lastly, even though b is a 1D vector of size N, it is copied/resized to
|
||||
be size (M x N) implicitly and added to each vector in the batch.
|
||||
Each of these dimensions must be matched correctly, or else the operator
|
||||
will throw errors.
|
||||
)DOC")
|
||||
.Attr(
|
||||
"axis",
|
||||
"(int32_t) default to 1; describes the axis of the inputs; "
|
||||
"defaults to one because the 0th axis most likely describes "
|
||||
"the batch_size",
|
||||
AttrType::INT)
|
||||
.Attr(
|
||||
"axis_w",
|
||||
"(int32_t) default to 1; describes the axis of the weights; "
|
||||
"defaults to one because the 0th axis most likely describes "
|
||||
"the batch_size",
|
||||
AttrType::INT)
|
||||
.Input(
|
||||
0,
|
||||
"X",
|
||||
"input tensor that's coerced into a 2D matrix of size (MxK) "
|
||||
"as described above")
|
||||
.Input(
|
||||
1,
|
||||
"W",
|
||||
"2D blob of size (KxN) containing fully connected weight "
|
||||
"matrix")
|
||||
.Input(2, "b", "1D blob containing bias vector")
|
||||
.Output(0, "Y", "2D output tensor");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Normalize)
|
||||
.SetSupportLevel(SupportType::EXPERIMENTAL)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.SetDoc(R"DOC(
|
||||
Given a matrix, apply L2-normalization along the last dimension.
|
||||
)DOC");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Scale)
|
||||
.SetSupportLevel(SupportType::EXPERIMENTAL)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
Scale takes one input data (Tensor<float>) and produces one output data
|
||||
(Tensor<float>) whose value is the input data tensor scaled element-wise.
|
||||
)DOC")
|
||||
.Attr("scale",
|
||||
"(float, default 1.0) the scale to apply.",
|
||||
AttrType::FLOAT);
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(ChannelShuffle)
|
||||
.SetSupportLevel(SupportType::EXPERIMENTAL)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.Attr("kernel_shape",
|
||||
"The size of the kernel along each axis",
|
||||
AttrType::INTS)
|
||||
.Attr("group",
|
||||
"Number of channel groups",
|
||||
AttrType::INT);
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(RecurrentNetwork)
|
||||
.SetSupportLevel(SupportType::EXPERIMENTAL)
|
||||
.NumInputs(1, INT_MAX)
|
||||
.NumOutputs(2, INT_MAX)
|
||||
.SetDoc(R"DOC(
|
||||
Run the input network in a recurrent fashion. This can be used to
|
||||
implement fairly general recurrent neural networks (RNNs).
|
||||
The operator proceeds as follows.
|
||||
- First, initialized the states from the input recurrent states
|
||||
- For each timestep T, apply the links (that map offsets from input/output
|
||||
tensors into the inputs/outputs for the `step` network)
|
||||
- Finally, alias the recurrent states to the specified output blobs.
|
||||
This is a fairly special-case meta-operator, and so the implementation
|
||||
is somewhat complex. It trades of generality (and frankly usability)
|
||||
against performance and control (compared to e.g. TF
|
||||
dynamic_rnn, Theano scan, etc).
|
||||
See the usage examples for a flavor of how to use it.
|
||||
)DOC");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(GRUUnit)
|
||||
.SetSupportLevel(SupportType::EXPERIMENTAL)
|
||||
.NumInputs(4)
|
||||
.NumOutputs(1)
|
||||
.SetDoc(R"DOC(
|
||||
GRUUnit computes the activations of a standard GRU,
|
||||
in a sequence-length aware fashion.
|
||||
Concretely, given the (fused) inputs X (TxNxD), the previous hidden
|
||||
state (NxD), and the sequence lengths (N), computes the GRU
|
||||
activations, avoiding computation if the input is invalid (as in, the
|
||||
value at X[t][n] >= seqLengths[n].
|
||||
)DOC")
|
||||
.Attr(
|
||||
"drop_states",
|
||||
"Bool to determine if hidden state is zeroes or passed "
|
||||
"along for timesteps past the given sequence_length.",
|
||||
AttrType::INT)
|
||||
.Input(0, "hidden_prev", "The previous GRU hidden state.")
|
||||
.Input(
|
||||
1,
|
||||
"gates",
|
||||
"Unactivated gate outputs from forget, update, "
|
||||
"and output gates, pre-activation.")
|
||||
.Input(
|
||||
2,
|
||||
"seq_lengths",
|
||||
"Array of sequence lengths. "
|
||||
"len(seq_lengths) should equal batch size N.")
|
||||
.Input(3, "t", "The timestep for this operation.")
|
||||
.Output(0, "hidden", "The new GRU hidden state calculated by this op.");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(ATen)
|
||||
.SetSupportLevel(SupportType::EXPERIMENTAL)
|
||||
.AllowUncheckedAttributes()
|
||||
.SetDoc(R"DOC(
|
||||
Experimental allowing ATen operations to be accessed directly from Caffe2
|
||||
to allow for quick prototyping when ONNX is missing standard versions of
|
||||
and op)DOC");
|
||||
}
|
|
@ -0,0 +1,168 @@
|
|||
// Copyright (c) Facebook Inc. and Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "proto/onnx/core/op.h"
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
REGISTER_OPERATOR_SCHEMA(Constant)
|
||||
.NumInputs(0)
|
||||
.NumOutputs(1)
|
||||
.SetDoc(R"DOC(A constant tensor.)DOC")
|
||||
.Attr("value",
|
||||
"The value for the elements of the output tensor.",
|
||||
AttrType::TENSOR)
|
||||
.Output(
|
||||
0,
|
||||
"output",
|
||||
"Output tensor containing the same value of the provided tensor.");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(RandomUniform)
|
||||
.NumInputs(0)
|
||||
.NumOutputs(1)
|
||||
.SetDoc(R"DOC(
|
||||
Generate a tensor with random values drawn from a uniform distribution. The shape
|
||||
of the tensor is specified by the `shape` argument and the range by `low` and `high`.
|
||||
|
||||
The data type is specified by the 'dtype' argument. The 'dtype' argument must
|
||||
be one of the data types specified in the 'DataType' enum field in the
|
||||
TensorProto message.
|
||||
)DOC")
|
||||
.Attr(
|
||||
"low",
|
||||
"Lower boundary of the output values.",
|
||||
AttrType::FLOAT)
|
||||
.Attr(
|
||||
"high",
|
||||
"Upper boundary of the output values.",
|
||||
AttrType::FLOAT)
|
||||
.Attr(
|
||||
"seed",
|
||||
"(Optional) Seed to the random generator, if not specified we will auto generate one.",
|
||||
AttrType::FLOAT)
|
||||
.Attr(
|
||||
"dtype",
|
||||
"The data type for the elements of the output tensor.",
|
||||
AttrType::INT)
|
||||
.Attr(
|
||||
"shape",
|
||||
"The shape of the output tensor.",
|
||||
AttrType::INTS)
|
||||
.Output(
|
||||
0,
|
||||
"output",
|
||||
"Output tensor of random values drawn from uniform distribution");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(RandomNormal)
|
||||
.NumInputs(0)
|
||||
.NumOutputs(1)
|
||||
.SetDoc(R"DOC(
|
||||
Generate a tensor with random values drawn from a normal distribution. The shape
|
||||
of the tensor is specified by the `shape` argument and the parameter of the normal distribution
|
||||
specified by `mean` and `scale`.
|
||||
|
||||
The data type is specified by the 'dtype' argument. The 'dtype' argument must
|
||||
be one of the data types specified in the 'DataType' enum field in the
|
||||
TensorProto message.
|
||||
)DOC")
|
||||
.Attr(
|
||||
"mean",
|
||||
"The mean of the normal distribution.",
|
||||
AttrType::FLOAT)
|
||||
.Attr(
|
||||
"scale",
|
||||
"The standard deviation of the normal distribution.",
|
||||
AttrType::FLOAT)
|
||||
.Attr(
|
||||
"seed",
|
||||
"(Optional) Seed to the random generator, if not specified we will auto generate one.",
|
||||
AttrType::FLOAT)
|
||||
.Attr(
|
||||
"dtype",
|
||||
"The data type for the elements of the output tensor.",
|
||||
AttrType::INT)
|
||||
.Attr(
|
||||
"shape",
|
||||
"The shape of the output tensor.",
|
||||
AttrType::INTS)
|
||||
.Output(
|
||||
0,
|
||||
"output",
|
||||
"Output tensor of random values drawn from normal distribution");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(RandomUniformLike)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.SetDoc(R"DOC(
|
||||
Generate a tensor with random values drawn from a uniform distribution. The shape
|
||||
of the tensor is computed from the input argument and the range by `low` and `high`.
|
||||
|
||||
The data type is specified by the 'dtype' argument. The 'dtype' argument must
|
||||
be one of the data types specified in the 'DataType' enum field in the
|
||||
TensorProto message.
|
||||
)DOC")
|
||||
.Attr(
|
||||
"low",
|
||||
"Lower boundary of the output values.",
|
||||
AttrType::FLOAT)
|
||||
.Attr(
|
||||
"high",
|
||||
"Upper boundary of the output values.",
|
||||
AttrType::FLOAT)
|
||||
.Attr(
|
||||
"seed",
|
||||
"(Optional) Seed to the random generator, if not specified we will auto generate one.",
|
||||
AttrType::FLOAT)
|
||||
.Attr(
|
||||
"dtype",
|
||||
"(Optional) The data type for the elements of the output tensor, if not specified, we will use"
|
||||
"the data type of the input tensor.",
|
||||
AttrType::INT)
|
||||
.Input(
|
||||
0,
|
||||
"input",
|
||||
"Input tensor to provide shape information.")
|
||||
.Output(
|
||||
0,
|
||||
"output",
|
||||
"Output tensor of random values drawn from uniform distribution");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(RandomNormalLike)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.SetDoc(R"DOC(
|
||||
Generate a tensor with random values drawn from a normal distribution. The shape
|
||||
of the tensor is computed from the input argument and the parameter of the normal distribution
|
||||
specified by `mean` and `scale`.
|
||||
|
||||
The data type is specified by the 'dtype' argument. The 'dtype' argument must
|
||||
be one of the data types specified in the 'DataType' enum field in the
|
||||
TensorProto message.
|
||||
)DOC")
|
||||
.Attr(
|
||||
"mean",
|
||||
"The mean of the normal distribution.",
|
||||
AttrType::FLOAT)
|
||||
.Attr(
|
||||
"scale",
|
||||
"The standard deviation of the normal distribution.",
|
||||
AttrType::FLOAT)
|
||||
.Attr(
|
||||
"seed",
|
||||
"(Optional) Seed to the random generator, if not specified we will auto generate one.",
|
||||
AttrType::FLOAT)
|
||||
.Attr(
|
||||
"dtype",
|
||||
"(Optional) The data type for the elements of the output tensor, if not specified, we will use"
|
||||
"the data type of the input tensor.",
|
||||
AttrType::INT)
|
||||
.Input(
|
||||
0,
|
||||
"input",
|
||||
"Input tensor to provide shape information.")
|
||||
.Output(
|
||||
0,
|
||||
"output",
|
||||
"Output tensor of random values drawn from normal distribution");
|
||||
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
// Copyright (c) Facebook Inc. and Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "proto/onnx/core/op.h"
|
||||
|
||||
namespace ONNXIR {
|
||||
|
||||
std::function<void(OpSchema&)> BinaryLogicDocGenerator(const char* name) {
|
||||
return [=](OpSchema& schema) {
|
||||
std::string doc = R"DOC(
|
||||
Computes the `{name} than` elementwise logical operation between `left` and `right` input tensor.
|
||||
The result is a tensor of type integer in which `0` mean false and `1` mean true.)DOC";
|
||||
ReplaceAll(doc, "{name}", name);
|
||||
schema.NumInputs(2);
|
||||
schema.NumOutputs(1);
|
||||
schema.SetDoc(doc);
|
||||
schema.Input(0, "left", "Left input tensor for the logical operator.");
|
||||
schema.Input(1, "right", "Right input tensor for the logical operator.");
|
||||
schema.Output(0, "output", "Result tensor of type `int`, 0 mean False and 1 mean True.");
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace ONNXIR
|
|
@ -0,0 +1,385 @@
|
|||
// Copyright (c) Facebook Inc. and Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "proto/onnx/core/op.h"
|
||||
#include <functional>
|
||||
|
||||
using AttrType = ONNXIR::AttrType;
|
||||
|
||||
namespace ONNXIR {
|
||||
|
||||
const char* kBroadcastDoc = R"DOC(
|
||||
If necessary the right-hand-side argument will be broadcasted to match the
|
||||
shape of left-hand-side argument. When broadcasting is specified, the second
|
||||
tensor can either be of size 1 (a scalar value), or having its shape as a
|
||||
contiguous subset of the first tensor's shape. The starting of the mutually
|
||||
equal shape is specified by the argument "axis", and if it is not set, suffix
|
||||
matching is assumed. 1-dim expansion doesn't work yet.
|
||||
|
||||
For example, the following tensor shapes are supported (with broadcast=1):
|
||||
|
||||
shape(A) = (2, 3, 4, 5), shape(B) = (,), i.e. B is a scalar
|
||||
shape(A) = (2, 3, 4, 5), shape(B) = (5,)
|
||||
shape(A) = (2, 3, 4, 5), shape(B) = (4, 5)
|
||||
shape(A) = (2, 3, 4, 5), shape(B) = (3, 4), with axis=1
|
||||
shape(A) = (2, 3, 4, 5), shape(B) = (2), with axis=0
|
||||
|
||||
Attribute `broadcast=1` needs to be passed to enable broadcasting.
|
||||
)DOC";
|
||||
|
||||
std::function<void(OpSchema&)> MathDocGenerator(const char* name) {
|
||||
return [=](OpSchema& schema) {
|
||||
std::string doc = R"DOC(
|
||||
Performs element-wise binary {name} (with limited broadcast support).
|
||||
{broadcast_doc})DOC";
|
||||
ReplaceAll(doc, "{name}", name);
|
||||
ReplaceAll(doc, "{broadcast_doc}", kBroadcastDoc);
|
||||
schema.SetDoc(doc);
|
||||
schema.Attr("broadcast",
|
||||
"Pass 1 to enable broadcasting",
|
||||
AttrType::INT);
|
||||
schema.Attr("axis",
|
||||
"If set, defines the broadcast dimensions. See doc for details.",
|
||||
AttrType::INT);
|
||||
schema.Input(
|
||||
0,
|
||||
"A",
|
||||
"First operand, should share the type with the second operand.");
|
||||
schema.Input(
|
||||
1,
|
||||
"B",
|
||||
"Second operand. With broadcasting can be of smaller size than A. "
|
||||
"If broadcasting is disabled it should be of the same size.");
|
||||
schema.Output(0, "C", "Result, has same dimensions and type as A");
|
||||
};
|
||||
}
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Add)
|
||||
.NumInputs(2)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0}, {1, 0} })
|
||||
.FillUsing(MathDocGenerator("addition"));
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Sub)
|
||||
.NumInputs(2)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0}, {1, 0} })
|
||||
.FillUsing(MathDocGenerator("subtraction"));
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Mul)
|
||||
.NumInputs(2)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0}, {1, 0} })
|
||||
.FillUsing(MathDocGenerator("multiplication"));
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Div)
|
||||
.NumInputs(2)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0}, {1, 0} })
|
||||
.FillUsing(MathDocGenerator("division"));
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Neg)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
Neg takes one input data (Tensor<T>) and produces one output data
|
||||
(Tensor<T>) where each element flipped sign, y = -x, is applied to
|
||||
the tensor elementwise.
|
||||
)DOC")
|
||||
.Input(0, "X", "Input tensor")
|
||||
.Output(0, "Y", "Output tensor");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Abs)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
Absolute takes one input data (Tensor<T>) and produces one output data
|
||||
(Tensor<T>) where the absolute is, y = abs(x), is applied to
|
||||
the tensor elementwise.
|
||||
)DOC")
|
||||
.Input(0, "X", "Input tensor")
|
||||
.Output(0, "Y", "Output tensor");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Reciprocal)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
Reciprocal takes one input data (Tensor<T>) and produces one output data
|
||||
(Tensor<T>) where the reciprocal is, y = 1/x, is applied to
|
||||
the tensor elementwise.
|
||||
)DOC")
|
||||
.Input(0, "X", "Input tensor")
|
||||
.Output(0, "Y", "Output tensor");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Floor)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
Floor takes one input data (Tensor<T>) and produces one output data
|
||||
(Tensor<T>) where the floor is, y = floor(x), is applied to
|
||||
the tensor elementwise.
|
||||
)DOC")
|
||||
.Input(0, "X", "Input tensor")
|
||||
.Output(0, "Y", "Output tensor");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Ceil)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
Ceil takes one input data (Tensor<T>) and produces one output data
|
||||
(Tensor<T>) where the ceil is, y = ceil(x), is applied to
|
||||
the tensor elementwise.
|
||||
)DOC")
|
||||
.Input(0, "X", "Input tensor")
|
||||
.Output(0, "Y", "Output tensor");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Sqrt)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
Square root takes one input data (Tensor<T>) and produces one output data
|
||||
(Tensor<T>) where the square root is, y = x^0.5, is applied to
|
||||
the tensor elementwise. If x is negative, then it will return NaN.
|
||||
)DOC")
|
||||
.Input(0, "X", "Input tensor")
|
||||
.Output(0, "Y", "Output tensor");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Relu)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
Relu takes one input data (Tensor<T>) and produces one output data
|
||||
(Tensor<T>) where the rectified linear function, y = max(0, x), is applied to
|
||||
the tensor elementwise.
|
||||
)DOC")
|
||||
.Input(0, "X", "Input tensor")
|
||||
.Output(0, "Y", "Output tensor");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(LeakyRelu)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.Attr("alpha",
|
||||
"Coefficient of leakage",
|
||||
AttrType::FLOAT)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
LeakyRelu takes input data (Tensor<T>) and an argument alpha, and produces one
|
||||
output data (Tensor<T>) where the function `f(x) = alpha * x for x < 0`,
|
||||
`f(x) = x for x >= 0`, is applied to the data tensor elementwise.
|
||||
)DOC")
|
||||
.Input(0, "X", "Input tensor")
|
||||
.Output(0, "Y", "Output tensor");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Selu)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.Attr("alpha",
|
||||
"Coefficient of SELU default to 1.6732.",
|
||||
AttrType::FLOAT)
|
||||
.Attr("gamma",
|
||||
"Coefficient of SELU default to 1.0507.",
|
||||
AttrType::FLOAT)
|
||||
.SetDoc(R"DOC(
|
||||
Selu takes one input data (Tensor<T>) and produces one output data
|
||||
(Tensor<T>) where the scaled exponential linear unit function,
|
||||
`y = gamma * (alpha * e^x - alpha) for x <= 0`, `f(x) = gamma * x for x > 0`,
|
||||
is applied to the tensor elementwise.
|
||||
)DOC")
|
||||
.Input(0, "X", "Input tensor")
|
||||
.Output(0, "Y", "Output tensor");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Elu)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
|
||||
Elu takes one input data (Tensor<T>) and produces one output data
|
||||
(Tensor<T>) where the function `f(x) = alpha * (exp(x) - 1.) for x <
|
||||
0`, `f(x) = x for x >= 0`., is applied to the tensor elementwise.
|
||||
|
||||
)DOC")
|
||||
.Input(0, "X", "1D input tensor")
|
||||
.Output(0, "Y", "1D input tensor");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Exp)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
Calculates the exponential of the given input tensor, element-wise. This
|
||||
operation can be done in an in-place fashion too, by providing the same input
|
||||
and output blobs.
|
||||
)DOC")
|
||||
.Input(0, "input", "Input tensor")
|
||||
.Output(
|
||||
0,
|
||||
"output",
|
||||
"The exponential of the input tensor computed "
|
||||
"element-wise");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Log)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
Calculates the natural log of the given input tensor, element-wise. This
|
||||
operation can be done in an in-place fashion too, by providing the same input
|
||||
and output blobs.
|
||||
)DOC")
|
||||
.Input(0, "input", "Input tensor")
|
||||
.Output(
|
||||
0,
|
||||
"output",
|
||||
"The natural log of the input tensor computed "
|
||||
"element-wise");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Tanh)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
Calculates the hyperbolic tangent of the given input tensor element-wise. This
|
||||
operation can be done in an in-place fashion too, by providing the same input
|
||||
and output blobs.
|
||||
)DOC")
|
||||
.Input(0, "input", "1-D input tensor")
|
||||
.Output(0, "output", "The hyperbolic tangent values of the input tensor "
|
||||
"computed element-wise");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Pow)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.Attr("exponent",
|
||||
"The exponent of the power function.",
|
||||
AttrType::FLOAT)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
Pow takes input data (Tensor<T>) and an argument exponent, and
|
||||
produces one output data (Tensor<T>) where the function `f(x) = x^exponent`,
|
||||
is applied to the data tensor elementwise.
|
||||
)DOC")
|
||||
.Input(0, "X", "Input tensor of any shape")
|
||||
.Output(0, "Y", "Output tensor (same size as X)");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Dot)
|
||||
.NumInputs(2)
|
||||
.NumOutputs(1)
|
||||
.SetDoc(R"DOC(
|
||||
Apply dot product between 2 tensors. Similar to numpy implementation:
|
||||
https://docs.scipy.org/doc/numpy/reference/generated/numpy.dot.html
|
||||
)DOC")
|
||||
.Input(0, "X", "Input tensor of any shape")
|
||||
.Input(1, "Y", "Input tensor of any shape")
|
||||
.Output(0, "Z", "Output tensor the dot product between X and Y.");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(PRelu)
|
||||
.NumInputs(2)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
|
||||
PRelu takes input data (Tensor<T>) and slope tensor as input, and produces one
|
||||
output data (Tensor<T>) where the function `f(x) = slope * x for x < 0`,
|
||||
`f(x) = x for x >= 0`., is applied to the data tensor elementwise.
|
||||
|
||||
)DOC")
|
||||
.Input(0, "X", "Input tensor")
|
||||
.Input(
|
||||
1,
|
||||
"Slope",
|
||||
"Slope tensor. If `Slope` is of size 1, the value is shared"
|
||||
"across different channels")
|
||||
.Output(0, "Y", "Input tensor");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Sigmoid)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
Sigmoid takes one input data (Tensor<T>) and produces one output data
|
||||
(Tensor<T>) where the sigmoid function, y = 1 / (1 + exp(-x)), is applied to the
|
||||
tensor elementwise.
|
||||
)DOC")
|
||||
.Input(0, "X", "Input tensor")
|
||||
.Output(0, "Y", "Output tensor");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Max)
|
||||
.NumInputs(1, INT_MAX)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
Element-wise max of each of the input tensors. The first input tensor can be
|
||||
used in-place as the output tensor, in which case the max will be done in
|
||||
place and results will be accumulated in input0. All inputs and outputs must
|
||||
have the same shape and data type.
|
||||
)DOC")
|
||||
.Input(0, "data_0", "First of the input tensors. Can be inplace.")
|
||||
.Output(0, "max", "Output tensor. Same dimension as inputs.");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Min)
|
||||
.NumInputs(1, INT_MAX)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
Element-wise min of each of the input tensors. The first input tensor can be
|
||||
used in-place as the output tensor, in which case the max will be done in
|
||||
place and results will be accumulated in input0. All inputs and outputs must
|
||||
have the same shape and data type.
|
||||
)DOC")
|
||||
.Input(0, "data_0", "First of the input tensors. Can be inplace.")
|
||||
.Output(0, "max", "Output tensor. Same dimension as inputs.");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Sum)
|
||||
.NumInputs(1, INT_MAX)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
Element-wise sum of each of the input tensors. The first input tensor can be
|
||||
used in-place as the output tensor, in which case the sum will be done in
|
||||
place and results will be accumulated in input0. All inputs and outputs must
|
||||
have the same shape and data type.
|
||||
)DOC")
|
||||
.Input(0, "data_0", "First of the input tensors. Can be inplace.")
|
||||
.Output(0, "sum", "Output tensor. Same dimension as inputs.");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Softmax)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.SetDoc(R"DOC(
|
||||
The operator computes the softmax normalized values for each layer in the batch
|
||||
of the given input. The input is a 2-D tensor (Tensor<float>) of size
|
||||
(batch_size x input_feature_dimensions). The output tensor has the same shape
|
||||
and contains the softmax normalized values of the corresponding input.
|
||||
|
||||
X does not need to explicitly be a 2D vector; rather, it will be
|
||||
coerced into one. For an arbitrary n-dimensional tensor
|
||||
X \in [a_0, a_1, ..., a_{k-1}, a_k, ..., a_{n-1}] and k is
|
||||
the axis provided, then X will be coerced into a 2-dimensional tensor with
|
||||
dimensions [a_0 * ... * a_{k-1}, a_k * ... * a_{n-1}]. For the default
|
||||
case where axis=1, this means the X tensor will be coerced into a 2D tensor
|
||||
of dimensions [a_0, a_1 * ... * a_{n-1}], where a_0 is often the batch size.
|
||||
In this situation, we must have a_0 = N and a_1 * ... * a_{n-1} = D.
|
||||
Each of these dimensions must be matched correctly, or else the operator
|
||||
will throw errors.
|
||||
)DOC")
|
||||
.Attr("axis",
|
||||
"(int) default to 1; describes the axis of the inputs when coerced "
|
||||
"to 2D; defaults to one because the 0th axis most likely describes "
|
||||
"the batch_size",
|
||||
AttrType::INT)
|
||||
.Input(0, "input",
|
||||
"The input tensor that's coerced into a 2D matrix of size (NxD) "
|
||||
"as described above.")
|
||||
.Output(0, "output", "The softmax normalized output values with the same "
|
||||
"shape as input tensor.");
|
||||
|
||||
}
|
|
@ -0,0 +1,309 @@
|
|||
// Copyright (c) Facebook Inc. and Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "proto/onnx/core/op.h"
|
||||
|
||||
namespace ONNXIR {
|
||||
std::function<void(OpSchema&)> AveragePoolOpSchemaGenerator(const char* name) {
|
||||
return [=](OpSchema& schema) {
|
||||
std::string doc = R"DOC(
|
||||
{name} consumes an input tensor X and applies average pooling across the
|
||||
the tensor according to kernel sizes, stride sizes, and pad lengths.
|
||||
Average pooling consisting of averaging all values of a subset of the
|
||||
input tensor according to the kernel size and downsampling the
|
||||
data into the output tensor Y for further processing.)DOC";
|
||||
ReplaceAll(doc, "{name}", name);
|
||||
schema.SetDoc(doc);
|
||||
schema.NumInputs(1);
|
||||
schema.NumOutputs(1);
|
||||
schema.Attr("kernel_shape",
|
||||
"The size of the kernel along each axis.",
|
||||
AttrType::INTS);
|
||||
schema.Attr("strides",
|
||||
"Stride along each axis.",
|
||||
AttrType::INTS);
|
||||
schema.Attr("pads",
|
||||
"Padding along each axis, can take the value 0 (False) or non 0 (True)",
|
||||
AttrType::INTS);
|
||||
schema.Input(0,
|
||||
"X",
|
||||
"Input data tensor from the previous operator; dimensions for image case "
|
||||
"are (N x C x H x W), where N is the batch size, C is the number of channels, "
|
||||
"and H and W are the height and the width of the data. For non image case, the "
|
||||
"dimension are in the form of (N x D1 x D2 ... Dn), where N is the batch size.");
|
||||
schema.Output(0,
|
||||
"Y",
|
||||
"Output data tensor from average pooling across the input "
|
||||
"tensor. Dimensions will vary based on various kernel, stride, and pad "
|
||||
"sizes.");
|
||||
};
|
||||
}
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(AveragePool)
|
||||
.FillUsing(AveragePoolOpSchemaGenerator("AveragePool"));
|
||||
|
||||
std::function<void(OpSchema&)> MaxPoolOpSchemaGenerator(const char* name) {
|
||||
return [=](OpSchema& schema) {
|
||||
std::string doc = R"DOC(
|
||||
{name} consumes an input tensor X and applies max pooling across the
|
||||
the tensor according to kernel sizes, stride sizes, and pad lengths.
|
||||
Average pooling consisting of averaging all values of a subset of the
|
||||
input tensor according to the kernel size and downsampling the
|
||||
data into the output tensor Y for further processing.)DOC";
|
||||
ReplaceAll(doc, "{name}", name);
|
||||
schema.SetDoc(doc);
|
||||
schema.NumInputs(1);
|
||||
schema.NumOutputs(1);
|
||||
schema.Attr("kernel_shape",
|
||||
"The size of the kernel along each axis.",
|
||||
AttrType::INTS);
|
||||
schema.Attr("strides",
|
||||
"Stride along each axis.",
|
||||
AttrType::INTS);
|
||||
schema.Attr("pads",
|
||||
"Padding along each axis, can take the value 0 (False) or non 0 (True)",
|
||||
AttrType::INTS);
|
||||
schema.Attr("dilations",
|
||||
"Dilaton along each axis, 1 mean no dilation.",
|
||||
AttrType::INTS);
|
||||
schema.Input(0,
|
||||
"X",
|
||||
"Input data tensor from the previous operator; dimensions for image case "
|
||||
"are (N x C x H x W), where N is the batch size, C is the number of channels, "
|
||||
"and H and W are the height and the width of the data. For non image case, the "
|
||||
"dimension are in the form of (N x D1 x D2 ... Dn), where N is the batch size.");
|
||||
schema.Output(0,
|
||||
"Y",
|
||||
"Output data tensor from max pooling across the input "
|
||||
"tensor. Dimensions will vary based on various kernel, stride, and pad "
|
||||
"sizes.");
|
||||
};
|
||||
}
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(MaxPool)
|
||||
.FillUsing(MaxPoolOpSchemaGenerator("MaxPool"));
|
||||
|
||||
std::function<void(OpSchema&)> ConvOpSchemaGenerator(const char* filter_desc) {
|
||||
return [=](OpSchema& schema) {
|
||||
std::string doc = R"DOC(
|
||||
The convolution operator consumes an input tensor and {filter_desc}, and
|
||||
computes the output.)DOC";
|
||||
ReplaceAll(doc, "{filter_desc}", filter_desc);
|
||||
schema.SetDoc(doc);
|
||||
schema.NumInputs(2, 3);
|
||||
schema.NumOutputs(1);
|
||||
schema.Input(0,
|
||||
"X",
|
||||
"Input data tensor from previous layer; has size (N x C x H x W)"
|
||||
", where N is the batch size, C is the number of channels, and"
|
||||
" H and W are the height and width. Note that this is for the 2D image."
|
||||
"Otherwise the size is (N x D1 x D2 ... x Dn)");
|
||||
schema.Input(1,
|
||||
"filter",
|
||||
"The filter blob that will be used in the convolutions; "
|
||||
"has size (M x C x kH x kW), where C is the number of channels, "
|
||||
"and kH and kW are the height and width of the kernel.");
|
||||
schema.Output(0,
|
||||
"Y",
|
||||
"Output data tensor that contains the result of the convolution. The "
|
||||
"output dimensions are functions of the kernel size, stride size, "
|
||||
"and pad lengths.");
|
||||
schema.Attr("kernel_shape",
|
||||
"The shape of the convolution kernel.",
|
||||
AttrType::INTS);
|
||||
schema.Attr("dilations",
|
||||
"dilation value along each axis of the filter.",
|
||||
AttrType::INTS);
|
||||
schema.Attr("strides",
|
||||
"stride along each axis.",
|
||||
AttrType::INTS);
|
||||
schema.Attr("pads",
|
||||
"Padding along each axis, can take the value 0 (False) or non 0 (True)",
|
||||
AttrType::INTS);
|
||||
schema.Attr("group",
|
||||
"number of groups input channels and output channels are divided into",
|
||||
AttrType::INT);
|
||||
};
|
||||
}
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Conv)
|
||||
.FillUsing(ConvOpSchemaGenerator("a filter"));
|
||||
|
||||
|
||||
std::function<void(OpSchema&)> ConvTransposeOpSchemaGenerator(const char* filter_desc) {
|
||||
return [=](OpSchema& schema) {
|
||||
std::string doc = R"DOC(
|
||||
The convolution transpose operator consumes an input tensor and {filter_desc},
|
||||
and computes the output.)DOC";
|
||||
ReplaceAll(doc, "{filter_desc}", filter_desc);
|
||||
schema.SetDoc(doc);
|
||||
schema.NumInputs(2);
|
||||
schema.NumOutputs(1);
|
||||
schema.Input(0,
|
||||
"X",
|
||||
"Input data tensor from previous layer; has size (N x C x H x W)"
|
||||
", where N is the batch size, C is the number of channels, and"
|
||||
" H and W are the height and width. Note that this is for the 2D image."
|
||||
"Otherwise the size is (N x D1 x D2 ... x Dn)");
|
||||
schema.Input(1,
|
||||
"filter",
|
||||
"The filter blob that will be used in the convolutions; "
|
||||
"has size (M x C x kH x kW), where C is the number of channels, "
|
||||
"and kH and kW are the height and width of the kernel.");
|
||||
schema.Output(0,
|
||||
"Y",
|
||||
"Output data tensor that contains the result of the convolution. The "
|
||||
"output dimensions are functions of the kernel size, stride size, "
|
||||
"and pad lengths.");
|
||||
schema.Attr("kernel_shape",
|
||||
"The shape of the convolution kernel.",
|
||||
AttrType::INTS);
|
||||
schema.Attr("output_shape",
|
||||
"The shape of the output.",
|
||||
AttrType::INTS);
|
||||
schema.Attr("dilations",
|
||||
"dilation value along each axis of the filter.",
|
||||
AttrType::INTS);
|
||||
schema.Attr("strides",
|
||||
"stride along each axis.",
|
||||
AttrType::INTS);
|
||||
schema.Attr("pads",
|
||||
"Padding along each axis, can take the value 0 (False) or non 0 (True)",
|
||||
AttrType::INTS);
|
||||
};
|
||||
}
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(ConvTranspose)
|
||||
.FillUsing(ConvTransposeOpSchemaGenerator("a filter"));
|
||||
|
||||
|
||||
std::function<void(OpSchema&)> GlobalPoolingOpSchemaGenerator(const char* op_type, const char* op) {
|
||||
return [=](OpSchema& schema) {
|
||||
std::string doc = R"DOC(
|
||||
Global{op_type} consumes an input tensor X and applies {op} pooling across the
|
||||
the values in the same channel. This is equivalent to {op_type} with kernel size
|
||||
equal to the spatial dimension of input tensor.)DOC";
|
||||
ReplaceAll(doc, "{op_type}", op_type);
|
||||
ReplaceAll(doc, "{op}", op);
|
||||
schema.SetDoc(doc);
|
||||
schema.NumInputs(1);
|
||||
schema.NumOutputs(1);
|
||||
schema.Input(0,
|
||||
"X",
|
||||
"Input data tensor from the previous operator; dimensions for image case "
|
||||
"are (N x C x H x W), where N is the batch size, C is the number of channels, "
|
||||
"and H and W are the height and the width of the data. For non image case, the "
|
||||
"dimension are in the form of (N x D1 x D2 ... Dn), where N is the batch size.");
|
||||
schema.Output(0,
|
||||
"Y",
|
||||
"Output data tensor from pooling across the input "
|
||||
"tensor. Dimensions will be N x C x 1 x 1");
|
||||
schema.SetDoc(doc);
|
||||
};
|
||||
}
|
||||
REGISTER_OPERATOR_SCHEMA(GlobalAveragePool)
|
||||
.FillUsing(GlobalPoolingOpSchemaGenerator("AveragePool", "average"));
|
||||
REGISTER_OPERATOR_SCHEMA(GlobalMaxPool)
|
||||
.FillUsing(GlobalPoolingOpSchemaGenerator("MaxPool", "max"));
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(BatchNormalization)
|
||||
.NumInputs(5)
|
||||
.NumOutputs({ 1, 5 })
|
||||
.EnforceConsumed({ {3, 1}, {4, 2} })
|
||||
.SetDoc(R"DOC(
|
||||
Carries out batch normalization as described in the paper
|
||||
https://arxiv.org/abs/1502.03167. Depending on the mode it is being run,
|
||||
there are multiple cases for the number of outputs, which we list below:
|
||||
|
||||
Output case #1: Y, mean, var, saved_mean, saved_var (training mode)
|
||||
Output case #2: Y (test mode)
|
||||
)DOC")
|
||||
.Attr("spatial",
|
||||
"Compute the mean and variance across all spatial elements or per feature.",
|
||||
AttrType::INT)
|
||||
.Attr("is_test",
|
||||
"If set to nonzero, run spatial batch normalization in test mode.",
|
||||
AttrType::INT)
|
||||
.Attr("epsilon",
|
||||
"The epsilon value to use to avoid division by zero.",
|
||||
AttrType::FLOAT)
|
||||
.Attr("momentum",
|
||||
"Factor used in computing the running mean and variance."
|
||||
"e.g., running_mean = running_mean * momentum + mean * (1 - momentum)",
|
||||
AttrType::FLOAT)
|
||||
.Input(0,
|
||||
"X",
|
||||
"The input 4-dimensional tensor of shape NCHW or NHWC depending "
|
||||
"on the order parameter.")
|
||||
.Input(1,
|
||||
"scale",
|
||||
"The scale as a 1-dimensional tensor of size C to be applied to the "
|
||||
"output.")
|
||||
.Input(2,
|
||||
"bias",
|
||||
"The bias as a 1-dimensional tensor of size C to be applied to the "
|
||||
"output.")
|
||||
.Input(3,
|
||||
"mean",
|
||||
"The running mean (training) or the estimated mean (testing) "
|
||||
"as a 1-dimensional tensor of size C.")
|
||||
.Input(4,
|
||||
"var",
|
||||
"The running variance (training) or the estimated "
|
||||
"variance (testing) as a 1-dimensional tensor of size C.")
|
||||
.Output(0, "Y", "The output 4-dimensional tensor of the same shape as X.")
|
||||
.Output(1,
|
||||
"mean",
|
||||
"The running mean after the BatchNormalization operator. Must be in-place "
|
||||
"with the input mean. Should not be used for testing.")
|
||||
.Output(2,
|
||||
"var",
|
||||
"The running variance after the BatchNormalization operator. Must be "
|
||||
"in-place with the input var. Should not be used for testing.")
|
||||
.Output(3,
|
||||
"saved_mean",
|
||||
"Saved mean used during training to speed up gradient "
|
||||
"computation. Should not be used for testing.")
|
||||
.Output(4,
|
||||
"saved_var",
|
||||
"Saved variance used during training to speed up "
|
||||
"gradient computation. Should not be used for testing.");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Dropout)
|
||||
.NumInputs(1)
|
||||
.NumOutputs({ 1,2 })
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
Dropout takes one input data (Tensor<float>) and produces two Tensor outputs,
|
||||
output (Tensor<float>) and mask (Tensor<bool>). Depending on whether it is in
|
||||
test mode or not, the output Y will either be a random dropout, or a simple
|
||||
copy of the input. Note that our implementation of Dropout does scaling in
|
||||
the training phase, so during testing nothing needs to be done.
|
||||
)DOC")
|
||||
.Attr("ratio",
|
||||
"(float, default 0.5) the ratio of random dropout",
|
||||
AttrType::FLOAT)
|
||||
.Attr("is_test",
|
||||
"(int, default 0) if nonzero, run dropout in test mode where "
|
||||
"the output is simply Y = X.",
|
||||
AttrType::INT)
|
||||
.Input(0, "data", "The input data as Tensor.")
|
||||
.Output(0, "output", "The output.")
|
||||
.Output(1, "mask",
|
||||
"The output mask. If is_test is nonzero, this output is not filled.");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Flatten)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.SetDoc(R"DOC(
|
||||
Flattens the input tensor into a 2D matrix, keeping the first dimension
|
||||
unchanged.
|
||||
)DOC")
|
||||
.Input(0, "input", "A tensor of rank >= 2.")
|
||||
.Output(
|
||||
0,
|
||||
"output",
|
||||
"A tensor of rank 2 with the contents of the input tensor, "
|
||||
"with first dimension equal first dimension of input, and remaining "
|
||||
"input dimensions flatenned into the inner dimension of the output.");
|
||||
}
|
|
@ -0,0 +1,91 @@
|
|||
// Copyright (c) Facebook Inc. and Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "proto/onnx/core/op.h"
|
||||
#include <functional>
|
||||
|
||||
namespace ONNXIR {
|
||||
|
||||
std::function<void(OpSchema&)> ReduceDocGenerator(const char* name) {
|
||||
return [=](OpSchema& schema) {
|
||||
std::string doc = R"DOC(
|
||||
Computes the {name} of the input tensor's element along the provided axes. The resulted
|
||||
tensor has the same shape as the input if keepdims equal 1. If keepdims equal 0, then
|
||||
the resulted tensor have the reduced dimension pruned.
|
||||
|
||||
The above behavior is similar to numpy, with the exception that numpy default keepdims to
|
||||
False instead of True.)DOC";
|
||||
ReplaceAll(doc, "{name}", name);
|
||||
schema.SetDoc(doc);
|
||||
schema.Attr("axes",
|
||||
"A list of integers, along which to reduce max.",
|
||||
AttrType::INTS);
|
||||
schema.Attr("keepdims",
|
||||
"Keep the reduced dimension or not, default 1 mean keep reduced dimension.",
|
||||
AttrType::INT);
|
||||
schema.Input(0, "data", "An input tensor.");
|
||||
schema.Output(0, "reduced", "Reduced output tensor.");
|
||||
};
|
||||
}
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(ReduceMax)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.FillUsing(ReduceDocGenerator("max"));
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(ReduceMin)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.FillUsing(ReduceDocGenerator("min"));
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(ReduceSum)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.FillUsing(ReduceDocGenerator("sum"));
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(ReduceMean)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.FillUsing(ReduceDocGenerator("mean"));
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(ReduceProd)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.FillUsing(ReduceDocGenerator("product"));
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(ReduceLogSumExp)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.FillUsing(ReduceDocGenerator("log sum exponent"));
|
||||
|
||||
std::function<void(OpSchema&)> ArgReduceDocGenerator(const char* name) {
|
||||
return [=](OpSchema& schema) {
|
||||
std::string doc = R"DOC(
|
||||
Computes the indices of the {name} elements of the input tensor's element along the
|
||||
provided axes. The resulted tensor has the same shape as the input if keepdims equal 1.
|
||||
If keepdims equal 0, then the resulted tensor have the reduced dimension pruned.
|
||||
The type of the output tensor is integer.)DOC";
|
||||
ReplaceAll(doc, "{name}", name);
|
||||
schema.SetDoc(doc);
|
||||
schema.Attr("axes",
|
||||
"A list of integers, along which to reduce max.",
|
||||
AttrType::INTS);
|
||||
schema.Attr("keepdims",
|
||||
"Keep the reduced dimension or not, default 1 mean keep reduced dimension.",
|
||||
AttrType::INT);
|
||||
schema.Input(0, "data", "An input tensor.");
|
||||
schema.Output(0, "reduced", "Reduced output tensor with integer data type.");
|
||||
};
|
||||
}
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(ArgMax)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.FillUsing(ArgReduceDocGenerator("max"));
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(ArgMin)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.FillUsing(ArgReduceDocGenerator("min"));
|
||||
|
||||
} // namespace ONNXIR
|
|
@ -0,0 +1,6 @@
|
|||
// Copyright (c) Facebook Inc. and Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "proto/onnx/core/op.h"
|
||||
|
||||
|
|
@ -0,0 +1,162 @@
|
|||
// Copyright (c) Facebook Inc. and Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "proto/onnx/core/op.h"
|
||||
|
||||
namespace ONNXIR
|
||||
{
|
||||
REGISTER_OPERATOR_SCHEMA(Cast)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.SetDoc(R"DOC(
|
||||
The operator casts the elements of a given input tensor to a data type
|
||||
specified by the 'to' argument and returns an output tensor of the same size in
|
||||
the converted type. The 'to' argument must be one of the data types specified
|
||||
in the 'DataType' enum field in the TensorProto message. If the 'to' argument
|
||||
is not provided or is not one of the enumerated types in DataType, Caffe2
|
||||
throws an Enforce error.
|
||||
|
||||
NOTE: Casting to and from strings is not supported yet.
|
||||
)DOC")
|
||||
.Attr(
|
||||
"to",
|
||||
"The data type to which the elements of the input tensor are cast."
|
||||
"Strictly must be one of the types from DataType enum in TensorProto",
|
||||
AttrType::STRING)
|
||||
.Input(0, "input", "Input tensor to be cast.")
|
||||
.Output(
|
||||
0,
|
||||
"output",
|
||||
"Output tensor with the same shape as input with type "
|
||||
"specified by the 'to' argument");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Reshape)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.AllowConsumed({ {0, 0} })
|
||||
.SetDoc(R"DOC(
|
||||
Reshape the input tensor similar to numpy.reshape.
|
||||
|
||||
It takes a tensor as input and an argument `shape`. It outputs the reshaped tensor.
|
||||
|
||||
At most one dimension of the new shape can be -1. In this case, the value is
|
||||
inferred from the size of the tensor and the remaining dimensions. A dimension
|
||||
could also be 0, in which case the actual dimension value is going to be copied
|
||||
from the shape argument.)DOC")
|
||||
.Attr("shape", "New shape", AttrType::INTS)
|
||||
.Input(0, "data", "An input tensor.")
|
||||
.Output(0, "reshaped", "Reshaped data.");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Concat)
|
||||
.NumInputs(1, INT_MAX)
|
||||
.NumOutputs(2)
|
||||
.Attr("axis",
|
||||
"Which axis to concat on",
|
||||
AttrType::INT)
|
||||
.SetDoc("Concatenate a list of tensors into a single tensor")
|
||||
.Output(0, "concat_result", "Concatenated tensor");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Split)
|
||||
.NumInputs(1, 2)
|
||||
.NumOutputs(1, INT_MAX)
|
||||
.Input(0, "input", "The tensor to split")
|
||||
.Input(1, "split", "Optional list of output lengths (see also arg 'split')")
|
||||
.Attr("axis",
|
||||
"Which axis to split on",
|
||||
AttrType::INT)
|
||||
.Attr("split",
|
||||
"length of each output",
|
||||
AttrType::INTS)
|
||||
.SetDoc(R"DOC(Split a tensor into a list of tensors, along the specified
|
||||
'axis'. The lengths of the split can be specified using argument 'axis' or
|
||||
optional second input blob to the operator. Otherwise, the tensor is split
|
||||
to equal sized parts.
|
||||
)DOC");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Slice)
|
||||
.NumInputs(1, 3)
|
||||
.NumOutputs(1)
|
||||
.SetDoc(R"DOC(
|
||||
Produces a slice of the input tensor along multiple axes. Similar to numpy:
|
||||
https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html
|
||||
|
||||
Slices are passed as two keyword argument lists with starting and end indices
|
||||
for each dimension of the input `data` tensor. If a negative value is passed
|
||||
for any of the start or end indices, it represent number of elements before
|
||||
the end of that dimension.
|
||||
|
||||
`strides` is the step sizes when applying slicing, negative value means in
|
||||
reverse order.
|
||||
)DOC")
|
||||
.Input(0, "data", "Tensor of data to extract slices from.")
|
||||
.Attr("starts",
|
||||
"List of starting indices",
|
||||
AttrType::INTS)
|
||||
.Attr("ends",
|
||||
"List of ending indices",
|
||||
AttrType::INTS)
|
||||
.Output(0, "output", "Sliced data tensor.");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Transpose)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.SetDoc(R"DOC(
|
||||
Transpose the input tensor similar to numpy.transpose. For example, when
|
||||
axes=(1, 0, 2), given an input tensor of shape (1, 2, 3), the output shape
|
||||
will be (2, 1, 3).
|
||||
)DOC")
|
||||
.Attr("perm",
|
||||
"A list of integers. By default, reverse the dimensions, "
|
||||
"otherwise permute the axes according to the values given.",
|
||||
AttrType::INTS)
|
||||
.Input(0, "data", "An input tensor.")
|
||||
.Output(0, "transposed", "Transposed output.");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Gather)
|
||||
.NumInputs(2)
|
||||
.NumOutputs(1)
|
||||
.SetDoc(R"DOC(
|
||||
Given DATA tensor of rank r >= 1, and INDICES tensor of rank q, gather
|
||||
entries of the outer-most dimension of DATA indexed by INDICES, and concatenate
|
||||
them in an output tensor of rank q + (r - 1).
|
||||
|
||||
Example:
|
||||
DATA = [
|
||||
[1.0, 1.2],
|
||||
[2.3, 3.4],
|
||||
[4.5, 5.7],
|
||||
]
|
||||
INDICES = [
|
||||
[0, 1],
|
||||
[1, 2],
|
||||
]
|
||||
OUTPUT = [
|
||||
[
|
||||
[1.0, 1.2],
|
||||
[2.3, 3.4],
|
||||
],
|
||||
[
|
||||
[2.3, 3.4],
|
||||
[4.5, 5.7],
|
||||
],
|
||||
]
|
||||
)DOC")
|
||||
.Input(0, "DATA", "Tensor of rank r >= 1.")
|
||||
.Input(1, "INDICES", "Tensor of int32/int64 indices, of any rank q.")
|
||||
.Output(0, "OUTPUT", "Tensor of rank q + (r - 1).");
|
||||
|
||||
REGISTER_OPERATOR_SCHEMA(Squeeze)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.Attr("axes",
|
||||
"List of positive integers, indicate the dimensions to squeeze.",
|
||||
AttrType::INTS,
|
||||
true)
|
||||
.SetDoc(R"DOC(
|
||||
Remove single-dimensional entries from the shape of a tensor.
|
||||
Takes a parameter `axes` with a list of axes to squeeze.
|
||||
)DOC")
|
||||
.Input(0, "data", "Tensors with at least max(dims) dimensions.")
|
||||
.Output(0, "squeezed", "Reshaped tensor with same data as input.");
|
||||
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable : 4800 4610 4512 4510 4267 4127 4125 4100 4456)
|
||||
#include "graph.pb.cc"
|
||||
#pragma warning(pop)
|
|
@ -0,0 +1,613 @@
|
|||
syntax = "proto2";
|
||||
|
||||
package ONNXIR;
|
||||
|
||||
// Note [Protobuf compatibility]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// Based on experience working with downstream vendors, we generally can't
|
||||
// assume recent versions of protobufs. This means that we do not use any
|
||||
// protobuf features that are only available in proto3.
|
||||
//
|
||||
// Here are the most notable contortions we have to carry out to work around
|
||||
// these limitations:
|
||||
//
|
||||
// - No 'map' (added protobuf 3.0). We instead represent mappings as lists
|
||||
// of key-value pairs, where order does not matter and duplicates
|
||||
// are not allowed.
|
||||
|
||||
// Note [Namespaces]
|
||||
// ~~~~~~~~~~~~~~~~~
|
||||
// LotusIR gives explicit names to graphs, intermediate values and
|
||||
// serialized tensors. To make it easier to generate names, we organize
|
||||
// these into separate namespaces (so, e.g., a graph can have the same
|
||||
// name as a serialized tensor.) The namespaces are as follows:
|
||||
//
|
||||
// - Node: These names identify specific nodes in the graph (but not, necessarily
|
||||
// any particular input or output of the node.
|
||||
// - Graph: These names identify graphs in the protobuf.
|
||||
// - Attribute: These names identify attribute names for extra attributes that
|
||||
// are passed to operators.
|
||||
// - OperatorOrFunction: These names identify particular operators and
|
||||
// functions.
|
||||
// - Value: These names identify intermediate values (typically tensors) flowing through
|
||||
// the computation of a graph.
|
||||
// - Shape: These names represent parameters for unknown shape dimensions.
|
||||
//
|
||||
// We specify the namespace of a name in LotusIR as comments in the form
|
||||
// of "namespace {Node,Graph,OperatorOrFunction,Attribute,Value,Shape}". Framework is
|
||||
// responsible for supporting the namespaces.
|
||||
|
||||
// To be compatible with both proto2 and proto3, we will use a version number
|
||||
// that is not defined by the default value but an explicit enum number.
|
||||
enum Version {
|
||||
// The version field is always serialized and we will use it to store the
|
||||
// version that the graph is generated from. This helps us set up version
|
||||
// control. We should use version as
|
||||
// xx(major) - xx(minor) - xxxx(bugfix)
|
||||
// and we are starting with 00000001.
|
||||
IR_VERSION = 00000001;
|
||||
}
|
||||
|
||||
// A named attribute containing either singular float, integer, string
|
||||
// and tensor values, or repeated float, integer, string and tensor values.
|
||||
// An AttributeProto MUST contain the name field, and *only one* of the
|
||||
// following content fields, effectively enforcing a C/C++ union equivalent.
|
||||
message AttributeProto {
|
||||
// The name field MUST be present for this version of the IR.
|
||||
optional string name = 1; // namespace Attribute
|
||||
optional float f = 2; // float
|
||||
optional int64 i = 3; // int
|
||||
optional bytes s = 4; // UTF-8 string
|
||||
optional TensorProto t = 5; // tensor value
|
||||
optional GraphProto g = 6; // graph
|
||||
|
||||
repeated float floats = 7; // list of floats
|
||||
repeated int64 ints = 8; // list of ints
|
||||
repeated bytes strings = 9; // list of UTF-8 strings
|
||||
repeated TensorProto tensors = 10; // list of tensors
|
||||
repeated GraphProto graphs = 11; // list of graph
|
||||
|
||||
optional TypeProto type = 51;
|
||||
repeated TypeProto types = 52;
|
||||
//ISSUE:13807134,dbox: Do we ever see shape showing up as an attribute value?
|
||||
// If so, won't it always be accompanied by a TypeProto?
|
||||
optional TypeProto.TensorShapeProto shape = 53;
|
||||
repeated TypeProto.TensorShapeProto shapes = 54;
|
||||
}
|
||||
|
||||
// Defines information on value, including the name, the type, and
|
||||
// the shape of the value.
|
||||
message ValueInfoProto {
|
||||
// This field MUST be present in this version of the IR.
|
||||
optional string name = 1; // namespace Value
|
||||
// This field MUST be present in this version of the IR.
|
||||
optional TypeProto type = 2;
|
||||
}
|
||||
|
||||
// Defines a node in a computation graph. Each graph node is either an
|
||||
// operator or a function call. A node that is similar to the notion of "layer"
|
||||
// or "operator" in many deep learning frameworks. For example, it can be a
|
||||
// node of type "Conv" that takes in an image, a filter tensor and a bias
|
||||
// tensor, and produces the convolved output.
|
||||
//
|
||||
// NOTE: Control flow is defined by two built-in operators:
|
||||
//
|
||||
// Cond(p, true_input, false_input) takes three inputs, where p is a
|
||||
// boolean scalar tensor, true_input is the list of inputs to the true
|
||||
// branch of cond, and false_input is the list of inputs to the false
|
||||
// branch of cond. The true and false branches are defined as
|
||||
// functions that takes true_input and false_input as inputs respectively.
|
||||
// The two functions must have the same number of outputs, and each
|
||||
// corresponding output must have the same types, and have compatible
|
||||
// shapes.
|
||||
//
|
||||
// While(vars, consts) takes two inputs, where vars are the initial
|
||||
// values of the loop variables and consts are the values of constants
|
||||
// used inside the loop. The loop condition and loop body are defined
|
||||
// as functions. The functions take both vars and consts as inputs.
|
||||
// The loop condition function returns a boolean scalar tensor. The
|
||||
// loop body function has the form: body(vars, consts) = new_vars,
|
||||
// where new_vars are the new values of the loop variables after one
|
||||
// iteration so must match vars in terms of types and shapes.
|
||||
message NodeProto {
|
||||
// The named inputs of the node.
|
||||
repeated string input = 1; // namespace Value
|
||||
|
||||
// The named outputs of the node.
|
||||
repeated string output = 2; // namespace Value
|
||||
|
||||
// The name of this node.
|
||||
// This field is optional and used to uniquely identify nodes in the graph.
|
||||
optional string name = 3; // namespace Node
|
||||
|
||||
// The name of the operator/function called by the node.
|
||||
// This field MUST be present for this version of the IR.
|
||||
optional string op_type = 4; // namespace OperatorOrFunction
|
||||
|
||||
// Additional named attributes.
|
||||
repeated AttributeProto attribute = 5;
|
||||
|
||||
// An optional human-readable documentation for this node in the graph.
|
||||
// This text MAY contain Markdown markup that conforms to http://commonmark.org/.
|
||||
optional string doc_string = 6;
|
||||
|
||||
// The number of inputs for each argument of the operator/function.
|
||||
// A formal parameter of the op may take a variable number of inputs
|
||||
// that is only known when this node is constructed.
|
||||
//BUG:13806939,dbox: I'm assuming that this field is like input_arg_info in that
|
||||
// a zero element/missing array implies that one needs to crawl
|
||||
// the graph to figure out the input counts, yes? Confirm and I'll
|
||||
// make clear. Otherwise, we need to require it to be present
|
||||
// and accurate.
|
||||
repeated int32 input_arg_count = 50;
|
||||
|
||||
// Specify a list of named nodes that must be executed before this node.
|
||||
// Framework may use this to give users the ability to impose additional
|
||||
// execution orders for the operations.
|
||||
repeated string control_input = 51;
|
||||
}
|
||||
|
||||
// ModelProto is a top-level file/container format for bundling a ML model.
|
||||
// The semantics of the model are described by the GraphProto that represents
|
||||
// a parameterized computation graph against a set of named operators that are
|
||||
// defined independently from the graph.
|
||||
message ModelProto {
|
||||
// The version of the IR this model targets. See Version enum above.
|
||||
// This field MUST be present.
|
||||
optional int64 ir_version = 1;
|
||||
|
||||
// The name of the framework or tool used to generate this model.
|
||||
// This field SHOULD be present to indicate which implementation/tool/framework
|
||||
// emitted the model.
|
||||
optional string producer_name = 2;
|
||||
|
||||
// The version of the framework or tool used to generate this model.
|
||||
// This field SHOULD be present to indicate which implementation/tool/framework
|
||||
// emitted the model.
|
||||
optional string producer_version = 3;
|
||||
|
||||
// Domain name of the model.
|
||||
// We use reverse domain names as name space indicators. For example:
|
||||
// `com.facebook.fair` or `com.microsoft.cognitiveservices`
|
||||
//
|
||||
// Together with `model_version` and GraphProto.name, this forms the unique identity of
|
||||
// the graph.
|
||||
optional string domain = 4;
|
||||
|
||||
// The version of the graph encoded. See Version enum below.
|
||||
optional int64 model_version = 5;
|
||||
|
||||
// A human-readable documentation for this model. Markdown is allowed.
|
||||
optional string doc_string = 6;
|
||||
|
||||
// The parameterized graph that is evaluated to execute the model.
|
||||
optional GraphProto graph = 7;
|
||||
|
||||
// NOTE: ids between 8 and 49 are reserved for more ONNX fields.
|
||||
|
||||
// The optional name of the author who created the graph.
|
||||
optional string model_author = 50;
|
||||
|
||||
// Optional licensing information concerning use or origination of the graph.
|
||||
// This text MAY contain Markdown markup that conforms to http://commonmark.org/.
|
||||
optional string model_license = 51;
|
||||
};
|
||||
|
||||
// GraphProto defines a parameterized series of nodes to form a directed acyclic graph.
|
||||
// This is the equivalent of the "network" and "graph" in many deep learning
|
||||
// frameworks.
|
||||
// All the input/output tensors are explicitly named so a framework can
|
||||
// run any subgraph of the graph by feeding and fetching the named tensors.
|
||||
message GraphProto {
|
||||
// The nodes in the graph.
|
||||
repeated NodeProto node = 1;
|
||||
|
||||
// The name of the graph.
|
||||
optional string name = 2; // namespace Graph
|
||||
|
||||
// A list of named tensor values (constants), used to specify default
|
||||
// values for some of the inputs of the graph.
|
||||
// Each TensorProto entry must have a distinct name (within the list) that
|
||||
// also appears in the input list.
|
||||
// In an evaluation, the default value specified here is used if and only if
|
||||
// user specifies no value for the corresponding input parameter.
|
||||
// May be used to pass serialized parameters for networks.
|
||||
repeated TensorProto initializer = 5;
|
||||
|
||||
// A human-readable documentation for this graph. Markdown is allowed.
|
||||
optional string doc_string = 10;
|
||||
|
||||
// The inputs and outputs of the graph.
|
||||
repeated ValueInfoProto input = 11;
|
||||
repeated ValueInfoProto output = 12;
|
||||
|
||||
// Information for the values in the graph. The ValueInfoProto.name's
|
||||
// must be distinct. It is optional for a value to appear in value_info list.
|
||||
repeated ValueInfoProto value_info = 13;
|
||||
|
||||
// DO NOT USE the following fields which were deprecated.
|
||||
// repeated string input = 3;
|
||||
// repeated string output = 4;
|
||||
// optional int64 ir_version = 6;
|
||||
// optional int64 producer_version = 7;
|
||||
// optional string producer_tag = 8;
|
||||
// optional string domain = 9;
|
||||
|
||||
// The function definitions of the graph. They can only only be used
|
||||
// (i.e., called) in this graph.
|
||||
// Each FunctionDefProto in function MUST have a unique name.
|
||||
repeated FunctionDefProto function = 50;
|
||||
|
||||
// The externally defined operators declared by this graph.
|
||||
repeated OperatorDeclProto operator = 51;
|
||||
|
||||
// TODO: When the map type is added, provide for the "model_information"
|
||||
// field which holds name/value pairs of strings with additional devops
|
||||
// metadata, such as an identifier for which training set this instance
|
||||
// of a graph was trained with.
|
||||
|
||||
// Imported libraries are referenced as a collection of strings in the form of absolute
|
||||
// URIs or relative paths. Where such relative paths are rooted is defined by tools and
|
||||
// runtime implementations.
|
||||
repeated string imported_libraries = 52;
|
||||
|
||||
reserved 100 to 200; // for future extensions.
|
||||
}
|
||||
|
||||
// A message defined to store a tensor in its serialized format.
|
||||
message TensorProto {
|
||||
enum DataType {
|
||||
UNDEFINED = 0;
|
||||
// Basic types.
|
||||
FLOAT = 1; // float
|
||||
UINT8 = 2; // uint8_t
|
||||
INT8 = 3; // int8_t
|
||||
UINT16 = 4; // uint16_t
|
||||
INT16 = 5; // int16_t
|
||||
INT32 = 6; // int32_t
|
||||
INT64 = 7; // int64_t
|
||||
STRING = 8; // string
|
||||
BOOL = 9; // bool
|
||||
|
||||
// Advanced types
|
||||
FLOAT16 = 10;
|
||||
DOUBLE = 11;
|
||||
UINT32 = 12;
|
||||
UINT64 = 13;
|
||||
COMPLEX64 = 14; // complex with float32 real and imaginary components
|
||||
COMPLEX128 = 15; // complex with float64 real and imaginary components
|
||||
// Future extensions go here.
|
||||
}
|
||||
|
||||
// The shape of the tensor.
|
||||
repeated int64 dims = 1;
|
||||
|
||||
// The data type of the tensor.
|
||||
optional DataType data_type = 2;
|
||||
|
||||
// For very large tensors, we may want to store them in chunks, in which
|
||||
// case the following fields will specify the segment that is stored in
|
||||
// the current TensorProto.
|
||||
message Segment {
|
||||
optional int64 begin = 1;
|
||||
optional int64 end = 2;
|
||||
}
|
||||
optional Segment segment = 3;
|
||||
|
||||
// Tensor content must be in the row major order.
|
||||
//
|
||||
// Depending on the data_type field, exactly one of the fields below with
|
||||
// name ending in _data is used to store the elements of the tensor.
|
||||
|
||||
// For float and complex64 values
|
||||
// Complex64 tensors are encoded as a single array of floats,
|
||||
// with the real components appearing in odd numbered positions,
|
||||
// and the corresponding imaginary component apparing in the
|
||||
// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
|
||||
// is encoded as [1.0, 2.0 ,3.0 ,4.0]
|
||||
// When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
|
||||
repeated float float_data = 4 [packed = true];
|
||||
|
||||
// For int32, uint8, int8, uint16, int16, bool, and float16 values
|
||||
// float16 values must be bit-wise converted to an uint16_t prior
|
||||
// to writing to the buffer.
|
||||
// When this field is present, the data_type field MUST be
|
||||
// INT32, INT16, INT8, UINT16, INT8, BOOL, or FLOAT32
|
||||
repeated int32 int32_data = 5 [packed = true];
|
||||
|
||||
// For strings.
|
||||
// Each element of string_data is a UTF-8 encoded Unicode
|
||||
// string. No trailing null, no leading BOM. The protobuf "string"
|
||||
// scalar type is not used to match ML community conventions.
|
||||
// When this field is present, the data_type field MUST be STRING
|
||||
repeated bytes string_data = 6;
|
||||
|
||||
// For int64.
|
||||
// When this field is present, the data_type field MUST be INT64
|
||||
repeated int64 int64_data = 7 [packed = true];
|
||||
|
||||
// Optionally, a name for the tensor.
|
||||
optional string name = 8; // namespace Value
|
||||
|
||||
// Serializations can either use one of the fields above, or use this
|
||||
// raw bytes field. The only exception is the string case, where one is
|
||||
// required to store the content in the repeated bytes string_data field.
|
||||
//
|
||||
// When this raw_data field is used to store tensor value, elements MUST
|
||||
// be stored in as fixed-width, little-endian order.
|
||||
// Floating-point data types MUST be stored in IEEE 754 format.
|
||||
// Complex64 elements must be written as two consecutive FLOAT values, real component first.
|
||||
// Complex128 elements must be written as two consecutive DOUBLE values, real component first.
|
||||
// Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
|
||||
//
|
||||
// Note: the advantage of specific field rather than the raw_data field is
|
||||
// that in some cases (e.g. int data), protobuf does a better packing via
|
||||
// variable length storage, and may lead to smaller binary footprint.
|
||||
// When this field is present, the data_type field MUST NOT be STRING or UNDEFINED
|
||||
optional bytes raw_data = 9;
|
||||
|
||||
// For double
|
||||
// Complex64 tensors are encoded as a single array of doubles,
|
||||
// with the real components appearing in odd numbered positions,
|
||||
// and the corresponding imaginary component apparing in the
|
||||
// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
|
||||
// is encoded as [1.0, 2.0 ,3.0 ,4.0]
|
||||
// When this field is present, the data_type field MUST be DOUBLE or COMPLEX128
|
||||
repeated double double_data = 10 [packed = true];
|
||||
|
||||
// For uint64 and uint32 values
|
||||
// When this field is present, the data_type field MUST be
|
||||
// UINT32 or UINT64
|
||||
repeated uint64 uint64_data = 11 [packed = true];
|
||||
}
|
||||
|
||||
// A sparse tensor must be stored as three dense tensors:
|
||||
// 1. dims: The shape of the original dense tensor.
|
||||
// 2. indices: A 2-D tensor specifying the indices of the nonzero elements.
|
||||
// 3. values: A 1-D tensor containing the values of the nonzero elements.
|
||||
message SparseTensorProto {
|
||||
// The dimensions in the tensor.
|
||||
repeated int64 dims = 1;
|
||||
// This field MUST be present this version of the IR.
|
||||
optional TensorProto indices = 2;
|
||||
// This field MUST be present this version of the IR.
|
||||
optional TensorProto values = 3;
|
||||
}
|
||||
|
||||
// Define the types.
|
||||
message TypeProto {
|
||||
// Defines a tensor shape. A dimension can be either an integer value
|
||||
// or a symbolic variable. A symbolic variable represents an unknown
|
||||
// dimension.
|
||||
message TensorShapeProto {
|
||||
message Dimension {
|
||||
oneof value {
|
||||
int64 dim_value = 1;
|
||||
string dim_param = 2; // namespace Shape
|
||||
}
|
||||
}
|
||||
repeated Dimension dim = 1;
|
||||
}
|
||||
|
||||
message TensorTypeProto {
|
||||
// This field MUST NOT have the value of UNDEFINED
|
||||
// This field MUST be present for this version of the IR.
|
||||
optional TensorProto.DataType elem_type = 1;
|
||||
optional TensorShapeProto shape = 2;
|
||||
}
|
||||
|
||||
message SparseTensorTypeProto {
|
||||
// This field MUST NOT have the value of UNDEFINED
|
||||
// This field MUST be present for this version of the IR.
|
||||
optional TensorProto.DataType elem_type = 1;
|
||||
optional TensorShapeProto shape = 2;
|
||||
}
|
||||
|
||||
message HandleTypeProto {
|
||||
}
|
||||
|
||||
message TupleTypeProto {
|
||||
repeated TypeProto elem_type = 1;
|
||||
}
|
||||
|
||||
message SeqTypeProto {
|
||||
// This field MUST be present for this version of the IR.
|
||||
optional TypeProto elem_type = 1;
|
||||
}
|
||||
|
||||
message MapTypeProto {
|
||||
// This field MUST be present for this version of the IR.
|
||||
// This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING
|
||||
optional TensorProto.DataType key_type = 1;
|
||||
// This field MUST be present for this version of the IR.
|
||||
// This field MUST NOT refer to UNDEFINED
|
||||
optional TensorProto.DataType value_type = 2;
|
||||
}
|
||||
|
||||
oneof value {
|
||||
// The type of a tensor.
|
||||
TensorTypeProto tensor_type = 1;
|
||||
|
||||
// The type of a sparse tensor.
|
||||
SparseTensorTypeProto sparse_tensor_type = 2;
|
||||
|
||||
// The type of an opaque handle. A handle is used to represent a
|
||||
// reference to a resource managed by the framework runtime.
|
||||
HandleTypeProto handle_type = 3;
|
||||
|
||||
// The type of a tuple.
|
||||
TupleTypeProto tuple_type = 4;
|
||||
|
||||
// The type of a sequence.
|
||||
SeqTypeProto seq_type = 5;
|
||||
|
||||
// The type of a map.
|
||||
MapTypeProto map_type = 6;
|
||||
}
|
||||
}
|
||||
|
||||
message ValueProto {
|
||||
// Defines a handle in its serialized format.
|
||||
message HandleProto {
|
||||
// This field MUST be present this version of the IR.
|
||||
optional int64 uid = 1;
|
||||
|
||||
// More information to be added. We need to specify the device
|
||||
// that the resource managed by the handle is on.
|
||||
}
|
||||
|
||||
// Defines a tuple in its serialized format.
|
||||
message TupleProto {
|
||||
repeated ValueProto elems = 1;
|
||||
}
|
||||
|
||||
// Defines a sequence in its serialized format.
|
||||
message SequenceProto {
|
||||
repeated ValueProto elems = 1;
|
||||
}
|
||||
|
||||
// Defines a map in its serialized format.
|
||||
// Maps are serialized as two single-dimensional tensors
|
||||
// for storage efficiency. The dimensions of each tensor MUST be identical
|
||||
// and the key at position N corresponds to the value at position N.
|
||||
// Keys SHOULD be unique. When a given key appears multiple times,
|
||||
// the value that corresponds last occurance of the key is the value.
|
||||
// This is consistent with protobuf3 encoding rules for map.
|
||||
message MapProto {
|
||||
// This field MUST be present for this version of the IR.
|
||||
// The data type of the tensor MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING
|
||||
optional TensorProto keys = 1;
|
||||
|
||||
// This field MUST be present for this version of the IR.
|
||||
optional TensorProto values = 2;
|
||||
}
|
||||
|
||||
oneof value {
|
||||
// A dense tensor.
|
||||
TensorProto dense_tensor = 1;
|
||||
|
||||
// A sparse tensor.
|
||||
SparseTensorProto sparse_tensor = 2;
|
||||
|
||||
// A handle.
|
||||
HandleProto handle = 3;
|
||||
|
||||
// A tuple.
|
||||
TupleProto tuple = 4;
|
||||
|
||||
// A sequence.
|
||||
SequenceProto seq = 5;
|
||||
|
||||
// A map.
|
||||
MapProto map = 6;
|
||||
}
|
||||
}
|
||||
|
||||
message ParameterDeclProto {
|
||||
optional string name = 1;
|
||||
optional TypeProto type = 2;
|
||||
// An optional human-readable documentation for this parameter.
|
||||
optional string doc_string = 3;
|
||||
}
|
||||
|
||||
// Defines a function.
|
||||
message FunctionDefProto {
|
||||
// The name of the function.
|
||||
// This field MUST be present for this version of the IR.
|
||||
optional string name = 1;
|
||||
|
||||
// The input parameters of the function.
|
||||
repeated ParameterDeclProto input_params = 2;
|
||||
|
||||
// The output parameters of the function.
|
||||
repeated ParameterDeclProto output_params = 3;
|
||||
|
||||
// The body of the function.
|
||||
repeated NodeProto node = 4;
|
||||
|
||||
// The named attributes of the function.
|
||||
repeated AttributeProto attr = 5;
|
||||
}
|
||||
|
||||
message SignatureDeclProto {
|
||||
// The formal input parameters to the operation or function
|
||||
repeated ParameterDeclProto input_params = 1;
|
||||
// The formal output parameters to the operation or function
|
||||
repeated ParameterDeclProto output_params = 2;
|
||||
// The declaration of expected attributes to the operation or function
|
||||
repeated ParameterDeclProto input_attributes = 3;
|
||||
// An optional human-readable documentation for this signature.
|
||||
optional string doc_string = 4;
|
||||
}
|
||||
|
||||
message OperatorDeclProto {
|
||||
// This field MUST be present for this version of the IR.
|
||||
optional string name = 1;
|
||||
|
||||
// This field MUST contain at least one SignatureDeclProto.
|
||||
// This field MAY contain multiple SignatureDeclProtos, one
|
||||
// per type signature supported by this operator.
|
||||
repeated SignatureDeclProto signature = 2;
|
||||
|
||||
// An optional human-readable documentation for this operator.
|
||||
optional string doc_string = 3;
|
||||
}
|
||||
|
||||
// A library is a top-level format that contains the declaration
|
||||
// of operators and the definition of functions.
|
||||
message LibraryProto {
|
||||
// The version of the IR this graph targets. See Version enum below.
|
||||
// This field MUST be present this version of the IR.
|
||||
optional int64 ir_version = 1;
|
||||
|
||||
// The optional version of the framework runtime that generates this graph.
|
||||
// This producer_version has the same format as ir_version.
|
||||
optional int64 producer_version = 2;
|
||||
|
||||
// The optional name of the framework used to generate this graph in the form
|
||||
// "framework_name[-tag]". Tag is optional and provides additional
|
||||
// information such as `alpha` or `beta` or `rc3`.
|
||||
optional string producer_tag = 3;
|
||||
|
||||
// An optional version identifier used to track evolution of this library.
|
||||
// This model_version has the same format as ir_version.
|
||||
optional int64 model_version = 4;
|
||||
|
||||
// The optional name of the author who created the library.
|
||||
optional string model_author = 5;
|
||||
|
||||
// Optional licensing information concerning use or origination of the library.
|
||||
optional string model_license = 6;
|
||||
|
||||
// The name of the library.
|
||||
optional string name = 7; // namespace Library
|
||||
|
||||
// Domain of the graph.
|
||||
// We use reverse domain names as name space indicators. For example:
|
||||
// `com.facebook.fair` or `com.microsoft.cognitiveservices`
|
||||
//
|
||||
// Together with `name` and `model_version`, this forms the unique identity of
|
||||
// the library.
|
||||
optional string domain = 8;
|
||||
|
||||
// An optional human-readable documentation for this graph.
|
||||
optional string doc_string = 9;
|
||||
|
||||
// The operators declared by this library.
|
||||
repeated OperatorDeclProto operator = 10;
|
||||
|
||||
// The function definitions of the library.
|
||||
repeated FunctionDefProto function = 11;
|
||||
|
||||
// A given name may appear at most once in either operator or function (but not both).
|
||||
// When refering to an operator or function from outside of this library, the op_type
|
||||
// field must equal:
|
||||
// LibraryProto.domain + "." + LibraryProto.name + "." OperatorDeclProto.name
|
||||
// or
|
||||
// LibraryProto.domain + "." + LibraryProto.name + "." FunctionDefProto.name
|
||||
|
||||
// Imported libraries are referenced as a collection of strings in the form of absolute
|
||||
// URIs or relative paths. Where such relative paths are rooted is defined by tools and
|
||||
// runtime implementations.
|
||||
repeated string imported_libraries = 12;
|
||||
}
|
|
@ -542,9 +542,9 @@ void fgetText(FILE* f, T& v)
|
|||
{
|
||||
int rc = ftrygetText(f, v);
|
||||
if (rc == 0)
|
||||
RuntimeError("error reading value from file (invalid format)");
|
||||
Microsoft::MSR::CNTK::RuntimeError("error reading value from file (invalid format)");
|
||||
else if (rc == EOF)
|
||||
RuntimeError("error reading from file: %s", strerror(errno));
|
||||
Microsoft::MSR::CNTK::RuntimeError("error reading from file: %s", strerror(errno));
|
||||
assert(rc == 1);
|
||||
}
|
||||
|
||||
|
@ -575,9 +575,9 @@ void fputText(FILE* f, T v)
|
|||
const wchar_t* formatString = GetFormatString(v);
|
||||
int rc = fwprintf(f, formatString, v);
|
||||
if (rc == 0)
|
||||
RuntimeError("error writing value to file, no values written");
|
||||
Microsoft::MSR::CNTK::RuntimeError("error writing value to file, no values written");
|
||||
else if (rc < 0)
|
||||
RuntimeError("error writing to file: %s", strerror(errno));
|
||||
Microsoft::MSR::CNTK::RuntimeError("error writing to file: %s", strerror(errno));
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
@ -723,7 +723,7 @@ class auto_file_ptr
|
|||
#pragma warning(disable : 4996)
|
||||
void openfailed(const std::string& path)
|
||||
{
|
||||
RuntimeError("auto_file_ptr: error opening file '%s': %s", path.c_str(), strerror(errno));
|
||||
Microsoft::MSR::CNTK::RuntimeError("auto_file_ptr: error opening file '%s': %s", path.c_str(), strerror(errno));
|
||||
}
|
||||
#pragma warning(pop)
|
||||
protected:
|
||||
|
|
|
@ -86,7 +86,7 @@ int main(int argc, char *argv[])
|
|||
fprintf(stderr, "Run test on a GPU device.\n");
|
||||
TrainCifarResnet();
|
||||
}
|
||||
|
||||
|
||||
if (ShouldRunOnCpu())
|
||||
{
|
||||
fprintf(stderr, "Cannot run TrainCifarResnet test on a CPU device.\n");
|
||||
|
|
|
@ -88,6 +88,7 @@
|
|||
<Compile Include="SwigProxyClasses\MinibatchInfo.cs" />
|
||||
<Compile Include="SwigProxyClasses\MinibatchSource.cs" />
|
||||
<Compile Include="SwigProxyClasses\MinibatchSourceConfig.cs" />
|
||||
<Compile Include="SwigProxyClasses\ModelFormat.cs" />
|
||||
<Compile Include="SwigProxyClasses\PaddingMode.cs" />
|
||||
<Compile Include="SwigProxyClasses\PairDoubleDouble.cs" />
|
||||
<Compile Include="SwigProxyClasses\PairFloatFloat.cs" />
|
||||
|
|
|
@ -105,6 +105,7 @@
|
|||
<None Include="cntk\pytest.ini" />
|
||||
<None Include="cntk\sample_installer.py" />
|
||||
<None Include="cntk\tensor.py" />
|
||||
<None Include="cntk\tests\onnx_format_test.py" />
|
||||
<None Include="cntk\tests\user_learner.py" />
|
||||
<None Include="cntk\tests\variables_test.py" />
|
||||
<None Include="cntk\train\distributed.py" />
|
||||
|
@ -161,4 +162,4 @@
|
|||
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
|
||||
<ImportGroup Label="ExtensionTargets">
|
||||
</ImportGroup>
|
||||
</Project>
|
||||
</Project>
|
|
@ -361,6 +361,9 @@
|
|||
<None Include="cntk\ops\tests\sequence_test.py">
|
||||
<Filter>cntk\ops\tests</Filter>
|
||||
</None>
|
||||
<None Include="cntk\tests\onnx_format_test.py">
|
||||
<Filter>cntk\tests</Filter>
|
||||
</None>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="cntk\cntk_py_wrap.h">
|
||||
|
|
|
@ -12,7 +12,7 @@ from __future__ import print_function
|
|||
import numpy as np
|
||||
import numbers
|
||||
from . import sequence
|
||||
from .functions import CloneMethod, Function, BlockFunction, load_model, register_native_user_function, native_user_function
|
||||
from .functions import ModelFormat, CloneMethod, Function, BlockFunction, load_model, register_native_user_function, native_user_function
|
||||
from cntk.internal import sanitize_input, sanitize_shape, sanitize_axis, sanitize_dynamic_axes, sanitize_axis_list, sanitize_multi_axis_reduction_list, typemap, sanitize_pooling_args, sanitize_convolution_args, sanitize_permutation
|
||||
from cntk.internal.utils import get_data_type
|
||||
from ..axis import Axis
|
||||
|
@ -517,6 +517,25 @@ def batch_normalization(operand, scale, bias, running_mean, running_inv_std, spa
|
|||
normalization_time_constant, blend_time_constant,
|
||||
epsilon, use_cudnn_engine, name)
|
||||
|
||||
@typemap
|
||||
def local_response_normalization(operand, depth_radius, bias, alpha, beta, name=''):
|
||||
'''
|
||||
Local Response Normalization layer.
|
||||
|
||||
Args:
|
||||
operand: input of the Local Response Normalization.
|
||||
depth_radius (int): the radius on the channel dimension to apply the normalization.
|
||||
bias (double): a bias term to avoid divide by zero.
|
||||
alpha (double): the alpha term of the above equation.
|
||||
beta (double): the beta term of the above equation.
|
||||
name (str, optional): the name of the Function instance in the network.
|
||||
Returns:
|
||||
:class:`~cntk.ops.functions.Function`
|
||||
'''
|
||||
from cntk.cntk_py import local_response_normalization
|
||||
operand = sanitize_input(operand)
|
||||
return local_response_normalization(operand, depth_radius, bias, alpha, beta, name)
|
||||
|
||||
##########################################################################
|
||||
# comparison ops
|
||||
##########################################################################
|
||||
|
@ -852,8 +871,6 @@ def element_times(left, right, name=''):
|
|||
right = sanitize_input(right, dtype)
|
||||
return cntk_py_element_times(left, right, name)
|
||||
|
||||
|
||||
# TODO: move element_max/min to C++
|
||||
@associative_multi_arg
|
||||
@typemap
|
||||
def element_max(left, right, name=''):
|
||||
|
@ -869,10 +886,11 @@ def element_max(left, right, name=''):
|
|||
Returns:
|
||||
:class:`~cntk.ops.functions.Function`
|
||||
'''
|
||||
gt = greater(left, right)
|
||||
# TODO: use as_block()
|
||||
return element_select(gt, left, right, name)
|
||||
|
||||
from cntk.cntk_py import element_max as cntk_py_element_max
|
||||
dtype = get_data_type(left, right)
|
||||
left = sanitize_input(left, dtype)
|
||||
right = sanitize_input(right, dtype)
|
||||
return cntk_py_element_max(left, right, name)
|
||||
|
||||
@associative_multi_arg
|
||||
@typemap
|
||||
|
@ -889,10 +907,11 @@ def element_min(left, right, name=''):
|
|||
Returns:
|
||||
:class:`~cntk.ops.functions.Function`
|
||||
'''
|
||||
lt = less(left, right)
|
||||
# TODO: use as_block()
|
||||
return element_select(lt, left, right, name)
|
||||
|
||||
from cntk.cntk_py import element_min as cntk_py_element_min
|
||||
dtype = get_data_type(left, right)
|
||||
left = sanitize_input(left, dtype)
|
||||
right = sanitize_input(right, dtype)
|
||||
return cntk_py_element_min(left, right, name)
|
||||
|
||||
@typemap
|
||||
def element_divide(left, right, name=''):
|
||||
|
|
|
@ -26,6 +26,24 @@ from cntk.internal.sanitize import is_byte_buffer
|
|||
from ..variables import Record, Variable
|
||||
|
||||
|
||||
|
||||
@unique
|
||||
class ModelFormat(Enum):
|
||||
'''
|
||||
Describes the supported disk format for CNTK model.
|
||||
'''
|
||||
|
||||
CNTKv2 = cntk_py.ModelFormat_CNTKv2
|
||||
'''
|
||||
Default CNTK version 2 format, it supports all CNTK functionalities.
|
||||
'''
|
||||
|
||||
ONNX = cntk_py.ModelFormat_ONNX
|
||||
'''
|
||||
Open Neural Network Exchange format from https://github.com/onnx/onnx, ONNX currently support
|
||||
subset of CNTK functionalities.
|
||||
'''
|
||||
|
||||
@unique
|
||||
class CloneMethod(Enum):
|
||||
'''
|
||||
|
@ -1435,10 +1453,9 @@ class Function(cntk_py.Function):
|
|||
return collector.test_summaries[-1]
|
||||
|
||||
@typemap
|
||||
def save(self, filename):
|
||||
def save(self, filename, format=ModelFormat.CNTKv2):
|
||||
'''
|
||||
Save this function graph into a model file using protobuf-based
|
||||
serialization.
|
||||
Save this function graph into a model file using the specified format.
|
||||
|
||||
Use distributed.Communicator.is_main() to gate your call to save()
|
||||
in distributed environment.
|
||||
|
@ -1446,7 +1463,7 @@ class Function(cntk_py.Function):
|
|||
Args:
|
||||
filename (str): model path
|
||||
'''
|
||||
return super(Function, self).save(filename)
|
||||
return super(Function, self).save(filename, format.value)
|
||||
|
||||
@typemap
|
||||
def restore(self, filename):
|
||||
|
@ -1489,7 +1506,7 @@ class Function(cntk_py.Function):
|
|||
|
||||
@staticmethod
|
||||
@typemap
|
||||
def load(model, device=None):
|
||||
def load(model, device=None, format=ModelFormat.CNTKv2):
|
||||
'''
|
||||
Load the ``model``, that has been saved using :func:`~cntk.ops.functions.Function.save`.
|
||||
|
||||
|
@ -1498,6 +1515,8 @@ class Function(cntk_py.Function):
|
|||
containing the binary representation of a model.
|
||||
device (:class:`~cntk.device.DeviceDescriptor`, defaults to the current globally default device):
|
||||
specifies the device to allocate the model on.
|
||||
format (:class:`~cntk.ModelFormat`, defaults to CNTKv2 format): specifies the format of the file to load.
|
||||
if the specified format is ONNX, then model must be a filename.
|
||||
|
||||
Returns:
|
||||
root node
|
||||
|
@ -1515,10 +1534,12 @@ class Function(cntk_py.Function):
|
|||
pass
|
||||
|
||||
if is_buffer:
|
||||
if format != ModelFormat.CNTKv2:
|
||||
raise ValueError('Loading from buffer only supported for CNTKv2 format.')
|
||||
return cntk_py.Function.load_from_buffer(model, device)
|
||||
|
||||
if is_file:
|
||||
return cntk_py.Function.load(str(model), device)
|
||||
return cntk_py.Function.load(str(model), device, format.value)
|
||||
|
||||
raise ValueError('Cannot load the model {} that is neither a file nor a byte buffer.'.format(model))
|
||||
|
||||
|
@ -1600,11 +1621,11 @@ def native_user_function(op_id, operands, attributes=None, user_function_instanc
|
|||
return cntk_py.Function_native_user_function(op_id, operands, attributes, user_function_instance_name)
|
||||
|
||||
@typemap
|
||||
def load_model(model, device=None):
|
||||
def load_model(model, device=None, format=ModelFormat.CNTKv2):
|
||||
'''
|
||||
Alias for :func:`~cntk.ops.functions.Function.load`.
|
||||
'''
|
||||
return Function.load(model, device)
|
||||
return Function.load(model, device, format)
|
||||
|
||||
class UserFunction(Function):
|
||||
'''
|
||||
|
|
|
@ -613,6 +613,36 @@ def test_op_batch_normalization(use_cudnn, sample, device_id, precision):
|
|||
|
||||
unittest_helper(op_node, forward_input, expected_forward, expected_backward=None, device_id=device_id, precision=precision)
|
||||
|
||||
def test_local_response_normalization(device_id, precision):
|
||||
dtype = PRECISION_TO_TYPE[precision]
|
||||
dev = cntk_device(device_id)
|
||||
|
||||
def lrn(x, depth_radius, bias, alpha, beta, name=''):
|
||||
x2 = C.square(x)
|
||||
# reshape to insert a fake singleton reduction dimension after the 3th axis (channel axis). Note Python axis order and BrainScript are reversed.
|
||||
x2s = C.reshape(x2, (1, C.InferredDimension), 0, 1)
|
||||
W = C.constant(alpha/(2*depth_radius+1), shape=(1,2*depth_radius+1,1,1), dtype=dtype, name='W')
|
||||
# 3D convolution with a filter that has a non 1-size only in the 3rd axis, and does not reduce since the reduction dimension is fake and 1
|
||||
y = C.convolution (W, x2s)
|
||||
# reshape back to remove the fake singleton reduction dimension
|
||||
b = C.reshape(y, C.InferredDimension, 0, 2)
|
||||
den = C.exp(beta * C.log(bias + b))
|
||||
return C.element_divide(x, den)
|
||||
|
||||
from cntk import local_response_normalization
|
||||
|
||||
img_shape = (64, 32, 32)
|
||||
img = np.asarray(np.random.uniform(-1, 1, img_shape), dtype=dtype)
|
||||
x_gt = C.input_variable(shape=img_shape, dtype=dtype)
|
||||
x_r = C.input_variable(shape=img_shape, dtype=dtype)
|
||||
|
||||
gt = lrn(x_gt, 2, 1.0, 0.0001, 0.75)
|
||||
r = local_response_normalization(x_r, 2, 1.0, 0.0001, 0.75)
|
||||
ss = gt.eval({x_gt:img})
|
||||
sa = r.eval({x_r:img})
|
||||
|
||||
assert np.allclose(r.eval({x_r:img}), gt.eval({x_gt:img}))
|
||||
|
||||
TENSOR_PAIRS = [
|
||||
([0.3], [0.1]),
|
||||
([[0.1]], [[0.3]]),
|
||||
|
|
|
@ -27,7 +27,8 @@ def test_convolution_attributes():
|
|||
'upperPad': (0, 0, 0),
|
||||
'lowerPad': (0, 0, 0),
|
||||
'transpose': False,
|
||||
'outputShape': (0,)
|
||||
'outputShape': (0,),
|
||||
'kernelShape': (1, 2, 2)
|
||||
}
|
||||
_check(expected, d)
|
||||
|
||||
|
@ -41,7 +42,8 @@ def test_convolution_attributes():
|
|||
'upperPad': (0, 0, 0),
|
||||
'lowerPad': (0, 0, 0),
|
||||
'transpose': False,
|
||||
'outputShape': (0,)
|
||||
'outputShape': (0,),
|
||||
'kernelShape': (1, 2, 2)
|
||||
}
|
||||
_check(expected, d)
|
||||
|
||||
|
@ -59,7 +61,8 @@ def test_convolution_transpose_attributes():
|
|||
'upperPad': (0, 0, 0),
|
||||
'lowerPad': (0, 0, 0),
|
||||
'transpose': True,
|
||||
'outputShape': (0,)
|
||||
'outputShape': (0,),
|
||||
'kernelShape': (1, 2, 2)
|
||||
}
|
||||
_check(expected, d)
|
||||
|
||||
|
|
|
@ -0,0 +1,280 @@
|
|||
# Copyright (c) Microsoft. All rights reserved.
|
||||
# Licensed under the MIT license. See LICENSE.md file in the project root
|
||||
# for full license information.
|
||||
# ==============================================================================
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import cntk as C
|
||||
|
||||
def test_load_save_constant(tmpdir):
|
||||
# import pdb;pdb.set_trace()
|
||||
c = C.constant(value=[1,3])
|
||||
root_node = c * 5
|
||||
|
||||
result = root_node.eval()
|
||||
expected = [[[[5,15]]]]
|
||||
assert np.allclose(result, expected)
|
||||
|
||||
filename = os.path.join(str(tmpdir), R'c_plus_c.onnx')
|
||||
root_node.save(filename, format=C.ModelFormat.ONNX)
|
||||
|
||||
loaded_node = C.Function.load(filename, format=C.ModelFormat.ONNX)
|
||||
assert root_node.shape == loaded_node.shape
|
||||
|
||||
loaded_result = loaded_node.eval()
|
||||
assert np.allclose(loaded_result, expected)
|
||||
|
||||
def test_dense_layer(tmpdir):
|
||||
img_shape = (1, 5, 5)
|
||||
img = np.asarray(np.random.uniform(-1, 1, img_shape), dtype=np.float32)
|
||||
|
||||
x = C.input_variable(img.shape)
|
||||
root_node = C.layers.Dense(5, activation=C.softmax)(x)
|
||||
|
||||
filename = os.path.join(str(tmpdir), R'dense_layer.onnx')
|
||||
root_node.save(filename, format=C.ModelFormat.ONNX)
|
||||
|
||||
loaded_node = C.Function.load(filename, format=C.ModelFormat.ONNX)
|
||||
assert root_node.shape == loaded_node.shape
|
||||
|
||||
x_ = loaded_node.arguments[0]
|
||||
assert np.allclose(loaded_node.eval({x_:img}), root_node.eval({x:img}))
|
||||
|
||||
def test_convolution(tmpdir):
|
||||
img_shape = (1, 5, 5)
|
||||
img = np.asarray(np.random.uniform(-1, 1, img_shape), dtype=np.float32)
|
||||
|
||||
x = C.input_variable(img.shape)
|
||||
filter = np.reshape(np.array([2, -1, -1, 2], dtype = np.float32), (1, 2, 2))
|
||||
kernel = C.constant(value = filter)
|
||||
root_node = C.convolution(kernel, x, auto_padding=[False])
|
||||
|
||||
filename = os.path.join(str(tmpdir), R'conv.onnx')
|
||||
root_node.save(filename, format=C.ModelFormat.ONNX)
|
||||
|
||||
loaded_node = C.Function.load(filename, format=C.ModelFormat.ONNX)
|
||||
assert root_node.shape == loaded_node.shape
|
||||
|
||||
x_ = loaded_node.arguments[0]
|
||||
assert np.allclose(loaded_node.eval({x_:[img]}), root_node.eval({x:[img]}))
|
||||
|
||||
def test_convolution_transpose(tmpdir):
|
||||
img_shape = (1, 3, 3)
|
||||
img = np.asarray(np.random.uniform(-1, 1, img_shape), dtype=np.float32)
|
||||
|
||||
x = C.input_variable(img.shape)
|
||||
filter = np.reshape(np.array([2, -1, -1, 2], dtype = np.float32), (1, 2, 2))
|
||||
kernel = C.constant(value = filter)
|
||||
root_node = C.convolution_transpose(kernel, x, auto_padding=[False], output_shape=(1, 4, 4))
|
||||
|
||||
filename = os.path.join(str(tmpdir), R'conv_transpose.onnx')
|
||||
root_node.save(filename, format=C.ModelFormat.ONNX)
|
||||
|
||||
loaded_node = C.Function.load(filename, format=C.ModelFormat.ONNX)
|
||||
assert root_node.shape == loaded_node.shape
|
||||
|
||||
x_ = loaded_node.arguments[0]
|
||||
assert np.allclose(loaded_node.eval({x_:[img]}), root_node.eval({x:[img]}))
|
||||
|
||||
def test_conv_model(tmpdir):
|
||||
def create_model(input):
|
||||
with C.layers.default_options(init=C.glorot_uniform(), activation=C.relu):
|
||||
model = C.layers.Sequential([
|
||||
C.layers.For(range(3), lambda i: [
|
||||
C.layers.Convolution((5,5), [32,32,64][i], pad=True),
|
||||
C.layers.MaxPooling((3,3), strides=(2,2))
|
||||
]),
|
||||
C.layers.Dense(64),
|
||||
C.layers.Dense(10, activation=None)
|
||||
])
|
||||
|
||||
return model(input)
|
||||
|
||||
img_shape = (3, 32, 32)
|
||||
img = np.asarray(np.random.uniform(-1, 1, img_shape), dtype=np.float32)
|
||||
|
||||
x = C.input_variable(img.shape)
|
||||
root_node = create_model(x)
|
||||
|
||||
filename = os.path.join(str(tmpdir), R'conv_model.onnx')
|
||||
root_node.save(filename, format=C.ModelFormat.ONNX)
|
||||
|
||||
loaded_node = C.Function.load(filename, format=C.ModelFormat.ONNX)
|
||||
assert root_node.shape == loaded_node.shape
|
||||
|
||||
x_ = loaded_node.arguments[0]
|
||||
assert np.allclose(loaded_node.eval({x_:img}), root_node.eval({x:img}))
|
||||
|
||||
def test_batch_norm_model(tmpdir):
|
||||
image_height = 32
|
||||
image_width = 32
|
||||
num_channels = 3
|
||||
num_classes = 10
|
||||
|
||||
input_var = C.input_variable((num_channels, image_height, image_width))
|
||||
label_var = C.input_variable((num_classes))
|
||||
def create_basic_model_with_batch_normalization(input, out_dims):
|
||||
with C.layers.default_options(activation=C.relu, init=C.glorot_uniform()):
|
||||
model = C.layers.Sequential([
|
||||
C.layers.For(range(3), lambda i: [
|
||||
C.layers.Convolution((5,5), [image_width,image_height,64][i], pad=True),
|
||||
C.layers.BatchNormalization(map_rank=1),
|
||||
C.layers.MaxPooling((3,3), strides=(2,2))
|
||||
]),
|
||||
C.layers.Dense(64),
|
||||
C.layers.BatchNormalization(map_rank=1),
|
||||
C.layers.Dense(out_dims, activation=None)
|
||||
])
|
||||
|
||||
return model(input)
|
||||
|
||||
feature_scale = 1.0 / 256.0
|
||||
input_var_norm = C.element_times(feature_scale, input_var)
|
||||
|
||||
# apply model to input
|
||||
z = create_basic_model_with_batch_normalization(input_var_norm, out_dims=10)
|
||||
|
||||
filename = os.path.join(str(tmpdir), R'bn_model.onnx')
|
||||
z.save(filename, format=C.ModelFormat.ONNX)
|
||||
|
||||
loaded_node = C.Function.load(filename, format=C.ModelFormat.ONNX)
|
||||
assert z.shape == loaded_node.shape
|
||||
|
||||
img_shape = (num_channels, image_width, image_height)
|
||||
img = np.asarray(np.random.uniform(-1, 1, img_shape), dtype=np.float32)
|
||||
|
||||
x = z.arguments[0];
|
||||
x_ = loaded_node.arguments[0]
|
||||
assert np.allclose(loaded_node.eval({x_:img}), z.eval({x:img}))
|
||||
|
||||
def test_vgg9_model(tmpdir):
|
||||
def create_model(input):
|
||||
with C.layers.default_options(activation=C.relu, init=C.glorot_uniform()):
|
||||
model = C.layers.Sequential([
|
||||
C.layers.For(range(3), lambda i: [
|
||||
C.layers.Convolution((3,3), [64,96,128][i], pad=True),
|
||||
C.layers.Convolution((3,3), [64,96,128][i], pad=True),
|
||||
C.layers.MaxPooling((3,3), strides=(2,2))
|
||||
]),
|
||||
C.layers.For(range(2), lambda : [
|
||||
C.layers.Dense(1024)
|
||||
]),
|
||||
C.layers.Dense(10, activation=None)
|
||||
])
|
||||
|
||||
return model(input)
|
||||
|
||||
img_shape = (3, 32, 32)
|
||||
img = np.asarray(np.random.uniform(-1, 1, img_shape), dtype=np.float32)
|
||||
|
||||
x = C.input_variable(img.shape)
|
||||
root_node = create_model(x)
|
||||
|
||||
filename = os.path.join(str(tmpdir), R'vgg9_model.onnx')
|
||||
root_node.save(filename, format=C.ModelFormat.ONNX)
|
||||
|
||||
loaded_node = C.Function.load(filename, format=C.ModelFormat.ONNX)
|
||||
assert root_node.shape == loaded_node.shape
|
||||
|
||||
x_ = loaded_node.arguments[0]
|
||||
assert np.allclose(loaded_node.eval({x_:img}), root_node.eval({x:img}))
|
||||
|
||||
def test_conv3d_model(tmpdir):
|
||||
def create_model(input):
|
||||
with C.default_options (activation=C.relu):
|
||||
model = C.layers.Sequential([
|
||||
C.layers.Convolution3D((3,3,3), 64, pad=True),
|
||||
C.layers.MaxPooling((1,2,2), (1,2,2)),
|
||||
C.layers.For(range(3), lambda i: [
|
||||
C.layers.Convolution3D((3,3,3), [96, 128, 128][i], pad=True),
|
||||
C.layers.Convolution3D((3,3,3), [96, 128, 128][i], pad=True),
|
||||
C.layers.MaxPooling((2,2,2), (2,2,2))
|
||||
]),
|
||||
C.layers.For(range(2), lambda : [
|
||||
C.layers.Dense(1024),
|
||||
C.layers.Dropout(0.5)
|
||||
]),
|
||||
C.layers.Dense(100, activation=None)
|
||||
])
|
||||
|
||||
return model(input)
|
||||
|
||||
video_shape = (3, 20, 32, 32)
|
||||
video = np.asarray(np.random.uniform(-1, 1, video_shape), dtype=np.float32)
|
||||
|
||||
x = C.input_variable(video.shape)
|
||||
root_node = create_model(x)
|
||||
|
||||
filename = os.path.join(str(tmpdir), R'conv3d_model.onnx')
|
||||
root_node.save(filename, format=C.ModelFormat.ONNX)
|
||||
|
||||
loaded_node = C.Function.load(filename, format=C.ModelFormat.ONNX)
|
||||
assert root_node.shape == loaded_node.shape
|
||||
|
||||
x_ = loaded_node.arguments[0]
|
||||
assert np.allclose(loaded_node.eval({x_:video}), root_node.eval({x:video}))
|
||||
|
||||
def test_resnet_model(tmpdir):
|
||||
def convolution_bn(input, filter_size, num_filters, strides=(1,1), init=C.normal(0.01), activation=C.relu):
|
||||
r = C.layers.Convolution(filter_size,
|
||||
num_filters,
|
||||
strides=strides,
|
||||
init=init,
|
||||
activation=None,
|
||||
pad=True, bias=False)(input)
|
||||
r = C.layers.BatchNormalization(map_rank=1)(r)
|
||||
r = r if activation is None else activation(r)
|
||||
return r
|
||||
|
||||
def resnet_basic(input, num_filters):
|
||||
c1 = convolution_bn(input, (3,3), num_filters)
|
||||
c2 = convolution_bn(c1, (3,3), num_filters, activation=None)
|
||||
p = c2 + input
|
||||
return C.relu(p)
|
||||
|
||||
def resnet_basic_inc(input, num_filters):
|
||||
c1 = convolution_bn(input, (3,3), num_filters, strides=(2,2))
|
||||
c2 = convolution_bn(c1, (3,3), num_filters, activation=None)
|
||||
|
||||
s = convolution_bn(input, (1,1), num_filters, strides=(2,2), activation=None)
|
||||
|
||||
p = c2 + s
|
||||
return C.relu(p)
|
||||
|
||||
def resnet_basic_stack(input, num_filters, num_stack):
|
||||
assert (num_stack > 0)
|
||||
|
||||
r = input
|
||||
for _ in range(num_stack):
|
||||
r = resnet_basic(r, num_filters)
|
||||
return r
|
||||
|
||||
def create_model(input):
|
||||
conv = convolution_bn(input, (3,3), 16)
|
||||
r1_1 = resnet_basic_stack(conv, 16, 3)
|
||||
|
||||
r2_1 = resnet_basic_inc(r1_1, 32)
|
||||
r2_2 = resnet_basic_stack(r2_1, 32, 2)
|
||||
|
||||
r3_1 = resnet_basic_inc(r2_2, 64)
|
||||
r3_2 = resnet_basic_stack(r3_1, 64, 2)
|
||||
|
||||
# Global average pooling
|
||||
pool = C.layers.AveragePooling(filter_shape=(8,8), strides=(1,1))(r3_2)
|
||||
return C.layers.Dense(10, init=C.normal(0.01), activation=None)(pool)
|
||||
|
||||
img_shape = (3, 32, 32)
|
||||
img = np.asarray(np.random.uniform(0, 1, img_shape), dtype=np.float32)
|
||||
|
||||
x = C.input_variable(img.shape)
|
||||
root_node = create_model(x)
|
||||
|
||||
filename = os.path.join(str(tmpdir), R'resnet_model.onnx')
|
||||
root_node.save(filename, format=C.ModelFormat.ONNX)
|
||||
|
||||
loaded_node = C.Function.load(filename, format=C.ModelFormat.ONNX)
|
||||
assert root_node.shape == loaded_node.shape
|
||||
|
||||
x_ = loaded_node.arguments[0]
|
||||
assert np.allclose(loaded_node.eval({x_:img}), root_node.eval({x:img}))
|
Загрузка…
Ссылка в новой задаче