This commit is contained in:
Xiang Zhang 2023-08-16 15:04:33 -07:00 коммит произвёл GitHub
Родитель 46a71aa9a0
Коммит ae44b1895d
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 8 добавлений и 3 удалений

Просмотреть файл

@ -3322,6 +3322,7 @@ namespace dml
Expression input,
TensorDimensions outputSizes,
DML_INTERPOLATION_MODE mode,
DML_AXIS_DIRECTION roundingDirection,
Span<const float> scales = {},
Span<const float> inputPixelOffsets = {},
Span<const float> outputPixelOffsets = {})
@ -3331,6 +3332,9 @@ namespace dml
TensorDesc inputTensor = input.Impl()->GetOutputDesc();
uint32_t dimensionCount = static_cast<uint32_t>(inputTensor.sizes.size());
assert(outputSizes.size() == dimensionCount);
assert(scales.empty() || scales.size() == dimensionCount);
assert(inputPixelOffsets.empty() || inputPixelOffsets.size() == dimensionCount);
assert(outputPixelOffsets.empty() || outputPixelOffsets.size() == dimensionCount);
SmallVector<float, 4> defaultScales;
if (scales.empty())
@ -3358,17 +3362,18 @@ namespace dml
TensorDesc outputTensor(inputTensor.dataType, std::move(outputSizes), builder->GetTensorPolicy());
DML_RESAMPLE1_OPERATOR_DESC desc = {};
DML_RESAMPLE2_OPERATOR_DESC desc = {};
desc.InputTensor = inputTensor.AsPtr<DML_TENSOR_DESC>();
desc.OutputTensor = outputTensor.AsPtr<DML_TENSOR_DESC>();
desc.InterpolationMode = mode;
desc.DimensionCount = static_cast<UINT>(scales.size());
desc.RoundingDirection = roundingDirection;
desc.DimensionCount = dimensionCount;
desc.Scales = scales.data();
desc.InputPixelOffsets = inputPixelOffsets.data();
desc.OutputPixelOffsets = outputPixelOffsets.data();
detail::NodeOutput* const inputs[] = { input.Impl() };
detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_RESAMPLE1, &desc, inputs);
detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_RESAMPLE2, &desc, inputs);
detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor));
return output;