* use lite custom op api for math

* add vision ops

* add cx2 ops

* remove useless code

* support register custom kernel struct

* add string tensor support

* add more text kernels

* fix issue with std stringg as scalar

* migrate all text ops

* initial tokenizer change

* migrate all tokenizers

* Resolve conflict with main (#433)

* resolve conflict

* resolve conflict

---------

Co-authored-by: Randy Shuai <rashuai@microsoft.com>

* Update custom-op-lite PR (#440)

* add the onnxruntime 1.14 release into the CI pipeline (#387)

* add the onnxruntime 1.14 release into the CI pipeline

* torch 2.0 crashed on Linux

* Fix size_t overflow issue for RobertaTokenizer (#388)

Co-authored-by: Sayan Shaw <sayanshaw@microsoft.com>

* Pre and Post processing example for openAI Whisper model (#380)

* add a stft-norm custom op for log-mel spectrum.

* undo the debug change

* Support ONNX standard STFT op signature.

* Add a unit test onnx STFT compatible mode.

* add whisper pre-/post- processing example

* Update dlib.cmake

* undo test code changes

* Update setup.cfg

* update the end2end example with STFT op

* Added optional outputs for GPT2, CLIP and Roberta Tokenizers (#389)

* Initial optional i/o for robertap

* Small fix

* Added working optional output functionality to RobertaTokenizer with tests

* Added optional outputs to CLIPTokenizer

* Added optional outputs to GPT2Tokenizer

* Use ternary operators

---------

Authored-by: Sayan Shaw <sayanshaw@microsoft.com>

* ignore the unknown token id on bpe deocder (#391)

* Use dependency name 'nlohmann_json' which is the same name that ORT uses. (#393)

* Add an audio decoder custom op for whisper end-to-end processing (#385)

* evaluate the audio decoder library

* MP3 Decoder

* rename it to test_audio_codec

* add the audio decoder to whisper model

* whisper end-to-end draft

* fix the mp3 decoder

* Running with ONNX models

* Add more audio format supports

* refine the end-to-end script

* Update operators/audio/audio_decoder.hpp

Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>

* Update operators/audio/audio_decoder.hpp

Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>

* Update operators/audio/audio_decoder.hpp

Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>

* some fixings of comments and more test cases.

* changes for review comments.

* Update audio_decoder.hpp

* Update audio_decoder.hpp

* code refinement

* Update operators/audio/audio_decoder.hpp

Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>

---------

Co-authored-by: Sayan Shaw <52221015+sayanshaw24@users.noreply.github.com>
Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>

* make tensorflow be optional for unittest (#394)

* make tensorflow be optional for unitest.

* typo

* built-in bounding box op (#382)

* built-in bounding box op
* update boundary check
* assert policy
* more boundary test and check
* XYXY--> X horizon
---------

Co-authored-by: Scott McKay <skottmckay@gmail.com>

* a quick nuget package impl. (#396)

* Update wheels_linux.yml: change the linux machine pool name (#398)

* Add a merge step in whisper end-to-end script and fixed some issues (#399)

* add merged models in whisper model

* verify the final model

* support batch > 1 in BpeDecoder (#400)

* support batch > 1 in BpeDecoder

* update the shape in helper function

* [object detection ppp] YoLo as example (#397)

* object detection
* Unit test
add e2e fastestdet model test

---------

Co-authored-by: Changming Sun <chasun@microsoft.com>
Co-authored-by: Scott McKay <skottmckay@gmail.com>

* some fixing for python package (#401)

* more code fixing related whisper models (#403)

* Added windows nuget work temporarily for testing (#402)

* Added windows nuget work temporarily for testing

* Cleanup

* Add back onnxruntime.lib in props file for possible future ORT need

---------

Co-authored-by: Sayan Shaw <sayanshaw@microsoft.com>
Co-authored-by: Wenbing Li <10278425+wenbingl@users.noreply.github.com>

* Remove unnecessary nupkg file and update nuspec (#405)

* Add nuget pack to build.bat and small nuget changes for demo

* Temporarily adding nuget.exe to build package until we can add to CI machine

* Switch back from Release to RelWithDebInfo

* Remove unnecessary changes

---------

Co-authored-by: Sayan Shaw <sayanshaw@microsoft.com>

* Add initial NuGet pipeline for Windows x64 build (#406)

* initial nuget pipeline

* Update nuget.yml for Azure Pipelines

* update nuget.yml for extensions specific packaging

TODO: add certain template yml files

* added component governance template yaml

* change template yaml path

* remove RoslynAnalyzers

* Add packDestination to nuget pack task (change from default)

* fix nuspec path

* Update nuget.yml for Azure Pipelines

* Update nuget.yml for Azure Pipelines

* Update nuget.yml for Azure Pipelines

* Update 2 nuget.yml for Azure Pipelines

* Update NativeNuget.nuspec

* Update nuget.yml for Azure Pipelines

* update nuspec

* Update 3 nuget.yml for Azure Pipelines

* Update 4 nuget.yml for Azure Pipelines

* Update 7 nuget.yml for Azure Pipelines

* Remove unnecessary nupkg file and update nuspec (#405)

* Add nuget pack to build.bat and small nuget changes for demo

* Temporarily adding nuget.exe to build package until we can add to CI machine

* Switch back from Release to RelWithDebInfo

* Remove unnecessary changes

---------

Co-authored-by: Sayan Shaw <sayanshaw@microsoft.com>

* Update 8 nuget.yml for Azure Pipelines

* Update 9 nuget.yml for Azure Pipelines

* add DLL signing

* Update nuget.yml for Azure Pipelines

* fix indendation

* Update 11 nuget.yml for Azure Pipelines

* Update 12 nuget.yml for Azure Pipelines

* Update 12 nuget.yml for Azure Pipelines

* Revert some unneccesary changes on nuget.yml

* clean up nuget.yml and update nuspec release notes

* small changes

* update commit id and release notes

---------

Co-authored-by: Wenbing Li <wenbingl@outlook.com>
Co-authored-by: Wenbing Li <10278425+wenbingl@users.noreply.github.com>
Co-authored-by: Sayan Shaw <sayanshaw@microsoft.com>

* Compatible with onnxruntime-gpu package (#410)

* be compatible without onnxruntime-gpu version

* some fixing

* Add nuget README and remove ort lib references from props (#409)

* Add nuget README and remove ort lib references from props

* replace commit id in nuspec dynamically

* remove $ sign for commit id token

---------

Co-authored-by: Sayan Shaw <sayanshaw@microsoft.com>
Co-authored-by: Wenbing Li <10278425+wenbingl@users.noreply.github.com>

* Add an C# demo project for NuGet package (#407)

* Add a nuget test app

* remove unused file

* Compatible with onnxruntime-gpu package (#410)

* be compatible without onnxruntime-gpu version

* some fixing

* turn it as a .net demo project

---------

Co-authored-by: Sayan Shaw <52221015+sayanshaw24@users.noreply.github.com>

* Make Whisper E2E script more portable (#412)

This PR makes the Whisper E2E script more portable for other environments.

* Update macos wheel timeout to 180 min (#390)

* Update ci timeout to 120 min

* Only update WindowsPython job timeout

* Update ci timeout to 90 min

* update macos wheel timeout to 180 min

---------

Co-authored-by: Sayan Shaw <sayanshaw@microsoft.com>

* Fix OneBranch PR pipeline CodeQL issue (#413)

* test codeql 3000

* switch codeql from compiled to python

* switch back to compiled

---------

Co-authored-by: Sayan Shaw <sayanshaw@microsoft.com>

* Adding down-sampling and stereo mixing features for AudioDecoder (#420)

* initial draft

* second

* third

* polishing

* fix the M_PI name in LINUX platform

* fix bessel function issue

* add a unit test case

* fix the unit test name

* Fix Secure Supply Chain Analysis Warning in PR pipeline (#414)

* remove package sources

* remove NuGet.config

* add .sscignore for cfs0011

* change sscignore

* add CFS0013 to sscignore

---------

Co-authored-by: Sayan Shaw <sayanshaw@microsoft.com>

* fix onnx version to 1.13.1 (#422)

* [NuGet] All platform package pipeline (#408)

* nuget ci package
* disable macos arm64 build for err

* Get the iOS xcframework build working with the split build/pack approach. (#416)

* refine build_xcframework.py
Cleanup/clarify various things
- naming of parameters and files
- consistency
Make handling of additional build args more generic
Update the artifact download dir/extract dir to more intuitive names
Update scripts
- make usage from CI pipeline clearer (e.g. don't hide directory names inside script)
- keep comments in nuspec
- remove unused args
- make additional arg handling more
Co-authored-by: Scott McKay <skottmckay@gmail.com>
Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>

* Add new required pre/post processing ops to Android and iOS packages. (#415)

* Revert "Pin onnx version to 1.13.1" (#423)

* Revert "fix onnx version to 1.13.1 (#422)"

This reverts commit eb29d225a7.

* Update requirements.txt

* PyOp attribute supports int and float data type (#425)

* Fix Android AAR in nuget package. Requires libortextensions.so. (#429)

* build for mac M1 (#430)

* Fix the unit test failure with ONNX 1.14 package. (#428)

* Fix the unit test failure with ONNX 1.14 package.

* more tests

* Update whisper_e2e.py

* Add nuget.org publish version option (#426)

* Add nuget.org publish version option

* typo

* small fix

* typo

---------

Co-authored-by: Sayan Shaw <sayanshaw@microsoft.com>

* resolve conflict

* resolve conflict

* minor fix

* rename from TensorT to Tensor

* fix string tensor

* Add OrtLiteCustomOp

* switch to string view

* fix regex ops

---------

Co-authored-by: Wenbing Li <10278425+wenbingl@users.noreply.github.com>
Co-authored-by: Sayan Shaw <52221015+sayanshaw24@users.noreply.github.com>
Co-authored-by: Sayan Shaw <sayanshaw@microsoft.com>
Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
Co-authored-by: JiCheng <247153481@qq.com>
Co-authored-by: Scott McKay <skottmckay@gmail.com>
Co-authored-by: Changming Sun <chasun@microsoft.com>
Co-authored-by: Wenbing Li <wenbingl@outlook.com>
Co-authored-by: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com>
Co-authored-by: Randy Shuai <rashuai@microsoft.com>

* Fix a build err (#442)

* resolve conflict

* resolve conflict

* minor fix

* rename from TensorT to Tensor

* fix string tensor

* Add OrtLiteCustomOp

* switch to string view

* fix regex ops

* fix build

---------

Co-authored-by: Randy Shuai <rashuai@microsoft.com>

* Fix build err on ort 141 (#444)

* resolve conflict

* resolve conflict

* minor fix

* rename from TensorT to Tensor

* fix string tensor

* Add OrtLiteCustomOp

* switch to string view

* fix regex ops

* fix build

* fix a build err

---------

Co-authored-by: Randy Shuai <rashuai@microsoft.com>

* Remove shape from span (#445)

* resolve conflict

* resolve conflict

* minor fix

* rename from TensorT to Tensor

* fix string tensor

* Add OrtLiteCustomOp

* switch to string view

* fix regex ops

* fix build

* fix a build err

* remove shape

---------

Co-authored-by: Randy Shuai <rashuai@microsoft.com>

* Fix python tests (#446)

* resolve conflict

* resolve conflict

* minor fix

* rename from TensorT to Tensor

* fix string tensor

* Add OrtLiteCustomOp

* switch to string view

* fix regex ops

* fix build

* fix a build err

* remove shape

* fix python tests

---------

Co-authored-by: Randy Shuai <rashuai@microsoft.com>

* Fix max build (#449)

* resolve conflict

* resolve conflict

* minor fix

* rename from TensorT to Tensor

* fix string tensor

* Add OrtLiteCustomOp

* switch to string view

* fix regex ops

* fix build

* fix a build err

* remove shape

* fix python tests

* fix packaging err

* fix mac build

---------

Co-authored-by: Randy Shuai <rashuai@microsoft.com>

* Fix comments (#452)

* resolve conflict

* resolve conflict

* minor fix

* rename from TensorT to Tensor

* fix string tensor

* Add OrtLiteCustomOp

* switch to string view

* fix regex ops

* fix build

* fix a build err

* remove shape

* fix python tests

* fix packaging err

* fix mac build

* fixing the universal2 python package for macOS (#448)

* Remove onnx<1.14 from requirements.txt (#447)

* remove onnx<1.14 from requirements.txt

* downgrade protobuf

* move protobuf req to requirements-dev.txt

---------

Co-authored-by: Sayan Shaw <sayanshaw@microsoft.com>
Co-authored-by: Wenbing Li <10278425+wenbingl@users.noreply.github.com>

* fix comments

* comment version macro

---------

Co-authored-by: Randy Shuai <rashuai@microsoft.com>
Co-authored-by: Wenbing Li <10278425+wenbingl@users.noreply.github.com>
Co-authored-by: Sayan Shaw <52221015+sayanshaw24@users.noreply.github.com>
Co-authored-by: Sayan Shaw <sayanshaw@microsoft.com>

* Fix build err (#453)

* resolve conflict

* resolve conflict

* minor fix

* rename from TensorT to Tensor

* fix string tensor

* Add OrtLiteCustomOp

* switch to string view

* fix regex ops

* fix build

* fix a build err

* remove shape

* fix python tests

* fix packaging err

* fix mac build

* fix comments

* comment version macro

* define Compute for StftNormal

---------

Co-authored-by: Randy Shuai <rashuai@microsoft.com>

* Merge latest main (#461)

* resolve conflict

* resolve conflict

* minor fix

* rename from TensorT to Tensor

* fix string tensor

* Add OrtLiteCustomOp

* switch to string view

* fix regex ops

* fix build

* fix a build err

* remove shape

* fix python tests

* fix packaging err

* fix mac build

* fix comments

* comment version macro

* define Compute for StftNormal

---------

Co-authored-by: Randy Shuai <rashuai@microsoft.com>

* revert wanted changes in test

* revert unwanted changed

* add string_strip op

---------

Co-authored-by: Cheng Tang <chenta@microsoft.com@orttrainingdev9.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: RandySheriffH <48490400+RandySheriffH@users.noreply.github.com>
Co-authored-by: Randy Shuai <rashuai@microsoft.com>
Co-authored-by: Wenbing Li <10278425+wenbingl@users.noreply.github.com>
Co-authored-by: Sayan Shaw <52221015+sayanshaw24@users.noreply.github.com>
Co-authored-by: Sayan Shaw <sayanshaw@microsoft.com>
Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
Co-authored-by: JiCheng <247153481@qq.com>
Co-authored-by: Scott McKay <skottmckay@gmail.com>
Co-authored-by: Changming Sun <chasun@microsoft.com>
Co-authored-by: Wenbing Li <wenbingl@outlook.com>
Co-authored-by: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com>
This commit is contained in:
Tang, Cheng 2023-05-30 18:04:44 -07:00 коммит произвёл GitHub
Родитель 30eb7afcfa
Коммит 8f36cf3272
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
74 изменённых файлов: 1436 добавлений и 2179 удалений

663
includes/custom_op_lite.h Normal file
Просмотреть файл

@ -0,0 +1,663 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "onnxruntime_customop.hpp"
#include <optional>
#include <numeric>
// uplevel the version when supported ort version migrates to newer ones
#define SUPPORT_ORT_API_VERSION_TO 13
namespace Ort {
namespace Custom {
class TensorBase {
public:
TensorBase(const OrtW::CustomOpApi& api,
OrtKernelContext& ctx,
size_t indice,
bool is_input) : api_(api),
ctx_(ctx),
indice_(indice),
is_input_(is_input) {}
virtual ~TensorBase() = default;
operator bool() const {
return shape_.has_value();
}
const std::vector<int64_t>& Shape() const {
if (shape_.has_value()) {
return *shape_;
} else {
ORTX_CXX_API_THROW("tensor shape is not yet initialized", ORT_RUNTIME_EXCEPTION);
}
}
int64_t NumberOfElement() const {
if (shape_.has_value()) {
return std::accumulate(shape_->begin(), shape_->end(), 1LL, std::multiplies<int64_t>());
} else {
ORTX_CXX_API_THROW("tensor shape is not yet initialized", ORT_RUNTIME_EXCEPTION);
}
}
std::string Shape2Str() const {
if (shape_.has_value()) {
std::string shape_str;
for (const auto& dim: *shape_) {
shape_str.append(std::to_string(dim));
shape_str.append(", ");
}
return shape_str;
} else {
return "empty";
}
}
protected:
const OrtW::CustomOpApi& api_;
OrtKernelContext& ctx_;
size_t indice_;
bool is_input_;
std::optional<std::vector<int64_t>> shape_;
};
template <typename T>
struct Span {
const T* data_ = {};
size_t size_ = {};
void Assign(const T* data, size_t size) {
data_ = data;
size_ = size;
}
size_t size() const { return size_; }
T operator[](size_t indice) const {
return data_[indice];
}
const T* data() const { return data_; }
};
template <typename T>
class Tensor : public TensorBase {
public:
using TT = typename std::remove_reference<T>::type;
Tensor(const OrtW::CustomOpApi& api,
OrtKernelContext& ctx,
size_t indice,
bool is_input) : TensorBase(api,
ctx,
indice,
is_input) {
if (is_input) {
auto input_count = api_.KernelContext_GetInputCount(&ctx_);
if (indice >= input_count) {
ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
}
const_value_ = api_.KernelContext_GetInput(&ctx_, indice);
auto* info = api_.GetTensorTypeAndShape(const_value_);
shape_ = api_.GetTensorShape(info);
api_.ReleaseTensorTypeAndShapeInfo(info);
}
}
const TT* Data() const {
return api_.GetTensorData<TT>(const_value_);
}
TT* Allocate(const std::vector<int64_t>& shape) {
if (!data_) {
OrtValue* out = api_.KernelContext_GetOutput(&ctx_, indice_, shape.data(), shape.size());
shape_ = shape;
data_ = api_.GetTensorMutableData<TT>(out);
}
return data_;
}
const Span<T>& AsSpan() {
if (!shape_.has_value() || shape_->size() != 1) {
ORTX_CXX_API_THROW("to get a span, shape must be 1-D, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
}
span_.Assign(Data(), (*shape_)[0]);
return span_;
}
const T& AsScalar() {
if (!shape_.has_value() || (shape_->size() == 1 && (*shape_)[0] != 1) || shape_->size() > 1) {
ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
}
return *Data();
}
private:
const OrtValue* const_value_{}; // for input
TT* data_{}; // for output
Span<T> span_;
};
template <>
class Tensor<std::string> : public TensorBase {
public:
using strings = std::vector<std::string>;
Tensor(const OrtW::CustomOpApi& api,
OrtKernelContext& ctx,
size_t indice,
bool is_input) : TensorBase(api,
ctx,
indice,
is_input) {
if (is_input) {
auto input_count = api_.KernelContext_GetInputCount(&ctx_);
if (indice >= input_count) {
ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
}
auto* const_value = api_.KernelContext_GetInput(&ctx_, indice);
auto* info = api_.GetTensorTypeAndShape(const_value);
shape_ = api_.GetTensorShape(info);
api_.ReleaseTensorTypeAndShapeInfo(info);
size_t num_chars;
OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorDataLength(const_value, &num_chars));
std::vector<char> chars(num_chars + 1, '\0');
auto num_strings = NumberOfElement();
std::vector<size_t> offsets(NumberOfElement());
OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorContent(const_value,
(void*)chars.data(),
num_chars,
offsets.data(),
offsets.size()));
auto upper_bound = static_cast<int64_t>(num_strings) - 1;
input_strings_.resize(num_strings);
for (int64_t i = upper_bound; i >= 0; --i) {
if (i < upper_bound) {
chars[offsets[i + 1]] = '\0';
}
input_strings_[i] = chars.data() + offsets[i];
}
}
}
const strings& Data() const {
return input_strings_;
}
void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
std::vector<const char*> raw;
for (const auto& s : ss) {
raw.push_back(s.data());
}
auto* output = api_.KernelContext_GetOutput(&ctx_, indice_, dims.data(), dims.size());
OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().FillStringTensor(output, raw.data(), raw.size()));
}
void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) {
auto* output = api_.KernelContext_GetOutput(&ctx_, indice_, dims.data(), dims.size());
OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().FillStringTensor(output, ss.data(), ss.size()));
}
const Span<std::string>& AsSpan() {
ORTX_CXX_API_THROW("span for TensorT of string not implemented", ORT_RUNTIME_EXCEPTION);
}
const std::string& AsScalar() {
if (!shape_.has_value() || (shape_->size() == 1 && (*shape_)[0] != 1) || shape_->size() > 1) {
ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
}
return input_strings_[0];
}
private:
std::vector<std::string> input_strings_; // for input
};
template <>
class Tensor<std::string_view> : public TensorBase {
public:
using strings = std::vector<std::string>;
using string_views = std::vector<std::string_view>;
Tensor(const OrtW::CustomOpApi& api,
OrtKernelContext& ctx,
size_t indice,
bool is_input) : TensorBase(api,
ctx,
indice,
is_input) {
if (is_input_) {
auto input_count = api_.KernelContext_GetInputCount(&ctx_);
if (indice >= input_count) {
ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
}
auto* const_value = api_.KernelContext_GetInput(&ctx_, indice);
auto* info = api_.GetTensorTypeAndShape(const_value);
shape_ = api_.GetTensorShape(info);
api_.ReleaseTensorTypeAndShapeInfo(info);
size_t num_chars;
OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorDataLength(const_value, &num_chars));
chars_.resize(num_chars + 1, '\0');
auto num_strings = static_cast<size_t>(NumberOfElement());
if (num_strings) {
std::vector<size_t> offsets(num_strings);
OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorContent(const_value,
(void*)chars_.data(),
num_chars,
offsets.data(),
offsets.size()));
offsets.push_back(num_chars);
for (size_t i = 0; i < num_strings; ++i) {
input_string_views_.emplace_back(chars_.data() + offsets[i], offsets[i + 1] - offsets[i]);
}
}
}
}
int64_t NumberOfElement() const {
if (shape_.has_value()) {
return std::accumulate(shape_->begin(), shape_->end(), 1ULL, std::multiplies<int64_t>());
} else {
return 0;
}
}
const string_views& Data() const {
return input_string_views_;
}
const Span<std::string_view>& AsSpan() {
ORTX_CXX_API_THROW("span for TensorT of string view not implemented", ORT_RUNTIME_EXCEPTION);
}
std::string_view AsScalar() {
if (!shape_.has_value() || (shape_->size() == 1 && (*shape_)[0] != 1) || shape_->size() > 1) {
ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
}
return input_string_views_[0];
}
private:
std::vector<char> chars_; // for input
std::vector<std::string_view> input_string_views_; // for input
};
using TensorPtr = std::unique_ptr<Custom::TensorBase>;
struct OrtLiteCustomOp : public OrtCustomOp {
using ConstOptionalFloatTensor = std::optional<const Custom::Tensor<float>&>;
using OptionalFloatTensor = std::optional<Custom::Tensor<float>>;
// CreateTuple
template <size_t ith_input, size_t ith_output, typename... Ts>
static typename std::enable_if<sizeof...(Ts) == 0, std::tuple<>>::type
CreateTuple(const OrtW::CustomOpApi*, OrtKernelContext*, std::vector<TensorPtr>&, size_t, size_t, const std::string&) {
return std::make_tuple();
}
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, OrtKernelContext*>::value, std::tuple<T, Ts...>>::type
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
std::tuple<T> current = std::tuple<OrtKernelContext*>{context};
auto next = CreateTuple<ith_input, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
}
#define CREATE_TUPLE_INPUT(data_type) \
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())}; \
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
return std::tuple_cat(current, next); \
} \
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())}; \
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
return std::tuple_cat(current, next); \
} \
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
static typename std::enable_if<std::is_same<T, std::optional<const Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
if (ith_input < num_input) { \
tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())}; \
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
return std::tuple_cat(current, next); \
} else { \
std::tuple<T> current = std::tuple<T>{}; \
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
return std::tuple_cat(current, next); \
} \
} \
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>*>::value, std::tuple<T, Ts...>>::type \
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
if ("CPUExecutionProvider" != ep) { \
ORTX_CXX_API_THROW("span input could only be applied to CPU EP", ORT_FAIL); \
} \
tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsSpan()}; \
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
return std::tuple_cat(current, next); \
} \
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>&>::value, std::tuple<T, Ts...>>::type \
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
if ("CPUExecutionProvider" != ep) { \
ORTX_CXX_API_THROW("span input could only be applied to CPU EP", ORT_FAIL); \
} \
tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsSpan()}; \
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
return std::tuple_cat(current, next); \
} \
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
static typename std::enable_if<std::is_same<T, std::optional<const Custom::Span<data_type>*>>::value, std::tuple<T, Ts...>>::type \
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
if (ith_input < num_input) { \
if ("CPUExecutionProvider" != ep) { \
ORTX_CXX_API_THROW("span input could only be applied to CPU EP", ORT_FAIL); \
} \
tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsSpan()}; \
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
return std::tuple_cat(current, next); \
} else { \
std::tuple<T> current = std::tuple<T>{}; \
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
return std::tuple_cat(current, next); \
} \
} \
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
static typename std::enable_if<std::is_same<T, data_type>::value, std::tuple<T, Ts...>>::type \
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
if ("CPUExecutionProvider" != ep) { \
ORTX_CXX_API_THROW("scalar input could only be applied to CPU EP", ORT_FAIL); \
} \
tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsScalar()}; \
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
return std::tuple_cat(current, next); \
} \
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
static typename std::enable_if<std::is_same<T, std::optional<data_type>>::value, std::tuple<T, Ts...>>::type \
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
if (ith_input < num_input) { \
if ("CPUExecutionProvider" != ep) { \
ORTX_CXX_API_THROW("scalar input could only be applied to CPU EP", ORT_FAIL); \
} \
tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsScalar()}; \
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
return std::tuple_cat(current, next); \
} else { \
std::tuple<T> current = std::tuple<T>{}; \
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
return std::tuple_cat(current, next); \
} \
}
#define CREATE_TUPLE_OUTPUT(data_type) \
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_output, false)); \
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())}; \
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \
return std::tuple_cat(current, next); \
} \
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_output, false)); \
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())}; \
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \
return std::tuple_cat(current, next); \
} \
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
static typename std::enable_if<std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
if (ith_output < num_output) { \
tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_output, false)); \
std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())}; \
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \
return std::tuple_cat(current, next); \
} else { \
std::tuple<T> current = std::tuple<T>{}; \
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \
return std::tuple_cat(current, next); \
} \
}
#define CREATE_TUPLE(data_type) \
CREATE_TUPLE_INPUT(data_type) \
CREATE_TUPLE_OUTPUT(data_type)
CREATE_TUPLE(bool)
CREATE_TUPLE(float)
CREATE_TUPLE(double)
CREATE_TUPLE(int8_t)
CREATE_TUPLE(int16_t)
CREATE_TUPLE(int32_t)
CREATE_TUPLE(int64_t)
CREATE_TUPLE(uint8_t)
CREATE_TUPLE(uint16_t)
CREATE_TUPLE(uint32_t)
CREATE_TUPLE(uint64_t)
CREATE_TUPLE(std::string)
CREATE_TUPLE_INPUT(std::string_view)
// ParseArgs ...
template <typename... Ts>
static typename std::enable_if<0 == sizeof...(Ts)>::type
ParseArgs(std::vector<ONNXTensorElementDataType>&, std::vector<ONNXTensorElementDataType>&) {
}
template <typename T, typename... Ts>
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, OrtKernelContext*>::value>::type
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
ParseArgs<Ts...>(input_types, output_types);
}
#define PARSE_INPUT_BASE(pack_type, onnx_type) \
template <typename T, typename... Ts> \
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, pack_type>::value>::type \
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
input_types.push_back(onnx_type); \
ParseArgs<Ts...>(input_types, output_types); \
} \
template <typename T, typename... Ts> \
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<pack_type>>::value>::type \
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
input_types.push_back(onnx_type); \
ParseArgs<Ts...>(input_types, output_types); \
}
#define PARSE_INPUT(data_type, onnx_type) \
PARSE_INPUT_BASE(const Custom::Tensor<data_type>*, onnx_type) \
PARSE_INPUT_BASE(const Custom::Tensor<data_type>&, onnx_type) \
PARSE_INPUT_BASE(const Custom::Span<data_type>*, onnx_type) \
PARSE_INPUT_BASE(const Custom::Span<data_type>&, onnx_type) \
PARSE_INPUT_BASE(data_type, onnx_type)
#define PARSE_OUTPUT(data_type, onnx_type) \
template <typename T, typename... Ts> \
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>*>::value>::type \
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
output_types.push_back(onnx_type); \
ParseArgs<Ts...>(input_types, output_types); \
} \
template <typename T, typename... Ts> \
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>&>::value>::type \
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
output_types.push_back(onnx_type); \
ParseArgs<Ts...>(input_types, output_types); \
} \
template <typename T, typename... Ts> \
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value>::type \
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
output_types.push_back(onnx_type); \
ParseArgs<Ts...>(input_types, output_types); \
}
#define PARSE_ARGS(data_type, onnx_type) \
PARSE_INPUT(data_type, onnx_type) \
PARSE_OUTPUT(data_type, onnx_type)
PARSE_ARGS(bool, ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL)
PARSE_ARGS(float, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
PARSE_ARGS(double, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE)
PARSE_ARGS(int8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8)
PARSE_ARGS(int16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16)
PARSE_ARGS(int32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32)
PARSE_ARGS(int64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)
PARSE_ARGS(uint8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8)
PARSE_ARGS(uint16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16)
PARSE_ARGS(uint32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32)
PARSE_ARGS(uint64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64)
PARSE_ARGS(std::string, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING)
PARSE_ARGS(std::string_view, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) // todo - remove string_view output
OrtLiteCustomOp(const char* op_name,
const char* execution_provider) : op_name_(op_name),
execution_provider_(execution_provider) {
OrtCustomOp::version = MIN_ORT_VERSION_SUPPORTED;
OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast<const OrtLiteCustomOp*>(op)->op_name_.c_str(); };
OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* op) { return ((OrtLiteCustomOp*)op)->execution_provider_.c_str(); };
OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* op) {
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
return self->input_types_.size();
};
OrtCustomOp::GetInputType = [](const OrtCustomOp* op, size_t indice) {
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
return self->input_types_[indice];
};
OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* op) {
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
return self->output_types_.size();
};
OrtCustomOp::GetOutputType = [](const OrtCustomOp* op, size_t indice) {
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
return self->output_types_[indice];
};
OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp*, size_t) {
return INPUT_OUTPUT_OPTIONAL;
};
OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp*, size_t) {
return INPUT_OUTPUT_OPTIONAL;
};
}
const std::string op_name_;
const std::string execution_provider_;
std::vector<ONNXTensorElementDataType> input_types_;
std::vector<ONNXTensorElementDataType> output_types_;
};
template <typename... Args>
struct OrtLiteCustomFunc : public OrtLiteCustomOp {
using ComputeFn = void (*)(Args...);
using MyType = OrtLiteCustomFunc<Args...>;
struct Kernel {
ComputeFn compute_fn_{};
std::string ep_{};
std::unique_ptr<OrtW::CustomOpApi> api_;
};
OrtLiteCustomFunc(const char* op_name,
const char* execution_provider,
ComputeFn compute_fn) : OrtLiteCustomOp(op_name, execution_provider),
compute_fn_(compute_fn) {
ParseArgs<Args...>(input_types_, output_types_);
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
auto kernel = reinterpret_cast<Kernel*>(op_kernel);
std::vector<TensorPtr> tensors;
auto t = CreateTuple<0, 0, Args...>(kernel->api_.get(),
context,
tensors,
kernel->api_->KernelContext_GetInputCount(context),
kernel->api_->KernelContext_GetOutputCount(context),
kernel->ep_);
std::apply([kernel](Args const&... t_args) { kernel->compute_fn_(t_args...); }, t);
};
OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
auto kernel = std::make_unique<Kernel>();
auto self = static_cast<const OrtLiteCustomFunc*>(this_);
kernel->compute_fn_ = self->compute_fn_;
kernel->ep_ = self->execution_provider_;
kernel->api_ = std::make_unique<OrtW::CustomOpApi>(*ort_api);
return reinterpret_cast<void*>(kernel.release());
};
OrtCustomOp::KernelDestroy = [](void* op_kernel) {
delete reinterpret_cast<Kernel*>(op_kernel);
};
}
ComputeFn compute_fn_;
};
template <typename CustomOp>
struct OrtLiteCustomStruct : public OrtLiteCustomOp {
template <typename... Args>
using CustomComputeFn = void (CustomOp::*)(Args...);
using MyType = OrtLiteCustomStruct<CustomOp>;
struct Kernel {
std::unique_ptr<CustomOp> custom_op_;
std::string ep_{};
std::unique_ptr<OrtW::CustomOpApi> api_;
};
OrtLiteCustomStruct(const char* op_name,
const char* execution_provider) : OrtLiteCustomOp(op_name,
execution_provider) {
init(&CustomOp::Compute);
}
template <typename... Args>
void init(CustomComputeFn<Args...>) {
ParseArgs<Args...>(input_types_, output_types_);
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
auto kernel = reinterpret_cast<Kernel*>(op_kernel);
std::vector<TensorPtr> tensors;
auto t = CreateTuple<0, 0, Args...>(kernel->api_.get(),
context,
tensors,
kernel->api_->KernelContext_GetInputCount(context),
kernel->api_->KernelContext_GetOutputCount(context),
kernel->ep_);
std::apply([kernel](Args const&... t_args) { kernel->custom_op_->Compute(t_args...); }, t);
};
OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
auto kernel = std::make_unique<Kernel>();
kernel->custom_op_ = std::make_unique<CustomOp>(*ort_api, *info);
auto self = static_cast<const MyType*>(this_);
kernel->ep_ = self->execution_provider_;
kernel->api_ = std::make_unique<OrtW::CustomOpApi>(*ort_api);
return reinterpret_cast<void*>(kernel.release());
};
OrtCustomOp::KernelDestroy = [](void* op_kernel) {
delete reinterpret_cast<Kernel*>(op_kernel);
};
}
};
template <typename... Args>
OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
const char* execution_provider,
void (*custom_compute_fn)(Args...)) {
using LiteOp = OrtLiteCustomFunc<Args...>;
return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn).release();
}
template <typename CustomOp>
OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
const char* execution_provider) {
using LiteOp = OrtLiteCustomStruct<CustomOp>;
return std::make_unique<LiteOp>(op_name, execution_provider).release();
}
} // namespace Custom
} // namespace Ort

Просмотреть файл

@ -85,6 +85,53 @@ class CuopContainer {
std::vector<std::shared_ptr<OrtCustomOp>> op_instances_; // use shared_ptr to capture type specific deleter
};
#define CustomCpuFunc(name, f) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOp(name, "CPUExecutionProvider", f)); }
#define CustomCpuStruct(name, s) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOp<s>(name, "CPUExecutionProvider")); }
template <typename F>
void AppendCustomOp(std::vector<std::shared_ptr<OrtCustomOp>>& ops,
F arg) {
ops.emplace_back(std::move(arg()));
}
template <typename T, typename... Args>
void AppendCustomOp(std::vector<std::shared_ptr<OrtCustomOp>>& ops,
T arg, Args... args) {
AppendCustomOp(ops, arg);
AppendCustomOp(ops, args...);
}
class OrtOpLoader {
public:
template <typename... Args>
OrtOpLoader(Args... args) {
LoadOps(args...);
for (auto& ptr : op_instances_) {
if (ptr)
ocos_list_.push_back(ptr.get());
}
}
const std::vector<const OrtCustomOp*>& GetCustomOps() const {
return ocos_list_;
}
private:
template <typename T>
void LoadOps(T fn) {
AppendCustomOp(op_instances_, fn);
}
template <typename T, typename... Args>
void LoadOps(T fn, Args... args) {
AppendCustomOp(op_instances_, fn);
AppendCustomOp(op_instances_, args...);
}
std::vector<const OrtCustomOp*> ocos_list_;
std::vector<std::shared_ptr<OrtCustomOp>> op_instances_;
};
struct CustomOpClassBegin {
};

Просмотреть файл

@ -18,6 +18,8 @@
#include "exceptions.h"
#define MIN_ORT_VERSION_SUPPORTED 10
namespace OrtW {
//
@ -54,6 +56,8 @@ struct CustomOpApi {
OrtW::ThrowOnError(api_, status);
}
const OrtApi& GetOrtApi() const { return api_; }
private:
const OrtApi& api_;
};
@ -61,7 +65,7 @@ struct CustomOpApi {
template <typename TOp, typename TKernel>
struct CustomOpBase : OrtCustomOp {
CustomOpBase() {
OrtCustomOp::version = 10; // The minimum ORT version supported
OrtCustomOp::version = MIN_ORT_VERSION_SUPPORTED; // The minimum ORT version supported
OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) {
void* result = nullptr;
OCOS_API_IMPL_BEGIN
@ -295,3 +299,7 @@ inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context,
}
} // namespace OrtW
// !! TODO: only do it for legecy ort build
#include "custom_op_lite.h"
namespace ortc = Ort::Custom;

Просмотреть файл

@ -12,70 +12,28 @@
#include <cstdint>
struct KernelImageDecoder : BaseKernel {
KernelImageDecoder(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {}
void Compute(OrtKernelContext* context) {
// Setup inputs
const OrtValue* const inputs = ort_.KernelContext_GetInput(context, 0ULL);
OrtTensorDimensions dimensions(ort_, inputs);
if (dimensions.size() != 1ULL) {
ORTX_CXX_API_THROW("[ImageDecoder]: Only raw image formats are supported.", ORT_INVALID_ARGUMENT);
}
// Get data & the length
const uint8_t* const encoded_image_data = ort_.GetTensorData<uint8_t>(inputs);
OrtTensorTypeAndShapeInfo* const input_info = ort_.GetTensorTypeAndShape(inputs);
const int64_t encoded_image_data_len = ort_.GetTensorShapeElementCount(input_info);
ort_.ReleaseTensorTypeAndShapeInfo(input_info);
// Decode the image
const std::vector<int32_t> encoded_image_sizes{1, static_cast<int32_t>(encoded_image_data_len)};
const cv::Mat encoded_image(encoded_image_sizes, CV_8UC1,
const_cast<void*>(static_cast<const void*>(encoded_image_data)));
const cv::Mat decoded_image = cv::imdecode(encoded_image, cv::IMREAD_COLOR);
// Setup output & copy to destination
const cv::Size decoded_image_size = decoded_image.size();
const int64_t colors = 3;
const std::vector<int64_t> output_dimensions{decoded_image_size.height, decoded_image_size.width, colors};
OrtValue* const output_value = ort_.KernelContext_GetOutput(
context, 0, output_dimensions.data(), output_dimensions.size());
uint8_t* const decoded_image_data = ort_.GetTensorMutableData<uint8_t>(output_value);
memcpy(decoded_image_data, decoded_image.data, decoded_image.total() * decoded_image.elemSize());
}
};
struct CustomOpImageDecoder : OrtW::CustomOpBase<CustomOpImageDecoder, KernelImageDecoder> {
const char* GetName() const {
return "ImageDecoder";
void image_decoder(const ortc::Tensor<uint8_t>& input,
ortc::Tensor<uint8_t>& output) {
auto& dimensions = input.Shape();
if (dimensions.size() != 1ULL) {
ORTX_CXX_API_THROW("[ImageDecoder]: Only raw image formats are supported.", ORT_INVALID_ARGUMENT);
}
size_t GetInputTypeCount() const {
return 1;
}
// Get data & the length
const uint8_t* const encoded_image_data = input.Data();
const int64_t encoded_image_data_len = input.NumberOfElement();
ONNXTensorElementDataType GetInputType(size_t index) const {
switch (index) {
case 0:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
default:
ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
}
}
// Decode the image
const std::vector<int32_t> encoded_image_sizes{1, static_cast<int32_t>(encoded_image_data_len)};
const cv::Mat encoded_image(encoded_image_sizes, CV_8UC1,
const_cast<void*>(static_cast<const void*>(encoded_image_data)));
const cv::Mat decoded_image = cv::imdecode(encoded_image, cv::IMREAD_COLOR);
size_t GetOutputTypeCount() const {
return 1;
}
// Setup output & copy to destination
const cv::Size decoded_image_size = decoded_image.size();
const int64_t colors = 3;
ONNXTensorElementDataType GetOutputType(size_t index) const {
switch (index) {
case 0:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
default:
ORTX_CXX_API_THROW(MakeString("Unexpected output index ", index), ORT_INVALID_ARGUMENT);
}
}
};
const std::vector<int64_t> output_dimensions{decoded_image_size.height, decoded_image_size.width, colors};
uint8_t* const decoded_image_data = output.Allocate(output_dimensions);
memcpy(decoded_image_data, decoded_image.data, decoded_image.total() * decoded_image.elemSize());
}

Просмотреть файл

@ -27,6 +27,20 @@ struct KernelImageReader : BaseKernel {
}
};
void image_reader(const ortc::Tensor<std::string>& input,
ortc::Tensor<uint8_t>& output) {
auto& input_data_dimensions = input.Shape();
int n = input_data_dimensions[0];
if (n != 1) {
ORTX_CXX_API_THROW("[ImageReader]: the dimension of input value can only be 1 now.", ORT_INVALID_ARGUMENT);
}
auto& image_paths = input.Data();
cv::Mat img = cv::imread(image_paths[0], cv::IMREAD_COLOR);
std::vector<int64_t> output_dimensions = {1, img.size[0], img.size[1], static_cast<int64_t>(img.elemSize())};
std::uint8_t* p_output_image = output.Allocate(output_dimensions);
memcpy(p_output_image, img.data, img.total() * img.elemSize());
}
struct CustomOpImageReader : OrtW::CustomOpBase<CustomOpImageReader, KernelImageReader> {
size_t GetInputTypeCount() const {
return 1;

Просмотреть файл

@ -1,81 +1,40 @@
#include <opencv2/core.hpp>
#include <opencv2/imgproc.hpp>
struct KernelGaussianBlur : BaseKernel {
KernelGaussianBlur(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
void gaussian_blur(const ortc::Tensor<float>& input_data,
const ortc::Span<int64_t>& input_ksize,
const ortc::Span<double>& input_sigma,
ortc::Tensor<float>& output) {
const float* p_input_data = input_data.Data();
std::int64_t ksize[] = {3, 3};
double sigma[] = {0., 0.};
if (input_ksize.size() != 2) {
ORTX_CXX_API_THROW("[GaussianBlur]: ksize shape is (2,)", ORT_INVALID_ARGUMENT);
}
std::copy_n(input_ksize.data(), 2, ksize);
void Compute(OrtKernelContext* context) {
size_t input_c = ort_.KernelContext_GetInputCount(context);
const OrtValue* input_data = ort_.KernelContext_GetInput(context, 0);
const float* p_input_data = ort_.GetTensorData<float>(input_data);
std::int64_t ksize[] = {3, 3};
double sigma[] = {0., 0.};
if (input_c > 1) {
const OrtValue* input_ksize = ort_.KernelContext_GetInput(context, 1);
OrtTensorDimensions dim_ksize(ort_, input_ksize);
if (dim_ksize.size() != 1 || dim_ksize[0] != 2) {
ORTX_CXX_API_THROW("[GaussianBlur]: ksize shape is (2,)", ORT_INVALID_ARGUMENT);
}
std::copy_n(ort_.GetTensorData<std::int64_t>(input_ksize), 2, ksize);
}
if (input_c > 2) {
const OrtValue* input_sigma = ort_.KernelContext_GetInput(context, 2);
OrtTensorDimensions dim_sigma(ort_, input_sigma);
if (dim_sigma.size() != 1 || dim_sigma[0] != 2) {
ORTX_CXX_API_THROW("[GaussianBlur]: sigma shape is (2,)", ORT_INVALID_ARGUMENT);
}
std::copy_n(ort_.GetTensorData<double>(input_sigma), 2, sigma);
}
OrtTensorDimensions input_data_dimensions(ort_, input_data);
int n = static_cast<int>(input_data_dimensions[0]);
int h = static_cast<int>(input_data_dimensions[1]);
int w = static_cast<int>(input_data_dimensions[2]);
int c = static_cast<int>(input_data_dimensions[3]);
(void)n;
(void)c;
cv::Mat input_image(cv::Size(w, h), CV_32FC3, (void*)p_input_data);
cv::Mat output_image;
cv::GaussianBlur(input_image,
output_image,
cv::Size(static_cast<int>(ksize[1]), static_cast<int>(ksize[0])),
sigma[0], sigma[1], cv::BORDER_DEFAULT);
OrtValue* image_y = ort_.KernelContext_GetOutput(context,
0, input_data_dimensions.data(), input_data_dimensions.size());
float* p_output_image = ort_.GetTensorMutableData<float>(image_y);
memcpy(p_output_image, output_image.data, output_image.total() * output_image.elemSize());
if (input_sigma.size() != 2) {
ORTX_CXX_API_THROW("[GaussianBlur]: sigma shape is (2,)", ORT_INVALID_ARGUMENT);
}
};
std::copy_n(input_sigma.data(), 2, sigma);
struct CustomOpGaussianBlur : OrtW::CustomOpBase<CustomOpGaussianBlur, KernelGaussianBlur> {
size_t GetInputTypeCount() const {
return 3;
}
auto& input_data_dimensions = input_data.Shape();
size_t GetOutputTypeCount() const {
return 1;
}
int n = static_cast<int>(input_data_dimensions[0]);
int h = static_cast<int>(input_data_dimensions[1]);
int w = static_cast<int>(input_data_dimensions[2]);
int c = static_cast<int>(input_data_dimensions[3]);
(void)n;
(void)c;
ONNXTensorElementDataType GetInputType(size_t index) const {
if (index == 0) {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
} else if (index == 1) {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
} else {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
}
}
cv::Mat input_image(cv::Size(w, h), CV_32FC3, (void*)p_input_data);
cv::Mat output_image;
cv::GaussianBlur(input_image,
output_image,
cv::Size(static_cast<int>(ksize[1]), static_cast<int>(ksize[0])),
sigma[0], sigma[1], cv::BORDER_DEFAULT);
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
}
const char* GetName() const {
return "GaussianBlur";
}
};
float* p_output_image = output.Allocate(input_data_dimensions);
memcpy(p_output_image, output_image.data, output_image.total() * output_image.elemSize());
}

Просмотреть файл

@ -3,14 +3,17 @@
#ifdef ENABLE_OPENCV_CODECS
#include "imgcodecs/imread.hpp"
#include "imgcodecs/imdecode.hpp"
#endif // ENABLE_OPENCV_CODECS
#endif // ENABLE_OPENCV_CODECS
FxLoadCustomOpFactory LoadCustomOpClasses_CV2 =
LoadCustomOpClasses<CustomOpClassBegin
, CustomOpGaussianBlur
const std::vector<const OrtCustomOp*>& Cv2Loader() {
static OrtOpLoader op_loader(CustomCpuFunc("GaussianBlur", gaussian_blur)
#ifdef ENABLE_OPENCV_CODECS
, CustomOpImageReader
, CustomOpImageDecoder
#endif // ENABLE_OPENCV_CODECS
>;
,
CustomCpuFunc("ImageDecoder", image_decoder),
CustomCpuFunc("ImageReader", image_reader)
#endif // ENABLE_OPENCV_CODECS
);
return op_loader.GetCustomOps();
}
FxLoadCustomOpFactory LoadCustomOpClasses_CV2 = Cv2Loader;

Просмотреть файл

@ -6,52 +6,17 @@
#include "ocos.h"
#include <dlib/matrix.h>
struct KernelInverse : BaseKernel {
KernelInverse(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
void inverse(const ortc::Tensor<float>& input,
ortc::Tensor<float>& output) {
auto& dimensions = input.Shape();
if (dimensions.size() != 2) {
throw std::runtime_error("Only 2-d matrix supported.");
}
const float* X = input.Data();
float* out = output.Allocate(dimensions);
void Compute(OrtKernelContext* context) {
// Setup inputs
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
const float* X = ort_.GetTensorData<float>(input_X);
// Setup output
OrtTensorDimensions dimensions(ort_, input_X);
if (dimensions.size() != 2) {
throw std::runtime_error("Only 2-d matrix supported.");
}
OrtValue* output0 = ort_.KernelContext_GetOutput(
context, 0, dimensions.data(), dimensions.size());
float* out0 = ort_.GetTensorMutableData<float>(output0);
dlib::matrix<float> dm_x(dimensions[0], dimensions[1]);
std::copy(X, X + dm_x.size(), dm_x.begin());
dlib::matrix<float> dm = dlib::inv(dm_x);
memcpy(out0, dm.steal_memory().get(), dm_x.size() * sizeof(float));
}
};
struct CustomOpInverse : OrtW::CustomOpBase<CustomOpInverse, KernelInverse> {
const char* GetName() const {
return "Inverse";
}
size_t GetInputTypeCount() const {
return 1;
}
ONNXTensorElementDataType GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
}
size_t GetOutputTypeCount() const {
return 1;
}
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
}
};
dlib::matrix<float> dm_x(dimensions[0], dimensions[1]);
std::copy(X, X + dm_x.size(), dm_x.begin());
dlib::matrix<float> dm = dlib::inv(dm_x);
memcpy(out, dm.steal_memory().get(), dm_x.size() * sizeof(float));
}

Просмотреть файл

@ -6,38 +6,31 @@
#include "ocos.h"
#include <dlib/matrix.h>
struct KernelStft : BaseKernel {
KernelStft(const OrtApi& api, const OrtKernelInfo& info, bool return_magnitude)
: BaseKernel(api, info), return_magnitude_(return_magnitude) {
struct STFT : public BaseKernel {
STFT(const OrtApi& api, const OrtKernelInfo& info,
bool with_norm = false) : BaseKernel(api, info),
with_norm_(with_norm) {
onesided_ = TryToGetAttributeWithDefault<int64_t>("onesided", 1);
}
void Compute(OrtKernelContext* context) {
const OrtValue* input_x1 = ort_.KernelContext_GetInput(context, 0);
const OrtValue* input_x2 = ort_.KernelContext_GetInput(context, 1);
const OrtValue* input_x3 = ort_.KernelContext_GetInput(context, 2);
const OrtValue* input_x4 = ort_.KernelContext_GetInput(context, 3);
const OrtValue* input_x5 = ort_.KernelContext_GetInput(context, 4);
void Compute(const ortc::Tensor<float>& input0,
int64_t n_fft,
int64_t hop_length,
const ortc::Span<float>& input3,
int64_t frame_length,
ortc::Tensor<float>& output0) {
auto X = input0.Data();
auto window = input3.data();
auto dimensions = input0.Shape();
auto win_length = input3.size();
const float* X = ort_.GetTensorData<float>(input_x1);
auto n_fft = *ort_.GetTensorData<int64_t>(input_x2);
auto hop_length = *ort_.GetTensorData<int64_t>(input_x3);
auto window = ort_.GetTensorData<float>(input_x4);
auto frame_length = *ort_.GetTensorData<int64_t>(input_x5);
OrtTensorDimensions dimensions(ort_, input_x1);
OrtTensorDimensions win_dim(ort_, input_x4);
if (dimensions.size() < 2 || dimensions.Size() != dimensions[1]) {
if (dimensions.size() < 2 || input0.NumberOfElement() != dimensions[1]) {
ORTX_CXX_API_THROW("[Stft] Only batch == 1 tensor supported.", ORT_INVALID_ARGUMENT);
}
if (win_dim.size() != 1) {
ORTX_CXX_API_THROW("[Stft] Only 1-d hanning window supported.", ORT_INVALID_ARGUMENT);
}
if (frame_length != n_fft) {
ORTX_CXX_API_THROW("[Stft] Only support size of FFT equals the frame length.", ORT_INVALID_ARGUMENT);
}
auto win_length = win_dim[0];
dlib::matrix<float> dm_x = dlib::mat(X, 1, dimensions[1]);
dlib::matrix<float> hann_win = dlib::mat(window, 1, win_length);
@ -49,23 +42,19 @@ struct KernelStft : BaseKernel {
m_stft = dlib::subm(m_stft, 0, 0, m_stft.nr(), (m_stft.nc() >> 1) + 1);
}
if (return_magnitude_) {
if (with_norm_) {
dlib::matrix<float> result = dlib::norm(m_stft);
result = dlib::trans(result);
int64_t outdim[] = {1, result.nr(), result.nc()};
std::vector<int64_t> outdim{1, result.nr(), result.nc()};
auto result_size = result.size();
OrtValue* output0 = ort_.KernelContext_GetOutput(
context, 0, outdim, 3);
float* out0 = ort_.GetTensorMutableData<float>(output0);
auto out0 = output0.Allocate(outdim);
memcpy(out0, result.steal_memory().get(), result_size * sizeof(float));
} else {
auto result = m_stft;
// No transpose here since it is done on copying data,
// switch nr and nc, so the output dim willbe tranposed one.
int64_t outdim[] = {1, result.nc(), result.nr(), 2};
OrtValue* output0 = ort_.KernelContext_GetOutput(
context, 0, outdim, 4);
float* out0 = ort_.GetTensorMutableData<float>(output0);
std::vector<int64_t> outdim{1, result.nc(), result.nr(), 2};
auto out0 = output0.Allocate(outdim);
for (size_t c = 0; c < result.nc(); ++c) {
for (size_t r = 0; r < result.nr(); ++r) {
*out0 = result(r, c).real();
@ -76,50 +65,18 @@ struct KernelStft : BaseKernel {
}
}
private:
bool return_magnitude_;
int64_t onesided_;
int64_t onesided_{};
bool with_norm_{};
};
struct CustomOpStft : OrtW::CustomOpBase<CustomOpStft, KernelStft> {
const char* GetName() const {
return op_name_;
struct StftNormal : public STFT {
StftNormal(const OrtApi& api, const OrtKernelInfo& info) : STFT(api, info, true) {}
void Compute(const ortc::Tensor<float>& input0,
int64_t n_fft,
int64_t hop_length,
const ortc::Span<float>& input3,
int64_t frame_length,
ortc::Tensor<float>& output0) {
STFT::Compute(input0, n_fft, hop_length, input3, frame_length, output0);
}
size_t GetInputTypeCount() const {
return 5;
}
ONNXTensorElementDataType GetInputType(size_t index) const {
// pcm and window are float
if (index == 0 || index == 3) {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
} else {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
}
}
size_t GetOutputTypeCount() const {
return 1;
}
void* CreateKernel(const OrtApi& api, const OrtKernelInfo& info) const {
return new KernelStft(api, info, with_norm_);
};
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
}
protected:
bool with_norm_ = false;
const char* op_name_ = "STFT";
};
struct CustomOpStftNorm : CustomOpStft {
public:
CustomOpStftNorm() {
with_norm_ = true;
op_name_ = "StftNorm";
}
};
};

Просмотреть файл

@ -7,14 +7,16 @@
#include "segment_extraction.hpp"
#include "segment_sum.hpp"
FxLoadCustomOpFactory LoadCustomOpClasses_Math =
LoadCustomOpClasses<CustomOpClassBegin,
CustomOpNegPos,
const std::vector<const OrtCustomOp*>& MathLoader() {
static OrtOpLoader op_loader(CustomCpuFunc("NegPos", neg_pos),
#ifdef ENABLE_DLIB
CustomOpInverse,
CustomOpStft,
CustomOpStftNorm,
CustomCpuFunc("Inverse", inverse),
CustomCpuStruct("STFT", STFT),
CustomCpuStruct("StftNorm", StftNormal),
#endif
CustomOpSegmentExtraction,
CustomOpSegmentSum>;
CustomCpuFunc("SegmentExtraction", segment_extraction),
CustomCpuFunc("SegmentSum", segment_sum));
return op_loader.GetCustomOps();
}
FxLoadCustomOpFactory LoadCustomOpClasses_Math = MathLoader;

Просмотреть файл

@ -5,58 +5,21 @@
#include "ocos.h"
struct KernelNegPos : BaseKernel {
KernelNegPos(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
}
void Compute(OrtKernelContext* context) {
// Setup inputs
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
const float* X = ort_.GetTensorData<float>(input_X);
// Setup output
OrtTensorDimensions dimensions(ort_, input_X);
OrtValue* output0 = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
float* out0 = ort_.GetTensorMutableData<float>(output0);
OrtValue* output1 = ort_.KernelContext_GetOutput(context, 1, dimensions.data(), dimensions.size());
float* out1 = ort_.GetTensorMutableData<float>(output1);
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output0);
int64_t size = ort_.GetTensorShapeElementCount(output_info);
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
// Do computation
for (int64_t i = 0; i < size; i++) {
if (X[i] > 0) {
out0[i] = 0;
out1[i] = X[i];
} else {
out0[i] = X[i];
out1[i] = 0;
}
void neg_pos(const ortc::Tensor<float>& input,
ortc::Tensor<float>& out0_tensor,
ortc::Tensor<float>& out1_tensor) {
int64_t size = input.NumberOfElement();
float* out0 = out0_tensor.Allocate(input.Shape());
float* out1 = out1_tensor.Allocate(input.Shape());
const float* X = input.Data();
// Do computation
for (int64_t i = 0; i < size; i++) {
if (X[i] > 0) {
out0[i] = 0;
out1[i] = X[i];
} else {
out0[i] = X[i];
out1[i] = 0;
}
}
};
struct CustomOpNegPos : OrtW::CustomOpBase<CustomOpNegPos, KernelNegPos> {
const char* GetName() const {
return "NegPos";
}
size_t GetInputTypeCount() const {
return 1;
}
ONNXTensorElementDataType GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
}
size_t GetOutputTypeCount() const {
return 2;
}
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
}
};
}

Просмотреть файл

@ -3,21 +3,17 @@
#include "segment_extraction.hpp"
KernelSegmentExtraction::KernelSegmentExtraction(const OrtApi& api, const OrtKernelInfo& info)
: BaseKernel(api, info) {
}
void KernelSegmentExtraction::Compute(OrtKernelContext* context) {
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
const int64_t* p_data = ort_.GetTensorData<int64_t>(input);
OrtTensorDimensions input_dim(ort_, input);
void segment_extraction(const ortc::Tensor<int64_t>& input,
ortc::Tensor<int64_t>& output0,
ortc::Tensor<int64_t>& output1) {
auto& input_dim = input.Shape();
if (!((input_dim.size() == 1) || (input_dim.size() == 2 && input_dim[0] == 1))) {
ORTX_CXX_API_THROW("[SegmentExtraction]: Expect input dimension [n] or [1,n].", ORT_INVALID_GRAPH);
}
const int64_t* p_data = input.Data();
std::vector<std::int64_t> segment_value;
std::vector<std::int64_t> segment_position;
for (std::int64_t i = 0; i < input_dim.Size(); i++) {
for (std::int64_t i = 0; i < input.NumberOfElement(); i++) {
if (!p_data[i]) {
continue;
}
@ -29,33 +25,17 @@ void KernelSegmentExtraction::Compute(OrtKernelContext* context) {
}
// push end position
if (i == (input_dim.Size() - 1) || p_data[i + 1] != p_data[i]) {
if (i == (input.NumberOfElement() - 1) || p_data[i + 1] != p_data[i]) {
segment_position.push_back(i + 1);
}
}
std::vector<int64_t> segment_value_dim({static_cast<int64_t>(segment_value.size())});
std::vector<int64_t> segment_position_dim({static_cast<int64_t>(segment_value.size()), 2});
SetOutput(context, 0, segment_position_dim, segment_position);
SetOutput(context, 1, segment_value_dim, segment_value);
int64_t* out0_data = output0.Allocate(segment_position_dim);
std::copy(segment_position.begin(), segment_position.end(), out0_data);
int64_t* out1_data = output1.Allocate(segment_value_dim);
std::copy(segment_value.begin(), segment_value.end(), out1_data);
}
size_t CustomOpSegmentExtraction::GetInputTypeCount() const {
return 1;
};
size_t CustomOpSegmentExtraction::GetOutputTypeCount() const {
return 2;
};
ONNXTensorElementDataType CustomOpSegmentExtraction::GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
};
const char* CustomOpSegmentExtraction::GetName() const {
return "SegmentExtraction";
};
ONNXTensorElementDataType CustomOpSegmentExtraction::GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
};

Просмотреть файл

@ -6,15 +6,6 @@
#include "ocos.h"
#include "string_utils.h"
struct KernelSegmentExtraction : BaseKernel {
KernelSegmentExtraction(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
};
struct CustomOpSegmentExtraction : OrtW::CustomOpBase<CustomOpSegmentExtraction, KernelSegmentExtraction> {
size_t GetInputTypeCount() const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
const char* GetName() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
};
void segment_extraction(const ortc::Tensor<int64_t>& input,
ortc::Tensor<int64_t>& output0,
ortc::Tensor<int64_t>& output1);

Просмотреть файл

@ -3,17 +3,11 @@
#include "segment_sum.hpp"
template <typename T>
void KernelSegmentSum_Compute(OrtW::CustomOpApi& ort_, OrtKernelContext* context) {
// Setup inputs
const OrtValue* data = ort_.KernelContext_GetInput(context, 0);
const T* p_data = ort_.GetTensorData<T>(data);
const OrtValue* segment_ids = ort_.KernelContext_GetInput(context, 1);
const int64_t* p_segment_ids = ort_.GetTensorData<int64_t>(segment_ids);
// Setup output
OrtTensorDimensions dim_data(ort_, data);
OrtTensorDimensions dim_seg(ort_, segment_ids);
void segment_sum(const ortc::Tensor<float>& data,
const ortc::Tensor<int64_t>& segment_ids,
ortc::Tensor<float>& output) {
auto& dim_data = data.Shape();
auto& dim_seg = segment_ids.Shape();
if (dim_data.size() == 0 || dim_seg.size() == 0)
ORTX_CXX_API_THROW("Both inputs cannot be empty.", ORT_INVALID_GRAPH);
if (dim_seg.size() != 1)
@ -24,22 +18,24 @@ void KernelSegmentSum_Compute(OrtW::CustomOpApi& ort_, OrtKernelContext* context
" segment_ids shape: ", dim_seg),
ORT_INVALID_GRAPH);
const int64_t* p_segment_ids = segment_ids.Data();
const float* p_data = data.Data();
int64_t last_seg = p_segment_ids[dim_seg[0] - 1];
OrtTensorDimensions dim_out = dim_data;
std::vector<int64_t> dim_out = dim_data;
dim_out[0] = last_seg + 1;
OrtValue* v = ort_.KernelContext_GetOutput(context, 0, dim_out.data(), dim_out.size());
T* p_output = ort_.GetTensorMutableData<T>(v);
int64_t out_size = dim_out.Size();
memset(p_output, 0, static_cast<size_t>(out_size * sizeof(T)));
float* p_output = output.Allocate(dim_out);
int64_t out_size = output.NumberOfElement();
memset(p_output, 0, static_cast<size_t>(out_size * sizeof(float)));
// The implementation is naive. It could be parallelized and
// use SIMD instructions to be faster.
int64_t in_stride = dim_data.Size();
const T* begin = p_data;
const T* end = p_data + in_stride;
int64_t in_stride = data.NumberOfElement();
const float* begin = p_data;
const float* end = p_data + in_stride;
in_stride /= dim_data[0];
T *p_out, *p_out_end;
float *p_out, *p_out_end;
const int64_t* p_seg = p_segment_ids;
for (; begin != end; ++p_seg) {
if ((p_seg != p_segment_ids) && (*p_seg != *(p_seg - 1)) && (*p_seg != *(p_seg - 1) + 1))
@ -53,37 +49,3 @@ void KernelSegmentSum_Compute(OrtW::CustomOpApi& ort_, OrtKernelContext* context
*p_out += *begin;
}
}
KernelSegmentSum::KernelSegmentSum(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
}
void KernelSegmentSum::Compute(OrtKernelContext* context) {
KernelSegmentSum_Compute<float>(ort_, context);
}
size_t CustomOpSegmentSum::GetInputTypeCount() const {
return 2;
};
size_t CustomOpSegmentSum::GetOutputTypeCount() const {
return 1;
};
ONNXTensorElementDataType CustomOpSegmentSum::GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
};
const char* CustomOpSegmentSum::GetName() const {
return "SegmentSum";
};
ONNXTensorElementDataType CustomOpSegmentSum::GetInputType(size_t index) const {
switch (index) {
case 0:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
case 1:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
default:
ORTX_CXX_API_THROW("Operator SegmentSum has 2 inputs.", ORT_INVALID_ARGUMENT);
}
};

Просмотреть файл

@ -6,15 +6,6 @@
#include "ocos.h"
#include "string_utils.h"
struct KernelSegmentSum : BaseKernel {
KernelSegmentSum(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
};
struct CustomOpSegmentSum : OrtW::CustomOpBase<CustomOpSegmentSum, KernelSegmentSum> {
size_t GetInputTypeCount() const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
const char* GetName() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
};
void segment_sum(const ortc::Tensor<float>& data,
const ortc::Tensor<int64_t>& segment_ids,
ortc::Tensor<float>& output);

Просмотреть файл

@ -8,18 +8,12 @@
#include <codecvt>
#include <algorithm>
KernelMaskedFill::KernelMaskedFill(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
}
void KernelMaskedFill::Compute(OrtKernelContext* context) {
// Setup inputs
const OrtValue* input_value = ort_.KernelContext_GetInput(context, 0);
const OrtValue* input_mask = ort_.KernelContext_GetInput(context, 1);
OrtTensorDimensions value_dimensions(ort_, input_value);
OrtTensorDimensions mask_dimensions(ort_, input_mask);
if (!(value_dimensions.IsScalar() || value_dimensions.IsVector())) {
void masked_fill(const ortc::Tensor<std::string>& input,
const ortc::Tensor<bool>& input_mask,
ortc::Tensor<std::string>& output) {
auto& value_dimensions = input.Shape();
auto& mask_dimensions = input_mask.Shape();
if (!(value_dimensions.empty() || mask_dimensions.size() == 1)) {
ORTX_CXX_API_THROW("[MaskedFill]: the dimension of input value should be vector or scalar.", ORT_INVALID_ARGUMENT);
}
@ -27,11 +21,8 @@ void KernelMaskedFill::Compute(OrtKernelContext* context) {
ORTX_CXX_API_THROW("[MaskedFill]: the dimension of input value and mask should be same.", ORT_INVALID_ARGUMENT);
}
std::vector<std::string> value;
const bool* mask = nullptr;
GetTensorMutableDataString(api_, ort_, context, input_value, value);
mask = ort_.GetTensorData<bool>(input_mask);
auto& value = input.Data();
const bool* mask = input_mask.Data();
std::vector<std::string> result;
std::vector<int64_t> result_dimension;
@ -44,33 +35,5 @@ void KernelMaskedFill::Compute(OrtKernelContext* context) {
result.push_back(value[i]);
}
result_dimension.push_back(result.size());
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, result_dimension.data(), result_dimension.size());
FillTensorDataString(api_, ort_, context, result, output);
output.SetStringOutput(result, result_dimension);
}
const char* CustomOpMaskedFill::GetName() const { return "MaskedFill"; };
size_t CustomOpMaskedFill::GetInputTypeCount() const {
return 2;
};
ONNXTensorElementDataType CustomOpMaskedFill::GetInputType(size_t index) const {
switch (index) {
case 0:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
case 1:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
default:
ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
}
};
size_t CustomOpMaskedFill::GetOutputTypeCount() const {
return 1;
};
ONNXTensorElementDataType CustomOpMaskedFill::GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};

Просмотреть файл

@ -7,18 +7,6 @@
#include "string_utils.h"
#include <unordered_map>
struct KernelMaskedFill : BaseKernel {
KernelMaskedFill(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
private:
std::unordered_map<std::string, std::string> map_;
};
struct CustomOpMaskedFill : OrtW::CustomOpBase<CustomOpMaskedFill, KernelMaskedFill> {
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
};
void masked_fill(const ortc::Tensor<std::string>& input,
const ortc::Tensor<bool>& input_mask,
ortc::Tensor<std::string>& output);

Просмотреть файл

@ -8,26 +8,9 @@
KernelStringEqual::KernelStringEqual(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
}
void KernelStringEqual::Compute(OrtKernelContext* context) {
void KernelStringEqual::Compute(OrtKernelContext* context,
const ortc::Tensor<std::string>&,
const ortc::Tensor<std::string>&,
ortc::Tensor<bool>& output) {
KernelEqual_Compute<std::string>(api_, ort_, context);
}
size_t CustomOpStringEqual::GetInputTypeCount() const {
return 2;
};
size_t CustomOpStringEqual::GetOutputTypeCount() const {
return 1;
};
ONNXTensorElementDataType CustomOpStringEqual::GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
};
const char* CustomOpStringEqual::GetName() const {
return "StringEqual";
};
ONNXTensorElementDataType CustomOpStringEqual::GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};

Просмотреть файл

@ -8,13 +8,8 @@
struct KernelStringEqual : BaseKernel {
KernelStringEqual(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
};
struct CustomOpStringEqual : OrtW::CustomOpBase<CustomOpStringEqual, KernelStringEqual> {
size_t GetInputTypeCount() const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
const char* GetName() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
void Compute(OrtKernelContext* context,
const ortc::Tensor<std::string>&,
const ortc::Tensor<std::string>&,
ortc::Tensor<bool>& output);
};

Просмотреть файл

@ -4,11 +4,12 @@
#include "string_tensor.h"
#include "op_ragged_tensor.hpp"
void KernelRaggedTensorToSparse::Compute(OrtKernelContext* context) {
const OrtValue* n_elements = ort_.KernelContext_GetInput(context, 0);
const int64_t* p_n_elements = ort_.GetTensorData<int64_t>(n_elements);
void KernelRaggedTensoroSparse::Compute(const ortc::Tensor<int64_t>& n_element,
ortc::Tensor<int64_t>& output_0,
ortc::Tensor<int64_t>& output_1) {
const int64_t* p_n_elements = n_element.Data();
OrtTensorDimensions d_length(ort_, n_elements);
auto& d_length = n_element.Shape();
if (d_length.size() != 1)
ORTX_CXX_API_THROW(MakeString(
@ -19,10 +20,8 @@ void KernelRaggedTensorToSparse::Compute(OrtKernelContext* context) {
std::vector<int64_t> shape{n_values, 2};
std::vector<int64_t> shape2{2};
OrtValue* v0 = ort_.KernelContext_GetOutput(context, 0, shape.data(), shape.size());
int64_t* out0 = ort_.GetTensorMutableData<int64_t>(v0);
OrtValue* v1 = ort_.KernelContext_GetOutput(context, 1, shape2.data(), shape2.size());
int64_t* out1 = ort_.GetTensorMutableData<int64_t>(v1);
int64_t* out0 = output_0.Allocate(shape);
int64_t* out1 = output_1.Allocate(shape2);
out1[0] = n_els;
out1[1] = 0;
int64_t row = 0;
@ -39,38 +38,18 @@ void KernelRaggedTensorToSparse::Compute(OrtKernelContext* context) {
}
}
size_t CustomOpRaggedTensorToSparse::GetInputTypeCount() const {
return 1;
};
size_t CustomOpRaggedTensorToSparse::GetOutputTypeCount() const {
return 2;
};
ONNXTensorElementDataType CustomOpRaggedTensorToSparse::GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
};
const char* CustomOpRaggedTensorToSparse::GetName() const {
return "RaggedTensorToSparse";
};
ONNXTensorElementDataType CustomOpRaggedTensorToSparse::GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
};
CommonRaggedTensorToDense::CommonRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo& info)
CommonRaggedTensoroDense::CommonRaggedTensoroDense(const OrtApi& api, const OrtKernelInfo& info)
: BaseKernel(api, info) {
}
void CommonRaggedTensorToDense::GetInputDims(OrtKernelContext* context, const OrtValue** inputs, OrtTensorDimensions* dims) {
void CommonRaggedTensoroDense::GetInputDims(OrtKernelContext* context, const OrtValue** inputs, OrtTensorDimensions* dims) {
for (int i = 0; i < 4; ++i) {
inputs[i] = ort_.KernelContext_GetInput(context, i);
dims[i] = OrtTensorDimensions(ort_, inputs[i]);
}
}
int64_t CommonRaggedTensorToDense::GetMaxCol(int64_t n, const int64_t* p_indices) {
int64_t CommonRaggedTensoroDense::GetMaxCol(int64_t n, const int64_t* p_indices) {
int64_t size = n;
int64_t max_col = 0;
for (int64_t i = 1; i < size; ++i) {
@ -79,26 +58,25 @@ int64_t CommonRaggedTensorToDense::GetMaxCol(int64_t n, const int64_t* p_indices
return max_col;
}
KernelRaggedTensorToDense::KernelRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo& info)
: CommonRaggedTensorToDense(api, info) {
missing_value_ = TryToGetAttributeWithDefault("missing_value", -1) ;
KernelRaggedTensoroDense::KernelRaggedTensoroDense(const OrtApi& api, const OrtKernelInfo& info)
: CommonRaggedTensoroDense(api, info) {
missing_value_ = TryToGetAttributeWithDefault("missing_value", -1);
}
void KernelRaggedTensorToDense::Compute(OrtKernelContext* context) {
const OrtValue* inputs[4];
OrtTensorDimensions dims[4];
GetInputDims(context, inputs, dims);
void KernelRaggedTensoroDense::Compute(const ortc::Tensor<int64_t>& input0,
const ortc::Tensor<int64_t>& input1,
const ortc::Tensor<int64_t>& input2,
const ortc::Tensor<int64_t>& input3,
ortc::Tensor<int64_t>& output) {
const int64_t* p_values = input1.Data();
const int64_t* p_missing = input2.Data();
const int64_t* p_indices = input3.Data();
const int64_t* p_values = ort_.GetTensorData<int64_t>(inputs[1]);
const int64_t* p_missing = ort_.GetTensorData<int64_t>(inputs[2]);
const int64_t* p_indices = ort_.GetTensorData<int64_t>(inputs[3]);
int64_t size = dims[3].Size();
int64_t size = input3.NumberOfElement();
int64_t max_col = GetMaxCol(size, p_indices);
std::vector<int64_t> shape_out{size - 1, max_col};
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, shape_out.data(), shape_out.size());
int64_t* dense = ort_.GetTensorMutableData<int64_t>(output);
int64_t* dense = output.Allocate(shape_out);
int64_t pos = 0;
int64_t j, pos_end;
@ -119,38 +97,20 @@ void KernelRaggedTensorToDense::Compute(OrtKernelContext* context) {
}
}
size_t CustomOpRaggedTensorToDense::GetInputTypeCount() const {
return 4;
};
size_t CustomOpRaggedTensorToDense::GetOutputTypeCount() const {
return 1;
};
ONNXTensorElementDataType CustomOpRaggedTensorToDense::GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
};
const char* CustomOpRaggedTensorToDense::GetName() const {
return "RaggedTensorToDense";
};
ONNXTensorElementDataType CustomOpRaggedTensorToDense::GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
};
KernelStringRaggedTensorToDense::KernelStringRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo& info) : CommonRaggedTensorToDense(api, info) {
KernelStringRaggedTensoroDense::KernelStringRaggedTensoroDense(const OrtApi& api, const OrtKernelInfo& info) : CommonRaggedTensoroDense(api, info) {
}
void KernelStringRaggedTensorToDense::Compute(OrtKernelContext* context) {
const OrtValue* inputs[4];
OrtTensorDimensions dims[4];
GetInputDims(context, inputs, dims);
void KernelStringRaggedTensoroDense::Compute(const ortc::Tensor<int64_t>& input0,
const ortc::Tensor<std::string>& input1,
const ortc::Tensor<int64_t>& input2,
const ortc::Tensor<std::string>& input3,
ortc::Tensor<std::string>& output) {
// const OrtValue* inputs[4];
// OrtTensorDimensions dims[4];
std::vector<std::string> input;
GetTensorMutableDataString(api_, ort_, context, inputs[1], input);
const int64_t* p_indices = ort_.GetTensorData<int64_t>(inputs[3]);
int64_t size = dims[3].Size();
auto& input = input1.Data();
const int64_t* p_indices = input2.Data();
int64_t size = input3.NumberOfElement();
int64_t max_col = GetMaxCol(size, p_indices);
std::vector<int64_t> shape_out{size - 1, max_col};
@ -170,36 +130,5 @@ void KernelStringRaggedTensorToDense::Compute(OrtKernelContext* context) {
}
pos = pos_end;
}
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, shape_out.data(), shape_out.size());
FillTensorDataString(api_, ort_, context, dense, output);
output.SetStringOutput(dense, shape_out);
}
size_t CustomOpStringRaggedTensorToDense::GetInputTypeCount() const {
return 4;
};
size_t CustomOpStringRaggedTensorToDense::GetOutputTypeCount() const {
return 1;
};
ONNXTensorElementDataType CustomOpStringRaggedTensorToDense::GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};
const char* CustomOpStringRaggedTensorToDense::GetName() const {
return "StringRaggedTensorToDense";
};
ONNXTensorElementDataType CustomOpStringRaggedTensorToDense::GetInputType(size_t index) const {
switch (index) {
case 1:
case 2:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
case 0:
case 3:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
default:
ORTX_CXX_API_THROW(MakeString("[StringRaggedTensorToDense] Unexpected output index ", index, "."), ORT_INVALID_ARGUMENT);
}
};

Просмотреть файл

@ -5,55 +5,40 @@
#include "ocos.h"
struct KernelRaggedTensorToSparse : BaseKernel {
KernelRaggedTensorToSparse(const OrtApi& api, const OrtKernelInfo& info)
struct KernelRaggedTensoroSparse : BaseKernel {
KernelRaggedTensoroSparse(const OrtApi& api, const OrtKernelInfo& info)
: BaseKernel(api, info) {}
void Compute(OrtKernelContext* context);
void Compute(const ortc::Tensor<int64_t>& n_element,
ortc::Tensor<int64_t>& output_0,
ortc::Tensor<int64_t>& output_1);
};
struct CustomOpRaggedTensorToSparse : OrtW::CustomOpBase<CustomOpRaggedTensorToSparse, KernelRaggedTensorToSparse> {
size_t GetInputTypeCount() const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
const char* GetName() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
};
struct CommonRaggedTensorToDense : BaseKernel {
CommonRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo& info);
struct CommonRaggedTensoroDense : BaseKernel {
CommonRaggedTensoroDense(const OrtApi& api, const OrtKernelInfo& info);
protected:
void GetInputDims(OrtKernelContext* context, const OrtValue** inputs, OrtTensorDimensions* dims);
int64_t GetMaxCol(int64_t n, const int64_t* p_indices);
};
struct KernelRaggedTensorToDense : CommonRaggedTensorToDense {
KernelRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
struct KernelRaggedTensoroDense : CommonRaggedTensoroDense {
KernelRaggedTensoroDense(const OrtApi& api, const OrtKernelInfo& info);
void Compute(const ortc::Tensor<int64_t>& input0,
const ortc::Tensor<int64_t>& input1,
const ortc::Tensor<int64_t>& input2,
const ortc::Tensor<int64_t>& input3,
ortc::Tensor<int64_t>& output);
private:
int64_t missing_value_;
};
struct CustomOpRaggedTensorToDense : OrtW::CustomOpBase<CustomOpRaggedTensorToDense, KernelRaggedTensorToDense> {
size_t GetInputTypeCount() const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
const char* GetName() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
};
struct KernelStringRaggedTensorToDense : CommonRaggedTensorToDense {
KernelStringRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
};
struct CustomOpStringRaggedTensorToDense : OrtW::CustomOpBase<CustomOpStringRaggedTensorToDense,
KernelStringRaggedTensorToDense> {
size_t GetInputTypeCount() const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
const char* GetName() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
struct KernelStringRaggedTensoroDense : CommonRaggedTensoroDense {
KernelStringRaggedTensoroDense(const OrtApi& api, const OrtKernelInfo& info);
void Compute(const ortc::Tensor<int64_t>& input0,
const ortc::Tensor<std::string>& input1,
const ortc::Tensor<int64_t>& input2,
const ortc::Tensor<std::string>& input3,
ortc::Tensor<std::string>& output);
};

Просмотреть файл

@ -13,45 +13,21 @@ KernelStringRegexReplace::KernelStringRegexReplace(const OrtApi& api, const OrtK
global_replace_ = TryToGetAttributeWithDefault("global_replace",1);
}
void KernelStringRegexReplace::Compute(OrtKernelContext* context) {
// Setup inputs
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
const OrtValue* pattern = ort_.KernelContext_GetInput(context, 1);
const OrtValue* rewrite = ort_.KernelContext_GetInput(context, 2);
std::vector<std::string> str_input, str_pattern, str_rewrite;
GetTensorMutableDataString(api_, ort_, context, input, str_input);
GetTensorMutableDataString(api_, ort_, context, pattern, str_pattern);
GetTensorMutableDataString(api_, ort_, context, rewrite, str_rewrite);
// Verifications
OrtTensorDimensions pattern_dimensions(ort_, pattern);
OrtTensorDimensions rewrite_dimensions(ort_, rewrite);
if (pattern_dimensions.size() != 1 || pattern_dimensions[0] != 1)
ORTX_CXX_API_THROW(MakeString(
"pattern (second input) must contain only one element. It has ",
pattern_dimensions.size(), " dimensions."),
ORT_INVALID_ARGUMENT);
if (rewrite_dimensions.size() != 1 || rewrite_dimensions[0] != 1)
ORTX_CXX_API_THROW(MakeString(
"rewrite (third input) must contain only one element. It has ",
rewrite_dimensions.size(), " dimensions."),
ORT_INVALID_ARGUMENT);
if (str_pattern[0].empty())
void KernelStringRegexReplace::Compute(const ortc::Tensor<std::string>& input,
std::string_view str_pattern,
std::string_view str_rewrite,
ortc::Tensor<std::string>& output) {
if (str_pattern.empty())
ORTX_CXX_API_THROW("pattern (second input) cannot be empty.", ORT_INVALID_ARGUMENT);
// Setup output
OrtTensorDimensions dimensions(ort_, input);
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
std::vector<std::string> str_input{input.Data()};
auto dim = input.Shape();
size_t size = input.NumberOfElement();
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
size_t size = ort_.GetTensorShapeElementCount(output_info);
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
re2::StringPiece piece(str_rewrite.data());
re2::RE2 reg(str_pattern.data());
re2::StringPiece piece(str_rewrite[0]);
re2::RE2 reg(str_pattern[0]);
// Do computation
if (global_replace_) {
for (size_t i = 0; i < size; i++) {
re2::RE2::GlobalReplace(&(str_input[i]), reg, piece);
@ -61,24 +37,5 @@ void KernelStringRegexReplace::Compute(OrtKernelContext* context) {
re2::RE2::Replace(&(str_input[i]), reg, piece);
}
}
FillTensorDataString(api_, ort_, context, str_input, output);
}
const char* CustomOpStringRegexReplace::GetName() const { return "StringRegexReplace"; };
size_t CustomOpStringRegexReplace::GetInputTypeCount() const {
return 3;
};
ONNXTensorElementDataType CustomOpStringRegexReplace::GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};
size_t CustomOpStringRegexReplace::GetOutputTypeCount() const {
return 1;
};
ONNXTensorElementDataType CustomOpStringRegexReplace::GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};
output.SetStringOutput(str_input, dim);
}

Просмотреть файл

@ -8,16 +8,11 @@
struct KernelStringRegexReplace : BaseKernel {
KernelStringRegexReplace(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
void Compute(const ortc::Tensor<std::string>& input,
std::string_view str_pattern,
std::string_view str_rewrite,
ortc::Tensor<std::string>& output);
protected:
int64_t global_replace_;
};
struct CustomOpStringRegexReplace : OrtW::CustomOpBase<CustomOpStringRegexReplace, KernelStringRegexReplace> {
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
};
};

Просмотреть файл

@ -7,39 +7,30 @@
#include <vector>
#include <cmath>
KernelStringRegexSplitWithOffsets::KernelStringRegexSplitWithOffsets(const OrtApi& api, const OrtKernelInfo& info)
: BaseKernel(api, info) {
}
void KernelStringRegexSplitWithOffsets::Compute(OrtKernelContext* context) {
void KernelStringRegexSplitWithOffsets(const ortc::Tensor<std::string>& input,
std::string_view str_pattern,
const ortc::Tensor<std::string>& str_keep_pattern,
ortc::Tensor<std::string>& output_text,
ortc::Tensor<int64_t>& output_begin,
ortc::Tensor<int64_t>& output_end,
ortc::Tensor<int64_t>& output_offset) {
// Setup inputs
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
const OrtValue* pattern = ort_.KernelContext_GetInput(context, 1);
const OrtValue* keep_pattern = ort_.KernelContext_GetInput(context, 2);
std::vector<std::string> str_input, str_pattern, str_keep_pattern;
GetTensorMutableDataString(api_, ort_, context, input, str_input);
GetTensorMutableDataString(api_, ort_, context, pattern, str_pattern);
GetTensorMutableDataString(api_, ort_, context, keep_pattern, str_keep_pattern);
std::vector<std::string> str_input(input.Data());
// Verifications
OrtTensorDimensions keep_pattern_dimensions(ort_, keep_pattern);
if (str_pattern.size() != 1)
ORTX_CXX_API_THROW(MakeString("pattern (second input) must contain only one element. It has ",
str_pattern.size(), " values."),
ORT_INVALID_ARGUMENT);
if (str_keep_pattern.size() > 1)
ORTX_CXX_API_THROW(MakeString("Third input must contain only one element. It has ",
str_keep_pattern.size(), " values."),
ORT_INVALID_ARGUMENT);
if (str_pattern[0].empty())
if (str_pattern.empty()) {
ORTX_CXX_API_THROW("Splitting pattern cannot be empty.", ORT_INVALID_ARGUMENT);
}
if (str_keep_pattern.Data().size() > 1) {
ORTX_CXX_API_THROW(MakeString("Third input must contain only one element. It has ",
str_keep_pattern.Data().size(), " values."),
ORT_INVALID_ARGUMENT);
}
auto dimensions = input.Shape();
bool include_delimiter = (str_keep_pattern.Data().size() == 1) && (!str_keep_pattern.Data()[0].empty());
OrtTensorDimensions dimensions(ort_, input);
bool include_delimiter = (str_keep_pattern.size() == 1) && (!str_keep_pattern[0].empty());
re2::RE2 reg(str_pattern[0]);
re2::RE2 keep_reg(include_delimiter ? str_keep_pattern[0] : "");
re2::RE2 reg(str_pattern.data());
re2::RE2 keep_reg(include_delimiter ? str_keep_pattern.Data()[0].data() : "");
std::vector<std::string> all_tokens;
std::vector<int64_t> all_begin_offsets, all_end_offsets;
@ -63,47 +54,15 @@ void KernelStringRegexSplitWithOffsets::Compute(OrtKernelContext* context) {
// Setup output
std::vector<int64_t> dim_out{(int64_t)all_tokens.size()};
OrtValue* output_text = ort_.KernelContext_GetOutput(context, 0, dim_out.data(), dim_out.size());
FillTensorDataString(api_, ort_, context, all_tokens, output_text);
output_text.SetStringOutput(all_tokens, dim_out);
OrtValue* output = ort_.KernelContext_GetOutput(context, 1, dim_out.data(), dim_out.size());
int64_t* p_output = ort_.GetTensorMutableData<int64_t>(output);
memcpy(p_output, all_begin_offsets.data(), all_begin_offsets.size() * sizeof(int64_t));
auto output_raw = output_begin.Allocate(dim_out);
memcpy(output_raw, all_begin_offsets.data(), all_begin_offsets.size() * sizeof(int64_t));
output = ort_.KernelContext_GetOutput(context, 2, dim_out.data(), dim_out.size());
p_output = ort_.GetTensorMutableData<int64_t>(output);
memcpy(p_output, all_end_offsets.data(), all_end_offsets.size() * sizeof(int64_t));
output_raw = output_end.Allocate(dim_out);
memcpy(output_raw, all_end_offsets.data(), all_end_offsets.size() * sizeof(int64_t));
std::vector<int64_t> dim_out_row{(int64_t)row_offsets.size()};
output = ort_.KernelContext_GetOutput(context, 3, dim_out_row.data(), dim_out_row.size());
p_output = ort_.GetTensorMutableData<int64_t>(output);
memcpy(p_output, row_offsets.data(), row_offsets.size() * sizeof(int64_t));
output_raw = output_offset.Allocate(dim_out_row);
memcpy(output_raw, row_offsets.data(), row_offsets.size() * sizeof(int64_t));
}
const char* CustomOpStringRegexSplitWithOffsets::GetName() const { return "StringRegexSplitWithOffsets"; };
size_t CustomOpStringRegexSplitWithOffsets::GetInputTypeCount() const {
return 3;
};
ONNXTensorElementDataType CustomOpStringRegexSplitWithOffsets::GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};
size_t CustomOpStringRegexSplitWithOffsets::GetOutputTypeCount() const {
return 4;
};
ONNXTensorElementDataType CustomOpStringRegexSplitWithOffsets::GetOutputType(size_t index) const {
switch (index) {
case 0:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
case 1:
case 2:
case 3:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
default:
ORTX_CXX_API_THROW(MakeString("StringRegexSplitWithOffsets has 4 outputs but index is ", index, "."),
ORT_INVALID_ARGUMENT);
}
};

Просмотреть файл

@ -7,15 +7,10 @@
#include "string_utils.h"
// See https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/regex_split_with_offsets.md.
struct KernelStringRegexSplitWithOffsets : BaseKernel {
KernelStringRegexSplitWithOffsets(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
};
struct CustomOpStringRegexSplitWithOffsets : OrtW::CustomOpBase<CustomOpStringRegexSplitWithOffsets, KernelStringRegexSplitWithOffsets> {
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
};
void KernelStringRegexSplitWithOffsets(const ortc::Tensor<std::string>& input,
std::string_view str_pattern,
const ortc::Tensor<std::string>& str_keep_pattern,
ortc::Tensor<std::string>& output_text,
ortc::Tensor<int64_t>& output_begin,
ortc::Tensor<int64_t>& output_end,
ortc::Tensor<int64_t>& output_offset);

Просмотреть файл

@ -8,48 +8,19 @@
#include <codecvt>
#include <algorithm>
KernelStringConcat::KernelStringConcat(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
}
void KernelStringConcat::Compute(OrtKernelContext* context) {
// Setup inputs
const OrtValue* left = ort_.KernelContext_GetInput(context, 0);
const OrtValue* right = ort_.KernelContext_GetInput(context, 1);
OrtTensorDimensions left_dim(ort_, left);
OrtTensorDimensions right_dim(ort_, right);
if (left_dim != right_dim) {
void string_concat(const ortc::Tensor<std::string>& left,
const ortc::Tensor<std::string>& right,
ortc::Tensor<std::string>& output) {
if (left.Shape() != right.Shape()) {
ORTX_CXX_API_THROW("Two input tensor should have the same dimension.", ORT_INVALID_ARGUMENT);
}
// make a copy as input is const
std::vector<std::string> left_value = left.Data();
auto& right_value = right.Data();
std::vector<std::string> left_value;
std::vector<std::string> right_value;
GetTensorMutableDataString(api_, ort_, context, left, left_value);
GetTensorMutableDataString(api_, ort_, context, right, right_value);
// reuse left_value as output to save memory
for (size_t i = 0; i < left_value.size(); i++) {
left_value[i].append(right_value[i]);
}
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, left_dim.data(), left_dim.size());
FillTensorDataString(api_, ort_, context, left_value, output);
output.SetStringOutput(left_value, left.Shape());
}
const char* CustomOpStringConcat::GetName() const { return "StringConcat"; };
size_t CustomOpStringConcat::GetInputTypeCount() const {
return 2;
};
ONNXTensorElementDataType CustomOpStringConcat::GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};
size_t CustomOpStringConcat::GetOutputTypeCount() const {
return 1;
};
ONNXTensorElementDataType CustomOpStringConcat::GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};

Просмотреть файл

@ -6,15 +6,6 @@
#include "ocos.h"
#include "string_utils.h"
struct KernelStringConcat : BaseKernel {
KernelStringConcat(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
};
struct CustomOpStringConcat : OrtW::CustomOpBase<CustomOpStringConcat, KernelStringConcat> {
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
};
void string_concat(const ortc::Tensor<std::string>& left,
const ortc::Tensor<std::string>& right,
ortc::Tensor<std::string>& output);

Просмотреть файл

@ -13,76 +13,34 @@ KernelStringECMARegexReplace::KernelStringECMARegexReplace(const OrtApi& api, co
ignore_case_ = TryToGetAttributeWithDefault("ignore_case", false);
}
void KernelStringECMARegexReplace::Compute(OrtKernelContext* context) {
// Setup inputs
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
const OrtValue* pattern = ort_.KernelContext_GetInput(context, 1);
const OrtValue* rewrite = ort_.KernelContext_GetInput(context, 2);
std::vector<std::string> str_input, str_pattern, str_rewrite;
GetTensorMutableDataString(api_, ort_, context, input, str_input);
GetTensorMutableDataString(api_, ort_, context, pattern, str_pattern);
GetTensorMutableDataString(api_, ort_, context, rewrite, str_rewrite);
// Verifications
OrtTensorDimensions pattern_dimensions(ort_, pattern);
OrtTensorDimensions rewrite_dimensions(ort_, rewrite);
if (pattern_dimensions.Size() != 1) {
ORTX_CXX_API_THROW(MakeString("pattern (second input) must contain only one element. It has ",
pattern_dimensions.size(), " dimensions."),
ORT_INVALID_GRAPH);
}
if (rewrite_dimensions.Size() != 1) {
ORTX_CXX_API_THROW(MakeString("rewrite (third input) must contain only one element. It has ",
rewrite_dimensions.size(), " dimensions."),
ORT_INVALID_GRAPH);
}
if (str_pattern[0].empty()) {
void KernelStringECMARegexReplace::Compute(const ortc::Tensor<std::string>& input,
std::string_view pattern,
std::string_view rewrite,
ortc::Tensor<std::string>& output) {
// make a copy as input is constant;
std::vector<std::string> str_input = input.Data();
if (pattern.empty()) {
ORTX_CXX_API_THROW("pattern (second input) cannot be empty.", ORT_INVALID_GRAPH);
}
// Setup output
OrtTensorDimensions dimensions(ort_, input);
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
size_t size = ort_.GetTensorShapeElementCount(output_info);
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
size_t size = input.NumberOfElement();
auto regex_flag = std::regex_constants::optimize | std::regex_constants::ECMAScript;
if (ignore_case_) {
regex_flag |= std::regex_constants::icase;
}
std::regex reg(str_pattern[0], regex_flag);
std::regex reg(pattern.data(), regex_flag);
if (global_replace_) {
for (size_t i = 0; i < size; i++) {
str_input[i] = std::regex_replace(str_input[i], reg, str_rewrite[0]);
str_input[i] = std::regex_replace(str_input[i], reg, rewrite.data());
}
} else {
for (size_t i = 0; i < size; i++) {
str_input[i] = std::regex_replace(str_input[i], reg, str_rewrite[0], std::regex_constants::format_first_only);
str_input[i] = std::regex_replace(str_input[i], reg, rewrite.data(), std::regex_constants::format_first_only);
}
}
FillTensorDataString(api_, ort_, context, str_input, output);
auto& dimensions = input.Shape();
output.SetStringOutput(str_input, dimensions);
}
const char* CustomOpStringECMARegexReplace::GetName() const { return "StringECMARegexReplace"; };
size_t CustomOpStringECMARegexReplace::GetInputTypeCount() const {
return 3;
};
ONNXTensorElementDataType CustomOpStringECMARegexReplace::GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};
size_t CustomOpStringECMARegexReplace::GetOutputTypeCount() const {
return 1;
};
ONNXTensorElementDataType CustomOpStringECMARegexReplace::GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};

Просмотреть файл

@ -8,17 +8,12 @@
struct KernelStringECMARegexReplace : BaseKernel {
KernelStringECMARegexReplace(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
void Compute(const ortc::Tensor<std::string>& input,
std::string_view pattern,
std::string_view rewrite,
ortc::Tensor<std::string>& output);
protected:
bool global_replace_;
bool ignore_case_;
};
struct CustomOpStringECMARegexReplace : OrtW::CustomOpBase<CustomOpStringECMARegexReplace, KernelStringECMARegexReplace> {
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
};

Просмотреть файл

@ -15,40 +15,26 @@ KernelStringECMARegexSplitWithOffsets::KernelStringECMARegexSplitWithOffsets(con
ignore_case_ = TryToGetAttributeWithDefault("ignore_case", false);
}
void KernelStringECMARegexSplitWithOffsets::Compute(OrtKernelContext* context) {
void KernelStringECMARegexSplitWithOffsets::Compute(const ortc::Tensor<std::string>& input,
std::string_view pattern,
std::string_view keep_pattern,
ortc::Tensor<std::string>& output_text,
ortc::Tensor<int64_t>& output1,
ortc::Tensor<int64_t>& output2,
ortc::Tensor<int64_t>& output3) {
// Setup inputs
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
const OrtValue* pattern = ort_.KernelContext_GetInput(context, 1);
const OrtValue* keep_pattern = ort_.KernelContext_GetInput(context, 2);
auto& str_input = input.Data();
std::vector<std::string> str_input, str_pattern, str_keep_pattern;
GetTensorMutableDataString(api_, ort_, context, input, str_input);
GetTensorMutableDataString(api_, ort_, context, pattern, str_pattern);
GetTensorMutableDataString(api_, ort_, context, keep_pattern, str_keep_pattern);
// Verifications
OrtTensorDimensions keep_pattern_dimensions(ort_, keep_pattern);
if (str_pattern.size() != 1)
ORTX_CXX_API_THROW(MakeString("pattern (second input) must contain only one element. It has ", str_pattern.size(),
" values."),
ORT_INVALID_GRAPH);
if (str_keep_pattern.size() > 1)
ORTX_CXX_API_THROW(MakeString("Third input must contain only one element. It has ", str_keep_pattern.size(),
" values."),
ORT_INVALID_GRAPH);
if (str_pattern[0].empty())
ORTX_CXX_API_THROW("Splitting pattern cannot be empty.", ORT_INVALID_GRAPH);
OrtTensorDimensions dimensions(ort_, input);
bool include_delimiter = (str_keep_pattern.size() == 1) && (!str_keep_pattern[0].empty());
auto& dimensions = input.Shape();
bool include_delimiter = !keep_pattern.empty();
auto regex_flag = std::regex_constants::ECMAScript;
if (ignore_case_) {
regex_flag |= std::regex_constants::icase;
}
std::regex reg(str_pattern[0], regex_flag);
std::regex keep_reg(include_delimiter ? str_keep_pattern[0] : "", regex_flag);
std::regex reg(pattern.data(), regex_flag);
std::regex keep_reg(include_delimiter ? keep_pattern.data() : "", regex_flag);
std::vector<std::string> all_tokens;
std::vector<int64_t> all_begin_offsets, all_end_offsets;
@ -72,48 +58,15 @@ void KernelStringECMARegexSplitWithOffsets::Compute(OrtKernelContext* context) {
// Setup output
std::vector<int64_t> dim_out{(int64_t)all_tokens.size()};
OrtValue* output_text = ort_.KernelContext_GetOutput(context, 0, dim_out.data(), dim_out.size());
FillTensorDataString(api_, ort_, context, all_tokens, output_text);
output_text.SetStringOutput(all_tokens, dim_out);
OrtValue* output = ort_.KernelContext_GetOutput(context, 1, dim_out.data(), dim_out.size());
int64_t* p_output = ort_.GetTensorMutableData<int64_t>(output);
int64_t* p_output = output1.Allocate(dim_out);
memcpy(p_output, all_begin_offsets.data(), all_begin_offsets.size() * sizeof(int64_t));
output = ort_.KernelContext_GetOutput(context, 2, dim_out.data(), dim_out.size());
p_output = ort_.GetTensorMutableData<int64_t>(output);
p_output = output2.Allocate(dim_out);
memcpy(p_output, all_end_offsets.data(), all_end_offsets.size() * sizeof(int64_t));
std::vector<int64_t> dim_out_row{(int64_t)row_offsets.size()};
output = ort_.KernelContext_GetOutput(context, 3, dim_out_row.data(), dim_out_row.size());
p_output = ort_.GetTensorMutableData<int64_t>(output);
p_output = output3.Allocate(dim_out_row);
memcpy(p_output, row_offsets.data(), row_offsets.size() * sizeof(int64_t));
}
const char* CustomOpStringECMARegexSplitWithOffsets::GetName() const { return "StringECMARegexSplitWithOffsets"; };
size_t CustomOpStringECMARegexSplitWithOffsets::GetInputTypeCount() const {
return 3;
};
ONNXTensorElementDataType CustomOpStringECMARegexSplitWithOffsets::GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};
size_t CustomOpStringECMARegexSplitWithOffsets::GetOutputTypeCount() const {
return 4;
};
ONNXTensorElementDataType CustomOpStringECMARegexSplitWithOffsets::GetOutputType(size_t index) const {
switch (index) {
case 0:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
case 1:
case 2:
case 3:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
default:
ORTX_CXX_API_THROW(MakeString(
"StringRegexSplitWithOffsets has 4 outputs but index is ", index, "."),
ORT_INVALID_ARGUMENT);
}
};

Просмотреть файл

@ -10,20 +10,18 @@
// See https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/regex_split_with_offsets.md.
struct KernelStringECMARegexSplitWithOffsets : BaseKernel {
KernelStringECMARegexSplitWithOffsets(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
void Compute(const ortc::Tensor<std::string>& input,
std::string_view pattern,
std::string_view keep_pattern,
ortc::Tensor<std::string>& output_text,
ortc::Tensor<int64_t>& output1,
ortc::Tensor<int64_t>& output2,
ortc::Tensor<int64_t>& output3);
private:
bool ignore_case_;
};
struct CustomOpStringECMARegexSplitWithOffsets : OrtW::CustomOpBase<CustomOpStringECMARegexSplitWithOffsets, KernelStringECMARegexSplitWithOffsets> {
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
};
template <typename T>
void ECMARegexSplitImpl(const std::string& input, const std::regex& pattern,
bool include_delimiter, const std::regex& include_delim_regex,

Просмотреть файл

@ -8,121 +8,37 @@
#include "string_tensor.h"
#include "string_hash.hpp"
KernelStringHash::KernelStringHash(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
}
void KernelStringHash::Compute(OrtKernelContext* context) {
void string_hash(const ortc::Tensor<std::string>& input,
int64_t num_buckets,
ortc::Tensor<int64_t>& output) {
// Setup inputs
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
const OrtValue* num_buckets = ort_.KernelContext_GetInput(context, 1);
const int64_t* p_num_buckets = ort_.GetTensorData<int64_t>(num_buckets);
std::vector<std::string> str_input;
GetTensorMutableDataString(api_, ort_, context, input, str_input);
// Verifications
OrtTensorDimensions num_buckets_dimensions(ort_, num_buckets);
if (num_buckets_dimensions.size() != 1 || num_buckets_dimensions[0] != 1)
ORTX_CXX_API_THROW(MakeString(
"num_buckets must contain only one element. It has ",
num_buckets_dimensions.size(), " dimensions."), ORT_INVALID_ARGUMENT);
auto& str_input = input.Data();
// Setup output
OrtTensorDimensions dimensions(ort_, input);
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
int64_t* out = ort_.GetTensorMutableData<int64_t>(output);
auto& dimensions = input.Shape();
int64_t* out = output.Allocate(dimensions);
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
size_t size = ort_.GetTensorShapeElementCount(output_info);
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
size_t size = output.NumberOfElement();
// Do computation
size_t nb = static_cast<size_t>(*p_num_buckets);
size_t nb = static_cast<size_t>(num_buckets);
for (size_t i = 0; i < size; i++) {
out[i] = static_cast<int64_t>(Hash64(str_input[i].c_str(), str_input[i].size()) % nb);
}
}
const char* CustomOpStringHash::GetName() const { return "StringToHashBucket"; };
size_t CustomOpStringHash::GetInputTypeCount() const {
return 2;
};
ONNXTensorElementDataType CustomOpStringHash::GetInputType(size_t index) const {
switch (index) {
case 0:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
case 1:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
default:
ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
}
};
size_t CustomOpStringHash::GetOutputTypeCount() const {
return 1;
};
ONNXTensorElementDataType CustomOpStringHash::GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
};
KernelStringHashFast::KernelStringHashFast(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
}
void KernelStringHashFast::Compute(OrtKernelContext* context) {
void string_hash_fast(const ortc::Tensor<std::string>& input,
int64_t num_buckets,
ortc::Tensor<int64_t>& output) {
// Setup inputs
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
const OrtValue* num_buckets = ort_.KernelContext_GetInput(context, 1);
const int64_t* p_num_buckets = ort_.GetTensorData<int64_t>(num_buckets);
std::vector<std::string> str_input;
GetTensorMutableDataString(api_, ort_, context, input, str_input);
// Verifications
OrtTensorDimensions num_buckets_dimensions(ort_, num_buckets);
if (num_buckets_dimensions.size() != 1 || num_buckets_dimensions[0] != 1)
ORTX_CXX_API_THROW(MakeString(
"num_buckets must contain only one element. It has ",
num_buckets_dimensions.size(), " dimensions."), ORT_INVALID_ARGUMENT);
auto& str_input = input.Data();
// Setup output
OrtTensorDimensions dimensions(ort_, input);
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
int64_t* out = ort_.GetTensorMutableData<int64_t>(output);
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
size_t size = ort_.GetTensorShapeElementCount(output_info);
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
auto& dimensions = input.Shape();
int64_t* out = output.Allocate(dimensions);
size_t size = output.NumberOfElement();
// Do computation
size_t nb = static_cast<size_t>(*p_num_buckets);
size_t nb = static_cast<size_t>(num_buckets);
for (size_t i = 0; i < size; i++) {
out[i] = static_cast<int64_t>(util::Fingerprint64(str_input[i].c_str(), str_input[i].size()) % nb);
}
}
const char* CustomOpStringHashFast::GetName() const { return "StringToHashBucketFast"; };
size_t CustomOpStringHashFast::GetInputTypeCount() const {
return 2;
};
ONNXTensorElementDataType CustomOpStringHashFast::GetInputType(size_t index) const {
switch (index) {
case 0:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
case 1:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
default:
ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
}
};
size_t CustomOpStringHashFast::GetOutputTypeCount() const {
return 1;
};
ONNXTensorElementDataType CustomOpStringHashFast::GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
};

Просмотреть файл

@ -6,28 +6,9 @@
#include "ocos.h"
#include "string_utils.h"
struct KernelStringHash : BaseKernel {
KernelStringHash(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
};
struct CustomOpStringHash : OrtW::CustomOpBase<CustomOpStringHash, KernelStringHash> {
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
};
struct KernelStringHashFast : BaseKernel {
KernelStringHashFast(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
};
struct CustomOpStringHashFast : OrtW::CustomOpBase<CustomOpStringHashFast, KernelStringHashFast> {
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
};
void string_hash(const ortc::Tensor<std::string>& input,
int64_t num_buckets,
ortc::Tensor<int64_t>& output);
void string_hash_fast(const ortc::Tensor<std::string>& input,
int64_t num_buckets,
ortc::Tensor<int64_t>& output);

Просмотреть файл

@ -4,40 +4,26 @@
#include "string_join.hpp"
#include "string_tensor.h"
KernelStringJoin::KernelStringJoin(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
}
void KernelStringJoin::Compute(OrtKernelContext* context) {
void string_join(const ortc::Tensor<std::string>& input_X,
std::string_view input_sep,
int64_t axis,
ortc::Tensor<std::string>& output) {
// Setup inputs
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
const OrtValue* input_sep = ort_.KernelContext_GetInput(context, 1);
const OrtValue* input_axis = ort_.KernelContext_GetInput(context, 2);
const int64_t* axis = ort_.GetTensorData<int64_t>(input_axis);
std::vector<std::string> X, sep;
GetTensorMutableDataString(api_, ort_, context, input_X, X);
GetTensorMutableDataString(api_, ort_, context, input_sep, sep);
// Check input
OrtTensorDimensions dimensions_sep(ort_, input_sep);
if (dimensions_sep.size() != 1 || dimensions_sep[0] != 1)
ORTX_CXX_API_THROW("Input 2 is the separator, it should have 1 element.", ORT_INVALID_ARGUMENT);
OrtTensorDimensions dimensions_axis(ort_, input_axis);
if (dimensions_axis.size() != 1 || dimensions_axis[0] != 1)
ORTX_CXX_API_THROW("Input 3 is the axis, it should have 1 element.", ORT_INVALID_ARGUMENT);
OrtTensorDimensions dimensions(ort_, input_X);
auto& X = input_X.Data();
auto& dimensions = input_X.Shape();
if (dimensions.size() == 0) {
// dimensions size 0 means input 1 is scalar, input 1 must have 1 element. See issue: https://github.com/onnx/onnx/issues/3724
if (X.size() != 1)
ORTX_CXX_API_THROW(MakeString("Input 1's dimensions size is 0 (scalar), it must has 1 element but it has ", X.size()), ORT_INVALID_ARGUMENT);
} else {
if (*axis < 0 || *axis >= static_cast<int64_t>(dimensions.size()))
ORTX_CXX_API_THROW(MakeString("axis must be positive and smaller than the number of dimension but it is ", *axis), ORT_INVALID_ARGUMENT);
if (axis < 0 || axis >= static_cast<int64_t>(dimensions.size()))
ORTX_CXX_API_THROW(MakeString("axis must be positive and smaller than the number of dimension but it is ", axis), ORT_INVALID_ARGUMENT);
}
std::vector<int64_t> dimensions_out(dimensions.size() > 1 ? dimensions.size() - 1 : 1);
if (dimensions.size() > 1) {
for (size_t i = 0, pos = 0; i < dimensions.size(); ++i) {
if (static_cast<int64_t>(i) == *axis)
if (static_cast<int64_t>(i) == axis)
continue;
dimensions_out[pos++] = dimensions[i];
}
@ -45,22 +31,19 @@ void KernelStringJoin::Compute(OrtKernelContext* context) {
dimensions_out[0] = 1;
}
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions_out.data(), dimensions_out.size());
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
int64_t size = ort_.GetTensorShapeElementCount(output_info);
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
int64_t size = std::accumulate(dimensions_out.begin(), dimensions_out.end(), 1ULL, std::multiplies<int64_t>());
std::vector<std::string> out(static_cast<size_t>(size));
if (dimensions.size() > 0) {
if (X.size() > 0) {
// Do computation
int64_t h = 1;
for (size_t i = static_cast<size_t>(*axis + 1); i < dimensions.size(); ++i) {
for (size_t i = static_cast<size_t>(axis + 1); i < dimensions.size(); ++i) {
h *= dimensions[i];
}
int64_t left_part = size / h;
int64_t right_part = size / left_part;
int64_t n_red = dimensions[static_cast<size_t>(*axis)] - 1;
int64_t n_red = dimensions[static_cast<size_t>(axis)] - 1;
int64_t inc = right_part * (n_red + 1);
int64_t pos = 0;
for (int64_t li = 0; li < left_part; ++li) {
@ -68,7 +51,7 @@ void KernelStringJoin::Compute(OrtKernelContext* context) {
std::ostringstream st;
int64_t index = ri + li * inc;
for (int64_t j = 0; j < n_red; ++j, index += h) {
st << X[static_cast<size_t>(index)] << sep[0];
st << X[static_cast<size_t>(index)] << input_sep;
}
st << X[static_cast<size_t>(index)];
out[static_cast<size_t>(pos)] = st.str();
@ -83,33 +66,5 @@ void KernelStringJoin::Compute(OrtKernelContext* context) {
out[0] = X[0];
}
FillTensorDataString(api_, ort_, context, out, output);
output.SetStringOutput(out, dimensions_out);
}
const char* CustomOpStringJoin::GetName() const {
return "StringJoin";
};
size_t CustomOpStringJoin::GetInputTypeCount() const {
return 3;
};
ONNXTensorElementDataType CustomOpStringJoin::GetInputType(size_t index) const {
switch (index) {
case 0:
case 1:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
case 2:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
default:
ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
}
};
size_t CustomOpStringJoin::GetOutputTypeCount() const {
return 1;
};
ONNXTensorElementDataType CustomOpStringJoin::GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};

Просмотреть файл

@ -6,15 +6,7 @@
#include "ocos.h"
#include "string_utils.h"
struct KernelStringJoin : BaseKernel {
KernelStringJoin(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
};
struct CustomOpStringJoin : OrtW::CustomOpBase<CustomOpStringJoin, KernelStringJoin> {
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
};
void string_join(const ortc::Tensor<std::string>& input_X,
std::string_view input_sep,
int64_t axis,
ortc::Tensor<std::string>& output);

Просмотреть файл

@ -9,38 +9,15 @@
#include <algorithm>
#include "ustring.h"
KernelStringLength::KernelStringLength(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
}
void KernelStringLength::Compute(OrtKernelContext* context) {
void string_length(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& output) {
// Setup inputs
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
std::vector<std::string> input_data;
GetTensorMutableDataString(api_, ort_, context, input, input_data);
auto& input_data = input.Data();
OrtTensorDimensions dimensions(ort_, input);
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
auto* output_data = ort_.GetTensorMutableData<int64_t>(output);
auto& dimensions = input.Shape();
auto* output_data = output.Allocate(dimensions);
for (int i = 0; i < dimensions.Size(); i++) {
for (int i = 0; i < input.NumberOfElement(); i++) {
output_data[i] = ustring(input_data[i]).size();
}
}
const char* CustomOpStringLength::GetName() const { return "StringLength"; };
size_t CustomOpStringLength::GetInputTypeCount() const {
return 1;
};
ONNXTensorElementDataType CustomOpStringLength::GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};
size_t CustomOpStringLength::GetOutputTypeCount() const {
return 1;
};
ONNXTensorElementDataType CustomOpStringLength::GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
};

Просмотреть файл

@ -6,15 +6,5 @@
#include "ocos.h"
#include "string_utils.h"
struct KernelStringLength : BaseKernel {
KernelStringLength(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
};
struct CustomOpStringLength : OrtW::CustomOpBase<CustomOpStringLength, KernelStringLength> {
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
};
void string_length(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& output);

Просмотреть файл

@ -7,38 +7,14 @@
#include <cmath>
#include <algorithm>
KernelStringLower::KernelStringLower(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
}
void KernelStringLower::Compute(OrtKernelContext* context) {
// Setup inputs
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
std::vector<std::string> X;
GetTensorMutableDataString(api_, ort_, context, input_X, X);
void string_lower(const ortc::Tensor<std::string>& input,
ortc::Tensor<std::string>& output) {
// make a copy as input is constant
std::vector<std::string> X = input.Data();
for (size_t i = 0; i < X.size(); ++i) {
std::transform(X[i].begin(), X[i].end(), X[i].begin(), [](char c) {return static_cast<char>(ToLower(c));});
std::transform(X[i].begin(), X[i].end(), X[i].begin(), [](char c) { return static_cast<char>(ToLower(c)); });
}
OrtTensorDimensions dimensions(ort_, input_X);
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
FillTensorDataString(api_, ort_, context, X, output);
output.SetStringOutput(X, input.Shape());
}
const char* CustomOpStringLower::GetName() const { return "StringLower"; };
size_t CustomOpStringLower::GetInputTypeCount() const {
return 1;
};
ONNXTensorElementDataType CustomOpStringLower::GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};
size_t CustomOpStringLower::GetOutputTypeCount() const {
return 1;
};
ONNXTensorElementDataType CustomOpStringLower::GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};

Просмотреть файл

@ -6,15 +6,5 @@
#include "ocos.h"
#include "string_utils.h"
struct KernelStringLower : BaseKernel {
KernelStringLower(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
};
struct CustomOpStringLower : OrtW::CustomOpBase<CustomOpStringLower, KernelStringLower> {
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
};
void string_lower(const ortc::Tensor<std::string>& input,
ortc::Tensor<std::string>& output);

Просмотреть файл

@ -21,39 +21,15 @@ KernelStringMapping::KernelStringMapping(const OrtApi& api, const OrtKernelInfo&
}
}
void KernelStringMapping::Compute(OrtKernelContext* context) {
// Setup inputs
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
std::vector<std::string> input_data;
GetTensorMutableDataString(api_, ort_, context, input, input_data);
OrtTensorDimensions dimensions(ort_, input);
void KernelStringMapping::Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<std::string>& output) {
// make a copy as input is constant
std::vector<std::string> input_data = input.Data();
for (auto& str : input_data) {
if (map_.find(str) != map_.end()) {
str = map_[str];
}
}
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
FillTensorDataString(api_, ort_, context, input_data, output);
output.SetStringOutput(input_data, input.Shape());
}
const char* CustomOpStringMapping::GetName() const { return "StringMapping"; };
size_t CustomOpStringMapping::GetInputTypeCount() const {
return 1;
};
ONNXTensorElementDataType CustomOpStringMapping::GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};
size_t CustomOpStringMapping::GetOutputTypeCount() const {
return 1;
};
ONNXTensorElementDataType CustomOpStringMapping::GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};

Просмотреть файл

@ -9,16 +9,9 @@
struct KernelStringMapping : BaseKernel {
KernelStringMapping(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
void Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<std::string>& output);
private:
std::unordered_map<std::string, std::string> map_;
};
struct CustomOpStringMapping : OrtW::CustomOpBase<CustomOpStringMapping, KernelStringMapping> {
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
};

Просмотреть файл

@ -4,27 +4,17 @@
#include "string_split.hpp"
#include "string_tensor.h"
KernelStringSplit::KernelStringSplit(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
}
void KernelStringSplit::Compute(OrtKernelContext* context) {
void string_split(const ortc::Tensor<std::string>& input_X,
std::string_view sep,
bool skip_empty,
ortc::Tensor<int64_t>& out_indices,
ortc::Tensor<std::string>& out_text,
ortc::Tensor<int64_t>& out_shape) {
// Setup inputs
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
const OrtValue* input_sep = ort_.KernelContext_GetInput(context, 1);
const OrtValue* input_skip_empty = ort_.KernelContext_GetInput(context, 2);
const bool* skip_empty = ort_.GetTensorData<bool>(input_skip_empty);
std::vector<std::string> X, sep;
GetTensorMutableDataString(api_, ort_, context, input_X, X);
GetTensorMutableDataString(api_, ort_, context, input_sep, sep);
auto& X = input_X.Data();
// Setup output
OrtTensorDimensions dimensions_sep(ort_, input_sep);
if (dimensions_sep.size() != 1 || dimensions_sep[0] != 1)
ORTX_CXX_API_THROW("Input 2 is the delimiter, it has 1 element.", ORT_INVALID_ARGUMENT);
OrtTensorDimensions dimensions_skip_empty(ort_, input_skip_empty);
if (dimensions_skip_empty.size() != 1 || dimensions_skip_empty[0] != 1)
ORTX_CXX_API_THROW("Input 3 is skip_empty, it has 1 element.", ORT_INVALID_ARGUMENT);
OrtTensorDimensions dimensions(ort_, input_X);
auto& dimensions = input_X.Shape();
if (dimensions.size() != 1)
ORTX_CXX_API_THROW("Only 1D tensor are supported as input.", ORT_INVALID_ARGUMENT);
@ -32,8 +22,7 @@ void KernelStringSplit::Compute(OrtKernelContext* context) {
std::vector<int64_t> indices;
int64_t maxc = 0;
int64_t col;
std::string delimiter = sep[0];
if (delimiter.size() == 0) {
if (sep.size() == 0) {
char word[2] = "a";
for (int64_t row = 0; row < dimensions[0]; ++row) {
const std::string& str = X[static_cast<size_t>(row)];
@ -48,7 +37,7 @@ void KernelStringSplit::Compute(OrtKernelContext* context) {
}
}
} else {
bool keep = !(*skip_empty);
bool keep = !skip_empty;
std::size_t current, previous = 0;
for (int64_t row = 0; row < dimensions[0]; ++row) {
const std::string& str = X[static_cast<size_t>(row)];
@ -56,7 +45,7 @@ void KernelStringSplit::Compute(OrtKernelContext* context) {
continue;
previous = 0;
col = 0;
current = str.find_first_of(delimiter);
current = str.find_first_of(sep);
while (current != std::string::npos) {
if (keep || current > previous) {
words.push_back(str.substr(previous, current - previous));
@ -65,7 +54,7 @@ void KernelStringSplit::Compute(OrtKernelContext* context) {
++col;
}
previous = current + 1;
current = str.find_first_of(delimiter, previous);
current = str.find_first_of(sep, previous);
}
current = str.size();
if (keep || current > previous) {
@ -79,55 +68,13 @@ void KernelStringSplit::Compute(OrtKernelContext* context) {
}
std::vector<int64_t> shape_indices = {static_cast<int64_t>(indices.size()) / 2, 2};
OrtValue* out_indices = ort_.KernelContext_GetOutput(context, 0, shape_indices.data(), shape_indices.size());
int64_t* p_indices = out_indices.Allocate(shape_indices);
std::vector<int64_t> shape_text(1, words.size());
OrtValue* out_text = ort_.KernelContext_GetOutput(context, 1, shape_text.data(), shape_text.size());
std::vector<int64_t> shape_shape(1, 2);
OrtValue* out_shape = ort_.KernelContext_GetOutput(context, 2, shape_shape.data(), shape_shape.size());
int64_t* p_indices = ort_.GetTensorMutableData<int64_t>(out_indices);
int64_t* p_shape = ort_.GetTensorMutableData<int64_t>(out_shape);
int64_t* p_shape = out_shape.Allocate(shape_shape);
memcpy(p_indices, indices.data(), indices.size() * sizeof(int64_t));
p_shape[0] = dimensions[0];
p_shape[1] = maxc;
FillTensorDataString(api_, ort_, context, words, out_text);
out_text.SetStringOutput(words, shape_text);
}
const char* CustomOpStringSplit::GetName() const {
return "StringSplit";
};
size_t CustomOpStringSplit::GetInputTypeCount() const {
return 3;
};
ONNXTensorElementDataType CustomOpStringSplit::GetInputType(size_t index) const {
switch (index) {
case 0:
case 1:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
case 2:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
default:
ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
}
};
size_t CustomOpStringSplit::GetOutputTypeCount() const {
return 3;
};
ONNXTensorElementDataType CustomOpStringSplit::GetOutputType(size_t index) const {
switch (index) {
case 0:
case 2:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
case 1:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
default:
ORTX_CXX_API_THROW(MakeString("[StringSplit] Unexpected output index ", index), ORT_INVALID_ARGUMENT);
}
};

Просмотреть файл

@ -6,15 +6,9 @@
#include "ocos.h"
#include "string_utils.h"
struct KernelStringSplit : BaseKernel {
KernelStringSplit(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
};
struct CustomOpStringSplit : OrtW::CustomOpBase<CustomOpStringSplit, KernelStringSplit> {
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
};
void string_split(const ortc::Tensor<std::string>& input_X,
std::string_view sep,
bool skip_empty,
ortc::Tensor<int64_t>& out_indices,
ortc::Tensor<std::string>& out_text,
ortc::Tensor<int64_t>& out_shape);

Просмотреть файл

@ -9,46 +9,16 @@
const char* WHITE_SPACE_CHARS = " \t\n\r\f\v";
KernelStringStrip::KernelStringStrip(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
}
void KernelStringStrip::Compute(OrtKernelContext* context) {
// Setup inputs
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
std::vector<std::string> X;
GetTensorMutableDataString(api_, ort_, context, input_X, X);
// For each string in input, replace with whitespace-trimmed version.
void string_strip(const ortc::Tensor<std::string>& input,
ortc::Tensor<std::string>& output) {
std::vector<std::string> X = input.Data();
for (size_t i = 0; i < X.size(); ++i) {
size_t nonWhitespaceBegin = X[i].find_first_not_of(WHITE_SPACE_CHARS);
if (nonWhitespaceBegin != std::string::npos) {
size_t nonWhitespaceEnd = X[i].find_last_not_of(WHITE_SPACE_CHARS);
size_t nonWhitespaceRange = nonWhitespaceEnd - nonWhitespaceBegin + 1;
X[i] = X[i].substr(nonWhitespaceBegin, nonWhitespaceRange);
}
}
// Fills the output
OrtTensorDimensions dimensions(ort_, input_X);
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
FillTensorDataString(api_, ort_, context, X, output);
output.SetStringOutput(X, input.Shape());
}
const char* CustomOpStringStrip::GetName() const { return "StringStrip"; };
size_t CustomOpStringStrip::GetInputTypeCount() const {
return 1;
};
ONNXTensorElementDataType CustomOpStringStrip::GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};
size_t CustomOpStringStrip::GetOutputTypeCount() const {
return 1;
};
ONNXTensorElementDataType CustomOpStringStrip::GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};

Просмотреть файл

@ -6,15 +6,5 @@
#include "ocos.h"
#include "string_utils.h"
struct KernelStringStrip : BaseKernel {
KernelStringStrip(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
};
struct CustomOpStringStrip : OrtW::CustomOpBase<CustomOpStringStrip, KernelStringStrip> {
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
};
void string_strip(const ortc::Tensor<std::string>& input,
ortc::Tensor<std::string>& output);

Просмотреть файл

@ -9,7 +9,9 @@ StringToVectorImpl::StringToVectorImpl(std::string& map, std::string& unk) {
ParseUnkownValue(unk);
}
std::vector<std::vector<int64_t>> StringToVectorImpl::Compute(std::vector<std::string>& str_input, const OrtTensorDimensions& input_dim, OrtTensorDimensions& output_dim) {
std::vector<std::vector<int64_t>> StringToVectorImpl::Compute(const std::vector<std::string>& str_input,
const std::vector<int64_t>& input_dim,
std::vector<int64_t>& output_dim) {
std::vector<std::vector<int64_t>> result;
// Set output dimension
@ -107,19 +109,15 @@ KernelStringToVector::KernelStringToVector(const OrtApi& api, const OrtKernelInf
impl_ = std::make_shared<StringToVectorImpl>(map, unk);
}
void KernelStringToVector::Compute(OrtKernelContext* context) {
void KernelStringToVector::Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& out) {
// Setup input
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
std::vector<std::string> input_data;
GetTensorMutableDataString(api_, ort_, context, input, input_data);
OrtTensorDimensions input_dim(ort_, input);
auto& input_data = input.Data();
// Get output
OrtTensorDimensions output_dim;
auto mapping_result = impl_->Compute(input_data, input_dim, output_dim);
std::vector<int64_t> output_dim;
auto mapping_result = impl_->Compute(input_data, input.Shape(), output_dim);
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, output_dim.data(), output_dim.size());
auto* output_data = ort_.GetTensorMutableData<int64_t>(output);
auto* output_data = out.Allocate(output_dim);
// Set output tensor data
int idx = 0;
@ -130,21 +128,3 @@ void KernelStringToVector::Compute(OrtKernelContext* context) {
}
}
}
const char* CustomOpStringToVector::GetName() const { return "StringToVector"; };
size_t CustomOpStringToVector::GetInputTypeCount() const {
return 1;
};
ONNXTensorElementDataType CustomOpStringToVector::GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};
size_t CustomOpStringToVector::GetOutputTypeCount() const {
return 1;
};
ONNXTensorElementDataType CustomOpStringToVector::GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
};

Просмотреть файл

@ -12,7 +12,9 @@
class StringToVectorImpl {
public:
StringToVectorImpl(std::string& map, std::string& unk);
std::vector<std::vector<int64_t>> Compute(std::vector<std::string>& str_input, const OrtTensorDimensions& input_dim, OrtTensorDimensions& output_dim);
std::vector<std::vector<int64_t>> Compute(const std::vector<std::string>& str_input,
const std::vector<int64_t>& input_dim,
std::vector<int64_t>& output_dim);
private:
void ParseMappingTable(std::string& map);
@ -29,16 +31,9 @@ class StringToVectorImpl {
struct KernelStringToVector : BaseKernel {
KernelStringToVector(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
void Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& out);
private:
std::shared_ptr<StringToVectorImpl> impl_;
};
struct CustomOpStringToVector : OrtW::CustomOpBase<CustomOpStringToVector, KernelStringToVector> {
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
};

Просмотреть файл

@ -7,39 +7,14 @@
#include <cmath>
#include <algorithm>
KernelStringUpper::KernelStringUpper(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
}
void KernelStringUpper::Compute(OrtKernelContext* context) {
void string_upper(const ortc::Tensor<std::string>& input,
ortc::Tensor<std::string>& output) {
// Setup inputs
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
std::vector<std::string> X;
GetTensorMutableDataString(api_, ort_, context, input_X, X);
std::vector<std::string> X = input.Data();
for (size_t i = 0; i < X.size(); ++i) {
std::transform(X[i].begin(), X[i].end(), X[i].begin(), [](char c) { return static_cast<char>(::toupper(c)); });
}
// Fills the output
OrtTensorDimensions dimensions(ort_, input_X);
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
FillTensorDataString(api_, ort_, context, X, output);
output.SetStringOutput(X, input.Shape());
}
const char* CustomOpStringUpper::GetName() const { return "StringUpper"; };
size_t CustomOpStringUpper::GetInputTypeCount() const {
return 1;
};
ONNXTensorElementDataType CustomOpStringUpper::GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};
size_t CustomOpStringUpper::GetOutputTypeCount() const {
return 1;
};
ONNXTensorElementDataType CustomOpStringUpper::GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};

Просмотреть файл

@ -6,15 +6,5 @@
#include "ocos.h"
#include "string_utils.h"
struct KernelStringUpper : BaseKernel {
KernelStringUpper(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
};
struct CustomOpStringUpper : OrtW::CustomOpBase<CustomOpStringUpper, KernelStringUpper> {
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
};
void string_upper(const ortc::Tensor<std::string>& input,
ortc::Tensor<std::string>& output);

Просмотреть файл

@ -20,28 +20,33 @@
#include "text/re2_strings/string_regex_split.hpp"
#endif // ENABLE_RE2_REGEX
FxLoadCustomOpFactory LoadCustomOpClasses_Text =
LoadCustomOpClasses<CustomOpClassBegin,
const std::vector<const OrtCustomOp*>& TextLoader() {
static OrtOpLoader op_loader(
#if defined(ENABLE_RE2_REGEX)
CustomOpStringRegexReplace,
CustomOpStringRegexSplitWithOffsets,
CustomCpuStruct("StringRegexReplace", KernelStringRegexReplace),
CustomCpuFunc("StringRegexSplitWithOffsets", KernelStringRegexSplitWithOffsets),
#endif // ENABLE_RE2_REGEX
CustomOpRaggedTensorToDense,
CustomOpRaggedTensorToSparse,
CustomOpStringRaggedTensorToDense,
CustomOpStringEqual,
CustomOpStringHash,
CustomOpStringHashFast,
CustomOpStringJoin,
CustomOpStringLower,
CustomOpStringUpper,
CustomOpStringMapping,
CustomOpMaskedFill,
CustomOpStringSplit,
CustomOpStringStrip,
CustomOpStringToVector,
CustomOpVectorToString,
CustomOpStringLength,
CustomOpStringConcat,
CustomOpStringECMARegexReplace,
CustomOpStringECMARegexSplitWithOffsets>;
CustomCpuStruct("RaggedTensorToSparse", KernelRaggedTensoroSparse),
CustomCpuStruct("RaggedTensorToDense", KernelRaggedTensoroDense),
CustomCpuStruct("StringRaggedTensorToDense", KernelStringRaggedTensoroDense),
CustomCpuStruct("StringEqual", KernelStringEqual),
CustomCpuFunc("StringToHashBucket", string_hash),
CustomCpuFunc("StringToHashBucketFast", string_hash_fast),
CustomCpuFunc("StringJoin", string_join),
CustomCpuFunc("StringLower", string_lower),
CustomCpuFunc("StringUpper", string_upper),
CustomCpuStruct("StringMapping", KernelStringMapping),
CustomCpuFunc("MaskedFill", masked_fill),
CustomCpuFunc("StringSplit", string_split),
CustomCpuFunc("StringStrip", string_strip),
CustomCpuStruct("StringToVector", KernelStringToVector),
CustomCpuStruct("VectorToString", KernelVectorToString),
CustomCpuFunc("StringLength", string_length),
CustomCpuFunc("StringConcat", string_concat),
CustomCpuStruct("StringECMARegexReplace", KernelStringECMARegexReplace),
CustomCpuStruct("StringECMARegexSplitWithOffsets", KernelStringECMARegexSplitWithOffsets));
return op_loader.GetCustomOps();
}
FxLoadCustomOpFactory LoadCustomOpClasses_Text = TextLoader;

Просмотреть файл

@ -18,16 +18,18 @@ VectorToStringImpl::VectorToStringImpl(std::string& map, std::string& unk) : unk
ParseMappingTable(map);
}
std::vector<std::string> VectorToStringImpl::Compute(const void* input, const OrtTensorDimensions& input_dim, OrtTensorDimensions& output_dim) {
std::vector<std::string> VectorToStringImpl::Compute(const void* input,
const std::vector<int64_t>& input_dim,
std::vector<int64_t>& output_dim) {
std::vector<std::string> result;
const int64_t* ptr = static_cast<const int64_t*>(input);
if (vector_len_ == 1 && (input_dim.size() == 1 || input_dim.IsScalar())) {
if (vector_len_ == 1 && (input_dim.size() == 1 || input_dim.empty())) {
// only hit when the key is a scalar and the input is a vector
output_dim = input_dim;
} else {
if (input_dim.IsScalar() || input_dim[input_dim.size() - 1] != static_cast<int64_t>(vector_len_)) {
if (input_dim.empty() || input_dim[input_dim.size() - 1] != static_cast<int64_t>(vector_len_)) {
ORTX_CXX_API_THROW(MakeString("Incompatible dimension: required vector length should be ", vector_len_), ORT_INVALID_ARGUMENT);
}
@ -36,7 +38,8 @@ std::vector<std::string> VectorToStringImpl::Compute(const void* input, const Or
}
std::vector<int64_t> key(vector_len_);
for (int64_t i = 0; i < input_dim.Size(); i = static_cast<int64_t>(i + vector_len_)) {
int64_t input_element_size = std::accumulate(input_dim.begin(), input_dim.end(), 1ULL, std::multiplies<int64_t>());
for (int64_t i = 0; i < input_element_size; i = static_cast<int64_t>(i + vector_len_)) {
// construct key
for (size_t j = 0; j < vector_len_; j++) {
key[j] = ptr[j];
@ -110,33 +113,11 @@ KernelVectorToString::KernelVectorToString(const OrtApi& api, const OrtKernelInf
impl_ = std::make_shared<VectorToStringImpl>(map, unk);
}
void KernelVectorToString::Compute(OrtKernelContext* context) {
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
const void* input_data = ort_.GetTensorData<int64_t>(input);
void KernelVectorToString::Compute(const ortc::Tensor<int64_t>& input,
ortc::Tensor<std::string>& out) {
const void* input_data = input.Data();
OrtTensorDimensions input_dim(ort_, input);
OrtTensorDimensions output_dim;
std::vector<std::string> mapping_result = impl_->Compute(input_data, input_dim, output_dim);
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, output_dim.data(), output_dim.size());
FillTensorDataString(api_, ort_, context, mapping_result, output);
std::vector<int64_t> output_dim;
std::vector<std::string> mapping_result = impl_->Compute(input_data, input.Shape(), output_dim);
out.SetStringOutput(mapping_result, output_dim);
}
const char* CustomOpVectorToString::GetName() const { return "VectorToString"; };
size_t CustomOpVectorToString::GetInputTypeCount() const {
return 1;
};
ONNXTensorElementDataType CustomOpVectorToString::GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
};
size_t CustomOpVectorToString::GetOutputTypeCount() const {
return 1;
};
ONNXTensorElementDataType CustomOpVectorToString::GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};

Просмотреть файл

@ -20,7 +20,9 @@ struct hash<std::vector<T>> {
class VectorToStringImpl {
public:
VectorToStringImpl(std::string& map, std::string& unk);
std::vector<std::string> Compute(const void* input, const OrtTensorDimensions& input_dim, OrtTensorDimensions& output_dim);
std::vector<std::string> Compute(const void* input,
const std::vector<int64_t>& input_dim,
std::vector<int64_t>& output_dim);
private:
void ParseMappingTable(std::string& map);
@ -34,16 +36,9 @@ class VectorToStringImpl {
struct KernelVectorToString : BaseKernel {
KernelVectorToString(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
void Compute(const ortc::Tensor<int64_t>& input,
ortc::Tensor<std::string>& out);
private:
std::shared_ptr<VectorToStringImpl> impl_;
};
struct CustomOpVectorToString : OrtW::CustomOpBase<CustomOpVectorToString, KernelVectorToString> {
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
};

Просмотреть файл

@ -91,37 +91,9 @@ KernelBasicTokenizer::KernelBasicTokenizer(const OrtApi& api, const OrtKernelInf
tokenizer_ = std::make_shared<BasicTokenizer>(do_lower_case, tokenize_chinese_chars, strip_accents, tokenize_punctuation, remove_control_chars);
}
void KernelBasicTokenizer::Compute(OrtKernelContext* context) {
void KernelBasicTokenizer::Compute(std::string_view input,
ortc::Tensor<std::string>& output) {
// Setup inputs
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
std::vector<std::string> input_data;
GetTensorMutableDataString(api_, ort_, context, input, input_data);
OrtTensorDimensions dimensions(ort_, input);
if (dimensions.size() != 1 && dimensions[0] != 1) {
ORTX_CXX_API_THROW("[BasicTokenizer]: only support string scalar.", ORT_INVALID_GRAPH);
}
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
std::vector<ustring> result = tokenizer_->Tokenize(ustring(input_data[0]));
FillTensorDataString(api_, ort_, context, result, output);
std::vector<ustring> result = tokenizer_->Tokenize(ustring(input));
output.SetStringOutput({result[0].operator std::string()}, {1});
}
const char* CustomOpBasicTokenizer::GetName() const { return "BasicTokenizer"; };
size_t CustomOpBasicTokenizer::GetInputTypeCount() const {
return 1;
};
ONNXTensorElementDataType CustomOpBasicTokenizer::GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};
size_t CustomOpBasicTokenizer::GetOutputTypeCount() const {
return 1;
};
ONNXTensorElementDataType CustomOpBasicTokenizer::GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};

Просмотреть файл

@ -23,16 +23,9 @@ class BasicTokenizer {
struct KernelBasicTokenizer : BaseKernel {
KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
void Compute(std::string_view input,
ortc::Tensor<std::string>& output);
private:
std::shared_ptr<BasicTokenizer> tokenizer_;
};
struct CustomOpBasicTokenizer : OrtW::CustomOpBase<CustomOpBasicTokenizer, KernelBasicTokenizer> {
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
};

Просмотреть файл

@ -292,11 +292,12 @@ KernelBertTokenizer::KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo&
ustring(suffix_indicator), max_len, truncation_strategy_name);
}
void KernelBertTokenizer::Compute(OrtKernelContext* context) {
void KernelBertTokenizer::Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& output,
ortc::Tensor<int64_t>& output1,
ortc::Tensor<int64_t>& output2) {
// Setup inputs
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
std::vector<std::string> input_data;
GetTensorMutableDataString(api_, ort_, context, input, input_data);
auto& input_data = input.Data();
if (input_data.size() != 1 && input_data.size() != 2) {
ORTX_CXX_API_THROW("[BertTokenizer]: only support one or two query.", ORT_INVALID_GRAPH);
@ -323,37 +324,23 @@ void KernelBertTokenizer::Compute(OrtKernelContext* context) {
std::vector<int64_t> output_dim{static_cast<int64_t>(input_ids.size())};
SetOutput(context, 0, output_dim, input_ids);
SetOutput(context, 1, output_dim, token_type_ids);
SetOutput(context, 2, output_dim, attention_mask);
}
const char* CustomOpBertTokenizer::GetName() const { return "BertTokenizer"; }
size_t CustomOpBertTokenizer::GetInputTypeCount() const {
return 1;
}
ONNXTensorElementDataType CustomOpBertTokenizer::GetInputType(size_t /* index */) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
}
size_t CustomOpBertTokenizer::GetOutputTypeCount() const {
return 3;
}
ONNXTensorElementDataType CustomOpBertTokenizer::GetOutputType(size_t /* index */) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
auto* p_out = output.Allocate(output_dim);
std::copy(input_ids.begin(), input_ids.end(), p_out);
auto* p_out1 = output1.Allocate(output_dim);
std::copy(token_type_ids.begin(), token_type_ids.end(), p_out1);
auto* p_out2 = output2.Allocate(output_dim);
std::copy(attention_mask.begin(), attention_mask.end(), p_out2);
}
KernelHfBertTokenizer::KernelHfBertTokenizer(const OrtApi& api, const OrtKernelInfo& info)
: KernelBertTokenizer(api, info) {}
void KernelHfBertTokenizer::Compute(OrtKernelContext* context) {
void KernelHfBertTokenizer::Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& output,
ortc::Tensor<int64_t>& output1,
ortc::Tensor<int64_t>& output2) {
// Setup inputs
const OrtValue* const input = ort_.KernelContext_GetInput(context, 0);
std::vector<std::string> input_data;
GetTensorMutableDataString(api_, ort_, context, input, input_data);
auto& input_data = input.Data();
if (input_data.size() != 2) {
ORTX_CXX_API_THROW("[HfBertTokenizer]: Support only two input strings.", ORT_INVALID_GRAPH);
@ -368,33 +355,11 @@ void KernelHfBertTokenizer::Compute(OrtKernelContext* context) {
std::vector<int64_t> attention_mask(input_ids.size(), 1LL);
const std::vector<int64_t> outer_dims{1LL, static_cast<int64_t>(input_ids.size())};
const std::vector<int64_t> inner_dims{1LL};
for (int32_t i = 0; i < 3; ++i) {
OrtValue* const value = ort_.KernelContext_GetOutput(context, i, outer_dims.data(), outer_dims.size());
OrtTensorTypeAndShapeInfo* const info = ort_.GetTensorTypeAndShape(value);
ort_.SetDimensions(info, inner_dims.data(), inner_dims.size());
ort_.ReleaseTensorTypeAndShapeInfo(info);
}
SetOutput(context, 0, outer_dims, input_ids);
SetOutput(context, 1, outer_dims, attention_mask);
SetOutput(context, 2, outer_dims, token_type_ids);
}
const char* CustomOpHfBertTokenizer::GetName() const { return "HfBertTokenizer"; }
size_t CustomOpHfBertTokenizer::GetInputTypeCount() const {
return 1;
}
ONNXTensorElementDataType CustomOpHfBertTokenizer::GetInputType(size_t /* index */) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
}
size_t CustomOpHfBertTokenizer::GetOutputTypeCount() const {
return 3;
}
ONNXTensorElementDataType CustomOpHfBertTokenizer::GetOutputType(size_t /* index */) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
auto* p_out = output.Allocate(outer_dims);
std::copy(input_ids.begin(), input_ids.end(), p_out);
auto* p_out1 = output1.Allocate(outer_dims);
std::copy(attention_mask.begin(), attention_mask.end(), p_out1);
auto* p_out2 = output2.Allocate(outer_dims);
std::copy(token_type_ids.begin(), token_type_ids.end(), p_out2);
}

Просмотреть файл

@ -11,7 +11,6 @@
#include <unordered_map>
class BertTokenizerVocab final {
public:
explicit BertTokenizerVocab(std::string_view vocab);
@ -92,29 +91,19 @@ class BertTokenizer final {
struct KernelBertTokenizer : BaseKernel {
KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
void Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& output,
ortc::Tensor<int64_t>& output1,
ortc::Tensor<int64_t>& output2);
protected:
std::unique_ptr<BertTokenizer> tokenizer_;
};
struct CustomOpBertTokenizer : OrtW::CustomOpBase<CustomOpBertTokenizer, KernelBertTokenizer> {
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
};
struct KernelHfBertTokenizer : KernelBertTokenizer {
KernelHfBertTokenizer(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
};
struct CustomOpHfBertTokenizer : OrtW::CustomOpBase<CustomOpHfBertTokenizer, KernelHfBertTokenizer> {
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
void Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& output,
ortc::Tensor<int64_t>& output1,
ortc::Tensor<int64_t>& output2);
};

Просмотреть файл

@ -136,30 +136,30 @@ KernelBertTokenizerDecoder::KernelBertTokenizerDecoder(const OrtApi& api, const
cls_token, mask_token, suffix_indicator);
}
void KernelBertTokenizerDecoder::Compute(OrtKernelContext* context) {
const OrtValue* ids = ort_.KernelContext_GetInput(context, 0);
const int64_t* p_ids = ort_.GetTensorData<int64_t>(ids);
OrtTensorDimensions ids_dim(ort_, ids);
void KernelBertTokenizerDecoder::Compute(const ortc::Tensor<int64_t>& ids,
const ortc::Tensor<int64_t>& positions,
ortc::Tensor<std::string>& output) {
const int64_t* p_ids = ids.Data();
auto& ids_dim = ids.Shape();
if (!((ids_dim.size() == 1) || (ids_dim.size() == 2 && ids_dim[0] == 1))) {
ORTX_CXX_API_THROW("[BertTokenizerDecoder]: Expect ids dimension [n] or [1,n].", ORT_INVALID_GRAPH);
}
// const int64_t* p_row_indices = ort_row_indices_dim.empty() ? nullptr : ort_.GetTensorData<int64_t>(ort_row_indices);
const OrtValue* positions = ort_.KernelContext_GetInput(context, 1);
OrtTensorDimensions positions_dim(ort_, positions);
auto& positions_dim = positions.Shape();
if (use_indices_ &&
(!((positions_dim.Size() == 0) ||
(!((positions.NumberOfElement() == 0) ||
(positions_dim.size() == 2 && positions_dim[1] == 2)))) {
ORTX_CXX_API_THROW("[BertTokenizerDecoder]: Expect positions empty or a [n, 2] matrix when use indices", ORT_INVALID_GRAPH);
}
const int64_t* p_positions = positions_dim.Size() == 0 ? nullptr : ort_.GetTensorData<int64_t>(positions);
const int64_t* p_positions = positions.NumberOfElement() == 0 ? nullptr : positions.Data();
std::vector<std::string> result;
std::vector<int64_t> output_dim(1);
if (!use_indices_) {
result.push_back(decoder_->Decode(std::vector<int64_t>(p_ids, p_ids + ids_dim.Size()), skip_special_tokens_, clean_up_tokenization_spaces_));
result.push_back(decoder_->Decode(std::vector<int64_t>(p_ids, p_ids + ids.NumberOfElement()), skip_special_tokens_, clean_up_tokenization_spaces_));
output_dim[0] = 1;
} else {
if (p_positions != nullptr) {
@ -172,25 +172,5 @@ void KernelBertTokenizerDecoder::Compute(OrtKernelContext* context) {
output_dim[0] = positions_dim[0];
}
}
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, output_dim.data(), output_dim.size());
FillTensorDataString(api_, ort_, context, result, output);
output.SetStringOutput(result, output_dim);
}
const char* CustomOpBertTokenizerDecoder::GetName() const { return "BertTokenizerDecoder"; };
size_t CustomOpBertTokenizerDecoder::GetInputTypeCount() const {
return 2;
};
ONNXTensorElementDataType CustomOpBertTokenizerDecoder::GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
};
size_t CustomOpBertTokenizerDecoder::GetOutputTypeCount() const {
return 1;
};
ONNXTensorElementDataType CustomOpBertTokenizerDecoder::GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};

Просмотреть файл

@ -31,7 +31,9 @@ class BertTokenizerDecoder {
struct KernelBertTokenizerDecoder : BaseKernel {
KernelBertTokenizerDecoder(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
void Compute(const ortc::Tensor<int64_t>& ids,
const ortc::Tensor<int64_t>& positions,
ortc::Tensor<std::string>& output);
private:
std::shared_ptr<BertTokenizerDecoder> decoder_;
@ -39,11 +41,3 @@ struct KernelBertTokenizerDecoder : BaseKernel {
bool skip_special_tokens_;
bool clean_up_tokenization_spaces_;
};
struct CustomOpBertTokenizerDecoder : OrtW::CustomOpBase<CustomOpBertTokenizerDecoder, KernelBertTokenizerDecoder> {
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
};

Просмотреть файл

@ -27,26 +27,14 @@ KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(const OrtApi& api
max_sentence = TryToGetAttributeWithDefault("max_sentence", -1);
}
void KernelBlingFireSentenceBreaker::Compute(OrtKernelContext* context) {
// Setup inputs
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
OrtTensorDimensions dimensions(ort_, input);
// TODO: fix this scalar check.
if (dimensions.Size() != 1 && dimensions[0] != 1) {
ORTX_CXX_API_THROW("We only support string scalar.", ORT_INVALID_ARGUMENT);
}
std::vector<std::string> input_data;
GetTensorMutableDataString(api_, ort_, context, input, input_data);
std::string& input_string = input_data[0];
int max_length = static_cast<int>(2 * input_string.size() + 1);
void KernelBlingFireSentenceBreaker::Compute(std::string_view input,
ortc::Tensor<std::string>& output) {
int max_length = static_cast<int>(2 * input.size() + 1);
std::unique_ptr<char[]> output_str = std::make_unique<char[]>(max_length);
int output_length = TextToSentencesWithOffsetsWithModel(input_string.data(), static_cast<int>(input_string.size()), output_str.get(), nullptr, nullptr, max_length, model_.get());
int output_length = TextToSentencesWithOffsetsWithModel(input.data(), static_cast<int>(input.size()), output_str.get(), nullptr, nullptr, max_length, model_.get());
if (output_length < 0) {
ORTX_CXX_API_THROW(MakeString("splitting input:\"", input_string, "\" failed"), ORT_INVALID_ARGUMENT);
ORTX_CXX_API_THROW(MakeString("splitting input:\"", input, "\" failed"), ORT_INVALID_ARGUMENT);
}
// inline split output_str by newline '\n'
@ -72,25 +60,5 @@ void KernelBlingFireSentenceBreaker::Compute(OrtKernelContext* context) {
std::vector<int64_t> output_dimensions(1);
output_dimensions[0] = output_sentences.size();
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, output_dimensions.data(), output_dimensions.size());
OrtW::ThrowOnError(api_, api_.FillStringTensor(output, output_sentences.data(), output_sentences.size()));
output.SetStringOutput(output_sentences, output_dimensions);
}
const char* CustomOpBlingFireSentenceBreaker::GetName() const { return "BlingFireSentenceBreaker"; };
size_t CustomOpBlingFireSentenceBreaker::GetInputTypeCount() const {
return 1;
};
ONNXTensorElementDataType CustomOpBlingFireSentenceBreaker::GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};
size_t CustomOpBlingFireSentenceBreaker::GetOutputTypeCount() const {
return 1;
};
ONNXTensorElementDataType CustomOpBlingFireSentenceBreaker::GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};

Просмотреть файл

@ -17,7 +17,8 @@ extern "C" void* SetModel(const unsigned char* pImgBytes, int ModelByteCount);
struct KernelBlingFireSentenceBreaker : BaseKernel {
KernelBlingFireSentenceBreaker(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
void Compute(std::string_view input,
ortc::Tensor<std::string>& output);
private:
using ModelPtr = std::shared_ptr<void>;
@ -25,12 +26,3 @@ struct KernelBlingFireSentenceBreaker : BaseKernel {
std::string model_data_;
int max_sentence;
};
struct CustomOpBlingFireSentenceBreaker : OrtW::CustomOpBase<CustomOpBlingFireSentenceBreaker,
KernelBlingFireSentenceBreaker> {
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
};

Просмотреть файл

@ -99,10 +99,10 @@ struct KernelBpeDecoder : public BaseKernel {
arr_vocab_.shrink_to_fit();
}
void Compute(OrtKernelContext* context) {
const OrtValue* ids = ort_.KernelContext_GetInput(context, 0);
const int64_t* p_ids = ort_.GetTensorData<int64_t>(ids);
OrtTensorDimensions ids_dim(ort_, ids);
void Compute(const ortc::Tensor<int64_t>& ids,
ortc::Tensor<std::string>& output) {
const int64_t* p_ids = ids.Data();
const auto& ids_dim = ids.Shape();
std::vector<int64_t> output_dim = {1};
if (ids_dim.size() > 1) {
output_dim.resize(ids_dim.size() - 1);
@ -110,14 +110,15 @@ struct KernelBpeDecoder : public BaseKernel {
}
size_t seq_len = ids_dim.back();
size_t string_batch = ids_dim.Size() / seq_len;
size_t string_batch = ids.NumberOfElement() / seq_len;
std::vector<std::string> decoded_strings;
decoded_strings.reserve(string_batch);
for (auto n = string_batch; n > 0; n--) {
std::string text;
bool f_special_last = false;
bool f_special = false;
auto count = static_cast<size_t>(ids_dim.Size());
auto count = static_cast<size_t>(ids.NumberOfElement());
for (size_t tok_idx = 0; tok_idx < count; ++tok_idx) {
const auto token = *(p_ids + tok_idx);
@ -163,9 +164,7 @@ struct KernelBpeDecoder : public BaseKernel {
decoded_strings.emplace_back(std::move(text));
p_ids += seq_len;
}
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, output_dim.data(), output_dim.size());
FillTensorDataString(api_, ort_, context, decoded_strings, output);
output.SetStringOutput(decoded_strings, output_dim);
}
private:
@ -182,26 +181,4 @@ struct KernelBpeDecoder : public BaseKernel {
std::map<char32_t, unsigned char> byte_decoder_;
std::map<int64_t, std::string> added_tokens_;
std::set<int64_t> all_special_ids_;
};
struct CustomOpBpeDecoder : OrtW::CustomOpBase<CustomOpBpeDecoder, KernelBpeDecoder> {
const char* GetName() const {
return "BpeDecoder";
}
size_t GetInputTypeCount() const {
return 1;
}
ONNXTensorElementDataType GetInputType(size_t index) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
}
size_t GetOutputTypeCount() const {
return 1;
}
ONNXTensorElementDataType GetOutputType(size_t index) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
}
};
};

Просмотреть файл

@ -3,6 +3,7 @@
// Partial code comes from other Microsoft employee.
#include "clip_tokenizer.hpp"
#include "narrow.h"
#include <optional>
KernelClipBpeTokenizer::KernelClipBpeTokenizer(const OrtApi& api, const OrtKernelInfo& info)
: BaseKernel(api, info) {
@ -115,13 +116,14 @@ std::vector<int64_t> KernelClipBpeTokenizer::Tokenize(ustring& input, int64_t ma
return res;
}
void KernelClipBpeTokenizer::Compute(OrtKernelContext* context) {
void KernelClipBpeTokenizer::Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) {
// Setup inputs
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
std::vector<std::string> str_input;
std::vector<std::string> str_input{input.Data()};
std::list<OffsetMappingType> offset_map;
GetTensorMutableDataString(api_, ort_, context, input, str_input);
OrtTensorDimensions input_dim(ort_, input);
const auto& input_dim = input.Shape();
std::vector<std::vector<int64_t>> tokenize_results;
for (auto& str : str_input) {
@ -138,18 +140,15 @@ void KernelClipBpeTokenizer::Compute(OrtKernelContext* context) {
max_length = static_cast<size_t>(padding_length_);
}
OrtTensorDimensions output_dim = input_dim;
std::vector<int64_t> output_dim = input_dim;
output_dim.push_back(max_length);
OrtTensorDimensions offset_dim = output_dim;
std::vector<int64_t> offset_dim = output_dim;
offset_dim.push_back(2); // tuple of offsets for each input id
OrtValue* tokenize_output = ort_.KernelContext_GetOutput(context, 0, output_dim.data(), output_dim.size());
OrtValue* attention_mask = ort_.KernelContext_GetOutput(context, 1, output_dim.data(), output_dim.size());
OrtValue* offset_mapping = ort_.KernelContext_GetOutput(context, 2, offset_dim.data(), offset_dim.size());
auto* token = ort_.GetTensorMutableData<int64_t>(tokenize_output);
if (attention_mask != nullptr) {
auto* mask = ort_.GetTensorMutableData<int64_t>(attention_mask);
auto* token = tokenize_output.Allocate(output_dim);
if (attention_mask.has_value()) {
auto* mask = (*attention_mask)->Allocate(output_dim);
int idx = 0;
for (auto& res : tokenize_results) {
for (int64_t id : res) {
@ -163,8 +162,8 @@ void KernelClipBpeTokenizer::Compute(OrtKernelContext* context) {
}
}
}
if (offset_mapping != nullptr) {
auto* offset = ort_.GetTensorMutableData<int64_t>(offset_mapping);
if (offset_mapping.has_value()) {
auto* offset = (*offset_mapping)->Allocate(offset_dim);
int idx2 = 0;
for (auto& res : offset_map) {
for (auto& mapping : res) {
@ -188,32 +187,3 @@ void KernelClipBpeTokenizer::Compute(OrtKernelContext* context) {
}
}
}
const char* CustomOpClipBpeTokenizer::GetName() const {
return "CLIPTokenizer";
}
size_t CustomOpClipBpeTokenizer::GetInputTypeCount() const {
return 1;
}
ONNXTensorElementDataType CustomOpClipBpeTokenizer::GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
}
OrtCustomOpInputOutputCharacteristic CustomOpClipBpeTokenizer::GetInputCharacteristic(size_t /*index*/) const {
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
}
size_t CustomOpClipBpeTokenizer::GetOutputTypeCount() const {
return 3;
}
ONNXTensorElementDataType CustomOpClipBpeTokenizer::GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
}
OrtCustomOpInputOutputCharacteristic CustomOpClipBpeTokenizer::GetOutputCharacteristic(size_t index) const {
return index == 0 ? OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED
: OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL;
}

Просмотреть файл

@ -6,7 +6,10 @@
struct KernelClipBpeTokenizer : BaseKernel {
KernelClipBpeTokenizer(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
void Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping);
private:
using OffsetMappingType = std::list<std::pair<size_t, size_t>>;
@ -16,13 +19,3 @@ struct KernelClipBpeTokenizer : BaseKernel {
std::list<std::pair<int, int>> byte_list_;
std::shared_ptr<VocabData> bbpe_tokenizer_;
};
struct CustomOpClipBpeTokenizer : OrtW::CustomOpBase<CustomOpClipBpeTokenizer, KernelClipBpeTokenizer> {
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t index) const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t index) const;
};

Просмотреть файл

@ -4,7 +4,6 @@
#include "gpt2_tokenizer.hpp"
KernelBpeTokenizer::KernelBpeTokenizer(const OrtApi& api, const OrtKernelInfo& info)
: BaseKernel(api, info) {
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(&info, "vocab");
@ -80,12 +79,12 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(const ustring& input, int64_t
return res;
}
void KernelBpeTokenizer::Compute(OrtKernelContext* context) {
void KernelBpeTokenizer::Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask) {
// Setup inputs
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
std::vector<std::string> str_input;
GetTensorMutableDataString(api_, ort_, context, input, str_input);
OrtTensorDimensions input_dim(ort_, input);
std::vector<std::string> str_input{input.Data()};
const auto& input_dim = input.Shape();
std::vector<std::vector<int64_t>> tokenize_results;
for (auto& str : str_input) {
@ -101,13 +100,11 @@ void KernelBpeTokenizer::Compute(OrtKernelContext* context) {
max_length = static_cast<size_t>(padding_length_);
}
OrtTensorDimensions output_dim = input_dim;
std::vector<int64_t> output_dim = input_dim;
output_dim.push_back(max_length);
OrtValue* tokenize_output = ort_.KernelContext_GetOutput(context, 0, output_dim.data(), output_dim.size());
OrtValue* attention_mask = ort_.KernelContext_GetOutput(context, 1, output_dim.data(), output_dim.size());
auto* token = ort_.GetTensorMutableData<int64_t>(tokenize_output);
if (attention_mask != nullptr) {
auto* mask = ort_.GetTensorMutableData<int64_t>(attention_mask);
auto* token = tokenize_output.Allocate(output_dim);
if (attention_mask.has_value()) {
auto* mask = (*attention_mask)->Allocate(output_dim);
int idx = 0;
for (auto& res : tokenize_results) {
for (int64_t id : res) {
@ -134,32 +131,3 @@ void KernelBpeTokenizer::Compute(OrtKernelContext* context) {
}
}
}
const char* CustomOpBpeTokenizer::GetName() const {
return "GPT2Tokenizer";
}
size_t CustomOpBpeTokenizer::GetInputTypeCount() const {
return 1;
}
ONNXTensorElementDataType CustomOpBpeTokenizer::GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
}
OrtCustomOpInputOutputCharacteristic CustomOpBpeTokenizer::GetInputCharacteristic(size_t /*index*/) const {
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
}
size_t CustomOpBpeTokenizer::GetOutputTypeCount() const {
return 2;
}
ONNXTensorElementDataType CustomOpBpeTokenizer::GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
}
OrtCustomOpInputOutputCharacteristic CustomOpBpeTokenizer::GetOutputCharacteristic(size_t index) const {
return index == 0 ? OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED
: OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL;
}

Просмотреть файл

@ -6,7 +6,9 @@
struct KernelBpeTokenizer : BaseKernel {
KernelBpeTokenizer(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
void Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask);
private:
std::vector<int64_t> Tokenize(const ustring& input, int64_t max_length);
@ -15,13 +17,3 @@ struct KernelBpeTokenizer : BaseKernel {
std::list<std::pair<int, int>> byte_list_;
std::shared_ptr<VocabData> bbpe_tokenizer_;
};
struct CustomOpBpeTokenizer : OrtW::CustomOpBase<CustomOpBpeTokenizer, KernelBpeTokenizer> {
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t index) const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t index) const;
};

Просмотреть файл

@ -5,7 +5,6 @@
#include "roberta_tokenizer.hpp"
#include "narrow.h"
KernelRobertaBpeTokenizer::KernelRobertaBpeTokenizer(const OrtApi& api, const OrtKernelInfo& info)
: BaseKernel(api, info) {
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(&info, "vocab");
@ -74,7 +73,7 @@ std::vector<int64_t> KernelRobertaBpeTokenizer::Tokenize(ustring& input, int64_t
size_t space_dif = 0;
if (utf8_token.at(0) == ' ') {
offset++;
space_dif = -1; // account for spaces used in offset map algorithm in bpe(byte_list_)
space_dif = -1; // account for spaces used in offset map algorithm in bpe(byte_list_)
}
// Get byte encodings prior to performing BPE
@ -108,13 +107,14 @@ std::vector<int64_t> KernelRobertaBpeTokenizer::Tokenize(ustring& input, int64_t
return res;
}
void KernelRobertaBpeTokenizer::Compute(OrtKernelContext* context) {
void KernelRobertaBpeTokenizer::Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping) {
// Setup inputs
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
std::vector<std::string> str_input;
std::vector<std::string> str_input{input.Data()};
std::list<OffsetMappingType> offset_map;
GetTensorMutableDataString(api_, ort_, context, input, str_input);
OrtTensorDimensions input_dim(ort_, input);
const auto& input_dim = input.Shape();
std::vector<std::vector<int64_t>> tokenize_results;
for (auto& str : str_input) {
@ -131,18 +131,15 @@ void KernelRobertaBpeTokenizer::Compute(OrtKernelContext* context) {
max_length = static_cast<size_t>(padding_length_);
}
OrtTensorDimensions output_dim = input_dim;
std::vector<int64_t> output_dim = input_dim;
output_dim.push_back(max_length);
OrtTensorDimensions offset_dim = output_dim;
offset_dim.push_back(2); // tuple of offsets for each input id
std::vector<int64_t> offset_dim = output_dim;
offset_dim.push_back(2); // tuple of offsets for each input id
OrtValue* tokenize_output = ort_.KernelContext_GetOutput(context, 0, output_dim.data(), output_dim.size());
OrtValue* attention_mask = ort_.KernelContext_GetOutput(context, 1, output_dim.data(), output_dim.size());
OrtValue* offset_mapping = ort_.KernelContext_GetOutput(context, 2, offset_dim.data(), offset_dim.size());
auto* token = ort_.GetTensorMutableData<int64_t>(tokenize_output);
if (attention_mask != nullptr) {
auto* mask = ort_.GetTensorMutableData<int64_t>(attention_mask);
auto* token = tokenize_output.Allocate(output_dim);
if (attention_mask.has_value()) {
auto* mask = (*attention_mask)->Allocate(output_dim);
int idx = 0;
for (auto& res : tokenize_results) {
for (int64_t id : res) {
@ -156,8 +153,8 @@ void KernelRobertaBpeTokenizer::Compute(OrtKernelContext* context) {
}
}
}
if (offset_mapping != nullptr) {
auto* offset = ort_.GetTensorMutableData<int64_t>(offset_mapping);
if (offset_mapping.has_value()) {
auto* offset = (*offset_mapping)->Allocate(offset_dim);
int idx2 = 0;
for (auto& res : offset_map) {
for (auto& mapping : res) {
@ -181,32 +178,3 @@ void KernelRobertaBpeTokenizer::Compute(OrtKernelContext* context) {
}
}
}
const char* CustomOpRobertaBpeTokenizer::GetName() const {
return "RobertaTokenizer";
}
size_t CustomOpRobertaBpeTokenizer::GetInputTypeCount() const {
return 1;
}
ONNXTensorElementDataType CustomOpRobertaBpeTokenizer::GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
}
OrtCustomOpInputOutputCharacteristic CustomOpRobertaBpeTokenizer::GetInputCharacteristic(size_t /*index*/) const {
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
}
size_t CustomOpRobertaBpeTokenizer::GetOutputTypeCount() const {
return 3;
}
ONNXTensorElementDataType CustomOpRobertaBpeTokenizer::GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
}
OrtCustomOpInputOutputCharacteristic CustomOpRobertaBpeTokenizer::GetOutputCharacteristic(size_t index) const {
return index == 0 ? OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED
: OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL;
}

Просмотреть файл

@ -6,7 +6,10 @@
struct KernelRobertaBpeTokenizer : BaseKernel {
KernelRobertaBpeTokenizer(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
void Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask,
std::optional<ortc::Tensor<int64_t>*> offset_mapping);
private:
using OffsetMappingType = std::list<std::pair<size_t, size_t>>;
@ -16,13 +19,3 @@ struct KernelRobertaBpeTokenizer : BaseKernel {
std::list<std::pair<int, int>> byte_list_;
std::shared_ptr<VocabData> bbpe_tokenizer_;
};
struct CustomOpRobertaBpeTokenizer : OrtW::CustomOpBase<CustomOpRobertaBpeTokenizer, KernelRobertaBpeTokenizer> {
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t index) const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t index) const;
};

Просмотреть файл

@ -22,10 +22,10 @@ struct KernelSentencepieceDecoder : BaseKernel {
}
}
void Compute(OrtKernelContext* context) {
const OrtValue* ids = ort_.KernelContext_GetInput(context, 0);
const int64_t* p_ids = ort_.GetTensorData<int64_t>(ids);
OrtTensorDimensions ids_dim(ort_, ids);
void Compute(const ortc::Tensor<int64_t>& ids,
ortc::Tensor<std::string>& output) {
const int64_t* p_ids = ids.Data();
auto& ids_dim = ids.Shape();
if (!((ids_dim.size() == 1) || (ids_dim.size() == 2 && ids_dim[0] == 1))) {
ORTX_CXX_API_THROW("[SentencePieceDecoder]: Expect ids dimension [n] or [1,n].", ORT_INVALID_GRAPH);
@ -34,7 +34,7 @@ struct KernelSentencepieceDecoder : BaseKernel {
std::string decoded_string;
std::vector<int64_t> output_dim = {1};
std::vector<int> tids;
std::transform(p_ids, p_ids + ids_dim.Size(),
std::transform(p_ids, p_ids + ids.NumberOfElement(),
std::back_inserter(tids),
[](auto _id) { return static_cast<int>(_id); });
auto status = tokenizer_.Decode(tids, &decoded_string);
@ -43,32 +43,9 @@ struct KernelSentencepieceDecoder : BaseKernel {
}
std::vector<std::string> result = {decoded_string};
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, output_dim.data(), output_dim.size());
FillTensorDataString(api_, ort_, context, result, output);
output.SetStringOutput(result, output_dim);
}
private:
sentencepiece::SentencePieceProcessor tokenizer_;
};
struct CustomOpSentencepieceDecoder : OrtW::CustomOpBase<CustomOpSentencepieceDecoder, KernelSentencepieceDecoder> {
const char* GetName() const {
return "SentencepieceDecoder";
}
size_t GetInputTypeCount() const {
return 1;
}
ONNXTensorElementDataType GetInputType(size_t index) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
}
size_t GetOutputTypeCount() const {
return 1;
}
ONNXTensorElementDataType GetOutputType(size_t index) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
}
};

Просмотреть файл

@ -31,32 +31,16 @@ static void _check_dimension_constant(OrtW::CustomOpApi ort, const OrtValue* ort
ORT_INVALID_ARGUMENT);
}
void KernelSentencepieceTokenizer::Compute(OrtKernelContext* context) {
void KernelSentencepieceTokenizer::Compute(const ortc::Tensor<std::string>& input,
int64_t /*nbest_size*/,
float /*alpha*/,
bool add_bos,
bool add_eos,
bool add_rev,
ortc::Tensor<int32_t>& output,
ortc::Tensor<int64_t>& output1) {
// Update with the new API
const OrtValue* ort_input = ort_.KernelContext_GetInput(context, 0);
std::vector<std::string> str_input;
GetTensorMutableDataString(api_, ort_, context, ort_input, str_input);
const OrtValue* ort_nbest_size = ort_.KernelContext_GetInput(context, 1);
const float* p_nbest_size = ort_.GetTensorData<float>(ort_nbest_size);
const OrtValue* ort_alpha = ort_.KernelContext_GetInput(context, 2);
const float* p_alpha = ort_.GetTensorData<float>(ort_alpha);
const OrtValue* ort_add_bos = ort_.KernelContext_GetInput(context, 3);
const bool* p_add_bos = ort_.GetTensorData<bool>(ort_add_bos);
const OrtValue* ort_add_eos = ort_.KernelContext_GetInput(context, 4);
const bool* p_add_eos = ort_.GetTensorData<bool>(ort_add_eos);
const OrtValue* ort_add_rev = ort_.KernelContext_GetInput(context, 5);
const bool* p_add_rev = ort_.GetTensorData<bool>(ort_add_rev);
(void)p_nbest_size;
(void)p_alpha;
// Verifications
_check_dimension_constant(ort_, ort_nbest_size, "nbest_size");
_check_dimension_constant(ort_, ort_alpha, "alpha");
_check_dimension_constant(ort_, ort_add_bos, "add_bos");
_check_dimension_constant(ort_, ort_add_eos, "add_eos");
_check_dimension_constant(ort_, ort_add_rev, "add_rev");
auto& str_input = input.Data();
// computation
std::vector<int64_t> indices;
@ -68,20 +52,20 @@ void KernelSentencepieceTokenizer::Compute(OrtKernelContext* context) {
ORTX_CXX_API_THROW(MakeString("Unable to encode string '", str_input[i], "'."), ORT_INVALID_ARGUMENT);
indices.push_back(content.size());
if (*p_add_rev) {
if (*p_add_eos) {
if (add_rev) {
if (add_eos) {
content.push_back(tokenizer_.eos_id());
}
content.insert(content.end(), inloop.rbegin(), inloop.rend());
if (*p_add_bos) {
if (add_bos) {
content.push_back(tokenizer_.bos_id());
}
} else {
if (*p_add_bos) {
if (add_bos) {
content.push_back(tokenizer_.bos_id());
}
content.insert(content.end(), inloop.begin(), inloop.end());
if (*p_add_eos) {
if (add_eos) {
content.push_back(tokenizer_.eos_id());
}
}
@ -91,54 +75,12 @@ void KernelSentencepieceTokenizer::Compute(OrtKernelContext* context) {
// Setup output
std::vector<int64_t> size_content(1);
size_content[0] = content.size();
OrtValue* out_content = ort_.KernelContext_GetOutput(context, 0, size_content.data(), size_content.size());
std::vector<int64_t> size_indices(1);
size_indices[0] = indices.size();
OrtValue* out_indices = ort_.KernelContext_GetOutput(context, 1, size_indices.data(), size_indices.size());
int* ptr_content = ort_.GetTensorMutableData<int>(out_content);
int* ptr_content = output.Allocate(size_content);
memcpy(ptr_content, content.data(), content.size() * sizeof(int));
int64_t* ptr_indices = ort_.GetTensorMutableData<int64_t>(out_indices);
int64_t* ptr_indices = output1.Allocate(size_indices);
memcpy(ptr_indices, indices.data(), indices.size() * sizeof(int64_t));
}
const char* CustomOpSentencepieceTokenizer::GetName() const {
return "SentencepieceTokenizer";
};
size_t CustomOpSentencepieceTokenizer::GetInputTypeCount() const {
return 6;
};
ONNXTensorElementDataType CustomOpSentencepieceTokenizer::GetInputType(size_t index) const {
switch (index) {
case 0:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
case 1:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
case 2:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
case 3:
case 4:
case 5:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
default:
ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
}
};
size_t CustomOpSentencepieceTokenizer::GetOutputTypeCount() const {
return 2;
};
ONNXTensorElementDataType CustomOpSentencepieceTokenizer::GetOutputType(size_t index) const {
switch (index) {
case 0:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
case 1:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
default:
ORTX_CXX_API_THROW(MakeString("[SentencepieceTokenizer] Unexpected output index ", index), ORT_INVALID_ARGUMENT);
}
};

Просмотреть файл

@ -9,17 +9,15 @@
struct KernelSentencepieceTokenizer : BaseKernel {
KernelSentencepieceTokenizer(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
void Compute(const ortc::Tensor<std::string>& input,
int64_t /*nbest_size*/,
float /*alpha*/,
bool add_bos,
bool add_eos,
bool add_rev,
ortc::Tensor<int32_t>& output,
ortc::Tensor<int64_t>& output1);
private:
sentencepiece::SentencePieceProcessor tokenizer_;
};
struct CustomOpSentencepieceTokenizer : OrtW::CustomOpBase<CustomOpSentencepieceTokenizer,
KernelSentencepieceTokenizer> {
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
};

Просмотреть файл

@ -27,33 +27,42 @@
#include "bert_tokenizer_decoder.hpp"
#endif
FxLoadCustomOpFactory LoadCustomOpClasses_Tokenizer = LoadCustomOpClasses<
CustomOpClassBegin
const std::vector<const OrtCustomOp*>& TokenizerLoader() {
static OrtOpLoader op_loader(
[]() { return nullptr; }
#ifdef ENABLE_GPT2_TOKENIZER
, CustomOpBpeTokenizer
, CustomOpClipBpeTokenizer
, CustomOpRobertaBpeTokenizer
, CustomOpBpeDecoder
,
CustomCpuStruct("GPT2Tokenizer", KernelBpeTokenizer),
CustomCpuStruct("CLIPTokenizer", KernelClipBpeTokenizer),
CustomCpuStruct("RobertaTokenizer", KernelRobertaBpeTokenizer),
CustomCpuStruct("BpeDecoder", KernelBpeDecoder)
#endif
#ifdef ENABLE_SPM_TOKENIZER
, CustomOpSentencepieceTokenizer
, CustomOpSentencepieceDecoder
,
CustomCpuStruct("SentencepieceTokenizer", KernelSentencepieceTokenizer),
CustomCpuStruct("SentencepieceDecoder", KernelSentencepieceDecoder)
#endif
#ifdef ENABLE_WORDPIECE_TOKENIZER
, CustomOpWordpieceTokenizer
,
CustomCpuStruct("WordpieceTokenizer", KernelWordpieceTokenizer)
#endif
#ifdef ENABLE_BERT_TOKENIZER
, CustomOpBasicTokenizer
, CustomOpBertTokenizer
, CustomOpBertTokenizerDecoder
, CustomOpHfBertTokenizer
,
CustomCpuStruct("BasicTokenizer", KernelBasicTokenizer),
CustomCpuStruct("BertTokenizer", KernelBertTokenizer),
CustomCpuStruct("BertTokenizerDecoder", KernelBertTokenizerDecoder),
CustomCpuStruct("HfBertTokenizer", KernelHfBertTokenizer)
#endif
#ifdef ENABLE_BLINGFIRE
, CustomOpBlingFireSentenceBreaker
,
CustomCpuStruct("BlingFireSentenceBreaker", KernelBlingFireSentenceBreaker)
#endif
>;
);
return op_loader.GetCustomOps();
}
FxLoadCustomOpFactory LoadCustomOpClasses_Tokenizer = TokenizerLoader;

Просмотреть файл

@ -122,14 +122,20 @@ void KernelWordpieceTokenizer_Tokenizer(const std::unordered_map<std::u32string,
rows.push_back(indices.size());
}
void KernelWordpieceTokenizer::Compute(OrtKernelContext* context) {
void KernelWordpieceTokenizer::Compute(const ortc::Tensor<std::string>& input,
const ortc::Tensor<int64_t>& row_indices,
ortc::Tensor<std::string>& output,
ortc::Tensor<int64_t>& row_lengths,
ortc::Tensor<int64_t>& out_row_begin,
ortc::Tensor<int64_t>& output_limit_values) {
// Update with the new API
const OrtValue* ort_input = ort_.KernelContext_GetInput(context, 0);
// make a copy as we need ustring
std::vector<ustring> str_input;
GetTensorMutableDataString(api_, ort_, context, ort_input, str_input);
const OrtValue* ort_row_indices = ort_.KernelContext_GetInput(context, 1);
OrtTensorDimensions ort_row_indices_dim(ort_, ort_row_indices);
const int64_t* p_row_indices = ort_row_indices_dim.empty() ? nullptr : ort_.GetTensorData<int64_t>(ort_row_indices);
str_input.reserve(input.NumberOfElement());
for (auto& str : input.Data()) {
str_input.emplace_back(str);
}
const int64_t* p_row_indices = row_indices.Shape().empty() ? nullptr : row_indices.Data();
std::vector<ustring> tokens;
std::vector<int32_t> indices;
@ -137,21 +143,21 @@ void KernelWordpieceTokenizer::Compute(OrtKernelContext* context) {
KernelWordpieceTokenizer_Tokenizer(vocab_, suffix_indicator_, unk_token_, str_input,
tokens, indices, row_begins,
p_row_indices, ort_row_indices_dim.Size(),
p_row_indices, row_indices.NumberOfElement(),
max_input_chars_per_word_);
std::vector<int64_t> size_content{(int64_t)indices.size()};
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, size_content.data(), size_content.size());
FillTensorDataString(api_, ort_, context, tokens, output);
// TODO: avoid copy
std::vector<std::string> out_content;
for (auto& s : tokens)
out_content.emplace_back(s);
output.SetStringOutput(out_content, size_content);
std::vector<int64_t> size_row_lengths{(int64_t)row_begins.size()};
OrtValue* output_row_lengths = ort_.KernelContext_GetOutput(context, 1, size_row_lengths.data(), size_row_lengths.size());
int64_t* ptr_row_lengths = row_lengths.Allocate(size_row_lengths);
--size_row_lengths[0];
OrtValue* output_row_begins = ort_.KernelContext_GetOutput(context, 2, size_row_lengths.data(), size_row_lengths.size());
OrtValue* output_limit_values = ort_.KernelContext_GetOutput(context, 3, size_row_lengths.data(), size_row_lengths.size());
int64_t* ptr_row_lengths = ort_.GetTensorMutableData<int64_t>(output_row_lengths);
int64_t* ptr_row_begins = ort_.GetTensorMutableData<int64_t>(output_row_begins);
int64_t* ptr_limit_values = ort_.GetTensorMutableData<int64_t>(output_limit_values);
int64_t* ptr_row_begins = out_row_begin.Allocate(size_row_lengths);
int64_t* ptr_limit_values = output_limit_values.Allocate(size_row_lengths);
int64_t i;
for (i = 0; i < size_row_lengths[0]; ++i) {
@ -163,39 +169,3 @@ void KernelWordpieceTokenizer::Compute(OrtKernelContext* context) {
i = size_row_lengths[0];
ptr_row_lengths[i] = row_begins[static_cast<size_t>(i)];
}
const char* CustomOpWordpieceTokenizer::GetName() const {
return "WordpieceTokenizer";
};
size_t CustomOpWordpieceTokenizer::GetInputTypeCount() const {
return 2;
};
ONNXTensorElementDataType CustomOpWordpieceTokenizer::GetInputType(size_t index) const {
switch (index) {
case 0:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
case 1:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
default:
ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
}
};
size_t CustomOpWordpieceTokenizer::GetOutputTypeCount() const {
return 4;
};
ONNXTensorElementDataType CustomOpWordpieceTokenizer::GetOutputType(size_t index) const {
switch (index) {
case 0:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
case 1:
case 2:
case 3:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
default:
ORTX_CXX_API_THROW(MakeString("[WordpieceTokenizer] Unexpected output index ", index), ORT_INVALID_ARGUMENT);
}
};

Просмотреть файл

@ -10,10 +10,14 @@
#include <unordered_map>
struct KernelWordpieceTokenizer : BaseKernel {
KernelWordpieceTokenizer(const OrtApi& api, const OrtKernelInfo& info);
void Compute(OrtKernelContext* context);
void Compute(const ortc::Tensor<std::string>& input,
const ortc::Tensor<int64_t>& row_indices,
ortc::Tensor<std::string>& output,
ortc::Tensor<int64_t>& row_lengths,
ortc::Tensor<int64_t>& out_row_begin,
ortc::Tensor<int64_t>& output_limit_values);
private:
int64_t max_input_chars_per_word_;
@ -22,14 +26,6 @@ struct KernelWordpieceTokenizer : BaseKernel {
std::unordered_map<std::u32string, int32_t> vocab_;
};
struct CustomOpWordpieceTokenizer : OrtW::CustomOpBase<CustomOpWordpieceTokenizer, KernelWordpieceTokenizer> {
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
};
void KernelWordpieceTokenizer_Split(const std::u32string& suffix_indicator,
const std::u32string& text,
std::vector<std::u32string>& words);

Просмотреть файл

@ -9,6 +9,10 @@
#include <cstdint>
namespace ort_extensions {
void decode_image(const ortc::Tensor<uint8_t>& input,
ortc::Tensor<uint8_t>& output);
struct KernelDecodeImage : BaseKernel {
KernelDecodeImage(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {}
@ -50,4 +54,5 @@ struct CustomOpDecodeImage : OrtW::CustomOpBase<CustomOpDecodeImage, KernelDecod
}
}
};
} // namespace ort_extensions