Split out some miscellaneous changes from refactoring the azure ops. (#506)
- ifdef out some test code that requires RE2 if RE2 is not enabled - add ability to plugin custom output validator for C++ unit tests - OpenAI responses can have different punctuation. used in the new unit tests that will be in the refactoring PR
This commit is contained in:
Родитель
911c2b2340
Коммит
d9fa8ea060
|
@ -382,6 +382,8 @@ struct Variadic : public TensorBase {
|
|||
}
|
||||
tensors_.emplace_back(tensor.release());
|
||||
} // for
|
||||
} else {
|
||||
// a Variadic used for output is populated by the Compute so leave tensors_ empty here
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
|
|
|
@ -40,11 +40,21 @@ struct TestValue {
|
|||
std::vector<bool> values_bool;
|
||||
};
|
||||
|
||||
// output_validator is optional if you need custom validation for one or more outputs.
|
||||
// for any output you do not have custom validation for call ValidateOutputEqual
|
||||
using OutputValidator = std::function<void(size_t output_idx, Ort::Value& actual, TestValue expected)>;
|
||||
|
||||
void ValidateOutputEqual(size_t output_idx, Ort::Value& actual, TestValue expected);
|
||||
|
||||
void RunSession(Ort::Session& session_object,
|
||||
const std::vector<TestValue>& inputs,
|
||||
const std::vector<TestValue>& outputs);
|
||||
const std::vector<TestValue>& outputs,
|
||||
OutputValidator output_validator = nullptr);
|
||||
|
||||
void TestInference(Ort::Env& env, const ORTCHAR_T* model_uri,
|
||||
const std::vector<TestValue>& inputs,
|
||||
const std::vector<TestValue>& outputs,
|
||||
const char* custom_op_library_filename);
|
||||
const char* custom_op_library_filename,
|
||||
OutputValidator output_validator = nullptr);
|
||||
|
||||
void GetTensorMutableDataString(const OrtApi& api, const OrtValue* value, std::vector<std::string>& output);
|
||||
|
|
|
@ -209,7 +209,8 @@ void GetTensorMutableDataString(const OrtApi& api, const OrtValue* value, std::v
|
|||
|
||||
void RunSession(Ort::Session& session_object,
|
||||
const std::vector<TestValue>& inputs,
|
||||
const std::vector<TestValue>& outputs) {
|
||||
const std::vector<TestValue>& outputs,
|
||||
OutputValidator output_validator) {
|
||||
std::vector<Ort::Value> ort_inputs;
|
||||
std::vector<const char*> input_names;
|
||||
std::vector<const char*> output_names;
|
||||
|
@ -267,38 +268,48 @@ void RunSession(Ort::Session& session_object,
|
|||
ASSERT_EQ(output_type, expected.element_type);
|
||||
std::vector<int64_t> dimension = type_info.GetShape();
|
||||
ASSERT_EQ(dimension, expected.dims);
|
||||
size_t total_len = type_info.GetElementCount();
|
||||
switch (expected.element_type) {
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
|
||||
_assert_eq(*output_tensor, expected.values_float, total_len);
|
||||
break;
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
|
||||
_assert_eq(*output_tensor, expected.values_uint8, total_len);
|
||||
break;
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
|
||||
_assert_eq(*output_tensor, expected.values_int32, total_len);
|
||||
break;
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
|
||||
_assert_eq(*output_tensor, expected.values_int64, total_len);
|
||||
break;
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: {
|
||||
std::vector<std::string> output_string;
|
||||
GetTensorMutableDataString(Ort::GetApi(), *output_tensor, output_string);
|
||||
ASSERT_EQ(expected.values_string, output_string);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error(MakeString(
|
||||
"Unable to handle output ", index, " type ", expected.element_type,
|
||||
" is not implemented yet."));
|
||||
if (output_validator != nullptr) {
|
||||
output_validator(index, *output_tensor, expected);
|
||||
} else {
|
||||
ValidateOutputEqual(index, *output_tensor, expected);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ValidateOutputEqual(size_t output_idx, Ort::Value& actual, TestValue expected) {
|
||||
size_t total_len = actual.GetTensorTypeAndShapeInfo().GetElementCount();
|
||||
|
||||
switch (expected.element_type) {
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
|
||||
_assert_eq(actual, expected.values_float, total_len);
|
||||
break;
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
|
||||
_assert_eq(actual, expected.values_uint8, total_len);
|
||||
break;
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
|
||||
_assert_eq(actual, expected.values_int32, total_len);
|
||||
break;
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
|
||||
_assert_eq(actual, expected.values_int64, total_len);
|
||||
break;
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: {
|
||||
std::vector<std::string> output_string;
|
||||
GetTensorMutableDataString(Ort::GetApi(), actual, output_string);
|
||||
ASSERT_EQ(expected.values_string, output_string);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error(MakeString(
|
||||
"Unable to handle output ", output_idx, " type ", expected.element_type,
|
||||
" is not implemented yet."));
|
||||
}
|
||||
}
|
||||
|
||||
void TestInference(Ort::Env& env, const ORTCHAR_T* model_uri,
|
||||
const std::vector<TestValue>& inputs,
|
||||
const std::vector<TestValue>& outputs,
|
||||
const char* custom_op_library_filename) {
|
||||
const char* custom_op_library_filename,
|
||||
OutputValidator output_validator) {
|
||||
Ort::SessionOptions session_options;
|
||||
void* handle = nullptr;
|
||||
if (custom_op_library_filename) {
|
||||
|
@ -310,7 +321,7 @@ void TestInference(Ort::Env& env, const ORTCHAR_T* model_uri,
|
|||
Ort::Session session(env, model_uri, session_options);
|
||||
|
||||
// Now run
|
||||
RunSession(session, inputs, outputs);
|
||||
RunSession(session, inputs, outputs, output_validator);
|
||||
}
|
||||
|
||||
static CustomOpOne op_1st;
|
||||
|
|
|
@ -27,6 +27,7 @@ TEST(string_operator, test_string_lower) {
|
|||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
}
|
||||
|
||||
#ifdef ENABLE_RE2_REGEX
|
||||
TEST(string_operator, test_regex_split_with_offsets) {
|
||||
auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");
|
||||
|
||||
|
@ -61,6 +62,7 @@ TEST(string_operator, test_regex_split_with_offsets) {
|
|||
model_path /= "test_regex_split_with_offsets.onnx";
|
||||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST(string_operator, test_string_ecmaregex_replace) {
|
||||
auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");
|
||||
|
|
|
@ -1,32 +1,20 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifdef ENABLE_CV2
|
||||
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <vector>
|
||||
|
||||
|
||||
#ifdef ENABLE_CV2
|
||||
#include "gtest/gtest.h"
|
||||
#include "opencv2/imgcodecs.hpp"
|
||||
|
||||
#include "ocos.h"
|
||||
#include "test_kernel.hpp"
|
||||
#include "test_utils.hpp"
|
||||
|
||||
namespace {
|
||||
std::vector<uint8_t> LoadBytesFromFile(const std::filesystem::path& filename) {
|
||||
using namespace std;
|
||||
ifstream ifs(filename, ios::binary | ios::ate);
|
||||
ifstream::pos_type pos = ifs.tellg();
|
||||
|
||||
std::vector<uint8_t> input_bytes(pos);
|
||||
ifs.seekg(0, ios::beg);
|
||||
// we want uint8_t values so reinterpret_cast so we don't have to read chars and copy to uint8_t after.
|
||||
ifs.read(reinterpret_cast<char*>(input_bytes.data()), pos);
|
||||
|
||||
return input_bytes;
|
||||
}
|
||||
} // namespace
|
||||
using namespace ort_extensions::test;
|
||||
|
||||
// Test DecodeImage and EncodeImage by providing a jpg image. Model will decode to BGR, encode to PNG and decode
|
||||
// again to BGR. We validate that the BGR output from that matches the original image.
|
||||
|
|
|
@ -3,7 +3,9 @@
|
|||
|
||||
#include "gtest/gtest.h"
|
||||
#include "string_utils.h"
|
||||
#ifdef ENABLE_RE2_REGEX
|
||||
#include "text/re2_strings/string_regex_split_re.hpp"
|
||||
#endif
|
||||
#include "text/string_ecmaregex_split.hpp"
|
||||
|
||||
TEST(strings, std_regex_test) {
|
||||
|
@ -17,7 +19,7 @@ TEST(strings, std_regex_test) {
|
|||
std::cout << result << std::endl;
|
||||
}
|
||||
|
||||
|
||||
#ifdef ENABLE_RE2_REGEX
|
||||
TEST(strings, regex_split) {
|
||||
std::string input = "hello world";
|
||||
re2::RE2 reg("(\\s)");
|
||||
|
@ -49,6 +51,7 @@ TEST(strings, regex_split_skip) {
|
|||
EXPECT_EQ(expected_begin_offsets, begin_offsets);
|
||||
EXPECT_EQ(expected_end_offsets, end_offsets);
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST(strings, regex_split_no_matched) {
|
||||
std::string input = "helloworld";
|
||||
|
@ -81,4 +84,3 @@ TEST(strings, regex_split_begin_end_delim) {
|
|||
EXPECT_EQ(expected_begin_offsets, begin_offsets);
|
||||
EXPECT_EQ(expected_end_offsets, end_offsets);
|
||||
}
|
||||
|
||||
|
|
|
@ -2,7 +2,9 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#ifdef ENABLE_RE2_REGEX
|
||||
#include "re2/re2.h"
|
||||
#endif
|
||||
#include "nlohmann/json.hpp"
|
||||
#include "string_utils.h"
|
||||
#include "ustring.h"
|
||||
|
@ -13,10 +15,12 @@ TEST(utils, make_string) {
|
|||
EXPECT_EQ(res, "ab0");
|
||||
}
|
||||
|
||||
#ifdef ENABLE_RE2_REGEX
|
||||
TEST(utils, re2_basic) {
|
||||
re2::StringPiece piece("1234");
|
||||
re2::RE2 reg("[0-9]*");
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST(utils, json) {
|
||||
nlohmann::json j;
|
||||
|
|
Загрузка…
Ссылка в новой задаче