Rashuai/bert cuda (#614)
* cmake fast gelu * bridge func and cuda kernel * tune ut * fix build warning * fix format * tune ut * drop OCOS_ENABLE_CONTRIB * tune cmake --------- Co-authored-by: Randy Shuai <rashuai@microsoft.com>
This commit is contained in:
Родитель
040599b265
Коммит
9b2daf56de
|
@ -337,10 +337,17 @@ if(OCOS_ENABLE_MATH)
|
|||
file(GLOB TARGET_SRC_MATH_CUDA "operators/math/cuda/*.*")
|
||||
list(APPEND TARGET_SRC_MATH ${TARGET_SRC_MATH_CUDA})
|
||||
endif()
|
||||
|
||||
|
||||
list(APPEND TARGET_SRC ${TARGET_SRC_MATH} ${TARGET_SRC_DLIB} ${TARGET_SRC_INVERSE})
|
||||
endif()
|
||||
|
||||
file(GLOB TARGET_SRC_CONTRIB "operators/contrib/*.cc" "operators/contrib/*.h*")
|
||||
if (OCOS_USE_CUDA)
|
||||
file(GLOB TARGET_SRC_CONTRIB_CUDA "operators/contrib/cuda/*.*")
|
||||
list(APPEND TARGET_SRC_CONTRIB ${TARGET_SRC_CONTRIB_CUDA})
|
||||
endif()
|
||||
list(APPEND TARGET_SRC ${TARGET_SRC_CONTRIB})
|
||||
|
||||
# enable the opencv dependency if we have ops that require it
|
||||
if(OCOS_ENABLE_CV2 OR OCOS_ENABLE_VISION)
|
||||
set(_ENABLE_OPENCV ON)
|
||||
|
|
|
@ -44,6 +44,7 @@ class CuopContainer {
|
|||
#define CustomCpuStructV2(name, s) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOpV2<s>(name, "CPUExecutionProvider")); }
|
||||
|
||||
#define CustomCudaFuncV2(name, f) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOpV2(name, "CUDAExecutionProvider", f)); }
|
||||
#define CustomCudaStructV2(name, s) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOpV2<s>(name, "CUDAExecutionProvider")); }
|
||||
|
||||
template <typename F>
|
||||
void AppendCustomOp(std::vector<std::shared_ptr<OrtCustomOp>>& ops,
|
||||
|
@ -133,3 +134,5 @@ extern FxLoadCustomOpFactory LoadCustomOpClasses_Audio;
|
|||
#if ENABLE_AZURE
|
||||
extern FxLoadCustomOpFactory LoadCustomOpClasses_Azure;
|
||||
#endif
|
||||
|
||||
extern FxLoadCustomOpFactory LoadCustomOpClasses_Contrib;
|
|
@ -0,0 +1,20 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "ocos.h"
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#include "cuda/fast_gelu.h"
|
||||
#endif
|
||||
|
||||
FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
|
||||
static OrtOpLoader op_loader(
|
||||
[]() { return nullptr; }
|
||||
#ifdef USE_CUDA
|
||||
,
|
||||
CustomCudaStructV2("FastGelu", contrib::FastGelu<float>)
|
||||
#endif
|
||||
);
|
||||
return op_loader.GetCustomOps();
|
||||
};
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include "ocos.h"
|
||||
#include "fast_gelu_impl.cuh"
|
||||
|
||||
namespace contrib {
|
||||
|
||||
template <typename T>
|
||||
struct FastGelu {
|
||||
OrtStatusPtr OnModelAttach(const OrtApi& /*api*/,
|
||||
const OrtKernelInfo& /*info*/) {
|
||||
return nullptr;
|
||||
}
|
||||
OrtStatusPtr Compute(const Ort::Custom::CudaContext& ctx,
|
||||
const ortc::Tensor<T>& input,
|
||||
std::optional<const ortc::Tensor<T>*> bias,
|
||||
ortc::Tensor<T>& output) const {
|
||||
const T* input_data = input.Data();
|
||||
T* output_data = output.Allocate(input.Shape());
|
||||
auto input_length = input.NumberOfElement();
|
||||
if (0 == input_length) {
|
||||
return nullptr;
|
||||
}
|
||||
const T* bias_data = bias.has_value()?(*bias)->Data():nullptr;
|
||||
auto bias_length = bias.has_value()?(*bias)->NumberOfElement():0;
|
||||
LaunchFastGeluKernel(reinterpret_cast<cudaStream_t>(ctx.cuda_stream),
|
||||
input_length,
|
||||
bias_length,
|
||||
input_data,
|
||||
bias_data,
|
||||
output_data,
|
||||
use_half2_);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
bool use_half2_ = false; // to-do, read this from env var
|
||||
};
|
||||
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "fast_gelu_impl.cuh"
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
template <typename T>
|
||||
__device__ __inline T _Tanh(T a);
|
||||
|
||||
template <>
|
||||
__device__ __inline__ float _Tanh(float a) { return tanhf(a); }
|
||||
|
||||
constexpr float A = 0.5f;
|
||||
|
||||
constexpr float B = 0.7978845608028654f; // sqrt(2.0/M_PI)
|
||||
|
||||
constexpr float C = 0.035677408136300125f; // 0.044715 * sqrt(2.0/M_PI)
|
||||
|
||||
template <typename T, unsigned TPB>
|
||||
__global__ void FastGeluKernel(const T a, const T b, const T c, int input_length, int bias_length,
|
||||
const T* input, const T* bias, T* output) {
|
||||
const int idx = blockIdx.x * TPB + threadIdx.x;
|
||||
|
||||
if (idx < input_length) {
|
||||
const T x = input[idx];
|
||||
const T in = (bias == nullptr) ? x : (T)(x + bias[idx % bias_length]);
|
||||
const T cdf = a + a * _Tanh(in * (c * in * in + b));
|
||||
output[idx] = in * cdf;
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
cudaError_t LaunchFastGeluKernel(cudaStream_t stream, int input_length, int bias_length,
|
||||
const float* input, const float* bias, float* output, bool /*use_half2*/) {
|
||||
constexpr int blockSize = 256;
|
||||
const int gridSize = (input_length + blockSize - 1) / blockSize;
|
||||
FastGeluKernel<float, blockSize><<<gridSize, blockSize, 0, stream>>>(A, B, C, input_length, bias_length,
|
||||
input, bias, output);
|
||||
|
||||
return cudaGetLastError();
|
||||
}
|
|
@ -0,0 +1,10 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
template <typename T>
|
||||
cudaError_t LaunchFastGeluKernel(cudaStream_t stream, int input_length, int bias_length,
|
||||
const T* input, const T* bias, T* output, bool use_half2);
|
|
@ -162,7 +162,8 @@ extern "C" OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options,
|
|||
#if defined(ENABLE_DR_LIBS)
|
||||
LoadCustomOpClasses_Audio,
|
||||
#endif
|
||||
LoadCustomOpClasses<>
|
||||
LoadCustomOpClasses_Contrib,
|
||||
LoadCustomOpClasses<>,
|
||||
};
|
||||
|
||||
for (const auto& fx : c_factories) {
|
||||
|
|
|
@ -10,7 +10,7 @@ import onnxruntime as _ort
|
|||
|
||||
class TestCudaOps(unittest.TestCase):
|
||||
@staticmethod
|
||||
def _create_test_model(domain='ai.onnx.contrib'):
|
||||
def _create_negpos_test_model(domain='ai.onnx.contrib'):
|
||||
nodes = [
|
||||
helper.make_node('Identity', ['x'], ['identity1']),
|
||||
helper.make_node(
|
||||
|
@ -32,7 +32,7 @@ class TestCudaOps(unittest.TestCase):
|
|||
def test_cuda_negpos(self):
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = self._create_test_model()
|
||||
onnx_model = self._create_negpos_test_model()
|
||||
self.assertIn('op_type: "NegPos"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(),
|
||||
so,
|
||||
|
@ -42,6 +42,43 @@ class TestCudaOps(unittest.TestCase):
|
|||
diff = x - (neg + pos)
|
||||
assert_almost_equal(diff, np.zeros(diff.shape))
|
||||
|
||||
@staticmethod
|
||||
def _create_fastgelu_test_model(domain='ai.onnx.contrib'):
|
||||
nodes = [
|
||||
helper.make_node(
|
||||
'FastGelu', ['x', 'bias'], ['y'],
|
||||
domain=domain)
|
||||
]
|
||||
|
||||
input0 = helper.make_tensor_value_info(
|
||||
'x', onnx_proto.TensorProto.FLOAT, [])
|
||||
input1 = helper.make_tensor_value_info(
|
||||
'bias', onnx_proto.TensorProto.FLOAT, [])
|
||||
output0 = helper.make_tensor_value_info(
|
||||
'y', onnx_proto.TensorProto.FLOAT, [])
|
||||
|
||||
graph = helper.make_graph(nodes, 'test1', [input0, input1], [output0])
|
||||
model = make_onnx_model(graph)
|
||||
return model
|
||||
|
||||
def test_cuda_fastgelu(self):
|
||||
eps = _ort.get_available_providers()
|
||||
if 'CUDAExecutionProvider' in eps:
|
||||
so = _ort.SessionOptions()
|
||||
so.register_custom_ops_library(_get_library_path())
|
||||
onnx_model = self._create_fastgelu_test_model()
|
||||
self.assertIn('op_type: "FastGelu"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(),
|
||||
so,
|
||||
providers=['CUDAExecutionProvider'])
|
||||
x = np.array([0., 1., 2., 3., 4., 5.]).astype(np.float32)
|
||||
bias = np.array([0.0, 0.1, 0.2, 0.3, 0.4, 0.5]).astype(np.float32)
|
||||
expected_y = np.array([0., 0.9505811, 2.1696784, 3.298689, 4.399991, 5.5]).astype(np.float32)
|
||||
y = sess.run(None, {'x': x, 'bias':bias})[0]
|
||||
assert_almost_equal(y, expected_y)
|
||||
else:
|
||||
print ('CUDAExecutionProvider not available, test_cuda_fastgelu skipped.')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Загрузка…
Ссылка в новой задаче