Add DirectML Execution Provider (#2057)

This change adds a new execution provider powered by [DirectML](https://aka.ms/DirectML).

DirectML is a high-performance, hardware-accelerated DirectX 12 library for machine learning on Windows. DirectML provides GPU acceleration for common machine learning tasks across a broad range of supported hardware and drivers.

The DirectML execution provider is capable of greatly improving evaluation time of models using commodity GPU hardware, without sacrificing broad hardware support or requiring vendor-specific extensions to be installed.

**Note** that the DML EP code was moved verbatim from the existing WindowsAI project, which is why it doesn't yet conform to the onnxruntime coding style. This is something that can be fixed later; we would like to keep formatting/whitespace changes to a minimum for the time being to make it easier to port fixes from WindowsAI to ORT during this transition.

Summary of changes:
* Initial commit of DML EP files under onnxruntime/core/providers/dml
* Add cmake entries for building the DML EP and for pulling down the DirectML redist using nuget
* Add a submodule dependency on the Windows Implementation Library (WIL)
* Add docs under docs/execution_providers/DirectML-ExecutionProvider.md
* Add support for DML EP to provider tests and perf tests
* Add support for DML EP to fns_candy_style_transfer sample
* Add entries to the C ABI for instantiating the DML EP
This commit is contained in:
Adrian Tsai 2019-10-15 06:13:07 -07:00 коммит произвёл GitHub
Родитель b101f1bcee
Коммит 4090d0d0de
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
150 изменённых файлов: 29239 добавлений и 36 удалений

4
.gitmodules поставляемый
Просмотреть файл

@ -43,3 +43,7 @@
[submodule "cmake/external/onnx-tensorrt"]
path = cmake/external/onnx-tensorrt
url = https://github.com/onnx/onnx-tensorrt.git
[submodule "cmake/external/wil"]
path = cmake/external/wil
url = https://github.com/microsoft/wil

Просмотреть файл

@ -87,6 +87,7 @@ The complete list of build options can be found by running `./build.sh (or ./bui
* [Intel OpenVINO](#openvino)
* [Android NNAPI](#Android)
* [Nuphar](#Nuphar)
* [DirectML](#DirectML)
**Options**
* [OpenMP](#OpenMP)
@ -387,6 +388,16 @@ For Linux (e.g. Ubuntu 16.04), install libopenblas-dev package
---
### DirectML
To build onnxruntime with the [DirectML execution provider](./docs/execution_providers/DirectML-ExecutionProvider.md) included, supply the `--use_dml` parameter to build.bat. e.g.
build.bat --use_dml
The DirectML execution provider supports building for both x64 and x86 architectures. DirectML is only supported on Windows.
---
## Architectures
### x86
- For Windows, just add --x86 argument when launching build.bat

10
NuGet.config Normal file
Просмотреть файл

@ -0,0 +1,10 @@
<?xml version="1.0" encoding="utf-8"?>
<configuration>
<solution>
<add key="disableSourceControlIntegration" value="true" />
</solution>
<packageSources>
<add key="NuGet Official" value="https://api.nuget.org/v3/index.json" />
<add key="onnxruntime_public" value="https://pkgs.dev.azure.com/onnxruntime/onnxruntime/_packaging/onnxruntime_public/nuget/v3/index.json" />
</packageSources>
</configuration>

Просмотреть файл

@ -51,6 +51,7 @@ ONNX Runtime supports both CPU and GPU. Using various graph optimizations and ac
Currently ONNX Runtime supports the following accelerators:
* MLAS (Microsoft Linear Algebra Subprograms)
* [DirectML](./docs/execution_providers/DirectML-ExecutionProvider.md)
* [MKL-DNN](./docs/execution_providers/MKL-DNN-ExecutionProvider.md) - [subgraph optimization](./docs/execution_providers/MKL-DNN-Subgraphs.md)
* MKL-ML
* [Intel nGraph](./docs/execution_providers/nGraph-ExecutionProvider.md)

Просмотреть файл

@ -3769,3 +3769,29 @@ LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-----
microsoft/wil
MIT License
Copyright (c) Microsoft Corporation. All rights reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE

Просмотреть файл

@ -390,8 +390,17 @@
"type": "git"
}
},
{
{
"component": {
"git": {
"commitHash": "e8c599bca6c56c44b6730ad93f6abbc9ecd60fc1",
"repositoryUrl": "https://github.com/microsoft/wil"
},
"type": "git"
}
},
{
"component":{
"type": "other",
"Other": {
"Name": "Go",

Просмотреть файл

@ -83,6 +83,7 @@ option(onnxruntime_USE_EIGEN_THREADPOOL "Use eigen threadpool. Otherwise OpenMP
option(tensorflow_C_PACKAGE_PATH "Path to tensorflow C package installation dir")
option(onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS "Enable operator implemented in language other than cpp" OFF)
option(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS "Dump node input shapes and output data to standard output when executing the model." OFF)
option(onnxruntime_USE_DML "Build with DirectML support" OFF)
set(protobuf_BUILD_TESTS OFF CACHE BOOL "Build protobuf tests" FORCE)
#nsync tests failed on Mac Build
@ -653,6 +654,15 @@ if (onnxruntime_ENABLE_MICROSOFT_INTERNAL)
add_definitions(-DMICROSOFT_INTERNAL)
endif()
if (onnxruntime_USE_DML)
if(NOT WIN32)
message(FATAL_ERROR "The DirectML execution provider is only supported when building for Windows.")
endif()
add_definitions(-DUSE_DML=1)
include(dml)
endif()
#names in this var must match the directory names under onnxruntime/core/providers
set(ONNXRUNTIME_PROVIDER_NAMES cpu)

39
cmake/external/dml.cmake поставляемый Normal file
Просмотреть файл

@ -0,0 +1,39 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
if (NOT onnxruntime_USE_CUSTOM_DIRECTML)
if (NOT(MSVC) OR NOT(WIN32))
message(FATAL_ERROR "NuGet packages are only supported for MSVC on Windows.")
endif()
# Retrieve the latest version of nuget
include(ExternalProject)
ExternalProject_Add(nuget
PREFIX nuget
URL "https://dist.nuget.org/win-x86-commandline/v5.3.0/nuget.exe"
DOWNLOAD_NO_EXTRACT 1
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
UPDATE_COMMAND ""
INSTALL_COMMAND "")
set(NUGET_CONFIG ${PROJECT_SOURCE_DIR}/../NuGet.config)
set(PACKAGES_CONFIG ${PROJECT_SOURCE_DIR}/../packages.config)
set(PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/packages)
# Restore nuget packages, which will pull down the DirectML redist package
add_custom_command(
OUTPUT restore_packages.stamp
DEPENDS ${PACKAGES_CONFIG} ${NUGET_CONFIG}
COMMAND ${CMAKE_CURRENT_BINARY_DIR}/nuget/src/nuget restore ${PACKAGES_CONFIG} -PackagesDirectory ${PACKAGES_DIR} -ConfigFile ${NUGET_CONFIG}
COMMAND ${CMAKE_COMMAND} -E touch restore_packages.stamp
VERBATIM)
add_custom_target(RESTORE_PACKAGES ALL DEPENDS restore_packages.stamp)
add_dependencies(RESTORE_PACKAGES nuget)
list(APPEND onnxruntime_EXTERNAL_DEPENDENCIES RESTORE_PACKAGES)
else()
include_directories(${dml_INCLUDE_DIR})
link_directories(${dml_LIB_DIR})
endif()

1
cmake/external/wil поставляемый Submodule

@ -0,0 +1 @@
Subproject commit e8c599bca6c56c44b6730ad93f6abbc9ecd60fc1

Просмотреть файл

@ -64,6 +64,7 @@ target_link_libraries(onnxruntime PRIVATE
${PROVIDERS_TENSORRT}
${PROVIDERS_OPENVINO}
${PROVIDERS_NUPHAR}
${PROVIDERS_DML}
onnxruntime_optimizer
onnxruntime_providers
onnxruntime_util

Просмотреть файл

@ -21,6 +21,13 @@ if(NOT onnxruntime_USE_AUTOML)
)
endif()
if(NOT onnxruntime_USE_DML)
list(REMOVE_ITEM onnxruntime_graph_src
"${ONNXRUNTIME_ROOT}/core/graph/dml_ops/*.h"
"${ONNXRUNTIME_ROOT}/core/graph/dml_ops/*.cc"
)
endif()
file(GLOB_RECURSE onnxruntime_ir_defs_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/core/defs/*.cc"
)

Просмотреть файл

@ -68,6 +68,10 @@ if(onnxruntime_USE_NNAPI)
set(PROVIDERS_NNAPI onnxruntime_providers_nnapi)
list(APPEND ONNXRUNTIME_PROVIDER_NAMES nnapi)
endif()
if(onnxruntime_USE_DML)
set(PROVIDERS_DML onnxruntime_providers_dml)
list(APPEND ONNXRUNTIME_PROVIDER_NAMES dml)
endif()
source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_common_srcs} ${onnxruntime_providers_srcs})
set(onnxruntime_providers_src ${onnxruntime_providers_common_srcs} ${onnxruntime_providers_srcs})
@ -492,6 +496,39 @@ if (onnxruntime_USE_NNAPI)
set_target_properties(onnxruntime_providers_nnapi PROPERTIES LINKER_LANGUAGE CXX)
endif()
if (onnxruntime_USE_DML)
file(GLOB_RECURSE onnxruntime_providers_dml_cc_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/core/providers/dml/*.h"
"${ONNXRUNTIME_ROOT}/core/providers/dml/*.cpp"
"${ONNXRUNTIME_ROOT}/core/providers/dml/*.cc"
)
source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_dml_cc_srcs})
add_library(onnxruntime_providers_dml ${onnxruntime_providers_dml_cc_srcs})
onnxruntime_add_include_to_target(onnxruntime_providers_dml onnxruntime_common onnxruntime_framework onnx onnx_proto protobuf::libprotobuf)
add_dependencies(onnxruntime_providers_dml ${onnxruntime_EXTERNAL_DEPENDENCIES})
target_include_directories(onnxruntime_providers_dml PRIVATE ${ONNXRUNTIME_ROOT} ${ONNXRUNTIME_ROOT}/../cmake/external/wil/include)
target_link_libraries(onnxruntime_providers_dml ${CMAKE_CURRENT_BINARY_DIR}/packages/DirectML.0.0.1/build/DirectML.targets)
target_link_libraries(onnxruntime_providers_dml d3d12.lib dxgi.lib)
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} /DELAYLOAD:DirectML.dll /DELAYLOAD:d3d12.dll /DELAYLOAD:dxgi.dll")
# The DML EP requires C++17
set_target_properties(onnxruntime_providers_dml PROPERTIES CXX_STANDARD 17)
set_target_properties(onnxruntime_providers_dml PROPERTIES CXX_STANDARD_REQUIRED ON)
target_compile_definitions(onnxruntime_providers_dml PRIVATE ONNX_NAMESPACE=onnx ONNX_ML LOTUS_LOG_THRESHOLD=2 LOTUS_ENABLE_STDERR_LOGGING PLATFORM_WINDOWS)
target_compile_definitions(onnxruntime_providers_dml PRIVATE UNICODE _UNICODE NOMINMAX)
if (MSVC)
target_compile_definitions(onnxruntime_providers_dml PRIVATE _SILENCE_CXX17_ITERATOR_BASE_CLASS_DEPRECATION_WARNING)
target_compile_options(onnxruntime_providers_dml PRIVATE "/W3")
endif()
install(DIRECTORY ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/dml DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core/providers)
set_target_properties(onnxruntime_providers_dml PROPERTIES LINKER_LANGUAGE CXX)
set_target_properties(onnxruntime_providers_dml PROPERTIES FOLDER "ONNXRuntime")
endif()
if (onnxruntime_ENABLE_MICROSOFT_INTERNAL)
include(onnxruntime_providers_internal.cmake)
endif()

Просмотреть файл

@ -74,6 +74,7 @@ set(onnxruntime_pybind11_state_libs
${PROVIDERS_OPENVINO}
${PROVIDERS_NUPHAR}
${PROVIDERS_NNAPI}
${PROVIDERS_DML}
onnxruntime_optimizer
onnxruntime_providers
onnxruntime_util

Просмотреть файл

@ -219,6 +219,10 @@ if(onnxruntime_USE_AUTOML)
list(APPEND onnxruntime_test_providers_dependencies automl_featurizers)
endif()
if(onnxruntime_USE_DML)
list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_dml)
endif()
file(GLOB_RECURSE onnxruntime_test_tvm_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/test/tvm/*.h"
"${ONNXRUNTIME_ROOT}/test/tvm/*.cc"
@ -250,6 +254,7 @@ set(ONNXRUNTIME_TEST_LIBS
${PROVIDERS_OPENVINO}
${PROVIDERS_NUPHAR}
${PROVIDERS_NNAPI}
${PROVIDERS_DML}
onnxruntime_optimizer
onnxruntime_providers
onnxruntime_util

Просмотреть файл

@ -0,0 +1,110 @@
# DirectML Execution Provider (Preview)
DirectML is a high-performance, hardware-accelerated DirectX 12 library for machine learning on Windows. DirectML provides GPU acceleration for common machine learning tasks across a broad range of supported hardware and drivers.
When used standalone, the DirectML API is a low-level DirectX 12 library and is suitable for high-performance, low-latency applications such as frameworks, games, and other real-time applications. The seamless interoperability of DirectML with Direct3D 12 as well as its low overhead and conformance across hardware makes DirectML ideal for accelerating machine learning when both high performance is desired, and the reliability and predictabiltiy of results across hardware is critical.
The *DirectML Execution Provider* is an optional component of ONNX Runtime that uses DirectML to accelerate inference of ONNX models. The DirectML execution provider is capable of greatly improving evaluation time of models using commodity GPU hardware, without sacrificing broad hardware support or requiring vendor-specific extensions to be installed.
The DirectML Execution Provider is currently in preview.
## Table of contents
- [DirectML Execution Provider (Preview)](#directml-execution-provider-preview)
- [Table of contents](#table-of-contents)
- [Minimum requirements](#minimum-requirements)
- [Building from source](#building-from-source)
- [Using the DirectML execution provider](#using-the-directml-execution-provider)
- [`OrtSessionOptionsAppendExecutionProvider_DML` function](#ortsessionoptionsappendexecutionproviderdml-function)
- [`OrtSessionOptionsAppendExecutionProviderEx_DML` function](#ortsessionoptionsappendexecutionproviderexdml-function)
- [ONNX opset support](#onnx-opset-support)
- [Multi-threading and supported session options](#multi-threading-and-supported-session-options)
- [Samples](#samples)
- [See also](#see-also)
## Minimum requirements
The DirectML execution provider requires any DirectX 12 capable device. Almost all commercially-available graphics cards released in the last several years support DirectX 12. Examples of compatible hardware include:
* NVIDIA Kepler (GTX 600 series) and above
* AMD GCN 1st Gen (Radeon HD 7000 series) and above
* Intel Haswell (4th-gen core) HD Integrated Graphics and above
DirectML is compatible with Windows 10, version 1709 (10.0.16299; RS3, "Fall Creators Update") and newer.
## Building from source
For general information about building onnxruntime, see [BUILD.md](../../BUILD.md).
Requirements for building the DirectML execution provider:
1. Visual Studio 2017 toolchain (see [cmake configuration instructions](../../BUILD.md))
2. [The Windows 10 SDK (10.0.18362.0) for Windows 10, version 1903](https://developer.microsoft.com/en-us/windows/downloads/windows-10-sdk) (or newer)
To build onnxruntime with the DML EP included, supply the `--use_dml` parameter to `build.bat`. e.g.
build.bat --config RelWithDebInfo --build_shared_lib --parallel --use_dml
The DirectML execution provider supports building for both x64 (default) and x86 architectures.
Note that building onnxruntime with the DirectML execution provider enabled causes the the DirectML redistributable package to be automatically downloaded as part of the build. This package contains a pre-release version of DirectML, and its use is governed by a license whose text may be found as part of the NuGet package.
## Using the DirectML execution provider
When using the [C API](../C_API.md) with a DML-enabled build of onnxruntime (see [Building from source](#building-from-source)), the DirectML execution provider can be enabled using one of the two factory functions included in `include/onnxruntime/core/providers/dml/dml_provider_factory.h`.
### `OrtSessionOptionsAppendExecutionProvider_DML` function
Creates a DirectML Execution Provider which executes on the hardware adapter with the given `device_id`, also known as the adapter index. The device ID corresponds to the enumeration order of hardware adapters as given by [IDXGIFactory::EnumAdapters](https://docs.microsoft.com/windows/win32/api/dxgi/nf-dxgi-idxgifactory-enumadapters). A `device_id` of 0 always corresponds to the default adapter, which is typically the primary display GPU installed on the system. A negative `device_id` is invalid.
OrtStatus* OrtSessionOptionsAppendExecutionProvider_DML(
_In_ OrtSessionOptions* options,
int device_id
);
### `OrtSessionOptionsAppendExecutionProviderEx_DML` function
Creates a DirectML Execution Provider using the given DirectML device, and which executes work on the supplied D3D12 command queue. The DirectML device and D3D12 command queue must have the same parent [ID3D12Device](https://docs.microsoft.com/windows/win32/api/d3d12/nn-d3d12-id3d12device), or an error will be returned. The D3D12 command queue must be of type `DIRECT` or `COMPUTE` (see [D3D12_COMMAND_LIST_TYPE](https://docs.microsoft.com/windows/win32/api/d3d12/ne-d3d12-d3d12_command_list_type)). If this function succeeds, the inference session once created will maintain a strong reference on both the `dml_device` and `command_queue` objects.
OrtStatus* OrtSessionOptionsAppendExecutionProviderEx_DML(
_In_ OrtSessionOptions* options,
_In_ IDMLDevice* dml_device,
_In_ ID3D12CommandQueue* cmd_queue
);
**See Also**
[DMLCreateDevice function](https://docs.microsoft.com/windows/win32/api/directml/nf-directml-dmlcreatedevice)
[ID3D12Device::CreateCommandQueue method](https://docs.microsoft.com/windows/win32/api/d3d12/nf-d3d12-id3d12device-createcommandqueue)
[Direct3D 12 programming guide](https://docs.microsoft.com/windows/win32/direct3d12/directx-12-programming-guide)
### ONNX opset support
The DirectML execution provider currently supports ONNX opset 9 ([ONNX v1.4](https://github.com/onnx/onnx/releases/tag/v1.4.0)). Evaluating models which require a higher opset version is not supported, and may produce unexpected results.
### Multi-threading and supported session options
The DirectML execution provider does not support the use of memory pattern optimizations or parallel execution in onnxruntime. When supplying session options during InferenceSession creation, these options must be disabled or an error will be returned.
If using the onnxruntime C API, you must call `DisableMemPattern` and `SetSessionExecutionMode` functions to set the options required by the DirectML execution provider.
See [onnxruntime\include\onnxruntime\core\session\onnxruntime_c_api.h](..\..\include\onnxruntime\core\session\onnxruntime_c_api.h).
OrtStatus*(ORT_API_CALL* DisableMemPattern)(_Inout_ OrtSessionOptions* options)NO_EXCEPTION;
OrtStatus*(ORT_API_CALL* SetSessionExecutionMode)(_Inout_ OrtSessionOptions* options, ExecutionMode execution_mode)NO_EXCEPTION;
If creating the onnxruntime InferenceSession object directly, you must set the appropriate fields on the `onnxruntime::SessionOptions` struct. Specifically, `execution_mode` must be set to `ExecutionMode::ORT_SEQUENTIAL`, and `enable_mem_pattern` must be `false`.
Additionally, as the DirectML execution provider does not support parallel execution, it does not support multi-threaded calls to `Run` on the same inference session. That is, if an inference session using the DirectML execution provider, only one thread may call `Run` at a time. Multiple threads are permitted to call `Run` simultaneously if they operate on different inference session objects.
## Samples
A complete sample of onnxruntime using the DirectML execution provider can be found under [samples/c_cxx/fns_candy_style_transfer](../../samples/c_cxx/fns_candy_style_transfer).
## See also
[DirectML documentation \(docs.microsoft.com\)](https://docs.microsoft.com/en-us/windows/win32/direct3d12/dml)

Просмотреть файл

@ -20,6 +20,7 @@ constexpr const char* kMLDomain = "ai.onnx.ml";
constexpr const char* kMSDomain = "com.microsoft";
constexpr const char* kMSNchwcDomain = "com.microsoft.nchwc";
constexpr const char* kMSAutoMLDomain = "com.microsoft.automl";
constexpr const char* kMSDmlDomain = "com.microsoft.dml";
constexpr const char* kNGraphDomain = "com.intel.ai";
constexpr const char* kCpuExecutionProvider = "CPUExecutionProvider";
constexpr const char* kCudaExecutionProvider = "CUDAExecutionProvider";
@ -30,5 +31,5 @@ constexpr const char* kNupharExecutionProvider = "NupharExecutionProvider";
constexpr const char* kBrainSliceExecutionProvider = "BrainSliceExecutionProvider";
constexpr const char* kTensorrtExecutionProvider = "TensorrtExecutionProvider";
constexpr const char* kNnapiExecutionProvider = "NnapiExecutionProvider";
constexpr const char* kDmlExecutionProvider = "DmlExecutionProvider";
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,47 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma warning(push)
#pragma warning(disable : 4201) // nonstandard extension used: nameless struct/union
#include <d3d12.h>
#pragma warning(pop)
#ifdef __cplusplus
#include <DirectML.h>
#else
struct IDMLDevice;
typedef struct IDMLDevice IDMLDevice;
#endif
// Windows pollutes the macro space, causing a build break in constants.h.
#undef OPTIONAL
#include "onnxruntime_c_api.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* Creates a DirectML Execution Provider which executes on the hardware adapter with the given device_id, also known as
* the adapter index. The device ID corresponds to the enumeration order of hardware adapters as given by
* IDXGIFactory::EnumAdapters. A device_id of 0 always corresponds to the default adapter, which is typically the
* primary display GPU installed on the system. A negative device_id is invalid.
*/
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_DML, _In_ OrtSessionOptions* options, int device_id);
/**
* Creates a DirectML Execution Provider using the given DirectML device, and which executes work on the supplied D3D12
* command queue. The DirectML device and D3D12 command queue must have the same parent ID3D12Device, or an error will
* be returned. The D3D12 command queue must be of type DIRECT or COMPUTE (see D3D12_COMMAND_LIST_TYPE). If this
* function succeeds, the inference session maintains a strong reference on both the dml_device and the command_queue
* objects.
* See also: DMLCreateDevice
* See also: ID3D12Device::CreateCommandQueue
*/
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProviderEx_DML, _In_ OrtSessionOptions* options,
_In_ IDMLDevice* dml_device, _In_ ID3D12CommandQueue* cmd_queue);
#ifdef __cplusplus
}
#endif

Просмотреть файл

@ -0,0 +1,343 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/graph/constants.h"
#include "core/graph/dml_ops/dml_defs.h"
#include "core/graph/op.h"
#include "onnx/defs/schema.h"
#include "onnx/defs/shape_inference.h"
#include "core/providers/dml/OperatorAuthorHelper/Attributes.h"
namespace ONNX_NAMESPACE {
void convPoolShapeInference(
ONNX_NAMESPACE::InferenceContext& ctx,
bool use_dilation, bool require_kernel_shape,
int input1Idx,
int input2Idx);
void convTransposeShapeInference(InferenceContext& ctx);
} // namespace ONNX_NAMESPACE
namespace onnxruntime {
namespace dml {
using ONNX_NAMESPACE::AttributeProto;
using ONNX_NAMESPACE::OpSchema;
using ONNX_NAMESPACE::OPTIONAL;
void RegisterDmlSchemas() {
MS_DML_OPERATOR_SCHEMA(FusedConv)
.SetDomain(kMSDmlDomain)
.SinceVersion(1)
.SetDoc(R"DOC(DirectML fused Conv+Activation)DOC")
.Input(0, "X", "", "T")
.Input(1, "W", "", "T")
.Input(2, "B", "", "T", OpSchema::Optional)
.Output(0, "Y", "", "T")
.TypeConstraint("T", {"tensor(float16)", "tensor(float)", "tensor(double)"}, "")
.Attr("kernel_shape", "", AttributeProto::INTS, OPTIONAL)
.Attr("dilations", "", AttributeProto::INTS, OPTIONAL)
.Attr("strides", "", AttributeProto::INTS, OPTIONAL)
.Attr("auto_pad", "", AttributeProto::STRING, std::string("NOTSET"))
.Attr("pads", "", AttributeProto::INTS, OPTIONAL)
.Attr("group", "", AttributeProto::INT, static_cast<int64_t>(1))
.Attr(AttrName::FusedActivation, "", onnx::AttributeProto::STRING)
.Attr(AttrName::FusedActivationDomain, "", onnx::AttributeProto::STRING)
.Attr(AttrName::FusedActivationSinceVersion, "", onnx::AttributeProto::INT)
.Attr(AttrName::FusedAlpha, "", onnx::AttributeProto::FLOAT, false)
.Attr(AttrName::FusedBeta, "", onnx::AttributeProto::FLOAT, false)
.Attr(AttrName::FusedGamma, "", onnx::AttributeProto::FLOAT, false)
.Attr(AttrName::FusedRatio, "", onnx::AttributeProto::FLOAT, false)
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0);
ONNX_NAMESPACE::convPoolShapeInference(ctx, true, false, 0, 1);
});
MS_DML_OPERATOR_SCHEMA(FusedConvTranspose)
.SetDomain(kMSDmlDomain)
.SinceVersion(1)
.SetDoc(R"DOC(DirectML fused ConvTranspose+Activation)DOC")
.Input(0, "X", "", "T")
.Input(1, "W", "", "T")
.Input(2, "B", "", "T", OpSchema::Optional)
.Output(0, "Y", "", "T")
.TypeConstraint("T", {"tensor(float16)", "tensor(float)", "tensor(double)"}, "")
.Attr("kernel_shape", "", AttributeProto::INTS, OPTIONAL)
.Attr("output_shape", "", AttributeProto::INTS, OPTIONAL)
.Attr("output_padding", "", AttributeProto::INTS, OPTIONAL)
.Attr("dilations", "", AttributeProto::INTS, OPTIONAL)
.Attr("strides", "", AttributeProto::INTS, OPTIONAL)
.Attr("auto_pad", "", AttributeProto::STRING, std::string("NOTSET"))
.Attr("pads", "", AttributeProto::INTS, OPTIONAL)
.Attr("group", "", AttributeProto::INT, static_cast<int64_t>(1))
.Attr(AttrName::FusedActivation, "", onnx::AttributeProto::STRING)
.Attr(AttrName::FusedActivationDomain, "", onnx::AttributeProto::STRING)
.Attr(AttrName::FusedActivationSinceVersion, "", onnx::AttributeProto::INT)
.Attr(AttrName::FusedAlpha, "", onnx::AttributeProto::FLOAT, false)
.Attr(AttrName::FusedBeta, "", onnx::AttributeProto::FLOAT, false)
.Attr(AttrName::FusedGamma, "", onnx::AttributeProto::FLOAT, false)
.Attr(AttrName::FusedRatio, "", onnx::AttributeProto::FLOAT, false)
.TypeAndShapeInferenceFunction(
[](ONNX_NAMESPACE::InferenceContext& ctx) { ONNX_NAMESPACE::convTransposeShapeInference(ctx); });
MS_DML_OPERATOR_SCHEMA(FusedInstanceNormalization)
.SetDomain(kMSDmlDomain)
.SinceVersion(1)
.SetDoc(R"DOC(DirectML fused InstanceNormalization+Activation)DOC")
.Attr("epsilon", "", AttributeProto::FLOAT, 1e-5f)
.Input(0, "input", "", "T")
.Input(1, "scale", "", "T")
.Input(2, "B", "", "T")
.Output(0, "output", "", "T")
.TypeConstraint("T", {"tensor(float16)", "tensor(float)", "tensor(double)"}, "")
.Attr(AttrName::FusedActivation, "", onnx::AttributeProto::STRING)
.Attr(AttrName::FusedActivationDomain, "", onnx::AttributeProto::STRING)
.Attr(AttrName::FusedActivationSinceVersion, "", onnx::AttributeProto::INT)
.Attr(AttrName::FusedAlpha, "", onnx::AttributeProto::FLOAT, false)
.Attr(AttrName::FusedBeta, "", onnx::AttributeProto::FLOAT, false)
.Attr(AttrName::FusedGamma, "", onnx::AttributeProto::FLOAT, false)
.Attr(AttrName::FusedRatio, "", onnx::AttributeProto::FLOAT, false)
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput(ctx);
});
MS_DML_OPERATOR_SCHEMA(FusedBatchNormalization)
.SetDomain(kMSDmlDomain)
.SinceVersion(1)
.SetDoc(R"DOC(DirectML fused BatchNormalization+Activation)DOC")
.NumOutputs({1, 5})
.Attr("spatial", "", AttributeProto::INT, static_cast<int64_t>(1))
.Attr("epsilon", "", AttributeProto::FLOAT, 1e-5f)
.Attr("momentum", "", AttributeProto::FLOAT, 0.9f)
.Input(0, "X", "", "T")
.Input(1, "scale", "", "T")
.Input(2, "B", "", "T")
.Input(3, "mean", "", "T")
.Input(4, "var", "", "T")
.Output(0, "Y", "", "T")
.Output(1, "mean", "", "T", OpSchema::Optional)
.Output(2, "var", "", "T", OpSchema::Optional)
.Output(3, "saved_mean", "", "T", OpSchema::Optional)
.Output(4, "saved_var", "", "T", OpSchema::Optional)
.TypeConstraint("T", {"tensor(float16)", "tensor(float)", "tensor(double)"}, "")
.Attr(AttrName::FusedActivation, "", onnx::AttributeProto::STRING)
.Attr(AttrName::FusedActivationDomain, "", onnx::AttributeProto::STRING)
.Attr(AttrName::FusedActivationSinceVersion, "", onnx::AttributeProto::INT)
.Attr(AttrName::FusedAlpha, "", onnx::AttributeProto::FLOAT, false)
.Attr(AttrName::FusedBeta, "", onnx::AttributeProto::FLOAT, false)
.Attr(AttrName::FusedGamma, "", onnx::AttributeProto::FLOAT, false)
.Attr(AttrName::FusedRatio, "", onnx::AttributeProto::FLOAT, false)
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput(ctx);
// TODO in training mode, it may be possible to infer some of
// the other outputs as well.
});
MS_DML_OPERATOR_SCHEMA(FusedMeanVarianceNormalization)
.SetDomain(kMSDmlDomain)
.SinceVersion(1)
.SetDoc(R"DOC(DirectML fused MeanVarianceNormalization+Activation)DOC")
.Attr("across_channels", "", AttributeProto::INT, static_cast<int64_t>(0))
.Attr("normalize_variance", "", AttributeProto::INT, static_cast<int64_t>(1))
.Input(0, "input", "", "T")
.Output(0, "output", "", "T")
.TypeConstraint( "T", { "tensor(float16)", "tensor(float)", "tensor(double)" }, "")
.Attr(AttrName::FusedActivation, "", onnx::AttributeProto::STRING)
.Attr(AttrName::FusedActivationDomain, "", onnx::AttributeProto::STRING)
.Attr(AttrName::FusedActivationSinceVersion, "", onnx::AttributeProto::INT)
.Attr(AttrName::FusedAlpha, "", onnx::AttributeProto::FLOAT, false)
.Attr(AttrName::FusedBeta, "", onnx::AttributeProto::FLOAT, false)
.Attr(AttrName::FusedGamma, "", onnx::AttributeProto::FLOAT, false)
.Attr(AttrName::FusedRatio, "", onnx::AttributeProto::FLOAT, false)
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput);
MS_DML_OPERATOR_SCHEMA(FusedGemm)
.SetDomain(kMSDmlDomain)
.SinceVersion(1)
.SetDoc(R"DOC(DirectML fused Gemm+Activation)DOC")
.Input(0, "A", "", "T")
.Input(1, "B", "", "T")
.Input(2, "C", "", "T")
.Output(0, "Y", "", "T")
.TypeConstraint("T", {"tensor(float16)", "tensor(float)", "tensor(double)"}, "")
.Attr("transA", "", AttributeProto::INT, static_cast<int64_t>(0))
.Attr("transB", "", AttributeProto::INT, static_cast<int64_t>(0))
.Attr("alpha", "", AttributeProto::FLOAT, 1.0f)
.Attr("beta", "", AttributeProto::FLOAT, 1.0f)
.Attr(AttrName::FusedActivation, "", onnx::AttributeProto::STRING)
.Attr(AttrName::FusedActivationDomain, "", onnx::AttributeProto::STRING)
.Attr(AttrName::FusedActivationSinceVersion, "", onnx::AttributeProto::INT)
.Attr(AttrName::FusedAlpha, "", onnx::AttributeProto::FLOAT, false)
.Attr(AttrName::FusedBeta, "", onnx::AttributeProto::FLOAT, false)
.Attr(AttrName::FusedGamma, "", onnx::AttributeProto::FLOAT, false)
.Attr(AttrName::FusedRatio, "", onnx::AttributeProto::FLOAT, false)
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0);
if (hasNInputShapes(ctx, 2)) {
auto transAAttr = ctx.getAttribute("transA");
bool transA =
transAAttr ? static_cast<int>(transAAttr->i()) != 0 : false;
auto transBAttr = ctx.getAttribute("transB");
bool transB =
transBAttr ? static_cast<int>(transBAttr->i()) != 0 : false;
auto& first_input_shape = getInputShape(ctx, 0);
auto& second_input_shape = getInputShape(ctx, 1);
if (first_input_shape.dim_size() != 2)
fail_shape_inference("First input does not have rank 2");
if (second_input_shape.dim_size() != 2)
fail_shape_inference("Second input does not have rank 2");
updateOutputShape(
ctx,
0,
{first_input_shape.dim(transA ? 1 : 0),
second_input_shape.dim(transB ? 0 : 1)});
}
});
MS_DML_OPERATOR_SCHEMA(FusedMatMul)
.SetDomain(kMSDmlDomain)
.SinceVersion(1)
.SetDoc(R"DOC(DirectML fused MatMul+Activation)DOC")
.Input(0, "A", "", "T")
.Input(1, "B", "", "T")
.Output(0, "Y", "", "T")
.TypeConstraint( "T", {"tensor(float16)", "tensor(float)", "tensor(double)"}, "")
.Attr(AttrName::FusedActivation, "", onnx::AttributeProto::STRING)
.Attr(AttrName::FusedActivationDomain, "", onnx::AttributeProto::STRING)
.Attr(AttrName::FusedActivationSinceVersion, "", onnx::AttributeProto::INT)
.Attr(AttrName::FusedAlpha, "", onnx::AttributeProto::FLOAT, false)
.Attr(AttrName::FusedBeta, "", onnx::AttributeProto::FLOAT, false)
.Attr(AttrName::FusedGamma, "", onnx::AttributeProto::FLOAT, false)
.Attr(AttrName::FusedRatio, "", onnx::AttributeProto::FLOAT, false)
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0);
if (!hasNInputShapes(ctx, 2)) {
return;
}
const auto shape0 = ctx.getInputType(0)->tensor_type().shape();
const auto shape1 = ctx.getInputType(1)->tensor_type().shape();
if (shape0.dim_size() == 0 || shape1.dim_size() == 0) {
fail_shape_inference("Input tensors of wrong rank (0).");
}
ONNX_NAMESPACE::TensorShapeProto shapeL, shapeR;
// First promote each shape to at least rank-2. This logic is
// specific to matmul, not generic broadcasting.
{
if (shape0.dim_size() == 1) {
shapeL.add_dim()->set_dim_value(1);
*shapeL.add_dim() = shape0.dim(0);
} else {
*shapeL.mutable_dim() = shape0.dim();
}
if (shape1.dim_size() == 1) {
*shapeR.add_dim() = shape1.dim(0);
shapeR.add_dim()->set_dim_value(1);
} else {
*shapeR.mutable_dim() = shape1.dim();
}
}
// Check for compatible matrix multiply dimensions
{
auto dimL = shapeL.dim(shapeL.dim_size() - 1);
auto dimR = shapeR.dim(shapeR.dim_size() - 2);
if (dimL.has_dim_value() && dimR.has_dim_value() &&
dimL.dim_value() != dimR.dim_value()) {
fail_shape_inference(
"Incompatible dimensions for matrix multiplication");
;
}
}
ONNX_NAMESPACE::TensorShapeProto resultShape;
// Now call out to generic multidimensional broadcasting for
// the broadcastable prefixes.
{
ONNX_NAMESPACE::TensorShapeProto prefixShapeL, prefixShapeR;
for (int i = 0; i < shapeL.dim_size() - 2; ++i) {
*prefixShapeL.add_dim() = shapeL.dim(i);
}
for (int i = 0; i < shapeR.dim_size() - 2; ++i) {
*prefixShapeR.add_dim() = shapeR.dim(i);
}
bidirectionalBroadcastShapeInference(
prefixShapeL, prefixShapeR, resultShape);
}
// Back to matmul-specific. Add the trailing dimensions back in.
{
if (shape0.dim_size() != 1) {
*resultShape.add_dim() = shapeL.dim(shapeL.dim_size() - 2);
}
if (shape1.dim_size() != 1) {
*resultShape.add_dim() = shapeR.dim(shapeR.dim_size() - 1);
}
}
*ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape() =
resultShape;
});
MS_DML_OPERATOR_SCHEMA(FusedAdd)
.SetDomain(kMSDmlDomain)
.SinceVersion(1)
.SetDoc(R"DOC(DirectML fused Add+Activation)DOC")
.Input(0, "A", "", "T")
.Input(1, "B", "", "T")
.Output(0, "C", "", "T")
.TypeConstraint("T", OpSchema::numeric_types_for_math_reduction(), "")
.Attr(AttrName::FusedActivation, "", onnx::AttributeProto::STRING)
.Attr(AttrName::FusedActivationDomain, "", onnx::AttributeProto::STRING)
.Attr(AttrName::FusedActivationSinceVersion, "", onnx::AttributeProto::INT)
.Attr(AttrName::FusedAlpha, "", onnx::AttributeProto::FLOAT, false)
.Attr(AttrName::FusedBeta, "", onnx::AttributeProto::FLOAT, false)
.Attr(AttrName::FusedGamma, "", onnx::AttributeProto::FLOAT, false)
.Attr(AttrName::FusedRatio, "", onnx::AttributeProto::FLOAT, false)
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0);
if (hasNInputShapes(ctx, 2))
bidirectionalBroadcastShapeInference(
ctx.getInputType(0)->tensor_type().shape(),
ctx.getInputType(1)->tensor_type().shape(),
*ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape());
});
MS_DML_OPERATOR_SCHEMA(FusedSum)
.SetDomain(kMSDmlDomain)
.SinceVersion(1)
.SetDoc(R"DOC(DirectML fused Sum+Activation)DOC")
.Input(0, "data_0", "", "T", OpSchema::Variadic)
.Output(0, "sum", "", "T")
.TypeConstraint("T", {"tensor(float16)", "tensor(float)", "tensor(double)"}, "")
.Attr(AttrName::FusedActivation, "", onnx::AttributeProto::STRING)
.Attr(AttrName::FusedActivationDomain, "", onnx::AttributeProto::STRING)
.Attr(AttrName::FusedActivationSinceVersion, "", onnx::AttributeProto::INT)
.Attr(AttrName::FusedAlpha, "", onnx::AttributeProto::FLOAT, false)
.Attr(AttrName::FusedBeta, "", onnx::AttributeProto::FLOAT, false)
.Attr(AttrName::FusedGamma, "", onnx::AttributeProto::FLOAT, false)
.Attr(AttrName::FusedRatio, "", onnx::AttributeProto::FLOAT, false)
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0);
int num_inputs = static_cast<int>(ctx.getNumInputs());
std::vector<const ONNX_NAMESPACE::TensorShapeProto*> shapes;
for (int i = 0; i < num_inputs; ++i) {
auto input_type = ctx.getInputType(i);
if (nullptr == input_type || !input_type->has_tensor_type() ||
!input_type->tensor_type().has_shape()) {
return;
}
shapes.push_back(&input_type->tensor_type().shape());
}
multidirectionalBroadcastShapeInference(
shapes,
*ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape());
});
}
} // namespace dml
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,30 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/graph/onnx_protobuf.h"
namespace onnxruntime {
namespace dml {
#define MS_DML_OPERATOR_SCHEMA(name) \
MS_DML_OPERATOR_SCHEMA_UNIQ_HELPER(__COUNTER__, name)
#define MS_DML_OPERATOR_SCHEMA_UNIQ_HELPER(Counter, name) \
MS_DML_OPERATOR_SCHEMA_UNIQ(Counter, name)
#define MS_DML_OPERATOR_SCHEMA_UNIQ(Counter, name) \
static ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce( \
op_schema_register_once##name##Counter) ONNX_UNUSED = \
ONNX_NAMESPACE::OpSchema(#name, __FILE__, __LINE__)
#define MS_DML_OPERATOR_SCHEMA_ELSEWHERE(name, schema_func) \
MS_DML_OPERATOR_SCHEMA_UNIQ_HELPER_ELSEWHERE(__COUNTER__, name, schema_func)
#define MS_DML_OPERATOR_SCHEMA_UNIQ_HELPER_ELSEWHERE(Counter, name, schema_func) \
MS_DML_OPERATOR_SCHEMA_UNIQ_ELSEWHERE(Counter, name, schema_func)
#define MS_DML_OPERATOR_SCHEMA_UNIQ_ELSEWHERE(Counter, name, schema_func) \
static ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce( \
op_schema_register_once##name##Counter) ONNX_UNUSED = \
schema_func(ONNX_NAMESPACE::OpSchema(#name, __FILE__, __LINE__))
void RegisterDmlSchemas();
} // namespace dml
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,55 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
interface IMLOperatorRegistry;
#include "core/common/status.h"
#include "core/framework/data_transfer.h"
#include "IWinmlExecutionProvider.h"
namespace onnxruntime
{
class IExecutionProvider;
class IAllocator;
class CustomRegistry;
class InferenceSession;
class KernelRegistry;
}
enum class AllocatorRoundingMode
{
Disabled = 0,
Enabled = 1,
};
namespace Dml
{
std::unique_ptr<onnxruntime::IExecutionProvider> CreateExecutionProvider(
IDMLDevice* dmlDevice,
ID3D12CommandQueue* commandQueue,
bool enableMetacommands = true);
ID3D12Resource* GetD3D12ResourceFromAllocation(onnxruntime::IAllocator* allocator, void* ptr);
void FlushContext(onnxruntime::IExecutionProvider* provider);
void SetDefaultRoundingMode(onnxruntime::IExecutionProvider* provider, AllocatorRoundingMode roundingMode);
void ReleaseCompletedReferences(onnxruntime::IExecutionProvider* provider);
void TrimUploadHeap(onnxruntime::IExecutionProvider * provider);
void WaitForGpuCompletion(onnxruntime::IExecutionProvider * provider);
onnxruntime::common::Status CopyTensor(
onnxruntime::IExecutionProvider* provider,
const onnxruntime::Tensor& src, onnxruntime::Tensor& dst
);
void* CreateGPUAllocationFromD3DResource(ID3D12Resource* pResource);
void FreeGPUAllocation(void* ptr);
void RegisterDmlOperators(IMLOperatorRegistry* registry);
onnxruntime::common::Status RegisterDmlGraphTransformer(
onnxruntime::InferenceSession* session,
std::shared_ptr<onnxruntime::KernelRegistry> dmlRegistry
);
} // namespace Dml

Просмотреть файл

@ -0,0 +1,107 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <unordered_map>
#include <functional>
#include <variant>
#include <optional>
#include "core/framework/op_kernel.h"
struct AbstractOperatorDesc;
interface IMLOperatorTensor;
namespace onnxruntime
{
class KernelDef;
class Node;
}
namespace winrt::Windows::AI::MachineLearning::implementation
{
interface __declspec(uuid("5b19a18a-5ed5-4df2-a363-21b89380a698"))
IWinmlExecutionProvider : public IUnknown
{
public:
// Hold a reference to an object until preceding work in the queue is complete. This
// only needs to be handled by providers which hide the asynchronous nature of
// computation, and involve resoures which cannot be automatically by work in the
// the provider's underlying queues.
virtual void QueueReference(IUnknown *object) = 0;
virtual void GetShadowCopyIfRequired(
bool isInternalOperator,
IUnknown* data,
IUnknown** dataCopy) const = 0;
virtual void GetABIDataInterface(
bool isInternalOperator,
IUnknown* data,
IUnknown** abiData) const = 0;
virtual uint64_t TryGetPooledAllocationId(
IUnknown* data,
bool isInternalOperator) = 0;
virtual void GetABIExecutionInterface(
bool isInternalOperator,
IUnknown** abiExecutionObject) const = 0;
// Whether TransitionResourcesForOperator should be called before and after executing
// an operator registered to this provider with the specified flags
virtual bool TransitionsRequiredForOperator(bool isInternalOperator) = 0;
// If TransitionsRequiredForOperator returns true, should be called before and after executing
// an operator to transition its resources to and from the appropriate state.
virtual void TransitionResourcesForOperator(
bool isBeforeOp,
uint32_t resourceCount,
IUnknown** resources) = 0;
// Waits for flushed work, discards unflushed work, and discards associated references to
// prevent circular references. Must be the last call on the object before destruction.
virtual void Close() = 0;
};
struct DmlOperatorParams
{
Microsoft::WRL::ComPtr<IDMLOperator> op;
std::unique_ptr<AbstractOperatorDesc> desc;
};
// This is the counterpart to the MLOperatorKernelDmlProperties ABI struct which owns its memory and uses containers.
struct DmlGraphNodeCreateInfo
{
bool initialized = false;
// Mapping between DML in/out indices and kernel in/out indices
std::vector<uint32_t> kernelInputIndices;
std::vector<uint32_t> kernelOutputIndices;
Microsoft::WRL::ComPtr<IDMLOperator> op;
std::unique_ptr<AbstractOperatorDesc> desc;
bool allowHalfPrecisionComputation = false;
};
using MLOperatorTensorGetter = std::function<Microsoft::WRL::ComPtr<IMLOperatorTensor>(uint32_t index)>;
using GraphNodeFactory = std::function<void(
const onnxruntime::Node& node,
MLOperatorTensorGetter& constantInputGetter,
const void* executionHandle,
DmlGraphNodeCreateInfo* graphNodeCreateInfo
)>;
struct GraphNodeFactoryRegistration
{
GraphNodeFactory factory;
std::optional<uint32_t> requiredInputCount;
std::vector<uint32_t> requiredConstantCpuInputs;
bool requiresFloatFormatsExceptConstInputs = false;
};
using GraphNodeFactoryMap = std::unordered_map<onnxruntime::KernelDef*, std::shared_ptr<GraphNodeFactoryRegistration>>;
}

Просмотреть файл

@ -0,0 +1,794 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
//! \file MLOperatorAuthor.h
#pragma once
#if defined(__cplusplus)
#if (!defined(_MSC_VER)) || (_MSC_VER >= 1700)
#if !defined(COM_NO_WINDOWS_H)
#include <unknwn.h>
#endif /* !defined(COM_NO_WINDOWS_H) */
#include <cstdint>
#include <winapifamily.h>
#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_APP)
static_assert(sizeof(bool) == 1, "Unsupported size for bool type");
//! \enum MLOperatorAttributeType
//! \brief Specifies the type of an attribute.
enum class MLOperatorAttributeType : uint32_t
{
//! Undefined (unused)
Undefined = 0,
//! 32 bit floating point
Float = 2,
//! 64 bit integer
Int = 3,
//! String
String = 4,
//! Array of 32 bit floating point values
FloatArray = 7,
//! Array of 64 bit floating integer values
IntArray = 8,
//! Array of string values
StringArray = 9
};
//! \enum MLOperatorTensorDataType
//! \brief Specifies the data type of a tensor.
//! Each data type numerically matches corresponding ONNX types.
enum class MLOperatorTensorDataType : uint32_t
{
//! Undefined (unused).
Undefined = 0,
//! IEEE 32 bit floating point
Float = 1,
//! 8 bit unsigned integer
UInt8 = 2,
//! 8 bit signed integer
Int8 = 3,
//! 16 bit unsigned integer
UInt16 = 4,
//! 16 bit signed integer
Int16 = 5,
//! 32 bit signed integer
Int32 = 6,
//! 64 bit signed integer
Int64 = 7,
//! String (unsupported)
String = 8,
//! 8 bit boolean. Values other than zero and one result in undefined behavior.
Bool = 9,
//! IEEE 16 bit floating point
Float16 = 10,
//! 64 bit double-precision floating point
Double = 11,
//! 32 bit unsigned integer
UInt32 = 12,
//! 64 bit unsigned integer
UInt64 = 13,
//! 64 bit Complex type (unsupported)
Complex64 = 14,
//! 128 bit complex type (unsupported)
Complex128 = 15
};
//! \enum MLOperatorEdgeType
//! \brief Specifies the types of an input or output edge of an operator.
enum class MLOperatorEdgeType : uint32_t
{
Undefined = 0,
Tensor = 1,
};
//! \struct MLOperatorEdgeDescription
//! \brief Specifies the properties of an input or output edge of an operator.
struct MLOperatorEdgeDescription
{
//! The type of the edge.
MLOperatorEdgeType edgeType;
union
{
uint64_t reserved;
//! The data type of a tensor. Used when edgeType is set to Tensor.
MLOperatorTensorDataType tensorDataType;
};
};
//! \interface IMLOperatorAttributes
//! \brief Represents the values of an operator's attributes, as determined by a model using the operator.
//! This interface is called by implementations of custom operator kernels, and by implementations
//! of shape and type inferrers.
interface DECLSPEC_UUID("4B1B1759-EC40-466C-AAB4-BEB5347FD24C") DECLSPEC_NOVTABLE
IMLOperatorAttributes : IUnknown
{
//! Gets the count of elements in an attribute.
//! This may be used to determine if an attribute exists, and to determine the
//! count of elements within an attribute of an array type.
STDMETHOD(GetAttributeElementCount)(
_In_z_ const char* name,
MLOperatorAttributeType type,
_Out_ uint32_t* elementCount
) const noexcept PURE;
//! Gets the value of an attribute element which is of a numeric type.
//! For attributes which are of array types, this method queries
//! an individual element within the attribute at the specified index.
STDMETHOD(GetAttribute)(
_In_z_ const char* name,
MLOperatorAttributeType type,
uint32_t elementCount,
size_t elementByteSize,
_Out_writes_bytes_(elementCount * elementByteSize) void* value
) const noexcept PURE;
//! Gets the length of an attribute element which is of a string type.
//! For attributes which are string arrays, this method queries
//! the size of an individual element within the attribute at the
//! specified index.
//! The string is in UTF-8 format. The size includes the null termination character.
STDMETHOD(GetStringAttributeElementLength)(
_In_z_ const char* name,
uint32_t elementIndex,
_Out_ uint32_t* attributeElementByteSize
) const noexcept PURE;
//! Gets the value of an attribute element which is of a string type.
//! For attributes which are string arrays, this method queries
//! the value of an individual element within the attribute at the
//! specified index.
//! The string is in UTF-8 format. The size includes the null termination character.
STDMETHOD(GetStringAttributeElement)(
_In_z_ const char* name,
uint32_t elementIndex,
uint32_t attributeElementByteSize,
_Out_writes_(attributeElementByteSize) char* attributeElement
) const noexcept PURE;
};
//! \interface IMLOperatorTensorShapeDescription
//! \brief Represents the set of input and output tensor shapes of an operator.
//! This interface is called by the factory objects registered to create kernels.
//! It is available to these factory objects unless corresponding kernels are
//! registered using the MLOperatorKernelOptions::AllowDynamicInputShapes flag.
interface DECLSPEC_UUID("F20E8CBE-3B28-4248-BE95-F96FBC6E4643") DECLSPEC_NOVTABLE
IMLOperatorTensorShapeDescription : IUnknown
{
//! Gets the number of dimensions of a tensor input of the operator.
//! Returns an error if the input at the specified index is not a tensor.
STDMETHOD(GetInputTensorDimensionCount)(
uint32_t inputIndex,
_Out_ uint32_t* dimensionCount
) const noexcept PURE;
//! Gets the sizes of dimensions of an input tensor of the operator.
//! Returns an error if the input at the specified index is not a tensor.
STDMETHOD(GetInputTensorShape)(
uint32_t inputIndex,
uint32_t dimensionCount,
_Out_writes_(dimensionCount) uint32_t* dimensions
) const noexcept PURE;
//! Returns true if output shapes may be queried using GetOutputTensorDimensionCount
//! and GetOutputTensorShape. This is true if the kernel was registered with a
//! shape inferrer.
STDMETHOD_(bool, HasOutputShapeDescription)() const noexcept PURE;
//! Gets the number of dimensions of a tensor output of the operator.
//! Returns an error if the output at the specified index is not a tensor.
STDMETHOD(GetOutputTensorDimensionCount)(
uint32_t outputIndex,
_Out_ uint32_t* dimensionCount
) const noexcept PURE;
//! Gets the sizes of dimensions of a tensor output of the operator.
//! Returns an error if the output at the specified index is not a tensor.
STDMETHOD(GetOutputTensorShape)(
uint32_t outputIndex,
uint32_t dimensionCount,
_Out_writes_(dimensionCount) uint32_t* dimensions
) const noexcept PURE;
};
//! \interface IMLOperatorKernelCreationContext
//! \brief Provides information about an operator's usage while kernels are being created.
interface DECLSPEC_UUID("5459B53D-A0FC-4665-ADDD-70171EF7E631") DECLSPEC_NOVTABLE
IMLOperatorKernelCreationContext : public IMLOperatorAttributes
{
//! Gets the number of inputs to the operator.
STDMETHOD_(uint32_t, GetInputCount)() const noexcept PURE;
//! Gets the number of outputs to the operator.
STDMETHOD_(uint32_t, GetOutputCount)() const noexcept PURE;
//! Returns true if an input to the operator is valid.
//! This always returns true if within GetInputCount except for optional inputs.
STDMETHOD_(bool, IsInputValid)(uint32_t inputIndex) const noexcept PURE;
//! Returns true if an output to the operator is valid.
//! This always returns true if within GetOutputCount except for optional outputs.
STDMETHOD_(bool, IsOutputValid)(uint32_t outputIndex) const noexcept PURE;
//! Gets the description of the specified input edge of the operator.
STDMETHOD(GetInputEdgeDescription)(
uint32_t inputIndex,
_Out_ MLOperatorEdgeDescription* edgeDescription
) const noexcept PURE;
//! Gets the description of the specified output edge of the operator.
STDMETHOD(GetOutputEdgeDescription)(
uint32_t outputIndex,
_Out_ MLOperatorEdgeDescription* edgeDescription
) const noexcept PURE;
//! Returns true if the description of input and output shapes connected to
//! operator edges may be queried using GetTensorShapeDescription.
//! This returns true unless the operator was registered using
//! the MLOperatorKernelOptions::AllowDynamicInputShapes flag.
STDMETHOD_(bool, HasTensorShapeDescription)() const noexcept PURE;
//! Gets the description of input and output shapes connected to
//! operator edges.
STDMETHOD(GetTensorShapeDescription)(
_COM_Outptr_ IMLOperatorTensorShapeDescription** shapeDescription
) const noexcept PURE;
//! Returns an object whose supported interfaces vary based on the kernel type.
//! For kernels registered with MLOperatorExecutionType::Cpu, executionObject will
//! be set to nullptr.
//! For kernels registered with MLOperatorExecutionType::D3D12, executionObject will
//! support the ID3D12GraphicsCommandList interface.
STDMETHOD_(void, GetExecutionInterface)(
_COM_Outptr_result_maybenull_ IUnknown** executionObject
) const noexcept PURE;
};
//! \interface IMLOperatorTensor
//! \brief Representation of a tensor used during computation of custom operator kernels.
interface DECLSPEC_UUID("7FE41F41-F430-440E-AECE-54416DC8B9DB") DECLSPEC_NOVTABLE
IMLOperatorTensor : IUnknown
{
//! Gets the number of dimensions in the tensor. This may be zero.
STDMETHOD_(uint32_t, GetDimensionCount)() const noexcept PURE;
//! Gets the size of dimensions in the tensor.
STDMETHOD(GetShape)(
uint32_t dimensionCount,
_Out_writes_(dimensionCount) uint32_t* dimensions
) const noexcept PURE;
//! Gets the data type of the tensor.
STDMETHOD_(MLOperatorTensorDataType, GetTensorDataType)() const noexcept PURE;
//! Indicates whether the memory used by the tensor is CPU-addressable.
//! This is true when kernels are registered using MLOperatorExecutionType::Cpu.
STDMETHOD_(bool, IsCpuData)() const noexcept PURE;
//! Whether the contents of the tensor are represented by an interface type,
//! or byte-addressable memory. This returns true when kernels are registered
//! using MLOperatorExecutionType::D3D12.
STDMETHOD_(bool, IsDataInterface)() const noexcept PURE;
//! Returns a pointer to byte-addressable memory for the tensor. This may be
//! used when IsDataInterface returns false, because the kernel was
//! registered using MLOperatorExecutionType::Cpu. The data size is derived
//! from the tensor's shape. It is fully packed in memory.
STDMETHOD_(void*, GetData)() noexcept PURE;
//! Gets an interface pointer for the tensor. This may be
//! used when IsDataInterface returns true, because the kernel was
//! registered using MLOperatorExecutionType::D3D12. The dataInterface
//! object supports the ID3D12Resource interface, and is a GPU buffer.
STDMETHOD_(void, GetDataInterface)(
_COM_Outptr_result_maybenull_ IUnknown** dataInterface
) noexcept PURE;
};
//! \interface IMLOperatorKernelContext
//! \brief Provides information about an operator's usage while kernels are being computed.
interface DECLSPEC_UUID("82536A28-F022-4769-9D3F-8B278F84C0C3") DECLSPEC_NOVTABLE
IMLOperatorKernelContext : IUnknown
{
//! Gets the input tensor of the operator at the specified index.
//! This sets tensor to nullptr for optional inputs which do not exist.
//! Returns an error if the input at the specified index is not a tensor.
STDMETHOD(GetInputTensor)(
uint32_t inputIndex,
_COM_Outptr_result_maybenull_ IMLOperatorTensor** tensor
) const noexcept PURE;
//! Gets the output tensor of the operator at the specified index.
//! This sets tensor to nullptr for optional outputs which do not exist.
//! If the operator kernel was registered without a shape inference method,
//! then the overload of GetOutputTensor which consumes the tensor's shape must
//! be called instead. Returns an error if the output at the specified index is
//! not a tensor.
STDMETHOD(GetOutputTensor)(
uint32_t outputIndex,
_COM_Outptr_result_maybenull_ IMLOperatorTensor** tensor
) noexcept PURE;
//! Gets the output tensor of the operator at the specified index, while declaring
//! its shape.
//! This returns nullptr for optional outputs which do not exist.
//! If the operator kernel was registered with a shape inference method,
//! then the overload of GetOutputTensor which doesn't consume a shape may also
//! be called. Returns an error if the output at the specified index is
//! not a tensor.
STDMETHOD(GetOutputTensor)(
uint32_t outputIndex,
uint32_t dimensionCount,
_In_reads_(dimensionCount) const uint32_t* dimensionSizes,
_COM_Outptr_result_maybenull_ IMLOperatorTensor** tensor
) noexcept PURE;
//! Allocates temporary data which will be usable as intermediate memory for the duration
//! of a call to IMLOperatorKernel::Compute. This may be used by kernels
//! registered using MLOperatorExecutionType::D3D12. The data
//! object supports the ID3D12Resource interface, and is a GPU buffer.
STDMETHOD(AllocateTemporaryData)(size_t size, _COM_Outptr_ IUnknown** data) const = 0;
//! Returns an object whose supported interfaces vary based on the kernel type.
//! For kernels registered with MLOperatorExecutionType::Cpu, executionObject will
//! be set to nullptr.
//! For kernels registered with MLOperatorExecutionType::D3D12, executionObject will
//! support the ID3D12GraphicsCommandList interface. This may be a different object
//! than was provided to IMLOperatorKernelCreationContext::GetExecutionInterface
//! when the kernel instance was created.
STDMETHOD_(void, GetExecutionInterface)(
_Outptr_result_maybenull_ IUnknown** executionObject
) const noexcept PURE;
};
//! \interface IMLOperatorKernel
//! \brief Implemented by custom operator kernels.
//! A factory which creates interfaces of this interface is supplied when
//! registering custom operator kernels using IMLOperatorKernelFactory::RegisterOperatorKernel.
interface DECLSPEC_UUID("11C4B4A0-B467-4EAA-A1A6-B961D8D0ED79") DECLSPEC_NOVTABLE
IMLOperatorKernel : IUnknown
{
//! Computes the outputs of the kernel. The implementation of this method
//! should be thread-safe. The same instance of the kernel may be computed
//! simultaneously on different threads.
STDMETHOD(Compute)(IMLOperatorKernelContext* context) noexcept PURE;
};
//! \enum MLOperatorParameterOptions
//! \brief Specifies option flags of input and output edges of operators.
//! These options are used while defining custom operator schema.
enum class MLOperatorParameterOptions : uint32_t
{
//! There is a single instance of the input or output.
Single = 0,
//! The input or output may be omitted.
Optional = 1,
//! The number of instances of the operator is variable. Variadic parameters
//! must be last among the set of inputs or outputs.
Variadic = 2,
};
DEFINE_ENUM_FLAG_OPERATORS(MLOperatorParameterOptions);
//! \enum MLOperatorSchemaEdgeTypeFormat
//! \brief Specifies the manner in which types of input and output edges are described.
//! This is used within MLOperatorSchemaEdgeDescription while defining custom operator schema.
enum class MLOperatorSchemaEdgeTypeFormat
{
//! The type is defined using MLOperatorEdgeDescription.
EdgeDescription = 0,
//! The type is defined by a type string constructed as in ONNX operator schema.
Label = 1,
};
//! \struct MLOperatorSchemaEdgeDescription
//! \brief Specifies information about an input or output edge of an operator.
//! This is used while defining custom operator schema.
struct MLOperatorSchemaEdgeDescription
{
//! Options of the parameter, including whether it is optional or variadic.
MLOperatorParameterOptions options;
//! The manner in which the type constraints and type mapping are defined.
MLOperatorSchemaEdgeTypeFormat typeFormat;
union
{
const void* reserved;
//! A type label string constructed as in ONNX operator schema. For example, "T".
//! This is used when typeFormat is MLOperatorSchemaEdgeTypeFormat::Label.
_Field_z_ const char* typeLabel;
//! A structure describing type support.
//! This is used when typeFormat is MLOperatorSchemaEdgeTypeFormat::EdgeDescription.
MLOperatorEdgeDescription edgeDescription;
};
};
//! \struct MLOperatorEdgeTypeConstraint
//! \brief Specifies constraints upon the types of edges supported in custom operator kernels
//! and schema. The provided type label string corresponds to type labels in the ONNX
//! specification for the same operator. For custom schema, it corresponds to type labels
//! specified within MLOperatorSchemaEdgeDescription when registering the operator's schema.
struct MLOperatorEdgeTypeConstraint
{
//! The label of the type for which the constraint is being defined.
//! This is constructed as in ONNX operator schema. For example, "T".
_Field_z_ const char* typeLabel;
//! The set of allowed types for the constraint.
_Field_size_opt_(allowedTypeCount) const MLOperatorEdgeDescription* allowedTypes;
uint32_t allowedTypeCount;
};
// Legacy alias.
using MLOperatorEdgeTypeConstrant = MLOperatorEdgeTypeConstraint;
//! \interface IMLOperatorShapeInferenceContext
//! \brief Provides information about an operator's usage while shape inferrers are being invoked.
interface DECLSPEC_UUID("105B6B29-5408-4A68-9959-09B5955A3492") DECLSPEC_NOVTABLE
IMLOperatorShapeInferenceContext : public IMLOperatorAttributes
{
//! Gets the number of inputs to the operator.
STDMETHOD_(uint32_t, GetInputCount)() const noexcept PURE;
//! Gets the number of outputs to the operator.
STDMETHOD_(uint32_t, GetOutputCount)() const noexcept PURE;
//! Returns true if an input to the operator is valid.
//! This always returns true except for optional inputs and invalid indices.
STDMETHOD_(bool, IsInputValid)(uint32_t inputIndex) const noexcept PURE;
//! Returns true if an output to the operator is valid.
//! This always returns true except for optional outputs and invalid indices.
STDMETHOD_(bool, IsOutputValid)(uint32_t outputIndex) const noexcept PURE;
//! Gets the description of the specified input edge of the operator.
STDMETHOD(GetInputEdgeDescription)(
uint32_t inputIndex,
_Out_ MLOperatorEdgeDescription* edgeDescription
) const noexcept PURE;
//! Gets the number of dimensions of a tensor output of the operator.
STDMETHOD(GetInputTensorDimensionCount)(
uint32_t inputIndex,
_Out_ uint32_t* dimensionCount
) const noexcept PURE;
//! Gets the sizes of dimensions of an input tensor of the operator.
//! Returns an error if the input at the specified index is not a tensor.
STDMETHOD(GetInputTensorShape)(
uint32_t inputIndex,
uint32_t dimensionCount,
_Out_writes_(dimensionCount) uint32_t* dimensions
) const noexcept PURE;
//! Sets the inferred shape of an output tensor.
//! Returns an error if the output at the specified index is not a tensor.
STDMETHOD(SetOutputTensorShape)(
uint32_t outputIndex,
uint32_t dimensionCount,
const uint32_t* dimensions
) noexcept PURE;
};
//! \interface IMLOperatorTypeInferenceContext
//! \brief Provides information about an operator's usage while type inferrers are being invoked.
interface DECLSPEC_UUID("EC893BB1-F938-427B-8488-C8DCF775F138") DECLSPEC_NOVTABLE
IMLOperatorTypeInferenceContext : public IMLOperatorAttributes
{
//! Gets the number of inputs to the operator.
STDMETHOD_(uint32_t, GetInputCount)() const noexcept PURE;
//! Gets the number of outputs to the operator.
STDMETHOD_(uint32_t, GetOutputCount)() const noexcept PURE;
//! Returns true if an input to the operator is valid.
//! This always returns true except for optional inputs.
STDMETHOD_(bool, IsInputValid)(uint32_t inputIndex) const noexcept PURE;
//! Returns true if an output to the operator is valid.
//! This always returns true except for optional outputs.
STDMETHOD_(bool, IsOutputValid)(uint32_t outputIndex) const noexcept PURE;
//! Gets the description of the specified input edge of the operator.
STDMETHOD(GetInputEdgeDescription)(
uint32_t inputIndex,
_Out_ MLOperatorEdgeDescription* edgeDescription
) const noexcept PURE;
//! Sets the inferred type of an output edge.
STDMETHOD(SetOutputEdgeDescription)(
uint32_t outputIndex,
const MLOperatorEdgeDescription* edgeDescription
) const noexcept PURE;
};
//! \interface IMLOperatorTypeInferrer
//! \brief Implemented by type inferrers to infer types of an operator's output edges.
//! Type inferrers must be provided when registering schema of custom operators if
//! the MLOperatorSchemaDescription structure cannot express how output types are
//! determined. For example, such as when an attribute of the operator determines
//! the data type of one of that operator's outputs.
interface DECLSPEC_UUID("781AEB48-9BCB-4797-BF77-8BF455217BEB") DECLSPEC_NOVTABLE
IMLOperatorTypeInferrer : IUnknown
{
//! Called to infer types of an operator's output edges
STDMETHOD(InferOutputTypes)(
IMLOperatorTypeInferenceContext* context
) noexcept PURE;
};
//! \interface IMLOperatorShapeInferrer
//! \brief Implemented by shape inferrers to infer shapes of an operator's
//! output tensor edges. Shape inferrers may be provided when registering custom
//! operator kernels to improve performance and to enable the kernel to query
//! the shape of its output tensors when it is created and computed. Shape
//! inferrers may also be provided when registering custom operator schema to
//! improve model validation.
interface DECLSPEC_UUID("540BE5BE-A6C9-40EE-83F6-D2B8B40A7798") DECLSPEC_NOVTABLE
IMLOperatorShapeInferrer : IUnknown
{
//! Called to infer shapes of an operator's output edges.
STDMETHOD(InferOutputShapes)(
IMLOperatorShapeInferenceContext* context
) noexcept PURE;
};
//! \struct MLOperatorAttribute
//! \brief Specifies the name and properties of an attribute of a custom operator.
//! This is used when registering custom operator kernels and custom operator schema.
struct MLOperatorAttribute
{
//! NULL-terminated UTF-8 string representing the name of the attribute in the
//! associated operator type.
_Field_z_ const char* name;
//! The type of the attribute in the associated operator type.
MLOperatorAttributeType type;
//! Whether the attribute is required in any model using the associated operator type.
bool required;
};
//! \struct MLOperatorAttributeNameValue
//! \brief Specifies the name and value(s) of an attribute of a custom operator.
//! This is used when registering custom operator kernels and custom operator schema.
struct MLOperatorAttributeNameValue
{
//! NULL-terminated UTF-8 string representing the name of the attribute in the
//! associated operator type.
_Field_z_ const char* name;
//! The type of the attribute in the associated operator type.
MLOperatorAttributeType type;
//! The number of elements in the attribute value. This must be one, except for attributes
//! which are of array types.
uint32_t valueCount;
union
{
const void* reserved;
//! 64 bit integer value(s). Used when the type field is
//! MLOperatorAttributeType::Int or MLOperatorAttributeType::IntArray.
_Field_size_(valueCount) const int64_t* ints;
//! NULL-terminated UTF-8 string value(s). Used when the type field is
//! MLOperatorAttributeType::String or MLOperatorAttributeType::StringArray.
_Field_size_(valueCount) const char* const* strings;
//! 32 bit floating point value(s). Used when the type field is
//! MLOperatorAttributeType::Float or MLOperatorAttributeType::FloatArray.
_Field_size_(valueCount) const float* floats;
};
};
//! \struct MLOperatorSchemaDescription
//! \brief Description of a custom operator schema used to register that schema.
struct MLOperatorSchemaDescription
{
//! NULL-terminated UTF-8 string representing the name of the operator.
_Field_z_ const char* name;
//! The operator set version at which this operator was introduced or last changed.
int32_t operatorSetVersionAtLastChange;
//! An array containing the descriptions of the operator's input edges.
_Field_size_opt_(inputCount) const MLOperatorSchemaEdgeDescription* inputs;
//! The number of inputs of the operator.
uint32_t inputCount;
//! An array containing the descriptions of the operator's output edges.
_Field_size_opt_(outputCount) const MLOperatorSchemaEdgeDescription* outputs;
//! The number of outputs of the operator.
uint32_t outputCount;
//! An array of type constraints. Each constraint restricts input and outputs
//! associated with a type label string to one or more edge types.
_Field_size_opt_(typeConstraintCount) const MLOperatorEdgeTypeConstraint* typeConstraints;
//! The number of type constraints provided.
uint32_t typeConstraintCount;
//! The set of attributes supported by the operator type.
_Field_size_opt_(attributeCount) const MLOperatorAttribute* attributes;
//! The number of provided attributes.
uint32_t attributeCount;
//! The default values of attributes. These will be applied when the attributes are missing
//! in a model containing the operator type.
_Field_size_opt_(defaultAttributeCount) const MLOperatorAttributeNameValue* defaultAttributes;
//! The number of provided default attribute values.
uint32_t defaultAttributeCount;
};
//! \struct MLOperatorSetId
//! \brief Specifies the identity of an operator set.
struct MLOperatorSetId
{
//! The domain of the operator, for example, "ai.onnx.ml", or an empty string
//! for the ONNX domain.
_Field_z_ const char* domain;
//! The version of the operator domain.
int32_t version;
};
//! \enum MLOperatorKernelOptions
//! \brief Specifies options used when registering custom operator kernels.
enum class MLOperatorKernelOptions : uint32_t
{
None = 0,
//! Specifies whether the shapes of input tensors are allowed to vary among invocations
//! of an operator kernel instance. If this is not set, kernel instances may query input
//! tensor shapes during creation, and front-load initialization work which depends
//! on those shapes. Setting this may improve performance if shapes vary dynamically between
//! inference operations, and the kernel implementation handles this efficiently.
AllowDynamicInputShapes = 1,
};
DEFINE_ENUM_FLAG_OPERATORS(MLOperatorKernelOptions);
//! \enum MLOperatorExecutionType
//! \brief Specifies whether a kernel uses the CPU or GPU for computation.
enum class MLOperatorExecutionType : uint32_t
{
Undefined = 0,
Cpu = 1,
D3D12 = 2
};
//! \struct MLOperatorKernelDescription
//! \brief Description of a custom operator kernel used to register that schema.
struct MLOperatorKernelDescription
{
//! NULL-terminated UTF-8 string representing the name of the operator's domain.
_Field_z_ const char* domain;
//! NULL-terminated UTF-8 string representing the name of the operator.
_Field_z_ const char* name;
//! The minimum version of the operator sets for which this kernel is valid.
//! The maximum version is inferred based on registrations of operator set schema for
//! subsequent versions of the same domain.
int32_t minimumOperatorSetVersion;
//! Specifies whether a kernel uses the CPU or GPU for computation.
MLOperatorExecutionType executionType;
//! An array of type constraints. Each constraint restricts input and outputs
//! associated with a type label string to one or more edge types.
_Field_size_opt_(typeConstraintCount) const MLOperatorEdgeTypeConstraint* typeConstraints;
//! The number of type constraints provided.
uint32_t typeConstraintCount;
//! The default values of attributes. These will be applied when the attributes are missing
//! in a model containing the operator type.
_Field_size_opt_(defaultAttributeCount) const MLOperatorAttributeNameValue* defaultAttributes;
//! The number of provided default attribute values.
uint32_t defaultAttributeCount;
//! Options for the kernel which apply to all execution provider types.
MLOperatorKernelOptions options;
//! Reserved for additional options. Must be zero.
uint32_t executionOptions;
};
//! \interface IMLOperatorKernelFactory
//! \brief Implemented by the author of a custom operator kernel to create instances of that kernel.
interface DECLSPEC_UUID("EF15AD6F-0DC9-4908-AB35-A575A30DFBF8") DECLSPEC_NOVTABLE
IMLOperatorKernelFactory : IUnknown
{
//! Creates an instance of the associated operator kernel, given information about the operator's
//! usage within a model described in the provided context object.
STDMETHOD(CreateKernel)(
IMLOperatorKernelCreationContext* context,
_COM_Outptr_ IMLOperatorKernel** kernel
) noexcept PURE;
};
//! \interface IMLOperatorRegistry
//! \brief Represents an instance of a registry for custom operator kernel and schema.
//! Custom operators may be used with Windows.AI.MachineLearning APIs by returning
//! instances of IMLOperatorRegistry through ILearningModelOperatorProviderNative.
interface DECLSPEC_UUID("2AF9DD2D-B516-4672-9AB5-530C208493AD") DECLSPEC_NOVTABLE
IMLOperatorRegistry : IUnknown
{
//! Registers a set of custom operator schema comprising an operator set. Operator sets follow
//! the ONNX versioning design. Callers should provide schema for all operators that have changed
//! between the specified baseline version and the version specified within operatorSetId. This
//! prevents older versions of kernels from being used in models which import the newer operator
//! set version. A type inferrer must be provided if the MLOperatorSchemaDescription structure
//! cannot express how output types are determined. A shape inferrer may optionally be provided
//! to enable model validation.
STDMETHOD(RegisterOperatorSetSchema)(
const MLOperatorSetId* operatorSetId,
int32_t baselineVersion,
_In_reads_opt_(schemaCount) const MLOperatorSchemaDescription* const* schema,
uint32_t schemaCount,
_In_opt_ IMLOperatorTypeInferrer* typeInferrer,
_In_opt_ IMLOperatorShapeInferrer* shapeInferrer
) const noexcept PURE;
//! Registers a custom operator kernel.
//! A shape inferrer may optionally be provided. This may improve performance and enables
//! the kernel to query the shape of its output tensors when it is created and computed.
STDMETHOD(RegisterOperatorKernel)(
const MLOperatorKernelDescription* operatorKernel,
IMLOperatorKernelFactory* operatorKernelFactory,
_In_opt_ IMLOperatorShapeInferrer* shapeInferrer
) const noexcept PURE;
};
extern "C"
{
//! \fn MLCreateOperatorRegistry
//! Creates an instance of IMLOperatorRegistry which may be used to register custom
//! operator kernel and custom operator schema.
HRESULT WINAPI MLCreateOperatorRegistry(_COM_Outptr_ IMLOperatorRegistry** registry);
}
#endif /* WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_APP) */
#endif /* defined(__cplusplus) */
#endif /* defined(_MSC_VER) && (_MSC_VER >= 1700) */

Просмотреть файл

@ -0,0 +1,511 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
#include "AbiCustomRegistry.h"
namespace winrt::Windows::AI::MachineLearning::implementation
{
AbiCustomRegistry::AbiCustomRegistry() :
m_kernelRegistry(std::make_shared<onnxruntime::CustomRegistry>()),
m_graphNodeFactoryMap(std::make_shared<GraphNodeFactoryMap>())
{
}
onnx::OpSchema::FormalParameterOption AbiCustomRegistry::ConvertFormalParameterOption(MLOperatorParameterOptions options)
{
switch (options)
{
case MLOperatorParameterOptions::Single:
return onnx::OpSchema::FormalParameterOption::Single;
case MLOperatorParameterOptions::Optional:
return onnx::OpSchema::FormalParameterOption::Optional;
case MLOperatorParameterOptions::Variadic:
return onnx::OpSchema::FormalParameterOption::Variadic;
default:
THROW_HR(E_NOTIMPL);
}
}
// Convert edge types from the ABI types to ONNX strings
std::string AbiCustomRegistry::ConvertFormalParameterType(const MLOperatorSchemaEdgeDescription& formalParameter)
{
ML_CHECK_BOOL(formalParameter.typeFormat == MLOperatorSchemaEdgeTypeFormat::Label ||
formalParameter.typeFormat == MLOperatorSchemaEdgeTypeFormat::EdgeDescription);
if (formalParameter.typeFormat == MLOperatorSchemaEdgeTypeFormat::Label)
{
return formalParameter.typeLabel;
} else
{
return ToTypeString(formalParameter.edgeDescription);
}
}
// Convert type constraints from the ABI types to ONNX strings
std::vector<std::string> ConvertTypeConstraintTypes(const MLOperatorEdgeTypeConstrant& constraint)
{
std::vector<std::string> ret;
ret.reserve(constraint.allowedTypeCount);
for (uint32_t i = 0; i < constraint.allowedTypeCount; ++i)
{
ret.emplace_back(ToTypeString(constraint.allowedTypes[i]));
}
return ret;
}
// Convert attributes and defaults from the ABI to ONNX schema
void AbiCustomRegistry::SetAttributesAndDefaults(onnx::OpSchema& schema, const MLOperatorSchemaDescription& abiSchema)
{
// Create a map with default attributes
std::map<std::string, const MLOperatorAttributeNameValue*> defaultAttributes;
for (uint32_t attributeIndex = 0; attributeIndex < abiSchema.defaultAttributeCount; ++attributeIndex)
{
const MLOperatorAttributeNameValue& defaultAttribute = abiSchema.defaultAttributes[attributeIndex];
defaultAttributes[defaultAttribute.name] = &defaultAttribute;
}
// Set each attribute along with default values, looked up by name, if available
for (uint32_t attributeIndex = 0; attributeIndex < abiSchema.attributeCount; ++attributeIndex)
{
const MLOperatorAttribute& attribute = abiSchema.attributes[attributeIndex];
auto defaultVal = defaultAttributes.find(attribute.name);
if (defaultVal == defaultAttributes.end())
{
schema.Attr(attribute.name, "", ToProto(attribute.type), attribute.required);
}
else
{
ML_CHECK_BOOL(!attribute.required);
ML_CHECK_BOOL(attribute.type == defaultVal->second->type);
uint32_t defaultCount = defaultVal->second->valueCount;
switch (attribute.type)
{
case MLOperatorAttributeType::Float:
ML_CHECK_BOOL(defaultCount == 1);
schema.Attr(attribute.name, "", ToProto(attribute.type), defaultVal->second->floats[0]);
break;
case MLOperatorAttributeType::Int:
ML_CHECK_BOOL(defaultCount == 1);
schema.Attr(attribute.name, "", ToProto(attribute.type), defaultVal->second->ints[0]);
break;
case MLOperatorAttributeType::String:
ML_CHECK_BOOL(defaultCount == 1);
schema.Attr(attribute.name, "", ToProto(attribute.type), std::string(defaultVal->second->strings[0]));
break;
case MLOperatorAttributeType::FloatArray:
{
std::vector<float> defaultVals(defaultVal->second->floats, defaultVal->second->floats + defaultCount);
schema.Attr(attribute.name, "", ToProto(attribute.type), defaultVals);
break;
}
case MLOperatorAttributeType::IntArray:
{
std::vector<int64_t> defaultVals(defaultVal->second->ints, defaultVal->second->ints + defaultCount);
schema.Attr(attribute.name, "", ToProto(attribute.type), defaultVals);
break;
}
case MLOperatorAttributeType::StringArray:
{
std::vector<std::string> defaultVals(defaultVal->second->strings, defaultVal->second->strings + defaultCount);
schema.Attr(attribute.name, "", ToProto(attribute.type), defaultVals);
break;
}
case MLOperatorAttributeTypeTensor:
// Tensor is too complex to express a default value. Default checking is done by the operator code.
__fallthrough;
default:
ML_CHECK_BOOL(false);
break;
}
// Remove the default attribute from the map, to later ensure defaults matched attributes
defaultAttributes.erase(attribute.name);
}
}
ML_CHECK_BOOL(defaultAttributes.empty());
}
// Convert a schema from the ABI to ONNX type
onnx::OpSchema AbiCustomRegistry::ConvertOpSchema(
_In_z_ const char* domain,
const MLOperatorSchemaDescription& abiSchema,
IMLOperatorTypeInferrer* typeInferrer,
IMLOperatorShapeInferrer* shapeInferrer
)
{
// Set the op schema name, domain, and version
onnx::OpSchema schema(abiSchema.name, "", 0);
schema.SetDomain(domain);
schema.SinceVersion(abiSchema.operatorSetVersionAtLastChange);
// ONNX fails if using an empty string for edge names, although their names don't
// matter for us.
const char* emptyName = " ";
// Populate inputs
for (uint32_t inputIndex = 0; inputIndex < abiSchema.inputCount; ++inputIndex)
{
schema.Input(
inputIndex,
emptyName,
"",
ConvertFormalParameterType(abiSchema.inputs[inputIndex]),
ConvertFormalParameterOption(abiSchema.inputs[inputIndex].options));
}
// Populate outputs
for (uint32_t outputIndex = 0; outputIndex < abiSchema.outputCount; ++outputIndex)
{
schema.Output(
outputIndex,
emptyName,
"",
ConvertFormalParameterType(abiSchema.outputs[outputIndex]),
ConvertFormalParameterOption(abiSchema.outputs[outputIndex].options));
}
// Populate type constraints
for (uint32_t constraintIndex = 0; constraintIndex < abiSchema.typeConstraintCount; ++constraintIndex)
{
schema.TypeConstraint(
abiSchema.typeConstraints[constraintIndex].typeLabel,
ConvertTypeConstraintTypes(abiSchema.typeConstraints[constraintIndex]),
"");
}
// Set attribute defaults
SetAttributesAndDefaults(schema, abiSchema);
// Set an inferencing method
if (shapeInferrer || typeInferrer)
{
ComPtr<IMLOperatorShapeInferrer> shapeInferrerCapture = shapeInferrer;
ComPtr<IMLOperatorTypeInferrer> typeInferrerCapture = typeInferrer;
schema.TypeAndShapeInferenceFunction([=](onnx::InferenceContext& ctx)
{
// Constant CPU inputs cannot currently be specified through the public ABI for schema registration.
gsl::span<const uint32_t> requiredConstantCpuInputs;
onnxruntime::OpNodeProtoHelper<onnx::InferenceContext> nodeInfo(&ctx);
ComPtr<MLSchemaInferenceContext> abiContext = wil::MakeOrThrow<MLSchemaInferenceContext>(&nodeInfo, &ctx, requiredConstantCpuInputs);
// Do type inference
if (typeInferrerCapture)
{
THROW_IF_FAILED(typeInferrerCapture->InferOutputTypes(abiContext.Get()));
}
// Do shape inference if all input tensor shapes are known
if (shapeInferrerCapture && InputTensorShapesDefinedOnNode(nodeInfo))
{
THROW_IF_FAILED(shapeInferrerCapture->InferOutputShapes(abiContext.Get()));
}
abiContext->Close();
});
}
return schema;
}
HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorSetSchema(
const MLOperatorSetId* opSetId,
int baseline_version,
const MLOperatorSchemaDescription* const* schema,
uint32_t schemaCount,
_In_opt_ IMLOperatorTypeInferrer* typeInferrer,
_In_opt_ IMLOperatorShapeInferrer* shapeInferrer) const noexcept try
{
std::vector<onnx::OpSchema> schemaVector;
schemaVector.reserve(schemaCount);
// Convert schema to ONNX types and accumulate them in a vector
for (uint32_t i = 0; i < schemaCount; ++i)
{
schemaVector.emplace_back(ConvertOpSchema(opSetId->domain, *schema[i], typeInferrer, shapeInferrer));
}
// Multiple registries are used to avoid having different versions of the same domain in a single
// registry, which Lotus doesn't support.
auto registryKey = std::pair<int, int>(baseline_version, opSetId->version);
auto registryIter = m_customRegistryOpsetVerMap.find(registryKey);
if (registryIter == m_customRegistryOpsetVerMap.end())
{
m_customRegistryOpsetVerMap[registryKey] = std::make_shared<onnxruntime::CustomRegistry>();
}
// Register the operator set with Lotus
// TODO - Split apart multiple op-sets with a common domain into multiple registries, as required by Lotus
// for correct lookup (Bug 4662).
THROW_IF_NOT_OK(m_customRegistryOpsetVerMap[registryKey]->RegisterOpSet(
schemaVector,
opSetId->domain,
baseline_version,
opSetId->version));
return S_OK;
}
CATCH_RETURN();
// Convert the list of attribute defaults in a kernel registration into a
// map of AttributeValue entries, which own their own memory
AttributeMap AbiCustomRegistry::GetDefaultAttributes(
const MLOperatorKernelDescription* opKernel
)
{
AttributeMap ret;
for (uint32_t i = 0; i < opKernel->defaultAttributeCount; ++i)
{
const MLOperatorAttributeNameValue &apiAttr = opKernel->defaultAttributes[i];
AttributeValue attr;
attr.type = apiAttr.type;
switch(apiAttr.type)
{
case MLOperatorAttributeType::Float:
ML_CHECK_BOOL(apiAttr.valueCount == 1);
__fallthrough;
case MLOperatorAttributeType::FloatArray:
attr.floats.assign(&apiAttr.floats[0], apiAttr.floats + apiAttr.valueCount);
attr.floats.assign(&apiAttr.floats[0], apiAttr.floats + apiAttr.valueCount);
break;
case MLOperatorAttributeType::String:
ML_CHECK_BOOL(apiAttr.valueCount == 1);
__fallthrough;
case MLOperatorAttributeType::StringArray:
attr.strings.assign(&apiAttr.strings[0], &apiAttr.strings[apiAttr.valueCount]);
break;
case MLOperatorAttributeType::Int:
ML_CHECK_BOOL(apiAttr.valueCount == 1);
__fallthrough;
case MLOperatorAttributeType::IntArray:
attr.ints.assign(&apiAttr.ints[0], &apiAttr.ints[apiAttr.valueCount]);
break;
case MLOperatorAttributeTypeTensor:
// Tensor is too complex to express a default value. Default checking is done by the operator code.
__fallthrough;
default:
THROW_HR(E_INVALIDARG);
}
ret[apiAttr.name] = attr;
}
return ret;
}
HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
const MLOperatorKernelDescription* opKernel,
IMLOperatorKernelFactory* operatorKernelFactory,
_In_opt_ IMLOperatorShapeInferrer* shapeInferrer) const noexcept
{
return RegisterOperatorKernel(opKernel, operatorKernelFactory, shapeInferrer, false, false, false);
}
HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
const MLOperatorKernelDescription* opKernel,
IMLOperatorKernelFactory* operatorKernelFactory,
_In_opt_ IMLOperatorShapeInferrer* shapeInferrer,
bool isInternalOperator,
bool canAliasFirstInput,
bool supportsGraph,
const uint32_t* requiredInputCountForGraph,
bool requiresFloatFormatsForGraph,
_In_reads_(constantCpuInputCount) const uint32_t* requiredConstantCpuInputs,
uint32_t constantCpuInputCount) const noexcept try
{
// Verify that invalid flags are not passed
if ((opKernel->options & ~MLOperatorKernelOptions::AllowDynamicInputShapes) !=
MLOperatorKernelOptions::None)
{
return E_INVALIDARG;
}
// Translate flags
bool requiresInputShapesAtCreation = (opKernel->options & MLOperatorKernelOptions::AllowDynamicInputShapes) == MLOperatorKernelOptions::None;
bool requiresOutputShapesAtCreation = !!shapeInferrer;
// Verify allowed combinations of flags are used
if (!requiresInputShapesAtCreation && requiresOutputShapesAtCreation)
{
return E_INVALIDARG;
}
const char* providerType = nullptr;
if (opKernel->executionOptions != 0)
{
return E_INVALIDARG;
}
if (opKernel->executionType == MLOperatorExecutionType::Cpu)
{
providerType = onnxruntime::kCpuExecutionProvider;
}
else if (opKernel->executionType == MLOperatorExecutionType::D3D12)
{
providerType = onnxruntime::kDmlExecutionProvider;
}
else
{
return E_INVALIDARG;
}
// Set the name, domain, version, and provider
onnxruntime::KernelDefBuilder builder;
builder.SetName(opKernel->name);
builder.SetDomain(opKernel->domain)
.SinceVersion(opKernel->minimumOperatorSetVersion)
.Provider(providerType);
std::string_view name(opKernel->name);
if (name == "MemcpyToHost")
{
builder.OutputMemoryType<::OrtMemType::OrtMemTypeCPUOutput>(0);
}
else if (name == "MemcpyFromHost")
{
builder.InputMemoryType<::OrtMemType::OrtMemTypeCPUInput>(0);
}
std::vector<uint32_t> constantCpuInputCapture;
constantCpuInputCapture.assign(requiredConstantCpuInputs, requiredConstantCpuInputs + constantCpuInputCount);
for (uint32_t inputIndex : constantCpuInputCapture)
{
builder.InputMemoryType<::OrtMemType::OrtMemTypeCPUInput>(inputIndex);
}
if (canAliasFirstInput)
{
builder.Alias(0, 0);
}
// Set type constraints
for (uint32_t i = 0; i < opKernel->typeConstraintCount; ++i)
{
std::vector<onnxruntime::MLDataType> types;
types.reserve(opKernel->typeConstraints[i].allowedTypeCount);
for (uint32_t j = 0; j < opKernel->typeConstraints[i].allowedTypeCount; ++j)
{
// TODO - handle non-tensor types
if (opKernel->typeConstraints[i].allowedTypes[j].edgeType != MLOperatorEdgeType::Tensor)
{
THROW_IF_FAILED(E_NOTIMPL);
}
types.push_back(ToTensorDataType(opKernel->typeConstraints[i].allowedTypes[j].tensorDataType));
}
builder.TypeConstraint(opKernel->typeConstraints[i].typeLabel, types);
}
ComPtr<IMLOperatorKernelFactory> kernelFactoryCapture = operatorKernelFactory;
ComPtr<IMLOperatorShapeInferrer> shapeInferrerCapture = shapeInferrer;
AttributeMap defaultAttributesCapture = GetDefaultAttributes(opKernel);
auto lotusKernelCreateFn = [
kernelFactoryCapture,
requiresInputShapesAtCreation,
requiresOutputShapesAtCreation,
isInternalOperator,
constantCpuInputCapture,
shapeInferrerCapture,
defaultAttributesCapture
](const onnxruntime::OpKernelInfo& info) -> onnxruntime::OpKernel*
{
return new AbiOpKernel(
kernelFactoryCapture.Get(),
info,
requiresInputShapesAtCreation,
requiresOutputShapesAtCreation,
isInternalOperator,
constantCpuInputCapture,
shapeInferrerCapture.Get(),
&defaultAttributesCapture);
};
onnxruntime::KernelCreateInfo create_info(builder.Build(), lotusKernelCreateFn);
if (supportsGraph)
{
// Only internal operators support usage in DML graphs
if (!isInternalOperator)
{
THROW_HR(E_INVALIDARG);
}
auto registration = std::make_shared<GraphNodeFactoryRegistration>();
registration->factory =
[kernelFactoryCapture,
requiresInputShapesAtCreation,
requiresOutputShapesAtCreation,
shapeInferrerCapture,
defaultAttributesCapture,
constantCpuInputCapture](const onnxruntime::Node& node, MLOperatorTensorGetter& constantInputGetter, const void* executionHandle, DmlGraphNodeCreateInfo* graphNodeCreateInfo)
{
onnxruntime::ProtoHelperNodeContext nodeContext(node);
onnxruntime::OpNodeProtoHelper<onnxruntime::ProtoHelperNodeContext> protoHelper(&nodeContext);
// Use the same list of required constant inputs for the shape inferrer and the kernel.
EdgeShapes outputShapes;
InferAndVerifyOutputSizes(node, &defaultAttributesCapture, shapeInferrerCapture.Get(), constantCpuInputCapture, constantInputGetter, nullptr, outputShapes);
// Create the kernel while allowing input shape and output shape queries according to options
ComPtr<DmlGraphOpKernelInfoWrapper> kernelInfoWrapper = wil::MakeOrThrow<DmlGraphOpKernelInfoWrapper>(
&protoHelper,
executionHandle,
true,
&outputShapes,
&defaultAttributesCapture,
graphNodeCreateInfo,
constantCpuInputCapture,
constantInputGetter);
Microsoft::WRL::ComPtr<IMLOperatorKernel> kernel;
THROW_IF_FAILED(kernelFactoryCapture->CreateKernel(kernelInfoWrapper.Get(), kernel.GetAddressOf()));
kernelInfoWrapper->Close();
};
if (requiredInputCountForGraph)
{
registration->requiredInputCount = *requiredInputCountForGraph;
}
registration->requiresFloatFormatsExceptConstInputs = requiresFloatFormatsForGraph;
registration->requiredConstantCpuInputs = constantCpuInputCapture;
(*m_graphNodeFactoryMap)[create_info.kernel_def.get()] = registration;
}
m_kernelRegistry->RegisterCustomKernel(create_info);
return S_OK;
}
CATCH_RETURN();
}

Просмотреть файл

@ -0,0 +1,111 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h"
namespace WRL
{
template <typename... TInterfaces>
using Base = Microsoft::WRL::RuntimeClass<
Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>,
TInterfaces...
>;
}
namespace winrt::Windows::AI::MachineLearning::implementation
{
using namespace Microsoft::WRL;
class AbiCustomRegistry : public WRL::Base<IMLOperatorRegistry, IMLOperatorRegistryPrivate>
{
public:
AbiCustomRegistry();
HRESULT STDMETHODCALLTYPE RegisterOperatorSetSchema(
const MLOperatorSetId* opSetId,
int baseline_version,
const MLOperatorSchemaDescription* const* schema,
uint32_t schemaCount,
_In_opt_ IMLOperatorTypeInferrer* typeInferrer,
_In_opt_ IMLOperatorShapeInferrer* shapeInferrer) const noexcept override;
HRESULT STDMETHODCALLTYPE RegisterOperatorKernel(
const MLOperatorKernelDescription* operatorKernel,
IMLOperatorKernelFactory* operatorKernelFactory,
_In_opt_ IMLOperatorShapeInferrer* shapeInferrer,
bool isInternalOperator,
bool canAliasFirstInput,
bool supportsGraph,
const uint32_t* requiredInputCountForGraph = nullptr,
bool requiresFloatFormatsForGraph = false,
_In_reads_(constantCpuInputCount) const uint32_t* requiredConstantCpuInputs = nullptr,
uint32_t constantCpuInputCount = 0) const noexcept override;
HRESULT STDMETHODCALLTYPE RegisterOperatorKernel(
const MLOperatorKernelDescription* opKernel,
IMLOperatorKernelFactory* operatorKernelFactory,
_In_opt_ IMLOperatorShapeInferrer* shapeInferrer) const noexcept override;
std::list<std::shared_ptr<onnxruntime::CustomRegistry>> GetRegistries()
{
std::list<std::shared_ptr<onnxruntime::CustomRegistry>> registries;
for (auto& registry : m_customRegistryOpsetVerMap)
{
registries.push_back(registry.second);
}
registries.push_back(m_kernelRegistry);
return registries;
}
std::list<std::shared_ptr<onnxruntime::IOnnxRuntimeOpSchemaCollection>> GetSchemaRegistries()
{
std::list<std::shared_ptr<onnxruntime::IOnnxRuntimeOpSchemaCollection>> registries;
for (auto& registry : m_customRegistryOpsetVerMap)
{
registries.push_back(registry.second->GetOpschemaRegistry());
}
return registries;
}
std::shared_ptr<onnxruntime::CustomRegistry> GetLotusKernelRegistry()
{
return m_kernelRegistry;
}
std::shared_ptr<GraphNodeFactoryMap> GetGraphNodeFactoryMap() const
{
return m_graphNodeFactoryMap;
}
private:
static onnx::OpSchema ConvertOpSchema(
_In_z_ const char* domain,
const MLOperatorSchemaDescription& abiSchema,
IMLOperatorTypeInferrer* typeInferrer,
IMLOperatorShapeInferrer* shapeInferrer);
static std::string ConvertFormalParameterType(const MLOperatorSchemaEdgeDescription& formalParameter);
static onnx::OpSchema::FormalParameterOption ConvertFormalParameterOption(MLOperatorParameterOptions options);
static void SetAttributesAndDefaults(onnx::OpSchema& schema, const MLOperatorSchemaDescription& abiSchema);
static AttributeMap GetDefaultAttributes(const MLOperatorKernelDescription* opKernel);
std::shared_ptr<onnxruntime::CustomRegistry> m_kernelRegistry;
// Map between (baseline version, opset version) and registries. This ensures that no registry has multiple
// versions of the same domain within it. This works around limitations in Lotus op-set version arbitration
// (see LotusOpSchemaRegistry::GetSchemaAndHistory).
mutable std::map<std::pair<int, int>, std::shared_ptr<onnxruntime::CustomRegistry>> m_customRegistryOpsetVerMap;
// Map between Lotus KernelDefs and graph node factories used for fusing nodes for graph compilation
mutable std::shared_ptr<GraphNodeFactoryMap> m_graphNodeFactoryMap;
};
} // namespace winrt::Windows::AI::MachineLearning::implementation

Просмотреть файл

@ -0,0 +1,255 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
#include "core/session/onnxruntime_c_api.h"
#include "BucketizedBufferAllocator.h"
// #define PRINT_OUTSTANDING_ALLOCATIONS
namespace Dml
{
AllocationInfo::~AllocationInfo()
{
if (m_owner)
{
m_owner->FreeResource(this, m_pooledResourceId);
}
}
BucketizedBufferAllocator::~BucketizedBufferAllocator()
{
#ifdef PRINT_OUTSTANDING_ALLOCATIONS
if (!m_outstandingAllocationsById.empty())
{
printf("BucketizedBufferAllocator outstanding allocation indices:\n");
for (auto& entry : m_outstandingAllocationsById)
{
printf("%u\n", static_cast<int>(entry.first));
}
printf("\n");
}
#endif
}
BucketizedBufferAllocator::BucketizedBufferAllocator(
ID3D12Device* device,
std::shared_ptr<ExecutionContext> context,
const D3D12_HEAP_PROPERTIES& heapProps,
D3D12_HEAP_FLAGS heapFlags,
D3D12_RESOURCE_FLAGS resourceFlags,
D3D12_RESOURCE_STATES initialState)
: m_device(device)
, m_context(context)
, m_heapProperties(heapProps)
, m_heapFlags(heapFlags)
, m_resourceFlags(resourceFlags)
, m_initialState(initialState)
{
}
/*static*/ gsl::index BucketizedBufferAllocator::GetBucketIndexFromSize(uint64_t size)
{
assert(size != 0);
// Each bucket is twice as large as the previous one, in ascending order
gsl::index index = static_cast<gsl::index>(ceil(log2(size)));
assert((1ull << index) >= size); // This must be true unless there were some strange rounding issues
// The smallest bucket is 2^n bytes large, where n = c_minResourceSizeExponent
index = std::max<gsl::index>(index, c_minResourceSizeExponent);
index -= c_minResourceSizeExponent;
return index;
}
/*static*/ uint64_t BucketizedBufferAllocator::GetBucketSizeFromIndex(gsl::index index)
{
return (1ull << (index + c_minResourceSizeExponent));
}
void* BucketizedBufferAllocator::Alloc(size_t size)
{
return Alloc(size, m_defaultRoundingMode);
}
void* BucketizedBufferAllocator::Alloc(size_t size, AllocatorRoundingMode roundingMode)
{
// For some reason lotus likes requesting 0 bytes of memory
size = std::max<size_t>(1, size);
ComPtr<ID3D12Resource> resource;
uint64_t resourceId = 0;
uint64_t bucketSize = 0;
// Use a pooled resource if the size (post rounding, if requested) matches a bucket size
if (m_defaultRoundingMode == AllocatorRoundingMode::Enabled || size == GetBucketSizeFromIndex(GetBucketIndexFromSize(size)))
{
Bucket* bucket = nullptr;
// Find the bucket for this allocation size
gsl::index bucketIndex = GetBucketIndexFromSize(size);
if (gsl::narrow_cast<gsl::index>(m_pool.size()) <= bucketIndex)
{
// Ensure there are sufficient buckets
m_pool.resize(bucketIndex + 1);
}
bucket = &m_pool[bucketIndex];
bucketSize = GetBucketSizeFromIndex(bucketIndex);
if (bucket->resources.empty())
{
// No more resources in this bucket - allocate a new one
THROW_IF_FAILED(m_device->CreateCommittedResource(
&m_heapProperties,
m_heapFlags,
&CD3DX12_RESOURCE_DESC::Buffer(bucketSize, m_resourceFlags),
m_initialState,
nullptr,
IID_PPV_ARGS(&resource)));
resourceId = ++m_currentResourceId;
}
else
{
// Retrieve a resource from the bucket
resource = std::move(bucket->resources.back().resource);
resourceId = bucket->resources.back().resourceId;
bucket->resources.pop_back();
}
}
else
{
// The allocation will not be pooled. Construct a new one
bucketSize = (size + 3) & ~3;
THROW_IF_FAILED(m_device->CreateCommittedResource(
&m_heapProperties,
m_heapFlags,
&CD3DX12_RESOURCE_DESC::Buffer(bucketSize, m_resourceFlags),
m_initialState,
nullptr,
IID_PPV_ARGS(&resource)));
resourceId = ++m_currentResourceId;
}
assert(resource->GetDesc().Width == bucketSize);
assert(resource != nullptr);
ComPtr<AllocationInfo> allocInfo = wil::MakeOrThrow<AllocationInfo>(
this,
++m_currentAllocationId,
resourceId,
resource.Get(),
size);
#if _DEBUG
m_outstandingAllocationsById[allocInfo->GetId()] = allocInfo.Get();
#endif
return allocInfo.Detach();
}
void BucketizedBufferAllocator::Free(void* p)
{
// Release Lotus's reference on the allocation. The allocation
// also inherits IUnknown, and once its final reference reaches zero
// it will call FreeResource
ComPtr<AllocationInfo> allocInfo;
allocInfo.Attach(static_cast<AllocationInfo*>(p));
}
void BucketizedBufferAllocator::FreeResource(void* p, uint64_t pooledResourceId)
{
AllocationInfo *allocInfo = static_cast<AllocationInfo*>(p);
assert(allocInfo != nullptr); // Can't free nullptr
if (allocInfo->GetOwner() != this)
{
// This allocation doesn't belong to this allocator!
THROW_HR(E_INVALIDARG);
}
// Free the resource to the pool if its size matches a bucket size
gsl::index bucketIndex = GetBucketIndexFromSize(allocInfo->GetRequestedSize());
if (GetBucketSizeFromIndex(bucketIndex) == allocInfo->GetResource()->GetDesc().Width)
{
assert(gsl::narrow_cast<gsl::index>(m_pool.size()) > bucketIndex);
// Return the resource to the bucket
Bucket* bucket = &m_pool[bucketIndex];
Resource resource = {std::move(allocInfo->DetachResource()), pooledResourceId};
bucket->resources.push_back(resource);
}
else
{
// Free the underlying allocation once queued work has completed.
m_context->QueueReference(allocInfo->GetResource());
allocInfo->DetachResource();
}
#if _DEBUG
assert(m_outstandingAllocationsById[allocInfo->GetId()] == allocInfo);
m_outstandingAllocationsById.erase(allocInfo->GetId());
#endif
// The allocation info is already destructing at this point
}
const ::OrtMemoryInfo& BucketizedBufferAllocator::Info() const
{
static const ::OrtMemoryInfo sc_info("DML allocator", ::OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0));
return sc_info;
}
const AllocationInfo* BucketizedBufferAllocator::DecodeDataHandle(const void* opaqueHandle)
{
const auto* allocInfo = static_cast<const AllocationInfo*>(opaqueHandle);
auto owner = allocInfo->GetOwner();
//The owner can be null if the resource was wrapped via CreateGPUAllocationFromD3DResource
if (owner != nullptr && owner != this)
{
// This allocation doesn't belong to this allocator!
THROW_HR(E_INVALIDARG);
}
return allocInfo;
}
void BucketizedBufferAllocator::SetDefaultRoundingMode(AllocatorRoundingMode roundingMode)
{
m_defaultRoundingMode = roundingMode;
}
CPUAllocator::CPUAllocator(OrtMemType memType)
: m_allocatorInfo("DML CPU", ::OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0), 0, memType)
{
}
void* CPUAllocator::Alloc(size_t size) {
if (size <= 0)
{
return nullptr;
}
void* p = malloc(size);
return p;
}
void CPUAllocator::Free(void* p) {
free(p);
}
const ::OrtMemoryInfo& CPUAllocator::Info() const {
return m_allocatorInfo;
}
} // namespace Dml

Просмотреть файл

@ -0,0 +1,163 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/framework/allocator.h"
#include "ExecutionContext.h"
namespace Dml
{
class CPUAllocator : public onnxruntime::IDeviceAllocator
{
public:
explicit CPUAllocator(OrtMemType memType);
void* Alloc(size_t size) override;
void Free(void* p) override;
const ::OrtMemoryInfo& Info() const override;
private:
OrtMemoryInfo m_allocatorInfo;
};
class BucketizedBufferAllocator;
class AllocationInfo : public Microsoft::WRL::RuntimeClass<
Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>, IUnknown>
{
public:
AllocationInfo(
BucketizedBufferAllocator* owner,
size_t id,
uint64_t pooledResourceId,
ID3D12Resource* resource,
size_t requestedSize)
: m_owner(owner)
, m_allocationId(id)
, m_pooledResourceId(pooledResourceId)
, m_resource(resource)
, m_requestedSize(requestedSize)
{}
~AllocationInfo();
BucketizedBufferAllocator* GetOwner() const
{
return m_owner;
}
ID3D12Resource* GetResource() const
{
return m_resource.Get();
}
ComPtr<ID3D12Resource> DetachResource() const
{
return std::move(m_resource);
}
size_t GetRequestedSize() const
{
return m_requestedSize;
}
size_t GetId() const
{
return m_allocationId;
}
uint64_t GetPooledResourceId() const
{
return m_pooledResourceId;
}
private:
BucketizedBufferAllocator* m_owner;
size_t m_allocationId; // For debugging purposes
uint64_t m_pooledResourceId = 0;
ComPtr<ID3D12Resource> m_resource;
// The size requested during Alloc(), which may be smaller than the physical resource size
size_t m_requestedSize;
};
// Implements a Lotus allocator for D3D12 heap buffers, using a bucket allocation strategy. The allocator
// maintains a set of fixed-size buckets, with each bucket containing one or more D3D12 buffers of that fixed size.
// All requested allocation sizes are rounded up to the nearest bucket size, which ensures minimal fragmentation
// while providing an upper bound on the amount of memory "wasted" with each allocation.
class BucketizedBufferAllocator : public onnxruntime::IAllocator
{
public:
~BucketizedBufferAllocator();
// Constructs a BucketizedBufferAllocator which allocates D3D12 committed resources with the specified heap properties,
// resource flags, and initial resource state.
BucketizedBufferAllocator(
ID3D12Device* device,
std::shared_ptr<ExecutionContext> context,
const D3D12_HEAP_PROPERTIES& heapProps,
D3D12_HEAP_FLAGS heapFlags,
D3D12_RESOURCE_FLAGS resourceFlags,
D3D12_RESOURCE_STATES initialState);
// Returns the information associated with an opaque allocation handle returned by IAllocator::Alloc.
const AllocationInfo* DecodeDataHandle(const void* opaqueHandle);
void SetDefaultRoundingMode(AllocatorRoundingMode roundingMode);
public: // onnxruntime::IAllocator
void* Alloc(size_t size, AllocatorRoundingMode roundingMode);
void* Alloc(size_t size) final;
void Free(void* p) final;
const ::OrtMemoryInfo& Info() const final;
private:
static const uint32_t c_minResourceSizeExponent = 16; // 2^16 = 64KB
// The pool consists of a number of buckets, and each bucket contains a number of resources of the same size.
// The resources in each bucket are always sized as a power of two, and each bucket contains resources twice
// as large as the previous bucket.
struct Resource
{
ComPtr<ID3D12Resource> resource;
uint64_t resourceId;
};
struct Bucket
{
std::vector<Resource> resources;
};
static gsl::index GetBucketIndexFromSize(uint64_t size);
static uint64_t GetBucketSizeFromIndex(gsl::index index);
AllocationInfo* DecodeDataHandleInternal(void* opaqueHandle)
{
// Implement in terms of const version
return const_cast<AllocationInfo*>(DecodeDataHandle(static_cast<const void*>(opaqueHandle)));
}
friend class AllocationInfo;
void FreeResource(void* p, uint64_t resourceId);
ComPtr<ID3D12Device> m_device;
D3D12_HEAP_PROPERTIES m_heapProperties;
D3D12_HEAP_FLAGS m_heapFlags;
D3D12_RESOURCE_FLAGS m_resourceFlags;
D3D12_RESOURCE_STATES m_initialState;
std::vector<Bucket> m_pool;
size_t m_currentAllocationId = 0;
uint64_t m_currentResourceId = 0;
AllocatorRoundingMode m_defaultRoundingMode = AllocatorRoundingMode::Enabled;
std::shared_ptr<ExecutionContext> m_context;
#if _DEBUG
// Useful for debugging; keeps track of all allocations that haven't been freed yet
std::map<size_t, AllocationInfo*> m_outstandingAllocationsById;
#endif
};
} // namespace Dml

Просмотреть файл

@ -0,0 +1,69 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "GpuEvent.h"
namespace Dml
{
// A fixed-size ring of command allocators. Each time an allocator is retrieved, the allocator will
// be reset if its previously recorded commands have finished executing on the GPU.
template <size_t AllocatorCount>
class CommandAllocatorRing
{
public:
CommandAllocatorRing(
ID3D12Device* device,
D3D12_COMMAND_LIST_TYPE commandListType,
GpuEvent initialEvent)
{
for (auto& info : m_commandAllocators)
{
THROW_IF_FAILED(device->CreateCommandAllocator(
commandListType,
IID_PPV_ARGS(&info.allocator)));
info.completionEvent = initialEvent;
}
}
ID3D12CommandAllocator* GetCurrentAllocator()
{
CommandAllocatorInfo& allocatorInfo = m_commandAllocators[m_currentCommandAllocator];
// Take the opportunity to reset the command allocator if possible.
if (allocatorInfo.completionEvent.IsSignaled())
{
THROW_IF_FAILED(allocatorInfo.Get()->Reset());
}
return m_commandAllocators[m_currentCommandAllocator].Get();
}
void AdvanceAllocator(GpuEvent completionEvent)
{
// Set the completion event for the current allocator so it can be reset eventually.
m_commandAllocators[m_currentCommandAllocator].completionEvent = completionEvent;
// Advance to the next allocator.
m_currentCommandAllocator = (m_currentCommandAllocator + 1) % AllocatorCount;
}
private:
struct CommandAllocatorInfo
{
ComPtr<ID3D12CommandAllocator> allocator;
// The event which will be signaled when the last command list submitted using this allocator
// completes execution on the GPU.
GpuEvent completionEvent = {};
ID3D12CommandAllocator* Get() const { return allocator.Get(); }
};
std::array<CommandAllocatorInfo, AllocatorCount> m_commandAllocators;
size_t m_currentCommandAllocator = 0;
};
}

Просмотреть файл

@ -0,0 +1,94 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
#include "CommandQueue.h"
namespace Dml
{
CommandQueue::CommandQueue(ID3D12CommandQueue* existingQueue)
: m_queue(existingQueue)
, m_type(existingQueue->GetDesc().Type)
{
ComPtr<ID3D12Device> device;
THROW_IF_FAILED(m_queue->GetDevice(IID_PPV_ARGS(&device)));
THROW_IF_FAILED(device->CreateFence(0, D3D12_FENCE_FLAG_NONE, IID_PPV_ARGS(&m_fence)));
}
void CommandQueue::ExecuteCommandList(ID3D12CommandList* commandList)
{
ExecuteCommandLists(gsl::make_span(&commandList, 1));
}
void CommandQueue::ExecuteCommandLists(gsl::span<ID3D12CommandList*> commandLists)
{
m_queue->ExecuteCommandLists(gsl::narrow<uint32_t>(commandLists.size()), commandLists.data());
++m_lastFenceValue;
THROW_IF_FAILED(m_queue->Signal(m_fence.Get(), m_lastFenceValue));
}
void CommandQueue::Wait(ID3D12Fence* fence, uint64_t value)
{
THROW_IF_FAILED(m_queue->Wait(fence, value));
++m_lastFenceValue;
THROW_IF_FAILED(m_queue->Signal(m_fence.Get(), m_lastFenceValue));
}
GpuEvent CommandQueue::GetCurrentCompletionEvent()
{
return GpuEvent{ m_lastFenceValue, m_fence };
}
GpuEvent CommandQueue::GetNextCompletionEvent()
{
return GpuEvent{ m_lastFenceValue + 1, m_fence };
}
void CommandQueue::QueueReference(IUnknown* object, bool waitForUnsubmittedWork)
{
// If the CommandQueue is closing, then m_queuedReferences is being cleared -- it is not OK
// to queue additional references at this time, since those references would be leaked. This
// affects any objects in m_queuedReferences whose destructors indirectly call QueueReference;
// for example, an allocation from BucketizedBufferAllocator attempts to queue a reference
// to its underlying D3D resource when freed. Furthermore, these references are unnecessary
// since Close() already blocks for scheduled GPU work before clearing m_queuedReferences.
if (!m_closing)
{
QueuedReference queuedReference = {GetLastFenceValue(), object};
// If something has been recorded into a command list but not submitted yet, it means that the *next* fence
// value is the one to signal completion.
if (waitForUnsubmittedWork)
{
++queuedReference.fenceValue;
}
m_queuedReferences.push_back(queuedReference);
}
}
void CommandQueue::Close()
{
// Wait for flushed work:
assert(!m_closing);
m_closing = true;
GpuEvent event = GetCurrentCompletionEvent();
event.WaitForSignal();
m_queuedReferences.clear();
m_closing = false;
}
void CommandQueue::ReleaseCompletedReferences()
{
uint64_t completedValue = GetFence()->GetCompletedValue();
while (!m_queuedReferences.empty() && m_queuedReferences.front().fenceValue <= completedValue)
{
m_queuedReferences.pop_front();
}
}
} // namespace Dml

Просмотреть файл

@ -0,0 +1,56 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "GpuEvent.h"
namespace Dml
{
// Manages a D3D12 command queue and provides a waitable fence which is signaled with a monotonically increasing
// value once each execute completes on the GPU.
class CommandQueue
{
public:
// Creates a CommandQueue object that wraps an existing D3D12 queue.
CommandQueue(ID3D12CommandQueue* existingQueue);
D3D12_COMMAND_LIST_TYPE GetType() const { return m_type; }
ComPtr<ID3D12Fence> GetFence() const { return m_fence; }
uint64_t GetLastFenceValue() const { return m_lastFenceValue; }
void ExecuteCommandList(ID3D12CommandList* commandList);
void ExecuteCommandLists(gsl::span<ID3D12CommandList*> commandLists);
// Queues a wait to block the GPU until the specified fence is signaled to a given value.
void Wait(ID3D12Fence* fence, uint64_t value);
// Returns an event that will become signaled when everything submitted to the queue thus far has
// completed execution on the GPU.
GpuEvent GetCurrentCompletionEvent();
// Returns an event that will become signaled after the next ExecuteCommandLists call.
GpuEvent GetNextCompletionEvent();
void QueueReference(IUnknown* object, bool waitForUnsubmittedWork);
void Close();
void ReleaseCompletedReferences();
private:
struct QueuedReference
{
uint64_t fenceValue;
ComPtr<IUnknown> object;
};
std::deque<QueuedReference> m_queuedReferences;
ComPtr<ID3D12CommandQueue> m_queue;
D3D12_COMMAND_LIST_TYPE m_type;
ComPtr<ID3D12Fence> m_fence;
uint64_t m_lastFenceValue = 0;
bool m_closing = false;
};
} // namespace Dml

Просмотреть файл

@ -0,0 +1,128 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
DescriptorHeap::DescriptorHeap(ID3D12DescriptorHeap* heap) :
m_heap(heap),
m_capacity(heap->GetDesc().NumDescriptors),
m_headCpuHandle(heap->GetCPUDescriptorHandleForHeapStart()),
m_headGpuHandle(heap->GetGPUDescriptorHandleForHeapStart()),
m_heapFlags(heap->GetDesc().Flags)
{
ComPtr<ID3D12Device> device;
THROW_IF_FAILED(heap->GetDevice(IID_PPV_ARGS(&device)));
m_handleIncrementSize = device->GetDescriptorHandleIncrementSize(heap->GetDesc().Type);
}
std::optional<DescriptorRange> DescriptorHeap::TryAllocDescriptors(
uint32_t numDescriptors,
GpuEvent completionEvent,
D3D12_DESCRIPTOR_HEAP_FLAGS heapFlags
)
{
// Bail if the desired heap creation flags are incompatible with the existing heap.
if (m_heapFlags != heapFlags)
{
return std::nullopt;
}
if ((m_completionEvent.fence != nullptr) && (m_completionEvent.IsSignaled()))
{
// This class always allocates descriptors from the end of the heap.
// If the most recent completion event is signaled, then all previous
// allocations have completed; the entire capacity is available to use.
m_size = 0;
m_headCpuHandle = m_heap->GetCPUDescriptorHandleForHeapStart();
m_headGpuHandle = m_heap->GetGPUDescriptorHandleForHeapStart();
}
// The caller will need to create a new heap if there is no space left in this one.
uint32_t spaceRemaining = m_capacity - m_size;
if (spaceRemaining < numDescriptors)
{
return std::nullopt;
}
DescriptorRange range = { m_heap.Get(), m_headCpuHandle, m_headGpuHandle };
m_size += numDescriptors;
m_completionEvent = completionEvent;
m_headCpuHandle.Offset(numDescriptors, m_handleIncrementSize);
m_headGpuHandle.Offset(numDescriptors, m_handleIncrementSize);
return range;
}
DescriptorPool::DescriptorPool(ID3D12Device* device, uint32_t initialCapacity) :
m_device(device),
m_initialHeapCapacity(initialCapacity)
{
CreateHeap(initialCapacity, D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE);
}
DescriptorRange DescriptorPool::AllocDescriptors(
uint32_t numDescriptors,
GpuEvent completionEvent,
D3D12_DESCRIPTOR_HEAP_FLAGS heapFlags
)
{
// Attempt to allocate from an existing heap.
for (DescriptorHeap& heap : m_heaps)
{
auto descriptorRange = heap.TryAllocDescriptors(numDescriptors, completionEvent, heapFlags);
if (descriptorRange.has_value())
{
return descriptorRange.value();
}
}
// A new descriptor heap must be created.
uint32_t newHeapCapacity = std::max(numDescriptors, m_initialHeapCapacity);
CreateHeap(newHeapCapacity, heapFlags);
auto descriptorRange = m_heaps.back().TryAllocDescriptors(numDescriptors, completionEvent, heapFlags);
assert(descriptorRange.has_value());
return descriptorRange.value();
}
void DescriptorPool::Trim()
{
// Remove any heaps that are not pending execution.
auto it = std::remove_if(m_heaps.begin(), m_heaps.end(), [](const DescriptorHeap& heap) {
auto completionEvent = heap.GetLastCompletionEvent();
return !completionEvent.fence || completionEvent.IsSignaled();
});
m_heaps.erase(it, m_heaps.end());
}
void DescriptorPool::CreateHeap(uint32_t numDescriptors, D3D12_DESCRIPTOR_HEAP_FLAGS heapFlags)
{
// This pool only manages CBV/SRV/UAV descriptors.
D3D12_DESCRIPTOR_HEAP_DESC desc = {};
desc.Flags = heapFlags;
desc.NumDescriptors = numDescriptors;
desc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV;
ComPtr<ID3D12DescriptorHeap> heap;
THROW_IF_FAILED(m_device->CreateDescriptorHeap(&desc, IID_PPV_ARGS(&heap)));
m_heaps.push_back(DescriptorHeap{heap.Get()});
}
uint32_t DescriptorPool::GetTotalCapacity() const
{
uint32_t capacity = 0;
for (auto& heap : m_heaps)
{
capacity += heap.GetCapacity();
}
return capacity;
}
}

Просмотреть файл

@ -0,0 +1,87 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "GpuEvent.h"
namespace Dml
{
// A contiguous range of descriptors.
struct DescriptorRange
{
ID3D12DescriptorHeap* heap;
D3D12_CPU_DESCRIPTOR_HANDLE cpuHandle;
D3D12_GPU_DESCRIPTOR_HANDLE gpuHandle;
};
// Wraps an ID3D12DescriptorHeap to allocate descriptor ranges.
class DescriptorHeap
{
public:
// Wraps an existing heap.
explicit DescriptorHeap(ID3D12DescriptorHeap* heap);
// Reserves descriptors from the end of the heap. Returns nullopt if there is
// no space left in the heap.
std::optional<DescriptorRange> TryAllocDescriptors(
uint32_t numDescriptors,
GpuEvent completionEvent,
D3D12_DESCRIPTOR_HEAP_FLAGS heapFlags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE
);
GpuEvent GetLastCompletionEvent() const
{
return m_completionEvent;
}
uint32_t GetCapacity() const
{
return m_capacity;
}
private:
ComPtr<ID3D12DescriptorHeap> m_heap;
uint32_t m_capacity = 0;
uint32_t m_size = 0;
uint32_t m_handleIncrementSize = 0;
CD3DX12_CPU_DESCRIPTOR_HANDLE m_headCpuHandle;
CD3DX12_GPU_DESCRIPTOR_HANDLE m_headGpuHandle;
D3D12_DESCRIPTOR_HEAP_FLAGS m_heapFlags = D3D12_DESCRIPTOR_HEAP_FLAG_NONE;
// Most recent GPU completion event. Allocations are always done at the end,
// so there is no fragmentation of the heap.
GpuEvent m_completionEvent;
};
// Manages a pool of CBV/SRV/UAV descriptors.
class DescriptorPool
{
public:
DescriptorPool(ID3D12Device* device, uint32_t initialCapacity);
// Reserves a contiguous range of descriptors from a single descriptor heap. The
// lifetime of the referenced descriptor heap is managed by the DescriptorPool class.
// The caller must supply a GpuEvent that informs the pool when the reserved descriptors
// are no longer required.
DescriptorRange AllocDescriptors(
uint32_t numDescriptors,
GpuEvent completionEvent,
D3D12_DESCRIPTOR_HEAP_FLAGS heapFlags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE
);
// Releases all descriptor heaps that contain only descriptors which have completed
// their work on the GPU.
void Trim();
// Returns the total capacity of all heaps.
uint32_t GetTotalCapacity() const;
private:
ComPtr<ID3D12Device> m_device;
std::vector<DescriptorHeap> m_heaps;
const uint32_t m_initialHeapCapacity;
void CreateHeap(uint32_t numDescriptors, D3D12_DESCRIPTOR_HEAP_FLAGS heapFlags);
};
}

Просмотреть файл

@ -0,0 +1,383 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
#include "DmlCommandRecorder.h"
#include "CommandQueue.h"
#include "BucketizedBufferAllocator.h"
using namespace Dml;
DmlCommandRecorder::DmlCommandRecorder(
ID3D12Device* d3dDevice,
IDMLDevice* dmlDevice,
std::shared_ptr<CommandQueue> commandQueue)
: m_queue(std::move(commandQueue)),
m_d3dDevice(d3dDevice),
m_dmlDevice(dmlDevice),
m_descriptorPool(d3dDevice, 2048),
m_commandAllocatorRing(d3dDevice, m_queue->GetType(), m_queue->GetCurrentCompletionEvent())
{
THROW_IF_FAILED(dmlDevice->CreateOperatorInitializer(0, nullptr, IID_PPV_ARGS(&m_initializer)));
THROW_IF_FAILED(dmlDevice->CreateCommandRecorder(IID_PPV_ARGS(&m_recorder)));
}
void DmlCommandRecorder::SetAllocator(std::weak_ptr<BucketizedBufferAllocator> allocator)
{
m_bufferAllocator = allocator;
}
void DmlCommandRecorder::InitializeOperator(
IDMLCompiledOperator* op,
const DML_BINDING_DESC& persistentResourceBinding,
const DML_BINDING_DESC& inputArrayBinding)
{
// Reset the initializer to reference the input operator.
IDMLCompiledOperator* ops[] = { op };
THROW_IF_FAILED(m_initializer->Reset(ARRAYSIZE(ops), ops));
DML_BINDING_PROPERTIES initBindingProps = m_initializer->GetBindingProperties();
DML_BINDING_PROPERTIES execBindingProps = op->GetBindingProperties();
const uint32_t numDescriptors = initBindingProps.RequiredDescriptorCount;
DescriptorRange descriptorRange = m_descriptorPool.AllocDescriptors(
numDescriptors,
m_queue->GetNextCompletionEvent());
// Create a binding table for initialization.
DML_BINDING_TABLE_DESC bindingTableDesc = {};
bindingTableDesc.Dispatchable = m_initializer.Get();
bindingTableDesc.CPUDescriptorHandle = descriptorRange.cpuHandle;
bindingTableDesc.GPUDescriptorHandle = descriptorRange.gpuHandle;
bindingTableDesc.SizeInDescriptors = numDescriptors;
ComPtr<IDMLBindingTable> bindingTable;
THROW_IF_FAILED(m_dmlDevice->CreateBindingTable(&bindingTableDesc, IID_PPV_ARGS(&bindingTable)));
// Create a temporary resource for initializing the op, if it's required.
UINT64 temporaryResourceSize = initBindingProps.TemporaryResourceSize;
if (temporaryResourceSize > 0)
{
auto allocator = m_bufferAllocator.lock();
// Allocate and immediately free a temporary buffer. The buffer resource will still be
// alive (managed by the pool); freeing allows the resource to be shared with other operators.
void* tempResourceHandle = allocator->Alloc(temporaryResourceSize, AllocatorRoundingMode::Enabled);
if (!tempResourceHandle)
{
THROW_HR(E_OUTOFMEMORY);
}
ID3D12Resource* buffer = allocator->DecodeDataHandle(tempResourceHandle)->GetResource();
allocator->Free(tempResourceHandle);
// Bind the temporary resource.
DML_BUFFER_BINDING bufferBinding = { buffer, 0, temporaryResourceSize };
DML_BINDING_DESC bindingDesc = { DML_BINDING_TYPE_BUFFER, &bufferBinding };
bindingTable->BindTemporaryResource(&bindingDesc);
}
// Bind inputs, if provided.
if (inputArrayBinding.Type != DML_BINDING_TYPE_NONE)
{
// An operator with inputs to bind MUST use a BUFFER_ARRAY.
assert(inputArrayBinding.Type == DML_BINDING_TYPE_BUFFER_ARRAY);
bindingTable->BindInputs(1, &inputArrayBinding);
}
// Bind the persistent resource, which is an output of initialization.
if (persistentResourceBinding.Type != DML_BINDING_TYPE_NONE)
{
// Persistent resources MUST be bound as buffers.
assert(persistentResourceBinding.Type == DML_BINDING_TYPE_BUFFER);
bindingTable->BindOutputs(1, &persistentResourceBinding);
}
// Record the initialization work.
SetDescriptorHeap(descriptorRange.heap);
m_recorder->RecordDispatch(m_currentCommandList.Get(), m_initializer.Get(), bindingTable.Get());
m_operationsRecordedInCurrentCommandList = true;
// Barrier if there's an output (i.e. persistent resource), or if any temps are used.
if ((persistentResourceBinding.Type != DML_BINDING_TYPE_NONE) ||
(temporaryResourceSize > 0))
{
m_currentCommandList->ResourceBarrier(1, &CD3DX12_RESOURCE_BARRIER::UAV(nullptr));
}
}
void DmlCommandRecorder::ExecuteOperator(
IDMLCompiledOperator* op,
const DML_BINDING_DESC& persistentResourceBinding,
gsl::span<const DML_BINDING_DESC> inputBindings,
gsl::span<const DML_BINDING_DESC> outputBindings)
{
DML_BINDING_PROPERTIES execBindingProps = op->GetBindingProperties();
const uint32_t numDescriptors = execBindingProps.RequiredDescriptorCount;
DescriptorRange descriptorRange = m_descriptorPool.AllocDescriptors(
numDescriptors,
m_queue->GetNextCompletionEvent());
// Create a binding table for execution.
DML_BINDING_TABLE_DESC bindingTableDesc = {};
bindingTableDesc.Dispatchable = op;
bindingTableDesc.CPUDescriptorHandle = descriptorRange.cpuHandle;
bindingTableDesc.GPUDescriptorHandle = descriptorRange.gpuHandle;
bindingTableDesc.SizeInDescriptors = numDescriptors;
ComPtr<IDMLBindingTable> bindingTable;
THROW_IF_FAILED(m_dmlDevice->CreateBindingTable(&bindingTableDesc, IID_PPV_ARGS(&bindingTable)));
// Create a temporary resource for executing the op, if it's required.
UINT64 temporaryResourceSize = execBindingProps.TemporaryResourceSize;
if (temporaryResourceSize > 0)
{
auto allocator = m_bufferAllocator.lock();
// Allocate and immediately free a temporary buffer. The buffer resource will still be
// alive (managed by the pool); freeing allows the resource to be shared with other operators.
void* tempResourceHandle = allocator->Alloc(temporaryResourceSize, AllocatorRoundingMode::Enabled);
if (!tempResourceHandle)
{
THROW_HR(E_OUTOFMEMORY);
}
ID3D12Resource* buffer = allocator->DecodeDataHandle(tempResourceHandle)->GetResource();
allocator->Free(tempResourceHandle);
// Bind the temporary resource.
DML_BUFFER_BINDING bufferBinding = { buffer, 0, temporaryResourceSize };
DML_BINDING_DESC bindingDesc = { DML_BINDING_TYPE_BUFFER, &bufferBinding };
bindingTable->BindTemporaryResource(&bindingDesc);
}
if (persistentResourceBinding.Type != DML_BINDING_TYPE_NONE)
{
bindingTable->BindPersistentResource(&persistentResourceBinding);
}
bindingTable->BindInputs(gsl::narrow<uint32_t>(inputBindings.size()), inputBindings.data());
bindingTable->BindOutputs(gsl::narrow<uint32_t>(outputBindings.size()), outputBindings.data());
// Record the execution work.
SetDescriptorHeap(descriptorRange.heap);
m_recorder->RecordDispatch(m_currentCommandList.Get(), op, bindingTable.Get());
m_operationsRecordedInCurrentCommandList = true;
// Barrier all outputs.
m_currentCommandList->ResourceBarrier(1, &CD3DX12_RESOURCE_BARRIER::UAV(nullptr));
}
void DmlCommandRecorder::CopyBufferRegion(
ID3D12Resource* dstBuffer,
uint64_t dstOffset,
ID3D12Resource* srcBuffer,
uint64_t srcOffset,
uint64_t byteCount)
{
m_currentCommandList->CopyBufferRegion(dstBuffer, dstOffset, srcBuffer, srcOffset, byteCount);
m_operationsRecordedInCurrentCommandList = true;
}
void DmlCommandRecorder::FillBufferWithPattern(
ID3D12Resource* dstBuffer,
gsl::span<const std::byte> value /* Data type agnostic value, treated as raw bits */)
{
// The fill pattern for ClearUnorderedAccessViewUint is 16 bytes.
union
{
uint32_t integers[4];
std::byte bytes[16];
} fillPattern = {};
assert(ARRAYSIZE(fillPattern.bytes) == 16);
assert(value.size() <= ARRAYSIZE(fillPattern.bytes)); // No element is expected larger than 128 bits (e.g. complex128).
if (!value.empty())
{
assert(ARRAYSIZE(fillPattern.bytes) % value.size() == 0); // Should fit evenly into 16 bytes (e.g. uint8, float16, uint32, float64...).
// Repeat the value multiple times into the pattern buffer.
size_t valueIndex = 0;
for (std::byte& p : fillPattern.bytes)
{
p = value[valueIndex++];
valueIndex = (valueIndex == value.size()) ? 0 : valueIndex;
}
}
// Else just leave fill pattern as zeroes.
// Create a RAW buffer UAV over the resource.
D3D12_UNORDERED_ACCESS_VIEW_DESC uavDesc = {};
uavDesc.ViewDimension = D3D12_UAV_DIMENSION_BUFFER;
uavDesc.Format = DXGI_FORMAT_R32_TYPELESS;
uavDesc.Buffer.NumElements = gsl::narrow<uint32_t>(dstBuffer->GetDesc().Width / sizeof(uint32_t));
uavDesc.Buffer.Flags = D3D12_BUFFER_UAV_FLAG_RAW;
const uint32_t neededDescriptorCount = 1;
DescriptorRange descriptorRangeCpu = m_descriptorPool.AllocDescriptors(neededDescriptorCount, m_queue->GetNextCompletionEvent(), D3D12_DESCRIPTOR_HEAP_FLAG_NONE);
DescriptorRange descriptorRangeGpu = m_descriptorPool.AllocDescriptors(neededDescriptorCount, m_queue->GetNextCompletionEvent(), D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE);
m_d3dDevice->CreateUnorderedAccessView(dstBuffer, nullptr, &uavDesc, descriptorRangeCpu.cpuHandle);
m_d3dDevice->CreateUnorderedAccessView(dstBuffer, nullptr, &uavDesc, descriptorRangeGpu.cpuHandle);
SetDescriptorHeap(descriptorRangeGpu.heap);
// Record a ClearUAV onto the command list.
m_currentCommandList->ClearUnorderedAccessViewUint(
descriptorRangeGpu.gpuHandle,
descriptorRangeCpu.cpuHandle,
dstBuffer,
fillPattern.integers,
0,
nullptr);
m_operationsRecordedInCurrentCommandList = true;
// Barrier all outputs.
m_currentCommandList->ResourceBarrier(1, &CD3DX12_RESOURCE_BARRIER::UAV(nullptr));
}
void DmlCommandRecorder::ExecuteCommandList(
ID3D12GraphicsCommandList* commandList,
_Outptr_ ID3D12Fence** fence,
_Out_ uint64_t* completionValue
)
{
THROW_IF_FAILED(m_currentCommandList->Close());
if (m_operationsRecordedInCurrentCommandList)
{
m_pendingCommandLists.push_back(m_currentCommandList.Get());
m_pendingCommandListsCacheable.push_back(true);
}
else
{
m_cachedCommandLists.push_back(m_currentCommandList.Get());
}
m_currentCommandList = nullptr;
m_operationsRecordedInCurrentCommandList = false;
m_pendingCommandLists.push_back(commandList);
m_pendingCommandListsCacheable.push_back(false);
// Remember the descriptor heap and apply it to the next command list
auto heap = m_currentDescriptorHeap;
m_currentDescriptorHeap = nullptr;
Open();
// The caller can re-use relevent resources after the next set of work to be
// flushed has completed. Its command list hasn't been executed yet, just batched.
GpuEvent gpuEvent = m_queue->GetNextCompletionEvent();
gpuEvent.fence.CopyTo(fence);
*completionValue = gpuEvent.fenceValue;
// Trigger a flush of the command list, with the assumption that it contains enough GPU work that this
// will help parallelize GPU work with subsequent CPU work. This policy is related to the choice of
// minNodeCountToReuseCommandList within FusedGraphKernel, so both should be tuned together.
CloseAndExecute();
Open();
SetDescriptorHeap(heap);
}
ComPtr<ID3D12GraphicsCommandList> DmlCommandRecorder::GetCommandList()
{
// Assume operations are added by the caller after this returns
m_operationsRecordedInCurrentCommandList = true;
return m_currentCommandList;
}
void DmlCommandRecorder::ResourceBarrier(gsl::span<const D3D12_RESOURCE_BARRIER> barriers)
{
m_currentCommandList->ResourceBarrier(gsl::narrow_cast<uint32_t>(barriers.size()), barriers.data());
m_operationsRecordedInCurrentCommandList = true;
}
void DmlCommandRecorder::AddUAVBarrier()
{
auto barrier = CD3DX12_RESOURCE_BARRIER::UAV(nullptr);
m_currentCommandList->ResourceBarrier(1, &barrier);
m_operationsRecordedInCurrentCommandList = true;
}
void DmlCommandRecorder::Open()
{
assert(m_currentDescriptorHeap == nullptr);
ID3D12CommandAllocator* allocator = m_commandAllocatorRing.GetCurrentAllocator();
if (m_cachedCommandLists.empty())
{
THROW_IF_FAILED(m_d3dDevice->CreateCommandList(
0,
m_queue->GetType(),
m_commandAllocatorRing.GetCurrentAllocator(),
nullptr,
IID_PPV_ARGS(&m_currentCommandList)));
}
else
{
m_currentCommandList = m_cachedCommandLists.front();
m_cachedCommandLists.pop_front();
THROW_IF_FAILED(m_currentCommandList->Reset(allocator, nullptr));
}
// The current command allocator will become eligible for reset once this command list completes execution
m_commandAllocatorRing.AdvanceAllocator(m_queue->GetNextCompletionEvent());
}
void DmlCommandRecorder::CloseAndExecute()
{
THROW_IF_FAILED(m_currentCommandList->Close());
if (m_operationsRecordedInCurrentCommandList)
{
m_pendingCommandLists.push_back(m_currentCommandList.Get());
m_pendingCommandListsCacheable.push_back(true);
}
else
{
m_cachedCommandLists.push_back(m_currentCommandList.Get());
}
m_currentCommandList = nullptr;
m_operationsRecordedInCurrentCommandList = false;
if (!m_pendingCommandLists.empty())
{
// Close and execute the command list
m_queue->ExecuteCommandLists(
gsl::span<ID3D12CommandList*>(reinterpret_cast<ID3D12CommandList**>(m_pendingCommandLists.data()), m_pendingCommandLists.size()));
assert(m_pendingCommandLists.size() == m_pendingCommandListsCacheable.size());
for (size_t i = 0; i < m_pendingCommandLists.size(); ++i)
{
if (m_pendingCommandListsCacheable[i])
{
m_cachedCommandLists.push_back(m_pendingCommandLists[i]);
}
}
m_pendingCommandLists.clear();
m_pendingCommandListsCacheable.clear();
}
// The descriptor heap must be set on the command list the next time it's opened.
m_currentDescriptorHeap = nullptr;
// Fail early if something horrifying happens
THROW_IF_FAILED(m_dmlDevice->GetDeviceRemovedReason());
THROW_IF_FAILED(m_d3dDevice->GetDeviceRemovedReason());
}
void DmlCommandRecorder::SetDescriptorHeap(ID3D12DescriptorHeap* descriptorHeap)
{
if (descriptorHeap != nullptr && descriptorHeap != m_currentDescriptorHeap)
{
m_currentDescriptorHeap = descriptorHeap;
ID3D12DescriptorHeap* descriptorHeaps[] = { descriptorHeap };
m_currentCommandList->SetDescriptorHeaps(ARRAYSIZE(descriptorHeaps), descriptorHeaps);
}
}

Просмотреть файл

@ -0,0 +1,94 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "ICommandRecorder.h"
#include "CommandAllocatorRing.h"
namespace Dml
{
class CommandQueue;
class BucketizedBufferAllocator;
class DmlCommandRecorder : public ICommandRecorder
{
public:
DmlCommandRecorder(
ID3D12Device* d3dDevice,
IDMLDevice* device,
std::shared_ptr<CommandQueue> commandQueue);
void InitializeOperator(
IDMLCompiledOperator* op,
const DML_BINDING_DESC& persistentResourceBinding,
const DML_BINDING_DESC& inputArrayBinding);
void ExecuteOperator(
IDMLCompiledOperator* op,
const DML_BINDING_DESC& persistentResourceBinding,
gsl::span<const DML_BINDING_DESC> inputBindings,
gsl::span<const DML_BINDING_DESC> outputBindings);
void CopyBufferRegion(
ID3D12Resource* dstBuffer,
uint64_t dstOffset,
ID3D12Resource* srcBuffer,
uint64_t srcOffset,
uint64_t byteCount);
void FillBufferWithPattern(
ID3D12Resource* dstBuffer,
gsl::span<const std::byte> value /* Data type agnostic value, treated as raw bits */);
void ExecuteCommandList(
ID3D12GraphicsCommandList* commandList,
_Outptr_ ID3D12Fence** fence,
_Out_ uint64_t* completionValue);
ComPtr<ID3D12GraphicsCommandList> GetCommandList();
void ResourceBarrier(gsl::span<const D3D12_RESOURCE_BARRIER> barriers);
void AddUAVBarrier();
void Open() final;
void CloseAndExecute() final;
void SetAllocator(std::weak_ptr<BucketizedBufferAllocator> allocator);
private:
std::shared_ptr<CommandQueue> m_queue;
ComPtr<ID3D12Device> m_d3dDevice;
ComPtr<IDMLDevice> m_dmlDevice;
ComPtr<IDMLOperatorInitializer> m_initializer;
ComPtr<IDMLCommandRecorder> m_recorder;
// Descriptors are allocated from a pool. The current heap pointer is only used to avoid redundantly
// setting the same heap; it does not have ownership of the heap object.
DescriptorPool m_descriptorPool;
ID3D12DescriptorHeap* m_currentDescriptorHeap = nullptr;
// The weak pointer avoids a circular reference from context->recorder->allocator->context
std::weak_ptr<BucketizedBufferAllocator> m_bufferAllocator;
CommandAllocatorRing<2> m_commandAllocatorRing;
// The command list currently being recorded into, and whether any command have been recorded yet.
ComPtr<ID3D12GraphicsCommandList> m_currentCommandList;
bool m_operationsRecordedInCurrentCommandList = false;
// Command lists which have been batched up for execution. The values in
// m_pendingCommandListsCacheable indicate whether they can be moved into this
// class's cache after execution, versus if they belong to the caller and were
// passed to ExecuteCommandList.
std::vector<ComPtr<ID3D12GraphicsCommandList>> m_pendingCommandLists;
std::vector<bool> m_pendingCommandListsCacheable;
// A pool of cached command lists which may be re-used.
std::deque<ComPtr<ID3D12GraphicsCommandList>> m_cachedCommandLists;
void SetDescriptorHeap(ID3D12DescriptorHeap* descriptorHeap);
};
} // namespace Dml

Просмотреть файл

@ -0,0 +1,75 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
DML_TENSOR_DATA_TYPE GetDmlDataTypeFromMlDataTypeNoThrow(MLOperatorTensorDataType tensorDataType) noexcept
{
switch (tensorDataType)
{
case MLOperatorTensorDataType::Float: return DML_TENSOR_DATA_TYPE_FLOAT32;
case MLOperatorTensorDataType::UInt8: return DML_TENSOR_DATA_TYPE_UINT8;
case MLOperatorTensorDataType::Int8: return DML_TENSOR_DATA_TYPE_INT8;
case MLOperatorTensorDataType::UInt16: return DML_TENSOR_DATA_TYPE_UINT16;
case MLOperatorTensorDataType::Int16: return DML_TENSOR_DATA_TYPE_INT16;
case MLOperatorTensorDataType::Int32: return DML_TENSOR_DATA_TYPE_INT32;
case MLOperatorTensorDataType::Int64: return DML_TENSOR_DATA_TYPE_UINT32;
case MLOperatorTensorDataType::String: return DML_TENSOR_DATA_TYPE_UNKNOWN;
case MLOperatorTensorDataType::Bool: return DML_TENSOR_DATA_TYPE_UINT8;
case MLOperatorTensorDataType::Float16: return DML_TENSOR_DATA_TYPE_FLOAT16;
case MLOperatorTensorDataType::Double: return DML_TENSOR_DATA_TYPE_UNKNOWN;
case MLOperatorTensorDataType::UInt32: return DML_TENSOR_DATA_TYPE_UINT32;
case MLOperatorTensorDataType::UInt64: return DML_TENSOR_DATA_TYPE_UINT32; // Stride is used to access lower 32-bits.
case MLOperatorTensorDataType::Complex64: return DML_TENSOR_DATA_TYPE_UNKNOWN;
case MLOperatorTensorDataType::Complex128: return DML_TENSOR_DATA_TYPE_UNKNOWN;
case MLOperatorTensorDataType::Undefined:
default: return DML_TENSOR_DATA_TYPE_UNKNOWN;;
};
}
DML_TENSOR_DATA_TYPE GetDmlDataTypeFromMlDataType(MLOperatorTensorDataType tensorDataType)
{
DML_TENSOR_DATA_TYPE dmlTensorDataType = GetDmlDataTypeFromMlDataTypeNoThrow(tensorDataType);
if (dmlTensorDataType == DML_TENSOR_DATA_TYPE_UNKNOWN)
{
ML_INVALID_ARGUMENT("MLOperatorTensorDataType has no equivalent data type in DML.");
}
return dmlTensorDataType;
}
MLOperatorTensorDataType GetMlDataTypeFromDmlDataType(DML_TENSOR_DATA_TYPE tensorDataType)
{
switch (tensorDataType)
{
case DML_TENSOR_DATA_TYPE_FLOAT32: return MLOperatorTensorDataType::Float;
case DML_TENSOR_DATA_TYPE_UINT8: return MLOperatorTensorDataType::UInt8;
case DML_TENSOR_DATA_TYPE_INT8: return MLOperatorTensorDataType::Int8;
case DML_TENSOR_DATA_TYPE_UINT16: return MLOperatorTensorDataType::UInt16;
case DML_TENSOR_DATA_TYPE_INT16: return MLOperatorTensorDataType::Int16;
case DML_TENSOR_DATA_TYPE_INT32: return MLOperatorTensorDataType::Int32;
case DML_TENSOR_DATA_TYPE_FLOAT16: return MLOperatorTensorDataType::Float16;
case DML_TENSOR_DATA_TYPE_UINT32: return MLOperatorTensorDataType::UInt32;
default: ML_INVALID_ARGUMENT("Unknown DML_TENSOR_DATA_TYPE.");
};
}
size_t ComputeByteSizeFromDimensions(gsl::span<const DimensionType> dimensions, MLOperatorTensorDataType tensorDataType)
{
return ComputeElementCountFromDimensions(dimensions) * GetByteSizeFromMlDataType(tensorDataType);
}
size_t ComputeByteSizeFromTensor(IMLOperatorTensor& tensor)
{
uint32_t dimensionCount = 0;
dimensionCount = tensor.GetDimensionCount();
ML_CHECK_VALID_ARGUMENT(dimensionCount <= MaximumDimensionCount, "Dimensions are beyond supported count.");
std::array<DimensionType, MaximumDimensionCount> dimensions;
THROW_IF_FAILED(tensor.GetShape(dimensionCount, /*out*/ dimensions.data()));
return ComputeByteSizeFromDimensions(gsl::make_span(dimensions.data(), dimensionCount), tensor.GetTensorDataType());
}
} // namespace Dml

Просмотреть файл

@ -0,0 +1,88 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <assert.h>
#include "core/providers/dml/OperatorAuthorHelper/Common.h"
namespace Dml
{
using namespace OperatorHelper;
static const int MaximumDimensionCount = DML_TENSOR_DIMENSION_COUNT_MAX;
DML_TENSOR_DATA_TYPE GetDmlDataTypeFromMlDataType(MLOperatorTensorDataType tensorDataType);
DML_TENSOR_DATA_TYPE GetDmlDataTypeFromMlDataTypeNoThrow(MLOperatorTensorDataType tensorDataType) noexcept;
MLOperatorTensorDataType GetMlDataTypeFromDmlDataType(DML_TENSOR_DATA_TYPE tensorDataType);
size_t ComputeByteSizeFromDimensions(gsl::span<const DimensionType> dimensions, MLOperatorTensorDataType tensorDataType);
size_t ComputeByteSizeFromTensor(IMLOperatorTensor& tensor);
/** Calculates the minimum number of bytes required to store a buffer tensor with the specified type, sizes, and
strides. The formula can be expressed as the following:
IndexOfLastElement = dot(Sizes - 1, Strides);
MinimumImpliedSizeInBytes = roundup((IndexOfLastElement + 1) * ElementSizeInBytes, 4)
In other words, the minimum size of a tensor is the index of the one-past-the-end element, multiplied by the
element size (e.g. 2 bytes for a FLOAT16 tensor). Additionally DirectML requires that all buffers bound must have
a total size which is DWORD-aligned, and hence the minimum implied size in bytes must be rounded up to the nearest
4-byte boundary.
*/
inline UINT64 DMLCalcBufferTensorSize(
DML_TENSOR_DATA_TYPE dataType,
UINT dimensionCount,
_In_reads_(dimensionCount) const UINT* sizes,
_In_reads_opt_(dimensionCount) const UINT* strides)
{
UINT elementSizeInBytes = 0;
switch (dataType)
{
case DML_TENSOR_DATA_TYPE_FLOAT32:
case DML_TENSOR_DATA_TYPE_UINT32:
case DML_TENSOR_DATA_TYPE_INT32:
elementSizeInBytes = 4;
break;
case DML_TENSOR_DATA_TYPE_FLOAT16:
case DML_TENSOR_DATA_TYPE_UINT16:
case DML_TENSOR_DATA_TYPE_INT16:
elementSizeInBytes = 2;
break;
case DML_TENSOR_DATA_TYPE_UINT8:
case DML_TENSOR_DATA_TYPE_INT8:
elementSizeInBytes = 1;
break;
default:
return 0; // Invalid data type
}
UINT64 minimumImpliedSizeInBytes = 0;
if (!strides)
{
minimumImpliedSizeInBytes = sizes[0];
for (UINT i = 1; i < dimensionCount; ++i)
{
minimumImpliedSizeInBytes *= sizes[i];
}
minimumImpliedSizeInBytes *= elementSizeInBytes;
}
else
{
UINT indexOfLastElement = 0;
for (UINT i = 0; i < dimensionCount; ++i)
{
indexOfLastElement += (sizes[i] - 1) * strides[i];
}
minimumImpliedSizeInBytes = (indexOfLastElement + 1) * elementSizeInBytes;
}
// Round up to the nearest 4 bytes.
minimumImpliedSizeInBytes = (minimumImpliedSizeInBytes + 3) & ~3ui64;
return minimumImpliedSizeInBytes;
}
} // namespace Dml

Просмотреть файл

@ -0,0 +1,42 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
HRESULT MapLotusErrorToHRESULT(onnxruntime::common::Status status)
{
switch (status.Code())
{
case onnxruntime::common::StatusCode::OK:
return S_OK;
case onnxruntime::common::StatusCode::FAIL:
return E_FAIL;
case onnxruntime::common::StatusCode::INVALID_ARGUMENT:
return E_INVALIDARG;
case onnxruntime::common::StatusCode::NO_SUCHFILE:
return __HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND);
case onnxruntime::common::StatusCode::NO_MODEL:
return __HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND);
case onnxruntime::common::StatusCode::ENGINE_ERROR:
return E_FAIL;
case onnxruntime::common::StatusCode::RUNTIME_EXCEPTION:
return E_FAIL;
case onnxruntime::common::StatusCode::INVALID_PROTOBUF:
return __HRESULT_FROM_WIN32(ERROR_FILE_CORRUPT);
case onnxruntime::common::StatusCode::MODEL_LOADED:
return __HRESULT_FROM_WIN32(ERROR_INTERNAL_ERROR);
case onnxruntime::common::StatusCode::NOT_IMPLEMENTED:
return E_NOTIMPL;
case onnxruntime::common::StatusCode::INVALID_GRAPH:
return __HRESULT_FROM_WIN32(ERROR_FILE_CORRUPT);
case onnxruntime::common::StatusCode::EP_FAIL:
return __HRESULT_FROM_WIN32(ERROR_INTERNAL_ERROR);
default:
return E_FAIL;
}
}
}

Просмотреть файл

@ -0,0 +1,18 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#define THROW_IF_NOT_OK(status) \
do { \
auto _status = status; \
if (!_status.IsOK()) \
{ \
THROW_HR(Dml::MapLotusErrorToHRESULT(_status)); \
} \
} while (0)
namespace Dml
{
HRESULT MapLotusErrorToHRESULT(onnxruntime::common::Status status);
}

Просмотреть файл

@ -0,0 +1,227 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
#include "ExecutionContext.h"
#include "CommandQueue.h"
namespace Dml
{
ExecutionContext::ExecutionContext(
ID3D12Device* d3d12Device,
IDMLDevice* dmlDevice,
ID3D12CommandQueue* queue
)
: m_queue(std::make_shared<CommandQueue>(queue))
, m_dmlRecorder(d3d12Device, dmlDevice, m_queue)
{
THROW_IF_FAILED(dmlDevice->GetParentDevice(IID_PPV_ARGS(m_d3dDevice.GetAddressOf())));
}
void ExecutionContext::SetAllocator(std::weak_ptr<BucketizedBufferAllocator> allocator)
{
m_dmlRecorder.SetAllocator(allocator);
}
void ExecutionContext::CopyBufferRegion(
ID3D12Resource* dstBuffer,
uint64_t dstOffset,
D3D12_RESOURCE_STATES dstState,
ID3D12Resource* srcBuffer,
uint64_t srcOffset,
D3D12_RESOURCE_STATES srcState,
uint64_t byteCount)
{
assert(!m_closed);
SetCommandRecorder(&m_dmlRecorder);
std::vector<D3D12_RESOURCE_BARRIER> barriers;
if (!(dstState & D3D12_RESOURCE_STATE_COPY_DEST))
{
barriers.push_back(CD3DX12_RESOURCE_BARRIER::Transition(dstBuffer, dstState, D3D12_RESOURCE_STATE_COPY_DEST));
}
if (!(srcState & D3D12_RESOURCE_STATE_COPY_SOURCE))
{
barriers.push_back(CD3DX12_RESOURCE_BARRIER::Transition(srcBuffer, srcState, D3D12_RESOURCE_STATE_COPY_SOURCE));
}
if (!barriers.empty())
{
m_dmlRecorder.ResourceBarrier(barriers);
}
m_dmlRecorder.CopyBufferRegion(dstBuffer, dstOffset, srcBuffer, srcOffset, byteCount);
// Reset barrier state
if (!barriers.empty())
{
for (auto& barrier : barriers)
{
std::swap(barrier.Transition.StateBefore, barrier.Transition.StateAfter);
}
m_dmlRecorder.ResourceBarrier(barriers);
}
}
void ExecutionContext::FillBufferWithPattern(
ID3D12Resource* dstBuffer,
gsl::span<const std::byte> value /* Data type agnostic value, treated as raw bits */)
{
SetCommandRecorder(&m_dmlRecorder);
m_dmlRecorder.FillBufferWithPattern(dstBuffer, value);
}
void ExecutionContext::ExecuteCommandList(
ID3D12GraphicsCommandList* commandList,
_Outptr_ ID3D12Fence** fence,
_Out_ uint64_t* completionValue
)
{
assert(!m_closed);
SetCommandRecorder(&m_dmlRecorder);
m_dmlRecorder.ExecuteCommandList(commandList, fence, completionValue);
}
void ExecutionContext::InitializeOperator(
IDMLCompiledOperator* op,
const DML_BINDING_DESC& persistentResourceBinding,
const DML_BINDING_DESC& inputArrayBinding)
{
assert(!m_closed);
SetCommandRecorder(&m_dmlRecorder);
m_dmlRecorder.InitializeOperator(op, persistentResourceBinding, inputArrayBinding);
}
void ExecutionContext::ExecuteOperator(
IDMLCompiledOperator* op,
const DML_BINDING_DESC& persistentResourceBinding,
gsl::span<const DML_BINDING_DESC> inputBindings,
gsl::span<const DML_BINDING_DESC> outputBindings)
{
assert(!m_closed);
SetCommandRecorder(&m_dmlRecorder);
m_dmlRecorder.ExecuteOperator(op, persistentResourceBinding, inputBindings, outputBindings);
}
void ExecutionContext::AddUAVBarrier()
{
assert(!m_closed);
SetCommandRecorder(&m_dmlRecorder);
m_dmlRecorder.AddUAVBarrier();
}
void ExecutionContext::ResourceBarrier(gsl::span<const D3D12_RESOURCE_BARRIER> barriers)
{
assert(!m_closed);
SetCommandRecorder(&m_dmlRecorder);
m_dmlRecorder.ResourceBarrier(barriers);
}
void ExecutionContext::GetCommandListForRecording(ID3D12GraphicsCommandList** commandList)
{
assert(!m_closed);
SetCommandRecorder(&m_dmlRecorder);
m_dmlRecorder.GetCommandList().CopyTo(commandList);
}
void ExecutionContext::Wait(ID3D12Fence* fence, uint64_t value)
{
assert(!m_closed);
Flush();
m_queue->Wait(fence, value);
ReleaseCompletedReferences();
}
void ExecutionContext::SetCommandRecorder(ICommandRecorder* newRecorder)
{
assert(!m_closed);
// If changing which recorder is the current one, we need to flush the old one first. This is to ensure correct
// ordering of operations on the command queue.
if (m_currentRecorder != newRecorder)
{
Flush();
m_currentRecorder = newRecorder;
if (m_currentRecorder != nullptr)
{
m_currentRecorder->Open();
}
}
}
void ExecutionContext::Flush()
{
assert(!m_closed);
if (!m_currentRecorder)
{
// Nothing to flush
return;
}
m_currentRecorder->CloseAndExecute();
ReleaseCompletedReferences();
// Just submitted our command list, so we have neither DML or D3D12 work recorded on any of our command lists.
m_currentRecorder = nullptr;
}
void ExecutionContext::QueueReference(IUnknown* object)
{
assert(!m_closed);
// If something has been recorded into a command list but not submitted yet, it means that the *next* fence
// value is the one to signal completion.
bool waitForUnsubmittedWork = (m_currentRecorder != nullptr);
m_queue->QueueReference(object, waitForUnsubmittedWork);
}
void ExecutionContext::Close()
{
assert(!m_closed);
// Discard unflushed work and clear queued references. This prevents the circular reference:
// Kernel --> ProviderImpl --> Context --> QueuedRefs --> Kernel
m_queue->Close();
m_currentRecorder = nullptr;
m_closed = true;
}
GpuEvent ExecutionContext::GetCurrentCompletionEvent()
{
assert(!m_closed);
GpuEvent event = m_queue->GetCurrentCompletionEvent();
// If something has been recorded into a command list but not submitted yet, it means that the *next* fence
// value is the one to signal completion.
const bool unflushedWorkExists = (m_currentRecorder != nullptr);
if (unflushedWorkExists)
{
++event.fenceValue;
}
return event;
}
void ExecutionContext::ReleaseCompletedReferences()
{
assert(!m_closed);
m_queue->ReleaseCompletedReferences();
}
D3D12_COMMAND_LIST_TYPE ExecutionContext::GetCommandListTypeForQueue() const
{
assert(!m_closed);
return m_queue->GetType();
}
} // namespace Dml

Просмотреть файл

@ -0,0 +1,106 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "GpuEvent.h"
#include "ICommandRecorder.h"
#include "DmlCommandRecorder.h"
namespace Dml
{
class CommandQueue;
// Asynchronously performs GPU work, and automatically manages command list recording and submission to queues.
// Work submitted to the ExecutionContext is typically recorded onto a command list and may not immediately begin
// execution on the GPU. Call Flush() to force all recorded work to be submitted to the command queue for execution
// on the GPU.
class ExecutionContext
{
public:
// Constructs an ExecutionContext that executes on the supplied queue.
ExecutionContext(
ID3D12Device* d3d12Device,
IDMLDevice* dmlDevice,
ID3D12CommandQueue* queue);
void SetAllocator(std::weak_ptr<BucketizedBufferAllocator> allocator);
// Waits for flushed work, discards unflushed work, and discards associated references to
// prevent circular references. Must be the last call on the object before destruction.
void Close();
// Queues a CopyBufferRegion (see ID3D12GraphicsCommandList::CopyBufferRegion) for execution. Transition
// barriers are automatically inserted to transition the source and destination resources to COPY_SOURCE and
// COPY_DEST if necessary.
void CopyBufferRegion(
ID3D12Resource* dstBuffer,
uint64_t dstOffset,
D3D12_RESOURCE_STATES dstState,
ID3D12Resource* srcBuffer,
uint64_t srcOffset,
D3D12_RESOURCE_STATES srcState,
uint64_t byteCount);
void FillBufferWithPattern(
ID3D12Resource* dstBuffer,
gsl::span<const std::byte> value /* Data type agnostic value, treated as raw bits */);
void InitializeOperator(
IDMLCompiledOperator* op,
const DML_BINDING_DESC& persistentResourceBinding,
const DML_BINDING_DESC& inputArrayBinding);
void ExecuteOperator(
IDMLCompiledOperator* op,
const DML_BINDING_DESC& persistentResourceBinding,
gsl::span<const DML_BINDING_DESC> inputBindings,
gsl::span<const DML_BINDING_DESC> outputBindings);
void ExecuteCommandList(
ID3D12GraphicsCommandList* commandList,
_Outptr_ ID3D12Fence** fence,
_Out_ uint64_t* completionValue
);
void AddUAVBarrier();
void ResourceBarrier(gsl::span<const D3D12_RESOURCE_BARRIER> barriers);
void GetCommandListForRecording(ID3D12GraphicsCommandList** commandList);
// See ID3D12CommandQueue::Wait
void Wait(ID3D12Fence* fence, uint64_t value);
// Forces all queued work to begin executing on the GPU. This method returns immediately and does not wait
// for the submitted work to complete execution on the GPU.
void Flush();
// Returns an event which will become signaled when everything submitted to the execution context thus far has
// completed execution on the GPU, including work that has yet to be flushed to the queue.
GpuEvent GetCurrentCompletionEvent();
// Adds a reference which will be released when queued GPU work is completed
void QueueReference(IUnknown* object);
// Release any accumulated references who corresponding GPU fence values have
// been reached.
void ReleaseCompletedReferences();
D3D12_COMMAND_LIST_TYPE GetCommandListTypeForQueue() const;
private:
ComPtr<ID3D12Device> m_d3dDevice;
void SetCommandRecorder(ICommandRecorder* newRecorder);
std::shared_ptr<CommandQueue> m_queue;
ICommandRecorder* m_currentRecorder = nullptr;
// Up to one of these is active at a time
DmlCommandRecorder m_dmlRecorder;
bool m_closed = false;
};
} // namespace Dml

Просмотреть файл

@ -0,0 +1,721 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
#include "IExecutionProvider.h"
#include "ExecutionProvider.h"
#include "PooledUploadHeap.h"
#include "ReadbackHeap.h"
#include "ExecutionContext.h"
#include "BucketizedBufferAllocator.h"
#include "MLOperatorAuthorImpl.h"
#include "core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h"
#include "AbiCustomRegistry.h"
#include "GraphPartitioner.h"
#include "core/graph/indexed_sub_graph.h"
#include "core/framework/compute_capability.h"
#ifdef ERROR
#undef ERROR
#endif
#include "core/session/inference_session.h"
#define ERROR 0
#include "core/session/onnxruntime_c_api.h"
#include <wil/wrl.h>
#include <dxgi1_6.h>
#define ENABLE_GRAPH_COMPILATION
using namespace winrt::Windows::AI::MachineLearning::implementation;
namespace Dml
{
using namespace onnxruntime::common;
ExecutionProvider::~ExecutionProvider()
{
if (m_impl)
{
m_impl->Close();
}
}
static void CreateDmlKernelRegistry(
_Outptr_ std::shared_ptr<onnxruntime::KernelRegistry>* registry,
_Outptr_ std::shared_ptr<const GraphNodeFactoryMap>* graphNodeFactoryMap)
{
ComPtr<AbiCustomRegistry> abiRegistry = wil::MakeOrThrow<AbiCustomRegistry>();
Dml::RegisterDmlOperators(abiRegistry.Get());
assert(abiRegistry->GetRegistries().size() == 1);
auto customRegistry = *abiRegistry->GetRegistries().begin();
*registry = customRegistry->GetKernelRegistry();
*graphNodeFactoryMap = abiRegistry->GetGraphNodeFactoryMap();
}
ExecutionProvider::ExecutionProvider(
IDMLDevice* dmlDevice,
ID3D12CommandQueue* commandQueue,
bool enableMetacommands) :
IExecutionProvider(onnxruntime::kDmlExecutionProvider)
{
D3D12_COMMAND_LIST_TYPE queueType = commandQueue->GetDesc().Type;
if (queueType != D3D12_COMMAND_LIST_TYPE_DIRECT && queueType != D3D12_COMMAND_LIST_TYPE_COMPUTE)
{
// DML requires either DIRECT or COMPUTE command queues.
THROW_HR(E_INVALIDARG);
}
ComPtr<ID3D12Device> device;
THROW_IF_FAILED(commandQueue->GetDevice(IID_PPV_ARGS(&device)));
m_impl = wil::MakeOrThrow<ExecutionProviderImpl>(dmlDevice, device.Get(), commandQueue, enableMetacommands);
// Register the allocators with ORT, through concrete ORT methods on the IExecutionProvider base class
InsertAllocator(m_impl->GetGpuAllocator());
InsertAllocator(m_impl->GetCpuInputAllocator());
InsertAllocator(m_impl->GetCpuOutputAllocator());
}
std::vector<std::unique_ptr<onnxruntime::ComputeCapability>>
ExecutionProvider::GetCapability(
const onnxruntime::GraphViewer& graph,
const std::vector<const onnxruntime::KernelRegistry*>& kernel_registries) const
{
#ifdef ENABLE_GRAPH_COMPILATION
return m_impl->GetCapability(graph, kernel_registries);
#endif
return onnxruntime::IExecutionProvider::GetCapability(graph, kernel_registries);
}
void ExecutionProviderImpl::Close()
{
m_context->Close();
}
HRESULT __stdcall ExecutionProviderImpl::AllocatePooledResource(
size_t size,
AllocatorRoundingMode roundingMode,
ID3D12Resource **d3dResource,
IUnknown** pooledResource
) const noexcept try
{
ComPtr<IUnknown> allocation;
allocation.Attach(static_cast<IUnknown* >(m_allocator->Alloc(size, roundingMode)));
const auto* allocInfo = m_allocator->DecodeDataHandle(allocation.Get());
ComPtr<ID3D12Resource> resource = allocInfo->GetResource();
resource.CopyTo(d3dResource);
*pooledResource = allocation.Detach();
return S_OK;
}
CATCH_RETURN();
ID3D12Resource* __stdcall ExecutionProviderImpl::DecodeResource(void* allocation) const noexcept
{
try
{
const AllocationInfo* allocInfo = m_allocator->DecodeDataHandle(allocation);
return allocInfo->GetResource();
}
catch(...)
{
return nullptr;
}
}
ExecutionProviderImpl::ExecutionProviderImpl(IDMLDevice* dmlDevice, ID3D12Device* d3d12Device, ID3D12CommandQueue* queue, bool enableMetacommands)
: m_d3d12Device(d3d12Device),
m_dmlDevice(dmlDevice),
m_areMetacommandsEnabled(enableMetacommands)
{
D3D12_FEATURE_DATA_FEATURE_LEVELS featureLevels = {};
D3D_FEATURE_LEVEL featureLevelsList[] = {
D3D_FEATURE_LEVEL_1_0_CORE,
D3D_FEATURE_LEVEL_11_0,
D3D_FEATURE_LEVEL_11_1,
D3D_FEATURE_LEVEL_12_0,
D3D_FEATURE_LEVEL_12_1
};
featureLevels.NumFeatureLevels = ARRAYSIZE(featureLevelsList);
featureLevels.pFeatureLevelsRequested = featureLevelsList;
THROW_IF_FAILED(d3d12Device->CheckFeatureSupport(
D3D12_FEATURE_FEATURE_LEVELS,
&featureLevels,
sizeof(featureLevels)
));
m_isMcdmDevice = (featureLevels.MaxSupportedFeatureLevel == D3D_FEATURE_LEVEL_1_0_CORE);
m_context = std::make_shared<ExecutionContext>(m_d3d12Device.Get(), m_dmlDevice.Get(), queue);
// Create an allocator for D3D12 buffers used to hold tensor data. The returned buffers from the allocator
// should be DEFAULT heap buffers which can be used as UAVs, and which start in UAV state.
m_allocator = std::make_shared<BucketizedBufferAllocator>(
m_d3d12Device.Get(),
m_context,
CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT),
D3D12_HEAP_FLAG_NONE,
D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS,
D3D12_RESOURCE_STATE_UNORDERED_ACCESS);
m_context->SetAllocator(m_allocator);
m_uploadHeap = std::make_unique<PooledUploadHeap>(m_d3d12Device.Get(), m_context);
m_readbackHeap = std::make_unique<ReadbackHeap>(m_d3d12Device.Get(), m_context);
// CPU Allocator used to create buffers for the MemcpyFromHost operator.
m_cpuInputAllocator = std::make_shared<CPUAllocator>(OrtMemType::OrtMemTypeCPUInput);
m_cpuOutputAllocator = std::make_shared<CPUAllocator>(OrtMemType::OrtMemTypeCPUOutput);
CreateDmlKernelRegistry(&m_kernelRegistry, &m_graphNodeFactoryMap);
}
HRESULT __stdcall ExecutionProviderImpl::GetD3DDevice(_COM_Outptr_ ID3D12Device** d3dDevice) const noexcept
{
return m_d3d12Device.CopyTo(d3dDevice);
}
HRESULT __stdcall ExecutionProviderImpl::GetDmlDevice(_COM_Outptr_ IDMLDevice** dmlDevice) const noexcept
{
return m_dmlDevice.CopyTo(dmlDevice);
}
HRESULT __stdcall ExecutionProviderImpl::ExecuteCommandList(
ID3D12GraphicsCommandList* commandList,
_Outptr_ ID3D12Fence** fence,
_Out_ uint64_t* completionValue
) const noexcept try
{
assert(!m_closed);
m_context->ExecuteCommandList(commandList, fence, completionValue);
return S_OK;
}
CATCH_RETURN();
HRESULT __stdcall ExecutionProviderImpl::AddUAVBarrier() const noexcept try
{
assert(!m_closed);
m_context->AddUAVBarrier();
return S_OK;
}
CATCH_RETURN();
HRESULT __stdcall ExecutionProviderImpl::InitializeOperator(
IDMLCompiledOperator* op,
_In_opt_ const DML_BUFFER_BINDING* persistentResourceBinding,
gsl::span<const DML_BUFFER_BINDING> inputBindings
) const noexcept try
{
assert(!m_closed);
bool hasInputsToBind = false;
std::vector<DML_BUFFER_BINDING> inputBufferBindings(inputBindings.size());
for (gsl::index i = 0; i < inputBindings.size(); i++)
{
if (inputBindings[i].Buffer)
{
hasInputsToBind = true;
inputBufferBindings[i] = { inputBindings[i].Buffer, inputBindings[i].Offset, inputBindings[i].SizeInBytes };
}
}
DML_BINDING_DESC persistentResourceBindingDesc =
persistentResourceBinding
? DML_BINDING_DESC{ DML_BINDING_TYPE_BUFFER, persistentResourceBinding }
: DML_BINDING_DESC{ DML_BINDING_TYPE_NONE, nullptr };
DML_BUFFER_ARRAY_BINDING inputBufferArrayDesc;
inputBufferArrayDesc.BindingCount = gsl::narrow_cast<uint32_t>(inputBufferBindings.size());
inputBufferArrayDesc.Bindings = inputBufferBindings.data();
DML_BINDING_DESC inputArrayBindingDesc = hasInputsToBind ?
DML_BINDING_DESC{ DML_BINDING_TYPE_BUFFER_ARRAY, &inputBufferArrayDesc } :
DML_BINDING_DESC{ DML_BINDING_TYPE_NONE, nullptr };
m_context->InitializeOperator(
op,
persistentResourceBindingDesc,
inputArrayBindingDesc);
return S_OK;
}
CATCH_RETURN();
HRESULT __stdcall ExecutionProviderImpl::ExecuteOperator(
IDMLCompiledOperator* op,
_In_opt_ const DML_BUFFER_BINDING* persistentResourceBinding,
gsl::span<IMLOperatorTensor*> inputTensors,
gsl::span<IMLOperatorTensor*> outputTensors
) const noexcept try
{
assert(!m_closed);
auto FillBindings = [this](auto& bufferBindings, auto& bindingDescs, auto& tensors)
{
for (IMLOperatorTensor* tensor : tensors)
{
if (tensor)
{
assert(tensor->IsDataInterface());
const AllocationInfo* allocInfo = m_allocator->DecodeDataHandle(MLOperatorTensor(tensor).GetDataInterface().Get());
ID3D12Resource* resource = allocInfo->GetResource();
D3D12_RESOURCE_DESC resourceDesc = resource->GetDesc();
bufferBindings.push_back({ resource, 0, resourceDesc.Width });
bindingDescs.push_back({ DML_BINDING_TYPE_BUFFER, &bufferBindings.back() });
}
else
{
bufferBindings.push_back({ nullptr, 0, 0 });
bindingDescs.push_back({ DML_BINDING_TYPE_NONE, nullptr });
}
}
};
std::vector<DML_BUFFER_BINDING> inputBufferBindings;
inputBufferBindings.reserve(inputTensors.size());
std::vector<DML_BINDING_DESC> inputBindings;
inputBindings.reserve(inputTensors.size());
FillBindings(inputBufferBindings, inputBindings, inputTensors);
std::vector<DML_BUFFER_BINDING> outputBufferBindings;
outputBufferBindings.reserve(outputTensors.size());
std::vector<DML_BINDING_DESC> outputBindings;
outputBindings.reserve(outputTensors.size());
FillBindings(outputBufferBindings, outputBindings, outputTensors);
THROW_IF_FAILED(ExecuteOperator(op, persistentResourceBinding, inputBindings, outputBindings));
return S_OK;
}
CATCH_RETURN();
HRESULT __stdcall ExecutionProviderImpl::ExecuteOperator(
IDMLCompiledOperator* op,
_In_opt_ const DML_BUFFER_BINDING* persistentResourceBinding,
gsl::span<DML_BINDING_DESC> inputTensors,
gsl::span<DML_BINDING_DESC> outputTensors
) const noexcept try
{
assert(!m_closed);
DML_BINDING_DESC persistentResourceBindingDesc =
persistentResourceBinding
? DML_BINDING_DESC{ DML_BINDING_TYPE_BUFFER, persistentResourceBinding }
: DML_BINDING_DESC{ DML_BINDING_TYPE_NONE, nullptr };
m_context->ExecuteOperator(
op,
persistentResourceBindingDesc,
inputTensors,
outputTensors);
return S_OK;
}
CATCH_RETURN();
static gsl::span<const std::byte> AsByteSpan(const void* data, size_t sizeInBytes)
{
return gsl::make_span(static_cast<const std::byte*>(data), sizeInBytes);
}
static gsl::span<std::byte> AsByteSpan(void* data, size_t sizeInBytes)
{
return gsl::make_span(static_cast<std::byte*>(data), sizeInBytes);
}
HRESULT __stdcall ExecutionProviderImpl::CopyTensor(IMLOperatorTensor* dst, IMLOperatorTensor* src) const noexcept try
{
assert(!m_closed);
const size_t dataSizeInBytes = ComputeByteSizeFromTensor(*dst);
THROW_HR_IF(E_INVALIDARG, dataSizeInBytes != ComputeByteSizeFromTensor(*src)); // Tensors must be the same size
if (src->IsCpuData() && !dst->IsCpuData())
{
//
// CPU -> GPU copy (upload)
//
const AllocationInfo* dstAllocInfo = m_allocator->DecodeDataHandle(MLOperatorTensor(dst).GetDataInterface().Get());
ID3D12Resource* dstData = dstAllocInfo->GetResource();
const void* srcData = src->GetData();
const uint64_t dstOffset = 0;
const auto dstState = D3D12_RESOURCE_STATE_UNORDERED_ACCESS; // GPU resources are always kept in UAV state
m_uploadHeap->BeginUploadToGpu(dstData, dstOffset, dstState, AsByteSpan(srcData, dataSizeInBytes));
}
else if (!src->IsCpuData() && dst->IsCpuData())
{
//
// GPU -> CPU copy (readback)
//
void* dstData = dst->GetData();
const AllocationInfo* srcAllocInfo = m_allocator->DecodeDataHandle(MLOperatorTensor(src).GetDataInterface().Get());
ID3D12Resource* srcData = srcAllocInfo->GetResource();
const uint64_t srcOffset = 0;
const auto srcState = D3D12_RESOURCE_STATE_UNORDERED_ACCESS; // GPU resources are always kept in UAV state
// Performs a blocking call to synchronize and read back data from the GPU into the destination buffer
m_readbackHeap->ReadbackFromGpu(AsByteSpan(dstData, dataSizeInBytes), srcData, srcOffset, srcState);
}
else if (!src->IsCpuData() && !dst->IsCpuData())
{
//
// GPU -> GPU copy
//
const AllocationInfo* srcAllocInfo = m_allocator->DecodeDataHandle(MLOperatorTensor(src).GetDataInterface().Get());
const AllocationInfo* dstAllocInfo = m_allocator->DecodeDataHandle(MLOperatorTensor(dst).GetDataInterface().Get());
ID3D12Resource* srcData = srcAllocInfo->GetResource();
ID3D12Resource* dstData = dstAllocInfo->GetResource();
m_context->CopyBufferRegion(dstData, 0, D3D12_RESOURCE_STATE_UNORDERED_ACCESS, srcData, 0, D3D12_RESOURCE_STATE_UNORDERED_ACCESS, dataSizeInBytes);
}
else
{
// CPU -> CPU copies not supported
THROW_HR(E_INVALIDARG);
}
return S_OK;
}
CATCH_RETURN();
HRESULT STDMETHODCALLTYPE ExecutionProviderImpl::FillTensorWithPattern(
IMLOperatorTensor* dst,
gsl::span<const std::byte> value // Data type agnostic value, treated as raw bits
) const noexcept try
{
const AllocationInfo* dstAllocInfo = m_allocator->DecodeDataHandle(MLOperatorTensor(dst).GetDataInterface().Get());
ID3D12Resource* dstData = dstAllocInfo->GetResource();
m_context->FillBufferWithPattern(dstData, value);
return S_OK;
}
CATCH_RETURN();
HRESULT __stdcall ExecutionProviderImpl::UploadToResource(ID3D12Resource* dstData, const void* srcData, uint64_t srcDataSize) const noexcept try
{
assert(!m_closed);
m_uploadHeap->BeginUploadToGpu(dstData, 0, D3D12_RESOURCE_STATE_UNORDERED_ACCESS, AsByteSpan(srcData, srcDataSize));
return S_OK;
}
CATCH_RETURN();
uint32_t ExecutionProviderImpl::GetSuppportedDeviceDataTypeMask() const
{
// The DML provider registers all supported kernels up-front regardless of actual device capability,
// but this is problematic later when executing the graph because DirectML will fail to create
// the operator, and by that late phase, it's long past too late to recover. So, this function queries
// the actual type capabilities so the partitioner may assigns nodes to the CPU if the GPU cannot
// handle them, similar to the fallback in CUDAExecutionProvider::GetCapability for certain RNN/GRU/Conv
// attributes.
uint32_t deviceTypeMask = 0u;
// Form the bitmask of all supported data types.
for (uint32_t i = 0; i <= DML_TENSOR_DATA_TYPE_INT8; ++i)
{
DML_FEATURE_QUERY_TENSOR_DATA_TYPE_SUPPORT dataTypeQuery = { static_cast<DML_TENSOR_DATA_TYPE>(i) };
DML_FEATURE_DATA_TENSOR_DATA_TYPE_SUPPORT dataTypeSupport = {};
THROW_IF_FAILED(m_dmlDevice->CheckFeatureSupport(
DML_FEATURE_TENSOR_DATA_TYPE_SUPPORT,
sizeof(dataTypeQuery),
&dataTypeQuery,
sizeof(dataTypeSupport),
&dataTypeSupport
));
deviceTypeMask |= (dataTypeSupport.IsSupported << i);
}
return deviceTypeMask;
}
std::vector<std::unique_ptr<onnxruntime::ComputeCapability>>
ExecutionProviderImpl::GetCapability(
const onnxruntime::GraphViewer& graph,
const std::vector<const onnxruntime::KernelRegistry*>& registries) const
{
std::string partitionKernelPrefix = std::to_string(m_partitionKernelPrefixVal++) + "_";
uint32_t deviceDataTypeMask = GetSuppportedDeviceDataTypeMask();
return PartitionGraph(graph, *m_graphNodeFactoryMap, registries, deviceDataTypeMask, m_kernelRegistry.get(), partitionKernelPrefix);
}
Status ExecutionProviderImpl::CopyTensor(const onnxruntime::Tensor& src, onnxruntime::Tensor& dst) const
{
assert(!m_closed);
auto provider = const_cast<ExecutionProviderImpl*>(this);
TensorWrapper destInternal(
&dst,
strcmp(dst.Location().name, onnxruntime::CPU) && !(dst.Location().mem_type == ::OrtMemType::OrtMemTypeCPUOutput || dst.Location().mem_type == ::OrtMemType::OrtMemTypeCPUInput),
provider,
true);
TensorWrapper srcInternal(
const_cast<onnxruntime::Tensor*>(&src),
strcmp(src.Location().name, onnxruntime::CPU) && !(src.Location().mem_type == ::OrtMemType::OrtMemTypeCPUOutput || src.Location().mem_type == ::OrtMemType::OrtMemTypeCPUInput),
provider,
true);
THROW_IF_FAILED(CopyTensor(&destInternal, &srcInternal));
return onnxruntime::common::Status::OK();
}
Status ExecutionProviderImpl::WaitForGpuCompletion()
{
assert(!m_closed);
Flush();
m_context->GetCurrentCompletionEvent().WaitForSignal();
m_context->ReleaseCompletedReferences();
return Status::OK();
}
void __stdcall ExecutionProviderImpl::Flush() const
{
assert(!m_closed);
m_context->Flush();
}
void ExecutionProviderImpl::SetDefaultRoundingMode(AllocatorRoundingMode roundingMode)
{
m_allocator->SetDefaultRoundingMode(roundingMode);
}
void ExecutionProviderImpl::ReleaseCompletedReferences()
{
m_context->ReleaseCompletedReferences();
}
void ExecutionProviderImpl::TrimUploadHeap()
{
m_uploadHeap->Trim();
}
void ExecutionProviderImpl::QueueReference(IUnknown* object)
{
assert(!m_closed);
m_context->QueueReference(object);
}
void ExecutionProviderImpl::GetShadowCopyIfRequired(
bool isInternalOperator,
IUnknown* data,
IUnknown** dataCopy) const
{
assert(!m_closed);
*dataCopy = data;
data->AddRef();
}
void ExecutionProviderImpl::GetABIDataInterface(
bool isInternalOperator,
IUnknown* data,
IUnknown** abiData) const
{
assert(!m_closed);
if (isInternalOperator)
{
*abiData = data;
data->AddRef();
}
else
{
ComPtr<ID3D12Resource> resource = m_allocator->DecodeDataHandle(data)->GetResource();
*abiData = resource.Detach();
}
}
uint64_t ExecutionProviderImpl::TryGetPooledAllocationId(
IUnknown* data,
bool isInternalOperator)
{
assert(!isInternalOperator);
return m_allocator->DecodeDataHandle(data)->GetPooledResourceId();
}
void ExecutionProviderImpl::GetABIExecutionInterface(
bool isInternalOperator,
IUnknown** abiExecutionObject) const
{
assert(!m_closed);
if (isInternalOperator)
{
ComPtr<IUnknown> thisPtr = const_cast<IExecutionProvider*>(static_cast<const IExecutionProvider*>(this));
*abiExecutionObject = thisPtr.Detach();
}
else
{
ComPtr<ID3D12GraphicsCommandList> commandList;
m_context->GetCommandListForRecording(commandList.GetAddressOf());
*abiExecutionObject = commandList.Detach();
}
}
bool ExecutionProviderImpl::TransitionsRequiredForOperator(
bool isInternalOperator
)
{
// External operators receive resources in Common state, while internal operators receive
// them in UAV state. Resources are otherwise kept in UAV state (or are promotable to UAV).
return !isInternalOperator;
}
void ExecutionProviderImpl::TransitionResourcesForOperator(
bool isBeforeOp,
uint32_t resourceCount,
IUnknown** resources
)
{
std::vector<D3D12_RESOURCE_BARRIER> barriers;
barriers.reserve(resourceCount);
for (uint32_t i = 0; i < resourceCount; ++i)
{
ComPtr<ID3D12Resource> resource;
THROW_IF_FAILED(resources[i]->QueryInterface(resource.GetAddressOf()));
// Custom operators receive resources in Common state and must return them to Common
// state when finished. Resources are otherwise kept in UAV state (or are promotable to UAV).
barriers.push_back(CD3DX12_RESOURCE_BARRIER::Transition(
resource.Get(),
isBeforeOp ? D3D12_RESOURCE_STATE_UNORDERED_ACCESS : D3D12_RESOURCE_STATE_COMMON,
isBeforeOp ? D3D12_RESOURCE_STATE_COMMON : D3D12_RESOURCE_STATE_UNORDERED_ACCESS
));
}
if (!barriers.empty())
{
m_context->ResourceBarrier(barriers);
}
}
D3D12_COMMAND_LIST_TYPE __stdcall ExecutionProviderImpl::GetCommandListTypeForQueue() const
{
return m_context->GetCommandListTypeForQueue();
}
bool __stdcall ExecutionProviderImpl::IsMcdmDevice() const noexcept
{
return m_isMcdmDevice;
}
bool __stdcall ExecutionProviderImpl::MetacommandsEnabled() const noexcept
{
return m_areMetacommandsEnabled;
}
std::shared_ptr<onnxruntime::IAllocator> ExecutionProviderImpl::GetGpuAllocator()
{
return m_allocator;
}
std::shared_ptr<onnxruntime::IAllocator> ExecutionProviderImpl::GetCpuInputAllocator()
{
return m_cpuInputAllocator;
}
std::shared_ptr<onnxruntime::IAllocator> ExecutionProviderImpl::GetCpuOutputAllocator()
{
return m_cpuOutputAllocator;
}
std::unique_ptr<onnxruntime::IExecutionProvider> CreateExecutionProvider(
IDMLDevice* dmlDevice,
ID3D12CommandQueue* commandQueue,
bool enableMetacommands)
{
return std::make_unique<Dml::ExecutionProvider>(dmlDevice, commandQueue, enableMetacommands);
}
ID3D12Resource* GetD3D12ResourceFromAllocation(onnxruntime::IAllocator* allocator, void* ptr)
{
Dml::BucketizedBufferAllocator* pAllocationInfo = static_cast<Dml::BucketizedBufferAllocator*>(allocator);
return pAllocationInfo->DecodeDataHandle(ptr)->GetResource();
}
void FlushContext(onnxruntime::IExecutionProvider* provider)
{
ExecutionProvider* dmlexecutionprovider = static_cast<Dml::ExecutionProvider*>(provider);
dmlexecutionprovider->Flush();
}
void SetDefaultRoundingMode(onnxruntime::IExecutionProvider* provider, AllocatorRoundingMode roundingMode)
{
ExecutionProvider* dmlexecutionprovider = static_cast<Dml::ExecutionProvider*>(provider);
dmlexecutionprovider->SetDefaultRoundingMode(roundingMode);
}
void ReleaseCompletedReferences(onnxruntime::IExecutionProvider * provider)
{
ExecutionProvider* dmlexecutionprovider = static_cast<Dml::ExecutionProvider*>(provider);
dmlexecutionprovider->ReleaseCompletedReferences();
}
void TrimUploadHeap(onnxruntime::IExecutionProvider * provider)
{
ExecutionProvider* dmlexecutionprovider = static_cast<Dml::ExecutionProvider*>(provider);
dmlexecutionprovider->TrimUploadHeap();
}
void WaitForGpuCompletion(onnxruntime::IExecutionProvider * provider)
{
ExecutionProvider* dmlexecutionprovider = static_cast<Dml::ExecutionProvider*>(provider);
dmlexecutionprovider->WaitForGpuCompletion();
}
onnxruntime::common::Status CopyTensor(
onnxruntime::IExecutionProvider* provider,
const onnxruntime::Tensor& src,
onnxruntime::Tensor& dst
)
{
ExecutionProvider* dmlexecutionprovider = static_cast<Dml::ExecutionProvider*>(provider);
return dmlexecutionprovider->GetImpl()->CopyTensor(src, dst);
}
void* CreateGPUAllocationFromD3DResource(ID3D12Resource* pResource)
{
uint64_t pooledResourceId = 0; // Not a pooled resource
ComPtr<AllocationInfo> allocInfo = wil::MakeOrThrow<AllocationInfo>(nullptr, 0, pooledResourceId, pResource, (size_t)pResource->GetDesc().Width);
return allocInfo.Detach();
}
void FreeGPUAllocation(void* ptr)
{
ComPtr<AllocationInfo> allocInfo;
allocInfo.Attach(static_cast<AllocationInfo*>(ptr));
}
onnxruntime::common::Status RegisterDmlGraphTransformer(onnxruntime::InferenceSession* session, std::shared_ptr<onnxruntime::KernelRegistry> dmlRegistry)
{
auto graphTransformer = std::make_unique<Dml::GraphTransformer>(std::string(onnxruntime::kDmlExecutionProvider), dmlRegistry);
return session->RegisterGraphTransformer(std::move(graphTransformer), onnxruntime::TransformerLevel::Level1);
}
} // namespace Dml

Просмотреть файл

@ -0,0 +1,284 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "GraphTransformer.h"
#include "core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h"
#include <wrl/client.h>
#include <wrl/implements.h>
namespace WRL
{
template <typename... TInterfaces>
using Base = Microsoft::WRL::RuntimeClass<
Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>,
TInterfaces...
>;
}
using namespace Microsoft::WRL;
namespace Dml
{
class PooledUploadHeap;
class ReadbackHeap;
class ExecutionContext;
class BucketizedBufferAllocator;
class CPUAllocator;
class ExecutionProvider;
class ExecutionProviderImpl : public WRL::Base<Dml::IExecutionProvider,
winrt::Windows::AI::MachineLearning::implementation::IWinmlExecutionProvider>
{
public:
explicit ExecutionProviderImpl::ExecutionProviderImpl(
IDMLDevice* dmlDevice,
ID3D12Device* d3d12Device,
ID3D12CommandQueue* queue,
bool enableMetacommands = true);
void ReleaseCompletedReferences();
void TrimUploadHeap();
public: // implements Dml::IExecutionProvider
STDMETHOD(GetD3DDevice)(_COM_Outptr_ ID3D12Device** d3dDevice) const noexcept final;
STDMETHOD(GetDmlDevice)(_COM_Outptr_ IDMLDevice** dmlDevice) const noexcept final;
STDMETHOD(ExecuteCommandList)(
ID3D12GraphicsCommandList* commandList,
_Outptr_ ID3D12Fence** fence,
_Out_ uint64_t* completionValue
) const noexcept final;
STDMETHOD(AddUAVBarrier)() const noexcept final;
STDMETHOD(InitializeOperator)(
IDMLCompiledOperator* op,
_In_opt_ const DML_BUFFER_BINDING* persistentResourceBinding,
gsl::span<const DML_BUFFER_BINDING> inputBindings
) const noexcept final;
STDMETHOD(ExecuteOperator)(
IDMLCompiledOperator* op,
_In_opt_ const DML_BUFFER_BINDING* persistentResourceBinding,
gsl::span<IMLOperatorTensor*> inputTensors,
gsl::span<IMLOperatorTensor*> outputTensors
) const noexcept final;
STDMETHOD(ExecuteOperator)(
IDMLCompiledOperator* op,
_In_opt_ const DML_BUFFER_BINDING* persistentResourceBinding,
gsl::span<DML_BINDING_DESC> inputTensors,
gsl::span<DML_BINDING_DESC> outputTensors
) const noexcept final;
STDMETHOD(CopyTensor)(IMLOperatorTensor* dst, IMLOperatorTensor* src) const noexcept final;
STDMETHOD(FillTensorWithPattern)(
IMLOperatorTensor* dst,
gsl::span<const std::byte> value
) const noexcept final;
STDMETHOD(UploadToResource)(ID3D12Resource* dstData, const void* srcData, uint64_t srcDataSize) const noexcept final;
std::vector<std::unique_ptr<onnxruntime::ComputeCapability>>
GetCapability(
const onnxruntime::GraphViewer& graph,
const std::vector<const onnxruntime::KernelRegistry*>& registries
) const;
uint32_t GetSuppportedDeviceDataTypeMask() const;
onnxruntime::common::Status CopyTensor(const onnxruntime::Tensor& src, onnxruntime::Tensor& dst) const;
onnxruntime::common::Status WaitForGpuCompletion();
// IWinmlExecutionProvider methods
void QueueReference(IUnknown* object) override;
void GetShadowCopyIfRequired(
bool isInternalOperator,
IUnknown* data,
IUnknown** dataCopy) const override;
void GetABIDataInterface(
bool isInternalOperator,
IUnknown* data,
IUnknown** abiData) const override;
uint64_t TryGetPooledAllocationId(
IUnknown* data,
bool isInternalOperator) override;
void GetABIExecutionInterface(
bool isInternalOperator,
IUnknown** abiExecutionObject) const override;
bool TransitionsRequiredForOperator(
bool isInternalOperator
) override;
void TransitionResourcesForOperator(
bool isBeforeOp,
uint32_t resourceCount,
IUnknown** resources
) override;
STDMETHOD_(D3D12_COMMAND_LIST_TYPE, GetCommandListTypeForQueue)() const override;
STDMETHOD_(void, Flush)() const override;
void SetDefaultRoundingMode(AllocatorRoundingMode roundingMode);
// Waits for flushed work, discards unflushed work, and discards associated references to
// prevent circular references. Must be the last call on the object before destruction.
void Close();
// Allocate a resource from pools. Releasing pooledResource returns it to the pool.
STDMETHOD(AllocatePooledResource)(
size_t size,
AllocatorRoundingMode roundingMode,
ID3D12Resource **d3dResource,
IUnknown* *pooledResource
) const noexcept final;
STDMETHOD_(ID3D12Resource*, DecodeResource)(void* allocation) const noexcept final;
std::shared_ptr<onnxruntime::KernelRegistry> GetKernelRegistry() const
{
return m_kernelRegistry;
}
STDMETHOD_(bool, IsMcdmDevice)() const noexcept final;
STDMETHOD_(bool, MetacommandsEnabled)() const noexcept final;
std::shared_ptr<onnxruntime::IAllocator> GetGpuAllocator();
std::shared_ptr<onnxruntime::IAllocator> GetCpuInputAllocator();
std::shared_ptr<onnxruntime::IAllocator> GetCpuOutputAllocator();
private:
void Initialize(ID3D12CommandQueue* queue, ExecutionProvider& executionProvider);
ComPtr<ID3D12Device> m_d3d12Device;
ComPtr<IDMLDevice> m_dmlDevice;
bool m_isMcdmDevice = false;
bool m_areMetacommandsEnabled = true;
std::shared_ptr<ExecutionContext> m_context;
std::unique_ptr<PooledUploadHeap> m_uploadHeap;
std::unique_ptr<ReadbackHeap> m_readbackHeap;
std::shared_ptr<BucketizedBufferAllocator> m_allocator;
std::shared_ptr<CPUAllocator> m_cpuInputAllocator;
std::shared_ptr<CPUAllocator> m_cpuOutputAllocator;
std::shared_ptr<onnxruntime::KernelRegistry> m_kernelRegistry;
std::shared_ptr<const winrt::Windows::AI::MachineLearning::implementation::GraphNodeFactoryMap> m_graphNodeFactoryMap;
mutable uint64_t m_partitionKernelPrefixVal = 0;
bool m_closed = false;
};
class DataTransfer : public onnxruntime::IDataTransfer
{
public:
DataTransfer() = delete;
DataTransfer(ExecutionProviderImpl* impl) : m_impl(impl)
{
}
onnxruntime::common::Status CopyTensor(const onnxruntime::Tensor& src, onnxruntime::Tensor& dst) const final
{
return CopyTensor(src, dst, 0);
}
onnxruntime::common::Status CopyTensor(const onnxruntime::Tensor& src, onnxruntime::Tensor& dst, int exec_queue_id) const final
{
assert(exec_queue_id == 0);
return m_impl->CopyTensor(src, dst);
}
bool CanCopy(const OrtDevice& srcDevice, const OrtDevice& dstDevice) const final
{
return (srcDevice.Type() == OrtDevice::GPU) ||
(dstDevice.Type() == OrtDevice::GPU);
}
private:
ComPtr<ExecutionProviderImpl> m_impl;
};
class ExecutionProvider : public onnxruntime::IExecutionProvider
{
public:
virtual ~ExecutionProvider();
ExecutionProvider() = delete;
explicit ExecutionProvider(
IDMLDevice* dmlDevice,
ID3D12CommandQueue* commandQueue,
bool enableMetacommands = true
);
std::unique_ptr<onnxruntime::IDataTransfer> GetDataTransfer() const final
{
return std::make_unique<DataTransfer>(m_impl.Get());
}
const void* GetExecutionHandle() const noexcept final
{
return m_impl.Get();
}
std::shared_ptr<onnxruntime::KernelRegistry> GetKernelRegistry() const final
{
return m_impl->GetKernelRegistry();
}
std::vector<std::unique_ptr<onnxruntime::ComputeCapability>>
GetCapability(const onnxruntime::GraphViewer& graph,
const std::vector<const onnxruntime::KernelRegistry*>& kernel_registries) const final;
// Not to be confused with IExecutionProvider::Sync() const. The DML provider handles
// synchronization when copying inputs and outputs, therefore doesn't override the
// default ORT method, which does nothin.
onnxruntime::common::Status WaitForGpuCompletion()
{
return m_impl->WaitForGpuCompletion();
}
void Flush()
{
return m_impl->Flush();
}
void SetDefaultRoundingMode(AllocatorRoundingMode roundingMode)
{
return m_impl->SetDefaultRoundingMode(roundingMode);
}
void ReleaseCompletedReferences()
{
return m_impl->ReleaseCompletedReferences();
}
void TrimUploadHeap()
{
m_impl->TrimUploadHeap();
}
ExecutionProviderImpl* GetImpl()
{
return m_impl.Get();
}
void MetacommandsEnabled()
{
m_impl->MetacommandsEnabled();
}
private:
ComPtr<ExecutionProviderImpl> m_impl;
};
} // namespace Dml

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -0,0 +1,71 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
class OperatorField;
struct AbstractOperatorDesc
{
const DML_OPERATOR_SCHEMA* schema = nullptr;
std::vector<OperatorField> fields;
AbstractOperatorDesc() = default;
AbstractOperatorDesc(const DML_OPERATOR_SCHEMA* schema, std::vector<OperatorField>&& fields)
: schema(schema)
, fields(std::move(fields))
{}
std::vector<DmlBufferTensorDesc*> GetInputTensors()
{
return GetTensors<DmlBufferTensorDesc, DML_SCHEMA_FIELD_KIND_INPUT_TENSOR>();
}
std::vector<const DmlBufferTensorDesc*> GetInputTensors() const
{
return GetTensors<const DmlBufferTensorDesc, DML_SCHEMA_FIELD_KIND_INPUT_TENSOR>();
}
std::vector<DmlBufferTensorDesc*> GetOutputTensors()
{
return GetTensors<DmlBufferTensorDesc, DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR>();
}
std::vector<const DmlBufferTensorDesc*> GetOutputTensors() const
{
return GetTensors<const DmlBufferTensorDesc, DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR>();
}
private:
template <typename TensorType, DML_SCHEMA_FIELD_KIND Kind>
std::vector<TensorType*> GetTensors() const
{
std::vector<TensorType*> tensors;
for (auto& field : fields)
{
const DML_SCHEMA_FIELD* fieldSchema = field.GetSchema();
if (fieldSchema->Kind != Kind)
{
continue;
}
if (fieldSchema->Type == DML_SCHEMA_FIELD_TYPE_TENSOR_DESC)
{
auto& tensor = field.AsTensorDesc();
tensors.push_back(tensor ? const_cast<TensorType*>(&*tensor) : nullptr);
}
else if (fieldSchema->Type == DML_SCHEMA_FIELD_TYPE_TENSOR_DESC_ARRAY)
{
auto& tensorArray = field.AsTensorDescArray();
if (tensorArray)
{
for (auto& tensor : *tensorArray)
{
tensors.push_back(const_cast<TensorType*>(&tensor));
}
}
}
}
return tensors;
}
};

Просмотреть файл

@ -0,0 +1,218 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
union ActivationOperatorDescUnion
{
DML_ACTIVATION_IDENTITY_OPERATOR_DESC identity;
DML_ACTIVATION_ELU_OPERATOR_DESC elu;
DML_ACTIVATION_HARDMAX_OPERATOR_DESC hardmax;
DML_ACTIVATION_HARD_SIGMOID_OPERATOR_DESC hardSigmoid;
DML_ACTIVATION_LEAKY_RELU_OPERATOR_DESC leakyRelu;
DML_ACTIVATION_LINEAR_OPERATOR_DESC linear;
DML_ACTIVATION_LOG_SOFTMAX_OPERATOR_DESC logSoftmax;
DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_DESC parameterizedRelu;
DML_ACTIVATION_PARAMETRIC_SOFTPLUS_OPERATOR_DESC parametricSoftplus;
DML_ACTIVATION_RELU_OPERATOR_DESC relu;
DML_ACTIVATION_SCALED_TANH_OPERATOR_DESC scaledTanh;
DML_ACTIVATION_SCALED_ELU_OPERATOR_DESC scaledElu;
DML_ACTIVATION_SIGMOID_OPERATOR_DESC sigmoid;
DML_ACTIVATION_SOFTMAX_OPERATOR_DESC softmax;
DML_ACTIVATION_SOFTPLUS_OPERATOR_DESC softplus;
DML_ACTIVATION_SOFTSIGN_OPERATOR_DESC softsign;
DML_ACTIVATION_TANH_OPERATOR_DESC tanh;
DML_ACTIVATION_THRESHOLDED_RELU_OPERATOR_DESC thresholdedRelu;
DML_ACTIVATION_SHRINK_OPERATOR_DESC shrink;
};
struct ActivationOperatorDesc
{
ActivationOperatorDescUnion params;
DML_OPERATOR_TYPE activationType;
DML_OPERATOR_DESC GetDmlDesc() const
{
switch (activationType)
{
case DML_OPERATOR_ACTIVATION_ELU: return { activationType, &params.elu };
case DML_OPERATOR_ACTIVATION_HARDMAX: return { activationType, &params.hardmax };
case DML_OPERATOR_ACTIVATION_HARD_SIGMOID: return { activationType, &params.sigmoid };
case DML_OPERATOR_ACTIVATION_IDENTITY: return { activationType, &params.identity };
case DML_OPERATOR_ACTIVATION_LEAKY_RELU: return { activationType, &params.leakyRelu };
case DML_OPERATOR_ACTIVATION_LINEAR: return { activationType, &params.linear };
case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX: return { activationType, &params.logSoftmax };
case DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU: return { activationType, &params.parameterizedRelu };
case DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS: return { activationType, &params.parametricSoftplus };
case DML_OPERATOR_ACTIVATION_RELU: return { activationType, &params.relu };
case DML_OPERATOR_ACTIVATION_SCALED_ELU: return { activationType, &params.scaledElu };
case DML_OPERATOR_ACTIVATION_SCALED_TANH: return { activationType, &params.scaledTanh };
case DML_OPERATOR_ACTIVATION_SIGMOID: return { activationType, &params.sigmoid };
case DML_OPERATOR_ACTIVATION_SOFTMAX: return { activationType, &params.softmax };
case DML_OPERATOR_ACTIVATION_SOFTPLUS: return { activationType, &params.softplus };
case DML_OPERATOR_ACTIVATION_SOFTSIGN: return { activationType, &params.softsign };
case DML_OPERATOR_ACTIVATION_TANH: return { activationType, &params.tanh };
case DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU: return { activationType, &params.thresholdedRelu };
case DML_OPERATOR_ACTIVATION_SHRINK: return { activationType, &params.shrink };
default: THROW_HR(E_INVALIDARG);
}
}
};
// DML_BUFFER_TENSOR_DESC (DML_TENSOR_TYPE_BUFFER)
struct DmlBufferTensorDesc
{
DML_TENSOR_DATA_TYPE dataType = DML_TENSOR_DATA_TYPE_UNKNOWN;
DML_TENSOR_FLAGS flags = DML_TENSOR_FLAG_NONE;
std::vector<uint32_t> sizes;
std::optional<std::vector<uint32_t>> strides;
uint64_t totalTensorSizeInBytes = 0;
uint32_t guaranteedBaseOffsetAlignment = 0;
DmlBufferTensorDesc() = default;
/*implicit*/ DmlBufferTensorDesc(const DML_BUFFER_TENSOR_DESC& desc)
: dataType(desc.DataType)
, flags(desc.Flags)
, sizes(desc.Sizes, desc.Sizes + desc.DimensionCount)
, totalTensorSizeInBytes(desc.TotalTensorSizeInBytes)
, guaranteedBaseOffsetAlignment(desc.GuaranteedBaseOffsetAlignment)
{
if (desc.Strides)
{
strides.emplace(desc.Strides, desc.Strides + desc.DimensionCount);
}
}
// Constructs a DmlBufferTensorDesc from a generic DML_TENSOR_DESC. The type must be DML_TENSOR_TYPE_BUFFER.
/*implicit*/ DmlBufferTensorDesc(const DML_TENSOR_DESC& desc)
: DmlBufferTensorDesc(*static_cast<const DML_BUFFER_TENSOR_DESC*>(desc.Desc))
{
assert(desc.Type == DML_TENSOR_TYPE_BUFFER);
}
};
template <size_t Size>
class StackAllocator
{
public:
StackAllocator() = default;
// Non-copiable, non-movable
StackAllocator(const StackAllocator&) = delete;
StackAllocator& operator=(const StackAllocator&) = delete;
StackAllocator(StackAllocator&&) = delete;
StackAllocator& operator=(StackAllocator&&) = delete;
template <typename T>
T* Allocate(size_t count = 1)
{
static_assert(std::is_trivial_v<T>,
"This class may only be used to allocate trivial types, as it does not invoke constructors.");
// Allocate from the fixed bucket before falling back to dynamic
Bucket* lastBucket = m_dynamic.empty() ? static_cast<Bucket*>(&m_fixed) : static_cast<Bucket*>(&m_dynamic.back());
size_t sizeInBytes = sizeof(T) * count;
void* memory = lastBucket->TryAllocate(sizeInBytes, alignof(T));
if (!memory)
{
// Not enough capacity remains; allocate a new dynamic bucket
size_t minimumSize = sizeInBytes;
m_dynamic.emplace_back(minimumSize);
memory = m_dynamic.back().TryAllocate(sizeInBytes, alignof(T));
}
assert(memory != nullptr);
return reinterpret_cast<T*>(memory);
}
void Reset()
{
m_fixed.allocatedSize = 0;
m_dynamic.clear();
}
private:
struct Bucket
{
void* data;
size_t allocatedSize;
size_t capacity;
Bucket() = default;
// Non-copiable, non-movable
Bucket(const Bucket&) = delete;
Bucket& operator=(const Bucket&) = delete;
Bucket(Bucket&&) = delete;
Bucket& operator=(Bucket&&) = delete;
template <typename T>
static T RoundUpToMultiple(T value, T multiple)
{
static_assert(std::is_integral_v<T>);
T remainder = value % multiple;
if (remainder != 0)
{
value += multiple - remainder;
}
return value;
}
void* TryAllocate(size_t sizeInBytes, size_t alignment)
{
size_t alignedOffset = RoundUpToMultiple(allocatedSize, alignment);
size_t newAllocatedSize = alignedOffset + sizeInBytes;
if (newAllocatedSize > capacity)
{
return nullptr; // Not enough capacity
}
allocatedSize = newAllocatedSize;
return static_cast<byte*>(data) + alignedOffset;
}
};
struct FixedBucket : Bucket
{
std::array<byte, Size> stack;
FixedBucket()
{
this->data = stack.data();
this->allocatedSize = 0;
this->capacity = stack.size();
}
};
struct DynamicBucket : Bucket
{
explicit DynamicBucket(size_t minimumSize)
{
this->allocatedSize = 0;
this->capacity = RoundUpToMultiple<size_t>(minimumSize, 4096); // Round up to nearest page granularity
this->data = VirtualAlloc(nullptr, this->capacity, MEM_COMMIT | MEM_RESERVE, PAGE_READWRITE);
THROW_LAST_ERROR_IF_NULL(this->data);
}
~DynamicBucket()
{
if (data)
{
(void)VirtualFree(data, 0, MEM_RELEASE);
}
}
};
// This allocator first retrieves memory from a fixed-size stack-allocated array before falling back to dynamically
// allocated memory if the fixed stack array is exhausted.
FixedBucket m_fixed;
std::deque<DynamicBucket> m_dynamic;
};

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -0,0 +1,105 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
using ApiAttributeVariant = std::variant<
const DML_TENSOR_DESC*,
const DML_OPERATOR_DESC*,
UINT,
INT,
FLOAT,
const UINT*,
const FLOAT*,
const DML_SCALE_BIAS*,
DML_SIZE_2D
>;
namespace OperatorFieldTypes
{
using TensorDesc = std::optional<DmlBufferTensorDesc>; // DML_SCHEMA_FIELD_TYPE_TENSOR_DESC
using TensorDescArray = std::optional<std::vector<DmlBufferTensorDesc>>; // DML_SCHEMA_FIELD_TYPE_TENSOR_DESC_ARRAY
using OperatorDesc = std::optional<AbstractOperatorDesc>; // DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC
using OperatorDescArray = std::optional<std::vector<AbstractOperatorDesc>>; // DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC_ARRAY
using UInt = uint32_t; // DML_SCHEMA_FIELD_TYPE_UINT
using Int = int32_t; // DML_SCHEMA_FIELD_TYPE_INT
using Float = float; // DML_SCHEMA_FIELD_TYPE_FLOAT
using UIntArray = std::optional<std::vector<uint32_t>>; // DML_SCHEMA_FIELD_TYPE_UINT_ARRAY
using FloatArray = std::optional<std::vector<float>>; // DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY
using ScaleBias = std::optional<DML_SCALE_BIAS>; // DML_SCHEMA_FIELD_TYPE_SCALE_BIAS
using Size2D = DML_SIZE_2D; // DML_SCHEMA_FIELD_TYPE_SIZE_2D
}
using OperatorFieldVariant = std::variant<
OperatorFieldTypes::TensorDesc,
OperatorFieldTypes::TensorDescArray,
OperatorFieldTypes::OperatorDesc,
OperatorFieldTypes::OperatorDescArray,
OperatorFieldTypes::UInt,
OperatorFieldTypes::Int,
OperatorFieldTypes::Float,
OperatorFieldTypes::UIntArray,
OperatorFieldTypes::FloatArray,
OperatorFieldTypes::ScaleBias,
OperatorFieldTypes::Size2D
>;
class OperatorField
{
public:
OperatorField() = default;
explicit OperatorField(const DML_SCHEMA_FIELD* schema, OperatorFieldVariant&& data)
: m_schema(schema)
, m_data(std::move(data))
{
assert(m_schema->Type == (DML_SCHEMA_FIELD_TYPE)m_data.index());
}
const DML_SCHEMA_FIELD* GetSchema() const
{
return m_schema;
}
const OperatorFieldVariant& GetData() const
{
return m_data;
}
const OperatorFieldTypes::TensorDesc& AsTensorDesc() const { return std::get<OperatorFieldTypes::TensorDesc>(m_data); }
OperatorFieldTypes::TensorDesc& AsTensorDesc() { return std::get<OperatorFieldTypes::TensorDesc>(m_data); }
const OperatorFieldTypes::TensorDescArray& AsTensorDescArray() const { return std::get<OperatorFieldTypes::TensorDescArray>(m_data); }
OperatorFieldTypes::TensorDescArray& AsTensorDescArray() { return std::get<OperatorFieldTypes::TensorDescArray>(m_data); }
const OperatorFieldTypes::OperatorDesc& AsOperatorDesc() const { return std::get<OperatorFieldTypes::OperatorDesc>(m_data); }
OperatorFieldTypes::OperatorDesc& AsOperatorDesc() { return std::get<OperatorFieldTypes::OperatorDesc>(m_data); }
const OperatorFieldTypes::OperatorDescArray& AsOperatorDescArray() const { return std::get<OperatorFieldTypes::OperatorDescArray>(m_data); }
OperatorFieldTypes::OperatorDescArray& AsOperatorDescArray() { return std::get<OperatorFieldTypes::OperatorDescArray>(m_data); }
const OperatorFieldTypes::UInt& AsUInt() const { return std::get<OperatorFieldTypes::UInt>(m_data); }
OperatorFieldTypes::UInt& AsUInt() { return std::get<OperatorFieldTypes::UInt>(m_data); }
const OperatorFieldTypes::Int& AsInt() const { return std::get<OperatorFieldTypes::Int>(m_data); }
OperatorFieldTypes::Int& AsInt() { return std::get<OperatorFieldTypes::Int>(m_data); }
const OperatorFieldTypes::Float& AsFloat() const { return std::get<OperatorFieldTypes::Float>(m_data); }
OperatorFieldTypes::Float& AsFloat() { return std::get<OperatorFieldTypes::Float>(m_data); }
const OperatorFieldTypes::UIntArray& AsUIntArray() const { return std::get<OperatorFieldTypes::UIntArray>(m_data); }
OperatorFieldTypes::UIntArray& AsUIntArray() { return std::get<OperatorFieldTypes::UIntArray>(m_data); }
const OperatorFieldTypes::FloatArray& AsFloatArray() const { return std::get<OperatorFieldTypes::FloatArray>(m_data); }
OperatorFieldTypes::FloatArray& AsFloatArray() { return std::get<OperatorFieldTypes::FloatArray>(m_data); }
const OperatorFieldTypes::ScaleBias& AsScaleBias() const { return std::get<OperatorFieldTypes::ScaleBias>(m_data); }
OperatorFieldTypes::ScaleBias& AsScaleBias() { return std::get<OperatorFieldTypes::ScaleBias>(m_data); }
const OperatorFieldTypes::Size2D& AsSize2D() const { return std::get<OperatorFieldTypes::Size2D>(m_data); }
OperatorFieldTypes::Size2D& AsSize2D() { return std::get<OperatorFieldTypes::Size2D>(m_data); }
private:
const DML_SCHEMA_FIELD* m_schema;
OperatorFieldVariant m_data;
};

Просмотреть файл

@ -0,0 +1,345 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
namespace SchemaHelpers
{
inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc);
inline OperatorFieldTypes::TensorDesc ToOperatorFieldType(const DML_TENSOR_DESC* value)
{
return value ? OperatorFieldTypes::TensorDesc(*value) : std::nullopt;
}
inline OperatorFieldTypes::TensorDescArray ToOperatorFieldType(const DML_TENSOR_DESC* values, uint32_t count)
{
OperatorFieldTypes::TensorDescArray field;
if (values && count != 0)
{
field.emplace(count);
for (uint32_t i = 0; i < count; ++i)
{
(*field)[i] = values[i];
}
}
return field;
}
inline OperatorFieldTypes::OperatorDesc ToOperatorFieldType(const DML_OPERATOR_DESC* value)
{
return value ? OperatorFieldTypes::OperatorDesc(ConvertOperatorDesc(*value)) : std::nullopt;
}
inline OperatorFieldTypes::OperatorDescArray ToOperatorFieldType(const DML_OPERATOR_DESC* values, uint32_t count)
{
OperatorFieldTypes::OperatorDescArray field;
if (values && count != 0)
{
field.emplace(count);
for (uint32_t i = 0; i < count; ++i)
{
(*field)[i] = ConvertOperatorDesc(values[i]);
}
}
return field;
}
inline OperatorFieldTypes::UInt ToOperatorFieldType(uint32_t value)
{
return value;
}
inline OperatorFieldTypes::Int ToOperatorFieldType(int32_t value)
{
return value;
}
inline OperatorFieldTypes::Float ToOperatorFieldType(float value)
{
return value;
}
inline OperatorFieldTypes::UIntArray ToOperatorFieldType(const uint32_t* values, uint32_t count)
{
OperatorFieldTypes::UIntArray field;
if (values && count != 0)
{
field.emplace(count);
std::copy_n(values, count, field->begin());
}
return field;
}
inline OperatorFieldTypes::FloatArray ToOperatorFieldType(const float* values, uint32_t count)
{
OperatorFieldTypes::FloatArray field;
if (values && count != 0)
{
field.emplace(count);
std::copy_n(values, count, field->begin());
}
return field;
}
inline OperatorFieldTypes::ScaleBias ToOperatorFieldType(const DML_SCALE_BIAS* value)
{
return value ? OperatorFieldTypes::ScaleBias(*value) : std::nullopt;
}
inline OperatorFieldTypes::Size2D ToOperatorFieldType(DML_SIZE_2D value)
{
return value;
}
class StructFieldWriter
{
public:
explicit StructFieldWriter(gsl::span<byte> dst)
: m_dst(dst)
, m_bytesWritten(0)
{}
template <typename T>
void Write(const T& value)
{
static_assert(std::is_trivial_v<T>, "Only trivial types are supported.");
size_t dstOffset = RoundUpToMultiple(m_bytesWritten, alignof(T));
size_t newBytesWritten = dstOffset + sizeof(value);
assert(newBytesWritten <= gsl::narrow_cast<size_t>(m_dst.size()));
memcpy(m_dst.data() + dstOffset, &value, sizeof(value));
m_bytesWritten = newBytesWritten;
}
private:
template <typename T>
T RoundUpToMultiple(T value, T multiple)
{
static_assert(std::is_integral_v<T>);
T remainder = value % multiple;
if (remainder != 0)
{
value += multiple - remainder;
}
return value;
}
gsl::span<byte> m_dst;
size_t m_bytesWritten;
};
template <size_t N>
DML_BUFFER_TENSOR_DESC MakeBufferTensorDesc(const DmlBufferTensorDesc& src, StackAllocator<N>* allocator)
{
size_t dimensionCount = src.sizes.size();
auto* sizes = allocator->Allocate<UINT>(dimensionCount);
std::copy_n(src.sizes.begin(), dimensionCount, sizes);
UINT* strides = nullptr;
if (src.strides)
{
strides = allocator->Allocate<UINT>(dimensionCount);
std::copy_n(src.strides->begin(), dimensionCount, strides);
}
DML_BUFFER_TENSOR_DESC dst;
dst.DataType = src.dataType;
dst.Flags = src.flags;
dst.Sizes = sizes;
dst.Strides = strides;
dst.DimensionCount = static_cast<UINT>(dimensionCount);
dst.TotalTensorSizeInBytes = src.totalTensorSizeInBytes;
dst.GuaranteedBaseOffsetAlignment = src.guaranteedBaseOffsetAlignment;
return dst;
}
template <size_t N>
DML_TENSOR_DESC MakeTensorDesc(const DmlBufferTensorDesc& src, StackAllocator<N>* allocator)
{
auto* desc = allocator->Allocate<DML_BUFFER_TENSOR_DESC>();
*desc = MakeBufferTensorDesc(src, allocator);
DML_TENSOR_DESC dst;
dst.Type = DML_TENSOR_TYPE_BUFFER;
dst.Desc = desc;
return dst;
}
template <size_t N>
DML_OPERATOR_DESC ConvertOperatorDesc(const AbstractOperatorDesc& abstractDesc, StackAllocator<N>* allocator);
template <size_t N>
void WriteOperatorDescField(const OperatorField& field, StructFieldWriter* dst, StackAllocator<N>* allocator)
{
const DML_SCHEMA_FIELD& schema = *field.GetSchema();
switch (schema.Type)
{
case DML_SCHEMA_FIELD_TYPE_TENSOR_DESC:
{
DML_TENSOR_DESC* desc = nullptr;
const auto& value = field.AsTensorDesc();
if (value)
{
desc = allocator->Allocate<DML_TENSOR_DESC>();
*desc = MakeTensorDesc(*value, allocator);
}
dst->Write(desc);
} break;
case DML_SCHEMA_FIELD_TYPE_TENSOR_DESC_ARRAY:
{
DML_TENSOR_DESC* descs = nullptr;
const auto& values = field.AsTensorDescArray();
if (values)
{
descs = allocator->Allocate<DML_TENSOR_DESC>(values->size());
for (size_t i = 0; i < values->size(); ++i)
{
descs[i] = MakeTensorDesc((*values)[i], allocator);
}
}
dst->Write(descs);
} break;
case DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC:
{
DML_OPERATOR_DESC* desc = nullptr;
const auto& value = field.AsOperatorDesc();
if (value)
{
desc = allocator->Allocate<DML_OPERATOR_DESC>();
*desc = ConvertOperatorDesc(*value, allocator);
}
dst->Write(desc);
} break;
case DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC_ARRAY:
{
DML_OPERATOR_DESC* descs = nullptr;
const auto& values = field.AsOperatorDescArray();
if (values)
{
descs = allocator->Allocate<DML_OPERATOR_DESC>(values->size());
for (size_t i = 0; i < values->size(); ++i)
{
descs[i] = ConvertOperatorDesc((*values)[i], allocator);
}
}
dst->Write(descs);
} break;
case DML_SCHEMA_FIELD_TYPE_UINT:
{
uint32_t value = field.AsUInt();
dst->Write(value);
} break;
case DML_SCHEMA_FIELD_TYPE_INT:
{
int32_t value = field.AsInt();
dst->Write(value);
} break;
case DML_SCHEMA_FIELD_TYPE_FLOAT:
{
float value = field.AsFloat();
dst->Write(value);
} break;
case DML_SCHEMA_FIELD_TYPE_UINT_ARRAY:
{
uint32_t* arrayPtr = nullptr;
const auto& values = field.AsUIntArray();
if (values)
{
arrayPtr = allocator->Allocate<uint32_t>(values->size());
std::copy(values->begin(), values->end(), arrayPtr);
}
dst->Write(arrayPtr);
} break;
case DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY:
{
float* arrayPtr = nullptr;
const auto& values = field.AsFloatArray();
if (values)
{
arrayPtr = allocator->Allocate<float>(values->size());
std::copy(values->begin(), values->end(), arrayPtr);
}
dst->Write(arrayPtr);
} break;
case DML_SCHEMA_FIELD_TYPE_SCALE_BIAS:
{
DML_SCALE_BIAS* scaleBias = nullptr;
const auto& value = field.AsScaleBias();
if (value)
{
scaleBias = allocator->Allocate<DML_SCALE_BIAS>();
*scaleBias = *value;
}
dst->Write(scaleBias);
} break;
case DML_SCHEMA_FIELD_TYPE_SIZE_2D:
{
DML_SIZE_2D value = field.AsSize2D();
dst->Write(value);
} break;
default:
assert(false);
THROW_HR(E_UNEXPECTED);
}
}
template <size_t N>
DML_OPERATOR_DESC ConvertOperatorDesc(const AbstractOperatorDesc& abstractDesc, StackAllocator<N>* allocator)
{
const DML_OPERATOR_SCHEMA& schema = *abstractDesc.schema;
// Retrieve the size of the ABI operator desc struct
size_t abiDescSizeInBytes = ApiTraits::OperatorTypeVisitor(schema.OperatorType, [](auto tag) {
using T = decltype(tag); // T is one of the DML_*_OPERATOR_DESC structs
return sizeof(T);
});
// Allocate a blob of bytes to hold the struct
byte* abiDesc = allocator->Allocate<byte>(abiDescSizeInBytes);
// Use the schema to write data into the blob
StructFieldWriter writer(gsl::make_span(abiDesc, abiDescSizeInBytes));
for (const OperatorField& field : abstractDesc.fields)
{
WriteOperatorDescField(field, &writer, allocator);
}
return DML_OPERATOR_DESC{ schema.OperatorType, abiDesc };
}
} // namespace SchemaHelpers

Просмотреть файл

@ -0,0 +1,763 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
#include "MLOperatorAuthorImpl.h"
#include "FusedGraphKernel.h"
using namespace winrt::Windows::AI::MachineLearning::implementation;
namespace Dml
{
template <typename T>
static T AlignToPow2(T offset, T alignment)
{
static_assert(std::is_unsigned_v<T>);
assert(alignment != 0);
assert((alignment & (alignment - 1)) == 0);
return (offset + alignment - 1) & ~(alignment - 1);
}
class FusedGraphKernel : public onnxruntime::OpKernel
{
public:
FusedGraphKernel() = delete;
FusedGraphKernel(
const onnxruntime::OpKernelInfo& kernelInfo,
const std::unordered_map<std::string, GraphNodeProperties> &graphNodePropertyMap,
std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap) : OpKernel(kernelInfo)
{
// Get the graph for the function which was created according to the computational
// capacity returned by the execution provider's graph partitioner
auto& node = kernelInfo.node();
THROW_HR_IF(E_UNEXPECTED, node.NodeType() != onnxruntime::Node::Type::Fused);
auto func = node.GetFunctionBody();
const onnxruntime::Graph& graph = func->Body();
// Get the shapes for outputs of the overall graph. These should be static, because
// the partitioner checked that each node has static shapes before fusing into a
// graph partition.
THROW_HR_IF(E_UNEXPECTED, !TryGetStaticOutputShapes(node, m_outputShapes));
// Get the execution provider interfaces
m_executionHandle = kernelInfo.GetExecutionProvider()->GetExecutionHandle();
if (m_executionHandle)
{
// We assume the execution object inherits IUnknown as its first base
ComPtr<IUnknown> providerExecutionObject = const_cast<IUnknown*>(static_cast<const IUnknown*>(m_executionHandle));
// Get the WinML-specific execution provider interface from the execution object.
THROW_IF_FAILED(providerExecutionObject.As(&m_provider));
THROW_IF_FAILED(providerExecutionObject.As(&m_winmlProvider));
}
TranslateAndCompileGraph(kernelInfo, graph, node.InputDefs(), node.OutputDefs(), graphNodePropertyMap, transferredInitializerMap);
}
void TranslateAndCompileGraph(
const onnxruntime::OpKernelInfo& kernelInfo,
const onnxruntime::Graph& graph,
const onnxruntime::ConstPointerContainer<std::vector<onnxruntime::NodeArg*>>& fusedNodeInputDefs,
const onnxruntime::ConstPointerContainer<std::vector<onnxruntime::NodeArg*>>& fusedNodeOutputDefs,
const std::unordered_map<std::string, GraphNodeProperties>& graphNodePropertyMap,
std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap
)
{
ComPtr<IDMLDevice> device;
THROW_IF_FAILED(m_provider->GetDmlDevice(device.GetAddressOf()));
ComPtr<IDMLDevicePreview> devicePreview;
THROW_IF_FAILED(device.As(&devicePreview));
const uint32_t graphInputCount = kernelInfo.GetInputCount();
auto gpuGraphInputConstnessGetter = [&kernelInfo, &fusedNodeInputDefs, &transferredInitializerMap](uint32_t index)
{
// Transferred initializers are uploaded to GPU memory
auto iter = transferredInitializerMap.find(fusedNodeInputDefs[index]->Name());
if (iter != transferredInitializerMap.end())
{
return true;
}
// If an initializer wasn't transferred, the constant input may be available from ORT
const onnxruntime::Tensor* inputTensor = nullptr;
if (!kernelInfo.TryGetConstantInput(index, &inputTensor) || inputTensor == nullptr)
{
return false;
}
// Check that the constant ORT input is in GPU memory
if (!strcmp(inputTensor->Location().name, onnxruntime::CPU) ||
inputTensor->Location().mem_type == ::OrtMemType::OrtMemTypeCPUOutput ||
inputTensor->Location().mem_type == ::OrtMemType::OrtMemTypeCPUInput)
{
return false;
}
return true;
};
m_inputsConstant.resize(graphInputCount);
for (uint32_t i = 0; i < graphInputCount; ++i)
{
m_inputsConstant[i] = gpuGraphInputConstnessGetter(i);
}
GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc(
kernelInfo,
m_inputsConstant,
transferredInitializerMap,
graph,
fusedNodeInputDefs,
fusedNodeOutputDefs,
graphNodePropertyMap,
device.Get(),
m_executionHandle);
// Determine the last input which uses an initializer, so initializers can be freed incrementally
// while processing each input in order.
std::map<const onnx::TensorProto*, uint32_t> initializerToLastInputIndexMap;
for (uint32_t i = 0; i < graphInputCount; i++)
{
auto iter = transferredInitializerMap.find(fusedNodeInputDefs[i]->Name());
if (iter != transferredInitializerMap.end())
{
initializerToLastInputIndexMap[&iter->second] = i;
}
}
// Walk through each graph edge and mark used inputs
m_inputsUsed.assign(graphInputCount, false);
for (const DML_PREVIEW_INPUT_GRAPH_EDGE& edge : graphDesc.inputEdges)
{
m_inputsUsed[edge.GraphInputIndex] = true;
}
// Populate input bindings for operator initialization
std::vector<ComPtr<ID3D12Resource>> initInputResources; // For lifetime control
std::vector<DML_BUFFER_BINDING> initInputBindings(graphInputCount);
m_nonOwnedGraphInputsFromInitializers.resize(graphInputCount);
std::vector<ComPtr<ID3D12Resource>> initializeResourceRefs;
for (uint32_t i = 0; i < initInputBindings.size(); i++)
{
// If the input isn't actually used by the graph, nothing ever needs to be bound (either for
// initialization or execution). So just throw away the transferred initializer and skip this input.
if (!m_inputsUsed[i])
{
transferredInitializerMap.erase(fusedNodeInputDefs[i]->Name());
continue;
}
// Look for the initializer among those transferred from the graph during partitioning
auto iter = transferredInitializerMap.find(fusedNodeInputDefs[i]->Name());
if (iter != transferredInitializerMap.end())
{
std::byte* tensorPtr = nullptr;
size_t tensorByteSize = 0;
std::unique_ptr<std::byte[]> unpackedTensor;
auto& initializer = iter->second;
// The tensor may be stored as raw data or in typed fields.
if (initializer.has_raw_data())
{
tensorPtr = (std::byte*)(initializer.raw_data().c_str());
tensorByteSize = initializer.raw_data().size();
}
else
{
std::tie(unpackedTensor, tensorByteSize) = winrt::Windows::AI::MachineLearning::implementation::UnpackTensor(initializer);
tensorPtr = unpackedTensor.get();
}
// Tensor sizes in DML must be a multiple of 4 bytes large.
tensorByteSize = AlignToPow2<size_t>(tensorByteSize, 4);
if (!m_inputsConstant[i])
{
// Store the resource to use during execution
ComPtr<ID3D12Resource> defaultBuffer = CreateResource(tensorPtr, tensorByteSize);
m_nonOwnedGraphInputsFromInitializers[i] = defaultBuffer;
initializeResourceRefs.push_back(std::move(defaultBuffer));
}
else
{
ComPtr<ID3D12Resource> initializeInputBuffer;
// D3D_FEATURE_LEVEL_1_0_CORE doesn't support Custom heaps
if (m_provider->IsMcdmDevice())
{
initializeInputBuffer = CreateResource(tensorPtr, tensorByteSize);
}
else
{
initializeInputBuffer = CreateCpuResource(tensorPtr, tensorByteSize);
}
// Set the binding for operator initialization to the buffer
initInputBindings[i].Buffer = initializeInputBuffer.Get();
initInputBindings[i].SizeInBytes = tensorByteSize;
initializeResourceRefs.push_back(std::move(initializeInputBuffer));
}
// Free the initializer if this is the last usage of it.
if (initializerToLastInputIndexMap[&initializer] == i)
{
transferredInitializerMap.erase(iter);
}
}
else if (m_inputsConstant[i])
{
const onnxruntime::Tensor* inputTensor = nullptr;
THROW_HR_IF(E_UNEXPECTED, !kernelInfo.TryGetConstantInput(i, &inputTensor));
uint64_t allocId;
UnwrapTensor(inputTensor, &initInputBindings[i].Buffer, &allocId);
initInputBindings[i].SizeInBytes = initInputBindings[i].Buffer->GetDesc().Width;
initInputBindings[i].Buffer->Release(); // Avoid holding an additional reference
initInputResources.push_back(initInputBindings[i].Buffer);
}
}
// All initializers should have been consumed and freed above
assert(transferredInitializerMap.empty());
std::vector<DML_PREVIEW_OPERATOR_GRAPH_NODE> dmlOperatorGraphNodes(graphDesc.nodes.size());
std::vector<DML_PREVIEW_GRAPH_NODE> dmlGraphNodes(graphDesc.nodes.size());
std::vector<DML_PREVIEW_GRAPH_EDGE> dmlInputEdges(graphDesc.inputEdges.size());
std::vector<DML_PREVIEW_GRAPH_EDGE> dmlOutputEdges(graphDesc.outputEdges.size());
std::vector<DML_PREVIEW_GRAPH_EDGE> dmlIntermediateEdges(graphDesc.intermediateEdges.size());
for (size_t i = 0; i < graphDesc.nodes.size(); ++i)
{
dmlOperatorGraphNodes[i] = DML_PREVIEW_OPERATOR_GRAPH_NODE{ graphDesc.nodes[i].op.Get() };
dmlGraphNodes[i] = DML_PREVIEW_GRAPH_NODE{ DML_PREVIEW_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i] };
}
for (size_t i = 0; i < graphDesc.inputEdges.size(); ++i)
{
dmlInputEdges[i] = DML_PREVIEW_GRAPH_EDGE{ DML_PREVIEW_GRAPH_EDGE_TYPE_INPUT, &graphDesc.inputEdges[i] };
}
for (size_t i = 0; i < graphDesc.outputEdges.size(); ++i)
{
dmlOutputEdges[i] = DML_PREVIEW_GRAPH_EDGE{ DML_PREVIEW_GRAPH_EDGE_TYPE_OUTPUT, &graphDesc.outputEdges[i] };
}
for (size_t i = 0; i < graphDesc.intermediateEdges.size(); ++i)
{
dmlIntermediateEdges[i] = DML_PREVIEW_GRAPH_EDGE{ DML_PREVIEW_GRAPH_EDGE_TYPE_INTERMEDIATE, &graphDesc.intermediateEdges[i] };
}
DML_PREVIEW_GRAPH_DESC dmlGraphDesc = {};
dmlGraphDesc.InputCount = graphInputCount;
dmlGraphDesc.OutputCount = kernelInfo.GetOutputCount();
dmlGraphDesc.NodeCount = gsl::narrow_cast<uint32_t>(dmlGraphNodes.size());
dmlGraphDesc.Nodes = dmlGraphNodes.data();
dmlGraphDesc.InputEdgeCount = gsl::narrow_cast<uint32_t>(dmlInputEdges.size());
dmlGraphDesc.InputEdges = dmlInputEdges.data();
dmlGraphDesc.OutputEdgeCount = gsl::narrow_cast<uint32_t>(dmlOutputEdges.size());
dmlGraphDesc.OutputEdges = dmlOutputEdges.data();
dmlGraphDesc.IntermediateEdgeCount = gsl::narrow_cast<uint32_t>(dmlIntermediateEdges.size());
dmlGraphDesc.IntermediateEdges = dmlIntermediateEdges.data();
DML_EXECUTION_FLAGS executionFlags = DML_EXECUTION_FLAG_NONE;
if (graphDesc.reuseCommandList)
{
executionFlags |= DML_EXECUTION_FLAG_DESCRIPTORS_VOLATILE;
}
// Query DML execution provider to see if metacommands is enabled
if (!m_provider->MetacommandsEnabled())
{
executionFlags |= DML_EXECUTION_FLAG_DISABLE_META_COMMANDS;
}
THROW_IF_FAILED(devicePreview->CompileGraph(
&dmlGraphDesc,
executionFlags,
IID_PPV_ARGS(&m_compiledExecutionPlanOperator)));
// Allocate a persistent resource and initialize the operator
UINT64 persistentResourceSize = m_compiledExecutionPlanOperator->GetBindingProperties().PersistentResourceSize;
if (persistentResourceSize > 0)
{
THROW_IF_FAILED(m_provider->AllocatePooledResource(
persistentResourceSize,
AllocatorRoundingMode::Disabled,
m_persistentResource.GetAddressOf(),
m_persistentResourceAllocatorUnk.GetAddressOf()));
m_persistentResourceBinding = DML_BUFFER_BINDING { m_persistentResource.Get(), 0, persistentResourceSize };
}
THROW_IF_FAILED(m_provider->InitializeOperator(
m_compiledExecutionPlanOperator.Get(),
m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr,
gsl::make_span(initInputBindings)));
// Queue references to objects which must be kept alive until resulting GPU work completes
m_winmlProvider->QueueReference(m_compiledExecutionPlanOperator.Get());
m_winmlProvider->QueueReference(m_persistentResourceAllocatorUnk.Get());
std::for_each(
initializeResourceRefs.begin(),
initializeResourceRefs.end(),
[&](ComPtr<ID3D12Resource>& resource){ m_winmlProvider->QueueReference(resource.Get()); }
);
if (graphDesc.reuseCommandList)
{
BuildReusableCommandList();
}
}
onnxruntime::Status Compute(onnxruntime::OpKernelContext* kernelContext) const override
{
// Only re-use the cached command list if its prior execution is complete on the GPU.
// This requirement can be avoided by mantaining ring buffers.
if (!m_graphicsCommandList ||
(m_fence != nullptr && m_fence->GetCompletedValue() < m_completionValue))
{
// Wrap tensors as required by Dml::IExecutionProvider::ExecuteOperator
OpKernelContextWrapper contextWrapper(
kernelContext,
Info().GetExecutionProvider(),
true,
nullptr);
THROW_IF_FAILED(m_provider->AddUAVBarrier());
// Get input resources for execution, excluding those which were specified as owned by DML and provided
// at initialization instead.
std::vector<ComPtr<IMLOperatorTensor>> inputTensors(kernelContext->InputCount());
std::vector<ID3D12Resource*> inputPtrs(kernelContext->InputCount());
for (int i = 0; i < kernelContext->InputCount(); ++i)
{
if (!m_inputsUsed[i])
{
continue;
}
if (m_nonOwnedGraphInputsFromInitializers[i])
{
inputPtrs[i] = m_nonOwnedGraphInputsFromInitializers[i].Get();
}
else if (!m_inputsConstant[i])
{
THROW_IF_FAILED(contextWrapper.GetInputTensor(i, inputTensors[i].GetAddressOf()));
inputPtrs[i] = m_provider->DecodeResource(MLOperatorTensor(inputTensors[i].Get()).GetDataInterface().Get());
}
}
ExecuteOperator(
m_compiledExecutionPlanOperator.Get(),
m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr,
inputPtrs,
contextWrapper.GetOutputTensors(m_outputShapes));
THROW_IF_FAILED(m_provider->AddUAVBarrier());
// Queue references to objects which must be kept alive until resulting GPU work completes
m_winmlProvider->QueueReference(m_compiledExecutionPlanOperator.Get());
m_winmlProvider->QueueReference(m_persistentResourceAllocatorUnk.Get());
}
else
{
ExecuteReusableCommandList(kernelContext);
}
return onnxruntime::Status::OK();
}
void ExecuteOperator(
IDMLCompiledOperator* op,
_In_opt_ const DML_BUFFER_BINDING* persistentResourceBinding,
gsl::span<ID3D12Resource*> inputTensors,
gsl::span<IMLOperatorTensor*> outputTensors) const
{
auto FillBindingsFromTensors = [this](auto& bufferBindings, auto& bindingDescs, gsl::span<IMLOperatorTensor*>& tensors)
{
for (IMLOperatorTensor* tensor : tensors)
{
if (tensor)
{
assert(tensor->IsDataInterface());
ID3D12Resource* resource = m_provider->DecodeResource(MLOperatorTensor(tensor).GetDataInterface().Get());
D3D12_RESOURCE_DESC resourceDesc = resource->GetDesc();
bufferBindings.push_back({ resource, 0, resourceDesc.Width });
bindingDescs.push_back({ DML_BINDING_TYPE_BUFFER, &bufferBindings.back() });
}
else
{
bufferBindings.push_back({ nullptr, 0, 0 });
bindingDescs.push_back({ DML_BINDING_TYPE_NONE, nullptr });
}
}
};
auto FillBindingsFromBuffers = [this](auto& bufferBindings, auto& bindingDescs, gsl::span<ID3D12Resource*>& resources)
{
for (ID3D12Resource* resource : resources)
{
if (resource)
{
D3D12_RESOURCE_DESC resourceDesc = resource->GetDesc();
bufferBindings.push_back({ resource, 0, resourceDesc.Width });
bindingDescs.push_back({ DML_BINDING_TYPE_BUFFER, &bufferBindings.back() });
}
else
{
bufferBindings.push_back({ nullptr, 0, 0 });
bindingDescs.push_back({ DML_BINDING_TYPE_NONE, nullptr });
}
}
};
std::vector<DML_BUFFER_BINDING> inputBufferBindings;
inputBufferBindings.reserve(inputTensors.size());
std::vector<DML_BINDING_DESC> inputBindings;
inputBindings.reserve(inputTensors.size());
FillBindingsFromBuffers(inputBufferBindings, inputBindings, inputTensors);
std::vector<DML_BUFFER_BINDING> outputBufferBindings;
outputBufferBindings.reserve(outputTensors.size());
std::vector<DML_BINDING_DESC> outputBindings;
outputBindings.reserve(outputTensors.size());
FillBindingsFromTensors(outputBufferBindings, outputBindings, outputTensors);
THROW_IF_FAILED(m_provider->ExecuteOperator(
op,
persistentResourceBinding,
inputBindings,
outputBindings));
}
private:
void BuildReusableCommandList()
{
ComPtr<IDMLDevice> device;
THROW_IF_FAILED(m_provider->GetDmlDevice(device.GetAddressOf()));
DML_BINDING_PROPERTIES execBindingProps = m_compiledExecutionPlanOperator->GetBindingProperties();
D3D12_DESCRIPTOR_HEAP_DESC desc = {};
desc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE;
desc.NumDescriptors = execBindingProps.RequiredDescriptorCount;
desc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV;
ComPtr<ID3D12Device> d3dDevice;
THROW_IF_FAILED(m_provider->GetD3DDevice(d3dDevice.GetAddressOf()));
THROW_IF_FAILED(d3dDevice->CreateDescriptorHeap(&desc, IID_PPV_ARGS(&m_heap)));
// Create a binding table for execution.
DML_BINDING_TABLE_DESC bindingTableDesc = {};
bindingTableDesc.Dispatchable = m_compiledExecutionPlanOperator.Get();
bindingTableDesc.CPUDescriptorHandle = m_heap->GetCPUDescriptorHandleForHeapStart();
bindingTableDesc.GPUDescriptorHandle = m_heap->GetGPUDescriptorHandleForHeapStart();
bindingTableDesc.SizeInDescriptors = execBindingProps.RequiredDescriptorCount;
THROW_IF_FAILED(device->CreateBindingTable(&bindingTableDesc, IID_PPV_ARGS(&m_bindingTable)));
ComPtr<ID3D12CommandAllocator> allocator;
THROW_IF_FAILED(d3dDevice->CreateCommandAllocator(
m_provider->GetCommandListTypeForQueue(),
IID_PPV_ARGS(&allocator)));
ComPtr<ID3D12CommandList> commandList;
THROW_IF_FAILED(d3dDevice->CreateCommandList(
0,
m_provider->GetCommandListTypeForQueue(),
allocator.Get(),
nullptr,
IID_PPV_ARGS(&commandList)));
THROW_IF_FAILED(commandList.As(&m_graphicsCommandList));
if (m_persistentResource)
{
DML_BINDING_DESC persistentResourceBindingDesc =
{ DML_BINDING_TYPE_BUFFER, m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr };
m_bindingTable->BindPersistentResource(&persistentResourceBindingDesc);
}
ID3D12DescriptorHeap* descriptorHeaps[] = { m_heap.Get() };
m_graphicsCommandList->SetDescriptorHeaps(ARRAYSIZE(descriptorHeaps), descriptorHeaps);
ComPtr<IDMLCommandRecorder> recorder;
THROW_IF_FAILED(device->CreateCommandRecorder(IID_PPV_ARGS(recorder.GetAddressOf())));
recorder->RecordDispatch(commandList.Get(), m_compiledExecutionPlanOperator.Get(), m_bindingTable.Get());
THROW_IF_FAILED(m_graphicsCommandList->Close());
}
void ExecuteReusableCommandList(onnxruntime::OpKernelContext* kernelContext) const
{
DML_BINDING_PROPERTIES execBindingProps = m_compiledExecutionPlanOperator->GetBindingProperties();
std::vector<DML_BUFFER_BINDING> inputBindings(kernelContext->InputCount());
std::vector<DML_BINDING_DESC> inputBindingDescs(kernelContext->InputCount());
OpKernelContextWrapper contextWrapper(
kernelContext,
Info().GetExecutionProvider(),
true,
nullptr);
// Populate input bindings, excluding those which were specified as owned by DML and provided
// at initialization instead.
m_inputBindingAllocIds.resize(inputBindings.size());
bool inputBindingsChanged = false;
for (uint32_t i = 0; i < inputBindings.size(); ++i)
{
if (!m_inputsConstant[i] && m_inputsUsed[i])
{
if (m_nonOwnedGraphInputsFromInitializers[i])
{
inputBindings[i].Buffer = m_nonOwnedGraphInputsFromInitializers[i].Get();
inputBindings[i].SizeInBytes = m_nonOwnedGraphInputsFromInitializers[i]->GetDesc().Width;
inputBindingDescs[i] = {DML_BINDING_TYPE_BUFFER, &inputBindings[i]};
}
else
{
const onnxruntime::Tensor* tensor = kernelContext->Input<onnxruntime::Tensor>(i);
uint64_t allocId;
UnwrapTensor(tensor, &inputBindings[i].Buffer, &allocId);
inputBindingsChanged = inputBindingsChanged || (!allocId || m_inputBindingAllocIds[i] != allocId);
inputBindings[i].Buffer->Release(); // Avoid holding an additional reference
inputBindings[i].SizeInBytes = AlignToPow2<size_t>(tensor->SizeInBytes(), 4);
inputBindingDescs[i] = {DML_BINDING_TYPE_BUFFER, &inputBindings[i]};
m_inputBindingAllocIds[i] = allocId;
}
}
}
if (inputBindingsChanged)
{
m_bindingTable->BindInputs(gsl::narrow_cast<uint32_t>(inputBindingDescs.size()), inputBindingDescs.data());
}
// Populate Output bindings
std::vector<DML_BUFFER_BINDING> outputBindings(kernelContext->OutputCount());
std::vector<DML_BINDING_DESC> outputBindingDescs(kernelContext->OutputCount());
m_outputBindingAllocIds.resize(outputBindings.size());
bool outputBindingsChanged = false;
for (uint32_t i = 0; i < outputBindings.size(); ++i)
{
std::vector<int64_t> outputDims;
outputDims.reserve(m_outputShapes.GetShape(i).size());
for (uint32_t dimSize : m_outputShapes.GetShape(i))
{
outputDims.push_back(dimSize);
}
onnxruntime::Tensor* tensor = kernelContext->Output(
static_cast<int>(i),
onnxruntime::TensorShape::ReinterpretBaseType(outputDims)
);
uint64_t allocId;
UnwrapTensor(tensor, &outputBindings[i].Buffer, &allocId);
outputBindingsChanged = outputBindingsChanged || (!allocId || m_outputBindingAllocIds[i] != allocId);
outputBindings[i].Buffer->Release(); // Avoid holding an additional reference
outputBindings[i].SizeInBytes = AlignToPow2<size_t>(tensor->SizeInBytes(), 4);
outputBindingDescs[i] = {DML_BINDING_TYPE_BUFFER, &outputBindings[i]};
m_outputBindingAllocIds[i] = allocId;
}
if (outputBindingsChanged)
{
m_bindingTable->BindOutputs(gsl::narrow_cast<uint32_t>(outputBindingDescs.size()), outputBindingDescs.data());
}
if (execBindingProps.TemporaryResourceSize > 0)
{
// Allocate temporary data which will automatically be freed when the GPU work
// which is scheduled up to the point that this method returns has completed.
ComPtr<IUnknown> tempAlloc;
uint64_t tempAllocId = 0;
THROW_IF_FAILED(contextWrapper.AllocateTemporaryData(execBindingProps.TemporaryResourceSize, tempAlloc.GetAddressOf(), &tempAllocId));
ComPtr<IUnknown> tempResourceUnk;
m_winmlProvider->GetABIDataInterface(false, tempAlloc.Get(), &tempResourceUnk);
// Bind the temporary resource.
ComPtr<ID3D12Resource> tempResource;
THROW_IF_FAILED(tempResourceUnk->QueryInterface(tempResource.GetAddressOf()));
DML_BUFFER_BINDING tempBufferBinding = {tempResource.Get(), 0, execBindingProps.TemporaryResourceSize};
DML_BINDING_DESC tempBindingDesc = { DML_BINDING_TYPE_BUFFER, &tempBufferBinding };
if (!tempAllocId || m_tempBindingAllocId != tempAllocId)
{
m_bindingTable->BindTemporaryResource(&tempBindingDesc);
}
m_tempBindingAllocId = tempAllocId;
}
// Execute the command list and if it succeeds, update the fence value at which this command may be
// re-used.
ComPtr<ID3D12Fence> fence;
uint64_t completionValue;
THROW_IF_FAILED(m_provider->ExecuteCommandList(m_graphicsCommandList.Get(), fence.GetAddressOf(), &completionValue));
m_fence = fence;
m_completionValue = completionValue;
// Queue references to objects which must be kept alive until resulting GPU work completes
m_winmlProvider->QueueReference(m_graphicsCommandList.Get());
m_winmlProvider->QueueReference(m_heap.Get());
m_winmlProvider->QueueReference(m_bindingTable.Get());
m_winmlProvider->QueueReference(m_persistentResourceAllocatorUnk.Get());
}
void UnwrapTensor(const onnxruntime::Tensor* tensor, ID3D12Resource** resource, uint64_t* allocId) const
{
IUnknown* allocationUnk = static_cast<IUnknown*>(const_cast<void*>(tensor->DataRaw()));
ComPtr<IUnknown> resourceUnk;
m_winmlProvider->GetABIDataInterface(false, allocationUnk, &resourceUnk);
*allocId = m_winmlProvider->TryGetPooledAllocationId(allocationUnk, 0);
THROW_IF_FAILED(resourceUnk->QueryInterface(resource));
}
ComPtr<ID3D12Resource> CreateResource(const std::byte* tensorPtr, size_t tensorByteSize) const
{
ComPtr<ID3D12Resource> buffer;
D3D12_HEAP_PROPERTIES heapProperties = {
D3D12_HEAP_TYPE_DEFAULT,
D3D12_CPU_PAGE_PROPERTY_UNKNOWN,
D3D12_MEMORY_POOL_UNKNOWN,
0,
0
};
D3D12_RESOURCE_DESC resourceDesc = {
D3D12_RESOURCE_DIMENSION_BUFFER,
0,
static_cast<uint64_t>((tensorByteSize + 3) & ~3),
1,
1,
1,
DXGI_FORMAT_UNKNOWN,
{ 1, 0 },
D3D12_TEXTURE_LAYOUT_ROW_MAJOR,
D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS
};
ComPtr<ID3D12Device> d3dDevice;
THROW_IF_FAILED(m_provider->GetD3DDevice(d3dDevice.GetAddressOf()));
THROW_IF_FAILED(d3dDevice->CreateCommittedResource(
&heapProperties,
D3D12_HEAP_FLAG_NONE,
&resourceDesc,
D3D12_RESOURCE_STATE_UNORDERED_ACCESS,
nullptr,
IID_PPV_ARGS(buffer.GetAddressOf())
));
THROW_IF_FAILED(m_provider->UploadToResource(buffer.Get(), tensorPtr, tensorByteSize));
return buffer;
}
ComPtr<ID3D12Resource> CreateCpuResource(const std::byte* tensorPtr, size_t tensorByteSize) const
{
ComPtr<ID3D12Resource> buffer;
D3D12_HEAP_PROPERTIES heapProperties = {
D3D12_HEAP_TYPE_CUSTOM,
D3D12_CPU_PAGE_PROPERTY_WRITE_COMBINE,
D3D12_MEMORY_POOL_L0,
0,
0
};
D3D12_RESOURCE_DESC resourceDesc = {
D3D12_RESOURCE_DIMENSION_BUFFER,
0,
static_cast<uint64_t>((tensorByteSize + 3) & ~3),
1,
1,
1,
DXGI_FORMAT_UNKNOWN,
{ 1, 0 },
D3D12_TEXTURE_LAYOUT_ROW_MAJOR,
D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS
};
ComPtr<ID3D12Device> d3dDevice;
THROW_IF_FAILED(m_provider->GetD3DDevice(d3dDevice.GetAddressOf()));
THROW_IF_FAILED(d3dDevice->CreateCommittedResource(
&heapProperties,
D3D12_HEAP_FLAG_NONE,
&resourceDesc,
D3D12_RESOURCE_STATE_UNORDERED_ACCESS,
nullptr,
IID_PPV_ARGS(buffer.GetAddressOf())
));
// Map the buffer and copy the data
void* bufferData = nullptr;
D3D12_RANGE range = {0, tensorByteSize};
THROW_IF_FAILED(buffer->Map(0, &range, &bufferData));
memcpy(bufferData, tensorPtr, tensorByteSize);
buffer->Unmap(0, &range);
return buffer;
}
ComPtr<IDMLCompiledOperator> m_compiledExecutionPlanOperator;
std::vector<bool> m_inputsUsed;
const void* m_executionHandle = nullptr;
ComPtr<winrt::Windows::AI::MachineLearning::implementation::IWinmlExecutionProvider> m_winmlProvider;
ComPtr<Dml::IExecutionProvider> m_provider;
EdgeShapes m_outputShapes;
// Re-usable command list, supporting descriptor heap, and DML binding table to update that heap.
ComPtr<ID3D12GraphicsCommandList> m_graphicsCommandList;
ComPtr<ID3D12DescriptorHeap> m_heap;
ComPtr<IDMLBindingTable> m_bindingTable;
std::optional<DML_BUFFER_BINDING> m_persistentResourceBinding;
ComPtr<ID3D12Resource> m_persistentResource;
ComPtr<IUnknown> m_persistentResourceAllocatorUnk; // Controls when the persistent resource is returned to the allocator
// Bindings from previous executions of a re-used command list
mutable std::vector<uint64_t> m_inputBindingAllocIds;
mutable std::vector<uint64_t> m_outputBindingAllocIds;
mutable uint64_t m_tempBindingAllocId = 0;
// Fence tracking the status of the command list's last execution, and whether its descriptor heap
// can safely be updated.
mutable ComPtr<ID3D12Fence> m_fence;
mutable uint64_t m_completionValue = 0;
std::vector<uint8_t> m_inputsConstant;
std::vector<ComPtr<ID3D12Resource>> m_nonOwnedGraphInputsFromInitializers;
};
onnxruntime::OpKernel* CreateFusedGraphKernel(
const onnxruntime::OpKernelInfo& info,
const std::unordered_map<std::string, GraphNodeProperties> &graphNodePropertyMap,
std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap
)
{
return new FusedGraphKernel(info, graphNodePropertyMap, transferredInitializerMap);
}
} // namespace Dml

Просмотреть файл

@ -0,0 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/framework/op_kernel.h"
#include "GraphDescBuilder.h"
namespace Dml
{
onnxruntime::OpKernel* CreateFusedGraphKernel(
const onnxruntime::OpKernelInfo& info,
const std::unordered_map<std::string, GraphNodeProperties> &graphNodePropertyMap,
std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap
);
} // namespace Dml

Просмотреть файл

@ -0,0 +1,34 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
namespace Dml
{
// Represents a fence which will be signaled at some point (usually by the GPU).
struct GpuEvent
{
uint64_t fenceValue;
ComPtr<ID3D12Fence> fence;
bool IsSignaled() const
{
return fence->GetCompletedValue() >= fenceValue;
}
// Blocks until IsSignaled returns true.
void WaitForSignal() const
{
if (IsSignaled())
return; // early-out
wil::unique_handle h(CreateEvent(nullptr, TRUE, FALSE, nullptr));
THROW_LAST_ERROR_IF(!h);
THROW_IF_FAILED(fence->SetEventOnCompletion(fenceValue, h.get()));
WaitForSingleObject(h.get(), INFINITE);
}
};
} // namespace Dml

Просмотреть файл

@ -0,0 +1,277 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
#include "GraphDescBuilder.h"
using namespace winrt::Windows::AI::MachineLearning::implementation;
namespace Dml::GraphDescBuilder
{
// TODO: This is a hack which strips the suffix added within Lotus transforms that insert mem copies.
// This shouldn't be necessary if Lotus exposes the inputs/ouputs in the same order between the kernel
// for a function, and the graph for that function exposed as a kernel property. When the ordering
// mismatch is fixed (WindowsAI: 21114358, Lotus: 1953), this workaround should be removed.
static std::string GetFusedNodeArgNameMatchingGraph(const std::string& fusedNodeArgeName)
{
// The suffix used when inserting mem copies is equal to the below, followed by an incrementing number.
const char* suffix = strstr(fusedNodeArgeName.c_str(), "_DmlExecutionProvider_");
if (suffix)
{
return std::string(
fusedNodeArgeName.begin(),
fusedNodeArgeName.begin() + (suffix - fusedNodeArgeName.c_str())
);
}
return fusedNodeArgeName;
}
const std::string& GetUniqueNodeName(const onnxruntime::Node& node)
{
// The node's name is optional, and it might be re-created with a different index
// and pointer after partitioning occurs. Use the name of the node's first valid
// output as the unique identifier for the node itself.
for (const auto* arg : node.OutputDefs())
{
if (arg->Exists())
{
return arg->Name();
}
}
assert(false);
THROW_HR(E_UNEXPECTED);
}
GraphDesc BuildGraphDesc(
const onnxruntime::OpKernelInfo& kernelInfo,
gsl::span<const uint8_t> isConstGpuGraphInput,
std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap,
const onnxruntime::Graph& graph,
const onnxruntime::ConstPointerContainer<std::vector<onnxruntime::NodeArg*>>& fusedNodeInputDefs,
const onnxruntime::ConstPointerContainer<std::vector<onnxruntime::NodeArg*>>& fusedNodeOutputDefs,
const std::unordered_map<std::string, GraphNodeProperties>& graphNodePropertyMap,
IDMLDevice* device,
const void* executionHandle)
{
struct NodeAndIndex
{
uint32_t nodeIndex; // The index of the node itself
uint32_t targetIndex; // The index of the input/output on the node (e.g. 1 for the second input on a node)
};
// Map from Lotus node argument names to the new node and index where it will be produced
std::unordered_map<std::string, NodeAndIndex> nameToNodeAndIndexMap;
// Map from Lotus node argument names to input indices of the fused kernel node.
std::unordered_map<std::string, uint32_t> nameToFusedNodeInputIndex;
for (size_t inputIndex = 0; inputIndex < fusedNodeInputDefs.size(); ++inputIndex)
{
const onnxruntime::NodeArg* graphInput = graph.GetNodeArg(
GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[inputIndex]->Name()));
if (!graphInput)
{
// This is a workaround for when node inputs get manipulated by transformers outside of our control,
// which then causes them to have a different name. If that happens we can't figure out how to
// correlate inputs to the fused graph index. This likely requires a higher-level fix, but for now
// just bail early.
THROW_HR(E_UNEXPECTED);
}
nameToFusedNodeInputIndex.emplace(graphInput->Name(), gsl::narrow_cast<uint32_t>(inputIndex));
}
StackAllocator<1024> allocator; // Used for converting abstract operator descs into DML_OPERATOR_DESC
std::vector<NodeInfo> graphNodes;
std::vector<DML_PREVIEW_INPUT_GRAPH_EDGE> graphInputEdges;
std::vector<DML_PREVIEW_INTERMEDIATE_GRAPH_EDGE> graphIntermediateEdges;
std::vector<DML_PREVIEW_OUTPUT_GRAPH_EDGE> graphOutputEdges;
// Get the topological sorting of Lotus nodes
// paulm: breaking change from LOTUS that removed GetNodesInTopologicalOrder from Graph
onnxruntime::GraphViewer viewer(graph);
const std::vector<onnxruntime::NodeIndex>& orderedNodeIndices = viewer.GetNodesInTopologicalOrder();
// Avoid using separate command lists for small graphs. This value can be reduced by tuning the
// flushing behavior of DmlCommandRecorder. Its current behavior is to assume that graphs contain
// enough GPU work to be worth flushing immediately.
const uint32_t minNodeCountToReuseCommandList = 5;
bool reuseCommandList = false;
if (orderedNodeIndices.size() >= minNodeCountToReuseCommandList)
{
reuseCommandList = true;
}
auto constantCpuGraphInputGetter = [&fusedNodeInputDefs, &transferredInitializerMap](const std::string& argName)
{
ComPtr<OnnxTensorWrapper> tensorWrapper;
auto iter = transferredInitializerMap.find(argName);
if (iter != transferredInitializerMap.end())
{
tensorWrapper = wil::MakeOrThrow<OnnxTensorWrapper>(&iter->second);
}
return tensorWrapper;
};
// Iterate through each node and create a corresponding node in the new graph
for (size_t sortedNodeIndex : orderedNodeIndices)
{
const onnxruntime::Node& node = *graph.GetNode(sortedNodeIndex);
const GraphNodeProperties& graphNodeProps = graphNodePropertyMap.find(GetUniqueNodeName(node))->second;
const auto& requiredConstantCpuInputs = graphNodeProps.graphNodeFactoryRegistration->requiredConstantCpuInputs;
MLOperatorTensorGetter constantCpuNodeInputGetter = [&node, &constantCpuGraphInputGetter, &requiredConstantCpuInputs](uint32_t inputIndex)
{
ComPtr<IMLOperatorTensor> tensor = nullptr;
// Check whether this specific node requested support for constant CPU inputs
if (std::find(requiredConstantCpuInputs.begin(), requiredConstantCpuInputs.end(), inputIndex) != requiredConstantCpuInputs.end())
{
const onnxruntime::NodeArg* arg = node.InputDefs()[inputIndex];
tensor = constantCpuGraphInputGetter(arg->Name());
}
return tensor;
};
DmlGraphNodeCreateInfo graphNodeInfo;
graphNodeProps.graphNodeFactoryRegistration->factory(
node,
constantCpuNodeInputGetter,
executionHandle,
&graphNodeInfo
);
// Determine the number of valid inputs and outputs of this node. The graph currently supports opererators
// with unused inputs and outputs only at the end of each list.
uint32_t validOpInputCount = 0;
uint32_t validOpOutputCount = 0;
for (uint32_t i = 0; i < graphNodeInfo.kernelInputIndices.size(); ++i)
{
if (graphNodeInfo.kernelInputIndices[i] != std::numeric_limits<uint32_t>::max())
{
assert(i - validOpInputCount == 0);
++validOpInputCount;
}
}
for (uint32_t i = 0; i < graphNodeInfo.kernelOutputIndices.size(); ++i)
{
if (graphNodeInfo.kernelOutputIndices[i] != std::numeric_limits<uint32_t>::max())
{
assert(i - validOpOutputCount == 0);
++validOpOutputCount;
}
}
uint32_t nodeIndex = gsl::narrow_cast<uint32_t>(graphNodes.size());
AbstractOperatorDesc opDesc = *graphNodeInfo.desc; // Make a copy
// Retrieve lists of input and output tensor descs. These point into the opDesc, which allows us to modify
// the tensor descs through these pointers.
std::vector<DmlBufferTensorDesc*> inputTensorDescs = opDesc.GetInputTensors();
std::vector<DmlBufferTensorDesc*> outputTensorDescs = opDesc.GetOutputTensors();
// Set connections of the new node
for (uint32_t inputIndex = 0; inputIndex < validOpInputCount; ++inputIndex)
{
uint32_t kernelInputIndex = graphNodeInfo.kernelInputIndices[inputIndex];
const onnxruntime::NodeArg* arg = node.InputDefs()[kernelInputIndex];
if (arg->Exists())
{
auto iter = nameToFusedNodeInputIndex.find(arg->Name());
if (iter != nameToFusedNodeInputIndex.end())
{
// This is a graph input
const uint32_t fusedNodeInputIndex = iter->second;
DML_PREVIEW_INPUT_GRAPH_EDGE edge = {};
edge.GraphInputIndex = fusedNodeInputIndex;
edge.ToNodeIndex = nodeIndex;
edge.ToNodeInputIndex = inputIndex;
graphInputEdges.push_back(edge);
// If this is a constant input, set the appropriate flags on the desc
if (isConstGpuGraphInput[fusedNodeInputIndex])
{
DmlBufferTensorDesc* tensorDesc = inputTensorDescs[inputIndex];
tensorDesc->flags |= DML_TENSOR_FLAG_OWNED_BY_DML;
}
}
else
{
const auto& inputNodeAndIndex = nameToNodeAndIndexMap.at(arg->Name());
DML_PREVIEW_INTERMEDIATE_GRAPH_EDGE edge = {};
edge.FromNodeIndex = inputNodeAndIndex.nodeIndex;
edge.FromNodeOutputIndex = inputNodeAndIndex.targetIndex;
edge.ToNodeIndex = nodeIndex;
edge.ToNodeInputIndex = inputIndex;
graphIntermediateEdges.push_back(edge);
}
}
}
// Store the new node for lookup when downstream nodes consume it.
for (uint32_t outputIndex = 0; outputIndex < validOpOutputCount; ++outputIndex)
{
uint32_t kernelOutputIndex = graphNodeInfo.kernelOutputIndices[outputIndex];
const onnxruntime::NodeArg* arg = node.OutputDefs()[kernelOutputIndex];
if (arg->Exists())
{
nameToNodeAndIndexMap[arg->Name()] = NodeAndIndex{ nodeIndex, outputIndex };
}
}
DML_OPERATOR_DESC dmlDesc = SchemaHelpers::ConvertOperatorDesc(opDesc, &allocator);
ComPtr<IDMLOperator> op;
THROW_IF_FAILED(device->CreateOperator(&dmlDesc, IID_PPV_ARGS(&op)));
allocator.Reset();
NodeInfo nodeInfo = {};
nodeInfo.op = std::move(op);
graphNodes.push_back(std::move(nodeInfo));
}
assert(graphNodes.size() == orderedNodeIndices.size());
// Add graph output nodes, which might be in a different order from the encapsulating node
for (size_t outputIndex = 0; outputIndex < fusedNodeOutputDefs.size(); ++outputIndex)
{
const onnxruntime::NodeArg* graphOutput = graph.GetNodeArg(
GetFusedNodeArgNameMatchingGraph(fusedNodeOutputDefs[outputIndex]->Name()));
const auto& outputNodeAndIndex = nameToNodeAndIndexMap.at(graphOutput->Name());
DML_PREVIEW_OUTPUT_GRAPH_EDGE edge = {};
edge.FromNodeIndex = outputNodeAndIndex.nodeIndex;
edge.FromNodeOutputIndex = outputNodeAndIndex.targetIndex;
edge.GraphOutputIndex = gsl::narrow_cast<uint32_t>(outputIndex);
graphOutputEdges.push_back(edge);
}
GraphDesc graphDesc{};
graphDesc.nodes = std::move(graphNodes);
graphDesc.inputEdges = std::move(graphInputEdges);
graphDesc.outputEdges = std::move(graphOutputEdges);
graphDesc.intermediateEdges = std::move(graphIntermediateEdges);
graphDesc.reuseCommandList = reuseCommandList;
return graphDesc;
}
}

Просмотреть файл

@ -0,0 +1,53 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "MLOperatorAuthorImpl.h"
namespace Dml
{
struct GraphNodeProperties
{
std::shared_ptr<const winrt::Windows::AI::MachineLearning::implementation::GraphNodeFactoryRegistration>
graphNodeFactoryRegistration;
// These are currently passed from the partitioning step since the only DML operators current
// supporting graph nodes don't customize the order of edges or shapes, other than coercing
// dimension count. This will change as the supported set of operators as graph nodes increases.
winrt::Windows::AI::MachineLearning::implementation::EdgeShapes inputShapes;
winrt::Windows::AI::MachineLearning::implementation::EdgeShapes outputShapes;
};
namespace GraphDescBuilder
{
// Gets a unique name for the node which survives recreation and graph manipulations between the point
// that graph partitioning occurs and kernel creation happens
const std::string& GetUniqueNodeName(const onnxruntime::Node& node);
struct NodeInfo
{
Microsoft::WRL::ComPtr<IDMLOperator> op;
};
struct GraphDesc
{
std::vector<NodeInfo> nodes;
std::vector<DML_PREVIEW_INPUT_GRAPH_EDGE> inputEdges;
std::vector<DML_PREVIEW_OUTPUT_GRAPH_EDGE> outputEdges;
std::vector<DML_PREVIEW_INTERMEDIATE_GRAPH_EDGE> intermediateEdges;
bool reuseCommandList;
};
GraphDesc BuildGraphDesc(
const onnxruntime::OpKernelInfo& kernelInfo,
gsl::span<const uint8_t> isConstGpuGraphInput,
std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap,
const onnxruntime::Graph& graph,
const onnxruntime::ConstPointerContainer<std::vector<onnxruntime::NodeArg*>>& fusedNodeInputDefs,
const onnxruntime::ConstPointerContainer<std::vector<onnxruntime::NodeArg*>>& fusedNodeOutputDefs,
const std::unordered_map<std::string, GraphNodeProperties>& graphNodePropertyMap,
IDMLDevice* device,
const void* executionHandle);
}
}

Просмотреть файл

@ -0,0 +1,812 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
#include "IExecutionProvider.h"
#include "ExecutionProvider.h"
#include "core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h"
#include "FusedGraphKernel.h"
#include "GraphDescBuilder.h"
#include "core/graph/indexed_sub_graph.h"
#include "core/framework/compute_capability.h"
#include <wil/wrl.h>
#include <dxgi1_6.h>
#include "GraphPartitioner.h"
//#define PRINT_PARTITON_INFO
using namespace winrt::Windows::AI::MachineLearning::implementation;
namespace Dml
{
GraphPartition* GraphPartition::GetRootMergedPartition()
{
return m_mergedPartition ? m_mergedPartition->GetRootMergedPartition() : this;
}
std::vector<onnxruntime::NodeIndex>& GraphPartition::GetNodeIndices()
{
assert(this == GetRootMergedPartition());
return m_nodeIndices;
}
std::set<std::string>& GraphPartition::GetInputs()
{
assert(this == GetRootMergedPartition());
return m_inputs;
}
std::set<std::string>& GraphPartition::GetOutputs()
{
assert(this == GetRootMergedPartition());
return m_outputs;
}
bool GraphPartition::IsFinalized()
{
assert(this == GetRootMergedPartition());
return m_finalized;
}
void GraphPartition::SetFinalized()
{
m_finalized = true;
}
bool GraphPartition::IsDmlPartition()
{
assert(this == GetRootMergedPartition());
return m_isDmlPartition;
}
bool GraphPartition::IsDmlGraphPartition()
{
assert(this == GetRootMergedPartition());
return m_isDmlGraphPartition;
}
void GraphPartition::SetIsDmlPartition(bool isDmlPartition)
{
assert(this == GetRootMergedPartition());
m_isDmlPartition = isDmlPartition;
}
void GraphPartition::SetIsDmlGraphPartition(bool isDmlGraphPartition)
{
assert(this == GetRootMergedPartition());
m_isDmlGraphPartition = isDmlGraphPartition;
}
void GraphPartition::AddNodeIndex(onnxruntime::NodeIndex index)
{
assert(!IsFinalized());
assert(std::find(m_nodeIndices.begin(), m_nodeIndices.end(), index) == m_nodeIndices.end());
m_nodeIndices.push_back(index);
}
void GraphPartition::AddInput(const std::string& name)
{
assert(!IsFinalized());
assert(this == GetRootMergedPartition());
m_inputs.insert(name);
}
void GraphPartition::AddOutput(const std::string& name)
{
assert(this == GetRootMergedPartition());
m_outputs.insert(name);
}
void GraphPartition::Merge(gsl::span<GraphPartition*> partitionsToMerge)
{
assert(this == GetRootMergedPartition());
for (GraphPartition* partitionToMerge : partitionsToMerge)
{
if (partitionToMerge == this)
{
continue;
}
assert(!partitionToMerge->IsFinalized());
assert(partitionToMerge->IsDmlPartition() == IsDmlPartition());
assert(partitionToMerge->IsDmlGraphPartition() == IsDmlGraphPartition());
partitionToMerge->m_mergedPartition = this;
m_nodeIndices.insert(m_nodeIndices.begin(), partitionToMerge->m_nodeIndices.begin(), partitionToMerge->m_nodeIndices.end());
m_inputs.insert(partitionToMerge->m_inputs.begin(), partitionToMerge->m_inputs.end());
m_outputs.insert(partitionToMerge->m_outputs.begin(), partitionToMerge->m_outputs.end());
}
}
// Adds the outputs of a node to the specified partition
void AddNodeOutputsToPartitionMap(
const onnxruntime::Node& node,
GraphPartition* partition,
std::unordered_map<std::string, GraphPartition*>& nodeNameToPartitionMap
)
{
for (uint32_t i = 0; i < node.OutputDefs().size(); ++i)
{
const auto* arg = node.OutputDefs()[i];
if (arg->Exists())
{
nodeNameToPartitionMap[arg->Name()] = partition;
}
}
};
bool NodeArgSupportedInGraph(const onnxruntime::NodeArg* arg, bool requiresFloatFormats)
{
if (arg->Exists())
{
const onnx::TypeProto* typeProto = arg->TypeAsProto();
if (typeProto->value_case() == onnx::TypeProto::kTensorType)
{
const onnx::TypeProto_Tensor tensorType = typeProto->tensor_type();
if (tensorType.has_elem_type())
{
// TODO: Remove this by handling zeroing on the output of fused graph nodes and handling of non-float
// types in DML's identity operator, which is used for strided copies.
if (ToMLTensorDataType(static_cast<onnx::TensorProto_DataType>(tensorType.elem_type())) == MLOperatorTensorDataType::UInt64 ||
ToMLTensorDataType(static_cast<onnx::TensorProto_DataType>(tensorType.elem_type())) == MLOperatorTensorDataType::Int64)
{
return false;
}
if (requiresFloatFormats)
{
if (ToMLTensorDataType(static_cast<onnx::TensorProto_DataType>(tensorType.elem_type())) != MLOperatorTensorDataType::Float &&
ToMLTensorDataType(static_cast<onnx::TensorProto_DataType>(tensorType.elem_type())) != MLOperatorTensorDataType::Float16)
{
return false;
}
}
}
}
}
return true;
}
bool NodeTensorTypesSupportedInGraph(const onnxruntime::Node& node, const GraphNodeFactoryRegistration& registration)
{
for (size_t i = 0; i < node.InputDefs().size(); ++i)
{
bool isConstantCpuInput = std::find(registration.requiredConstantCpuInputs.begin(), registration.requiredConstantCpuInputs.end(), i) !=
registration.requiredConstantCpuInputs.end();
if (!isConstantCpuInput && !NodeArgSupportedInGraph(node.InputDefs()[i], registration.requiresFloatFormatsExceptConstInputs))
{
return false;
}
}
for (auto arg : node.OutputDefs())
{
if (!NodeArgSupportedInGraph(arg, registration.requiresFloatFormatsExceptConstInputs))
{
return false;
}
}
return true;
}
bool DoesNodeContainSupportedDataTypes(
const onnxruntime::Node& node,
uint32_t supportedDeviceDataTypeMask // Each bit corresponds to each DML_TENSOR_DATA_TYPE.
)
{
// Assume data types are supported until proven otherwise.
bool nodeContainsSupportedDataTypes = true;
// Callback to check each node's data type.
std::function<void(const onnxruntime::NodeArg& nodeArg, bool isInput)> nodeCallback = [&](const onnxruntime::NodeArg& nodeArg, bool isInput) -> void
{
// Get the tensor element data type for this node, comparing against what the device actually supports.
// Use the enumeration from the proto instead of nodeArg.Type() which returns a string.
const ::onnx::TypeProto* typeProto = nodeArg.TypeAsProto();
if (typeProto != nullptr && typeProto->has_tensor_type())
{
const ::onnx::TypeProto_Tensor& tensorTypeProto = typeProto->tensor_type();
if (tensorTypeProto.has_elem_type())
{
MLOperatorTensorDataType onnxElementType = static_cast<MLOperatorTensorDataType>(tensorTypeProto.elem_type());
DML_TENSOR_DATA_TYPE dmlElementType = GetDmlDataTypeFromMlDataTypeNoThrow(onnxElementType);
if (dmlElementType != DML_TENSOR_DATA_TYPE_UNKNOWN)
{
if ((1 << dmlElementType) & supportedDeviceDataTypeMask)
{
// Leave nodeContainsSupportedDataTypes alone, since data type is supported.
return;
}
}
}
}
// Else it's not supported (non-tensors, opaque data types, unsupported data types...).
nodeContainsSupportedDataTypes = false;
};
// Check whether the node uses any data types which are unsupported by the device.
node.ForEachDef(nodeCallback);
return nodeContainsSupportedDataTypes;
}
// Gets properties of the registration for a node
void GetRegistrationProperties(
const onnxruntime::GraphViewer& graph,
const onnxruntime::Node& node,
const std::vector<const onnxruntime::KernelRegistry*>& dmlRegistries,
uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE.
const GraphNodeFactoryMap& graphNodeFactoryMap,
_Inout_ std::unordered_map<const onnxruntime::Node*, GraphNodeProperties>& dmlNodePropertyMap,
_Inout_ std::unordered_set<std::string>& requiredInitializerMap,
_Out_ bool* isDmlNode,
_Out_ bool* isDmlGraphNode
)
{
*isDmlNode = false;
*isDmlGraphNode = false;
// Find the highest priority DML registry supporting this node, and get its highest-priority
// registration. Determine if that registration supports usage as a graph node.
for (auto registry : dmlRegistries)
{
const onnxruntime::KernelCreateInfo* createInfo = registry->TryFindKernel(node, onnxruntime::kDmlExecutionProvider);
// Check whether the node uses any data types which are unsupported by the device.
bool nodeContainsSupportedDataTypes = DoesNodeContainSupportedDataTypes(node, supportedDeviceDataTypeMask);
if (createInfo && nodeContainsSupportedDataTypes)
{
*isDmlNode = true;
// Get the kernel creation info for the registration, and check if it carries the property
// set during registration of kernels that support DML graph node usage.
auto& graphNodeProperty = dmlNodePropertyMap.insert(std::make_pair(&node, GraphNodeProperties()));
// Ensure that shape information is known statically for the inputs and outputs of the node,
// which is required for MLGraph compilation.
auto graphNodeFactorMapIter = graphNodeFactoryMap.find(createInfo->kernel_def.get());
if (graphNodeFactorMapIter != graphNodeFactoryMap.end() &&
NodeTensorTypesSupportedInGraph(node, *graphNodeFactorMapIter->second))
{
bool requiredCpuInputsConstant = true;
for (uint32_t inputIndex : graphNodeFactorMapIter->second->requiredConstantCpuInputs)
{
const onnx::TensorProto* tensor = nullptr;
const std::string& inputName = node.InputDefs()[inputIndex]->Name();
if (!graph.GetInitializedTensor(inputName, tensor))
{
requiredCpuInputsConstant = false;
break;
}
requiredInitializerMap.insert(inputName);
}
std::optional<uint32_t> requiredInputCount = graphNodeFactorMapIter->second->requiredInputCount;
if (requiredCpuInputsConstant &&
TryGetStaticInputShapes( node, graphNodeProperty.first->second.inputShapes) &&
TryGetStaticOutputShapes(node, graphNodeProperty.first->second.outputShapes) &&
(requiredInputCount == std::nullopt || *requiredInputCount == node.InputDefs().size()))
{
*isDmlGraphNode = true;
graphNodeProperty.first->second.graphNodeFactoryRegistration = graphNodeFactorMapIter->second;
}
}
break;
}
}
}
// Creates a partition for a node which is not a DML graph node, and finalizes partitions
// which are inputs of the new partition.
std::unique_ptr<GraphPartition> CreateNonGraphNodePartitionAndFinalizeInputs(
const onnxruntime::Node& node,
bool isDmlNode,
std::unordered_map<std::string, GraphPartition*>& nodeNameToPartitionMap
)
{
std::unique_ptr<GraphPartition> partition = std::make_unique<GraphPartition>();
partition->SetIsDmlGraphPartition(false);
partition->SetIsDmlPartition(isDmlNode);
partition->AddNodeIndex(node.Index());
for (uint32_t i = 0; i < node.InputDefs().size(); ++i)
{
const auto* arg = node.InputDefs()[i];
if (arg->Exists())
{
const std::string& argName = arg->Name();
if (nodeNameToPartitionMap.find(argName) != nodeNameToPartitionMap.end())
{
// Finalize the partition which contains an input to a non-DML-graph partition.
// The connections from that partition to other partitions, such as this one,
// must become outputs of that partition. As subsequent downstream nodes of
// the finalized partition are visited, other outputs will subsequently be
// added to the partition, too.
GraphPartition* inputPartition = nodeNameToPartitionMap[argName]->GetRootMergedPartition();
inputPartition->SetFinalized();
inputPartition->AddOutput(argName);
}
partition->AddInput(argName);
}
}
partition->SetFinalized();
AddNodeOutputsToPartitionMap(node, partition.get(), nodeNameToPartitionMap);
return partition;
}
// Get the partitions which are inputs to the specified node and which are not finalized.
std::vector<GraphPartition*> GetNonFinalizedInputPartitions(
const onnxruntime::Node& node,
std::unordered_map<std::string, GraphPartition*>& nodeNameToPartitionMap
)
{
std::vector<GraphPartition*> inputNonFinalPartitions;
for (uint32_t i = 0; i < node.InputDefs().size(); ++i)
{
const auto* arg = node.InputDefs()[i];
if (arg->Exists())
{
const std::string& argName = arg->Name();
if (nodeNameToPartitionMap.find(argName) == nodeNameToPartitionMap.end())
{
// Must be source node
continue;
}
GraphPartition* inputPartition = nodeNameToPartitionMap[argName]->GetRootMergedPartition();
if (!inputPartition->IsFinalized())
{
inputNonFinalPartitions.push_back(inputPartition);
}
}
}
return inputNonFinalPartitions;
}
// Add graph outputs of the new node to a partition.
void AddGraphOutputsFromNodeToPartition(
const onnxruntime::Node& node,
const std::set<std::string>& graphOutputs,
GraphPartition* partition
)
{
for (uint32_t i = 0; i < node.OutputDefs().size(); ++i)
{
const auto* arg = node.OutputDefs()[i];
if (arg->Exists())
{
if (graphOutputs.find(arg->Name()) != graphOutputs.end())
{
partition->AddOutput(arg->Name());
}
}
}
}
std::unique_ptr<GraphPartition> CreateNewPartitionWithFinalizedInputPartitions(
const onnxruntime::Node& node,
const std::set<std::string>& graphOutputs,
std::unordered_map<std::string, GraphPartition*>& nodeNameToPartitionMap
)
{
std::unique_ptr<GraphPartition> partition = std::make_unique<GraphPartition>();
partition->SetIsDmlGraphPartition(true);
partition->SetIsDmlPartition(true);
partition->AddNodeIndex(node.Index());
// Inputs of the partition are added when partitions are created and extended when
// nodes are added with inputs which are not inside the partition
for (uint32_t i = 0; i < node.InputDefs().size(); ++i)
{
const auto* arg = node.InputDefs()[i];
if (arg->Exists())
{
partition->AddInput(arg->Name());
auto& inputPartition = nodeNameToPartitionMap.find(arg->Name());
if (inputPartition != nodeNameToPartitionMap.end())
{
inputPartition->second->GetRootMergedPartition()->AddOutput(arg->Name());
}
}
}
// Outputs of the partition are initially set to node outputs which are also
// graph outputs. They are extended when adding other node with the graph
// outputs from those nodes. They are also extended when a partition
// consumes an input from the current partition.
AddGraphOutputsFromNodeToPartition(node, graphOutputs, partition.get());
AddNodeOutputsToPartitionMap(node, partition.get(), nodeNameToPartitionMap);
return partition;
}
std::unique_ptr<onnxruntime::ComputeCapability> ComputationCapacityFromPartition(
GraphPartition* partition,
uint32_t partitionIndex,
const onnxruntime::GraphViewer& graph,
std::unordered_map<const onnxruntime::Node*, GraphNodeProperties>&& graphNodePropertyMap,
onnxruntime::KernelRegistry* registryForPartitionKernels,
const std::string& partitionKernelPrefix,
std::shared_ptr<std::unordered_map<std::string, onnx::TensorProto>> transferredInitializerMap)
{
std::unique_ptr<onnxruntime::IndexedSubGraph> subGraph = std::make_unique<onnxruntime::IndexedSubGraph>();
if (partition->IsDmlGraphPartition())
{
assert(partition->IsDmlGraphPartition());
// Create a definition for the node. The name must be unique.
auto def = std::make_unique<onnxruntime::IndexedSubGraph::MetaDef>();
def->name = std::string("DmlFusedNode_") + partitionKernelPrefix + std::to_string(partitionIndex);
def->domain = "DmlFusedNodeDomain";
def->since_version = 1;
def->inputs.insert(def->inputs.begin(), partition->GetInputs().begin(), partition->GetInputs().end());
def->outputs.insert(def->outputs.begin(), partition->GetOutputs().begin(), partition->GetOutputs().end());
// Populate properties which will be passed to OpKernel for this graph via the function below
std::unordered_map<std::string, GraphNodeProperties> partitionNodePropsMap;
for (auto nodeIndex : partition->GetNodeIndices())
{
const onnxruntime::Node* node = graph.GetNode(nodeIndex);
#ifdef PRINT_PARTITON_INFO
printf("Partition %u\t%s\n", partitionIndex, GraphDescBuilder::GetUniqueNodeName(*node).c_str());
#endif
partitionNodePropsMap.insert(std::make_pair(
GraphDescBuilder::GetUniqueNodeName(*node), std::move(graphNodePropertyMap[node])));
}
#ifdef PRINT_PARTITON_INFO
printf("\n");
#endif
auto fused_kernel_func = [partitionNodePropsMap, transferredInitializerMap](const onnxruntime::OpKernelInfo& info) mutable ->onnxruntime::OpKernel*
{
return CreateFusedGraphKernel(info, partitionNodePropsMap, *transferredInitializerMap);
};
// build the kernel definition on the fly, and register it to the fused_kernel_regisitry.
onnxruntime::KernelDefBuilder builder;
builder.SetName(def->name)
.SetDomain(def->domain)
.SinceVersion(def->since_version)
.Provider(onnxruntime::kDmlExecutionProvider);
registryForPartitionKernels->Register(builder, fused_kernel_func);
subGraph->SetMetaDef(std::move(def));
}
subGraph->nodes = std::move(partition->GetNodeIndices());
return std::make_unique<onnxruntime::ComputeCapability>(std::move(subGraph));
}
// Whether any operator in the model contains a subgraph. This is true
// if the graph being partitioned is itself within a subgraph, or contains
// an operator with a subgraph.
bool ModelUsesSubgraph(const onnxruntime::GraphViewer& graph)
{
if (graph.IsSubgraph())
{
return true;
}
const std::vector<onnxruntime::NodeIndex>& toplogicalOrder = graph.GetNodesInTopologicalOrder();
for (size_t nodeIndex : toplogicalOrder)
{
const onnxruntime::Node& node = *graph.GetNode(nodeIndex);
if (node.ContainsSubgraph())
{
return true;
}
}
return false;
}
//
// A simple graph partitioning algorithm is used:
//
// - If a node has any input which is already in a graph, and that graph is not finalized,
// then the node and all such input graphs are merged.
//
// - Once a node has an output which cannot be merged with its graph, its graph is marked
// as final, which disallows its future extensions. This ensures that no indirect
// downstream dependencies of the external output node are later merged.
//
std::vector<std::unique_ptr<GraphPartition>>
BuildPartitions(
const onnxruntime::GraphViewer& graph,
const GraphNodeFactoryMap& graphNodeFactoryMap,
const std::vector<const onnxruntime::KernelRegistry*>& registries,
uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE.
std::unordered_map<const onnxruntime::Node*, GraphNodeProperties>& graphNodePropertyMap,
std::unordered_set<std::string>& requiredInitializerMap,
std::function<void(const onnxruntime::Node&)> onNodeUnsupportedInGraph)
{
// Nodes are uniquely identified by the name of their first output argument
std::vector<std::unique_ptr<GraphPartition>> partitions;
std::unordered_map<std::string, GraphPartition*> nodeNameToPartitionMap;
// Get the list of node indices in toplogical order, so nodes are visited before.
// downstream nodes consuming them.
const std::vector<onnxruntime::NodeIndex>& toplogicalOrder = graph.GetNodesInTopologicalOrder();
// Construct sets with graph inputs and outputs for fast lookup later.
std::set<std::string> graphInputs;
std::set<std::string> graphOutputs;
for (const auto* arg : graph.GetInputsIncludingInitializers())
{
graphInputs.insert(arg->Name());
}
// If a model contains an intializer which is not also a graph input, it will not be returned
// by GetInputsIncludingInitializers above. Such models would be invalid, however they loaded
// in RS5. For compatibility, this ensures that such models continue to load. This is
// verified by an ONNX conformance test for Add.
for (const auto& arg : graph.GetAllInitializedTensors())
{
// This adds the initializer to the input set if it didn't already exist.
graphInputs.insert(arg.first);
}
for (const auto* arg : graph.GetOutputs())
{
graphOutputs.insert(arg->Name());
}
// Check whether this graph is a subgraph, or contains any node with a subgraph.
bool modelUsesSubgraph = ModelUsesSubgraph(graph);
// Build up partitions while traversing the graph.
for (size_t nodeIndex : toplogicalOrder)
{
const onnxruntime::Node& node = *graph.GetNode(nodeIndex);
// Whether the node is implemented through DML.
bool isDmlNode = false;
// Whether the node is implemented through DML and as a graph node, meaning it
// can generate DML operations through a private interface for use as an MLGraph node.
bool isDmlGraphNode = false;
// Get the registration properties above and populate nodeNameToPartitionMap.
GetRegistrationProperties(
graph,
node,
registries,
supportedDeviceDataTypeMask,
graphNodeFactoryMap,
graphNodePropertyMap,
requiredInitializerMap,
/*out*/ &isDmlNode,
/*out*/ &isDmlGraphNode
);
// Add a unique partition if graph node usage is not supported.
//
// Partitioning is disabled in models with subgraphs to work around issues with implicit inputs.
// The partitioning algorithm does not currently consider such inputs. Transfering shared initializers
// for partitions could also cause problems. Note, operators with subgraphs are currently not efficient
// anyhow due to CPU/GPU copies.
if (modelUsesSubgraph || !isDmlGraphNode)
{
if (onNodeUnsupportedInGraph)
{
onNodeUnsupportedInGraph(node);
}
partitions.push_back(CreateNonGraphNodePartitionAndFinalizeInputs(node, isDmlNode, nodeNameToPartitionMap));
continue;
}
std::vector<GraphPartition*> inputNonFinalPartitions = GetNonFinalizedInputPartitions(node, nodeNameToPartitionMap);
if (inputNonFinalPartitions.empty())
{
partitions.push_back(CreateNewPartitionWithFinalizedInputPartitions(node, graphOutputs, nodeNameToPartitionMap));
}
else
{
// Arbitrarily pick the first non-final partition found among the inputs, and add this node
// and its output arguments to that partition.
GraphPartition* firstNonFinalInputPartition = inputNonFinalPartitions[0]->GetRootMergedPartition();
firstNonFinalInputPartition->AddNodeIndex(node.Index());
AddNodeOutputsToPartitionMap(node, firstNonFinalInputPartition, nodeNameToPartitionMap);
// Add inputs for the new node which span partitions
for (uint32_t i = 0; i < node.InputDefs().size(); ++i)
{
const auto* arg = node.InputDefs()[i];
if (arg->Exists())
{
auto& inputPartition = nodeNameToPartitionMap.find(arg->Name());
// Add the input of the current node into the partition which the node will be merged into.
// Skip this if the input is already merged into the same partition or is not finalized,
// and so will be subsequently merged below.
if (inputPartition != nodeNameToPartitionMap.end() &&
inputPartition->second->GetRootMergedPartition() != firstNonFinalInputPartition &&
inputPartition->second->GetRootMergedPartition()->IsFinalized())
{
// Add this input of the current node as an output of the final partition to which
// it belongs.
inputPartition->second->GetRootMergedPartition()->AddOutput(arg->Name());
firstNonFinalInputPartition->AddInput(arg->Name());
}
if (graphInputs.find(arg->Name()) != graphInputs.end())
{
firstNonFinalInputPartition->AddInput(arg->Name());
}
}
}
// Add graph outputs of the new node
AddGraphOutputsFromNodeToPartition(node, graphOutputs, firstNonFinalInputPartition);
// Merge each other non-finalized input partition into the first one
if (inputNonFinalPartitions.size() > 1)
{
firstNonFinalInputPartition->Merge(gsl::span<GraphPartition*>(&inputNonFinalPartitions[1], inputNonFinalPartitions.size() - 1));
}
}
}
return partitions;
}
std::unordered_map<const onnx::TensorProto*, std::vector<uint32_t>>
GetInitializerToPartitionMap(
const onnxruntime::GraphViewer& graph,
gsl::span<std::unique_ptr<GraphPartition>> partitions
)
{
std::unordered_map<const onnx::TensorProto*, std::vector<uint32_t>> initializerPartitionMap;
for (uint32_t partitionIndex = 0; partitionIndex < gsl::narrow_cast<uint32_t>(partitions.size()); ++partitionIndex)
{
auto& partition = partitions[partitionIndex];
// Skip partitions which have been merged into other partitions
if (partition->GetRootMergedPartition() != partition.get())
{
continue;
}
std::unordered_map<std::string, onnx::TensorProto> transferredInitializerMap;
for (const std::string& input : partition->GetInputs())
{
const onnx::TensorProto* tensor = nullptr;
if (graph.GetInitializedTensor(input, tensor))
{
initializerPartitionMap[tensor].push_back(partitionIndex);
}
}
}
return initializerPartitionMap;
}
std::vector<std::unique_ptr<onnxruntime::ComputeCapability>>
PartitionGraph(
const onnxruntime::GraphViewer& graph,
const GraphNodeFactoryMap& graphNodeFactoryMap,
const std::vector<const onnxruntime::KernelRegistry*>& registries,
uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE.
onnxruntime::KernelRegistry* registryForPartitionKernels,
const std::string& partitionKernelPrefix
)
{
std::vector<std::unique_ptr<onnxruntime::ComputeCapability>> result;
// Initializers needed by any graph partition
std::unordered_set<std::string> requiredInitializerMap;
std::unordered_map<const onnxruntime::Node*, GraphNodeProperties> graphNodePropertyMap;
std::vector<std::unique_ptr<GraphPartition>> partitions = BuildPartitions(
graph,
graphNodeFactoryMap,
registries,
supportedDeviceDataTypeMask,
graphNodePropertyMap,
requiredInitializerMap);
// Create a map between each initialized tensor and the partition(s) it is part of.
auto initializerPartitionMap = GetInitializerToPartitionMap(graph, partitions);
for (uint32_t partitionIndex = 0; partitionIndex < partitions.size(); ++partitionIndex)
{
auto& partition = partitions[partitionIndex];
if (partition->GetRootMergedPartition() != partition.get() ||
!partition->IsDmlPartition())
{
continue;
}
// Create a map which will store by name each initializer which should be transferred to the
// partition. This prevents OnnxRuntime from allocating GPU resources and uploading those initializers,
// so the partiton's kernel can do so. In the process, it will pre-process weights while consuming a CPU
// backed resource, avoiding an extra set of GPU resources in memory.
// A shared pointer is used so the functor and contained initializer captures can be cheaply copied within ORT.
auto transferredInitializerMap = std::make_shared<std::unordered_map<std::string, onnx::TensorProto>>();
for (const auto& input : partition->GetInputs())
{
if (partition->IsDmlGraphPartition())
{
const onnx::TensorProto* tensor = nullptr;
if (graph.GetInitializedTensor(input, tensor))
{
// It's only safe to transfer tensors which are used by this partition alone.
auto iter = initializerPartitionMap.find(tensor);
assert(iter != initializerPartitionMap.end());
if (iter->second.size() > 1)
{
bool inputConstant = false;
if (requiredInitializerMap.find(input) != requiredInitializerMap.end())
{
// The kernel relies on this input to be initialized, and it should be small enough to copy
// cheaply. FusedGraphKernel only handles constant CPU inputs through transferred initializers,
// rather than ORT, to avoid mismatches in policy or implementation causing failures.
(*transferredInitializerMap)[input] = const_cast<onnx::TensorProto&>(*tensor);
}
continue;
}
// Transfer the initializer
auto& graphTensor = const_cast<onnx::TensorProto&>(*tensor);
onnx::TensorProto partitionTensor;
graphTensor.Swap(&partitionTensor);
(*transferredInitializerMap)[input] = std::move(partitionTensor);
const_cast<onnxruntime::InitializedTensorSet&>(graph.GetAllInitializedTensors()).erase(graph.GetAllInitializedTensors().find(input));
}
}
}
result.push_back(ComputationCapacityFromPartition(
partition.get(),
partitionIndex,
graph,
std::move(graphNodePropertyMap),
registryForPartitionKernels,
partitionKernelPrefix,
transferredInitializerMap));
}
return result;
}
} // namespace Dml

Просмотреть файл

@ -0,0 +1,63 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h"
namespace Dml
{
class GraphPartition
{
public:
GraphPartition() = default;
~GraphPartition() = default;
GraphPartition* GetRootMergedPartition();
std::vector<onnxruntime::NodeIndex>& GetNodeIndices();
std::set<std::string>& GetInputs();
std::set<std::string>& GetOutputs();
bool IsFinalized();
void SetFinalized();
bool IsDmlPartition();
bool IsDmlGraphPartition();
void SetIsDmlPartition(bool isDmlPartition);
void SetIsDmlGraphPartition(bool isDmlGraphPartition);
void AddNodeIndex(onnxruntime::NodeIndex index);
void AddInput(const std::string& name);
void AddOutput(const std::string& name);
void Merge(gsl::span<GraphPartition*> partitionsToMerge);
private:
std::vector<onnxruntime::NodeIndex> m_nodeIndices;
std::set<std::string> m_inputs;
std::set<std::string> m_outputs;
bool m_finalized = false;
bool m_isDmlGraphPartition = false;
bool m_isDmlPartition = false;
// If not null, this partition has been merged into another, and that partition
// should be used instead.
GraphPartition* m_mergedPartition = nullptr;
};
std::vector<std::unique_ptr<GraphPartition>>
BuildPartitions(
const onnxruntime::GraphViewer& graph,
const winrt::Windows::AI::MachineLearning::implementation::GraphNodeFactoryMap& graphNodeFactoryMap,
const std::vector<const onnxruntime::KernelRegistry*>& registries,
uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE.
std::unordered_map<const onnxruntime::Node*, GraphNodeProperties>& graphNodePropertyMap,
std::unordered_set<std::string>& requiredInitializerMap,
std::function<void(const onnxruntime::Node&)> onNodeUnsupportedInGraph = nullptr);
std::vector<std::unique_ptr<onnxruntime::ComputeCapability>>
PartitionGraph(
const onnxruntime::GraphViewer& graph,
const winrt::Windows::AI::MachineLearning::implementation::GraphNodeFactoryMap& graphNodeFactoryMap,
const std::vector<const onnxruntime::KernelRegistry*>& registries,
uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE.
onnxruntime::KernelRegistry* registryForPartitionKernels,
const std::string& partitionKernelPrefix
);
} // namespace Dml

Просмотреть файл

@ -0,0 +1,201 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
#include "GraphTransformer.h"
#include "Operators/OperatorRegistration.h"
#include "Operators/OperatorUtility.h"
#include "core/providers/dml/OperatorAuthorHelper/Attributes.h"
#include "core/providers/dml/OperatorAuthorHelper/OperatorHelper.h"
#include "core/providers/dml/OperatorAuthorHelper/OperatorRegistration.h"
#include "core/framework/kernel_registry.h"
#include "core/graph/graph_utils.h"
namespace Dml
{
GraphTransformer::GraphTransformer(const std::string& name, std::shared_ptr<onnxruntime::KernelRegistry> dmlRegistry)
: onnxruntime::GraphTransformer(name),
m_registry(dmlRegistry)
{
}
onnxruntime::common::Status GraphTransformer::ApplyImpl(
onnxruntime::Graph& graph,
bool& modified,
int graph_level) const
{
modified = false;
// Perform fusion
{
bool transformModifiedGraph = false;
PerformOperatorFusion(&graph, &transformModifiedGraph);
modified |= transformModifiedGraph;
if (modified)
{
ORT_RETURN_IF_ERROR(graph.Resolve());
}
}
return onnxruntime::common::Status::OK();
}
static std::string GetUniqueNodeName(const onnxruntime::Node* node)
{
std::stringstream ss;
ss << '#' << node->Index();
if (!node->Name().empty())
{
ss << " \'" << node->Name() << '\'';
}
return ss.str();
}
void GraphTransformer::PerformOperatorFusion(onnxruntime::Graph* graph, bool* modified) const
{
struct NodeToAdd
{
std::string name;
std::string description;
std::string opType;
std::string domain;
onnxruntime::NodeAttributes attributes;
std::string activationOpType;
std::string activationOpDomain;
int activationOpVersion;
onnxruntime::NodeAttributes activationAttributes;
std::vector<onnxruntime::NodeArg*> inputs;
std::vector<onnxruntime::NodeArg*> outputs;
};
// Defer adding new nodes to the graph until after we're done iterating over it, because we can't mutate the
// graph while iterating over it
std::vector<NodeToAdd> nodesToAdd;
for (auto& node : graph->Nodes())
{
// We need to predict whether the nodes will be assigned to the DML transformer by Lotus,
// which occurs in IExecutionProvider::GetCapability.
if (!m_registry->TryFindKernel(node, onnxruntime::kDmlExecutionProvider))
{
// Can't fuse nodes that don't belong to this execution provider
continue;
}
// The number of nodes which use the result of this convolution as input
const auto outputNodeCount = std::distance(node.OutputEdgesBegin(), node.OutputEdgesEnd());
if (outputNodeCount != 1)
{
// Can only fuse nodes whose only output feeds into a single activation - if multiple nodes use the
// output of this one, we can't fuse it.
continue;
}
const auto& outputNode = *node.OutputNodesBegin();
// We need to predict whether the nodes will be assigned to the DML transformer by Lotus,
// which occurs in IExecutionProvider::GetCapability.
if (!m_registry->TryFindKernel(outputNode, onnxruntime::kDmlExecutionProvider))
{
// Can't fuse nodes that don't belong to this execution provider
continue;
}
if (outputNode.InputDefs().size() != 1)
{
// Can only fuse activation functions that take a single input
continue;
}
auto fusedOpProperties = FusionHelpers::TryGetFusedOp(
node.OpType(),
node.Domain(),
node.Op()->SinceVersion(),
gsl::narrow_cast<uint32_t>(node.InputDefs().size()),
outputNode.OpType(),
outputNode.Domain(),
outputNode.Op()->SinceVersion());
if (!fusedOpProperties)
{
// These operators can't be fused
continue;
}
const auto& fuseableNode = node;
const auto& activationNode = outputNode;
// Fusable nodes only produce one output
assert(fuseableNode.OutputDefs().size() == 1);
// Activation only produces one output
assert(activationNode.OutputDefs().size() == 1);
// Add a new node that represents the combination of the fuseable node and the activation node.
NodeToAdd fusedNode;
fusedNode.name = "fused op (" + GetUniqueNodeName(&fuseableNode) + ") + (" + GetUniqueNodeName(&activationNode) + ")";
fusedNode.description = "";
fusedNode.opType = fusedOpProperties->opType;
fusedNode.activationOpType = activationNode.OpType();
fusedNode.activationOpDomain = activationNode.Domain();
fusedNode.activationOpVersion = activationNode.Op()->SinceVersion();
fusedNode.domain = fusedOpProperties->domain;
// Make a copy of the attributes of both nodes
fusedNode.attributes = fuseableNode.GetAttributes();
fusedNode.activationAttributes = activationNode.GetAttributes();
// Inputs to the fused node are the inputs to the fuseable node
for (const auto *input : fuseableNode.InputDefs()) {
fusedNode.inputs.push_back(graph->GetNodeArg(input->Name()));
}
// Outputs from the fused node are the outputs to the activation node
for (const auto *output : activationNode.OutputDefs()){
fusedNode.outputs.push_back(graph->GetNodeArg(output->Name()));
}
nodesToAdd.push_back(std::move(fusedNode));
onnxruntime::graph_utils::RemoveNodeOutputEdges(*graph, const_cast<onnxruntime::Node&>(fuseableNode));
onnxruntime::graph_utils::RemoveNodeOutputEdges(*graph, const_cast<onnxruntime::Node&>(activationNode));
// Remove the fuseable and activation nodes - they're replaced by the fused node
bool nodesRemoved = false;
nodesRemoved = graph->RemoveNode(fuseableNode.Index());
nodesRemoved &= graph->RemoveNode(activationNode.Index());
THROW_HR_IF(E_UNEXPECTED, !nodesRemoved);
*modified = true;
}
for (auto& nodeToAdd : nodesToAdd)
{
auto& node = graph->AddNode(
nodeToAdd.name,
nodeToAdd.opType,
nodeToAdd.description,
nodeToAdd.inputs,
nodeToAdd.outputs,
&nodeToAdd.attributes,
nodeToAdd.domain);
// Add a dynamic attribute to the fuseable operator to specify activation
node.AddAttribute(AttrName::FusedActivation, nodeToAdd.activationOpType);
node.AddAttribute(AttrName::FusedActivationDomain, nodeToAdd.activationOpDomain);
node.AddAttribute(AttrName::FusedActivationSinceVersion, static_cast<int64_t>(nodeToAdd.activationOpVersion));
// Copy all attributes from activation into the fuseable node (with new names)
for (auto& attribute : nodeToAdd.activationAttributes)
{
// Change the name of the attribute to its fused node version
std::string fusedAttributeName = Dml::FusionHelpers::GetFusedAttributeName(attribute.first);
attribute.second.set_name(fusedAttributeName);
node.AddAttribute(fusedAttributeName, attribute.second);
}
}
}
} // namespace Dml

Просмотреть файл

@ -0,0 +1,28 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
// Lotus framework headers for onnxruntime::IExecutionProvider (not part of the operator ABI).
#include "core/framework/allocatormgr.h"
#include "core/framework/execution_provider.h"
#include "core/framework/op_kernel.h"
#include "core/optimizer/graph_transformer.h"
namespace Dml
{
// Applies transforms to a Lotus graph. The graph transformer is responsible for setting the execution provider
// on the graph nodes which DML supports.
class GraphTransformer : public onnxruntime::GraphTransformer
{
public:
GraphTransformer(const std::string& name, std::shared_ptr<onnxruntime::KernelRegistry> dmlRegistry);
private:
onnxruntime::common::Status ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level = 0) const final;
private:
void PerformOperatorFusion(onnxruntime::Graph* graph, bool* modified) const;
std::shared_ptr<onnxruntime::KernelRegistry> m_registry;
};
} // namespace Dml

Просмотреть файл

@ -0,0 +1,20 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
namespace Dml
{
class ICommandRecorder
{
public:
virtual ~ICommandRecorder() = default;
virtual void Open() = 0;
// Forces all queued work to begin executing on the GPU. This method returns immediately and does not wait
// for the submitted work to complete execution on the GPU.
virtual void CloseAndExecute() = 0;
};
} // namespace Dml

Просмотреть файл

@ -0,0 +1,73 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h"
namespace Dml
{
struct Binding
{
// Non-null if required at the stage where it is used, i.e. Initialization
IMLOperatorTensor* tensor;
UINT64 sizeInBytes;
};
// DML specific interface into the execution provider, which avoids any dependencies with
// internal Lotus data types.
interface __declspec(uuid("b2488edb-fad2-4704-a6d2-5b5b129d4b8e"))
IExecutionProvider : public IUnknown
{
public:
STDMETHOD(GetD3DDevice)(_COM_Outptr_ ID3D12Device** d3dDevice) const noexcept = 0;
STDMETHOD(GetDmlDevice)(_COM_Outptr_ IDMLDevice** dmlDevice) const noexcept = 0;
STDMETHOD(ExecuteCommandList)(
ID3D12GraphicsCommandList* commandList,
_Outptr_ ID3D12Fence** fence,
_Out_ uint64_t* completionValue
) const noexcept = 0;
STDMETHOD(AddUAVBarrier)() const noexcept = 0;
STDMETHOD(InitializeOperator)(
IDMLCompiledOperator* op,
_In_opt_ const DML_BUFFER_BINDING* persistentResourceBinding,
gsl::span<const DML_BUFFER_BINDING> inputTensors
) const noexcept = 0;
STDMETHOD(ExecuteOperator)(
IDMLCompiledOperator* op,
_In_opt_ const DML_BUFFER_BINDING* persistentResourceBinding,
gsl::span<IMLOperatorTensor*> inputTensors,
gsl::span<IMLOperatorTensor*> outputTensors
) const noexcept = 0;
STDMETHOD(ExecuteOperator)(
IDMLCompiledOperator* op,
_In_opt_ const DML_BUFFER_BINDING* persistentResourceBinding,
gsl::span<DML_BINDING_DESC> inputTensors,
gsl::span<DML_BINDING_DESC> outputTensors
) const noexcept = 0;
STDMETHOD(CopyTensor)(IMLOperatorTensor* dst, IMLOperatorTensor* src) const noexcept = 0;
STDMETHOD(FillTensorWithPattern)(
IMLOperatorTensor* dst,
gsl::span<const std::byte> value
) const noexcept = 0;
STDMETHOD(UploadToResource)(ID3D12Resource* dstData, const void* srcData, uint64_t srcDataSize) const noexcept = 0;
STDMETHOD_(D3D12_COMMAND_LIST_TYPE, GetCommandListTypeForQueue)() const noexcept = 0;
STDMETHOD_(void, Flush)() const noexcept = 0;
STDMETHOD_(ID3D12Resource*, DecodeResource)(void* allocation) const noexcept = 0;
STDMETHOD(AllocatePooledResource(size_t size, AllocatorRoundingMode roundingMode, ID3D12Resource **d3dResource, IUnknown* *pooledResource)) const noexcept = 0;
STDMETHOD_(bool, IsMcdmDevice)() const noexcept = 0;
STDMETHOD_(bool, MetacommandsEnabled)() const noexcept = 0;
};
} // namespace Dml

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -0,0 +1,626 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h"
#include "core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h"
#include "core/framework/op_kernel.h"
#include "core/framework/customregistry.h"
#include "core/framework/tensorprotoutils.h"
#include <wrl/client.h>
#include <wrl/implements.h>
interface IDMLOperator;
namespace WRL
{
template <typename... TInterfaces>
using Base = Microsoft::WRL::RuntimeClass<
Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>,
TInterfaces...
>;
}
namespace winrt::Windows::AI::MachineLearning::implementation
{
using namespace Microsoft::WRL;
// Inline method querying whether tensor shapes are defined, during wrappers
// of shape inference callbacks.
template <class T>
bool InputTensorShapesDefinedOnNode(const onnxruntime::OpNodeProtoHelper<T>& nodeInfo)
{
uint32_t inputCount = nodeInfo.GetInputCount();
for (uint32_t inputIndex = 0; inputIndex < inputCount; ++inputIndex)
{
if (nodeInfo.GetInputType(inputIndex) && (nodeInfo.GetInputType(inputIndex)->value_case() == onnx::TypeProto::kTensorType))
{
if (!nodeInfo.GetInputType(inputIndex)->tensor_type().has_shape())
{
return false;
}
const auto& shape = nodeInfo.GetInputType(inputIndex)->tensor_type().shape();
for (int input_dim = 0; input_dim < shape.dim_size(); ++input_dim)
{
if (!shape.dim(input_dim).has_dim_value())
{
return false;
}
}
}
}
return true;
}
::MLOperatorTensorDataType ToMLTensorDataType(onnx::TensorProto_DataType type);
// Used for default values of attributes
struct AttributeValue
{
public:
size_t ElementCount() const;
void GetAttribute(
MLOperatorAttributeType type,
uint32_t elementCount,
size_t elementByteSize,
void* value) const;
const std::string* GetStringAttribute(
_In_z_ const char* name,
uint32_t elementIndex) const;
std::string name;
MLOperatorAttributeType type = MLOperatorAttributeType::Undefined;
std::vector<int64_t> ints;
std::vector<std::string> strings;
std::vector<float> floats;
};
using AttributeMap = std::map<std::string, AttributeValue>;
// Encapsulation of shapes across different edges of an operator. Non-tensor
// edges and unused edges have an empty array of dimensions.
class EdgeShapes
{
public:
EdgeShapes() = default;
EdgeShapes(size_t count) : m_shapes(count) {}
const std::vector<uint32_t>& GetShape(size_t edgeIndex) const
{
return m_shapes[edgeIndex];
}
std::vector<uint32_t>& GetMutableShape(size_t edgeIndex)
{
return m_shapes[edgeIndex];
}
size_t EdgeCount() const { return m_shapes.size(); }
void Reset(size_t edge_count)
{
m_shapes.clear();
m_shapes.resize(edge_count);
}
bool operator!=(const EdgeShapes& other) const noexcept
{
return (m_shapes != other.m_shapes);
}
private:
std::vector<std::vector<uint32_t>> m_shapes;
};
// Base class for ABI objects which may be "Closed", at which point calls will predictably
// fail or return a dummy value. This is used for transient ABI context objects which
// are passed to methods on kernel or inferencers, and which wrap Lotus objects whose lifetimes
// are not controlled by reference counts of the encapsulating object.
class Closable
{
public:
virtual void Close()
{
m_isClosed = true;
}
protected:
void VerifyNotClosed() const
{
if (m_isClosed)
{
THROW_HR(E_INVALIDARG);
}
}
bool IsClosed() const
{
return m_isClosed;
}
private:
bool m_isClosed = false;
};
template <class NodeInfoImpl_t, class Base1_t, class Base2_t>
class OpNodeInfoWrapper : public Base1_t, public Base2_t, public Closable
{
public:
OpNodeInfoWrapper() = delete;
OpNodeInfoWrapper(
const onnxruntime::OpNodeProtoHelper<NodeInfoImpl_t>* impl,
const EdgeShapes* inputShapesOverride,
const AttributeMap* defaultAttributes,
gsl::span<const uint32_t> requiredConstantCpuInputs,
MLOperatorTensorGetter& constantInputGetter) :
m_impl(impl),
m_inputShapesOverride(inputShapesOverride),
m_defaultAttributes(defaultAttributes),
m_constantInputGetter(constantInputGetter)
{
m_requiredConstantCpuInputs.assign(requiredConstantCpuInputs.begin(), requiredConstantCpuInputs.end());
}
HRESULT STDMETHODCALLTYPE GetAttributeElementCount(
_In_z_ const char* name,
MLOperatorAttributeType type,
uint32_t* elementCount) const noexcept override;
template <MLOperatorAttributeType T>
HRESULT GetAttributeArrayHelper(
_In_z_ const char* name,
uint32_t elementCount,
uint32_t elementByteSize,
void* values) const;
HRESULT STDMETHODCALLTYPE GetAttribute(
_In_z_ const char* name,
MLOperatorAttributeType type,
uint32_t elementCount,
size_t elementByteSize,
void* value) const noexcept override;
HRESULT STDMETHODCALLTYPE GetStringAttributeElementLength(
_In_z_ const char* name,
uint32_t elementIndex,
uint32_t* attributeElementByteLength) const noexcept override;
HRESULT STDMETHODCALLTYPE GetStringAttributeElement(
_In_z_ const char* name,
uint32_t elementIndex,
uint32_t attributeElementByteLength,
char* attributeElement) const noexcept override;
HRESULT STDMETHODCALLTYPE GetTensorAttribute(
_In_z_ const char* name,
_COM_Outptr_ IMLOperatorTensor** tensor) const noexcept override;
uint32_t STDMETHODCALLTYPE GetInputCount() const noexcept override;
uint32_t STDMETHODCALLTYPE GetOutputCount() const noexcept override;
HRESULT STDMETHODCALLTYPE GetInputEdgeDescription(uint32_t inputIndex, MLOperatorEdgeDescription* edgeDesc) const noexcept override;
HRESULT STDMETHODCALLTYPE GetOutputEdgeDescription(uint32_t outputIndex, MLOperatorEdgeDescription* edgeDesc) const noexcept;
HRESULT STDMETHODCALLTYPE GetInputTensorDimensionCount(uint32_t inputIndex, uint32_t* dimensionCount) const noexcept;
HRESULT STDMETHODCALLTYPE GetInputTensorShape(uint32_t inputIndex, uint32_t dimensionCount, uint32_t* dimensions) const noexcept;
bool STDMETHODCALLTYPE IsInputValid(uint32_t inputIndex) const noexcept override;
bool STDMETHODCALLTYPE IsOutputValid(uint32_t outputIndex) const noexcept override;
HRESULT STDMETHODCALLTYPE GetConstantInputTensor(
uint32_t inputIndex,
_Outptr_ IMLOperatorTensor** tensor
) const noexcept;
protected:
// Lifetime is managed by the caller and guaranteed to outlive this class
const onnxruntime::OpNodeProtoHelper<NodeInfoImpl_t>* m_impl = nullptr;
private:
template <MLOperatorAttributeType T>
HRESULT GetAttributeHelper(
const char* name,
uint32_t elementByteSize,
void* value) const;
const std::string* GetStringAttribute(
const char* name,
uint32_t elementIndex) const;
// May be null
const EdgeShapes* m_inputShapesOverride;
std::vector<uint32_t> m_requiredConstantCpuInputs;
MLOperatorTensorGetter m_constantInputGetter;
const AttributeMap* m_defaultAttributes = nullptr;
};
class TensorWrapper : public WRL::Base<IMLOperatorTensor>, public Closable
{
public:
TensorWrapper() = default;
TensorWrapper(onnxruntime::Tensor* impl, bool is_data_handle, IWinmlExecutionProvider* provider, bool isInternalOperator);
uint32_t STDMETHODCALLTYPE GetDimensionCount() const noexcept override;
HRESULT STDMETHODCALLTYPE GetShape(
uint32_t dimensionCount,
uint32_t* dimensions) const noexcept override;
MLOperatorTensorDataType STDMETHODCALLTYPE GetTensorDataType() const noexcept override;
bool STDMETHODCALLTYPE IsCpuData() const noexcept override;
bool STDMETHODCALLTYPE IsDataInterface() const noexcept override;
void* STDMETHODCALLTYPE GetData() noexcept override;
void STDMETHODCALLTYPE GetDataInterface(IUnknown** dataInterface) noexcept override;
const onnxruntime::Tensor* GetInterface() const { return nullptr; }
onnxruntime::Tensor* GetInterface() { return nullptr; }
private:
// Lifetime is managed by the caller and guaranteed to outlive this class
onnxruntime::Tensor* m_impl = nullptr;
ComPtr<IWinmlExecutionProvider> m_winmlExecutionProvider;
bool m_internalOperator = false;
void* m_tensorData = nullptr;
ComPtr<IUnknown> m_dataInterface;
bool m_isDataInterface = false;
// The returned data may be a converted shadow copy, and the piece of it which
// is returned may vary according to kernel registration options.
ComPtr<IUnknown> m_dataInterfaceOrShadowCopy;
ComPtr<IUnknown> m_abiDataInterface;
};
class OnnxTensorWrapper : public WRL::Base<IMLOperatorTensor>, public Closable
{
public:
OnnxTensorWrapper() = default;
OnnxTensorWrapper(onnx::TensorProto* impl);
uint32_t STDMETHODCALLTYPE GetDimensionCount() const noexcept override;
HRESULT STDMETHODCALLTYPE GetShape(
uint32_t dimensionCount,
uint32_t* dimensions) const noexcept override;
MLOperatorTensorDataType STDMETHODCALLTYPE GetTensorDataType() const noexcept override;
bool STDMETHODCALLTYPE IsCpuData() const noexcept override;
bool STDMETHODCALLTYPE IsDataInterface() const noexcept override;
void* STDMETHODCALLTYPE GetData() noexcept override;
void STDMETHODCALLTYPE GetDataInterface(IUnknown** dataInterface) noexcept override;
const onnxruntime::Tensor* GetInterface() const { return nullptr; }
onnxruntime::Tensor* GetInterface() { return nullptr; }
private:
size_t m_tensorByteSize = 0;
std::unique_ptr<std::byte[]> m_unpackedTensor;
std::byte* m_dataPtr = nullptr;
// Lifetime is managed by the caller and guaranteed to outlive this class
onnx::TensorProto* m_impl = nullptr;
};
class OpKernelInfoWrapper : public OpNodeInfoWrapper<
onnxruntime::ProtoHelperNodeContext,
WRL::Base<
Microsoft::WRL::ChainInterfaces<IMLOperatorKernelCreationContextPrivate, IMLOperatorKernelCreationContext>,
IMLOperatorTensorShapeDescription, IMLOperatorAttributes1>,
onnxruntime::null_type>
{
public:
OpKernelInfoWrapper(
const onnxruntime::OpKernelInfo* kerneInfo,
IUnknown* abiExecutionObject,
const EdgeShapes* inputShapeOverrides,
const EdgeShapes* inferredOutputShapes,
bool allowInputShapeQuery,
bool allowOutputShapeQuery,
bool isInternalOperator,
const AttributeMap* defaultAttributes,
gsl::span<const uint32_t> requiredConstantCpuInputs,
MLOperatorTensorGetter& constantInputGetter
);
// HasTensorShapeDescription returns false if and only if the kernel is registered using
// MLOperatorKernelOptions::AllowDynamicInputTensorSizes. If this flag is specified and upstream
// shapes are known when the kernel is created, HasTensorShapeDescription still returns false.
bool STDMETHODCALLTYPE HasTensorShapeDescription() const noexcept override;
HRESULT STDMETHODCALLTYPE GetTensorShapeDescription(IMLOperatorTensorShapeDescription** shapeInfo) const noexcept override;
void STDMETHODCALLTYPE GetExecutionInterface(IUnknown** executionInterface) const noexcept override;
// IMLOperatorTensorShapeDescription methods.
HRESULT STDMETHODCALLTYPE GetOutputTensorDimensionCount(uint32_t inputIndex, uint32_t* dimensionCount) const noexcept;
bool STDMETHODCALLTYPE HasOutputShapeDescription() const noexcept override;
HRESULT STDMETHODCALLTYPE GetOutputTensorShape(uint32_t inputIndex, uint32_t dimensionCount, uint32_t* dimensions) const noexcept;
bool STDMETHODCALLTYPE IsDmlGraphNode() const noexcept override
{
return false;
}
HRESULT STDMETHODCALLTYPE SetDmlOperator(
IDMLOperator* op,
_In_ const DML_OPERATOR_DESC* desc,
_In_opt_ const MLOperatorKernelDmlProperties* dmlProperties
) const noexcept override
{
return E_NOTIMPL;
}
private:
// For shape info, in addition to the info
const EdgeShapes* m_inferredOutputShapes = nullptr;
bool m_allowInputShapeQuery = false;
bool m_allowOutputShapeQuery = false;
bool m_internalOperator = false;
ComPtr<winrt::Windows::AI::MachineLearning::implementation::IWinmlExecutionProvider> m_winmlProvider;
const onnxruntime::OpKernelInfo* m_impl = nullptr;
// The execution object returned through the ABI, which may vary according to kernel
// registration options.
ComPtr<IUnknown> m_abiExecutionObject;
};
// OpKernelInfo used for DML graph fusion. This uses the ONNX graph structures instead of ORT OpKernelInfo.
class DmlGraphOpKernelInfoWrapper : public OpNodeInfoWrapper<
onnxruntime::ProtoHelperNodeContext,
WRL::Base<
Microsoft::WRL::ChainInterfaces<IMLOperatorKernelCreationContextPrivate, IMLOperatorKernelCreationContext>,
IMLOperatorTensorShapeDescription, IMLOperatorAttributes1>,
onnxruntime::null_type>
{
public:
DmlGraphOpKernelInfoWrapper(
const onnxruntime::OpNodeProtoHelper<onnxruntime::ProtoHelperNodeContext> * protoHelper,
const void* executionHandle,
bool isInternalOperator,
const EdgeShapes* inferredOutputShapes,
const AttributeMap* defaultAttributes,
DmlGraphNodeCreateInfo* graphNodeCreateInfo,
gsl::span<const uint32_t> requiredConstantCpuInputs,
MLOperatorTensorGetter& constantInputGetter
);
// HasTensorShapeDescription returns false if and only if the kernel is registered using
// MLOperatorKernelOptions::AllowDynamicInputTensorSizes. If this flag is specified and upstream
// shapes are known when the kernel is created, HasTensorShapeDescription still returns false.
bool STDMETHODCALLTYPE HasTensorShapeDescription() const noexcept override;
HRESULT STDMETHODCALLTYPE GetTensorShapeDescription(IMLOperatorTensorShapeDescription** shapeInfo) const noexcept override;
void STDMETHODCALLTYPE GetExecutionInterface(IUnknown** executionInterface) const noexcept override;
// IMLOperatorTensorShapeDescription methods.
HRESULT STDMETHODCALLTYPE GetOutputTensorDimensionCount(uint32_t inputIndex, uint32_t* dimensionCount) const noexcept;
bool STDMETHODCALLTYPE HasOutputShapeDescription() const noexcept override;
HRESULT STDMETHODCALLTYPE GetOutputTensorShape(uint32_t inputIndex, uint32_t dimensionCount, uint32_t* dimensions) const noexcept;
bool STDMETHODCALLTYPE IsDmlGraphNode() const noexcept override;
HRESULT STDMETHODCALLTYPE SetDmlOperator(
IDMLOperator* op,
_In_ const DML_OPERATOR_DESC* desc,
_In_opt_ const MLOperatorKernelDmlProperties* dmlProperties
) const noexcept override;
private:
void SetDmlProperties(_In_ const MLOperatorKernelDmlProperties* dmlProperties) const;
// For shape info, in addition to the info
const EdgeShapes* m_inferredOutputShapes = nullptr;
ComPtr<winrt::Windows::AI::MachineLearning::implementation::IWinmlExecutionProvider> m_winmlProvider;
bool m_internalOperator = false;
// The execution object returned through the ABI, which may vary according to kernel
// registration options.
ComPtr<IUnknown> m_abiExecutionObject;
DmlGraphNodeCreateInfo* m_graphNodeCreateInfo = nullptr;
};
class OpKernelContextWrapper : public WRL::Base<IMLOperatorKernelContext>, public Closable
{
public:
~OpKernelContextWrapper();
OpKernelContextWrapper(onnxruntime::OpKernelContext* context, const onnxruntime::IExecutionProvider* provider, bool isInternalOperator, const EdgeShapes* outputShapes);
HRESULT STDMETHODCALLTYPE GetInputTensor(uint32_t inputIndex, IMLOperatorTensor** tensor) const noexcept override;
HRESULT STDMETHODCALLTYPE GetOutputTensor(uint32_t outputIndex, IMLOperatorTensor** tensor) noexcept override;
HRESULT STDMETHODCALLTYPE GetOutputTensor(uint32_t outputIndex, uint32_t dimensions, const uint32_t* dimensionSizes, IMLOperatorTensor** tensor) noexcept override;
HRESULT STDMETHODCALLTYPE AllocateTemporaryData(size_t size, IUnknown** data) const override;
HRESULT STDMETHODCALLTYPE AllocateTemporaryData(size_t size, IUnknown** data, uint64_t* allocId) const;
void STDMETHODCALLTYPE GetExecutionInterface(IUnknown** executionInterface) const noexcept override;
void Close() override;
std::vector<IMLOperatorTensor*> GetInputTensors();
std::vector<IMLOperatorTensor*> GetOutputTensors(const EdgeShapes& outputShapes);
protected:
void ClearTempAllocations();
void TransitionResourcesForOperatorIfRequired(bool isBeforeOp);
// Lifetime is managed by the caller and guaranteed to outlive this class
onnxruntime::OpKernelContext* m_impl = nullptr;
const EdgeShapes* m_outputShapes = nullptr;
std::vector<ComPtr<TensorWrapper>> m_inputTensors;
std::vector<ComPtr<TensorWrapper>> m_outputTensors;
const onnxruntime::IExecutionProvider* m_provider = nullptr;
ComPtr<winrt::Windows::AI::MachineLearning::implementation::IWinmlExecutionProvider> m_winmlProvider;
bool m_internalOperator = false;
// The execution object returned to the kernel may vary according to kernel execution options
ComPtr<IUnknown> m_providerExecutionObject;
ComPtr<IUnknown> m_abiExecutionObject;
// Temporary allocations created by the kernel. These will be freed to the allocator following
// Compute being called on the kernel. This list is used to maintain their lifetime.
mutable std::vector<ComPtr<IUnknown>> m_temporaryAllocations;
mutable std::vector<ComPtr<IUnknown>> m_temporaryAbiAllocations;
};
class AbiOpKernel : public onnxruntime::OpKernel
{
public:
AbiOpKernel(
IMLOperatorKernelFactory* operatorFactory,
const onnxruntime::OpKernelInfo& kerneInfo,
bool requiresInputShapesAtCreation,
bool requiresOutputShapesAtCreation,
bool isInternalOperator,
gsl::span<const uint32_t> requiredConstantCpuInputs,
IMLOperatorShapeInferrer* shapeInferrer,
const AttributeMap* defaultAttributes);
onnxruntime::Status Compute(onnxruntime::OpKernelContext* context) const override;
protected:
bool RequiresLazyInitialization() const { return (m_operatorFactory != nullptr) && !m_lazyInitialized; };
void SetLazyInitialized() const { m_lazyInitialized = true; };
EdgeShapes GetInputShapes(onnxruntime::OpKernelContext* context) const;
bool InputTensorShapesDefined() const;
bool InputSizesInferencedFromSchema() const;
void InferAndVerifyOutputSizes(gsl::span<const uint32_t> requiredConstantCpuInputs, MLOperatorTensorGetter& constantInputGetter, const EdgeShapes* inputShapes, EdgeShapes& outputShapes) const;
bool m_requiresInputShapesAtCreation = false;
bool m_requiresOutputShapesAtCreation = false;
mutable Microsoft::WRL::ComPtr<IMLOperatorKernel> m_kernel;
// This is null unless the kernel requires lazy initialization
ComPtr<IMLOperatorKernelFactory> m_operatorFactory;
mutable volatile bool m_lazyInitialized = false;
ComPtr<IMLOperatorShapeInferrer> m_shapeInferrer;
// Used to determine whether anything has changed since creation when shapes or
// inputs treated as constant by the operator are not inferred / constant.
mutable EdgeShapes m_inputShapesOfKernelInference;
struct TensorContent
{
std::vector<uint32_t> shape;
MLOperatorTensorDataType type;
std::vector<std::byte> data;
};
mutable std::vector<TensorContent> m_constantInputTensorContentsOfKernel;
mutable std::mutex m_mutex;
mutable EdgeShapes m_inferredOutputShapes;
ComPtr<winrt::Windows::AI::MachineLearning::implementation::IWinmlExecutionProvider> m_winmlProvider;
bool m_internalOperator = false;
std::vector<uint32_t> m_requiredConstantCpuInputs;
// The execution object returned through the ABI may vary according to kernel
// registration options.
ComPtr<IUnknown> m_providerExecutionObject;
ComPtr<IUnknown> m_abiExecutionObject;
const AttributeMap* m_defaultAttributes = nullptr;
};
class MLSchemaInferenceContext final : public OpNodeInfoWrapper<
onnx::InferenceContext,
WRL::Base<
Microsoft::WRL::ChainInterfaces<IMLOperatorShapeInferenceContextPrivate, IMLOperatorShapeInferenceContext>,
IMLOperatorTypeInferenceContext, IMLOperatorAttributes, IMLOperatorAttributes1>,
onnxruntime::null_type>
{
public:
MLSchemaInferenceContext() = delete;
MLSchemaInferenceContext(
onnxruntime::OpNodeProtoHelper<onnx::InferenceContext>* info,
onnx::InferenceContext* ctx,
gsl::span<const uint32_t> requiredConstantCpuInputs
);
onnx::InferenceContext* GetContext() const
{
return m_context;
}
HRESULT STDMETHODCALLTYPE SetOutputEdgeDescription(uint32_t outputIndex, const MLOperatorEdgeDescription* edgeDesc) const noexcept override;
HRESULT STDMETHODCALLTYPE SetOutputTensorShape(uint32_t outputIndex, uint32_t dimensionCount, const uint32_t* dimensions) noexcept override;
private:
onnx::InferenceContext* m_context = nullptr;
};
class MLKernelInferenceContext final : public OpNodeInfoWrapper<
onnxruntime::ProtoHelperNodeContext,
WRL::Base<Microsoft::WRL::ChainInterfaces<IMLOperatorShapeInferenceContextPrivate, IMLOperatorShapeInferenceContext>, IMLOperatorAttributes, IMLOperatorAttributes1>,
onnxruntime::null_type>
{
public:
MLKernelInferenceContext() = delete;
MLKernelInferenceContext(
onnxruntime::OpNodeProtoHelper<onnxruntime::ProtoHelperNodeContext>* info,
const EdgeShapes* inputShapesOverride,
EdgeShapes& inferredOutputShapes,
const AttributeMap* defaultAttributes,
gsl::span<const uint32_t> requiredConstantCpuInputs,
MLOperatorTensorGetter& constantInputGetter) :
OpNodeInfoWrapper(info, inputShapesOverride, defaultAttributes, requiredConstantCpuInputs, constantInputGetter),
m_inferredOutputShapes(inferredOutputShapes)
{
}
HRESULT STDMETHODCALLTYPE SetOutputTensorShape(uint32_t outputIndex, uint32_t dimensionCount, const uint32_t* dimensions) noexcept override;
private:
EdgeShapes& m_inferredOutputShapes;
};
void InferAndVerifyOutputSizes(
const onnxruntime::Node& node,
const AttributeMap* defaultAttributes,
IMLOperatorShapeInferrer* shapeInferrer,
gsl::span<const uint32_t> requiredConstantCpuInputs,
MLOperatorTensorGetter& constantInputGetter,
const EdgeShapes* inputShapes,
EdgeShapes& outputShapes);
onnxruntime::MLDataType ToTensorDataType(::MLOperatorTensorDataType type);
std::string ToTypeString(MLOperatorEdgeDescription desc);
onnx::AttributeProto_AttributeType ToProto(MLOperatorAttributeType type);
bool TryGetStaticInputShapes(const onnxruntime::Node& node, EdgeShapes& inputShapes);
bool TryGetStaticOutputShapes(const onnxruntime::Node& node, EdgeShapes& outputShapes);
std::tuple<std::unique_ptr<std::byte[]>, size_t> UnpackTensor(const onnx::TensorProto& initializer);
} // namespace winrt::Windows::AI::MachineLearning::implementation

Просмотреть файл

@ -0,0 +1,446 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
DmlOperator::DmlOperator(const MLOperatorKernelCreationContext& kernelInfo)
{
ML_CHECK_HRESULT(kernelInfo.GetExecutionInterface().As(&m_executionProvider));
ML_CHECK_HRESULT(m_executionProvider->GetDmlDevice(/*out*/ m_dmlDevice.GetAddressOf()));
}
void DmlOperator::SetDmlOperatorDesc(
const DML_OPERATOR_DESC& operatorDesc,
const MLOperatorKernelCreationContext& kernelInfo
)
{
// Initialize should only be called once.
assert(m_compiledOperator == nullptr);
// Create and compile the operator.
ComPtr<IDMLOperator> dmlOperator;
THROW_IF_FAILED(m_dmlDevice->CreateOperator(&operatorDesc, IID_PPV_ARGS(&dmlOperator)));
ComPtr<IMLOperatorKernelCreationContextPrivate> contextPrivate;
THROW_IF_FAILED(kernelInfo.GetInterface()->QueryInterface(contextPrivate.GetAddressOf()));
if (contextPrivate->IsDmlGraphNode())
{
// Create an edge list using sentinels for unused edges, as required by the SetDmlOperator ABI
auto ReplaceUnusedEdgeIndicesWithSentinel = [](gsl::span<const std::optional<uint32_t>> indices)
{
std::vector<uint32_t> ret;
ret.reserve(indices.size());
for (const std::optional<uint32_t>& index : indices)
{
ret.push_back(index.has_value() ? index.value() : std::numeric_limits<uint32_t>::max());
}
return ret;
};
MLOperatorKernelDmlProperties properties = {};
auto kernelInputIndices = ReplaceUnusedEdgeIndicesWithSentinel(m_kernelInputIndices);
properties.dmlInputCount = static_cast<uint32_t>(kernelInputIndices.size());
properties.kernelInputIndices = kernelInputIndices.data();
auto kernelOutputIndices = ReplaceUnusedEdgeIndicesWithSentinel(m_kernelOutputIndices);
properties.dmlOutputCount = static_cast<uint32_t>(kernelOutputIndices.size());
properties.kernelOutputIndices = kernelOutputIndices.data();
properties.allowHalfPrecisionComputation = AllowHalfPrecisionComputation();
THROW_IF_FAILED(contextPrivate->SetDmlOperator(dmlOperator.Get(), &operatorDesc, &properties));
}
else
{
DML_EXECUTION_FLAGS executionFlags = GetExecutionFlags();
THROW_IF_FAILED(m_dmlDevice->CompileOperator(dmlOperator.Get(), executionFlags, IID_PPV_ARGS(&m_compiledOperator)));
UINT64 persistentResourceSize = m_compiledOperator->GetBindingProperties().PersistentResourceSize;
if (persistentResourceSize > 0)
{
THROW_IF_FAILED(m_executionProvider->AllocatePooledResource(
persistentResourceSize,
AllocatorRoundingMode::Enabled,
m_persistentResource.GetAddressOf(),
m_persistentResourcePoolingUnk.GetAddressOf()));
m_persistentResourceBinding = DML_BUFFER_BINDING{ m_persistentResource.Get(), 0, persistentResourceSize };
}
std::vector<DML_BUFFER_BINDING> initializationInputBindings(m_kernelInputIndices.size());
THROW_IF_FAILED(m_executionProvider->InitializeOperator(
m_compiledOperator.Get(),
m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr,
gsl::make_span(initializationInputBindings)));
}
}
void DmlOperator::SetDmlOperatorDesc(
const DML_OPERATOR_DESC& operatorDesc,
const MLOperatorKernelContext& kernelInfo
)
{
// Create and compile the operator.
// Unlike SetDmlOperatorDesc which takes a MLOperatorKernelCreationContext, it is okay to
// call this method more than once, since Compute may take different inputs each execution.
m_compiledOperator.Reset();
ComPtr<IDMLOperator> dmlOperator;
THROW_IF_FAILED(m_dmlDevice->CreateOperator(&operatorDesc, IID_PPV_ARGS(&dmlOperator)));
THROW_IF_FAILED(m_dmlDevice->CompileOperator(dmlOperator.Get(), GetExecutionFlags(), IID_PPV_ARGS(&m_compiledOperator)));
UINT64 persistentResourceSize = m_compiledOperator->GetBindingProperties().PersistentResourceSize;
if (persistentResourceSize > 0)
{
if (!m_persistentResource || m_persistentResource->GetDesc().Width < persistentResourceSize)
{
m_persistentResource = nullptr;
THROW_IF_FAILED(m_executionProvider->AllocatePooledResource(
persistentResourceSize,
AllocatorRoundingMode::Enabled,
m_persistentResource.GetAddressOf(),
m_persistentResourcePoolingUnk.GetAddressOf()));
}
m_persistentResourceBinding = DML_BUFFER_BINDING{ m_persistentResource.Get(), 0, persistentResourceSize };
}
THROW_IF_FAILED(m_executionProvider->InitializeOperator(
m_compiledOperator.Get(),
m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr,
gsl::span<const DML_BUFFER_BINDING>() // Empty input bindings since ownedByDml is not used.
));
}
void DmlOperator::Initialize(
const MLOperatorKernelCreationContext& kernelInfo,
const std::optional<const std::vector<std::optional<uint32_t>>>& kernelInputIndices,
const std::optional<const std::vector<std::optional<uint32_t>>>& kernelOutputIndices,
const std::optional<gsl::span<const uint32_t>> inputShape,
const std::optional<gsl::span<const uint32_t>> outputShape
)
{
if (kernelInputIndices)
{
m_kernelInputIndices = *kernelInputIndices;
}
else
{
m_kernelInputIndices.resize(kernelInfo.GetInputCount());
std::iota(m_kernelInputIndices.begin(), m_kernelInputIndices.end(), 0);
}
if (kernelOutputIndices)
{
m_kernelOutputIndices = *kernelOutputIndices;
}
else
{
m_kernelOutputIndices.resize(kernelInfo.GetOutputCount());
std::iota(m_kernelOutputIndices.begin(), m_kernelOutputIndices.end(), 0);
}
for (uint32_t i = 0; i < m_kernelInputIndices.size(); i++)
{
// Update m_kernelInputIndices to reflect optional tensors.
if (m_kernelInputIndices[i] == std::nullopt ||
!kernelInfo.IsInputValid(*m_kernelInputIndices[i]))
{
m_kernelInputIndices[i] = std::nullopt;
m_inputTensorDescs.push_back(TensorDesc());
}
else
{
m_inputTensorDescs.push_back(CreateTensorDescFromInput(
kernelInfo,
*m_kernelInputIndices[i],
TensorAxis::DoNotCoerce,
TensorAxis::W,
TensorAxis::RightAligned,
inputShape,
NchwDimensionCount));
}
}
for (uint32_t i = 0; i < m_kernelOutputIndices.size(); i++)
{
// Update m_kernelOutputIndices to reflect optional tensors.
if (m_kernelOutputIndices[i] == std::nullopt ||
!kernelInfo.IsOutputValid(*m_kernelOutputIndices[i]))
{
m_kernelOutputIndices[i] = std::nullopt;
m_outputTensorDescs.push_back(TensorDesc());
}
else
{
m_outputTensorDescs.push_back(CreateTensorDescFromOutput(
kernelInfo,
*m_kernelOutputIndices[i],
TensorAxis::DoNotCoerce,
TensorAxis::W,
TensorAxis::RightAligned,
outputShape));
}
}
}
void DmlOperator::Compute(const MLOperatorKernelContext& kernelContext)
{
std::vector<IMLOperatorTensor*> inputTensors = GetInputTensorsForExecute(kernelContext);
std::vector<IMLOperatorTensor*> outputTensors = GetOutputTensorsForExecute(kernelContext);
THROW_IF_FAILED(m_executionProvider->ExecuteOperator(
m_compiledOperator.Get(),
m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr,
gsl::make_span(inputTensors),
gsl::make_span(outputTensors)));
}
bool DmlOperator::AllowHalfPrecisionComputation() const
{
// Most of our operators work with float data, but some do not. In those cases
// no input params are float tensors. This function returns true if the operator
// works with at least one float16 tensor and has no tensors of float32 type
bool usesFloat16Tensors = false;
for (const TensorDesc& desc : m_inputTensorDescs)
{
if (desc.GetDmlDataType() == DML_TENSOR_DATA_TYPE_FLOAT32)
{
return false;
}
if (desc.GetDmlDataType() == DML_TENSOR_DATA_TYPE_FLOAT16)
{
usesFloat16Tensors = true;
}
}
for (const auto& desc : m_outputTensorDescs)
{
if (desc.GetDmlDataType() == DML_TENSOR_DATA_TYPE_FLOAT32)
{
return false;
}
}
return usesFloat16Tensors;
}
DML_EXECUTION_FLAGS DmlOperator::GetExecutionFlags() const
{
DML_EXECUTION_FLAGS flags = DML_EXECUTION_FLAG_NONE;
if (AllowHalfPrecisionComputation())
{
flags |= DML_EXECUTION_FLAG_ALLOW_HALF_PRECISION_COMPUTATION;
}
if (!m_executionProvider->MetacommandsEnabled())
{
flags |= DML_EXECUTION_FLAG_DISABLE_META_COMMANDS;
}
return flags;
}
std::vector<IMLOperatorTensor*> DmlOperator::GetInputTensors(const MLOperatorKernelContext& kernelContext)
{
std::vector<IMLOperatorTensor*> inputTensors(m_kernelInputIndices.size());
for (uint32_t i = 0; i < inputTensors.size(); i++)
{
if (m_kernelInputIndices[i] != std::nullopt)
{
assert(m_inputTensorDescs[i].IsValid());
inputTensors[i] = kernelContext.GetInputTensor(*m_kernelInputIndices[i]).GetInterface().Get();
}
}
return inputTensors;
}
std::vector<IMLOperatorTensor*> DmlOperator::GetOutputTensors(const MLOperatorKernelContext& kernelContext)
{
std::vector<IMLOperatorTensor*> outputTensors(m_kernelOutputIndices.size());
for (uint32_t i = 0; i < outputTensors.size(); i++)
{
if (m_kernelOutputIndices[i] != std::nullopt)
{
assert(m_outputTensorDescs[i].IsValid());
outputTensors[i] = kernelContext.GetOutputTensor(*m_kernelOutputIndices[i]).GetInterface().Get();
}
}
return outputTensors;
}
std::vector<IMLOperatorTensor*> DmlOperator::GetInputTensorsForExecute(const MLOperatorKernelContext& kernelContext)
{
return GetInputTensors(kernelContext);
}
std::vector<IMLOperatorTensor*> DmlOperator::GetOutputTensorsForExecute(const MLOperatorKernelContext& kernelContext)
{
return GetOutputTensors(kernelContext);
}
std::vector<DML_TENSOR_DESC> DmlOperator::GetDmlInputDescs()
{
std::vector<DML_TENSOR_DESC> descs(m_inputTensorDescs.size());
for (size_t i = 0; i < descs.size(); i++)
{
descs[i] = m_inputTensorDescs[i].GetDmlDesc();
}
return descs;
}
std::vector<DML_TENSOR_DESC> DmlOperator::GetDmlOutputDescs()
{
std::vector<DML_TENSOR_DESC> descs(m_outputTensorDescs.size());
for (size_t i = 0; i < descs.size(); i++)
{
descs[i] = m_outputTensorDescs[i].GetDmlDesc();
}
return descs;
}
ComPtr<IDMLCompiledOperator> DmlOperator::InitializeZeroInt64Tensor(uint64_t tensorSizeInBytes)
{
// This fun little solution uses DML's element-wise shader with XOR to zero the memory of the passed-in
// tensor. This requires that the tensor's memory has been initialized (i.e. raw_mutable_data has been
// called, and there is a size to the tensor). The tensor is XOR'd with itself to produce zeros,
// and the operation is performed in-place on the same tensor.
// Treat the tensor as a 1D array of 32-bit UINTs.
uint32_t sizes[] = { 1, 1, 1, gsl::narrow<uint32_t>(tensorSizeInBytes / sizeof(uint32_t)) };
DML_BUFFER_TENSOR_DESC bufferDesc = {};
bufferDesc.DataType = DML_TENSOR_DATA_TYPE_UINT32;
bufferDesc.Sizes = sizes;
bufferDesc.DimensionCount = ARRAYSIZE(sizes);
bufferDesc.TotalTensorSizeInBytes = tensorSizeInBytes;
DML_TENSOR_DESC tensorDesc = { DML_TENSOR_TYPE_BUFFER, &bufferDesc };
DML_ELEMENT_WISE_LOGICAL_XOR_OPERATOR_DESC xorDesc = {};
xorDesc.ATensor = &tensorDesc;
xorDesc.BTensor = &tensorDesc;
xorDesc.OutputTensor = &tensorDesc;
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR, &xorDesc };
ComPtr<IDMLOperator> dmlOperator;
THROW_IF_FAILED(m_dmlDevice->CreateOperator(&opDesc, IID_PPV_ARGS(&dmlOperator)));
ComPtr<IDMLCompiledOperator> dmlCompiledOperator;
THROW_IF_FAILED(m_dmlDevice->CompileOperator(dmlOperator.Get(), GetExecutionFlags(), IID_PPV_ARGS(&dmlCompiledOperator)));
return dmlCompiledOperator;
}
void DmlOperator::ExecuteZeroInt64Tensor(IDMLCompiledOperator* compiledOperator, IMLOperatorTensor* tensor)
{
// Element-wise XOR takes two inputs and an output. We want in-place execution, so all three
// resources are the same.
IMLOperatorTensor* inputTensors[] = { tensor, tensor };
IMLOperatorTensor* outputTensors[] = { tensor };
THROW_IF_FAILED(m_executionProvider->ExecuteOperator(
compiledOperator,
nullptr, // persistent resource binding
gsl::make_span(inputTensors),
gsl::make_span(outputTensors)
));
}
TensorDesc DmlOperator::CreateTensorDescFromInput(
const MLOperatorKernelCreationContext& kernelInfo,
uint32_t index,
uint32_t coerceAxis,
int32_t placement,
int32_t leftAlignedDimensionCount,
std::optional<gsl::span<const uint32_t>> tensorShape,
uint32_t minDimensionCount
) const
{
if (!kernelInfo.IsInputValid(index))
{
// The tensor is optional.
return TensorDesc();
}
auto edgeDesc = kernelInfo.GetInputEdgeDescription(index);
assert(edgeDesc.edgeType == MLOperatorEdgeType::Tensor);
std::vector<uint32_t> actualTensorShape;
if (kernelInfo.HasTensorShapeDescription())
{
actualTensorShape = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(index);
}
else
{
// The tensor has delayed shape determination.
return TensorDesc();
}
return TensorDesc(
edgeDesc.tensorDataType,
tensorShape ? *tensorShape : actualTensorShape,
actualTensorShape,
coerceAxis,
placement,
leftAlignedDimensionCount,
minDimensionCount,
0
);
}
TensorDesc DmlOperator::CreateTensorDescFromOutput(
const MLOperatorKernelCreationContext& kernelInfo,
uint32_t index,
uint32_t coerceAxis,
int32_t placement,
int32_t leftAlignedDimensionCount,
std::optional<gsl::span<const uint32_t>> tensorShape,
uint32_t minDimensionCount
) const
{
if (!kernelInfo.IsOutputValid(index))
{
// The tensor is optional.
return TensorDesc();
}
auto edgeDesc = kernelInfo.GetOutputEdgeDescription(index);
assert(edgeDesc.edgeType == MLOperatorEdgeType::Tensor);
if (!kernelInfo.HasTensorShapeDescription())
{
// The tensor has delayed shape determination.
return TensorDesc(edgeDesc.tensorDataType);
}
MLOperatorTensorShapeDescription outputShapeDescription = kernelInfo.GetTensorShapeDescription();
if (!outputShapeDescription.HasOutputShapeDescription())
{
// The tensor has delayed shape determination.
return TensorDesc();
}
auto outputShape = outputShapeDescription.GetOutputTensorShape(index);
return TensorDesc(
edgeDesc.tensorDataType,
tensorShape ? *tensorShape : outputShape,
tensorShape ? *tensorShape : outputShape,
coerceAxis,
placement,
leftAlignedDimensionCount,
minDimensionCount,
0
);
}
} // namespace Dml

Просмотреть файл

@ -0,0 +1,107 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "OperatorUtility.h"
namespace Dml
{
class DmlOperator
{
public:
DmlOperator(const MLOperatorKernelCreationContext& kernelInfo);
virtual ~DmlOperator() = default;
virtual void Compute(const MLOperatorKernelContext& kernelContext);
protected:
ComPtr<IExecutionProvider> m_executionProvider;
ComPtr<IDMLDevice> m_dmlDevice;
// Tensor descs ordered based on index arrays passed to Initialize
std::vector<TensorDesc> m_inputTensorDescs;
std::vector<TensorDesc> m_outputTensorDescs;
ComPtr<IDMLCompiledOperator> m_compiledOperator;
ComPtr<ID3D12Resource> m_persistentResource;
ComPtr<IUnknown> m_persistentResourcePoolingUnk; // Controls when the persistent resource is returned to the pool
std::optional<DML_BUFFER_BINDING> m_persistentResourceBinding;
void Initialize(
const MLOperatorKernelCreationContext& kernelInfo,
const std::optional<const std::vector<std::optional<uint32_t>>>& kernelInputIndices = std::nullopt,
const std::optional<const std::vector<std::optional<uint32_t>>>& kernelOutputIndices = std::nullopt,
const std::optional<gsl::span<const uint32_t>> inputShape = std::nullopt,
const std::optional<gsl::span<const uint32_t>> outputShape = std::nullopt
);
bool AllowHalfPrecisionComputation() const;
DML_EXECUTION_FLAGS GetExecutionFlags() const;
void SetDmlOperatorDesc(
const DML_OPERATOR_DESC& operatorDesc,
const MLOperatorKernelCreationContext& kernelInfo
);
void SetDmlOperatorDesc(
const DML_OPERATOR_DESC& operatorDesc,
const MLOperatorKernelContext& kernelInfo
);
// Tensors ordered based on index arrays passed to Initialize
std::vector<IMLOperatorTensor*> GetInputTensors(const MLOperatorKernelContext& kernelContext);
std::vector<IMLOperatorTensor*> GetOutputTensors(const MLOperatorKernelContext& kernelContext);
// Retrieves the input/output tensors to be supplied to DirectML for execution. These differ from
// Get[Input|Output]Tensors in that they account for the binding requirements of DML, instead of
// unconditionally retrieving all input and output tensors.
std::vector<IMLOperatorTensor*> GetInputTensorsForExecute(const MLOperatorKernelContext& kernelContext);
std::vector<IMLOperatorTensor*> GetOutputTensorsForExecute(const MLOperatorKernelContext& kernelContext);
// Tensor descs ordered based on index arrays passed to Initialize
std::vector<DML_TENSOR_DESC> GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> GetDmlOutputDescs();
// Sets the memory of a tensor to all zeros.
//
// WinML requires int64_t for certain operators, like ArgMax and ArgMin. DML does not directly support
// int64_t as a tensor data type, because D3D does not support 64-bit integers. Currently, we "hack"
// support for int64_t WinML tensors using int32_t tensors with strides; the upper 32-bits are not used,
// since this hack is only used for unsigned values that require less than 32 bits. However, WinML
// will read the full 64-bit values. This means it is necessary to zero out the memory to ensure there
// are no uninitialized values in the upper 32-bit portion of the tensor memory.
//
ComPtr<IDMLCompiledOperator> InitializeZeroInt64Tensor(uint64_t tensorSizeInBytes);
void ExecuteZeroInt64Tensor(IDMLCompiledOperator* compiledOperator, IMLOperatorTensor* tensor);
TensorDesc CreateTensorDescFromInput(
const MLOperatorKernelCreationContext& kernelInfo,
uint32_t index,
uint32_t coerceAxis = TensorAxis::DoNotCoerce,
int32_t placement = TensorAxis::W,
int32_t leftAlignedDimensionCount = TensorAxis::RightAligned,
std::optional<gsl::span<const uint32_t>> tensorShape = std::nullopt,
uint32_t minDimensionCount = NchwDimensionCount
) const;
TensorDesc CreateTensorDescFromOutput(
const MLOperatorKernelCreationContext& kernelInfo,
uint32_t index,
uint32_t coerceAxis = TensorAxis::DoNotCoerce,
int32_t placement = TensorAxis::W,
int32_t leftAlignedDimensionCount = TensorAxis::RightAligned,
std::optional<gsl::span<const uint32_t>> tensorShape = std::nullopt,
uint32_t minDimensionCount = NchwDimensionCount
) const;
private:
// For each input or output of the DML kernel, the corresponding input or output of the original
// kernel. Entries for unused DML inputs are nullopt.
std::vector<std::optional<uint32_t>> m_kernelInputIndices;
std::vector<std::optional<uint32_t>> m_kernelOutputIndices;
};
} // namespace Dml

Просмотреть файл

@ -0,0 +1,171 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
class DmlOperatorActivation : public DmlOperator
{
public:
using Self = DmlOperatorActivation;
DmlOperatorActivation(
const MLOperatorKernelCreationContext& kernelCreationContext,
DML_OPERATOR_TYPE operatorType
)
: DmlOperator(kernelCreationContext)
{
// Activation has a single output which is mapped to the first kernel output. Specifying
// this manually avoids a problem when activation is used to implement dropout, which may
// have a 'mask' output which is unused during inference.
std::vector<std::optional<uint32_t>> kernelOutputIndices = {0};
DmlOperator::Initialize(kernelCreationContext, std::nullopt, kernelOutputIndices);
ActivationOperatorDescUnion operatorDesc = {};
int coerceAxis = TensorAxis::DoNotCoerce;
switch (operatorType)
{
case DML_OPERATOR_ACTIVATION_ELU:
operatorDesc.elu.Alpha = kernelCreationContext.GetOptionalAttribute<float>(AttrName::Alpha, ActivationHelper::GetDefaultAlpha(operatorType));
break;
case DML_OPERATOR_ACTIVATION_SOFTMAX:
case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX:
case DML_OPERATOR_ACTIVATION_HARDMAX:
{
const uint32_t onnxDimCount = gsl::narrow_cast<uint32_t>(kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(0).size());
coerceAxis = HandleNegativeAxis(kernelCreationContext.GetOptionalAttribute<int>(AttrName::Axis, 1), onnxDimCount);
}
break;
case DML_OPERATOR_ACTIVATION_HARD_SIGMOID:
operatorDesc.hardSigmoid.Alpha = kernelCreationContext.GetOptionalAttribute<float>(AttrName::Alpha, ActivationHelper::GetDefaultAlpha(operatorType));
operatorDesc.hardSigmoid.Beta = kernelCreationContext.GetOptionalAttribute<float>(AttrName::Beta, ActivationHelper::GetDefaultBeta(operatorType));
break;
case DML_OPERATOR_ACTIVATION_LEAKY_RELU:
operatorDesc.leakyRelu.Alpha = kernelCreationContext.GetOptionalAttribute<float>(AttrName::Alpha, ActivationHelper::GetDefaultAlpha(operatorType));
break;
case DML_OPERATOR_ACTIVATION_LINEAR:
operatorDesc.linear.Alpha = kernelCreationContext.GetOptionalAttribute<float>(AttrName::Alpha, ActivationHelper::GetDefaultAlpha(operatorType));
operatorDesc.linear.Beta = kernelCreationContext.GetOptionalAttribute<float>(AttrName::Beta, ActivationHelper::GetDefaultBeta(operatorType));
break;
case DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS:
operatorDesc.parametricSoftplus.Alpha = kernelCreationContext.GetOptionalAttribute<float>(AttrName::Alpha, ActivationHelper::GetDefaultAlpha(operatorType));
operatorDesc.parametricSoftplus.Beta = kernelCreationContext.GetOptionalAttribute<float>(AttrName::Beta, ActivationHelper::GetDefaultBeta(operatorType));
break;
case DML_OPERATOR_ACTIVATION_SCALED_ELU:
operatorDesc.scaledElu.Alpha = kernelCreationContext.GetOptionalAttribute<float>(AttrName::Alpha, ActivationHelper::GetDefaultAlpha(operatorType));
operatorDesc.scaledElu.Gamma = kernelCreationContext.GetOptionalAttribute<float>(AttrName::Gamma, 0.0f);
break;
case DML_OPERATOR_ACTIVATION_SCALED_TANH:
operatorDesc.scaledTanh.Alpha = kernelCreationContext.GetOptionalAttribute<float>(AttrName::Alpha, ActivationHelper::GetDefaultAlpha(operatorType));
operatorDesc.scaledTanh.Beta = kernelCreationContext.GetOptionalAttribute<float>(AttrName::Beta, ActivationHelper::GetDefaultBeta(operatorType));
break;
case DML_OPERATOR_ACTIVATION_SOFTPLUS:
operatorDesc.softplus.Steepness = 1.0f;
break;
case DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU:
operatorDesc.thresholdedRelu.Alpha = kernelCreationContext.GetOptionalAttribute<float>(AttrName::Alpha, ActivationHelper::GetDefaultAlpha(operatorType));
break;
case DML_OPERATOR_ACTIVATION_SHRINK:
operatorDesc.shrink.Bias = kernelCreationContext.GetOptionalAttribute<float>(AttrName::Bias, ActivationHelper::GetDefaultBias(operatorType));
operatorDesc.shrink.Threshold = kernelCreationContext.GetOptionalAttribute<float>(AttrName::Lambda, ActivationHelper::GetDefaultLambda(operatorType));
break;
case DML_OPERATOR_ACTIVATION_IDENTITY:
case DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU:
case DML_OPERATOR_ACTIVATION_RELU:
case DML_OPERATOR_ACTIVATION_SIGMOID:
case DML_OPERATOR_ACTIVATION_TANH:
case DML_OPERATOR_ACTIVATION_SOFTSIGN:
// No additional parameters to set.
break;
default:
assert(false);
break;
}
if (coerceAxis != TensorAxis::DoNotCoerce)
{
m_inputTensorDescs[0] = CreateTensorDescFromInput(kernelCreationContext, 0, coerceAxis);
m_outputTensorDescs[0] = CreateTensorDescFromOutput(kernelCreationContext, 0, coerceAxis);
}
gsl::span<const uint32_t> outputSizes = m_outputTensorDescs[0].GetSizes();
std::vector<DML_TENSOR_DESC> inputDescs;
std::vector<DML_TENSOR_DESC> outputDescs;
if (operatorType == DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU)
{
// PRelu is unique and accepts its parameters as a second input tensor.
// The slope tensor is unidirectionally broadcastable. Reshape it based on the desired output sizes.
m_inputTensorDescs[1] = CreateTensorDescFromInput(kernelCreationContext, 1, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, outputSizes);
inputDescs = GetDmlInputDescs();
outputDescs = GetDmlOutputDescs();
ML_CHECK_VALID_ARGUMENT(inputDescs.size() == 2);
ML_CHECK_VALID_ARGUMENT(outputDescs.size() == 1);
operatorDesc.parameterizedRelu.InputTensor = &inputDescs[0];
operatorDesc.parameterizedRelu.SlopeTensor = &inputDescs[1];
operatorDesc.parameterizedRelu.OutputTensor = outputDescs.data();
}
else // All other activation descrptions are equivalent to Elu in layout.
{
inputDescs = GetDmlInputDescs();
outputDescs = GetDmlOutputDescs();
ML_CHECK_VALID_ARGUMENT(inputDescs.size() >= 1);
ML_CHECK_VALID_ARGUMENT(outputDescs.size() >= 1);
operatorDesc.elu.InputTensor = inputDescs.data();
operatorDesc.elu.OutputTensor = outputDescs.data();
}
DML_OPERATOR_DESC opDesc = { operatorType, &operatorDesc };
SetDmlOperatorDesc(opDesc, kernelCreationContext);
}
};
// A specific type of operation for registration.
template <DML_OPERATOR_TYPE OperatorType>
class DmlOperatorActivationTemplate : public DmlOperatorActivation
{
public:
DmlOperatorActivationTemplate(const MLOperatorKernelCreationContext& kernelCreationContext)
: DmlOperatorActivation(kernelCreationContext, OperatorType)
{
}
};
DML_OP_DEFINE_CREATION_FUNCTION(Sigmoid, DmlOperatorActivationTemplate<DML_OPERATOR_ACTIVATION_SIGMOID>);
DML_OP_DEFINE_CREATION_FUNCTION(HardSigmoid, DmlOperatorActivationTemplate<DML_OPERATOR_ACTIVATION_HARD_SIGMOID>);
DML_OP_DEFINE_CREATION_FUNCTION(Tanh, DmlOperatorActivationTemplate<DML_OPERATOR_ACTIVATION_TANH>);
DML_OP_DEFINE_CREATION_FUNCTION(ScaledTanh, DmlOperatorActivationTemplate<DML_OPERATOR_ACTIVATION_SCALED_TANH>);
DML_OP_DEFINE_CREATION_FUNCTION(Relu, DmlOperatorActivationTemplate<DML_OPERATOR_ACTIVATION_RELU>);
DML_OP_DEFINE_CREATION_FUNCTION(LeakyRelu, DmlOperatorActivationTemplate<DML_OPERATOR_ACTIVATION_LEAKY_RELU>);
DML_OP_DEFINE_CREATION_FUNCTION(PRelu, DmlOperatorActivationTemplate<DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU>);
DML_OP_DEFINE_CREATION_FUNCTION(ThresholdedRelu, DmlOperatorActivationTemplate<DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU>);
DML_OP_DEFINE_CREATION_FUNCTION(Elu, DmlOperatorActivationTemplate<DML_OPERATOR_ACTIVATION_ELU>);
DML_OP_DEFINE_CREATION_FUNCTION(Selu, DmlOperatorActivationTemplate<DML_OPERATOR_ACTIVATION_SCALED_ELU>);
DML_OP_DEFINE_CREATION_FUNCTION(Softsign, DmlOperatorActivationTemplate<DML_OPERATOR_ACTIVATION_SOFTSIGN>);
DML_OP_DEFINE_CREATION_FUNCTION(Softplus, DmlOperatorActivationTemplate<DML_OPERATOR_ACTIVATION_SOFTPLUS>);
DML_OP_DEFINE_CREATION_FUNCTION(ParametricSoftplus, DmlOperatorActivationTemplate<DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS>);
DML_OP_DEFINE_CREATION_FUNCTION(Dropout, DmlOperatorActivationTemplate<DML_OPERATOR_ACTIVATION_IDENTITY>);
DML_OP_DEFINE_CREATION_FUNCTION(Softmax, DmlOperatorActivationTemplate<DML_OPERATOR_ACTIVATION_SOFTMAX>);
DML_OP_DEFINE_CREATION_FUNCTION(LogSoftmax, DmlOperatorActivationTemplate<DML_OPERATOR_ACTIVATION_LOG_SOFTMAX>);
DML_OP_DEFINE_CREATION_FUNCTION(Hardmax, DmlOperatorActivationTemplate<DML_OPERATOR_ACTIVATION_HARDMAX>);
DML_OP_DEFINE_CREATION_FUNCTION(Shrink, DmlOperatorActivationTemplate<DML_OPERATOR_ACTIVATION_SHRINK>);
} // namespace Dml

Просмотреть файл

@ -0,0 +1,37 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
class DmlOperatorAffine : public DmlOperator
{
public:
using Self = DmlOperatorAffine;
DmlOperatorAffine(const MLOperatorKernelCreationContext& kernelInfo)
: DmlOperator(kernelInfo)
{
Initialize(kernelInfo);
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
DML_SCALE_BIAS scaleBias = {};
scaleBias.Scale = kernelInfo.GetOptionalAttribute<float>(AttrName::Alpha, 0.0f);
scaleBias.Bias = kernelInfo.GetOptionalAttribute<float>(AttrName::Beta, 0.0f);
DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC opDesc = {};
opDesc.InputTensor = inputDescs.data();
opDesc.OutputTensor = outputDescs.data();
opDesc.ScaleBias = &scaleBias;
SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_IDENTITY, &opDesc}, kernelInfo);
}
};
DML_OP_DEFINE_CREATION_FUNCTION(Affine, DmlOperatorAffine);
} // namespace Dml

Просмотреть файл

@ -0,0 +1,67 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
class DmlOperatorBatchNormalization : public DmlOperator
{
// This order matches the ONNX schema.
enum OnnxInputIndex
{
X, // Input
Scale,
Bias,
Mean,
Variance,
Count,
};
public:
DmlOperatorBatchNormalization(const MLOperatorKernelCreationContext& kernelCreationContext)
: DmlOperator(kernelCreationContext)
{
std::vector<std::optional<uint32_t>> kernelInputIndices = {X, Mean, Variance, Scale, Bias};
DmlOperator::Initialize(kernelCreationContext, kernelInputIndices);
ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs.size() == 5);
ML_CHECK_VALID_ARGUMENT(m_outputTensorDescs.size() >= 1);
const float epsilon = kernelCreationContext.GetOptionalAttribute<float>(AttrName::Epsilon, 0.0f);
const int spatial = kernelCreationContext.GetOptionalAttribute<int>(AttrName::Spatial, 1);
const std::optional<ActivationOperatorDesc> fusedActivation = FusionHelpers::TryGetFusedActivationDesc(kernelCreationContext);
DML_OPERATOR_DESC fusedActivationDmlDesc = fusedActivation ? fusedActivation->GetDmlDesc() : DML_OPERATOR_DESC();
m_inputTensorDescs[0] = CreateTensorDescFromInput(kernelCreationContext, 0, TensorAxis::DoNotCoerce, TensorAxis::N, TensorAxis::LeftAligned);
// Massage each of these 1D tensors (of length C) into 4D tensors of the form [1,C,1,1].
for (uint32_t i = Scale; i < OnnxInputIndex::Count; ++i)
{
m_inputTensorDescs[i] = CreateTensorDescFromInput(kernelCreationContext, i, TensorAxis::DoNotCoerce, TensorAxis::C, TensorAxis::LeftAligned);
}
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
DML_BATCH_NORMALIZATION_OPERATOR_DESC operatorDesc = {};
operatorDesc.InputTensor = &inputDescs[X];
operatorDesc.MeanTensor = &inputDescs[Mean];
operatorDesc.VarianceTensor = &inputDescs[Variance];
operatorDesc.ScaleTensor = &inputDescs[Scale];
operatorDesc.BiasTensor = &inputDescs[Bias];
operatorDesc.OutputTensor = &outputDescs[0];
operatorDesc.Spatial = static_cast<BOOL>(spatial);
operatorDesc.Epsilon = epsilon;
operatorDesc.FusedActivation = fusedActivation ? &fusedActivationDmlDesc : nullptr;
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_BATCH_NORMALIZATION, &operatorDesc };
SetDmlOperatorDesc(opDesc, kernelCreationContext);
}
};
DML_OP_DEFINE_CREATION_FUNCTION(BatchNormalization, DmlOperatorBatchNormalization);
DML_OP_DEFINE_CREATION_FUNCTION(FusedBatchNormalization, DmlOperatorBatchNormalization);
} // namespace Dml

Просмотреть файл

@ -0,0 +1,65 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
class DmlOperatorCast : public DmlOperator
{
public:
using Self = DmlOperatorCast;
DmlOperatorCast(
const MLOperatorKernelCreationContext& kernelInfo
) : DmlOperator(kernelInfo),
m_toDataType(static_cast<MLOperatorTensorDataType>(kernelInfo.GetAttribute<int64_t>(AttrName::To)))
{
Initialize(kernelInfo);
// Zero the output tensor's memory for 64-bit integer emulation with strides.
if (m_toDataType == MLOperatorTensorDataType::UInt64 || m_toDataType == MLOperatorTensorDataType::Int64)
{
m_zeroOperator = InitializeZeroInt64Tensor(m_outputTensorDescs[0].GetBufferSizeInBytes());
}
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
DML_CAST_OPERATOR_DESC castDesc = {};
castDesc.InputTensor = inputDescs.data();
castDesc.OutputTensor = outputDescs.data();
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_CAST, &castDesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}
void Compute(const MLOperatorKernelContext& kernelContext)
{
std::vector<IMLOperatorTensor*> inputTensors = GetInputTensorsForExecute(kernelContext);
std::vector<IMLOperatorTensor*> outputTensors = GetOutputTensorsForExecute(kernelContext);
// Zero the output tensor's memory for 64-bit integer emulation with strides.
if (m_toDataType == MLOperatorTensorDataType::UInt64 || m_toDataType == MLOperatorTensorDataType::Int64)
{
assert(m_zeroOperator);
ExecuteZeroInt64Tensor(m_zeroOperator.Get(), outputTensors[0]);
}
THROW_IF_FAILED(m_executionProvider->ExecuteOperator(
m_compiledOperator.Get(),
m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr,
gsl::make_span(inputTensors),
gsl::make_span(outputTensors)));
}
private:
MLOperatorTensorDataType m_toDataType;
ComPtr<IDMLCompiledOperator> m_zeroOperator;
};
DML_OP_DEFINE_CREATION_FUNCTION(Cast, DmlOperatorCast);
} // namespace Dml

Просмотреть файл

@ -0,0 +1,39 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
class DmlOperatorConcat : public DmlOperator, public ConcatHelper
{
public:
using Self = DmlOperatorConcat;
DmlOperatorConcat(const MLOperatorKernelCreationContext& kernelInfo)
: DmlOperator(kernelInfo),
ConcatHelper(kernelInfo, kernelInfo.GetTensorShapeDescription())
{
DmlOperator::Initialize(kernelInfo);
uint32_t dmlAxis = GetDmlAdjustedAxis(m_axis, kernelInfo, m_inputTensorDescs.front().GetDimensionCount());
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
DML_JOIN_OPERATOR_DESC joinDesc = {};
joinDesc.InputCount = gsl::narrow_cast<uint32_t>(inputDescs.size());
joinDesc.InputTensors = inputDescs.data();
joinDesc.OutputTensor = outputDescs.data();
joinDesc.Axis = dmlAxis;
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_JOIN, &joinDesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}
};
DML_OP_DEFINE_CREATION_FUNCTION(Concat, DmlOperatorConcat);
} // namespace Dml

Просмотреть файл

@ -0,0 +1,60 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
class DmlOperatorConstantOfShape : public DmlOperator, public ConstantOfShapeHelper
{
public:
using Self = DmlOperatorConstantOfShape;
DmlOperatorConstantOfShape(const MLOperatorKernelCreationContext& kernelCreationContext)
: DmlOperator(kernelCreationContext),
ConstantOfShapeHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription())
{
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 1); // ignored shape tensor
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1); // output tensor
std::vector<std::optional<uint32_t>> inputIndices = {}; // The shape tensor is not GPU bound.
std::vector<std::optional<uint32_t>> outputIndices = { 0 };
Initialize(kernelCreationContext, inputIndices, outputIndices);
// Read the tensor attribute for the output fill pattern.
if (kernelCreationContext.HasAttribute(AttrName::Value, MLOperatorAttributeTypeTensor))
{
ComPtr<IMLOperatorKernelCreationContext> kernelCreationContextInterface = kernelCreationContext.GetInterface();
ComPtr<IMLOperatorAttributes1> attributes;
ComPtr<IMLOperatorTensor> valueTensor;
// Get the extended attributes to be able to access the constant tensor.
THROW_IF_FAILED(kernelCreationContextInterface.As(&attributes));
THROW_IF_FAILED(attributes->GetTensorAttribute(AttrName::Value, &valueTensor));
MLOperatorTensor wrappedValueTensor(valueTensor.Get());
// Read the raw bytes from the tensor, agnostic to data type, which becomes the GPU fill pattern.
ML_CHECK_VALID_ARGUMENT(wrappedValueTensor.IsCpuData());
const uint32_t elementCount = wrappedValueTensor.GetTotalElementCount();
ML_CHECK_VALID_ARGUMENT(elementCount == 1); // Expect exactly one element.
const size_t rawDataByteSize = GetByteSizeFromMlDataType(wrappedValueTensor.GetTensorDataType());
const std::byte* rawData = static_cast<const std::byte*>(valueTensor->GetData());
valueBytes.assign(rawData, rawData + rawDataByteSize);
}
// Else valueBytes is empty, and the default fill pattern is 0.
}
void Compute(const MLOperatorKernelContext& kernelContext) override
{
std::vector<IMLOperatorTensor*> outputTensors = GetOutputTensorsForExecute(kernelContext);
THROW_IF_FAILED(m_executionProvider->FillTensorWithPattern(outputTensors.front(), valueBytes));
}
private:
std::vector<std::byte> valueBytes;
};
DML_OP_DEFINE_CREATION_FUNCTION(ConstantOfShape, DmlOperatorConstantOfShape);
} // namespace Dml

Просмотреть файл

@ -0,0 +1,110 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
class DmlOperatorConvolution : public DmlOperator, public ConvolutionHelperBase
{
public:
using Self = DmlOperatorConvolution;
DmlOperatorConvolution(
const MLOperatorKernelCreationContext& kernelInfo,
DML_CONVOLUTION_MODE mode,
DML_CONVOLUTION_DIRECTION direction
)
: DmlOperator(kernelInfo),
ConvolutionHelperBase(kernelInfo, kernelInfo.GetTensorShapeDescription(), direction == DML_CONVOLUTION_DIRECTION_BACKWARD)
{
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() >= 2);
std::vector<std::optional<uint32_t>> kernelInputIndices = {0, 1, 2};
DmlOperator::Initialize(kernelInfo, kernelInputIndices);
// Vibranium DirectML is limited to handle only 2D and 3D convolution (4D and 5D tensors). So for 1D tensors,
// massage the tensor descriptions. By default, the TensorDesc simply right aligns all the values up to 4D
// (padding the leading dimensions with 1's), but 1D tensors actually need to insert the 1 between C and W.
// e.g. [2,3,4] becomes [2,3,1,4]
m_inputTensorDescs[0] = CreateTensorDescFromInput(kernelInfo, 0, TensorAxis::DoNotCoerce, TensorAxis::NoPlacementAdjustment, NonspatialDimensionCount, std::nullopt);
m_inputTensorDescs[1] = CreateTensorDescFromInput(kernelInfo, 1, TensorAxis::DoNotCoerce, TensorAxis::NoPlacementAdjustment, NonspatialDimensionCount, std::nullopt);
// Bias is optional so only adjust it if it exists.
if (kernelInfo.GetInputCount() > 2)
{
uint32_t inputDimSize = kernelInfo.GetTensorShapeDescription().GetInputTensorDimensionCount(0);
ML_CHECK_VALID_ARGUMENT(
inputDimSize >= 3 && inputDimSize <= 5,
"Bias can only be used with 3D/4D/5D tensors."
);
uint32_t dmlDimSize = m_inputTensorDescs[0].GetDimensionCount();
// Resize the bias to be the same dimension as the input tensor.
// The 1D tensor needs to be moved to the C channel.
m_inputTensorDescs[2] = CreateTensorDescFromInput(
kernelInfo,
2,
TensorAxis::DoNotCoerce,
TensorAxis::C,
TensorAxis::LeftAligned,
std::nullopt,
dmlDimSize
);
}
m_outputTensorDescs[0] = CreateTensorDescFromOutput(kernelInfo, 0, TensorAxis::DoNotCoerce, TensorAxis::NoPlacementAdjustment, NonspatialDimensionCount, std::nullopt);
std::optional<ActivationOperatorDesc> fusedActivation = FusionHelpers::TryGetFusedActivationDesc(kernelInfo);
DML_OPERATOR_DESC fusedActivationDmlDesc = fusedActivation ? fusedActivation->GetDmlDesc() : DML_OPERATOR_DESC();
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
// Form transient kernel arguments with spatial dimensions padded up to at least 2,
// since the DirectML API rejects 1D convolution. Leave the base m_kernel alone
// so that all output tensor size computations are correct.
KernelArgs kernelArgs(m_kernel, NchwSpatialDimensionCount);
// Zero the output padding before sending to DirectML. Although it was needed to compute
// the output size, we don't want DML to see the values, which should just be ignored.
memset(kernelArgs.outputPadding, 0, sizeof(kernelArgs.outputPadding));
DML_CONVOLUTION_OPERATOR_DESC convDesc = {};
convDesc.InputTensor = &inputDescs[0];
convDesc.FilterTensor = &inputDescs[1];
convDesc.BiasTensor = kernelInfo.GetInputCount() > 2 ? &inputDescs[2] : nullptr;
convDesc.OutputTensor = &outputDescs[0];
convDesc.Mode = mode;
convDesc.Direction = direction;
convDesc.DimensionCount = kernelArgs.spatialDimensionCount;
convDesc.Strides = kernelArgs.strides;
convDesc.Dilations = kernelArgs.dilations;
convDesc.StartPadding = kernelArgs.startPadding;
convDesc.EndPadding = kernelArgs.endPadding;
convDesc.OutputPadding = kernelArgs.outputPadding;
convDesc.GroupCount = m_groupCount;
convDesc.FusedActivation = fusedActivation ? &fusedActivationDmlDesc : nullptr;
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_CONVOLUTION, &convDesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}
};
// A specific type of operation for registration.
template <DML_CONVOLUTION_MODE Mode, DML_CONVOLUTION_DIRECTION Direction>
class DmlOperatorConvolutionTemplate : public DmlOperatorConvolution
{
public:
DmlOperatorConvolutionTemplate(const MLOperatorKernelCreationContext& kernelInfo)
: DmlOperatorConvolution(kernelInfo, Mode, Direction)
{
}
};
DML_OP_DEFINE_CREATION_FUNCTION(Conv, DmlOperatorConvolutionTemplate<DML_CONVOLUTION_MODE_CROSS_CORRELATION, DML_CONVOLUTION_DIRECTION_FORWARD>);
DML_OP_DEFINE_CREATION_FUNCTION(ConvTranspose, DmlOperatorConvolutionTemplate<DML_CONVOLUTION_MODE_CROSS_CORRELATION, DML_CONVOLUTION_DIRECTION_BACKWARD>);
DML_OP_DEFINE_CREATION_FUNCTION(FusedConv, DmlOperatorConvolutionTemplate<DML_CONVOLUTION_MODE_CROSS_CORRELATION, DML_CONVOLUTION_DIRECTION_FORWARD>);
DML_OP_DEFINE_CREATION_FUNCTION(FusedConvTranspose, DmlOperatorConvolutionTemplate<DML_CONVOLUTION_MODE_CROSS_CORRELATION, DML_CONVOLUTION_DIRECTION_BACKWARD>);
} // namespace Dml

Просмотреть файл

@ -0,0 +1,68 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
// Copies first input and ignores others. Used for operators which perform reshaping.
class DmlOperatorCopy : public DmlOperator
{
public:
using Self = DmlOperatorCopy;
DmlOperatorCopy(const MLOperatorKernelCreationContext& kernelInfo) : DmlOperator(kernelInfo)
{
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() >= 1);
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);
std::vector<std::optional<uint32_t>> kernelInputOutputIndices = {0};
Initialize(kernelInfo, kernelInputOutputIndices);
// DirectML requires the input & output dimensions to be identical, even if the
// element counts are the same. All this operator does is copy the resource and
// rearrange the dimensions, so we tell DML that the output dimensions are the
// same as the input dimensions.
m_outputTensorDescs.front() = TensorDesc(
m_outputTensorDescs.front().GetDmlDataType(),
m_inputTensorDescs.front().GetSizes()
);
ComPtr<IMLOperatorKernelCreationContextPrivate> contextPrivate;
THROW_IF_FAILED(kernelInfo.GetInterface()->QueryInterface(contextPrivate.GetAddressOf()));
if (contextPrivate->IsDmlGraphNode())
{
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC opDesc = {};
opDesc.InputTensor = inputDescs.data();
opDesc.OutputTensor = outputDescs.data();
SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_IDENTITY, &opDesc }, kernelInfo);
}
}
void Compute(const MLOperatorKernelContext& kernelContext)
{
MLOperatorTensor inputTensor = kernelContext.GetInputTensor(0);
// Reshape the output tensor.
MLOperatorTensor outputTensor = kernelContext.GetOutputTensor(0);
// Avoid self copying.
if (inputTensor.GetDataInterface().Get() != outputTensor.GetDataInterface().Get())
{
// Copy elements from input tensor to output tensor.
THROW_IF_FAILED(m_executionProvider->CopyTensor(
outputTensor.GetInterface().Get(),
inputTensor.GetInterface().Get()));
}
}
};
DML_OP_DEFINE_CREATION_FUNCTION(Copy, DmlOperatorCopy);
} // namespace Dml

Просмотреть файл

@ -0,0 +1,46 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
class DmlOperatorCrop : public DmlOperator, public CropHelper
{
public:
DmlOperatorCrop(const MLOperatorKernelCreationContext& kernelInfo)
: DmlOperator(kernelInfo),
CropHelper(kernelInfo, kernelInfo.GetTensorShapeDescription())
{
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 1);
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);
DmlOperator::Initialize(kernelInfo);
// CropHelper coerces the input into 4D by this point.
auto outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0);
assert(outputShape.size() == NchwDimensionCount);
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
DML_SLICE_OPERATOR_DESC opDesc = {};
opDesc.InputTensor = inputDescs.data();
opDesc.OutputTensor = outputDescs.data();
opDesc.DimensionCount = NchwDimensionCount;
opDesc.Offsets = m_offsets;
opDesc.Sizes = outputShape.data();
opDesc.Strides = c_strides;
SetDmlOperatorDesc({ DML_OPERATOR_SLICE, &opDesc}, kernelInfo);
}
static const uint32_t c_strides[NchwDimensionCount];
};
/*static*/ const uint32_t DmlOperatorCrop::c_strides[NchwDimensionCount] = {1, 1, 1, 1};
DML_OP_DEFINE_CREATION_FUNCTION(Crop, DmlOperatorCrop);
} // namespace Dml

Просмотреть файл

@ -0,0 +1,35 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
class DmlOperatorDepthToSpace : public DmlOperator, public DepthToSpaceHelper
{
public:
DmlOperatorDepthToSpace(const MLOperatorKernelCreationContext& kernelCreationContext)
: DmlOperator(kernelCreationContext),
DepthToSpaceHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription())
{
DmlOperator::Initialize(kernelCreationContext);
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
ML_CHECK_VALID_ARGUMENT(inputDescs.size() == 1);
ML_CHECK_VALID_ARGUMENT(outputDescs.size() == 1);
DML_DEPTH_TO_SPACE_OPERATOR_DESC operatorDesc = {};
operatorDesc.InputTensor = inputDescs.data();
operatorDesc.OutputTensor = outputDescs.data();
operatorDesc.BlockSize = m_blockSize;
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_DEPTH_TO_SPACE, &operatorDesc };
SetDmlOperatorDesc(opDesc, kernelCreationContext);
}
};
DML_OP_DEFINE_CREATION_FUNCTION(DepthToSpace, DmlOperatorDepthToSpace);
} // namespace Dml

Просмотреть файл

@ -0,0 +1,555 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
bool AreAllStridesIdentical(gsl::span<const TensorDesc> tensorDescs)
{
const size_t tensorDescCount = tensorDescs.size();
for (size_t i = 1; i < tensorDescCount; ++i)
{
gsl::span<const uint32_t> stridesA = tensorDescs[i - 1].GetStrides();
gsl::span<const uint32_t> stridesB = tensorDescs[i].GetStrides();
if (stridesA.size() != stridesB.size() || !std::equal(stridesA.begin(), stridesA.end(), stridesB.begin()))
{
return false;
}
}
return true;
}
template <typename TOperatorDesc>
class DmlOperatorElementwiseUnary : public DmlOperator
{
public:
DmlOperatorElementwiseUnary(const MLOperatorKernelCreationContext& kernelInfo) : DmlOperator(kernelInfo)
{
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 1);
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);
Initialize(kernelInfo, std::nullopt, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0));
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
TOperatorDesc opDesc = {};
opDesc.InputTensor = inputDescs.data();
opDesc.OutputTensor = outputDescs.data();
SetDmlOperatorDesc({ ApiTraits::OperatorDescTraits<TOperatorDesc>::Type, &opDesc}, kernelInfo);
}
};
template<typename T>
void SetFusedActivation(T& opDesc, const DML_OPERATOR_DESC* fusedActivation)
{
// Activation is only fused for sum operators, which have a template specialization
THROW_HR(E_INVALIDARG);
}
template<>
void SetFusedActivation(DML_ELEMENT_WISE_ADD1_OPERATOR_DESC& opDesc, const DML_OPERATOR_DESC* fusedActivation)
{
opDesc.FusedActivation = fusedActivation;
}
template <typename TOperatorDesc>
class DmlOperatorElementwiseBinary : public DmlOperator
{
public:
DmlOperatorElementwiseBinary(const MLOperatorKernelCreationContext& kernelInfo) : DmlOperator(kernelInfo)
{
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 2);
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);
Initialize(kernelInfo, std::nullopt, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0));
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
std::optional<ActivationOperatorDesc> fusedActivation = FusionHelpers::TryGetFusedActivationDesc(kernelInfo);
DML_OPERATOR_DESC fusedActivationDmlDesc = fusedActivation ? fusedActivation->GetDmlDesc() : DML_OPERATOR_DESC();
TOperatorDesc opDesc = {};
opDesc.ATensor = &inputDescs[0];
opDesc.BTensor = &inputDescs[1];
opDesc.OutputTensor = outputDescs.data();
DML_OPERATOR_DESC opDescDesc = { ApiTraits::OperatorDescTraits<TOperatorDesc>::Type, &opDesc};
if (fusedActivation != std::nullopt)
{
// Activation is only fused for two-input sum operators
THROW_HR_IF(E_INVALIDARG, opDescDesc.Type != DML_OPERATOR_ELEMENT_WISE_ADD1 || kernelInfo.GetInputCount() > 2);
SetFusedActivation(opDesc, &fusedActivationDmlDesc);
}
SetDmlOperatorDesc(opDescDesc, kernelInfo);
}
};
ComPtr<IDMLCompiledOperator> CreateSecondaryOperator(
IDMLDevice* dmlDevice,
DML_EXECUTION_FLAGS executionFlags,
const DML_OPERATOR_DESC& operatorDesc,
const MLOperatorKernelCreationContext& kernelInfo
)
{
ComPtr<IDMLOperator> dmlOperator;
ComPtr<IDMLCompiledOperator> compiledOperator;
THROW_IF_FAILED(dmlDevice->CreateOperator(&operatorDesc, IID_PPV_ARGS(&dmlOperator)));
THROW_IF_FAILED(dmlDevice->CompileOperator(dmlOperator.Get(), executionFlags, IID_PPV_ARGS(&compiledOperator)));
return compiledOperator;
}
template <typename TOperatorDesc>
class DmlOperatorElementwiseBinaryLoop : public DmlOperator
{
public:
DmlOperatorElementwiseBinaryLoop(const MLOperatorKernelCreationContext& kernelInfo) : DmlOperator(kernelInfo)
{
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() >= 1);
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);
Initialize(kernelInfo, std::nullopt, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0));
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
const size_t inputCount = m_inputTensorDescs.size();
std::optional<ActivationOperatorDesc> fusedActivation = FusionHelpers::TryGetFusedActivationDesc(kernelInfo);
DML_OPERATOR_DESC fusedActivationDmlDesc = fusedActivation ? fusedActivation->GetDmlDesc() : DML_OPERATOR_DESC();
// Activation is only fused for two-input sum operators
THROW_HR_IF(E_INVALIDARG, fusedActivation != std::nullopt && inputCount != 2);
if (inputCount == 1)
{
DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC identityDesc = {};
identityDesc.InputTensor = &inputDescs[0];
identityDesc.OutputTensor = &outputDescs[0];
SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_IDENTITY, &identityDesc }, kernelInfo);
}
else
{
// Create a single operator that applies to pairwise to every two inputs,
// accumulated into the output tensor.
TOperatorDesc opDesc = {};
opDesc.ATensor = &inputDescs[0];
opDesc.BTensor = &inputDescs[1];
opDesc.OutputTensor = outputDescs.data();
DML_OPERATOR_DESC opDescDesc = { ApiTraits::OperatorDescTraits<TOperatorDesc>::Type, &opDesc};
if (fusedActivation != std::nullopt)
{
SetFusedActivation(opDesc, &fusedActivationDmlDesc);
}
SetDmlOperatorDesc(opDescDesc, kernelInfo);
// If the tensor strides differ between pairs, then it's unsafe to reuse the same operator
// for all pairs because the wrong stride would be used. So create operators for every additional
// pair after the first. Given tensors {A, B, C}, the first operator handles A&B, the secondary
// operator handles tensors B&C, and any additional after that would need another operator.
if (inputCount >= 2 && !AreAllStridesIdentical(m_inputTensorDescs))
{
const DML_EXECUTION_FLAGS executionFlags = GetExecutionFlags();
gsl::span<const DML_TENSOR_DESC> remainingInputDescs = gsl::make_span(inputDescs);
remainingInputDescs = remainingInputDescs.subspan(2, remainingInputDescs.size() - 2);
for (const DML_TENSOR_DESC& tensorDesc : remainingInputDescs)
{
opDesc.ATensor = &tensorDesc;
opDesc.BTensor = &outputDescs[0];
// Already set - tOpDesc.OutputTensor = &outputDescs[0];
m_compiledOperators.push_back(CreateSecondaryOperator(m_dmlDevice.Get(), executionFlags, opDescDesc, kernelInfo));
}
}
}
}
void Compute(const MLOperatorKernelContext& kernelContext)
{
// For 1 input, just return the input (identity).
if (m_inputTensorDescs.size() == 1)
{
DmlOperator::Compute(kernelContext);
return;
}
// Apply the operator to the first two inputs.
std::array<IMLOperatorTensor*, 2> inputTensors;
inputTensors[0] = kernelContext.GetInputTensor(0).GetInterface().Get();
inputTensors[1] = kernelContext.GetInputTensor(1).GetInterface().Get();
IMLOperatorTensor* outputTensor = kernelContext.GetOutputTensor(0).GetInterface().Get();
gsl::span<IMLOperatorTensor*> outputTensors{ &outputTensor, 1 };
// Combine the first two inputs and store the result in the output tensor.
THROW_IF_FAILED(m_executionProvider->ExecuteOperator(
m_compiledOperator.Get(),
m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr,
gsl::make_span(inputTensors),
outputTensors));
// For each input after the first two, accumulate into the output tensor.
for (size_t inputIndex = 2; inputIndex < m_inputTensorDescs.size(); ++inputIndex)
{
inputTensors[0] = kernelContext.GetInputTensor(gsl::narrow_cast<uint32_t>(inputIndex)).GetInterface().Get();
inputTensors[1] = outputTensors[0];
// Get the next operator for this pair, either reusing the first or using a distinct operator.
IDMLCompiledOperator* compiledOperator = m_compiledOperators.empty()
? m_compiledOperator.Get()
: m_compiledOperators[inputIndex - 2].Get();
THROW_IF_FAILED(m_executionProvider->ExecuteOperator(
compiledOperator,
m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr,
gsl::make_span(inputTensors),
outputTensors));
}
}
// If multiple compiled operators are needed, beyond m_compiledOperator, they are appended here.
// The size of the vector will either be empty if all tensor pairs have identical properties,
// or it will equal inputCount - 2, with the first operator in this vector corresponding to the
// 3rd input tensor combined with the output of the previous 2 input tensors.
std::vector<ComPtr<IDMLCompiledOperator>> m_compiledOperators;
};
class DmlOperatorElementwiseMean : public DmlOperator
{
// Used with 3+ inputs to divide each element by the number of input tensors.
ComPtr<IDMLCompiledOperator> m_compiledIdentityOp;
public:
DmlOperatorElementwiseMean(const MLOperatorKernelCreationContext& kernelInfo) : DmlOperator(kernelInfo)
{
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() >= 1);
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);
Initialize(kernelInfo, std::nullopt, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0));
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
const size_t inputCount = m_inputTensorDescs.size();
if (inputCount == 1)
{
// For 1 input, just return the input
DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC identityDesc = {};
identityDesc.InputTensor = &inputDescs[0];
identityDesc.OutputTensor = &outputDescs[0];
SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_IDENTITY, &identityDesc }, kernelInfo);
}
else if (inputCount == 2)
{
// For 2 inputs, use DML's mean operator.
DML_ELEMENT_WISE_MEAN_OPERATOR_DESC meanDesc = {};
meanDesc.ATensor = &inputDescs[0];
meanDesc.BTensor = &inputDescs[1];
meanDesc.OutputTensor = &outputDescs[0];
SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_MEAN, &meanDesc}, kernelInfo);
}
else
{
// For 3+ inputs, use several DML adds followed by a divide (identity with scale=1/InputCount).
assert(inputDescs.size() > 2);
DML_ELEMENT_WISE_ADD_OPERATOR_DESC addDesc = {};
addDesc.ATensor = &inputDescs[0];
addDesc.BTensor = &inputDescs[1];
addDesc.OutputTensor = &outputDescs[0];
DML_OPERATOR_DESC addDescDesc = { DML_OPERATOR_ELEMENT_WISE_ADD, &addDesc};
SetDmlOperatorDesc(addDescDesc, kernelInfo);
if (!AreAllStridesIdentical(m_inputTensorDescs))
{
// Create operators for each input after the first two.
const DML_EXECUTION_FLAGS executionFlags = GetExecutionFlags();
gsl::span<const DML_TENSOR_DESC> remainingInputDescs = gsl::make_span(inputDescs);
remainingInputDescs = remainingInputDescs.subspan(2, remainingInputDescs.size() - 2);
for (const DML_TENSOR_DESC& tensorDesc : remainingInputDescs)
{
addDesc.ATensor = &tensorDesc;
addDesc.BTensor = &outputDescs[0];
// Already set - addDesc.OutputTensor = &outputDescs[0];
m_compiledOperators.push_back(CreateSecondaryOperator(m_dmlDevice.Get(), executionFlags, addDescDesc, kernelInfo));
}
}
// Create division operation using reciprocal of input tensor count.
DML_SCALE_BIAS scaleBias = {};
scaleBias.Scale = 1.0f / inputDescs.size();
DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC identityDesc = {};
identityDesc.InputTensor = &outputDescs[0];
identityDesc.OutputTensor = &outputDescs[0];
identityDesc.ScaleBias = &scaleBias;
DML_OPERATOR_DESC identityDescDesc = { DML_OPERATOR_ELEMENT_WISE_IDENTITY, &identityDesc };
ComPtr<IDMLOperator> identityOp;
THROW_IF_FAILED(m_dmlDevice->CreateOperator(&identityDescDesc, IID_PPV_ARGS(&identityOp)));
THROW_IF_FAILED(m_dmlDevice->CompileOperator(identityOp.Get(), GetExecutionFlags(), IID_PPV_ARGS(&m_compiledIdentityOp)));
}
}
void Compute(const MLOperatorKernelContext& kernelContext)
{
// Where there's only a single element, just return the input (identity).
if (m_inputTensorDescs.size() == 1)
{
DmlOperator::Compute(kernelContext);
}
else if (!m_compiledIdentityOp)
{
// Use DML mean operator.
DmlOperator::Compute(kernelContext);
}
else
{
// Do N-1 adds followed by a division, where N is the number of inputs.
std::array<IMLOperatorTensor*, 2> inputTensors;
inputTensors[0] = kernelContext.GetInputTensor(0).GetInterface().Get();
inputTensors[1] = kernelContext.GetInputTensor(1).GetInterface().Get();
IMLOperatorTensor* outputTensor = kernelContext.GetOutputTensor(0).GetInterface().Get();
gsl::span<IMLOperatorTensor*> outputTensors{ &outputTensor, 1 };
// Add the first two inputs and store the result in the output tensor.
THROW_IF_FAILED(m_executionProvider->ExecuteOperator(
m_compiledOperator.Get(),
m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr,
gsl::make_span(inputTensors),
outputTensors));
// For each input after the first two, accumulate into the output tensor.
for (size_t inputIndex = 2; inputIndex < m_inputTensorDescs.size(); ++inputIndex)
{
inputTensors[0] = kernelContext.GetInputTensor(gsl::narrow_cast<uint32_t>(inputIndex)).GetInterface().Get();
inputTensors[1] = outputTensors[0];
// Get the next operator for this pair, either reusing the first or using a distinct operator.
IDMLCompiledOperator* compiledOperator = m_compiledOperators.empty()
? m_compiledOperator.Get()
: m_compiledOperators[inputIndex - 2].Get();
THROW_IF_FAILED(m_executionProvider->ExecuteOperator(
compiledOperator,
m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr,
gsl::make_span(inputTensors),
outputTensors));
}
// Dispatch the identity w/ scale operator in-place on the output.
THROW_IF_FAILED(m_executionProvider->ExecuteOperator(
m_compiledIdentityOp.Get(),
nullptr, // persistent resoruce binding
outputTensors,
outputTensors));
}
}
// If multiple compiled operators are needed, beyond m_compiledOperator, they are appended here.
std::vector<ComPtr<IDMLCompiledOperator>> m_compiledOperators;
};
class DmlOperatorElementwiseClip : public DmlOperator
{
public:
DmlOperatorElementwiseClip(const MLOperatorKernelCreationContext& kernelInfo) : DmlOperator(kernelInfo)
{
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 1);
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);
Initialize(kernelInfo, std::nullopt, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0));
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
DML_ELEMENT_WISE_CLIP_OPERATOR_DESC opDesc = {};
opDesc.InputTensor = inputDescs.data();
opDesc.OutputTensor = outputDescs.data();
opDesc.Min = kernelInfo.GetOptionalAttribute<float>(AttrName::Min, std::numeric_limits<float>::lowest());
opDesc.Max = kernelInfo.GetOptionalAttribute<float>(AttrName::Max, std::numeric_limits<float>::max());
SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_CLIP, &opDesc}, kernelInfo);
}
};
class DmlOperatorElementwisePow : public DmlOperator
{
public:
DmlOperatorElementwisePow(const MLOperatorKernelCreationContext& kernelInfo) : DmlOperator(kernelInfo)
{
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 2);
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);
Initialize(kernelInfo, std::nullopt, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0));
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
DML_ELEMENT_WISE_POW_OPERATOR_DESC opDesc = {};
opDesc.InputTensor = &inputDescs[0];
opDesc.ExponentTensor = &inputDescs[1];
opDesc.OutputTensor = &outputDescs[0];
SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_POW, &opDesc}, kernelInfo);
}
};
template <typename TOperatorDesc>
class DmlOperatorElementwiseQLinear : public DmlOperator
{
public:
DmlOperatorElementwiseQLinear(const MLOperatorKernelCreationContext& kernelInfo) : DmlOperator(kernelInfo)
{
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 3);
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);
std::vector<uint32_t> outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0);
const uint32_t outputShapeDimCount = gsl::narrow_cast<uint32_t>(outputShape.size());
Initialize(kernelInfo, std::nullopt, std::nullopt, outputShape);
// If the axis attribute is explicitly provided, then broadcasting must be performed along that axis.
// So massage the actual shapes of the scale and zero-point tensors (1D with length equal to the input
// axis being broadcast to) into broadcastable shapes.
if (kernelInfo.HasAttribute(AttrName::Axis, MLOperatorAttributeType::Int))
{
const int32_t signedAxis = gsl::narrow_cast<int32_t>(kernelInfo.GetAttribute<int64_t>(AttrName::Axis));
const uint32_t axis = Dml::HandleNegativeAxis(signedAxis, outputShapeDimCount);
const uint32_t broadcastAxisLength = outputShape[axis];
// Explicitly reshape each of the inputs after the first input (scale and zero point tensors).
for (uint32_t index = 1, inputCount = gsl::narrow_cast<uint32_t>(m_inputTensorDescs.size()); index < inputCount; ++index)
{
auto edgeDesc = kernelInfo.GetInputEdgeDescription(index);
assert(edgeDesc.edgeType == MLOperatorEdgeType::Tensor);
// Fix up the the tensor shape by filling with trailing ones. So input[2,3] with axis=0 and scale[2]
// becomes scale[2,1], so that broadcasting works correctly.
std::vector<uint32_t> adjustedInputTensorShape = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(index);
ML_CHECK_VALID_ARGUMENT(adjustedInputTensorShape.size() == 1);
ML_CHECK_VALID_ARGUMENT(adjustedInputTensorShape[0] == broadcastAxisLength);
adjustedInputTensorShape.insert(adjustedInputTensorShape.end(), outputShapeDimCount - 1 - axis, 1);
m_inputTensorDescs[index] = TensorDesc(
edgeDesc.tensorDataType,
gsl::make_span(outputShape),
gsl::make_span(adjustedInputTensorShape),
TensorAxis::DoNotCoerce,
TensorAxis::W,
TensorAxis::RightAligned,
NchwDimensionCount, // minDimensionCount
0 // guaranteedBaseOffsetAlignment
);
}
}
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
TOperatorDesc opDesc = {};
opDesc.InputTensor = &inputDescs[0];
opDesc.ScaleTensor = &inputDescs[1];
opDesc.ZeroPointTensor = &inputDescs[2];
opDesc.OutputTensor = &outputDescs[0];
SetDmlOperatorDesc({ApiTraits::OperatorDescTraits<TOperatorDesc>::Type, &opDesc}, kernelInfo);
}
};
class DmlOperatorElementwiseIf : public DmlOperator
{
public:
DmlOperatorElementwiseIf(const MLOperatorKernelCreationContext& kernelInfo) : DmlOperator(kernelInfo)
{
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 3);
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);
Initialize(kernelInfo, std::nullopt, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0));
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
DML_ELEMENT_WISE_IF_OPERATOR_DESC opDesc = {};
opDesc.ConditionTensor = &inputDescs[0];
opDesc.ATensor = &inputDescs[1];
opDesc.BTensor = &inputDescs[2];
opDesc.OutputTensor = &outputDescs[0];
SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_IF, &opDesc }, kernelInfo);
}
};
// Unary operators:
DML_OP_DEFINE_CREATION_FUNCTION(Sqrt, DmlOperatorElementwiseUnary<DML_ELEMENT_WISE_SQRT_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Reciprocal, DmlOperatorElementwiseUnary<DML_ELEMENT_WISE_RECIP_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Cos, DmlOperatorElementwiseUnary<DML_ELEMENT_WISE_COS_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Sin, DmlOperatorElementwiseUnary<DML_ELEMENT_WISE_SIN_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Tan, DmlOperatorElementwiseUnary<DML_ELEMENT_WISE_TAN_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Acos, DmlOperatorElementwiseUnary<DML_ELEMENT_WISE_ACOS_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Asin, DmlOperatorElementwiseUnary<DML_ELEMENT_WISE_ASIN_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Atan, DmlOperatorElementwiseUnary<DML_ELEMENT_WISE_ATAN_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Exp, DmlOperatorElementwiseUnary<DML_ELEMENT_WISE_EXP_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Log, DmlOperatorElementwiseUnary<DML_ELEMENT_WISE_LOG_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Abs, DmlOperatorElementwiseUnary<DML_ELEMENT_WISE_ABS_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Ceil, DmlOperatorElementwiseUnary<DML_ELEMENT_WISE_CEIL_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Floor, DmlOperatorElementwiseUnary<DML_ELEMENT_WISE_FLOOR_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Not, DmlOperatorElementwiseUnary<DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Sign, DmlOperatorElementwiseUnary<DML_ELEMENT_WISE_SIGN_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(IsNan, DmlOperatorElementwiseUnary<DML_ELEMENT_WISE_IS_NAN_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Sinh, DmlOperatorElementwiseUnary<DML_ELEMENT_WISE_SINH_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Cosh, DmlOperatorElementwiseUnary<DML_ELEMENT_WISE_COSH_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Asinh, DmlOperatorElementwiseUnary<DML_ELEMENT_WISE_ASINH_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Acosh, DmlOperatorElementwiseUnary<DML_ELEMENT_WISE_ACOSH_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Atanh, DmlOperatorElementwiseUnary<DML_ELEMENT_WISE_ATANH_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Erf, DmlOperatorElementwiseUnary<DML_ELEMENT_WISE_ERF_OPERATOR_DESC>);
// Binary operators:
DML_OP_DEFINE_CREATION_FUNCTION(Greater, DmlOperatorElementwiseBinary<DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Less, DmlOperatorElementwiseBinary<DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Equal, DmlOperatorElementwiseBinary<DML_ELEMENT_WISE_LOGICAL_EQUALS_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(And, DmlOperatorElementwiseBinary<DML_ELEMENT_WISE_LOGICAL_AND_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Or, DmlOperatorElementwiseBinary<DML_ELEMENT_WISE_LOGICAL_OR_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Xor, DmlOperatorElementwiseBinary<DML_ELEMENT_WISE_LOGICAL_XOR_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Add, DmlOperatorElementwiseBinary<DML_ELEMENT_WISE_ADD_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Sub, DmlOperatorElementwiseBinary<DML_ELEMENT_WISE_SUBTRACT_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Mul, DmlOperatorElementwiseBinary<DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Div, DmlOperatorElementwiseBinary<DML_ELEMENT_WISE_DIVIDE_OPERATOR_DESC>);
// Binary operators that support >2 inputs:
DML_OP_DEFINE_CREATION_FUNCTION(Sum, DmlOperatorElementwiseBinaryLoop<DML_ELEMENT_WISE_ADD_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Min, DmlOperatorElementwiseBinaryLoop<DML_ELEMENT_WISE_MIN_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Max, DmlOperatorElementwiseBinaryLoop<DML_ELEMENT_WISE_MAX_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Mean, DmlOperatorElementwiseMean);
// Operators with extra attributes:
DML_OP_DEFINE_CREATION_FUNCTION(Clip, DmlOperatorElementwiseClip);
DML_OP_DEFINE_CREATION_FUNCTION(Pow, DmlOperatorElementwisePow);
DML_OP_DEFINE_CREATION_FUNCTION(QuantizeLinear, DmlOperatorElementwiseQLinear<DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(DequantizeLinear, DmlOperatorElementwiseQLinear<DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Where, DmlOperatorElementwiseIf);
// Fused operators:
DML_OP_DEFINE_CREATION_FUNCTION(FusedAdd, DmlOperatorElementwiseBinary<DML_ELEMENT_WISE_ADD1_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(FusedSum, DmlOperatorElementwiseBinaryLoop<DML_ELEMENT_WISE_ADD1_OPERATOR_DESC>);
} // namespace Dml

Просмотреть файл

@ -0,0 +1,66 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
class DmlOperatorExpand : public DmlOperator, public ExpandHelper
{
public:
using Self = DmlOperatorExpand;
DmlOperatorExpand(const MLOperatorKernelCreationContext& kernelCreationContext)
: DmlOperator(kernelCreationContext),
ExpandHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription())
{
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 2);
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1);
std::vector<std::optional<uint32_t>> inputIndices = { 0 }; // The second tensor is not bound to Identity operator.
std::vector<std::optional<uint32_t>> outputIndices = { 0 };
Initialize(kernelCreationContext, inputIndices, outputIndices);
TensorDesc inputTensorDesc =
TensorDesc(
kernelCreationContext.GetInputEdgeDescription(0).tensorDataType,
m_outputTensorDescs[0].GetDmlSizes(),
m_inputTensorDescs[0].GetDmlSizes(),
TensorAxis::DoNotCoerce,
TensorAxis::W,
TensorAxis::RightAligned,
NchwDimensionCount, // minDimensionCount
0);
TensorDesc outputTensorDesc =
TensorDesc(
kernelCreationContext.GetOutputEdgeDescription(0).tensorDataType,
m_outputTensorDescs[0].GetDmlSizes(),
m_outputTensorDescs[0].GetDmlSizes(),
TensorAxis::DoNotCoerce,
TensorAxis::W,
TensorAxis::RightAligned,
NchwDimensionCount, // minDimensionCount
0
);
m_inputTensorDescs[0] = inputTensorDesc;
m_outputTensorDescs[0] = outputTensorDesc;
// Create the operator with new shape after calling UpdateShape.
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC operatorDesc = {};
operatorDesc.InputTensor = &inputDescs[0];
operatorDesc.OutputTensor = &outputDescs[0];
// identityDesc.ScaleBias left empty.
SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_IDENTITY, &operatorDesc}, kernelCreationContext);
}
};
DML_OP_DEFINE_CREATION_FUNCTION(Expand, DmlOperatorExpand);
} // namespace Dml

Просмотреть файл

@ -0,0 +1,45 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
class DmlOperatorEyeLike : public DmlOperator
{
public:
DmlOperatorEyeLike(const MLOperatorKernelCreationContext& kernelCreationContext)
: DmlOperator(kernelCreationContext)
{
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 1, "EyeLike expects 1 input.");
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1, "EyeLike expects 1 output.");
std::vector<std::optional<uint32_t>> inputIndices = {}; // Ignore the 1st input tensor for the GPU.
std::vector<std::optional<uint32_t>> outputIndices = { 0 };
DmlOperator::Initialize(kernelCreationContext, inputIndices, outputIndices);
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
assert(inputDescs.size() <= 1);
assert(outputDescs.size() == 1);
auto outputTensorShapeDescription = kernelCreationContext.GetTensorShapeDescription();;
std::vector<DimensionType> outputDimensions = outputTensorShapeDescription.GetOutputTensorShape(0);
ML_CHECK_VALID_ARGUMENT(outputDimensions.size() <= OperatorHelper::NchwDimensionCount);
const int32_t diagonalOffset = kernelCreationContext.GetOptionalAttribute<int32_t>(AttrName::K, 0);
DML_DIAGONAL_MATRIX_OPERATOR_DESC operatorDesc = {};
operatorDesc.OutputTensor = outputDescs.data();
operatorDesc.Offset = diagonalOffset;
operatorDesc.Value = 1.0f;
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_DIAGONAL_MATRIX, &operatorDesc };
SetDmlOperatorDesc(opDesc, kernelCreationContext);
}
};
DML_OP_DEFINE_CREATION_FUNCTION(EyeLike, DmlOperatorEyeLike);
} // namespace Dml

Просмотреть файл

@ -0,0 +1,48 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
class DmlOperatorGather : public DmlOperator, public GatherHelper
{
public:
DmlOperatorGather(const MLOperatorKernelCreationContext& kernelCreationContext)
: DmlOperator(kernelCreationContext),
GatherHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription())
{
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 2, "Gather expects 2 inputs.");
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1, "Gather expects 1 output.");
DmlOperator::Initialize(kernelCreationContext);
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
assert(inputDescs.size() == 2);
assert(outputDescs.size() == 1);
m_inputTensorDescs[1].ForceUnsignedDataType();
auto outputTensorShapeDescription = kernelCreationContext.GetTensorShapeDescription();;
std::vector<DimensionType> dataDimensions = outputTensorShapeDescription.GetInputTensorShape(0);
std::vector<DimensionType> indicesDimensions = outputTensorShapeDescription.GetInputTensorShape(1);
ML_CHECK_VALID_ARGUMENT(dataDimensions.size() <= OperatorHelper::NchwDimensionCount);
uint32_t dmlAxis = GetDmlAdjustedAxis(m_axis, kernelCreationContext, m_inputTensorDescs.front().GetDimensionCount());
DML_GATHER_OPERATOR_DESC operatorDesc = {};
operatorDesc.InputTensor = &inputDescs[0];
operatorDesc.IndicesTensor = &inputDescs[1];
operatorDesc.OutputTensor = outputDescs.data();
operatorDesc.Axis = dmlAxis;
operatorDesc.IndexDimensions = gsl::narrow_cast<uint32_t>(indicesDimensions.size());
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_GATHER, &operatorDesc };
SetDmlOperatorDesc(opDesc, kernelCreationContext);
}
};
DML_OP_DEFINE_CREATION_FUNCTION(Gather, DmlOperatorGather);
} // namespace Dml

Просмотреть файл

@ -0,0 +1,55 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
class DmlOperatorGemm : public DmlOperator, public GemmHelper
{
public:
DmlOperatorGemm(const MLOperatorKernelCreationContext& kernelInfo)
: DmlOperator(kernelInfo),
GemmHelper(kernelInfo, kernelInfo.GetTensorShapeDescription())
{
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() >= 3);
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);
DmlOperator::Initialize(kernelInfo);
// Broadcast C tensor to the shape of the output tensor.
m_inputTensorDescs[2] = CreateTensorDescFromInput(
kernelInfo,
2,
TensorAxis::DoNotCoerce,
TensorAxis::W,
TensorAxis::RightAligned,
kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0)
);
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
std::optional<ActivationOperatorDesc> fusedActivation = FusionHelpers::TryGetFusedActivationDesc(kernelInfo);
DML_OPERATOR_DESC fusedActivationDmlDesc = fusedActivation ? fusedActivation->GetDmlDesc() : DML_OPERATOR_DESC();
DML_GEMM_OPERATOR_DESC gemmDesc = {};
gemmDesc.ATensor = &inputDescs[0];
gemmDesc.BTensor = &inputDescs[1];
gemmDesc.CTensor = &inputDescs[2];
gemmDesc.OutputTensor = &outputDescs[0];
gemmDesc.TransA = (m_transA ? DML_MATRIX_TRANSFORM_TRANSPOSE : DML_MATRIX_TRANSFORM_NONE);
gemmDesc.TransB = (m_transB ? DML_MATRIX_TRANSFORM_TRANSPOSE : DML_MATRIX_TRANSFORM_NONE);
gemmDesc.Alpha = m_alpha;
gemmDesc.Beta = m_beta;
gemmDesc.FusedActivation = fusedActivation ? &fusedActivationDmlDesc : nullptr;
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_GEMM, &gemmDesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}
};
DML_OP_DEFINE_CREATION_FUNCTION(Gemm, DmlOperatorGemm);
DML_OP_DEFINE_CREATION_FUNCTION(FusedGemm, DmlOperatorGemm);
} // namespace Dml

Просмотреть файл

@ -0,0 +1,68 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
class DmlOperatorInstanceNormalization : public DmlOperator
{
enum InputTensors
{
IN_X,
IN_SCALE,
IN_BIAS
};
void Shift1DInputsTensorDesc(const MLOperatorKernelCreationContext& kernelCreationContext, int index, int count, uint32_t destinationAxis)
{
for (int i = index; i != index + count; ++i)
{
// Shift a single dimension size to the C channel.
// e.g. [7] or [1,1,1,7] becomes [1,7,1,1]
TensorDesc& tensorDesc = m_inputTensorDescs[i];
gsl::span<const uint32_t> sizes = tensorDesc.GetSizes();
gsl::span<const uint32_t> lastDimension = sizes.last(1);
ML_CHECK_VALID_ARGUMENT(tensorDesc.GetDimensionCount() == OperatorHelper::NchwDimensionCount);
ML_CHECK_VALID_ARGUMENT(sizes.size() >=4 && sizes[N] == 1 && sizes[C] == 1 && sizes[H] == 1);
m_inputTensorDescs[i] = CreateTensorDescFromInput(kernelCreationContext, i, TensorAxis::DoNotCoerce, TensorAxis::C, TensorAxis::LeftAligned, lastDimension);
}
}
public:
DmlOperatorInstanceNormalization(const MLOperatorKernelCreationContext& kernelCreationContext)
: DmlOperator(kernelCreationContext)
{
std::vector<std::optional<uint32_t>> kernelInputIndices = {0, 1, 2};
DmlOperator::Initialize(kernelCreationContext, kernelInputIndices);
const float epsilon = kernelCreationContext.GetOptionalAttribute<float>(AttrName::Epsilon, DefaultEpsilon);
const std::optional<ActivationOperatorDesc> fusedActivation = FusionHelpers::TryGetFusedActivationDesc(kernelCreationContext);
DML_OPERATOR_DESC fusedActivationDmlDesc = fusedActivation ? fusedActivation->GetDmlDesc() : DML_OPERATOR_DESC();
// Shift IN_SCALE and IN_BIAS input tensor descs {1, C, 1, 1} out of 1D tensors.
Shift1DInputsTensorDesc(kernelCreationContext, IN_SCALE, 2, C);
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_DESC operatorDesc = {};
operatorDesc.InputTensor = &inputDescs[0];
operatorDesc.ScaleTensor = &inputDescs[1];
operatorDesc.BiasTensor = &inputDescs[2];
operatorDesc.OutputTensor = outputDescs.data();
operatorDesc.CrossChannel = false;
operatorDesc.NormalizeVariance = true;
operatorDesc.Epsilon = epsilon;
operatorDesc.FusedActivation = fusedActivation ? &fusedActivationDmlDesc : nullptr;
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION, &operatorDesc };
SetDmlOperatorDesc(opDesc, kernelCreationContext);
}
};
DML_OP_DEFINE_CREATION_FUNCTION(InstanceNormalization, DmlOperatorInstanceNormalization);
DML_OP_DEFINE_CREATION_FUNCTION(FusedInstanceNormalization, DmlOperatorInstanceNormalization);
} // namespace Dml

Просмотреть файл

@ -0,0 +1,43 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
class DmlOperatorLocalResponseNormalization : public DmlOperator
{
public:
DmlOperatorLocalResponseNormalization(const MLOperatorKernelCreationContext& kernelCreationContext)
: DmlOperator(kernelCreationContext)
{
DmlOperator::Initialize(kernelCreationContext);
const int size = kernelCreationContext.GetOptionalAttribute<int>(AttrName::Size, 0);
const float bias = kernelCreationContext.GetOptionalAttribute<float>(AttrName::Bias, 0.0f);
const float alpha = kernelCreationContext.GetOptionalAttribute<float>(AttrName::Alpha, 0.0f);
const float beta = kernelCreationContext.GetOptionalAttribute<float>(AttrName::Beta, 0.0f);
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
ML_CHECK_VALID_ARGUMENT(inputDescs.size() == 1);
ML_CHECK_VALID_ARGUMENT(outputDescs.size() == 1);
DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_DESC operatorDesc = {};
operatorDesc.InputTensor = inputDescs.data();
operatorDesc.OutputTensor = outputDescs.data();
operatorDesc.CrossChannel = true; // crossChannel - ONNX only supports cross-channel.
operatorDesc.LocalSize = gsl::narrow_cast<uint32_t>(size);
operatorDesc.Alpha = alpha;
operatorDesc.Beta = beta;
operatorDesc.Bias = bias;
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION, &operatorDesc };
SetDmlOperatorDesc(opDesc, kernelCreationContext);
}
};
DML_OP_DEFINE_CREATION_FUNCTION(LRN, DmlOperatorLocalResponseNormalization);
} // namespace Dml

Просмотреть файл

@ -0,0 +1,42 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
class DmlOperatorLpNormalization : public DmlOperator
{
public:
DmlOperatorLpNormalization(const MLOperatorKernelCreationContext& kernelCreationContext)
: DmlOperator(kernelCreationContext)
{
DmlOperator::Initialize(kernelCreationContext);
const int onnxAxis = kernelCreationContext.GetOptionalAttribute<int>(AttrName::Axis, 0);
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
// Valid values for p are 1 and 2.
int p = kernelCreationContext.GetOptionalAttribute<int>(AttrName::P, 2);
ML_CHECK_VALID_ARGUMENT(p >= 1 && p <= 2);
uint32_t dmlAxis = GetDmlAdjustedAxis(onnxAxis, kernelCreationContext, m_inputTensorDescs.front().GetDimensionCount());
DML_LP_NORMALIZATION_OPERATOR_DESC operatorDesc = {};
operatorDesc.InputTensor = inputDescs.data();
operatorDesc.OutputTensor = outputDescs.data();
operatorDesc.Axis = dmlAxis;
operatorDesc.Epsilon = DefaultEpsilon;
operatorDesc.P = p;
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_LP_NORMALIZATION, &operatorDesc };
SetDmlOperatorDesc(opDesc, kernelCreationContext);
}
};
DML_OP_DEFINE_CREATION_FUNCTION(LpNormalization, DmlOperatorLpNormalization);
} // namespace Dml

Просмотреть файл

@ -0,0 +1,80 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
class DmlOperatorMatMul : public DmlOperator
{
enum InputTensors { IN_A, IN_B };
public:
DmlOperatorMatMul(const MLOperatorKernelCreationContext& kernelInfo)
: DmlOperator(kernelInfo)
{
// MatMul has two inputs, but DML GEMM requires 3 input bindings (a null binding for the C Tensor).
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 2);
std::vector<std::optional<uint32_t>> inputIndices = { 0, 1, std::nullopt };
DmlOperator::Initialize(kernelInfo, inputIndices);
std::vector<DimensionType> inputShape0 = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(0);
std::vector<DimensionType> inputShape1 = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(1);
std::vector<DimensionType> outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0);
// Get the padded input shapes and undo the effect of padding removal from the output shape
if (inputShape1.size() == 1)
{
inputShape1.push_back(1);
outputShape.push_back(1);
}
if (inputShape0.size() == 1)
{
inputShape0.insert(inputShape0.begin(), 1);
outputShape.insert(outputShape.end() - 1, 1);
}
// Remove the batch dimensions from each input, then re-add the broadcasted batch dimensions
// based on the output shape
inputShape0.erase(inputShape0.begin(), inputShape0.end() - 2);
inputShape1.erase(inputShape1.begin(), inputShape1.end() - 2);
inputShape0.insert(inputShape0.begin(), outputShape.begin(), outputShape.end() - 2);
inputShape1.insert(inputShape1.begin(), outputShape.begin(), outputShape.end() - 2);
// Initialize the input descriptions with broadcasting
m_inputTensorDescs[0] = CreateTensorDescFromInput(kernelInfo, 0, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, inputShape0);
m_inputTensorDescs[1] = CreateTensorDescFromInput(kernelInfo, 1, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, inputShape1);
// Initialize the output description while overriding the shape
m_outputTensorDescs[0] = CreateTensorDescFromOutput(kernelInfo, 0, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, outputShape);
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
std::optional<ActivationOperatorDesc> fusedActivation = FusionHelpers::TryGetFusedActivationDesc(kernelInfo);
DML_OPERATOR_DESC fusedActivationDmlDesc = fusedActivation ? fusedActivation->GetDmlDesc() : DML_OPERATOR_DESC();
DML_GEMM_OPERATOR_DESC gemmDesc = {};
gemmDesc.ATensor = &inputDescs[0];
gemmDesc.BTensor = &inputDescs[1];
gemmDesc.CTensor = nullptr;
gemmDesc.OutputTensor = &outputDescs[0];
gemmDesc.TransA = DML_MATRIX_TRANSFORM_NONE;
gemmDesc.TransB = DML_MATRIX_TRANSFORM_NONE;
gemmDesc.Alpha = 1.0f;
gemmDesc.Beta = 0.0f;
gemmDesc.FusedActivation = fusedActivation ? &fusedActivationDmlDesc : nullptr;
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_GEMM, &gemmDesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}
};
DML_OP_DEFINE_CREATION_FUNCTION(MatMul, DmlOperatorMatMul);
DML_OP_DEFINE_CREATION_FUNCTION(FusedMatMul, DmlOperatorMatMul);
} // namespace Dml

Просмотреть файл

@ -0,0 +1,44 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
class DmlOperatorMaxUnpool : public DmlOperator
{
public:
using Self = DmlOperatorMaxUnpool;
DmlOperatorMaxUnpool(const MLOperatorKernelCreationContext& kernelCreationContext)
: DmlOperator(kernelCreationContext)
{
uint32_t inputCount = kernelCreationContext.GetInputCount();
ML_CHECK_VALID_ARGUMENT(inputCount == 2 || inputCount == 3, "MaxUnpool expects 2 or 3 inputs.");
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1, "MaxUnpool expects 1 output.");
std::vector<std::optional<uint32_t>> inputIndices = { 0, 1 }; // The 3rd tensor ('output_shape') is not bound, just 'X' and 'I' indices.
std::vector<std::optional<uint32_t>> outputIndices = { 0 };
DmlOperator::Initialize(kernelCreationContext, inputIndices, outputIndices);
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
assert(inputDescs.size() == 2);
assert(outputDescs.size() == 1);
DML_MAX_UNPOOLING_OPERATOR_DESC poolingDesc = {};
poolingDesc.InputTensor = &inputDescs[0];
poolingDesc.IndicesTensor = &inputDescs[1];
poolingDesc.OutputTensor = outputDescs.data();
DML_OPERATOR_DESC operaterDesc = {};
operaterDesc.Type = DML_OPERATOR_MAX_UNPOOLING;
operaterDesc.Desc = &poolingDesc;
SetDmlOperatorDesc(operaterDesc, kernelCreationContext);
}
};
DML_OP_DEFINE_CREATION_FUNCTION(MaxUnpool, DmlOperatorMaxUnpool);
} // namespace Dml

Просмотреть файл

@ -0,0 +1,44 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
class DmlOperatorMeanVarNormalization : public DmlOperator
{
public:
DmlOperatorMeanVarNormalization(const MLOperatorKernelCreationContext& kernelCreationContext)
: DmlOperator(kernelCreationContext)
{
std::vector<std::optional<uint32_t>> kernelInputIndices = {0, std::nullopt, std::nullopt};
DmlOperator::Initialize(kernelCreationContext, kernelInputIndices);
const bool acrossChannels = (static_cast<bool>(kernelCreationContext.GetOptionalAttribute<int>(AttrName::AcrossChannels, 0)));
const bool normalizeVariance = (static_cast<bool>(kernelCreationContext.GetOptionalAttribute<int>(AttrName::NormalizeVariance, 0)));
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
std::optional<ActivationOperatorDesc> fusedActivation = FusionHelpers::TryGetFusedActivationDesc(kernelCreationContext);
DML_OPERATOR_DESC fusedActivationDmlDesc = fusedActivation ? fusedActivation->GetDmlDesc() : DML_OPERATOR_DESC();
DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_DESC operatorDesc = {};
operatorDesc.InputTensor = &inputDescs[0];
operatorDesc.ScaleTensor = nullptr;
operatorDesc.BiasTensor = nullptr;
operatorDesc.OutputTensor = outputDescs.data();
operatorDesc.CrossChannel = acrossChannels;
operatorDesc.NormalizeVariance = normalizeVariance;
operatorDesc.Epsilon = DefaultEpsilon;
operatorDesc.FusedActivation = fusedActivation ? &fusedActivationDmlDesc : nullptr;
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION, &operatorDesc };
SetDmlOperatorDesc(opDesc, kernelCreationContext);
}
};
DML_OP_DEFINE_CREATION_FUNCTION(MeanVarianceNormalization, DmlOperatorMeanVarNormalization);
DML_OP_DEFINE_CREATION_FUNCTION(FusedMeanVarianceNormalization, DmlOperatorMeanVarNormalization);
} // namespace Dml

Просмотреть файл

@ -0,0 +1,44 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
class DmlOperatorMemcpy : public DmlOperator
{
public:
using Self = DmlOperatorMemcpy;
DmlOperatorMemcpy(const MLOperatorKernelCreationContext& kernelCreationContext)
: DmlOperator(kernelCreationContext)
{
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 1, "MemcpyFromHost/ToHost expects 1 input tensor.");
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1, "MemcpyFromHost/ToHost expects 1 output tensor.");
DmlOperator::Initialize(kernelCreationContext);
}
void Compute(const MLOperatorKernelContext& kernelContext)
{
std::vector<IMLOperatorTensor*> inputTensors = GetInputTensors(kernelContext);
std::vector<IMLOperatorTensor*> outputTensors = GetOutputTensors(kernelContext);
assert(inputTensors.size() == 1);
assert(outputTensors.size() == 1);
THROW_IF_FAILED(m_executionProvider->CopyTensor(
outputTensors.front(),
inputTensors.front()
));
}
private:
};
// MemcpyToHost is a special case which is hardcoded in MLOperatorAuthorImpl.cpp. If name changes this must be updated.
// Special case makes sure that the output resource is created using the CPU allocator.
DML_OP_DEFINE_CREATION_FUNCTION(MemcpyFromHost, DmlOperatorMemcpy);
DML_OP_DEFINE_CREATION_FUNCTION(MemcpyToHost, DmlOperatorMemcpy);
} // namespace Dml

Просмотреть файл

@ -0,0 +1,40 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
class DmlOperatorNeg : public DmlOperator
{
public:
using Self = DmlOperatorNeg;
DmlOperatorNeg(const MLOperatorKernelCreationContext& kernelInfo)
: DmlOperator(kernelInfo)
{
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 1);
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);
Initialize(kernelInfo);
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
DML_SCALE_BIAS scaleBias = {};
scaleBias.Scale = -1.0f;
scaleBias.Bias = 0.0f;
DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC opDesc = {};
opDesc.InputTensor = inputDescs.data();
opDesc.OutputTensor = outputDescs.data();
opDesc.ScaleBias = &scaleBias;
SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_IDENTITY, &opDesc}, kernelInfo);
}
};
DML_OP_DEFINE_CREATION_FUNCTION(Neg, DmlOperatorNeg);
} // namespace Dml

Просмотреть файл

@ -0,0 +1,79 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
class DmlOperatorOneHot : public DmlOperator, OneHotHelper
{
public:
using Self = DmlOperatorOneHot;
DmlOperatorOneHot(const MLOperatorKernelCreationContext& kernelCreationContext)
: DmlOperator(kernelCreationContext),
OneHotHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription())
{
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 3);
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1);
std::vector<std::optional<uint32_t>> inputIndices = { 0, 2 }; // The second tensor ('depth') is not bound, just 'indices' and 'values'.
std::vector<std::optional<uint32_t>> outputIndices = { 0 };
DmlOperator::Initialize(kernelCreationContext, inputIndices, outputIndices);
// Unsqueeze the indices tensor by inserting a flat dimension of size 1,
// and compute the output tensor by expanding along the active axis.
// This way they are both size-compatible and directly consumable by DirectML.
std::vector<uint32_t> indicesDimensions;
indicesDimensions = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(0);
indicesDimensions.insert(indicesDimensions.begin() + m_absoluteAxis, 1u);
// Update the tensor descriptions with new sizes.
m_inputTensorDescs[0] =
TensorDesc(
m_inputTensorDescs[0].GetMlOperatorDataType(),
gsl::make_span(indicesDimensions),
gsl::make_span(indicesDimensions),
TensorAxis::DoNotCoerce,
TensorAxis::W,
TensorAxis::RightAligned,
NchwDimensionCount, // minDimensionCount
0
);
m_outputTensorDescs[0] =
TensorDesc(
m_outputTensorDescs[0].GetMlOperatorDataType(),
gsl::make_span(m_outputDimensions),
gsl::make_span(m_outputDimensions),
TensorAxis::DoNotCoerce,
TensorAxis::W,
TensorAxis::RightAligned,
NchwDimensionCount, // minDimensionCount
0
);
// Adjust the axis so it's in DML's terms rather than the original ONNX indexing.
uint32_t dmlAxis = GetDmlAdjustedAxis(
m_absoluteAxis,
gsl::narrow_cast<uint32_t>(indicesDimensions.size()),
m_inputTensorDescs.front().GetDimensionCount()
);
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
DML_ONE_HOT_OPERATOR_DESC operatorDesc = {};
operatorDesc.IndicesTensor = &inputDescs[0];
operatorDesc.ValuesTensor = &inputDescs[1];
operatorDesc.OutputTensor = outputDescs.data();
operatorDesc.Axis = dmlAxis;
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ONE_HOT, &operatorDesc };
SetDmlOperatorDesc(opDesc, kernelCreationContext);
}
};
DML_OP_DEFINE_CREATION_FUNCTION(OneHot, DmlOperatorOneHot);
} // namespace Dml

Просмотреть файл

@ -0,0 +1,77 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
class DmlOperatorPadding : public DmlOperator, public PaddingHelper
{
public:
DmlOperatorPadding(const MLOperatorKernelCreationContext& kernelInfo)
: DmlOperator(kernelInfo),
PaddingHelper(kernelInfo, kernelInfo.GetTensorShapeDescription())
{
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 1);
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);
DmlOperator::Initialize(kernelInfo);
assert(m_inputTensorDescs[0].GetDimensionCount() >= gsl::narrow_cast<uint32_t>(m_startPadding.size()));
assert(m_inputTensorDescs[0].GetDimensionCount() >= gsl::narrow_cast<uint32_t>(m_endPadding.size()));
// Pad the parameters to respect DML's requirements
m_startPadding.insert(
m_startPadding.begin(),
m_inputTensorDescs[0].GetDimensionCount() - gsl::narrow_cast<uint32_t>(m_startPadding.size()),
0);
m_endPadding.insert(
m_endPadding.begin(),
m_inputTensorDescs[0].GetDimensionCount() - gsl::narrow_cast<uint32_t>(m_endPadding.size()),
0);
// Convert padding mode.
DML_PADDING_MODE mode = DML_PADDING_MODE_CONSTANT;
std::string modeString = kernelInfo.GetOptionalAttribute<std::string>(AttrName::Mode, AttrValue::Reflect);
if (modeString == AttrValue::Constant)
{
mode = DML_PADDING_MODE_CONSTANT;
}
else if (modeString == AttrValue::Edge)
{
mode = DML_PADDING_MODE_EDGE;
}
else if (modeString == AttrValue::Reflect)
{
mode = DML_PADDING_MODE_REFLECTION;
}
else
{
ML_INVALID_ARGUMENT("Unknown Pad mode attribute.");
}
float value = kernelInfo.GetOptionalAttribute<float>(AttrName::Value, 0.0f);
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
DML_PADDING_OPERATOR_DESC paddingDesc = {};
paddingDesc.InputTensor = inputDescs.data();
paddingDesc.OutputTensor = outputDescs.data();
paddingDesc.PaddingMode = mode;
paddingDesc.PaddingValue = value;
paddingDesc.DimensionCount = gsl::narrow_cast<uint32_t>(m_startPadding.size());
paddingDesc.StartPadding = m_startPadding.data();
paddingDesc.EndPadding = m_endPadding.data();
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_PADDING, &paddingDesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}
};
DML_OP_DEFINE_CREATION_FUNCTION(Pad, DmlOperatorPadding);
} // namespace Dml

Просмотреть файл

@ -0,0 +1,131 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
class DmlOperatorPooling : public DmlOperator, public PoolingHelperBase
{
public:
using Self = DmlOperatorPooling;
DmlOperatorPooling(
const MLOperatorKernelCreationContext& kernelInfo,
DML_OPERATOR_TYPE function,
bool useGlobalPooling
)
: DmlOperator(kernelInfo),
PoolingHelperBase(kernelInfo, kernelInfo.GetTensorShapeDescription(), useGlobalPooling),
m_function(function)
{
DmlOperator::Initialize(kernelInfo);
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
ML_CHECK_VALID_ARGUMENT(inputDescs.size() >= 1, "MaxPool input count must be >=1.");
ML_CHECK_VALID_ARGUMENT(outputDescs.size() >= 1, "MaxPool output count must be >=1.");
assert(m_kernel.spatialDimensionCount <= ARRAYSIZE(m_kernel.windowSize));
// DML requires that DimensionCount be equal to Input.DimCount - 2 for Pooling
uint32_t expectedSpatialDimCount = m_inputTensorDescs[0].GetDimensionCount() - 2;
if (m_kernel.spatialDimensionCount < expectedSpatialDimCount)
{
size_t shift = expectedSpatialDimCount - m_kernel.spatialDimensionCount;
for (int i = gsl::narrow_cast<int>(m_kernel.spatialDimensionCount) - 1; i >= 0; i--)
{
m_kernel.windowSize[i + shift] = m_kernel.windowSize[i];
m_kernel.windowSize[i] = 1;
m_kernel.strides[i + shift] = m_kernel.strides[i];
m_kernel.strides[i] = 1;
m_kernel.startPadding[i + shift] = m_kernel.startPadding[i];
m_kernel.startPadding[i] = 0;
m_kernel.endPadding[i + shift] = m_kernel.endPadding[i];
m_kernel.endPadding[i] = 0;
}
m_kernel.spatialDimensionCount = expectedSpatialDimCount;
}
auto SetOpDesc = [&](auto& poolingDesc)
{
poolingDesc.InputTensor = inputDescs.data();
poolingDesc.OutputTensor = outputDescs.data();
poolingDesc.DimensionCount = m_kernel.spatialDimensionCount;
poolingDesc.WindowSize = m_kernel.windowSize;
poolingDesc.Strides = m_kernel.strides;
poolingDesc.StartPadding = m_kernel.startPadding;
poolingDesc.EndPadding = m_kernel.endPadding;
DML_OPERATOR_DESC opDesc = {};
opDesc.Type = ApiTraits::OperatorDescTraits<std::remove_reference<decltype(poolingDesc)>::type>::Type;
opDesc.Desc = &poolingDesc;
SetDmlOperatorDesc(opDesc, kernelInfo);
};
switch (m_function)
{
case DML_OPERATOR_AVERAGE_POOLING:
{
DML_AVERAGE_POOLING_OPERATOR_DESC desc = {};
desc.IncludePadding = kernelInfo.GetOptionalAttribute<bool>(AttrName::CountIncludePad, false);
SetOpDesc(desc);
break;
}
case DML_OPERATOR_LP_POOLING:
{
DML_LP_POOLING_OPERATOR_DESC desc = {};
desc.P = kernelInfo.GetOptionalAttribute<int>(AttrName::P, 2);
ML_CHECK_VALID_ARGUMENT(desc.P > 0);
SetOpDesc(desc);
break;
}
case DML_OPERATOR_MAX_POOLING:
case DML_OPERATOR_MAX_POOLING1:
{
if (outputDescs.size() > 1 && outputDescs[1].Desc != nullptr)
{
DML_MAX_POOLING1_OPERATOR_DESC desc = {};
desc.OutputIndicesTensor = &outputDescs[1];
SetOpDesc(desc);
}
else
{
// Use the old pooling command, which supports potential metacommands.
DML_MAX_POOLING_OPERATOR_DESC desc = {};
SetOpDesc(desc);
}
break;
}
}
}
private:
DML_OPERATOR_TYPE m_function;
};
// A specific type of operation for registration.
template <DML_OPERATOR_TYPE Function, bool UseGlobalPooling>
class DmlOperatorPoolingTemplate : public DmlOperatorPooling
{
public:
DmlOperatorPoolingTemplate(const MLOperatorKernelCreationContext& kernelInfo)
: DmlOperatorPooling(kernelInfo, Function, UseGlobalPooling)
{
}
};
DML_OP_DEFINE_CREATION_FUNCTION(AveragePool, DmlOperatorPoolingTemplate<DML_OPERATOR_AVERAGE_POOLING, false>);
DML_OP_DEFINE_CREATION_FUNCTION(GlobalAveragePool, DmlOperatorPoolingTemplate<DML_OPERATOR_AVERAGE_POOLING, true>);
DML_OP_DEFINE_CREATION_FUNCTION(MaxPool, DmlOperatorPoolingTemplate<DML_OPERATOR_MAX_POOLING1, false>);
DML_OP_DEFINE_CREATION_FUNCTION(GlobalMaxPool, DmlOperatorPoolingTemplate<DML_OPERATOR_MAX_POOLING, true>);
DML_OP_DEFINE_CREATION_FUNCTION(LpPool, DmlOperatorPoolingTemplate<DML_OPERATOR_LP_POOLING, false>);
DML_OP_DEFINE_CREATION_FUNCTION(GlobalLpPool, DmlOperatorPoolingTemplate<DML_OPERATOR_LP_POOLING, true>);
} // namespace Dml

Просмотреть файл

@ -0,0 +1,410 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
// Base class for RNN ops (simple RNN, GRU, and LSTM).
//
class DmlOperatorRecurrentBase: public DmlOperator, public RecurrentHelper
{
public:
using Self = DmlOperatorRecurrentBase;
DmlOperatorRecurrentBase(const MLOperatorKernelCreationContext& kernelInfo):
DmlOperator(kernelInfo),
RecurrentHelper(kernelInfo, kernelInfo.GetTensorShapeDescription())
{
}
void Initialize(
const MLOperatorKernelCreationContext& kernelInfo,
uint32_t sequenceLengthInputIndex,
gsl::span<const std::string> defaultActivations,
const std::optional<const std::vector<std::optional<uint32_t>>>& kernelInputIndices = std::nullopt,
const std::optional<const std::vector<std::optional<uint32_t>>>& kernelOutputIndices = std::nullopt)
{
DmlOperator::Initialize(kernelInfo, kernelInputIndices, kernelOutputIndices);
m_direction = GetRNNDirection(kernelInfo);
InitActivationDescs(kernelInfo, /*inout*/ m_activationOpDescs, defaultActivations);
bool hasOutput = false;
for (const TensorDesc& desc : m_outputTensorDescs)
{
if (desc.IsValid())
{
hasOutput = true;
break;
}
}
if (!hasOutput)
{
ML_INVALID_ARGUMENT("At least one output should be requested.");
}
if (m_inputTensorDescs.size() > sequenceLengthInputIndex && m_inputTensorDescs[sequenceLengthInputIndex].IsValid())
{
m_inputTensorDescs[sequenceLengthInputIndex].ForceUnsignedDataType();
}
}
DML_RECURRENT_NETWORK_DIRECTION GetRNNDirection(const MLOperatorKernelCreationContext& kernelInfo)
{
std::string direction = kernelInfo.GetOptionalAttribute<std::string>(AttrName::Direction, AttrValue::DirectionForward);
if (direction == AttrValue::DirectionForward) { return DML_RECURRENT_NETWORK_DIRECTION_FORWARD; }
if (direction == AttrValue::DirectionReverse) { return DML_RECURRENT_NETWORK_DIRECTION_BACKWARD; }
if (direction == AttrValue::DirectionBidirectional) { return DML_RECURRENT_NETWORK_DIRECTION_BIDIRECTIONAL; }
ML_INVALID_ARGUMENT("Unsupported direction"); // throws
}
void InitActivationDescs(const MLOperatorKernelCreationContext& kernelInfo, _Out_ std::vector<DML_OPERATOR_DESC>& descs, gsl::span<const std::string> defaultActivations)
{
std::vector<std::string> activations = kernelInfo.GetOptionalStringAttributeVector(AttrName::Activations);
if (activations.empty())
{
uint32_t loopCount = (m_direction == DML_RECURRENT_NETWORK_DIRECTION_BIDIRECTIONAL) ? 2 : 1;
// Default value is set if none are given
for (uint32_t i = 0; i < loopCount; i++)
{
std::copy(defaultActivations.begin(), defaultActivations.end(), std::back_inserter(activations));
}
}
// resize the array to the correct direction count. The schema defaults to always be 2 elements which is wrong for single direction case.
activations.resize((m_direction == DML_RECURRENT_NETWORK_DIRECTION_BIDIRECTIONAL) ? 2 * defaultActivations.size() : defaultActivations.size());
descs.resize(activations.size());
m_activationDescs.resize(activations.size());
// Some functions have additional parameters. It is assumed the alpha/beta values will
// be ordered by function, so this treats the respective operator attributes as stacks.
std::vector<float> alphas;
if (kernelInfo.HasAttribute(AttrName::ActivationAlpha, MLOperatorAttributeType::FloatArray))
{
alphas = kernelInfo.GetAttributeVector<float>(AttrName::ActivationAlpha);
}
std::vector<float> betas;
if (kernelInfo.HasAttribute(AttrName::ActivationBeta, MLOperatorAttributeType::FloatArray))
{
betas = kernelInfo.GetAttributeVector<float>(AttrName::ActivationBeta);
}
size_t currentAlpha = 0;
size_t currentBeta = 0;
auto NextAlpha = [&](DML_OPERATOR_TYPE function)
{
if (currentAlpha >= alphas.size())
{
return ActivationHelper::GetDefaultAlpha(function);
}
return alphas[currentAlpha++];
};
auto NextBeta = [&](DML_OPERATOR_TYPE function)
{
if (currentBeta >= betas.size())
{
return ActivationHelper::GetDefaultBeta(function);
}
return betas[currentBeta++];
};
for (size_t i = 0; i < activations.size(); ++i)
{
const std::string& activationName = activations[i];
DML_OPERATOR_DESC& desc = descs[i];
ActivationOperatorDescUnion& activationDesc = m_activationDescs[i];
desc.Desc = &activationDesc;
if (activationName == AttrValue::ActivationRelu)
{
desc.Type = DML_OPERATOR_ACTIVATION_RELU;
}
else if (activationName == AttrValue::ActivationLeakyRelu)
{
desc.Type = DML_OPERATOR_ACTIVATION_LEAKY_RELU;
activationDesc.leakyRelu.Alpha = NextAlpha(desc.Type);
}
else if (activationName == AttrValue::ActivationThresholdedRelu)
{
desc.Type = DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU;
activationDesc.thresholdedRelu.Alpha = NextAlpha(desc.Type);
}
else if (activationName == AttrValue::ActivationTanh)
{
desc.Type = DML_OPERATOR_ACTIVATION_TANH;
}
else if (activationName == AttrValue::ActivationScaledTanh)
{
desc.Type = DML_OPERATOR_ACTIVATION_SCALED_TANH;
activationDesc.scaledTanh.Alpha = NextAlpha(desc.Type);
activationDesc.scaledTanh.Beta = NextBeta(desc.Type);
}
else if (activationName == AttrValue::ActivationSigmoid)
{
desc.Type = DML_OPERATOR_ACTIVATION_SIGMOID;
}
else if (activationName == AttrValue::ActivationSigmoidHard)
{
desc.Type = DML_OPERATOR_ACTIVATION_HARD_SIGMOID;
activationDesc.hardSigmoid.Alpha = NextAlpha(desc.Type);
activationDesc.hardSigmoid.Beta = NextBeta(desc.Type);
}
else if (activationName == AttrValue::ActivationElu)
{
desc.Type = DML_OPERATOR_ACTIVATION_ELU;
activationDesc.elu.Alpha = NextAlpha(desc.Type);
}
else if (activationName == AttrValue::ActivationSoftsign)
{
desc.Type = DML_OPERATOR_ACTIVATION_SOFTSIGN;
}
else if (activationName == AttrValue::ActivationSoftplus)
{
desc.Type = DML_OPERATOR_ACTIVATION_SOFTPLUS;
}
else
{
ML_INVALID_ARGUMENT("Unsupported activation function");
}
}
}
void Compute(const MLOperatorKernelContext& kernelContext) override
{
// Assume that enough GPU work has been queued up after the RNN operator that it is worth
// kicking it off, to enable subsequent CPU work to be parallelized with this GPU work.
__super::Compute(kernelContext);
m_executionProvider->Flush();
}
protected:
std::vector<DML_OPERATOR_DESC> m_activationOpDescs;
std::vector<ActivationOperatorDescUnion> m_activationDescs;
DML_RECURRENT_NETWORK_DIRECTION m_direction;
};
// Simple RNN
//
class DmlOperatorRecurrentNeuralNetwork : public DmlOperatorRecurrentBase
{
public:
DmlOperatorRecurrentNeuralNetwork(const MLOperatorKernelCreationContext& kernelInfo)
: DmlOperatorRecurrentBase(kernelInfo)
{
// HiddenInit and SequenceLengths are reverse with ONNX ordering
std::vector<std::optional<uint32_t>> kernelInputIndices = {0, 1, 2, 3, 5, 4};
std::vector<std::optional<uint32_t>> kernelOutputIndices = {0, 1};
std::array<std::string, 1> defaultActivations = {AttrValue::ActivationTanh};
DmlOperatorRecurrentBase::Initialize(kernelInfo, IN_SEQUENCE_LENGTHS, defaultActivations, kernelInputIndices, kernelOutputIndices);
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
DML_RNN_OPERATOR_DESC rnnDesc = {};
rnnDesc.InputTensor = &inputDescs[IN_X];
rnnDesc.WeightTensor = &inputDescs[IN_WEIGHTS];
rnnDesc.RecurrenceTensor = &inputDescs[IN_RECURRENCE];
rnnDesc.BiasTensor = (inputDescs[IN_BIAS].Desc != nullptr) ? &inputDescs[IN_BIAS] : nullptr;
rnnDesc.HiddenInitTensor = (inputDescs[IN_HIDDEN_INIT].Desc != nullptr) ? &inputDescs[IN_HIDDEN_INIT] : nullptr;
rnnDesc.SequenceLengthsTensor = (inputDescs[IN_SEQUENCE_LENGTHS].Desc != nullptr) ? &inputDescs[IN_SEQUENCE_LENGTHS] : nullptr;
rnnDesc.OutputSequenceTensor = (outputDescs[OUT_SEQUENCE].Desc != nullptr) ? &outputDescs[OUT_SEQUENCE] : nullptr;
rnnDesc.OutputSingleTensor = (outputDescs[OUT_SINGLE].Desc != nullptr) ? &outputDescs[OUT_SINGLE] : nullptr;
rnnDesc.ActivationDescCount = gsl::narrow_cast<uint32_t>(m_activationOpDescs.size());
rnnDesc.ActivationDescs = m_activationOpDescs.data();
rnnDesc.Direction = m_direction;
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_RNN, &rnnDesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}
private:
// Inputs in DML's order, which is different from ONNX.
enum InputTensors
{
IN_X, // X
IN_WEIGHTS, // W
IN_RECURRENCE, // R
IN_BIAS, // B
IN_HIDDEN_INIT, // initial_h
IN_SEQUENCE_LENGTHS, // sequence_lens
};
enum OutputTensors
{
OUT_SEQUENCE, // Y
OUT_SINGLE
};
};
// GRU
//
class DmlOperatorGatedRecurrentUnit : public DmlOperatorRecurrentBase
{
public:
DmlOperatorGatedRecurrentUnit(const MLOperatorKernelCreationContext& kernelInfo)
: DmlOperatorRecurrentBase(kernelInfo)
{
std::array<std::string, 2> defaultActivations = {AttrValue::ActivationSigmoid, AttrValue::ActivationTanh};
bool linearBeforeReset = kernelInfo.GetOptionalAttribute<int64_t>(AttrName::LinearBeforeReset, 0) != 0;
// HiddenInit and SequenceLengths are reverse with ONNX ordering
std::vector<std::optional<uint32_t>> kernelInputIndices = {0, 1, 2, 3, 5, 4};
std::vector<std::optional<uint32_t>> kernelOutputIndices = {0, 1};
DmlOperatorRecurrentBase::Initialize(kernelInfo, IN_SEQUENCE_LENGTHS, defaultActivations, kernelInputIndices, kernelOutputIndices);
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
DML_GRU_OPERATOR_DESC rnnDesc = {};
rnnDesc.InputTensor = &inputDescs[IN_X];
rnnDesc.WeightTensor = &inputDescs[IN_WEIGHTS];
rnnDesc.RecurrenceTensor = &inputDescs[IN_RECURRENCE];
rnnDesc.BiasTensor = (inputDescs[IN_BIAS].Desc != nullptr) ? &inputDescs[IN_BIAS] : nullptr;
rnnDesc.HiddenInitTensor = (inputDescs[IN_HIDDEN_INIT].Desc != nullptr) ? &inputDescs[IN_HIDDEN_INIT] : nullptr;
rnnDesc.SequenceLengthsTensor = (inputDescs[IN_SEQUENCE_LENGTHS].Desc != nullptr) ? &inputDescs[IN_SEQUENCE_LENGTHS] : nullptr;
rnnDesc.OutputSequenceTensor = (outputDescs[OUT_SEQUENCE].Desc != nullptr) ? &outputDescs[OUT_SEQUENCE] : nullptr;
rnnDesc.OutputSingleTensor = (outputDescs[OUT_SINGLE].Desc != nullptr) ? &outputDescs[OUT_SINGLE] : nullptr;
rnnDesc.ActivationDescCount = gsl::narrow_cast<uint32_t>(m_activationOpDescs.size());
rnnDesc.ActivationDescs = m_activationOpDescs.data();
rnnDesc.Direction = m_direction;
rnnDesc.LinearBeforeReset = linearBeforeReset ? TRUE : FALSE;
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_GRU, &rnnDesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}
private:
// Inputs in DML's order, which is different from ONNX.
enum InputTensors
{
IN_X,
IN_WEIGHTS,
IN_RECURRENCE,
IN_BIAS,
IN_HIDDEN_INIT,
IN_SEQUENCE_LENGTHS,
};
enum OutputTensors
{
OUT_SEQUENCE, // Y
OUT_SINGLE
};
};
// LSTM
//
class DmlOperatorLongShortTermUnit : public DmlOperatorRecurrentBase
{
public:
DmlOperatorLongShortTermUnit(const MLOperatorKernelCreationContext& kernelInfo)
: DmlOperatorRecurrentBase(kernelInfo)
{
std::array<std::string, 3> defaultActivations = {AttrValue::ActivationSigmoid, AttrValue::ActivationTanh, AttrValue::ActivationTanh};
bool useClipThreshold = kernelInfo.HasAttribute(AttrName::Clip, MLOperatorAttributeType::Float);
float clipThreshold = kernelInfo.GetOptionalAttribute<float>(AttrName::Clip, 0.0f);
bool coupleInputForget = kernelInfo.GetOptionalAttribute<bool>(AttrName::InputForget, false);
std::vector<std::optional<uint32_t>> kernelInputIndices =
{
0, // DML Input tensor is ONNX input 0
1, // DML Weight tensor is ONNX input 1
2, // DML Recurrence tensor is ONNX input 2
3, // DML Bias tensor is ONNX input 3
5, // DML HiddenInit tensor is ONNX input 5
6, // DML CellMem tensor is ONNX input 6
4, // DML SequenceLengths tensor is ONNX input 4
7 // DML Peephole tensor is ONNX input 7
};
std::vector<std::optional<uint32_t>> kernelOutputIndices =
{
0, // DML OutputSequence tensor is ONNX input 0
1, // DML OutputSingle tensor is ONNX input 1
2, // DML OutputCellSingle tensor is ONNX input 2
};
DmlOperatorRecurrentBase::Initialize(kernelInfo, IN_SEQUENCE_LENGTHS, defaultActivations, kernelInputIndices, kernelOutputIndices);
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
DML_LSTM_OPERATOR_DESC rnnDesc = {};
rnnDesc.InputTensor = &inputDescs[IN_X];
rnnDesc.WeightTensor = &inputDescs[IN_WEIGHTS];
rnnDesc.RecurrenceTensor = &inputDescs[IN_RECURRENCE];
rnnDesc.BiasTensor = (inputDescs[IN_BIAS].Desc != nullptr) ? &inputDescs[IN_BIAS] : nullptr;
rnnDesc.HiddenInitTensor = (inputDescs[IN_HIDDEN_INIT].Desc != nullptr) ? &inputDescs[IN_HIDDEN_INIT] : nullptr;
rnnDesc.CellMemInitTensor = (inputDescs[IN_CELL_GATE_INIT].Desc != nullptr) ? &inputDescs[IN_CELL_GATE_INIT] : nullptr;
rnnDesc.SequenceLengthsTensor = (inputDescs[IN_SEQUENCE_LENGTHS].Desc != nullptr) ? &inputDescs[IN_SEQUENCE_LENGTHS] : nullptr;
rnnDesc.PeepholeTensor = (inputDescs[IN_PEEPHOLE].Desc != nullptr) ? &inputDescs[IN_PEEPHOLE] : nullptr;
rnnDesc.OutputSequenceTensor = (outputDescs[OUT_SEQUENCE].Desc != nullptr) ? &outputDescs[OUT_SEQUENCE] : nullptr;
rnnDesc.OutputSingleTensor = (outputDescs[OUT_SINGLE].Desc != nullptr) ? &outputDescs[OUT_SINGLE] : nullptr;
rnnDesc.OutputCellSingleTensor = (outputDescs[OUT_CELL_SINGLE].Desc != nullptr) ? &outputDescs[OUT_CELL_SINGLE] : nullptr;
rnnDesc.ActivationDescCount = gsl::narrow_cast<uint32_t>(m_activationOpDescs.size());
rnnDesc.ActivationDescs = m_activationOpDescs.data();
rnnDesc.Direction = m_direction;
rnnDesc.UseClipThreshold = useClipThreshold ? TRUE : FALSE;
rnnDesc.ClipThreshold = clipThreshold;
rnnDesc.CoupleInputForget = coupleInputForget ? TRUE : FALSE;
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_LSTM, &rnnDesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}
private:
// Inputs in DML's order, which is different from ONNX.
enum InputTensors
{
IN_X,
IN_WEIGHTS,
IN_RECURRENCE,
IN_BIAS,
IN_HIDDEN_INIT,
IN_CELL_GATE_INIT,
IN_SEQUENCE_LENGTHS,
IN_PEEPHOLE
};
enum OutputTensors
{
OUT_SEQUENCE, // Y
OUT_SINGLE,
OUT_CELL_SINGLE
};
};
DML_OP_DEFINE_CREATION_FUNCTION(RNN, DmlOperatorRecurrentNeuralNetwork);
DML_OP_DEFINE_CREATION_FUNCTION(GRU, DmlOperatorGatedRecurrentUnit);
DML_OP_DEFINE_CREATION_FUNCTION(LSTM, DmlOperatorLongShortTermUnit);
} // namespace Dml

Просмотреть файл

@ -0,0 +1,121 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
class DmlOperatorReduce : public DmlOperator, public ReduceHelperBase
{
public:
DmlOperatorReduce(
const MLOperatorKernelCreationContext& kernelInfo,
DML_REDUCE_FUNCTION function
)
: DmlOperator(kernelInfo),
ReduceHelperBase(kernelInfo,
kernelInfo.GetTensorShapeDescription(),
(function != DML_REDUCE_FUNCTION_ARGMAX && function != DML_REDUCE_FUNCTION_ARGMIN))
{
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 1);
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);
DmlOperator::Initialize(kernelInfo);
// Zero the output tensor's memory for ArgMin & ArgMax, which produce INT64 output.
if ((function == DML_REDUCE_FUNCTION_ARGMAX) || (function == DML_REDUCE_FUNCTION_ARGMIN))
{
m_zeroOperator = InitializeZeroInt64Tensor(m_outputTensorDescs[0].GetBufferSizeInBytes());
}
std::vector<uint32_t> dmlAxes;
std::vector<DimensionType> reducedDims = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(0);
int dimOffset = gsl::narrow_cast<int>(OperatorHelper::NchwDimensionCount - reducedDims.size());
for (auto& dim : m_axes)
{
reducedDims[dim] = 1;
dmlAxes.push_back(static_cast<uint32_t>(dim + dimOffset));
}
if (!m_keepDims)
{
// DML doesn't know about keepDim and always assume the dim is preserved after reduce.
// So if m_keepDims is false, the ONNX output dim is different than DML tensor desc dim.
// ReduceSum example:
// input dims: {3, 2, 2}
// axes: 1
// keepDims: 0
//
// the ONNX output expect to be of dim {3, 2}, while DML expect the output tensor desc
// dim to be {3, 1, 2}.
//
m_outputTensorDescs[0] = CreateTensorDescFromOutput(
kernelInfo,
0,
TensorAxis::DoNotCoerce,
TensorAxis::W,
TensorAxis::RightAligned,
reducedDims);
}
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
DML_REDUCE_OPERATOR_DESC reduceDesc = {};
reduceDesc.InputTensor = inputDescs.data();
reduceDesc.OutputTensor = outputDescs.data();
reduceDesc.Function = function;
reduceDesc.Axes = dmlAxes.data();
reduceDesc.AxisCount = gsl::narrow_cast<uint32_t>(dmlAxes.size());
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_REDUCE, &reduceDesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}
void Compute(const MLOperatorKernelContext& kernelContext) override
{
std::vector<IMLOperatorTensor*> inputTensors = GetInputTensorsForExecute(kernelContext);
std::vector<IMLOperatorTensor*> outputTensors = GetOutputTensorsForExecute(kernelContext);
if (m_zeroOperator)
{
ExecuteZeroInt64Tensor(m_zeroOperator.Get(), outputTensors[0]);
}
THROW_IF_FAILED(m_executionProvider->ExecuteOperator(
m_compiledOperator.Get(),
m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr,
gsl::make_span(inputTensors),
gsl::make_span(outputTensors)));
}
private:
ComPtr<IDMLCompiledOperator> m_zeroOperator;
};
// A specific type of operation for registration.
template <DML_REDUCE_FUNCTION Function>
class DmlOperatorReduceTemplate : public DmlOperatorReduce
{
public:
DmlOperatorReduceTemplate(const MLOperatorKernelCreationContext& kernelInfo)
: DmlOperatorReduce(kernelInfo, Function)
{
}
};
DML_OP_DEFINE_CREATION_FUNCTION(ReduceSum, DmlOperatorReduceTemplate<DML_REDUCE_FUNCTION_SUM>);
DML_OP_DEFINE_CREATION_FUNCTION(ReduceMean, DmlOperatorReduceTemplate<DML_REDUCE_FUNCTION_AVERAGE>);
DML_OP_DEFINE_CREATION_FUNCTION(ReduceProd, DmlOperatorReduceTemplate<DML_REDUCE_FUNCTION_MULTIPLY>);
DML_OP_DEFINE_CREATION_FUNCTION(ReduceLogSum, DmlOperatorReduceTemplate<DML_REDUCE_FUNCTION_LOG_SUM>);
DML_OP_DEFINE_CREATION_FUNCTION(ReduceLogSumExp, DmlOperatorReduceTemplate<DML_REDUCE_FUNCTION_LOG_SUM_EXP>);
DML_OP_DEFINE_CREATION_FUNCTION(ReduceSumSquare, DmlOperatorReduceTemplate<DML_REDUCE_FUNCTION_SUM_SQUARE>);
DML_OP_DEFINE_CREATION_FUNCTION(ReduceL1, DmlOperatorReduceTemplate<DML_REDUCE_FUNCTION_L1>);
DML_OP_DEFINE_CREATION_FUNCTION(ReduceL2, DmlOperatorReduceTemplate<DML_REDUCE_FUNCTION_L2>);
DML_OP_DEFINE_CREATION_FUNCTION(ReduceMax, DmlOperatorReduceTemplate<DML_REDUCE_FUNCTION_MAX>);
DML_OP_DEFINE_CREATION_FUNCTION(ReduceMin, DmlOperatorReduceTemplate<DML_REDUCE_FUNCTION_MIN>);
DML_OP_DEFINE_CREATION_FUNCTION(ArgMax, DmlOperatorReduceTemplate<DML_REDUCE_FUNCTION_ARGMAX>);
DML_OP_DEFINE_CREATION_FUNCTION(ArgMin, DmlOperatorReduceTemplate<DML_REDUCE_FUNCTION_ARGMIN>);
} // namespace Dml

Просмотреть файл

@ -0,0 +1,79 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
class DmlOperatorResize : public DmlOperator, public ResizeHelper
{
public:
// Resample a multidimensional image to a new size.
DmlOperatorResize(const MLOperatorKernelCreationContext& kernelCreationContext)
: DmlOperator(kernelCreationContext),
ResizeHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription())
{
ML_CHECK_VALID_ARGUMENT(!m_scales.empty(), "Resize/Upsample expect scales, either a 2nd input tensors or 'scales' attribute.");
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1, "Resize/Upsample expect 1 output tensor.");
// Use only the first input tensor. In the case of Resize or the later Upsample-v9,
// the second tensor is CPU based and should not be passed to Resize.
std::vector<std::optional<uint32_t>> inputIndices = { 0 };
std::vector<std::optional<uint32_t>> outputIndices = { 0 };
DmlOperator::Initialize(kernelCreationContext, inputIndices, outputIndices);
// Because DirectML supports a limited number of dimensions, try to squeeze the dimension count
// to only those which actually matter. Models sometimes use a greater number of dimensions,
// even though those dimensions have no significance and can be elided (nop 1's), coercing the
// total dimension count back down to a supported value.
std::vector<uint32_t> squeezedInputShape = m_inputDimensions;
std::vector<uint32_t> squeezedOutputShape = m_outputDimensions;
std::vector<uint32_t> squeezableDimensionIndices;
std::vector<float> paddedScales = m_scales;
FindValueIndices<uint32_t>(gsl::make_span(m_outputDimensions), 1u, /*out*/ squeezableDimensionIndices);
RemoveValuesByIndex(squeezableDimensionIndices, /*keepOneValue*/ true, /*inout*/ squeezedInputShape);
RemoveValuesByIndex(squeezableDimensionIndices, /*keepOneValue*/ true, /*inout*/ paddedScales);
RemoveValuesByIndex(squeezableDimensionIndices, /*keepOneValue*/ true, /*inout*/ squeezedOutputShape);
// Update the tensor descriptions.
MLOperatorTensorDataType inputTensorDataType = kernelCreationContext.GetInputEdgeDescription(0).tensorDataType;
auto inputTensorDesc = TensorDesc(inputTensorDataType, squeezedInputShape, squeezedInputShape, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, NchwDimensionCount, 0);
auto outputTensorDesc = TensorDesc(inputTensorDataType, squeezedOutputShape, squeezedOutputShape, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, NchwDimensionCount, 0);
m_inputTensorDescs[0] = inputTensorDesc;
m_outputTensorDescs[0] = outputTensorDesc;
// If the output tensor dimension count was right-aligned to a larger size,
// then ensure that scales has the same count as the tensor rank by inserting
// leading ones, since DirectML requires the scales to have the same count.
const uint32_t squeezedDimCount = gsl::narrow_cast<uint32_t>(squeezedOutputShape.size());
const uint32_t dmlCompatibleDimCount = outputTensorDesc.GetDimensionCount();
if (dmlCompatibleDimCount > squeezedDimCount)
{
paddedScales.insert(paddedScales.begin(), dmlCompatibleDimCount - squeezedDimCount, 1.0f);
}
std::string mode = kernelCreationContext.GetOptionalAttribute<std::string>(AttrName::Mode, "NEAREST");
DML_INTERPOLATION_MODE interpolationMode = Dml::MapStringToInteropolationMode(mode);
// Create the operator description.
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
DML_RESAMPLE_OPERATOR_DESC operatorDesc = {};
operatorDesc.InputTensor = inputDescs.data();
operatorDesc.OutputTensor = outputDescs.data();
operatorDesc.InterpolationMode = interpolationMode;
operatorDesc.Scales = paddedScales.data();
operatorDesc.ScaleCount = gsl::narrow_cast<uint32_t>(paddedScales.size());
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_RESAMPLE, &operatorDesc };
SetDmlOperatorDesc(opDesc, kernelCreationContext);
}
};
DML_OP_DEFINE_CREATION_FUNCTION(Resize, DmlOperatorResize);
DML_OP_DEFINE_CREATION_FUNCTION(Upsample, DmlOperatorResize);
} // namespace Dml

Просмотреть файл

@ -0,0 +1,42 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
class DmlOperatorRegionOfInterestPooling : public DmlOperator, public RoiPoolingHelper
{
public:
using Self = DmlOperatorRegionOfInterestPooling;
DmlOperatorRegionOfInterestPooling(const MLOperatorKernelCreationContext& kernelInfo)
: DmlOperator(kernelInfo),
RoiPoolingHelper(kernelInfo, kernelInfo.GetTensorShapeDescription()),
m_spatialScale(kernelInfo.GetOptionalAttribute<float>(AttrName::SpatialScale, 1.0f))
{
DmlOperator::Initialize(kernelInfo);
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
DML_ROI_POOLING_OPERATOR_DESC poolingDesc = {};
poolingDesc.InputTensor = &inputDescs[0];
poolingDesc.ROITensor = &inputDescs[1];
poolingDesc.OutputTensor = &outputDescs[0];
poolingDesc.SpatialScale = m_spatialScale;
poolingDesc.PooledSize = { m_pooledSizeH, m_pooledSizeW };
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ROI_POOLING, &poolingDesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}
private:
float m_spatialScale = 1.0f;
};
DML_OP_DEFINE_CREATION_FUNCTION(MaxRoiPool, DmlOperatorRegionOfInterestPooling);
} // namespace Dml

Просмотреть файл

@ -0,0 +1,55 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
class DmlOperatorScatter : public DmlOperator
{
public:
DmlOperatorScatter(const MLOperatorKernelCreationContext& kernelCreationContext)
: DmlOperator(kernelCreationContext)
{
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 3, "Scatter expects 3 inputs.");
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1, "Scatter expects 1 output.");
DmlOperator::Initialize(kernelCreationContext);
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
assert(inputDescs.size() == 3);
assert(outputDescs.size() == 1);
m_inputTensorDescs[1].ForceUnsignedDataType();
auto tensorShapeDescription = kernelCreationContext.GetTensorShapeDescription();;
std::vector<DimensionType> dataDimensions = tensorShapeDescription.GetInputTensorShape(0);
std::vector<DimensionType> indicesDimensions = tensorShapeDescription.GetInputTensorShape(1);
std::vector<DimensionType> updatesDimensions = tensorShapeDescription.GetInputTensorShape(2);
std::vector<DimensionType> outputDimensions = tensorShapeDescription.GetInputTensorShape(0);
ML_CHECK_VALID_ARGUMENT(dataDimensions == outputDimensions);
ML_CHECK_VALID_ARGUMENT(indicesDimensions == updatesDimensions);
ML_CHECK_VALID_ARGUMENT(dataDimensions.size() == indicesDimensions.size());
ML_CHECK_VALID_ARGUMENT(dataDimensions.size() <= OperatorHelper::NchwDimensionCount);
// Read the axis.
int onnxAxis = kernelCreationContext.GetOptionalAttribute<int>(AttrName::Axis, 0);
uint32_t dmlAxis = GetDmlAdjustedAxis(onnxAxis, kernelCreationContext, m_inputTensorDescs.front().GetDimensionCount());
DML_SCATTER_OPERATOR_DESC operatorDesc = {};
operatorDesc.InputTensor = &inputDescs[0];
operatorDesc.IndicesTensor = &inputDescs[1];
operatorDesc.UpdatesTensor = &inputDescs[2];
operatorDesc.OutputTensor = outputDescs.data();
operatorDesc.Axis = dmlAxis;
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_SCATTER, &operatorDesc };
SetDmlOperatorDesc(opDesc, kernelCreationContext);
}
};
DML_OP_DEFINE_CREATION_FUNCTION(Scatter, DmlOperatorScatter);
} // namespace Dml

Просмотреть файл

@ -0,0 +1,59 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
class DmlOperatorSlice : public DmlOperator, public SliceHelper
{
public:
DmlOperatorSlice(const MLOperatorKernelCreationContext& kernelInfo)
: DmlOperator(kernelInfo),
SliceHelper(kernelInfo, kernelInfo.GetTensorShapeDescription())
{
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 1);
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);
DmlOperator::Initialize(kernelInfo);
assert(m_inputTensorDescs[0].GetDimensionCount() >= gsl::narrow_cast<uint32_t>(m_offsets.size()));
assert(m_inputTensorDescs[0].GetDimensionCount() >= gsl::narrow_cast<uint32_t>(m_sizes.size()));
assert(m_inputTensorDescs[0].GetDimensionCount() >= gsl::narrow_cast<uint32_t>(m_strides.size()));
// Pad the parameters to respect DML's requirements
m_offsets.insert(
m_offsets.begin(),
m_inputTensorDescs[0].GetDimensionCount() - gsl::narrow_cast<uint32_t>(m_offsets.size()),
0);
m_sizes.insert(
m_sizes.begin(),
m_inputTensorDescs[0].GetDimensionCount() - gsl::narrow_cast<uint32_t>(m_sizes.size()),
1);
m_strides.insert(
m_strides.begin(),
m_inputTensorDescs[0].GetDimensionCount() - gsl::narrow_cast<uint32_t>(m_strides.size()),
1);
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
DML_SLICE_OPERATOR_DESC sliceDesc = {};
sliceDesc.InputTensor = inputDescs.data();
sliceDesc.OutputTensor = outputDescs.data();
sliceDesc.DimensionCount = gsl::narrow_cast<uint32_t>(m_offsets.size());
sliceDesc.Offsets = m_offsets.data();
sliceDesc.Sizes = m_sizes.data();
sliceDesc.Strides = m_strides.data();
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_SLICE, &sliceDesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}
};
DML_OP_DEFINE_CREATION_FUNCTION(Slice, DmlOperatorSlice);
} // namespace Dml

Просмотреть файл

@ -0,0 +1,35 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
class DmlOperatorSpaceToDepth : public DmlOperator, public SpaceToDepthHelper
{
public:
DmlOperatorSpaceToDepth(const MLOperatorKernelCreationContext& kernelCreationContext)
: DmlOperator(kernelCreationContext),
SpaceToDepthHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription())
{
DmlOperator::Initialize(kernelCreationContext);
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
ML_CHECK_VALID_ARGUMENT(inputDescs.size() == 1);
ML_CHECK_VALID_ARGUMENT(outputDescs.size() == 1);
DML_SPACE_TO_DEPTH_OPERATOR_DESC operatorDesc = {};
operatorDesc.InputTensor = inputDescs.data();
operatorDesc.OutputTensor = outputDescs.data();
operatorDesc.BlockSize = m_blockSize;
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_SPACE_TO_DEPTH, &operatorDesc };
SetDmlOperatorDesc(opDesc, kernelCreationContext);
}
};
DML_OP_DEFINE_CREATION_FUNCTION(SpaceToDepth, DmlOperatorSpaceToDepth);
} // namespace Dml

Просмотреть файл

@ -0,0 +1,41 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
class DmlOperatorSplit : public DmlOperator, public SplitHelper
{
public:
using Self = DmlOperatorSplit;
DmlOperatorSplit(const MLOperatorKernelCreationContext& kernelInfo)
: DmlOperator(kernelInfo),
SplitHelper(kernelInfo, kernelInfo.GetTensorShapeDescription())
{
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 1, "DML only supports split on a single input tensor.");
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() > 0, "Runtime error no output stream specified.");
DmlOperator::Initialize(kernelInfo);
uint32_t dmlAxis = GetDmlAdjustedAxis(m_axis, kernelInfo, m_inputTensorDescs.front().GetDimensionCount());
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
DML_SPLIT_OPERATOR_DESC splitDesc = {};
splitDesc.InputTensor = inputDescs.data();
splitDesc.OutputTensors = outputDescs.data();
splitDesc.OutputCount = gsl::narrow_cast<uint32_t>(outputDescs.size());
splitDesc.Axis = dmlAxis;
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_SPLIT, &splitDesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}
};
DML_OP_DEFINE_CREATION_FUNCTION(Split, DmlOperatorSplit);
} // namespace Dml

Просмотреть файл

@ -0,0 +1,73 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
class DmlOperatorTile : public DmlOperator, TileHelper
{
public:
using Self = DmlOperatorTile;
DmlOperatorTile(const MLOperatorKernelCreationContext& kernelCreationContext)
: DmlOperator(kernelCreationContext),
TileHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription())
{
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 2, "Tile expects 2 input tensors.");
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1, "Tile expects 1 output tensor.");
std::vector<std::optional<uint32_t>> inputIndices = { 0 }; // Use only the first tensor. The second tensor is CPU based and should not be passed to Tile.
std::vector<std::optional<uint32_t>> outputIndices = { 0 };
DmlOperator::Initialize(kernelCreationContext, inputIndices, outputIndices);
// Because DirectML supports a limited number of dimensions, try to squeeze the dimension count
// to only those which actually matter. Models sometimes use a greater number of dimensions,
// even though those dimensions have no significance and can be elided (nop 1's), coercing the
// total dimension count back down to a supported value.
std::vector<uint32_t> squeezedInputShape = m_inputDimensions;
std::vector<uint32_t> squeezedOutputShape = m_outputDimensions;
std::vector<uint32_t> squeezableDimensionIndices;
std::vector<uint32_t> paddedRepeatsData = m_repeatsData;
FindValueIndices<uint32_t>(gsl::make_span(squeezedOutputShape), 1u, /*out*/ squeezableDimensionIndices);
RemoveValuesByIndex(squeezableDimensionIndices, /*keepOneValue*/ true, /*inout*/ squeezedInputShape);
RemoveValuesByIndex(squeezableDimensionIndices, /*keepOneValue*/ true, /*inout*/ paddedRepeatsData);
RemoveValuesByIndex(squeezableDimensionIndices, /*keepOneValue*/ true, /*inout*/ squeezedOutputShape);
// Update the tensor descriptions.
MLOperatorTensorDataType inputTensorDataType = kernelCreationContext.GetInputEdgeDescription(0).tensorDataType;
auto inputTensorDesc = TensorDesc(inputTensorDataType, squeezedInputShape, squeezedInputShape, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, NchwDimensionCount, 0);
auto outputTensorDesc = TensorDesc(inputTensorDataType, squeezedOutputShape, squeezedOutputShape, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, NchwDimensionCount, 0);
m_inputTensorDescs[0] = inputTensorDesc;
m_outputTensorDescs[0] = outputTensorDesc;
// If the output tensor dimension count was right-aligned to a larger size,
// then ensure that repeat counts have the same count as the tensor rank by
// inserting leading ones, since DirectML requires them to have the same count.
const uint32_t squeezedDimCount = gsl::narrow_cast<uint32_t>(squeezedOutputShape.size());
const uint32_t dmlCompatibleDimCount = outputTensorDesc.GetDimensionCount();
if (dmlCompatibleDimCount > squeezedDimCount)
{
paddedRepeatsData.insert(paddedRepeatsData.begin(), dmlCompatibleDimCount - squeezedDimCount, 1);
}
// Create the operator description.
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
DML_TILE_OPERATOR_DESC operatorDesc = {};
operatorDesc.InputTensor = inputDescs.data();
operatorDesc.OutputTensor = outputDescs.data();
operatorDesc.RepeatsCount = gsl::narrow_cast<uint32_t>(paddedRepeatsData.size());
operatorDesc.Repeats = paddedRepeatsData.data();
SetDmlOperatorDesc({ DML_OPERATOR_TILE, &operatorDesc }, kernelCreationContext);
}
};
DML_OP_DEFINE_CREATION_FUNCTION(Tile, DmlOperatorTile);
} // namespace Dml

Просмотреть файл

@ -0,0 +1,63 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
class DmlOperatorTopK : public DmlOperator, public TopKHelper
{
public:
using Self = DmlOperatorTopK;
DmlOperatorTopK(const MLOperatorKernelCreationContext& kernelCreationContext)
: DmlOperator(kernelCreationContext),
TopKHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription())
{
DmlOperator::Initialize(kernelCreationContext);
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
ML_CHECK_VALID_ARGUMENT(inputDescs.size() == 1);
ML_CHECK_VALID_ARGUMENT(outputDescs.size() == 2);
uint32_t dmlAxis = GetDmlAdjustedAxis(m_axis, kernelCreationContext, m_inputTensorDescs.front().GetDimensionCount());
DML_TOP_K_OPERATOR_DESC operatorDesc = {};
operatorDesc.InputTensor = inputDescs.data();
operatorDesc.OutputValueTensor = &outputDescs[0];
operatorDesc.OutputIndexTensor = &outputDescs[1];
operatorDesc.Axis = dmlAxis;
operatorDesc.K = m_k;
// Index tensor is always of type int64. We need to create an extra DML operator to
// initialize the tensor data.
m_zeroOperator = InitializeZeroInt64Tensor(m_outputTensorDescs[1].GetBufferSizeInBytes());
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_TOP_K, &operatorDesc };
SetDmlOperatorDesc(opDesc, kernelCreationContext);
}
void Compute(const MLOperatorKernelContext& kernelContext) override
{
std::vector<IMLOperatorTensor*> inputTensors = GetInputTensorsForExecute(kernelContext);
std::vector<IMLOperatorTensor*> outputTensors = GetOutputTensorsForExecute(kernelContext);
ExecuteZeroInt64Tensor(m_zeroOperator.Get(), outputTensors[1]);
THROW_IF_FAILED(m_executionProvider->ExecuteOperator(
m_compiledOperator.Get(),
m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr,
gsl::make_span(inputTensors),
gsl::make_span(outputTensors)
));
}
private:
ComPtr<IDMLCompiledOperator> m_zeroOperator;
};
DML_OP_DEFINE_CREATION_FUNCTION(TopK, DmlOperatorTopK);
} // namespace Dml

Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше