Add MVN2 and Dequantize support (#591)

* Add NormalizeMean support to MeanVarianceNormalization

* Fix build break

* Fix

* Add Dequantize support
This commit is contained in:
Patrice Vignola 2024-06-07 16:15:19 -07:00 коммит произвёл GitHub
Родитель e3a75c1f58
Коммит 4d65cad0be
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
1 изменённых файлов: 72 добавлений и 1 удалений

Просмотреть файл

@ -3193,6 +3193,9 @@ namespace dml
Optional<Expression> bias,
Span<const uint32_t> axes,
bool normalizeVariance,
#if DML_TARGET_VERSION >= 0x6300
bool normalizeMean,
#endif
float epsilon,
FusedActivation fusedActivation = FusedActivation::None())
{
@ -3213,14 +3216,20 @@ namespace dml
detail::FusedActivationStorage storage;
#if DML_TARGET_VERSION >= 0x6300
DML_MEAN_VARIANCE_NORMALIZATION2_OPERATOR_DESC desc = {};
desc.UseMean = normalizeMean;
desc.UseVariance = normalizeVariance;
#else
DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC desc = {};
desc.NormalizeVariance = normalizeVariance;
#endif
desc.InputTensor = inputTensor.AsPtr<DML_TENSOR_DESC>();
desc.ScaleTensor = scale ? scaleTensor.AsPtr<DML_TENSOR_DESC>() : nullptr;
desc.BiasTensor = bias ? biasTensor.AsPtr<DML_TENSOR_DESC>() : nullptr;
desc.OutputTensor = outputTensor.AsPtr<DML_TENSOR_DESC>();
desc.AxisCount = static_cast<UINT>(axes.size());
desc.Axes = axes.data();
desc.NormalizeVariance = normalizeVariance;
desc.Epsilon = epsilon;
desc.FusedActivation = detail::GetFusedActivationPtr(fusedActivation, &storage);
@ -3230,7 +3239,13 @@ namespace dml
scale ? scale->Impl() : nullptr,
bias ? bias->Impl() : nullptr
};
#if DML_TARGET_VERSION >= 0x6300
detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION2, &desc, inputs);
#else
detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1, &desc, inputs);
#endif
detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor));
return output;
@ -4222,6 +4237,62 @@ namespace dml
#endif // DML_TARGET_VERSION >= 0x5000
#if DML_TARGET_VERSION >= 0x6300
inline Expression Dequantize(
Expression input,
Span<const Expression> quantizationParameters,
DML_QUANTIZATION_TYPE quantizationType)
{
for (const auto& quantizationParameter : quantizationParameters)
{
assert(detail::HasSameOwner({ quantizationParameter, input }));
}
assert(quantizationType != DML_QUANTIZATION_TYPE_NONE);
assert(!quantizationParameters.empty());
detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder();
TensorDesc inputTensor = input.Impl()->GetOutputDesc();
std::vector<detail::NodeOutput*> inputs;
inputs.reserve(quantizationParameters.size() + 1);
inputs.push_back(input.Impl());
std::vector<TensorDesc> quantizationParametersTensors;
quantizationParametersTensors.reserve(quantizationParameters.size());
std::vector<DML_TENSOR_DESC> quantizationParametersDescs;
quantizationParametersDescs.reserve(quantizationParameters.size());
// The output data type is always the same as the data type of the scale
assert(quantizationType == DML_QUANTIZATION_TYPE_SCALE || quantizationType == DML_QUANTIZATION_TYPE_SCALE_ZERO_POINT);
DML_TENSOR_DATA_TYPE outputDataType = quantizationParameters[0].GetOutputDesc().dataType;
for (const auto& quantizationParameter : quantizationParameters)
{
quantizationParametersTensors.push_back(quantizationParameter.Impl()->GetOutputDesc());
TensorDesc& quantizationParameterDesc = quantizationParametersTensors.back();
quantizationParametersDescs.push_back(*quantizationParameterDesc.AsPtr<DML_TENSOR_DESC>());
inputs.push_back(quantizationParameter.Impl());
}
TensorDesc outputTensor(outputDataType, inputTensor.sizes, builder->GetTensorPolicy());
DML_DEQUANTIZE_OPERATOR_DESC desc = {};
desc.InputTensor = inputTensor.AsPtr<DML_TENSOR_DESC>();
desc.QuantizationTensors = quantizationParametersDescs.data();
desc.QuantizationTensorCount = static_cast<uint32_t>(quantizationParametersDescs.size());
desc.OutputTensor = outputTensor.AsPtr<DML_TENSOR_DESC>();
desc.QuantizationType = quantizationType;
detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_DEQUANTIZE, &desc, inputs);
detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor));
return output;
}
#endif
// Reinterprets the memory of a tensor with a different type and dimensions (analogously to using
// reinterpret_cast to access raw bits). Note that this is different to the DML Cast operator, which performs
// a type cast on the contents of a tensor (analogously to static_cast). The total tensor size of the output