From b60df02fd0b0ef2065f087b20f7492fd7e214686 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 30 May 2024 10:56:06 +0200 Subject: [PATCH] Use of OrtxStatus in kernel NegXPlus1 (#732) * first draft for NegXPlus1 * complete * fix unit test * rename one test * remove test if not cuda * switch to OrtxStatus --------- Co-authored-by: Wenbing Li <10278425+wenbingl@users.noreply.github.com> --- operators/cuda/negxplus1.h | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/operators/cuda/negxplus1.h b/operators/cuda/negxplus1.h index 832da604..5460c37a 100644 --- a/operators/cuda/negxplus1.h +++ b/operators/cuda/negxplus1.h @@ -4,29 +4,30 @@ #pragma once #include "ocos.h" #include "negxplus1_impl.cuh" +#include "ortx_common.h" namespace contrib { template struct NegXPlus1 { template - OrtStatusPtr OnModelAttach(const TDict& /*dict*/) { - return nullptr; + OrtxStatus OnModelAttach(const TDict& /*dict*/) { + return {}; } - OrtStatusPtr Compute(Ort::Custom::CUDAKernelContext* ctx, + OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx, const ortc::Tensor& input, ortc::Tensor& output) const { const T* input_data = input.Data(); T* output_data = output.Allocate(input.Shape()); auto input_length = input.NumberOfElement(); if (0 == input_length) { - return nullptr; + return {}; } LaunchNegXPlus1Kernel(reinterpret_cast(ctx->GetCudaStream()), input_length, input_data, output_data); - return nullptr; + return {}; } };