Use API-lite for custom ops (#386)
* use lite custom op api for math
* add vision ops
* add cx2 ops
* remove useless code
* support register custom kernel struct
* add string tensor support
* add more text kernels
* fix issue with std stringg as scalar
* migrate all text ops
* initial tokenizer change
* migrate all tokenizers
* Resolve conflict with main (#433)
* resolve conflict
* resolve conflict
---------
Co-authored-by: Randy Shuai <rashuai@microsoft.com>
* Update custom-op-lite PR (#440)
* add the onnxruntime 1.14 release into the CI pipeline (#387)
* add the onnxruntime 1.14 release into the CI pipeline
* torch 2.0 crashed on Linux
* Fix size_t overflow issue for RobertaTokenizer (#388)
Co-authored-by: Sayan Shaw <sayanshaw@microsoft.com>
* Pre and Post processing example for openAI Whisper model (#380)
* add a stft-norm custom op for log-mel spectrum.
* undo the debug change
* Support ONNX standard STFT op signature.
* Add a unit test onnx STFT compatible mode.
* add whisper pre-/post- processing example
* Update dlib.cmake
* undo test code changes
* Update setup.cfg
* update the end2end example with STFT op
* Added optional outputs for GPT2, CLIP and Roberta Tokenizers (#389)
* Initial optional i/o for robertap
* Small fix
* Added working optional output functionality to RobertaTokenizer with tests
* Added optional outputs to CLIPTokenizer
* Added optional outputs to GPT2Tokenizer
* Use ternary operators
---------
Authored-by: Sayan Shaw <sayanshaw@microsoft.com>
* ignore the unknown token id on bpe deocder (#391)
* Use dependency name 'nlohmann_json' which is the same name that ORT uses. (#393)
* Add an audio decoder custom op for whisper end-to-end processing (#385)
* evaluate the audio decoder library
* MP3 Decoder
* rename it to test_audio_codec
* add the audio decoder to whisper model
* whisper end-to-end draft
* fix the mp3 decoder
* Running with ONNX models
* Add more audio format supports
* refine the end-to-end script
* Update operators/audio/audio_decoder.hpp
Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
* Update operators/audio/audio_decoder.hpp
Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
* Update operators/audio/audio_decoder.hpp
Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
* some fixings of comments and more test cases.
* changes for review comments.
* Update audio_decoder.hpp
* Update audio_decoder.hpp
* code refinement
* Update operators/audio/audio_decoder.hpp
Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
---------
Co-authored-by: Sayan Shaw <52221015+sayanshaw24@users.noreply.github.com>
Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
* make tensorflow be optional for unittest (#394)
* make tensorflow be optional for unitest.
* typo
* built-in bounding box op (#382)
* built-in bounding box op
* update boundary check
* assert policy
* more boundary test and check
* XYXY--> X horizon
---------
Co-authored-by: Scott McKay <skottmckay@gmail.com>
* a quick nuget package impl. (#396)
* Update wheels_linux.yml: change the linux machine pool name (#398)
* Add a merge step in whisper end-to-end script and fixed some issues (#399)
* add merged models in whisper model
* verify the final model
* support batch > 1 in BpeDecoder (#400)
* support batch > 1 in BpeDecoder
* update the shape in helper function
* [object detection ppp] YoLo as example (#397)
* object detection
* Unit test
add e2e fastestdet model test
---------
Co-authored-by: Changming Sun <chasun@microsoft.com>
Co-authored-by: Scott McKay <skottmckay@gmail.com>
* some fixing for python package (#401)
* more code fixing related whisper models (#403)
* Added windows nuget work temporarily for testing (#402)
* Added windows nuget work temporarily for testing
* Cleanup
* Add back onnxruntime.lib in props file for possible future ORT need
---------
Co-authored-by: Sayan Shaw <sayanshaw@microsoft.com>
Co-authored-by: Wenbing Li <10278425+wenbingl@users.noreply.github.com>
* Remove unnecessary nupkg file and update nuspec (#405)
* Add nuget pack to build.bat and small nuget changes for demo
* Temporarily adding nuget.exe to build package until we can add to CI machine
* Switch back from Release to RelWithDebInfo
* Remove unnecessary changes
---------
Co-authored-by: Sayan Shaw <sayanshaw@microsoft.com>
* Add initial NuGet pipeline for Windows x64 build (#406)
* initial nuget pipeline
* Update nuget.yml for Azure Pipelines
* update nuget.yml for extensions specific packaging
TODO: add certain template yml files
* added component governance template yaml
* change template yaml path
* remove RoslynAnalyzers
* Add packDestination to nuget pack task (change from default)
* fix nuspec path
* Update nuget.yml for Azure Pipelines
* Update nuget.yml for Azure Pipelines
* Update nuget.yml for Azure Pipelines
* Update 2 nuget.yml for Azure Pipelines
* Update NativeNuget.nuspec
* Update nuget.yml for Azure Pipelines
* update nuspec
* Update 3 nuget.yml for Azure Pipelines
* Update 4 nuget.yml for Azure Pipelines
* Update 7 nuget.yml for Azure Pipelines
* Remove unnecessary nupkg file and update nuspec (#405)
* Add nuget pack to build.bat and small nuget changes for demo
* Temporarily adding nuget.exe to build package until we can add to CI machine
* Switch back from Release to RelWithDebInfo
* Remove unnecessary changes
---------
Co-authored-by: Sayan Shaw <sayanshaw@microsoft.com>
* Update 8 nuget.yml for Azure Pipelines
* Update 9 nuget.yml for Azure Pipelines
* add DLL signing
* Update nuget.yml for Azure Pipelines
* fix indendation
* Update 11 nuget.yml for Azure Pipelines
* Update 12 nuget.yml for Azure Pipelines
* Update 12 nuget.yml for Azure Pipelines
* Revert some unneccesary changes on nuget.yml
* clean up nuget.yml and update nuspec release notes
* small changes
* update commit id and release notes
---------
Co-authored-by: Wenbing Li <wenbingl@outlook.com>
Co-authored-by: Wenbing Li <10278425+wenbingl@users.noreply.github.com>
Co-authored-by: Sayan Shaw <sayanshaw@microsoft.com>
* Compatible with onnxruntime-gpu package (#410)
* be compatible without onnxruntime-gpu version
* some fixing
* Add nuget README and remove ort lib references from props (#409)
* Add nuget README and remove ort lib references from props
* replace commit id in nuspec dynamically
* remove $ sign for commit id token
---------
Co-authored-by: Sayan Shaw <sayanshaw@microsoft.com>
Co-authored-by: Wenbing Li <10278425+wenbingl@users.noreply.github.com>
* Add an C# demo project for NuGet package (#407)
* Add a nuget test app
* remove unused file
* Compatible with onnxruntime-gpu package (#410)
* be compatible without onnxruntime-gpu version
* some fixing
* turn it as a .net demo project
---------
Co-authored-by: Sayan Shaw <52221015+sayanshaw24@users.noreply.github.com>
* Make Whisper E2E script more portable (#412)
This PR makes the Whisper E2E script more portable for other environments.
* Update macos wheel timeout to 180 min (#390)
* Update ci timeout to 120 min
* Only update WindowsPython job timeout
* Update ci timeout to 90 min
* update macos wheel timeout to 180 min
---------
Co-authored-by: Sayan Shaw <sayanshaw@microsoft.com>
* Fix OneBranch PR pipeline CodeQL issue (#413)
* test codeql 3000
* switch codeql from compiled to python
* switch back to compiled
---------
Co-authored-by: Sayan Shaw <sayanshaw@microsoft.com>
* Adding down-sampling and stereo mixing features for AudioDecoder (#420)
* initial draft
* second
* third
* polishing
* fix the M_PI name in LINUX platform
* fix bessel function issue
* add a unit test case
* fix the unit test name
* Fix Secure Supply Chain Analysis Warning in PR pipeline (#414)
* remove package sources
* remove NuGet.config
* add .sscignore for cfs0011
* change sscignore
* add CFS0013 to sscignore
---------
Co-authored-by: Sayan Shaw <sayanshaw@microsoft.com>
* fix onnx version to 1.13.1 (#422)
* [NuGet] All platform package pipeline (#408)
* nuget ci package
* disable macos arm64 build for err
* Get the iOS xcframework build working with the split build/pack approach. (#416)
* refine build_xcframework.py
Cleanup/clarify various things
- naming of parameters and files
- consistency
Make handling of additional build args more generic
Update the artifact download dir/extract dir to more intuitive names
Update scripts
- make usage from CI pipeline clearer (e.g. don't hide directory names inside script)
- keep comments in nuspec
- remove unused args
- make additional arg handling more
Co-authored-by: Scott McKay <skottmckay@gmail.com>
Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
* Add new required pre/post processing ops to Android and iOS packages. (#415)
* Revert "Pin onnx version to 1.13.1" (#423)
* Revert "fix onnx version to 1.13.1 (#422)"
This reverts commit eb29d225a7
.
* Update requirements.txt
* PyOp attribute supports int and float data type (#425)
* Fix Android AAR in nuget package. Requires libortextensions.so. (#429)
* build for mac M1 (#430)
* Fix the unit test failure with ONNX 1.14 package. (#428)
* Fix the unit test failure with ONNX 1.14 package.
* more tests
* Update whisper_e2e.py
* Add nuget.org publish version option (#426)
* Add nuget.org publish version option
* typo
* small fix
* typo
---------
Co-authored-by: Sayan Shaw <sayanshaw@microsoft.com>
* resolve conflict
* resolve conflict
* minor fix
* rename from TensorT to Tensor
* fix string tensor
* Add OrtLiteCustomOp
* switch to string view
* fix regex ops
---------
Co-authored-by: Wenbing Li <10278425+wenbingl@users.noreply.github.com>
Co-authored-by: Sayan Shaw <52221015+sayanshaw24@users.noreply.github.com>
Co-authored-by: Sayan Shaw <sayanshaw@microsoft.com>
Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
Co-authored-by: JiCheng <247153481@qq.com>
Co-authored-by: Scott McKay <skottmckay@gmail.com>
Co-authored-by: Changming Sun <chasun@microsoft.com>
Co-authored-by: Wenbing Li <wenbingl@outlook.com>
Co-authored-by: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com>
Co-authored-by: Randy Shuai <rashuai@microsoft.com>
* Fix a build err (#442)
* resolve conflict
* resolve conflict
* minor fix
* rename from TensorT to Tensor
* fix string tensor
* Add OrtLiteCustomOp
* switch to string view
* fix regex ops
* fix build
---------
Co-authored-by: Randy Shuai <rashuai@microsoft.com>
* Fix build err on ort 141 (#444)
* resolve conflict
* resolve conflict
* minor fix
* rename from TensorT to Tensor
* fix string tensor
* Add OrtLiteCustomOp
* switch to string view
* fix regex ops
* fix build
* fix a build err
---------
Co-authored-by: Randy Shuai <rashuai@microsoft.com>
* Remove shape from span (#445)
* resolve conflict
* resolve conflict
* minor fix
* rename from TensorT to Tensor
* fix string tensor
* Add OrtLiteCustomOp
* switch to string view
* fix regex ops
* fix build
* fix a build err
* remove shape
---------
Co-authored-by: Randy Shuai <rashuai@microsoft.com>
* Fix python tests (#446)
* resolve conflict
* resolve conflict
* minor fix
* rename from TensorT to Tensor
* fix string tensor
* Add OrtLiteCustomOp
* switch to string view
* fix regex ops
* fix build
* fix a build err
* remove shape
* fix python tests
---------
Co-authored-by: Randy Shuai <rashuai@microsoft.com>
* Fix max build (#449)
* resolve conflict
* resolve conflict
* minor fix
* rename from TensorT to Tensor
* fix string tensor
* Add OrtLiteCustomOp
* switch to string view
* fix regex ops
* fix build
* fix a build err
* remove shape
* fix python tests
* fix packaging err
* fix mac build
---------
Co-authored-by: Randy Shuai <rashuai@microsoft.com>
* Fix comments (#452)
* resolve conflict
* resolve conflict
* minor fix
* rename from TensorT to Tensor
* fix string tensor
* Add OrtLiteCustomOp
* switch to string view
* fix regex ops
* fix build
* fix a build err
* remove shape
* fix python tests
* fix packaging err
* fix mac build
* fixing the universal2 python package for macOS (#448)
* Remove onnx<1.14 from requirements.txt (#447)
* remove onnx<1.14 from requirements.txt
* downgrade protobuf
* move protobuf req to requirements-dev.txt
---------
Co-authored-by: Sayan Shaw <sayanshaw@microsoft.com>
Co-authored-by: Wenbing Li <10278425+wenbingl@users.noreply.github.com>
* fix comments
* comment version macro
---------
Co-authored-by: Randy Shuai <rashuai@microsoft.com>
Co-authored-by: Wenbing Li <10278425+wenbingl@users.noreply.github.com>
Co-authored-by: Sayan Shaw <52221015+sayanshaw24@users.noreply.github.com>
Co-authored-by: Sayan Shaw <sayanshaw@microsoft.com>
* Fix build err (#453)
* resolve conflict
* resolve conflict
* minor fix
* rename from TensorT to Tensor
* fix string tensor
* Add OrtLiteCustomOp
* switch to string view
* fix regex ops
* fix build
* fix a build err
* remove shape
* fix python tests
* fix packaging err
* fix mac build
* fix comments
* comment version macro
* define Compute for StftNormal
---------
Co-authored-by: Randy Shuai <rashuai@microsoft.com>
* Merge latest main (#461)
* resolve conflict
* resolve conflict
* minor fix
* rename from TensorT to Tensor
* fix string tensor
* Add OrtLiteCustomOp
* switch to string view
* fix regex ops
* fix build
* fix a build err
* remove shape
* fix python tests
* fix packaging err
* fix mac build
* fix comments
* comment version macro
* define Compute for StftNormal
---------
Co-authored-by: Randy Shuai <rashuai@microsoft.com>
* revert wanted changes in test
* revert unwanted changed
* add string_strip op
---------
Co-authored-by: Cheng Tang <chenta@microsoft.com@orttrainingdev9.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: RandySheriffH <48490400+RandySheriffH@users.noreply.github.com>
Co-authored-by: Randy Shuai <rashuai@microsoft.com>
Co-authored-by: Wenbing Li <10278425+wenbingl@users.noreply.github.com>
Co-authored-by: Sayan Shaw <52221015+sayanshaw24@users.noreply.github.com>
Co-authored-by: Sayan Shaw <sayanshaw@microsoft.com>
Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
Co-authored-by: JiCheng <247153481@qq.com>
Co-authored-by: Scott McKay <skottmckay@gmail.com>
Co-authored-by: Changming Sun <chasun@microsoft.com>
Co-authored-by: Wenbing Li <wenbingl@outlook.com>
Co-authored-by: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com>
This commit is contained in:
Родитель
30eb7afcfa
Коммит
8f36cf3272
|
@ -0,0 +1,663 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include "onnxruntime_customop.hpp"
|
||||
#include <optional>
|
||||
#include <numeric>
|
||||
// uplevel the version when supported ort version migrates to newer ones
|
||||
#define SUPPORT_ORT_API_VERSION_TO 13
|
||||
|
||||
namespace Ort {
|
||||
namespace Custom {
|
||||
|
||||
class TensorBase {
|
||||
public:
|
||||
TensorBase(const OrtW::CustomOpApi& api,
|
||||
OrtKernelContext& ctx,
|
||||
size_t indice,
|
||||
bool is_input) : api_(api),
|
||||
ctx_(ctx),
|
||||
indice_(indice),
|
||||
is_input_(is_input) {}
|
||||
|
||||
virtual ~TensorBase() = default;
|
||||
operator bool() const {
|
||||
return shape_.has_value();
|
||||
}
|
||||
const std::vector<int64_t>& Shape() const {
|
||||
if (shape_.has_value()) {
|
||||
return *shape_;
|
||||
} else {
|
||||
ORTX_CXX_API_THROW("tensor shape is not yet initialized", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
}
|
||||
int64_t NumberOfElement() const {
|
||||
if (shape_.has_value()) {
|
||||
return std::accumulate(shape_->begin(), shape_->end(), 1LL, std::multiplies<int64_t>());
|
||||
} else {
|
||||
ORTX_CXX_API_THROW("tensor shape is not yet initialized", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
}
|
||||
std::string Shape2Str() const {
|
||||
if (shape_.has_value()) {
|
||||
std::string shape_str;
|
||||
for (const auto& dim: *shape_) {
|
||||
shape_str.append(std::to_string(dim));
|
||||
shape_str.append(", ");
|
||||
}
|
||||
return shape_str;
|
||||
} else {
|
||||
return "empty";
|
||||
}
|
||||
}
|
||||
protected:
|
||||
const OrtW::CustomOpApi& api_;
|
||||
OrtKernelContext& ctx_;
|
||||
size_t indice_;
|
||||
bool is_input_;
|
||||
std::optional<std::vector<int64_t>> shape_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Span {
|
||||
const T* data_ = {};
|
||||
size_t size_ = {};
|
||||
void Assign(const T* data, size_t size) {
|
||||
data_ = data;
|
||||
size_ = size;
|
||||
}
|
||||
size_t size() const { return size_; }
|
||||
T operator[](size_t indice) const {
|
||||
return data_[indice];
|
||||
}
|
||||
const T* data() const { return data_; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class Tensor : public TensorBase {
|
||||
public:
|
||||
using TT = typename std::remove_reference<T>::type;
|
||||
Tensor(const OrtW::CustomOpApi& api,
|
||||
OrtKernelContext& ctx,
|
||||
size_t indice,
|
||||
bool is_input) : TensorBase(api,
|
||||
ctx,
|
||||
indice,
|
||||
is_input) {
|
||||
if (is_input) {
|
||||
auto input_count = api_.KernelContext_GetInputCount(&ctx_);
|
||||
if (indice >= input_count) {
|
||||
ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
const_value_ = api_.KernelContext_GetInput(&ctx_, indice);
|
||||
auto* info = api_.GetTensorTypeAndShape(const_value_);
|
||||
shape_ = api_.GetTensorShape(info);
|
||||
api_.ReleaseTensorTypeAndShapeInfo(info);
|
||||
}
|
||||
}
|
||||
const TT* Data() const {
|
||||
return api_.GetTensorData<TT>(const_value_);
|
||||
}
|
||||
TT* Allocate(const std::vector<int64_t>& shape) {
|
||||
if (!data_) {
|
||||
OrtValue* out = api_.KernelContext_GetOutput(&ctx_, indice_, shape.data(), shape.size());
|
||||
shape_ = shape;
|
||||
data_ = api_.GetTensorMutableData<TT>(out);
|
||||
}
|
||||
return data_;
|
||||
}
|
||||
const Span<T>& AsSpan() {
|
||||
if (!shape_.has_value() || shape_->size() != 1) {
|
||||
ORTX_CXX_API_THROW("to get a span, shape must be 1-D, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
span_.Assign(Data(), (*shape_)[0]);
|
||||
return span_;
|
||||
}
|
||||
const T& AsScalar() {
|
||||
if (!shape_.has_value() || (shape_->size() == 1 && (*shape_)[0] != 1) || shape_->size() > 1) {
|
||||
ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
return *Data();
|
||||
}
|
||||
|
||||
private:
|
||||
const OrtValue* const_value_{}; // for input
|
||||
TT* data_{}; // for output
|
||||
Span<T> span_;
|
||||
};
|
||||
|
||||
template <>
|
||||
class Tensor<std::string> : public TensorBase {
|
||||
public:
|
||||
using strings = std::vector<std::string>;
|
||||
|
||||
Tensor(const OrtW::CustomOpApi& api,
|
||||
OrtKernelContext& ctx,
|
||||
size_t indice,
|
||||
bool is_input) : TensorBase(api,
|
||||
ctx,
|
||||
indice,
|
||||
is_input) {
|
||||
if (is_input) {
|
||||
auto input_count = api_.KernelContext_GetInputCount(&ctx_);
|
||||
if (indice >= input_count) {
|
||||
ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
|
||||
auto* const_value = api_.KernelContext_GetInput(&ctx_, indice);
|
||||
auto* info = api_.GetTensorTypeAndShape(const_value);
|
||||
shape_ = api_.GetTensorShape(info);
|
||||
api_.ReleaseTensorTypeAndShapeInfo(info);
|
||||
|
||||
size_t num_chars;
|
||||
OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorDataLength(const_value, &num_chars));
|
||||
std::vector<char> chars(num_chars + 1, '\0');
|
||||
auto num_strings = NumberOfElement();
|
||||
std::vector<size_t> offsets(NumberOfElement());
|
||||
OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorContent(const_value,
|
||||
(void*)chars.data(),
|
||||
num_chars,
|
||||
offsets.data(),
|
||||
offsets.size()));
|
||||
auto upper_bound = static_cast<int64_t>(num_strings) - 1;
|
||||
input_strings_.resize(num_strings);
|
||||
for (int64_t i = upper_bound; i >= 0; --i) {
|
||||
if (i < upper_bound) {
|
||||
chars[offsets[i + 1]] = '\0';
|
||||
}
|
||||
input_strings_[i] = chars.data() + offsets[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
const strings& Data() const {
|
||||
return input_strings_;
|
||||
}
|
||||
void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
|
||||
std::vector<const char*> raw;
|
||||
for (const auto& s : ss) {
|
||||
raw.push_back(s.data());
|
||||
}
|
||||
auto* output = api_.KernelContext_GetOutput(&ctx_, indice_, dims.data(), dims.size());
|
||||
OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().FillStringTensor(output, raw.data(), raw.size()));
|
||||
}
|
||||
void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) {
|
||||
auto* output = api_.KernelContext_GetOutput(&ctx_, indice_, dims.data(), dims.size());
|
||||
OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().FillStringTensor(output, ss.data(), ss.size()));
|
||||
}
|
||||
const Span<std::string>& AsSpan() {
|
||||
ORTX_CXX_API_THROW("span for TensorT of string not implemented", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
const std::string& AsScalar() {
|
||||
if (!shape_.has_value() || (shape_->size() == 1 && (*shape_)[0] != 1) || shape_->size() > 1) {
|
||||
ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
return input_strings_[0];
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::string> input_strings_; // for input
|
||||
};
|
||||
|
||||
template <>
|
||||
class Tensor<std::string_view> : public TensorBase {
|
||||
public:
|
||||
using strings = std::vector<std::string>;
|
||||
using string_views = std::vector<std::string_view>;
|
||||
|
||||
Tensor(const OrtW::CustomOpApi& api,
|
||||
OrtKernelContext& ctx,
|
||||
size_t indice,
|
||||
bool is_input) : TensorBase(api,
|
||||
ctx,
|
||||
indice,
|
||||
is_input) {
|
||||
if (is_input_) {
|
||||
auto input_count = api_.KernelContext_GetInputCount(&ctx_);
|
||||
if (indice >= input_count) {
|
||||
ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
auto* const_value = api_.KernelContext_GetInput(&ctx_, indice);
|
||||
auto* info = api_.GetTensorTypeAndShape(const_value);
|
||||
shape_ = api_.GetTensorShape(info);
|
||||
api_.ReleaseTensorTypeAndShapeInfo(info);
|
||||
|
||||
size_t num_chars;
|
||||
OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorDataLength(const_value, &num_chars));
|
||||
chars_.resize(num_chars + 1, '\0');
|
||||
|
||||
auto num_strings = static_cast<size_t>(NumberOfElement());
|
||||
if (num_strings) {
|
||||
std::vector<size_t> offsets(num_strings);
|
||||
OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorContent(const_value,
|
||||
(void*)chars_.data(),
|
||||
num_chars,
|
||||
offsets.data(),
|
||||
offsets.size()));
|
||||
offsets.push_back(num_chars);
|
||||
for (size_t i = 0; i < num_strings; ++i) {
|
||||
input_string_views_.emplace_back(chars_.data() + offsets[i], offsets[i + 1] - offsets[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
int64_t NumberOfElement() const {
|
||||
if (shape_.has_value()) {
|
||||
return std::accumulate(shape_->begin(), shape_->end(), 1ULL, std::multiplies<int64_t>());
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
const string_views& Data() const {
|
||||
return input_string_views_;
|
||||
}
|
||||
const Span<std::string_view>& AsSpan() {
|
||||
ORTX_CXX_API_THROW("span for TensorT of string view not implemented", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
std::string_view AsScalar() {
|
||||
if (!shape_.has_value() || (shape_->size() == 1 && (*shape_)[0] != 1) || shape_->size() > 1) {
|
||||
ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
return input_string_views_[0];
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<char> chars_; // for input
|
||||
std::vector<std::string_view> input_string_views_; // for input
|
||||
};
|
||||
|
||||
using TensorPtr = std::unique_ptr<Custom::TensorBase>;
|
||||
|
||||
struct OrtLiteCustomOp : public OrtCustomOp {
|
||||
using ConstOptionalFloatTensor = std::optional<const Custom::Tensor<float>&>;
|
||||
using OptionalFloatTensor = std::optional<Custom::Tensor<float>>;
|
||||
|
||||
// CreateTuple
|
||||
template <size_t ith_input, size_t ith_output, typename... Ts>
|
||||
static typename std::enable_if<sizeof...(Ts) == 0, std::tuple<>>::type
|
||||
CreateTuple(const OrtW::CustomOpApi*, OrtKernelContext*, std::vector<TensorPtr>&, size_t, size_t, const std::string&) {
|
||||
return std::make_tuple();
|
||||
}
|
||||
|
||||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
|
||||
static typename std::enable_if<std::is_same<T, OrtKernelContext*>::value, std::tuple<T, Ts...>>::type
|
||||
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
|
||||
std::tuple<T> current = std::tuple<OrtKernelContext*>{context};
|
||||
auto next = CreateTuple<ith_input, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
|
||||
return std::tuple_cat(current, next);
|
||||
}
|
||||
|
||||
#define CREATE_TUPLE_INPUT(data_type) \
|
||||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
||||
static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
|
||||
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
|
||||
tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
|
||||
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())}; \
|
||||
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
|
||||
return std::tuple_cat(current, next); \
|
||||
} \
|
||||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
||||
static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \
|
||||
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
|
||||
tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
|
||||
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())}; \
|
||||
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
|
||||
return std::tuple_cat(current, next); \
|
||||
} \
|
||||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
||||
static typename std::enable_if<std::is_same<T, std::optional<const Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
|
||||
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
|
||||
if (ith_input < num_input) { \
|
||||
tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
|
||||
std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())}; \
|
||||
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
|
||||
return std::tuple_cat(current, next); \
|
||||
} else { \
|
||||
std::tuple<T> current = std::tuple<T>{}; \
|
||||
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
|
||||
return std::tuple_cat(current, next); \
|
||||
} \
|
||||
} \
|
||||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
||||
static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>*>::value, std::tuple<T, Ts...>>::type \
|
||||
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
|
||||
if ("CPUExecutionProvider" != ep) { \
|
||||
ORTX_CXX_API_THROW("span input could only be applied to CPU EP", ORT_FAIL); \
|
||||
} \
|
||||
tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
|
||||
std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsSpan()}; \
|
||||
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
|
||||
return std::tuple_cat(current, next); \
|
||||
} \
|
||||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
||||
static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>&>::value, std::tuple<T, Ts...>>::type \
|
||||
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
|
||||
if ("CPUExecutionProvider" != ep) { \
|
||||
ORTX_CXX_API_THROW("span input could only be applied to CPU EP", ORT_FAIL); \
|
||||
} \
|
||||
tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
|
||||
std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsSpan()}; \
|
||||
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
|
||||
return std::tuple_cat(current, next); \
|
||||
} \
|
||||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
||||
static typename std::enable_if<std::is_same<T, std::optional<const Custom::Span<data_type>*>>::value, std::tuple<T, Ts...>>::type \
|
||||
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
|
||||
if (ith_input < num_input) { \
|
||||
if ("CPUExecutionProvider" != ep) { \
|
||||
ORTX_CXX_API_THROW("span input could only be applied to CPU EP", ORT_FAIL); \
|
||||
} \
|
||||
tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
|
||||
std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsSpan()}; \
|
||||
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
|
||||
return std::tuple_cat(current, next); \
|
||||
} else { \
|
||||
std::tuple<T> current = std::tuple<T>{}; \
|
||||
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
|
||||
return std::tuple_cat(current, next); \
|
||||
} \
|
||||
} \
|
||||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
||||
static typename std::enable_if<std::is_same<T, data_type>::value, std::tuple<T, Ts...>>::type \
|
||||
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
|
||||
if ("CPUExecutionProvider" != ep) { \
|
||||
ORTX_CXX_API_THROW("scalar input could only be applied to CPU EP", ORT_FAIL); \
|
||||
} \
|
||||
tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
|
||||
std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsScalar()}; \
|
||||
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
|
||||
return std::tuple_cat(current, next); \
|
||||
} \
|
||||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
||||
static typename std::enable_if<std::is_same<T, std::optional<data_type>>::value, std::tuple<T, Ts...>>::type \
|
||||
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
|
||||
if (ith_input < num_input) { \
|
||||
if ("CPUExecutionProvider" != ep) { \
|
||||
ORTX_CXX_API_THROW("scalar input could only be applied to CPU EP", ORT_FAIL); \
|
||||
} \
|
||||
tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
|
||||
std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsScalar()}; \
|
||||
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
|
||||
return std::tuple_cat(current, next); \
|
||||
} else { \
|
||||
std::tuple<T> current = std::tuple<T>{}; \
|
||||
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
|
||||
return std::tuple_cat(current, next); \
|
||||
} \
|
||||
}
|
||||
#define CREATE_TUPLE_OUTPUT(data_type) \
|
||||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
||||
static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
|
||||
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
|
||||
tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_output, false)); \
|
||||
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())}; \
|
||||
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \
|
||||
return std::tuple_cat(current, next); \
|
||||
} \
|
||||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
||||
static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \
|
||||
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
|
||||
tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_output, false)); \
|
||||
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())}; \
|
||||
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \
|
||||
return std::tuple_cat(current, next); \
|
||||
} \
|
||||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
||||
static typename std::enable_if<std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
|
||||
CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
|
||||
if (ith_output < num_output) { \
|
||||
tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_output, false)); \
|
||||
std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())}; \
|
||||
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \
|
||||
return std::tuple_cat(current, next); \
|
||||
} else { \
|
||||
std::tuple<T> current = std::tuple<T>{}; \
|
||||
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \
|
||||
return std::tuple_cat(current, next); \
|
||||
} \
|
||||
}
|
||||
#define CREATE_TUPLE(data_type) \
|
||||
CREATE_TUPLE_INPUT(data_type) \
|
||||
CREATE_TUPLE_OUTPUT(data_type)
|
||||
|
||||
CREATE_TUPLE(bool)
|
||||
CREATE_TUPLE(float)
|
||||
CREATE_TUPLE(double)
|
||||
CREATE_TUPLE(int8_t)
|
||||
CREATE_TUPLE(int16_t)
|
||||
CREATE_TUPLE(int32_t)
|
||||
CREATE_TUPLE(int64_t)
|
||||
CREATE_TUPLE(uint8_t)
|
||||
CREATE_TUPLE(uint16_t)
|
||||
CREATE_TUPLE(uint32_t)
|
||||
CREATE_TUPLE(uint64_t)
|
||||
CREATE_TUPLE(std::string)
|
||||
CREATE_TUPLE_INPUT(std::string_view)
|
||||
|
||||
// ParseArgs ...
|
||||
template <typename... Ts>
|
||||
static typename std::enable_if<0 == sizeof...(Ts)>::type
|
||||
ParseArgs(std::vector<ONNXTensorElementDataType>&, std::vector<ONNXTensorElementDataType>&) {
|
||||
}
|
||||
|
||||
template <typename T, typename... Ts>
|
||||
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, OrtKernelContext*>::value>::type
|
||||
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
|
||||
ParseArgs<Ts...>(input_types, output_types);
|
||||
}
|
||||
|
||||
#define PARSE_INPUT_BASE(pack_type, onnx_type) \
|
||||
template <typename T, typename... Ts> \
|
||||
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, pack_type>::value>::type \
|
||||
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
|
||||
input_types.push_back(onnx_type); \
|
||||
ParseArgs<Ts...>(input_types, output_types); \
|
||||
} \
|
||||
template <typename T, typename... Ts> \
|
||||
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<pack_type>>::value>::type \
|
||||
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
|
||||
input_types.push_back(onnx_type); \
|
||||
ParseArgs<Ts...>(input_types, output_types); \
|
||||
}
|
||||
|
||||
#define PARSE_INPUT(data_type, onnx_type) \
|
||||
PARSE_INPUT_BASE(const Custom::Tensor<data_type>*, onnx_type) \
|
||||
PARSE_INPUT_BASE(const Custom::Tensor<data_type>&, onnx_type) \
|
||||
PARSE_INPUT_BASE(const Custom::Span<data_type>*, onnx_type) \
|
||||
PARSE_INPUT_BASE(const Custom::Span<data_type>&, onnx_type) \
|
||||
PARSE_INPUT_BASE(data_type, onnx_type)
|
||||
|
||||
#define PARSE_OUTPUT(data_type, onnx_type) \
|
||||
template <typename T, typename... Ts> \
|
||||
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>*>::value>::type \
|
||||
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
|
||||
output_types.push_back(onnx_type); \
|
||||
ParseArgs<Ts...>(input_types, output_types); \
|
||||
} \
|
||||
template <typename T, typename... Ts> \
|
||||
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>&>::value>::type \
|
||||
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
|
||||
output_types.push_back(onnx_type); \
|
||||
ParseArgs<Ts...>(input_types, output_types); \
|
||||
} \
|
||||
template <typename T, typename... Ts> \
|
||||
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value>::type \
|
||||
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
|
||||
output_types.push_back(onnx_type); \
|
||||
ParseArgs<Ts...>(input_types, output_types); \
|
||||
}
|
||||
|
||||
#define PARSE_ARGS(data_type, onnx_type) \
|
||||
PARSE_INPUT(data_type, onnx_type) \
|
||||
PARSE_OUTPUT(data_type, onnx_type)
|
||||
|
||||
PARSE_ARGS(bool, ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL)
|
||||
PARSE_ARGS(float, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
|
||||
PARSE_ARGS(double, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE)
|
||||
PARSE_ARGS(int8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8)
|
||||
PARSE_ARGS(int16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16)
|
||||
PARSE_ARGS(int32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32)
|
||||
PARSE_ARGS(int64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)
|
||||
PARSE_ARGS(uint8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8)
|
||||
PARSE_ARGS(uint16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16)
|
||||
PARSE_ARGS(uint32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32)
|
||||
PARSE_ARGS(uint64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64)
|
||||
PARSE_ARGS(std::string, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING)
|
||||
PARSE_ARGS(std::string_view, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) // todo - remove string_view output
|
||||
|
||||
OrtLiteCustomOp(const char* op_name,
|
||||
const char* execution_provider) : op_name_(op_name),
|
||||
execution_provider_(execution_provider) {
|
||||
OrtCustomOp::version = MIN_ORT_VERSION_SUPPORTED;
|
||||
|
||||
OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast<const OrtLiteCustomOp*>(op)->op_name_.c_str(); };
|
||||
OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* op) { return ((OrtLiteCustomOp*)op)->execution_provider_.c_str(); };
|
||||
|
||||
OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* op) {
|
||||
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
|
||||
return self->input_types_.size();
|
||||
};
|
||||
|
||||
OrtCustomOp::GetInputType = [](const OrtCustomOp* op, size_t indice) {
|
||||
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
|
||||
return self->input_types_[indice];
|
||||
};
|
||||
|
||||
OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* op) {
|
||||
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
|
||||
return self->output_types_.size();
|
||||
};
|
||||
|
||||
OrtCustomOp::GetOutputType = [](const OrtCustomOp* op, size_t indice) {
|
||||
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
|
||||
return self->output_types_[indice];
|
||||
};
|
||||
|
||||
OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp*, size_t) {
|
||||
return INPUT_OUTPUT_OPTIONAL;
|
||||
};
|
||||
|
||||
OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp*, size_t) {
|
||||
return INPUT_OUTPUT_OPTIONAL;
|
||||
};
|
||||
}
|
||||
|
||||
const std::string op_name_;
|
||||
const std::string execution_provider_;
|
||||
|
||||
std::vector<ONNXTensorElementDataType> input_types_;
|
||||
std::vector<ONNXTensorElementDataType> output_types_;
|
||||
};
|
||||
|
||||
template <typename... Args>
|
||||
struct OrtLiteCustomFunc : public OrtLiteCustomOp {
|
||||
using ComputeFn = void (*)(Args...);
|
||||
using MyType = OrtLiteCustomFunc<Args...>;
|
||||
|
||||
struct Kernel {
|
||||
ComputeFn compute_fn_{};
|
||||
std::string ep_{};
|
||||
std::unique_ptr<OrtW::CustomOpApi> api_;
|
||||
};
|
||||
|
||||
OrtLiteCustomFunc(const char* op_name,
|
||||
const char* execution_provider,
|
||||
ComputeFn compute_fn) : OrtLiteCustomOp(op_name, execution_provider),
|
||||
compute_fn_(compute_fn) {
|
||||
ParseArgs<Args...>(input_types_, output_types_);
|
||||
|
||||
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
|
||||
auto kernel = reinterpret_cast<Kernel*>(op_kernel);
|
||||
std::vector<TensorPtr> tensors;
|
||||
auto t = CreateTuple<0, 0, Args...>(kernel->api_.get(),
|
||||
context,
|
||||
tensors,
|
||||
kernel->api_->KernelContext_GetInputCount(context),
|
||||
kernel->api_->KernelContext_GetOutputCount(context),
|
||||
kernel->ep_);
|
||||
std::apply([kernel](Args const&... t_args) { kernel->compute_fn_(t_args...); }, t);
|
||||
};
|
||||
|
||||
OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
|
||||
auto kernel = std::make_unique<Kernel>();
|
||||
auto self = static_cast<const OrtLiteCustomFunc*>(this_);
|
||||
kernel->compute_fn_ = self->compute_fn_;
|
||||
kernel->ep_ = self->execution_provider_;
|
||||
kernel->api_ = std::make_unique<OrtW::CustomOpApi>(*ort_api);
|
||||
return reinterpret_cast<void*>(kernel.release());
|
||||
};
|
||||
|
||||
OrtCustomOp::KernelDestroy = [](void* op_kernel) {
|
||||
delete reinterpret_cast<Kernel*>(op_kernel);
|
||||
};
|
||||
}
|
||||
|
||||
ComputeFn compute_fn_;
|
||||
};
|
||||
|
||||
template <typename CustomOp>
|
||||
struct OrtLiteCustomStruct : public OrtLiteCustomOp {
|
||||
template <typename... Args>
|
||||
using CustomComputeFn = void (CustomOp::*)(Args...);
|
||||
using MyType = OrtLiteCustomStruct<CustomOp>;
|
||||
|
||||
struct Kernel {
|
||||
std::unique_ptr<CustomOp> custom_op_;
|
||||
std::string ep_{};
|
||||
std::unique_ptr<OrtW::CustomOpApi> api_;
|
||||
};
|
||||
|
||||
OrtLiteCustomStruct(const char* op_name,
|
||||
const char* execution_provider) : OrtLiteCustomOp(op_name,
|
||||
execution_provider) {
|
||||
init(&CustomOp::Compute);
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
void init(CustomComputeFn<Args...>) {
|
||||
ParseArgs<Args...>(input_types_, output_types_);
|
||||
|
||||
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
|
||||
auto kernel = reinterpret_cast<Kernel*>(op_kernel);
|
||||
std::vector<TensorPtr> tensors;
|
||||
auto t = CreateTuple<0, 0, Args...>(kernel->api_.get(),
|
||||
context,
|
||||
tensors,
|
||||
kernel->api_->KernelContext_GetInputCount(context),
|
||||
kernel->api_->KernelContext_GetOutputCount(context),
|
||||
kernel->ep_);
|
||||
std::apply([kernel](Args const&... t_args) { kernel->custom_op_->Compute(t_args...); }, t);
|
||||
};
|
||||
|
||||
OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
|
||||
auto kernel = std::make_unique<Kernel>();
|
||||
kernel->custom_op_ = std::make_unique<CustomOp>(*ort_api, *info);
|
||||
auto self = static_cast<const MyType*>(this_);
|
||||
kernel->ep_ = self->execution_provider_;
|
||||
kernel->api_ = std::make_unique<OrtW::CustomOpApi>(*ort_api);
|
||||
return reinterpret_cast<void*>(kernel.release());
|
||||
};
|
||||
|
||||
OrtCustomOp::KernelDestroy = [](void* op_kernel) {
|
||||
delete reinterpret_cast<Kernel*>(op_kernel);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
template <typename... Args>
|
||||
OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
|
||||
const char* execution_provider,
|
||||
void (*custom_compute_fn)(Args...)) {
|
||||
using LiteOp = OrtLiteCustomFunc<Args...>;
|
||||
return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn).release();
|
||||
}
|
||||
|
||||
template <typename CustomOp>
|
||||
OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
|
||||
const char* execution_provider) {
|
||||
using LiteOp = OrtLiteCustomStruct<CustomOp>;
|
||||
return std::make_unique<LiteOp>(op_name, execution_provider).release();
|
||||
}
|
||||
|
||||
} // namespace Custom
|
||||
} // namespace Ort
|
|
@ -85,6 +85,53 @@ class CuopContainer {
|
|||
std::vector<std::shared_ptr<OrtCustomOp>> op_instances_; // use shared_ptr to capture type specific deleter
|
||||
};
|
||||
|
||||
#define CustomCpuFunc(name, f) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOp(name, "CPUExecutionProvider", f)); }
|
||||
#define CustomCpuStruct(name, s) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOp<s>(name, "CPUExecutionProvider")); }
|
||||
|
||||
template <typename F>
|
||||
void AppendCustomOp(std::vector<std::shared_ptr<OrtCustomOp>>& ops,
|
||||
F arg) {
|
||||
ops.emplace_back(std::move(arg()));
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
void AppendCustomOp(std::vector<std::shared_ptr<OrtCustomOp>>& ops,
|
||||
T arg, Args... args) {
|
||||
AppendCustomOp(ops, arg);
|
||||
AppendCustomOp(ops, args...);
|
||||
}
|
||||
|
||||
class OrtOpLoader {
|
||||
public:
|
||||
template <typename... Args>
|
||||
OrtOpLoader(Args... args) {
|
||||
LoadOps(args...);
|
||||
for (auto& ptr : op_instances_) {
|
||||
if (ptr)
|
||||
ocos_list_.push_back(ptr.get());
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<const OrtCustomOp*>& GetCustomOps() const {
|
||||
return ocos_list_;
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
void LoadOps(T fn) {
|
||||
AppendCustomOp(op_instances_, fn);
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
void LoadOps(T fn, Args... args) {
|
||||
AppendCustomOp(op_instances_, fn);
|
||||
AppendCustomOp(op_instances_, args...);
|
||||
}
|
||||
|
||||
std::vector<const OrtCustomOp*> ocos_list_;
|
||||
std::vector<std::shared_ptr<OrtCustomOp>> op_instances_;
|
||||
};
|
||||
|
||||
struct CustomOpClassBegin {
|
||||
};
|
||||
|
||||
|
|
|
@ -18,6 +18,8 @@
|
|||
|
||||
#include "exceptions.h"
|
||||
|
||||
#define MIN_ORT_VERSION_SUPPORTED 10
|
||||
|
||||
namespace OrtW {
|
||||
|
||||
//
|
||||
|
@ -54,6 +56,8 @@ struct CustomOpApi {
|
|||
OrtW::ThrowOnError(api_, status);
|
||||
}
|
||||
|
||||
const OrtApi& GetOrtApi() const { return api_; }
|
||||
|
||||
private:
|
||||
const OrtApi& api_;
|
||||
};
|
||||
|
@ -61,7 +65,7 @@ struct CustomOpApi {
|
|||
template <typename TOp, typename TKernel>
|
||||
struct CustomOpBase : OrtCustomOp {
|
||||
CustomOpBase() {
|
||||
OrtCustomOp::version = 10; // The minimum ORT version supported
|
||||
OrtCustomOp::version = MIN_ORT_VERSION_SUPPORTED; // The minimum ORT version supported
|
||||
OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) {
|
||||
void* result = nullptr;
|
||||
OCOS_API_IMPL_BEGIN
|
||||
|
@ -295,3 +299,7 @@ inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context,
|
|||
}
|
||||
|
||||
} // namespace OrtW
|
||||
|
||||
// !! TODO: only do it for legecy ort build
|
||||
#include "custom_op_lite.h"
|
||||
namespace ortc = Ort::Custom;
|
||||
|
|
|
@ -12,70 +12,28 @@
|
|||
|
||||
#include <cstdint>
|
||||
|
||||
struct KernelImageDecoder : BaseKernel {
|
||||
KernelImageDecoder(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {}
|
||||
|
||||
void Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* const inputs = ort_.KernelContext_GetInput(context, 0ULL);
|
||||
OrtTensorDimensions dimensions(ort_, inputs);
|
||||
if (dimensions.size() != 1ULL) {
|
||||
ORTX_CXX_API_THROW("[ImageDecoder]: Only raw image formats are supported.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
// Get data & the length
|
||||
const uint8_t* const encoded_image_data = ort_.GetTensorData<uint8_t>(inputs);
|
||||
|
||||
OrtTensorTypeAndShapeInfo* const input_info = ort_.GetTensorTypeAndShape(inputs);
|
||||
const int64_t encoded_image_data_len = ort_.GetTensorShapeElementCount(input_info);
|
||||
ort_.ReleaseTensorTypeAndShapeInfo(input_info);
|
||||
|
||||
// Decode the image
|
||||
const std::vector<int32_t> encoded_image_sizes{1, static_cast<int32_t>(encoded_image_data_len)};
|
||||
const cv::Mat encoded_image(encoded_image_sizes, CV_8UC1,
|
||||
const_cast<void*>(static_cast<const void*>(encoded_image_data)));
|
||||
const cv::Mat decoded_image = cv::imdecode(encoded_image, cv::IMREAD_COLOR);
|
||||
|
||||
// Setup output & copy to destination
|
||||
const cv::Size decoded_image_size = decoded_image.size();
|
||||
const int64_t colors = 3;
|
||||
|
||||
const std::vector<int64_t> output_dimensions{decoded_image_size.height, decoded_image_size.width, colors};
|
||||
OrtValue* const output_value = ort_.KernelContext_GetOutput(
|
||||
context, 0, output_dimensions.data(), output_dimensions.size());
|
||||
uint8_t* const decoded_image_data = ort_.GetTensorMutableData<uint8_t>(output_value);
|
||||
memcpy(decoded_image_data, decoded_image.data, decoded_image.total() * decoded_image.elemSize());
|
||||
}
|
||||
};
|
||||
|
||||
struct CustomOpImageDecoder : OrtW::CustomOpBase<CustomOpImageDecoder, KernelImageDecoder> {
|
||||
const char* GetName() const {
|
||||
return "ImageDecoder";
|
||||
void image_decoder(const ortc::Tensor<uint8_t>& input,
|
||||
ortc::Tensor<uint8_t>& output) {
|
||||
auto& dimensions = input.Shape();
|
||||
if (dimensions.size() != 1ULL) {
|
||||
ORTX_CXX_API_THROW("[ImageDecoder]: Only raw image formats are supported.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
size_t GetInputTypeCount() const {
|
||||
return 1;
|
||||
}
|
||||
// Get data & the length
|
||||
const uint8_t* const encoded_image_data = input.Data();
|
||||
const int64_t encoded_image_data_len = input.NumberOfElement();
|
||||
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const {
|
||||
switch (index) {
|
||||
case 0:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
||||
default:
|
||||
ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
}
|
||||
// Decode the image
|
||||
const std::vector<int32_t> encoded_image_sizes{1, static_cast<int32_t>(encoded_image_data_len)};
|
||||
const cv::Mat encoded_image(encoded_image_sizes, CV_8UC1,
|
||||
const_cast<void*>(static_cast<const void*>(encoded_image_data)));
|
||||
const cv::Mat decoded_image = cv::imdecode(encoded_image, cv::IMREAD_COLOR);
|
||||
|
||||
size_t GetOutputTypeCount() const {
|
||||
return 1;
|
||||
}
|
||||
// Setup output & copy to destination
|
||||
const cv::Size decoded_image_size = decoded_image.size();
|
||||
const int64_t colors = 3;
|
||||
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const {
|
||||
switch (index) {
|
||||
case 0:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
||||
default:
|
||||
ORTX_CXX_API_THROW(MakeString("Unexpected output index ", index), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
}
|
||||
};
|
||||
const std::vector<int64_t> output_dimensions{decoded_image_size.height, decoded_image_size.width, colors};
|
||||
uint8_t* const decoded_image_data = output.Allocate(output_dimensions);
|
||||
memcpy(decoded_image_data, decoded_image.data, decoded_image.total() * decoded_image.elemSize());
|
||||
}
|
||||
|
|
|
@ -27,6 +27,20 @@ struct KernelImageReader : BaseKernel {
|
|||
}
|
||||
};
|
||||
|
||||
void image_reader(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<uint8_t>& output) {
|
||||
auto& input_data_dimensions = input.Shape();
|
||||
int n = input_data_dimensions[0];
|
||||
if (n != 1) {
|
||||
ORTX_CXX_API_THROW("[ImageReader]: the dimension of input value can only be 1 now.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
auto& image_paths = input.Data();
|
||||
cv::Mat img = cv::imread(image_paths[0], cv::IMREAD_COLOR);
|
||||
std::vector<int64_t> output_dimensions = {1, img.size[0], img.size[1], static_cast<int64_t>(img.elemSize())};
|
||||
std::uint8_t* p_output_image = output.Allocate(output_dimensions);
|
||||
memcpy(p_output_image, img.data, img.total() * img.elemSize());
|
||||
}
|
||||
|
||||
struct CustomOpImageReader : OrtW::CustomOpBase<CustomOpImageReader, KernelImageReader> {
|
||||
size_t GetInputTypeCount() const {
|
||||
return 1;
|
||||
|
|
|
@ -1,81 +1,40 @@
|
|||
#include <opencv2/core.hpp>
|
||||
#include <opencv2/imgproc.hpp>
|
||||
|
||||
struct KernelGaussianBlur : BaseKernel {
|
||||
KernelGaussianBlur(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
void gaussian_blur(const ortc::Tensor<float>& input_data,
|
||||
const ortc::Span<int64_t>& input_ksize,
|
||||
const ortc::Span<double>& input_sigma,
|
||||
ortc::Tensor<float>& output) {
|
||||
const float* p_input_data = input_data.Data();
|
||||
std::int64_t ksize[] = {3, 3};
|
||||
double sigma[] = {0., 0.};
|
||||
|
||||
if (input_ksize.size() != 2) {
|
||||
ORTX_CXX_API_THROW("[GaussianBlur]: ksize shape is (2,)", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
std::copy_n(input_ksize.data(), 2, ksize);
|
||||
|
||||
void Compute(OrtKernelContext* context) {
|
||||
size_t input_c = ort_.KernelContext_GetInputCount(context);
|
||||
const OrtValue* input_data = ort_.KernelContext_GetInput(context, 0);
|
||||
const float* p_input_data = ort_.GetTensorData<float>(input_data);
|
||||
std::int64_t ksize[] = {3, 3};
|
||||
double sigma[] = {0., 0.};
|
||||
if (input_c > 1) {
|
||||
const OrtValue* input_ksize = ort_.KernelContext_GetInput(context, 1);
|
||||
OrtTensorDimensions dim_ksize(ort_, input_ksize);
|
||||
if (dim_ksize.size() != 1 || dim_ksize[0] != 2) {
|
||||
ORTX_CXX_API_THROW("[GaussianBlur]: ksize shape is (2,)", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
std::copy_n(ort_.GetTensorData<std::int64_t>(input_ksize), 2, ksize);
|
||||
}
|
||||
|
||||
if (input_c > 2) {
|
||||
const OrtValue* input_sigma = ort_.KernelContext_GetInput(context, 2);
|
||||
OrtTensorDimensions dim_sigma(ort_, input_sigma);
|
||||
if (dim_sigma.size() != 1 || dim_sigma[0] != 2) {
|
||||
ORTX_CXX_API_THROW("[GaussianBlur]: sigma shape is (2,)", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
std::copy_n(ort_.GetTensorData<double>(input_sigma), 2, sigma);
|
||||
}
|
||||
|
||||
OrtTensorDimensions input_data_dimensions(ort_, input_data);
|
||||
|
||||
int n = static_cast<int>(input_data_dimensions[0]);
|
||||
int h = static_cast<int>(input_data_dimensions[1]);
|
||||
int w = static_cast<int>(input_data_dimensions[2]);
|
||||
int c = static_cast<int>(input_data_dimensions[3]);
|
||||
(void)n;
|
||||
(void)c;
|
||||
|
||||
cv::Mat input_image(cv::Size(w, h), CV_32FC3, (void*)p_input_data);
|
||||
cv::Mat output_image;
|
||||
cv::GaussianBlur(input_image,
|
||||
output_image,
|
||||
cv::Size(static_cast<int>(ksize[1]), static_cast<int>(ksize[0])),
|
||||
sigma[0], sigma[1], cv::BORDER_DEFAULT);
|
||||
|
||||
OrtValue* image_y = ort_.KernelContext_GetOutput(context,
|
||||
0, input_data_dimensions.data(), input_data_dimensions.size());
|
||||
float* p_output_image = ort_.GetTensorMutableData<float>(image_y);
|
||||
memcpy(p_output_image, output_image.data, output_image.total() * output_image.elemSize());
|
||||
if (input_sigma.size() != 2) {
|
||||
ORTX_CXX_API_THROW("[GaussianBlur]: sigma shape is (2,)", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
};
|
||||
std::copy_n(input_sigma.data(), 2, sigma);
|
||||
|
||||
struct CustomOpGaussianBlur : OrtW::CustomOpBase<CustomOpGaussianBlur, KernelGaussianBlur> {
|
||||
size_t GetInputTypeCount() const {
|
||||
return 3;
|
||||
}
|
||||
auto& input_data_dimensions = input_data.Shape();
|
||||
|
||||
size_t GetOutputTypeCount() const {
|
||||
return 1;
|
||||
}
|
||||
int n = static_cast<int>(input_data_dimensions[0]);
|
||||
int h = static_cast<int>(input_data_dimensions[1]);
|
||||
int w = static_cast<int>(input_data_dimensions[2]);
|
||||
int c = static_cast<int>(input_data_dimensions[3]);
|
||||
(void)n;
|
||||
(void)c;
|
||||
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const {
|
||||
if (index == 0) {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
} else if (index == 1) {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
} else {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
|
||||
}
|
||||
}
|
||||
cv::Mat input_image(cv::Size(w, h), CV_32FC3, (void*)p_input_data);
|
||||
cv::Mat output_image;
|
||||
cv::GaussianBlur(input_image,
|
||||
output_image,
|
||||
cv::Size(static_cast<int>(ksize[1]), static_cast<int>(ksize[0])),
|
||||
sigma[0], sigma[1], cv::BORDER_DEFAULT);
|
||||
|
||||
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
}
|
||||
|
||||
const char* GetName() const {
|
||||
return "GaussianBlur";
|
||||
}
|
||||
};
|
||||
float* p_output_image = output.Allocate(input_data_dimensions);
|
||||
memcpy(p_output_image, output_image.data, output_image.total() * output_image.elemSize());
|
||||
}
|
||||
|
|
|
@ -3,14 +3,17 @@
|
|||
#ifdef ENABLE_OPENCV_CODECS
|
||||
#include "imgcodecs/imread.hpp"
|
||||
#include "imgcodecs/imdecode.hpp"
|
||||
#endif // ENABLE_OPENCV_CODECS
|
||||
#endif // ENABLE_OPENCV_CODECS
|
||||
|
||||
|
||||
FxLoadCustomOpFactory LoadCustomOpClasses_CV2 =
|
||||
LoadCustomOpClasses<CustomOpClassBegin
|
||||
, CustomOpGaussianBlur
|
||||
const std::vector<const OrtCustomOp*>& Cv2Loader() {
|
||||
static OrtOpLoader op_loader(CustomCpuFunc("GaussianBlur", gaussian_blur)
|
||||
#ifdef ENABLE_OPENCV_CODECS
|
||||
, CustomOpImageReader
|
||||
, CustomOpImageDecoder
|
||||
#endif // ENABLE_OPENCV_CODECS
|
||||
>;
|
||||
,
|
||||
CustomCpuFunc("ImageDecoder", image_decoder),
|
||||
CustomCpuFunc("ImageReader", image_reader)
|
||||
#endif // ENABLE_OPENCV_CODECS
|
||||
);
|
||||
return op_loader.GetCustomOps();
|
||||
}
|
||||
|
||||
FxLoadCustomOpFactory LoadCustomOpClasses_CV2 = Cv2Loader;
|
||||
|
|
|
@ -6,52 +6,17 @@
|
|||
#include "ocos.h"
|
||||
#include <dlib/matrix.h>
|
||||
|
||||
|
||||
struct KernelInverse : BaseKernel {
|
||||
KernelInverse(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
void inverse(const ortc::Tensor<float>& input,
|
||||
ortc::Tensor<float>& output) {
|
||||
auto& dimensions = input.Shape();
|
||||
if (dimensions.size() != 2) {
|
||||
throw std::runtime_error("Only 2-d matrix supported.");
|
||||
}
|
||||
const float* X = input.Data();
|
||||
float* out = output.Allocate(dimensions);
|
||||
|
||||
void Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
||||
const float* X = ort_.GetTensorData<float>(input_X);
|
||||
|
||||
// Setup output
|
||||
OrtTensorDimensions dimensions(ort_, input_X);
|
||||
if (dimensions.size() != 2) {
|
||||
throw std::runtime_error("Only 2-d matrix supported.");
|
||||
}
|
||||
|
||||
OrtValue* output0 = ort_.KernelContext_GetOutput(
|
||||
context, 0, dimensions.data(), dimensions.size());
|
||||
float* out0 = ort_.GetTensorMutableData<float>(output0);
|
||||
|
||||
dlib::matrix<float> dm_x(dimensions[0], dimensions[1]);
|
||||
std::copy(X, X + dm_x.size(), dm_x.begin());
|
||||
dlib::matrix<float> dm = dlib::inv(dm_x);
|
||||
|
||||
memcpy(out0, dm.steal_memory().get(), dm_x.size() * sizeof(float));
|
||||
}
|
||||
};
|
||||
|
||||
struct CustomOpInverse : OrtW::CustomOpBase<CustomOpInverse, KernelInverse> {
|
||||
const char* GetName() const {
|
||||
return "Inverse";
|
||||
}
|
||||
|
||||
size_t GetInputTypeCount() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
}
|
||||
|
||||
size_t GetOutputTypeCount() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
}
|
||||
};
|
||||
dlib::matrix<float> dm_x(dimensions[0], dimensions[1]);
|
||||
std::copy(X, X + dm_x.size(), dm_x.begin());
|
||||
dlib::matrix<float> dm = dlib::inv(dm_x);
|
||||
memcpy(out, dm.steal_memory().get(), dm_x.size() * sizeof(float));
|
||||
}
|
||||
|
|
|
@ -6,38 +6,31 @@
|
|||
#include "ocos.h"
|
||||
#include <dlib/matrix.h>
|
||||
|
||||
struct KernelStft : BaseKernel {
|
||||
KernelStft(const OrtApi& api, const OrtKernelInfo& info, bool return_magnitude)
|
||||
: BaseKernel(api, info), return_magnitude_(return_magnitude) {
|
||||
struct STFT : public BaseKernel {
|
||||
STFT(const OrtApi& api, const OrtKernelInfo& info,
|
||||
bool with_norm = false) : BaseKernel(api, info),
|
||||
with_norm_(with_norm) {
|
||||
onesided_ = TryToGetAttributeWithDefault<int64_t>("onesided", 1);
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext* context) {
|
||||
const OrtValue* input_x1 = ort_.KernelContext_GetInput(context, 0);
|
||||
const OrtValue* input_x2 = ort_.KernelContext_GetInput(context, 1);
|
||||
const OrtValue* input_x3 = ort_.KernelContext_GetInput(context, 2);
|
||||
const OrtValue* input_x4 = ort_.KernelContext_GetInput(context, 3);
|
||||
const OrtValue* input_x5 = ort_.KernelContext_GetInput(context, 4);
|
||||
void Compute(const ortc::Tensor<float>& input0,
|
||||
int64_t n_fft,
|
||||
int64_t hop_length,
|
||||
const ortc::Span<float>& input3,
|
||||
int64_t frame_length,
|
||||
ortc::Tensor<float>& output0) {
|
||||
auto X = input0.Data();
|
||||
auto window = input3.data();
|
||||
auto dimensions = input0.Shape();
|
||||
auto win_length = input3.size();
|
||||
|
||||
const float* X = ort_.GetTensorData<float>(input_x1);
|
||||
auto n_fft = *ort_.GetTensorData<int64_t>(input_x2);
|
||||
auto hop_length = *ort_.GetTensorData<int64_t>(input_x3);
|
||||
auto window = ort_.GetTensorData<float>(input_x4);
|
||||
auto frame_length = *ort_.GetTensorData<int64_t>(input_x5);
|
||||
|
||||
OrtTensorDimensions dimensions(ort_, input_x1);
|
||||
OrtTensorDimensions win_dim(ort_, input_x4);
|
||||
if (dimensions.size() < 2 || dimensions.Size() != dimensions[1]) {
|
||||
if (dimensions.size() < 2 || input0.NumberOfElement() != dimensions[1]) {
|
||||
ORTX_CXX_API_THROW("[Stft] Only batch == 1 tensor supported.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
if (win_dim.size() != 1) {
|
||||
ORTX_CXX_API_THROW("[Stft] Only 1-d hanning window supported.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
if (frame_length != n_fft) {
|
||||
ORTX_CXX_API_THROW("[Stft] Only support size of FFT equals the frame length.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
auto win_length = win_dim[0];
|
||||
dlib::matrix<float> dm_x = dlib::mat(X, 1, dimensions[1]);
|
||||
dlib::matrix<float> hann_win = dlib::mat(window, 1, win_length);
|
||||
|
||||
|
@ -49,23 +42,19 @@ struct KernelStft : BaseKernel {
|
|||
m_stft = dlib::subm(m_stft, 0, 0, m_stft.nr(), (m_stft.nc() >> 1) + 1);
|
||||
}
|
||||
|
||||
if (return_magnitude_) {
|
||||
if (with_norm_) {
|
||||
dlib::matrix<float> result = dlib::norm(m_stft);
|
||||
result = dlib::trans(result);
|
||||
int64_t outdim[] = {1, result.nr(), result.nc()};
|
||||
std::vector<int64_t> outdim{1, result.nr(), result.nc()};
|
||||
auto result_size = result.size();
|
||||
OrtValue* output0 = ort_.KernelContext_GetOutput(
|
||||
context, 0, outdim, 3);
|
||||
float* out0 = ort_.GetTensorMutableData<float>(output0);
|
||||
auto out0 = output0.Allocate(outdim);
|
||||
memcpy(out0, result.steal_memory().get(), result_size * sizeof(float));
|
||||
} else {
|
||||
auto result = m_stft;
|
||||
// No transpose here since it is done on copying data,
|
||||
// switch nr and nc, so the output dim willbe tranposed one.
|
||||
int64_t outdim[] = {1, result.nc(), result.nr(), 2};
|
||||
OrtValue* output0 = ort_.KernelContext_GetOutput(
|
||||
context, 0, outdim, 4);
|
||||
float* out0 = ort_.GetTensorMutableData<float>(output0);
|
||||
std::vector<int64_t> outdim{1, result.nc(), result.nr(), 2};
|
||||
auto out0 = output0.Allocate(outdim);
|
||||
for (size_t c = 0; c < result.nc(); ++c) {
|
||||
for (size_t r = 0; r < result.nr(); ++r) {
|
||||
*out0 = result(r, c).real();
|
||||
|
@ -76,50 +65,18 @@ struct KernelStft : BaseKernel {
|
|||
}
|
||||
}
|
||||
|
||||
private:
|
||||
bool return_magnitude_;
|
||||
int64_t onesided_;
|
||||
int64_t onesided_{};
|
||||
bool with_norm_{};
|
||||
};
|
||||
|
||||
struct CustomOpStft : OrtW::CustomOpBase<CustomOpStft, KernelStft> {
|
||||
const char* GetName() const {
|
||||
return op_name_;
|
||||
struct StftNormal : public STFT {
|
||||
StftNormal(const OrtApi& api, const OrtKernelInfo& info) : STFT(api, info, true) {}
|
||||
void Compute(const ortc::Tensor<float>& input0,
|
||||
int64_t n_fft,
|
||||
int64_t hop_length,
|
||||
const ortc::Span<float>& input3,
|
||||
int64_t frame_length,
|
||||
ortc::Tensor<float>& output0) {
|
||||
STFT::Compute(input0, n_fft, hop_length, input3, frame_length, output0);
|
||||
}
|
||||
|
||||
size_t GetInputTypeCount() const {
|
||||
return 5;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const {
|
||||
// pcm and window are float
|
||||
if (index == 0 || index == 3) {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
} else {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
}
|
||||
}
|
||||
|
||||
size_t GetOutputTypeCount() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo& info) const {
|
||||
return new KernelStft(api, info, with_norm_);
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
}
|
||||
|
||||
protected:
|
||||
bool with_norm_ = false;
|
||||
const char* op_name_ = "STFT";
|
||||
};
|
||||
|
||||
struct CustomOpStftNorm : CustomOpStft {
|
||||
public:
|
||||
CustomOpStftNorm() {
|
||||
with_norm_ = true;
|
||||
op_name_ = "StftNorm";
|
||||
}
|
||||
};
|
||||
};
|
|
@ -7,14 +7,16 @@
|
|||
#include "segment_extraction.hpp"
|
||||
#include "segment_sum.hpp"
|
||||
|
||||
|
||||
FxLoadCustomOpFactory LoadCustomOpClasses_Math =
|
||||
LoadCustomOpClasses<CustomOpClassBegin,
|
||||
CustomOpNegPos,
|
||||
const std::vector<const OrtCustomOp*>& MathLoader() {
|
||||
static OrtOpLoader op_loader(CustomCpuFunc("NegPos", neg_pos),
|
||||
#ifdef ENABLE_DLIB
|
||||
CustomOpInverse,
|
||||
CustomOpStft,
|
||||
CustomOpStftNorm,
|
||||
CustomCpuFunc("Inverse", inverse),
|
||||
CustomCpuStruct("STFT", STFT),
|
||||
CustomCpuStruct("StftNorm", StftNormal),
|
||||
#endif
|
||||
CustomOpSegmentExtraction,
|
||||
CustomOpSegmentSum>;
|
||||
CustomCpuFunc("SegmentExtraction", segment_extraction),
|
||||
CustomCpuFunc("SegmentSum", segment_sum));
|
||||
return op_loader.GetCustomOps();
|
||||
}
|
||||
|
||||
FxLoadCustomOpFactory LoadCustomOpClasses_Math = MathLoader;
|
|
@ -5,58 +5,21 @@
|
|||
|
||||
#include "ocos.h"
|
||||
|
||||
struct KernelNegPos : BaseKernel {
|
||||
KernelNegPos(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
||||
const float* X = ort_.GetTensorData<float>(input_X);
|
||||
|
||||
// Setup output
|
||||
OrtTensorDimensions dimensions(ort_, input_X);
|
||||
|
||||
OrtValue* output0 = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
|
||||
float* out0 = ort_.GetTensorMutableData<float>(output0);
|
||||
OrtValue* output1 = ort_.KernelContext_GetOutput(context, 1, dimensions.data(), dimensions.size());
|
||||
float* out1 = ort_.GetTensorMutableData<float>(output1);
|
||||
|
||||
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output0);
|
||||
int64_t size = ort_.GetTensorShapeElementCount(output_info);
|
||||
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
|
||||
|
||||
// Do computation
|
||||
for (int64_t i = 0; i < size; i++) {
|
||||
if (X[i] > 0) {
|
||||
out0[i] = 0;
|
||||
out1[i] = X[i];
|
||||
} else {
|
||||
out0[i] = X[i];
|
||||
out1[i] = 0;
|
||||
}
|
||||
void neg_pos(const ortc::Tensor<float>& input,
|
||||
ortc::Tensor<float>& out0_tensor,
|
||||
ortc::Tensor<float>& out1_tensor) {
|
||||
int64_t size = input.NumberOfElement();
|
||||
float* out0 = out0_tensor.Allocate(input.Shape());
|
||||
float* out1 = out1_tensor.Allocate(input.Shape());
|
||||
const float* X = input.Data();
|
||||
// Do computation
|
||||
for (int64_t i = 0; i < size; i++) {
|
||||
if (X[i] > 0) {
|
||||
out0[i] = 0;
|
||||
out1[i] = X[i];
|
||||
} else {
|
||||
out0[i] = X[i];
|
||||
out1[i] = 0;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct CustomOpNegPos : OrtW::CustomOpBase<CustomOpNegPos, KernelNegPos> {
|
||||
const char* GetName() const {
|
||||
return "NegPos";
|
||||
}
|
||||
|
||||
size_t GetInputTypeCount() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
}
|
||||
|
||||
size_t GetOutputTypeCount() const {
|
||||
return 2;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
|
@ -3,21 +3,17 @@
|
|||
|
||||
#include "segment_extraction.hpp"
|
||||
|
||||
KernelSegmentExtraction::KernelSegmentExtraction(const OrtApi& api, const OrtKernelInfo& info)
|
||||
: BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void KernelSegmentExtraction::Compute(OrtKernelContext* context) {
|
||||
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
|
||||
const int64_t* p_data = ort_.GetTensorData<int64_t>(input);
|
||||
OrtTensorDimensions input_dim(ort_, input);
|
||||
void segment_extraction(const ortc::Tensor<int64_t>& input,
|
||||
ortc::Tensor<int64_t>& output0,
|
||||
ortc::Tensor<int64_t>& output1) {
|
||||
auto& input_dim = input.Shape();
|
||||
if (!((input_dim.size() == 1) || (input_dim.size() == 2 && input_dim[0] == 1))) {
|
||||
ORTX_CXX_API_THROW("[SegmentExtraction]: Expect input dimension [n] or [1,n].", ORT_INVALID_GRAPH);
|
||||
}
|
||||
|
||||
const int64_t* p_data = input.Data();
|
||||
std::vector<std::int64_t> segment_value;
|
||||
std::vector<std::int64_t> segment_position;
|
||||
for (std::int64_t i = 0; i < input_dim.Size(); i++) {
|
||||
for (std::int64_t i = 0; i < input.NumberOfElement(); i++) {
|
||||
if (!p_data[i]) {
|
||||
continue;
|
||||
}
|
||||
|
@ -29,33 +25,17 @@ void KernelSegmentExtraction::Compute(OrtKernelContext* context) {
|
|||
}
|
||||
|
||||
// push end position
|
||||
if (i == (input_dim.Size() - 1) || p_data[i + 1] != p_data[i]) {
|
||||
if (i == (input.NumberOfElement() - 1) || p_data[i + 1] != p_data[i]) {
|
||||
segment_position.push_back(i + 1);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int64_t> segment_value_dim({static_cast<int64_t>(segment_value.size())});
|
||||
std::vector<int64_t> segment_position_dim({static_cast<int64_t>(segment_value.size()), 2});
|
||||
SetOutput(context, 0, segment_position_dim, segment_position);
|
||||
SetOutput(context, 1, segment_value_dim, segment_value);
|
||||
|
||||
int64_t* out0_data = output0.Allocate(segment_position_dim);
|
||||
std::copy(segment_position.begin(), segment_position.end(), out0_data);
|
||||
|
||||
int64_t* out1_data = output1.Allocate(segment_value_dim);
|
||||
std::copy(segment_value.begin(), segment_value.end(), out1_data);
|
||||
}
|
||||
|
||||
size_t CustomOpSegmentExtraction::GetInputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
size_t CustomOpSegmentExtraction::GetOutputTypeCount() const {
|
||||
return 2;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpSegmentExtraction::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
};
|
||||
|
||||
const char* CustomOpSegmentExtraction::GetName() const {
|
||||
return "SegmentExtraction";
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpSegmentExtraction::GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
};
|
||||
|
|
|
@ -6,15 +6,6 @@
|
|||
#include "ocos.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
struct KernelSegmentExtraction : BaseKernel {
|
||||
KernelSegmentExtraction(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
struct CustomOpSegmentExtraction : OrtW::CustomOpBase<CustomOpSegmentExtraction, KernelSegmentExtraction> {
|
||||
size_t GetInputTypeCount() const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
const char* GetName() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
};
|
||||
void segment_extraction(const ortc::Tensor<int64_t>& input,
|
||||
ortc::Tensor<int64_t>& output0,
|
||||
ortc::Tensor<int64_t>& output1);
|
||||
|
|
|
@ -3,17 +3,11 @@
|
|||
|
||||
#include "segment_sum.hpp"
|
||||
|
||||
template <typename T>
|
||||
void KernelSegmentSum_Compute(OrtW::CustomOpApi& ort_, OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* data = ort_.KernelContext_GetInput(context, 0);
|
||||
const T* p_data = ort_.GetTensorData<T>(data);
|
||||
const OrtValue* segment_ids = ort_.KernelContext_GetInput(context, 1);
|
||||
const int64_t* p_segment_ids = ort_.GetTensorData<int64_t>(segment_ids);
|
||||
|
||||
// Setup output
|
||||
OrtTensorDimensions dim_data(ort_, data);
|
||||
OrtTensorDimensions dim_seg(ort_, segment_ids);
|
||||
void segment_sum(const ortc::Tensor<float>& data,
|
||||
const ortc::Tensor<int64_t>& segment_ids,
|
||||
ortc::Tensor<float>& output) {
|
||||
auto& dim_data = data.Shape();
|
||||
auto& dim_seg = segment_ids.Shape();
|
||||
if (dim_data.size() == 0 || dim_seg.size() == 0)
|
||||
ORTX_CXX_API_THROW("Both inputs cannot be empty.", ORT_INVALID_GRAPH);
|
||||
if (dim_seg.size() != 1)
|
||||
|
@ -24,22 +18,24 @@ void KernelSegmentSum_Compute(OrtW::CustomOpApi& ort_, OrtKernelContext* context
|
|||
" segment_ids shape: ", dim_seg),
|
||||
ORT_INVALID_GRAPH);
|
||||
|
||||
const int64_t* p_segment_ids = segment_ids.Data();
|
||||
const float* p_data = data.Data();
|
||||
|
||||
int64_t last_seg = p_segment_ids[dim_seg[0] - 1];
|
||||
OrtTensorDimensions dim_out = dim_data;
|
||||
std::vector<int64_t> dim_out = dim_data;
|
||||
dim_out[0] = last_seg + 1;
|
||||
|
||||
OrtValue* v = ort_.KernelContext_GetOutput(context, 0, dim_out.data(), dim_out.size());
|
||||
T* p_output = ort_.GetTensorMutableData<T>(v);
|
||||
int64_t out_size = dim_out.Size();
|
||||
memset(p_output, 0, static_cast<size_t>(out_size * sizeof(T)));
|
||||
float* p_output = output.Allocate(dim_out);
|
||||
int64_t out_size = output.NumberOfElement();
|
||||
memset(p_output, 0, static_cast<size_t>(out_size * sizeof(float)));
|
||||
|
||||
// The implementation is naive. It could be parallelized and
|
||||
// use SIMD instructions to be faster.
|
||||
int64_t in_stride = dim_data.Size();
|
||||
const T* begin = p_data;
|
||||
const T* end = p_data + in_stride;
|
||||
int64_t in_stride = data.NumberOfElement();
|
||||
const float* begin = p_data;
|
||||
const float* end = p_data + in_stride;
|
||||
in_stride /= dim_data[0];
|
||||
T *p_out, *p_out_end;
|
||||
float *p_out, *p_out_end;
|
||||
const int64_t* p_seg = p_segment_ids;
|
||||
for (; begin != end; ++p_seg) {
|
||||
if ((p_seg != p_segment_ids) && (*p_seg != *(p_seg - 1)) && (*p_seg != *(p_seg - 1) + 1))
|
||||
|
@ -53,37 +49,3 @@ void KernelSegmentSum_Compute(OrtW::CustomOpApi& ort_, OrtKernelContext* context
|
|||
*p_out += *begin;
|
||||
}
|
||||
}
|
||||
|
||||
KernelSegmentSum::KernelSegmentSum(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void KernelSegmentSum::Compute(OrtKernelContext* context) {
|
||||
KernelSegmentSum_Compute<float>(ort_, context);
|
||||
}
|
||||
|
||||
size_t CustomOpSegmentSum::GetInputTypeCount() const {
|
||||
return 2;
|
||||
};
|
||||
|
||||
size_t CustomOpSegmentSum::GetOutputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpSegmentSum::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
};
|
||||
|
||||
const char* CustomOpSegmentSum::GetName() const {
|
||||
return "SegmentSum";
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpSegmentSum::GetInputType(size_t index) const {
|
||||
switch (index) {
|
||||
case 0:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
case 1:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
default:
|
||||
ORTX_CXX_API_THROW("Operator SegmentSum has 2 inputs.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -6,15 +6,6 @@
|
|||
#include "ocos.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
struct KernelSegmentSum : BaseKernel {
|
||||
KernelSegmentSum(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
struct CustomOpSegmentSum : OrtW::CustomOpBase<CustomOpSegmentSum, KernelSegmentSum> {
|
||||
size_t GetInputTypeCount() const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
const char* GetName() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
};
|
||||
void segment_sum(const ortc::Tensor<float>& data,
|
||||
const ortc::Tensor<int64_t>& segment_ids,
|
||||
ortc::Tensor<float>& output);
|
||||
|
|
|
@ -8,18 +8,12 @@
|
|||
#include <codecvt>
|
||||
#include <algorithm>
|
||||
|
||||
KernelMaskedFill::KernelMaskedFill(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void KernelMaskedFill::Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input_value = ort_.KernelContext_GetInput(context, 0);
|
||||
const OrtValue* input_mask = ort_.KernelContext_GetInput(context, 1);
|
||||
|
||||
OrtTensorDimensions value_dimensions(ort_, input_value);
|
||||
OrtTensorDimensions mask_dimensions(ort_, input_mask);
|
||||
|
||||
if (!(value_dimensions.IsScalar() || value_dimensions.IsVector())) {
|
||||
void masked_fill(const ortc::Tensor<std::string>& input,
|
||||
const ortc::Tensor<bool>& input_mask,
|
||||
ortc::Tensor<std::string>& output) {
|
||||
auto& value_dimensions = input.Shape();
|
||||
auto& mask_dimensions = input_mask.Shape();
|
||||
if (!(value_dimensions.empty() || mask_dimensions.size() == 1)) {
|
||||
ORTX_CXX_API_THROW("[MaskedFill]: the dimension of input value should be vector or scalar.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
|
@ -27,11 +21,8 @@ void KernelMaskedFill::Compute(OrtKernelContext* context) {
|
|||
ORTX_CXX_API_THROW("[MaskedFill]: the dimension of input value and mask should be same.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
std::vector<std::string> value;
|
||||
const bool* mask = nullptr;
|
||||
|
||||
GetTensorMutableDataString(api_, ort_, context, input_value, value);
|
||||
mask = ort_.GetTensorData<bool>(input_mask);
|
||||
auto& value = input.Data();
|
||||
const bool* mask = input_mask.Data();
|
||||
|
||||
std::vector<std::string> result;
|
||||
std::vector<int64_t> result_dimension;
|
||||
|
@ -44,33 +35,5 @@ void KernelMaskedFill::Compute(OrtKernelContext* context) {
|
|||
result.push_back(value[i]);
|
||||
}
|
||||
result_dimension.push_back(result.size());
|
||||
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, result_dimension.data(), result_dimension.size());
|
||||
|
||||
FillTensorDataString(api_, ort_, context, result, output);
|
||||
output.SetStringOutput(result, result_dimension);
|
||||
}
|
||||
|
||||
const char* CustomOpMaskedFill::GetName() const { return "MaskedFill"; };
|
||||
|
||||
size_t CustomOpMaskedFill::GetInputTypeCount() const {
|
||||
return 2;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpMaskedFill::GetInputType(size_t index) const {
|
||||
switch (index) {
|
||||
case 0:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
case 1:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
|
||||
default:
|
||||
ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
};
|
||||
|
||||
size_t CustomOpMaskedFill::GetOutputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpMaskedFill::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
||||
|
|
|
@ -7,18 +7,6 @@
|
|||
#include "string_utils.h"
|
||||
#include <unordered_map>
|
||||
|
||||
struct KernelMaskedFill : BaseKernel {
|
||||
KernelMaskedFill(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, std::string> map_;
|
||||
};
|
||||
|
||||
struct CustomOpMaskedFill : OrtW::CustomOpBase<CustomOpMaskedFill, KernelMaskedFill> {
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
};
|
||||
void masked_fill(const ortc::Tensor<std::string>& input,
|
||||
const ortc::Tensor<bool>& input_mask,
|
||||
ortc::Tensor<std::string>& output);
|
||||
|
|
|
@ -8,26 +8,9 @@
|
|||
KernelStringEqual::KernelStringEqual(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void KernelStringEqual::Compute(OrtKernelContext* context) {
|
||||
void KernelStringEqual::Compute(OrtKernelContext* context,
|
||||
const ortc::Tensor<std::string>&,
|
||||
const ortc::Tensor<std::string>&,
|
||||
ortc::Tensor<bool>& output) {
|
||||
KernelEqual_Compute<std::string>(api_, ort_, context);
|
||||
}
|
||||
|
||||
size_t CustomOpStringEqual::GetInputTypeCount() const {
|
||||
return 2;
|
||||
};
|
||||
|
||||
size_t CustomOpStringEqual::GetOutputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringEqual::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
|
||||
};
|
||||
|
||||
const char* CustomOpStringEqual::GetName() const {
|
||||
return "StringEqual";
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringEqual::GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
||||
|
|
|
@ -8,13 +8,8 @@
|
|||
|
||||
struct KernelStringEqual : BaseKernel {
|
||||
KernelStringEqual(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
struct CustomOpStringEqual : OrtW::CustomOpBase<CustomOpStringEqual, KernelStringEqual> {
|
||||
size_t GetInputTypeCount() const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
const char* GetName() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
void Compute(OrtKernelContext* context,
|
||||
const ortc::Tensor<std::string>&,
|
||||
const ortc::Tensor<std::string>&,
|
||||
ortc::Tensor<bool>& output);
|
||||
};
|
||||
|
|
|
@ -4,11 +4,12 @@
|
|||
#include "string_tensor.h"
|
||||
#include "op_ragged_tensor.hpp"
|
||||
|
||||
void KernelRaggedTensorToSparse::Compute(OrtKernelContext* context) {
|
||||
const OrtValue* n_elements = ort_.KernelContext_GetInput(context, 0);
|
||||
const int64_t* p_n_elements = ort_.GetTensorData<int64_t>(n_elements);
|
||||
void KernelRaggedTensoroSparse::Compute(const ortc::Tensor<int64_t>& n_element,
|
||||
ortc::Tensor<int64_t>& output_0,
|
||||
ortc::Tensor<int64_t>& output_1) {
|
||||
const int64_t* p_n_elements = n_element.Data();
|
||||
|
||||
OrtTensorDimensions d_length(ort_, n_elements);
|
||||
auto& d_length = n_element.Shape();
|
||||
|
||||
if (d_length.size() != 1)
|
||||
ORTX_CXX_API_THROW(MakeString(
|
||||
|
@ -19,10 +20,8 @@ void KernelRaggedTensorToSparse::Compute(OrtKernelContext* context) {
|
|||
std::vector<int64_t> shape{n_values, 2};
|
||||
std::vector<int64_t> shape2{2};
|
||||
|
||||
OrtValue* v0 = ort_.KernelContext_GetOutput(context, 0, shape.data(), shape.size());
|
||||
int64_t* out0 = ort_.GetTensorMutableData<int64_t>(v0);
|
||||
OrtValue* v1 = ort_.KernelContext_GetOutput(context, 1, shape2.data(), shape2.size());
|
||||
int64_t* out1 = ort_.GetTensorMutableData<int64_t>(v1);
|
||||
int64_t* out0 = output_0.Allocate(shape);
|
||||
int64_t* out1 = output_1.Allocate(shape2);
|
||||
out1[0] = n_els;
|
||||
out1[1] = 0;
|
||||
int64_t row = 0;
|
||||
|
@ -39,38 +38,18 @@ void KernelRaggedTensorToSparse::Compute(OrtKernelContext* context) {
|
|||
}
|
||||
}
|
||||
|
||||
size_t CustomOpRaggedTensorToSparse::GetInputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
size_t CustomOpRaggedTensorToSparse::GetOutputTypeCount() const {
|
||||
return 2;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpRaggedTensorToSparse::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
};
|
||||
|
||||
const char* CustomOpRaggedTensorToSparse::GetName() const {
|
||||
return "RaggedTensorToSparse";
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpRaggedTensorToSparse::GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
};
|
||||
|
||||
CommonRaggedTensorToDense::CommonRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo& info)
|
||||
CommonRaggedTensoroDense::CommonRaggedTensoroDense(const OrtApi& api, const OrtKernelInfo& info)
|
||||
: BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void CommonRaggedTensorToDense::GetInputDims(OrtKernelContext* context, const OrtValue** inputs, OrtTensorDimensions* dims) {
|
||||
void CommonRaggedTensoroDense::GetInputDims(OrtKernelContext* context, const OrtValue** inputs, OrtTensorDimensions* dims) {
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
inputs[i] = ort_.KernelContext_GetInput(context, i);
|
||||
dims[i] = OrtTensorDimensions(ort_, inputs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
int64_t CommonRaggedTensorToDense::GetMaxCol(int64_t n, const int64_t* p_indices) {
|
||||
int64_t CommonRaggedTensoroDense::GetMaxCol(int64_t n, const int64_t* p_indices) {
|
||||
int64_t size = n;
|
||||
int64_t max_col = 0;
|
||||
for (int64_t i = 1; i < size; ++i) {
|
||||
|
@ -79,26 +58,25 @@ int64_t CommonRaggedTensorToDense::GetMaxCol(int64_t n, const int64_t* p_indices
|
|||
return max_col;
|
||||
}
|
||||
|
||||
KernelRaggedTensorToDense::KernelRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo& info)
|
||||
: CommonRaggedTensorToDense(api, info) {
|
||||
missing_value_ = TryToGetAttributeWithDefault("missing_value", -1) ;
|
||||
KernelRaggedTensoroDense::KernelRaggedTensoroDense(const OrtApi& api, const OrtKernelInfo& info)
|
||||
: CommonRaggedTensoroDense(api, info) {
|
||||
missing_value_ = TryToGetAttributeWithDefault("missing_value", -1);
|
||||
}
|
||||
|
||||
void KernelRaggedTensorToDense::Compute(OrtKernelContext* context) {
|
||||
const OrtValue* inputs[4];
|
||||
OrtTensorDimensions dims[4];
|
||||
GetInputDims(context, inputs, dims);
|
||||
void KernelRaggedTensoroDense::Compute(const ortc::Tensor<int64_t>& input0,
|
||||
const ortc::Tensor<int64_t>& input1,
|
||||
const ortc::Tensor<int64_t>& input2,
|
||||
const ortc::Tensor<int64_t>& input3,
|
||||
ortc::Tensor<int64_t>& output) {
|
||||
const int64_t* p_values = input1.Data();
|
||||
const int64_t* p_missing = input2.Data();
|
||||
const int64_t* p_indices = input3.Data();
|
||||
|
||||
const int64_t* p_values = ort_.GetTensorData<int64_t>(inputs[1]);
|
||||
const int64_t* p_missing = ort_.GetTensorData<int64_t>(inputs[2]);
|
||||
const int64_t* p_indices = ort_.GetTensorData<int64_t>(inputs[3]);
|
||||
|
||||
int64_t size = dims[3].Size();
|
||||
int64_t size = input3.NumberOfElement();
|
||||
int64_t max_col = GetMaxCol(size, p_indices);
|
||||
|
||||
std::vector<int64_t> shape_out{size - 1, max_col};
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, shape_out.data(), shape_out.size());
|
||||
int64_t* dense = ort_.GetTensorMutableData<int64_t>(output);
|
||||
int64_t* dense = output.Allocate(shape_out);
|
||||
|
||||
int64_t pos = 0;
|
||||
int64_t j, pos_end;
|
||||
|
@ -119,38 +97,20 @@ void KernelRaggedTensorToDense::Compute(OrtKernelContext* context) {
|
|||
}
|
||||
}
|
||||
|
||||
size_t CustomOpRaggedTensorToDense::GetInputTypeCount() const {
|
||||
return 4;
|
||||
};
|
||||
|
||||
size_t CustomOpRaggedTensorToDense::GetOutputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpRaggedTensorToDense::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
};
|
||||
|
||||
const char* CustomOpRaggedTensorToDense::GetName() const {
|
||||
return "RaggedTensorToDense";
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpRaggedTensorToDense::GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
};
|
||||
|
||||
KernelStringRaggedTensorToDense::KernelStringRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo& info) : CommonRaggedTensorToDense(api, info) {
|
||||
KernelStringRaggedTensoroDense::KernelStringRaggedTensoroDense(const OrtApi& api, const OrtKernelInfo& info) : CommonRaggedTensoroDense(api, info) {
|
||||
}
|
||||
|
||||
void KernelStringRaggedTensorToDense::Compute(OrtKernelContext* context) {
|
||||
const OrtValue* inputs[4];
|
||||
OrtTensorDimensions dims[4];
|
||||
GetInputDims(context, inputs, dims);
|
||||
void KernelStringRaggedTensoroDense::Compute(const ortc::Tensor<int64_t>& input0,
|
||||
const ortc::Tensor<std::string>& input1,
|
||||
const ortc::Tensor<int64_t>& input2,
|
||||
const ortc::Tensor<std::string>& input3,
|
||||
ortc::Tensor<std::string>& output) {
|
||||
// const OrtValue* inputs[4];
|
||||
// OrtTensorDimensions dims[4];
|
||||
|
||||
std::vector<std::string> input;
|
||||
GetTensorMutableDataString(api_, ort_, context, inputs[1], input);
|
||||
const int64_t* p_indices = ort_.GetTensorData<int64_t>(inputs[3]);
|
||||
int64_t size = dims[3].Size();
|
||||
auto& input = input1.Data();
|
||||
const int64_t* p_indices = input2.Data();
|
||||
int64_t size = input3.NumberOfElement();
|
||||
int64_t max_col = GetMaxCol(size, p_indices);
|
||||
std::vector<int64_t> shape_out{size - 1, max_col};
|
||||
|
||||
|
@ -170,36 +130,5 @@ void KernelStringRaggedTensorToDense::Compute(OrtKernelContext* context) {
|
|||
}
|
||||
pos = pos_end;
|
||||
}
|
||||
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, shape_out.data(), shape_out.size());
|
||||
FillTensorDataString(api_, ort_, context, dense, output);
|
||||
output.SetStringOutput(dense, shape_out);
|
||||
}
|
||||
|
||||
size_t CustomOpStringRaggedTensorToDense::GetInputTypeCount() const {
|
||||
return 4;
|
||||
};
|
||||
|
||||
size_t CustomOpStringRaggedTensorToDense::GetOutputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringRaggedTensorToDense::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
||||
|
||||
const char* CustomOpStringRaggedTensorToDense::GetName() const {
|
||||
return "StringRaggedTensorToDense";
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringRaggedTensorToDense::GetInputType(size_t index) const {
|
||||
switch (index) {
|
||||
case 1:
|
||||
case 2:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
case 0:
|
||||
case 3:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
default:
|
||||
ORTX_CXX_API_THROW(MakeString("[StringRaggedTensorToDense] Unexpected output index ", index, "."), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -5,55 +5,40 @@
|
|||
|
||||
#include "ocos.h"
|
||||
|
||||
struct KernelRaggedTensorToSparse : BaseKernel {
|
||||
KernelRaggedTensorToSparse(const OrtApi& api, const OrtKernelInfo& info)
|
||||
struct KernelRaggedTensoroSparse : BaseKernel {
|
||||
KernelRaggedTensoroSparse(const OrtApi& api, const OrtKernelInfo& info)
|
||||
: BaseKernel(api, info) {}
|
||||
|
||||
void Compute(OrtKernelContext* context);
|
||||
void Compute(const ortc::Tensor<int64_t>& n_element,
|
||||
ortc::Tensor<int64_t>& output_0,
|
||||
ortc::Tensor<int64_t>& output_1);
|
||||
};
|
||||
|
||||
struct CustomOpRaggedTensorToSparse : OrtW::CustomOpBase<CustomOpRaggedTensorToSparse, KernelRaggedTensorToSparse> {
|
||||
size_t GetInputTypeCount() const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
const char* GetName() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
};
|
||||
|
||||
struct CommonRaggedTensorToDense : BaseKernel {
|
||||
CommonRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo& info);
|
||||
struct CommonRaggedTensoroDense : BaseKernel {
|
||||
CommonRaggedTensoroDense(const OrtApi& api, const OrtKernelInfo& info);
|
||||
|
||||
protected:
|
||||
void GetInputDims(OrtKernelContext* context, const OrtValue** inputs, OrtTensorDimensions* dims);
|
||||
int64_t GetMaxCol(int64_t n, const int64_t* p_indices);
|
||||
};
|
||||
|
||||
struct KernelRaggedTensorToDense : CommonRaggedTensorToDense {
|
||||
KernelRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
struct KernelRaggedTensoroDense : CommonRaggedTensoroDense {
|
||||
KernelRaggedTensoroDense(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(const ortc::Tensor<int64_t>& input0,
|
||||
const ortc::Tensor<int64_t>& input1,
|
||||
const ortc::Tensor<int64_t>& input2,
|
||||
const ortc::Tensor<int64_t>& input3,
|
||||
ortc::Tensor<int64_t>& output);
|
||||
|
||||
private:
|
||||
int64_t missing_value_;
|
||||
};
|
||||
|
||||
struct CustomOpRaggedTensorToDense : OrtW::CustomOpBase<CustomOpRaggedTensorToDense, KernelRaggedTensorToDense> {
|
||||
size_t GetInputTypeCount() const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
const char* GetName() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
};
|
||||
|
||||
struct KernelStringRaggedTensorToDense : CommonRaggedTensorToDense {
|
||||
KernelStringRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
struct CustomOpStringRaggedTensorToDense : OrtW::CustomOpBase<CustomOpStringRaggedTensorToDense,
|
||||
KernelStringRaggedTensorToDense> {
|
||||
size_t GetInputTypeCount() const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
const char* GetName() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
struct KernelStringRaggedTensoroDense : CommonRaggedTensoroDense {
|
||||
KernelStringRaggedTensoroDense(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(const ortc::Tensor<int64_t>& input0,
|
||||
const ortc::Tensor<std::string>& input1,
|
||||
const ortc::Tensor<int64_t>& input2,
|
||||
const ortc::Tensor<std::string>& input3,
|
||||
ortc::Tensor<std::string>& output);
|
||||
};
|
||||
|
|
|
@ -13,45 +13,21 @@ KernelStringRegexReplace::KernelStringRegexReplace(const OrtApi& api, const OrtK
|
|||
global_replace_ = TryToGetAttributeWithDefault("global_replace",1);
|
||||
}
|
||||
|
||||
void KernelStringRegexReplace::Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
|
||||
const OrtValue* pattern = ort_.KernelContext_GetInput(context, 1);
|
||||
const OrtValue* rewrite = ort_.KernelContext_GetInput(context, 2);
|
||||
|
||||
std::vector<std::string> str_input, str_pattern, str_rewrite;
|
||||
GetTensorMutableDataString(api_, ort_, context, input, str_input);
|
||||
GetTensorMutableDataString(api_, ort_, context, pattern, str_pattern);
|
||||
GetTensorMutableDataString(api_, ort_, context, rewrite, str_rewrite);
|
||||
|
||||
// Verifications
|
||||
OrtTensorDimensions pattern_dimensions(ort_, pattern);
|
||||
OrtTensorDimensions rewrite_dimensions(ort_, rewrite);
|
||||
if (pattern_dimensions.size() != 1 || pattern_dimensions[0] != 1)
|
||||
ORTX_CXX_API_THROW(MakeString(
|
||||
"pattern (second input) must contain only one element. It has ",
|
||||
pattern_dimensions.size(), " dimensions."),
|
||||
ORT_INVALID_ARGUMENT);
|
||||
if (rewrite_dimensions.size() != 1 || rewrite_dimensions[0] != 1)
|
||||
ORTX_CXX_API_THROW(MakeString(
|
||||
"rewrite (third input) must contain only one element. It has ",
|
||||
rewrite_dimensions.size(), " dimensions."),
|
||||
ORT_INVALID_ARGUMENT);
|
||||
if (str_pattern[0].empty())
|
||||
void KernelStringRegexReplace::Compute(const ortc::Tensor<std::string>& input,
|
||||
std::string_view str_pattern,
|
||||
std::string_view str_rewrite,
|
||||
ortc::Tensor<std::string>& output) {
|
||||
if (str_pattern.empty())
|
||||
ORTX_CXX_API_THROW("pattern (second input) cannot be empty.", ORT_INVALID_ARGUMENT);
|
||||
|
||||
// Setup output
|
||||
OrtTensorDimensions dimensions(ort_, input);
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
|
||||
std::vector<std::string> str_input{input.Data()};
|
||||
auto dim = input.Shape();
|
||||
size_t size = input.NumberOfElement();
|
||||
|
||||
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
|
||||
size_t size = ort_.GetTensorShapeElementCount(output_info);
|
||||
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
|
||||
re2::StringPiece piece(str_rewrite.data());
|
||||
re2::RE2 reg(str_pattern.data());
|
||||
|
||||
re2::StringPiece piece(str_rewrite[0]);
|
||||
re2::RE2 reg(str_pattern[0]);
|
||||
|
||||
// Do computation
|
||||
if (global_replace_) {
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
re2::RE2::GlobalReplace(&(str_input[i]), reg, piece);
|
||||
|
@ -61,24 +37,5 @@ void KernelStringRegexReplace::Compute(OrtKernelContext* context) {
|
|||
re2::RE2::Replace(&(str_input[i]), reg, piece);
|
||||
}
|
||||
}
|
||||
|
||||
FillTensorDataString(api_, ort_, context, str_input, output);
|
||||
}
|
||||
|
||||
const char* CustomOpStringRegexReplace::GetName() const { return "StringRegexReplace"; };
|
||||
|
||||
size_t CustomOpStringRegexReplace::GetInputTypeCount() const {
|
||||
return 3;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringRegexReplace::GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
||||
|
||||
size_t CustomOpStringRegexReplace::GetOutputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringRegexReplace::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
||||
output.SetStringOutput(str_input, dim);
|
||||
}
|
|
@ -8,16 +8,11 @@
|
|||
|
||||
struct KernelStringRegexReplace : BaseKernel {
|
||||
KernelStringRegexReplace(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
void Compute(const ortc::Tensor<std::string>& input,
|
||||
std::string_view str_pattern,
|
||||
std::string_view str_rewrite,
|
||||
ortc::Tensor<std::string>& output);
|
||||
|
||||
protected:
|
||||
int64_t global_replace_;
|
||||
};
|
||||
|
||||
struct CustomOpStringRegexReplace : OrtW::CustomOpBase<CustomOpStringRegexReplace, KernelStringRegexReplace> {
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
};
|
||||
};
|
|
@ -7,39 +7,30 @@
|
|||
#include <vector>
|
||||
#include <cmath>
|
||||
|
||||
KernelStringRegexSplitWithOffsets::KernelStringRegexSplitWithOffsets(const OrtApi& api, const OrtKernelInfo& info)
|
||||
: BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void KernelStringRegexSplitWithOffsets::Compute(OrtKernelContext* context) {
|
||||
void KernelStringRegexSplitWithOffsets(const ortc::Tensor<std::string>& input,
|
||||
std::string_view str_pattern,
|
||||
const ortc::Tensor<std::string>& str_keep_pattern,
|
||||
ortc::Tensor<std::string>& output_text,
|
||||
ortc::Tensor<int64_t>& output_begin,
|
||||
ortc::Tensor<int64_t>& output_end,
|
||||
ortc::Tensor<int64_t>& output_offset) {
|
||||
// Setup inputs
|
||||
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
|
||||
const OrtValue* pattern = ort_.KernelContext_GetInput(context, 1);
|
||||
const OrtValue* keep_pattern = ort_.KernelContext_GetInput(context, 2);
|
||||
|
||||
std::vector<std::string> str_input, str_pattern, str_keep_pattern;
|
||||
GetTensorMutableDataString(api_, ort_, context, input, str_input);
|
||||
GetTensorMutableDataString(api_, ort_, context, pattern, str_pattern);
|
||||
GetTensorMutableDataString(api_, ort_, context, keep_pattern, str_keep_pattern);
|
||||
std::vector<std::string> str_input(input.Data());
|
||||
|
||||
// Verifications
|
||||
OrtTensorDimensions keep_pattern_dimensions(ort_, keep_pattern);
|
||||
if (str_pattern.size() != 1)
|
||||
ORTX_CXX_API_THROW(MakeString("pattern (second input) must contain only one element. It has ",
|
||||
str_pattern.size(), " values."),
|
||||
ORT_INVALID_ARGUMENT);
|
||||
if (str_keep_pattern.size() > 1)
|
||||
ORTX_CXX_API_THROW(MakeString("Third input must contain only one element. It has ",
|
||||
str_keep_pattern.size(), " values."),
|
||||
ORT_INVALID_ARGUMENT);
|
||||
if (str_pattern[0].empty())
|
||||
if (str_pattern.empty()) {
|
||||
ORTX_CXX_API_THROW("Splitting pattern cannot be empty.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
if (str_keep_pattern.Data().size() > 1) {
|
||||
ORTX_CXX_API_THROW(MakeString("Third input must contain only one element. It has ",
|
||||
str_keep_pattern.Data().size(), " values."),
|
||||
ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
auto dimensions = input.Shape();
|
||||
bool include_delimiter = (str_keep_pattern.Data().size() == 1) && (!str_keep_pattern.Data()[0].empty());
|
||||
|
||||
OrtTensorDimensions dimensions(ort_, input);
|
||||
bool include_delimiter = (str_keep_pattern.size() == 1) && (!str_keep_pattern[0].empty());
|
||||
|
||||
re2::RE2 reg(str_pattern[0]);
|
||||
re2::RE2 keep_reg(include_delimiter ? str_keep_pattern[0] : "");
|
||||
re2::RE2 reg(str_pattern.data());
|
||||
re2::RE2 keep_reg(include_delimiter ? str_keep_pattern.Data()[0].data() : "");
|
||||
|
||||
std::vector<std::string> all_tokens;
|
||||
std::vector<int64_t> all_begin_offsets, all_end_offsets;
|
||||
|
@ -63,47 +54,15 @@ void KernelStringRegexSplitWithOffsets::Compute(OrtKernelContext* context) {
|
|||
|
||||
// Setup output
|
||||
std::vector<int64_t> dim_out{(int64_t)all_tokens.size()};
|
||||
OrtValue* output_text = ort_.KernelContext_GetOutput(context, 0, dim_out.data(), dim_out.size());
|
||||
FillTensorDataString(api_, ort_, context, all_tokens, output_text);
|
||||
output_text.SetStringOutput(all_tokens, dim_out);
|
||||
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 1, dim_out.data(), dim_out.size());
|
||||
int64_t* p_output = ort_.GetTensorMutableData<int64_t>(output);
|
||||
memcpy(p_output, all_begin_offsets.data(), all_begin_offsets.size() * sizeof(int64_t));
|
||||
auto output_raw = output_begin.Allocate(dim_out);
|
||||
memcpy(output_raw, all_begin_offsets.data(), all_begin_offsets.size() * sizeof(int64_t));
|
||||
|
||||
output = ort_.KernelContext_GetOutput(context, 2, dim_out.data(), dim_out.size());
|
||||
p_output = ort_.GetTensorMutableData<int64_t>(output);
|
||||
memcpy(p_output, all_end_offsets.data(), all_end_offsets.size() * sizeof(int64_t));
|
||||
output_raw = output_end.Allocate(dim_out);
|
||||
memcpy(output_raw, all_end_offsets.data(), all_end_offsets.size() * sizeof(int64_t));
|
||||
|
||||
std::vector<int64_t> dim_out_row{(int64_t)row_offsets.size()};
|
||||
output = ort_.KernelContext_GetOutput(context, 3, dim_out_row.data(), dim_out_row.size());
|
||||
p_output = ort_.GetTensorMutableData<int64_t>(output);
|
||||
memcpy(p_output, row_offsets.data(), row_offsets.size() * sizeof(int64_t));
|
||||
output_raw = output_offset.Allocate(dim_out_row);
|
||||
memcpy(output_raw, row_offsets.data(), row_offsets.size() * sizeof(int64_t));
|
||||
}
|
||||
|
||||
const char* CustomOpStringRegexSplitWithOffsets::GetName() const { return "StringRegexSplitWithOffsets"; };
|
||||
|
||||
size_t CustomOpStringRegexSplitWithOffsets::GetInputTypeCount() const {
|
||||
return 3;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringRegexSplitWithOffsets::GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
||||
|
||||
size_t CustomOpStringRegexSplitWithOffsets::GetOutputTypeCount() const {
|
||||
return 4;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringRegexSplitWithOffsets::GetOutputType(size_t index) const {
|
||||
switch (index) {
|
||||
case 0:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
case 1:
|
||||
case 2:
|
||||
case 3:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
default:
|
||||
ORTX_CXX_API_THROW(MakeString("StringRegexSplitWithOffsets has 4 outputs but index is ", index, "."),
|
||||
ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -7,15 +7,10 @@
|
|||
#include "string_utils.h"
|
||||
|
||||
// See https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/regex_split_with_offsets.md.
|
||||
struct KernelStringRegexSplitWithOffsets : BaseKernel {
|
||||
KernelStringRegexSplitWithOffsets(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
struct CustomOpStringRegexSplitWithOffsets : OrtW::CustomOpBase<CustomOpStringRegexSplitWithOffsets, KernelStringRegexSplitWithOffsets> {
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
};
|
||||
void KernelStringRegexSplitWithOffsets(const ortc::Tensor<std::string>& input,
|
||||
std::string_view str_pattern,
|
||||
const ortc::Tensor<std::string>& str_keep_pattern,
|
||||
ortc::Tensor<std::string>& output_text,
|
||||
ortc::Tensor<int64_t>& output_begin,
|
||||
ortc::Tensor<int64_t>& output_end,
|
||||
ortc::Tensor<int64_t>& output_offset);
|
|
@ -8,48 +8,19 @@
|
|||
#include <codecvt>
|
||||
#include <algorithm>
|
||||
|
||||
KernelStringConcat::KernelStringConcat(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void KernelStringConcat::Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* left = ort_.KernelContext_GetInput(context, 0);
|
||||
const OrtValue* right = ort_.KernelContext_GetInput(context, 1);
|
||||
OrtTensorDimensions left_dim(ort_, left);
|
||||
OrtTensorDimensions right_dim(ort_, right);
|
||||
|
||||
if (left_dim != right_dim) {
|
||||
void string_concat(const ortc::Tensor<std::string>& left,
|
||||
const ortc::Tensor<std::string>& right,
|
||||
ortc::Tensor<std::string>& output) {
|
||||
if (left.Shape() != right.Shape()) {
|
||||
ORTX_CXX_API_THROW("Two input tensor should have the same dimension.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
// make a copy as input is const
|
||||
std::vector<std::string> left_value = left.Data();
|
||||
auto& right_value = right.Data();
|
||||
|
||||
std::vector<std::string> left_value;
|
||||
std::vector<std::string> right_value;
|
||||
GetTensorMutableDataString(api_, ort_, context, left, left_value);
|
||||
GetTensorMutableDataString(api_, ort_, context, right, right_value);
|
||||
|
||||
// reuse left_value as output to save memory
|
||||
for (size_t i = 0; i < left_value.size(); i++) {
|
||||
left_value[i].append(right_value[i]);
|
||||
}
|
||||
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, left_dim.data(), left_dim.size());
|
||||
FillTensorDataString(api_, ort_, context, left_value, output);
|
||||
output.SetStringOutput(left_value, left.Shape());
|
||||
}
|
||||
|
||||
const char* CustomOpStringConcat::GetName() const { return "StringConcat"; };
|
||||
|
||||
size_t CustomOpStringConcat::GetInputTypeCount() const {
|
||||
return 2;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringConcat::GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
||||
|
||||
size_t CustomOpStringConcat::GetOutputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringConcat::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
||||
|
|
|
@ -6,15 +6,6 @@
|
|||
#include "ocos.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringConcat : BaseKernel {
|
||||
KernelStringConcat(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
struct CustomOpStringConcat : OrtW::CustomOpBase<CustomOpStringConcat, KernelStringConcat> {
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
};
|
||||
void string_concat(const ortc::Tensor<std::string>& left,
|
||||
const ortc::Tensor<std::string>& right,
|
||||
ortc::Tensor<std::string>& output);
|
||||
|
|
|
@ -13,76 +13,34 @@ KernelStringECMARegexReplace::KernelStringECMARegexReplace(const OrtApi& api, co
|
|||
ignore_case_ = TryToGetAttributeWithDefault("ignore_case", false);
|
||||
}
|
||||
|
||||
void KernelStringECMARegexReplace::Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
|
||||
const OrtValue* pattern = ort_.KernelContext_GetInput(context, 1);
|
||||
const OrtValue* rewrite = ort_.KernelContext_GetInput(context, 2);
|
||||
|
||||
std::vector<std::string> str_input, str_pattern, str_rewrite;
|
||||
GetTensorMutableDataString(api_, ort_, context, input, str_input);
|
||||
GetTensorMutableDataString(api_, ort_, context, pattern, str_pattern);
|
||||
GetTensorMutableDataString(api_, ort_, context, rewrite, str_rewrite);
|
||||
|
||||
// Verifications
|
||||
OrtTensorDimensions pattern_dimensions(ort_, pattern);
|
||||
OrtTensorDimensions rewrite_dimensions(ort_, rewrite);
|
||||
if (pattern_dimensions.Size() != 1) {
|
||||
ORTX_CXX_API_THROW(MakeString("pattern (second input) must contain only one element. It has ",
|
||||
pattern_dimensions.size(), " dimensions."),
|
||||
ORT_INVALID_GRAPH);
|
||||
}
|
||||
if (rewrite_dimensions.Size() != 1) {
|
||||
ORTX_CXX_API_THROW(MakeString("rewrite (third input) must contain only one element. It has ",
|
||||
rewrite_dimensions.size(), " dimensions."),
|
||||
ORT_INVALID_GRAPH);
|
||||
}
|
||||
if (str_pattern[0].empty()) {
|
||||
void KernelStringECMARegexReplace::Compute(const ortc::Tensor<std::string>& input,
|
||||
std::string_view pattern,
|
||||
std::string_view rewrite,
|
||||
ortc::Tensor<std::string>& output) {
|
||||
// make a copy as input is constant;
|
||||
std::vector<std::string> str_input = input.Data();
|
||||
if (pattern.empty()) {
|
||||
ORTX_CXX_API_THROW("pattern (second input) cannot be empty.", ORT_INVALID_GRAPH);
|
||||
}
|
||||
|
||||
// Setup output
|
||||
OrtTensorDimensions dimensions(ort_, input);
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
|
||||
|
||||
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
|
||||
size_t size = ort_.GetTensorShapeElementCount(output_info);
|
||||
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
|
||||
size_t size = input.NumberOfElement();
|
||||
|
||||
auto regex_flag = std::regex_constants::optimize | std::regex_constants::ECMAScript;
|
||||
if (ignore_case_) {
|
||||
regex_flag |= std::regex_constants::icase;
|
||||
}
|
||||
|
||||
std::regex reg(str_pattern[0], regex_flag);
|
||||
std::regex reg(pattern.data(), regex_flag);
|
||||
|
||||
if (global_replace_) {
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
str_input[i] = std::regex_replace(str_input[i], reg, str_rewrite[0]);
|
||||
str_input[i] = std::regex_replace(str_input[i], reg, rewrite.data());
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
str_input[i] = std::regex_replace(str_input[i], reg, str_rewrite[0], std::regex_constants::format_first_only);
|
||||
str_input[i] = std::regex_replace(str_input[i], reg, rewrite.data(), std::regex_constants::format_first_only);
|
||||
}
|
||||
}
|
||||
|
||||
FillTensorDataString(api_, ort_, context, str_input, output);
|
||||
auto& dimensions = input.Shape();
|
||||
output.SetStringOutput(str_input, dimensions);
|
||||
}
|
||||
|
||||
const char* CustomOpStringECMARegexReplace::GetName() const { return "StringECMARegexReplace"; };
|
||||
|
||||
size_t CustomOpStringECMARegexReplace::GetInputTypeCount() const {
|
||||
return 3;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringECMARegexReplace::GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
||||
|
||||
size_t CustomOpStringECMARegexReplace::GetOutputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringECMARegexReplace::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
||||
|
|
|
@ -8,17 +8,12 @@
|
|||
|
||||
struct KernelStringECMARegexReplace : BaseKernel {
|
||||
KernelStringECMARegexReplace(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
void Compute(const ortc::Tensor<std::string>& input,
|
||||
std::string_view pattern,
|
||||
std::string_view rewrite,
|
||||
ortc::Tensor<std::string>& output);
|
||||
|
||||
protected:
|
||||
bool global_replace_;
|
||||
bool ignore_case_;
|
||||
};
|
||||
|
||||
struct CustomOpStringECMARegexReplace : OrtW::CustomOpBase<CustomOpStringECMARegexReplace, KernelStringECMARegexReplace> {
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
};
|
||||
|
|
|
@ -15,40 +15,26 @@ KernelStringECMARegexSplitWithOffsets::KernelStringECMARegexSplitWithOffsets(con
|
|||
ignore_case_ = TryToGetAttributeWithDefault("ignore_case", false);
|
||||
}
|
||||
|
||||
void KernelStringECMARegexSplitWithOffsets::Compute(OrtKernelContext* context) {
|
||||
void KernelStringECMARegexSplitWithOffsets::Compute(const ortc::Tensor<std::string>& input,
|
||||
std::string_view pattern,
|
||||
std::string_view keep_pattern,
|
||||
ortc::Tensor<std::string>& output_text,
|
||||
ortc::Tensor<int64_t>& output1,
|
||||
ortc::Tensor<int64_t>& output2,
|
||||
ortc::Tensor<int64_t>& output3) {
|
||||
// Setup inputs
|
||||
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
|
||||
const OrtValue* pattern = ort_.KernelContext_GetInput(context, 1);
|
||||
const OrtValue* keep_pattern = ort_.KernelContext_GetInput(context, 2);
|
||||
auto& str_input = input.Data();
|
||||
|
||||
std::vector<std::string> str_input, str_pattern, str_keep_pattern;
|
||||
GetTensorMutableDataString(api_, ort_, context, input, str_input);
|
||||
GetTensorMutableDataString(api_, ort_, context, pattern, str_pattern);
|
||||
GetTensorMutableDataString(api_, ort_, context, keep_pattern, str_keep_pattern);
|
||||
|
||||
// Verifications
|
||||
OrtTensorDimensions keep_pattern_dimensions(ort_, keep_pattern);
|
||||
if (str_pattern.size() != 1)
|
||||
ORTX_CXX_API_THROW(MakeString("pattern (second input) must contain only one element. It has ", str_pattern.size(),
|
||||
" values."),
|
||||
ORT_INVALID_GRAPH);
|
||||
if (str_keep_pattern.size() > 1)
|
||||
ORTX_CXX_API_THROW(MakeString("Third input must contain only one element. It has ", str_keep_pattern.size(),
|
||||
" values."),
|
||||
ORT_INVALID_GRAPH);
|
||||
if (str_pattern[0].empty())
|
||||
ORTX_CXX_API_THROW("Splitting pattern cannot be empty.", ORT_INVALID_GRAPH);
|
||||
|
||||
OrtTensorDimensions dimensions(ort_, input);
|
||||
bool include_delimiter = (str_keep_pattern.size() == 1) && (!str_keep_pattern[0].empty());
|
||||
auto& dimensions = input.Shape();
|
||||
bool include_delimiter = !keep_pattern.empty();
|
||||
|
||||
auto regex_flag = std::regex_constants::ECMAScript;
|
||||
if (ignore_case_) {
|
||||
regex_flag |= std::regex_constants::icase;
|
||||
}
|
||||
|
||||
std::regex reg(str_pattern[0], regex_flag);
|
||||
std::regex keep_reg(include_delimiter ? str_keep_pattern[0] : "", regex_flag);
|
||||
std::regex reg(pattern.data(), regex_flag);
|
||||
std::regex keep_reg(include_delimiter ? keep_pattern.data() : "", regex_flag);
|
||||
|
||||
std::vector<std::string> all_tokens;
|
||||
std::vector<int64_t> all_begin_offsets, all_end_offsets;
|
||||
|
@ -72,48 +58,15 @@ void KernelStringECMARegexSplitWithOffsets::Compute(OrtKernelContext* context) {
|
|||
|
||||
// Setup output
|
||||
std::vector<int64_t> dim_out{(int64_t)all_tokens.size()};
|
||||
OrtValue* output_text = ort_.KernelContext_GetOutput(context, 0, dim_out.data(), dim_out.size());
|
||||
FillTensorDataString(api_, ort_, context, all_tokens, output_text);
|
||||
output_text.SetStringOutput(all_tokens, dim_out);
|
||||
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 1, dim_out.data(), dim_out.size());
|
||||
int64_t* p_output = ort_.GetTensorMutableData<int64_t>(output);
|
||||
int64_t* p_output = output1.Allocate(dim_out);
|
||||
memcpy(p_output, all_begin_offsets.data(), all_begin_offsets.size() * sizeof(int64_t));
|
||||
|
||||
output = ort_.KernelContext_GetOutput(context, 2, dim_out.data(), dim_out.size());
|
||||
p_output = ort_.GetTensorMutableData<int64_t>(output);
|
||||
p_output = output2.Allocate(dim_out);
|
||||
memcpy(p_output, all_end_offsets.data(), all_end_offsets.size() * sizeof(int64_t));
|
||||
|
||||
std::vector<int64_t> dim_out_row{(int64_t)row_offsets.size()};
|
||||
output = ort_.KernelContext_GetOutput(context, 3, dim_out_row.data(), dim_out_row.size());
|
||||
p_output = ort_.GetTensorMutableData<int64_t>(output);
|
||||
p_output = output3.Allocate(dim_out_row);
|
||||
memcpy(p_output, row_offsets.data(), row_offsets.size() * sizeof(int64_t));
|
||||
}
|
||||
|
||||
const char* CustomOpStringECMARegexSplitWithOffsets::GetName() const { return "StringECMARegexSplitWithOffsets"; };
|
||||
|
||||
size_t CustomOpStringECMARegexSplitWithOffsets::GetInputTypeCount() const {
|
||||
return 3;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringECMARegexSplitWithOffsets::GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
||||
|
||||
size_t CustomOpStringECMARegexSplitWithOffsets::GetOutputTypeCount() const {
|
||||
return 4;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringECMARegexSplitWithOffsets::GetOutputType(size_t index) const {
|
||||
switch (index) {
|
||||
case 0:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
case 1:
|
||||
case 2:
|
||||
case 3:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
default:
|
||||
ORTX_CXX_API_THROW(MakeString(
|
||||
"StringRegexSplitWithOffsets has 4 outputs but index is ", index, "."),
|
||||
ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -10,20 +10,18 @@
|
|||
// See https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/regex_split_with_offsets.md.
|
||||
struct KernelStringECMARegexSplitWithOffsets : BaseKernel {
|
||||
KernelStringECMARegexSplitWithOffsets(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
void Compute(const ortc::Tensor<std::string>& input,
|
||||
std::string_view pattern,
|
||||
std::string_view keep_pattern,
|
||||
ortc::Tensor<std::string>& output_text,
|
||||
ortc::Tensor<int64_t>& output1,
|
||||
ortc::Tensor<int64_t>& output2,
|
||||
ortc::Tensor<int64_t>& output3);
|
||||
|
||||
private:
|
||||
bool ignore_case_;
|
||||
};
|
||||
|
||||
struct CustomOpStringECMARegexSplitWithOffsets : OrtW::CustomOpBase<CustomOpStringECMARegexSplitWithOffsets, KernelStringECMARegexSplitWithOffsets> {
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void ECMARegexSplitImpl(const std::string& input, const std::regex& pattern,
|
||||
bool include_delimiter, const std::regex& include_delim_regex,
|
||||
|
|
|
@ -8,121 +8,37 @@
|
|||
#include "string_tensor.h"
|
||||
#include "string_hash.hpp"
|
||||
|
||||
|
||||
KernelStringHash::KernelStringHash(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void KernelStringHash::Compute(OrtKernelContext* context) {
|
||||
void string_hash(const ortc::Tensor<std::string>& input,
|
||||
int64_t num_buckets,
|
||||
ortc::Tensor<int64_t>& output) {
|
||||
// Setup inputs
|
||||
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
|
||||
const OrtValue* num_buckets = ort_.KernelContext_GetInput(context, 1);
|
||||
const int64_t* p_num_buckets = ort_.GetTensorData<int64_t>(num_buckets);
|
||||
std::vector<std::string> str_input;
|
||||
GetTensorMutableDataString(api_, ort_, context, input, str_input);
|
||||
|
||||
// Verifications
|
||||
OrtTensorDimensions num_buckets_dimensions(ort_, num_buckets);
|
||||
if (num_buckets_dimensions.size() != 1 || num_buckets_dimensions[0] != 1)
|
||||
ORTX_CXX_API_THROW(MakeString(
|
||||
"num_buckets must contain only one element. It has ",
|
||||
num_buckets_dimensions.size(), " dimensions."), ORT_INVALID_ARGUMENT);
|
||||
auto& str_input = input.Data();
|
||||
|
||||
// Setup output
|
||||
OrtTensorDimensions dimensions(ort_, input);
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
|
||||
int64_t* out = ort_.GetTensorMutableData<int64_t>(output);
|
||||
auto& dimensions = input.Shape();
|
||||
int64_t* out = output.Allocate(dimensions);
|
||||
|
||||
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
|
||||
size_t size = ort_.GetTensorShapeElementCount(output_info);
|
||||
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
|
||||
size_t size = output.NumberOfElement();
|
||||
|
||||
// Do computation
|
||||
size_t nb = static_cast<size_t>(*p_num_buckets);
|
||||
size_t nb = static_cast<size_t>(num_buckets);
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
out[i] = static_cast<int64_t>(Hash64(str_input[i].c_str(), str_input[i].size()) % nb);
|
||||
}
|
||||
}
|
||||
|
||||
const char* CustomOpStringHash::GetName() const { return "StringToHashBucket"; };
|
||||
|
||||
size_t CustomOpStringHash::GetInputTypeCount() const {
|
||||
return 2;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringHash::GetInputType(size_t index) const {
|
||||
switch (index) {
|
||||
case 0:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
case 1:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
default:
|
||||
ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
};
|
||||
|
||||
size_t CustomOpStringHash::GetOutputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringHash::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
};
|
||||
|
||||
KernelStringHashFast::KernelStringHashFast(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void KernelStringHashFast::Compute(OrtKernelContext* context) {
|
||||
void string_hash_fast(const ortc::Tensor<std::string>& input,
|
||||
int64_t num_buckets,
|
||||
ortc::Tensor<int64_t>& output) {
|
||||
// Setup inputs
|
||||
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
|
||||
const OrtValue* num_buckets = ort_.KernelContext_GetInput(context, 1);
|
||||
const int64_t* p_num_buckets = ort_.GetTensorData<int64_t>(num_buckets);
|
||||
std::vector<std::string> str_input;
|
||||
GetTensorMutableDataString(api_, ort_, context, input, str_input);
|
||||
|
||||
// Verifications
|
||||
OrtTensorDimensions num_buckets_dimensions(ort_, num_buckets);
|
||||
if (num_buckets_dimensions.size() != 1 || num_buckets_dimensions[0] != 1)
|
||||
ORTX_CXX_API_THROW(MakeString(
|
||||
"num_buckets must contain only one element. It has ",
|
||||
num_buckets_dimensions.size(), " dimensions."), ORT_INVALID_ARGUMENT);
|
||||
|
||||
auto& str_input = input.Data();
|
||||
// Setup output
|
||||
OrtTensorDimensions dimensions(ort_, input);
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
|
||||
int64_t* out = ort_.GetTensorMutableData<int64_t>(output);
|
||||
|
||||
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
|
||||
size_t size = ort_.GetTensorShapeElementCount(output_info);
|
||||
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
|
||||
|
||||
auto& dimensions = input.Shape();
|
||||
int64_t* out = output.Allocate(dimensions);
|
||||
size_t size = output.NumberOfElement();
|
||||
// Do computation
|
||||
size_t nb = static_cast<size_t>(*p_num_buckets);
|
||||
size_t nb = static_cast<size_t>(num_buckets);
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
out[i] = static_cast<int64_t>(util::Fingerprint64(str_input[i].c_str(), str_input[i].size()) % nb);
|
||||
}
|
||||
}
|
||||
|
||||
const char* CustomOpStringHashFast::GetName() const { return "StringToHashBucketFast"; };
|
||||
|
||||
size_t CustomOpStringHashFast::GetInputTypeCount() const {
|
||||
return 2;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringHashFast::GetInputType(size_t index) const {
|
||||
switch (index) {
|
||||
case 0:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
case 1:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
default:
|
||||
ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
};
|
||||
|
||||
size_t CustomOpStringHashFast::GetOutputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringHashFast::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
};
|
||||
|
|
|
@ -6,28 +6,9 @@
|
|||
#include "ocos.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringHash : BaseKernel {
|
||||
KernelStringHash(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
struct CustomOpStringHash : OrtW::CustomOpBase<CustomOpStringHash, KernelStringHash> {
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
};
|
||||
|
||||
struct KernelStringHashFast : BaseKernel {
|
||||
KernelStringHashFast(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
struct CustomOpStringHashFast : OrtW::CustomOpBase<CustomOpStringHashFast, KernelStringHashFast> {
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
};
|
||||
void string_hash(const ortc::Tensor<std::string>& input,
|
||||
int64_t num_buckets,
|
||||
ortc::Tensor<int64_t>& output);
|
||||
void string_hash_fast(const ortc::Tensor<std::string>& input,
|
||||
int64_t num_buckets,
|
||||
ortc::Tensor<int64_t>& output);
|
||||
|
|
|
@ -4,40 +4,26 @@
|
|||
#include "string_join.hpp"
|
||||
#include "string_tensor.h"
|
||||
|
||||
KernelStringJoin::KernelStringJoin(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void KernelStringJoin::Compute(OrtKernelContext* context) {
|
||||
void string_join(const ortc::Tensor<std::string>& input_X,
|
||||
std::string_view input_sep,
|
||||
int64_t axis,
|
||||
ortc::Tensor<std::string>& output) {
|
||||
// Setup inputs
|
||||
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
||||
const OrtValue* input_sep = ort_.KernelContext_GetInput(context, 1);
|
||||
const OrtValue* input_axis = ort_.KernelContext_GetInput(context, 2);
|
||||
const int64_t* axis = ort_.GetTensorData<int64_t>(input_axis);
|
||||
std::vector<std::string> X, sep;
|
||||
GetTensorMutableDataString(api_, ort_, context, input_X, X);
|
||||
GetTensorMutableDataString(api_, ort_, context, input_sep, sep);
|
||||
|
||||
// Check input
|
||||
OrtTensorDimensions dimensions_sep(ort_, input_sep);
|
||||
if (dimensions_sep.size() != 1 || dimensions_sep[0] != 1)
|
||||
ORTX_CXX_API_THROW("Input 2 is the separator, it should have 1 element.", ORT_INVALID_ARGUMENT);
|
||||
OrtTensorDimensions dimensions_axis(ort_, input_axis);
|
||||
if (dimensions_axis.size() != 1 || dimensions_axis[0] != 1)
|
||||
ORTX_CXX_API_THROW("Input 3 is the axis, it should have 1 element.", ORT_INVALID_ARGUMENT);
|
||||
OrtTensorDimensions dimensions(ort_, input_X);
|
||||
auto& X = input_X.Data();
|
||||
auto& dimensions = input_X.Shape();
|
||||
if (dimensions.size() == 0) {
|
||||
// dimensions size 0 means input 1 is scalar, input 1 must have 1 element. See issue: https://github.com/onnx/onnx/issues/3724
|
||||
if (X.size() != 1)
|
||||
ORTX_CXX_API_THROW(MakeString("Input 1's dimensions size is 0 (scalar), it must has 1 element but it has ", X.size()), ORT_INVALID_ARGUMENT);
|
||||
} else {
|
||||
if (*axis < 0 || *axis >= static_cast<int64_t>(dimensions.size()))
|
||||
ORTX_CXX_API_THROW(MakeString("axis must be positive and smaller than the number of dimension but it is ", *axis), ORT_INVALID_ARGUMENT);
|
||||
if (axis < 0 || axis >= static_cast<int64_t>(dimensions.size()))
|
||||
ORTX_CXX_API_THROW(MakeString("axis must be positive and smaller than the number of dimension but it is ", axis), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
std::vector<int64_t> dimensions_out(dimensions.size() > 1 ? dimensions.size() - 1 : 1);
|
||||
if (dimensions.size() > 1) {
|
||||
for (size_t i = 0, pos = 0; i < dimensions.size(); ++i) {
|
||||
if (static_cast<int64_t>(i) == *axis)
|
||||
if (static_cast<int64_t>(i) == axis)
|
||||
continue;
|
||||
dimensions_out[pos++] = dimensions[i];
|
||||
}
|
||||
|
@ -45,22 +31,19 @@ void KernelStringJoin::Compute(OrtKernelContext* context) {
|
|||
dimensions_out[0] = 1;
|
||||
}
|
||||
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions_out.data(), dimensions_out.size());
|
||||
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
|
||||
int64_t size = ort_.GetTensorShapeElementCount(output_info);
|
||||
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
|
||||
int64_t size = std::accumulate(dimensions_out.begin(), dimensions_out.end(), 1ULL, std::multiplies<int64_t>());
|
||||
std::vector<std::string> out(static_cast<size_t>(size));
|
||||
|
||||
if (dimensions.size() > 0) {
|
||||
if (X.size() > 0) {
|
||||
// Do computation
|
||||
int64_t h = 1;
|
||||
for (size_t i = static_cast<size_t>(*axis + 1); i < dimensions.size(); ++i) {
|
||||
for (size_t i = static_cast<size_t>(axis + 1); i < dimensions.size(); ++i) {
|
||||
h *= dimensions[i];
|
||||
}
|
||||
int64_t left_part = size / h;
|
||||
int64_t right_part = size / left_part;
|
||||
int64_t n_red = dimensions[static_cast<size_t>(*axis)] - 1;
|
||||
int64_t n_red = dimensions[static_cast<size_t>(axis)] - 1;
|
||||
int64_t inc = right_part * (n_red + 1);
|
||||
int64_t pos = 0;
|
||||
for (int64_t li = 0; li < left_part; ++li) {
|
||||
|
@ -68,7 +51,7 @@ void KernelStringJoin::Compute(OrtKernelContext* context) {
|
|||
std::ostringstream st;
|
||||
int64_t index = ri + li * inc;
|
||||
for (int64_t j = 0; j < n_red; ++j, index += h) {
|
||||
st << X[static_cast<size_t>(index)] << sep[0];
|
||||
st << X[static_cast<size_t>(index)] << input_sep;
|
||||
}
|
||||
st << X[static_cast<size_t>(index)];
|
||||
out[static_cast<size_t>(pos)] = st.str();
|
||||
|
@ -83,33 +66,5 @@ void KernelStringJoin::Compute(OrtKernelContext* context) {
|
|||
out[0] = X[0];
|
||||
}
|
||||
|
||||
FillTensorDataString(api_, ort_, context, out, output);
|
||||
output.SetStringOutput(out, dimensions_out);
|
||||
}
|
||||
|
||||
const char* CustomOpStringJoin::GetName() const {
|
||||
return "StringJoin";
|
||||
};
|
||||
|
||||
size_t CustomOpStringJoin::GetInputTypeCount() const {
|
||||
return 3;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringJoin::GetInputType(size_t index) const {
|
||||
switch (index) {
|
||||
case 0:
|
||||
case 1:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
case 2:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
default:
|
||||
ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
};
|
||||
|
||||
size_t CustomOpStringJoin::GetOutputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringJoin::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
||||
|
|
|
@ -6,15 +6,7 @@
|
|||
#include "ocos.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringJoin : BaseKernel {
|
||||
KernelStringJoin(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
struct CustomOpStringJoin : OrtW::CustomOpBase<CustomOpStringJoin, KernelStringJoin> {
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
};
|
||||
void string_join(const ortc::Tensor<std::string>& input_X,
|
||||
std::string_view input_sep,
|
||||
int64_t axis,
|
||||
ortc::Tensor<std::string>& output);
|
||||
|
|
|
@ -9,38 +9,15 @@
|
|||
#include <algorithm>
|
||||
#include "ustring.h"
|
||||
|
||||
KernelStringLength::KernelStringLength(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void KernelStringLength::Compute(OrtKernelContext* context) {
|
||||
void string_length(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& output) {
|
||||
// Setup inputs
|
||||
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
|
||||
std::vector<std::string> input_data;
|
||||
GetTensorMutableDataString(api_, ort_, context, input, input_data);
|
||||
auto& input_data = input.Data();
|
||||
|
||||
OrtTensorDimensions dimensions(ort_, input);
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
|
||||
auto* output_data = ort_.GetTensorMutableData<int64_t>(output);
|
||||
auto& dimensions = input.Shape();
|
||||
auto* output_data = output.Allocate(dimensions);
|
||||
|
||||
for (int i = 0; i < dimensions.Size(); i++) {
|
||||
for (int i = 0; i < input.NumberOfElement(); i++) {
|
||||
output_data[i] = ustring(input_data[i]).size();
|
||||
}
|
||||
}
|
||||
|
||||
const char* CustomOpStringLength::GetName() const { return "StringLength"; };
|
||||
|
||||
size_t CustomOpStringLength::GetInputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringLength::GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
||||
|
||||
size_t CustomOpStringLength::GetOutputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringLength::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
};
|
||||
|
|
|
@ -6,15 +6,5 @@
|
|||
#include "ocos.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringLength : BaseKernel {
|
||||
KernelStringLength(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
struct CustomOpStringLength : OrtW::CustomOpBase<CustomOpStringLength, KernelStringLength> {
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
};
|
||||
void string_length(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& output);
|
||||
|
|
|
@ -7,38 +7,14 @@
|
|||
#include <cmath>
|
||||
#include <algorithm>
|
||||
|
||||
KernelStringLower::KernelStringLower(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void KernelStringLower::Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
||||
std::vector<std::string> X;
|
||||
GetTensorMutableDataString(api_, ort_, context, input_X, X);
|
||||
void string_lower(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<std::string>& output) {
|
||||
// make a copy as input is constant
|
||||
std::vector<std::string> X = input.Data();
|
||||
|
||||
for (size_t i = 0; i < X.size(); ++i) {
|
||||
std::transform(X[i].begin(), X[i].end(), X[i].begin(), [](char c) {return static_cast<char>(ToLower(c));});
|
||||
std::transform(X[i].begin(), X[i].end(), X[i].begin(), [](char c) { return static_cast<char>(ToLower(c)); });
|
||||
}
|
||||
|
||||
OrtTensorDimensions dimensions(ort_, input_X);
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
|
||||
FillTensorDataString(api_, ort_, context, X, output);
|
||||
output.SetStringOutput(X, input.Shape());
|
||||
}
|
||||
|
||||
const char* CustomOpStringLower::GetName() const { return "StringLower"; };
|
||||
|
||||
size_t CustomOpStringLower::GetInputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringLower::GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
||||
|
||||
size_t CustomOpStringLower::GetOutputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringLower::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
||||
|
|
|
@ -6,15 +6,5 @@
|
|||
#include "ocos.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringLower : BaseKernel {
|
||||
KernelStringLower(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
struct CustomOpStringLower : OrtW::CustomOpBase<CustomOpStringLower, KernelStringLower> {
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
};
|
||||
void string_lower(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<std::string>& output);
|
||||
|
|
|
@ -21,39 +21,15 @@ KernelStringMapping::KernelStringMapping(const OrtApi& api, const OrtKernelInfo&
|
|||
}
|
||||
}
|
||||
|
||||
void KernelStringMapping::Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
|
||||
std::vector<std::string> input_data;
|
||||
GetTensorMutableDataString(api_, ort_, context, input, input_data);
|
||||
|
||||
OrtTensorDimensions dimensions(ort_, input);
|
||||
void KernelStringMapping::Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<std::string>& output) {
|
||||
// make a copy as input is constant
|
||||
std::vector<std::string> input_data = input.Data();
|
||||
|
||||
for (auto& str : input_data) {
|
||||
if (map_.find(str) != map_.end()) {
|
||||
str = map_[str];
|
||||
}
|
||||
}
|
||||
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
|
||||
|
||||
FillTensorDataString(api_, ort_, context, input_data, output);
|
||||
output.SetStringOutput(input_data, input.Shape());
|
||||
}
|
||||
|
||||
const char* CustomOpStringMapping::GetName() const { return "StringMapping"; };
|
||||
|
||||
size_t CustomOpStringMapping::GetInputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringMapping::GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
||||
|
||||
size_t CustomOpStringMapping::GetOutputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringMapping::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
||||
|
|
|
@ -9,16 +9,9 @@
|
|||
|
||||
struct KernelStringMapping : BaseKernel {
|
||||
KernelStringMapping(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
void Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<std::string>& output);
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, std::string> map_;
|
||||
};
|
||||
|
||||
struct CustomOpStringMapping : OrtW::CustomOpBase<CustomOpStringMapping, KernelStringMapping> {
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
};
|
||||
|
|
|
@ -4,27 +4,17 @@
|
|||
#include "string_split.hpp"
|
||||
#include "string_tensor.h"
|
||||
|
||||
KernelStringSplit::KernelStringSplit(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void KernelStringSplit::Compute(OrtKernelContext* context) {
|
||||
void string_split(const ortc::Tensor<std::string>& input_X,
|
||||
std::string_view sep,
|
||||
bool skip_empty,
|
||||
ortc::Tensor<int64_t>& out_indices,
|
||||
ortc::Tensor<std::string>& out_text,
|
||||
ortc::Tensor<int64_t>& out_shape) {
|
||||
// Setup inputs
|
||||
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
||||
const OrtValue* input_sep = ort_.KernelContext_GetInput(context, 1);
|
||||
const OrtValue* input_skip_empty = ort_.KernelContext_GetInput(context, 2);
|
||||
const bool* skip_empty = ort_.GetTensorData<bool>(input_skip_empty);
|
||||
std::vector<std::string> X, sep;
|
||||
GetTensorMutableDataString(api_, ort_, context, input_X, X);
|
||||
GetTensorMutableDataString(api_, ort_, context, input_sep, sep);
|
||||
auto& X = input_X.Data();
|
||||
|
||||
// Setup output
|
||||
OrtTensorDimensions dimensions_sep(ort_, input_sep);
|
||||
if (dimensions_sep.size() != 1 || dimensions_sep[0] != 1)
|
||||
ORTX_CXX_API_THROW("Input 2 is the delimiter, it has 1 element.", ORT_INVALID_ARGUMENT);
|
||||
OrtTensorDimensions dimensions_skip_empty(ort_, input_skip_empty);
|
||||
if (dimensions_skip_empty.size() != 1 || dimensions_skip_empty[0] != 1)
|
||||
ORTX_CXX_API_THROW("Input 3 is skip_empty, it has 1 element.", ORT_INVALID_ARGUMENT);
|
||||
OrtTensorDimensions dimensions(ort_, input_X);
|
||||
auto& dimensions = input_X.Shape();
|
||||
if (dimensions.size() != 1)
|
||||
ORTX_CXX_API_THROW("Only 1D tensor are supported as input.", ORT_INVALID_ARGUMENT);
|
||||
|
||||
|
@ -32,8 +22,7 @@ void KernelStringSplit::Compute(OrtKernelContext* context) {
|
|||
std::vector<int64_t> indices;
|
||||
int64_t maxc = 0;
|
||||
int64_t col;
|
||||
std::string delimiter = sep[0];
|
||||
if (delimiter.size() == 0) {
|
||||
if (sep.size() == 0) {
|
||||
char word[2] = "a";
|
||||
for (int64_t row = 0; row < dimensions[0]; ++row) {
|
||||
const std::string& str = X[static_cast<size_t>(row)];
|
||||
|
@ -48,7 +37,7 @@ void KernelStringSplit::Compute(OrtKernelContext* context) {
|
|||
}
|
||||
}
|
||||
} else {
|
||||
bool keep = !(*skip_empty);
|
||||
bool keep = !skip_empty;
|
||||
std::size_t current, previous = 0;
|
||||
for (int64_t row = 0; row < dimensions[0]; ++row) {
|
||||
const std::string& str = X[static_cast<size_t>(row)];
|
||||
|
@ -56,7 +45,7 @@ void KernelStringSplit::Compute(OrtKernelContext* context) {
|
|||
continue;
|
||||
previous = 0;
|
||||
col = 0;
|
||||
current = str.find_first_of(delimiter);
|
||||
current = str.find_first_of(sep);
|
||||
while (current != std::string::npos) {
|
||||
if (keep || current > previous) {
|
||||
words.push_back(str.substr(previous, current - previous));
|
||||
|
@ -65,7 +54,7 @@ void KernelStringSplit::Compute(OrtKernelContext* context) {
|
|||
++col;
|
||||
}
|
||||
previous = current + 1;
|
||||
current = str.find_first_of(delimiter, previous);
|
||||
current = str.find_first_of(sep, previous);
|
||||
}
|
||||
current = str.size();
|
||||
if (keep || current > previous) {
|
||||
|
@ -79,55 +68,13 @@ void KernelStringSplit::Compute(OrtKernelContext* context) {
|
|||
}
|
||||
|
||||
std::vector<int64_t> shape_indices = {static_cast<int64_t>(indices.size()) / 2, 2};
|
||||
OrtValue* out_indices = ort_.KernelContext_GetOutput(context, 0, shape_indices.data(), shape_indices.size());
|
||||
|
||||
int64_t* p_indices = out_indices.Allocate(shape_indices);
|
||||
std::vector<int64_t> shape_text(1, words.size());
|
||||
OrtValue* out_text = ort_.KernelContext_GetOutput(context, 1, shape_text.data(), shape_text.size());
|
||||
|
||||
std::vector<int64_t> shape_shape(1, 2);
|
||||
OrtValue* out_shape = ort_.KernelContext_GetOutput(context, 2, shape_shape.data(), shape_shape.size());
|
||||
|
||||
int64_t* p_indices = ort_.GetTensorMutableData<int64_t>(out_indices);
|
||||
int64_t* p_shape = ort_.GetTensorMutableData<int64_t>(out_shape);
|
||||
int64_t* p_shape = out_shape.Allocate(shape_shape);
|
||||
|
||||
memcpy(p_indices, indices.data(), indices.size() * sizeof(int64_t));
|
||||
p_shape[0] = dimensions[0];
|
||||
p_shape[1] = maxc;
|
||||
FillTensorDataString(api_, ort_, context, words, out_text);
|
||||
out_text.SetStringOutput(words, shape_text);
|
||||
}
|
||||
|
||||
const char* CustomOpStringSplit::GetName() const {
|
||||
return "StringSplit";
|
||||
};
|
||||
|
||||
size_t CustomOpStringSplit::GetInputTypeCount() const {
|
||||
return 3;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringSplit::GetInputType(size_t index) const {
|
||||
switch (index) {
|
||||
case 0:
|
||||
case 1:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
case 2:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
|
||||
default:
|
||||
ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
};
|
||||
|
||||
size_t CustomOpStringSplit::GetOutputTypeCount() const {
|
||||
return 3;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringSplit::GetOutputType(size_t index) const {
|
||||
switch (index) {
|
||||
case 0:
|
||||
case 2:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
case 1:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
default:
|
||||
ORTX_CXX_API_THROW(MakeString("[StringSplit] Unexpected output index ", index), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -6,15 +6,9 @@
|
|||
#include "ocos.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringSplit : BaseKernel {
|
||||
KernelStringSplit(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
struct CustomOpStringSplit : OrtW::CustomOpBase<CustomOpStringSplit, KernelStringSplit> {
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
};
|
||||
void string_split(const ortc::Tensor<std::string>& input_X,
|
||||
std::string_view sep,
|
||||
bool skip_empty,
|
||||
ortc::Tensor<int64_t>& out_indices,
|
||||
ortc::Tensor<std::string>& out_text,
|
||||
ortc::Tensor<int64_t>& out_shape);
|
||||
|
|
|
@ -9,46 +9,16 @@
|
|||
|
||||
const char* WHITE_SPACE_CHARS = " \t\n\r\f\v";
|
||||
|
||||
KernelStringStrip::KernelStringStrip(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void KernelStringStrip::Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
||||
std::vector<std::string> X;
|
||||
GetTensorMutableDataString(api_, ort_, context, input_X, X);
|
||||
|
||||
// For each string in input, replace with whitespace-trimmed version.
|
||||
void string_strip(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<std::string>& output) {
|
||||
std::vector<std::string> X = input.Data();
|
||||
for (size_t i = 0; i < X.size(); ++i) {
|
||||
size_t nonWhitespaceBegin = X[i].find_first_not_of(WHITE_SPACE_CHARS);
|
||||
if (nonWhitespaceBegin != std::string::npos) {
|
||||
size_t nonWhitespaceEnd = X[i].find_last_not_of(WHITE_SPACE_CHARS);
|
||||
size_t nonWhitespaceRange = nonWhitespaceEnd - nonWhitespaceBegin + 1;
|
||||
|
||||
X[i] = X[i].substr(nonWhitespaceBegin, nonWhitespaceRange);
|
||||
}
|
||||
}
|
||||
|
||||
// Fills the output
|
||||
OrtTensorDimensions dimensions(ort_, input_X);
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
|
||||
FillTensorDataString(api_, ort_, context, X, output);
|
||||
output.SetStringOutput(X, input.Shape());
|
||||
}
|
||||
|
||||
const char* CustomOpStringStrip::GetName() const { return "StringStrip"; };
|
||||
|
||||
size_t CustomOpStringStrip::GetInputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringStrip::GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
||||
|
||||
size_t CustomOpStringStrip::GetOutputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringStrip::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
||||
|
|
|
@ -6,15 +6,5 @@
|
|||
#include "ocos.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringStrip : BaseKernel {
|
||||
KernelStringStrip(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
struct CustomOpStringStrip : OrtW::CustomOpBase<CustomOpStringStrip, KernelStringStrip> {
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
};
|
||||
void string_strip(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<std::string>& output);
|
|
@ -9,7 +9,9 @@ StringToVectorImpl::StringToVectorImpl(std::string& map, std::string& unk) {
|
|||
ParseUnkownValue(unk);
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> StringToVectorImpl::Compute(std::vector<std::string>& str_input, const OrtTensorDimensions& input_dim, OrtTensorDimensions& output_dim) {
|
||||
std::vector<std::vector<int64_t>> StringToVectorImpl::Compute(const std::vector<std::string>& str_input,
|
||||
const std::vector<int64_t>& input_dim,
|
||||
std::vector<int64_t>& output_dim) {
|
||||
std::vector<std::vector<int64_t>> result;
|
||||
|
||||
// Set output dimension
|
||||
|
@ -107,19 +109,15 @@ KernelStringToVector::KernelStringToVector(const OrtApi& api, const OrtKernelInf
|
|||
impl_ = std::make_shared<StringToVectorImpl>(map, unk);
|
||||
}
|
||||
|
||||
void KernelStringToVector::Compute(OrtKernelContext* context) {
|
||||
void KernelStringToVector::Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& out) {
|
||||
// Setup input
|
||||
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
|
||||
std::vector<std::string> input_data;
|
||||
GetTensorMutableDataString(api_, ort_, context, input, input_data);
|
||||
OrtTensorDimensions input_dim(ort_, input);
|
||||
|
||||
auto& input_data = input.Data();
|
||||
// Get output
|
||||
OrtTensorDimensions output_dim;
|
||||
auto mapping_result = impl_->Compute(input_data, input_dim, output_dim);
|
||||
std::vector<int64_t> output_dim;
|
||||
auto mapping_result = impl_->Compute(input_data, input.Shape(), output_dim);
|
||||
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, output_dim.data(), output_dim.size());
|
||||
auto* output_data = ort_.GetTensorMutableData<int64_t>(output);
|
||||
auto* output_data = out.Allocate(output_dim);
|
||||
|
||||
// Set output tensor data
|
||||
int idx = 0;
|
||||
|
@ -130,21 +128,3 @@ void KernelStringToVector::Compute(OrtKernelContext* context) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
const char* CustomOpStringToVector::GetName() const { return "StringToVector"; };
|
||||
|
||||
size_t CustomOpStringToVector::GetInputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringToVector::GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
||||
|
||||
size_t CustomOpStringToVector::GetOutputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringToVector::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
};
|
||||
|
|
|
@ -12,7 +12,9 @@
|
|||
class StringToVectorImpl {
|
||||
public:
|
||||
StringToVectorImpl(std::string& map, std::string& unk);
|
||||
std::vector<std::vector<int64_t>> Compute(std::vector<std::string>& str_input, const OrtTensorDimensions& input_dim, OrtTensorDimensions& output_dim);
|
||||
std::vector<std::vector<int64_t>> Compute(const std::vector<std::string>& str_input,
|
||||
const std::vector<int64_t>& input_dim,
|
||||
std::vector<int64_t>& output_dim);
|
||||
|
||||
private:
|
||||
void ParseMappingTable(std::string& map);
|
||||
|
@ -29,16 +31,9 @@ class StringToVectorImpl {
|
|||
|
||||
struct KernelStringToVector : BaseKernel {
|
||||
KernelStringToVector(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
void Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& out);
|
||||
|
||||
private:
|
||||
std::shared_ptr<StringToVectorImpl> impl_;
|
||||
};
|
||||
|
||||
struct CustomOpStringToVector : OrtW::CustomOpBase<CustomOpStringToVector, KernelStringToVector> {
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
};
|
||||
|
|
|
@ -7,39 +7,14 @@
|
|||
#include <cmath>
|
||||
#include <algorithm>
|
||||
|
||||
KernelStringUpper::KernelStringUpper(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void KernelStringUpper::Compute(OrtKernelContext* context) {
|
||||
void string_upper(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<std::string>& output) {
|
||||
// Setup inputs
|
||||
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
||||
std::vector<std::string> X;
|
||||
GetTensorMutableDataString(api_, ort_, context, input_X, X);
|
||||
std::vector<std::string> X = input.Data();
|
||||
|
||||
for (size_t i = 0; i < X.size(); ++i) {
|
||||
std::transform(X[i].begin(), X[i].end(), X[i].begin(), [](char c) { return static_cast<char>(::toupper(c)); });
|
||||
}
|
||||
|
||||
// Fills the output
|
||||
OrtTensorDimensions dimensions(ort_, input_X);
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
|
||||
FillTensorDataString(api_, ort_, context, X, output);
|
||||
output.SetStringOutput(X, input.Shape());
|
||||
}
|
||||
|
||||
const char* CustomOpStringUpper::GetName() const { return "StringUpper"; };
|
||||
|
||||
size_t CustomOpStringUpper::GetInputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringUpper::GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
||||
|
||||
size_t CustomOpStringUpper::GetOutputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpStringUpper::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
||||
|
|
|
@ -6,15 +6,5 @@
|
|||
#include "ocos.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringUpper : BaseKernel {
|
||||
KernelStringUpper(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
struct CustomOpStringUpper : OrtW::CustomOpBase<CustomOpStringUpper, KernelStringUpper> {
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
};
|
||||
void string_upper(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<std::string>& output);
|
||||
|
|
|
@ -20,28 +20,33 @@
|
|||
#include "text/re2_strings/string_regex_split.hpp"
|
||||
#endif // ENABLE_RE2_REGEX
|
||||
|
||||
FxLoadCustomOpFactory LoadCustomOpClasses_Text =
|
||||
LoadCustomOpClasses<CustomOpClassBegin,
|
||||
const std::vector<const OrtCustomOp*>& TextLoader() {
|
||||
static OrtOpLoader op_loader(
|
||||
#if defined(ENABLE_RE2_REGEX)
|
||||
CustomOpStringRegexReplace,
|
||||
CustomOpStringRegexSplitWithOffsets,
|
||||
CustomCpuStruct("StringRegexReplace", KernelStringRegexReplace),
|
||||
CustomCpuFunc("StringRegexSplitWithOffsets", KernelStringRegexSplitWithOffsets),
|
||||
#endif // ENABLE_RE2_REGEX
|
||||
CustomOpRaggedTensorToDense,
|
||||
CustomOpRaggedTensorToSparse,
|
||||
CustomOpStringRaggedTensorToDense,
|
||||
CustomOpStringEqual,
|
||||
CustomOpStringHash,
|
||||
CustomOpStringHashFast,
|
||||
CustomOpStringJoin,
|
||||
CustomOpStringLower,
|
||||
CustomOpStringUpper,
|
||||
CustomOpStringMapping,
|
||||
CustomOpMaskedFill,
|
||||
CustomOpStringSplit,
|
||||
CustomOpStringStrip,
|
||||
CustomOpStringToVector,
|
||||
CustomOpVectorToString,
|
||||
CustomOpStringLength,
|
||||
CustomOpStringConcat,
|
||||
CustomOpStringECMARegexReplace,
|
||||
CustomOpStringECMARegexSplitWithOffsets>;
|
||||
CustomCpuStruct("RaggedTensorToSparse", KernelRaggedTensoroSparse),
|
||||
CustomCpuStruct("RaggedTensorToDense", KernelRaggedTensoroDense),
|
||||
CustomCpuStruct("StringRaggedTensorToDense", KernelStringRaggedTensoroDense),
|
||||
CustomCpuStruct("StringEqual", KernelStringEqual),
|
||||
CustomCpuFunc("StringToHashBucket", string_hash),
|
||||
CustomCpuFunc("StringToHashBucketFast", string_hash_fast),
|
||||
CustomCpuFunc("StringJoin", string_join),
|
||||
CustomCpuFunc("StringLower", string_lower),
|
||||
CustomCpuFunc("StringUpper", string_upper),
|
||||
CustomCpuStruct("StringMapping", KernelStringMapping),
|
||||
CustomCpuFunc("MaskedFill", masked_fill),
|
||||
CustomCpuFunc("StringSplit", string_split),
|
||||
CustomCpuFunc("StringStrip", string_strip),
|
||||
CustomCpuStruct("StringToVector", KernelStringToVector),
|
||||
CustomCpuStruct("VectorToString", KernelVectorToString),
|
||||
CustomCpuFunc("StringLength", string_length),
|
||||
CustomCpuFunc("StringConcat", string_concat),
|
||||
CustomCpuStruct("StringECMARegexReplace", KernelStringECMARegexReplace),
|
||||
CustomCpuStruct("StringECMARegexSplitWithOffsets", KernelStringECMARegexSplitWithOffsets));
|
||||
return op_loader.GetCustomOps();
|
||||
}
|
||||
|
||||
FxLoadCustomOpFactory LoadCustomOpClasses_Text = TextLoader;
|
||||
|
||||
|
|
|
@ -18,16 +18,18 @@ VectorToStringImpl::VectorToStringImpl(std::string& map, std::string& unk) : unk
|
|||
ParseMappingTable(map);
|
||||
}
|
||||
|
||||
std::vector<std::string> VectorToStringImpl::Compute(const void* input, const OrtTensorDimensions& input_dim, OrtTensorDimensions& output_dim) {
|
||||
std::vector<std::string> VectorToStringImpl::Compute(const void* input,
|
||||
const std::vector<int64_t>& input_dim,
|
||||
std::vector<int64_t>& output_dim) {
|
||||
std::vector<std::string> result;
|
||||
|
||||
const int64_t* ptr = static_cast<const int64_t*>(input);
|
||||
|
||||
if (vector_len_ == 1 && (input_dim.size() == 1 || input_dim.IsScalar())) {
|
||||
if (vector_len_ == 1 && (input_dim.size() == 1 || input_dim.empty())) {
|
||||
// only hit when the key is a scalar and the input is a vector
|
||||
output_dim = input_dim;
|
||||
} else {
|
||||
if (input_dim.IsScalar() || input_dim[input_dim.size() - 1] != static_cast<int64_t>(vector_len_)) {
|
||||
if (input_dim.empty() || input_dim[input_dim.size() - 1] != static_cast<int64_t>(vector_len_)) {
|
||||
ORTX_CXX_API_THROW(MakeString("Incompatible dimension: required vector length should be ", vector_len_), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
|
@ -36,7 +38,8 @@ std::vector<std::string> VectorToStringImpl::Compute(const void* input, const Or
|
|||
}
|
||||
|
||||
std::vector<int64_t> key(vector_len_);
|
||||
for (int64_t i = 0; i < input_dim.Size(); i = static_cast<int64_t>(i + vector_len_)) {
|
||||
int64_t input_element_size = std::accumulate(input_dim.begin(), input_dim.end(), 1ULL, std::multiplies<int64_t>());
|
||||
for (int64_t i = 0; i < input_element_size; i = static_cast<int64_t>(i + vector_len_)) {
|
||||
// construct key
|
||||
for (size_t j = 0; j < vector_len_; j++) {
|
||||
key[j] = ptr[j];
|
||||
|
@ -110,33 +113,11 @@ KernelVectorToString::KernelVectorToString(const OrtApi& api, const OrtKernelInf
|
|||
impl_ = std::make_shared<VectorToStringImpl>(map, unk);
|
||||
}
|
||||
|
||||
void KernelVectorToString::Compute(OrtKernelContext* context) {
|
||||
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
|
||||
const void* input_data = ort_.GetTensorData<int64_t>(input);
|
||||
void KernelVectorToString::Compute(const ortc::Tensor<int64_t>& input,
|
||||
ortc::Tensor<std::string>& out) {
|
||||
const void* input_data = input.Data();
|
||||
|
||||
OrtTensorDimensions input_dim(ort_, input);
|
||||
OrtTensorDimensions output_dim;
|
||||
std::vector<std::string> mapping_result = impl_->Compute(input_data, input_dim, output_dim);
|
||||
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, output_dim.data(), output_dim.size());
|
||||
|
||||
FillTensorDataString(api_, ort_, context, mapping_result, output);
|
||||
std::vector<int64_t> output_dim;
|
||||
std::vector<std::string> mapping_result = impl_->Compute(input_data, input.Shape(), output_dim);
|
||||
out.SetStringOutput(mapping_result, output_dim);
|
||||
}
|
||||
|
||||
const char* CustomOpVectorToString::GetName() const { return "VectorToString"; };
|
||||
|
||||
size_t CustomOpVectorToString::GetInputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpVectorToString::GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
};
|
||||
|
||||
size_t CustomOpVectorToString::GetOutputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpVectorToString::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
||||
|
|
|
@ -20,7 +20,9 @@ struct hash<std::vector<T>> {
|
|||
class VectorToStringImpl {
|
||||
public:
|
||||
VectorToStringImpl(std::string& map, std::string& unk);
|
||||
std::vector<std::string> Compute(const void* input, const OrtTensorDimensions& input_dim, OrtTensorDimensions& output_dim);
|
||||
std::vector<std::string> Compute(const void* input,
|
||||
const std::vector<int64_t>& input_dim,
|
||||
std::vector<int64_t>& output_dim);
|
||||
|
||||
private:
|
||||
void ParseMappingTable(std::string& map);
|
||||
|
@ -34,16 +36,9 @@ class VectorToStringImpl {
|
|||
|
||||
struct KernelVectorToString : BaseKernel {
|
||||
KernelVectorToString(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
void Compute(const ortc::Tensor<int64_t>& input,
|
||||
ortc::Tensor<std::string>& out);
|
||||
|
||||
private:
|
||||
std::shared_ptr<VectorToStringImpl> impl_;
|
||||
};
|
||||
|
||||
struct CustomOpVectorToString : OrtW::CustomOpBase<CustomOpVectorToString, KernelVectorToString> {
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
};
|
||||
|
|
|
@ -91,37 +91,9 @@ KernelBasicTokenizer::KernelBasicTokenizer(const OrtApi& api, const OrtKernelInf
|
|||
tokenizer_ = std::make_shared<BasicTokenizer>(do_lower_case, tokenize_chinese_chars, strip_accents, tokenize_punctuation, remove_control_chars);
|
||||
}
|
||||
|
||||
void KernelBasicTokenizer::Compute(OrtKernelContext* context) {
|
||||
void KernelBasicTokenizer::Compute(std::string_view input,
|
||||
ortc::Tensor<std::string>& output) {
|
||||
// Setup inputs
|
||||
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
|
||||
std::vector<std::string> input_data;
|
||||
GetTensorMutableDataString(api_, ort_, context, input, input_data);
|
||||
|
||||
OrtTensorDimensions dimensions(ort_, input);
|
||||
if (dimensions.size() != 1 && dimensions[0] != 1) {
|
||||
ORTX_CXX_API_THROW("[BasicTokenizer]: only support string scalar.", ORT_INVALID_GRAPH);
|
||||
}
|
||||
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
|
||||
std::vector<ustring> result = tokenizer_->Tokenize(ustring(input_data[0]));
|
||||
|
||||
FillTensorDataString(api_, ort_, context, result, output);
|
||||
std::vector<ustring> result = tokenizer_->Tokenize(ustring(input));
|
||||
output.SetStringOutput({result[0].operator std::string()}, {1});
|
||||
}
|
||||
|
||||
const char* CustomOpBasicTokenizer::GetName() const { return "BasicTokenizer"; };
|
||||
|
||||
size_t CustomOpBasicTokenizer::GetInputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpBasicTokenizer::GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
||||
|
||||
size_t CustomOpBasicTokenizer::GetOutputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpBasicTokenizer::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
||||
|
|
|
@ -23,16 +23,9 @@ class BasicTokenizer {
|
|||
|
||||
struct KernelBasicTokenizer : BaseKernel {
|
||||
KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
void Compute(std::string_view input,
|
||||
ortc::Tensor<std::string>& output);
|
||||
|
||||
private:
|
||||
std::shared_ptr<BasicTokenizer> tokenizer_;
|
||||
};
|
||||
|
||||
struct CustomOpBasicTokenizer : OrtW::CustomOpBase<CustomOpBasicTokenizer, KernelBasicTokenizer> {
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
};
|
||||
|
|
|
@ -292,11 +292,12 @@ KernelBertTokenizer::KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo&
|
|||
ustring(suffix_indicator), max_len, truncation_strategy_name);
|
||||
}
|
||||
|
||||
void KernelBertTokenizer::Compute(OrtKernelContext* context) {
|
||||
void KernelBertTokenizer::Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& output,
|
||||
ortc::Tensor<int64_t>& output1,
|
||||
ortc::Tensor<int64_t>& output2) {
|
||||
// Setup inputs
|
||||
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
|
||||
std::vector<std::string> input_data;
|
||||
GetTensorMutableDataString(api_, ort_, context, input, input_data);
|
||||
auto& input_data = input.Data();
|
||||
|
||||
if (input_data.size() != 1 && input_data.size() != 2) {
|
||||
ORTX_CXX_API_THROW("[BertTokenizer]: only support one or two query.", ORT_INVALID_GRAPH);
|
||||
|
@ -323,37 +324,23 @@ void KernelBertTokenizer::Compute(OrtKernelContext* context) {
|
|||
|
||||
std::vector<int64_t> output_dim{static_cast<int64_t>(input_ids.size())};
|
||||
|
||||
SetOutput(context, 0, output_dim, input_ids);
|
||||
SetOutput(context, 1, output_dim, token_type_ids);
|
||||
SetOutput(context, 2, output_dim, attention_mask);
|
||||
}
|
||||
|
||||
const char* CustomOpBertTokenizer::GetName() const { return "BertTokenizer"; }
|
||||
|
||||
size_t CustomOpBertTokenizer::GetInputTypeCount() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType CustomOpBertTokenizer::GetInputType(size_t /* index */) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
}
|
||||
|
||||
size_t CustomOpBertTokenizer::GetOutputTypeCount() const {
|
||||
return 3;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType CustomOpBertTokenizer::GetOutputType(size_t /* index */) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
auto* p_out = output.Allocate(output_dim);
|
||||
std::copy(input_ids.begin(), input_ids.end(), p_out);
|
||||
auto* p_out1 = output1.Allocate(output_dim);
|
||||
std::copy(token_type_ids.begin(), token_type_ids.end(), p_out1);
|
||||
auto* p_out2 = output2.Allocate(output_dim);
|
||||
std::copy(attention_mask.begin(), attention_mask.end(), p_out2);
|
||||
}
|
||||
|
||||
KernelHfBertTokenizer::KernelHfBertTokenizer(const OrtApi& api, const OrtKernelInfo& info)
|
||||
: KernelBertTokenizer(api, info) {}
|
||||
|
||||
void KernelHfBertTokenizer::Compute(OrtKernelContext* context) {
|
||||
void KernelHfBertTokenizer::Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& output,
|
||||
ortc::Tensor<int64_t>& output1,
|
||||
ortc::Tensor<int64_t>& output2) {
|
||||
// Setup inputs
|
||||
const OrtValue* const input = ort_.KernelContext_GetInput(context, 0);
|
||||
std::vector<std::string> input_data;
|
||||
GetTensorMutableDataString(api_, ort_, context, input, input_data);
|
||||
auto& input_data = input.Data();
|
||||
|
||||
if (input_data.size() != 2) {
|
||||
ORTX_CXX_API_THROW("[HfBertTokenizer]: Support only two input strings.", ORT_INVALID_GRAPH);
|
||||
|
@ -368,33 +355,11 @@ void KernelHfBertTokenizer::Compute(OrtKernelContext* context) {
|
|||
std::vector<int64_t> attention_mask(input_ids.size(), 1LL);
|
||||
|
||||
const std::vector<int64_t> outer_dims{1LL, static_cast<int64_t>(input_ids.size())};
|
||||
const std::vector<int64_t> inner_dims{1LL};
|
||||
for (int32_t i = 0; i < 3; ++i) {
|
||||
OrtValue* const value = ort_.KernelContext_GetOutput(context, i, outer_dims.data(), outer_dims.size());
|
||||
OrtTensorTypeAndShapeInfo* const info = ort_.GetTensorTypeAndShape(value);
|
||||
ort_.SetDimensions(info, inner_dims.data(), inner_dims.size());
|
||||
ort_.ReleaseTensorTypeAndShapeInfo(info);
|
||||
}
|
||||
|
||||
SetOutput(context, 0, outer_dims, input_ids);
|
||||
SetOutput(context, 1, outer_dims, attention_mask);
|
||||
SetOutput(context, 2, outer_dims, token_type_ids);
|
||||
}
|
||||
|
||||
const char* CustomOpHfBertTokenizer::GetName() const { return "HfBertTokenizer"; }
|
||||
|
||||
size_t CustomOpHfBertTokenizer::GetInputTypeCount() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType CustomOpHfBertTokenizer::GetInputType(size_t /* index */) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
}
|
||||
|
||||
size_t CustomOpHfBertTokenizer::GetOutputTypeCount() const {
|
||||
return 3;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType CustomOpHfBertTokenizer::GetOutputType(size_t /* index */) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
auto* p_out = output.Allocate(outer_dims);
|
||||
std::copy(input_ids.begin(), input_ids.end(), p_out);
|
||||
auto* p_out1 = output1.Allocate(outer_dims);
|
||||
std::copy(attention_mask.begin(), attention_mask.end(), p_out1);
|
||||
auto* p_out2 = output2.Allocate(outer_dims);
|
||||
std::copy(token_type_ids.begin(), token_type_ids.end(), p_out2);
|
||||
}
|
||||
|
|
|
@ -11,7 +11,6 @@
|
|||
|
||||
#include <unordered_map>
|
||||
|
||||
|
||||
class BertTokenizerVocab final {
|
||||
public:
|
||||
explicit BertTokenizerVocab(std::string_view vocab);
|
||||
|
@ -92,29 +91,19 @@ class BertTokenizer final {
|
|||
|
||||
struct KernelBertTokenizer : BaseKernel {
|
||||
KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
void Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& output,
|
||||
ortc::Tensor<int64_t>& output1,
|
||||
ortc::Tensor<int64_t>& output2);
|
||||
|
||||
protected:
|
||||
std::unique_ptr<BertTokenizer> tokenizer_;
|
||||
};
|
||||
|
||||
struct CustomOpBertTokenizer : OrtW::CustomOpBase<CustomOpBertTokenizer, KernelBertTokenizer> {
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
};
|
||||
|
||||
struct KernelHfBertTokenizer : KernelBertTokenizer {
|
||||
KernelHfBertTokenizer(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
struct CustomOpHfBertTokenizer : OrtW::CustomOpBase<CustomOpHfBertTokenizer, KernelHfBertTokenizer> {
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
void Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& output,
|
||||
ortc::Tensor<int64_t>& output1,
|
||||
ortc::Tensor<int64_t>& output2);
|
||||
};
|
||||
|
|
|
@ -136,30 +136,30 @@ KernelBertTokenizerDecoder::KernelBertTokenizerDecoder(const OrtApi& api, const
|
|||
cls_token, mask_token, suffix_indicator);
|
||||
}
|
||||
|
||||
void KernelBertTokenizerDecoder::Compute(OrtKernelContext* context) {
|
||||
const OrtValue* ids = ort_.KernelContext_GetInput(context, 0);
|
||||
const int64_t* p_ids = ort_.GetTensorData<int64_t>(ids);
|
||||
OrtTensorDimensions ids_dim(ort_, ids);
|
||||
void KernelBertTokenizerDecoder::Compute(const ortc::Tensor<int64_t>& ids,
|
||||
const ortc::Tensor<int64_t>& positions,
|
||||
ortc::Tensor<std::string>& output) {
|
||||
const int64_t* p_ids = ids.Data();
|
||||
auto& ids_dim = ids.Shape();
|
||||
|
||||
if (!((ids_dim.size() == 1) || (ids_dim.size() == 2 && ids_dim[0] == 1))) {
|
||||
ORTX_CXX_API_THROW("[BertTokenizerDecoder]: Expect ids dimension [n] or [1,n].", ORT_INVALID_GRAPH);
|
||||
}
|
||||
|
||||
// const int64_t* p_row_indices = ort_row_indices_dim.empty() ? nullptr : ort_.GetTensorData<int64_t>(ort_row_indices);
|
||||
const OrtValue* positions = ort_.KernelContext_GetInput(context, 1);
|
||||
OrtTensorDimensions positions_dim(ort_, positions);
|
||||
auto& positions_dim = positions.Shape();
|
||||
if (use_indices_ &&
|
||||
(!((positions_dim.Size() == 0) ||
|
||||
(!((positions.NumberOfElement() == 0) ||
|
||||
(positions_dim.size() == 2 && positions_dim[1] == 2)))) {
|
||||
ORTX_CXX_API_THROW("[BertTokenizerDecoder]: Expect positions empty or a [n, 2] matrix when use indices", ORT_INVALID_GRAPH);
|
||||
}
|
||||
|
||||
const int64_t* p_positions = positions_dim.Size() == 0 ? nullptr : ort_.GetTensorData<int64_t>(positions);
|
||||
const int64_t* p_positions = positions.NumberOfElement() == 0 ? nullptr : positions.Data();
|
||||
|
||||
std::vector<std::string> result;
|
||||
std::vector<int64_t> output_dim(1);
|
||||
if (!use_indices_) {
|
||||
result.push_back(decoder_->Decode(std::vector<int64_t>(p_ids, p_ids + ids_dim.Size()), skip_special_tokens_, clean_up_tokenization_spaces_));
|
||||
result.push_back(decoder_->Decode(std::vector<int64_t>(p_ids, p_ids + ids.NumberOfElement()), skip_special_tokens_, clean_up_tokenization_spaces_));
|
||||
output_dim[0] = 1;
|
||||
} else {
|
||||
if (p_positions != nullptr) {
|
||||
|
@ -172,25 +172,5 @@ void KernelBertTokenizerDecoder::Compute(OrtKernelContext* context) {
|
|||
output_dim[0] = positions_dim[0];
|
||||
}
|
||||
}
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, output_dim.data(), output_dim.size());
|
||||
|
||||
FillTensorDataString(api_, ort_, context, result, output);
|
||||
output.SetStringOutput(result, output_dim);
|
||||
}
|
||||
|
||||
const char* CustomOpBertTokenizerDecoder::GetName() const { return "BertTokenizerDecoder"; };
|
||||
|
||||
size_t CustomOpBertTokenizerDecoder::GetInputTypeCount() const {
|
||||
return 2;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpBertTokenizerDecoder::GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
};
|
||||
|
||||
size_t CustomOpBertTokenizerDecoder::GetOutputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpBertTokenizerDecoder::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
||||
|
|
|
@ -31,7 +31,9 @@ class BertTokenizerDecoder {
|
|||
|
||||
struct KernelBertTokenizerDecoder : BaseKernel {
|
||||
KernelBertTokenizerDecoder(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
void Compute(const ortc::Tensor<int64_t>& ids,
|
||||
const ortc::Tensor<int64_t>& positions,
|
||||
ortc::Tensor<std::string>& output);
|
||||
|
||||
private:
|
||||
std::shared_ptr<BertTokenizerDecoder> decoder_;
|
||||
|
@ -39,11 +41,3 @@ struct KernelBertTokenizerDecoder : BaseKernel {
|
|||
bool skip_special_tokens_;
|
||||
bool clean_up_tokenization_spaces_;
|
||||
};
|
||||
|
||||
struct CustomOpBertTokenizerDecoder : OrtW::CustomOpBase<CustomOpBertTokenizerDecoder, KernelBertTokenizerDecoder> {
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
};
|
||||
|
|
|
@ -27,26 +27,14 @@ KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(const OrtApi& api
|
|||
max_sentence = TryToGetAttributeWithDefault("max_sentence", -1);
|
||||
}
|
||||
|
||||
void KernelBlingFireSentenceBreaker::Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
|
||||
OrtTensorDimensions dimensions(ort_, input);
|
||||
|
||||
// TODO: fix this scalar check.
|
||||
if (dimensions.Size() != 1 && dimensions[0] != 1) {
|
||||
ORTX_CXX_API_THROW("We only support string scalar.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
std::vector<std::string> input_data;
|
||||
GetTensorMutableDataString(api_, ort_, context, input, input_data);
|
||||
|
||||
std::string& input_string = input_data[0];
|
||||
int max_length = static_cast<int>(2 * input_string.size() + 1);
|
||||
void KernelBlingFireSentenceBreaker::Compute(std::string_view input,
|
||||
ortc::Tensor<std::string>& output) {
|
||||
int max_length = static_cast<int>(2 * input.size() + 1);
|
||||
std::unique_ptr<char[]> output_str = std::make_unique<char[]>(max_length);
|
||||
|
||||
int output_length = TextToSentencesWithOffsetsWithModel(input_string.data(), static_cast<int>(input_string.size()), output_str.get(), nullptr, nullptr, max_length, model_.get());
|
||||
int output_length = TextToSentencesWithOffsetsWithModel(input.data(), static_cast<int>(input.size()), output_str.get(), nullptr, nullptr, max_length, model_.get());
|
||||
if (output_length < 0) {
|
||||
ORTX_CXX_API_THROW(MakeString("splitting input:\"", input_string, "\" failed"), ORT_INVALID_ARGUMENT);
|
||||
ORTX_CXX_API_THROW(MakeString("splitting input:\"", input, "\" failed"), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
// inline split output_str by newline '\n'
|
||||
|
@ -72,25 +60,5 @@ void KernelBlingFireSentenceBreaker::Compute(OrtKernelContext* context) {
|
|||
|
||||
std::vector<int64_t> output_dimensions(1);
|
||||
output_dimensions[0] = output_sentences.size();
|
||||
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, output_dimensions.data(), output_dimensions.size());
|
||||
OrtW::ThrowOnError(api_, api_.FillStringTensor(output, output_sentences.data(), output_sentences.size()));
|
||||
output.SetStringOutput(output_sentences, output_dimensions);
|
||||
}
|
||||
|
||||
const char* CustomOpBlingFireSentenceBreaker::GetName() const { return "BlingFireSentenceBreaker"; };
|
||||
|
||||
size_t CustomOpBlingFireSentenceBreaker::GetInputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpBlingFireSentenceBreaker::GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
||||
|
||||
size_t CustomOpBlingFireSentenceBreaker::GetOutputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpBlingFireSentenceBreaker::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
||||
|
|
|
@ -17,7 +17,8 @@ extern "C" void* SetModel(const unsigned char* pImgBytes, int ModelByteCount);
|
|||
|
||||
struct KernelBlingFireSentenceBreaker : BaseKernel {
|
||||
KernelBlingFireSentenceBreaker(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
void Compute(std::string_view input,
|
||||
ortc::Tensor<std::string>& output);
|
||||
|
||||
private:
|
||||
using ModelPtr = std::shared_ptr<void>;
|
||||
|
@ -25,12 +26,3 @@ struct KernelBlingFireSentenceBreaker : BaseKernel {
|
|||
std::string model_data_;
|
||||
int max_sentence;
|
||||
};
|
||||
|
||||
struct CustomOpBlingFireSentenceBreaker : OrtW::CustomOpBase<CustomOpBlingFireSentenceBreaker,
|
||||
KernelBlingFireSentenceBreaker> {
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
};
|
||||
|
|
|
@ -99,10 +99,10 @@ struct KernelBpeDecoder : public BaseKernel {
|
|||
arr_vocab_.shrink_to_fit();
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext* context) {
|
||||
const OrtValue* ids = ort_.KernelContext_GetInput(context, 0);
|
||||
const int64_t* p_ids = ort_.GetTensorData<int64_t>(ids);
|
||||
OrtTensorDimensions ids_dim(ort_, ids);
|
||||
void Compute(const ortc::Tensor<int64_t>& ids,
|
||||
ortc::Tensor<std::string>& output) {
|
||||
const int64_t* p_ids = ids.Data();
|
||||
const auto& ids_dim = ids.Shape();
|
||||
std::vector<int64_t> output_dim = {1};
|
||||
if (ids_dim.size() > 1) {
|
||||
output_dim.resize(ids_dim.size() - 1);
|
||||
|
@ -110,14 +110,15 @@ struct KernelBpeDecoder : public BaseKernel {
|
|||
}
|
||||
|
||||
size_t seq_len = ids_dim.back();
|
||||
size_t string_batch = ids_dim.Size() / seq_len;
|
||||
size_t string_batch = ids.NumberOfElement() / seq_len;
|
||||
std::vector<std::string> decoded_strings;
|
||||
decoded_strings.reserve(string_batch);
|
||||
|
||||
for (auto n = string_batch; n > 0; n--) {
|
||||
std::string text;
|
||||
bool f_special_last = false;
|
||||
bool f_special = false;
|
||||
auto count = static_cast<size_t>(ids_dim.Size());
|
||||
auto count = static_cast<size_t>(ids.NumberOfElement());
|
||||
|
||||
for (size_t tok_idx = 0; tok_idx < count; ++tok_idx) {
|
||||
const auto token = *(p_ids + tok_idx);
|
||||
|
@ -163,9 +164,7 @@ struct KernelBpeDecoder : public BaseKernel {
|
|||
decoded_strings.emplace_back(std::move(text));
|
||||
p_ids += seq_len;
|
||||
}
|
||||
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, output_dim.data(), output_dim.size());
|
||||
FillTensorDataString(api_, ort_, context, decoded_strings, output);
|
||||
output.SetStringOutput(decoded_strings, output_dim);
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -182,26 +181,4 @@ struct KernelBpeDecoder : public BaseKernel {
|
|||
std::map<char32_t, unsigned char> byte_decoder_;
|
||||
std::map<int64_t, std::string> added_tokens_;
|
||||
std::set<int64_t> all_special_ids_;
|
||||
};
|
||||
|
||||
struct CustomOpBpeDecoder : OrtW::CustomOpBase<CustomOpBpeDecoder, KernelBpeDecoder> {
|
||||
const char* GetName() const {
|
||||
return "BpeDecoder";
|
||||
}
|
||||
|
||||
size_t GetInputTypeCount() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
}
|
||||
|
||||
size_t GetOutputTypeCount() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
}
|
||||
};
|
||||
};
|
|
@ -3,6 +3,7 @@
|
|||
// Partial code comes from other Microsoft employee.
|
||||
#include "clip_tokenizer.hpp"
|
||||
#include "narrow.h"
|
||||
#include <optional>
|
||||
|
||||
KernelClipBpeTokenizer::KernelClipBpeTokenizer(const OrtApi& api, const OrtKernelInfo& info)
|
||||
: BaseKernel(api, info) {
|
||||
|
@ -115,13 +116,14 @@ std::vector<int64_t> KernelClipBpeTokenizer::Tokenize(ustring& input, int64_t ma
|
|||
return res;
|
||||
}
|
||||
|
||||
void KernelClipBpeTokenizer::Compute(OrtKernelContext* context) {
|
||||
void KernelClipBpeTokenizer::Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& tokenize_output,
|
||||
std::optional<ortc::Tensor<int64_t>*> attention_mask,
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping) {
|
||||
// Setup inputs
|
||||
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
|
||||
std::vector<std::string> str_input;
|
||||
std::vector<std::string> str_input{input.Data()};
|
||||
std::list<OffsetMappingType> offset_map;
|
||||
GetTensorMutableDataString(api_, ort_, context, input, str_input);
|
||||
OrtTensorDimensions input_dim(ort_, input);
|
||||
const auto& input_dim = input.Shape();
|
||||
|
||||
std::vector<std::vector<int64_t>> tokenize_results;
|
||||
for (auto& str : str_input) {
|
||||
|
@ -138,18 +140,15 @@ void KernelClipBpeTokenizer::Compute(OrtKernelContext* context) {
|
|||
max_length = static_cast<size_t>(padding_length_);
|
||||
}
|
||||
|
||||
OrtTensorDimensions output_dim = input_dim;
|
||||
std::vector<int64_t> output_dim = input_dim;
|
||||
output_dim.push_back(max_length);
|
||||
|
||||
OrtTensorDimensions offset_dim = output_dim;
|
||||
std::vector<int64_t> offset_dim = output_dim;
|
||||
offset_dim.push_back(2); // tuple of offsets for each input id
|
||||
|
||||
OrtValue* tokenize_output = ort_.KernelContext_GetOutput(context, 0, output_dim.data(), output_dim.size());
|
||||
OrtValue* attention_mask = ort_.KernelContext_GetOutput(context, 1, output_dim.data(), output_dim.size());
|
||||
OrtValue* offset_mapping = ort_.KernelContext_GetOutput(context, 2, offset_dim.data(), offset_dim.size());
|
||||
auto* token = ort_.GetTensorMutableData<int64_t>(tokenize_output);
|
||||
if (attention_mask != nullptr) {
|
||||
auto* mask = ort_.GetTensorMutableData<int64_t>(attention_mask);
|
||||
auto* token = tokenize_output.Allocate(output_dim);
|
||||
if (attention_mask.has_value()) {
|
||||
auto* mask = (*attention_mask)->Allocate(output_dim);
|
||||
int idx = 0;
|
||||
for (auto& res : tokenize_results) {
|
||||
for (int64_t id : res) {
|
||||
|
@ -163,8 +162,8 @@ void KernelClipBpeTokenizer::Compute(OrtKernelContext* context) {
|
|||
}
|
||||
}
|
||||
}
|
||||
if (offset_mapping != nullptr) {
|
||||
auto* offset = ort_.GetTensorMutableData<int64_t>(offset_mapping);
|
||||
if (offset_mapping.has_value()) {
|
||||
auto* offset = (*offset_mapping)->Allocate(offset_dim);
|
||||
int idx2 = 0;
|
||||
for (auto& res : offset_map) {
|
||||
for (auto& mapping : res) {
|
||||
|
@ -188,32 +187,3 @@ void KernelClipBpeTokenizer::Compute(OrtKernelContext* context) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
const char* CustomOpClipBpeTokenizer::GetName() const {
|
||||
return "CLIPTokenizer";
|
||||
}
|
||||
|
||||
size_t CustomOpClipBpeTokenizer::GetInputTypeCount() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType CustomOpClipBpeTokenizer::GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
}
|
||||
|
||||
OrtCustomOpInputOutputCharacteristic CustomOpClipBpeTokenizer::GetInputCharacteristic(size_t /*index*/) const {
|
||||
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
|
||||
}
|
||||
|
||||
size_t CustomOpClipBpeTokenizer::GetOutputTypeCount() const {
|
||||
return 3;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType CustomOpClipBpeTokenizer::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
}
|
||||
|
||||
OrtCustomOpInputOutputCharacteristic CustomOpClipBpeTokenizer::GetOutputCharacteristic(size_t index) const {
|
||||
return index == 0 ? OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED
|
||||
: OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL;
|
||||
}
|
|
@ -6,7 +6,10 @@
|
|||
|
||||
struct KernelClipBpeTokenizer : BaseKernel {
|
||||
KernelClipBpeTokenizer(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
void Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& tokenize_output,
|
||||
std::optional<ortc::Tensor<int64_t>*> attention_mask,
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping);
|
||||
|
||||
private:
|
||||
using OffsetMappingType = std::list<std::pair<size_t, size_t>>;
|
||||
|
@ -16,13 +19,3 @@ struct KernelClipBpeTokenizer : BaseKernel {
|
|||
std::list<std::pair<int, int>> byte_list_;
|
||||
std::shared_ptr<VocabData> bbpe_tokenizer_;
|
||||
};
|
||||
|
||||
struct CustomOpClipBpeTokenizer : OrtW::CustomOpBase<CustomOpClipBpeTokenizer, KernelClipBpeTokenizer> {
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t index) const;
|
||||
};
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
|
||||
#include "gpt2_tokenizer.hpp"
|
||||
|
||||
|
||||
KernelBpeTokenizer::KernelBpeTokenizer(const OrtApi& api, const OrtKernelInfo& info)
|
||||
: BaseKernel(api, info) {
|
||||
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(&info, "vocab");
|
||||
|
@ -80,12 +79,12 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(const ustring& input, int64_t
|
|||
return res;
|
||||
}
|
||||
|
||||
void KernelBpeTokenizer::Compute(OrtKernelContext* context) {
|
||||
void KernelBpeTokenizer::Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& tokenize_output,
|
||||
std::optional<ortc::Tensor<int64_t>*> attention_mask) {
|
||||
// Setup inputs
|
||||
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
|
||||
std::vector<std::string> str_input;
|
||||
GetTensorMutableDataString(api_, ort_, context, input, str_input);
|
||||
OrtTensorDimensions input_dim(ort_, input);
|
||||
std::vector<std::string> str_input{input.Data()};
|
||||
const auto& input_dim = input.Shape();
|
||||
|
||||
std::vector<std::vector<int64_t>> tokenize_results;
|
||||
for (auto& str : str_input) {
|
||||
|
@ -101,13 +100,11 @@ void KernelBpeTokenizer::Compute(OrtKernelContext* context) {
|
|||
max_length = static_cast<size_t>(padding_length_);
|
||||
}
|
||||
|
||||
OrtTensorDimensions output_dim = input_dim;
|
||||
std::vector<int64_t> output_dim = input_dim;
|
||||
output_dim.push_back(max_length);
|
||||
OrtValue* tokenize_output = ort_.KernelContext_GetOutput(context, 0, output_dim.data(), output_dim.size());
|
||||
OrtValue* attention_mask = ort_.KernelContext_GetOutput(context, 1, output_dim.data(), output_dim.size());
|
||||
auto* token = ort_.GetTensorMutableData<int64_t>(tokenize_output);
|
||||
if (attention_mask != nullptr) {
|
||||
auto* mask = ort_.GetTensorMutableData<int64_t>(attention_mask);
|
||||
auto* token = tokenize_output.Allocate(output_dim);
|
||||
if (attention_mask.has_value()) {
|
||||
auto* mask = (*attention_mask)->Allocate(output_dim);
|
||||
int idx = 0;
|
||||
for (auto& res : tokenize_results) {
|
||||
for (int64_t id : res) {
|
||||
|
@ -134,32 +131,3 @@ void KernelBpeTokenizer::Compute(OrtKernelContext* context) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
const char* CustomOpBpeTokenizer::GetName() const {
|
||||
return "GPT2Tokenizer";
|
||||
}
|
||||
|
||||
size_t CustomOpBpeTokenizer::GetInputTypeCount() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType CustomOpBpeTokenizer::GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
}
|
||||
|
||||
OrtCustomOpInputOutputCharacteristic CustomOpBpeTokenizer::GetInputCharacteristic(size_t /*index*/) const {
|
||||
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
|
||||
}
|
||||
|
||||
size_t CustomOpBpeTokenizer::GetOutputTypeCount() const {
|
||||
return 2;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType CustomOpBpeTokenizer::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
}
|
||||
|
||||
OrtCustomOpInputOutputCharacteristic CustomOpBpeTokenizer::GetOutputCharacteristic(size_t index) const {
|
||||
return index == 0 ? OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED
|
||||
: OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL;
|
||||
}
|
|
@ -6,7 +6,9 @@
|
|||
|
||||
struct KernelBpeTokenizer : BaseKernel {
|
||||
KernelBpeTokenizer(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
void Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& tokenize_output,
|
||||
std::optional<ortc::Tensor<int64_t>*> attention_mask);
|
||||
|
||||
private:
|
||||
std::vector<int64_t> Tokenize(const ustring& input, int64_t max_length);
|
||||
|
@ -15,13 +17,3 @@ struct KernelBpeTokenizer : BaseKernel {
|
|||
std::list<std::pair<int, int>> byte_list_;
|
||||
std::shared_ptr<VocabData> bbpe_tokenizer_;
|
||||
};
|
||||
|
||||
struct CustomOpBpeTokenizer : OrtW::CustomOpBase<CustomOpBpeTokenizer, KernelBpeTokenizer> {
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t index) const;
|
||||
};
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
#include "roberta_tokenizer.hpp"
|
||||
#include "narrow.h"
|
||||
|
||||
|
||||
KernelRobertaBpeTokenizer::KernelRobertaBpeTokenizer(const OrtApi& api, const OrtKernelInfo& info)
|
||||
: BaseKernel(api, info) {
|
||||
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(&info, "vocab");
|
||||
|
@ -74,7 +73,7 @@ std::vector<int64_t> KernelRobertaBpeTokenizer::Tokenize(ustring& input, int64_t
|
|||
size_t space_dif = 0;
|
||||
if (utf8_token.at(0) == ' ') {
|
||||
offset++;
|
||||
space_dif = -1; // account for spaces used in offset map algorithm in bpe(byte_list_)
|
||||
space_dif = -1; // account for spaces used in offset map algorithm in bpe(byte_list_)
|
||||
}
|
||||
|
||||
// Get byte encodings prior to performing BPE
|
||||
|
@ -108,13 +107,14 @@ std::vector<int64_t> KernelRobertaBpeTokenizer::Tokenize(ustring& input, int64_t
|
|||
return res;
|
||||
}
|
||||
|
||||
void KernelRobertaBpeTokenizer::Compute(OrtKernelContext* context) {
|
||||
void KernelRobertaBpeTokenizer::Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& tokenize_output,
|
||||
std::optional<ortc::Tensor<int64_t>*> attention_mask,
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping) {
|
||||
// Setup inputs
|
||||
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
|
||||
std::vector<std::string> str_input;
|
||||
std::vector<std::string> str_input{input.Data()};
|
||||
std::list<OffsetMappingType> offset_map;
|
||||
GetTensorMutableDataString(api_, ort_, context, input, str_input);
|
||||
OrtTensorDimensions input_dim(ort_, input);
|
||||
const auto& input_dim = input.Shape();
|
||||
|
||||
std::vector<std::vector<int64_t>> tokenize_results;
|
||||
for (auto& str : str_input) {
|
||||
|
@ -131,18 +131,15 @@ void KernelRobertaBpeTokenizer::Compute(OrtKernelContext* context) {
|
|||
max_length = static_cast<size_t>(padding_length_);
|
||||
}
|
||||
|
||||
OrtTensorDimensions output_dim = input_dim;
|
||||
std::vector<int64_t> output_dim = input_dim;
|
||||
output_dim.push_back(max_length);
|
||||
|
||||
OrtTensorDimensions offset_dim = output_dim;
|
||||
offset_dim.push_back(2); // tuple of offsets for each input id
|
||||
std::vector<int64_t> offset_dim = output_dim;
|
||||
offset_dim.push_back(2); // tuple of offsets for each input id
|
||||
|
||||
OrtValue* tokenize_output = ort_.KernelContext_GetOutput(context, 0, output_dim.data(), output_dim.size());
|
||||
OrtValue* attention_mask = ort_.KernelContext_GetOutput(context, 1, output_dim.data(), output_dim.size());
|
||||
OrtValue* offset_mapping = ort_.KernelContext_GetOutput(context, 2, offset_dim.data(), offset_dim.size());
|
||||
auto* token = ort_.GetTensorMutableData<int64_t>(tokenize_output);
|
||||
if (attention_mask != nullptr) {
|
||||
auto* mask = ort_.GetTensorMutableData<int64_t>(attention_mask);
|
||||
auto* token = tokenize_output.Allocate(output_dim);
|
||||
if (attention_mask.has_value()) {
|
||||
auto* mask = (*attention_mask)->Allocate(output_dim);
|
||||
int idx = 0;
|
||||
for (auto& res : tokenize_results) {
|
||||
for (int64_t id : res) {
|
||||
|
@ -156,8 +153,8 @@ void KernelRobertaBpeTokenizer::Compute(OrtKernelContext* context) {
|
|||
}
|
||||
}
|
||||
}
|
||||
if (offset_mapping != nullptr) {
|
||||
auto* offset = ort_.GetTensorMutableData<int64_t>(offset_mapping);
|
||||
if (offset_mapping.has_value()) {
|
||||
auto* offset = (*offset_mapping)->Allocate(offset_dim);
|
||||
int idx2 = 0;
|
||||
for (auto& res : offset_map) {
|
||||
for (auto& mapping : res) {
|
||||
|
@ -181,32 +178,3 @@ void KernelRobertaBpeTokenizer::Compute(OrtKernelContext* context) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
const char* CustomOpRobertaBpeTokenizer::GetName() const {
|
||||
return "RobertaTokenizer";
|
||||
}
|
||||
|
||||
size_t CustomOpRobertaBpeTokenizer::GetInputTypeCount() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType CustomOpRobertaBpeTokenizer::GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
}
|
||||
|
||||
OrtCustomOpInputOutputCharacteristic CustomOpRobertaBpeTokenizer::GetInputCharacteristic(size_t /*index*/) const {
|
||||
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
|
||||
}
|
||||
|
||||
size_t CustomOpRobertaBpeTokenizer::GetOutputTypeCount() const {
|
||||
return 3;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType CustomOpRobertaBpeTokenizer::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
}
|
||||
|
||||
OrtCustomOpInputOutputCharacteristic CustomOpRobertaBpeTokenizer::GetOutputCharacteristic(size_t index) const {
|
||||
return index == 0 ? OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED
|
||||
: OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL;
|
||||
}
|
|
@ -6,7 +6,10 @@
|
|||
|
||||
struct KernelRobertaBpeTokenizer : BaseKernel {
|
||||
KernelRobertaBpeTokenizer(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
void Compute(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<int64_t>& tokenize_output,
|
||||
std::optional<ortc::Tensor<int64_t>*> attention_mask,
|
||||
std::optional<ortc::Tensor<int64_t>*> offset_mapping);
|
||||
|
||||
private:
|
||||
using OffsetMappingType = std::list<std::pair<size_t, size_t>>;
|
||||
|
@ -16,13 +19,3 @@ struct KernelRobertaBpeTokenizer : BaseKernel {
|
|||
std::list<std::pair<int, int>> byte_list_;
|
||||
std::shared_ptr<VocabData> bbpe_tokenizer_;
|
||||
};
|
||||
|
||||
struct CustomOpRobertaBpeTokenizer : OrtW::CustomOpBase<CustomOpRobertaBpeTokenizer, KernelRobertaBpeTokenizer> {
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t index) const;
|
||||
};
|
||||
|
|
|
@ -22,10 +22,10 @@ struct KernelSentencepieceDecoder : BaseKernel {
|
|||
}
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext* context) {
|
||||
const OrtValue* ids = ort_.KernelContext_GetInput(context, 0);
|
||||
const int64_t* p_ids = ort_.GetTensorData<int64_t>(ids);
|
||||
OrtTensorDimensions ids_dim(ort_, ids);
|
||||
void Compute(const ortc::Tensor<int64_t>& ids,
|
||||
ortc::Tensor<std::string>& output) {
|
||||
const int64_t* p_ids = ids.Data();
|
||||
auto& ids_dim = ids.Shape();
|
||||
|
||||
if (!((ids_dim.size() == 1) || (ids_dim.size() == 2 && ids_dim[0] == 1))) {
|
||||
ORTX_CXX_API_THROW("[SentencePieceDecoder]: Expect ids dimension [n] or [1,n].", ORT_INVALID_GRAPH);
|
||||
|
@ -34,7 +34,7 @@ struct KernelSentencepieceDecoder : BaseKernel {
|
|||
std::string decoded_string;
|
||||
std::vector<int64_t> output_dim = {1};
|
||||
std::vector<int> tids;
|
||||
std::transform(p_ids, p_ids + ids_dim.Size(),
|
||||
std::transform(p_ids, p_ids + ids.NumberOfElement(),
|
||||
std::back_inserter(tids),
|
||||
[](auto _id) { return static_cast<int>(_id); });
|
||||
auto status = tokenizer_.Decode(tids, &decoded_string);
|
||||
|
@ -43,32 +43,9 @@ struct KernelSentencepieceDecoder : BaseKernel {
|
|||
}
|
||||
|
||||
std::vector<std::string> result = {decoded_string};
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, output_dim.data(), output_dim.size());
|
||||
FillTensorDataString(api_, ort_, context, result, output);
|
||||
output.SetStringOutput(result, output_dim);
|
||||
}
|
||||
|
||||
private:
|
||||
sentencepiece::SentencePieceProcessor tokenizer_;
|
||||
};
|
||||
|
||||
struct CustomOpSentencepieceDecoder : OrtW::CustomOpBase<CustomOpSentencepieceDecoder, KernelSentencepieceDecoder> {
|
||||
const char* GetName() const {
|
||||
return "SentencepieceDecoder";
|
||||
}
|
||||
|
||||
size_t GetInputTypeCount() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
}
|
||||
|
||||
size_t GetOutputTypeCount() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
}
|
||||
};
|
||||
|
|
|
@ -31,32 +31,16 @@ static void _check_dimension_constant(OrtW::CustomOpApi ort, const OrtValue* ort
|
|||
ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
void KernelSentencepieceTokenizer::Compute(OrtKernelContext* context) {
|
||||
void KernelSentencepieceTokenizer::Compute(const ortc::Tensor<std::string>& input,
|
||||
int64_t /*nbest_size*/,
|
||||
float /*alpha*/,
|
||||
bool add_bos,
|
||||
bool add_eos,
|
||||
bool add_rev,
|
||||
ortc::Tensor<int32_t>& output,
|
||||
ortc::Tensor<int64_t>& output1) {
|
||||
// Update with the new API
|
||||
const OrtValue* ort_input = ort_.KernelContext_GetInput(context, 0);
|
||||
std::vector<std::string> str_input;
|
||||
GetTensorMutableDataString(api_, ort_, context, ort_input, str_input);
|
||||
const OrtValue* ort_nbest_size = ort_.KernelContext_GetInput(context, 1);
|
||||
const float* p_nbest_size = ort_.GetTensorData<float>(ort_nbest_size);
|
||||
const OrtValue* ort_alpha = ort_.KernelContext_GetInput(context, 2);
|
||||
const float* p_alpha = ort_.GetTensorData<float>(ort_alpha);
|
||||
const OrtValue* ort_add_bos = ort_.KernelContext_GetInput(context, 3);
|
||||
const bool* p_add_bos = ort_.GetTensorData<bool>(ort_add_bos);
|
||||
const OrtValue* ort_add_eos = ort_.KernelContext_GetInput(context, 4);
|
||||
const bool* p_add_eos = ort_.GetTensorData<bool>(ort_add_eos);
|
||||
const OrtValue* ort_add_rev = ort_.KernelContext_GetInput(context, 5);
|
||||
const bool* p_add_rev = ort_.GetTensorData<bool>(ort_add_rev);
|
||||
|
||||
(void)p_nbest_size;
|
||||
(void)p_alpha;
|
||||
|
||||
// Verifications
|
||||
_check_dimension_constant(ort_, ort_nbest_size, "nbest_size");
|
||||
_check_dimension_constant(ort_, ort_alpha, "alpha");
|
||||
_check_dimension_constant(ort_, ort_add_bos, "add_bos");
|
||||
_check_dimension_constant(ort_, ort_add_eos, "add_eos");
|
||||
_check_dimension_constant(ort_, ort_add_rev, "add_rev");
|
||||
|
||||
auto& str_input = input.Data();
|
||||
// computation
|
||||
|
||||
std::vector<int64_t> indices;
|
||||
|
@ -68,20 +52,20 @@ void KernelSentencepieceTokenizer::Compute(OrtKernelContext* context) {
|
|||
ORTX_CXX_API_THROW(MakeString("Unable to encode string '", str_input[i], "'."), ORT_INVALID_ARGUMENT);
|
||||
indices.push_back(content.size());
|
||||
|
||||
if (*p_add_rev) {
|
||||
if (*p_add_eos) {
|
||||
if (add_rev) {
|
||||
if (add_eos) {
|
||||
content.push_back(tokenizer_.eos_id());
|
||||
}
|
||||
content.insert(content.end(), inloop.rbegin(), inloop.rend());
|
||||
if (*p_add_bos) {
|
||||
if (add_bos) {
|
||||
content.push_back(tokenizer_.bos_id());
|
||||
}
|
||||
} else {
|
||||
if (*p_add_bos) {
|
||||
if (add_bos) {
|
||||
content.push_back(tokenizer_.bos_id());
|
||||
}
|
||||
content.insert(content.end(), inloop.begin(), inloop.end());
|
||||
if (*p_add_eos) {
|
||||
if (add_eos) {
|
||||
content.push_back(tokenizer_.eos_id());
|
||||
}
|
||||
}
|
||||
|
@ -91,54 +75,12 @@ void KernelSentencepieceTokenizer::Compute(OrtKernelContext* context) {
|
|||
// Setup output
|
||||
std::vector<int64_t> size_content(1);
|
||||
size_content[0] = content.size();
|
||||
OrtValue* out_content = ort_.KernelContext_GetOutput(context, 0, size_content.data(), size_content.size());
|
||||
|
||||
std::vector<int64_t> size_indices(1);
|
||||
size_indices[0] = indices.size();
|
||||
OrtValue* out_indices = ort_.KernelContext_GetOutput(context, 1, size_indices.data(), size_indices.size());
|
||||
|
||||
int* ptr_content = ort_.GetTensorMutableData<int>(out_content);
|
||||
int* ptr_content = output.Allocate(size_content);
|
||||
memcpy(ptr_content, content.data(), content.size() * sizeof(int));
|
||||
int64_t* ptr_indices = ort_.GetTensorMutableData<int64_t>(out_indices);
|
||||
int64_t* ptr_indices = output1.Allocate(size_indices);
|
||||
memcpy(ptr_indices, indices.data(), indices.size() * sizeof(int64_t));
|
||||
}
|
||||
|
||||
const char* CustomOpSentencepieceTokenizer::GetName() const {
|
||||
return "SentencepieceTokenizer";
|
||||
};
|
||||
|
||||
size_t CustomOpSentencepieceTokenizer::GetInputTypeCount() const {
|
||||
return 6;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpSentencepieceTokenizer::GetInputType(size_t index) const {
|
||||
switch (index) {
|
||||
case 0:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
case 1:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
case 2:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
case 3:
|
||||
case 4:
|
||||
case 5:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
|
||||
default:
|
||||
ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
};
|
||||
|
||||
size_t CustomOpSentencepieceTokenizer::GetOutputTypeCount() const {
|
||||
return 2;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpSentencepieceTokenizer::GetOutputType(size_t index) const {
|
||||
switch (index) {
|
||||
case 0:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
|
||||
case 1:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
default:
|
||||
ORTX_CXX_API_THROW(MakeString("[SentencepieceTokenizer] Unexpected output index ", index), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -9,17 +9,15 @@
|
|||
|
||||
struct KernelSentencepieceTokenizer : BaseKernel {
|
||||
KernelSentencepieceTokenizer(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
void Compute(const ortc::Tensor<std::string>& input,
|
||||
int64_t /*nbest_size*/,
|
||||
float /*alpha*/,
|
||||
bool add_bos,
|
||||
bool add_eos,
|
||||
bool add_rev,
|
||||
ortc::Tensor<int32_t>& output,
|
||||
ortc::Tensor<int64_t>& output1);
|
||||
|
||||
private:
|
||||
sentencepiece::SentencePieceProcessor tokenizer_;
|
||||
};
|
||||
|
||||
struct CustomOpSentencepieceTokenizer : OrtW::CustomOpBase<CustomOpSentencepieceTokenizer,
|
||||
KernelSentencepieceTokenizer> {
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
};
|
||||
|
|
|
@ -27,33 +27,42 @@
|
|||
#include "bert_tokenizer_decoder.hpp"
|
||||
#endif
|
||||
|
||||
|
||||
FxLoadCustomOpFactory LoadCustomOpClasses_Tokenizer = LoadCustomOpClasses<
|
||||
CustomOpClassBegin
|
||||
const std::vector<const OrtCustomOp*>& TokenizerLoader() {
|
||||
static OrtOpLoader op_loader(
|
||||
[]() { return nullptr; }
|
||||
#ifdef ENABLE_GPT2_TOKENIZER
|
||||
, CustomOpBpeTokenizer
|
||||
, CustomOpClipBpeTokenizer
|
||||
, CustomOpRobertaBpeTokenizer
|
||||
, CustomOpBpeDecoder
|
||||
,
|
||||
CustomCpuStruct("GPT2Tokenizer", KernelBpeTokenizer),
|
||||
CustomCpuStruct("CLIPTokenizer", KernelClipBpeTokenizer),
|
||||
CustomCpuStruct("RobertaTokenizer", KernelRobertaBpeTokenizer),
|
||||
CustomCpuStruct("BpeDecoder", KernelBpeDecoder)
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_SPM_TOKENIZER
|
||||
, CustomOpSentencepieceTokenizer
|
||||
, CustomOpSentencepieceDecoder
|
||||
,
|
||||
CustomCpuStruct("SentencepieceTokenizer", KernelSentencepieceTokenizer),
|
||||
CustomCpuStruct("SentencepieceDecoder", KernelSentencepieceDecoder)
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_WORDPIECE_TOKENIZER
|
||||
, CustomOpWordpieceTokenizer
|
||||
,
|
||||
CustomCpuStruct("WordpieceTokenizer", KernelWordpieceTokenizer)
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_BERT_TOKENIZER
|
||||
, CustomOpBasicTokenizer
|
||||
, CustomOpBertTokenizer
|
||||
, CustomOpBertTokenizerDecoder
|
||||
, CustomOpHfBertTokenizer
|
||||
,
|
||||
CustomCpuStruct("BasicTokenizer", KernelBasicTokenizer),
|
||||
CustomCpuStruct("BertTokenizer", KernelBertTokenizer),
|
||||
CustomCpuStruct("BertTokenizerDecoder", KernelBertTokenizerDecoder),
|
||||
CustomCpuStruct("HfBertTokenizer", KernelHfBertTokenizer)
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_BLINGFIRE
|
||||
, CustomOpBlingFireSentenceBreaker
|
||||
,
|
||||
CustomCpuStruct("BlingFireSentenceBreaker", KernelBlingFireSentenceBreaker)
|
||||
#endif
|
||||
>;
|
||||
);
|
||||
return op_loader.GetCustomOps();
|
||||
}
|
||||
|
||||
FxLoadCustomOpFactory LoadCustomOpClasses_Tokenizer = TokenizerLoader;
|
||||
|
|
|
@ -122,14 +122,20 @@ void KernelWordpieceTokenizer_Tokenizer(const std::unordered_map<std::u32string,
|
|||
rows.push_back(indices.size());
|
||||
}
|
||||
|
||||
void KernelWordpieceTokenizer::Compute(OrtKernelContext* context) {
|
||||
void KernelWordpieceTokenizer::Compute(const ortc::Tensor<std::string>& input,
|
||||
const ortc::Tensor<int64_t>& row_indices,
|
||||
ortc::Tensor<std::string>& output,
|
||||
ortc::Tensor<int64_t>& row_lengths,
|
||||
ortc::Tensor<int64_t>& out_row_begin,
|
||||
ortc::Tensor<int64_t>& output_limit_values) {
|
||||
// Update with the new API
|
||||
const OrtValue* ort_input = ort_.KernelContext_GetInput(context, 0);
|
||||
// make a copy as we need ustring
|
||||
std::vector<ustring> str_input;
|
||||
GetTensorMutableDataString(api_, ort_, context, ort_input, str_input);
|
||||
const OrtValue* ort_row_indices = ort_.KernelContext_GetInput(context, 1);
|
||||
OrtTensorDimensions ort_row_indices_dim(ort_, ort_row_indices);
|
||||
const int64_t* p_row_indices = ort_row_indices_dim.empty() ? nullptr : ort_.GetTensorData<int64_t>(ort_row_indices);
|
||||
str_input.reserve(input.NumberOfElement());
|
||||
for (auto& str : input.Data()) {
|
||||
str_input.emplace_back(str);
|
||||
}
|
||||
const int64_t* p_row_indices = row_indices.Shape().empty() ? nullptr : row_indices.Data();
|
||||
|
||||
std::vector<ustring> tokens;
|
||||
std::vector<int32_t> indices;
|
||||
|
@ -137,21 +143,21 @@ void KernelWordpieceTokenizer::Compute(OrtKernelContext* context) {
|
|||
|
||||
KernelWordpieceTokenizer_Tokenizer(vocab_, suffix_indicator_, unk_token_, str_input,
|
||||
tokens, indices, row_begins,
|
||||
p_row_indices, ort_row_indices_dim.Size(),
|
||||
p_row_indices, row_indices.NumberOfElement(),
|
||||
max_input_chars_per_word_);
|
||||
|
||||
std::vector<int64_t> size_content{(int64_t)indices.size()};
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, size_content.data(), size_content.size());
|
||||
FillTensorDataString(api_, ort_, context, tokens, output);
|
||||
// TODO: avoid copy
|
||||
std::vector<std::string> out_content;
|
||||
for (auto& s : tokens)
|
||||
out_content.emplace_back(s);
|
||||
output.SetStringOutput(out_content, size_content);
|
||||
|
||||
std::vector<int64_t> size_row_lengths{(int64_t)row_begins.size()};
|
||||
OrtValue* output_row_lengths = ort_.KernelContext_GetOutput(context, 1, size_row_lengths.data(), size_row_lengths.size());
|
||||
int64_t* ptr_row_lengths = row_lengths.Allocate(size_row_lengths);
|
||||
--size_row_lengths[0];
|
||||
OrtValue* output_row_begins = ort_.KernelContext_GetOutput(context, 2, size_row_lengths.data(), size_row_lengths.size());
|
||||
OrtValue* output_limit_values = ort_.KernelContext_GetOutput(context, 3, size_row_lengths.data(), size_row_lengths.size());
|
||||
int64_t* ptr_row_lengths = ort_.GetTensorMutableData<int64_t>(output_row_lengths);
|
||||
int64_t* ptr_row_begins = ort_.GetTensorMutableData<int64_t>(output_row_begins);
|
||||
int64_t* ptr_limit_values = ort_.GetTensorMutableData<int64_t>(output_limit_values);
|
||||
int64_t* ptr_row_begins = out_row_begin.Allocate(size_row_lengths);
|
||||
int64_t* ptr_limit_values = output_limit_values.Allocate(size_row_lengths);
|
||||
|
||||
int64_t i;
|
||||
for (i = 0; i < size_row_lengths[0]; ++i) {
|
||||
|
@ -163,39 +169,3 @@ void KernelWordpieceTokenizer::Compute(OrtKernelContext* context) {
|
|||
i = size_row_lengths[0];
|
||||
ptr_row_lengths[i] = row_begins[static_cast<size_t>(i)];
|
||||
}
|
||||
|
||||
const char* CustomOpWordpieceTokenizer::GetName() const {
|
||||
return "WordpieceTokenizer";
|
||||
};
|
||||
|
||||
size_t CustomOpWordpieceTokenizer::GetInputTypeCount() const {
|
||||
return 2;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpWordpieceTokenizer::GetInputType(size_t index) const {
|
||||
switch (index) {
|
||||
case 0:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
case 1:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
default:
|
||||
ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
};
|
||||
|
||||
size_t CustomOpWordpieceTokenizer::GetOutputTypeCount() const {
|
||||
return 4;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpWordpieceTokenizer::GetOutputType(size_t index) const {
|
||||
switch (index) {
|
||||
case 0:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
case 1:
|
||||
case 2:
|
||||
case 3:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
default:
|
||||
ORTX_CXX_API_THROW(MakeString("[WordpieceTokenizer] Unexpected output index ", index), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -10,10 +10,14 @@
|
|||
|
||||
#include <unordered_map>
|
||||
|
||||
|
||||
struct KernelWordpieceTokenizer : BaseKernel {
|
||||
KernelWordpieceTokenizer(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
void Compute(const ortc::Tensor<std::string>& input,
|
||||
const ortc::Tensor<int64_t>& row_indices,
|
||||
ortc::Tensor<std::string>& output,
|
||||
ortc::Tensor<int64_t>& row_lengths,
|
||||
ortc::Tensor<int64_t>& out_row_begin,
|
||||
ortc::Tensor<int64_t>& output_limit_values);
|
||||
|
||||
private:
|
||||
int64_t max_input_chars_per_word_;
|
||||
|
@ -22,14 +26,6 @@ struct KernelWordpieceTokenizer : BaseKernel {
|
|||
std::unordered_map<std::u32string, int32_t> vocab_;
|
||||
};
|
||||
|
||||
struct CustomOpWordpieceTokenizer : OrtW::CustomOpBase<CustomOpWordpieceTokenizer, KernelWordpieceTokenizer> {
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
};
|
||||
|
||||
void KernelWordpieceTokenizer_Split(const std::u32string& suffix_indicator,
|
||||
const std::u32string& text,
|
||||
std::vector<std::u32string>& words);
|
||||
|
|
|
@ -9,6 +9,10 @@
|
|||
#include <cstdint>
|
||||
|
||||
namespace ort_extensions {
|
||||
|
||||
void decode_image(const ortc::Tensor<uint8_t>& input,
|
||||
ortc::Tensor<uint8_t>& output);
|
||||
|
||||
struct KernelDecodeImage : BaseKernel {
|
||||
KernelDecodeImage(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {}
|
||||
|
||||
|
@ -50,4 +54,5 @@ struct CustomOpDecodeImage : OrtW::CustomOpBase<CustomOpDecodeImage, KernelDecod
|
|||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ort_extensions
|
||||
|
|
Загрузка…
Ссылка в новой задаче