Upgrade dxdispatch to ORT 1.18 (#584)

* upgrade ort

* fl6.3 ops

* update version
This commit is contained in:
Justin Stoecker 2024-05-21 22:09:53 -07:00 коммит произвёл GitHub
Родитель 3e69d6a4c6
Коммит e3a75c1f58
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
5 изменённых файлов: 170 добавлений и 6 удалений

Просмотреть файл

@ -1,5 +1,5 @@
cmake_minimum_required(VERSION 3.19)
project(dxdispatch VERSION 0.17.3 LANGUAGES CXX)
project(dxdispatch VERSION 0.18.0 LANGUAGES CXX)
# ==============================================================================
# External Libraries/Helpers

Просмотреть файл

@ -48,13 +48,13 @@ function(init_directml_cache_variables prefix)
# <PREFIX>_DIRECTML_NUGET_VERSION
set(${prefix}_DIRECTML_NUGET_VERSION
1.13.1
1.14.2
CACHE STRING "Version of the DirectML NuGet package (TYPE == nuget)."
)
# <PREFIX>_DIRECTML_NUGET_HASH
set(${prefix}_DIRECTML_NUGET_HASH
a38cef0d59f314fbcc0cd6551c5a762db7cfdaf8a977f85df32a0b1e279d3ba7
09253C0FB45E8A03313B6EB41ABD21A5607B8996B63AAE4A93A420EA1E5BD1AB
CACHE STRING "SHA256 hash of the DirectML NuGet package (TYPE == nuget)."
)

Просмотреть файл

@ -52,13 +52,13 @@ function(init_onnxruntime_cache_variables prefix)
# <PREFIX>_ONNXRUNTIME_NUGET_VERSION
set(${prefix}_ONNXRUNTIME_NUGET_VERSION
1.17.1
1.18.0
CACHE STRING "Version of the ONNX Runtime NuGet package (TYPE == nuget)."
)
# <PREFIX>_ONNXRUNTIME_NUGET_HASH
set(${prefix}_ONNXRUNTIME_NUGET_HASH
834E9F02A348BE0AE0FDF0E71DF59661B64072E6F89FD6DA19BCF74765F6574E
16D73AF3FC1EDD8392E8B6843FDEA281E89EE68A0C78DDE6325C30D20080EEE5
CACHE STRING "SHA256 hash of the ONNX Runtime NuGet package (TYPE == nuget)."
)

Просмотреть файл

@ -23,6 +23,8 @@ DML_TENSOR_DATA_TYPE ParseDmlTensorDataType(const rapidjson::Value& value)
if (!strcmp(valueString, "DML_TENSOR_DATA_TYPE_FLOAT64") || !strcmp(valueString, "FLOAT64")) { return DML_TENSOR_DATA_TYPE_FLOAT64; }
if (!strcmp(valueString, "DML_TENSOR_DATA_TYPE_UINT64") || !strcmp(valueString, "UINT64")) { return DML_TENSOR_DATA_TYPE_UINT64; }
if (!strcmp(valueString, "DML_TENSOR_DATA_TYPE_INT64") || !strcmp(valueString, "INT64")) { return DML_TENSOR_DATA_TYPE_INT64; }
if (!strcmp(valueString, "DML_TENSOR_DATA_TYPE_UINT4") || !strcmp(valueString, "UINT4")) { return DML_TENSOR_DATA_TYPE_UINT4; }
if (!strcmp(valueString, "DML_TENSOR_DATA_TYPE_INT4") || !strcmp(valueString, "INT4")) { return DML_TENSOR_DATA_TYPE_INT4; }
throw std::invalid_argument(fmt::format("'{}' is not a recognized value for DML_TENSOR_DATA_TYPE.", valueString));
}
@ -229,6 +231,10 @@ DML_OPERATOR_TYPE ParseDmlOperatorType(const rapidjson::Value& value)
if (!strcmp(valueString, "DML_OPERATOR_MULTIHEAD_ATTENTION") || !strcmp(valueString, "MULTIHEAD_ATTENTION")) { return DML_OPERATOR_MULTIHEAD_ATTENTION; }
if (!strcmp(valueString, "DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING") || !strcmp(valueString, "QUANTIZED_LINEAR_AVERAGE_POOLING")) { return DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING; }
if (!strcmp(valueString, "DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT") || !strcmp(valueString, "MATRIX_MULTIPLY_INTEGER_TO_FLOAT")) { return DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT; }
if (!strcmp(valueString, "DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION2") || !strcmp(valueString, "MEAN_VARIANCE_NORMALIZATION2")) { return DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION2; }
if (!strcmp(valueString, "DML_OPERATOR_MULTIHEAD_ATTENTION1") || !strcmp(valueString, "MULTIHEAD_ATTENTION1")) { return DML_OPERATOR_MULTIHEAD_ATTENTION1; }
if (!strcmp(valueString, "DML_OPERATOR_QUANTIZE") || !strcmp(valueString, "QUANTIZE")) { return DML_OPERATOR_QUANTIZE; }
if (!strcmp(valueString, "DML_OPERATOR_DEQUANTIZE") || !strcmp(valueString, "DEQUANTIZE")) { return DML_OPERATOR_DEQUANTIZE; }
throw std::invalid_argument(fmt::format("'{}' is not a recognized value for DML_OPERATOR_TYPE.", valueString));
}
@ -444,6 +450,7 @@ DML_FEATURE_LEVEL ParseDmlFeatureLevel(const rapidjson::Value& value)
if (!strcmp(valueString, "DML_FEATURE_LEVEL_6_0") || !strcmp(valueString, "6_0")) { return DML_FEATURE_LEVEL_6_0; }
if (!strcmp(valueString, "DML_FEATURE_LEVEL_6_1") || !strcmp(valueString, "6_1")) { return DML_FEATURE_LEVEL_6_1; }
if (!strcmp(valueString, "DML_FEATURE_LEVEL_6_2") || !strcmp(valueString, "6_2")) { return DML_FEATURE_LEVEL_6_2; }
if (!strcmp(valueString, "DML_FEATURE_LEVEL_6_3") || !strcmp(valueString, "6_3")) { return DML_FEATURE_LEVEL_6_3; }
throw std::invalid_argument(fmt::format("'{}' is not a recognized value for DML_FEATURE_LEVEL.", valueString));
}
@ -572,6 +579,26 @@ DML_MULTIHEAD_ATTENTION_MASK_TYPE ParseDmlMultiheadAttentionMaskTypeField(const
});
}
DML_QUANTIZATION_TYPE ParseDmlQuantizationType(const rapidjson::Value& value)
{
if (value.GetType() != rapidjson::Type::kStringType)
{
throw std::invalid_argument("DML_QUANTIZATION_TYPE must be a string.");
}
auto valueString = value.GetString();
if (!strcmp(valueString, "DML_QUANTIZATION_TYPE_NONE") || !strcmp(valueString, "NONE")) { return DML_QUANTIZATION_TYPE_NONE; }
if (!strcmp(valueString, "DML_QUANTIZATION_TYPE_SCALE") || !strcmp(valueString, "SCALE")) { return DML_QUANTIZATION_TYPE_SCALE; }
if (!strcmp(valueString, "DML_QUANTIZATION_TYPE_SCALE_ZERO_POINT") || !strcmp(valueString, "SCALE_ZERO_POINT")) { return DML_QUANTIZATION_TYPE_SCALE_ZERO_POINT; }
throw std::invalid_argument(fmt::format("'{}' is not a recognized value for DML_QUANTIZATION_TYPE.", valueString));
}
DML_QUANTIZATION_TYPE ParseDmlQuantizationTypeField(const rapidjson::Value& object, std::string_view fieldName, bool required, DML_QUANTIZATION_TYPE defaultValue)
{
return ParseFieldHelper<DML_QUANTIZATION_TYPE>(object, fieldName, required, defaultValue, [](auto& value){
return ParseDmlQuantizationType(value);
});
}
// ====================================================================================================
// DIRECTML FLAGS
// ====================================================================================================
@ -4187,6 +4214,135 @@ Model::DmlDispatchableDesc::BindPoints GetBindPoints(const DML_MATRIX_MULTIPLY_I
return bindPoints;
}
DML_OPERATOR_DESC* ParseDmlMeanVarianceNormalization2OperatorDesc(const rapidjson::Value& value, bool fused, BucketAllocator& allocator)
{
if (!value.IsObject()) { throw std::invalid_argument("Expected a valid JSON object."); }
auto desc = allocator.Allocate<DML_MEAN_VARIANCE_NORMALIZATION2_OPERATOR_DESC>();
desc->InputTensor = fused ? nullptr : ParseDmlTensorDescField(value, "InputTensor", allocator, true);
desc->ScaleTensor = fused ? nullptr : ParseDmlTensorDescField(value, "ScaleTensor", allocator, false);
desc->BiasTensor = fused ? nullptr : ParseDmlTensorDescField(value, "BiasTensor", allocator, false);
desc->OutputTensor = fused ? nullptr : ParseDmlTensorDescField(value, "OutputTensor", allocator, true);
desc->AxisCount = ParseUInt32Field(value, "AxisCount", true);
desc->Axes = AsPointer(ParseUInt32ArrayField(value, "Axes", allocator, true));
desc->UseMean = ParseBoolField(value, "UseMean", true) ? 1 : 0;
desc->UseVariance = ParseBoolField(value, "UseVariance", true) ? 1 : 0;
desc->Epsilon = ParseFloat32Field(value, "Epsilon", true);
desc->FusedActivation = ParseDmlOperatorDescField(value, "FusedActivation", true, allocator, false);
auto opDesc = allocator.Allocate<DML_OPERATOR_DESC>();
opDesc->Type = DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION2;
opDesc->Desc = desc;
return opDesc;
}
Model::DmlDispatchableDesc::BindPoints GetBindPoints(const DML_MEAN_VARIANCE_NORMALIZATION2_OPERATOR_DESC& desc)
{
Model::DmlDispatchableDesc::BindPoints bindPoints = {};
bindPoints.inputs.push_back({"InputTensor", 1, true});
bindPoints.inputs.push_back({"ScaleTensor", 1, false});
bindPoints.inputs.push_back({"BiasTensor", 1, false});
bindPoints.outputs.push_back({"OutputTensor", 1, true});
return bindPoints;
}
DML_OPERATOR_DESC* ParseDmlMultiheadAttention1OperatorDesc(const rapidjson::Value& value, bool fused, BucketAllocator& allocator)
{
if (!value.IsObject()) { throw std::invalid_argument("Expected a valid JSON object."); }
auto desc = allocator.Allocate<DML_MULTIHEAD_ATTENTION1_OPERATOR_DESC>();
desc->QueryTensor = fused ? nullptr : ParseDmlTensorDescField(value, "QueryTensor", allocator, false);
desc->KeyTensor = fused ? nullptr : ParseDmlTensorDescField(value, "KeyTensor", allocator, false);
desc->ValueTensor = fused ? nullptr : ParseDmlTensorDescField(value, "ValueTensor", allocator, false);
desc->StackedQueryKeyTensor = fused ? nullptr : ParseDmlTensorDescField(value, "StackedQueryKeyTensor", allocator, false);
desc->StackedKeyValueTensor = fused ? nullptr : ParseDmlTensorDescField(value, "StackedKeyValueTensor", allocator, false);
desc->StackedQueryKeyValueTensor = fused ? nullptr : ParseDmlTensorDescField(value, "StackedQueryKeyValueTensor", allocator, false);
desc->BiasTensor = fused ? nullptr : ParseDmlTensorDescField(value, "BiasTensor", allocator, false);
desc->MaskTensor = fused ? nullptr : ParseDmlTensorDescField(value, "MaskTensor", allocator, false);
desc->RelativePositionBiasTensor = fused ? nullptr : ParseDmlTensorDescField(value, "RelativePositionBiasTensor", allocator, false);
desc->PastKeyTensor = fused ? nullptr : ParseDmlTensorDescField(value, "PastKeyTensor", allocator, false);
desc->PastValueTensor = fused ? nullptr : ParseDmlTensorDescField(value, "PastValueTensor", allocator, false);
desc->PastSequenceLengthsTensor = fused ? nullptr : ParseDmlTensorDescField(value, "PastSequenceLengthsTensor", allocator, false);
desc->OutputTensor = fused ? nullptr : ParseDmlTensorDescField(value, "OutputTensor", allocator, true);
desc->OutputPresentKeyTensor = fused ? nullptr : ParseDmlTensorDescField(value, "OutputPresentKeyTensor", allocator, false);
desc->OutputPresentValueTensor = fused ? nullptr : ParseDmlTensorDescField(value, "OutputPresentValueTensor", allocator, false);
desc->Scale = ParseFloat32Field(value, "Scale", true);
desc->MaskFilterValue = ParseFloat32Field(value, "MaskFilterValue", true);
desc->QueryHeadCount = ParseUInt32Field(value, "QueryHeadCount", true);
desc->KeyValueHeadCount = ParseUInt32Field(value, "KeyValueHeadCount", true);
desc->MaskType = ParseDmlMultiheadAttentionMaskTypeField(value, "MaskType", true, {});
auto opDesc = allocator.Allocate<DML_OPERATOR_DESC>();
opDesc->Type = DML_OPERATOR_MULTIHEAD_ATTENTION1;
opDesc->Desc = desc;
return opDesc;
}
Model::DmlDispatchableDesc::BindPoints GetBindPoints(const DML_MULTIHEAD_ATTENTION1_OPERATOR_DESC& desc)
{
Model::DmlDispatchableDesc::BindPoints bindPoints = {};
bindPoints.inputs.push_back({"QueryTensor", 1, false});
bindPoints.inputs.push_back({"KeyTensor", 1, false});
bindPoints.inputs.push_back({"ValueTensor", 1, false});
bindPoints.inputs.push_back({"StackedQueryKeyTensor", 1, false});
bindPoints.inputs.push_back({"StackedKeyValueTensor", 1, false});
bindPoints.inputs.push_back({"StackedQueryKeyValueTensor", 1, false});
bindPoints.inputs.push_back({"BiasTensor", 1, false});
bindPoints.inputs.push_back({"MaskTensor", 1, false});
bindPoints.inputs.push_back({"RelativePositionBiasTensor", 1, false});
bindPoints.inputs.push_back({"PastKeyTensor", 1, false});
bindPoints.inputs.push_back({"PastValueTensor", 1, false});
bindPoints.inputs.push_back({"PastSequenceLengthsTensor", 1, false});
bindPoints.outputs.push_back({"OutputTensor", 1, true});
bindPoints.outputs.push_back({"OutputPresentKeyTensor", 1, false});
bindPoints.outputs.push_back({"OutputPresentValueTensor", 1, false});
return bindPoints;
}
DML_OPERATOR_DESC* ParseDmlQuantizeOperatorDesc(const rapidjson::Value& value, bool fused, BucketAllocator& allocator)
{
if (!value.IsObject()) { throw std::invalid_argument("Expected a valid JSON object."); }
auto desc = allocator.Allocate<DML_QUANTIZE_OPERATOR_DESC>();
desc->InputTensor = fused ? nullptr : ParseDmlTensorDescField(value, "InputTensor", allocator, true);
desc->QuantizationType = ParseDmlQuantizationTypeField(value, "QuantizationType", true, {});
desc->QuantizationTensorCount = ParseUInt32Field(value, "QuantizationTensorCount", true);
desc->QuantizationTensors = fused ? nullptr : AsPointer(ParseDmlTensorDescArrayField(value, "QuantizationTensors", allocator, true));
desc->OutputTensor = fused ? nullptr : ParseDmlTensorDescField(value, "OutputTensor", allocator, true);
auto opDesc = allocator.Allocate<DML_OPERATOR_DESC>();
opDesc->Type = DML_OPERATOR_QUANTIZE;
opDesc->Desc = desc;
return opDesc;
}
Model::DmlDispatchableDesc::BindPoints GetBindPoints(const DML_QUANTIZE_OPERATOR_DESC& desc)
{
Model::DmlDispatchableDesc::BindPoints bindPoints = {};
bindPoints.inputs.push_back({"InputTensor", 1, true});
bindPoints.inputs.push_back({"QuantizationTensors", desc.QuantizationTensorCount, true});
bindPoints.outputs.push_back({"OutputTensor", 1, true});
return bindPoints;
}
DML_OPERATOR_DESC* ParseDmlDequantizeOperatorDesc(const rapidjson::Value& value, bool fused, BucketAllocator& allocator)
{
if (!value.IsObject()) { throw std::invalid_argument("Expected a valid JSON object."); }
auto desc = allocator.Allocate<DML_DEQUANTIZE_OPERATOR_DESC>();
desc->InputTensor = fused ? nullptr : ParseDmlTensorDescField(value, "InputTensor", allocator, true);
desc->QuantizationType = ParseDmlQuantizationTypeField(value, "QuantizationType", true, {});
desc->QuantizationTensorCount = ParseUInt32Field(value, "QuantizationTensorCount", true);
desc->QuantizationTensors = fused ? nullptr : AsPointer(ParseDmlTensorDescArrayField(value, "QuantizationTensors", allocator, true));
desc->OutputTensor = fused ? nullptr : ParseDmlTensorDescField(value, "OutputTensor", allocator, true);
auto opDesc = allocator.Allocate<DML_OPERATOR_DESC>();
opDesc->Type = DML_OPERATOR_DEQUANTIZE;
opDesc->Desc = desc;
return opDesc;
}
Model::DmlDispatchableDesc::BindPoints GetBindPoints(const DML_DEQUANTIZE_OPERATOR_DESC& desc)
{
Model::DmlDispatchableDesc::BindPoints bindPoints = {};
bindPoints.inputs.push_back({"InputTensor", 1, true});
bindPoints.inputs.push_back({"QuantizationTensors", desc.QuantizationTensorCount, true});
bindPoints.outputs.push_back({"OutputTensor", 1, true});
return bindPoints;
}
DML_OPERATOR_DESC* ParseDmlActivationEluOperatorDesc(const rapidjson::Value& value, bool fused, BucketAllocator& allocator)
{
if (!value.IsObject()) { throw std::invalid_argument("Expected a valid JSON object."); }
@ -4905,6 +5061,10 @@ DML_OPERATOR_DESC* ParseDmlOperatorDesc(const rapidjson::Value& value, bool fuse
if (!strcmp(type, "DML_OPERATOR_MULTIHEAD_ATTENTION") || !strcmp(type, "MULTIHEAD_ATTENTION")) return ParseDmlMultiheadAttentionOperatorDesc(descValue, fused, allocator);
if (!strcmp(type, "DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING") || !strcmp(type, "QUANTIZED_LINEAR_AVERAGE_POOLING")) return ParseDmlQuantizedLinearAveragePoolingOperatorDesc(descValue, fused, allocator);
if (!strcmp(type, "DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT") || !strcmp(type, "MATRIX_MULTIPLY_INTEGER_TO_FLOAT")) return ParseDmlMatrixMultiplyIntegerToFloatOperatorDesc(descValue, fused, allocator);
if (!strcmp(type, "DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION2") || !strcmp(type, "MEAN_VARIANCE_NORMALIZATION2")) return ParseDmlMeanVarianceNormalization2OperatorDesc(descValue, fused, allocator);
if (!strcmp(type, "DML_OPERATOR_MULTIHEAD_ATTENTION1") || !strcmp(type, "MULTIHEAD_ATTENTION1")) return ParseDmlMultiheadAttention1OperatorDesc(descValue, fused, allocator);
if (!strcmp(type, "DML_OPERATOR_QUANTIZE") || !strcmp(type, "QUANTIZE")) return ParseDmlQuantizeOperatorDesc(descValue, fused, allocator);
if (!strcmp(type, "DML_OPERATOR_DEQUANTIZE") || !strcmp(type, "DEQUANTIZE")) return ParseDmlDequantizeOperatorDesc(descValue, fused, allocator);
if (!strcmp(type, "DML_OPERATOR_ACTIVATION_ELU") || !strcmp(type, "ACTIVATION_ELU")) return ParseDmlActivationEluOperatorDesc(descValue, fused, allocator);
if (!strcmp(type, "DML_OPERATOR_ACTIVATION_CELU") || !strcmp(type, "ACTIVATION_CELU")) return ParseDmlActivationCeluOperatorDesc(descValue, fused, allocator);
if (!strcmp(type, "DML_OPERATOR_ACTIVATION_HARDMAX") || !strcmp(type, "ACTIVATION_HARDMAX")) return ParseDmlActivationHardmaxOperatorDesc(descValue, fused, allocator);
@ -5082,6 +5242,10 @@ Model::DmlDispatchableDesc::BindPoints GetBindPoints(const DML_OPERATOR_DESC& de
case DML_OPERATOR_MULTIHEAD_ATTENTION: return GetBindPoints(*reinterpret_cast<const DML_MULTIHEAD_ATTENTION_OPERATOR_DESC*>(desc.Desc));
case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: return GetBindPoints(*reinterpret_cast<const DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC*>(desc.Desc));
case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: return GetBindPoints(*reinterpret_cast<const DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC*>(desc.Desc));
case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION2: return GetBindPoints(*reinterpret_cast<const DML_MEAN_VARIANCE_NORMALIZATION2_OPERATOR_DESC*>(desc.Desc));
case DML_OPERATOR_MULTIHEAD_ATTENTION1: return GetBindPoints(*reinterpret_cast<const DML_MULTIHEAD_ATTENTION1_OPERATOR_DESC*>(desc.Desc));
case DML_OPERATOR_QUANTIZE: return GetBindPoints(*reinterpret_cast<const DML_QUANTIZE_OPERATOR_DESC*>(desc.Desc));
case DML_OPERATOR_DEQUANTIZE: return GetBindPoints(*reinterpret_cast<const DML_DEQUANTIZE_OPERATOR_DESC*>(desc.Desc));
case DML_OPERATOR_ACTIVATION_ELU: return GetBindPoints(*reinterpret_cast<const DML_ACTIVATION_ELU_OPERATOR_DESC*>(desc.Desc));
case DML_OPERATOR_ACTIVATION_CELU: return GetBindPoints(*reinterpret_cast<const DML_ACTIVATION_CELU_OPERATOR_DESC*>(desc.Desc));
case DML_OPERATOR_ACTIVATION_HARDMAX: return GetBindPoints(*reinterpret_cast<const DML_ACTIVATION_HARDMAX_OPERATOR_DESC*>(desc.Desc));

Просмотреть файл

@ -1,7 +1,7 @@
param
(
[string]$SchemaFilePath = "$PSScriptRoot\DmlSchema.json",
[string]$MaxFeatureLevel = "6.2"
[string]$MaxFeatureLevel = "6.3"
)
function ConvertSnakeToCamelCase($SnakeCaseName)