add a native unit test for regex_split op (#166)

* add a native unit test for regex_split op

* fix the case of shape [1, 0]

* Update mshost.yaml

* downgrade the test model version.

* upgrade torch version on Windows CI

* disable windows python 3.7 pipeline.
This commit is contained in:
Wenbing Li 2021-10-06 15:58:46 -07:00 коммит произвёл GitHub
Родитель 4d7004bf6e
Коммит 70aa18e14e
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
12 изменённых файлов: 68 добавлений и 16 удалений

Просмотреть файл

@ -45,7 +45,7 @@ jobs:
- script: |
cd out/Linux/RelWithDebInfo
ctest -C RelWithDebInfo
ctest -C RelWithDebInfo --output-on-failure
displayName: Run C++ native tests
- task: UsePythonVersion@0
@ -120,7 +120,7 @@ jobs:
- script: |
cd out/Darwin/RelWithDebInfo
ctest -C RelWithDebInfo
ctest -C RelWithDebInfo --output-on-failure
displayName: Run C++ native tests
#############
@ -229,7 +229,7 @@ jobs:
- script: |
cd out/Windows
ctest -C RelWithDebInfo
ctest -C RelWithDebInfo --output-on-failure
displayName: Run C++ native tests
################
@ -282,7 +282,7 @@ jobs:
- script: |
call activate pyenv
python -m pip install torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio===0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
python -m pip install torch==1.8.2+cpu torchvision==0.9.2+cpu torchaudio===0.8.2 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
displayName: Install pytorch
- script: |

Просмотреть файл

@ -49,12 +49,8 @@ struct OrtTensorDimensions : std::vector<int64_t> {
std::vector<int64_t>::operator=(ort.GetTensorShape(info));
ort.ReleaseTensorTypeAndShapeInfo(info);
}
const std::vector<int64_t>& GetDims() const { return *this; }
int64_t Size() const {
if (empty()) {
return 0;
}
int64_t Size() const {
int64_t s = 1.;
for (auto it = begin(); it != end(); ++it)
s *= *it;

Просмотреть файл

@ -16,7 +16,7 @@ void KernelSegmentExtraction::Compute(OrtKernelContext* context) {
std::vector<std::int64_t> segment_value;
std::vector<std::int64_t> segment_position;
for (int i = 0; i < input_dim.Size(); i++) {
for (std::int64_t i = 0; i < input_dim.Size(); i++) {
if (!p_data[i]) {
continue;
}

Просмотреть файл

@ -20,8 +20,8 @@ void KernelSegmentSum_Compute(Ort::CustomOpApi& ort_, OrtKernelContext* context)
ORT_CXX_API_THROW("segment_ids must a single tensor", ORT_INVALID_GRAPH);
if (dim_data[0] != dim_seg[0])
ORT_CXX_API_THROW(MakeString(
"First dimensions of data and segment_ids should be the same, data shape: ", dim_data.GetDims(),
" segment_ids shape: ", dim_seg.GetDims()), ORT_INVALID_GRAPH);
"First dimensions of data and segment_ids should be the same, data shape: ", dim_data,
" segment_ids shape: ", dim_seg), ORT_INVALID_GRAPH);
int64_t last_seg = p_segment_ids[dim_seg[0] - 1];
OrtTensorDimensions dim_out = dim_data;

Просмотреть файл

@ -4,6 +4,7 @@
#include <iostream>
#include <sstream>
#include <vector>
#include "ocos.h"
template <typename T>
inline void MakeStringInternal(std::ostringstream& ss, const T& t) noexcept {
@ -22,6 +23,11 @@ inline void MakeStringInternal(std::ostringstream& ss, const std::vector<int64_t
ss << "]";
}
template <>
inline void MakeStringInternal(std::ostringstream& ss, const OrtTensorDimensions& t) noexcept {
MakeStringInternal(ss, static_cast<const std::vector<int64_t>&>(t));
}
template <>
inline void MakeStringInternal(std::ostringstream& ss, const std::vector<std::string>& t) noexcept {
ss << "[";

Просмотреть файл

@ -143,12 +143,13 @@ void KernelBertTokenizerDecoder::Compute(OrtKernelContext* context) {
OrtTensorDimensions positions_dim(ort_, positions);
if (use_indices_ &&
(!(positions_dim.empty() ||
(positions_dim.Size() == 0) ||
(positions_dim.size() == 2 && positions_dim[1] == 2)))) {
(positions_dim.Size() == 0) ||
(positions_dim.size() == 2 && positions_dim[1] == 2)))) {
ORT_CXX_API_THROW("[BertTokenizerDecoder]: Expect positions empty or a [n, 2] matrix when use indices", ORT_INVALID_GRAPH);
}
const int64_t* p_positions = positions_dim.Size() == 0 ? nullptr : ort_.GetTensorData<int64_t>(positions);
const int64_t* p_positions =
positions_dim.empty() || positions_dim.Size() == 0? nullptr : ort_.GetTensorData<int64_t>(positions);
std::vector<std::string> result;
std::vector<int64_t> output_dim(1);

Просмотреть файл

@ -33,6 +33,7 @@ void KernelBlingFireSentenceBreaker::Compute(OrtKernelContext* context) {
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
OrtTensorDimensions dimensions(ort_, input);
// TODO: fix this scalar check.
if (dimensions.Size() != 1 && dimensions[0] != 1) {
ORT_CXX_API_THROW("We only support string scalar.", ORT_INVALID_ARGUMENT);
}

Двоичные данные
test/data/test_regex_split_with_offsets.onnx Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -13,6 +13,7 @@ struct TestValue {
std::vector<int64_t> dims;
std::vector<float> values_float;
std::vector<int32_t> values_int32;
std::vector<int64_t> values_int64;
std::vector<std::string> values_string;
};

Просмотреть файл

@ -168,6 +168,9 @@ void RunSession(Ort::Session& session_object,
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
_emplace_back(memory_info, ort_inputs, inputs[i].values_int32, inputs[i].dims);
break;
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
_emplace_back(memory_info, ort_inputs, inputs[i].values_int64, inputs[i].dims);
break;
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: {
Ort::Value& ort_value = ort_inputs.emplace_back(
Ort::Value::CreateTensor(allocator, inputs[i].dims.data(), inputs[i].dims.size(),
@ -208,6 +211,9 @@ void RunSession(Ort::Session& session_object,
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
_assert_eq(*output_tensor, expected.values_int32, total_len);
break;
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
_assert_eq(*output_tensor, expected.values_int64, total_len);
break;
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: {
std::vector<std::string> output_string;
GetTensorMutableDataString(Ort::GetApi(), *output_tensor, output_string);

Просмотреть файл

@ -30,3 +30,42 @@ TEST(utils, test_string_lower) {
model_path /= "custom_op_string_lower.onnx";
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
}
TEST(utils, test_regex_split_with_offsets) {
auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");
std::vector<TestValue> inputs(1);
inputs[0].name = "input:0";
inputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
inputs[0].dims = {2};
inputs[0].values_string = {"a Test 1 2 3 ♠♣", "Hi there test test ♥♦"};
std::vector<TestValue> outputs(4);
outputs[0].name = "output:0";
outputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
outputs[0].dims = {11};
outputs[0].values_string = {"a", "Test", "1", "2", "3", "♠♣", "Hi", "there", "test", "test", "♥♦"};
outputs[1].name = "output1:0";
outputs[1].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
outputs[1].dims = {11};
outputs[1].values_int64 = {0, 2, 7, 9, 11, 13, 0, 3, 9, 14, 19};
outputs[2].name = "output2:0";
outputs[2].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
outputs[2].dims = {11};
outputs[2].values_int64 = {1, 6, 8, 10, 12, 19, 2, 8, 13, 18, 25};
outputs[3].name = "output3:0";
outputs[3].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
outputs[3].dims = {3};
outputs[3].values_int64 = {0, 6, 11};
std::filesystem::path model_path = __FILE__;
model_path = model_path.parent_path();
model_path /= "..";
model_path /= "data";
model_path /= "test_regex_split_with_offsets.onnx";
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
}

Просмотреть файл

@ -1,13 +1,15 @@
import io
import onnx
import unittest
import platform
import torchvision
import numpy as np
from onnxruntime_extensions import PyOrtFunction, hook_model_op, PyOp
from onnxruntime_extensions.onnxprocess import torch_wrapper as torch
from onnxruntime_extensions.onnxprocess import trace_for_onnx, pyfunc_from_model
@unittest.skipIf(platform.python_version_tuple()[0:2] == (
'3', '7'), 'Windows CI pipeline failed on the version temporarily.')
class TestTorchE2E(unittest.TestCase):
@classmethod
def setUpClass(cls):