[DML EP] Add DML implementation for BiasGelu (#13795)
### Description Add DML implementation for BiasGelu
This commit is contained in:
Родитель
e0dcbc3832
Коммит
e9b92fdf33
|
@ -1118,6 +1118,7 @@ Do not modify directly.*
|
|||
| |
|
||||
|**Operator Domain:** *com.microsoft*||||
|
||||
|Attention|*in* input:**T**<br> *in* weights:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* extra_add:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *out* output:**T**<br> *out* present:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
|
||||
|BiasGelu|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|ConvTransposeWithDynamicPads|*in* X:**T**<br> *in* W:**T**<br> *in* Pads:**tensor(int64)**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|DequantizeLinear|*in* x:**T1**<br> *in* x_scale:**T2**<br> *in* x_zero_point:**T1**<br> *out* y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)|
|
||||
|FusedMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "precomp.h"
|
||||
|
||||
namespace Dml
|
||||
{
|
||||
|
||||
class DmlOperatorBiasGelu : public DmlOperator
|
||||
{
|
||||
public:
|
||||
DmlOperatorBiasGelu(const MLOperatorKernelCreationContext& kernelCreationContext)
|
||||
: DmlOperator(kernelCreationContext)
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 2);
|
||||
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1);
|
||||
|
||||
// Broadcast bias to have the same dimensions as the input
|
||||
std::vector<uint32_t> inputTensorShape = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(0);
|
||||
DmlOperator::Initialize(kernelCreationContext, std::nullopt, std::nullopt, inputTensorShape);
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
ML_CHECK_VALID_ARGUMENT(inputDescs.size() == 2);
|
||||
ML_CHECK_VALID_ARGUMENT(outputDescs.size() == 1);
|
||||
|
||||
TensorDesc biasInputTensorDesc(m_inputTensorDescs[0].GetDmlDataType(), m_inputTensorDescs[0].GetSizes());
|
||||
DML_TENSOR_DESC biasInputDmlTensorDesc = biasInputTensorDesc.GetDmlDesc();
|
||||
|
||||
DML_ACTIVATION_GELU_OPERATOR_DESC geluDesc = {};
|
||||
DML_OPERATOR_DESC geluOpDesc = { DML_OPERATOR_ACTIVATION_GELU, &geluDesc };
|
||||
|
||||
DML_ELEMENT_WISE_ADD1_OPERATOR_DESC addDesc = {};
|
||||
addDesc.ATensor = &inputDescs[0];
|
||||
addDesc.BTensor = &inputDescs[1];
|
||||
addDesc.FusedActivation = &geluOpDesc;
|
||||
addDesc.OutputTensor = &outputDescs[0];
|
||||
DML_OPERATOR_DESC addOpDesc = { DML_OPERATOR_ELEMENT_WISE_ADD1, &addDesc };
|
||||
|
||||
SetDmlOperatorDesc(addOpDesc, kernelCreationContext);
|
||||
}
|
||||
};
|
||||
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(BiasGelu, DmlOperatorBiasGelu);
|
||||
|
||||
} // namespace Dml
|
|
@ -233,6 +233,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Erf);
|
|||
DML_OP_EXTERN_CREATION_FUNCTION(Where);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Shrink);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Gelu);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(BiasGelu);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(OneHot);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(EyeLike);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(MaxUnpool);
|
||||
|
@ -714,6 +715,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
|
|||
|
||||
// Contrib operators
|
||||
{REG_INFO_MS( 1, Gelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MS( 1, BiasGelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MS( 1, FusedMatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MS( 1, QLinearSigmoid, typeNameListDefault, supportedTypeListQLinearSigmoid, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryQLinearSigmoid)},
|
||||
{REG_INFO_MS( 1, Attention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryAttention)},
|
||||
|
|
|
@ -1557,6 +1557,7 @@ using ShapeInferenceHelper_ParametricSoftplus = GetOutputShapeAsInputShapeHelper
|
|||
using ShapeInferenceHelper_Dropout = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_Shrink = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_Gelu = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_BiasGelu = GetOutputShapeAsInputShapeHelper;
|
||||
|
||||
using ShapeInferenceHelper_Identity7 = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_Identity13 = GetOutputShapeAsInputShapeHelper;
|
||||
|
|
|
@ -390,6 +390,7 @@ namespace OperatorHelper
|
|||
static const int sc_sinceVer_ConvTransposeWithDynamicPads = 1;
|
||||
static const int sc_sinceVer_QLinearAdd = 1;
|
||||
static const int sc_sinceVer_Gelu = 1;
|
||||
static const int sc_sinceVer_BiasGelu = 1;
|
||||
static const int sc_sinceVer_FusedMatMul = 1;
|
||||
static const int sc_sinceVer_QLinearSigmoid = 1;
|
||||
static const int sc_sinceVer_Attention = 1;
|
||||
|
|
|
@ -113,7 +113,7 @@ TEST(BiasGeluTest, Two_One_Dim) {
|
|||
RunBiasGeluTest(input_a_data, input_b_data, {2, 4}, {4});
|
||||
}
|
||||
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML)
|
||||
TEST(BiasGeluTest, Two_One_Dim_fp16) {
|
||||
#ifdef USE_CUDA
|
||||
int min_cuda_architecture = 530;
|
||||
|
@ -187,6 +187,8 @@ TEST(BiasGeluTest, Two_One_Dim_bfloat16) {
|
|||
execution_providers.push_back(DefaultRocmExecutionProvider());
|
||||
#elif USE_DNNL
|
||||
execution_providers.push_back(DefaultDnnlExecutionProvider());
|
||||
#elif USE_DML
|
||||
execution_providers.push_back(DefaultDmlExecutionProvider());
|
||||
#endif
|
||||
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче