Add custom kernel ScatterNDOfShape (#705)

* first draft

* clang

* Draft for ScatterNFOfShape

* fix build

* disable test when cuda is missing

* fix implementation

* update test

* add MaskedScatterNdOfShape

* fix merge conflicts
This commit is contained in:
Xavier Dupré 2024-06-11 09:59:46 +02:00 коммит произвёл GitHub
Родитель 79f3b048d4
Коммит f5055466d5
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
5 изменённых файлов: 646 добавлений и 0 удалений

Просмотреть файл

@ -7,6 +7,7 @@
#include "cuda/add_mul.h"
#include "cuda/fast_gelu.h"
#include "cuda/negxplus1.h"
#include "cuda/scatter_nd_of_shape.h"
#include "cuda/transpose_cast.h"
#endif
@ -29,15 +30,19 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
,
CustomCudaStructV2("AddSharedInput", AddSharedInputFloat32Type),
CustomCudaStructV2("FastGelu", contrib::FastGelu<float>),
CustomCudaStructV2("MaskedScatterNDOfShape", contrib::MaskedScatterNDOfShape<float>),
CustomCudaStructV2("MulSharedInput", MulSharedInputFloat32Type),
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<float>),
CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape<float>),
#if ORT_API_VERSION >= 16
CustomCudaStructV2("AddSharedInput", AddSharedInputFloat16Type),
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::MFloat16>),
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::BFloat16>),
CustomCudaStructV2("MaskedScatterNDOfShape", contrib::MaskedScatterNDOfShape<ortc::MFloat16>),
CustomCudaStructV2("MulSharedInput", MulSharedInputFloat16Type),
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<ortc::MFloat16>),
CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape<ortc::MFloat16>),
CustomCudaStructV2("Transpose2DCastFP16", Transpose2DCastFloat32ToFloat16Type),
CustomCudaStructV2("Transpose2DCastFP32", Transpose2DCastFloat16ToFloat32Type)
#endif

Просмотреть файл

@ -0,0 +1,143 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "ocos.h"
#include "string_utils.h"
#include "scatter_nd_of_shape_impl.cuh"
namespace contrib {
template <typename T>
struct ScatterNDOfShape {
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
std::string value;
OrtStatusPtr status = OrtW::GetOpAttribute(info, "reduction", value);
if (status != nullptr)
return status;
if (value == "add")
reduction_ = ScatterReduction::Add;
else if (value == "mul")
reduction_ = ScatterReduction::Mul;
else if (value == "min")
reduction_ = ScatterReduction::Min;
else if (value == "max")
reduction_ = ScatterReduction::Max;
else
ORTX_CXX_API_THROW("Unexpected reduction, only Add is implemented.", ORT_RUNTIME_EXCEPTION);
return nullptr;
}
OrtStatusPtr Compute(Ort::Custom::CUDAKernelContext* ctx,
const ortc::Tensor<int64_t>& output_shape,
const ortc::Tensor<int64_t>& indices,
const ortc::Tensor<T>& updates,
ortc::Tensor<T>& output) const {
auto& output_shape_shape = output_shape.Shape();
auto& indices_shape = indices.Shape();
auto& updates_shape = updates.Shape();
if (output_shape_shape.size() != 1 || output_shape_shape[0] == 0) {
ORTX_CXX_API_THROW("output shape must be a 1D tensor", ORT_RUNTIME_EXCEPTION);
}
if (indices_shape[indices_shape.size() - 1] != 1) {
ORTX_CXX_API_THROW("last dimension of the indices tensor should be one", ORT_RUNTIME_EXCEPTION);
}
const int64_t* shape_data = output_shape.Data(); // CPU pointer
const int64_t* indices_data = indices.Data(); // GPU pointer
const T* updates_data = updates.Data(); // GPU pointer
std::vector<int64_t> voutput_shape(shape_data, shape_data + output_shape_shape[0]);
T* output_data = output.Allocate(voutput_shape); // GPU pointer
LaunchScatterNDOfShapeKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
voutput_shape,
indices_shape,
indices_data,
updates_data,
output_data,
reduction_);
return nullptr;
}
static OrtMemType GetInputMemoryType(size_t input_index) {
if (input_index == 0) // shape
return OrtMemType::OrtMemTypeCPUInput;
return OrtMemType::OrtMemTypeDefault;
}
ScatterReduction reduction_;
};
template <typename T>
struct MaskedScatterNDOfShape {
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
std::string value;
OrtStatusPtr status = OrtW::GetOpAttribute(info, "reduction", value);
if (status != nullptr)
return status;
if (value == "add")
reduction_ = ScatterReduction::Add;
else if (value == "mul")
reduction_ = ScatterReduction::Mul;
else if (value == "min")
reduction_ = ScatterReduction::Min;
else if (value == "max")
reduction_ = ScatterReduction::Max;
else
ORTX_CXX_API_THROW("Unexpected reduction, only Add is implemented.", ORT_RUNTIME_EXCEPTION);
status = OrtW::GetOpAttribute(info, "maskedValue", masked_value_);
if (status != nullptr)
return status;
return nullptr;
}
OrtStatusPtr Compute(Ort::Custom::CUDAKernelContext* ctx,
const ortc::Tensor<int64_t>& output_shape,
const ortc::Tensor<int64_t>& indices,
const ortc::Tensor<T>& updates,
ortc::Tensor<T>& output) const {
auto& output_shape_shape = output_shape.Shape();
auto& indices_shape = indices.Shape();
auto& updates_shape = updates.Shape();
if (output_shape_shape.size() != 1 || output_shape_shape[0] == 0) {
ORTX_CXX_API_THROW("output shape must be a 1D tensor", ORT_RUNTIME_EXCEPTION);
}
if (indices_shape[indices_shape.size() - 1] != 1) {
ORTX_CXX_API_THROW("last dimension of the indices tensor should be one", ORT_RUNTIME_EXCEPTION);
}
const int64_t* shape_data = output_shape.Data(); // CPU pointer
const int64_t* indices_data = indices.Data(); // GPU pointer
const T* updates_data = updates.Data(); // GPU pointer
std::vector<int64_t> voutput_shape(shape_data, shape_data + output_shape_shape[0]);
T* output_data = output.Allocate(voutput_shape); // GPU pointer
LaunchMaskedScatterNDOfShapeKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
voutput_shape,
indices_shape,
indices_data,
updates_data,
output_data,
reduction_,
masked_value_);
return nullptr;
}
static OrtMemType GetInputMemoryType(size_t input_index) {
if (input_index == 0) // shape
return OrtMemType::OrtMemTypeCPUInput;
return OrtMemType::OrtMemTypeDefault;
}
private:
ScatterReduction reduction_;
int64_t masked_value_;
};
} // namespace contrib

Просмотреть файл

@ -0,0 +1,264 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "device_prop.cuh"
#include "utils.cuh"
#include "scatter_nd_of_shape_impl.cuh"
#include "cuda_type.h"
namespace contrib {
#define _ENFORCE(cond, msg) \
if (!(cond)) ORTX_CXX_API_THROW(msg, ORT_RUNTIME_EXCEPTION);
#ifndef HIP_LONG
#define HIP_LONG int32_t
#endif
#ifndef CUDA_LONG
#define CUDA_LONG int32_t
#endif
template <typename T>
__device__ __forceinline__ void _add_inplace(T& x, const T a) { x += a; }
template <>
__device__ __forceinline__ void _add_inplace(half& x, const half a) {
#if __CUDA_ARCH__ < 700
x = __float2half(__half2float(x) + __half2float(a));
#else
x += a;
#endif
}
template <typename T>
__global__ void
addition_inplace_kernel(T* __restrict__ output_data, const int64_t* __restrict__ indices_data,
const T* __restrict__ updates_data, const CUDA_LONG indice_size,
const CUDA_LONG nrows, const CUDA_LONG stride) {
HIP_LONG id = blockDim.x * blockIdx.x + threadIdx.x;
if (id >= stride)
return;
for (size_t i = 0; i < nrows; ++i) {
output_data[i * stride + id] = 0;
}
int64_t index;
for (size_t i = 0; i < indice_size; ++i) {
index = (indices_data[i] + nrows) % nrows;
_add_inplace(output_data[index * stride + id], updates_data[i * stride + id]);
}
}
template <typename T>
__global__ void masked_addition_inplace_kernel(T *__restrict__ output_data,
const int64_t *__restrict__ indices_data,
const T *__restrict__ updates_data,
const CUDA_LONG indice_size,
const CUDA_LONG nrows, const CUDA_LONG stride,
const int64_t masked_value) {
auto id = blockDim.x * blockIdx.x + threadIdx.x;
if (id >= stride)
return;
for (size_t i = 0; i < nrows; ++i) {
output_data[i * stride + id] = 0;
}
for (size_t i = 0; i < indice_size; ++i) {
if (indices_data[i] == masked_value)
continue;
_add_inplace(output_data[indices_data[i] * stride + id], updates_data[i * stride + id]);
}
}
template <typename T, int NTHREAD>
__global__ void masked_addition_inplace_kernelN(T *__restrict__ output_data,
const int64_t *__restrict__ indices_data,
const T *__restrict__ updates_data,
const CUDA_LONG indice_size,
const CUDA_LONG nrows, const CUDA_LONG stride,
const int64_t masked_value) {
__shared__ int64_t shared_indices[NTHREAD];
CUDA_LONG tid = threadIdx.x;
CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x;
for (size_t i = 0; i < nrows; ++i) {
output_data[i * stride + id] = 0;
}
int begin = 0;
int end = std::min(begin + NTHREAD, indice_size);
while (begin < end && (end == begin + NTHREAD)) {
shared_indices[tid] = indices_data[tid + begin];
__syncthreads();
for (size_t i = begin; i < end; ++i) {
if (shared_indices[tid] == masked_value)
continue;
_add_inplace(output_data[shared_indices[tid] * stride + id],
updates_data[i * stride + id]);
}
begin = end;
end = std::min(begin + NTHREAD, indice_size);
}
for (size_t i = begin; i < indice_size; ++i) {
if (indices_data[i] == masked_value)
continue;
_add_inplace(output_data[indices_data[i] * stride + id], updates_data[i * stride + id]);
}
}
template <class NTYPE>
NTYPE flattened_dimension(const std::vector<NTYPE>& values, size_t first = 0) {
NTYPE r = 1;
for (auto it = values.begin() + first; it != values.end(); ++it)
r *= *it;
return r;
}
template <typename T>
cudaError_t ScatterNDOfShapeKernel(cudaStream_t stream,
const std::vector<int64_t>& output_shape,
const std::vector<int64_t>& indices_shape,
const int64_t* indices_data,
const T* updates_data,
T* output_data,
ScatterReduction reduction) {
if (reduction != ScatterReduction::Add)
ORTX_CXX_API_THROW("Only reduction 'add' is implemented.", ORT_RUNTIME_EXCEPTION);
size_t indice_size = static_cast<size_t>(flattened_dimension(indices_shape));
size_t output_size = static_cast<size_t>(flattened_dimension(output_shape));
size_t rank = output_shape.size() - indices_shape.size();
size_t stride = static_cast<size_t>(flattened_dimension(output_shape, output_shape.size() - 1 - rank));
size_t nrows = output_size / stride;
int threads_per_block = 256;
int blocks_per_grid = (stride + threads_per_block - 1) / threads_per_block;
dim3 threads(threads_per_block);
dim3 blocks(blocks_per_grid);
using TT = typename CudaT<T>::MappedType;
addition_inplace_kernel<TT><<<blocks, threads, 0, stream>>>(reinterpret_cast<TT*>(output_data), indices_data,
reinterpret_cast<const TT*>(updates_data),
indice_size, nrows, stride);
return cudaGetLastError();
}
template <typename T>
cudaError_t MaskedScatterNDOfShapeKernel(cudaStream_t stream, const std::vector<int64_t> &input_shape,
const std::vector<int64_t> &indices_shape,
const int64_t *indices_data, const T *updates_data,
T *output_data,
ScatterReduction reduction, int64_t masked_value) {
if (reduction != ScatterReduction::Add)
ORTX_CXX_API_THROW("Only reduction 'add' is implemented.", ORT_RUNTIME_EXCEPTION);
size_t indice_size = static_cast<size_t>(flattened_dimension(indices_shape));
size_t input_size = static_cast<size_t>(flattened_dimension(input_shape));
size_t stride = input_shape[input_shape.size() - 1];
size_t nrows = input_size / stride;
std::vector<size_t> next_batch(indice_size);
std::vector<uint8_t> processed(input_shape[0], 0);
std::vector<uint8_t> processed_once(input_shape[0], 0);
int threads_per_block = 256;
bool split = stride / threads_per_block <= 32;
int blocks_per_grid = (stride + threads_per_block - 1) / threads_per_block;
dim3 threads(threads_per_block);
dim3 blocks(blocks_per_grid);
using TT = typename CudaT<T>::MappedType;
if (split && stride >= 256 && threads_per_block == 256) {
masked_addition_inplace_kernelN<TT, 256><<<blocks, threads, 0, stream>>>(
reinterpret_cast<TT*>(output_data), indices_data,
reinterpret_cast<const TT*>(updates_data),
indice_size, nrows, stride, masked_value);
} else {
masked_addition_inplace_kernel<TT><<<blocks, threads, 0, stream>>>(
reinterpret_cast<TT*>(output_data), indices_data,
reinterpret_cast<const TT*>(updates_data),
indice_size, nrows, stride, masked_value);
}
return cudaGetLastError();
}
template <>
cudaError_t LaunchScatterNDOfShapeKernel<float>(cudaStream_t stream,
const std::vector<int64_t>& output_shape,
const std::vector<int64_t>& indices_shape,
const int64_t* indices,
const float* updates,
float* output,
ScatterReduction reduction) {
return ScatterNDOfShapeKernel(stream,
output_shape,
indices_shape,
indices,
updates,
output,
reduction);
}
template <>
cudaError_t LaunchScatterNDOfShapeKernel<ortc::MFloat16>(cudaStream_t stream,
const std::vector<int64_t>& output_shape,
const std::vector<int64_t>& indices_shape,
const int64_t* indices,
const ortc::MFloat16* updates,
ortc::MFloat16* output,
ScatterReduction reduction) {
return ScatterNDOfShapeKernel(stream,
output_shape,
indices_shape,
indices,
updates,
output,
reduction);
}
template <>
cudaError_t LaunchMaskedScatterNDOfShapeKernel<float>(cudaStream_t stream,
const std::vector<int64_t>& output_shape,
const std::vector<int64_t>& indices_shape,
const int64_t* indices,
const float* updates,
float* output,
ScatterReduction reduction,
int64_t masked_value) {
return MaskedScatterNDOfShapeKernel(stream,
output_shape,
indices_shape,
indices,
updates,
output,
reduction,
masked_value);
}
template <>
cudaError_t LaunchMaskedScatterNDOfShapeKernel<ortc::MFloat16>(cudaStream_t stream,
const std::vector<int64_t>& output_shape,
const std::vector<int64_t>& indices_shape,
const int64_t* indices,
const ortc::MFloat16* updates,
ortc::MFloat16* output,
ScatterReduction reduction,
int64_t masked_value) {
return MaskedScatterNDOfShapeKernel(stream,
output_shape,
indices_shape,
indices,
updates,
output,
reduction,
masked_value);
}
} // namespace contrib

Просмотреть файл

@ -0,0 +1,37 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
namespace contrib {
enum class ScatterReduction : int {
None = 0,
Add = 1,
Mul = 2,
Min = 3,
Max = 4,
};
template <typename T>
cudaError_t LaunchScatterNDOfShapeKernel(cudaStream_t stream,
const std::vector<int64_t>& output_shape,
const std::vector<int64_t>& indices_shape,
const int64_t* indices,
const T* updates,
T* output,
ScatterReduction reduction);
template <typename T>
cudaError_t LaunchMaskedScatterNDOfShapeKernel(cudaStream_t stream,
const std::vector<int64_t>& output_shape,
const std::vector<int64_t>& indices_shape,
const int64_t* indices,
const T* updates,
T* output,
ScatterReduction reduction,
int64_t masked_value);
} // namespace contrib

Просмотреть файл

@ -4,6 +4,7 @@ from numpy.testing import assert_almost_equal
from onnx import helper, numpy_helper, onnx_pb as onnx_proto, TensorProto
from onnx.reference import ReferenceEvaluator
from onnx.reference.op_run import OpRun
from onnx.reference.ops.op_scatternd import _scatter_nd_impl
from onnxruntime_extensions import make_onnx_model
from onnxruntime_extensions import get_library_path as _get_library_path
@ -14,6 +15,15 @@ def has_cuda():
return "CUDAExecutionProvider" in _ort.get_available_providers()
class ScatterNDOfShape(OpRun):
op_domain = "ai.onnx.contrib"
def _run(self, shape, indices, updates, reduction=None, strategy=None):
data = np.zeros(shape, dtype=updates.dtype)
y = _scatter_nd_impl(data, indices, updates, reduction=reduction)
return (y,)
class NegXPlus1(OpRun):
op_domain = "ai.onnx.contrib"
@ -274,6 +284,193 @@ class TestCudaOps(unittest.TestCase):
shapec=(3, 2, 3),
)
def _scatternd_of_shape_optimize_cuda(self, optimize, dim3, itype):
indices_shape = ["i", "j", 1] if dim3 else ["j", 1]
updates_shape = ["i", "j", "b"] if dim3 else ["j", "b"]
model = helper.make_model(
helper.make_graph(
[
helper.make_node(
"ScatterNDOfShape",
inputs=["shape", "indices", "updates"],
outputs=["y"],
reduction="add",
strategy="optimize" if optimize else "none",
domain="ai.onnx.contrib",
)
],
"nd",
[
helper.make_tensor_value_info("shape", TensorProto.INT64, [2]),
helper.make_tensor_value_info("indices", TensorProto.INT64, indices_shape),
helper.make_tensor_value_info("updates", itype, updates_shape),
],
[helper.make_tensor_value_info("y", itype, [None, None])],
),
opset_imports=[
helper.make_opsetid("", 18),
helper.make_opsetid("ai.onnx.contrib", 1),
],
ir_version=9,
)
if dim3:
shape = (128, 1024)
indices = np.zeros((2, 64, 1)).astype(np.int64)
indices[:, ::2, 0] = 87
indices[:, ::3, 0] = 85
updates = np.ones((2, 64, 1024)).astype(np.float32)
else:
shape = (128, 1024)
indices = np.zeros((128, 1)).astype(np.int64)
indices[::2, 0] = 87
indices[::3, 0] = 85
updates = np.ones((128, 1024)).astype(np.float32)
if itype != 1:
updates = updates.astype(np.float16)
feeds = dict(shape=np.array(shape, dtype=np.int64), indices=indices, updates=updates)
ref = ReferenceEvaluator(model, new_ops=[ScatterNDOfShape])
expected = ref.run(None, feeds)[0]
opts = _ort.SessionOptions()
opts.register_custom_ops_library(_get_library_path())
sess = _ort.InferenceSession(model.SerializeToString(), opts, providers=["CUDAExecutionProvider"])
ro = None
got = sess.run(None, feeds, ro)[0]
self.assertEqual(expected.tolist(), got.tolist())
@unittest.skipIf(not has_cuda(), reason="cuda not available")
def test_scatternd_of_shape_optimize_cuda(self):
with self.subTest(optimize=True, dim3=True):
self._scatternd_of_shape_optimize_cuda(True, True, TensorProto.FLOAT)
self._scatternd_of_shape_optimize_cuda(False, False, TensorProto.FLOAT)
self._scatternd_of_shape_optimize_cuda(False, True, TensorProto.FLOAT)
with self.subTest(optimize=True, dim3=False):
self._scatternd_of_shape_optimize_cuda(True, False, TensorProto.FLOAT)
with self.subTest(optimize=True, dim3=True, itype=TensorProto.FLOAT16):
self._scatternd_of_shape_optimize_cuda(True, True, TensorProto.FLOAT16)
@unittest.skipIf(not has_cuda(), reason="cuda not available")
def test_scatternd_of_shape_standalone_cuda(self):
self._scatternd_of_shape_cuda("add", 0, TensorProto.FLOAT)
self._scatternd_of_shape_cuda("add", 0, TensorProto.FLOAT16)
self._scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT)
self._scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT16)
def _masked_scatternd_of_shape_cuda(self, reduction, line, itype, big):
dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
model1 = helper.make_model(
helper.make_graph(
[
helper.make_node("Equal", ["indices", "mone"], ["masked_indices"]),
helper.make_node(
"Where",
["masked_indices", "zero", "updates"],
["masked_updates"],
),
helper.make_node(
"ScatterND",
inputs=["data", "indices", "masked_updates"],
outputs=["y"],
reduction=reduction,
),
],
"nd",
[
helper.make_tensor_value_info("data", itype, [None, None]),
helper.make_tensor_value_info("indices", TensorProto.INT64, [None, None, 1]),
helper.make_tensor_value_info("updates", itype, [None, None, None]),
],
[helper.make_tensor_value_info("y", itype, [None, None])],
[
numpy_helper.from_array(np.array([-1], dtype=np.int64), name="mone"),
numpy_helper.from_array(np.array([0], dtype=dtype), name="zero"),
],
),
opset_imports=[helper.make_opsetid("", 18)],
ir_version=9,
)
model2 = helper.make_model(
helper.make_graph(
[
helper.make_node(
"MaskedScatterNDOfShape",
inputs=["shape", "indices", "updates"],
outputs=["y"],
reduction=reduction,
maskedValue=-1,
domain="ai.onnx.contrib",
)
],
"nd",
[
helper.make_tensor_value_info("shape", TensorProto.INT64, [None]),
helper.make_tensor_value_info("indices", TensorProto.INT64, [None, None, 1]),
helper.make_tensor_value_info("updates", itype, [None, None, None]),
],
[helper.make_tensor_value_info("y", itype, [None, None])],
),
opset_imports=[
helper.make_opsetid("", 18),
helper.make_opsetid("ai.onnx.contrib", 1),
],
ir_version=9,
)
if big:
data = np.zeros((2048, 4096), dtype=dtype)
indices = np.ones((2, 1024), dtype=np.int64)
indices = indices[..., np.newaxis]
shape = tuple(indices.shape[:2]) + (data.shape[-1],)
updates = (np.arange(np.prod(shape)).reshape(shape) / np.prod(shape)).astype(dtype)
else:
data = np.zeros((32, 16), dtype=dtype)
indices = np.array(
[
[0, 1, 2],
[2, 3, 4],
[-1, 30, 31],
[-1, 7, 8],
[10, 11, -1],
[20, -1, 21],
],
dtype=np.int64,
)
indices = indices[..., np.newaxis]
shape = (6, 3, data.shape[-1])
updates = (np.arange(np.prod(shape)).reshape(shape) + 1).astype(dtype)
feeds1 = dict(data=data, indices=indices, updates=updates)
feeds2 = dict(shape=np.array(data.shape, dtype=np.int64), indices=indices, updates=updates)
ref = ReferenceEvaluator(model1)
expected = ref.run(None, feeds1)[0]
opts = _ort.SessionOptions()
opts.register_custom_ops_library(_get_library_path())
# opts.log_severity_level = 0
# opts.log_verbosity_level = 0
sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"])
got = sess.run(None, feeds2)[0]
assert_almost_equal(expected.tolist(), got.tolist())
@unittest.skipIf(not has_cuda(), reason="cuda not available")
def test_masked_scatternd_of_shape_standalone_cuda_small(self):
self._masked_scatternd_of_shape_cuda("add", 0, TensorProto.FLOAT, False)
self._masked_scatternd_of_shape_cuda("add", 0, TensorProto.FLOAT16, False)
self._masked_scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT, False)
self._masked_scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT16, False)
@unittest.skipIf(not has_cuda(), reason="cuda not available")
def test_masked_scatternd_of_shape_standalone_cuda_big(self):
self._masked_scatternd_of_shape_cuda("add", 0, TensorProto.FLOAT, True)
self._masked_scatternd_of_shape_cuda("add", 0, TensorProto.FLOAT16, True)
self._masked_scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT, True)
self._masked_scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT16, True)
def _transpose_cast_cuda(self, itype):
dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
itype2 = TensorProto.FLOAT if itype == TensorProto.FLOAT16 else TensorProto.FLOAT16