From ca433cbea706e7c1782df25391f877e28b887d61 Mon Sep 17 00:00:00 2001 From: Wenbing Li <10278425+wenbingl@users.noreply.github.com> Date: Thu, 30 May 2024 14:16:14 -0700 Subject: [PATCH] Refactor the unit tests and cmake build script (#726) * refine the build script * complete the unit tests. * remove the commented code --- CMakeLists.txt | 9 ++- base/file_sys.h | 22 +++++-- cmake/ext_tests.cmake | 66 ++++++++++++--------- cmake/externals/googletest.cmake | 4 +- cmake/externals/gsl.cmake | 12 +++- cmake/externals/json.cmake | 1 - cmake/externals/opencv.cmake | 20 +++---- operators/vision/draw_bounding_box.cc | 4 +- shared/api/image_processor.cc | 6 +- shared/api/image_transforms.hpp | 46 ++++++++------- shared/api/runner.hpp | 82 +++++++++++++++++++++------ test/pp_api_test/test_processor.cc | 54 +++++++++++++++--- test/pp_api_test/test_tokenizer.cc | 36 ++++++++++++ test/shared_test/test_exceptions.cc | 1 - 14 files changed, 258 insertions(+), 105 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 2ffce0e4..066e204f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -285,9 +285,11 @@ function(set_msvc_c_cpp_compiler_warning_level warning_level) endif() endfunction() +if (NOT ONNXRUNTIME_INCLUDE_DIR) + include(ext_ortlib) +endif() # set default MSVC warning level to 3 for external dependencies set_msvc_c_cpp_compiler_warning_level(3) -include(ext_ortlib) include(gsl) macro(standardize_output_folder bin_target) @@ -681,7 +683,6 @@ endif() if(OCOS_ENABLE_GPT2_TOKENIZER OR OCOS_ENABLE_WORDPIECE_TOKENIZER) target_include_directories(ocos_operators PUBLIC ${nlohmann_json_SOURCE_DIR}/single_include) - list(APPEND ocos_libraries nlohmann_json::nlohmann_json) endif() # If building a shared library we can't throw an internal exception type across the library boundary as the type @@ -695,8 +696,6 @@ if(ANDROID) list(APPEND ocos_libraries log) endif() -list(APPEND ocos_libraries Microsoft.GSL::GSL) - list(REMOVE_DUPLICATES OCOS_COMPILE_DEFINITIONS) target_compile_definitions(noexcep_operators PRIVATE ${OCOS_COMPILE_DEFINITIONS}) if(NOT OCOS_ENABLE_CPP_EXCEPTIONS) @@ -899,7 +898,7 @@ if (_ORTX_STANDALONE_PROJECT) # Run CPack to generate the NuGet package include(CPack) - if(OCOS_ENABLE_CTEST) + if(OCOS_ENABLE_CTEST AND NOT MAC_CATALYST) include(ext_tests) endif() endif() diff --git a/base/file_sys.h b/base/file_sys.h index 2dc0b0bf..3ff42537 100644 --- a/base/file_sys.h +++ b/base/file_sys.h @@ -17,7 +17,20 @@ namespace ort_extensions { class path { public: path() = default; - path(const std::string& path) : path_(path){}; + path(const std::string& path) : path_(path) { +#ifdef _WIN32 + w_path_ = to_wstring(); +#endif // _WIN32 + }; + +#ifdef _WIN32 + path(const std::wstring& wpath) { + int size_needed = WideCharToMultiByte(CP_UTF8, 0, wpath.c_str(), -1, nullptr, 0, nullptr, nullptr); + std::string utf8_str(size_needed, 0); + WideCharToMultiByte(CP_UTF8, 0, wpath.c_str(), -1, &utf8_str[0], size_needed, nullptr, nullptr); + path_ = utf8_str; + } +#endif // _WIN32 static constexpr char separator = #ifdef _WIN32 @@ -30,7 +43,7 @@ class path { std::ifstream open(ios_base::openmode mode = ios_base::in) const { // if Windows, need to convert the string to UTF-16 #ifdef _WIN32 - return std::ifstream(to_wstring(), mode); + return std::ifstream(w_path_, mode); #else return std::ifstream(path_, mode); #endif // _WIN32 @@ -55,7 +68,7 @@ class path { bool is_directory() const { #ifdef _WIN32 struct _stat64 info; - if (_wstat64(to_wstring().c_str(), &info) != 0) { + if (_wstat64(w_path_.c_str(), &info) != 0) { return false; } #else @@ -69,8 +82,9 @@ class path { private: std::string path_; - #ifdef _WIN32 + std::wstring w_path_; + std::wstring to_wstring() const { int size_needed = MultiByteToWideChar(CP_UTF8, 0, path_.c_str(), -1, nullptr, 0); std::wstring utf16_str(size_needed, 0); diff --git a/cmake/ext_tests.cmake b/cmake/ext_tests.cmake index c51535fd..4e39e7ba 100644 --- a/cmake/ext_tests.cmake +++ b/cmake/ext_tests.cmake @@ -1,5 +1,3 @@ -if (NOT MAC_CATALYST) - if (OCOS_ENABLE_SELECTED_OPLIST) # currently the tests don't handle operator exclusion cleanly. message(FATAL_ERROR "Due to usage of OCOS_ENABLE_SELECTED_OPLIST excluding operators the tests are unable to be built and run") @@ -55,7 +53,7 @@ function(add_test_target) "${TEST_SRC_DIR}/unittest_main/test_main.cc") target_link_libraries(${ARG_TARGET} PRIVATE ${ARG_LIBRARIES} - gtest gmock) + gtest) if(OCOS_USE_CUDA) target_link_directories(${ARG_TARGET} PRIVATE ${CUDAToolkit_LIBRARY_DIR}) @@ -93,7 +91,7 @@ function(add_test_target) target_link_libraries(${ARG_TARGET} PRIVATE ${ARG_LIBRARIES} - gtest gmock) + gtest) set(test_data_destination_root_directory $) @@ -130,9 +128,40 @@ add_test_target(TARGET ocos_test LIBRARIES ortcustomops ${ocos_libraries}) target_compile_definitions(ocos_test PRIVATE ${OCOS_COMPILE_DEFINITIONS}) +if (OCOS_ENABLE_C_API) + file(GLOB pp_api_TEST_SRC + "${TEST_SRC_DIR}/pp_api_test/*.c" + "${TEST_SRC_DIR}/pp_api_test/*.cc" + "${TEST_SRC_DIR}/pp_api_test/*.h") + + add_test_target(TARGET pp_api_test + TEST_SOURCES ${pp_api_TEST_SRC} + LIBRARIES onnxruntime_extensions ${ocos_libraries} + TEST_DATA_DIRECTORIES ${TEST_SRC_DIR}/data) + + target_compile_definitions(pp_api_test PRIVATE ${OCOS_COMPILE_DEFINITIONS}) + target_include_directories(pp_api_test PRIVATE + ${PROJECT_SOURCE_DIR}/ + "$" + "$") + + if (ORTX_TEST_DATA2) + file(TO_NATIVE_PATH "${ORTX_TEST_DATA2}/tests/data2" _TEST_DATA2) + add_custom_command(TARGET pp_api_test POST_BUILD + COMMAND ${CMAKE_COMMAND} -E create_symlink ${_TEST_DATA2} ${onnxruntime_extensions_BINARY_DIR}/data2) + endif() +endif() + + # -- shared test (needs onnxruntime) -- -SET(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH) -find_library(ONNXRUNTIME onnxruntime HINTS "${ONNXRUNTIME_LIB_DIR}") +# avoid blindling searching for onnxruntime library +# wbhich leads to a unpredictable result +if (NOT ONNXRUNTIME_LIB_DIR) + set(ONNXRUNTIME "ONNXRUNTIME-NOTFOUND") +else() + SET(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH) + find_library(ONNXRUNTIME onnxruntime HINTS "${ONNXRUNTIME_LIB_DIR}") +endif() if("${ONNXRUNTIME}" STREQUAL "ONNXRUNTIME-NOTFOUND") message(WARNING "The prebuilt onnxruntime library was not found, extensions_test will be skipped.") @@ -197,25 +226,10 @@ else() COMMAND ${CMAKE_COMMAND} -E copy_if_different ${ONNXRUNTIME} ${CMAKE_BINARY_DIR}/lib ) endif() + + if (OCOS_ENABLE_C_API) + # avoid copying the same data directory at the same time. + add_dependencies(extensions_test pp_api_test) + endif() endblock() - - if (OCOS_ENABLE_C_API) - file(GLOB pp_api_TEST_SRC - "${TEST_SRC_DIR}/pp_api_test/*.c" - "${TEST_SRC_DIR}/pp_api_test/*.cc" - "${TEST_SRC_DIR}/pp_api_test/*.h") - - add_test_target(TARGET pp_api_test - TEST_SOURCES ${pp_api_TEST_SRC} - LIBRARIES onnxruntime_extensions ${ocos_libraries} - TEST_DATA_DIRECTORIES ${TEST_SRC_DIR}/data) - - target_compile_definitions(pp_api_test PRIVATE ${OCOS_COMPILE_DEFINITIONS}) - target_include_directories(pp_api_test PRIVATE - ${PROJECT_SOURCE_DIR}/ - "$" - "$") - endif() -endif() - endif() diff --git a/cmake/externals/googletest.cmake b/cmake/externals/googletest.cmake index 5d310e93..2bc1de35 100644 --- a/cmake/externals/googletest.cmake +++ b/cmake/externals/googletest.cmake @@ -4,8 +4,8 @@ FetchContent_Declare( URL_HASH SHA1=06096d3900c356e468ba060a609642c635131106 ) +set(BUILD_GMOCK OFF CACHE BOOL "" FORCE) +set(INSTALL_GTEST OFF CACHE BOOL "" FORCE) FetchContent_MakeAvailable(googletest) -set_target_properties(gmock PROPERTIES FOLDER "externals/gtest") -set_target_properties(gmock_main PROPERTIES FOLDER "externals/gtest") set_target_properties(gtest PROPERTIES FOLDER "externals/gtest") set_target_properties(gtest_main PROPERTIES FOLDER "externals/gtest") diff --git a/cmake/externals/gsl.cmake b/cmake/externals/gsl.cmake index 377a4065..c4cedc4b 100644 --- a/cmake/externals/gsl.cmake +++ b/cmake/externals/gsl.cmake @@ -14,5 +14,13 @@ else() ) endif() -FetchContent_MakeAvailable(GSL) -get_target_property(GSL_INCLUDE_DIR Microsoft.GSL::GSL INTERFACE_INCLUDE_DIRECTORIES) +FetchContent_GetProperties(GSL) +string(TOLOWER "GSL" lcName) +if(NOT ${lcName}_POPULATED) + FetchContent_Populate(GSL) +# add_subdirectory(${GSL_SOURCE_DIR} ${GSL_BINARY_DIR} EXCLUDE_FROM_ALL) +endif() + +set(GSL_INCLUDE_DIR ${gsl_SOURCE_DIR}/include) + +#get_target_property(GSL_INCLUDE_DIR Microsoft.GSL::GSL INTERFACE_INCLUDE_DIRECTORIES) diff --git a/cmake/externals/json.cmake b/cmake/externals/json.cmake index 9e9f179b..a41676c0 100644 --- a/cmake/externals/json.cmake +++ b/cmake/externals/json.cmake @@ -7,5 +7,4 @@ set(JSON_BuildTests OFF CACHE INTERNAL "") FetchContent_GetProperties(nlohmann_json) if(NOT nlohmann_json_POPULATED) FetchContent_Populate(nlohmann_json) - add_subdirectory(${nlohmann_json_SOURCE_DIR} ${nlohmann_json_BINARY_DIR} EXCLUDE_FROM_ALL) endif() diff --git a/cmake/externals/opencv.cmake b/cmake/externals/opencv.cmake index d94683ba..dc595906 100644 --- a/cmake/externals/opencv.cmake +++ b/cmake/externals/opencv.cmake @@ -107,6 +107,16 @@ set(BUILD_TESTS OFF CACHE INTERNAL "") set(CV_TRACE OFF CACHE INTERNAL "") set(CV_DISABLE_OPTIMIZATION ON CACHE INTERNAL "") +set(BUILD_PERF_TESTS OFF CACHE INTERNAL "") +set(BUILD_opencv_java_bindings_generator OFF CACHE INTERNAL "") +set(BUILD_opencv_js_bindings_generator OFF CACHE INTERNAL "") +set(BUILD_opencv_objc_bindings_generator OFF CACHE INTERNAL "") +set(BUILD_opencv_python_bindings_generator OFF CACHE INTERNAL "") +set(BUILD_opencv_python_tests OFF CACHE INTERNAL "") + +set(WITH_ADE OFF CACHE INTERNAL "") +set(VIDEOIO_ENABLE_PLUGINS OFF CACHE INTERNAL "") +set(HIGHGUI_ENABLE_PLUGINS OFF CACHE INTERNAL "") if(IOS) # copy what OpenCV's platforms/ios/build_framework.py does and set CPU_BASELINE=DETECT @@ -157,13 +167,3 @@ endif() # unset it to avoid affecting other projects. unset(EXECUTABLE_OUTPUT_PATH CACHE) - -if (CMAKE_SYSTEM_NAME MATCHES "Windows") - set(opencv_projs gen_opencv_java_source gen_opencv_js_source gen_opencv_python_source) - list(APPEND opencv_projs gen_opencv_objc_source gen_opencv_objc_source_ios gen_opencv_objc_source_osx) - list(APPEND opencv_projs opencv_highgui_plugins opencv_videoio_plugins) - foreach(p ${opencv_projs}) - set_target_properties(${p} PROPERTIES FOLDER "externals/opencv") - set_target_properties(${p} PROPERTIES EXCLUDE_FROM_ALL TRUE EXCLUDE_FROM_DEFAULT_BUILD TRUE) - endforeach() -endif() diff --git a/operators/vision/draw_bounding_box.cc b/operators/vision/draw_bounding_box.cc index 42b6e1b4..c720d8db 100644 --- a/operators/vision/draw_bounding_box.cc +++ b/operators/vision/draw_bounding_box.cc @@ -48,7 +48,7 @@ class BoxArray { private: void SortBoxesByScore(gsl::span data) { boxes_by_score_.reserve(NumBoxes()); - for (size_t i = 0; i < NumBoxes(); ++i) { + for (size_t i = 0; i < static_cast(NumBoxes()); ++i) { boxes_by_score_.push_back(data.subspan(i * shape_[1], shape_[1])); } @@ -188,7 +188,7 @@ void DrawBoxesForNumClasses(ImageView& image, const BoxArray& boxes, int64_t thi std::unordered_map color_used; std::vector> box_reverse; box_reverse.reserve(boxes.NumBoxes()); - for (size_t i = 0; i < boxes.NumBoxes(); ++i) { + for (size_t i = 0; i < static_cast(boxes.NumBoxes()); ++i) { const auto box = boxes.GetBox(i); if (color_used.find(box[kBoxClassIndex]) == color_used.end()) { if (color_used.size() >= KBGRColorMap.size()) { diff --git a/shared/api/image_processor.cc b/shared/api/image_processor.cc index df9f773f..02801597 100644 --- a/shared/api/image_processor.cc +++ b/shared/api/image_processor.cc @@ -38,9 +38,9 @@ LoadRawImages(const std::initializer_list& image_paths) { } // namespace ort_extensions Operation::KernelRegistry ImageProcessor::kernel_registry_ = { - {"DecodeImage", []() { return DefineKernelFunction(image_decoder); }}, - {"ConvertRGB", []() { return DefineKernelFunction(convert_to_rgb); }}, - {"Phi3ImageTransform", []() { return DefineKernelFunction(phi3_hd_transform); }}, + {"DecodeImage", []() { return CreateKernelInstance(image_decoder); }}, + {"ConvertRGB", []() { return CreateKernelInstance(&ConvertToRGB::Compute); }}, + {"Phi3ImageTransform", []() { return CreateKernelInstance(phi3_hd_transform); }}, }; OrtxStatus ImageProcessor::Init(std::string_view processor_def) { diff --git a/shared/api/image_transforms.hpp b/shared/api/image_transforms.hpp index 8dc26e68..773d70cc 100644 --- a/shared/api/image_transforms.hpp +++ b/shared/api/image_transforms.hpp @@ -13,30 +13,32 @@ constexpr int image_resized_height = 336; constexpr float OPENAI_CLIP_MEAN[] = {0.48145466f, 0.4578275f, 0.40821073f}; constexpr float OPENAI_CLIP_STD[] = {0.26862954f, 0.26130258f, 0.27577711f}; -inline OrtxStatus convert_to_rgb(const ortc::Tensor& input, - ortc::Tensor& output) { - auto& dimensions = input.Shape(); - if (dimensions.size() != 3ULL || dimensions[2] != 3) { - return {kOrtxErrorInvalidArgument, "[ConvertToRGB]: input is not (H, W, C)"}; - } - - std::uint8_t* p_output_image = output.Allocate(dimensions); - auto* input_data = input.Data(); - auto h = dimensions[0]; - auto w = dimensions[1]; - auto c = dimensions[2]; - - // convert BGR channel layouts to RGB - for (int64_t j = 0; j < h; ++j) { - for (int64_t k = 0; k < w; ++k) { - auto c0_index = j * w * c + k * c; - std::tie(p_output_image[c0_index], p_output_image[c0_index + 1], p_output_image[c0_index + 2]) = - std::make_tuple(input_data[c0_index + 2], input_data[c0_index + 1], input_data[c0_index]); +struct ConvertToRGB { + OrtxStatus Compute(const ortc::Tensor& input, + ortc::Tensor& output) { + auto& dimensions = input.Shape(); + if (dimensions.size() != 3ULL || dimensions[2] != 3) { + return {kOrtxErrorInvalidArgument, "[ConvertToRGB]: input is not (H, W, C)"}; } - } - return {}; -} + std::uint8_t* p_output_image = output.Allocate(dimensions); + auto* input_data = input.Data(); + auto h = dimensions[0]; + auto w = dimensions[1]; + auto c = dimensions[2]; + + // convert BGR channel layouts to RGB + for (int64_t j = 0; j < h; ++j) { + for (int64_t k = 0; k < w; ++k) { + auto c0_index = j * w * c + k * c; + std::tie(p_output_image[c0_index], p_output_image[c0_index + 1], p_output_image[c0_index + 2]) = + std::make_tuple(input_data[c0_index + 2], input_data[c0_index + 1], input_data[c0_index]); + } + } + + return {}; + } +}; inline cv::Mat padding_336(const cv::Mat& image) { // def padding_336(b): diff --git a/shared/api/runner.hpp b/shared/api/runner.hpp index a9ee6285..b3170e0a 100644 --- a/shared/api/runner.hpp +++ b/shared/api/runner.hpp @@ -18,10 +18,11 @@ namespace ort_extensions { using json = nlohmann::json; using TensorArgs = std::vector; -class KernelClass { +class KernelDef { public: - KernelClass() = default; - virtual ~KernelClass() = default; + KernelDef() = default; + virtual ~KernelDef() = default; + virtual OrtxStatus Init(std::string_view attr) { return {}; } // no need to be initialized for a kernel function virtual TensorArgs AllocateOutput(ortc::IAllocator* allocator) const = 0; virtual OrtxStatus Apply(TensorArgs& inputs, TensorArgs& output) const = 0; @@ -91,17 +92,13 @@ class KernelClass { }; template -class KernelFunction : public KernelClass { +class KernelFunction : public KernelDef { public: KernelFunction(OrtxStatus (*body)(Args...)) : body_(body){}; virtual ~KernelFunction() = default; - OrtxStatus Compute(Args... args) const { - return body_(std::forward(args)...); - } - TensorArgs AllocateOutput(ortc::IAllocator* allocator) const override { - auto tensors = KernelClass::AllocateOutput(allocator); + auto tensors = KernelDef::AllocateOutput(allocator); TensorArgs all_args; for (auto& tensor : tensors) { if (tensor != nullptr) { @@ -123,16 +120,64 @@ class KernelFunction : public KernelClass { private: std::function body_; + + OrtxStatus Compute(Args... args) const { + return body_(std::forward(args)...); + } +}; + +template +class KernelStruct : public KernelDef { + public: + KernelStruct(OrtxStatus (T::*body)(Args...)) : body_(body){}; + virtual ~KernelStruct() = default; + + TensorArgs AllocateOutput(ortc::IAllocator* allocator) const override { + auto tensors = KernelDef::AllocateOutput(allocator); + TensorArgs all_args; + for (auto& tensor : tensors) { + if (tensor != nullptr) { + all_args.push_back(tensor); + } + } + + return all_args; + } + + template + OrtxStatus Init(DT attr) { + instance_ = std::make_unique(); + return instance_->Init(std::move(attr)); + } + + OrtxStatus Apply(TensorArgs& inputs, TensorArgs& outputs) const override { + TensorArgs all_args; + all_args.reserve(inputs.size() + outputs.size()); + all_args.insert(all_args.end(), inputs.begin(), inputs.end()); + all_args.insert(all_args.end(), outputs.begin(), outputs.end()); + auto args_tuple = std::tuple_cat(CastTensors(all_args)); + return std::apply([this](auto&&... args) { + return (instance_.get()->*body_)(std::forward(*args)...); }, std::move(args_tuple)); + } + + private: + OrtxStatus (T::*body_)(Args...){}; + std::unique_ptr instance_; }; template -std::unique_ptr DefineKernelFunction(OrtxStatus (*body)(Args...)) { +std::unique_ptr CreateKernelInstance(OrtxStatus (*body)(Args...)) { return std::make_unique>(body); } +template +std::unique_ptr CreateKernelInstance(OrtxStatus (T::*method)(Args...)) { + return std::make_unique>(method); +} + class Operation { public: - using KernelRegistry = std::unordered_map()>>; + using KernelRegistry = std::unordered_map()>>; Operation(const KernelRegistry& registry) { kernel_registry_ = ®istry; }; OrtxStatus Init(std::string_view op_def) { @@ -162,12 +207,14 @@ class Operation { op_name_ = op_name; kernel_ = kernel_iter->second(); - /* TODO: parse the attributes - if (op_json.contains("attrs")) { - auto attrs = op_json.at("attrs"); - attrs.get_to(attributes_); + if (op_json.contains("attrs")) { + auto attrs = op_json.at("attrs"); + auto status = kernel_->Init(attrs.dump()); + if (!status.IsOk()) { + return status; } - */ + } + return {}; } @@ -192,9 +239,8 @@ class Operation { private: const KernelRegistry* kernel_registry_; - std::unique_ptr kernel_; + std::unique_ptr kernel_; std::string op_name_; - std::unordered_map attributes_; ortc::IAllocator* allocator_{}; }; diff --git a/test/pp_api_test/test_processor.cc b/test/pp_api_test/test_processor.cc index cc21a42e..076c3b9d 100644 --- a/test/pp_api_test/test_processor.cc +++ b/test/pp_api_test/test_processor.cc @@ -1,38 +1,74 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include #include #include +#include +#include #include "gtest/gtest.h" #include "shared/api/image_processor.h" using namespace ort_extensions; -TEST(ProcessorTest, TestMsImage) { +std::vector ReadArrayFromFile(const std::string& filename) { + std::ifstream inFile(filename, std::ios::binary | std::ios::ate); + if (!inFile) { + throw std::runtime_error("Cannot open file for reading."); + } + std::streamsize fileSize = inFile.tellg(); + inFile.seekg(0, std::ios::beg); + std::vector array(fileSize / sizeof(float)); + if (!inFile.read(reinterpret_cast(array.data()), fileSize)) { + throw std::runtime_error("Error reading file."); + } + + return array; +} + +TEST(ProcessorTest, TestPhi3VImageProcessing) { auto [input_data, n_data] = ort_extensions::LoadRawImages( {"data/processor/standard_s.jpg", "data/processor/australia.jpg", "data/processor/exceltable.png"}); auto proc = OrtxObjectPtr(OrtxCreateProcessor, "data/processor/image_processor.json"); ortc::Tensor* pixel_values; ortc::Tensor* image_sizes; - ortc::Tensor* num_img_takens; + ortc::Tensor* num_img_tokens; auto [status, r] = proc->PreProcess( ort_extensions::span(input_data.get(), (size_t)n_data), &pixel_values, &image_sizes, - &num_img_takens); + &num_img_tokens); ASSERT_TRUE(status.IsOk()); - - // dump the output to a file - // FILE* fp = fopen("ppoutput.bin", "wb"); - // fwrite(pixel_values->Data(), sizeof(float), pixel_values->NumberOfElement(), fp); - // fclose(fp); + int64_t expected_image_size[] = {1344, 1344, 1008, 1344, 1008, 1680}; + int64_t expected_num_token[] = {2509, 1921, 2353}; ASSERT_EQ(pixel_values->Shape(), std::vector({3, 17, 3, 336, 336})); + ASSERT_EQ(image_sizes->Shape(), std::vector({3, 2})); + ASSERT_EQ(num_img_tokens->Shape(), std::vector({3, 1})); + + if (std::filesystem::is_directory("data2/processor")) { + // the test data was dumped in this way + // { + // std::ofstream outFile("data2/processor/img_proc_pixel_values.bin", std::ios::binary); + // outFile.write(reinterpret_cast(array.data()), array.size() * sizeof(float)); + // } + + auto expected_output = ReadArrayFromFile("data2/processor/img_proc_pixel_values.bin"); + ASSERT_EQ(pixel_values->NumberOfElement(), expected_output.size()); + for (size_t i = 0; i < expected_output.size(); i++) { + ASSERT_NEAR(pixel_values->Data()[i], expected_output[i], 1e-3); + } + } + + // compare the image sizes + for (size_t i = 0; i < 3; i++) { + ASSERT_EQ(image_sizes->Data()[i * 2], expected_image_size[i * 2]); + ASSERT_EQ(image_sizes->Data()[i * 2 + 1], expected_image_size[i * 2 + 1]); + ASSERT_EQ(num_img_tokens->Data()[i], expected_num_token[i]); + } proc->ClearOutputs(&r); } diff --git a/test/pp_api_test/test_tokenizer.cc b/test/pp_api_test/test_tokenizer.cc index 1022148f..f48ae880 100644 --- a/test/pp_api_test/test_tokenizer.cc +++ b/test/pp_api_test/test_tokenizer.cc @@ -117,6 +117,42 @@ TEST(OrtxTokenizerTest, TicTokenTokenizer) { EXPECT_EQ(out_text[0], input[0]); } +TEST(OrtxTokenizerTest, Phi3_S_Tokenizer) { + if (!std::filesystem::exists("data2/phi-3-small")) { + GTEST_SKIP() << "Skip test as extra test data is not deployed."; + } + + auto tokenizer = std::make_unique(); + auto status = tokenizer->Load("data2/phi-3-small"); + if (!status.IsOk()) { + std::cout << status.ToString() << std::endl; + } + + // validate tokenizer is not null + EXPECT_NE(tokenizer, nullptr); + + std::vector EXPECTED_IDS_0 = {2028, 374, 264, 1296, 13}; + std::vector input = { + "This is a test.", + "the second one", + "I like walking my cute dog\n and\x17 then", + "Hey<|endoftext|>. \t\t \n\nyou é @#😈 🤗! , 1234 15 5,61"}; + std::vector> + token_ids; + status = tokenizer->Tokenize(input, token_ids); + EXPECT_TRUE(status.IsOk()); + DumpTokenIds(token_ids); + + EXPECT_EQ(token_ids.size(), input.size()); + EXPECT_EQ(token_ids[0], EXPECTED_IDS_0); + + std::vector out_text; + std::vector> token_ids_span = {token_ids[0], token_ids[1]}; + status = tokenizer->Detokenize(token_ids_span, out_text); + EXPECT_TRUE(status.IsOk()); + EXPECT_EQ(out_text[0], input[0]); +} + TEST(OrtxTokenizerTest, GemmaTokenizer) { auto tokenizer = std::make_unique(); auto status = tokenizer->Load("data/gemma"); diff --git a/test/shared_test/test_exceptions.cc b/test/shared_test/test_exceptions.cc index ed70dece..953ba8dd 100644 --- a/test/shared_test/test_exceptions.cc +++ b/test/shared_test/test_exceptions.cc @@ -3,7 +3,6 @@ #include #include "gtest/gtest.h" -#include "gmock/gmock.h" #include "ocos.h" #include "test_kernel.hpp"