update resample in DirectMLX (#495)
This commit is contained in:
Родитель
46a71aa9a0
Коммит
ae44b1895d
|
@ -3322,6 +3322,7 @@ namespace dml
|
|||
Expression input,
|
||||
TensorDimensions outputSizes,
|
||||
DML_INTERPOLATION_MODE mode,
|
||||
DML_AXIS_DIRECTION roundingDirection,
|
||||
Span<const float> scales = {},
|
||||
Span<const float> inputPixelOffsets = {},
|
||||
Span<const float> outputPixelOffsets = {})
|
||||
|
@ -3331,6 +3332,9 @@ namespace dml
|
|||
TensorDesc inputTensor = input.Impl()->GetOutputDesc();
|
||||
uint32_t dimensionCount = static_cast<uint32_t>(inputTensor.sizes.size());
|
||||
assert(outputSizes.size() == dimensionCount);
|
||||
assert(scales.empty() || scales.size() == dimensionCount);
|
||||
assert(inputPixelOffsets.empty() || inputPixelOffsets.size() == dimensionCount);
|
||||
assert(outputPixelOffsets.empty() || outputPixelOffsets.size() == dimensionCount);
|
||||
|
||||
SmallVector<float, 4> defaultScales;
|
||||
if (scales.empty())
|
||||
|
@ -3358,17 +3362,18 @@ namespace dml
|
|||
|
||||
TensorDesc outputTensor(inputTensor.dataType, std::move(outputSizes), builder->GetTensorPolicy());
|
||||
|
||||
DML_RESAMPLE1_OPERATOR_DESC desc = {};
|
||||
DML_RESAMPLE2_OPERATOR_DESC desc = {};
|
||||
desc.InputTensor = inputTensor.AsPtr<DML_TENSOR_DESC>();
|
||||
desc.OutputTensor = outputTensor.AsPtr<DML_TENSOR_DESC>();
|
||||
desc.InterpolationMode = mode;
|
||||
desc.DimensionCount = static_cast<UINT>(scales.size());
|
||||
desc.RoundingDirection = roundingDirection;
|
||||
desc.DimensionCount = dimensionCount;
|
||||
desc.Scales = scales.data();
|
||||
desc.InputPixelOffsets = inputPixelOffsets.data();
|
||||
desc.OutputPixelOffsets = outputPixelOffsets.data();
|
||||
|
||||
detail::NodeOutput* const inputs[] = { input.Impl() };
|
||||
detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_RESAMPLE1, &desc, inputs);
|
||||
detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_RESAMPLE2, &desc, inputs);
|
||||
detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor));
|
||||
|
||||
return output;
|
||||
|
|
Загрузка…
Ссылка в новой задаче