ort-customops/include/custom_op/tensor_api.h

602 строки
18 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <optional>
#include <numeric>
#include <type_traits>
#include <assert.h>
#include "onnxruntime_f16.h"
#include "kernel_context.h"
namespace Ort {
namespace Custom {
template <typename T>
struct Span {
const T* data_ = {};
size_t size_ = {};
void Assign(const T* data, size_t size) {
data_ = data;
size_ = size;
}
size_t size() const { return size_; }
T operator[](size_t indice) const {
return data_[indice];
}
const T* data() const { return data_; }
};
#if ORT_API_VERSION >= 16
template <>
struct Span<MFloat16> {
const MFloat16* data_ = {};
size_t size_ = {};
void Assign(const MFloat16* data, size_t size) {
data_ = data;
size_ = size;
}
size_t size() const { return size_; }
MFloat16 operator[](size_t indice) const {
return data_[indice];
}
const MFloat16* data() const { return data_; }
};
template <>
struct Span<BFloat16> {
const BFloat16* data_ = {};
size_t size_ = {};
void Assign(const BFloat16* data, size_t size) {
data_ = data;
size_ = size;
}
size_t size() const { return size_; }
BFloat16 operator[](size_t indice) const {
return data_[indice];
}
const BFloat16* data() const { return data_; }
};
#endif
class ITensorStorage{
public:
virtual const std::vector<int64_t>& Shape() const = 0;
virtual const void* DataRaw() const = 0;
virtual bool IsInitialized() const = 0;
virtual void* Initialize(const std::vector<int64_t>& shape, size_t element_size) = 0;
virtual void* Release() = 0;
virtual ~ITensorStorage() = default;
};
class IAllocator {
public:
virtual void* Alloc(size_t size) = 0;
virtual void Free(void* p) = 0;
};
class OrtEagerTensorStorage : public ITensorStorage {
public:
OrtEagerTensorStorage(const std::vector<int64_t>& shape,
void* buffer) : buffer_(buffer), shape_(shape){
}
OrtEagerTensorStorage(IAllocator* allocator) : allocator_(allocator){
}
~OrtEagerTensorStorage() override{
if (allocator_ && buffer_)
allocator_->Free(buffer_);
}
const std::vector<int64_t>& Shape() const override {
if (!IsInitialized())
ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION);
return *shape_;
}
bool IsInitialized() const override {
return shape_.has_value();
}
const void* DataRaw() const override {
return buffer_;
}
void* Initialize(const std::vector<int64_t>& shape, size_t element_size) override {
if (IsInitialized())
return buffer_;
assert(allocator_);
shape_ = shape;
int64_t n_elem = std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
auto buffer_size = n_elem * element_size;
buffer_ = allocator_->Alloc(buffer_size);
return buffer_;
}
void* Release() override {
void* tmp = buffer_;
buffer_ = 0;
shape_ = std::nullopt;
return tmp;
}
private:
void* buffer_ {};
std::optional<std::vector<int64_t>> shape_;
// caller need to make sure the allocator is alive
IAllocator* allocator_{};
};
template <typename TT>
ONNXTensorElementDataType GetOrtDType(){
if constexpr (std::is_same<TT, bool>::value)
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
else if constexpr (std::is_same<TT, float>::value)
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
else if constexpr (std::is_same<TT, double>::value)
return ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
else if constexpr (std::is_same<TT, uint8_t>::value)
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
else if constexpr (std::is_same<TT, int8_t>::value)
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
else if constexpr (std::is_same<TT, uint16_t>::value)
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16;
else if constexpr (std::is_same<TT, int16_t>::value)
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
else if constexpr (std::is_same<TT, uint32_t>::value)
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32;
else if constexpr (std::is_same<TT, int32_t>::value)
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
else if constexpr (std::is_same<TT, uint64_t>::value)
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64;
else if constexpr (std::is_same<TT, int64_t>::value)
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
else if constexpr (std::is_same<TT, std::string>::value)
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
ORTX_CXX_API_THROW("Unexpected type", ORT_RUNTIME_EXCEPTION);
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
}
class TensorBase : public Arg {
public:
virtual ~TensorBase() = default;
virtual ONNXTensorElementDataType Type() const = 0;
virtual const std::vector<int64_t>& Shape() const = 0;
virtual int64_t NumberOfElement() const = 0;
virtual const void* DataRaw() const = 0;
virtual size_t SizeInBytes() const = 0;
virtual std::byte* AllocateRaw(const std::vector<int64_t>& shape) = 0;
};
template <typename T>
class Tensor : public TensorBase {
public:
using TT = typename std::remove_reference<T>::type;
Tensor(std::unique_ptr<ITensorStorage> tensor_storage) : storage_(std::move(tensor_storage)){
}
Tensor(const std::vector<int64_t>& shape, void* buffer) : Tensor(std::make_unique<OrtEagerTensorStorage>(shape, buffer)) {}
Tensor(IAllocator* allocator) : storage_(std::make_unique<OrtEagerTensorStorage>(allocator)){}
Tensor(const Tensor& src) = delete;
Tensor& operator=(Tensor src) = delete;
Tensor(Tensor&& other) : storage_(std::move(other.storage_)) {
other.storage_ = nullptr;
other.span_ = {};
}
Tensor& operator=(Tensor&& other)
{
storage_ = std::move(other.storage_);
other.span_ = {};
return *this;
}
operator bool() const {
return storage_ && storage_->IsInitialized();
}
ONNXTensorElementDataType Type() const override {
return GetOrtDType<T>();
}
const std::vector<int64_t>& Shape() const override {
if (!storage_)
ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
return storage_->Shape();
}
int64_t NumberOfElement() const override {
if (!storage_)
ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
auto& shape = storage_->Shape();
return std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
}
std::string Shape2Str() const {
if (!storage_)
ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
if (storage_&& storage_->IsInitialized()) {
std::string shape_str;
auto& shape = storage_->Shape();
for (const auto& dim : shape) {
shape_str.append(std::to_string(dim));
shape_str.append(", ");
}
return shape_str;
} else {
return "empty";
}
}
void* Release() {
if (!storage_)
ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
span_ = {};
return storage_->Release();
}
const TT* Data() const {
if (!storage_)
ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
#if ORT_API_VERSION >= 16
if constexpr (std::is_same<TT, MFloat16>::value || std::is_same<TT, BFloat16>::value)
return reinterpret_cast<const TT*>(storage_->DataRaw());
else
#endif
return static_cast<const TT*>(storage_->DataRaw());
}
const void* DataRaw() const override {
if (!storage_)
ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
return storage_->DataRaw();
}
size_t SizeInBytes() const override {
if (!storage_)
ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
return NumberOfElement() * sizeof(TT);
}
TT* Allocate(const std::vector<int64_t>& shape) {
if (!storage_)
ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
// it should be OK to allocate multiple times
void* buffer = storage_->Initialize(shape, sizeof(TT));
#if ORT_API_VERSION >= 16
if constexpr (std::is_same<TT, MFloat16>::value || std::is_same<TT, BFloat16>::value)
return reinterpret_cast<TT*>(buffer);
else
#endif
return static_cast<TT*>(buffer);
}
std::byte* AllocateRaw(const std::vector<int64_t>& shape) override {
return reinterpret_cast<std::byte*>(Allocate(shape));
}
const Span<T>& AsSpan() {
if (!storage_)
ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
#if ORT_API_VERSION >= 16
if constexpr (std::is_same<TT, MFloat16>::value || std::is_same<TT, BFloat16>::value) {
ORTX_CXX_API_THROW("AsSpan for MFloat16 / BFloat16 not implemented", ORT_RUNTIME_EXCEPTION);
}
else{
#endif
auto& shape = storage_->Shape();
if (shape.size() != 1) {
ORTX_CXX_API_THROW("to get a span, shape must be 1-D, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
}
span_.Assign(Data(), shape[0]);
return span_;
#if ORT_API_VERSION >= 16
}
#endif
}
const T& AsScalar() {
if (!storage_)
ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
#if ORT_API_VERSION >= 16
if constexpr (std::is_same<TT, MFloat16>::value || std::is_same<TT, BFloat16>::value) {
ORTX_CXX_API_THROW("AsScalar for MFloat16 / BFloat16 not implemented", ORT_RUNTIME_EXCEPTION);
}
else{
#endif
auto& shape = storage_->Shape();
if ((shape.size() == 1 && shape[0] != 1) || shape.size() > 1) {
ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
}
return *Data();
#if ORT_API_VERSION >= 16
}
#endif
}
private:
std::unique_ptr<ITensorStorage> storage_;
Span<T> span_;
};
template<typename T>
class IStringTensorStorage{
public:
using strings = std::vector<T>;
virtual const std::vector<int64_t>& Shape() const = 0;
virtual const void* DataRaw() const = 0;
virtual const strings& Data() const = 0;
virtual bool IsInitialized() const = 0;
virtual void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) = 0;
virtual void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) = 0;
virtual ~IStringTensorStorage() = default;
};
template<typename T>
class EagerStringTensorStorage : public IStringTensorStorage<T>{
public:
using strings = std::vector<T>;
EagerStringTensorStorage(const strings& ss) : input_strings_(ss), shape_(std::vector<int64_t>{static_cast<int64_t>(ss.size())}){}
EagerStringTensorStorage() {}
const std::vector<int64_t>& Shape() const override {
if (!IsInitialized())
ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION);
return *shape_;
}
const void* DataRaw() const override {
if (input_strings_.size() != 1) {
ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
}
if constexpr (std::is_same<std::string_view, T>::value)
return reinterpret_cast<const void*>(input_strings_[0].data());
else
return reinterpret_cast<const void*>(input_strings_[0].c_str());
}
bool IsInitialized() const override {
return shape_.has_value();
}
void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) override {
if constexpr (std::is_same<std::string_view, T>::value)
ORTX_CXX_API_THROW("Set output for string view tensor is not supported", ORT_RUNTIME_EXCEPTION);
input_strings_.assign(ss.begin(), ss.end());
shape_ = dims;
}
const strings& Data() const override {
return input_strings_;
}
void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) override {
if constexpr (std::is_same<std::string_view, T>::value)
ORTX_CXX_API_THROW("Set output for string view tensor is not supported", ORT_RUNTIME_EXCEPTION);
for (const char* s : ss){
input_strings_.push_back(s);
}
shape_ = dims;
}
private:
std::vector<T> input_strings_;
std::optional<std::vector<int64_t>> shape_;
};
template <>
class Tensor<std::string> : public TensorBase {
public:
using strings = std::vector<std::string>;
Tensor(std::unique_ptr<IStringTensorStorage<std::string>> storage) : storage_(std::move(storage)) {}
Tensor(const strings& ss) : storage_(std::make_unique<EagerStringTensorStorage<std::string>>(ss)) {}
Tensor() : storage_(std::make_unique<EagerStringTensorStorage<std::string>>()) {}
ONNXTensorElementDataType Type() const override {
return GetOrtDType<std::string>();
}
const strings& Data() const {
return storage_->Data();
}
const std::vector<int64_t>& Shape() const override {
return storage_->Shape();
}
int64_t NumberOfElement() const override {
auto& shape = storage_->Shape();
return std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
}
std::string Shape2Str() const {
if (storage_->IsInitialized()) {
std::string shape_str;
auto& shape = storage_->Shape();
for (const auto& dim : shape) {
shape_str.append(std::to_string(dim));
shape_str.append(", ");
}
return shape_str;
} else {
return "empty";
}
}
const void* DataRaw() const override {
return storage_->DataRaw();
}
size_t SizeInBytes() const override {
auto& ss = storage_->Data();
if (ss.size() != 1) {
ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
}
return ss[0].size();
}
std::byte* AllocateRaw(const std::vector<int64_t>& shape) override {
ORTX_CXX_API_THROW("AllocateRaw() not supported for string tensor", ORT_RUNTIME_EXCEPTION);
}
void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
storage_->SetStringOutput(ss, dims);
}
void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) {
storage_->SetStringOutput(ss, dims);
}
const Span<std::string>& AsSpan() const {
ORTX_CXX_API_THROW("span for TensorT of string not implemented", ORT_RUNTIME_EXCEPTION);
}
const std::string& AsScalar() const {
auto& ss = storage_->Data();
if (ss.size() != 1) {
ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
}
return ss[0];
}
private:
std::unique_ptr<IStringTensorStorage<std::string>> storage_;
};
template <>
class Tensor<std::string_view> : public TensorBase {
public:
using strings = std::vector<std::string_view>;
Tensor(std::unique_ptr<IStringTensorStorage<std::string_view>> storage) : storage_(std::move(storage)) {}
Tensor(const strings& ss) : storage_(std::make_unique<EagerStringTensorStorage<std::string_view>>(ss)) {}
ONNXTensorElementDataType Type() const override {
return GetOrtDType<std::string_view>();
}
const strings& Data() const {
return storage_->Data();
}
const std::vector<int64_t>& Shape() const override {
return storage_->Shape();
}
int64_t NumberOfElement() const override {
auto& shape = storage_->Shape();
return std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
}
std::string Shape2Str() const {
if (storage_->IsInitialized()) {
std::string shape_str;
auto& shape = storage_->Shape();
for (const auto& dim : shape) {
shape_str.append(std::to_string(dim));
shape_str.append(", ");
}
return shape_str;
} else {
return "empty";
}
}
const void* DataRaw() const override {
return storage_->DataRaw();
}
size_t SizeInBytes() const override {
auto& ss = storage_->Data();
if (ss.size() != 1) {
ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
}
return ss[0].size();
}
std::byte* AllocateRaw(const std::vector<int64_t>& shape) override {
ORTX_CXX_API_THROW("AllocateRaw() not supported for string tensor", ORT_RUNTIME_EXCEPTION);
}
void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
storage_->SetStringOutput(ss, dims);
}
void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) {
storage_->SetStringOutput(ss, dims);
}
const Span<std::string_view>& AsSpan() const {
ORTX_CXX_API_THROW("span for TensorT of string not implemented", ORT_RUNTIME_EXCEPTION);
}
const std::string_view& AsScalar() const {
auto& ss = storage_->Data();
if (ss.size() != 1) {
ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
}
return ss[0];
}
private:
std::unique_ptr<IStringTensorStorage<std::string_view>> storage_;
};
template<typename ...Args>
class NamedArgumentDict{
public:
using ValueTuple = std::tuple<Args...>;
NamedArgumentDict(const std::vector<const char*>& keys, const std::tuple<Args...>& args) : entries_(args) {
for (const char* key : keys){
names_.push_back(key);
}
}
template<typename T>
T TryToGetAttributeWithDefault(const char* name, const T& default_value) const {
return TryToGetAttributeWithDefaultInternal<0>(name, default_value);
}
private:
template<size_t I, typename T>
typename std::enable_if<I == sizeof...(Args), T>::type
TryToGetAttributeWithDefaultInternal(const char* name, const T& default_value) const {
return default_value;
}
template<size_t I, typename T>
typename std::enable_if<I < sizeof...(Args), T>::type
TryToGetAttributeWithDefaultInternal(const char* name, const T& default_value) const {
if (names_[I] == name){
if constexpr (std::is_same<std::tuple_element_t<I, ValueTuple>, T>::value)
return std::get<I>(entries_);
else
throw std::runtime_error("name matched but type is not");
}
return TryToGetAttributeWithDefaultInternal<I+1>(name, default_value);
}
std::vector<std::string> names_;
std::tuple<Args...> entries_;
};
}
}