Refactor the unit tests and cmake build script (#726)
* refine the build script * complete the unit tests. * remove the commented code
This commit is contained in:
Родитель
b60df02fd0
Коммит
ca433cbea7
|
@ -285,9 +285,11 @@ function(set_msvc_c_cpp_compiler_warning_level warning_level)
|
|||
endif()
|
||||
endfunction()
|
||||
|
||||
if (NOT ONNXRUNTIME_INCLUDE_DIR)
|
||||
include(ext_ortlib)
|
||||
endif()
|
||||
# set default MSVC warning level to 3 for external dependencies
|
||||
set_msvc_c_cpp_compiler_warning_level(3)
|
||||
include(ext_ortlib)
|
||||
include(gsl)
|
||||
|
||||
macro(standardize_output_folder bin_target)
|
||||
|
@ -681,7 +683,6 @@ endif()
|
|||
|
||||
if(OCOS_ENABLE_GPT2_TOKENIZER OR OCOS_ENABLE_WORDPIECE_TOKENIZER)
|
||||
target_include_directories(ocos_operators PUBLIC ${nlohmann_json_SOURCE_DIR}/single_include)
|
||||
list(APPEND ocos_libraries nlohmann_json::nlohmann_json)
|
||||
endif()
|
||||
|
||||
# If building a shared library we can't throw an internal exception type across the library boundary as the type
|
||||
|
@ -695,8 +696,6 @@ if(ANDROID)
|
|||
list(APPEND ocos_libraries log)
|
||||
endif()
|
||||
|
||||
list(APPEND ocos_libraries Microsoft.GSL::GSL)
|
||||
|
||||
list(REMOVE_DUPLICATES OCOS_COMPILE_DEFINITIONS)
|
||||
target_compile_definitions(noexcep_operators PRIVATE ${OCOS_COMPILE_DEFINITIONS})
|
||||
if(NOT OCOS_ENABLE_CPP_EXCEPTIONS)
|
||||
|
@ -899,7 +898,7 @@ if (_ORTX_STANDALONE_PROJECT)
|
|||
# Run CPack to generate the NuGet package
|
||||
include(CPack)
|
||||
|
||||
if(OCOS_ENABLE_CTEST)
|
||||
if(OCOS_ENABLE_CTEST AND NOT MAC_CATALYST)
|
||||
include(ext_tests)
|
||||
endif()
|
||||
endif()
|
||||
|
|
|
@ -17,7 +17,20 @@ namespace ort_extensions {
|
|||
class path {
|
||||
public:
|
||||
path() = default;
|
||||
path(const std::string& path) : path_(path){};
|
||||
path(const std::string& path) : path_(path) {
|
||||
#ifdef _WIN32
|
||||
w_path_ = to_wstring();
|
||||
#endif // _WIN32
|
||||
};
|
||||
|
||||
#ifdef _WIN32
|
||||
path(const std::wstring& wpath) {
|
||||
int size_needed = WideCharToMultiByte(CP_UTF8, 0, wpath.c_str(), -1, nullptr, 0, nullptr, nullptr);
|
||||
std::string utf8_str(size_needed, 0);
|
||||
WideCharToMultiByte(CP_UTF8, 0, wpath.c_str(), -1, &utf8_str[0], size_needed, nullptr, nullptr);
|
||||
path_ = utf8_str;
|
||||
}
|
||||
#endif // _WIN32
|
||||
|
||||
static constexpr char separator =
|
||||
#ifdef _WIN32
|
||||
|
@ -30,7 +43,7 @@ class path {
|
|||
std::ifstream open(ios_base::openmode mode = ios_base::in) const {
|
||||
// if Windows, need to convert the string to UTF-16
|
||||
#ifdef _WIN32
|
||||
return std::ifstream(to_wstring(), mode);
|
||||
return std::ifstream(w_path_, mode);
|
||||
#else
|
||||
return std::ifstream(path_, mode);
|
||||
#endif // _WIN32
|
||||
|
@ -55,7 +68,7 @@ class path {
|
|||
bool is_directory() const {
|
||||
#ifdef _WIN32
|
||||
struct _stat64 info;
|
||||
if (_wstat64(to_wstring().c_str(), &info) != 0) {
|
||||
if (_wstat64(w_path_.c_str(), &info) != 0) {
|
||||
return false;
|
||||
}
|
||||
#else
|
||||
|
@ -69,8 +82,9 @@ class path {
|
|||
|
||||
private:
|
||||
std::string path_;
|
||||
|
||||
#ifdef _WIN32
|
||||
std::wstring w_path_;
|
||||
|
||||
std::wstring to_wstring() const {
|
||||
int size_needed = MultiByteToWideChar(CP_UTF8, 0, path_.c_str(), -1, nullptr, 0);
|
||||
std::wstring utf16_str(size_needed, 0);
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
if (NOT MAC_CATALYST)
|
||||
|
||||
if (OCOS_ENABLE_SELECTED_OPLIST)
|
||||
# currently the tests don't handle operator exclusion cleanly.
|
||||
message(FATAL_ERROR "Due to usage of OCOS_ENABLE_SELECTED_OPLIST excluding operators the tests are unable to be built and run")
|
||||
|
@ -55,7 +53,7 @@ function(add_test_target)
|
|||
"${TEST_SRC_DIR}/unittest_main/test_main.cc")
|
||||
target_link_libraries(${ARG_TARGET} PRIVATE
|
||||
${ARG_LIBRARIES}
|
||||
gtest gmock)
|
||||
gtest)
|
||||
|
||||
if(OCOS_USE_CUDA)
|
||||
target_link_directories(${ARG_TARGET} PRIVATE ${CUDAToolkit_LIBRARY_DIR})
|
||||
|
@ -93,7 +91,7 @@ function(add_test_target)
|
|||
|
||||
target_link_libraries(${ARG_TARGET} PRIVATE
|
||||
${ARG_LIBRARIES}
|
||||
gtest gmock)
|
||||
gtest)
|
||||
|
||||
set(test_data_destination_root_directory $<TARGET_FILE_DIR:${dummy_testee_target}>)
|
||||
|
||||
|
@ -130,9 +128,40 @@ add_test_target(TARGET ocos_test
|
|||
LIBRARIES ortcustomops ${ocos_libraries})
|
||||
target_compile_definitions(ocos_test PRIVATE ${OCOS_COMPILE_DEFINITIONS})
|
||||
|
||||
if (OCOS_ENABLE_C_API)
|
||||
file(GLOB pp_api_TEST_SRC
|
||||
"${TEST_SRC_DIR}/pp_api_test/*.c"
|
||||
"${TEST_SRC_DIR}/pp_api_test/*.cc"
|
||||
"${TEST_SRC_DIR}/pp_api_test/*.h")
|
||||
|
||||
add_test_target(TARGET pp_api_test
|
||||
TEST_SOURCES ${pp_api_TEST_SRC}
|
||||
LIBRARIES onnxruntime_extensions ${ocos_libraries}
|
||||
TEST_DATA_DIRECTORIES ${TEST_SRC_DIR}/data)
|
||||
|
||||
target_compile_definitions(pp_api_test PRIVATE ${OCOS_COMPILE_DEFINITIONS})
|
||||
target_include_directories(pp_api_test PRIVATE
|
||||
${PROJECT_SOURCE_DIR}/
|
||||
"$<TARGET_PROPERTY:ortcustomops,INTERFACE_INCLUDE_DIRECTORIES>"
|
||||
"$<TARGET_PROPERTY:ocos_operators,INTERFACE_INCLUDE_DIRECTORIES>")
|
||||
|
||||
if (ORTX_TEST_DATA2)
|
||||
file(TO_NATIVE_PATH "${ORTX_TEST_DATA2}/tests/data2" _TEST_DATA2)
|
||||
add_custom_command(TARGET pp_api_test POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E create_symlink ${_TEST_DATA2} ${onnxruntime_extensions_BINARY_DIR}/data2)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
# -- shared test (needs onnxruntime) --
|
||||
SET(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH)
|
||||
find_library(ONNXRUNTIME onnxruntime HINTS "${ONNXRUNTIME_LIB_DIR}")
|
||||
# avoid blindling searching for onnxruntime library
|
||||
# wbhich leads to a unpredictable result
|
||||
if (NOT ONNXRUNTIME_LIB_DIR)
|
||||
set(ONNXRUNTIME "ONNXRUNTIME-NOTFOUND")
|
||||
else()
|
||||
SET(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH)
|
||||
find_library(ONNXRUNTIME onnxruntime HINTS "${ONNXRUNTIME_LIB_DIR}")
|
||||
endif()
|
||||
|
||||
if("${ONNXRUNTIME}" STREQUAL "ONNXRUNTIME-NOTFOUND")
|
||||
message(WARNING "The prebuilt onnxruntime library was not found, extensions_test will be skipped.")
|
||||
|
@ -197,25 +226,10 @@ else()
|
|||
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${ONNXRUNTIME} ${CMAKE_BINARY_DIR}/lib
|
||||
)
|
||||
endif()
|
||||
|
||||
if (OCOS_ENABLE_C_API)
|
||||
# avoid copying the same data directory at the same time.
|
||||
add_dependencies(extensions_test pp_api_test)
|
||||
endif()
|
||||
endblock()
|
||||
|
||||
if (OCOS_ENABLE_C_API)
|
||||
file(GLOB pp_api_TEST_SRC
|
||||
"${TEST_SRC_DIR}/pp_api_test/*.c"
|
||||
"${TEST_SRC_DIR}/pp_api_test/*.cc"
|
||||
"${TEST_SRC_DIR}/pp_api_test/*.h")
|
||||
|
||||
add_test_target(TARGET pp_api_test
|
||||
TEST_SOURCES ${pp_api_TEST_SRC}
|
||||
LIBRARIES onnxruntime_extensions ${ocos_libraries}
|
||||
TEST_DATA_DIRECTORIES ${TEST_SRC_DIR}/data)
|
||||
|
||||
target_compile_definitions(pp_api_test PRIVATE ${OCOS_COMPILE_DEFINITIONS})
|
||||
target_include_directories(pp_api_test PRIVATE
|
||||
${PROJECT_SOURCE_DIR}/
|
||||
"$<TARGET_PROPERTY:ortcustomops,INTERFACE_INCLUDE_DIRECTORIES>"
|
||||
"$<TARGET_PROPERTY:ocos_operators,INTERFACE_INCLUDE_DIRECTORIES>")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
endif()
|
||||
|
|
|
@ -4,8 +4,8 @@ FetchContent_Declare(
|
|||
URL_HASH SHA1=06096d3900c356e468ba060a609642c635131106
|
||||
)
|
||||
|
||||
set(BUILD_GMOCK OFF CACHE BOOL "" FORCE)
|
||||
set(INSTALL_GTEST OFF CACHE BOOL "" FORCE)
|
||||
FetchContent_MakeAvailable(googletest)
|
||||
set_target_properties(gmock PROPERTIES FOLDER "externals/gtest")
|
||||
set_target_properties(gmock_main PROPERTIES FOLDER "externals/gtest")
|
||||
set_target_properties(gtest PROPERTIES FOLDER "externals/gtest")
|
||||
set_target_properties(gtest_main PROPERTIES FOLDER "externals/gtest")
|
||||
|
|
|
@ -14,5 +14,13 @@ else()
|
|||
)
|
||||
endif()
|
||||
|
||||
FetchContent_MakeAvailable(GSL)
|
||||
get_target_property(GSL_INCLUDE_DIR Microsoft.GSL::GSL INTERFACE_INCLUDE_DIRECTORIES)
|
||||
FetchContent_GetProperties(GSL)
|
||||
string(TOLOWER "GSL" lcName)
|
||||
if(NOT ${lcName}_POPULATED)
|
||||
FetchContent_Populate(GSL)
|
||||
# add_subdirectory(${GSL_SOURCE_DIR} ${GSL_BINARY_DIR} EXCLUDE_FROM_ALL)
|
||||
endif()
|
||||
|
||||
set(GSL_INCLUDE_DIR ${gsl_SOURCE_DIR}/include)
|
||||
|
||||
#get_target_property(GSL_INCLUDE_DIR Microsoft.GSL::GSL INTERFACE_INCLUDE_DIRECTORIES)
|
||||
|
|
|
@ -7,5 +7,4 @@ set(JSON_BuildTests OFF CACHE INTERNAL "")
|
|||
FetchContent_GetProperties(nlohmann_json)
|
||||
if(NOT nlohmann_json_POPULATED)
|
||||
FetchContent_Populate(nlohmann_json)
|
||||
add_subdirectory(${nlohmann_json_SOURCE_DIR} ${nlohmann_json_BINARY_DIR} EXCLUDE_FROM_ALL)
|
||||
endif()
|
||||
|
|
|
@ -107,6 +107,16 @@ set(BUILD_TESTS OFF CACHE INTERNAL "")
|
|||
set(CV_TRACE OFF CACHE INTERNAL "")
|
||||
|
||||
set(CV_DISABLE_OPTIMIZATION ON CACHE INTERNAL "")
|
||||
set(BUILD_PERF_TESTS OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_java_bindings_generator OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_js_bindings_generator OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_objc_bindings_generator OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_python_bindings_generator OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_python_tests OFF CACHE INTERNAL "")
|
||||
|
||||
set(WITH_ADE OFF CACHE INTERNAL "")
|
||||
set(VIDEOIO_ENABLE_PLUGINS OFF CACHE INTERNAL "")
|
||||
set(HIGHGUI_ENABLE_PLUGINS OFF CACHE INTERNAL "")
|
||||
|
||||
if(IOS)
|
||||
# copy what OpenCV's platforms/ios/build_framework.py does and set CPU_BASELINE=DETECT
|
||||
|
@ -157,13 +167,3 @@ endif()
|
|||
|
||||
# unset it to avoid affecting other projects.
|
||||
unset(EXECUTABLE_OUTPUT_PATH CACHE)
|
||||
|
||||
if (CMAKE_SYSTEM_NAME MATCHES "Windows")
|
||||
set(opencv_projs gen_opencv_java_source gen_opencv_js_source gen_opencv_python_source)
|
||||
list(APPEND opencv_projs gen_opencv_objc_source gen_opencv_objc_source_ios gen_opencv_objc_source_osx)
|
||||
list(APPEND opencv_projs opencv_highgui_plugins opencv_videoio_plugins)
|
||||
foreach(p ${opencv_projs})
|
||||
set_target_properties(${p} PROPERTIES FOLDER "externals/opencv")
|
||||
set_target_properties(${p} PROPERTIES EXCLUDE_FROM_ALL TRUE EXCLUDE_FROM_DEFAULT_BUILD TRUE)
|
||||
endforeach()
|
||||
endif()
|
||||
|
|
|
@ -48,7 +48,7 @@ class BoxArray {
|
|||
private:
|
||||
void SortBoxesByScore(gsl::span<const float> data) {
|
||||
boxes_by_score_.reserve(NumBoxes());
|
||||
for (size_t i = 0; i < NumBoxes(); ++i) {
|
||||
for (size_t i = 0; i < static_cast<size_t>(NumBoxes()); ++i) {
|
||||
boxes_by_score_.push_back(data.subspan(i * shape_[1], shape_[1]));
|
||||
}
|
||||
|
||||
|
@ -188,7 +188,7 @@ void DrawBoxesForNumClasses(ImageView& image, const BoxArray& boxes, int64_t thi
|
|||
std::unordered_map<float, size_t> color_used;
|
||||
std::vector<std::pair<size_t, int64_t>> box_reverse;
|
||||
box_reverse.reserve(boxes.NumBoxes());
|
||||
for (size_t i = 0; i < boxes.NumBoxes(); ++i) {
|
||||
for (size_t i = 0; i < static_cast<size_t>(boxes.NumBoxes()); ++i) {
|
||||
const auto box = boxes.GetBox(i);
|
||||
if (color_used.find(box[kBoxClassIndex]) == color_used.end()) {
|
||||
if (color_used.size() >= KBGRColorMap.size()) {
|
||||
|
|
|
@ -38,9 +38,9 @@ LoadRawImages(const std::initializer_list<const char*>& image_paths) {
|
|||
} // namespace ort_extensions
|
||||
|
||||
Operation::KernelRegistry ImageProcessor::kernel_registry_ = {
|
||||
{"DecodeImage", []() { return DefineKernelFunction(image_decoder); }},
|
||||
{"ConvertRGB", []() { return DefineKernelFunction(convert_to_rgb); }},
|
||||
{"Phi3ImageTransform", []() { return DefineKernelFunction(phi3_hd_transform); }},
|
||||
{"DecodeImage", []() { return CreateKernelInstance(image_decoder); }},
|
||||
{"ConvertRGB", []() { return CreateKernelInstance(&ConvertToRGB::Compute); }},
|
||||
{"Phi3ImageTransform", []() { return CreateKernelInstance(phi3_hd_transform); }},
|
||||
};
|
||||
|
||||
OrtxStatus ImageProcessor::Init(std::string_view processor_def) {
|
||||
|
|
|
@ -13,30 +13,32 @@ constexpr int image_resized_height = 336;
|
|||
constexpr float OPENAI_CLIP_MEAN[] = {0.48145466f, 0.4578275f, 0.40821073f};
|
||||
constexpr float OPENAI_CLIP_STD[] = {0.26862954f, 0.26130258f, 0.27577711f};
|
||||
|
||||
inline OrtxStatus convert_to_rgb(const ortc::Tensor<uint8_t>& input,
|
||||
ortc::Tensor<uint8_t>& output) {
|
||||
auto& dimensions = input.Shape();
|
||||
if (dimensions.size() != 3ULL || dimensions[2] != 3) {
|
||||
return {kOrtxErrorInvalidArgument, "[ConvertToRGB]: input is not (H, W, C)"};
|
||||
}
|
||||
|
||||
std::uint8_t* p_output_image = output.Allocate(dimensions);
|
||||
auto* input_data = input.Data();
|
||||
auto h = dimensions[0];
|
||||
auto w = dimensions[1];
|
||||
auto c = dimensions[2];
|
||||
|
||||
// convert BGR channel layouts to RGB
|
||||
for (int64_t j = 0; j < h; ++j) {
|
||||
for (int64_t k = 0; k < w; ++k) {
|
||||
auto c0_index = j * w * c + k * c;
|
||||
std::tie(p_output_image[c0_index], p_output_image[c0_index + 1], p_output_image[c0_index + 2]) =
|
||||
std::make_tuple(input_data[c0_index + 2], input_data[c0_index + 1], input_data[c0_index]);
|
||||
struct ConvertToRGB {
|
||||
OrtxStatus Compute(const ortc::Tensor<uint8_t>& input,
|
||||
ortc::Tensor<uint8_t>& output) {
|
||||
auto& dimensions = input.Shape();
|
||||
if (dimensions.size() != 3ULL || dimensions[2] != 3) {
|
||||
return {kOrtxErrorInvalidArgument, "[ConvertToRGB]: input is not (H, W, C)"};
|
||||
}
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
std::uint8_t* p_output_image = output.Allocate(dimensions);
|
||||
auto* input_data = input.Data();
|
||||
auto h = dimensions[0];
|
||||
auto w = dimensions[1];
|
||||
auto c = dimensions[2];
|
||||
|
||||
// convert BGR channel layouts to RGB
|
||||
for (int64_t j = 0; j < h; ++j) {
|
||||
for (int64_t k = 0; k < w; ++k) {
|
||||
auto c0_index = j * w * c + k * c;
|
||||
std::tie(p_output_image[c0_index], p_output_image[c0_index + 1], p_output_image[c0_index + 2]) =
|
||||
std::make_tuple(input_data[c0_index + 2], input_data[c0_index + 1], input_data[c0_index]);
|
||||
}
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
inline cv::Mat padding_336(const cv::Mat& image) {
|
||||
// def padding_336(b):
|
||||
|
|
|
@ -18,10 +18,11 @@ namespace ort_extensions {
|
|||
using json = nlohmann::json;
|
||||
using TensorArgs = std::vector<ortc::TensorBase*>;
|
||||
|
||||
class KernelClass {
|
||||
class KernelDef {
|
||||
public:
|
||||
KernelClass() = default;
|
||||
virtual ~KernelClass() = default;
|
||||
KernelDef() = default;
|
||||
virtual ~KernelDef() = default;
|
||||
virtual OrtxStatus Init(std::string_view attr) { return {}; } // no need to be initialized for a kernel function
|
||||
virtual TensorArgs AllocateOutput(ortc::IAllocator* allocator) const = 0;
|
||||
virtual OrtxStatus Apply(TensorArgs& inputs, TensorArgs& output) const = 0;
|
||||
|
||||
|
@ -91,17 +92,13 @@ class KernelClass {
|
|||
};
|
||||
|
||||
template <typename... Args>
|
||||
class KernelFunction : public KernelClass {
|
||||
class KernelFunction : public KernelDef {
|
||||
public:
|
||||
KernelFunction(OrtxStatus (*body)(Args...)) : body_(body){};
|
||||
virtual ~KernelFunction() = default;
|
||||
|
||||
OrtxStatus Compute(Args... args) const {
|
||||
return body_(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
TensorArgs AllocateOutput(ortc::IAllocator* allocator) const override {
|
||||
auto tensors = KernelClass::AllocateOutput<Args...>(allocator);
|
||||
auto tensors = KernelDef::AllocateOutput<Args...>(allocator);
|
||||
TensorArgs all_args;
|
||||
for (auto& tensor : tensors) {
|
||||
if (tensor != nullptr) {
|
||||
|
@ -123,16 +120,64 @@ class KernelFunction : public KernelClass {
|
|||
|
||||
private:
|
||||
std::function<OrtxStatus(Args...)> body_;
|
||||
|
||||
OrtxStatus Compute(Args... args) const {
|
||||
return body_(std::forward<Args>(args)...);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename... Args>
|
||||
class KernelStruct : public KernelDef {
|
||||
public:
|
||||
KernelStruct(OrtxStatus (T::*body)(Args...)) : body_(body){};
|
||||
virtual ~KernelStruct() = default;
|
||||
|
||||
TensorArgs AllocateOutput(ortc::IAllocator* allocator) const override {
|
||||
auto tensors = KernelDef::AllocateOutput<Args...>(allocator);
|
||||
TensorArgs all_args;
|
||||
for (auto& tensor : tensors) {
|
||||
if (tensor != nullptr) {
|
||||
all_args.push_back(tensor);
|
||||
}
|
||||
}
|
||||
|
||||
return all_args;
|
||||
}
|
||||
|
||||
template <typename DT>
|
||||
OrtxStatus Init(DT attr) {
|
||||
instance_ = std::make_unique<T>();
|
||||
return instance_->Init(std::move(attr));
|
||||
}
|
||||
|
||||
OrtxStatus Apply(TensorArgs& inputs, TensorArgs& outputs) const override {
|
||||
TensorArgs all_args;
|
||||
all_args.reserve(inputs.size() + outputs.size());
|
||||
all_args.insert(all_args.end(), inputs.begin(), inputs.end());
|
||||
all_args.insert(all_args.end(), outputs.begin(), outputs.end());
|
||||
auto args_tuple = std::tuple_cat(CastTensors<Args...>(all_args));
|
||||
return std::apply([this](auto&&... args) {
|
||||
return (instance_.get()->*body_)(std::forward<decltype(*args)>(*args)...); }, std::move(args_tuple));
|
||||
}
|
||||
|
||||
private:
|
||||
OrtxStatus (T::*body_)(Args...){};
|
||||
std::unique_ptr<T> instance_;
|
||||
};
|
||||
|
||||
template <typename... Args>
|
||||
std::unique_ptr<KernelClass> DefineKernelFunction(OrtxStatus (*body)(Args...)) {
|
||||
std::unique_ptr<KernelDef> CreateKernelInstance(OrtxStatus (*body)(Args...)) {
|
||||
return std::make_unique<KernelFunction<Args...>>(body);
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
std::unique_ptr<KernelDef> CreateKernelInstance(OrtxStatus (T::*method)(Args...)) {
|
||||
return std::make_unique<KernelStruct<T, Args...>>(method);
|
||||
}
|
||||
|
||||
class Operation {
|
||||
public:
|
||||
using KernelRegistry = std::unordered_map<std::string_view, std::function<std::unique_ptr<KernelClass>()>>;
|
||||
using KernelRegistry = std::unordered_map<std::string_view, std::function<std::unique_ptr<KernelDef>()>>;
|
||||
Operation(const KernelRegistry& registry) { kernel_registry_ = ®istry; };
|
||||
|
||||
OrtxStatus Init(std::string_view op_def) {
|
||||
|
@ -162,12 +207,14 @@ class Operation {
|
|||
op_name_ = op_name;
|
||||
kernel_ = kernel_iter->second();
|
||||
|
||||
/* TODO: parse the attributes
|
||||
if (op_json.contains("attrs")) {
|
||||
auto attrs = op_json.at("attrs");
|
||||
attrs.get_to(attributes_);
|
||||
if (op_json.contains("attrs")) {
|
||||
auto attrs = op_json.at("attrs");
|
||||
auto status = kernel_->Init(attrs.dump());
|
||||
if (!status.IsOk()) {
|
||||
return status;
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
|
@ -192,9 +239,8 @@ class Operation {
|
|||
private:
|
||||
const KernelRegistry* kernel_registry_;
|
||||
|
||||
std::unique_ptr<KernelClass> kernel_;
|
||||
std::unique_ptr<KernelDef> kernel_;
|
||||
std::string op_name_;
|
||||
std::unordered_map<std::string, std::string> attributes_;
|
||||
ortc::IAllocator* allocator_{};
|
||||
};
|
||||
|
||||
|
|
|
@ -1,38 +1,74 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <fstream>
|
||||
#include <vector>
|
||||
#include <tuple>
|
||||
#include <fstream>
|
||||
#include <filesystem>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "shared/api/image_processor.h"
|
||||
|
||||
using namespace ort_extensions;
|
||||
|
||||
TEST(ProcessorTest, TestMsImage) {
|
||||
std::vector<float> ReadArrayFromFile(const std::string& filename) {
|
||||
std::ifstream inFile(filename, std::ios::binary | std::ios::ate);
|
||||
if (!inFile) {
|
||||
throw std::runtime_error("Cannot open file for reading.");
|
||||
}
|
||||
std::streamsize fileSize = inFile.tellg();
|
||||
inFile.seekg(0, std::ios::beg);
|
||||
std::vector<float> array(fileSize / sizeof(float));
|
||||
if (!inFile.read(reinterpret_cast<char*>(array.data()), fileSize)) {
|
||||
throw std::runtime_error("Error reading file.");
|
||||
}
|
||||
|
||||
return array;
|
||||
}
|
||||
|
||||
TEST(ProcessorTest, TestPhi3VImageProcessing) {
|
||||
auto [input_data, n_data] = ort_extensions::LoadRawImages(
|
||||
{"data/processor/standard_s.jpg", "data/processor/australia.jpg", "data/processor/exceltable.png"});
|
||||
|
||||
auto proc = OrtxObjectPtr<ImageProcessor>(OrtxCreateProcessor, "data/processor/image_processor.json");
|
||||
ortc::Tensor<float>* pixel_values;
|
||||
ortc::Tensor<int64_t>* image_sizes;
|
||||
ortc::Tensor<int64_t>* num_img_takens;
|
||||
ortc::Tensor<int64_t>* num_img_tokens;
|
||||
|
||||
auto [status, r] = proc->PreProcess(
|
||||
ort_extensions::span(input_data.get(), (size_t)n_data),
|
||||
&pixel_values,
|
||||
&image_sizes,
|
||||
&num_img_takens);
|
||||
&num_img_tokens);
|
||||
|
||||
ASSERT_TRUE(status.IsOk());
|
||||
|
||||
// dump the output to a file
|
||||
// FILE* fp = fopen("ppoutput.bin", "wb");
|
||||
// fwrite(pixel_values->Data(), sizeof(float), pixel_values->NumberOfElement(), fp);
|
||||
// fclose(fp);
|
||||
int64_t expected_image_size[] = {1344, 1344, 1008, 1344, 1008, 1680};
|
||||
int64_t expected_num_token[] = {2509, 1921, 2353};
|
||||
|
||||
ASSERT_EQ(pixel_values->Shape(), std::vector<int64_t>({3, 17, 3, 336, 336}));
|
||||
ASSERT_EQ(image_sizes->Shape(), std::vector<int64_t>({3, 2}));
|
||||
ASSERT_EQ(num_img_tokens->Shape(), std::vector<int64_t>({3, 1}));
|
||||
|
||||
if (std::filesystem::is_directory("data2/processor")) {
|
||||
// the test data was dumped in this way
|
||||
// {
|
||||
// std::ofstream outFile("data2/processor/img_proc_pixel_values.bin", std::ios::binary);
|
||||
// outFile.write(reinterpret_cast<const char*>(array.data()), array.size() * sizeof(float));
|
||||
// }
|
||||
|
||||
auto expected_output = ReadArrayFromFile("data2/processor/img_proc_pixel_values.bin");
|
||||
ASSERT_EQ(pixel_values->NumberOfElement(), expected_output.size());
|
||||
for (size_t i = 0; i < expected_output.size(); i++) {
|
||||
ASSERT_NEAR(pixel_values->Data()[i], expected_output[i], 1e-3);
|
||||
}
|
||||
}
|
||||
|
||||
// compare the image sizes
|
||||
for (size_t i = 0; i < 3; i++) {
|
||||
ASSERT_EQ(image_sizes->Data()[i * 2], expected_image_size[i * 2]);
|
||||
ASSERT_EQ(image_sizes->Data()[i * 2 + 1], expected_image_size[i * 2 + 1]);
|
||||
ASSERT_EQ(num_img_tokens->Data()[i], expected_num_token[i]);
|
||||
}
|
||||
|
||||
proc->ClearOutputs(&r);
|
||||
}
|
||||
|
|
|
@ -117,6 +117,42 @@ TEST(OrtxTokenizerTest, TicTokenTokenizer) {
|
|||
EXPECT_EQ(out_text[0], input[0]);
|
||||
}
|
||||
|
||||
TEST(OrtxTokenizerTest, Phi3_S_Tokenizer) {
|
||||
if (!std::filesystem::exists("data2/phi-3-small")) {
|
||||
GTEST_SKIP() << "Skip test as extra test data is not deployed.";
|
||||
}
|
||||
|
||||
auto tokenizer = std::make_unique<ort_extensions::TokenizerImpl>();
|
||||
auto status = tokenizer->Load("data2/phi-3-small");
|
||||
if (!status.IsOk()) {
|
||||
std::cout << status.ToString() << std::endl;
|
||||
}
|
||||
|
||||
// validate tokenizer is not null
|
||||
EXPECT_NE(tokenizer, nullptr);
|
||||
|
||||
std::vector<extTokenId_t> EXPECTED_IDS_0 = {2028, 374, 264, 1296, 13};
|
||||
std::vector<std::string_view> input = {
|
||||
"This is a test.",
|
||||
"the second one",
|
||||
"I like walking my cute dog\n and\x17 then",
|
||||
"Hey<|endoftext|>. \t\t \n\nyou é @#😈 🤗! , 1234 15 5,61"};
|
||||
std::vector<std::vector<extTokenId_t>>
|
||||
token_ids;
|
||||
status = tokenizer->Tokenize(input, token_ids);
|
||||
EXPECT_TRUE(status.IsOk());
|
||||
DumpTokenIds(token_ids);
|
||||
|
||||
EXPECT_EQ(token_ids.size(), input.size());
|
||||
EXPECT_EQ(token_ids[0], EXPECTED_IDS_0);
|
||||
|
||||
std::vector<std::string> out_text;
|
||||
std::vector<ort_extensions::span<extTokenId_t const>> token_ids_span = {token_ids[0], token_ids[1]};
|
||||
status = tokenizer->Detokenize(token_ids_span, out_text);
|
||||
EXPECT_TRUE(status.IsOk());
|
||||
EXPECT_EQ(out_text[0], input[0]);
|
||||
}
|
||||
|
||||
TEST(OrtxTokenizerTest, GemmaTokenizer) {
|
||||
auto tokenizer = std::make_unique<ort_extensions::TokenizerImpl>();
|
||||
auto status = tokenizer->Load("data/gemma");
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
#include <filesystem>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "gmock/gmock.h"
|
||||
|
||||
#include "ocos.h"
|
||||
#include "test_kernel.hpp"
|
||||
|
|
Загрузка…
Ссылка в новой задаче