This commit is contained in:
Wenbing Li 2020-10-12 10:52:26 -07:00
Родитель 126b59b78b
Коммит da95244190
9 изменённых файлов: 2691 добавлений и 3 удалений

22
.clang-format Normal file
Просмотреть файл

@ -0,0 +1,22 @@
---
# Defaults for all languages.
BasedOnStyle: Google
# Setting ColumnLimit to 0 so developer choices about where to break lines are maintained.
# Developers are responsible for adhering to the 120 character maximum.
ColumnLimit: 0
SortIncludes: false
DerivePointerAlignment: false
# if you want to customize when working locally see https://clang.llvm.org/docs/ClangFormatStyleOptions.html for options.
# See ReformatSource.ps1 for a script to update all source according to the current options in this file.
# e.g. customizations to use Allman bracing and more indenting.
# AccessModifierOffset: -2
# BreakBeforeBraces: Allman
# CompactNamespaces: false
# IndentCaseLabels: true
# IndentWidth: 4
# NamespaceIndentation: All
...

30
.clang-tidy Normal file
Просмотреть файл

@ -0,0 +1,30 @@
---
# turn off readability-braces-around-statements to allow single line statement like 'if (x == y) doSomething();'
Checks: '-*,cppcoreguidelines-*,google-*,readability-*,modernize-*,-readability-braces-around-statements,-google-runtime-references,-cppcoreguidelines-pro-type-reinterpret-cast'
WarningsAsErrors: ''
HeaderFilterRegex: '.*onnxruntime\/core\/.*'
AnalyzeTemporaryDtors: false
FormatStyle: none
CheckOptions:
- key: google-readability-braces-around-statements.ShortStatementLines
value: '1'
- key: google-readability-function-size.StatementThreshold
value: '800'
- key: google-readability-namespace-comments.ShortNamespaceLines
value: '10'
- key: google-readability-namespace-comments.SpacesBeforeComments
value: '2'
- key: modernize-loop-convert.MaxCopySize
value: '16'
- key: modernize-loop-convert.MinConfidence
value: reasonable
- key: modernize-loop-convert.NamingStyle
value: CamelCase
- key: modernize-pass-by-value.IncludeStyle
value: google
- key: modernize-replace-auto-ptr.IncludeStyle
value: google
- key: modernize-use-nullptr.NullMacros
value: 'NULL'
...

5
.flake8 Normal file
Просмотреть файл

@ -0,0 +1,5 @@
[flake8]
max-line-length = 120
per-file-ignores =
__init__.py:F401
format = [flake8 PEP8 ERROR] %(path)s:%(row)d:%(col)d: %(code)s %(text)s

27
.gitignore поставляемый Normal file
Просмотреть файл

@ -0,0 +1,27 @@
# build, distribute, and bins (+ python proto bindings)
mybuild.*
build
build_host_protoc
build_android
build_ios
build_*
.build_debug/*
.build_release/*
distribute/*
*.testbin
*.bin
cmake_build
.cmake_build
cmake-build-debug
gen
*~
.vs
TestResults/
.idea/
onnxruntime.egg-info
nuget_root/
.packages/
.vscode/
*.code-workspace
__pycache__
out/

Просмотреть файл

@ -1,12 +1,17 @@
# Introduction
The onnxruntime-customops package is an onnxuntime custom op library which supports the ONNX model inference with non-standard ONNX operators. Besides, and the custom op also is implemented with python function.
# License
[MIT License](LICENSE)
# Contributing
This project welcomes contributions and suggestions. Most contributions require you to agree to a
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
the rights to use your contribution. For details, visit https://cla.microsoft.com.
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
When you submit a pull request, a CLA-bot will automatically determine whether you need to provide
a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the instructions
provided by the bot. You will only need to do this once across all repos using our CLA.
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -0,0 +1,538 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// Summary: The Ort C++ API is a header only wrapper around the Ort C API.
//
// The C++ API simplifies usage by returning values directly instead of error codes, throwing exceptions on errors
// and automatically releasing resources in the destructors.
//
// Each of the C++ wrapper classes holds only a pointer to the C internal object. Treat them like smart pointers.
// To create an empty object, pass 'nullptr' to the constructor (for example, Env e{nullptr};).
//
// Only move assignment between objects is allowed, there are no copy constructors. Some objects have explicit 'Clone'
// methods for this purpose.
#pragma once
#include "onnxruntime_c_api.h"
#include <cstddef>
#include <array>
#include <memory>
#include <stdexcept>
#include <string>
#include <vector>
#include <utility>
#include <type_traits>
#ifdef ORT_NO_EXCEPTIONS
#include <iostream>
#endif
namespace Ort {
// All C++ methods that can fail will throw an exception of this type
struct Exception : std::exception {
Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {}
OrtErrorCode GetOrtErrorCode() const { return code_; }
const char* what() const noexcept override { return message_.c_str(); }
private:
std::string message_;
OrtErrorCode code_;
};
#ifdef ORT_NO_EXCEPTIONS
#define ORT_CXX_API_THROW(string, code) \
do { \
std::cerr << Ort::Exception(string, code) \
.what() \
<< std::endl; \
abort(); \
} while (false)
#else
#define ORT_CXX_API_THROW(string, code) \
throw Ort::Exception(string, code)
#endif
// This is used internally by the C++ API. This class holds the global variable that points to the OrtApi, it's in a template so that we can define a global variable in a header and make
// it transparent to the users of the API.
template <typename T>
struct Global {
static const OrtApi* api_;
};
// If macro ORT_API_MANUAL_INIT is defined, no static initialization will be performed. Instead, user must call InitApi() before using it.
template <typename T>
#ifdef ORT_API_MANUAL_INIT
const OrtApi* Global<T>::api_{};
inline void InitApi() { Global<void>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); }
#else
const OrtApi* Global<T>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION);
#endif
// This returns a reference to the OrtApi interface in use, in case someone wants to use the C API functions
inline const OrtApi& GetApi() { return *Global<void>::api_; }
// This is a C++ wrapper for GetAvailableProviders() C API and returns
// a vector of strings representing the available execution providers.
std::vector<std::string> GetAvailableProviders();
// This is used internally by the C++ API. This macro is to make it easy to generate overloaded methods for all of the various OrtRelease* functions for every Ort* type
// This can't be done in the C API since C doesn't have function overloading.
#define ORT_DEFINE_RELEASE(NAME) \
inline void OrtRelease(Ort##NAME* ptr) { GetApi().Release##NAME(ptr); }
ORT_DEFINE_RELEASE(Allocator);
ORT_DEFINE_RELEASE(MemoryInfo);
ORT_DEFINE_RELEASE(CustomOpDomain);
ORT_DEFINE_RELEASE(Env);
ORT_DEFINE_RELEASE(RunOptions);
ORT_DEFINE_RELEASE(Session);
ORT_DEFINE_RELEASE(SessionOptions);
ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo);
ORT_DEFINE_RELEASE(SequenceTypeInfo);
ORT_DEFINE_RELEASE(MapTypeInfo);
ORT_DEFINE_RELEASE(TypeInfo);
ORT_DEFINE_RELEASE(Value);
ORT_DEFINE_RELEASE(ModelMetadata);
ORT_DEFINE_RELEASE(ThreadingOptions);
ORT_DEFINE_RELEASE(IoBinding);
// This is used internally by the C++ API. This is the common base class used by the wrapper objects.
template <typename T>
struct Base {
using contained_type = T;
Base() = default;
Base(T* p) : p_{p} {
if (!p)
ORT_CXX_API_THROW("Allocation failure", ORT_FAIL);
}
~Base() { OrtRelease(p_); }
operator T*() { return p_; }
operator const T*() const { return p_; }
T* release() {
T* p = p_;
p_ = nullptr;
return p;
}
protected:
Base(const Base&) = delete;
Base& operator=(const Base&) = delete;
Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
void operator=(Base&& v) noexcept {
OrtRelease(p_);
p_ = v.p_;
v.p_ = nullptr;
}
T* p_{};
template <typename>
friend struct Unowned; // This friend line is needed to keep the centos C++ compiler from giving an error
};
template <typename T>
struct Base<const T> {
using contained_type = const T;
Base() = default;
Base(const T* p) : p_{p} {
if (!p)
ORT_CXX_API_THROW("Invalid instance ptr", ORT_INVALID_ARGUMENT);
}
~Base() = default;
operator const T*() const { return p_; }
protected:
Base(const Base&) = delete;
Base& operator=(const Base&) = delete;
Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
void operator=(Base&& v) noexcept {
p_ = v.p_;
v.p_ = nullptr;
}
const T* p_{};
};
template <typename T>
struct Unowned : T {
Unowned(decltype(T::p_) p) : T{p} {}
Unowned(Unowned&& v) : T{v.p_} {}
~Unowned() { this->release(); }
};
struct AllocatorWithDefaultOptions;
struct MemoryInfo;
struct Env;
struct TypeInfo;
struct Value;
struct ModelMetadata;
struct Env : Base<OrtEnv> {
Env(std::nullptr_t) {}
Env(OrtLoggingLevel default_logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel default_logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
Env(OrtLoggingLevel default_logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param);
explicit Env(OrtEnv* p) : Base<OrtEnv>{p} {}
Env& EnableTelemetryEvents();
Env& DisableTelemetryEvents();
Env& CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg);
static const OrtApi* s_api;
};
struct CustomOpDomain : Base<OrtCustomOpDomain> {
explicit CustomOpDomain(std::nullptr_t) {}
explicit CustomOpDomain(const char* domain);
void Add(OrtCustomOp* op);
};
struct RunOptions : Base<OrtRunOptions> {
RunOptions(std::nullptr_t) {}
RunOptions();
RunOptions& SetRunLogVerbosityLevel(int);
int GetRunLogVerbosityLevel() const;
RunOptions& SetRunLogSeverityLevel(int);
int GetRunLogSeverityLevel() const;
RunOptions& SetRunTag(const char* run_tag);
const char* GetRunTag() const;
// terminate ALL currently executing Session::Run calls that were made using this RunOptions instance
RunOptions& SetTerminate();
// unset the terminate flag so this RunOptions instance can be used in a new Session::Run call
RunOptions& UnsetTerminate();
};
struct SessionOptions : Base<OrtSessionOptions> {
explicit SessionOptions(std::nullptr_t) {}
SessionOptions();
explicit SessionOptions(OrtSessionOptions* p) : Base<OrtSessionOptions>{p} {}
SessionOptions Clone() const;
SessionOptions& SetIntraOpNumThreads(int intra_op_num_threads);
SessionOptions& SetInterOpNumThreads(int inter_op_num_threads);
SessionOptions& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level);
SessionOptions& EnableCpuMemArena();
SessionOptions& DisableCpuMemArena();
SessionOptions& SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_file);
SessionOptions& EnableProfiling(const ORTCHAR_T* profile_file_prefix);
SessionOptions& DisableProfiling();
SessionOptions& EnableMemPattern();
SessionOptions& DisableMemPattern();
SessionOptions& SetExecutionMode(ExecutionMode execution_mode);
SessionOptions& SetLogId(const char* logid);
SessionOptions& SetLogSeverityLevel(int level);
SessionOptions& Add(OrtCustomOpDomain* custom_op_domain);
SessionOptions& DisablePerSessionThreads();
SessionOptions& AddConfigEntry(const char* config_key, const char* config_value);
};
struct ModelMetadata : Base<OrtModelMetadata> {
explicit ModelMetadata(std::nullptr_t) {}
explicit ModelMetadata(OrtModelMetadata* p) : Base<OrtModelMetadata>{p} {}
char* GetProducerName(OrtAllocator* allocator) const;
char* GetGraphName(OrtAllocator* allocator) const;
char* GetDomain(OrtAllocator* allocator) const;
char* GetDescription(OrtAllocator* allocator) const;
char** GetCustomMetadataMapKeys(OrtAllocator* allocator, _Out_ int64_t& num_keys) const;
char* LookupCustomMetadataMap(const char* key, OrtAllocator* allocator) const;
int64_t GetVersion() const;
};
struct Session : Base<OrtSession> {
explicit Session(std::nullptr_t) {}
Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options);
Session(Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options);
// Run that will allocate the output values
std::vector<Value> Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
const char* const* output_names, size_t output_count);
// Run for when there is a list of prealloated outputs
void Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
const char* const* output_names, Value* output_values, size_t output_count);
void Run(const RunOptions& run_options, const struct IoBinding&);
size_t GetInputCount() const;
size_t GetOutputCount() const;
size_t GetOverridableInitializerCount() const;
char* GetInputName(size_t index, OrtAllocator* allocator) const;
char* GetOutputName(size_t index, OrtAllocator* allocator) const;
char* GetOverridableInitializerName(size_t index, OrtAllocator* allocator) const;
char* EndProfiling(OrtAllocator* allocator) const;
uint64_t GetProfilingStartTimeNs() const;
ModelMetadata GetModelMetadata() const;
TypeInfo GetInputTypeInfo(size_t index) const;
TypeInfo GetOutputTypeInfo(size_t index) const;
TypeInfo GetOverridableInitializerTypeInfo(size_t index) const;
};
struct TensorTypeAndShapeInfo : Base<OrtTensorTypeAndShapeInfo> {
explicit TensorTypeAndShapeInfo(std::nullptr_t) {}
explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : Base<OrtTensorTypeAndShapeInfo>{p} {}
ONNXTensorElementDataType GetElementType() const;
size_t GetElementCount() const;
size_t GetDimensionsCount() const;
void GetDimensions(int64_t* values, size_t values_count) const;
void GetSymbolicDimensions(const char** values, size_t values_count) const;
std::vector<int64_t> GetShape() const;
};
struct SequenceTypeInfo : Base<OrtSequenceTypeInfo> {
explicit SequenceTypeInfo(std::nullptr_t) {}
explicit SequenceTypeInfo(OrtSequenceTypeInfo* p) : Base<OrtSequenceTypeInfo>{p} {}
TypeInfo GetSequenceElementType() const;
};
struct MapTypeInfo : Base<OrtMapTypeInfo> {
explicit MapTypeInfo(std::nullptr_t) {}
explicit MapTypeInfo(OrtMapTypeInfo* p) : Base<OrtMapTypeInfo>{p} {}
ONNXTensorElementDataType GetMapKeyType() const;
TypeInfo GetMapValueType() const;
};
struct TypeInfo : Base<OrtTypeInfo> {
explicit TypeInfo(std::nullptr_t) {}
explicit TypeInfo(OrtTypeInfo* p) : Base<OrtTypeInfo>{p} {}
Unowned<TensorTypeAndShapeInfo> GetTensorTypeAndShapeInfo() const;
Unowned<SequenceTypeInfo> GetSequenceTypeInfo() const;
Unowned<MapTypeInfo> GetMapTypeInfo() const;
ONNXType GetONNXType() const;
};
struct Value : Base<OrtValue> {
template <typename T>
static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len);
static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
ONNXTensorElementDataType type);
template <typename T>
static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len);
static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type);
static Value CreateMap(Value& keys, Value& values);
static Value CreateSequence(std::vector<Value>& values);
template <typename T>
static Value CreateOpaque(const char* domain, const char* type_name, const T&);
template <typename T>
void GetOpaqueData(const char* domain, const char* type_name, T&) const;
explicit Value(std::nullptr_t) {}
explicit Value(OrtValue* p) : Base<OrtValue>{p} {}
Value(Value&&) = default;
Value& operator=(Value&&) = default;
bool IsTensor() const;
size_t GetCount() const; // If a non tensor, returns 2 for map and N for sequence, where N is the number of elements
Value GetValue(int index, OrtAllocator* allocator) const;
size_t GetStringTensorDataLength() const;
void GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const;
template <typename T>
T* GetTensorMutableData();
template <typename T>
const T* GetTensorData() const;
template <typename T>
T& At(const std::vector<int64_t>& location);
TypeInfo GetTypeInfo() const;
TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const;
size_t GetStringTensorElementLength(size_t element_index) const;
void GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const;
void FillStringTensor(const char* const* s, size_t s_len);
void FillStringTensorElement(const char* s, size_t index);
};
// Represents native memory allocation
struct MemoryAllocation {
MemoryAllocation(OrtAllocator* allocator, void* p, size_t size);
~MemoryAllocation();
MemoryAllocation(const MemoryAllocation&) = delete;
MemoryAllocation& operator=(const MemoryAllocation&) = delete;
MemoryAllocation(MemoryAllocation&&);
MemoryAllocation& operator=(MemoryAllocation&&);
void* get() { return p_; }
size_t size() const { return size_; }
private:
OrtAllocator* allocator_;
void* p_;
size_t size_;
};
struct AllocatorWithDefaultOptions {
AllocatorWithDefaultOptions();
operator OrtAllocator*() { return p_; }
operator const OrtAllocator*() const { return p_; }
void* Alloc(size_t size);
// The return value will own the allocation
MemoryAllocation GetAllocation(size_t size);
void Free(void* p);
const OrtMemoryInfo* GetInfo() const;
private:
OrtAllocator* p_{};
};
template <typename B>
struct BaseMemoryInfo : B {
BaseMemoryInfo() = default;
explicit BaseMemoryInfo(typename B::contained_type* p) : B(p) {}
~BaseMemoryInfo() = default;
BaseMemoryInfo(BaseMemoryInfo&&) = default;
BaseMemoryInfo& operator=(BaseMemoryInfo&&) = default;
std::string GetAllocatorName() const;
OrtAllocatorType GetAllocatorType() const;
int GetDeviceId() const;
OrtMemType GetMemoryType() const;
template <typename U>
bool operator==(const BaseMemoryInfo<U>& o) const;
};
struct UnownedMemoryInfo : BaseMemoryInfo<Base<const OrtMemoryInfo> > {
explicit UnownedMemoryInfo(std::nullptr_t) {}
explicit UnownedMemoryInfo(const OrtMemoryInfo* p) : BaseMemoryInfo(p) {}
};
struct MemoryInfo : BaseMemoryInfo<Base<OrtMemoryInfo> > {
static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1);
explicit MemoryInfo(std::nullptr_t) {}
explicit MemoryInfo(OrtMemoryInfo* p) : BaseMemoryInfo(p) {}
MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type);
};
struct Allocator : public Base<OrtAllocator> {
Allocator(const Session& session, const MemoryInfo&);
void* Alloc(size_t size) const;
// The return value will own the allocation
MemoryAllocation GetAllocation(size_t size);
void Free(void* p) const;
UnownedMemoryInfo GetInfo() const;
};
struct IoBinding : public Base<OrtIoBinding> {
private:
std::vector<std::string> GetOutputNamesHelper(OrtAllocator*) const;
std::vector<Value> GetOutputValuesHelper(OrtAllocator*) const;
public:
explicit IoBinding(Session& session);
void BindInput(const char* name, const Value&);
void BindOutput(const char* name, const Value&);
void BindOutput(const char* name, const MemoryInfo&);
std::vector<std::string> GetOutputNames() const;
std::vector<std::string> GetOutputNames(Allocator&) const;
std::vector<Value> GetOutputValues() const;
std::vector<Value> GetOutputValues(Allocator&) const;
void ClearBoundInputs();
void ClearBoundOutputs();
};
//
// Custom OPs (only needed to implement custom OPs)
//
struct CustomOpApi {
CustomOpApi(const OrtApi& api) : api_(api) {}
template <typename T> // T is only implemented for float, int64_t, and string
T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name);
OrtTensorTypeAndShapeInfo* GetTensorTypeAndShape(_In_ const OrtValue* value);
size_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info);
ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo* info);
size_t GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info);
void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length);
void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count);
template <typename T>
T* GetTensorMutableData(_Inout_ OrtValue* value);
template <typename T>
const T* GetTensorData(_Inout_ const OrtValue* value);
std::vector<int64_t> GetTensorShape(const OrtTensorTypeAndShapeInfo* info);
void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input);
size_t KernelContext_GetInputCount(const OrtKernelContext* context);
const OrtValue* KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index);
size_t KernelContext_GetOutputCount(const OrtKernelContext* context);
OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count);
void ThrowOnError(OrtStatus* result);
private:
const OrtApi& api_;
};
template <typename TOp, typename TKernel>
struct CustomOpBase : OrtCustomOp {
CustomOpBase() {
OrtCustomOp::version = ORT_API_VERSION;
OrtCustomOp::CreateKernel = [](OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast<TOp*>(this_)->CreateKernel(*api, info); };
OrtCustomOp::GetName = [](OrtCustomOp* this_) { return static_cast<TOp*>(this_)->GetName(); };
OrtCustomOp::GetExecutionProviderType = [](OrtCustomOp* this_) { return static_cast<TOp*>(this_)->GetExecutionProviderType(); };
OrtCustomOp::GetInputTypeCount = [](OrtCustomOp* this_) { return static_cast<TOp*>(this_)->GetInputTypeCount(); };
OrtCustomOp::GetInputType = [](OrtCustomOp* this_, size_t index) { return static_cast<TOp*>(this_)->GetInputType(index); };
OrtCustomOp::GetOutputTypeCount = [](OrtCustomOp* this_) { return static_cast<TOp*>(this_)->GetOutputTypeCount(); };
OrtCustomOp::GetOutputType = [](OrtCustomOp* this_, size_t index) { return static_cast<TOp*>(this_)->GetOutputType(index); };
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { static_cast<TKernel*>(op_kernel)->Compute(context); };
OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<TKernel*>(op_kernel); };
}
// Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
const char* GetExecutionProviderType() const { return nullptr; }
};
} // namespace Ort
#include "onnxruntime_cxx_inline.h"

Просмотреть файл

@ -0,0 +1,930 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// Do not include this file directly. Please include "onnxruntime_cxx_api.h" instead.
// If interested in trying out features of the new experimental C++ API, include "experimental_onnxruntime_cxx_api.h" instead.
//
// These are the inline implementations of the C++ header APIs. They're in this separate file as to not clutter
// the main C++ file with implementation details.
namespace Ort {
inline void ThrowOnError(const OrtApi& ort, OrtStatus* status) {
if (status) {
std::string error_message = ort.GetErrorMessage(status);
OrtErrorCode error_code = ort.GetErrorCode(status);
ort.ReleaseStatus(status);
ORT_CXX_API_THROW(std::move(error_message), error_code);
}
}
inline void ThrowOnError(OrtStatus* status) {
ThrowOnError(GetApi(), status);
}
// This template converts a C++ type into it's ONNXTensorElementDataType
template <typename T>
struct TypeToTensorType;
template <>
struct TypeToTensorType<float> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
template <>
struct TypeToTensorType<double> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; };
template <>
struct TypeToTensorType<int8_t> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; };
template <>
struct TypeToTensorType<int16_t> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16; };
template <>
struct TypeToTensorType<int32_t> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; };
template <>
struct TypeToTensorType<int64_t> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; };
template <>
struct TypeToTensorType<uint8_t> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; };
template <>
struct TypeToTensorType<uint16_t> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16; };
template <>
struct TypeToTensorType<uint32_t> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32; };
template <>
struct TypeToTensorType<uint64_t> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64; };
template <>
struct TypeToTensorType<bool> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; };
inline MemoryAllocation::MemoryAllocation(OrtAllocator* allocator, void* p, size_t size)
: allocator_(allocator), p_(p), size_(size) {
}
inline MemoryAllocation::~MemoryAllocation() {
if (p_ != nullptr) {
// We do not throw out of destructor
auto ret = GetApi().AllocatorFree(allocator_, p_);
static_cast<void>(ret);
}
}
inline MemoryAllocation::MemoryAllocation(MemoryAllocation&& o) : allocator_(nullptr), p_(nullptr), size_(0) {
*this = std::move(o);
}
inline MemoryAllocation& MemoryAllocation::operator=(MemoryAllocation&& o) {
OrtAllocator* alloc = nullptr;
void* p = nullptr;
size_t sz = 0;
// Swap out this
std::swap(alloc, allocator_);
std::swap(p, p_);
std::swap(sz, size_);
// Swap with incoming
std::swap(allocator_, o.allocator_);
std::swap(p_, o.p_);
std::swap(size_, o.size_);
// Destroy this instance if needed
MemoryAllocation this_alloc(alloc, p, sz);
return *this;
}
inline AllocatorWithDefaultOptions::AllocatorWithDefaultOptions() {
ThrowOnError(GetApi().GetAllocatorWithDefaultOptions(&p_));
}
inline void* AllocatorWithDefaultOptions::Alloc(size_t size) {
void* out;
ThrowOnError(GetApi().AllocatorAlloc(p_, size, &out));
return out;
}
inline MemoryAllocation Ort::AllocatorWithDefaultOptions::GetAllocation(size_t size) {
void* out;
ThrowOnError(GetApi().AllocatorAlloc(p_, size, &out));
MemoryAllocation result(p_, out, size);
return result;
}
inline void AllocatorWithDefaultOptions::Free(void* p) {
ThrowOnError(GetApi().AllocatorFree(p_, p));
}
inline const OrtMemoryInfo* AllocatorWithDefaultOptions::GetInfo() const {
const OrtMemoryInfo* out;
ThrowOnError(GetApi().AllocatorGetInfo(p_, &out));
return out;
}
template <typename B>
inline std::string BaseMemoryInfo<B>::GetAllocatorName() const {
const char* name = nullptr;
ThrowOnError(GetApi().MemoryInfoGetName(*this, &name));
return std::string(name);
}
template <typename B>
inline OrtAllocatorType BaseMemoryInfo<B>::GetAllocatorType() const {
OrtAllocatorType type;
ThrowOnError(GetApi().MemoryInfoGetType(*this, &type));
return type;
}
template <typename B>
int BaseMemoryInfo<B>::GetDeviceId() const {
int id = 0;
ThrowOnError(GetApi().MemoryInfoGetId(*this, &id));
return id;
}
template <typename B>
inline OrtMemType BaseMemoryInfo<B>::GetMemoryType() const {
OrtMemType type;
ThrowOnError(GetApi().MemoryInfoGetMemType(*this, &type));
return type;
}
template <typename B>
template <typename U>
inline bool BaseMemoryInfo<B>::operator==(const BaseMemoryInfo<U>& o) const {
int comp_result = 0;
ThrowOnError(Ort::GetApi().CompareMemoryInfo(*this, o, &comp_result));
return comp_result == 0;
}
inline MemoryInfo MemoryInfo::CreateCpu(OrtAllocatorType type, OrtMemType mem_type) {
OrtMemoryInfo* p;
ThrowOnError(GetApi().CreateCpuMemoryInfo(type, mem_type, &p));
return MemoryInfo(p);
}
inline MemoryInfo::MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) {
ThrowOnError(GetApi().CreateMemoryInfo(name, type, id, mem_type, &p_));
}
inline Allocator::Allocator(const Session& sess, const MemoryInfo& mem_info) {
ThrowOnError(GetApi().CreateAllocator(sess, mem_info, &p_));
}
inline void* Allocator::Alloc(size_t size) const {
void* out = nullptr;
ThrowOnError(GetApi().AllocatorAlloc(p_, size, &out));
return out;
}
inline MemoryAllocation Ort::Allocator::GetAllocation(size_t size) {
void* out = nullptr;
ThrowOnError(GetApi().AllocatorAlloc(p_, size, &out));
MemoryAllocation result(p_, out, size);
return result;
}
inline void Allocator::Free(void* p) const {
ThrowOnError(GetApi().AllocatorFree(p_, p));
}
inline UnownedMemoryInfo Allocator::GetInfo() const {
const OrtMemoryInfo* out = nullptr;
ThrowOnError(GetApi().AllocatorGetInfo(p_, &out));
return UnownedMemoryInfo(out);
}
inline IoBinding::IoBinding(Session& session) {
ThrowOnError(GetApi().CreateIoBinding(session, &p_));
}
inline void IoBinding::BindInput(const char* name, const Value& value) {
ThrowOnError(GetApi().BindInput(p_, name, value));
}
inline void IoBinding::BindOutput(const char* name, const Value& value) {
ThrowOnError(GetApi().BindOutput(p_, name, value));
}
inline void IoBinding::BindOutput(const char* name, const MemoryInfo& mem_info) {
ThrowOnError(GetApi().BindOutputToDevice(p_, name, mem_info));
}
inline std::vector<std::string> IoBinding::GetOutputNamesHelper(OrtAllocator* allocator) const {
std::vector<std::string> result;
auto free_fn = [allocator](void* p) { if (p) allocator->Free(allocator, p); };
using Ptr = std::unique_ptr<void, decltype(free_fn)>;
char* buffer = nullptr;
size_t* lengths = nullptr;
size_t count = 0;
ThrowOnError(GetApi().GetBoundOutputNames(p_, allocator, &buffer, &lengths, &count));
if (count == 0) {
return result;
}
Ptr buffer_g(buffer, free_fn);
Ptr lengths_g(lengths, free_fn);
result.reserve(count);
for (size_t i = 0; i < count; ++i) {
auto sz = *lengths;
result.emplace_back(buffer, sz);
buffer += sz;
++lengths;
}
return result;
}
inline std::vector<std::string> IoBinding::GetOutputNames() const {
AllocatorWithDefaultOptions allocator;
return GetOutputNamesHelper(allocator);
}
inline std::vector<std::string> IoBinding::GetOutputNames(Allocator& allocator) const {
return GetOutputNamesHelper(allocator);
}
inline std::vector<Value> Ort::IoBinding::GetOutputValuesHelper(OrtAllocator* allocator) const {
std::vector<Value> result;
size_t owned = 0;
size_t output_count = 0;
// Lambda to release the buffer when no longer needed and
// make sure that we destroy all instances on exception
auto free_fn = [&owned, &output_count, allocator](OrtValue** buffer) {
if (buffer) {
while (owned < output_count) {
auto* p = buffer + owned++;
GetApi().ReleaseValue(*p);
}
allocator->Free(allocator, buffer);
}
};
using Ptr = std::unique_ptr<OrtValue*, decltype(free_fn)>;
OrtValue** output_buffer = nullptr;
ThrowOnError(GetApi().GetBoundOutputValues(p_, allocator, &output_buffer, &output_count));
if (output_count == 0) {
return result;
}
Ptr buffer_g(output_buffer, free_fn);
result.reserve(output_count);
for (size_t i = 0; i < output_count; ++i) {
result.emplace_back(output_buffer[i]);
++owned;
}
return result;
}
inline std::vector<Value> Ort::IoBinding::GetOutputValues(Allocator& allocator) const {
return GetOutputValuesHelper(allocator);
}
inline std::vector<Value> Ort::IoBinding::GetOutputValues() const {
AllocatorWithDefaultOptions allocator;
return GetOutputValuesHelper(allocator);
}
inline void IoBinding::ClearBoundInputs() {
GetApi().ClearBoundInputs(p_);
}
inline void IoBinding::ClearBoundOutputs() {
GetApi().ClearBoundOutputs(p_);
}
inline Env::Env(OrtLoggingLevel default_warning_level, _In_ const char* logid) {
ThrowOnError(GetApi().CreateEnv(default_warning_level, logid, &p_));
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
}
inline Env::Env(OrtLoggingLevel default_warning_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param) {
ThrowOnError(GetApi().CreateEnvWithCustomLogger(logging_function, logger_param, default_warning_level, logid, &p_));
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
}
inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel default_warning_level, _In_ const char* logid) {
ThrowOnError(GetApi().CreateEnvWithGlobalThreadPools(default_warning_level, logid, tp_options, &p_));
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
}
inline Env& Env::EnableTelemetryEvents() {
ThrowOnError(GetApi().EnableTelemetryEvents(p_));
return *this;
}
inline Env& Env::DisableTelemetryEvents() {
ThrowOnError(GetApi().DisableTelemetryEvents(p_));
return *this;
}
inline Env& Env::CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg) {
ThrowOnError(GetApi().CreateAndRegisterAllocator(p_, mem_info, arena_cfg));
return *this;
}
inline CustomOpDomain::CustomOpDomain(const char* domain) {
ThrowOnError(GetApi().CreateCustomOpDomain(domain, &p_));
}
inline void CustomOpDomain::Add(OrtCustomOp* op) {
ThrowOnError(GetApi().CustomOpDomain_Add(p_, op));
}
inline RunOptions::RunOptions() {
ThrowOnError(GetApi().CreateRunOptions(&p_));
}
inline RunOptions& RunOptions::SetRunLogVerbosityLevel(int level) {
ThrowOnError(GetApi().RunOptionsSetRunLogVerbosityLevel(p_, level));
return *this;
}
inline RunOptions& RunOptions::SetRunLogSeverityLevel(int level) {
ThrowOnError(GetApi().RunOptionsSetRunLogSeverityLevel(p_, level));
return *this;
}
inline int RunOptions::GetRunLogVerbosityLevel() const {
int out;
ThrowOnError(GetApi().RunOptionsGetRunLogVerbosityLevel(p_, &out));
return out;
}
inline RunOptions& RunOptions::SetRunTag(const char* run_tag) {
ThrowOnError(GetApi().RunOptionsSetRunTag(p_, run_tag));
return *this;
}
inline const char* RunOptions::GetRunTag() const {
const char* out;
ThrowOnError(GetApi().RunOptionsGetRunTag(p_, &out));
return out;
}
inline RunOptions& RunOptions::SetTerminate() {
ThrowOnError(GetApi().RunOptionsSetTerminate(p_));
return *this;
}
inline RunOptions& RunOptions::UnsetTerminate() {
ThrowOnError(GetApi().RunOptionsUnsetTerminate(p_));
return *this;
}
inline SessionOptions::SessionOptions() {
ThrowOnError(GetApi().CreateSessionOptions(&p_));
}
inline SessionOptions SessionOptions::Clone() const {
OrtSessionOptions* out;
ThrowOnError(GetApi().CloneSessionOptions(p_, &out));
return SessionOptions{out};
}
inline SessionOptions& SessionOptions::SetIntraOpNumThreads(int intra_op_num_threads) {
ThrowOnError(GetApi().SetIntraOpNumThreads(p_, intra_op_num_threads));
return *this;
}
inline SessionOptions& SessionOptions::SetInterOpNumThreads(int inter_op_num_threads) {
ThrowOnError(GetApi().SetInterOpNumThreads(p_, inter_op_num_threads));
return *this;
}
inline SessionOptions& SessionOptions::SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level) {
ThrowOnError(GetApi().SetSessionGraphOptimizationLevel(p_, graph_optimization_level));
return *this;
}
inline SessionOptions& SessionOptions::SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_filepath) {
ThrowOnError(GetApi().SetOptimizedModelFilePath(p_, optimized_model_filepath));
return *this;
}
inline SessionOptions& SessionOptions::EnableProfiling(const ORTCHAR_T* profile_file_prefix) {
ThrowOnError(GetApi().EnableProfiling(p_, profile_file_prefix));
return *this;
}
inline SessionOptions& SessionOptions::DisableProfiling() {
ThrowOnError(GetApi().DisableProfiling(p_));
return *this;
}
inline SessionOptions& SessionOptions::EnableMemPattern() {
ThrowOnError(GetApi().EnableMemPattern(p_));
return *this;
}
inline SessionOptions& SessionOptions::DisableMemPattern() {
ThrowOnError(GetApi().DisableMemPattern(p_));
return *this;
}
inline SessionOptions& SessionOptions::EnableCpuMemArena() {
ThrowOnError(GetApi().EnableCpuMemArena(p_));
return *this;
}
inline SessionOptions& SessionOptions::DisableCpuMemArena() {
ThrowOnError(GetApi().DisableCpuMemArena(p_));
return *this;
}
inline SessionOptions& SessionOptions::SetExecutionMode(ExecutionMode execution_mode) {
ThrowOnError(GetApi().SetSessionExecutionMode(p_, execution_mode));
return *this;
}
inline SessionOptions& SessionOptions::SetLogId(const char* logid) {
ThrowOnError(GetApi().SetSessionLogId(p_, logid));
return *this;
}
inline SessionOptions& SessionOptions::SetLogSeverityLevel(int level) {
ThrowOnError(GetApi().SetSessionLogSeverityLevel(p_, level));
return *this;
}
inline SessionOptions& SessionOptions::Add(OrtCustomOpDomain* custom_op_domain) {
ThrowOnError(GetApi().AddCustomOpDomain(p_, custom_op_domain));
return *this;
}
inline SessionOptions& SessionOptions::AddConfigEntry(const char* config_key, const char* config_value) {
ThrowOnError(GetApi().AddSessionConfigEntry(p_, config_key, config_value));
return *this;
}
inline Session::Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) {
ThrowOnError(GetApi().CreateSession(env, model_path, options, &p_));
}
inline Session::Session(Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options) {
ThrowOnError(GetApi().CreateSessionFromArray(env, model_data, model_data_length, options, &p_));
}
inline std::vector<Value> Session::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
const char* const* output_names, size_t output_names_count) {
std::vector<Ort::Value> output_values;
for (size_t i = 0; i < output_names_count; i++)
output_values.emplace_back(nullptr);
Run(run_options, input_names, input_values, input_count, output_names, output_values.data(), output_names_count);
return output_values;
}
inline void Session::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
const char* const* output_names, Value* output_values, size_t output_count) {
static_assert(sizeof(Value) == sizeof(OrtValue*), "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely");
auto ort_input_values = reinterpret_cast<const OrtValue**>(const_cast<Value*>(input_values));
auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
ThrowOnError(GetApi().Run(p_, run_options, input_names, ort_input_values, input_count, output_names, output_count, ort_output_values));
}
inline void Session::Run(const RunOptions& run_options, const IoBinding& io_binding) {
ThrowOnError(GetApi().RunWithBinding(p_, run_options, io_binding));
}
inline size_t Session::GetInputCount() const {
size_t out;
ThrowOnError(GetApi().SessionGetInputCount(p_, &out));
return out;
}
inline size_t Session::GetOutputCount() const {
size_t out;
ThrowOnError(GetApi().SessionGetOutputCount(p_, &out));
return out;
}
inline size_t Session::GetOverridableInitializerCount() const {
size_t out;
ThrowOnError(GetApi().SessionGetOverridableInitializerCount(p_, &out));
return out;
}
inline char* Session::GetInputName(size_t index, OrtAllocator* allocator) const {
char* out;
ThrowOnError(GetApi().SessionGetInputName(p_, index, allocator, &out));
return out;
}
inline char* Session::GetOutputName(size_t index, OrtAllocator* allocator) const {
char* out;
ThrowOnError(GetApi().SessionGetOutputName(p_, index, allocator, &out));
return out;
}
inline char* Session::GetOverridableInitializerName(size_t index, OrtAllocator* allocator) const {
char* out;
ThrowOnError(GetApi().SessionGetOverridableInitializerName(p_, index, allocator, &out));
return out;
}
inline char* Session::EndProfiling(OrtAllocator* allocator) const {
char* out;
ThrowOnError(GetApi().SessionEndProfiling(p_, allocator, &out));
return out;
}
inline uint64_t Session::GetProfilingStartTimeNs() const {
uint64_t out;
ThrowOnError(GetApi().SessionGetProfilingStartTimeNs(p_, &out));
return out;
}
inline ModelMetadata Session::GetModelMetadata() const {
OrtModelMetadata* out;
ThrowOnError(GetApi().SessionGetModelMetadata(p_, &out));
return ModelMetadata{out};
}
inline char* ModelMetadata::GetProducerName(OrtAllocator* allocator) const {
char* out;
ThrowOnError(GetApi().ModelMetadataGetProducerName(p_, allocator, &out));
return out;
}
inline char* ModelMetadata::GetGraphName(OrtAllocator* allocator) const {
char* out;
ThrowOnError(GetApi().ModelMetadataGetGraphName(p_, allocator, &out));
return out;
}
inline char* ModelMetadata::GetDomain(OrtAllocator* allocator) const {
char* out;
ThrowOnError(GetApi().ModelMetadataGetDomain(p_, allocator, &out));
return out;
}
inline char* ModelMetadata::GetDescription(OrtAllocator* allocator) const {
char* out;
ThrowOnError(GetApi().ModelMetadataGetDescription(p_, allocator, &out));
return out;
}
inline char* ModelMetadata::LookupCustomMetadataMap(const char* key, OrtAllocator* allocator) const {
char* out;
ThrowOnError(GetApi().ModelMetadataLookupCustomMetadataMap(p_, allocator, key, &out));
return out;
}
inline char** ModelMetadata::GetCustomMetadataMapKeys(OrtAllocator* allocator, _Out_ int64_t& num_keys) const {
char** out;
ThrowOnError(GetApi().ModelMetadataGetCustomMetadataMapKeys(p_, allocator, &out, &num_keys));
return out;
}
inline int64_t ModelMetadata::GetVersion() const {
int64_t out;
ThrowOnError(GetApi().ModelMetadataGetVersion(p_, &out));
return out;
}
inline TypeInfo Session::GetInputTypeInfo(size_t index) const {
OrtTypeInfo* out;
ThrowOnError(GetApi().SessionGetInputTypeInfo(p_, index, &out));
return TypeInfo{out};
}
inline TypeInfo Session::GetOutputTypeInfo(size_t index) const {
OrtTypeInfo* out;
ThrowOnError(GetApi().SessionGetOutputTypeInfo(p_, index, &out));
return TypeInfo{out};
}
inline TypeInfo Session::GetOverridableInitializerTypeInfo(size_t index) const {
OrtTypeInfo* out;
ThrowOnError(GetApi().SessionGetOverridableInitializerTypeInfo(p_, index, &out));
return TypeInfo{out};
}
inline ONNXTensorElementDataType TensorTypeAndShapeInfo::GetElementType() const {
ONNXTensorElementDataType out;
ThrowOnError(GetApi().GetTensorElementType(p_, &out));
return out;
}
inline size_t TensorTypeAndShapeInfo::GetElementCount() const {
size_t out;
ThrowOnError(GetApi().GetTensorShapeElementCount(p_, &out));
return static_cast<size_t>(out);
}
inline size_t TensorTypeAndShapeInfo::GetDimensionsCount() const {
size_t out;
ThrowOnError(GetApi().GetDimensionsCount(p_, &out));
return out;
}
inline void TensorTypeAndShapeInfo::GetDimensions(int64_t* values, size_t values_count) const {
ThrowOnError(GetApi().GetDimensions(p_, values, values_count));
}
inline void TensorTypeAndShapeInfo::GetSymbolicDimensions(const char** values, size_t values_count) const {
ThrowOnError(GetApi().GetSymbolicDimensions(p_, values, values_count));
}
inline std::vector<int64_t> TensorTypeAndShapeInfo::GetShape() const {
std::vector<int64_t> out(GetDimensionsCount(), 0);
GetDimensions(out.data(), out.size());
return out;
}
inline Unowned<TensorTypeAndShapeInfo> TypeInfo::GetTensorTypeAndShapeInfo() const {
const OrtTensorTypeAndShapeInfo* out;
ThrowOnError(GetApi().CastTypeInfoToTensorInfo(p_, &out));
return Unowned<TensorTypeAndShapeInfo>(const_cast<OrtTensorTypeAndShapeInfo*>(out));
}
inline Unowned<SequenceTypeInfo> TypeInfo::GetSequenceTypeInfo() const {
const OrtSequenceTypeInfo* out;
ThrowOnError(GetApi().CastTypeInfoToSequenceTypeInfo(p_, &out));
return Unowned<SequenceTypeInfo>{const_cast<OrtSequenceTypeInfo*>(out)};
}
inline TypeInfo SequenceTypeInfo::GetSequenceElementType() const {
OrtTypeInfo* output;
ThrowOnError(GetApi().GetSequenceElementType(p_, &output));
return TypeInfo{output};
}
inline Unowned<MapTypeInfo> TypeInfo::GetMapTypeInfo() const {
const OrtMapTypeInfo* out;
ThrowOnError(GetApi().CastTypeInfoToMapTypeInfo(p_, &out));
return Unowned<MapTypeInfo>{const_cast<OrtMapTypeInfo*>(out)};
}
inline ONNXTensorElementDataType MapTypeInfo::GetMapKeyType() const {
ONNXTensorElementDataType out;
ThrowOnError(GetApi().GetMapKeyType(p_, &out));
return out;
}
inline TypeInfo MapTypeInfo::GetMapValueType() const {
OrtTypeInfo* output;
ThrowOnError(GetApi().GetMapValueType(p_, &output));
return TypeInfo{output};
}
inline ONNXType TypeInfo::GetONNXType() const {
ONNXType out;
ThrowOnError(GetApi().GetOnnxTypeFromTypeInfo(p_, &out));
return out;
}
template <typename T>
inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len) {
return CreateTensor(info, p_data, p_data_element_count * sizeof(T), shape, shape_len, TypeToTensorType<T>::type);
}
inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
ONNXTensorElementDataType type) {
OrtValue* out;
ThrowOnError(GetApi().CreateTensorWithDataAsOrtValue(info, p_data, p_data_byte_count, shape, shape_len, type, &out));
return Value{out};
}
template <typename T>
inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len) {
return CreateTensor(allocator, shape, shape_len, TypeToTensorType<T>::type);
}
inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type) {
OrtValue* out;
ThrowOnError(GetApi().CreateTensorAsOrtValue(allocator, shape, shape_len, type, &out));
return Value{out};
}
inline Value Value::CreateMap(Value& keys, Value& values) {
OrtValue* out;
OrtValue* inputs[2] = {keys, values};
ThrowOnError(GetApi().CreateValue(inputs, 2, ONNX_TYPE_MAP, &out));
return Value{out};
}
inline Value Value::CreateSequence(std::vector<Value>& values) {
OrtValue* out;
std::vector<OrtValue*> values_ort{values.data(), values.data() + values.size()};
ThrowOnError(GetApi().CreateValue(values_ort.data(), values_ort.size(), ONNX_TYPE_SEQUENCE, &out));
return Value{out};
}
template <typename T>
inline Value Value::CreateOpaque(const char* domain, const char* type_name, const T& data_container) {
OrtValue* out;
ThrowOnError(GetApi().CreateOpaqueValue(domain, type_name, &data_container, sizeof(T), &out));
return Value{out};
}
template <typename T>
inline void Value::GetOpaqueData(const char* domain, const char* type_name, T& out) const {
ThrowOnError(GetApi().GetOpaqueValue(domain, type_name, p_, &out, sizeof(T)));
}
inline bool Value::IsTensor() const {
int out;
ThrowOnError(GetApi().IsTensor(p_, &out));
return out != 0;
}
inline size_t Value::GetCount() const {
size_t out;
ThrowOnError(GetApi().GetValueCount(p_, &out));
return out;
}
inline Value Value::GetValue(int index, OrtAllocator* allocator) const {
OrtValue* out;
ThrowOnError(GetApi().GetValue(p_, index, allocator, &out));
return Value{out};
}
inline size_t Value::GetStringTensorDataLength() const {
size_t out;
ThrowOnError(GetApi().GetStringTensorDataLength(p_, &out));
return out;
}
inline size_t Value::GetStringTensorElementLength(size_t element_index) const {
size_t out;
ThrowOnError(GetApi().GetStringTensorElementLength(p_, element_index, &out));
return out;
}
inline void Value::GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const {
ThrowOnError(GetApi().GetStringTensorContent(p_, buffer, buffer_length, offsets, offsets_count));
}
inline void Value::GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const {
ThrowOnError(GetApi().GetStringTensorElement(p_, buffer_length, element_index, buffer));
}
inline void Value::FillStringTensor(const char* const* s, size_t s_len) {
ThrowOnError(GetApi().FillStringTensor(p_, s, s_len));
}
inline void Value::FillStringTensorElement(const char* s, size_t index) {
ThrowOnError(GetApi().FillStringTensorElement(p_, s, index));
}
template <typename T>
T* Value::GetTensorMutableData() {
T* out;
ThrowOnError(GetApi().GetTensorMutableData(p_, (void**)&out));
return out;
}
template <typename T>
const T* Value::GetTensorData() const {
T* out;
ThrowOnError(GetApi().GetTensorMutableData(p_, (void**)&out));
return out;
}
template <typename T>
inline T& Value::At(const std::vector<int64_t>& location) {
static_assert(!std::is_same<T, std::string>::value, "this api does not support std::string");
T* out;
ThrowOnError(GetApi().TensorAt(p_, location.data(), location.size(), (void**)&out));
return *out;
}
inline TypeInfo Value::GetTypeInfo() const {
OrtTypeInfo* output;
ThrowOnError(GetApi().GetTypeInfo(p_, &output));
return TypeInfo{output};
}
inline TensorTypeAndShapeInfo Value::GetTensorTypeAndShapeInfo() const {
OrtTensorTypeAndShapeInfo* output;
ThrowOnError(GetApi().GetTensorTypeAndShape(p_, &output));
return TensorTypeAndShapeInfo{output};
}
//
// Custom OP API Inlines
//
inline void CustomOpApi::ThrowOnError(OrtStatus* status) {
Ort::ThrowOnError(api_, status);
}
template <>
inline float CustomOpApi::KernelInfoGetAttribute<float>(_In_ const OrtKernelInfo* info, _In_ const char* name) {
float out;
ThrowOnError(api_.KernelInfoGetAttribute_float(info, name, &out));
return out;
}
template <>
inline int64_t CustomOpApi::KernelInfoGetAttribute<int64_t>(_In_ const OrtKernelInfo* info, _In_ const char* name) {
int64_t out;
ThrowOnError(api_.KernelInfoGetAttribute_int64(info, name, &out));
return out;
}
template <>
inline std::string CustomOpApi::KernelInfoGetAttribute<std::string>(_In_ const OrtKernelInfo* info, _In_ const char* name) {
size_t size = 0;
std::string out;
OrtStatus* status = api_.KernelInfoGetAttribute_string(info, name, nullptr, &size);
// The status should be ORT_INVALID_ARGUMENT because the size is insufficient to hold the string
if (api_.GetErrorCode(status) == ORT_INVALID_ARGUMENT) {
api_.ReleaseStatus(status);
out.resize(size);
ThrowOnError(api_.KernelInfoGetAttribute_string(info, name, &out[0], &size));
out.resize(size - 1); // remove the terminating character '\0'
} else {
ThrowOnError(status);
}
return out;
}
inline OrtTensorTypeAndShapeInfo* CustomOpApi::GetTensorTypeAndShape(_In_ const OrtValue* value) {
OrtTensorTypeAndShapeInfo* out;
ThrowOnError(api_.GetTensorTypeAndShape(value, &out));
return out;
}
inline size_t CustomOpApi::GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) {
size_t out;
ThrowOnError(api_.GetTensorShapeElementCount(info, &out));
return out;
}
inline ONNXTensorElementDataType CustomOpApi::GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) {
ONNXTensorElementDataType out;
ThrowOnError(api_.GetTensorElementType(info, &out));
return out;
}
inline size_t CustomOpApi::GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) {
size_t out;
ThrowOnError(api_.GetDimensionsCount(info, &out));
return out;
}
inline void CustomOpApi::GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) {
ThrowOnError(api_.GetDimensions(info, dim_values, dim_values_length));
}
inline void CustomOpApi::SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) {
ThrowOnError(api_.SetDimensions(info, dim_values, dim_count));
}
template <typename T>
inline T* CustomOpApi::GetTensorMutableData(_Inout_ OrtValue* value) {
T* data;
ThrowOnError(api_.GetTensorMutableData(value, reinterpret_cast<void**>(&data)));
return data;
}
template <typename T>
inline const T* CustomOpApi::GetTensorData(_Inout_ const OrtValue* value) {
return GetTensorMutableData<T>(const_cast<OrtValue*>(value));
}
inline std::vector<int64_t> CustomOpApi::GetTensorShape(const OrtTensorTypeAndShapeInfo* info) {
std::vector<int64_t> output(GetDimensionsCount(info));
GetDimensions(info, output.data(), output.size());
return output;
}
inline void CustomOpApi::ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) {
api_.ReleaseTensorTypeAndShapeInfo(input);
}
inline size_t CustomOpApi::KernelContext_GetInputCount(const OrtKernelContext* context) {
size_t out;
ThrowOnError(api_.KernelContext_GetInputCount(context, &out));
return out;
}
inline const OrtValue* CustomOpApi::KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) {
const OrtValue* out;
ThrowOnError(api_.KernelContext_GetInput(context, index, &out));
return out;
}
inline size_t CustomOpApi::KernelContext_GetOutputCount(const OrtKernelContext* context) {
size_t out;
ThrowOnError(api_.KernelContext_GetOutputCount(context, &out));
return out;
}
inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count) {
OrtValue* out;
ThrowOnError(api_.KernelContext_GetOutput(context, index, dim_values, dim_count, &out));
return out;
}
inline SessionOptions& SessionOptions::DisablePerSessionThreads() {
ThrowOnError(GetApi().DisablePerSessionThreads(p_));
return *this;
}
inline std::vector<std::string> GetAvailableProviders() {
int len;
char** providers;
const OrtApi& api = GetApi();
ThrowOnError(api.GetAvailableProviders(&providers, &len));
std::vector<std::string> available_providers(providers, providers + len);
ThrowOnError(api.ReleaseAvailableProviders(providers, len));
return available_providers;
}
} // namespace Ort

Просмотреть файл

@ -0,0 +1,33 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
/*
* This file defines SessionOptions Config Keys and format of the Config Values.
*
* The Naming Convention for a SessionOptions Config Key,
* "[Area][.[SubArea1].[SubArea2]...].[Keyname]"
* Such as "ep.cuda.use_arena"
* The Config Key cannot be empty
* The maximum length of the Config Key is 128
*
* The string format of a SessionOptions Config Value is defined individually for each Config.
* The maximum length of the Config Value is 1024
*/
// Key for disable PrePacking,
// If the config value is set to "1" then the prepacking is disabled, otherwise prepacking is enabled (default value)
static const char* const kOrtSessionOptionsConfigDisablePrepacking = "session.disable_prepacking";
// A value of "1" means allocators registered in the env will be used. "0" means the allocators created in the session
// will be used. Use this to override the usage of env allocators on a per session level.
static const char* const kOrtSessionOptionsConfigUseEnvAllocators = "session.use_env_allocators";
// Set to 'ORT' (case sensitive) to load an ORT format model.
// If unset, model type will default to ONNX unless inferred from filename ('.ort' == ORT format) or bytes to be ORT
static const char* const kOrtSessionOptionsConfigLoadModelFormat = "session.load_model_format";
// Set to 'ORT' (case sensitive) to save optimized model in ORT format when SessionOptions.optimized_model_path is set.
// If unset, format will default to ONNX unless optimized_model_filepath ends in '.ort'.
static const char* const kOrtSessionOptionsConfigSaveModelFormat = "session.save_model_format";