Refactor the unit tests and cmake build script (#726)

* refine the build script

* complete the unit tests.

* remove the commented code
This commit is contained in:
Wenbing Li 2024-05-30 14:16:14 -07:00 коммит произвёл GitHub
Родитель b60df02fd0
Коммит ca433cbea7
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
14 изменённых файлов: 258 добавлений и 105 удалений

Просмотреть файл

@ -285,9 +285,11 @@ function(set_msvc_c_cpp_compiler_warning_level warning_level)
endif()
endfunction()
if (NOT ONNXRUNTIME_INCLUDE_DIR)
include(ext_ortlib)
endif()
# set default MSVC warning level to 3 for external dependencies
set_msvc_c_cpp_compiler_warning_level(3)
include(ext_ortlib)
include(gsl)
macro(standardize_output_folder bin_target)
@ -681,7 +683,6 @@ endif()
if(OCOS_ENABLE_GPT2_TOKENIZER OR OCOS_ENABLE_WORDPIECE_TOKENIZER)
target_include_directories(ocos_operators PUBLIC ${nlohmann_json_SOURCE_DIR}/single_include)
list(APPEND ocos_libraries nlohmann_json::nlohmann_json)
endif()
# If building a shared library we can't throw an internal exception type across the library boundary as the type
@ -695,8 +696,6 @@ if(ANDROID)
list(APPEND ocos_libraries log)
endif()
list(APPEND ocos_libraries Microsoft.GSL::GSL)
list(REMOVE_DUPLICATES OCOS_COMPILE_DEFINITIONS)
target_compile_definitions(noexcep_operators PRIVATE ${OCOS_COMPILE_DEFINITIONS})
if(NOT OCOS_ENABLE_CPP_EXCEPTIONS)
@ -899,7 +898,7 @@ if (_ORTX_STANDALONE_PROJECT)
# Run CPack to generate the NuGet package
include(CPack)
if(OCOS_ENABLE_CTEST)
if(OCOS_ENABLE_CTEST AND NOT MAC_CATALYST)
include(ext_tests)
endif()
endif()

Просмотреть файл

@ -17,7 +17,20 @@ namespace ort_extensions {
class path {
public:
path() = default;
path(const std::string& path) : path_(path){};
path(const std::string& path) : path_(path) {
#ifdef _WIN32
w_path_ = to_wstring();
#endif // _WIN32
};
#ifdef _WIN32
path(const std::wstring& wpath) {
int size_needed = WideCharToMultiByte(CP_UTF8, 0, wpath.c_str(), -1, nullptr, 0, nullptr, nullptr);
std::string utf8_str(size_needed, 0);
WideCharToMultiByte(CP_UTF8, 0, wpath.c_str(), -1, &utf8_str[0], size_needed, nullptr, nullptr);
path_ = utf8_str;
}
#endif // _WIN32
static constexpr char separator =
#ifdef _WIN32
@ -30,7 +43,7 @@ class path {
std::ifstream open(ios_base::openmode mode = ios_base::in) const {
// if Windows, need to convert the string to UTF-16
#ifdef _WIN32
return std::ifstream(to_wstring(), mode);
return std::ifstream(w_path_, mode);
#else
return std::ifstream(path_, mode);
#endif // _WIN32
@ -55,7 +68,7 @@ class path {
bool is_directory() const {
#ifdef _WIN32
struct _stat64 info;
if (_wstat64(to_wstring().c_str(), &info) != 0) {
if (_wstat64(w_path_.c_str(), &info) != 0) {
return false;
}
#else
@ -69,8 +82,9 @@ class path {
private:
std::string path_;
#ifdef _WIN32
std::wstring w_path_;
std::wstring to_wstring() const {
int size_needed = MultiByteToWideChar(CP_UTF8, 0, path_.c_str(), -1, nullptr, 0);
std::wstring utf16_str(size_needed, 0);

Просмотреть файл

@ -1,5 +1,3 @@
if (NOT MAC_CATALYST)
if (OCOS_ENABLE_SELECTED_OPLIST)
# currently the tests don't handle operator exclusion cleanly.
message(FATAL_ERROR "Due to usage of OCOS_ENABLE_SELECTED_OPLIST excluding operators the tests are unable to be built and run")
@ -55,7 +53,7 @@ function(add_test_target)
"${TEST_SRC_DIR}/unittest_main/test_main.cc")
target_link_libraries(${ARG_TARGET} PRIVATE
${ARG_LIBRARIES}
gtest gmock)
gtest)
if(OCOS_USE_CUDA)
target_link_directories(${ARG_TARGET} PRIVATE ${CUDAToolkit_LIBRARY_DIR})
@ -93,7 +91,7 @@ function(add_test_target)
target_link_libraries(${ARG_TARGET} PRIVATE
${ARG_LIBRARIES}
gtest gmock)
gtest)
set(test_data_destination_root_directory $<TARGET_FILE_DIR:${dummy_testee_target}>)
@ -130,9 +128,40 @@ add_test_target(TARGET ocos_test
LIBRARIES ortcustomops ${ocos_libraries})
target_compile_definitions(ocos_test PRIVATE ${OCOS_COMPILE_DEFINITIONS})
if (OCOS_ENABLE_C_API)
file(GLOB pp_api_TEST_SRC
"${TEST_SRC_DIR}/pp_api_test/*.c"
"${TEST_SRC_DIR}/pp_api_test/*.cc"
"${TEST_SRC_DIR}/pp_api_test/*.h")
add_test_target(TARGET pp_api_test
TEST_SOURCES ${pp_api_TEST_SRC}
LIBRARIES onnxruntime_extensions ${ocos_libraries}
TEST_DATA_DIRECTORIES ${TEST_SRC_DIR}/data)
target_compile_definitions(pp_api_test PRIVATE ${OCOS_COMPILE_DEFINITIONS})
target_include_directories(pp_api_test PRIVATE
${PROJECT_SOURCE_DIR}/
"$<TARGET_PROPERTY:ortcustomops,INTERFACE_INCLUDE_DIRECTORIES>"
"$<TARGET_PROPERTY:ocos_operators,INTERFACE_INCLUDE_DIRECTORIES>")
if (ORTX_TEST_DATA2)
file(TO_NATIVE_PATH "${ORTX_TEST_DATA2}/tests/data2" _TEST_DATA2)
add_custom_command(TARGET pp_api_test POST_BUILD
COMMAND ${CMAKE_COMMAND} -E create_symlink ${_TEST_DATA2} ${onnxruntime_extensions_BINARY_DIR}/data2)
endif()
endif()
# -- shared test (needs onnxruntime) --
SET(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH)
find_library(ONNXRUNTIME onnxruntime HINTS "${ONNXRUNTIME_LIB_DIR}")
# avoid blindling searching for onnxruntime library
# wbhich leads to a unpredictable result
if (NOT ONNXRUNTIME_LIB_DIR)
set(ONNXRUNTIME "ONNXRUNTIME-NOTFOUND")
else()
SET(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH)
find_library(ONNXRUNTIME onnxruntime HINTS "${ONNXRUNTIME_LIB_DIR}")
endif()
if("${ONNXRUNTIME}" STREQUAL "ONNXRUNTIME-NOTFOUND")
message(WARNING "The prebuilt onnxruntime library was not found, extensions_test will be skipped.")
@ -197,25 +226,10 @@ else()
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${ONNXRUNTIME} ${CMAKE_BINARY_DIR}/lib
)
endif()
if (OCOS_ENABLE_C_API)
# avoid copying the same data directory at the same time.
add_dependencies(extensions_test pp_api_test)
endif()
endblock()
if (OCOS_ENABLE_C_API)
file(GLOB pp_api_TEST_SRC
"${TEST_SRC_DIR}/pp_api_test/*.c"
"${TEST_SRC_DIR}/pp_api_test/*.cc"
"${TEST_SRC_DIR}/pp_api_test/*.h")
add_test_target(TARGET pp_api_test
TEST_SOURCES ${pp_api_TEST_SRC}
LIBRARIES onnxruntime_extensions ${ocos_libraries}
TEST_DATA_DIRECTORIES ${TEST_SRC_DIR}/data)
target_compile_definitions(pp_api_test PRIVATE ${OCOS_COMPILE_DEFINITIONS})
target_include_directories(pp_api_test PRIVATE
${PROJECT_SOURCE_DIR}/
"$<TARGET_PROPERTY:ortcustomops,INTERFACE_INCLUDE_DIRECTORIES>"
"$<TARGET_PROPERTY:ocos_operators,INTERFACE_INCLUDE_DIRECTORIES>")
endif()
endif()
endif()

4
cmake/externals/googletest.cmake поставляемый
Просмотреть файл

@ -4,8 +4,8 @@ FetchContent_Declare(
URL_HASH SHA1=06096d3900c356e468ba060a609642c635131106
)
set(BUILD_GMOCK OFF CACHE BOOL "" FORCE)
set(INSTALL_GTEST OFF CACHE BOOL "" FORCE)
FetchContent_MakeAvailable(googletest)
set_target_properties(gmock PROPERTIES FOLDER "externals/gtest")
set_target_properties(gmock_main PROPERTIES FOLDER "externals/gtest")
set_target_properties(gtest PROPERTIES FOLDER "externals/gtest")
set_target_properties(gtest_main PROPERTIES FOLDER "externals/gtest")

12
cmake/externals/gsl.cmake поставляемый
Просмотреть файл

@ -14,5 +14,13 @@ else()
)
endif()
FetchContent_MakeAvailable(GSL)
get_target_property(GSL_INCLUDE_DIR Microsoft.GSL::GSL INTERFACE_INCLUDE_DIRECTORIES)
FetchContent_GetProperties(GSL)
string(TOLOWER "GSL" lcName)
if(NOT ${lcName}_POPULATED)
FetchContent_Populate(GSL)
# add_subdirectory(${GSL_SOURCE_DIR} ${GSL_BINARY_DIR} EXCLUDE_FROM_ALL)
endif()
set(GSL_INCLUDE_DIR ${gsl_SOURCE_DIR}/include)
#get_target_property(GSL_INCLUDE_DIR Microsoft.GSL::GSL INTERFACE_INCLUDE_DIRECTORIES)

1
cmake/externals/json.cmake поставляемый
Просмотреть файл

@ -7,5 +7,4 @@ set(JSON_BuildTests OFF CACHE INTERNAL "")
FetchContent_GetProperties(nlohmann_json)
if(NOT nlohmann_json_POPULATED)
FetchContent_Populate(nlohmann_json)
add_subdirectory(${nlohmann_json_SOURCE_DIR} ${nlohmann_json_BINARY_DIR} EXCLUDE_FROM_ALL)
endif()

20
cmake/externals/opencv.cmake поставляемый
Просмотреть файл

@ -107,6 +107,16 @@ set(BUILD_TESTS OFF CACHE INTERNAL "")
set(CV_TRACE OFF CACHE INTERNAL "")
set(CV_DISABLE_OPTIMIZATION ON CACHE INTERNAL "")
set(BUILD_PERF_TESTS OFF CACHE INTERNAL "")
set(BUILD_opencv_java_bindings_generator OFF CACHE INTERNAL "")
set(BUILD_opencv_js_bindings_generator OFF CACHE INTERNAL "")
set(BUILD_opencv_objc_bindings_generator OFF CACHE INTERNAL "")
set(BUILD_opencv_python_bindings_generator OFF CACHE INTERNAL "")
set(BUILD_opencv_python_tests OFF CACHE INTERNAL "")
set(WITH_ADE OFF CACHE INTERNAL "")
set(VIDEOIO_ENABLE_PLUGINS OFF CACHE INTERNAL "")
set(HIGHGUI_ENABLE_PLUGINS OFF CACHE INTERNAL "")
if(IOS)
# copy what OpenCV's platforms/ios/build_framework.py does and set CPU_BASELINE=DETECT
@ -157,13 +167,3 @@ endif()
# unset it to avoid affecting other projects.
unset(EXECUTABLE_OUTPUT_PATH CACHE)
if (CMAKE_SYSTEM_NAME MATCHES "Windows")
set(opencv_projs gen_opencv_java_source gen_opencv_js_source gen_opencv_python_source)
list(APPEND opencv_projs gen_opencv_objc_source gen_opencv_objc_source_ios gen_opencv_objc_source_osx)
list(APPEND opencv_projs opencv_highgui_plugins opencv_videoio_plugins)
foreach(p ${opencv_projs})
set_target_properties(${p} PROPERTIES FOLDER "externals/opencv")
set_target_properties(${p} PROPERTIES EXCLUDE_FROM_ALL TRUE EXCLUDE_FROM_DEFAULT_BUILD TRUE)
endforeach()
endif()

Просмотреть файл

@ -48,7 +48,7 @@ class BoxArray {
private:
void SortBoxesByScore(gsl::span<const float> data) {
boxes_by_score_.reserve(NumBoxes());
for (size_t i = 0; i < NumBoxes(); ++i) {
for (size_t i = 0; i < static_cast<size_t>(NumBoxes()); ++i) {
boxes_by_score_.push_back(data.subspan(i * shape_[1], shape_[1]));
}
@ -188,7 +188,7 @@ void DrawBoxesForNumClasses(ImageView& image, const BoxArray& boxes, int64_t thi
std::unordered_map<float, size_t> color_used;
std::vector<std::pair<size_t, int64_t>> box_reverse;
box_reverse.reserve(boxes.NumBoxes());
for (size_t i = 0; i < boxes.NumBoxes(); ++i) {
for (size_t i = 0; i < static_cast<size_t>(boxes.NumBoxes()); ++i) {
const auto box = boxes.GetBox(i);
if (color_used.find(box[kBoxClassIndex]) == color_used.end()) {
if (color_used.size() >= KBGRColorMap.size()) {

Просмотреть файл

@ -38,9 +38,9 @@ LoadRawImages(const std::initializer_list<const char*>& image_paths) {
} // namespace ort_extensions
Operation::KernelRegistry ImageProcessor::kernel_registry_ = {
{"DecodeImage", []() { return DefineKernelFunction(image_decoder); }},
{"ConvertRGB", []() { return DefineKernelFunction(convert_to_rgb); }},
{"Phi3ImageTransform", []() { return DefineKernelFunction(phi3_hd_transform); }},
{"DecodeImage", []() { return CreateKernelInstance(image_decoder); }},
{"ConvertRGB", []() { return CreateKernelInstance(&ConvertToRGB::Compute); }},
{"Phi3ImageTransform", []() { return CreateKernelInstance(phi3_hd_transform); }},
};
OrtxStatus ImageProcessor::Init(std::string_view processor_def) {

Просмотреть файл

@ -13,30 +13,32 @@ constexpr int image_resized_height = 336;
constexpr float OPENAI_CLIP_MEAN[] = {0.48145466f, 0.4578275f, 0.40821073f};
constexpr float OPENAI_CLIP_STD[] = {0.26862954f, 0.26130258f, 0.27577711f};
inline OrtxStatus convert_to_rgb(const ortc::Tensor<uint8_t>& input,
ortc::Tensor<uint8_t>& output) {
auto& dimensions = input.Shape();
if (dimensions.size() != 3ULL || dimensions[2] != 3) {
return {kOrtxErrorInvalidArgument, "[ConvertToRGB]: input is not (H, W, C)"};
}
std::uint8_t* p_output_image = output.Allocate(dimensions);
auto* input_data = input.Data();
auto h = dimensions[0];
auto w = dimensions[1];
auto c = dimensions[2];
// convert BGR channel layouts to RGB
for (int64_t j = 0; j < h; ++j) {
for (int64_t k = 0; k < w; ++k) {
auto c0_index = j * w * c + k * c;
std::tie(p_output_image[c0_index], p_output_image[c0_index + 1], p_output_image[c0_index + 2]) =
std::make_tuple(input_data[c0_index + 2], input_data[c0_index + 1], input_data[c0_index]);
struct ConvertToRGB {
OrtxStatus Compute(const ortc::Tensor<uint8_t>& input,
ortc::Tensor<uint8_t>& output) {
auto& dimensions = input.Shape();
if (dimensions.size() != 3ULL || dimensions[2] != 3) {
return {kOrtxErrorInvalidArgument, "[ConvertToRGB]: input is not (H, W, C)"};
}
}
return {};
}
std::uint8_t* p_output_image = output.Allocate(dimensions);
auto* input_data = input.Data();
auto h = dimensions[0];
auto w = dimensions[1];
auto c = dimensions[2];
// convert BGR channel layouts to RGB
for (int64_t j = 0; j < h; ++j) {
for (int64_t k = 0; k < w; ++k) {
auto c0_index = j * w * c + k * c;
std::tie(p_output_image[c0_index], p_output_image[c0_index + 1], p_output_image[c0_index + 2]) =
std::make_tuple(input_data[c0_index + 2], input_data[c0_index + 1], input_data[c0_index]);
}
}
return {};
}
};
inline cv::Mat padding_336(const cv::Mat& image) {
// def padding_336(b):

Просмотреть файл

@ -18,10 +18,11 @@ namespace ort_extensions {
using json = nlohmann::json;
using TensorArgs = std::vector<ortc::TensorBase*>;
class KernelClass {
class KernelDef {
public:
KernelClass() = default;
virtual ~KernelClass() = default;
KernelDef() = default;
virtual ~KernelDef() = default;
virtual OrtxStatus Init(std::string_view attr) { return {}; } // no need to be initialized for a kernel function
virtual TensorArgs AllocateOutput(ortc::IAllocator* allocator) const = 0;
virtual OrtxStatus Apply(TensorArgs& inputs, TensorArgs& output) const = 0;
@ -91,17 +92,13 @@ class KernelClass {
};
template <typename... Args>
class KernelFunction : public KernelClass {
class KernelFunction : public KernelDef {
public:
KernelFunction(OrtxStatus (*body)(Args...)) : body_(body){};
virtual ~KernelFunction() = default;
OrtxStatus Compute(Args... args) const {
return body_(std::forward<Args>(args)...);
}
TensorArgs AllocateOutput(ortc::IAllocator* allocator) const override {
auto tensors = KernelClass::AllocateOutput<Args...>(allocator);
auto tensors = KernelDef::AllocateOutput<Args...>(allocator);
TensorArgs all_args;
for (auto& tensor : tensors) {
if (tensor != nullptr) {
@ -123,16 +120,64 @@ class KernelFunction : public KernelClass {
private:
std::function<OrtxStatus(Args...)> body_;
OrtxStatus Compute(Args... args) const {
return body_(std::forward<Args>(args)...);
}
};
template <typename T, typename... Args>
class KernelStruct : public KernelDef {
public:
KernelStruct(OrtxStatus (T::*body)(Args...)) : body_(body){};
virtual ~KernelStruct() = default;
TensorArgs AllocateOutput(ortc::IAllocator* allocator) const override {
auto tensors = KernelDef::AllocateOutput<Args...>(allocator);
TensorArgs all_args;
for (auto& tensor : tensors) {
if (tensor != nullptr) {
all_args.push_back(tensor);
}
}
return all_args;
}
template <typename DT>
OrtxStatus Init(DT attr) {
instance_ = std::make_unique<T>();
return instance_->Init(std::move(attr));
}
OrtxStatus Apply(TensorArgs& inputs, TensorArgs& outputs) const override {
TensorArgs all_args;
all_args.reserve(inputs.size() + outputs.size());
all_args.insert(all_args.end(), inputs.begin(), inputs.end());
all_args.insert(all_args.end(), outputs.begin(), outputs.end());
auto args_tuple = std::tuple_cat(CastTensors<Args...>(all_args));
return std::apply([this](auto&&... args) {
return (instance_.get()->*body_)(std::forward<decltype(*args)>(*args)...); }, std::move(args_tuple));
}
private:
OrtxStatus (T::*body_)(Args...){};
std::unique_ptr<T> instance_;
};
template <typename... Args>
std::unique_ptr<KernelClass> DefineKernelFunction(OrtxStatus (*body)(Args...)) {
std::unique_ptr<KernelDef> CreateKernelInstance(OrtxStatus (*body)(Args...)) {
return std::make_unique<KernelFunction<Args...>>(body);
}
template <typename T, typename... Args>
std::unique_ptr<KernelDef> CreateKernelInstance(OrtxStatus (T::*method)(Args...)) {
return std::make_unique<KernelStruct<T, Args...>>(method);
}
class Operation {
public:
using KernelRegistry = std::unordered_map<std::string_view, std::function<std::unique_ptr<KernelClass>()>>;
using KernelRegistry = std::unordered_map<std::string_view, std::function<std::unique_ptr<KernelDef>()>>;
Operation(const KernelRegistry& registry) { kernel_registry_ = &registry; };
OrtxStatus Init(std::string_view op_def) {
@ -162,12 +207,14 @@ class Operation {
op_name_ = op_name;
kernel_ = kernel_iter->second();
/* TODO: parse the attributes
if (op_json.contains("attrs")) {
auto attrs = op_json.at("attrs");
attrs.get_to(attributes_);
if (op_json.contains("attrs")) {
auto attrs = op_json.at("attrs");
auto status = kernel_->Init(attrs.dump());
if (!status.IsOk()) {
return status;
}
*/
}
return {};
}
@ -192,9 +239,8 @@ class Operation {
private:
const KernelRegistry* kernel_registry_;
std::unique_ptr<KernelClass> kernel_;
std::unique_ptr<KernelDef> kernel_;
std::string op_name_;
std::unordered_map<std::string, std::string> attributes_;
ortc::IAllocator* allocator_{};
};

Просмотреть файл

@ -1,38 +1,74 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <fstream>
#include <vector>
#include <tuple>
#include <fstream>
#include <filesystem>
#include "gtest/gtest.h"
#include "shared/api/image_processor.h"
using namespace ort_extensions;
TEST(ProcessorTest, TestMsImage) {
std::vector<float> ReadArrayFromFile(const std::string& filename) {
std::ifstream inFile(filename, std::ios::binary | std::ios::ate);
if (!inFile) {
throw std::runtime_error("Cannot open file for reading.");
}
std::streamsize fileSize = inFile.tellg();
inFile.seekg(0, std::ios::beg);
std::vector<float> array(fileSize / sizeof(float));
if (!inFile.read(reinterpret_cast<char*>(array.data()), fileSize)) {
throw std::runtime_error("Error reading file.");
}
return array;
}
TEST(ProcessorTest, TestPhi3VImageProcessing) {
auto [input_data, n_data] = ort_extensions::LoadRawImages(
{"data/processor/standard_s.jpg", "data/processor/australia.jpg", "data/processor/exceltable.png"});
auto proc = OrtxObjectPtr<ImageProcessor>(OrtxCreateProcessor, "data/processor/image_processor.json");
ortc::Tensor<float>* pixel_values;
ortc::Tensor<int64_t>* image_sizes;
ortc::Tensor<int64_t>* num_img_takens;
ortc::Tensor<int64_t>* num_img_tokens;
auto [status, r] = proc->PreProcess(
ort_extensions::span(input_data.get(), (size_t)n_data),
&pixel_values,
&image_sizes,
&num_img_takens);
&num_img_tokens);
ASSERT_TRUE(status.IsOk());
// dump the output to a file
// FILE* fp = fopen("ppoutput.bin", "wb");
// fwrite(pixel_values->Data(), sizeof(float), pixel_values->NumberOfElement(), fp);
// fclose(fp);
int64_t expected_image_size[] = {1344, 1344, 1008, 1344, 1008, 1680};
int64_t expected_num_token[] = {2509, 1921, 2353};
ASSERT_EQ(pixel_values->Shape(), std::vector<int64_t>({3, 17, 3, 336, 336}));
ASSERT_EQ(image_sizes->Shape(), std::vector<int64_t>({3, 2}));
ASSERT_EQ(num_img_tokens->Shape(), std::vector<int64_t>({3, 1}));
if (std::filesystem::is_directory("data2/processor")) {
// the test data was dumped in this way
// {
// std::ofstream outFile("data2/processor/img_proc_pixel_values.bin", std::ios::binary);
// outFile.write(reinterpret_cast<const char*>(array.data()), array.size() * sizeof(float));
// }
auto expected_output = ReadArrayFromFile("data2/processor/img_proc_pixel_values.bin");
ASSERT_EQ(pixel_values->NumberOfElement(), expected_output.size());
for (size_t i = 0; i < expected_output.size(); i++) {
ASSERT_NEAR(pixel_values->Data()[i], expected_output[i], 1e-3);
}
}
// compare the image sizes
for (size_t i = 0; i < 3; i++) {
ASSERT_EQ(image_sizes->Data()[i * 2], expected_image_size[i * 2]);
ASSERT_EQ(image_sizes->Data()[i * 2 + 1], expected_image_size[i * 2 + 1]);
ASSERT_EQ(num_img_tokens->Data()[i], expected_num_token[i]);
}
proc->ClearOutputs(&r);
}

Просмотреть файл

@ -117,6 +117,42 @@ TEST(OrtxTokenizerTest, TicTokenTokenizer) {
EXPECT_EQ(out_text[0], input[0]);
}
TEST(OrtxTokenizerTest, Phi3_S_Tokenizer) {
if (!std::filesystem::exists("data2/phi-3-small")) {
GTEST_SKIP() << "Skip test as extra test data is not deployed.";
}
auto tokenizer = std::make_unique<ort_extensions::TokenizerImpl>();
auto status = tokenizer->Load("data2/phi-3-small");
if (!status.IsOk()) {
std::cout << status.ToString() << std::endl;
}
// validate tokenizer is not null
EXPECT_NE(tokenizer, nullptr);
std::vector<extTokenId_t> EXPECTED_IDS_0 = {2028, 374, 264, 1296, 13};
std::vector<std::string_view> input = {
"This is a test.",
"the second one",
"I like walking my cute dog\n and\x17 then",
"Hey<|endoftext|>. \t\t \n\nyou é @#😈 🤗! , 1234 15 5,61"};
std::vector<std::vector<extTokenId_t>>
token_ids;
status = tokenizer->Tokenize(input, token_ids);
EXPECT_TRUE(status.IsOk());
DumpTokenIds(token_ids);
EXPECT_EQ(token_ids.size(), input.size());
EXPECT_EQ(token_ids[0], EXPECTED_IDS_0);
std::vector<std::string> out_text;
std::vector<ort_extensions::span<extTokenId_t const>> token_ids_span = {token_ids[0], token_ids[1]};
status = tokenizer->Detokenize(token_ids_span, out_text);
EXPECT_TRUE(status.IsOk());
EXPECT_EQ(out_text[0], input[0]);
}
TEST(OrtxTokenizerTest, GemmaTokenizer) {
auto tokenizer = std::make_unique<ort_extensions::TokenizerImpl>();
auto status = tokenizer->Load("data/gemma");

Просмотреть файл

@ -3,7 +3,6 @@
#include <filesystem>
#include "gtest/gtest.h"
#include "gmock/gmock.h"
#include "ocos.h"
#include "test_kernel.hpp"