Add MVN2 and Dequantize support (#591)
* Add NormalizeMean support to MeanVarianceNormalization * Fix build break * Fix * Add Dequantize support
This commit is contained in:
Родитель
e3a75c1f58
Коммит
4d65cad0be
|
@ -3193,6 +3193,9 @@ namespace dml
|
|||
Optional<Expression> bias,
|
||||
Span<const uint32_t> axes,
|
||||
bool normalizeVariance,
|
||||
#if DML_TARGET_VERSION >= 0x6300
|
||||
bool normalizeMean,
|
||||
#endif
|
||||
float epsilon,
|
||||
FusedActivation fusedActivation = FusedActivation::None())
|
||||
{
|
||||
|
@ -3213,14 +3216,20 @@ namespace dml
|
|||
|
||||
detail::FusedActivationStorage storage;
|
||||
|
||||
#if DML_TARGET_VERSION >= 0x6300
|
||||
DML_MEAN_VARIANCE_NORMALIZATION2_OPERATOR_DESC desc = {};
|
||||
desc.UseMean = normalizeMean;
|
||||
desc.UseVariance = normalizeVariance;
|
||||
#else
|
||||
DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC desc = {};
|
||||
desc.NormalizeVariance = normalizeVariance;
|
||||
#endif
|
||||
desc.InputTensor = inputTensor.AsPtr<DML_TENSOR_DESC>();
|
||||
desc.ScaleTensor = scale ? scaleTensor.AsPtr<DML_TENSOR_DESC>() : nullptr;
|
||||
desc.BiasTensor = bias ? biasTensor.AsPtr<DML_TENSOR_DESC>() : nullptr;
|
||||
desc.OutputTensor = outputTensor.AsPtr<DML_TENSOR_DESC>();
|
||||
desc.AxisCount = static_cast<UINT>(axes.size());
|
||||
desc.Axes = axes.data();
|
||||
desc.NormalizeVariance = normalizeVariance;
|
||||
desc.Epsilon = epsilon;
|
||||
desc.FusedActivation = detail::GetFusedActivationPtr(fusedActivation, &storage);
|
||||
|
||||
|
@ -3230,7 +3239,13 @@ namespace dml
|
|||
scale ? scale->Impl() : nullptr,
|
||||
bias ? bias->Impl() : nullptr
|
||||
};
|
||||
|
||||
#if DML_TARGET_VERSION >= 0x6300
|
||||
detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION2, &desc, inputs);
|
||||
#else
|
||||
detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1, &desc, inputs);
|
||||
#endif
|
||||
|
||||
detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor));
|
||||
|
||||
return output;
|
||||
|
@ -4222,6 +4237,62 @@ namespace dml
|
|||
|
||||
#endif // DML_TARGET_VERSION >= 0x5000
|
||||
|
||||
#if DML_TARGET_VERSION >= 0x6300
|
||||
inline Expression Dequantize(
|
||||
Expression input,
|
||||
Span<const Expression> quantizationParameters,
|
||||
DML_QUANTIZATION_TYPE quantizationType)
|
||||
{
|
||||
for (const auto& quantizationParameter : quantizationParameters)
|
||||
{
|
||||
assert(detail::HasSameOwner({ quantizationParameter, input }));
|
||||
}
|
||||
|
||||
assert(quantizationType != DML_QUANTIZATION_TYPE_NONE);
|
||||
assert(!quantizationParameters.empty());
|
||||
|
||||
detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder();
|
||||
|
||||
TensorDesc inputTensor = input.Impl()->GetOutputDesc();
|
||||
|
||||
std::vector<detail::NodeOutput*> inputs;
|
||||
inputs.reserve(quantizationParameters.size() + 1);
|
||||
inputs.push_back(input.Impl());
|
||||
|
||||
std::vector<TensorDesc> quantizationParametersTensors;
|
||||
quantizationParametersTensors.reserve(quantizationParameters.size());
|
||||
|
||||
std::vector<DML_TENSOR_DESC> quantizationParametersDescs;
|
||||
quantizationParametersDescs.reserve(quantizationParameters.size());
|
||||
|
||||
// The output data type is always the same as the data type of the scale
|
||||
assert(quantizationType == DML_QUANTIZATION_TYPE_SCALE || quantizationType == DML_QUANTIZATION_TYPE_SCALE_ZERO_POINT);
|
||||
DML_TENSOR_DATA_TYPE outputDataType = quantizationParameters[0].GetOutputDesc().dataType;
|
||||
|
||||
for (const auto& quantizationParameter : quantizationParameters)
|
||||
{
|
||||
quantizationParametersTensors.push_back(quantizationParameter.Impl()->GetOutputDesc());
|
||||
TensorDesc& quantizationParameterDesc = quantizationParametersTensors.back();
|
||||
quantizationParametersDescs.push_back(*quantizationParameterDesc.AsPtr<DML_TENSOR_DESC>());
|
||||
inputs.push_back(quantizationParameter.Impl());
|
||||
}
|
||||
|
||||
TensorDesc outputTensor(outputDataType, inputTensor.sizes, builder->GetTensorPolicy());
|
||||
|
||||
DML_DEQUANTIZE_OPERATOR_DESC desc = {};
|
||||
desc.InputTensor = inputTensor.AsPtr<DML_TENSOR_DESC>();
|
||||
desc.QuantizationTensors = quantizationParametersDescs.data();
|
||||
desc.QuantizationTensorCount = static_cast<uint32_t>(quantizationParametersDescs.size());
|
||||
desc.OutputTensor = outputTensor.AsPtr<DML_TENSOR_DESC>();
|
||||
desc.QuantizationType = quantizationType;
|
||||
|
||||
detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_DEQUANTIZE, &desc, inputs);
|
||||
detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor));
|
||||
|
||||
return output;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Reinterprets the memory of a tensor with a different type and dimensions (analogously to using
|
||||
// reinterpret_cast to access raw bits). Note that this is different to the DML Cast operator, which performs
|
||||
// a type cast on the contents of a tensor (analogously to static_cast). The total tensor size of the output
|
||||
|
|
Загрузка…
Ссылка в новой задаче