add a native unit test for regex_split op (#166)
* add a native unit test for regex_split op * fix the case of shape [1, 0] * Update mshost.yaml * downgrade the test model version. * upgrade torch version on Windows CI * disable windows python 3.7 pipeline.
This commit is contained in:
Родитель
4d7004bf6e
Коммит
70aa18e14e
|
@ -45,7 +45,7 @@ jobs:
|
|||
|
||||
- script: |
|
||||
cd out/Linux/RelWithDebInfo
|
||||
ctest -C RelWithDebInfo
|
||||
ctest -C RelWithDebInfo --output-on-failure
|
||||
displayName: Run C++ native tests
|
||||
|
||||
- task: UsePythonVersion@0
|
||||
|
@ -120,7 +120,7 @@ jobs:
|
|||
|
||||
- script: |
|
||||
cd out/Darwin/RelWithDebInfo
|
||||
ctest -C RelWithDebInfo
|
||||
ctest -C RelWithDebInfo --output-on-failure
|
||||
displayName: Run C++ native tests
|
||||
|
||||
#############
|
||||
|
@ -229,7 +229,7 @@ jobs:
|
|||
|
||||
- script: |
|
||||
cd out/Windows
|
||||
ctest -C RelWithDebInfo
|
||||
ctest -C RelWithDebInfo --output-on-failure
|
||||
displayName: Run C++ native tests
|
||||
|
||||
################
|
||||
|
@ -282,7 +282,7 @@ jobs:
|
|||
|
||||
- script: |
|
||||
call activate pyenv
|
||||
python -m pip install torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio===0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
|
||||
python -m pip install torch==1.8.2+cpu torchvision==0.9.2+cpu torchaudio===0.8.2 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
|
||||
displayName: Install pytorch
|
||||
|
||||
- script: |
|
||||
|
|
|
@ -49,12 +49,8 @@ struct OrtTensorDimensions : std::vector<int64_t> {
|
|||
std::vector<int64_t>::operator=(ort.GetTensorShape(info));
|
||||
ort.ReleaseTensorTypeAndShapeInfo(info);
|
||||
}
|
||||
const std::vector<int64_t>& GetDims() const { return *this; }
|
||||
int64_t Size() const {
|
||||
if (empty()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int64_t Size() const {
|
||||
int64_t s = 1.;
|
||||
for (auto it = begin(); it != end(); ++it)
|
||||
s *= *it;
|
||||
|
|
|
@ -16,7 +16,7 @@ void KernelSegmentExtraction::Compute(OrtKernelContext* context) {
|
|||
|
||||
std::vector<std::int64_t> segment_value;
|
||||
std::vector<std::int64_t> segment_position;
|
||||
for (int i = 0; i < input_dim.Size(); i++) {
|
||||
for (std::int64_t i = 0; i < input_dim.Size(); i++) {
|
||||
if (!p_data[i]) {
|
||||
continue;
|
||||
}
|
||||
|
|
|
@ -20,8 +20,8 @@ void KernelSegmentSum_Compute(Ort::CustomOpApi& ort_, OrtKernelContext* context)
|
|||
ORT_CXX_API_THROW("segment_ids must a single tensor", ORT_INVALID_GRAPH);
|
||||
if (dim_data[0] != dim_seg[0])
|
||||
ORT_CXX_API_THROW(MakeString(
|
||||
"First dimensions of data and segment_ids should be the same, data shape: ", dim_data.GetDims(),
|
||||
" segment_ids shape: ", dim_seg.GetDims()), ORT_INVALID_GRAPH);
|
||||
"First dimensions of data and segment_ids should be the same, data shape: ", dim_data,
|
||||
" segment_ids shape: ", dim_seg), ORT_INVALID_GRAPH);
|
||||
|
||||
int64_t last_seg = p_segment_ids[dim_seg[0] - 1];
|
||||
OrtTensorDimensions dim_out = dim_data;
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include "ocos.h"
|
||||
|
||||
template <typename T>
|
||||
inline void MakeStringInternal(std::ostringstream& ss, const T& t) noexcept {
|
||||
|
@ -22,6 +23,11 @@ inline void MakeStringInternal(std::ostringstream& ss, const std::vector<int64_t
|
|||
ss << "]";
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void MakeStringInternal(std::ostringstream& ss, const OrtTensorDimensions& t) noexcept {
|
||||
MakeStringInternal(ss, static_cast<const std::vector<int64_t>&>(t));
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void MakeStringInternal(std::ostringstream& ss, const std::vector<std::string>& t) noexcept {
|
||||
ss << "[";
|
||||
|
|
|
@ -143,12 +143,13 @@ void KernelBertTokenizerDecoder::Compute(OrtKernelContext* context) {
|
|||
OrtTensorDimensions positions_dim(ort_, positions);
|
||||
if (use_indices_ &&
|
||||
(!(positions_dim.empty() ||
|
||||
(positions_dim.Size() == 0) ||
|
||||
(positions_dim.size() == 2 && positions_dim[1] == 2)))) {
|
||||
(positions_dim.Size() == 0) ||
|
||||
(positions_dim.size() == 2 && positions_dim[1] == 2)))) {
|
||||
ORT_CXX_API_THROW("[BertTokenizerDecoder]: Expect positions empty or a [n, 2] matrix when use indices", ORT_INVALID_GRAPH);
|
||||
}
|
||||
|
||||
const int64_t* p_positions = positions_dim.Size() == 0 ? nullptr : ort_.GetTensorData<int64_t>(positions);
|
||||
const int64_t* p_positions =
|
||||
positions_dim.empty() || positions_dim.Size() == 0? nullptr : ort_.GetTensorData<int64_t>(positions);
|
||||
|
||||
std::vector<std::string> result;
|
||||
std::vector<int64_t> output_dim(1);
|
||||
|
|
|
@ -33,6 +33,7 @@ void KernelBlingFireSentenceBreaker::Compute(OrtKernelContext* context) {
|
|||
const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
|
||||
OrtTensorDimensions dimensions(ort_, input);
|
||||
|
||||
// TODO: fix this scalar check.
|
||||
if (dimensions.Size() != 1 && dimensions[0] != 1) {
|
||||
ORT_CXX_API_THROW("We only support string scalar.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
|
Двоичный файл не отображается.
|
@ -13,6 +13,7 @@ struct TestValue {
|
|||
std::vector<int64_t> dims;
|
||||
std::vector<float> values_float;
|
||||
std::vector<int32_t> values_int32;
|
||||
std::vector<int64_t> values_int64;
|
||||
std::vector<std::string> values_string;
|
||||
};
|
||||
|
||||
|
|
|
@ -168,6 +168,9 @@ void RunSession(Ort::Session& session_object,
|
|||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
|
||||
_emplace_back(memory_info, ort_inputs, inputs[i].values_int32, inputs[i].dims);
|
||||
break;
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
|
||||
_emplace_back(memory_info, ort_inputs, inputs[i].values_int64, inputs[i].dims);
|
||||
break;
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: {
|
||||
Ort::Value& ort_value = ort_inputs.emplace_back(
|
||||
Ort::Value::CreateTensor(allocator, inputs[i].dims.data(), inputs[i].dims.size(),
|
||||
|
@ -208,6 +211,9 @@ void RunSession(Ort::Session& session_object,
|
|||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
|
||||
_assert_eq(*output_tensor, expected.values_int32, total_len);
|
||||
break;
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
|
||||
_assert_eq(*output_tensor, expected.values_int64, total_len);
|
||||
break;
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: {
|
||||
std::vector<std::string> output_string;
|
||||
GetTensorMutableDataString(Ort::GetApi(), *output_tensor, output_string);
|
||||
|
|
|
@ -30,3 +30,42 @@ TEST(utils, test_string_lower) {
|
|||
model_path /= "custom_op_string_lower.onnx";
|
||||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
}
|
||||
|
||||
|
||||
TEST(utils, test_regex_split_with_offsets) {
|
||||
auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");
|
||||
|
||||
std::vector<TestValue> inputs(1);
|
||||
inputs[0].name = "input:0";
|
||||
inputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
inputs[0].dims = {2};
|
||||
inputs[0].values_string = {"a Test 1 2 3 ♠♣", "Hi there test test ♥♦"};
|
||||
|
||||
std::vector<TestValue> outputs(4);
|
||||
outputs[0].name = "output:0";
|
||||
outputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
outputs[0].dims = {11};
|
||||
outputs[0].values_string = {"a", "Test", "1", "2", "3", "♠♣", "Hi", "there", "test", "test", "♥♦"};
|
||||
|
||||
outputs[1].name = "output1:0";
|
||||
outputs[1].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
outputs[1].dims = {11};
|
||||
outputs[1].values_int64 = {0, 2, 7, 9, 11, 13, 0, 3, 9, 14, 19};
|
||||
|
||||
outputs[2].name = "output2:0";
|
||||
outputs[2].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
outputs[2].dims = {11};
|
||||
outputs[2].values_int64 = {1, 6, 8, 10, 12, 19, 2, 8, 13, 18, 25};
|
||||
|
||||
outputs[3].name = "output3:0";
|
||||
outputs[3].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
outputs[3].dims = {3};
|
||||
outputs[3].values_int64 = {0, 6, 11};
|
||||
|
||||
std::filesystem::path model_path = __FILE__;
|
||||
model_path = model_path.parent_path();
|
||||
model_path /= "..";
|
||||
model_path /= "data";
|
||||
model_path /= "test_regex_split_with_offsets.onnx";
|
||||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
}
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
import io
|
||||
import onnx
|
||||
import unittest
|
||||
import platform
|
||||
import torchvision
|
||||
import numpy as np
|
||||
from onnxruntime_extensions import PyOrtFunction, hook_model_op, PyOp
|
||||
from onnxruntime_extensions.onnxprocess import torch_wrapper as torch
|
||||
from onnxruntime_extensions.onnxprocess import trace_for_onnx, pyfunc_from_model
|
||||
|
||||
|
||||
@unittest.skipIf(platform.python_version_tuple()[0:2] == (
|
||||
'3', '7'), 'Windows CI pipeline failed on the version temporarily.')
|
||||
class TestTorchE2E(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
|
Загрузка…
Ссылка в новой задаче