Split out some miscellaneous changes from refactoring the azure ops. (#506)

- ifdef out some test code that requires RE2 if RE2 is not enabled
- add ability to plugin custom output validator for C++ unit tests
  - OpenAI responses can have different punctuation. used in the new unit tests that will be in the refactoring PR
This commit is contained in:
Scott McKay 2023-08-04 17:53:11 +10:00 коммит произвёл GitHub
Родитель 911c2b2340
Коммит d9fa8ea060
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
7 изменённых файлов: 66 добавлений и 47 удалений

Просмотреть файл

@ -382,6 +382,8 @@ struct Variadic : public TensorBase {
}
tensors_.emplace_back(tensor.release());
} // for
} else {
// a Variadic used for output is populated by the Compute so leave tensors_ empty here
}
}
template <typename T>

Просмотреть файл

@ -40,11 +40,21 @@ struct TestValue {
std::vector<bool> values_bool;
};
// output_validator is optional if you need custom validation for one or more outputs.
// for any output you do not have custom validation for call ValidateOutputEqual
using OutputValidator = std::function<void(size_t output_idx, Ort::Value& actual, TestValue expected)>;
void ValidateOutputEqual(size_t output_idx, Ort::Value& actual, TestValue expected);
void RunSession(Ort::Session& session_object,
const std::vector<TestValue>& inputs,
const std::vector<TestValue>& outputs);
const std::vector<TestValue>& outputs,
OutputValidator output_validator = nullptr);
void TestInference(Ort::Env& env, const ORTCHAR_T* model_uri,
const std::vector<TestValue>& inputs,
const std::vector<TestValue>& outputs,
const char* custom_op_library_filename);
const char* custom_op_library_filename,
OutputValidator output_validator = nullptr);
void GetTensorMutableDataString(const OrtApi& api, const OrtValue* value, std::vector<std::string>& output);

Просмотреть файл

@ -209,7 +209,8 @@ void GetTensorMutableDataString(const OrtApi& api, const OrtValue* value, std::v
void RunSession(Ort::Session& session_object,
const std::vector<TestValue>& inputs,
const std::vector<TestValue>& outputs) {
const std::vector<TestValue>& outputs,
OutputValidator output_validator) {
std::vector<Ort::Value> ort_inputs;
std::vector<const char*> input_names;
std::vector<const char*> output_names;
@ -267,38 +268,48 @@ void RunSession(Ort::Session& session_object,
ASSERT_EQ(output_type, expected.element_type);
std::vector<int64_t> dimension = type_info.GetShape();
ASSERT_EQ(dimension, expected.dims);
size_t total_len = type_info.GetElementCount();
switch (expected.element_type) {
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
_assert_eq(*output_tensor, expected.values_float, total_len);
break;
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
_assert_eq(*output_tensor, expected.values_uint8, total_len);
break;
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
_assert_eq(*output_tensor, expected.values_int32, total_len);
break;
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
_assert_eq(*output_tensor, expected.values_int64, total_len);
break;
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: {
std::vector<std::string> output_string;
GetTensorMutableDataString(Ort::GetApi(), *output_tensor, output_string);
ASSERT_EQ(expected.values_string, output_string);
break;
}
default:
throw std::runtime_error(MakeString(
"Unable to handle output ", index, " type ", expected.element_type,
" is not implemented yet."));
if (output_validator != nullptr) {
output_validator(index, *output_tensor, expected);
} else {
ValidateOutputEqual(index, *output_tensor, expected);
}
}
}
void ValidateOutputEqual(size_t output_idx, Ort::Value& actual, TestValue expected) {
size_t total_len = actual.GetTensorTypeAndShapeInfo().GetElementCount();
switch (expected.element_type) {
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
_assert_eq(actual, expected.values_float, total_len);
break;
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
_assert_eq(actual, expected.values_uint8, total_len);
break;
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
_assert_eq(actual, expected.values_int32, total_len);
break;
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
_assert_eq(actual, expected.values_int64, total_len);
break;
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: {
std::vector<std::string> output_string;
GetTensorMutableDataString(Ort::GetApi(), actual, output_string);
ASSERT_EQ(expected.values_string, output_string);
break;
}
default:
throw std::runtime_error(MakeString(
"Unable to handle output ", output_idx, " type ", expected.element_type,
" is not implemented yet."));
}
}
void TestInference(Ort::Env& env, const ORTCHAR_T* model_uri,
const std::vector<TestValue>& inputs,
const std::vector<TestValue>& outputs,
const char* custom_op_library_filename) {
const char* custom_op_library_filename,
OutputValidator output_validator) {
Ort::SessionOptions session_options;
void* handle = nullptr;
if (custom_op_library_filename) {
@ -310,7 +321,7 @@ void TestInference(Ort::Env& env, const ORTCHAR_T* model_uri,
Ort::Session session(env, model_uri, session_options);
// Now run
RunSession(session, inputs, outputs);
RunSession(session, inputs, outputs, output_validator);
}
static CustomOpOne op_1st;

Просмотреть файл

@ -27,6 +27,7 @@ TEST(string_operator, test_string_lower) {
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
}
#ifdef ENABLE_RE2_REGEX
TEST(string_operator, test_regex_split_with_offsets) {
auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");
@ -61,6 +62,7 @@ TEST(string_operator, test_regex_split_with_offsets) {
model_path /= "test_regex_split_with_offsets.onnx";
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
}
#endif
TEST(string_operator, test_string_ecmaregex_replace) {
auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");

Просмотреть файл

@ -1,32 +1,20 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifdef ENABLE_CV2
#include <filesystem>
#include <fstream>
#include <vector>
#ifdef ENABLE_CV2
#include "gtest/gtest.h"
#include "opencv2/imgcodecs.hpp"
#include "ocos.h"
#include "test_kernel.hpp"
#include "test_utils.hpp"
namespace {
std::vector<uint8_t> LoadBytesFromFile(const std::filesystem::path& filename) {
using namespace std;
ifstream ifs(filename, ios::binary | ios::ate);
ifstream::pos_type pos = ifs.tellg();
std::vector<uint8_t> input_bytes(pos);
ifs.seekg(0, ios::beg);
// we want uint8_t values so reinterpret_cast so we don't have to read chars and copy to uint8_t after.
ifs.read(reinterpret_cast<char*>(input_bytes.data()), pos);
return input_bytes;
}
} // namespace
using namespace ort_extensions::test;
// Test DecodeImage and EncodeImage by providing a jpg image. Model will decode to BGR, encode to PNG and decode
// again to BGR. We validate that the BGR output from that matches the original image.

Просмотреть файл

@ -3,7 +3,9 @@
#include "gtest/gtest.h"
#include "string_utils.h"
#ifdef ENABLE_RE2_REGEX
#include "text/re2_strings/string_regex_split_re.hpp"
#endif
#include "text/string_ecmaregex_split.hpp"
TEST(strings, std_regex_test) {
@ -17,7 +19,7 @@ TEST(strings, std_regex_test) {
std::cout << result << std::endl;
}
#ifdef ENABLE_RE2_REGEX
TEST(strings, regex_split) {
std::string input = "hello world";
re2::RE2 reg("(\\s)");
@ -49,6 +51,7 @@ TEST(strings, regex_split_skip) {
EXPECT_EQ(expected_begin_offsets, begin_offsets);
EXPECT_EQ(expected_end_offsets, end_offsets);
}
#endif
TEST(strings, regex_split_no_matched) {
std::string input = "helloworld";
@ -81,4 +84,3 @@ TEST(strings, regex_split_begin_end_delim) {
EXPECT_EQ(expected_begin_offsets, begin_offsets);
EXPECT_EQ(expected_end_offsets, end_offsets);
}

Просмотреть файл

@ -2,7 +2,9 @@
// Licensed under the MIT License.
#include "gtest/gtest.h"
#ifdef ENABLE_RE2_REGEX
#include "re2/re2.h"
#endif
#include "nlohmann/json.hpp"
#include "string_utils.h"
#include "ustring.h"
@ -13,10 +15,12 @@ TEST(utils, make_string) {
EXPECT_EQ(res, "ab0");
}
#ifdef ENABLE_RE2_REGEX
TEST(utils, re2_basic) {
re2::StringPiece piece("1234");
re2::RE2 reg("[0-9]*");
}
#endif
TEST(utils, json) {
nlohmann::json j;