Run python unit tests on every platform + fix all string operators (#44)
* Use appropriate API for strings * Modify all string operators * Enable ci on linux and MacOs
This commit is contained in:
Родитель
4a0f892949
Коммит
d48d825a66
|
@ -1,3 +1,7 @@
|
|||
#######
|
||||
# Linux
|
||||
#######
|
||||
|
||||
jobs:
|
||||
- job: Linux
|
||||
pool:
|
||||
|
@ -5,7 +9,7 @@ jobs:
|
|||
|
||||
strategy:
|
||||
matrix:
|
||||
py37:
|
||||
py38:
|
||||
python.version: '3.8'
|
||||
maxParallel: 1
|
||||
|
||||
|
@ -31,26 +35,26 @@ jobs:
|
|||
./ortcustomops_test
|
||||
displayName: Run the native only unit tests
|
||||
|
||||
- script: |
|
||||
python -m pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
- script: python -m pip install torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
|
||||
displayName: Install pytorch
|
||||
|
||||
- script: |
|
||||
python -m pip install -r requirements-dev.txt
|
||||
- script: python -m pip install -r requirements-dev.txt
|
||||
displayName: Install requirements-dev.txt
|
||||
|
||||
- script: |
|
||||
# FIXME: need check the CI environment for the failure.
|
||||
# python -m pytest test
|
||||
- script: python -m pytest test --verbose --verbose
|
||||
displayName: Run python test
|
||||
|
||||
#######
|
||||
# macOS
|
||||
#######
|
||||
|
||||
- job: macOS
|
||||
pool:
|
||||
vmImage: 'macOS-latest'
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
py37:
|
||||
py38:
|
||||
python.version: '3.8'
|
||||
maxParallel: 1
|
||||
|
||||
|
@ -60,14 +64,22 @@ jobs:
|
|||
versionSpec: '$(python.version)'
|
||||
addToPath: true
|
||||
|
||||
# needed for onnxruntime
|
||||
- script: brew install libomp
|
||||
displayName: 'Install omp'
|
||||
|
||||
- script: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install --upgrade setuptools
|
||||
python -m pip install -r requirements.txt
|
||||
displayName: Install requirements.txt
|
||||
|
||||
- script: python -c "import onnxruntime;print(onnxruntime.__version__)"
|
||||
displayName: Check installation
|
||||
|
||||
- script: |
|
||||
sh ./build.sh
|
||||
call activate pyenv
|
||||
python setup.py develop
|
||||
displayName: Build the library and tests
|
||||
|
||||
|
@ -76,6 +88,19 @@ jobs:
|
|||
./ortcustomops_test
|
||||
displayName: Run the native only unit tests
|
||||
|
||||
- script: python -m pip install -r requirements-dev.txt
|
||||
displayName: Install requirements-dev.txt
|
||||
|
||||
- script: python -m pip install torch torchvision torchaudio
|
||||
displayName: Install pytorch
|
||||
|
||||
- script: python -m pytest test --verbose
|
||||
displayName: Run python test
|
||||
|
||||
#########
|
||||
# Windows
|
||||
#########
|
||||
|
||||
- job: Windows
|
||||
pool:
|
||||
vmImage: 'windows-latest'
|
||||
|
@ -83,7 +108,7 @@ jobs:
|
|||
- powershell: Write-Host "##vso[task.prependpath]$env:CONDA\Scripts"
|
||||
displayName: Add conda to PATH
|
||||
|
||||
- script: conda create --yes --quiet --name pyenv -c conda-forge python=3.7 numpy
|
||||
- script: conda create --yes --quiet --name pyenv -c conda-forge python=3.8 numpy
|
||||
displayName: Create Anaconda environment
|
||||
|
||||
- script: |
|
||||
|
@ -106,14 +131,9 @@ jobs:
|
|||
|
||||
- script: |
|
||||
call activate pyenv
|
||||
python -m pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
python -m pip install torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio===0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
|
||||
displayName: Install pytorch
|
||||
|
||||
- script: |
|
||||
call activate pyenv
|
||||
python -m pip install -r requirements-dev.txt
|
||||
displayName: Install requirements-dev.txt
|
||||
|
||||
- script: |
|
||||
call activate pyenv
|
||||
python -m pytest test
|
||||
|
|
|
@ -11,7 +11,7 @@ typedef CPTR_OrtCustomOp (*FxGetSchemaInstance)();
|
|||
FxGetSchemaInstance const* GetCustomOpSchemaList();
|
||||
|
||||
struct BaseKernel {
|
||||
BaseKernel(OrtApi api) : api_(api), ort_(api_) { }
|
||||
BaseKernel(OrtApi api) : api_(api), ort_(api_) {}
|
||||
|
||||
protected:
|
||||
OrtApi api_; // keep a copy of the struct, whose ref is used in the ort_
|
||||
|
@ -25,8 +25,14 @@ struct OrtTensorDimensions : std::vector<int64_t> {
|
|||
ort.ReleaseTensorTypeAndShapeInfo(info);
|
||||
}
|
||||
const std::vector<int64_t>& GetDims() const { return *this; }
|
||||
int64_t Size() const {
|
||||
int64_t s = 1.;
|
||||
for (auto it = begin(); it != end(); ++it)
|
||||
s *= *it;
|
||||
return s;
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(ENABLE_TOKENIZER)
|
||||
const OrtCustomOp** LoadTokenizerSchemaList();
|
||||
#endif // ENABLE_TEXT_DOMAIN
|
||||
#endif // ENABLE_TEXT_DOMAIN
|
||||
|
|
|
@ -2,7 +2,9 @@
|
|||
// Licensed under the MIT License.
|
||||
#pragma once
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "utils.h"
|
||||
#include "string_common.h"
|
||||
|
||||
template <typename T1, typename T2, typename T3>
|
||||
class BroadcastIteratorRight {
|
||||
|
@ -140,3 +142,36 @@ void KernelEqual_Compute(Ort::CustomOpApi& ort_, OrtKernelContext* context) {
|
|||
state.loop(cmp, state);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void KernelEqual_Compute<std::string>(Ort::CustomOpApi& ort_, OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
||||
const OrtValue* input_Y = ort_.KernelContext_GetInput(context, 1);
|
||||
std::vector<std::string> X, Y;
|
||||
GetTensorMutableDataString(ort_, context, input_X, X);
|
||||
GetTensorMutableDataString(ort_, context, input_Y, Y);
|
||||
|
||||
// Setup output
|
||||
OrtTensorDimensions dimensions_x(ort_, input_X);
|
||||
OrtTensorDimensions dimensions_y(ort_, input_Y);
|
||||
Compare<std::string> cmp;
|
||||
|
||||
typename BroadcastIteratorRight<std::string, std::string, bool>::BroadcastIteratorRightState state;
|
||||
if (Size(dimensions_x) >= Size(dimensions_y)) {
|
||||
OrtValue* v = ort_.KernelContext_GetOutput(context, 0, dimensions_x.data(), dimensions_x.size());
|
||||
bool* out = ort_.GetTensorMutableData<bool>(v);
|
||||
BroadcastIteratorRight<std::string, std::string, bool> iter(
|
||||
dimensions_x, dimensions_y, X.data(), Y.data(), out);
|
||||
state.init(iter);
|
||||
state.loop(cmp, state);
|
||||
} else {
|
||||
// Operator Equal is commutative.
|
||||
OrtValue* v = ort_.KernelContext_GetOutput(context, 0, dimensions_y.data(), dimensions_y.size());
|
||||
bool* out = ort_.GetTensorMutableData<bool>(v);
|
||||
BroadcastIteratorRight<std::string, std::string, bool> iter(
|
||||
dimensions_y, dimensions_x, Y.data(), X.data(), out);
|
||||
state.init(iter);
|
||||
state.loop(cmp, state);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "string_common.h"
|
||||
#include "utils.h"
|
||||
|
||||
const OrtApi* _GetApi(Ort::CustomOpApi& ort) {
|
||||
if (sizeof(Ort::CustomOpApi) == sizeof(OrtApi*)) {
|
||||
// The following line should be replaced when an accessor is available.
|
||||
// ort.api_ is not accessible (marked as private)
|
||||
// Ort::GetApi() returns null.
|
||||
// InitApi is missing when linking.
|
||||
// OrtGetApiBase() is missing when linking.
|
||||
const OrtApi* api = (const OrtApi*)((OrtApi**)&ort) - 1;
|
||||
/*
|
||||
// The following code checks api is equal to the expected
|
||||
// pointer stored ins Ort::CustomOpApi ort as a private member.
|
||||
// The following method can be added `const OrtApi& Api() { return api_; }`
|
||||
// to give access to this value.
|
||||
if (api != &(ort.Api())) {
|
||||
// Ort::InitApi(); - missing from link
|
||||
auto diff = (int64_t)(&(ort.Api()) - api);
|
||||
throw std::runtime_error(MakeString(
|
||||
"Internal error, pointers are different: ",
|
||||
api, "!=", &(ort.Api()), " (expected) (other value ", &(Ort::GetApi()),
|
||||
", delta=", diff, ")."));
|
||||
}
|
||||
*/
|
||||
return api;
|
||||
}
|
||||
throw std::runtime_error(MakeString(
|
||||
"Unable to get an OrtApi pointer from CustomOpApi. Variable ort is not a pointer. Size ",
|
||||
sizeof(Ort::CustomOpApi), "!=", sizeof(OrtApi*), " (expected)."));
|
||||
}
|
||||
|
||||
void GetTensorMutableDataString(Ort::CustomOpApi& ort, OrtKernelContext* context,
|
||||
const OrtValue* value, std::vector<std::string>& output) {
|
||||
const OrtApi& api = *_GetApi(ort);
|
||||
OrtTensorDimensions dimensions(ort, value);
|
||||
size_t len = static_cast<size_t>(dimensions.Size());
|
||||
size_t data_len;
|
||||
Ort::ThrowOnError(api, api.GetStringTensorDataLength(value, &data_len));
|
||||
output.resize(len);
|
||||
std::vector<char> result(data_len + len + 1, '\0');
|
||||
std::vector<size_t> offsets(len);
|
||||
Ort::ThrowOnError(api, api.GetStringTensorContent(value, (void*)result.data(), data_len, offsets.data(), offsets.size()));
|
||||
output.resize(len);
|
||||
for (int64_t i = (int64_t)len - 1; i >= 0; --i) {
|
||||
if (i < len - 1)
|
||||
result[offsets[i + (int64_t)1]] = '\0';
|
||||
output[i] = result.data() + offsets[i];
|
||||
}
|
||||
}
|
||||
|
||||
void FillTensorDataString(Ort::CustomOpApi& ort, OrtKernelContext* context,
|
||||
const std::vector<std::string>& value, OrtValue* output) {
|
||||
const OrtApi& api = *_GetApi(ort);
|
||||
std::vector<const char*> temp(value.size());
|
||||
for (size_t i = 0; i < value.size(); ++i) {
|
||||
temp[i] = value[i].c_str();
|
||||
}
|
||||
api.FillStringTensor(output, temp.data(), value.size());
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "kernels.h"
|
||||
|
||||
// Retrieves a vector of strings if the input type is std::string.
|
||||
// It is a copy of the input data and can be modified to compute the output.
|
||||
void GetTensorMutableDataString(Ort::CustomOpApi& ort, OrtKernelContext* context,
|
||||
const OrtValue* value, std::vector<std::string>& output);
|
||||
|
||||
void FillTensorDataString(Ort::CustomOpApi& ort, OrtKernelContext* context,
|
||||
const std::vector<std::string>& value, OrtValue* output);
|
|
@ -7,6 +7,7 @@
|
|||
#include <algorithm>
|
||||
#include "re2/re2.h"
|
||||
#include "farmhash.h"
|
||||
#include "string_common.h"
|
||||
|
||||
// Source: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/platform/hash.cc#L28
|
||||
static inline uint64_t ByteAs64(char c) { return static_cast<uint64_t>(c) & 0xff; }
|
||||
|
@ -81,9 +82,10 @@ KernelStringHash::KernelStringHash(OrtApi api) : BaseKernel(api) {
|
|||
void KernelStringHash::Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
|
||||
const std::string* str_input = ort_.GetTensorData<std::string>(input);
|
||||
const OrtValue* num_buckets = ort_.KernelContext_GetInput(context, 1);
|
||||
const int64_t* p_num_buckets = ort_.GetTensorData<int64_t>(num_buckets);
|
||||
std::vector<std::string> str_input;
|
||||
GetTensorMutableDataString(ort_, context, input, str_input);
|
||||
|
||||
// Verifications
|
||||
OrtTensorDimensions num_buckets_dimensions(ort_, num_buckets);
|
||||
|
@ -143,9 +145,10 @@ KernelStringHashFast::KernelStringHashFast(OrtApi api) : BaseKernel(api) {
|
|||
void KernelStringHashFast::Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
|
||||
const std::string* str_input = ort_.GetTensorData<std::string>(input);
|
||||
const OrtValue* num_buckets = ort_.KernelContext_GetInput(context, 1);
|
||||
const int64_t* p_num_buckets = ort_.GetTensorData<int64_t>(num_buckets);
|
||||
std::vector<std::string> str_input;
|
||||
GetTensorMutableDataString(ort_, context, input, str_input);
|
||||
|
||||
// Verifications
|
||||
OrtTensorDimensions num_buckets_dimensions(ort_, num_buckets);
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "string_join.hpp"
|
||||
#include "string_common.h"
|
||||
|
||||
KernelStringJoin::KernelStringJoin(OrtApi api) : BaseKernel(api) {
|
||||
}
|
||||
|
@ -9,11 +10,12 @@ KernelStringJoin::KernelStringJoin(OrtApi api) : BaseKernel(api) {
|
|||
void KernelStringJoin::Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
||||
const std::string* X = ort_.GetTensorData<std::string>(input_X);
|
||||
const OrtValue* input_sep = ort_.KernelContext_GetInput(context, 1);
|
||||
const std::string* sep = ort_.GetTensorData<std::string>(input_sep);
|
||||
const OrtValue* input_axis = ort_.KernelContext_GetInput(context, 2);
|
||||
const int64_t* axis = ort_.GetTensorData<int64_t>(input_axis);
|
||||
std::vector<std::string> X, sep;
|
||||
GetTensorMutableDataString(ort_, context, input_X, X);
|
||||
GetTensorMutableDataString(ort_, context, input_sep, sep);
|
||||
|
||||
// Setup output
|
||||
OrtTensorDimensions dimensions_sep(ort_, input_sep);
|
||||
|
@ -38,11 +40,10 @@ void KernelStringJoin::Compute(OrtKernelContext* context) {
|
|||
}
|
||||
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions_out.data(), dimensions_out.size());
|
||||
std::string* out = ort_.GetTensorMutableData<std::string>(output);
|
||||
|
||||
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
|
||||
int64_t size = ort_.GetTensorShapeElementCount(output_info);
|
||||
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
|
||||
std::vector<std::string> out(size);
|
||||
|
||||
// Do computation
|
||||
int64_t h = 1;
|
||||
|
@ -59,12 +60,13 @@ void KernelStringJoin::Compute(OrtKernelContext* context) {
|
|||
std::ostringstream st;
|
||||
int64_t index = ri + li * inc;
|
||||
for (int64_t j = 0; j < n_red; ++j, index += h) {
|
||||
st << X[index] << *sep;
|
||||
st << X[index] << sep[0];
|
||||
}
|
||||
st << X[index];
|
||||
out[pos] = st.str();
|
||||
}
|
||||
}
|
||||
FillTensorDataString(ort_, context, out, output);
|
||||
}
|
||||
|
||||
void* CustomOpStringJoin::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
#include <cmath>
|
||||
#include <algorithm>
|
||||
#include "re2/re2.h"
|
||||
#include "string_common.h"
|
||||
|
||||
KernelStringRegexReplace::KernelStringRegexReplace(OrtApi api) : BaseKernel(api) {
|
||||
}
|
||||
|
@ -13,11 +14,12 @@ KernelStringRegexReplace::KernelStringRegexReplace(OrtApi api) : BaseKernel(api)
|
|||
void KernelStringRegexReplace::Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
|
||||
const std::string* str_input = ort_.GetTensorData<std::string>(input);
|
||||
const OrtValue* pattern = ort_.KernelContext_GetInput(context, 1);
|
||||
const std::string* str_pattern = ort_.GetTensorData<std::string>(pattern);
|
||||
const OrtValue* rewrite = ort_.KernelContext_GetInput(context, 2);
|
||||
const std::string* str_rewrite = ort_.GetTensorData<std::string>(rewrite);
|
||||
std::vector<std::string> str_input, str_pattern, str_rewrite;
|
||||
GetTensorMutableDataString(ort_, context, input, str_input);
|
||||
GetTensorMutableDataString(ort_, context, pattern, str_pattern);
|
||||
GetTensorMutableDataString(ort_, context, rewrite, str_rewrite);
|
||||
|
||||
// Verifications
|
||||
OrtTensorDimensions pattern_dimensions(ort_, pattern);
|
||||
|
@ -34,20 +36,19 @@ void KernelStringRegexReplace::Compute(OrtKernelContext* context) {
|
|||
// Setup output
|
||||
OrtTensorDimensions dimensions(ort_, input);
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
|
||||
std::string* out = ort_.GetTensorMutableData<std::string>(output);
|
||||
|
||||
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
|
||||
int64_t size = ort_.GetTensorShapeElementCount(output_info);
|
||||
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
|
||||
|
||||
re2::StringPiece piece(*str_rewrite);
|
||||
re2::RE2 reg(*str_pattern);
|
||||
re2::StringPiece piece(str_rewrite[0]);
|
||||
re2::RE2 reg(str_pattern[0]);
|
||||
|
||||
// Do computation
|
||||
for (int64_t i = 0; i < size; i++) {
|
||||
out[i] = str_input[i];
|
||||
re2::RE2::GlobalReplace(out + i, reg, piece);
|
||||
re2::RE2::GlobalReplace(&(str_input[i]), reg, piece);
|
||||
}
|
||||
FillTensorDataString(ort_, context, str_input, output);
|
||||
}
|
||||
|
||||
void* CustomOpStringRegexReplace::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "string_split.hpp"
|
||||
#include "string_common.h"
|
||||
|
||||
KernelStringSplit::KernelStringSplit(OrtApi api) : BaseKernel(api) {
|
||||
}
|
||||
|
@ -9,11 +10,12 @@ KernelStringSplit::KernelStringSplit(OrtApi api) : BaseKernel(api) {
|
|||
void KernelStringSplit::Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
||||
const std::string* X = ort_.GetTensorData<std::string>(input_X);
|
||||
const OrtValue* input_sep = ort_.KernelContext_GetInput(context, 1);
|
||||
const std::string* sep = ort_.GetTensorData<std::string>(input_sep);
|
||||
const OrtValue* input_skip_empty = ort_.KernelContext_GetInput(context, 2);
|
||||
const bool* skip_empty = ort_.GetTensorData<bool>(input_skip_empty);
|
||||
std::vector<std::string> X, sep;
|
||||
GetTensorMutableDataString(ort_, context, input_X, X);
|
||||
GetTensorMutableDataString(ort_, context, input_sep, sep);
|
||||
|
||||
// Setup output
|
||||
OrtTensorDimensions dimensions_sep(ort_, input_sep);
|
||||
|
@ -30,7 +32,7 @@ void KernelStringSplit::Compute(OrtKernelContext* context) {
|
|||
std::vector<int64_t> indices;
|
||||
int64_t maxc = 0;
|
||||
int64_t col;
|
||||
std::string delimiter = *sep;
|
||||
std::string delimiter = sep[0];
|
||||
if (delimiter.size() == 0) {
|
||||
char word[2] = "a";
|
||||
for (int64_t row = 0; row < dimensions[0]; ++row) {
|
||||
|
@ -86,13 +88,12 @@ void KernelStringSplit::Compute(OrtKernelContext* context) {
|
|||
OrtValue* out_shape = ort_.KernelContext_GetOutput(context, 2, shape_shape.data(), shape_shape.size());
|
||||
|
||||
int64_t* p_indices = ort_.GetTensorMutableData<int64_t>(out_indices);
|
||||
std::string* p_text = ort_.GetTensorMutableData<std::string>(out_text);
|
||||
int64_t* p_shape = ort_.GetTensorMutableData<int64_t>(out_shape);
|
||||
|
||||
memcpy(p_indices, indices.data(), indices.size() * sizeof(int64_t));
|
||||
p_shape[0] = dimensions[0];
|
||||
p_shape[1] = maxc;
|
||||
std::copy(words.begin(), words.end(), p_text);
|
||||
FillTensorDataString(ort_, context, words, out_text);
|
||||
}
|
||||
|
||||
void* CustomOpStringSplit::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
|
||||
|
|
|
@ -2,33 +2,31 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "string_upper.hpp"
|
||||
#include "string_common.h"
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
|
||||
|
||||
KernelStringUpper::KernelStringUpper(OrtApi api) : BaseKernel(api) {
|
||||
}
|
||||
|
||||
void KernelStringUpper::Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
||||
const std::string* X = ort_.GetTensorData<std::string>(input_X);
|
||||
std::vector<std::string> X;
|
||||
GetTensorMutableDataString(ort_, context, input_X, X);
|
||||
|
||||
// Setup output
|
||||
OrtTensorDimensions dimensions(ort_, input_X);
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
|
||||
std::string* out = ort_.GetTensorMutableData<std::string>(output);
|
||||
|
||||
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
|
||||
int64_t size = ort_.GetTensorShapeElementCount(output_info);
|
||||
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
|
||||
|
||||
// Do computation
|
||||
for (int64_t i = 0; i < size; i++) {
|
||||
out[i] = X[i];
|
||||
std::transform(out[i].begin(), out[i].end(), out[i].begin(), ::toupper);
|
||||
for (int64_t i = 0; i < (int64_t)X.size(); ++i) {
|
||||
std::transform(X[i].begin(), X[i].end(), X[i].begin(), ::toupper);
|
||||
}
|
||||
|
||||
// Fills the output
|
||||
FillTensorDataString(ort_, context, X, output);
|
||||
}
|
||||
|
||||
void* CustomOpStringUpper::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
|
||||
|
@ -52,4 +50,3 @@ size_t CustomOpStringUpper::GetOutputTypeCount() const {
|
|||
ONNXTensorElementDataType CustomOpStringUpper::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "utils.h"
|
||||
#include "pykernel.h"
|
||||
#include "kernels/string_hash.hpp"
|
||||
#include "kernels/string_common.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
|
@ -34,11 +35,11 @@ const int PyCustomOpDef::dt_int64 = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; // m
|
|||
const int PyCustomOpDef::dt_string = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; // maps to c++ type std::string
|
||||
const int PyCustomOpDef::dt_bool = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
|
||||
const int PyCustomOpDef::dt_float16 = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
|
||||
const int PyCustomOpDef::dt_double = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; // maps to c type double
|
||||
const int PyCustomOpDef::dt_uint32 = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32; // maps to c type uint32_t
|
||||
const int PyCustomOpDef::dt_uint64 = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64; // maps to c type uint64_t
|
||||
const int PyCustomOpDef::dt_double = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; // maps to c type double
|
||||
const int PyCustomOpDef::dt_uint32 = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32; // maps to c type uint32_t
|
||||
const int PyCustomOpDef::dt_uint64 = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64; // maps to c type uint64_t
|
||||
// complex with float32 real and imaginary components
|
||||
const int PyCustomOpDef::dt_complex64 = ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64;
|
||||
const int PyCustomOpDef::dt_complex64 = ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64;
|
||||
// complex with float64 real and imaginary components
|
||||
const int PyCustomOpDef::dt_complex128 = ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128;
|
||||
// Non-IEEE floating-point format based on IEEE754 single-precision
|
||||
|
@ -112,7 +113,7 @@ static size_t element_size(ONNXTensorElementDataType dt) {
|
|||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128:
|
||||
return sizeof(std::complex<double>);
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
|
||||
return sizeof(std::string*);
|
||||
throw std::runtime_error("OrtValue content cannot be casted into std::string*.");
|
||||
default:
|
||||
throw std::runtime_error("No corresponding Numpy data type/Tensor data Type.");
|
||||
}
|
||||
|
@ -166,7 +167,9 @@ struct PyCustomOpDefImpl : public PyCustomOpDef {
|
|||
return c;
|
||||
}
|
||||
|
||||
static py::object BuildPyObjFromTensor(const void* p, const shape_t& shape, ONNXTensorElementDataType dtype) {
|
||||
static py::object BuildPyObjFromTensor(
|
||||
Ort::CustomOpApi& ort, OrtKernelContext* context, const OrtValue* value,
|
||||
const shape_t& shape, ONNXTensorElementDataType dtype) {
|
||||
std::vector<npy_intp> npy_dims;
|
||||
for (auto n : shape) {
|
||||
npy_dims.push_back(n);
|
||||
|
@ -180,11 +183,13 @@ struct PyCustomOpDefImpl : public PyCustomOpDef {
|
|||
if (dtype == ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
|
||||
py::object* outObj = static_cast<py::object*>(out_ptr);
|
||||
auto size = calc_size_from_shape(shape);
|
||||
const std::string* src = (const std::string*)p;
|
||||
for (int i = 0; i < size; i++, src++) {
|
||||
outObj[i] = py::cast(*src);
|
||||
std::vector<std::string> src;
|
||||
GetTensorMutableDataString(ort, context, value, src);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
outObj[i] = py::cast(src[i]);
|
||||
}
|
||||
} else {
|
||||
const void* p = (const void*)ort.GetTensorData<char>(value);
|
||||
size_t size_type = element_size(dtype);
|
||||
memcpy(out_ptr, p, size_type * calc_size_from_shape(shape));
|
||||
}
|
||||
|
@ -215,7 +220,7 @@ void PyCustomOpKernel::Compute(OrtKernelContext* context) {
|
|||
// std::unique_lock<std::mutex> lck(op_mutex);
|
||||
// is_ready = true;
|
||||
// op_cv.notify_all();
|
||||
// std::this_thread::sleep_for(std::chrono::milliseconds(5000));
|
||||
// std::this_thread::sleep_for(std::chrono::milliseconds(5000));
|
||||
size_t n_inputs = ort_.KernelContext_GetInputCount(context);
|
||||
size_t n_outputs = ort_.KernelContext_GetOutputCount(context);
|
||||
|
||||
|
@ -252,7 +257,7 @@ void PyCustomOpKernel::Compute(OrtKernelContext* context) {
|
|||
py::list pyinputs;
|
||||
for (auto it = inputs.begin(); it != inputs.end(); ++it) {
|
||||
py::object input0 = PyCustomOpDefImpl::BuildPyObjFromTensor(
|
||||
(const void*)ort_.GetTensorData<float>(it->input_X), it->dimensions, it->dtype);
|
||||
ort_, context, it->input_X, it->dimensions, it->dtype);
|
||||
pyinputs.append(input0);
|
||||
}
|
||||
|
||||
|
@ -267,19 +272,14 @@ void PyCustomOpKernel::Compute(OrtKernelContext* context) {
|
|||
OrtValue* output = ort_.KernelContext_GetOutput(context, no, dims.data(), dims.size());
|
||||
OrtTensorTypeAndShapeInfo* o_info = ort_.GetTensorTypeAndShape(output);
|
||||
ONNXTensorElementDataType o_dtype = ort_.GetTensorElementType(o_info);
|
||||
const void* Y = (const void*)ort_.GetTensorData<float>(output);
|
||||
ort_.ReleaseTensorTypeAndShapeInfo(o_info);
|
||||
void* out = (void*)ort_.GetTensorMutableData<float>(output);
|
||||
|
||||
if (o_dtype == ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
|
||||
auto retval = fetch[2 + no * 2].cast<std::vector<std::string>>();
|
||||
std::string* type_outPtr = (std::string*)out;
|
||||
std::string* end = type_outPtr + retval.size();
|
||||
const std::string* source = (const std::string*)retval.data();
|
||||
for (; type_outPtr != end; ++type_outPtr, ++source) {
|
||||
*type_outPtr = *source;
|
||||
}
|
||||
std::vector<std::string> retval = fetch[2 + no * 2].cast<std::vector<std::string>>();
|
||||
FillTensorDataString(ort_, context, retval, output);
|
||||
} else {
|
||||
const void* Y = (const void*)ort_.GetTensorData<float>(output);
|
||||
void* out = (void*)ort_.GetTensorMutableData<float>(output);
|
||||
py::array retval = fetch[2 + no * 2].cast<py::array>();
|
||||
if (element_size(o_dtype) != retval.itemsize()) {
|
||||
switch (o_dtype) {
|
||||
|
@ -395,7 +395,7 @@ const OrtCustomOp* FetchPyCustomOps(size_t& count) {
|
|||
return ptr;
|
||||
}
|
||||
|
||||
bool EnablePyCustomOps(bool enabled){
|
||||
bool EnablePyCustomOps(bool enabled) {
|
||||
static bool f_pyop_enabled = true;
|
||||
bool last = f_pyop_enabled;
|
||||
f_pyop_enabled = enabled;
|
||||
|
@ -430,8 +430,7 @@ void AddObjectMethods(pybind11::module& m) {
|
|||
.def_readwrite("obj_id", &PyCustomOpDef::obj_id)
|
||||
.def_readwrite("input_types", &PyCustomOpDef::input_types)
|
||||
.def_readwrite("output_types", &PyCustomOpDef::output_types)
|
||||
.def_static("install_hooker", [](py::object obj) {
|
||||
PyCustomOpDefImpl::op_invoker = std::make_unique<PyCustomOpDefImpl::callback_t>(obj); })
|
||||
.def_static("install_hooker", [](py::object obj) { PyCustomOpDefImpl::op_invoker = std::make_unique<PyCustomOpDefImpl::callback_t>(obj); })
|
||||
.def_readonly_static("undefined", &PyCustomOpDef::undefined)
|
||||
.def_readonly_static("dt_float", &PyCustomOpDef::dt_float)
|
||||
.def_readonly_static("dt_uint8", &PyCustomOpDef::dt_uint8)
|
||||
|
|
|
@ -347,9 +347,11 @@ class TestPythonOpString(unittest.TestCase):
|
|||
onnx_model = _create_test_model_string_upper('')
|
||||
self.assertIn('op_type: "StringUpper"', str(onnx_model))
|
||||
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
|
||||
input_1 = np.array([["Abcé"]])
|
||||
input_1 = np.array([["R"], ["Abcé"], ["ABC"], ["A"]])
|
||||
txout = sess.run(None, {'input_1': input_1})
|
||||
self.assertEqual(txout[0].tolist(), np.array([["ABCé"]]).tolist())
|
||||
self.assertEqual(
|
||||
txout[0].tolist(),
|
||||
np.array([["R"], ["ABCé"], ["ABC"], ["A"]]).tolist())
|
||||
|
||||
def test_string_upper_python(self):
|
||||
so = _ort.SessionOptions()
|
||||
|
|
1344
tokenizer/gpt2tok.cc
1344
tokenizer/gpt2tok.cc
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
Загрузка…
Ссылка в новой задаче