diff --git a/.clang-format b/.clang-format new file mode 100644 index 00000000..491a8357 --- /dev/null +++ b/.clang-format @@ -0,0 +1,22 @@ +--- +# Defaults for all languages. +BasedOnStyle: Google + +# Setting ColumnLimit to 0 so developer choices about where to break lines are maintained. +# Developers are responsible for adhering to the 120 character maximum. +ColumnLimit: 0 +SortIncludes: false +DerivePointerAlignment: false + +# if you want to customize when working locally see https://clang.llvm.org/docs/ClangFormatStyleOptions.html for options. +# See ReformatSource.ps1 for a script to update all source according to the current options in this file. +# e.g. customizations to use Allman bracing and more indenting. +# AccessModifierOffset: -2 +# BreakBeforeBraces: Allman +# CompactNamespaces: false +# IndentCaseLabels: true +# IndentWidth: 4 +# NamespaceIndentation: All + +... + diff --git a/.clang-tidy b/.clang-tidy new file mode 100644 index 00000000..665bb057 --- /dev/null +++ b/.clang-tidy @@ -0,0 +1,30 @@ +--- +# turn off readability-braces-around-statements to allow single line statement like 'if (x == y) doSomething();' +Checks: '-*,cppcoreguidelines-*,google-*,readability-*,modernize-*,-readability-braces-around-statements,-google-runtime-references,-cppcoreguidelines-pro-type-reinterpret-cast' +WarningsAsErrors: '' +HeaderFilterRegex: '.*onnxruntime\/core\/.*' +AnalyzeTemporaryDtors: false +FormatStyle: none +CheckOptions: + - key: google-readability-braces-around-statements.ShortStatementLines + value: '1' + - key: google-readability-function-size.StatementThreshold + value: '800' + - key: google-readability-namespace-comments.ShortNamespaceLines + value: '10' + - key: google-readability-namespace-comments.SpacesBeforeComments + value: '2' + - key: modernize-loop-convert.MaxCopySize + value: '16' + - key: modernize-loop-convert.MinConfidence + value: reasonable + - key: modernize-loop-convert.NamingStyle + value: CamelCase + - key: modernize-pass-by-value.IncludeStyle + value: google + - key: modernize-replace-auto-ptr.IncludeStyle + value: google + - key: modernize-use-nullptr.NullMacros + value: 'NULL' +... + diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..bd1c971b --- /dev/null +++ b/.flake8 @@ -0,0 +1,5 @@ +[flake8] +max-line-length = 120 +per-file-ignores = + __init__.py:F401 +format = [flake8 PEP8 ERROR] %(path)s:%(row)d:%(col)d: %(code)s %(text)s diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..9c63508c --- /dev/null +++ b/.gitignore @@ -0,0 +1,27 @@ +# build, distribute, and bins (+ python proto bindings) +mybuild.* +build +build_host_protoc +build_android +build_ios +build_* +.build_debug/* +.build_release/* +distribute/* +*.testbin +*.bin +cmake_build +.cmake_build +cmake-build-debug +gen +*~ +.vs +TestResults/ +.idea/ +onnxruntime.egg-info +nuget_root/ +.packages/ +.vscode/ +*.code-workspace +__pycache__ +out/ diff --git a/README.md b/README.md index 8eeee9c7..b7be03fb 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,17 @@ +# Introduction +The onnxruntime-customops package is an onnxuntime custom op library which supports the ONNX model inference with non-standard ONNX operators. Besides, and the custom op also is implemented with python function. + +# License +[MIT License](LICENSE) # Contributing This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us -the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. +the rights to use your contribution. For details, visit https://cla.microsoft.com. -When you submit a pull request, a CLA bot will automatically determine whether you need to provide -a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions +When you submit a pull request, a CLA-bot will automatically determine whether you need to provide +a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA. This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). diff --git a/includes/onnxruntime/onnxruntime_c_api.h b/includes/onnxruntime/onnxruntime_c_api.h new file mode 100644 index 00000000..feaaa0e2 --- /dev/null +++ b/includes/onnxruntime/onnxruntime_c_api.h @@ -0,0 +1,1098 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include +#include +#include "onnxruntime_session_options_config_keys.h" + +// This value is used in structures passed to ORT so that a newer version of ORT will still work with them +#define ORT_API_VERSION 5 + +#ifdef __cplusplus +extern "C" { +#endif + +// SAL2 Definitions +#ifndef _WIN32 +#define _In_ +#define _In_z_ +#define _In_opt_ +#define _In_opt_z_ +#define _Out_ +#define _Outptr_ +#define _Out_opt_ +#define _Inout_ +#define _Inout_opt_ +#define _Frees_ptr_opt_ +#define _Ret_maybenull_ +#define _Ret_notnull_ +#define _Check_return_ +#define _Outptr_result_maybenull_ +#define _In_reads_(X) +#define _Inout_updates_all_(X) +#define _Out_writes_bytes_all_(X) +#define _Out_writes_all_(X) +#define _Success_(X) +#define _Outptr_result_buffer_maybenull_(X) +#define ORT_ALL_ARGS_NONNULL __attribute__((nonnull)) +#else +#include +#define ORT_ALL_ARGS_NONNULL +#endif + +#ifdef _WIN32 +// Define ORT_DLL_IMPORT if your program is dynamically linked to Ort. +// dllexport is not used, we use a .def file. +#ifdef ORT_DLL_IMPORT +#define ORT_EXPORT __declspec(dllimport) +#else +#define ORT_EXPORT +#endif +#define ORT_API_CALL _stdcall +#define ORT_MUST_USE_RESULT +#define ORTCHAR_T wchar_t +#else +#define ORT_EXPORT +#define ORT_API_CALL +#define ORT_MUST_USE_RESULT __attribute__((warn_unused_result)) +#define ORTCHAR_T char +#endif + +#ifndef ORT_TSTR +#ifdef _WIN32 +#define ORT_TSTR(X) L##X +#else +#define ORT_TSTR(X) X +#endif +#endif + +// Any pointer marked with _In_ or _Out_, cannot be NULL. + +// Windows users should use unicode paths when possible to bypass the MAX_PATH limitation +// Every pointer marked with _In_ or _Out_, cannot be NULL. Caller should ensure that. +// for ReleaseXXX(...) functions, they can accept NULL pointer. + +#ifdef __cplusplus +// For any compiler with C++11 support, MSVC 2015 and greater, or Clang version supporting noexcept. +// Such complex condition is needed because compilers set __cplusplus value differently. +#ifndef __has_feature +#define __has_feature(x) 0 +#endif +#if ((__cplusplus >= 201103L) || (_MSC_VER >= 1900) || (defined(__has_feature) && __has_feature(cxx_noexcept))) +#define NO_EXCEPTION noexcept +#else +#define NO_EXCEPTION throw() +#endif +#else +#define NO_EXCEPTION +#endif + +// Copied from TensorProto::DataType +// Currently, Ort doesn't support complex64, complex128, bfloat16 types +typedef enum ONNXTensorElementDataType { + ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED, + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, // maps to c type float + ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, // maps to c type uint8_t + ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, // maps to c type int8_t + ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, // maps to c type uint16_t + ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, // maps to c type int16_t + ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, // maps to c type int32_t + ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, // maps to c type int64_t + ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, // maps to c++ type std::string + ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, + ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, // maps to c type double + ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, // maps to c type uint32_t + ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, // maps to c type uint64_t + ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64, // complex with float32 real and imaginary components + ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128, // complex with float64 real and imaginary components + ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 // Non-IEEE floating-point format based on IEEE754 single-precision +} ONNXTensorElementDataType; + +// Synced with onnx TypeProto oneof +typedef enum ONNXType { + ONNX_TYPE_UNKNOWN, + ONNX_TYPE_TENSOR, + ONNX_TYPE_SEQUENCE, + ONNX_TYPE_MAP, + ONNX_TYPE_OPAQUE, + ONNX_TYPE_SPARSETENSOR, +} ONNXType; + +typedef enum OrtLoggingLevel { + ORT_LOGGING_LEVEL_VERBOSE, + ORT_LOGGING_LEVEL_INFO, + ORT_LOGGING_LEVEL_WARNING, + ORT_LOGGING_LEVEL_ERROR, + ORT_LOGGING_LEVEL_FATAL, +} OrtLoggingLevel; + +typedef enum OrtErrorCode { + ORT_OK, + ORT_FAIL, + ORT_INVALID_ARGUMENT, + ORT_NO_SUCHFILE, + ORT_NO_MODEL, + ORT_ENGINE_ERROR, + ORT_RUNTIME_EXCEPTION, + ORT_INVALID_PROTOBUF, + ORT_MODEL_LOADED, + ORT_NOT_IMPLEMENTED, + ORT_INVALID_GRAPH, + ORT_EP_FAIL, +} OrtErrorCode; + +// This configures the arena based allocator used by ORT +// See ONNX_Runtime_Perf_Tuning.md for details on what these mean and how to choose these values +typedef struct OrtArenaCfg { + size_t max_mem; // use 0 to allow ORT to choose the default + int arena_extend_strategy; // use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested + int initial_chunk_size_bytes; // use -1 to allow ORT to choose the default + int max_dead_bytes_per_chunk; // use -1 to allow ORT to choose the default +} OrtArenaCfg; + +#define ORT_RUNTIME_CLASS(X) \ + struct Ort##X; \ + typedef struct Ort##X Ort##X; + +// The actual types defined have an Ort prefix +ORT_RUNTIME_CLASS(Env); +ORT_RUNTIME_CLASS(Status); // nullptr for Status* indicates success +ORT_RUNTIME_CLASS(MemoryInfo); +ORT_RUNTIME_CLASS(IoBinding); +ORT_RUNTIME_CLASS(Session); //Don't call OrtReleaseSession from Dllmain (because session owns a thread pool) +ORT_RUNTIME_CLASS(Value); +ORT_RUNTIME_CLASS(RunOptions); +ORT_RUNTIME_CLASS(TypeInfo); +ORT_RUNTIME_CLASS(TensorTypeAndShapeInfo); +ORT_RUNTIME_CLASS(SessionOptions); +ORT_RUNTIME_CLASS(CustomOpDomain); +ORT_RUNTIME_CLASS(MapTypeInfo); +ORT_RUNTIME_CLASS(SequenceTypeInfo); +ORT_RUNTIME_CLASS(ModelMetadata); +ORT_RUNTIME_CLASS(ThreadPoolParams); +ORT_RUNTIME_CLASS(ThreadingOptions); + +#ifdef _WIN32 +typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr; +#else +typedef OrtStatus* OrtStatusPtr; +#endif + +// __VA_ARGS__ on Windows and Linux are different +#define ORT_API(RETURN_TYPE, NAME, ...) ORT_EXPORT RETURN_TYPE ORT_API_CALL NAME(__VA_ARGS__) NO_EXCEPTION + +#define ORT_API_STATUS(NAME, ...) \ + ORT_EXPORT _Check_return_ _Ret_maybenull_ OrtStatusPtr ORT_API_CALL NAME(__VA_ARGS__) NO_EXCEPTION ORT_MUST_USE_RESULT + +// XXX: Unfortunately, SAL annotations are known to not work with function pointers +#define ORT_API2_STATUS(NAME, ...) \ + _Check_return_ _Ret_maybenull_ OrtStatusPtr(ORT_API_CALL* NAME)(__VA_ARGS__) NO_EXCEPTION ORT_MUST_USE_RESULT + +// Used in *.cc files. Almost as same as ORT_API_STATUS, except without ORT_MUST_USE_RESULT and ORT_EXPORT +#define ORT_API_STATUS_IMPL(NAME, ...) \ + _Check_return_ _Ret_maybenull_ OrtStatusPtr ORT_API_CALL NAME(__VA_ARGS__) NO_EXCEPTION + +#define ORT_CLASS_RELEASE(X) void(ORT_API_CALL * Release##X)(_Frees_ptr_opt_ Ort##X * input) + +// When passing in an allocator to any ORT function, be sure that the allocator object +// is not destroyed until the last allocated object using it is freed. +typedef struct OrtAllocator { + uint32_t version; // Initialize to ORT_API_VERSION + void*(ORT_API_CALL* Alloc)(struct OrtAllocator* this_, size_t size); + void(ORT_API_CALL* Free)(struct OrtAllocator* this_, void* p); + const struct OrtMemoryInfo*(ORT_API_CALL* Info)(const struct OrtAllocator* this_); +} OrtAllocator; + +typedef void(ORT_API_CALL* OrtLoggingFunction)( + void* param, OrtLoggingLevel severity, const char* category, const char* logid, const char* code_location, + const char* message); + +// Set Graph optimization level. +// Refer https://github.com/microsoft/onnxruntime/blob/master/docs/ONNX_Runtime_Graph_Optimizations.md +// for in-depth undersrtanding of Graph Optimizations in ORT +typedef enum GraphOptimizationLevel { + ORT_DISABLE_ALL = 0, + ORT_ENABLE_BASIC = 1, + ORT_ENABLE_EXTENDED = 2, + ORT_ENABLE_ALL = 99 +} GraphOptimizationLevel; + +typedef enum ExecutionMode { + ORT_SEQUENTIAL = 0, + ORT_PARALLEL = 1, +} ExecutionMode; + +// Set the language projection, default is C, which means it will classify the language not in the list to C also. +typedef enum OrtLanguageProjection { + ORT_PROJECTION_C = 0, // default + ORT_PROJECTION_CPLUSPLUS = 1, + ORT_PROJECTION_CSHARP = 2, + ORT_PROJECTION_PYTHON = 3, + ORT_PROJECTION_JAVA = 4, + ORT_PROJECTION_WINML = 5, +} OrtLanguageProjection; + +struct OrtKernelInfo; +typedef struct OrtKernelInfo OrtKernelInfo; +struct OrtKernelContext; +typedef struct OrtKernelContext OrtKernelContext; +struct OrtCustomOp; +typedef struct OrtCustomOp OrtCustomOp; + +typedef enum OrtAllocatorType { + Invalid = -1, + OrtDeviceAllocator = 0, + OrtArenaAllocator = 1 +} OrtAllocatorType; + +/** + * memory types for allocator, exec provider specific types should be extended in each provider + * Whenever this struct is updated, please also update the MakeKey function in onnxruntime/core/framework/execution_provider.cc +*/ +typedef enum OrtMemType { + OrtMemTypeCPUInput = -2, // Any CPU memory used by non-CPU execution provider + OrtMemTypeCPUOutput = -1, // CPU accessible memory outputted by non-CPU execution provider, i.e. CUDA_PINNED + OrtMemTypeCPU = OrtMemTypeCPUOutput, // temporary CPU accessible memory allocated by non-CPU execution provider, i.e. CUDA_PINNED + OrtMemTypeDefault = 0, // the default allocator for execution provider +} OrtMemType; + +struct OrtApi; +typedef struct OrtApi OrtApi; + +struct OrtApiBase { + const OrtApi*(ORT_API_CALL* GetApi)(uint32_t version)NO_EXCEPTION; // Pass in ORT_API_VERSION + // nullptr will be returned if the version is unsupported, for example when using a runtime older than this header file + + const char*(ORT_API_CALL* GetVersionString)() NO_EXCEPTION; +}; +typedef struct OrtApiBase OrtApiBase; + +ORT_EXPORT const OrtApiBase* ORT_API_CALL OrtGetApiBase(void) NO_EXCEPTION; + +struct OrtApi { + /** +* \param msg A null-terminated string. Its content will be copied into the newly created OrtStatus +*/ + OrtStatus*(ORT_API_CALL* CreateStatus)(OrtErrorCode code, _In_ const char* msg)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; + + OrtErrorCode(ORT_API_CALL* GetErrorCode)(_In_ const OrtStatus* status) NO_EXCEPTION ORT_ALL_ARGS_NONNULL; + + /** + * \param status must not be NULL + * \return The error message inside the `status`. Do not free the returned value. + */ + const char*(ORT_API_CALL* GetErrorMessage)(_In_ const OrtStatus* status)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; + + /** + * \param out Should be freed by `OrtReleaseEnv` after use + */ + ORT_API2_STATUS(CreateEnv, OrtLoggingLevel default_logging_level, _In_ const char* logid, _Outptr_ OrtEnv** out); + + /** + * \param out Should be freed by `OrtReleaseEnv` after use + */ + ORT_API2_STATUS(CreateEnvWithCustomLogger, OrtLoggingFunction logging_function, _In_opt_ void* logger_param, + OrtLoggingLevel default_warning_level, _In_ const char* logid, _Outptr_ OrtEnv** out); + + // Platform telemetry events are on by default since they are lightweight. You can manually turn them off. + ORT_API2_STATUS(EnableTelemetryEvents, _In_ const OrtEnv* env); + ORT_API2_STATUS(DisableTelemetryEvents, _In_ const OrtEnv* env); + + // TODO: document the path separator convention? '/' vs '\' + // TODO: should specify the access characteristics of model_path. Is this read only during the + // execution of CreateSession, or does the OrtSession retain a handle to the file/directory + // and continue to access throughout the OrtSession lifetime? + // What sort of access is needed to model_path : read or read/write? + ORT_API2_STATUS(CreateSession, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, + _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out); + + ORT_API2_STATUS(CreateSessionFromArray, _In_ const OrtEnv* env, _In_ const void* model_data, size_t model_data_length, + _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out); + + ORT_API2_STATUS(Run, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOptions* run_options, + _In_reads_(input_len) const char* const* input_names, + _In_reads_(input_len) const OrtValue* const* input, size_t input_len, + _In_reads_(output_names_len) const char* const* output_names1, size_t output_names_len, + _Inout_updates_all_(output_names_len) OrtValue** output); + + /** + * \return A pointer of the newly created object. The pointer should be freed by OrtReleaseSessionOptions after use + */ + ORT_API2_STATUS(CreateSessionOptions, _Outptr_ OrtSessionOptions** options); + + // Set filepath to save optimized model after graph level transformations. + ORT_API2_STATUS(SetOptimizedModelFilePath, _Inout_ OrtSessionOptions* options, + _In_ const ORTCHAR_T* optimized_model_filepath); + + // create a copy of an existing OrtSessionOptions + ORT_API2_STATUS(CloneSessionOptions, _In_ const OrtSessionOptions* in_options, + _Outptr_ OrtSessionOptions** out_options); + + // Controls whether you want to execute operators in your graph sequentially or in parallel. Usually when the model + // has many branches, setting this option to ExecutionMode.ORT_PARALLEL will give you better performance. + // See [docs/ONNX_Runtime_Perf_Tuning.md] for more details. + ORT_API2_STATUS(SetSessionExecutionMode, _Inout_ OrtSessionOptions* options, ExecutionMode execution_mode); + + // Enable profiling for this session. + ORT_API2_STATUS(EnableProfiling, _Inout_ OrtSessionOptions* options, _In_ const ORTCHAR_T* profile_file_prefix); + ORT_API2_STATUS(DisableProfiling, _Inout_ OrtSessionOptions* options); + + // Enable the memory pattern optimization. + // The idea is if the input shapes are the same, we could trace the internal memory allocation + // and generate a memory pattern for future request. So next time we could just do one allocation + // with a big chunk for all the internal memory allocation. + // Note: memory pattern optimization is only available when SequentialExecution enabled. + ORT_API2_STATUS(EnableMemPattern, _Inout_ OrtSessionOptions* options); + ORT_API2_STATUS(DisableMemPattern, _Inout_ OrtSessionOptions* options); + + // Enable the memory arena on CPU + // Arena may pre-allocate memory for future usage. + // set this option to false if you don't want it. + ORT_API2_STATUS(EnableCpuMemArena, _Inout_ OrtSessionOptions* options); + ORT_API2_STATUS(DisableCpuMemArena, _Inout_ OrtSessionOptions* options); + + // < logger id to use for session output + ORT_API2_STATUS(SetSessionLogId, _Inout_ OrtSessionOptions* options, const char* logid); + + // < applies to session load, initialization, etc + ORT_API2_STATUS(SetSessionLogVerbosityLevel, _Inout_ OrtSessionOptions* options, int session_log_verbosity_level); + ORT_API2_STATUS(SetSessionLogSeverityLevel, _Inout_ OrtSessionOptions* options, int session_log_severity_level); + + ORT_API2_STATUS(SetSessionGraphOptimizationLevel, _Inout_ OrtSessionOptions* options, + GraphOptimizationLevel graph_optimization_level); + + // Sets the number of threads used to parallelize the execution within nodes + // A value of 0 means ORT will pick a default + // Note: If you've built ORT with OpenMP, this API has no effect on the number of threads used. In this case + // use the OpenMP env variables to configure the number of intra op num threads. + ORT_API2_STATUS(SetIntraOpNumThreads, _Inout_ OrtSessionOptions* options, int intra_op_num_threads); + + // Sets the number of threads used to parallelize the execution of the graph (across nodes) + // If sequential execution is enabled this value is ignored + // A value of 0 means ORT will pick a default + ORT_API2_STATUS(SetInterOpNumThreads, _Inout_ OrtSessionOptions* options, int inter_op_num_threads); + + /* + Create a custom op domain. After all sessions using it are released, call OrtReleaseCustomOpDomain + */ + ORT_API2_STATUS(CreateCustomOpDomain, _In_ const char* domain, _Outptr_ OrtCustomOpDomain** out); + + /* + * Add custom ops to the OrtCustomOpDomain + * Note: The OrtCustomOp* pointer must remain valid until the OrtCustomOpDomain using it is released + */ + ORT_API2_STATUS(CustomOpDomain_Add, _Inout_ OrtCustomOpDomain* custom_op_domain, _In_ OrtCustomOp* op); + + /* + * Add a custom op domain to the OrtSessionOptions + * Note: The OrtCustomOpDomain* must not be deleted until the sessions using it are released + */ + ORT_API2_STATUS(AddCustomOpDomain, _Inout_ OrtSessionOptions* options, _In_ OrtCustomOpDomain* custom_op_domain); + + /* + * Loads a DLL named 'library_path' and looks for this entry point: + * OrtStatus* RegisterCustomOps(OrtSessionOptions * options, const OrtApiBase* api); + * It then passes in the provided session options to this function along with the api base. + * The handle to the loaded library is returned in library_handle. It can be freed by the caller after all sessions using the passed in + * session options are destroyed, or if an error occurs and it is non null. + */ + ORT_API2_STATUS(RegisterCustomOpsLibrary, _Inout_ OrtSessionOptions* options, _In_ const char* library_path, + void** library_handle); + + /** + * To use additional providers, you must build ORT with the extra providers enabled. Then call one of these + * functions to enable them in the session: + * OrtSessionOptionsAppendExecutionProvider_CPU + * OrtSessionOptionsAppendExecutionProvider_CUDA + * OrtSessionOptionsAppendExecutionProvider_ + * The order they are called indicates the preference order as well. In other words call this method + * on your most preferred execution provider first followed by the less preferred ones. + * If none are called Ort will use its internal CPU execution provider. + */ + + ORT_API2_STATUS(SessionGetInputCount, _In_ const OrtSession* sess, _Out_ size_t* out); + ORT_API2_STATUS(SessionGetOutputCount, _In_ const OrtSession* sess, _Out_ size_t* out); + ORT_API2_STATUS(SessionGetOverridableInitializerCount, _In_ const OrtSession* sess, _Out_ size_t* out); + + /** + * \param out should be freed by OrtReleaseTypeInfo after use + */ + ORT_API2_STATUS(SessionGetInputTypeInfo, _In_ const OrtSession* sess, size_t index, _Outptr_ OrtTypeInfo** type_info); + + /** + * \param out should be freed by OrtReleaseTypeInfo after use + */ + ORT_API2_STATUS(SessionGetOutputTypeInfo, _In_ const OrtSession* sess, size_t index, + _Outptr_ OrtTypeInfo** type_info); + + /** + * \param out should be freed by OrtReleaseTypeInfo after use + */ + ORT_API2_STATUS(SessionGetOverridableInitializerTypeInfo, _In_ const OrtSession* sess, size_t index, + _Outptr_ OrtTypeInfo** type_info); + + /** + * \param value is set to a null terminated string allocated using 'allocator'. The caller is responsible for freeing it. + */ + ORT_API2_STATUS(SessionGetInputName, _In_ const OrtSession* sess, size_t index, _Inout_ OrtAllocator* allocator, + _Outptr_ char** value); + ORT_API2_STATUS(SessionGetOutputName, _In_ const OrtSession* sess, size_t index, _Inout_ OrtAllocator* allocator, + _Outptr_ char** value); + ORT_API2_STATUS(SessionGetOverridableInitializerName, _In_ const OrtSession* sess, size_t index, + _Inout_ OrtAllocator* allocator, _Outptr_ char** value); + + /** + * \return A pointer to the newly created object. The pointer should be freed by OrtReleaseRunOptions after use + */ + ORT_API2_STATUS(CreateRunOptions, _Outptr_ OrtRunOptions** out); + + ORT_API2_STATUS(RunOptionsSetRunLogVerbosityLevel, _Inout_ OrtRunOptions* options, int value); + ORT_API2_STATUS(RunOptionsSetRunLogSeverityLevel, _Inout_ OrtRunOptions* options, int value); + ORT_API2_STATUS(RunOptionsSetRunTag, _Inout_ OrtRunOptions*, _In_ const char* run_tag); + + ORT_API2_STATUS(RunOptionsGetRunLogVerbosityLevel, _In_ const OrtRunOptions* options, _Out_ int* out); + ORT_API2_STATUS(RunOptionsGetRunLogSeverityLevel, _In_ const OrtRunOptions* options, _Out_ int* out); + ORT_API2_STATUS(RunOptionsGetRunTag, _In_ const OrtRunOptions*, _Out_ const char** out); + + // Set a flag so that ALL incomplete OrtRun calls that are using this instance of OrtRunOptions + // will exit as soon as possible. + ORT_API2_STATUS(RunOptionsSetTerminate, _Inout_ OrtRunOptions* options); + // Unset the terminate flag to enable this OrtRunOptions instance being used in new OrtRun calls. + ORT_API2_STATUS(RunOptionsUnsetTerminate, _Inout_ OrtRunOptions* options); + + /** + * Create a tensor from an allocator. OrtReleaseValue will also release the buffer inside the output value + * \param out Should be freed by calling OrtReleaseValue + * \param type must be one of TENSOR_ELEMENT_DATA_TYPE_xxxx + */ + ORT_API2_STATUS(CreateTensorAsOrtValue, _Inout_ OrtAllocator* allocator, _In_ const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type, _Outptr_ OrtValue** out); + + /** + * Create a tensor with user's buffer. You can fill the buffer either before calling this function or after. + * p_data is owned by caller. OrtReleaseValue won't release p_data. + * \param out Should be freed by calling OrtReleaseValue + */ + ORT_API2_STATUS(CreateTensorWithDataAsOrtValue, _In_ const OrtMemoryInfo* info, _Inout_ void* p_data, + size_t p_data_len, _In_ const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type, + _Outptr_ OrtValue** out); + + /** + * \Sets *out to 1 iff an OrtValue is a tensor, 0 otherwise + */ + ORT_API2_STATUS(IsTensor, _In_ const OrtValue* value, _Out_ int* out); + + // This function doesn't work with string tensor + // this is a no-copy method whose pointer is only valid until the backing OrtValue is free'd. + ORT_API2_STATUS(GetTensorMutableData, _Inout_ OrtValue* value, _Outptr_ void** out); + + /** + * \param value A tensor created from OrtCreateTensor... function. + * \param s each A string array. Each string in this array must be null terminated. + * \param s_len length of s + */ + ORT_API2_STATUS(FillStringTensor, _Inout_ OrtValue* value, _In_ const char* const* s, size_t s_len); + + /** + * \param value A tensor created from OrtCreateTensor... function. + * \param len total data length, not including the trailing '\0' chars. + */ + ORT_API2_STATUS(GetStringTensorDataLength, _In_ const OrtValue* value, _Out_ size_t* len); + + /** + * \param s string contents. Each string is NOT null-terminated. + * \param value A tensor created from OrtCreateTensor... function. + * \param s_len total data length, get it from OrtGetStringTensorDataLength + */ + ORT_API2_STATUS(GetStringTensorContent, _In_ const OrtValue* value, _Out_writes_bytes_all_(s_len) void* s, + size_t s_len, _Out_writes_all_(offsets_len) size_t* offsets, size_t offsets_len); + + /** + * Don't free the 'out' value + */ + ORT_API2_STATUS(CastTypeInfoToTensorInfo, _In_ const OrtTypeInfo*, + _Outptr_result_maybenull_ const OrtTensorTypeAndShapeInfo** out); + + /** + * Return OnnxType from OrtTypeInfo + */ + ORT_API2_STATUS(GetOnnxTypeFromTypeInfo, _In_ const OrtTypeInfo*, _Out_ enum ONNXType* out); + + /** + * The 'out' value should be released by calling OrtReleaseTensorTypeAndShapeInfo + */ + ORT_API2_STATUS(CreateTensorTypeAndShapeInfo, _Outptr_ OrtTensorTypeAndShapeInfo** out); + + ORT_API2_STATUS(SetTensorElementType, _Inout_ OrtTensorTypeAndShapeInfo*, enum ONNXTensorElementDataType type); + + /** + * \param info Created from CreateTensorTypeAndShapeInfo() function + * \param dim_values An array with length of `dim_count`. Its elements can contain negative values. + * \param dim_count length of dim_values + */ + ORT_API2_STATUS(SetDimensions, OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count); + + ORT_API2_STATUS(GetTensorElementType, _In_ const OrtTensorTypeAndShapeInfo*, + _Out_ enum ONNXTensorElementDataType* out); + ORT_API2_STATUS(GetDimensionsCount, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ size_t* out); + ORT_API2_STATUS(GetDimensions, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, + size_t dim_values_length); + ORT_API2_STATUS(GetSymbolicDimensions, _In_ const OrtTensorTypeAndShapeInfo* info, + _Out_writes_all_(dim_params_length) const char* dim_params[], size_t dim_params_length); + + /** + * Return the number of elements specified by the tensor shape. + * Return a negative value if unknown (i.e., any dimension is negative.) + * e.g. + * [] -> 1 + * [1,3,4] -> 12 + * [2,0,4] -> 0 + * [-1,3,4] -> -1 + */ + ORT_API2_STATUS(GetTensorShapeElementCount, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ size_t* out); + + /** + * \param out Should be freed by OrtReleaseTensorTypeAndShapeInfo after use + */ + ORT_API2_STATUS(GetTensorTypeAndShape, _In_ const OrtValue* value, _Outptr_ OrtTensorTypeAndShapeInfo** out); + + /** + * Get the type information of an OrtValue + * \param value + * \param out The returned value should be freed by OrtReleaseTypeInfo after use + */ + ORT_API2_STATUS(GetTypeInfo, _In_ const OrtValue* value, _Outptr_result_maybenull_ OrtTypeInfo** out); + + ORT_API2_STATUS(GetValueType, _In_ const OrtValue* value, _Out_ enum ONNXType* out); + + ORT_API2_STATUS(CreateMemoryInfo, _In_ const char* name1, enum OrtAllocatorType type, int id1, + enum OrtMemType mem_type1, _Outptr_ OrtMemoryInfo** out); + + /** + * Convenience function for special case of CreateMemoryInfo, for the CPU allocator. Uses name = "Cpu" and id = 0. + */ + ORT_API2_STATUS(CreateCpuMemoryInfo, enum OrtAllocatorType type, enum OrtMemType mem_type1, + _Outptr_ OrtMemoryInfo** out); + + /** + * Test if two memory info are equal + * \Sets 'out' to 0 if equal, -1 if not equal + */ + ORT_API2_STATUS(CompareMemoryInfo, _In_ const OrtMemoryInfo* info1, _In_ const OrtMemoryInfo* info2, _Out_ int* out); + + /** + * Do not free the returned value + */ + ORT_API2_STATUS(MemoryInfoGetName, _In_ const OrtMemoryInfo* ptr, _Out_ const char** out); + ORT_API2_STATUS(MemoryInfoGetId, _In_ const OrtMemoryInfo* ptr, _Out_ int* out); + ORT_API2_STATUS(MemoryInfoGetMemType, _In_ const OrtMemoryInfo* ptr, _Out_ OrtMemType* out); + ORT_API2_STATUS(MemoryInfoGetType, _In_ const OrtMemoryInfo* ptr, _Out_ OrtAllocatorType* out); + + ORT_API2_STATUS(AllocatorAlloc, _Inout_ OrtAllocator* ptr, size_t size, _Outptr_ void** out); + ORT_API2_STATUS(AllocatorFree, _Inout_ OrtAllocator* ptr, void* p); + ORT_API2_STATUS(AllocatorGetInfo, _In_ const OrtAllocator* ptr, _Outptr_ const struct OrtMemoryInfo** out); + + // The returned pointer doesn't have to be freed. + // Always returns the same instance on every invocation. + // Please note that this is a non-arena based allocator. + ORT_API2_STATUS(GetAllocatorWithDefaultOptions, _Outptr_ OrtAllocator** out); + + // Override symbolic dimensions (by specific denotation strings) with actual values if known at session initialization time to enable + // optimizations that can take advantage of fixed values (such as memory planning, etc) + ORT_API2_STATUS(AddFreeDimensionOverride, _Inout_ OrtSessionOptions* options, _In_ const char* dim_denotation, + _In_ int64_t dim_value); + + /** + * APIs to support non-tensor types - map and sequence. + * Currently only the following types are supported + * Note: the following types should be kept in sync with data_types.h + * Map types + * ========= + * std::map + * std::map + * std::map + * std::map + * std::map + * std::map + * std::map + * std::map + * + * Sequence types + * ============== + * std::vector + * std::vector + * std::vector + * std::vector + * std::vector> + * std::vector + */ + + /** + * If input OrtValue represents a map, you need to retrieve the keys and values + * separately. Use index=0 to retrieve keys and index=1 to retrieve values. + * If input OrtValue represents a sequence, use index to retrieve the index'th element + * of the sequence. + */ + ORT_API2_STATUS(GetValue, _In_ const OrtValue* value, int index, _Inout_ OrtAllocator* allocator, + _Outptr_ OrtValue** out); + + /** + * Returns 2 for type map and N for sequence where N is the number of elements + * in the sequence. + */ + ORT_API2_STATUS(GetValueCount, _In_ const OrtValue* value, _Out_ size_t* out); + + /** + * To construct a map, use num_values = 2 and 'in' should be an arrary of 2 OrtValues + * representing keys and values. + * To construct a sequence, use num_values = N where N is the number of the elements in the + * sequence. 'in' should be an arrary of N OrtValues. + * \value_type should be either map or sequence. + */ + ORT_API2_STATUS(CreateValue, _In_reads_(num_values) const OrtValue* const* in, size_t num_values, + enum ONNXType value_type, _Outptr_ OrtValue** out); + + /** + * Construct OrtValue that contains a value of non-standard type created for + * experiments or while awaiting standardization. OrtValue in this case would contain + * an internal representation of the Opaque type. Opaque types are distinguished between + * each other by two strings 1) domain and 2) type name. The combination of the two + * must be unique, so the type representation is properly identified internally. The combination + * must be properly registered from within ORT at both compile/run time or by another API. + * + * To construct the OrtValue pass domain and type names, also a pointer to a data container + * the type of which must be know to both ORT and the client program. That data container may or may + * not match the internal representation of the Opaque type. The sizeof(data_container) is passed for + * verification purposes. + * + * \domain_name - domain name for the Opaque type, null terminated. + * \type_name - type name for the Opaque type, null terminated. + * \data_contianer - data to populate OrtValue + * \data_container_size - sizeof() of the data container. Must match the sizeof() of the expected + * data_container size internally. + */ + ORT_API2_STATUS(CreateOpaqueValue, _In_z_ const char* domain_name, _In_z_ const char* type_name, + _In_ const void* data_container, size_t data_container_size, _Outptr_ OrtValue** out); + + /** + * Fetch data from an OrtValue that contains a value of non-standard type created for + * experiments or while awaiting standardization. + * \domain_name - domain name for the Opaque type, null terminated. + * \type_name - type name for the Opaque type, null terminated. + * \data_contianer - data to populate OrtValue + * \data_container_size - sizeof() of the data container. Must match the sizeof() of the expected + * data_container size internally. + */ + + ORT_API2_STATUS(GetOpaqueValue, _In_ const char* domain_name, _In_ const char* type_name, _In_ const OrtValue* in, + _Out_ void* data_container, size_t data_container_size); + + ORT_API2_STATUS(KernelInfoGetAttribute_float, _In_ const OrtKernelInfo* info, _In_ const char* name, + _Out_ float* out); + ORT_API2_STATUS(KernelInfoGetAttribute_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, + _Out_ int64_t* out); + ORT_API2_STATUS(KernelInfoGetAttribute_string, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ char* out, + _Inout_ size_t* size); + + ORT_API2_STATUS(KernelContext_GetInputCount, _In_ const OrtKernelContext* context, _Out_ size_t* out); + ORT_API2_STATUS(KernelContext_GetOutputCount, _In_ const OrtKernelContext* context, _Out_ size_t* out); + ORT_API2_STATUS(KernelContext_GetInput, _In_ const OrtKernelContext* context, _In_ size_t index, + _Out_ const OrtValue** out); + ORT_API2_STATUS(KernelContext_GetOutput, _Inout_ OrtKernelContext* context, _In_ size_t index, + _In_ const int64_t* dim_values, size_t dim_count, _Outptr_ OrtValue** out); + + ORT_CLASS_RELEASE(Env); + ORT_CLASS_RELEASE(Status); // nullptr for Status* indicates success + ORT_CLASS_RELEASE(MemoryInfo); + ORT_CLASS_RELEASE(Session); //Don't call OrtReleaseSession from Dllmain (because session owns a thread pool) + ORT_CLASS_RELEASE(Value); + ORT_CLASS_RELEASE(RunOptions); + ORT_CLASS_RELEASE(TypeInfo); + ORT_CLASS_RELEASE(TensorTypeAndShapeInfo); + ORT_CLASS_RELEASE(SessionOptions); + ORT_CLASS_RELEASE(CustomOpDomain); + + // End of Version 1 - DO NOT MODIFY ABOVE (see above text for more information) + + // Version 2 - In development, feel free to add/remove/rearrange here + + /** + * GetDenotationFromTypeInfo + * This api augments OrtTypeInfo to return denotations on the type. + * This is used by WinML to determine if an input/output is intended to be an Image or a Tensor. + */ + ORT_API2_STATUS(GetDenotationFromTypeInfo, _In_ const OrtTypeInfo*, _Out_ const char** const denotation, + _Out_ size_t* len); + + // OrtTypeInfo Casting methods + + /** + * CastTypeInfoToMapTypeInfo + * This api augments OrtTypeInfo to return an OrtMapTypeInfo when the type is a map. + * The OrtMapTypeInfo has additional information about the map's key type and value type. + * This is used by WinML to support model reflection APIs. + * This is used by WinML to support model reflection APIs. + * + * Don't free the 'out' value + */ + ORT_API2_STATUS(CastTypeInfoToMapTypeInfo, _In_ const OrtTypeInfo* type_info, + _Outptr_result_maybenull_ const OrtMapTypeInfo** out); + + /** + * CastTypeInfoToSequenceTypeInfo + * This api augments OrtTypeInfo to return an OrtSequenceTypeInfo when the type is a sequence. + * The OrtSequenceTypeInfo has additional information about the sequence's element type. + * This is used by WinML to support model reflection APIs. + * + * Don't free the 'out' value + */ + ORT_API2_STATUS(CastTypeInfoToSequenceTypeInfo, _In_ const OrtTypeInfo* type_info, + _Outptr_result_maybenull_ const OrtSequenceTypeInfo** out); + + // OrtMapTypeInfo Accessors + + /** + * GetMapKeyType + * This api augments get the key type of a map. Key types are restricted to being scalar types and use ONNXTensorElementDataType. + * This is used by WinML to support model reflection APIs. + */ + ORT_API2_STATUS(GetMapKeyType, _In_ const OrtMapTypeInfo* map_type_info, _Out_ enum ONNXTensorElementDataType* out); + + /** + * GetMapValueType + * This api augments get the value type of a map. + */ + ORT_API2_STATUS(GetMapValueType, _In_ const OrtMapTypeInfo* map_type_info, _Outptr_ OrtTypeInfo** type_info); + + // OrtSequenceTypeInfo Accessors + + /** + * GetSequenceElementType + * This api augments get the element type of a sequence. + * This is used by WinML to support model reflection APIs. + */ + ORT_API2_STATUS(GetSequenceElementType, _In_ const OrtSequenceTypeInfo* sequence_type_info, + _Outptr_ OrtTypeInfo** type_info); + + ORT_CLASS_RELEASE(MapTypeInfo); + ORT_CLASS_RELEASE(SequenceTypeInfo); + + /** + * \param out is set to a null terminated string allocated using 'allocator'. The caller is responsible for freeing it. + * Profiling is turned ON automatically if enabled for the particular session by invoking EnableProfiling() + * on the SessionOptions instance used to create the session. + */ + ORT_API2_STATUS(SessionEndProfiling, _In_ OrtSession* sess, _Inout_ OrtAllocator* allocator, _Outptr_ char** out); + + /** + * \param out is a pointer to the newly created object. The pointer should be freed by calling ReleaseModelMetadata after use. + */ + ORT_API2_STATUS(SessionGetModelMetadata, _In_ const OrtSession* sess, _Outptr_ OrtModelMetadata** out); + + /** + * \param value is set to a null terminated string allocated using 'allocator'. The caller is responsible for freeing it. + */ + ORT_API2_STATUS(ModelMetadataGetProducerName, _In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _Outptr_ char** value); + ORT_API2_STATUS(ModelMetadataGetGraphName, _In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _Outptr_ char** value); + ORT_API2_STATUS(ModelMetadataGetDomain, _In_ const OrtModelMetadata* model_metadata, _Inout_ OrtAllocator* allocator, + _Outptr_ char** value); + ORT_API2_STATUS(ModelMetadataGetDescription, _In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _Outptr_ char** value); + /** + * \param value is set to a null terminated string allocated using 'allocator'. The caller is responsible for freeing it. + * 'value' will be a nullptr if the given key is not found in the custom metadata map. + */ + ORT_API2_STATUS(ModelMetadataLookupCustomMetadataMap, _In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _In_ const char* key, _Outptr_result_maybenull_ char** value); + + ORT_API2_STATUS(ModelMetadataGetVersion, _In_ const OrtModelMetadata* model_metadata, _Out_ int64_t* value); + + ORT_CLASS_RELEASE(ModelMetadata); + + /* + * Creates an environment with global threadpools that will be shared across sessions. + * Use this in conjunction with DisablePerSessionThreads API or else the session will use + * its own thread pools. + */ + ORT_API2_STATUS(CreateEnvWithGlobalThreadPools, OrtLoggingLevel default_logging_level, _In_ const char* logid, + _In_ const OrtThreadingOptions* t_options, _Outptr_ OrtEnv** out); + + /* TODO: Should there be a version of CreateEnvWithGlobalThreadPools with custom logging function? */ + + /* + * Calling this API will make the session use the global threadpools shared across sessions. + * This API should be used in conjunction with CreateEnvWithGlobalThreadPools API. + */ + ORT_API2_STATUS(DisablePerSessionThreads, _Inout_ OrtSessionOptions* options); + + ORT_API2_STATUS(CreateThreadingOptions, _Outptr_ OrtThreadingOptions** out); + + ORT_CLASS_RELEASE(ThreadingOptions); + + /** + * \param num_keys contains the number of keys in the custom metadata map + * \param keys is an array of null terminated strings (array count = num_keys) allocated using 'allocator'. + * The caller is responsible for freeing each string and the pointer array. + * 'keys' will be a nullptr if custom metadata map is empty. + */ + ORT_API2_STATUS(ModelMetadataGetCustomMetadataMapKeys, _In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _Outptr_result_buffer_maybenull_(*num_keys) char*** keys, _Out_ int64_t* num_keys); + + // Override symbolic dimensions (by specific name strings) with actual values + // if known at session initialization time to enable optimizations that can + // take advantage of fixed values (such as memory planning, etc) + ORT_API2_STATUS(AddFreeDimensionOverrideByName, + _Inout_ OrtSessionOptions* options, _In_ const char* dim_name, + _In_ int64_t dim_value); + + /** + * \param out_ptr will hold a pointer to the array of char * + * representing available providers. + * \param provider_length is a pointer to an int variable where + * the number of available providers will be added. + * The caller is responsible for freeing each char * and the pointer + * array by calling ReleaseAvailableProviders(). + */ + ORT_API2_STATUS(GetAvailableProviders, _Outptr_ char*** out_ptr, + _In_ int* provider_length); + + /** + * \param ptr is the pointer to an array of available providers you + * get after calling GetAvailableProviders(). + * \param providers_length is the number of available providers. + */ + ORT_API2_STATUS(ReleaseAvailableProviders, _In_ char** ptr, + _In_ int providers_length); + + /** + * \param value - A tensor created from OrtCreateTensor... function. + * \param index - index of string tensor element, length of element at index will be returned. + * \param out - number of UTF-8 bytes that the string contains + */ + ORT_API2_STATUS(GetStringTensorElementLength, _In_ const OrtValue* value, size_t index, _Out_ size_t* out); + + /** + * \param s string element contents in UTF-8 encoding. The string is NOT null-terminated. + * \param value A tensor created from OrtCreateTensor... function. + * \param s_len element length, get it from OrtGetStringTensorElementLength. + * \param index offset of element of tensor to return. + */ + ORT_API2_STATUS(GetStringTensorElement, _In_ const OrtValue* value, size_t s_len, size_t index, _Out_writes_bytes_all_(s_len) void* s); + + /** + * \param value - A tensor created from OrtCreateTensor... function. + * \param s - A null terminated UTF-8 encoded string. + * \param index - index of string tensor element to fill + */ + ORT_API2_STATUS(FillStringTensorElement, _Inout_ OrtValue* value, _In_ const char* s, size_t index); + + /** + * Set a single session configuration entry as a pair of strings + * If a configuration with same key exists, this will overwrite the configuration with the given config_value + * \param config_key A null terminated string representation of the config key + * \param config_value A null terminated string representation of the config value + * The config_key and the format of config_value are defined in onnxruntime_session_options_config_keys.h + */ + ORT_API2_STATUS(AddSessionConfigEntry, _Inout_ OrtSessionOptions* options, + _In_z_ const char* config_key, _In_z_ const char* config_value); + + /** + * \param sess valid OrtSession instance + * \param mem_info - valid OrtMemoryInfo instance + * \param - out a ptr to a new instance of OrtAllocator according to the spec within mem_info + * if successful + * \return OrtStatus or nullptr if successful + */ + ORT_API2_STATUS(CreateAllocator, _In_ const OrtSession* sess, _In_ const OrtMemoryInfo* mem_info, + _Outptr_ OrtAllocator** out); + + // Release instance of OrtAllocator obtained from CreateAllocator API + ORT_CLASS_RELEASE(Allocator); + + ORT_API2_STATUS(RunWithBinding, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOptions* run_options, _In_ const OrtIoBinding* binding_ptr); + + // Creates an IoBinding instance that allows one to bind pre-allocated OrtValues + // to input names. Thus if you want to use a raw on device buffer as input or output + // you can avoid extra copy during runtime. + ORT_API2_STATUS(CreateIoBinding, _Inout_ OrtSession* sess, _Outptr_ OrtIoBinding** out); + + // Release instance or OrtIoBinding obtained from CreateIoBinding API + ORT_CLASS_RELEASE(IoBinding); + + /** + * The function will bind the OrtValue to a specified input name. + * The OrtValue must be a Tensor. ORT would use that value in place of input for the specified name. + * \param binding_ptr - an instance of OrtIoBinding created by CreateIoBinding() + * \param name - name for the model input + * \param val_ptr - OrtValue of Tensor type. + * \return OrtStatus instance on error which the caller is responsible to free or nullptr on success + */ + ORT_API2_STATUS(BindInput, _Inout_ OrtIoBinding* binding_ptr, _In_ const char* name, _In_ const OrtValue* val_ptr); + + /** + * The function will bind the OrtValue to the specified output name. + * The OrtValue must be a Tensor. ORT would use that value in place of output for the specified name. + * + * \param binding_ptr - an instance of OrtIoBinding created by CreateIoBinding() + * \param name - name for the model output + * \param val_ptr - OrtValue of Tensor type. + * \return OrtStatus instance on error which the caller is responsible to free or nullptr on success + */ + ORT_API2_STATUS(BindOutput, _Inout_ OrtIoBinding* binding_ptr, _In_ const char* name, _In_ const OrtValue* val_ptr); + + /** + * The function will bind the OrtValue to a device which specification is contained within OrtMemoryInfo + * You can either create an instance of OrtMemoryInfo with a device id or obtain one from the allocator that you are created/using + * This is useful when one or more outputs have dynamic shapes and, it is hard to pre-allocated and bind a chunk of + * memory within OrtValue ahead of time. + * + * \param binding_ptr - an instance of OrtIoBinding created by CreateIoBinding() + * \param name - name for the model output + * \param mem_info_ptr - OrtMemoryInfo + * \return OrtStatus instance on error which the caller is responsible to free or nullptr on success + */ + ORT_API2_STATUS(BindOutputToDevice, _Inout_ OrtIoBinding* binding_ptr, _In_ const char* name, _In_ const OrtMemoryInfo* val_ptr); + + /** + * The function returns the names of the outputs in the order they were bound. This is useful after running the model + * with bound outputs because the returned names are in order in which output OrtValues are returned. This API is optional + * to use. If you knew the order of outputs and its names you used for binding you would not need to use this API. + * + * \param binding_ptr - a ptr to an instance of OrtIoBinding created obtained from CreateIoBinding() + * \param allocator - a ptr to an instance of OrtAllocator obtained with CreateAllocator() or GetAllocatorWithDefaultOptions() + * the specified allocator will be used to allocate continuous buffers for output strings and lengths. + * \param buffer - pointer to a continuous buffer of non-zero terminated UTF-8 encoded strings. The number of strings stored is returned count parameter. + * this buffer will be allocated with the specified allocator and must be freed after it is no longer needed. + * \param lengths - a pointer to a continuous buffer of size_t lengths of strings returned in the buffer. The number of items is returned + * in the count. This buffer is allocated with the specified allocator and must be freed after it is no longer needed. + * \para count - is the number of strings returned. If the instance of OrtIoBiding has no bound outputs, zero is returned, + * no memory allocation is performed and buffer and lengths are nullptr on return. + */ + ORT_API2_STATUS(GetBoundOutputNames, _In_ const OrtIoBinding* binding_ptr, _In_ OrtAllocator* allocator, + _Out_ char** buffer, _Out_writes_all_(count) size_t** lengths, _Out_ size_t* count); + + /** + * The function returns an array of pointers to individually allocated OrtValues that contain results of a model execution with RunWithBinding() + * The array contains the same number of OrtValues and they are in the same order as they were bound with BindOutput() + * or BindOutputToDevice(). + * The returned OrtValues must be individually released after they are no longer needed. + * The array is allocated using the specified instance of the allocator and must be freed using the same allocator after + * all the OrtValues contained therein are individually released. + * + * \param binding_ptr - instance of OrtIoBidning + * \param allocator - instance of allocator to allocate output array + * \param output - pointer to the allocated buffer. Returns nullptr if no outputs. + * \param output_count - pointer to the number of OrtValues returned. Zero if no outputs. + */ + ORT_API2_STATUS(GetBoundOutputValues, _In_ const OrtIoBinding* binding_ptr, _In_ OrtAllocator* allocator, + _Out_writes_all_(output_count) OrtValue*** output, _Out_ size_t* output_count); + + /** Clears any previously specified bindings for inputs/outputs + */ + void(ORT_API_CALL* ClearBoundInputs)(_Inout_ OrtIoBinding* binding_ptr) NO_EXCEPTION ORT_ALL_ARGS_NONNULL; + void(ORT_API_CALL* ClearBoundOutputs)(_Inout_ OrtIoBinding* binding_ptr) NO_EXCEPTION ORT_ALL_ARGS_NONNULL; + + /** + * Provides element-level access into a tensor. + * \param location_values a pointer to an array of index values that specify an element's location in the tensor data blob + * \param location_values_count length of location_values + * \param out a pointer to the element specified by location_values + * e.g. + * Given a tensor with overall shape [3,224,224], an element at + * location [2,150,128] can be accessed directly. + * + * This function only works for numeric tensors. + * This is a no-copy method whose pointer is only valid until the backing OrtValue is free'd. + */ + ORT_API2_STATUS(TensorAt, _Inout_ OrtValue* value, const int64_t* location_values, size_t location_values_count, _Outptr_ void** out); + + /** + * Creates an allocator instance and registers it with the env to enable + * sharing between multiple sessions that use the same env instance. + * Lifetime of the created allocator will be valid for the duration of the environment. + * Returns an error if an allocator with the same OrtMemoryInfo is already registered. + * \param mem_info must be non-null. + * \param arena_cfg if nullptr defaults will be used. + * See docs/C_API.md for details. + */ + ORT_API2_STATUS(CreateAndRegisterAllocator, _Inout_ OrtEnv* env, _In_ const OrtMemoryInfo* mem_info, + _In_ const OrtArenaCfg* arena_cfg); + + /** + * Set the language projection for collecting telemetry data when Env is created + * \param projection the source projected language. + */ + ORT_API2_STATUS(SetLanguageProjection, _In_ const OrtEnv* ort_env, _In_ OrtLanguageProjection projection); + + /** + * \param out is set to the nanoseconds of profiling's start time + */ + ORT_API2_STATUS(SessionGetProfilingStartTimeNs, _In_ const OrtSession* sess, _Outptr_ uint64_t* out); + + /** + * Use this API to configure the global thread pool options to be used in the call to CreateEnvWithGlobalThreadPools. + * A value of 0 means ORT will pick the default. + * A value of 1 means the invoking thread will be used; no threads will be created in the thread pool. + */ + ORT_API2_STATUS(SetGlobalIntraOpNumThreads, _Inout_ OrtThreadingOptions* tp_options, int intra_op_num_threads); + ORT_API2_STATUS(SetGlobalInterOpNumThreads, _Inout_ OrtThreadingOptions* tp_options, int inter_op_num_threads); + + /** + * Use this API to configure the global thread pool options to be used in the call to CreateEnvWithGlobalThreadPools. + * Allow spinning of thread pools when their queues are empty. This API will set the value for both + * inter_op and intra_op threadpools. + * \param allow_spinning valid values are 1 and 0. + * 1: threadpool will spin to wait for queue to become non-empty, 0: it won't spin. + * Prefer a value of 0 if your CPU usage is very high. + */ + ORT_API2_STATUS(SetGlobalSpinControl, _Inout_ OrtThreadingOptions* tp_options, int allow_spinning); +}; + +/* + * Steps to use a custom op: + * 1 Create an OrtCustomOpDomain with the domain name used by the custom ops + * 2 Create an OrtCustomOp structure for each op and add them to the domain + * 3 Call OrtAddCustomOpDomain to add the custom domain of ops to the session options +*/ +#define OrtCustomOpApi OrtApi + +/* + * The OrtCustomOp structure defines a custom op's schema and its kernel callbacks. The callbacks are filled in by + * the implementor of the custom op. +*/ +struct OrtCustomOp { + uint32_t version; // Initialize to ORT_API_VERSION + + // This callback creates the kernel, which is a user defined parameter that is passed to the Kernel* callbacks below. + void*(ORT_API_CALL* CreateKernel)(_In_ struct OrtCustomOp* op, _In_ const OrtApi* api, + _In_ const OrtKernelInfo* info); + + // Returns the name of the op + const char*(ORT_API_CALL* GetName)(_In_ struct OrtCustomOp* op); + + // Returns the type of the execution provider, return nullptr to use CPU execution provider + const char*(ORT_API_CALL* GetExecutionProviderType)(_In_ struct OrtCustomOp* op); + + // Returns the count and types of the input & output tensors + ONNXTensorElementDataType(ORT_API_CALL* GetInputType)(_In_ struct OrtCustomOp* op, _In_ size_t index); + size_t(ORT_API_CALL* GetInputTypeCount)(_In_ struct OrtCustomOp* op); + ONNXTensorElementDataType(ORT_API_CALL* GetOutputType)(_In_ struct OrtCustomOp* op, _In_ size_t index); + size_t(ORT_API_CALL* GetOutputTypeCount)(_In_ struct OrtCustomOp* op); + + // Op kernel callbacks + void(ORT_API_CALL* KernelCompute)(_In_ void* op_kernel, _In_ OrtKernelContext* context); + void(ORT_API_CALL* KernelDestroy)(_In_ void* op_kernel); +}; + +/* + * END EXPERIMENTAL +*/ + +#ifdef __cplusplus +} +#endif diff --git a/includes/onnxruntime/onnxruntime_cxx_api.h b/includes/onnxruntime/onnxruntime_cxx_api.h new file mode 100644 index 00000000..9e8f365a --- /dev/null +++ b/includes/onnxruntime/onnxruntime_cxx_api.h @@ -0,0 +1,538 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Summary: The Ort C++ API is a header only wrapper around the Ort C API. +// +// The C++ API simplifies usage by returning values directly instead of error codes, throwing exceptions on errors +// and automatically releasing resources in the destructors. +// +// Each of the C++ wrapper classes holds only a pointer to the C internal object. Treat them like smart pointers. +// To create an empty object, pass 'nullptr' to the constructor (for example, Env e{nullptr};). +// +// Only move assignment between objects is allowed, there are no copy constructors. Some objects have explicit 'Clone' +// methods for this purpose. + +#pragma once +#include "onnxruntime_c_api.h" +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef ORT_NO_EXCEPTIONS +#include +#endif + +namespace Ort { + +// All C++ methods that can fail will throw an exception of this type +struct Exception : std::exception { + Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {} + + OrtErrorCode GetOrtErrorCode() const { return code_; } + const char* what() const noexcept override { return message_.c_str(); } + + private: + std::string message_; + OrtErrorCode code_; +}; + +#ifdef ORT_NO_EXCEPTIONS +#define ORT_CXX_API_THROW(string, code) \ + do { \ + std::cerr << Ort::Exception(string, code) \ + .what() \ + << std::endl; \ + abort(); \ + } while (false) +#else +#define ORT_CXX_API_THROW(string, code) \ + throw Ort::Exception(string, code) +#endif + +// This is used internally by the C++ API. This class holds the global variable that points to the OrtApi, it's in a template so that we can define a global variable in a header and make +// it transparent to the users of the API. +template +struct Global { + static const OrtApi* api_; +}; + +// If macro ORT_API_MANUAL_INIT is defined, no static initialization will be performed. Instead, user must call InitApi() before using it. + +template +#ifdef ORT_API_MANUAL_INIT +const OrtApi* Global::api_{}; +inline void InitApi() { Global::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); } +#else +const OrtApi* Global::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); +#endif + +// This returns a reference to the OrtApi interface in use, in case someone wants to use the C API functions +inline const OrtApi& GetApi() { return *Global::api_; } + +// This is a C++ wrapper for GetAvailableProviders() C API and returns +// a vector of strings representing the available execution providers. +std::vector GetAvailableProviders(); + +// This is used internally by the C++ API. This macro is to make it easy to generate overloaded methods for all of the various OrtRelease* functions for every Ort* type +// This can't be done in the C API since C doesn't have function overloading. +#define ORT_DEFINE_RELEASE(NAME) \ + inline void OrtRelease(Ort##NAME* ptr) { GetApi().Release##NAME(ptr); } + +ORT_DEFINE_RELEASE(Allocator); +ORT_DEFINE_RELEASE(MemoryInfo); +ORT_DEFINE_RELEASE(CustomOpDomain); +ORT_DEFINE_RELEASE(Env); +ORT_DEFINE_RELEASE(RunOptions); +ORT_DEFINE_RELEASE(Session); +ORT_DEFINE_RELEASE(SessionOptions); +ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo); +ORT_DEFINE_RELEASE(SequenceTypeInfo); +ORT_DEFINE_RELEASE(MapTypeInfo); +ORT_DEFINE_RELEASE(TypeInfo); +ORT_DEFINE_RELEASE(Value); +ORT_DEFINE_RELEASE(ModelMetadata); +ORT_DEFINE_RELEASE(ThreadingOptions); +ORT_DEFINE_RELEASE(IoBinding); + +// This is used internally by the C++ API. This is the common base class used by the wrapper objects. +template +struct Base { + using contained_type = T; + + Base() = default; + Base(T* p) : p_{p} { + if (!p) + ORT_CXX_API_THROW("Allocation failure", ORT_FAIL); + } + ~Base() { OrtRelease(p_); } + + operator T*() { return p_; } + operator const T*() const { return p_; } + + T* release() { + T* p = p_; + p_ = nullptr; + return p; + } + + protected: + Base(const Base&) = delete; + Base& operator=(const Base&) = delete; + Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; } + void operator=(Base&& v) noexcept { + OrtRelease(p_); + p_ = v.p_; + v.p_ = nullptr; + } + + T* p_{}; + + template + friend struct Unowned; // This friend line is needed to keep the centos C++ compiler from giving an error +}; + +template +struct Base { + using contained_type = const T; + + Base() = default; + Base(const T* p) : p_{p} { + if (!p) + ORT_CXX_API_THROW("Invalid instance ptr", ORT_INVALID_ARGUMENT); + } + ~Base() = default; + + operator const T*() const { return p_; } + + protected: + Base(const Base&) = delete; + Base& operator=(const Base&) = delete; + Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; } + void operator=(Base&& v) noexcept { + p_ = v.p_; + v.p_ = nullptr; + } + + const T* p_{}; +}; + +template +struct Unowned : T { + Unowned(decltype(T::p_) p) : T{p} {} + Unowned(Unowned&& v) : T{v.p_} {} + ~Unowned() { this->release(); } +}; + +struct AllocatorWithDefaultOptions; +struct MemoryInfo; +struct Env; +struct TypeInfo; +struct Value; +struct ModelMetadata; + +struct Env : Base { + Env(std::nullptr_t) {} + Env(OrtLoggingLevel default_logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = ""); + Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel default_logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = ""); + Env(OrtLoggingLevel default_logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param); + explicit Env(OrtEnv* p) : Base{p} {} + + Env& EnableTelemetryEvents(); + Env& DisableTelemetryEvents(); + + Env& CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg); + + static const OrtApi* s_api; +}; + +struct CustomOpDomain : Base { + explicit CustomOpDomain(std::nullptr_t) {} + explicit CustomOpDomain(const char* domain); + + void Add(OrtCustomOp* op); +}; + +struct RunOptions : Base { + RunOptions(std::nullptr_t) {} + RunOptions(); + + RunOptions& SetRunLogVerbosityLevel(int); + int GetRunLogVerbosityLevel() const; + + RunOptions& SetRunLogSeverityLevel(int); + int GetRunLogSeverityLevel() const; + + RunOptions& SetRunTag(const char* run_tag); + const char* GetRunTag() const; + + // terminate ALL currently executing Session::Run calls that were made using this RunOptions instance + RunOptions& SetTerminate(); + // unset the terminate flag so this RunOptions instance can be used in a new Session::Run call + RunOptions& UnsetTerminate(); +}; + +struct SessionOptions : Base { + explicit SessionOptions(std::nullptr_t) {} + SessionOptions(); + explicit SessionOptions(OrtSessionOptions* p) : Base{p} {} + + SessionOptions Clone() const; + + SessionOptions& SetIntraOpNumThreads(int intra_op_num_threads); + SessionOptions& SetInterOpNumThreads(int inter_op_num_threads); + SessionOptions& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); + + SessionOptions& EnableCpuMemArena(); + SessionOptions& DisableCpuMemArena(); + + SessionOptions& SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_file); + + SessionOptions& EnableProfiling(const ORTCHAR_T* profile_file_prefix); + SessionOptions& DisableProfiling(); + + SessionOptions& EnableMemPattern(); + SessionOptions& DisableMemPattern(); + + SessionOptions& SetExecutionMode(ExecutionMode execution_mode); + + SessionOptions& SetLogId(const char* logid); + SessionOptions& SetLogSeverityLevel(int level); + + SessionOptions& Add(OrtCustomOpDomain* custom_op_domain); + + SessionOptions& DisablePerSessionThreads(); + + SessionOptions& AddConfigEntry(const char* config_key, const char* config_value); +}; + +struct ModelMetadata : Base { + explicit ModelMetadata(std::nullptr_t) {} + explicit ModelMetadata(OrtModelMetadata* p) : Base{p} {} + + char* GetProducerName(OrtAllocator* allocator) const; + char* GetGraphName(OrtAllocator* allocator) const; + char* GetDomain(OrtAllocator* allocator) const; + char* GetDescription(OrtAllocator* allocator) const; + char** GetCustomMetadataMapKeys(OrtAllocator* allocator, _Out_ int64_t& num_keys) const; + char* LookupCustomMetadataMap(const char* key, OrtAllocator* allocator) const; + int64_t GetVersion() const; +}; + +struct Session : Base { + explicit Session(std::nullptr_t) {} + Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); + Session(Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options); + + // Run that will allocate the output values + std::vector Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, + const char* const* output_names, size_t output_count); + // Run for when there is a list of prealloated outputs + void Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, + const char* const* output_names, Value* output_values, size_t output_count); + + void Run(const RunOptions& run_options, const struct IoBinding&); + + size_t GetInputCount() const; + size_t GetOutputCount() const; + size_t GetOverridableInitializerCount() const; + + char* GetInputName(size_t index, OrtAllocator* allocator) const; + char* GetOutputName(size_t index, OrtAllocator* allocator) const; + char* GetOverridableInitializerName(size_t index, OrtAllocator* allocator) const; + char* EndProfiling(OrtAllocator* allocator) const; + uint64_t GetProfilingStartTimeNs() const; + ModelMetadata GetModelMetadata() const; + + TypeInfo GetInputTypeInfo(size_t index) const; + TypeInfo GetOutputTypeInfo(size_t index) const; + TypeInfo GetOverridableInitializerTypeInfo(size_t index) const; +}; + +struct TensorTypeAndShapeInfo : Base { + explicit TensorTypeAndShapeInfo(std::nullptr_t) {} + explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : Base{p} {} + + ONNXTensorElementDataType GetElementType() const; + size_t GetElementCount() const; + + size_t GetDimensionsCount() const; + void GetDimensions(int64_t* values, size_t values_count) const; + void GetSymbolicDimensions(const char** values, size_t values_count) const; + + std::vector GetShape() const; +}; + +struct SequenceTypeInfo : Base { + explicit SequenceTypeInfo(std::nullptr_t) {} + explicit SequenceTypeInfo(OrtSequenceTypeInfo* p) : Base{p} {} + + TypeInfo GetSequenceElementType() const; +}; + +struct MapTypeInfo : Base { + explicit MapTypeInfo(std::nullptr_t) {} + explicit MapTypeInfo(OrtMapTypeInfo* p) : Base{p} {} + + ONNXTensorElementDataType GetMapKeyType() const; + TypeInfo GetMapValueType() const; +}; + +struct TypeInfo : Base { + explicit TypeInfo(std::nullptr_t) {} + explicit TypeInfo(OrtTypeInfo* p) : Base{p} {} + + Unowned GetTensorTypeAndShapeInfo() const; + Unowned GetSequenceTypeInfo() const; + Unowned GetMapTypeInfo() const; + + + ONNXType GetONNXType() const; +}; + +struct Value : Base { + template + static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len); + static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type); + template + static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len); + static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type); + + static Value CreateMap(Value& keys, Value& values); + static Value CreateSequence(std::vector& values); + + template + static Value CreateOpaque(const char* domain, const char* type_name, const T&); + + template + void GetOpaqueData(const char* domain, const char* type_name, T&) const; + + explicit Value(std::nullptr_t) {} + explicit Value(OrtValue* p) : Base{p} {} + Value(Value&&) = default; + Value& operator=(Value&&) = default; + + bool IsTensor() const; + size_t GetCount() const; // If a non tensor, returns 2 for map and N for sequence, where N is the number of elements + Value GetValue(int index, OrtAllocator* allocator) const; + + size_t GetStringTensorDataLength() const; + void GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const; + + template + T* GetTensorMutableData(); + + template + const T* GetTensorData() const; + + template + T& At(const std::vector& location); + + TypeInfo GetTypeInfo() const; + TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const; + + size_t GetStringTensorElementLength(size_t element_index) const; + void GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const; + + void FillStringTensor(const char* const* s, size_t s_len); + void FillStringTensorElement(const char* s, size_t index); +}; + +// Represents native memory allocation +struct MemoryAllocation { + MemoryAllocation(OrtAllocator* allocator, void* p, size_t size); + ~MemoryAllocation(); + MemoryAllocation(const MemoryAllocation&) = delete; + MemoryAllocation& operator=(const MemoryAllocation&) = delete; + MemoryAllocation(MemoryAllocation&&); + MemoryAllocation& operator=(MemoryAllocation&&); + + void* get() { return p_; } + size_t size() const { return size_; } + + private: + OrtAllocator* allocator_; + void* p_; + size_t size_; +}; + +struct AllocatorWithDefaultOptions { + AllocatorWithDefaultOptions(); + + operator OrtAllocator*() { return p_; } + operator const OrtAllocator*() const { return p_; } + + void* Alloc(size_t size); + // The return value will own the allocation + MemoryAllocation GetAllocation(size_t size); + void Free(void* p); + + const OrtMemoryInfo* GetInfo() const; + + private: + OrtAllocator* p_{}; +}; + +template +struct BaseMemoryInfo : B { + BaseMemoryInfo() = default; + explicit BaseMemoryInfo(typename B::contained_type* p) : B(p) {} + ~BaseMemoryInfo() = default; + BaseMemoryInfo(BaseMemoryInfo&&) = default; + BaseMemoryInfo& operator=(BaseMemoryInfo&&) = default; + + std::string GetAllocatorName() const; + OrtAllocatorType GetAllocatorType() const; + int GetDeviceId() const; + OrtMemType GetMemoryType() const; + template + bool operator==(const BaseMemoryInfo& o) const; +}; + +struct UnownedMemoryInfo : BaseMemoryInfo > { + explicit UnownedMemoryInfo(std::nullptr_t) {} + explicit UnownedMemoryInfo(const OrtMemoryInfo* p) : BaseMemoryInfo(p) {} +}; + +struct MemoryInfo : BaseMemoryInfo > { + static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1); + + explicit MemoryInfo(std::nullptr_t) {} + explicit MemoryInfo(OrtMemoryInfo* p) : BaseMemoryInfo(p) {} + MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type); +}; + +struct Allocator : public Base { + Allocator(const Session& session, const MemoryInfo&); + + void* Alloc(size_t size) const; + // The return value will own the allocation + MemoryAllocation GetAllocation(size_t size); + void Free(void* p) const; + UnownedMemoryInfo GetInfo() const; +}; + +struct IoBinding : public Base { + private: + std::vector GetOutputNamesHelper(OrtAllocator*) const; + std::vector GetOutputValuesHelper(OrtAllocator*) const; + + public: + explicit IoBinding(Session& session); + void BindInput(const char* name, const Value&); + void BindOutput(const char* name, const Value&); + void BindOutput(const char* name, const MemoryInfo&); + std::vector GetOutputNames() const; + std::vector GetOutputNames(Allocator&) const; + std::vector GetOutputValues() const; + std::vector GetOutputValues(Allocator&) const; + void ClearBoundInputs(); + void ClearBoundOutputs(); +}; + +// +// Custom OPs (only needed to implement custom OPs) +// + +struct CustomOpApi { + CustomOpApi(const OrtApi& api) : api_(api) {} + + template // T is only implemented for float, int64_t, and string + T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name); + + OrtTensorTypeAndShapeInfo* GetTensorTypeAndShape(_In_ const OrtValue* value); + size_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info); + ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo* info); + size_t GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info); + void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length); + void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count); + + template + T* GetTensorMutableData(_Inout_ OrtValue* value); + template + const T* GetTensorData(_Inout_ const OrtValue* value); + + std::vector GetTensorShape(const OrtTensorTypeAndShapeInfo* info); + void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input); + size_t KernelContext_GetInputCount(const OrtKernelContext* context); + const OrtValue* KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index); + size_t KernelContext_GetOutputCount(const OrtKernelContext* context); + OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count); + + void ThrowOnError(OrtStatus* result); + + private: + const OrtApi& api_; +}; + +template +struct CustomOpBase : OrtCustomOp { + CustomOpBase() { + OrtCustomOp::version = ORT_API_VERSION; + OrtCustomOp::CreateKernel = [](OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast(this_)->CreateKernel(*api, info); }; + OrtCustomOp::GetName = [](OrtCustomOp* this_) { return static_cast(this_)->GetName(); }; + + OrtCustomOp::GetExecutionProviderType = [](OrtCustomOp* this_) { return static_cast(this_)->GetExecutionProviderType(); }; + + OrtCustomOp::GetInputTypeCount = [](OrtCustomOp* this_) { return static_cast(this_)->GetInputTypeCount(); }; + OrtCustomOp::GetInputType = [](OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetInputType(index); }; + + OrtCustomOp::GetOutputTypeCount = [](OrtCustomOp* this_) { return static_cast(this_)->GetOutputTypeCount(); }; + OrtCustomOp::GetOutputType = [](OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetOutputType(index); }; + + OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { static_cast(op_kernel)->Compute(context); }; + OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast(op_kernel); }; + } + + // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider + const char* GetExecutionProviderType() const { return nullptr; } +}; + +} // namespace Ort + +#include "onnxruntime_cxx_inline.h" diff --git a/includes/onnxruntime/onnxruntime_cxx_inline.h b/includes/onnxruntime/onnxruntime_cxx_inline.h new file mode 100644 index 00000000..104ae566 --- /dev/null +++ b/includes/onnxruntime/onnxruntime_cxx_inline.h @@ -0,0 +1,930 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Do not include this file directly. Please include "onnxruntime_cxx_api.h" instead. +// If interested in trying out features of the new experimental C++ API, include "experimental_onnxruntime_cxx_api.h" instead. +// +// These are the inline implementations of the C++ header APIs. They're in this separate file as to not clutter +// the main C++ file with implementation details. + +namespace Ort { + +inline void ThrowOnError(const OrtApi& ort, OrtStatus* status) { + if (status) { + std::string error_message = ort.GetErrorMessage(status); + OrtErrorCode error_code = ort.GetErrorCode(status); + ort.ReleaseStatus(status); + ORT_CXX_API_THROW(std::move(error_message), error_code); + } +} + +inline void ThrowOnError(OrtStatus* status) { + ThrowOnError(GetApi(), status); +} + +// This template converts a C++ type into it's ONNXTensorElementDataType +template +struct TypeToTensorType; +template <> +struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }; +template <> +struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; }; +template <> +struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; }; +template <> +struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16; }; +template <> +struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; }; +template <> +struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; }; +template <> +struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; }; +template <> +struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16; }; +template <> +struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32; }; +template <> +struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64; }; +template <> +struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; }; + +inline MemoryAllocation::MemoryAllocation(OrtAllocator* allocator, void* p, size_t size) + : allocator_(allocator), p_(p), size_(size) { +} + +inline MemoryAllocation::~MemoryAllocation() { + if (p_ != nullptr) { + // We do not throw out of destructor + auto ret = GetApi().AllocatorFree(allocator_, p_); + static_cast(ret); + } +} + +inline MemoryAllocation::MemoryAllocation(MemoryAllocation&& o) : allocator_(nullptr), p_(nullptr), size_(0) { + *this = std::move(o); +} + +inline MemoryAllocation& MemoryAllocation::operator=(MemoryAllocation&& o) { + OrtAllocator* alloc = nullptr; + void* p = nullptr; + size_t sz = 0; + + // Swap out this + std::swap(alloc, allocator_); + std::swap(p, p_); + std::swap(sz, size_); + + // Swap with incoming + std::swap(allocator_, o.allocator_); + std::swap(p_, o.p_); + std::swap(size_, o.size_); + + // Destroy this instance if needed + MemoryAllocation this_alloc(alloc, p, sz); + return *this; +} + +inline AllocatorWithDefaultOptions::AllocatorWithDefaultOptions() { + ThrowOnError(GetApi().GetAllocatorWithDefaultOptions(&p_)); +} + +inline void* AllocatorWithDefaultOptions::Alloc(size_t size) { + void* out; + ThrowOnError(GetApi().AllocatorAlloc(p_, size, &out)); + return out; +} + +inline MemoryAllocation Ort::AllocatorWithDefaultOptions::GetAllocation(size_t size) { + void* out; + ThrowOnError(GetApi().AllocatorAlloc(p_, size, &out)); + MemoryAllocation result(p_, out, size); + return result; +} + +inline void AllocatorWithDefaultOptions::Free(void* p) { + ThrowOnError(GetApi().AllocatorFree(p_, p)); +} + +inline const OrtMemoryInfo* AllocatorWithDefaultOptions::GetInfo() const { + const OrtMemoryInfo* out; + ThrowOnError(GetApi().AllocatorGetInfo(p_, &out)); + return out; +} + +template +inline std::string BaseMemoryInfo::GetAllocatorName() const { + const char* name = nullptr; + ThrowOnError(GetApi().MemoryInfoGetName(*this, &name)); + return std::string(name); +} + +template +inline OrtAllocatorType BaseMemoryInfo::GetAllocatorType() const { + OrtAllocatorType type; + ThrowOnError(GetApi().MemoryInfoGetType(*this, &type)); + return type; +} + +template +int BaseMemoryInfo::GetDeviceId() const { + int id = 0; + ThrowOnError(GetApi().MemoryInfoGetId(*this, &id)); + return id; +} + +template +inline OrtMemType BaseMemoryInfo::GetMemoryType() const { + OrtMemType type; + ThrowOnError(GetApi().MemoryInfoGetMemType(*this, &type)); + return type; +} + +template +template +inline bool BaseMemoryInfo::operator==(const BaseMemoryInfo& o) const { + int comp_result = 0; + ThrowOnError(Ort::GetApi().CompareMemoryInfo(*this, o, &comp_result)); + return comp_result == 0; +} + +inline MemoryInfo MemoryInfo::CreateCpu(OrtAllocatorType type, OrtMemType mem_type) { + OrtMemoryInfo* p; + ThrowOnError(GetApi().CreateCpuMemoryInfo(type, mem_type, &p)); + return MemoryInfo(p); +} + +inline MemoryInfo::MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) { + ThrowOnError(GetApi().CreateMemoryInfo(name, type, id, mem_type, &p_)); +} + +inline Allocator::Allocator(const Session& sess, const MemoryInfo& mem_info) { + ThrowOnError(GetApi().CreateAllocator(sess, mem_info, &p_)); +} + +inline void* Allocator::Alloc(size_t size) const { + void* out = nullptr; + ThrowOnError(GetApi().AllocatorAlloc(p_, size, &out)); + return out; +} + +inline MemoryAllocation Ort::Allocator::GetAllocation(size_t size) { + void* out = nullptr; + ThrowOnError(GetApi().AllocatorAlloc(p_, size, &out)); + MemoryAllocation result(p_, out, size); + return result; +} + +inline void Allocator::Free(void* p) const { + ThrowOnError(GetApi().AllocatorFree(p_, p)); +} + +inline UnownedMemoryInfo Allocator::GetInfo() const { + const OrtMemoryInfo* out = nullptr; + ThrowOnError(GetApi().AllocatorGetInfo(p_, &out)); + return UnownedMemoryInfo(out); +} + +inline IoBinding::IoBinding(Session& session) { + ThrowOnError(GetApi().CreateIoBinding(session, &p_)); +} + +inline void IoBinding::BindInput(const char* name, const Value& value) { + ThrowOnError(GetApi().BindInput(p_, name, value)); +} + +inline void IoBinding::BindOutput(const char* name, const Value& value) { + ThrowOnError(GetApi().BindOutput(p_, name, value)); +} + +inline void IoBinding::BindOutput(const char* name, const MemoryInfo& mem_info) { + ThrowOnError(GetApi().BindOutputToDevice(p_, name, mem_info)); +} + +inline std::vector IoBinding::GetOutputNamesHelper(OrtAllocator* allocator) const { + std::vector result; + auto free_fn = [allocator](void* p) { if (p) allocator->Free(allocator, p); }; + using Ptr = std::unique_ptr; + + char* buffer = nullptr; + size_t* lengths = nullptr; + size_t count = 0; + ThrowOnError(GetApi().GetBoundOutputNames(p_, allocator, &buffer, &lengths, &count)); + + if (count == 0) { + return result; + } + + Ptr buffer_g(buffer, free_fn); + Ptr lengths_g(lengths, free_fn); + + result.reserve(count); + for (size_t i = 0; i < count; ++i) { + auto sz = *lengths; + result.emplace_back(buffer, sz); + buffer += sz; + ++lengths; + } + return result; +} + +inline std::vector IoBinding::GetOutputNames() const { + AllocatorWithDefaultOptions allocator; + return GetOutputNamesHelper(allocator); +} + +inline std::vector IoBinding::GetOutputNames(Allocator& allocator) const { + return GetOutputNamesHelper(allocator); +} + +inline std::vector Ort::IoBinding::GetOutputValuesHelper(OrtAllocator* allocator) const { + std::vector result; + size_t owned = 0; + size_t output_count = 0; + // Lambda to release the buffer when no longer needed and + // make sure that we destroy all instances on exception + auto free_fn = [&owned, &output_count, allocator](OrtValue** buffer) { + if (buffer) { + while (owned < output_count) { + auto* p = buffer + owned++; + GetApi().ReleaseValue(*p); + } + allocator->Free(allocator, buffer); + } + }; + using Ptr = std::unique_ptr; + + OrtValue** output_buffer = nullptr; + ThrowOnError(GetApi().GetBoundOutputValues(p_, allocator, &output_buffer, &output_count)); + if (output_count == 0) { + return result; + } + + Ptr buffer_g(output_buffer, free_fn); + + result.reserve(output_count); + for (size_t i = 0; i < output_count; ++i) { + result.emplace_back(output_buffer[i]); + ++owned; + } + return result; +} + +inline std::vector Ort::IoBinding::GetOutputValues(Allocator& allocator) const { + return GetOutputValuesHelper(allocator); +} + +inline std::vector Ort::IoBinding::GetOutputValues() const { + AllocatorWithDefaultOptions allocator; + return GetOutputValuesHelper(allocator); +} + +inline void IoBinding::ClearBoundInputs() { + GetApi().ClearBoundInputs(p_); +} + +inline void IoBinding::ClearBoundOutputs() { + GetApi().ClearBoundOutputs(p_); +} + +inline Env::Env(OrtLoggingLevel default_warning_level, _In_ const char* logid) { + ThrowOnError(GetApi().CreateEnv(default_warning_level, logid, &p_)); + ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS)); +} + +inline Env::Env(OrtLoggingLevel default_warning_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param) { + ThrowOnError(GetApi().CreateEnvWithCustomLogger(logging_function, logger_param, default_warning_level, logid, &p_)); + ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS)); +} + +inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel default_warning_level, _In_ const char* logid) { + ThrowOnError(GetApi().CreateEnvWithGlobalThreadPools(default_warning_level, logid, tp_options, &p_)); + ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS)); +} + +inline Env& Env::EnableTelemetryEvents() { + ThrowOnError(GetApi().EnableTelemetryEvents(p_)); + return *this; +} + +inline Env& Env::DisableTelemetryEvents() { + ThrowOnError(GetApi().DisableTelemetryEvents(p_)); + return *this; +} + +inline Env& Env::CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg) { + ThrowOnError(GetApi().CreateAndRegisterAllocator(p_, mem_info, arena_cfg)); + return *this; +} + +inline CustomOpDomain::CustomOpDomain(const char* domain) { + ThrowOnError(GetApi().CreateCustomOpDomain(domain, &p_)); +} + +inline void CustomOpDomain::Add(OrtCustomOp* op) { + ThrowOnError(GetApi().CustomOpDomain_Add(p_, op)); +} + +inline RunOptions::RunOptions() { + ThrowOnError(GetApi().CreateRunOptions(&p_)); +} + +inline RunOptions& RunOptions::SetRunLogVerbosityLevel(int level) { + ThrowOnError(GetApi().RunOptionsSetRunLogVerbosityLevel(p_, level)); + return *this; +} + +inline RunOptions& RunOptions::SetRunLogSeverityLevel(int level) { + ThrowOnError(GetApi().RunOptionsSetRunLogSeverityLevel(p_, level)); + return *this; +} + +inline int RunOptions::GetRunLogVerbosityLevel() const { + int out; + ThrowOnError(GetApi().RunOptionsGetRunLogVerbosityLevel(p_, &out)); + return out; +} + +inline RunOptions& RunOptions::SetRunTag(const char* run_tag) { + ThrowOnError(GetApi().RunOptionsSetRunTag(p_, run_tag)); + return *this; +} + +inline const char* RunOptions::GetRunTag() const { + const char* out; + ThrowOnError(GetApi().RunOptionsGetRunTag(p_, &out)); + return out; +} + +inline RunOptions& RunOptions::SetTerminate() { + ThrowOnError(GetApi().RunOptionsSetTerminate(p_)); + return *this; +} + +inline RunOptions& RunOptions::UnsetTerminate() { + ThrowOnError(GetApi().RunOptionsUnsetTerminate(p_)); + return *this; +} + +inline SessionOptions::SessionOptions() { + ThrowOnError(GetApi().CreateSessionOptions(&p_)); +} + +inline SessionOptions SessionOptions::Clone() const { + OrtSessionOptions* out; + ThrowOnError(GetApi().CloneSessionOptions(p_, &out)); + return SessionOptions{out}; +} + +inline SessionOptions& SessionOptions::SetIntraOpNumThreads(int intra_op_num_threads) { + ThrowOnError(GetApi().SetIntraOpNumThreads(p_, intra_op_num_threads)); + return *this; +} + +inline SessionOptions& SessionOptions::SetInterOpNumThreads(int inter_op_num_threads) { + ThrowOnError(GetApi().SetInterOpNumThreads(p_, inter_op_num_threads)); + return *this; +} + +inline SessionOptions& SessionOptions::SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level) { + ThrowOnError(GetApi().SetSessionGraphOptimizationLevel(p_, graph_optimization_level)); + return *this; +} + +inline SessionOptions& SessionOptions::SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_filepath) { + ThrowOnError(GetApi().SetOptimizedModelFilePath(p_, optimized_model_filepath)); + return *this; +} + +inline SessionOptions& SessionOptions::EnableProfiling(const ORTCHAR_T* profile_file_prefix) { + ThrowOnError(GetApi().EnableProfiling(p_, profile_file_prefix)); + return *this; +} + +inline SessionOptions& SessionOptions::DisableProfiling() { + ThrowOnError(GetApi().DisableProfiling(p_)); + return *this; +} + +inline SessionOptions& SessionOptions::EnableMemPattern() { + ThrowOnError(GetApi().EnableMemPattern(p_)); + return *this; +} + +inline SessionOptions& SessionOptions::DisableMemPattern() { + ThrowOnError(GetApi().DisableMemPattern(p_)); + return *this; +} + +inline SessionOptions& SessionOptions::EnableCpuMemArena() { + ThrowOnError(GetApi().EnableCpuMemArena(p_)); + return *this; +} + +inline SessionOptions& SessionOptions::DisableCpuMemArena() { + ThrowOnError(GetApi().DisableCpuMemArena(p_)); + return *this; +} + +inline SessionOptions& SessionOptions::SetExecutionMode(ExecutionMode execution_mode) { + ThrowOnError(GetApi().SetSessionExecutionMode(p_, execution_mode)); + return *this; +} + +inline SessionOptions& SessionOptions::SetLogId(const char* logid) { + ThrowOnError(GetApi().SetSessionLogId(p_, logid)); + return *this; +} + +inline SessionOptions& SessionOptions::SetLogSeverityLevel(int level) { + ThrowOnError(GetApi().SetSessionLogSeverityLevel(p_, level)); + return *this; +} + +inline SessionOptions& SessionOptions::Add(OrtCustomOpDomain* custom_op_domain) { + ThrowOnError(GetApi().AddCustomOpDomain(p_, custom_op_domain)); + return *this; +} + +inline SessionOptions& SessionOptions::AddConfigEntry(const char* config_key, const char* config_value) { + ThrowOnError(GetApi().AddSessionConfigEntry(p_, config_key, config_value)); + return *this; +} + +inline Session::Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) { + ThrowOnError(GetApi().CreateSession(env, model_path, options, &p_)); +} + +inline Session::Session(Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options) { + ThrowOnError(GetApi().CreateSessionFromArray(env, model_data, model_data_length, options, &p_)); +} + +inline std::vector Session::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, + const char* const* output_names, size_t output_names_count) { + std::vector output_values; + for (size_t i = 0; i < output_names_count; i++) + output_values.emplace_back(nullptr); + Run(run_options, input_names, input_values, input_count, output_names, output_values.data(), output_names_count); + return output_values; +} + +inline void Session::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, + const char* const* output_names, Value* output_values, size_t output_count) { + static_assert(sizeof(Value) == sizeof(OrtValue*), "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely"); + auto ort_input_values = reinterpret_cast(const_cast(input_values)); + auto ort_output_values = reinterpret_cast(output_values); + ThrowOnError(GetApi().Run(p_, run_options, input_names, ort_input_values, input_count, output_names, output_count, ort_output_values)); +} + +inline void Session::Run(const RunOptions& run_options, const IoBinding& io_binding) { + ThrowOnError(GetApi().RunWithBinding(p_, run_options, io_binding)); +} + +inline size_t Session::GetInputCount() const { + size_t out; + ThrowOnError(GetApi().SessionGetInputCount(p_, &out)); + return out; +} + +inline size_t Session::GetOutputCount() const { + size_t out; + ThrowOnError(GetApi().SessionGetOutputCount(p_, &out)); + return out; +} + +inline size_t Session::GetOverridableInitializerCount() const { + size_t out; + ThrowOnError(GetApi().SessionGetOverridableInitializerCount(p_, &out)); + return out; +} + +inline char* Session::GetInputName(size_t index, OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().SessionGetInputName(p_, index, allocator, &out)); + return out; +} + +inline char* Session::GetOutputName(size_t index, OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().SessionGetOutputName(p_, index, allocator, &out)); + return out; +} + +inline char* Session::GetOverridableInitializerName(size_t index, OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().SessionGetOverridableInitializerName(p_, index, allocator, &out)); + return out; +} + +inline char* Session::EndProfiling(OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().SessionEndProfiling(p_, allocator, &out)); + return out; +} + +inline uint64_t Session::GetProfilingStartTimeNs() const { + uint64_t out; + ThrowOnError(GetApi().SessionGetProfilingStartTimeNs(p_, &out)); + return out; +} + +inline ModelMetadata Session::GetModelMetadata() const { + OrtModelMetadata* out; + ThrowOnError(GetApi().SessionGetModelMetadata(p_, &out)); + return ModelMetadata{out}; +} + +inline char* ModelMetadata::GetProducerName(OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().ModelMetadataGetProducerName(p_, allocator, &out)); + return out; +} + +inline char* ModelMetadata::GetGraphName(OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().ModelMetadataGetGraphName(p_, allocator, &out)); + return out; +} + +inline char* ModelMetadata::GetDomain(OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().ModelMetadataGetDomain(p_, allocator, &out)); + return out; +} + +inline char* ModelMetadata::GetDescription(OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().ModelMetadataGetDescription(p_, allocator, &out)); + return out; +} + +inline char* ModelMetadata::LookupCustomMetadataMap(const char* key, OrtAllocator* allocator) const { + char* out; + ThrowOnError(GetApi().ModelMetadataLookupCustomMetadataMap(p_, allocator, key, &out)); + return out; +} + +inline char** ModelMetadata::GetCustomMetadataMapKeys(OrtAllocator* allocator, _Out_ int64_t& num_keys) const { + char** out; + ThrowOnError(GetApi().ModelMetadataGetCustomMetadataMapKeys(p_, allocator, &out, &num_keys)); + return out; +} + +inline int64_t ModelMetadata::GetVersion() const { + int64_t out; + ThrowOnError(GetApi().ModelMetadataGetVersion(p_, &out)); + return out; +} + +inline TypeInfo Session::GetInputTypeInfo(size_t index) const { + OrtTypeInfo* out; + ThrowOnError(GetApi().SessionGetInputTypeInfo(p_, index, &out)); + return TypeInfo{out}; +} + +inline TypeInfo Session::GetOutputTypeInfo(size_t index) const { + OrtTypeInfo* out; + ThrowOnError(GetApi().SessionGetOutputTypeInfo(p_, index, &out)); + return TypeInfo{out}; +} + +inline TypeInfo Session::GetOverridableInitializerTypeInfo(size_t index) const { + OrtTypeInfo* out; + ThrowOnError(GetApi().SessionGetOverridableInitializerTypeInfo(p_, index, &out)); + return TypeInfo{out}; +} + +inline ONNXTensorElementDataType TensorTypeAndShapeInfo::GetElementType() const { + ONNXTensorElementDataType out; + ThrowOnError(GetApi().GetTensorElementType(p_, &out)); + return out; +} + +inline size_t TensorTypeAndShapeInfo::GetElementCount() const { + size_t out; + ThrowOnError(GetApi().GetTensorShapeElementCount(p_, &out)); + return static_cast(out); +} + +inline size_t TensorTypeAndShapeInfo::GetDimensionsCount() const { + size_t out; + ThrowOnError(GetApi().GetDimensionsCount(p_, &out)); + return out; +} + +inline void TensorTypeAndShapeInfo::GetDimensions(int64_t* values, size_t values_count) const { + ThrowOnError(GetApi().GetDimensions(p_, values, values_count)); +} + +inline void TensorTypeAndShapeInfo::GetSymbolicDimensions(const char** values, size_t values_count) const { + ThrowOnError(GetApi().GetSymbolicDimensions(p_, values, values_count)); +} + +inline std::vector TensorTypeAndShapeInfo::GetShape() const { + std::vector out(GetDimensionsCount(), 0); + GetDimensions(out.data(), out.size()); + return out; +} + +inline Unowned TypeInfo::GetTensorTypeAndShapeInfo() const { + const OrtTensorTypeAndShapeInfo* out; + ThrowOnError(GetApi().CastTypeInfoToTensorInfo(p_, &out)); + return Unowned(const_cast(out)); +} + +inline Unowned TypeInfo::GetSequenceTypeInfo() const { + const OrtSequenceTypeInfo* out; + ThrowOnError(GetApi().CastTypeInfoToSequenceTypeInfo(p_, &out)); + return Unowned{const_cast(out)}; +} + +inline TypeInfo SequenceTypeInfo::GetSequenceElementType() const { + OrtTypeInfo* output; + ThrowOnError(GetApi().GetSequenceElementType(p_, &output)); + return TypeInfo{output}; +} + +inline Unowned TypeInfo::GetMapTypeInfo() const { + const OrtMapTypeInfo* out; + ThrowOnError(GetApi().CastTypeInfoToMapTypeInfo(p_, &out)); + return Unowned{const_cast(out)}; +} + +inline ONNXTensorElementDataType MapTypeInfo::GetMapKeyType() const { + ONNXTensorElementDataType out; + ThrowOnError(GetApi().GetMapKeyType(p_, &out)); + return out; +} + +inline TypeInfo MapTypeInfo::GetMapValueType() const { + OrtTypeInfo* output; + ThrowOnError(GetApi().GetMapValueType(p_, &output)); + return TypeInfo{output}; +} + +inline ONNXType TypeInfo::GetONNXType() const { + ONNXType out; + ThrowOnError(GetApi().GetOnnxTypeFromTypeInfo(p_, &out)); + return out; +} + +template +inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len) { + return CreateTensor(info, p_data, p_data_element_count * sizeof(T), shape, shape_len, TypeToTensorType::type); +} + +inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type) { + OrtValue* out; + ThrowOnError(GetApi().CreateTensorWithDataAsOrtValue(info, p_data, p_data_byte_count, shape, shape_len, type, &out)); + return Value{out}; +} + +template +inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len) { + return CreateTensor(allocator, shape, shape_len, TypeToTensorType::type); +} + +inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type) { + OrtValue* out; + ThrowOnError(GetApi().CreateTensorAsOrtValue(allocator, shape, shape_len, type, &out)); + return Value{out}; +} + +inline Value Value::CreateMap(Value& keys, Value& values) { + OrtValue* out; + OrtValue* inputs[2] = {keys, values}; + ThrowOnError(GetApi().CreateValue(inputs, 2, ONNX_TYPE_MAP, &out)); + return Value{out}; +} + +inline Value Value::CreateSequence(std::vector& values) { + OrtValue* out; + std::vector values_ort{values.data(), values.data() + values.size()}; + ThrowOnError(GetApi().CreateValue(values_ort.data(), values_ort.size(), ONNX_TYPE_SEQUENCE, &out)); + return Value{out}; +} + +template +inline Value Value::CreateOpaque(const char* domain, const char* type_name, const T& data_container) { + OrtValue* out; + ThrowOnError(GetApi().CreateOpaqueValue(domain, type_name, &data_container, sizeof(T), &out)); + return Value{out}; +} + +template +inline void Value::GetOpaqueData(const char* domain, const char* type_name, T& out) const { + ThrowOnError(GetApi().GetOpaqueValue(domain, type_name, p_, &out, sizeof(T))); +} + +inline bool Value::IsTensor() const { + int out; + ThrowOnError(GetApi().IsTensor(p_, &out)); + return out != 0; +} + +inline size_t Value::GetCount() const { + size_t out; + ThrowOnError(GetApi().GetValueCount(p_, &out)); + return out; +} + +inline Value Value::GetValue(int index, OrtAllocator* allocator) const { + OrtValue* out; + ThrowOnError(GetApi().GetValue(p_, index, allocator, &out)); + return Value{out}; +} + +inline size_t Value::GetStringTensorDataLength() const { + size_t out; + ThrowOnError(GetApi().GetStringTensorDataLength(p_, &out)); + return out; +} + +inline size_t Value::GetStringTensorElementLength(size_t element_index) const { + size_t out; + ThrowOnError(GetApi().GetStringTensorElementLength(p_, element_index, &out)); + return out; +} + +inline void Value::GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const { + ThrowOnError(GetApi().GetStringTensorContent(p_, buffer, buffer_length, offsets, offsets_count)); +} + +inline void Value::GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const { + ThrowOnError(GetApi().GetStringTensorElement(p_, buffer_length, element_index, buffer)); +} + +inline void Value::FillStringTensor(const char* const* s, size_t s_len) { + ThrowOnError(GetApi().FillStringTensor(p_, s, s_len)); +} + +inline void Value::FillStringTensorElement(const char* s, size_t index) { + ThrowOnError(GetApi().FillStringTensorElement(p_, s, index)); +} + +template +T* Value::GetTensorMutableData() { + T* out; + ThrowOnError(GetApi().GetTensorMutableData(p_, (void**)&out)); + return out; +} + +template +const T* Value::GetTensorData() const { + T* out; + ThrowOnError(GetApi().GetTensorMutableData(p_, (void**)&out)); + return out; +} + +template +inline T& Value::At(const std::vector& location) { + static_assert(!std::is_same::value, "this api does not support std::string"); + T* out; + ThrowOnError(GetApi().TensorAt(p_, location.data(), location.size(), (void**)&out)); + return *out; +} + +inline TypeInfo Value::GetTypeInfo() const { + OrtTypeInfo* output; + ThrowOnError(GetApi().GetTypeInfo(p_, &output)); + return TypeInfo{output}; +} + +inline TensorTypeAndShapeInfo Value::GetTensorTypeAndShapeInfo() const { + OrtTensorTypeAndShapeInfo* output; + ThrowOnError(GetApi().GetTensorTypeAndShape(p_, &output)); + return TensorTypeAndShapeInfo{output}; +} + +// +// Custom OP API Inlines +// +inline void CustomOpApi::ThrowOnError(OrtStatus* status) { + Ort::ThrowOnError(api_, status); +} + +template <> +inline float CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) { + float out; + ThrowOnError(api_.KernelInfoGetAttribute_float(info, name, &out)); + return out; +} + +template <> +inline int64_t CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) { + int64_t out; + ThrowOnError(api_.KernelInfoGetAttribute_int64(info, name, &out)); + return out; +} + +template <> +inline std::string CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) { + size_t size = 0; + std::string out; + OrtStatus* status = api_.KernelInfoGetAttribute_string(info, name, nullptr, &size); + + // The status should be ORT_INVALID_ARGUMENT because the size is insufficient to hold the string + if (api_.GetErrorCode(status) == ORT_INVALID_ARGUMENT) { + api_.ReleaseStatus(status); + out.resize(size); + ThrowOnError(api_.KernelInfoGetAttribute_string(info, name, &out[0], &size)); + out.resize(size - 1); // remove the terminating character '\0' + } else { + ThrowOnError(status); + } + return out; +} + +inline OrtTensorTypeAndShapeInfo* CustomOpApi::GetTensorTypeAndShape(_In_ const OrtValue* value) { + OrtTensorTypeAndShapeInfo* out; + ThrowOnError(api_.GetTensorTypeAndShape(value, &out)); + return out; +} + +inline size_t CustomOpApi::GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) { + size_t out; + ThrowOnError(api_.GetTensorShapeElementCount(info, &out)); + return out; +} + +inline ONNXTensorElementDataType CustomOpApi::GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) { + ONNXTensorElementDataType out; + ThrowOnError(api_.GetTensorElementType(info, &out)); + return out; +} + +inline size_t CustomOpApi::GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) { + size_t out; + ThrowOnError(api_.GetDimensionsCount(info, &out)); + return out; +} + +inline void CustomOpApi::GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) { + ThrowOnError(api_.GetDimensions(info, dim_values, dim_values_length)); +} + +inline void CustomOpApi::SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) { + ThrowOnError(api_.SetDimensions(info, dim_values, dim_count)); +} + +template +inline T* CustomOpApi::GetTensorMutableData(_Inout_ OrtValue* value) { + T* data; + ThrowOnError(api_.GetTensorMutableData(value, reinterpret_cast(&data))); + return data; +} + +template +inline const T* CustomOpApi::GetTensorData(_Inout_ const OrtValue* value) { + return GetTensorMutableData(const_cast(value)); +} + +inline std::vector CustomOpApi::GetTensorShape(const OrtTensorTypeAndShapeInfo* info) { + std::vector output(GetDimensionsCount(info)); + GetDimensions(info, output.data(), output.size()); + return output; +} + +inline void CustomOpApi::ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) { + api_.ReleaseTensorTypeAndShapeInfo(input); +} + +inline size_t CustomOpApi::KernelContext_GetInputCount(const OrtKernelContext* context) { + size_t out; + ThrowOnError(api_.KernelContext_GetInputCount(context, &out)); + return out; +} + +inline const OrtValue* CustomOpApi::KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) { + const OrtValue* out; + ThrowOnError(api_.KernelContext_GetInput(context, index, &out)); + return out; +} + +inline size_t CustomOpApi::KernelContext_GetOutputCount(const OrtKernelContext* context) { + size_t out; + ThrowOnError(api_.KernelContext_GetOutputCount(context, &out)); + return out; +} + +inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count) { + OrtValue* out; + ThrowOnError(api_.KernelContext_GetOutput(context, index, dim_values, dim_count, &out)); + return out; +} + +inline SessionOptions& SessionOptions::DisablePerSessionThreads() { + ThrowOnError(GetApi().DisablePerSessionThreads(p_)); + return *this; +} + +inline std::vector GetAvailableProviders() { + int len; + char** providers; + const OrtApi& api = GetApi(); + ThrowOnError(api.GetAvailableProviders(&providers, &len)); + std::vector available_providers(providers, providers + len); + ThrowOnError(api.ReleaseAvailableProviders(providers, len)); + return available_providers; +} +} // namespace Ort diff --git a/includes/onnxruntime/onnxruntime_session_options_config_keys.h b/includes/onnxruntime/onnxruntime_session_options_config_keys.h new file mode 100644 index 00000000..d0f07267 --- /dev/null +++ b/includes/onnxruntime/onnxruntime_session_options_config_keys.h @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +/* + * This file defines SessionOptions Config Keys and format of the Config Values. + * + * The Naming Convention for a SessionOptions Config Key, + * "[Area][.[SubArea1].[SubArea2]...].[Keyname]" + * Such as "ep.cuda.use_arena" + * The Config Key cannot be empty + * The maximum length of the Config Key is 128 + * + * The string format of a SessionOptions Config Value is defined individually for each Config. + * The maximum length of the Config Value is 1024 + */ + +// Key for disable PrePacking, +// If the config value is set to "1" then the prepacking is disabled, otherwise prepacking is enabled (default value) +static const char* const kOrtSessionOptionsConfigDisablePrepacking = "session.disable_prepacking"; + +// A value of "1" means allocators registered in the env will be used. "0" means the allocators created in the session +// will be used. Use this to override the usage of env allocators on a per session level. +static const char* const kOrtSessionOptionsConfigUseEnvAllocators = "session.use_env_allocators"; + +// Set to 'ORT' (case sensitive) to load an ORT format model. +// If unset, model type will default to ONNX unless inferred from filename ('.ort' == ORT format) or bytes to be ORT +static const char* const kOrtSessionOptionsConfigLoadModelFormat = "session.load_model_format"; + +// Set to 'ORT' (case sensitive) to save optimized model in ORT format when SessionOptions.optimized_model_path is set. +// If unset, format will default to ONNX unless optimized_model_filepath ends in '.ort'. +static const char* const kOrtSessionOptionsConfigSaveModelFormat = "session.save_model_format"; \ No newline at end of file