Add custom kernel ScatterNDOfShape (#705)
* first draft * clang * Draft for ScatterNFOfShape * fix build * disable test when cuda is missing * fix implementation * update test * add MaskedScatterNdOfShape * fix merge conflicts
This commit is contained in:
Родитель
79f3b048d4
Коммит
f5055466d5
|
@ -7,6 +7,7 @@
|
||||||
#include "cuda/add_mul.h"
|
#include "cuda/add_mul.h"
|
||||||
#include "cuda/fast_gelu.h"
|
#include "cuda/fast_gelu.h"
|
||||||
#include "cuda/negxplus1.h"
|
#include "cuda/negxplus1.h"
|
||||||
|
#include "cuda/scatter_nd_of_shape.h"
|
||||||
#include "cuda/transpose_cast.h"
|
#include "cuda/transpose_cast.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -29,15 +30,19 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
|
||||||
,
|
,
|
||||||
CustomCudaStructV2("AddSharedInput", AddSharedInputFloat32Type),
|
CustomCudaStructV2("AddSharedInput", AddSharedInputFloat32Type),
|
||||||
CustomCudaStructV2("FastGelu", contrib::FastGelu<float>),
|
CustomCudaStructV2("FastGelu", contrib::FastGelu<float>),
|
||||||
|
CustomCudaStructV2("MaskedScatterNDOfShape", contrib::MaskedScatterNDOfShape<float>),
|
||||||
CustomCudaStructV2("MulSharedInput", MulSharedInputFloat32Type),
|
CustomCudaStructV2("MulSharedInput", MulSharedInputFloat32Type),
|
||||||
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<float>),
|
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<float>),
|
||||||
|
CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape<float>),
|
||||||
#if ORT_API_VERSION >= 16
|
#if ORT_API_VERSION >= 16
|
||||||
|
|
||||||
CustomCudaStructV2("AddSharedInput", AddSharedInputFloat16Type),
|
CustomCudaStructV2("AddSharedInput", AddSharedInputFloat16Type),
|
||||||
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::MFloat16>),
|
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::MFloat16>),
|
||||||
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::BFloat16>),
|
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::BFloat16>),
|
||||||
|
CustomCudaStructV2("MaskedScatterNDOfShape", contrib::MaskedScatterNDOfShape<ortc::MFloat16>),
|
||||||
CustomCudaStructV2("MulSharedInput", MulSharedInputFloat16Type),
|
CustomCudaStructV2("MulSharedInput", MulSharedInputFloat16Type),
|
||||||
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<ortc::MFloat16>),
|
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<ortc::MFloat16>),
|
||||||
|
CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape<ortc::MFloat16>),
|
||||||
CustomCudaStructV2("Transpose2DCastFP16", Transpose2DCastFloat32ToFloat16Type),
|
CustomCudaStructV2("Transpose2DCastFP16", Transpose2DCastFloat32ToFloat16Type),
|
||||||
CustomCudaStructV2("Transpose2DCastFP32", Transpose2DCastFloat16ToFloat32Type)
|
CustomCudaStructV2("Transpose2DCastFP32", Transpose2DCastFloat16ToFloat32Type)
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -0,0 +1,143 @@
|
||||||
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
// Licensed under the MIT License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include "ocos.h"
|
||||||
|
#include "string_utils.h"
|
||||||
|
#include "scatter_nd_of_shape_impl.cuh"
|
||||||
|
|
||||||
|
namespace contrib {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct ScatterNDOfShape {
|
||||||
|
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
|
||||||
|
std::string value;
|
||||||
|
OrtStatusPtr status = OrtW::GetOpAttribute(info, "reduction", value);
|
||||||
|
if (status != nullptr)
|
||||||
|
return status;
|
||||||
|
|
||||||
|
if (value == "add")
|
||||||
|
reduction_ = ScatterReduction::Add;
|
||||||
|
else if (value == "mul")
|
||||||
|
reduction_ = ScatterReduction::Mul;
|
||||||
|
else if (value == "min")
|
||||||
|
reduction_ = ScatterReduction::Min;
|
||||||
|
else if (value == "max")
|
||||||
|
reduction_ = ScatterReduction::Max;
|
||||||
|
else
|
||||||
|
ORTX_CXX_API_THROW("Unexpected reduction, only Add is implemented.", ORT_RUNTIME_EXCEPTION);
|
||||||
|
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
OrtStatusPtr Compute(Ort::Custom::CUDAKernelContext* ctx,
|
||||||
|
const ortc::Tensor<int64_t>& output_shape,
|
||||||
|
const ortc::Tensor<int64_t>& indices,
|
||||||
|
const ortc::Tensor<T>& updates,
|
||||||
|
ortc::Tensor<T>& output) const {
|
||||||
|
auto& output_shape_shape = output_shape.Shape();
|
||||||
|
auto& indices_shape = indices.Shape();
|
||||||
|
auto& updates_shape = updates.Shape();
|
||||||
|
|
||||||
|
if (output_shape_shape.size() != 1 || output_shape_shape[0] == 0) {
|
||||||
|
ORTX_CXX_API_THROW("output shape must be a 1D tensor", ORT_RUNTIME_EXCEPTION);
|
||||||
|
}
|
||||||
|
if (indices_shape[indices_shape.size() - 1] != 1) {
|
||||||
|
ORTX_CXX_API_THROW("last dimension of the indices tensor should be one", ORT_RUNTIME_EXCEPTION);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t* shape_data = output_shape.Data(); // CPU pointer
|
||||||
|
const int64_t* indices_data = indices.Data(); // GPU pointer
|
||||||
|
const T* updates_data = updates.Data(); // GPU pointer
|
||||||
|
std::vector<int64_t> voutput_shape(shape_data, shape_data + output_shape_shape[0]);
|
||||||
|
T* output_data = output.Allocate(voutput_shape); // GPU pointer
|
||||||
|
LaunchScatterNDOfShapeKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
|
||||||
|
voutput_shape,
|
||||||
|
indices_shape,
|
||||||
|
indices_data,
|
||||||
|
updates_data,
|
||||||
|
output_data,
|
||||||
|
reduction_);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
static OrtMemType GetInputMemoryType(size_t input_index) {
|
||||||
|
if (input_index == 0) // shape
|
||||||
|
return OrtMemType::OrtMemTypeCPUInput;
|
||||||
|
return OrtMemType::OrtMemTypeDefault;
|
||||||
|
}
|
||||||
|
|
||||||
|
ScatterReduction reduction_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct MaskedScatterNDOfShape {
|
||||||
|
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
|
||||||
|
std::string value;
|
||||||
|
OrtStatusPtr status = OrtW::GetOpAttribute(info, "reduction", value);
|
||||||
|
if (status != nullptr)
|
||||||
|
return status;
|
||||||
|
|
||||||
|
if (value == "add")
|
||||||
|
reduction_ = ScatterReduction::Add;
|
||||||
|
else if (value == "mul")
|
||||||
|
reduction_ = ScatterReduction::Mul;
|
||||||
|
else if (value == "min")
|
||||||
|
reduction_ = ScatterReduction::Min;
|
||||||
|
else if (value == "max")
|
||||||
|
reduction_ = ScatterReduction::Max;
|
||||||
|
else
|
||||||
|
ORTX_CXX_API_THROW("Unexpected reduction, only Add is implemented.", ORT_RUNTIME_EXCEPTION);
|
||||||
|
|
||||||
|
status = OrtW::GetOpAttribute(info, "maskedValue", masked_value_);
|
||||||
|
if (status != nullptr)
|
||||||
|
return status;
|
||||||
|
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
OrtStatusPtr Compute(Ort::Custom::CUDAKernelContext* ctx,
|
||||||
|
const ortc::Tensor<int64_t>& output_shape,
|
||||||
|
const ortc::Tensor<int64_t>& indices,
|
||||||
|
const ortc::Tensor<T>& updates,
|
||||||
|
ortc::Tensor<T>& output) const {
|
||||||
|
auto& output_shape_shape = output_shape.Shape();
|
||||||
|
auto& indices_shape = indices.Shape();
|
||||||
|
auto& updates_shape = updates.Shape();
|
||||||
|
|
||||||
|
if (output_shape_shape.size() != 1 || output_shape_shape[0] == 0) {
|
||||||
|
ORTX_CXX_API_THROW("output shape must be a 1D tensor", ORT_RUNTIME_EXCEPTION);
|
||||||
|
}
|
||||||
|
if (indices_shape[indices_shape.size() - 1] != 1) {
|
||||||
|
ORTX_CXX_API_THROW("last dimension of the indices tensor should be one", ORT_RUNTIME_EXCEPTION);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t* shape_data = output_shape.Data(); // CPU pointer
|
||||||
|
const int64_t* indices_data = indices.Data(); // GPU pointer
|
||||||
|
const T* updates_data = updates.Data(); // GPU pointer
|
||||||
|
std::vector<int64_t> voutput_shape(shape_data, shape_data + output_shape_shape[0]);
|
||||||
|
T* output_data = output.Allocate(voutput_shape); // GPU pointer
|
||||||
|
LaunchMaskedScatterNDOfShapeKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
|
||||||
|
voutput_shape,
|
||||||
|
indices_shape,
|
||||||
|
indices_data,
|
||||||
|
updates_data,
|
||||||
|
output_data,
|
||||||
|
reduction_,
|
||||||
|
masked_value_);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
static OrtMemType GetInputMemoryType(size_t input_index) {
|
||||||
|
if (input_index == 0) // shape
|
||||||
|
return OrtMemType::OrtMemTypeCPUInput;
|
||||||
|
return OrtMemType::OrtMemTypeDefault;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
ScatterReduction reduction_;
|
||||||
|
int64_t masked_value_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace contrib
|
|
@ -0,0 +1,264 @@
|
||||||
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
// Licensed under the MIT License.
|
||||||
|
|
||||||
|
#include "device_prop.cuh"
|
||||||
|
#include "utils.cuh"
|
||||||
|
#include "scatter_nd_of_shape_impl.cuh"
|
||||||
|
#include "cuda_type.h"
|
||||||
|
|
||||||
|
namespace contrib {
|
||||||
|
|
||||||
|
#define _ENFORCE(cond, msg) \
|
||||||
|
if (!(cond)) ORTX_CXX_API_THROW(msg, ORT_RUNTIME_EXCEPTION);
|
||||||
|
|
||||||
|
#ifndef HIP_LONG
|
||||||
|
#define HIP_LONG int32_t
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifndef CUDA_LONG
|
||||||
|
#define CUDA_LONG int32_t
|
||||||
|
#endif
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ __forceinline__ void _add_inplace(T& x, const T a) { x += a; }
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ __forceinline__ void _add_inplace(half& x, const half a) {
|
||||||
|
#if __CUDA_ARCH__ < 700
|
||||||
|
x = __float2half(__half2float(x) + __half2float(a));
|
||||||
|
#else
|
||||||
|
x += a;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void
|
||||||
|
addition_inplace_kernel(T* __restrict__ output_data, const int64_t* __restrict__ indices_data,
|
||||||
|
const T* __restrict__ updates_data, const CUDA_LONG indice_size,
|
||||||
|
const CUDA_LONG nrows, const CUDA_LONG stride) {
|
||||||
|
HIP_LONG id = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
|
if (id >= stride)
|
||||||
|
return;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < nrows; ++i) {
|
||||||
|
output_data[i * stride + id] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t index;
|
||||||
|
for (size_t i = 0; i < indice_size; ++i) {
|
||||||
|
index = (indices_data[i] + nrows) % nrows;
|
||||||
|
_add_inplace(output_data[index * stride + id], updates_data[i * stride + id]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void masked_addition_inplace_kernel(T *__restrict__ output_data,
|
||||||
|
const int64_t *__restrict__ indices_data,
|
||||||
|
const T *__restrict__ updates_data,
|
||||||
|
const CUDA_LONG indice_size,
|
||||||
|
const CUDA_LONG nrows, const CUDA_LONG stride,
|
||||||
|
const int64_t masked_value) {
|
||||||
|
auto id = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
|
if (id >= stride)
|
||||||
|
return;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < nrows; ++i) {
|
||||||
|
output_data[i * stride + id] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < indice_size; ++i) {
|
||||||
|
if (indices_data[i] == masked_value)
|
||||||
|
continue;
|
||||||
|
_add_inplace(output_data[indices_data[i] * stride + id], updates_data[i * stride + id]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int NTHREAD>
|
||||||
|
__global__ void masked_addition_inplace_kernelN(T *__restrict__ output_data,
|
||||||
|
const int64_t *__restrict__ indices_data,
|
||||||
|
const T *__restrict__ updates_data,
|
||||||
|
const CUDA_LONG indice_size,
|
||||||
|
const CUDA_LONG nrows, const CUDA_LONG stride,
|
||||||
|
const int64_t masked_value) {
|
||||||
|
__shared__ int64_t shared_indices[NTHREAD];
|
||||||
|
|
||||||
|
CUDA_LONG tid = threadIdx.x;
|
||||||
|
CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < nrows; ++i) {
|
||||||
|
output_data[i * stride + id] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int begin = 0;
|
||||||
|
int end = std::min(begin + NTHREAD, indice_size);
|
||||||
|
while (begin < end && (end == begin + NTHREAD)) {
|
||||||
|
shared_indices[tid] = indices_data[tid + begin];
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (size_t i = begin; i < end; ++i) {
|
||||||
|
if (shared_indices[tid] == masked_value)
|
||||||
|
continue;
|
||||||
|
_add_inplace(output_data[shared_indices[tid] * stride + id],
|
||||||
|
updates_data[i * stride + id]);
|
||||||
|
}
|
||||||
|
|
||||||
|
begin = end;
|
||||||
|
end = std::min(begin + NTHREAD, indice_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = begin; i < indice_size; ++i) {
|
||||||
|
if (indices_data[i] == masked_value)
|
||||||
|
continue;
|
||||||
|
_add_inplace(output_data[indices_data[i] * stride + id], updates_data[i * stride + id]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class NTYPE>
|
||||||
|
NTYPE flattened_dimension(const std::vector<NTYPE>& values, size_t first = 0) {
|
||||||
|
NTYPE r = 1;
|
||||||
|
for (auto it = values.begin() + first; it != values.end(); ++it)
|
||||||
|
r *= *it;
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
cudaError_t ScatterNDOfShapeKernel(cudaStream_t stream,
|
||||||
|
const std::vector<int64_t>& output_shape,
|
||||||
|
const std::vector<int64_t>& indices_shape,
|
||||||
|
const int64_t* indices_data,
|
||||||
|
const T* updates_data,
|
||||||
|
T* output_data,
|
||||||
|
ScatterReduction reduction) {
|
||||||
|
if (reduction != ScatterReduction::Add)
|
||||||
|
ORTX_CXX_API_THROW("Only reduction 'add' is implemented.", ORT_RUNTIME_EXCEPTION);
|
||||||
|
size_t indice_size = static_cast<size_t>(flattened_dimension(indices_shape));
|
||||||
|
size_t output_size = static_cast<size_t>(flattened_dimension(output_shape));
|
||||||
|
size_t rank = output_shape.size() - indices_shape.size();
|
||||||
|
size_t stride = static_cast<size_t>(flattened_dimension(output_shape, output_shape.size() - 1 - rank));
|
||||||
|
size_t nrows = output_size / stride;
|
||||||
|
|
||||||
|
int threads_per_block = 256;
|
||||||
|
int blocks_per_grid = (stride + threads_per_block - 1) / threads_per_block;
|
||||||
|
|
||||||
|
dim3 threads(threads_per_block);
|
||||||
|
dim3 blocks(blocks_per_grid);
|
||||||
|
using TT = typename CudaT<T>::MappedType;
|
||||||
|
addition_inplace_kernel<TT><<<blocks, threads, 0, stream>>>(reinterpret_cast<TT*>(output_data), indices_data,
|
||||||
|
reinterpret_cast<const TT*>(updates_data),
|
||||||
|
indice_size, nrows, stride);
|
||||||
|
return cudaGetLastError();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
cudaError_t MaskedScatterNDOfShapeKernel(cudaStream_t stream, const std::vector<int64_t> &input_shape,
|
||||||
|
const std::vector<int64_t> &indices_shape,
|
||||||
|
const int64_t *indices_data, const T *updates_data,
|
||||||
|
T *output_data,
|
||||||
|
ScatterReduction reduction, int64_t masked_value) {
|
||||||
|
if (reduction != ScatterReduction::Add)
|
||||||
|
ORTX_CXX_API_THROW("Only reduction 'add' is implemented.", ORT_RUNTIME_EXCEPTION);
|
||||||
|
size_t indice_size = static_cast<size_t>(flattened_dimension(indices_shape));
|
||||||
|
size_t input_size = static_cast<size_t>(flattened_dimension(input_shape));
|
||||||
|
size_t stride = input_shape[input_shape.size() - 1];
|
||||||
|
size_t nrows = input_size / stride;
|
||||||
|
|
||||||
|
std::vector<size_t> next_batch(indice_size);
|
||||||
|
std::vector<uint8_t> processed(input_shape[0], 0);
|
||||||
|
std::vector<uint8_t> processed_once(input_shape[0], 0);
|
||||||
|
|
||||||
|
int threads_per_block = 256;
|
||||||
|
bool split = stride / threads_per_block <= 32;
|
||||||
|
|
||||||
|
int blocks_per_grid = (stride + threads_per_block - 1) / threads_per_block;
|
||||||
|
dim3 threads(threads_per_block);
|
||||||
|
dim3 blocks(blocks_per_grid);
|
||||||
|
|
||||||
|
using TT = typename CudaT<T>::MappedType;
|
||||||
|
|
||||||
|
if (split && stride >= 256 && threads_per_block == 256) {
|
||||||
|
masked_addition_inplace_kernelN<TT, 256><<<blocks, threads, 0, stream>>>(
|
||||||
|
reinterpret_cast<TT*>(output_data), indices_data,
|
||||||
|
reinterpret_cast<const TT*>(updates_data),
|
||||||
|
indice_size, nrows, stride, masked_value);
|
||||||
|
} else {
|
||||||
|
masked_addition_inplace_kernel<TT><<<blocks, threads, 0, stream>>>(
|
||||||
|
reinterpret_cast<TT*>(output_data), indices_data,
|
||||||
|
reinterpret_cast<const TT*>(updates_data),
|
||||||
|
indice_size, nrows, stride, masked_value);
|
||||||
|
}
|
||||||
|
return cudaGetLastError();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
cudaError_t LaunchScatterNDOfShapeKernel<float>(cudaStream_t stream,
|
||||||
|
const std::vector<int64_t>& output_shape,
|
||||||
|
const std::vector<int64_t>& indices_shape,
|
||||||
|
const int64_t* indices,
|
||||||
|
const float* updates,
|
||||||
|
float* output,
|
||||||
|
ScatterReduction reduction) {
|
||||||
|
return ScatterNDOfShapeKernel(stream,
|
||||||
|
output_shape,
|
||||||
|
indices_shape,
|
||||||
|
indices,
|
||||||
|
updates,
|
||||||
|
output,
|
||||||
|
reduction);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
cudaError_t LaunchScatterNDOfShapeKernel<ortc::MFloat16>(cudaStream_t stream,
|
||||||
|
const std::vector<int64_t>& output_shape,
|
||||||
|
const std::vector<int64_t>& indices_shape,
|
||||||
|
const int64_t* indices,
|
||||||
|
const ortc::MFloat16* updates,
|
||||||
|
ortc::MFloat16* output,
|
||||||
|
ScatterReduction reduction) {
|
||||||
|
return ScatterNDOfShapeKernel(stream,
|
||||||
|
output_shape,
|
||||||
|
indices_shape,
|
||||||
|
indices,
|
||||||
|
updates,
|
||||||
|
output,
|
||||||
|
reduction);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
cudaError_t LaunchMaskedScatterNDOfShapeKernel<float>(cudaStream_t stream,
|
||||||
|
const std::vector<int64_t>& output_shape,
|
||||||
|
const std::vector<int64_t>& indices_shape,
|
||||||
|
const int64_t* indices,
|
||||||
|
const float* updates,
|
||||||
|
float* output,
|
||||||
|
ScatterReduction reduction,
|
||||||
|
int64_t masked_value) {
|
||||||
|
return MaskedScatterNDOfShapeKernel(stream,
|
||||||
|
output_shape,
|
||||||
|
indices_shape,
|
||||||
|
indices,
|
||||||
|
updates,
|
||||||
|
output,
|
||||||
|
reduction,
|
||||||
|
masked_value);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
cudaError_t LaunchMaskedScatterNDOfShapeKernel<ortc::MFloat16>(cudaStream_t stream,
|
||||||
|
const std::vector<int64_t>& output_shape,
|
||||||
|
const std::vector<int64_t>& indices_shape,
|
||||||
|
const int64_t* indices,
|
||||||
|
const ortc::MFloat16* updates,
|
||||||
|
ortc::MFloat16* output,
|
||||||
|
ScatterReduction reduction,
|
||||||
|
int64_t masked_value) {
|
||||||
|
return MaskedScatterNDOfShapeKernel(stream,
|
||||||
|
output_shape,
|
||||||
|
indices_shape,
|
||||||
|
indices,
|
||||||
|
updates,
|
||||||
|
output,
|
||||||
|
reduction,
|
||||||
|
masked_value);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace contrib
|
|
@ -0,0 +1,37 @@
|
||||||
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
// Licensed under the MIT License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
|
namespace contrib {
|
||||||
|
|
||||||
|
enum class ScatterReduction : int {
|
||||||
|
None = 0,
|
||||||
|
Add = 1,
|
||||||
|
Mul = 2,
|
||||||
|
Min = 3,
|
||||||
|
Max = 4,
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
cudaError_t LaunchScatterNDOfShapeKernel(cudaStream_t stream,
|
||||||
|
const std::vector<int64_t>& output_shape,
|
||||||
|
const std::vector<int64_t>& indices_shape,
|
||||||
|
const int64_t* indices,
|
||||||
|
const T* updates,
|
||||||
|
T* output,
|
||||||
|
ScatterReduction reduction);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
cudaError_t LaunchMaskedScatterNDOfShapeKernel(cudaStream_t stream,
|
||||||
|
const std::vector<int64_t>& output_shape,
|
||||||
|
const std::vector<int64_t>& indices_shape,
|
||||||
|
const int64_t* indices,
|
||||||
|
const T* updates,
|
||||||
|
T* output,
|
||||||
|
ScatterReduction reduction,
|
||||||
|
int64_t masked_value);
|
||||||
|
|
||||||
|
} // namespace contrib
|
|
@ -4,6 +4,7 @@ from numpy.testing import assert_almost_equal
|
||||||
from onnx import helper, numpy_helper, onnx_pb as onnx_proto, TensorProto
|
from onnx import helper, numpy_helper, onnx_pb as onnx_proto, TensorProto
|
||||||
from onnx.reference import ReferenceEvaluator
|
from onnx.reference import ReferenceEvaluator
|
||||||
from onnx.reference.op_run import OpRun
|
from onnx.reference.op_run import OpRun
|
||||||
|
from onnx.reference.ops.op_scatternd import _scatter_nd_impl
|
||||||
from onnxruntime_extensions import make_onnx_model
|
from onnxruntime_extensions import make_onnx_model
|
||||||
from onnxruntime_extensions import get_library_path as _get_library_path
|
from onnxruntime_extensions import get_library_path as _get_library_path
|
||||||
|
|
||||||
|
@ -14,6 +15,15 @@ def has_cuda():
|
||||||
return "CUDAExecutionProvider" in _ort.get_available_providers()
|
return "CUDAExecutionProvider" in _ort.get_available_providers()
|
||||||
|
|
||||||
|
|
||||||
|
class ScatterNDOfShape(OpRun):
|
||||||
|
op_domain = "ai.onnx.contrib"
|
||||||
|
|
||||||
|
def _run(self, shape, indices, updates, reduction=None, strategy=None):
|
||||||
|
data = np.zeros(shape, dtype=updates.dtype)
|
||||||
|
y = _scatter_nd_impl(data, indices, updates, reduction=reduction)
|
||||||
|
return (y,)
|
||||||
|
|
||||||
|
|
||||||
class NegXPlus1(OpRun):
|
class NegXPlus1(OpRun):
|
||||||
op_domain = "ai.onnx.contrib"
|
op_domain = "ai.onnx.contrib"
|
||||||
|
|
||||||
|
@ -274,6 +284,193 @@ class TestCudaOps(unittest.TestCase):
|
||||||
shapec=(3, 2, 3),
|
shapec=(3, 2, 3),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _scatternd_of_shape_optimize_cuda(self, optimize, dim3, itype):
|
||||||
|
indices_shape = ["i", "j", 1] if dim3 else ["j", 1]
|
||||||
|
updates_shape = ["i", "j", "b"] if dim3 else ["j", "b"]
|
||||||
|
|
||||||
|
model = helper.make_model(
|
||||||
|
helper.make_graph(
|
||||||
|
[
|
||||||
|
helper.make_node(
|
||||||
|
"ScatterNDOfShape",
|
||||||
|
inputs=["shape", "indices", "updates"],
|
||||||
|
outputs=["y"],
|
||||||
|
reduction="add",
|
||||||
|
strategy="optimize" if optimize else "none",
|
||||||
|
domain="ai.onnx.contrib",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"nd",
|
||||||
|
[
|
||||||
|
helper.make_tensor_value_info("shape", TensorProto.INT64, [2]),
|
||||||
|
helper.make_tensor_value_info("indices", TensorProto.INT64, indices_shape),
|
||||||
|
helper.make_tensor_value_info("updates", itype, updates_shape),
|
||||||
|
],
|
||||||
|
[helper.make_tensor_value_info("y", itype, [None, None])],
|
||||||
|
),
|
||||||
|
opset_imports=[
|
||||||
|
helper.make_opsetid("", 18),
|
||||||
|
helper.make_opsetid("ai.onnx.contrib", 1),
|
||||||
|
],
|
||||||
|
ir_version=9,
|
||||||
|
)
|
||||||
|
|
||||||
|
if dim3:
|
||||||
|
shape = (128, 1024)
|
||||||
|
indices = np.zeros((2, 64, 1)).astype(np.int64)
|
||||||
|
indices[:, ::2, 0] = 87
|
||||||
|
indices[:, ::3, 0] = 85
|
||||||
|
updates = np.ones((2, 64, 1024)).astype(np.float32)
|
||||||
|
else:
|
||||||
|
shape = (128, 1024)
|
||||||
|
indices = np.zeros((128, 1)).astype(np.int64)
|
||||||
|
indices[::2, 0] = 87
|
||||||
|
indices[::3, 0] = 85
|
||||||
|
updates = np.ones((128, 1024)).astype(np.float32)
|
||||||
|
if itype != 1:
|
||||||
|
updates = updates.astype(np.float16)
|
||||||
|
feeds = dict(shape=np.array(shape, dtype=np.int64), indices=indices, updates=updates)
|
||||||
|
|
||||||
|
ref = ReferenceEvaluator(model, new_ops=[ScatterNDOfShape])
|
||||||
|
expected = ref.run(None, feeds)[0]
|
||||||
|
|
||||||
|
opts = _ort.SessionOptions()
|
||||||
|
opts.register_custom_ops_library(_get_library_path())
|
||||||
|
sess = _ort.InferenceSession(model.SerializeToString(), opts, providers=["CUDAExecutionProvider"])
|
||||||
|
ro = None
|
||||||
|
got = sess.run(None, feeds, ro)[0]
|
||||||
|
self.assertEqual(expected.tolist(), got.tolist())
|
||||||
|
|
||||||
|
@unittest.skipIf(not has_cuda(), reason="cuda not available")
|
||||||
|
def test_scatternd_of_shape_optimize_cuda(self):
|
||||||
|
with self.subTest(optimize=True, dim3=True):
|
||||||
|
self._scatternd_of_shape_optimize_cuda(True, True, TensorProto.FLOAT)
|
||||||
|
self._scatternd_of_shape_optimize_cuda(False, False, TensorProto.FLOAT)
|
||||||
|
self._scatternd_of_shape_optimize_cuda(False, True, TensorProto.FLOAT)
|
||||||
|
with self.subTest(optimize=True, dim3=False):
|
||||||
|
self._scatternd_of_shape_optimize_cuda(True, False, TensorProto.FLOAT)
|
||||||
|
with self.subTest(optimize=True, dim3=True, itype=TensorProto.FLOAT16):
|
||||||
|
self._scatternd_of_shape_optimize_cuda(True, True, TensorProto.FLOAT16)
|
||||||
|
|
||||||
|
@unittest.skipIf(not has_cuda(), reason="cuda not available")
|
||||||
|
def test_scatternd_of_shape_standalone_cuda(self):
|
||||||
|
self._scatternd_of_shape_cuda("add", 0, TensorProto.FLOAT)
|
||||||
|
self._scatternd_of_shape_cuda("add", 0, TensorProto.FLOAT16)
|
||||||
|
self._scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT)
|
||||||
|
self._scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT16)
|
||||||
|
|
||||||
|
def _masked_scatternd_of_shape_cuda(self, reduction, line, itype, big):
|
||||||
|
dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
|
||||||
|
|
||||||
|
model1 = helper.make_model(
|
||||||
|
helper.make_graph(
|
||||||
|
[
|
||||||
|
helper.make_node("Equal", ["indices", "mone"], ["masked_indices"]),
|
||||||
|
helper.make_node(
|
||||||
|
"Where",
|
||||||
|
["masked_indices", "zero", "updates"],
|
||||||
|
["masked_updates"],
|
||||||
|
),
|
||||||
|
helper.make_node(
|
||||||
|
"ScatterND",
|
||||||
|
inputs=["data", "indices", "masked_updates"],
|
||||||
|
outputs=["y"],
|
||||||
|
reduction=reduction,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
"nd",
|
||||||
|
[
|
||||||
|
helper.make_tensor_value_info("data", itype, [None, None]),
|
||||||
|
helper.make_tensor_value_info("indices", TensorProto.INT64, [None, None, 1]),
|
||||||
|
helper.make_tensor_value_info("updates", itype, [None, None, None]),
|
||||||
|
],
|
||||||
|
[helper.make_tensor_value_info("y", itype, [None, None])],
|
||||||
|
[
|
||||||
|
numpy_helper.from_array(np.array([-1], dtype=np.int64), name="mone"),
|
||||||
|
numpy_helper.from_array(np.array([0], dtype=dtype), name="zero"),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
opset_imports=[helper.make_opsetid("", 18)],
|
||||||
|
ir_version=9,
|
||||||
|
)
|
||||||
|
|
||||||
|
model2 = helper.make_model(
|
||||||
|
helper.make_graph(
|
||||||
|
[
|
||||||
|
helper.make_node(
|
||||||
|
"MaskedScatterNDOfShape",
|
||||||
|
inputs=["shape", "indices", "updates"],
|
||||||
|
outputs=["y"],
|
||||||
|
reduction=reduction,
|
||||||
|
maskedValue=-1,
|
||||||
|
domain="ai.onnx.contrib",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"nd",
|
||||||
|
[
|
||||||
|
helper.make_tensor_value_info("shape", TensorProto.INT64, [None]),
|
||||||
|
helper.make_tensor_value_info("indices", TensorProto.INT64, [None, None, 1]),
|
||||||
|
helper.make_tensor_value_info("updates", itype, [None, None, None]),
|
||||||
|
],
|
||||||
|
[helper.make_tensor_value_info("y", itype, [None, None])],
|
||||||
|
),
|
||||||
|
opset_imports=[
|
||||||
|
helper.make_opsetid("", 18),
|
||||||
|
helper.make_opsetid("ai.onnx.contrib", 1),
|
||||||
|
],
|
||||||
|
ir_version=9,
|
||||||
|
)
|
||||||
|
|
||||||
|
if big:
|
||||||
|
data = np.zeros((2048, 4096), dtype=dtype)
|
||||||
|
indices = np.ones((2, 1024), dtype=np.int64)
|
||||||
|
indices = indices[..., np.newaxis]
|
||||||
|
shape = tuple(indices.shape[:2]) + (data.shape[-1],)
|
||||||
|
updates = (np.arange(np.prod(shape)).reshape(shape) / np.prod(shape)).astype(dtype)
|
||||||
|
else:
|
||||||
|
data = np.zeros((32, 16), dtype=dtype)
|
||||||
|
indices = np.array(
|
||||||
|
[
|
||||||
|
[0, 1, 2],
|
||||||
|
[2, 3, 4],
|
||||||
|
[-1, 30, 31],
|
||||||
|
[-1, 7, 8],
|
||||||
|
[10, 11, -1],
|
||||||
|
[20, -1, 21],
|
||||||
|
],
|
||||||
|
dtype=np.int64,
|
||||||
|
)
|
||||||
|
indices = indices[..., np.newaxis]
|
||||||
|
shape = (6, 3, data.shape[-1])
|
||||||
|
updates = (np.arange(np.prod(shape)).reshape(shape) + 1).astype(dtype)
|
||||||
|
|
||||||
|
feeds1 = dict(data=data, indices=indices, updates=updates)
|
||||||
|
feeds2 = dict(shape=np.array(data.shape, dtype=np.int64), indices=indices, updates=updates)
|
||||||
|
ref = ReferenceEvaluator(model1)
|
||||||
|
expected = ref.run(None, feeds1)[0]
|
||||||
|
|
||||||
|
opts = _ort.SessionOptions()
|
||||||
|
opts.register_custom_ops_library(_get_library_path())
|
||||||
|
# opts.log_severity_level = 0
|
||||||
|
# opts.log_verbosity_level = 0
|
||||||
|
sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"])
|
||||||
|
got = sess.run(None, feeds2)[0]
|
||||||
|
assert_almost_equal(expected.tolist(), got.tolist())
|
||||||
|
|
||||||
|
@unittest.skipIf(not has_cuda(), reason="cuda not available")
|
||||||
|
def test_masked_scatternd_of_shape_standalone_cuda_small(self):
|
||||||
|
self._masked_scatternd_of_shape_cuda("add", 0, TensorProto.FLOAT, False)
|
||||||
|
self._masked_scatternd_of_shape_cuda("add", 0, TensorProto.FLOAT16, False)
|
||||||
|
self._masked_scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT, False)
|
||||||
|
self._masked_scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT16, False)
|
||||||
|
|
||||||
|
@unittest.skipIf(not has_cuda(), reason="cuda not available")
|
||||||
|
def test_masked_scatternd_of_shape_standalone_cuda_big(self):
|
||||||
|
self._masked_scatternd_of_shape_cuda("add", 0, TensorProto.FLOAT, True)
|
||||||
|
self._masked_scatternd_of_shape_cuda("add", 0, TensorProto.FLOAT16, True)
|
||||||
|
self._masked_scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT, True)
|
||||||
|
self._masked_scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT16, True)
|
||||||
|
|
||||||
def _transpose_cast_cuda(self, itype):
|
def _transpose_cast_cuda(self, itype):
|
||||||
dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
|
dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
|
||||||
itype2 = TensorProto.FLOAT if itype == TensorProto.FLOAT16 else TensorProto.FLOAT16
|
itype2 = TensorProto.FLOAT if itype == TensorProto.FLOAT16 else TensorProto.FLOAT16
|
||||||
|
|
Загрузка…
Ссылка в новой задаче