Add custom kernel ScatterNDOfShape (#705)
* first draft * clang * Draft for ScatterNFOfShape * fix build * disable test when cuda is missing * fix implementation * update test * add MaskedScatterNdOfShape * fix merge conflicts
This commit is contained in:
Родитель
79f3b048d4
Коммит
f5055466d5
|
@ -7,6 +7,7 @@
|
|||
#include "cuda/add_mul.h"
|
||||
#include "cuda/fast_gelu.h"
|
||||
#include "cuda/negxplus1.h"
|
||||
#include "cuda/scatter_nd_of_shape.h"
|
||||
#include "cuda/transpose_cast.h"
|
||||
#endif
|
||||
|
||||
|
@ -29,15 +30,19 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
|
|||
,
|
||||
CustomCudaStructV2("AddSharedInput", AddSharedInputFloat32Type),
|
||||
CustomCudaStructV2("FastGelu", contrib::FastGelu<float>),
|
||||
CustomCudaStructV2("MaskedScatterNDOfShape", contrib::MaskedScatterNDOfShape<float>),
|
||||
CustomCudaStructV2("MulSharedInput", MulSharedInputFloat32Type),
|
||||
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<float>),
|
||||
CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape<float>),
|
||||
#if ORT_API_VERSION >= 16
|
||||
|
||||
CustomCudaStructV2("AddSharedInput", AddSharedInputFloat16Type),
|
||||
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::MFloat16>),
|
||||
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::BFloat16>),
|
||||
CustomCudaStructV2("MaskedScatterNDOfShape", contrib::MaskedScatterNDOfShape<ortc::MFloat16>),
|
||||
CustomCudaStructV2("MulSharedInput", MulSharedInputFloat16Type),
|
||||
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<ortc::MFloat16>),
|
||||
CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape<ortc::MFloat16>),
|
||||
CustomCudaStructV2("Transpose2DCastFP16", Transpose2DCastFloat32ToFloat16Type),
|
||||
CustomCudaStructV2("Transpose2DCastFP32", Transpose2DCastFloat16ToFloat32Type)
|
||||
#endif
|
||||
|
|
|
@ -0,0 +1,143 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include "ocos.h"
|
||||
#include "string_utils.h"
|
||||
#include "scatter_nd_of_shape_impl.cuh"
|
||||
|
||||
namespace contrib {
|
||||
|
||||
template <typename T>
|
||||
struct ScatterNDOfShape {
|
||||
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
|
||||
std::string value;
|
||||
OrtStatusPtr status = OrtW::GetOpAttribute(info, "reduction", value);
|
||||
if (status != nullptr)
|
||||
return status;
|
||||
|
||||
if (value == "add")
|
||||
reduction_ = ScatterReduction::Add;
|
||||
else if (value == "mul")
|
||||
reduction_ = ScatterReduction::Mul;
|
||||
else if (value == "min")
|
||||
reduction_ = ScatterReduction::Min;
|
||||
else if (value == "max")
|
||||
reduction_ = ScatterReduction::Max;
|
||||
else
|
||||
ORTX_CXX_API_THROW("Unexpected reduction, only Add is implemented.", ORT_RUNTIME_EXCEPTION);
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
OrtStatusPtr Compute(Ort::Custom::CUDAKernelContext* ctx,
|
||||
const ortc::Tensor<int64_t>& output_shape,
|
||||
const ortc::Tensor<int64_t>& indices,
|
||||
const ortc::Tensor<T>& updates,
|
||||
ortc::Tensor<T>& output) const {
|
||||
auto& output_shape_shape = output_shape.Shape();
|
||||
auto& indices_shape = indices.Shape();
|
||||
auto& updates_shape = updates.Shape();
|
||||
|
||||
if (output_shape_shape.size() != 1 || output_shape_shape[0] == 0) {
|
||||
ORTX_CXX_API_THROW("output shape must be a 1D tensor", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
if (indices_shape[indices_shape.size() - 1] != 1) {
|
||||
ORTX_CXX_API_THROW("last dimension of the indices tensor should be one", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
|
||||
const int64_t* shape_data = output_shape.Data(); // CPU pointer
|
||||
const int64_t* indices_data = indices.Data(); // GPU pointer
|
||||
const T* updates_data = updates.Data(); // GPU pointer
|
||||
std::vector<int64_t> voutput_shape(shape_data, shape_data + output_shape_shape[0]);
|
||||
T* output_data = output.Allocate(voutput_shape); // GPU pointer
|
||||
LaunchScatterNDOfShapeKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
|
||||
voutput_shape,
|
||||
indices_shape,
|
||||
indices_data,
|
||||
updates_data,
|
||||
output_data,
|
||||
reduction_);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static OrtMemType GetInputMemoryType(size_t input_index) {
|
||||
if (input_index == 0) // shape
|
||||
return OrtMemType::OrtMemTypeCPUInput;
|
||||
return OrtMemType::OrtMemTypeDefault;
|
||||
}
|
||||
|
||||
ScatterReduction reduction_;
|
||||
};
|
||||
|
||||
|
||||
template <typename T>
|
||||
struct MaskedScatterNDOfShape {
|
||||
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
|
||||
std::string value;
|
||||
OrtStatusPtr status = OrtW::GetOpAttribute(info, "reduction", value);
|
||||
if (status != nullptr)
|
||||
return status;
|
||||
|
||||
if (value == "add")
|
||||
reduction_ = ScatterReduction::Add;
|
||||
else if (value == "mul")
|
||||
reduction_ = ScatterReduction::Mul;
|
||||
else if (value == "min")
|
||||
reduction_ = ScatterReduction::Min;
|
||||
else if (value == "max")
|
||||
reduction_ = ScatterReduction::Max;
|
||||
else
|
||||
ORTX_CXX_API_THROW("Unexpected reduction, only Add is implemented.", ORT_RUNTIME_EXCEPTION);
|
||||
|
||||
status = OrtW::GetOpAttribute(info, "maskedValue", masked_value_);
|
||||
if (status != nullptr)
|
||||
return status;
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
OrtStatusPtr Compute(Ort::Custom::CUDAKernelContext* ctx,
|
||||
const ortc::Tensor<int64_t>& output_shape,
|
||||
const ortc::Tensor<int64_t>& indices,
|
||||
const ortc::Tensor<T>& updates,
|
||||
ortc::Tensor<T>& output) const {
|
||||
auto& output_shape_shape = output_shape.Shape();
|
||||
auto& indices_shape = indices.Shape();
|
||||
auto& updates_shape = updates.Shape();
|
||||
|
||||
if (output_shape_shape.size() != 1 || output_shape_shape[0] == 0) {
|
||||
ORTX_CXX_API_THROW("output shape must be a 1D tensor", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
if (indices_shape[indices_shape.size() - 1] != 1) {
|
||||
ORTX_CXX_API_THROW("last dimension of the indices tensor should be one", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
|
||||
const int64_t* shape_data = output_shape.Data(); // CPU pointer
|
||||
const int64_t* indices_data = indices.Data(); // GPU pointer
|
||||
const T* updates_data = updates.Data(); // GPU pointer
|
||||
std::vector<int64_t> voutput_shape(shape_data, shape_data + output_shape_shape[0]);
|
||||
T* output_data = output.Allocate(voutput_shape); // GPU pointer
|
||||
LaunchMaskedScatterNDOfShapeKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
|
||||
voutput_shape,
|
||||
indices_shape,
|
||||
indices_data,
|
||||
updates_data,
|
||||
output_data,
|
||||
reduction_,
|
||||
masked_value_);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static OrtMemType GetInputMemoryType(size_t input_index) {
|
||||
if (input_index == 0) // shape
|
||||
return OrtMemType::OrtMemTypeCPUInput;
|
||||
return OrtMemType::OrtMemTypeDefault;
|
||||
}
|
||||
|
||||
private:
|
||||
ScatterReduction reduction_;
|
||||
int64_t masked_value_;
|
||||
};
|
||||
|
||||
} // namespace contrib
|
|
@ -0,0 +1,264 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "device_prop.cuh"
|
||||
#include "utils.cuh"
|
||||
#include "scatter_nd_of_shape_impl.cuh"
|
||||
#include "cuda_type.h"
|
||||
|
||||
namespace contrib {
|
||||
|
||||
#define _ENFORCE(cond, msg) \
|
||||
if (!(cond)) ORTX_CXX_API_THROW(msg, ORT_RUNTIME_EXCEPTION);
|
||||
|
||||
#ifndef HIP_LONG
|
||||
#define HIP_LONG int32_t
|
||||
#endif
|
||||
|
||||
#ifndef CUDA_LONG
|
||||
#define CUDA_LONG int32_t
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void _add_inplace(T& x, const T a) { x += a; }
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ void _add_inplace(half& x, const half a) {
|
||||
#if __CUDA_ARCH__ < 700
|
||||
x = __float2half(__half2float(x) + __half2float(a));
|
||||
#else
|
||||
x += a;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void
|
||||
addition_inplace_kernel(T* __restrict__ output_data, const int64_t* __restrict__ indices_data,
|
||||
const T* __restrict__ updates_data, const CUDA_LONG indice_size,
|
||||
const CUDA_LONG nrows, const CUDA_LONG stride) {
|
||||
HIP_LONG id = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (id >= stride)
|
||||
return;
|
||||
|
||||
for (size_t i = 0; i < nrows; ++i) {
|
||||
output_data[i * stride + id] = 0;
|
||||
}
|
||||
|
||||
int64_t index;
|
||||
for (size_t i = 0; i < indice_size; ++i) {
|
||||
index = (indices_data[i] + nrows) % nrows;
|
||||
_add_inplace(output_data[index * stride + id], updates_data[i * stride + id]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void masked_addition_inplace_kernel(T *__restrict__ output_data,
|
||||
const int64_t *__restrict__ indices_data,
|
||||
const T *__restrict__ updates_data,
|
||||
const CUDA_LONG indice_size,
|
||||
const CUDA_LONG nrows, const CUDA_LONG stride,
|
||||
const int64_t masked_value) {
|
||||
auto id = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (id >= stride)
|
||||
return;
|
||||
|
||||
for (size_t i = 0; i < nrows; ++i) {
|
||||
output_data[i * stride + id] = 0;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < indice_size; ++i) {
|
||||
if (indices_data[i] == masked_value)
|
||||
continue;
|
||||
_add_inplace(output_data[indices_data[i] * stride + id], updates_data[i * stride + id]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int NTHREAD>
|
||||
__global__ void masked_addition_inplace_kernelN(T *__restrict__ output_data,
|
||||
const int64_t *__restrict__ indices_data,
|
||||
const T *__restrict__ updates_data,
|
||||
const CUDA_LONG indice_size,
|
||||
const CUDA_LONG nrows, const CUDA_LONG stride,
|
||||
const int64_t masked_value) {
|
||||
__shared__ int64_t shared_indices[NTHREAD];
|
||||
|
||||
CUDA_LONG tid = threadIdx.x;
|
||||
CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
|
||||
for (size_t i = 0; i < nrows; ++i) {
|
||||
output_data[i * stride + id] = 0;
|
||||
}
|
||||
|
||||
int begin = 0;
|
||||
int end = std::min(begin + NTHREAD, indice_size);
|
||||
while (begin < end && (end == begin + NTHREAD)) {
|
||||
shared_indices[tid] = indices_data[tid + begin];
|
||||
__syncthreads();
|
||||
|
||||
for (size_t i = begin; i < end; ++i) {
|
||||
if (shared_indices[tid] == masked_value)
|
||||
continue;
|
||||
_add_inplace(output_data[shared_indices[tid] * stride + id],
|
||||
updates_data[i * stride + id]);
|
||||
}
|
||||
|
||||
begin = end;
|
||||
end = std::min(begin + NTHREAD, indice_size);
|
||||
}
|
||||
|
||||
for (size_t i = begin; i < indice_size; ++i) {
|
||||
if (indices_data[i] == masked_value)
|
||||
continue;
|
||||
_add_inplace(output_data[indices_data[i] * stride + id], updates_data[i * stride + id]);
|
||||
}
|
||||
}
|
||||
|
||||
template <class NTYPE>
|
||||
NTYPE flattened_dimension(const std::vector<NTYPE>& values, size_t first = 0) {
|
||||
NTYPE r = 1;
|
||||
for (auto it = values.begin() + first; it != values.end(); ++it)
|
||||
r *= *it;
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
cudaError_t ScatterNDOfShapeKernel(cudaStream_t stream,
|
||||
const std::vector<int64_t>& output_shape,
|
||||
const std::vector<int64_t>& indices_shape,
|
||||
const int64_t* indices_data,
|
||||
const T* updates_data,
|
||||
T* output_data,
|
||||
ScatterReduction reduction) {
|
||||
if (reduction != ScatterReduction::Add)
|
||||
ORTX_CXX_API_THROW("Only reduction 'add' is implemented.", ORT_RUNTIME_EXCEPTION);
|
||||
size_t indice_size = static_cast<size_t>(flattened_dimension(indices_shape));
|
||||
size_t output_size = static_cast<size_t>(flattened_dimension(output_shape));
|
||||
size_t rank = output_shape.size() - indices_shape.size();
|
||||
size_t stride = static_cast<size_t>(flattened_dimension(output_shape, output_shape.size() - 1 - rank));
|
||||
size_t nrows = output_size / stride;
|
||||
|
||||
int threads_per_block = 256;
|
||||
int blocks_per_grid = (stride + threads_per_block - 1) / threads_per_block;
|
||||
|
||||
dim3 threads(threads_per_block);
|
||||
dim3 blocks(blocks_per_grid);
|
||||
using TT = typename CudaT<T>::MappedType;
|
||||
addition_inplace_kernel<TT><<<blocks, threads, 0, stream>>>(reinterpret_cast<TT*>(output_data), indices_data,
|
||||
reinterpret_cast<const TT*>(updates_data),
|
||||
indice_size, nrows, stride);
|
||||
return cudaGetLastError();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
cudaError_t MaskedScatterNDOfShapeKernel(cudaStream_t stream, const std::vector<int64_t> &input_shape,
|
||||
const std::vector<int64_t> &indices_shape,
|
||||
const int64_t *indices_data, const T *updates_data,
|
||||
T *output_data,
|
||||
ScatterReduction reduction, int64_t masked_value) {
|
||||
if (reduction != ScatterReduction::Add)
|
||||
ORTX_CXX_API_THROW("Only reduction 'add' is implemented.", ORT_RUNTIME_EXCEPTION);
|
||||
size_t indice_size = static_cast<size_t>(flattened_dimension(indices_shape));
|
||||
size_t input_size = static_cast<size_t>(flattened_dimension(input_shape));
|
||||
size_t stride = input_shape[input_shape.size() - 1];
|
||||
size_t nrows = input_size / stride;
|
||||
|
||||
std::vector<size_t> next_batch(indice_size);
|
||||
std::vector<uint8_t> processed(input_shape[0], 0);
|
||||
std::vector<uint8_t> processed_once(input_shape[0], 0);
|
||||
|
||||
int threads_per_block = 256;
|
||||
bool split = stride / threads_per_block <= 32;
|
||||
|
||||
int blocks_per_grid = (stride + threads_per_block - 1) / threads_per_block;
|
||||
dim3 threads(threads_per_block);
|
||||
dim3 blocks(blocks_per_grid);
|
||||
|
||||
using TT = typename CudaT<T>::MappedType;
|
||||
|
||||
if (split && stride >= 256 && threads_per_block == 256) {
|
||||
masked_addition_inplace_kernelN<TT, 256><<<blocks, threads, 0, stream>>>(
|
||||
reinterpret_cast<TT*>(output_data), indices_data,
|
||||
reinterpret_cast<const TT*>(updates_data),
|
||||
indice_size, nrows, stride, masked_value);
|
||||
} else {
|
||||
masked_addition_inplace_kernel<TT><<<blocks, threads, 0, stream>>>(
|
||||
reinterpret_cast<TT*>(output_data), indices_data,
|
||||
reinterpret_cast<const TT*>(updates_data),
|
||||
indice_size, nrows, stride, masked_value);
|
||||
}
|
||||
return cudaGetLastError();
|
||||
}
|
||||
|
||||
template <>
|
||||
cudaError_t LaunchScatterNDOfShapeKernel<float>(cudaStream_t stream,
|
||||
const std::vector<int64_t>& output_shape,
|
||||
const std::vector<int64_t>& indices_shape,
|
||||
const int64_t* indices,
|
||||
const float* updates,
|
||||
float* output,
|
||||
ScatterReduction reduction) {
|
||||
return ScatterNDOfShapeKernel(stream,
|
||||
output_shape,
|
||||
indices_shape,
|
||||
indices,
|
||||
updates,
|
||||
output,
|
||||
reduction);
|
||||
}
|
||||
|
||||
template <>
|
||||
cudaError_t LaunchScatterNDOfShapeKernel<ortc::MFloat16>(cudaStream_t stream,
|
||||
const std::vector<int64_t>& output_shape,
|
||||
const std::vector<int64_t>& indices_shape,
|
||||
const int64_t* indices,
|
||||
const ortc::MFloat16* updates,
|
||||
ortc::MFloat16* output,
|
||||
ScatterReduction reduction) {
|
||||
return ScatterNDOfShapeKernel(stream,
|
||||
output_shape,
|
||||
indices_shape,
|
||||
indices,
|
||||
updates,
|
||||
output,
|
||||
reduction);
|
||||
}
|
||||
|
||||
template <>
|
||||
cudaError_t LaunchMaskedScatterNDOfShapeKernel<float>(cudaStream_t stream,
|
||||
const std::vector<int64_t>& output_shape,
|
||||
const std::vector<int64_t>& indices_shape,
|
||||
const int64_t* indices,
|
||||
const float* updates,
|
||||
float* output,
|
||||
ScatterReduction reduction,
|
||||
int64_t masked_value) {
|
||||
return MaskedScatterNDOfShapeKernel(stream,
|
||||
output_shape,
|
||||
indices_shape,
|
||||
indices,
|
||||
updates,
|
||||
output,
|
||||
reduction,
|
||||
masked_value);
|
||||
}
|
||||
|
||||
template <>
|
||||
cudaError_t LaunchMaskedScatterNDOfShapeKernel<ortc::MFloat16>(cudaStream_t stream,
|
||||
const std::vector<int64_t>& output_shape,
|
||||
const std::vector<int64_t>& indices_shape,
|
||||
const int64_t* indices,
|
||||
const ortc::MFloat16* updates,
|
||||
ortc::MFloat16* output,
|
||||
ScatterReduction reduction,
|
||||
int64_t masked_value) {
|
||||
return MaskedScatterNDOfShapeKernel(stream,
|
||||
output_shape,
|
||||
indices_shape,
|
||||
indices,
|
||||
updates,
|
||||
output,
|
||||
reduction,
|
||||
masked_value);
|
||||
}
|
||||
|
||||
} // namespace contrib
|
|
@ -0,0 +1,37 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace contrib {
|
||||
|
||||
enum class ScatterReduction : int {
|
||||
None = 0,
|
||||
Add = 1,
|
||||
Mul = 2,
|
||||
Min = 3,
|
||||
Max = 4,
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
cudaError_t LaunchScatterNDOfShapeKernel(cudaStream_t stream,
|
||||
const std::vector<int64_t>& output_shape,
|
||||
const std::vector<int64_t>& indices_shape,
|
||||
const int64_t* indices,
|
||||
const T* updates,
|
||||
T* output,
|
||||
ScatterReduction reduction);
|
||||
|
||||
template <typename T>
|
||||
cudaError_t LaunchMaskedScatterNDOfShapeKernel(cudaStream_t stream,
|
||||
const std::vector<int64_t>& output_shape,
|
||||
const std::vector<int64_t>& indices_shape,
|
||||
const int64_t* indices,
|
||||
const T* updates,
|
||||
T* output,
|
||||
ScatterReduction reduction,
|
||||
int64_t masked_value);
|
||||
|
||||
} // namespace contrib
|
|
@ -4,6 +4,7 @@ from numpy.testing import assert_almost_equal
|
|||
from onnx import helper, numpy_helper, onnx_pb as onnx_proto, TensorProto
|
||||
from onnx.reference import ReferenceEvaluator
|
||||
from onnx.reference.op_run import OpRun
|
||||
from onnx.reference.ops.op_scatternd import _scatter_nd_impl
|
||||
from onnxruntime_extensions import make_onnx_model
|
||||
from onnxruntime_extensions import get_library_path as _get_library_path
|
||||
|
||||
|
@ -14,6 +15,15 @@ def has_cuda():
|
|||
return "CUDAExecutionProvider" in _ort.get_available_providers()
|
||||
|
||||
|
||||
class ScatterNDOfShape(OpRun):
|
||||
op_domain = "ai.onnx.contrib"
|
||||
|
||||
def _run(self, shape, indices, updates, reduction=None, strategy=None):
|
||||
data = np.zeros(shape, dtype=updates.dtype)
|
||||
y = _scatter_nd_impl(data, indices, updates, reduction=reduction)
|
||||
return (y,)
|
||||
|
||||
|
||||
class NegXPlus1(OpRun):
|
||||
op_domain = "ai.onnx.contrib"
|
||||
|
||||
|
@ -274,6 +284,193 @@ class TestCudaOps(unittest.TestCase):
|
|||
shapec=(3, 2, 3),
|
||||
)
|
||||
|
||||
def _scatternd_of_shape_optimize_cuda(self, optimize, dim3, itype):
|
||||
indices_shape = ["i", "j", 1] if dim3 else ["j", 1]
|
||||
updates_shape = ["i", "j", "b"] if dim3 else ["j", "b"]
|
||||
|
||||
model = helper.make_model(
|
||||
helper.make_graph(
|
||||
[
|
||||
helper.make_node(
|
||||
"ScatterNDOfShape",
|
||||
inputs=["shape", "indices", "updates"],
|
||||
outputs=["y"],
|
||||
reduction="add",
|
||||
strategy="optimize" if optimize else "none",
|
||||
domain="ai.onnx.contrib",
|
||||
)
|
||||
],
|
||||
"nd",
|
||||
[
|
||||
helper.make_tensor_value_info("shape", TensorProto.INT64, [2]),
|
||||
helper.make_tensor_value_info("indices", TensorProto.INT64, indices_shape),
|
||||
helper.make_tensor_value_info("updates", itype, updates_shape),
|
||||
],
|
||||
[helper.make_tensor_value_info("y", itype, [None, None])],
|
||||
),
|
||||
opset_imports=[
|
||||
helper.make_opsetid("", 18),
|
||||
helper.make_opsetid("ai.onnx.contrib", 1),
|
||||
],
|
||||
ir_version=9,
|
||||
)
|
||||
|
||||
if dim3:
|
||||
shape = (128, 1024)
|
||||
indices = np.zeros((2, 64, 1)).astype(np.int64)
|
||||
indices[:, ::2, 0] = 87
|
||||
indices[:, ::3, 0] = 85
|
||||
updates = np.ones((2, 64, 1024)).astype(np.float32)
|
||||
else:
|
||||
shape = (128, 1024)
|
||||
indices = np.zeros((128, 1)).astype(np.int64)
|
||||
indices[::2, 0] = 87
|
||||
indices[::3, 0] = 85
|
||||
updates = np.ones((128, 1024)).astype(np.float32)
|
||||
if itype != 1:
|
||||
updates = updates.astype(np.float16)
|
||||
feeds = dict(shape=np.array(shape, dtype=np.int64), indices=indices, updates=updates)
|
||||
|
||||
ref = ReferenceEvaluator(model, new_ops=[ScatterNDOfShape])
|
||||
expected = ref.run(None, feeds)[0]
|
||||
|
||||
opts = _ort.SessionOptions()
|
||||
opts.register_custom_ops_library(_get_library_path())
|
||||
sess = _ort.InferenceSession(model.SerializeToString(), opts, providers=["CUDAExecutionProvider"])
|
||||
ro = None
|
||||
got = sess.run(None, feeds, ro)[0]
|
||||
self.assertEqual(expected.tolist(), got.tolist())
|
||||
|
||||
@unittest.skipIf(not has_cuda(), reason="cuda not available")
|
||||
def test_scatternd_of_shape_optimize_cuda(self):
|
||||
with self.subTest(optimize=True, dim3=True):
|
||||
self._scatternd_of_shape_optimize_cuda(True, True, TensorProto.FLOAT)
|
||||
self._scatternd_of_shape_optimize_cuda(False, False, TensorProto.FLOAT)
|
||||
self._scatternd_of_shape_optimize_cuda(False, True, TensorProto.FLOAT)
|
||||
with self.subTest(optimize=True, dim3=False):
|
||||
self._scatternd_of_shape_optimize_cuda(True, False, TensorProto.FLOAT)
|
||||
with self.subTest(optimize=True, dim3=True, itype=TensorProto.FLOAT16):
|
||||
self._scatternd_of_shape_optimize_cuda(True, True, TensorProto.FLOAT16)
|
||||
|
||||
@unittest.skipIf(not has_cuda(), reason="cuda not available")
|
||||
def test_scatternd_of_shape_standalone_cuda(self):
|
||||
self._scatternd_of_shape_cuda("add", 0, TensorProto.FLOAT)
|
||||
self._scatternd_of_shape_cuda("add", 0, TensorProto.FLOAT16)
|
||||
self._scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT)
|
||||
self._scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT16)
|
||||
|
||||
def _masked_scatternd_of_shape_cuda(self, reduction, line, itype, big):
|
||||
dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
|
||||
|
||||
model1 = helper.make_model(
|
||||
helper.make_graph(
|
||||
[
|
||||
helper.make_node("Equal", ["indices", "mone"], ["masked_indices"]),
|
||||
helper.make_node(
|
||||
"Where",
|
||||
["masked_indices", "zero", "updates"],
|
||||
["masked_updates"],
|
||||
),
|
||||
helper.make_node(
|
||||
"ScatterND",
|
||||
inputs=["data", "indices", "masked_updates"],
|
||||
outputs=["y"],
|
||||
reduction=reduction,
|
||||
),
|
||||
],
|
||||
"nd",
|
||||
[
|
||||
helper.make_tensor_value_info("data", itype, [None, None]),
|
||||
helper.make_tensor_value_info("indices", TensorProto.INT64, [None, None, 1]),
|
||||
helper.make_tensor_value_info("updates", itype, [None, None, None]),
|
||||
],
|
||||
[helper.make_tensor_value_info("y", itype, [None, None])],
|
||||
[
|
||||
numpy_helper.from_array(np.array([-1], dtype=np.int64), name="mone"),
|
||||
numpy_helper.from_array(np.array([0], dtype=dtype), name="zero"),
|
||||
],
|
||||
),
|
||||
opset_imports=[helper.make_opsetid("", 18)],
|
||||
ir_version=9,
|
||||
)
|
||||
|
||||
model2 = helper.make_model(
|
||||
helper.make_graph(
|
||||
[
|
||||
helper.make_node(
|
||||
"MaskedScatterNDOfShape",
|
||||
inputs=["shape", "indices", "updates"],
|
||||
outputs=["y"],
|
||||
reduction=reduction,
|
||||
maskedValue=-1,
|
||||
domain="ai.onnx.contrib",
|
||||
)
|
||||
],
|
||||
"nd",
|
||||
[
|
||||
helper.make_tensor_value_info("shape", TensorProto.INT64, [None]),
|
||||
helper.make_tensor_value_info("indices", TensorProto.INT64, [None, None, 1]),
|
||||
helper.make_tensor_value_info("updates", itype, [None, None, None]),
|
||||
],
|
||||
[helper.make_tensor_value_info("y", itype, [None, None])],
|
||||
),
|
||||
opset_imports=[
|
||||
helper.make_opsetid("", 18),
|
||||
helper.make_opsetid("ai.onnx.contrib", 1),
|
||||
],
|
||||
ir_version=9,
|
||||
)
|
||||
|
||||
if big:
|
||||
data = np.zeros((2048, 4096), dtype=dtype)
|
||||
indices = np.ones((2, 1024), dtype=np.int64)
|
||||
indices = indices[..., np.newaxis]
|
||||
shape = tuple(indices.shape[:2]) + (data.shape[-1],)
|
||||
updates = (np.arange(np.prod(shape)).reshape(shape) / np.prod(shape)).astype(dtype)
|
||||
else:
|
||||
data = np.zeros((32, 16), dtype=dtype)
|
||||
indices = np.array(
|
||||
[
|
||||
[0, 1, 2],
|
||||
[2, 3, 4],
|
||||
[-1, 30, 31],
|
||||
[-1, 7, 8],
|
||||
[10, 11, -1],
|
||||
[20, -1, 21],
|
||||
],
|
||||
dtype=np.int64,
|
||||
)
|
||||
indices = indices[..., np.newaxis]
|
||||
shape = (6, 3, data.shape[-1])
|
||||
updates = (np.arange(np.prod(shape)).reshape(shape) + 1).astype(dtype)
|
||||
|
||||
feeds1 = dict(data=data, indices=indices, updates=updates)
|
||||
feeds2 = dict(shape=np.array(data.shape, dtype=np.int64), indices=indices, updates=updates)
|
||||
ref = ReferenceEvaluator(model1)
|
||||
expected = ref.run(None, feeds1)[0]
|
||||
|
||||
opts = _ort.SessionOptions()
|
||||
opts.register_custom_ops_library(_get_library_path())
|
||||
# opts.log_severity_level = 0
|
||||
# opts.log_verbosity_level = 0
|
||||
sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"])
|
||||
got = sess.run(None, feeds2)[0]
|
||||
assert_almost_equal(expected.tolist(), got.tolist())
|
||||
|
||||
@unittest.skipIf(not has_cuda(), reason="cuda not available")
|
||||
def test_masked_scatternd_of_shape_standalone_cuda_small(self):
|
||||
self._masked_scatternd_of_shape_cuda("add", 0, TensorProto.FLOAT, False)
|
||||
self._masked_scatternd_of_shape_cuda("add", 0, TensorProto.FLOAT16, False)
|
||||
self._masked_scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT, False)
|
||||
self._masked_scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT16, False)
|
||||
|
||||
@unittest.skipIf(not has_cuda(), reason="cuda not available")
|
||||
def test_masked_scatternd_of_shape_standalone_cuda_big(self):
|
||||
self._masked_scatternd_of_shape_cuda("add", 0, TensorProto.FLOAT, True)
|
||||
self._masked_scatternd_of_shape_cuda("add", 0, TensorProto.FLOAT16, True)
|
||||
self._masked_scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT, True)
|
||||
self._masked_scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT16, True)
|
||||
|
||||
def _transpose_cast_cuda(self, itype):
|
||||
dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
|
||||
itype2 = TensorProto.FLOAT if itype == TensorProto.FLOAT16 else TensorProto.FLOAT16
|
||||
|
|
Загрузка…
Ссылка в новой задаче