diff --git a/Libraries/DirectMLX.h b/Libraries/DirectMLX.h index bbeb5b7..07bcd87 100644 --- a/Libraries/DirectMLX.h +++ b/Libraries/DirectMLX.h @@ -3322,6 +3322,7 @@ namespace dml Expression input, TensorDimensions outputSizes, DML_INTERPOLATION_MODE mode, + DML_AXIS_DIRECTION roundingDirection, Span scales = {}, Span inputPixelOffsets = {}, Span outputPixelOffsets = {}) @@ -3331,6 +3332,9 @@ namespace dml TensorDesc inputTensor = input.Impl()->GetOutputDesc(); uint32_t dimensionCount = static_cast(inputTensor.sizes.size()); assert(outputSizes.size() == dimensionCount); + assert(scales.empty() || scales.size() == dimensionCount); + assert(inputPixelOffsets.empty() || inputPixelOffsets.size() == dimensionCount); + assert(outputPixelOffsets.empty() || outputPixelOffsets.size() == dimensionCount); SmallVector defaultScales; if (scales.empty()) @@ -3358,17 +3362,18 @@ namespace dml TensorDesc outputTensor(inputTensor.dataType, std::move(outputSizes), builder->GetTensorPolicy()); - DML_RESAMPLE1_OPERATOR_DESC desc = {}; + DML_RESAMPLE2_OPERATOR_DESC desc = {}; desc.InputTensor = inputTensor.AsPtr(); desc.OutputTensor = outputTensor.AsPtr(); desc.InterpolationMode = mode; - desc.DimensionCount = static_cast(scales.size()); + desc.RoundingDirection = roundingDirection; + desc.DimensionCount = dimensionCount; desc.Scales = scales.data(); desc.InputPixelOffsets = inputPixelOffsets.data(); desc.OutputPixelOffsets = outputPixelOffsets.data(); detail::NodeOutput* const inputs[] = { input.Impl() }; - detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_RESAMPLE1, &desc, inputs); + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_RESAMPLE2, &desc, inputs); detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); return output;