Refactor OrtxStatus to be header-only implmentation. (#720)

This commit is contained in:
Wenbing Li 2024-05-17 15:40:11 -07:00 коммит произвёл GitHub
Родитель f0ef40d074
Коммит 97ee9eb56f
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
16 изменённых файлов: 128 добавлений и 143 удалений

Просмотреть файл

@ -4,6 +4,15 @@
#include "ocos.h"
#include "narrow.h"
OrtxStatus::operator OrtStatus*() const noexcept {
if (IsOk()) {
return nullptr;
}
OrtStatus* status = OrtW::CreateStatus(Message(), OrtErrorCode::ORT_RUNTIME_EXCEPTION);
return status;
}
OrtErrorCode BaseKernel::GetErrorCodeAndRelease(OrtStatusPtr status) const noexcept {
if (status == nullptr) {
return ORT_OK;

Просмотреть файл

@ -5,7 +5,7 @@
#include <optional>
#include <string>
#include <sstream>
#include "status.h"
#include "string_utils.h"
#ifdef _WIN32
#include <Windows.h>

Просмотреть файл

@ -1,102 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "status.h"
#include "ort_c_to_cpp.h"
struct OrtxStatus::Rep {
extError_t code{kOrtxOK};
std::string error_message;
};
OrtxStatus::OrtxStatus() = default;
OrtxStatus::~OrtxStatus() = default;
OrtxStatus::OrtxStatus(extError_t code, const std::string& error_message)
: rep_(new Rep) {
rep_->code = code;
rep_->error_message = std::string(error_message);
}
OrtxStatus::OrtxStatus(const OrtxStatus& s)
: rep_((s.rep_ == nullptr) ? nullptr : new Rep(*s.rep_)) {}
OrtxStatus& OrtxStatus::operator=(const OrtxStatus& s) {
if (rep_ != s.rep_)
rep_.reset((s.rep_ == nullptr) ? nullptr : new Rep(*s.rep_));
return *this;
}
bool OrtxStatus::operator==(const OrtxStatus& s) const { return (rep_ == s.rep_); }
bool OrtxStatus::operator!=(const OrtxStatus& s) const { return (rep_ != s.rep_); }
const char* OrtxStatus::Message() const {
return IsOk() ? "" : rep_->error_message.c_str();
}
void OrtxStatus::SetErrorMessage(const char* str) {
if (rep_ == nullptr)
rep_ = std::make_unique<Rep>();
rep_->error_message = str;
}
extError_t OrtxStatus::Code() const { return IsOk() ? extError_t() : rep_->code; }
OrtStatus* OrtxStatus::CreateOrtStatus() const {
if (IsOk()) {
return nullptr;
}
OrtStatus* status = OrtW::CreateStatus(Message(), OrtErrorCode::ORT_RUNTIME_EXCEPTION);
return status;
}
std::string OrtxStatus::ToString() const {
if (rep_ == nullptr)
return "OK";
std::string result;
switch (Code()) {
case extError_t::kOrtxOK:
result = "Success";
break;
case extError_t::kOrtxErrorInvalidArgument:
result = "Invalid argument";
break;
case extError_t::kOrtxErrorOutOfMemory:
result = "Out of Memory";
break;
case extError_t::kOrtxErrorCorruptData:
result = "Corrupt data";
break;
case extError_t::kOrtxErrorInvalidFile:
result = "Invalid data file";
break;
case extError_t::kOrtxErrorNotFound:
result = "Not found";
break;
case extError_t::kOrtxErrorAlreadyExists:
result = "Already exists";
break;
case extError_t::kOrtxErrorOutOfRange:
result = "Out of range";
break;
case extError_t::kOrtxErrorNotImplemented:
result = "Not implemented";
break;
case extError_t::kOrtxErrorInternal:
result = "Internal";
break;
case extError_t::kOrtxErrorUnknown:
result = "Unknown";
break;
default:
break;
}
result += ": ";
result += rep_->error_message;
return result;
}

Просмотреть файл

@ -1,31 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <string>
#include <memory>
#include "ortx_types.h"
struct OrtStatus;
struct OrtxStatus {
OrtxStatus();
~OrtxStatus();
OrtxStatus(extError_t code, const std::string& error_message);
OrtxStatus(const OrtxStatus& s);
OrtxStatus& operator=(const OrtxStatus& s);
bool operator==(const OrtxStatus& s) const;
bool operator!=(const OrtxStatus& s) const;
[[nodiscard]] inline bool IsOk() const { return rep_ == nullptr; }
void SetErrorMessage(const char* str);
[[nodiscard]] const char* Message() const;
[[nodiscard]] extError_t Code() const;
std::string ToString() const;
OrtStatus* CreateOrtStatus() const;
private:
struct Rep;
std::unique_ptr<Rep> rep_;
};

Просмотреть файл

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <optional>
#include <numeric>

Просмотреть файл

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <optional>
#include <numeric>

105
include/ext_status.h Normal file
Просмотреть файл

@ -0,0 +1,105 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <string>
#include <memory>
#include "ortx_types.h"
struct OrtStatus;
class OrtxStatus {
struct Rep {
extError_t code{kOrtxOK};
std::string error_message;
};
public:
OrtxStatus() = default;
~OrtxStatus() = default;
OrtxStatus(extError_t code, const std::string& error_message)
: rep_(std::make_unique<Rep>().release()) {
rep_->code = code;
rep_->error_message = std::string(error_message);
}
OrtxStatus(const OrtxStatus& s)
: rep_((s.rep_ == nullptr) ? nullptr : std::make_unique<Rep>(*s.rep_).release()) {}
OrtxStatus& operator=(const OrtxStatus& s) {
if (rep_ != s.rep_)
rep_.reset((s.rep_ == nullptr) ? nullptr : std::make_unique<Rep>(*s.rep_).release());
return *this;
}
bool operator==(const OrtxStatus& s) const { return (rep_ == s.rep_); }
bool operator!=(const OrtxStatus& s) const { return (rep_ != s.rep_); }
[[nodiscard]] inline bool IsOk() const noexcept{ return rep_ == nullptr; }
void SetErrorMessage(const char* str) {
if (rep_ == nullptr)
rep_ = std::make_unique<Rep>();
rep_->error_message = str;
}
[[nodiscard]] const char* Message() const noexcept{
return IsOk() ? "" : rep_->error_message.c_str();
}
[[nodiscard]] extError_t Code() const { return IsOk() ? extError_t() : rep_->code; }
std::string ToString() const {
if (rep_ == nullptr)
return "OK";
std::string result;
switch (Code()) {
case extError_t::kOrtxOK:
result = "Success";
break;
case extError_t::kOrtxErrorInvalidArgument:
result = "Invalid argument";
break;
case extError_t::kOrtxErrorOutOfMemory:
result = "Out of Memory";
break;
case extError_t::kOrtxErrorCorruptData:
result = "Corrupt data";
break;
case extError_t::kOrtxErrorInvalidFile:
result = "Invalid data file";
break;
case extError_t::kOrtxErrorNotFound:
result = "Not found";
break;
case extError_t::kOrtxErrorAlreadyExists:
result = "Already exists";
break;
case extError_t::kOrtxErrorOutOfRange:
result = "Out of range";
break;
case extError_t::kOrtxErrorNotImplemented:
result = "Not implemented";
break;
case extError_t::kOrtxErrorInternal:
result = "Internal";
break;
case extError_t::kOrtxErrorUnknown:
result = "Unknown";
break;
default:
break;
}
result += ": ";
result += rep_->error_message;
return result;
}
operator OrtStatus*() const noexcept;
private:
std::unique_ptr<Rep> rep_;
};

Просмотреть файл

@ -11,6 +11,7 @@
#include <vector>
#include "op_def_struct.h"
#include "ext_status.h"
// A helper API to support test kernels.
// Must be invoked before RegisterCustomOps.
@ -18,6 +19,8 @@ extern "C" bool ORT_API_CALL AddExternalCustomOp(const OrtCustomOp* c_op);
constexpr const char* c_OpDomain = "ai.onnx.contrib";
constexpr const char* c_ComMsExtOpDomain = "com.microsoft.extensions";
template <typename... Args>
class CuopContainer {
public:

Просмотреть файл

@ -28,7 +28,7 @@ namespace Custom {
template <typename T>
inline OrtStatusPtr ToApiStatus(const T& status) {
return status.CreateOrtStatus();
return (OrtStatus*)status;
}
template <>

Просмотреть файл

@ -1,7 +1,6 @@
#include <opencv2/core.hpp>
#include <opencv2/imgcodecs.hpp>
#include "status.h"
#include "string_tensor.h"
inline OrtxStatus image_reader(const ortc::Tensor<std::string>& input,

Просмотреть файл

@ -6,7 +6,6 @@
#include <fstream>
#include <filesystem>
#include "ocos.h"
#include "status.h"
#include "nlohmann/json.hpp"
#include "bpe_types.h"

Просмотреть файл

@ -143,14 +143,14 @@ OrtStatusPtr KernelBpeTokenizer::OnModelAttach(const OrtApi& api, const OrtKerne
bpe_conf_.get().GetSpecialTokens().c_str(),
IsSpmModel(ModelName()));
if (!status.IsOk()) {
return status.CreateOrtStatus();
return (OrtStatusPtr)status;
}
std::string added_token;
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "added_token", added_token));
status = bbpe_tokenizer_->LoadAddedTokens(added_token.c_str());
if (!status.IsOk()) {
return status.CreateOrtStatus();
return (OrtStatusPtr)status;
}
// TODO: need to check if the special token ids are the same as the ones in HFTokenizer

Просмотреть файл

@ -4,7 +4,6 @@
#pragma once
#include "ocos.h"
#include "status.h"
#include "ustring.h"
#include <list>

Просмотреть файл

@ -3,7 +3,6 @@
#pragma once
#include "ocos.h"
#include "status.h"
#include "narrow.h"
#include "ustring.h"
#include "string_utils.h"

Просмотреть файл

@ -5,7 +5,7 @@
#include <vector>
#include "ortx_utils.h"
#include "status.h"
#include "ext_status.h"
#include "op_def_struct.h"
namespace ort_extensions {

Просмотреть файл

@ -4,7 +4,6 @@
#pragma once
#include "ocos.h"
#include "status.h"
constexpr int max_crops = 16;
constexpr int num_img_tokens = 144;