Use of OrtxStatus in kernel NegXPlus1 (#732)

* first draft for NegXPlus1

* complete

* fix unit test

* rename one test

* remove test if not cuda

* switch to OrtxStatus

---------

Co-authored-by: Wenbing Li <10278425+wenbingl@users.noreply.github.com>
This commit is contained in:
Xavier Dupré 2024-05-30 10:56:06 +02:00 коммит произвёл GitHub
Родитель 95a49faabe
Коммит b60df02fd0
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
1 изменённых файлов: 6 добавлений и 5 удалений

Просмотреть файл

@ -4,29 +4,30 @@
#pragma once
#include "ocos.h"
#include "negxplus1_impl.cuh"
#include "ortx_common.h"
namespace contrib {
template <typename T>
struct NegXPlus1 {
template <typename TDict>
OrtStatusPtr OnModelAttach(const TDict& /*dict*/) {
return nullptr;
OrtxStatus OnModelAttach(const TDict& /*dict*/) {
return {};
}
OrtStatusPtr Compute(Ort::Custom::CUDAKernelContext* ctx,
OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx,
const ortc::Tensor<T>& input,
ortc::Tensor<T>& output) const {
const T* input_data = input.Data();
T* output_data = output.Allocate(input.Shape());
auto input_length = input.NumberOfElement();
if (0 == input_length) {
return nullptr;
return {};
}
LaunchNegXPlus1Kernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
input_length,
input_data,
output_data);
return nullptr;
return {};
}
};