* cmake fast gelu

* bridge func and cuda kernel

* tune ut

* fix build warning

* fix format

* tune ut

* drop OCOS_ENABLE_CONTRIB

* tune cmake

---------

Co-authored-by: Randy Shuai <rashuai@microsoft.com>
This commit is contained in:
RandySheriffH 2023-12-05 12:59:19 -08:00 коммит произвёл GitHub
Родитель 040599b265
Коммит 9b2daf56de
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
8 изменённых файлов: 165 добавлений и 4 удалений

Просмотреть файл

@ -337,10 +337,17 @@ if(OCOS_ENABLE_MATH)
file(GLOB TARGET_SRC_MATH_CUDA "operators/math/cuda/*.*")
list(APPEND TARGET_SRC_MATH ${TARGET_SRC_MATH_CUDA})
endif()
list(APPEND TARGET_SRC ${TARGET_SRC_MATH} ${TARGET_SRC_DLIB} ${TARGET_SRC_INVERSE})
endif()
file(GLOB TARGET_SRC_CONTRIB "operators/contrib/*.cc" "operators/contrib/*.h*")
if (OCOS_USE_CUDA)
file(GLOB TARGET_SRC_CONTRIB_CUDA "operators/contrib/cuda/*.*")
list(APPEND TARGET_SRC_CONTRIB ${TARGET_SRC_CONTRIB_CUDA})
endif()
list(APPEND TARGET_SRC ${TARGET_SRC_CONTRIB})
# enable the opencv dependency if we have ops that require it
if(OCOS_ENABLE_CV2 OR OCOS_ENABLE_VISION)
set(_ENABLE_OPENCV ON)

Просмотреть файл

@ -44,6 +44,7 @@ class CuopContainer {
#define CustomCpuStructV2(name, s) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOpV2<s>(name, "CPUExecutionProvider")); }
#define CustomCudaFuncV2(name, f) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOpV2(name, "CUDAExecutionProvider", f)); }
#define CustomCudaStructV2(name, s) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOpV2<s>(name, "CUDAExecutionProvider")); }
template <typename F>
void AppendCustomOp(std::vector<std::shared_ptr<OrtCustomOp>>& ops,
@ -133,3 +134,5 @@ extern FxLoadCustomOpFactory LoadCustomOpClasses_Audio;
#if ENABLE_AZURE
extern FxLoadCustomOpFactory LoadCustomOpClasses_Azure;
#endif
extern FxLoadCustomOpFactory LoadCustomOpClasses_Contrib;

Просмотреть файл

@ -0,0 +1,20 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "ocos.h"
#ifdef USE_CUDA
#include "cuda/fast_gelu.h"
#endif
FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
static OrtOpLoader op_loader(
[]() { return nullptr; }
#ifdef USE_CUDA
,
CustomCudaStructV2("FastGelu", contrib::FastGelu<float>)
#endif
);
return op_loader.GetCustomOps();
};

Просмотреть файл

@ -0,0 +1,42 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "ocos.h"
#include "fast_gelu_impl.cuh"
namespace contrib {
template <typename T>
struct FastGelu {
OrtStatusPtr OnModelAttach(const OrtApi& /*api*/,
const OrtKernelInfo& /*info*/) {
return nullptr;
}
OrtStatusPtr Compute(const Ort::Custom::CudaContext& ctx,
const ortc::Tensor<T>& input,
std::optional<const ortc::Tensor<T>*> bias,
ortc::Tensor<T>& output) const {
const T* input_data = input.Data();
T* output_data = output.Allocate(input.Shape());
auto input_length = input.NumberOfElement();
if (0 == input_length) {
return nullptr;
}
const T* bias_data = bias.has_value()?(*bias)->Data():nullptr;
auto bias_length = bias.has_value()?(*bias)->NumberOfElement():0;
LaunchFastGeluKernel(reinterpret_cast<cudaStream_t>(ctx.cuda_stream),
input_length,
bias_length,
input_data,
bias_data,
output_data,
use_half2_);
return nullptr;
}
private:
bool use_half2_ = false; // to-do, read this from env var
};
}

Просмотреть файл

@ -0,0 +1,41 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "fast_gelu_impl.cuh"
#include <cuda_runtime.h>
template <typename T>
__device__ __inline T _Tanh(T a);
template <>
__device__ __inline__ float _Tanh(float a) { return tanhf(a); }
constexpr float A = 0.5f;
constexpr float B = 0.7978845608028654f; // sqrt(2.0/M_PI)
constexpr float C = 0.035677408136300125f; // 0.044715 * sqrt(2.0/M_PI)
template <typename T, unsigned TPB>
__global__ void FastGeluKernel(const T a, const T b, const T c, int input_length, int bias_length,
const T* input, const T* bias, T* output) {
const int idx = blockIdx.x * TPB + threadIdx.x;
if (idx < input_length) {
const T x = input[idx];
const T in = (bias == nullptr) ? x : (T)(x + bias[idx % bias_length]);
const T cdf = a + a * _Tanh(in * (c * in * in + b));
output[idx] = in * cdf;
}
}
template <>
cudaError_t LaunchFastGeluKernel(cudaStream_t stream, int input_length, int bias_length,
const float* input, const float* bias, float* output, bool /*use_half2*/) {
constexpr int blockSize = 256;
const int gridSize = (input_length + blockSize - 1) / blockSize;
FastGeluKernel<float, blockSize><<<gridSize, blockSize, 0, stream>>>(A, B, C, input_length, bias_length,
input, bias, output);
return cudaGetLastError();
}

Просмотреть файл

@ -0,0 +1,10 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
template <typename T>
cudaError_t LaunchFastGeluKernel(cudaStream_t stream, int input_length, int bias_length,
const T* input, const T* bias, T* output, bool use_half2);

Просмотреть файл

@ -162,7 +162,8 @@ extern "C" OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options,
#if defined(ENABLE_DR_LIBS)
LoadCustomOpClasses_Audio,
#endif
LoadCustomOpClasses<>
LoadCustomOpClasses_Contrib,
LoadCustomOpClasses<>,
};
for (const auto& fx : c_factories) {

Просмотреть файл

@ -10,7 +10,7 @@ import onnxruntime as _ort
class TestCudaOps(unittest.TestCase):
@staticmethod
def _create_test_model(domain='ai.onnx.contrib'):
def _create_negpos_test_model(domain='ai.onnx.contrib'):
nodes = [
helper.make_node('Identity', ['x'], ['identity1']),
helper.make_node(
@ -32,7 +32,7 @@ class TestCudaOps(unittest.TestCase):
def test_cuda_negpos(self):
so = _ort.SessionOptions()
so.register_custom_ops_library(_get_library_path())
onnx_model = self._create_test_model()
onnx_model = self._create_negpos_test_model()
self.assertIn('op_type: "NegPos"', str(onnx_model))
sess = _ort.InferenceSession(onnx_model.SerializeToString(),
so,
@ -42,6 +42,43 @@ class TestCudaOps(unittest.TestCase):
diff = x - (neg + pos)
assert_almost_equal(diff, np.zeros(diff.shape))
@staticmethod
def _create_fastgelu_test_model(domain='ai.onnx.contrib'):
nodes = [
helper.make_node(
'FastGelu', ['x', 'bias'], ['y'],
domain=domain)
]
input0 = helper.make_tensor_value_info(
'x', onnx_proto.TensorProto.FLOAT, [])
input1 = helper.make_tensor_value_info(
'bias', onnx_proto.TensorProto.FLOAT, [])
output0 = helper.make_tensor_value_info(
'y', onnx_proto.TensorProto.FLOAT, [])
graph = helper.make_graph(nodes, 'test1', [input0, input1], [output0])
model = make_onnx_model(graph)
return model
def test_cuda_fastgelu(self):
eps = _ort.get_available_providers()
if 'CUDAExecutionProvider' in eps:
so = _ort.SessionOptions()
so.register_custom_ops_library(_get_library_path())
onnx_model = self._create_fastgelu_test_model()
self.assertIn('op_type: "FastGelu"', str(onnx_model))
sess = _ort.InferenceSession(onnx_model.SerializeToString(),
so,
providers=['CUDAExecutionProvider'])
x = np.array([0., 1., 2., 3., 4., 5.]).astype(np.float32)
bias = np.array([0.0, 0.1, 0.2, 0.3, 0.4, 0.5]).astype(np.float32)
expected_y = np.array([0., 0.9505811, 2.1696784, 3.298689, 4.399991, 5.5]).astype(np.float32)
y = sess.run(None, {'x': x, 'bias':bias})[0]
assert_almost_equal(y, expected_y)
else:
print ('CUDAExecutionProvider not available, test_cuda_fastgelu skipped.')
if __name__ == "__main__":
unittest.main()