Upgrade dxdispatch to ORT 1.18 (#584)
* upgrade ort * fl6.3 ops * update version
This commit is contained in:
Родитель
3e69d6a4c6
Коммит
e3a75c1f58
|
@ -1,5 +1,5 @@
|
|||
cmake_minimum_required(VERSION 3.19)
|
||||
project(dxdispatch VERSION 0.17.3 LANGUAGES CXX)
|
||||
project(dxdispatch VERSION 0.18.0 LANGUAGES CXX)
|
||||
|
||||
# ==============================================================================
|
||||
# External Libraries/Helpers
|
||||
|
|
|
@ -48,13 +48,13 @@ function(init_directml_cache_variables prefix)
|
|||
|
||||
# <PREFIX>_DIRECTML_NUGET_VERSION
|
||||
set(${prefix}_DIRECTML_NUGET_VERSION
|
||||
1.13.1
|
||||
1.14.2
|
||||
CACHE STRING "Version of the DirectML NuGet package (TYPE == nuget)."
|
||||
)
|
||||
|
||||
# <PREFIX>_DIRECTML_NUGET_HASH
|
||||
set(${prefix}_DIRECTML_NUGET_HASH
|
||||
a38cef0d59f314fbcc0cd6551c5a762db7cfdaf8a977f85df32a0b1e279d3ba7
|
||||
09253C0FB45E8A03313B6EB41ABD21A5607B8996B63AAE4A93A420EA1E5BD1AB
|
||||
CACHE STRING "SHA256 hash of the DirectML NuGet package (TYPE == nuget)."
|
||||
)
|
||||
|
||||
|
|
|
@ -52,13 +52,13 @@ function(init_onnxruntime_cache_variables prefix)
|
|||
|
||||
# <PREFIX>_ONNXRUNTIME_NUGET_VERSION
|
||||
set(${prefix}_ONNXRUNTIME_NUGET_VERSION
|
||||
1.17.1
|
||||
1.18.0
|
||||
CACHE STRING "Version of the ONNX Runtime NuGet package (TYPE == nuget)."
|
||||
)
|
||||
|
||||
# <PREFIX>_ONNXRUNTIME_NUGET_HASH
|
||||
set(${prefix}_ONNXRUNTIME_NUGET_HASH
|
||||
834E9F02A348BE0AE0FDF0E71DF59661B64072E6F89FD6DA19BCF74765F6574E
|
||||
16D73AF3FC1EDD8392E8B6843FDEA281E89EE68A0C78DDE6325C30D20080EEE5
|
||||
CACHE STRING "SHA256 hash of the ONNX Runtime NuGet package (TYPE == nuget)."
|
||||
)
|
||||
|
||||
|
|
|
@ -23,6 +23,8 @@ DML_TENSOR_DATA_TYPE ParseDmlTensorDataType(const rapidjson::Value& value)
|
|||
if (!strcmp(valueString, "DML_TENSOR_DATA_TYPE_FLOAT64") || !strcmp(valueString, "FLOAT64")) { return DML_TENSOR_DATA_TYPE_FLOAT64; }
|
||||
if (!strcmp(valueString, "DML_TENSOR_DATA_TYPE_UINT64") || !strcmp(valueString, "UINT64")) { return DML_TENSOR_DATA_TYPE_UINT64; }
|
||||
if (!strcmp(valueString, "DML_TENSOR_DATA_TYPE_INT64") || !strcmp(valueString, "INT64")) { return DML_TENSOR_DATA_TYPE_INT64; }
|
||||
if (!strcmp(valueString, "DML_TENSOR_DATA_TYPE_UINT4") || !strcmp(valueString, "UINT4")) { return DML_TENSOR_DATA_TYPE_UINT4; }
|
||||
if (!strcmp(valueString, "DML_TENSOR_DATA_TYPE_INT4") || !strcmp(valueString, "INT4")) { return DML_TENSOR_DATA_TYPE_INT4; }
|
||||
throw std::invalid_argument(fmt::format("'{}' is not a recognized value for DML_TENSOR_DATA_TYPE.", valueString));
|
||||
}
|
||||
|
||||
|
@ -229,6 +231,10 @@ DML_OPERATOR_TYPE ParseDmlOperatorType(const rapidjson::Value& value)
|
|||
if (!strcmp(valueString, "DML_OPERATOR_MULTIHEAD_ATTENTION") || !strcmp(valueString, "MULTIHEAD_ATTENTION")) { return DML_OPERATOR_MULTIHEAD_ATTENTION; }
|
||||
if (!strcmp(valueString, "DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING") || !strcmp(valueString, "QUANTIZED_LINEAR_AVERAGE_POOLING")) { return DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING; }
|
||||
if (!strcmp(valueString, "DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT") || !strcmp(valueString, "MATRIX_MULTIPLY_INTEGER_TO_FLOAT")) { return DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT; }
|
||||
if (!strcmp(valueString, "DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION2") || !strcmp(valueString, "MEAN_VARIANCE_NORMALIZATION2")) { return DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION2; }
|
||||
if (!strcmp(valueString, "DML_OPERATOR_MULTIHEAD_ATTENTION1") || !strcmp(valueString, "MULTIHEAD_ATTENTION1")) { return DML_OPERATOR_MULTIHEAD_ATTENTION1; }
|
||||
if (!strcmp(valueString, "DML_OPERATOR_QUANTIZE") || !strcmp(valueString, "QUANTIZE")) { return DML_OPERATOR_QUANTIZE; }
|
||||
if (!strcmp(valueString, "DML_OPERATOR_DEQUANTIZE") || !strcmp(valueString, "DEQUANTIZE")) { return DML_OPERATOR_DEQUANTIZE; }
|
||||
throw std::invalid_argument(fmt::format("'{}' is not a recognized value for DML_OPERATOR_TYPE.", valueString));
|
||||
}
|
||||
|
||||
|
@ -444,6 +450,7 @@ DML_FEATURE_LEVEL ParseDmlFeatureLevel(const rapidjson::Value& value)
|
|||
if (!strcmp(valueString, "DML_FEATURE_LEVEL_6_0") || !strcmp(valueString, "6_0")) { return DML_FEATURE_LEVEL_6_0; }
|
||||
if (!strcmp(valueString, "DML_FEATURE_LEVEL_6_1") || !strcmp(valueString, "6_1")) { return DML_FEATURE_LEVEL_6_1; }
|
||||
if (!strcmp(valueString, "DML_FEATURE_LEVEL_6_2") || !strcmp(valueString, "6_2")) { return DML_FEATURE_LEVEL_6_2; }
|
||||
if (!strcmp(valueString, "DML_FEATURE_LEVEL_6_3") || !strcmp(valueString, "6_3")) { return DML_FEATURE_LEVEL_6_3; }
|
||||
throw std::invalid_argument(fmt::format("'{}' is not a recognized value for DML_FEATURE_LEVEL.", valueString));
|
||||
}
|
||||
|
||||
|
@ -572,6 +579,26 @@ DML_MULTIHEAD_ATTENTION_MASK_TYPE ParseDmlMultiheadAttentionMaskTypeField(const
|
|||
});
|
||||
}
|
||||
|
||||
DML_QUANTIZATION_TYPE ParseDmlQuantizationType(const rapidjson::Value& value)
|
||||
{
|
||||
if (value.GetType() != rapidjson::Type::kStringType)
|
||||
{
|
||||
throw std::invalid_argument("DML_QUANTIZATION_TYPE must be a string.");
|
||||
}
|
||||
auto valueString = value.GetString();
|
||||
if (!strcmp(valueString, "DML_QUANTIZATION_TYPE_NONE") || !strcmp(valueString, "NONE")) { return DML_QUANTIZATION_TYPE_NONE; }
|
||||
if (!strcmp(valueString, "DML_QUANTIZATION_TYPE_SCALE") || !strcmp(valueString, "SCALE")) { return DML_QUANTIZATION_TYPE_SCALE; }
|
||||
if (!strcmp(valueString, "DML_QUANTIZATION_TYPE_SCALE_ZERO_POINT") || !strcmp(valueString, "SCALE_ZERO_POINT")) { return DML_QUANTIZATION_TYPE_SCALE_ZERO_POINT; }
|
||||
throw std::invalid_argument(fmt::format("'{}' is not a recognized value for DML_QUANTIZATION_TYPE.", valueString));
|
||||
}
|
||||
|
||||
DML_QUANTIZATION_TYPE ParseDmlQuantizationTypeField(const rapidjson::Value& object, std::string_view fieldName, bool required, DML_QUANTIZATION_TYPE defaultValue)
|
||||
{
|
||||
return ParseFieldHelper<DML_QUANTIZATION_TYPE>(object, fieldName, required, defaultValue, [](auto& value){
|
||||
return ParseDmlQuantizationType(value);
|
||||
});
|
||||
}
|
||||
|
||||
// ====================================================================================================
|
||||
// DIRECTML FLAGS
|
||||
// ====================================================================================================
|
||||
|
@ -4187,6 +4214,135 @@ Model::DmlDispatchableDesc::BindPoints GetBindPoints(const DML_MATRIX_MULTIPLY_I
|
|||
return bindPoints;
|
||||
}
|
||||
|
||||
DML_OPERATOR_DESC* ParseDmlMeanVarianceNormalization2OperatorDesc(const rapidjson::Value& value, bool fused, BucketAllocator& allocator)
|
||||
{
|
||||
if (!value.IsObject()) { throw std::invalid_argument("Expected a valid JSON object."); }
|
||||
auto desc = allocator.Allocate<DML_MEAN_VARIANCE_NORMALIZATION2_OPERATOR_DESC>();
|
||||
desc->InputTensor = fused ? nullptr : ParseDmlTensorDescField(value, "InputTensor", allocator, true);
|
||||
desc->ScaleTensor = fused ? nullptr : ParseDmlTensorDescField(value, "ScaleTensor", allocator, false);
|
||||
desc->BiasTensor = fused ? nullptr : ParseDmlTensorDescField(value, "BiasTensor", allocator, false);
|
||||
desc->OutputTensor = fused ? nullptr : ParseDmlTensorDescField(value, "OutputTensor", allocator, true);
|
||||
desc->AxisCount = ParseUInt32Field(value, "AxisCount", true);
|
||||
desc->Axes = AsPointer(ParseUInt32ArrayField(value, "Axes", allocator, true));
|
||||
desc->UseMean = ParseBoolField(value, "UseMean", true) ? 1 : 0;
|
||||
desc->UseVariance = ParseBoolField(value, "UseVariance", true) ? 1 : 0;
|
||||
desc->Epsilon = ParseFloat32Field(value, "Epsilon", true);
|
||||
desc->FusedActivation = ParseDmlOperatorDescField(value, "FusedActivation", true, allocator, false);
|
||||
auto opDesc = allocator.Allocate<DML_OPERATOR_DESC>();
|
||||
opDesc->Type = DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION2;
|
||||
opDesc->Desc = desc;
|
||||
return opDesc;
|
||||
}
|
||||
|
||||
Model::DmlDispatchableDesc::BindPoints GetBindPoints(const DML_MEAN_VARIANCE_NORMALIZATION2_OPERATOR_DESC& desc)
|
||||
{
|
||||
Model::DmlDispatchableDesc::BindPoints bindPoints = {};
|
||||
bindPoints.inputs.push_back({"InputTensor", 1, true});
|
||||
bindPoints.inputs.push_back({"ScaleTensor", 1, false});
|
||||
bindPoints.inputs.push_back({"BiasTensor", 1, false});
|
||||
bindPoints.outputs.push_back({"OutputTensor", 1, true});
|
||||
return bindPoints;
|
||||
}
|
||||
|
||||
DML_OPERATOR_DESC* ParseDmlMultiheadAttention1OperatorDesc(const rapidjson::Value& value, bool fused, BucketAllocator& allocator)
|
||||
{
|
||||
if (!value.IsObject()) { throw std::invalid_argument("Expected a valid JSON object."); }
|
||||
auto desc = allocator.Allocate<DML_MULTIHEAD_ATTENTION1_OPERATOR_DESC>();
|
||||
desc->QueryTensor = fused ? nullptr : ParseDmlTensorDescField(value, "QueryTensor", allocator, false);
|
||||
desc->KeyTensor = fused ? nullptr : ParseDmlTensorDescField(value, "KeyTensor", allocator, false);
|
||||
desc->ValueTensor = fused ? nullptr : ParseDmlTensorDescField(value, "ValueTensor", allocator, false);
|
||||
desc->StackedQueryKeyTensor = fused ? nullptr : ParseDmlTensorDescField(value, "StackedQueryKeyTensor", allocator, false);
|
||||
desc->StackedKeyValueTensor = fused ? nullptr : ParseDmlTensorDescField(value, "StackedKeyValueTensor", allocator, false);
|
||||
desc->StackedQueryKeyValueTensor = fused ? nullptr : ParseDmlTensorDescField(value, "StackedQueryKeyValueTensor", allocator, false);
|
||||
desc->BiasTensor = fused ? nullptr : ParseDmlTensorDescField(value, "BiasTensor", allocator, false);
|
||||
desc->MaskTensor = fused ? nullptr : ParseDmlTensorDescField(value, "MaskTensor", allocator, false);
|
||||
desc->RelativePositionBiasTensor = fused ? nullptr : ParseDmlTensorDescField(value, "RelativePositionBiasTensor", allocator, false);
|
||||
desc->PastKeyTensor = fused ? nullptr : ParseDmlTensorDescField(value, "PastKeyTensor", allocator, false);
|
||||
desc->PastValueTensor = fused ? nullptr : ParseDmlTensorDescField(value, "PastValueTensor", allocator, false);
|
||||
desc->PastSequenceLengthsTensor = fused ? nullptr : ParseDmlTensorDescField(value, "PastSequenceLengthsTensor", allocator, false);
|
||||
desc->OutputTensor = fused ? nullptr : ParseDmlTensorDescField(value, "OutputTensor", allocator, true);
|
||||
desc->OutputPresentKeyTensor = fused ? nullptr : ParseDmlTensorDescField(value, "OutputPresentKeyTensor", allocator, false);
|
||||
desc->OutputPresentValueTensor = fused ? nullptr : ParseDmlTensorDescField(value, "OutputPresentValueTensor", allocator, false);
|
||||
desc->Scale = ParseFloat32Field(value, "Scale", true);
|
||||
desc->MaskFilterValue = ParseFloat32Field(value, "MaskFilterValue", true);
|
||||
desc->QueryHeadCount = ParseUInt32Field(value, "QueryHeadCount", true);
|
||||
desc->KeyValueHeadCount = ParseUInt32Field(value, "KeyValueHeadCount", true);
|
||||
desc->MaskType = ParseDmlMultiheadAttentionMaskTypeField(value, "MaskType", true, {});
|
||||
auto opDesc = allocator.Allocate<DML_OPERATOR_DESC>();
|
||||
opDesc->Type = DML_OPERATOR_MULTIHEAD_ATTENTION1;
|
||||
opDesc->Desc = desc;
|
||||
return opDesc;
|
||||
}
|
||||
|
||||
Model::DmlDispatchableDesc::BindPoints GetBindPoints(const DML_MULTIHEAD_ATTENTION1_OPERATOR_DESC& desc)
|
||||
{
|
||||
Model::DmlDispatchableDesc::BindPoints bindPoints = {};
|
||||
bindPoints.inputs.push_back({"QueryTensor", 1, false});
|
||||
bindPoints.inputs.push_back({"KeyTensor", 1, false});
|
||||
bindPoints.inputs.push_back({"ValueTensor", 1, false});
|
||||
bindPoints.inputs.push_back({"StackedQueryKeyTensor", 1, false});
|
||||
bindPoints.inputs.push_back({"StackedKeyValueTensor", 1, false});
|
||||
bindPoints.inputs.push_back({"StackedQueryKeyValueTensor", 1, false});
|
||||
bindPoints.inputs.push_back({"BiasTensor", 1, false});
|
||||
bindPoints.inputs.push_back({"MaskTensor", 1, false});
|
||||
bindPoints.inputs.push_back({"RelativePositionBiasTensor", 1, false});
|
||||
bindPoints.inputs.push_back({"PastKeyTensor", 1, false});
|
||||
bindPoints.inputs.push_back({"PastValueTensor", 1, false});
|
||||
bindPoints.inputs.push_back({"PastSequenceLengthsTensor", 1, false});
|
||||
bindPoints.outputs.push_back({"OutputTensor", 1, true});
|
||||
bindPoints.outputs.push_back({"OutputPresentKeyTensor", 1, false});
|
||||
bindPoints.outputs.push_back({"OutputPresentValueTensor", 1, false});
|
||||
return bindPoints;
|
||||
}
|
||||
|
||||
DML_OPERATOR_DESC* ParseDmlQuantizeOperatorDesc(const rapidjson::Value& value, bool fused, BucketAllocator& allocator)
|
||||
{
|
||||
if (!value.IsObject()) { throw std::invalid_argument("Expected a valid JSON object."); }
|
||||
auto desc = allocator.Allocate<DML_QUANTIZE_OPERATOR_DESC>();
|
||||
desc->InputTensor = fused ? nullptr : ParseDmlTensorDescField(value, "InputTensor", allocator, true);
|
||||
desc->QuantizationType = ParseDmlQuantizationTypeField(value, "QuantizationType", true, {});
|
||||
desc->QuantizationTensorCount = ParseUInt32Field(value, "QuantizationTensorCount", true);
|
||||
desc->QuantizationTensors = fused ? nullptr : AsPointer(ParseDmlTensorDescArrayField(value, "QuantizationTensors", allocator, true));
|
||||
desc->OutputTensor = fused ? nullptr : ParseDmlTensorDescField(value, "OutputTensor", allocator, true);
|
||||
auto opDesc = allocator.Allocate<DML_OPERATOR_DESC>();
|
||||
opDesc->Type = DML_OPERATOR_QUANTIZE;
|
||||
opDesc->Desc = desc;
|
||||
return opDesc;
|
||||
}
|
||||
|
||||
Model::DmlDispatchableDesc::BindPoints GetBindPoints(const DML_QUANTIZE_OPERATOR_DESC& desc)
|
||||
{
|
||||
Model::DmlDispatchableDesc::BindPoints bindPoints = {};
|
||||
bindPoints.inputs.push_back({"InputTensor", 1, true});
|
||||
bindPoints.inputs.push_back({"QuantizationTensors", desc.QuantizationTensorCount, true});
|
||||
bindPoints.outputs.push_back({"OutputTensor", 1, true});
|
||||
return bindPoints;
|
||||
}
|
||||
|
||||
DML_OPERATOR_DESC* ParseDmlDequantizeOperatorDesc(const rapidjson::Value& value, bool fused, BucketAllocator& allocator)
|
||||
{
|
||||
if (!value.IsObject()) { throw std::invalid_argument("Expected a valid JSON object."); }
|
||||
auto desc = allocator.Allocate<DML_DEQUANTIZE_OPERATOR_DESC>();
|
||||
desc->InputTensor = fused ? nullptr : ParseDmlTensorDescField(value, "InputTensor", allocator, true);
|
||||
desc->QuantizationType = ParseDmlQuantizationTypeField(value, "QuantizationType", true, {});
|
||||
desc->QuantizationTensorCount = ParseUInt32Field(value, "QuantizationTensorCount", true);
|
||||
desc->QuantizationTensors = fused ? nullptr : AsPointer(ParseDmlTensorDescArrayField(value, "QuantizationTensors", allocator, true));
|
||||
desc->OutputTensor = fused ? nullptr : ParseDmlTensorDescField(value, "OutputTensor", allocator, true);
|
||||
auto opDesc = allocator.Allocate<DML_OPERATOR_DESC>();
|
||||
opDesc->Type = DML_OPERATOR_DEQUANTIZE;
|
||||
opDesc->Desc = desc;
|
||||
return opDesc;
|
||||
}
|
||||
|
||||
Model::DmlDispatchableDesc::BindPoints GetBindPoints(const DML_DEQUANTIZE_OPERATOR_DESC& desc)
|
||||
{
|
||||
Model::DmlDispatchableDesc::BindPoints bindPoints = {};
|
||||
bindPoints.inputs.push_back({"InputTensor", 1, true});
|
||||
bindPoints.inputs.push_back({"QuantizationTensors", desc.QuantizationTensorCount, true});
|
||||
bindPoints.outputs.push_back({"OutputTensor", 1, true});
|
||||
return bindPoints;
|
||||
}
|
||||
|
||||
DML_OPERATOR_DESC* ParseDmlActivationEluOperatorDesc(const rapidjson::Value& value, bool fused, BucketAllocator& allocator)
|
||||
{
|
||||
if (!value.IsObject()) { throw std::invalid_argument("Expected a valid JSON object."); }
|
||||
|
@ -4905,6 +5061,10 @@ DML_OPERATOR_DESC* ParseDmlOperatorDesc(const rapidjson::Value& value, bool fuse
|
|||
if (!strcmp(type, "DML_OPERATOR_MULTIHEAD_ATTENTION") || !strcmp(type, "MULTIHEAD_ATTENTION")) return ParseDmlMultiheadAttentionOperatorDesc(descValue, fused, allocator);
|
||||
if (!strcmp(type, "DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING") || !strcmp(type, "QUANTIZED_LINEAR_AVERAGE_POOLING")) return ParseDmlQuantizedLinearAveragePoolingOperatorDesc(descValue, fused, allocator);
|
||||
if (!strcmp(type, "DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT") || !strcmp(type, "MATRIX_MULTIPLY_INTEGER_TO_FLOAT")) return ParseDmlMatrixMultiplyIntegerToFloatOperatorDesc(descValue, fused, allocator);
|
||||
if (!strcmp(type, "DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION2") || !strcmp(type, "MEAN_VARIANCE_NORMALIZATION2")) return ParseDmlMeanVarianceNormalization2OperatorDesc(descValue, fused, allocator);
|
||||
if (!strcmp(type, "DML_OPERATOR_MULTIHEAD_ATTENTION1") || !strcmp(type, "MULTIHEAD_ATTENTION1")) return ParseDmlMultiheadAttention1OperatorDesc(descValue, fused, allocator);
|
||||
if (!strcmp(type, "DML_OPERATOR_QUANTIZE") || !strcmp(type, "QUANTIZE")) return ParseDmlQuantizeOperatorDesc(descValue, fused, allocator);
|
||||
if (!strcmp(type, "DML_OPERATOR_DEQUANTIZE") || !strcmp(type, "DEQUANTIZE")) return ParseDmlDequantizeOperatorDesc(descValue, fused, allocator);
|
||||
if (!strcmp(type, "DML_OPERATOR_ACTIVATION_ELU") || !strcmp(type, "ACTIVATION_ELU")) return ParseDmlActivationEluOperatorDesc(descValue, fused, allocator);
|
||||
if (!strcmp(type, "DML_OPERATOR_ACTIVATION_CELU") || !strcmp(type, "ACTIVATION_CELU")) return ParseDmlActivationCeluOperatorDesc(descValue, fused, allocator);
|
||||
if (!strcmp(type, "DML_OPERATOR_ACTIVATION_HARDMAX") || !strcmp(type, "ACTIVATION_HARDMAX")) return ParseDmlActivationHardmaxOperatorDesc(descValue, fused, allocator);
|
||||
|
@ -5082,6 +5242,10 @@ Model::DmlDispatchableDesc::BindPoints GetBindPoints(const DML_OPERATOR_DESC& de
|
|||
case DML_OPERATOR_MULTIHEAD_ATTENTION: return GetBindPoints(*reinterpret_cast<const DML_MULTIHEAD_ATTENTION_OPERATOR_DESC*>(desc.Desc));
|
||||
case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: return GetBindPoints(*reinterpret_cast<const DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC*>(desc.Desc));
|
||||
case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: return GetBindPoints(*reinterpret_cast<const DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC*>(desc.Desc));
|
||||
case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION2: return GetBindPoints(*reinterpret_cast<const DML_MEAN_VARIANCE_NORMALIZATION2_OPERATOR_DESC*>(desc.Desc));
|
||||
case DML_OPERATOR_MULTIHEAD_ATTENTION1: return GetBindPoints(*reinterpret_cast<const DML_MULTIHEAD_ATTENTION1_OPERATOR_DESC*>(desc.Desc));
|
||||
case DML_OPERATOR_QUANTIZE: return GetBindPoints(*reinterpret_cast<const DML_QUANTIZE_OPERATOR_DESC*>(desc.Desc));
|
||||
case DML_OPERATOR_DEQUANTIZE: return GetBindPoints(*reinterpret_cast<const DML_DEQUANTIZE_OPERATOR_DESC*>(desc.Desc));
|
||||
case DML_OPERATOR_ACTIVATION_ELU: return GetBindPoints(*reinterpret_cast<const DML_ACTIVATION_ELU_OPERATOR_DESC*>(desc.Desc));
|
||||
case DML_OPERATOR_ACTIVATION_CELU: return GetBindPoints(*reinterpret_cast<const DML_ACTIVATION_CELU_OPERATOR_DESC*>(desc.Desc));
|
||||
case DML_OPERATOR_ACTIVATION_HARDMAX: return GetBindPoints(*reinterpret_cast<const DML_ACTIVATION_HARDMAX_OPERATOR_DESC*>(desc.Desc));
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
param
|
||||
(
|
||||
[string]$SchemaFilePath = "$PSScriptRoot\DmlSchema.json",
|
||||
[string]$MaxFeatureLevel = "6.2"
|
||||
[string]$MaxFeatureLevel = "6.3"
|
||||
)
|
||||
|
||||
function ConvertSnakeToCamelCase($SnakeCaseName)
|
||||
|
|
Загрузка…
Ссылка в новой задаче