Run python unit tests on every platform + fix all string operators (#44)

* Use appropriate API for strings
* Modify all string operators
* Enable ci on linux and MacOs
This commit is contained in:
Xavier Dupré 2021-01-21 19:49:16 +01:00 коммит произвёл GitHub
Родитель 4a0f892949
Коммит d48d825a66
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
13 изменённых файлов: 891 добавлений и 745 удалений

Просмотреть файл

@ -1,3 +1,7 @@
#######
# Linux
#######
jobs:
- job: Linux
pool:
@ -5,7 +9,7 @@ jobs:
strategy:
matrix:
py37:
py38:
python.version: '3.8'
maxParallel: 1
@ -31,26 +35,26 @@ jobs:
./ortcustomops_test
displayName: Run the native only unit tests
- script: |
python -m pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
- script: python -m pip install torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
displayName: Install pytorch
- script: |
python -m pip install -r requirements-dev.txt
- script: python -m pip install -r requirements-dev.txt
displayName: Install requirements-dev.txt
- script: |
# FIXME: need check the CI environment for the failure.
# python -m pytest test
- script: python -m pytest test --verbose --verbose
displayName: Run python test
#######
# macOS
#######
- job: macOS
pool:
vmImage: 'macOS-latest'
strategy:
matrix:
py37:
py38:
python.version: '3.8'
maxParallel: 1
@ -60,14 +64,22 @@ jobs:
versionSpec: '$(python.version)'
addToPath: true
# needed for onnxruntime
- script: brew install libomp
displayName: 'Install omp'
- script: |
python -m pip install --upgrade pip
python -m pip install --upgrade setuptools
python -m pip install -r requirements.txt
displayName: Install requirements.txt
- script: python -c "import onnxruntime;print(onnxruntime.__version__)"
displayName: Check installation
- script: |
sh ./build.sh
call activate pyenv
python setup.py develop
displayName: Build the library and tests
@ -76,6 +88,19 @@ jobs:
./ortcustomops_test
displayName: Run the native only unit tests
- script: python -m pip install -r requirements-dev.txt
displayName: Install requirements-dev.txt
- script: python -m pip install torch torchvision torchaudio
displayName: Install pytorch
- script: python -m pytest test --verbose
displayName: Run python test
#########
# Windows
#########
- job: Windows
pool:
vmImage: 'windows-latest'
@ -83,7 +108,7 @@ jobs:
- powershell: Write-Host "##vso[task.prependpath]$env:CONDA\Scripts"
displayName: Add conda to PATH
- script: conda create --yes --quiet --name pyenv -c conda-forge python=3.7 numpy
- script: conda create --yes --quiet --name pyenv -c conda-forge python=3.8 numpy
displayName: Create Anaconda environment
- script: |
@ -106,14 +131,9 @@ jobs:
- script: |
call activate pyenv
python -m pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
python -m pip install torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio===0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
displayName: Install pytorch
- script: |
call activate pyenv
python -m pip install -r requirements-dev.txt
displayName: Install requirements-dev.txt
- script: |
call activate pyenv
python -m pytest test

Просмотреть файл

@ -11,7 +11,7 @@ typedef CPTR_OrtCustomOp (*FxGetSchemaInstance)();
FxGetSchemaInstance const* GetCustomOpSchemaList();
struct BaseKernel {
BaseKernel(OrtApi api) : api_(api), ort_(api_) { }
BaseKernel(OrtApi api) : api_(api), ort_(api_) {}
protected:
OrtApi api_; // keep a copy of the struct, whose ref is used in the ort_
@ -25,8 +25,14 @@ struct OrtTensorDimensions : std::vector<int64_t> {
ort.ReleaseTensorTypeAndShapeInfo(info);
}
const std::vector<int64_t>& GetDims() const { return *this; }
int64_t Size() const {
int64_t s = 1.;
for (auto it = begin(); it != end(); ++it)
s *= *it;
return s;
}
};
#if defined(ENABLE_TOKENIZER)
const OrtCustomOp** LoadTokenizerSchemaList();
#endif // ENABLE_TEXT_DOMAIN
#endif // ENABLE_TEXT_DOMAIN

Просмотреть файл

@ -2,7 +2,9 @@
// Licensed under the MIT License.
#pragma once
#include <vector>
#include <string>
#include "utils.h"
#include "string_common.h"
template <typename T1, typename T2, typename T3>
class BroadcastIteratorRight {
@ -140,3 +142,36 @@ void KernelEqual_Compute(Ort::CustomOpApi& ort_, OrtKernelContext* context) {
state.loop(cmp, state);
}
}
template <>
void KernelEqual_Compute<std::string>(Ort::CustomOpApi& ort_, OrtKernelContext* context) {
// Setup inputs
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
const OrtValue* input_Y = ort_.KernelContext_GetInput(context, 1);
std::vector<std::string> X, Y;
GetTensorMutableDataString(ort_, context, input_X, X);
GetTensorMutableDataString(ort_, context, input_Y, Y);
// Setup output
OrtTensorDimensions dimensions_x(ort_, input_X);
OrtTensorDimensions dimensions_y(ort_, input_Y);
Compare<std::string> cmp;
typename BroadcastIteratorRight<std::string, std::string, bool>::BroadcastIteratorRightState state;
if (Size(dimensions_x) >= Size(dimensions_y)) {
OrtValue* v = ort_.KernelContext_GetOutput(context, 0, dimensions_x.data(), dimensions_x.size());
bool* out = ort_.GetTensorMutableData<bool>(v);
BroadcastIteratorRight<std::string, std::string, bool> iter(
dimensions_x, dimensions_y, X.data(), Y.data(), out);
state.init(iter);
state.loop(cmp, state);
} else {
// Operator Equal is commutative.
OrtValue* v = ort_.KernelContext_GetOutput(context, 0, dimensions_y.data(), dimensions_y.size());
bool* out = ort_.GetTensorMutableData<bool>(v);
BroadcastIteratorRight<std::string, std::string, bool> iter(
dimensions_y, dimensions_x, Y.data(), X.data(), out);
state.init(iter);
state.loop(cmp, state);
}
}

Просмотреть файл

@ -0,0 +1,63 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "string_common.h"
#include "utils.h"
const OrtApi* _GetApi(Ort::CustomOpApi& ort) {
if (sizeof(Ort::CustomOpApi) == sizeof(OrtApi*)) {
// The following line should be replaced when an accessor is available.
// ort.api_ is not accessible (marked as private)
// Ort::GetApi() returns null.
// InitApi is missing when linking.
// OrtGetApiBase() is missing when linking.
const OrtApi* api = (const OrtApi*)((OrtApi**)&ort) - 1;
/*
// The following code checks api is equal to the expected
// pointer stored ins Ort::CustomOpApi ort as a private member.
// The following method can be added `const OrtApi& Api() { return api_; }`
// to give access to this value.
if (api != &(ort.Api())) {
// Ort::InitApi(); - missing from link
auto diff = (int64_t)(&(ort.Api()) - api);
throw std::runtime_error(MakeString(
"Internal error, pointers are different: ",
api, "!=", &(ort.Api()), " (expected) (other value ", &(Ort::GetApi()),
", delta=", diff, ")."));
}
*/
return api;
}
throw std::runtime_error(MakeString(
"Unable to get an OrtApi pointer from CustomOpApi. Variable ort is not a pointer. Size ",
sizeof(Ort::CustomOpApi), "!=", sizeof(OrtApi*), " (expected)."));
}
void GetTensorMutableDataString(Ort::CustomOpApi& ort, OrtKernelContext* context,
const OrtValue* value, std::vector<std::string>& output) {
const OrtApi& api = *_GetApi(ort);
OrtTensorDimensions dimensions(ort, value);
size_t len = static_cast<size_t>(dimensions.Size());
size_t data_len;
Ort::ThrowOnError(api, api.GetStringTensorDataLength(value, &data_len));
output.resize(len);
std::vector<char> result(data_len + len + 1, '\0');
std::vector<size_t> offsets(len);
Ort::ThrowOnError(api, api.GetStringTensorContent(value, (void*)result.data(), data_len, offsets.data(), offsets.size()));
output.resize(len);
for (int64_t i = (int64_t)len - 1; i >= 0; --i) {
if (i < len - 1)
result[offsets[i + (int64_t)1]] = '\0';
output[i] = result.data() + offsets[i];
}
}
void FillTensorDataString(Ort::CustomOpApi& ort, OrtKernelContext* context,
const std::vector<std::string>& value, OrtValue* output) {
const OrtApi& api = *_GetApi(ort);
std::vector<const char*> temp(value.size());
for (size_t i = 0; i < value.size(); ++i) {
temp[i] = value[i].c_str();
}
api.FillStringTensor(output, temp.data(), value.size());
}

Просмотреть файл

@ -0,0 +1,15 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <string>
#include "kernels.h"
// Retrieves a vector of strings if the input type is std::string.
// It is a copy of the input data and can be modified to compute the output.
void GetTensorMutableDataString(Ort::CustomOpApi& ort, OrtKernelContext* context,
const OrtValue* value, std::vector<std::string>& output);
void FillTensorDataString(Ort::CustomOpApi& ort, OrtKernelContext* context,
const std::vector<std::string>& value, OrtValue* output);

Просмотреть файл

@ -7,6 +7,7 @@
#include <algorithm>
#include "re2/re2.h"
#include "farmhash.h"
#include "string_common.h"
// Source: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/platform/hash.cc#L28
static inline uint64_t ByteAs64(char c) { return static_cast<uint64_t>(c) & 0xff; }
@ -81,9 +82,10 @@ KernelStringHash::KernelStringHash(OrtApi api) : BaseKernel(api) {
void KernelStringHash::Compute(OrtKernelContext* context) {
// Setup inputs
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
const std::string* str_input = ort_.GetTensorData<std::string>(input);
const OrtValue* num_buckets = ort_.KernelContext_GetInput(context, 1);
const int64_t* p_num_buckets = ort_.GetTensorData<int64_t>(num_buckets);
std::vector<std::string> str_input;
GetTensorMutableDataString(ort_, context, input, str_input);
// Verifications
OrtTensorDimensions num_buckets_dimensions(ort_, num_buckets);
@ -143,9 +145,10 @@ KernelStringHashFast::KernelStringHashFast(OrtApi api) : BaseKernel(api) {
void KernelStringHashFast::Compute(OrtKernelContext* context) {
// Setup inputs
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
const std::string* str_input = ort_.GetTensorData<std::string>(input);
const OrtValue* num_buckets = ort_.KernelContext_GetInput(context, 1);
const int64_t* p_num_buckets = ort_.GetTensorData<int64_t>(num_buckets);
std::vector<std::string> str_input;
GetTensorMutableDataString(ort_, context, input, str_input);
// Verifications
OrtTensorDimensions num_buckets_dimensions(ort_, num_buckets);

Просмотреть файл

@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include "string_join.hpp"
#include "string_common.h"
KernelStringJoin::KernelStringJoin(OrtApi api) : BaseKernel(api) {
}
@ -9,11 +10,12 @@ KernelStringJoin::KernelStringJoin(OrtApi api) : BaseKernel(api) {
void KernelStringJoin::Compute(OrtKernelContext* context) {
// Setup inputs
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
const std::string* X = ort_.GetTensorData<std::string>(input_X);
const OrtValue* input_sep = ort_.KernelContext_GetInput(context, 1);
const std::string* sep = ort_.GetTensorData<std::string>(input_sep);
const OrtValue* input_axis = ort_.KernelContext_GetInput(context, 2);
const int64_t* axis = ort_.GetTensorData<int64_t>(input_axis);
std::vector<std::string> X, sep;
GetTensorMutableDataString(ort_, context, input_X, X);
GetTensorMutableDataString(ort_, context, input_sep, sep);
// Setup output
OrtTensorDimensions dimensions_sep(ort_, input_sep);
@ -38,11 +40,10 @@ void KernelStringJoin::Compute(OrtKernelContext* context) {
}
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions_out.data(), dimensions_out.size());
std::string* out = ort_.GetTensorMutableData<std::string>(output);
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
int64_t size = ort_.GetTensorShapeElementCount(output_info);
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
std::vector<std::string> out(size);
// Do computation
int64_t h = 1;
@ -59,12 +60,13 @@ void KernelStringJoin::Compute(OrtKernelContext* context) {
std::ostringstream st;
int64_t index = ri + li * inc;
for (int64_t j = 0; j < n_red; ++j, index += h) {
st << X[index] << *sep;
st << X[index] << sep[0];
}
st << X[index];
out[pos] = st.str();
}
}
FillTensorDataString(ort_, context, out, output);
}
void* CustomOpStringJoin::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {

Просмотреть файл

@ -6,6 +6,7 @@
#include <cmath>
#include <algorithm>
#include "re2/re2.h"
#include "string_common.h"
KernelStringRegexReplace::KernelStringRegexReplace(OrtApi api) : BaseKernel(api) {
}
@ -13,11 +14,12 @@ KernelStringRegexReplace::KernelStringRegexReplace(OrtApi api) : BaseKernel(api)
void KernelStringRegexReplace::Compute(OrtKernelContext* context) {
// Setup inputs
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
const std::string* str_input = ort_.GetTensorData<std::string>(input);
const OrtValue* pattern = ort_.KernelContext_GetInput(context, 1);
const std::string* str_pattern = ort_.GetTensorData<std::string>(pattern);
const OrtValue* rewrite = ort_.KernelContext_GetInput(context, 2);
const std::string* str_rewrite = ort_.GetTensorData<std::string>(rewrite);
std::vector<std::string> str_input, str_pattern, str_rewrite;
GetTensorMutableDataString(ort_, context, input, str_input);
GetTensorMutableDataString(ort_, context, pattern, str_pattern);
GetTensorMutableDataString(ort_, context, rewrite, str_rewrite);
// Verifications
OrtTensorDimensions pattern_dimensions(ort_, pattern);
@ -34,20 +36,19 @@ void KernelStringRegexReplace::Compute(OrtKernelContext* context) {
// Setup output
OrtTensorDimensions dimensions(ort_, input);
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
std::string* out = ort_.GetTensorMutableData<std::string>(output);
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
int64_t size = ort_.GetTensorShapeElementCount(output_info);
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
re2::StringPiece piece(*str_rewrite);
re2::RE2 reg(*str_pattern);
re2::StringPiece piece(str_rewrite[0]);
re2::RE2 reg(str_pattern[0]);
// Do computation
for (int64_t i = 0; i < size; i++) {
out[i] = str_input[i];
re2::RE2::GlobalReplace(out + i, reg, piece);
re2::RE2::GlobalReplace(&(str_input[i]), reg, piece);
}
FillTensorDataString(ort_, context, str_input, output);
}
void* CustomOpStringRegexReplace::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {

Просмотреть файл

@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include "string_split.hpp"
#include "string_common.h"
KernelStringSplit::KernelStringSplit(OrtApi api) : BaseKernel(api) {
}
@ -9,11 +10,12 @@ KernelStringSplit::KernelStringSplit(OrtApi api) : BaseKernel(api) {
void KernelStringSplit::Compute(OrtKernelContext* context) {
// Setup inputs
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
const std::string* X = ort_.GetTensorData<std::string>(input_X);
const OrtValue* input_sep = ort_.KernelContext_GetInput(context, 1);
const std::string* sep = ort_.GetTensorData<std::string>(input_sep);
const OrtValue* input_skip_empty = ort_.KernelContext_GetInput(context, 2);
const bool* skip_empty = ort_.GetTensorData<bool>(input_skip_empty);
std::vector<std::string> X, sep;
GetTensorMutableDataString(ort_, context, input_X, X);
GetTensorMutableDataString(ort_, context, input_sep, sep);
// Setup output
OrtTensorDimensions dimensions_sep(ort_, input_sep);
@ -30,7 +32,7 @@ void KernelStringSplit::Compute(OrtKernelContext* context) {
std::vector<int64_t> indices;
int64_t maxc = 0;
int64_t col;
std::string delimiter = *sep;
std::string delimiter = sep[0];
if (delimiter.size() == 0) {
char word[2] = "a";
for (int64_t row = 0; row < dimensions[0]; ++row) {
@ -86,13 +88,12 @@ void KernelStringSplit::Compute(OrtKernelContext* context) {
OrtValue* out_shape = ort_.KernelContext_GetOutput(context, 2, shape_shape.data(), shape_shape.size());
int64_t* p_indices = ort_.GetTensorMutableData<int64_t>(out_indices);
std::string* p_text = ort_.GetTensorMutableData<std::string>(out_text);
int64_t* p_shape = ort_.GetTensorMutableData<int64_t>(out_shape);
memcpy(p_indices, indices.data(), indices.size() * sizeof(int64_t));
p_shape[0] = dimensions[0];
p_shape[1] = maxc;
std::copy(words.begin(), words.end(), p_text);
FillTensorDataString(ort_, context, words, out_text);
}
void* CustomOpStringSplit::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {

Просмотреть файл

@ -2,33 +2,31 @@
// Licensed under the MIT License.
#include "string_upper.hpp"
#include "string_common.h"
#include <vector>
#include <cmath>
#include <algorithm>
KernelStringUpper::KernelStringUpper(OrtApi api) : BaseKernel(api) {
}
void KernelStringUpper::Compute(OrtKernelContext* context) {
// Setup inputs
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
const std::string* X = ort_.GetTensorData<std::string>(input_X);
std::vector<std::string> X;
GetTensorMutableDataString(ort_, context, input_X, X);
// Setup output
OrtTensorDimensions dimensions(ort_, input_X);
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
std::string* out = ort_.GetTensorMutableData<std::string>(output);
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
int64_t size = ort_.GetTensorShapeElementCount(output_info);
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
// Do computation
for (int64_t i = 0; i < size; i++) {
out[i] = X[i];
std::transform(out[i].begin(), out[i].end(), out[i].begin(), ::toupper);
for (int64_t i = 0; i < (int64_t)X.size(); ++i) {
std::transform(X[i].begin(), X[i].end(), X[i].begin(), ::toupper);
}
// Fills the output
FillTensorDataString(ort_, context, X, output);
}
void* CustomOpStringUpper::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
@ -52,4 +50,3 @@ size_t CustomOpStringUpper::GetOutputTypeCount() const {
ONNXTensorElementDataType CustomOpStringUpper::GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};

Просмотреть файл

@ -20,6 +20,7 @@
#include "utils.h"
#include "pykernel.h"
#include "kernels/string_hash.hpp"
#include "kernels/string_common.h"
namespace py = pybind11;
@ -34,11 +35,11 @@ const int PyCustomOpDef::dt_int64 = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; // m
const int PyCustomOpDef::dt_string = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; // maps to c++ type std::string
const int PyCustomOpDef::dt_bool = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
const int PyCustomOpDef::dt_float16 = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
const int PyCustomOpDef::dt_double = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; // maps to c type double
const int PyCustomOpDef::dt_uint32 = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32; // maps to c type uint32_t
const int PyCustomOpDef::dt_uint64 = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64; // maps to c type uint64_t
const int PyCustomOpDef::dt_double = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; // maps to c type double
const int PyCustomOpDef::dt_uint32 = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32; // maps to c type uint32_t
const int PyCustomOpDef::dt_uint64 = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64; // maps to c type uint64_t
// complex with float32 real and imaginary components
const int PyCustomOpDef::dt_complex64 = ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64;
const int PyCustomOpDef::dt_complex64 = ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64;
// complex with float64 real and imaginary components
const int PyCustomOpDef::dt_complex128 = ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128;
// Non-IEEE floating-point format based on IEEE754 single-precision
@ -112,7 +113,7 @@ static size_t element_size(ONNXTensorElementDataType dt) {
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128:
return sizeof(std::complex<double>);
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
return sizeof(std::string*);
throw std::runtime_error("OrtValue content cannot be casted into std::string*.");
default:
throw std::runtime_error("No corresponding Numpy data type/Tensor data Type.");
}
@ -166,7 +167,9 @@ struct PyCustomOpDefImpl : public PyCustomOpDef {
return c;
}
static py::object BuildPyObjFromTensor(const void* p, const shape_t& shape, ONNXTensorElementDataType dtype) {
static py::object BuildPyObjFromTensor(
Ort::CustomOpApi& ort, OrtKernelContext* context, const OrtValue* value,
const shape_t& shape, ONNXTensorElementDataType dtype) {
std::vector<npy_intp> npy_dims;
for (auto n : shape) {
npy_dims.push_back(n);
@ -180,11 +183,13 @@ struct PyCustomOpDefImpl : public PyCustomOpDef {
if (dtype == ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
py::object* outObj = static_cast<py::object*>(out_ptr);
auto size = calc_size_from_shape(shape);
const std::string* src = (const std::string*)p;
for (int i = 0; i < size; i++, src++) {
outObj[i] = py::cast(*src);
std::vector<std::string> src;
GetTensorMutableDataString(ort, context, value, src);
for (int i = 0; i < size; ++i) {
outObj[i] = py::cast(src[i]);
}
} else {
const void* p = (const void*)ort.GetTensorData<char>(value);
size_t size_type = element_size(dtype);
memcpy(out_ptr, p, size_type * calc_size_from_shape(shape));
}
@ -215,7 +220,7 @@ void PyCustomOpKernel::Compute(OrtKernelContext* context) {
// std::unique_lock<std::mutex> lck(op_mutex);
// is_ready = true;
// op_cv.notify_all();
// std::this_thread::sleep_for(std::chrono::milliseconds(5000));
// std::this_thread::sleep_for(std::chrono::milliseconds(5000));
size_t n_inputs = ort_.KernelContext_GetInputCount(context);
size_t n_outputs = ort_.KernelContext_GetOutputCount(context);
@ -252,7 +257,7 @@ void PyCustomOpKernel::Compute(OrtKernelContext* context) {
py::list pyinputs;
for (auto it = inputs.begin(); it != inputs.end(); ++it) {
py::object input0 = PyCustomOpDefImpl::BuildPyObjFromTensor(
(const void*)ort_.GetTensorData<float>(it->input_X), it->dimensions, it->dtype);
ort_, context, it->input_X, it->dimensions, it->dtype);
pyinputs.append(input0);
}
@ -267,19 +272,14 @@ void PyCustomOpKernel::Compute(OrtKernelContext* context) {
OrtValue* output = ort_.KernelContext_GetOutput(context, no, dims.data(), dims.size());
OrtTensorTypeAndShapeInfo* o_info = ort_.GetTensorTypeAndShape(output);
ONNXTensorElementDataType o_dtype = ort_.GetTensorElementType(o_info);
const void* Y = (const void*)ort_.GetTensorData<float>(output);
ort_.ReleaseTensorTypeAndShapeInfo(o_info);
void* out = (void*)ort_.GetTensorMutableData<float>(output);
if (o_dtype == ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
auto retval = fetch[2 + no * 2].cast<std::vector<std::string>>();
std::string* type_outPtr = (std::string*)out;
std::string* end = type_outPtr + retval.size();
const std::string* source = (const std::string*)retval.data();
for (; type_outPtr != end; ++type_outPtr, ++source) {
*type_outPtr = *source;
}
std::vector<std::string> retval = fetch[2 + no * 2].cast<std::vector<std::string>>();
FillTensorDataString(ort_, context, retval, output);
} else {
const void* Y = (const void*)ort_.GetTensorData<float>(output);
void* out = (void*)ort_.GetTensorMutableData<float>(output);
py::array retval = fetch[2 + no * 2].cast<py::array>();
if (element_size(o_dtype) != retval.itemsize()) {
switch (o_dtype) {
@ -395,7 +395,7 @@ const OrtCustomOp* FetchPyCustomOps(size_t& count) {
return ptr;
}
bool EnablePyCustomOps(bool enabled){
bool EnablePyCustomOps(bool enabled) {
static bool f_pyop_enabled = true;
bool last = f_pyop_enabled;
f_pyop_enabled = enabled;
@ -430,8 +430,7 @@ void AddObjectMethods(pybind11::module& m) {
.def_readwrite("obj_id", &PyCustomOpDef::obj_id)
.def_readwrite("input_types", &PyCustomOpDef::input_types)
.def_readwrite("output_types", &PyCustomOpDef::output_types)
.def_static("install_hooker", [](py::object obj) {
PyCustomOpDefImpl::op_invoker = std::make_unique<PyCustomOpDefImpl::callback_t>(obj); })
.def_static("install_hooker", [](py::object obj) { PyCustomOpDefImpl::op_invoker = std::make_unique<PyCustomOpDefImpl::callback_t>(obj); })
.def_readonly_static("undefined", &PyCustomOpDef::undefined)
.def_readonly_static("dt_float", &PyCustomOpDef::dt_float)
.def_readonly_static("dt_uint8", &PyCustomOpDef::dt_uint8)

Просмотреть файл

@ -347,9 +347,11 @@ class TestPythonOpString(unittest.TestCase):
onnx_model = _create_test_model_string_upper('')
self.assertIn('op_type: "StringUpper"', str(onnx_model))
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so)
input_1 = np.array([["Abcé"]])
input_1 = np.array([["R"], ["Abcé"], ["ABC"], ["A"]])
txout = sess.run(None, {'input_1': input_1})
self.assertEqual(txout[0].tolist(), np.array([["ABCé"]]).tolist())
self.assertEqual(
txout[0].tolist(),
np.array([["R"], ["ABCé"], ["ABC"], ["A"]]).tolist())
def test_string_upper_python(self):
so = _ort.SessionOptions()

Разница между файлами не показана из-за своего большого размера Загрузить разницу