Add custom ops ReplaceZero (#739)
* Add custom ops ReplaceZero * fix merge conflicts
This commit is contained in:
Родитель
05df33b302
Коммит
bef5f07e33
|
@ -8,12 +8,12 @@
|
||||||
#include "cuda/fast_gelu.h"
|
#include "cuda/fast_gelu.h"
|
||||||
#include "cuda/mul_sigmoid.h"
|
#include "cuda/mul_sigmoid.h"
|
||||||
#include "cuda/negxplus1.h"
|
#include "cuda/negxplus1.h"
|
||||||
|
#include "cuda/replace_zero.h"
|
||||||
#include "cuda/scatter_nd_of_shape.h"
|
#include "cuda/scatter_nd_of_shape.h"
|
||||||
#include "cuda/transpose_cast.h"
|
#include "cuda/transpose_cast.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
|
FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
|
||||||
|
|
||||||
using AddSharedInputFloat32Type = typename contrib::AddOrMulSharedInput<float, true>;
|
using AddSharedInputFloat32Type = typename contrib::AddOrMulSharedInput<float, true>;
|
||||||
using MulSharedInputFloat32Type = typename contrib::AddOrMulSharedInput<float, false>;
|
using MulSharedInputFloat32Type = typename contrib::AddOrMulSharedInput<float, false>;
|
||||||
|
|
||||||
|
@ -24,7 +24,6 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
|
||||||
using Transpose2DCastFloat16ToFloat32Type = typename contrib::Transpose2DCast<ortc::MFloat16, float>;
|
using Transpose2DCastFloat16ToFloat32Type = typename contrib::Transpose2DCast<ortc::MFloat16, float>;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
static OrtOpLoader op_loader(
|
static OrtOpLoader op_loader(
|
||||||
[]() { return nullptr; }
|
[]() { return nullptr; }
|
||||||
#ifdef USE_CUDA
|
#ifdef USE_CUDA
|
||||||
|
@ -36,6 +35,7 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
|
||||||
CustomCudaStructV2("MulMulSigmoid", contrib::MulMulSigmoid<float>),
|
CustomCudaStructV2("MulMulSigmoid", contrib::MulMulSigmoid<float>),
|
||||||
CustomCudaStructV2("MulSigmoid", contrib::MulSigmoid<float>),
|
CustomCudaStructV2("MulSigmoid", contrib::MulSigmoid<float>),
|
||||||
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<float>),
|
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<float>),
|
||||||
|
CustomCudaStructV2("ReplaceZero", contrib::ReplaceZero<float>),
|
||||||
CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape<float>),
|
CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape<float>),
|
||||||
#if ORT_API_VERSION >= 16
|
#if ORT_API_VERSION >= 16
|
||||||
|
|
||||||
|
@ -47,6 +47,7 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
|
||||||
CustomCudaStructV2("MulMulSigmoid", contrib::MulMulSigmoid<ortc::MFloat16>),
|
CustomCudaStructV2("MulMulSigmoid", contrib::MulMulSigmoid<ortc::MFloat16>),
|
||||||
CustomCudaStructV2("MulSigmoid", contrib::MulSigmoid<ortc::MFloat16>),
|
CustomCudaStructV2("MulSigmoid", contrib::MulSigmoid<ortc::MFloat16>),
|
||||||
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<ortc::MFloat16>),
|
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<ortc::MFloat16>),
|
||||||
|
CustomCudaStructV2("ReplaceZero", contrib::ReplaceZero<ortc::MFloat16>),
|
||||||
CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape<ortc::MFloat16>),
|
CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape<ortc::MFloat16>),
|
||||||
CustomCudaStructV2("Transpose2DCastFP16", Transpose2DCastFloat32ToFloat16Type),
|
CustomCudaStructV2("Transpose2DCastFP16", Transpose2DCastFloat32ToFloat16Type),
|
||||||
CustomCudaStructV2("Transpose2DCastFP32", Transpose2DCastFloat16ToFloat32Type)
|
CustomCudaStructV2("Transpose2DCastFP32", Transpose2DCastFloat16ToFloat32Type)
|
||||||
|
|
|
@ -0,0 +1,51 @@
|
||||||
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
// Licensed under the MIT License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include "ocos.h"
|
||||||
|
#include "replace_zero_impl.cuh"
|
||||||
|
#include "ortx_common.h"
|
||||||
|
|
||||||
|
namespace contrib {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Y = ReplaceZero(X, by=c) is equivalent to:
|
||||||
|
*
|
||||||
|
* Y = X.copy()
|
||||||
|
* X[X == 0] = c
|
||||||
|
*
|
||||||
|
* This operation usually appears when a tensor is updated with an operator Equal and Where.
|
||||||
|
* This kernel avoids the creation of one null tensor.
|
||||||
|
*/
|
||||||
|
template <typename T>
|
||||||
|
struct ReplaceZero {
|
||||||
|
template <typename TDict>
|
||||||
|
OrtxStatus OnModelAttach(const TDict& dict) {
|
||||||
|
float default_value=0;
|
||||||
|
by_ = dict.TryToGetAttributeWithDefault("by", default_value);
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx,
|
||||||
|
const ortc::Tensor<T>& input,
|
||||||
|
ortc::Tensor<T>& output) const {
|
||||||
|
const T* input_data = input.Data();
|
||||||
|
auto input_shape = input.Shape();
|
||||||
|
T* output_data = output.Allocate(input_shape);
|
||||||
|
auto input_length = input.NumberOfElement();
|
||||||
|
if (0 == input_length) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
LaunchReplaceZeroKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
|
||||||
|
input_length,
|
||||||
|
input_data,
|
||||||
|
output_data,
|
||||||
|
by_);
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
float by_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace contrib
|
|
@ -0,0 +1,68 @@
|
||||||
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
// Licensed under the MIT License.
|
||||||
|
|
||||||
|
#include "device_prop.cuh"
|
||||||
|
#include "utils.cuh"
|
||||||
|
#include "replace_zero_impl.cuh"
|
||||||
|
#include "cuda_type.h"
|
||||||
|
|
||||||
|
#ifndef CUDA_LONG
|
||||||
|
#define CUDA_LONG int32_t
|
||||||
|
#endif
|
||||||
|
|
||||||
|
using namespace Ort::Custom;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ __inline__ T _replace_zero(const T x, const T by) {
|
||||||
|
return x == (T)0 ? by : x;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ __inline__ half _replace_zero(const half x, const half by) {
|
||||||
|
#if __CUDA_ARCH__ < 700
|
||||||
|
return __half2float(x) == 0 ? by : x;
|
||||||
|
#else
|
||||||
|
return x == (half)0 ? by : x;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void ReplaceZeroKernel(T* output_data, const T* input_data, CUDA_LONG N, const T by) {
|
||||||
|
CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
|
if (id >= N)
|
||||||
|
return;
|
||||||
|
output_data[id] = _replace_zero(input_data[id], by);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T _cast(float value) { return (T)value; }
|
||||||
|
|
||||||
|
template <>
|
||||||
|
half _cast(float value) { return __float2half(value); }
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
cudaError_t _LaunchReplaceZeroKernel(cudaStream_t stream, int input_length, const T* input_data, T* output_data, float by) {
|
||||||
|
if (input_length == 0)
|
||||||
|
return cudaGetLastError();
|
||||||
|
using TT = typename contrib::CudaT<T>::MappedType;
|
||||||
|
|
||||||
|
CUDA_LONG N = static_cast<CUDA_LONG>(input_length);
|
||||||
|
|
||||||
|
const int num_threads_per_block = 256;
|
||||||
|
const int num_elements_per_thread = (N + num_threads_per_block - 1) / num_threads_per_block;
|
||||||
|
|
||||||
|
TT cby = _cast<TT>(by);
|
||||||
|
ReplaceZeroKernel<TT><<<num_elements_per_thread, num_threads_per_block, 0, stream>>>(
|
||||||
|
reinterpret_cast<TT*>(output_data), reinterpret_cast<const TT*>(input_data), N, cby);
|
||||||
|
return cudaGetLastError();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
cudaError_t LaunchReplaceZeroKernel<float>(cudaStream_t stream, int input_length, const float* input_data, float* output_data, float by) {
|
||||||
|
return _LaunchReplaceZeroKernel(stream, input_length, input_data, output_data, by);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
cudaError_t LaunchReplaceZeroKernel<ortc::MFloat16>(cudaStream_t stream, int input_length, const ortc::MFloat16* input_data, ortc::MFloat16* output_data, float by) {
|
||||||
|
return _LaunchReplaceZeroKernel(stream, input_length, input_data, output_data, by);
|
||||||
|
}
|
|
@ -0,0 +1,9 @@
|
||||||
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
// Licensed under the MIT License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
cudaError_t LaunchReplaceZeroKernel(cudaStream_t stream, int input_length, const T* input_data, T* output_data, float by);
|
|
@ -652,6 +652,66 @@ class TestCudaOps(unittest.TestCase):
|
||||||
self._transpose_cast_cuda(TensorProto.FLOAT)
|
self._transpose_cast_cuda(TensorProto.FLOAT)
|
||||||
self._transpose_cast_cuda(TensorProto.FLOAT16)
|
self._transpose_cast_cuda(TensorProto.FLOAT16)
|
||||||
|
|
||||||
|
def _replace_zero_cuda(self, itype):
|
||||||
|
dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
|
||||||
|
model1 = helper.make_model(
|
||||||
|
helper.make_graph(
|
||||||
|
[
|
||||||
|
helper.make_node("Equal", ["X", "zero"], ["cond"]),
|
||||||
|
helper.make_node("Where", ["cond", "cst", "X"], ["Y"]),
|
||||||
|
],
|
||||||
|
"nd",
|
||||||
|
[helper.make_tensor_value_info("X", itype, [None, None, None])],
|
||||||
|
[helper.make_tensor_value_info("Y", itype, [None, None, None])],
|
||||||
|
[
|
||||||
|
numpy_helper.from_array(np.array([0], dtype=dtype), name="zero"),
|
||||||
|
numpy_helper.from_array(np.array([1.67], dtype=dtype), name="cst"),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
opset_imports=[helper.make_opsetid("", 18)],
|
||||||
|
ir_version=9,
|
||||||
|
)
|
||||||
|
|
||||||
|
model2 = helper.make_model(
|
||||||
|
helper.make_graph(
|
||||||
|
[
|
||||||
|
helper.make_node(
|
||||||
|
"ReplaceZero",
|
||||||
|
["X"],
|
||||||
|
["Y"],
|
||||||
|
by=1.67,
|
||||||
|
domain="ai.onnx.contrib",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"nd",
|
||||||
|
[helper.make_tensor_value_info("X", itype, [None, None, None])],
|
||||||
|
[helper.make_tensor_value_info("Y", itype, [None, None, None])],
|
||||||
|
),
|
||||||
|
opset_imports=[
|
||||||
|
helper.make_opsetid("", 18),
|
||||||
|
helper.make_opsetid("ai.onnx.contrib", 1),
|
||||||
|
],
|
||||||
|
ir_version=9,
|
||||||
|
)
|
||||||
|
|
||||||
|
dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
|
||||||
|
x = (np.arange(18) - 4).reshape((3, 2, 3)).astype(dtype)
|
||||||
|
|
||||||
|
feeds1 = dict(X=x)
|
||||||
|
ref = ReferenceEvaluator(model1)
|
||||||
|
expected = ref.run(None, feeds1)[0]
|
||||||
|
|
||||||
|
opts = _ort.SessionOptions()
|
||||||
|
opts.register_custom_ops_library(_get_library_path())
|
||||||
|
sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"])
|
||||||
|
got = sess.run(None, feeds1)[0]
|
||||||
|
assert_allclose(expected, got, atol=1e-5)
|
||||||
|
|
||||||
|
@unittest.skipIf(not has_cuda(), reason="cuda not available")
|
||||||
|
def test_replace_zero_cuda(self):
|
||||||
|
self._replace_zero_cuda(TensorProto.FLOAT)
|
||||||
|
self._replace_zero_cuda(TensorProto.FLOAT16)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main(verbosity=2)
|
unittest.main(verbosity=2)
|
||||||
|
|
Загрузка…
Ссылка в новой задаче