Use of OrtxStatus in kernel NegXPlus1 (#732)
* first draft for NegXPlus1 * complete * fix unit test * rename one test * remove test if not cuda * switch to OrtxStatus --------- Co-authored-by: Wenbing Li <10278425+wenbingl@users.noreply.github.com>
This commit is contained in:
Родитель
95a49faabe
Коммит
b60df02fd0
|
@ -4,29 +4,30 @@
|
|||
#pragma once
|
||||
#include "ocos.h"
|
||||
#include "negxplus1_impl.cuh"
|
||||
#include "ortx_common.h"
|
||||
|
||||
namespace contrib {
|
||||
|
||||
template <typename T>
|
||||
struct NegXPlus1 {
|
||||
template <typename TDict>
|
||||
OrtStatusPtr OnModelAttach(const TDict& /*dict*/) {
|
||||
return nullptr;
|
||||
OrtxStatus OnModelAttach(const TDict& /*dict*/) {
|
||||
return {};
|
||||
}
|
||||
OrtStatusPtr Compute(Ort::Custom::CUDAKernelContext* ctx,
|
||||
OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx,
|
||||
const ortc::Tensor<T>& input,
|
||||
ortc::Tensor<T>& output) const {
|
||||
const T* input_data = input.Data();
|
||||
T* output_data = output.Allocate(input.Shape());
|
||||
auto input_length = input.NumberOfElement();
|
||||
if (0 == input_length) {
|
||||
return nullptr;
|
||||
return {};
|
||||
}
|
||||
LaunchNegXPlus1Kernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
|
||||
input_length,
|
||||
input_data,
|
||||
output_data);
|
||||
return nullptr;
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче