Starter changes for supporting pre/post processing for vision models. (#312)
* Initial changes for supporting mobilenet and superresolution. - Script to update model with pre/post processing - custom ops for decode/encode - user just has to provide jpg or png bytes - superresolution can return the updated image in jpg or png - models for testing Updated cmake setup to enable building of the vision pre/post processing ops - opencv2 is treated as an internal dependency rather than the mechansim for selecting which operators to include. * Add extra check in decode.
This commit is contained in:
Родитель
52ae76d3df
Коммит
1cab9711ff
|
@ -361,7 +361,8 @@ jobs:
|
|||
-DCMAKE_TOOLCHAIN_FILE=$(Build.BinariesDirectory)/emsdk/upstream/emscripten/cmake/Modules/Platform/Emscripten.cmake \
|
||||
-DOCOS_ENABLE_SPM_TOKENIZER=ON \
|
||||
-DOCOS_BUILD_PYTHON=OFF \
|
||||
-DOCOS_ENABLE_OPENCV=OFF
|
||||
-DOCOS_ENABLE_CV2=OFF \
|
||||
-DOCOS_ENABLE_VISION=OFF
|
||||
displayName: build the customop library with onnxruntime
|
||||
# TODO add unittest for webassembly
|
||||
|
||||
|
|
|
@ -49,3 +49,4 @@ tutorials/*/app/libs
|
|||
*.so
|
||||
*.dylib
|
||||
*.pyd
|
||||
/test/data/ppp_vision/*.updated.onnx
|
||||
|
|
|
@ -37,8 +37,9 @@ option(OCOS_ENABLE_BERT_TOKENIZER "Enable the BertTokenizer building" ON)
|
|||
option(OCOS_ENABLE_BLINGFIRE "Enable operators depending on the Blingfire library" ON)
|
||||
option(OCOS_ENABLE_MATH "Enable math tensor operators building" ON)
|
||||
option(OCOS_ENABLE_DLIB "Enable operators like Inverse depending on DLIB" ON)
|
||||
option(OCOS_ENABLE_OPENCV "Enable operators depending on opencv" ON)
|
||||
option(OCOS_ENABLE_OPENCV_CODECS "Enable operators depending on opencv imgcodecs" ON)
|
||||
option(OCOS_ENABLE_OPENCV_CODECS "Enable cv2 and vision operators that require opencv imgcodecs." ON)
|
||||
option(OCOS_ENABLE_CV2 "Enable the operators in `operators/cv2`" ON)
|
||||
option(OCOS_ENABLE_VISION "Enable the operators in `operators/vision`" ON)
|
||||
option(OCOS_ENABLE_STATIC_LIB "Enable generating static library" OFF)
|
||||
option(OCOS_ENABLE_SELECTED_OPLIST "Enable including the selected_ops tool file" OFF)
|
||||
option(OCOS_BUILD_PYTHON "Enable building the Python package" OFF)
|
||||
|
@ -55,7 +56,8 @@ function(disable_all_operators)
|
|||
set(OCOS_ENABLE_BLINGFIRE OFF CACHE INTERNAL "")
|
||||
set(OCOS_ENABLE_MATH OFF CACHE INTERNAL "")
|
||||
set(OCOS_ENABLE_DLIB OFF CACHE INTERNAL "")
|
||||
set(OCOS_ENABLE_OPENCV OFF CACHE INTERNAL "")
|
||||
set(OCOS_ENABLE_CV2 OFF CACHE INTERNAL "")
|
||||
set(OCOS_ENABLE_VISION OFF CACHE INTERNAL "")
|
||||
endfunction()
|
||||
|
||||
if(NOT CC_OPTIMIZE)
|
||||
|
@ -174,11 +176,25 @@ if (OCOS_ENABLE_MATH)
|
|||
list(APPEND TARGET_SRC ${TARGET_SRC_MATH} ${TARGET_SRC_DLIB} ${TARGET_SRC_INVERSE})
|
||||
endif()
|
||||
|
||||
if (OCOS_ENABLE_OPENCV)
|
||||
# enable the opencv dependency if we have ops that require it
|
||||
if (OCOS_ENABLE_CV2 OR OCOS_ENABLE_VISION)
|
||||
set(_ENABLE_OPENCV ON)
|
||||
message(STATUS "Fetch opencv")
|
||||
include(opencv)
|
||||
file(GLOB TARGET_SRC_CV "operators/cv2/*.cc" "operators/cv2/*.h*")
|
||||
list(APPEND TARGET_SRC ${TARGET_SRC_CV})
|
||||
endif()
|
||||
|
||||
if (OCOS_ENABLE_CV2)
|
||||
file(GLOB TARGET_SRC_CV2 "operators/cv2/*.cc" "operators/cv2/*.h*")
|
||||
list(APPEND TARGET_SRC ${TARGET_SRC_CV2})
|
||||
endif()
|
||||
|
||||
if (OCOS_ENABLE_VISION)
|
||||
if (NOT OCOS_ENABLE_OPENCV_CODECS)
|
||||
message(FATAL_ERROR "OCOS_ENABLE_VISION requires OCOS_ENABLE_OPENCV_CODECS to be ON")
|
||||
endif()
|
||||
|
||||
file(GLOB TARGET_SRC_VISION "operators/vision/*.cc" "operators/vision/*.h*")
|
||||
list(APPEND TARGET_SRC ${TARGET_SRC_VISION})
|
||||
endif()
|
||||
|
||||
set(_HAS_TOKENIZER OFF)
|
||||
|
@ -240,6 +256,8 @@ endif()
|
|||
add_compile_options("$<$<C_COMPILER_ID:MSVC>:/utf-8>")
|
||||
add_compile_options("$<$<CXX_COMPILER_ID:MSVC>:/utf-8>")
|
||||
add_library(ocos_operators STATIC ${TARGET_SRC})
|
||||
set_target_properties(ocos_operators PROPERTIES FOLDER "operators")
|
||||
source_group(TREE ${PROJECT_SOURCE_DIR} FILES ${TARGET_SRC})
|
||||
standardize_output_folder(ocos_operators)
|
||||
|
||||
target_include_directories(ocos_operators PUBLIC
|
||||
|
@ -276,15 +294,23 @@ if (OCOS_ENABLE_MATH)
|
|||
# The dlib matrix implementation is all in the headers, no library compiling needed.
|
||||
endif()
|
||||
|
||||
if (OCOS_ENABLE_OPENCV)
|
||||
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_OPENCV)
|
||||
if (OCOS_ENABLE_OPENCV_CODECS)
|
||||
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_OPENCV ENABLE_OPENCV_CODEC)
|
||||
endif()
|
||||
if (_ENABLE_OPENCV)
|
||||
list(APPEND ocos_libraries ${opencv_LIBS})
|
||||
target_include_directories(ocos_operators PUBLIC ${opencv_INCLUDE_DIRS})
|
||||
endif()
|
||||
|
||||
if (OCOS_ENABLE_OPENCV_CODECS)
|
||||
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_OPENCV_CODECS)
|
||||
endif()
|
||||
|
||||
if (OCOS_ENABLE_CV2)
|
||||
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_CV2)
|
||||
endif()
|
||||
|
||||
if (OCOS_ENABLE_VISION)
|
||||
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_VISION)
|
||||
endif()
|
||||
|
||||
if (OCOS_ENABLE_GPT2_TOKENIZER)
|
||||
# GPT2
|
||||
target_include_directories(ocos_operators PRIVATE ${json_SOURCE_DIR}/single_include)
|
||||
|
@ -356,15 +382,18 @@ target_link_libraries(ortcustomops PUBLIC ocos_operators)
|
|||
if (_BUILD_SHARED_LIBRARY)
|
||||
file(GLOB shared_TARGET_SRC "shared/*.cc" "shared/*.h" "shared/*.def")
|
||||
add_library(extensions_shared SHARED ${shared_TARGET_SRC})
|
||||
source_group(TREE ${PROJECT_SOURCE_DIR} FILES ${shared_TARGET_SRC})
|
||||
standardize_output_folder(extensions_shared)
|
||||
if (CMAKE_SYSTEM_NAME STREQUAL "Android")
|
||||
if (OCOS_ENABLE_SPM_TOKENIZER)
|
||||
target_link_libraries(extensions_shared PUBLIC log)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (LINUX OR CMAKE_SYSTEM_NAME STREQUAL "Android")
|
||||
set_property(TARGET extensions_shared APPEND_STRING PROPERTY LINK_FLAGS "-Wl,-s -Wl,--version-script -Wl,${PROJECT_SOURCE_DIR}/shared/ortcustomops.ver")
|
||||
endif()
|
||||
|
||||
target_include_directories(extensions_shared PUBLIC
|
||||
"$<TARGET_PROPERTY:ortcustomops,INTERFACE_INCLUDE_DIRECTORIES>")
|
||||
target_link_libraries(extensions_shared PRIVATE ortcustomops)
|
||||
|
@ -394,7 +423,10 @@ foreach(nf ${NO_USE_FILES})
|
|||
endforeach()
|
||||
|
||||
# test section
|
||||
if (OCOS_ENABLE_CTEST)
|
||||
if(OCOS_ENABLE_CTEST AND OCOS_ENABLE_SELECTED_OPLIST)
|
||||
# currently the tests don't handle operator exclusion cleanly.
|
||||
message(WARNING "Due to usage of OCOS_ENABLE_SELECTED_OPLIST excluding operators the tests are unable to be built and run")
|
||||
elseif(OCOS_ENABLE_CTEST AND NOT OCOS_ENABLE_SELECTED_OPLIST)
|
||||
# Enable CTest
|
||||
enable_testing()
|
||||
message(STATUS "Fetch CTest")
|
||||
|
@ -409,7 +441,6 @@ if (OCOS_ENABLE_CTEST)
|
|||
target_link_libraries(ocos_test PRIVATE gtest_main ocos_operators ${ocos_libraries})
|
||||
add_test(NAME ocos_test COMMAND $<TARGET_FILE:ocos_test>)
|
||||
|
||||
|
||||
SET(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH)
|
||||
find_library(ONNXRUNTIME onnxruntime HINTS "${ONNXRUNTIME_LIB_DIR}")
|
||||
if (ONNXRUNTIME-NOTFOUND)
|
||||
|
|
|
@ -54,10 +54,10 @@ set(WITH_GSTREAMER OFF CACHE INTERNAL "")
|
|||
set(WITH_GTK OFF CACHE INTERNAL "")
|
||||
set(WITH_HALIDE OFF CACHE INTERNAL "")
|
||||
set(WITH_HPX OFF CACHE INTERNAL "")
|
||||
# set(WITH_IMGCODEC_HDR OFF CACHE INTERNAL "")
|
||||
# set(WITH_IMGCODEC_PFM OFF CACHE INTERNAL "")
|
||||
# set(WITH_IMGCODEC_PXM OFF CACHE INTERNAL "")
|
||||
# set(WITH_IMGCODEC_SUNRASTER OFF CACHE INTERNAL "")
|
||||
set(WITH_IMGCODEC_HDR OFF CACHE INTERNAL "")
|
||||
set(WITH_IMGCODEC_PFM OFF CACHE INTERNAL "")
|
||||
set(WITH_IMGCODEC_PXM OFF CACHE INTERNAL "")
|
||||
set(WITH_IMGCODEC_SUNRASTER OFF CACHE INTERNAL "")
|
||||
set(WITH_INF_ENGINE OFF CACHE INTERNAL "")
|
||||
set(WITH_IPP OFF CACHE INTERNAL "")
|
||||
set(WITH_ITT OFF CACHE INTERNAL "")
|
||||
|
@ -78,7 +78,7 @@ set(WITH_PTHREADS_PF OFF CACHE INTERNAL "")
|
|||
set(WITH_QUIRC OFF CACHE INTERNAL "")
|
||||
set(WITH_TBB OFF CACHE INTERNAL "")
|
||||
set(WITH_TENGINE OFF CACHE INTERNAL "")
|
||||
# set(WITH_TIFF OFF CACHE INTERNAL "")
|
||||
set(WITH_TIFF OFF CACHE INTERNAL "")
|
||||
set(WITH_V4L OFF CACHE INTERNAL "")
|
||||
set(WITH_VULKAN OFF CACHE INTERNAL "")
|
||||
set(WITH_WEBP OFF CACHE INTERNAL "")
|
||||
|
@ -90,23 +90,17 @@ if (OCOS_ENABLE_OPENCV_CODECS)
|
|||
set(BUILD_JPEG ON CACHE INTERNAL "")
|
||||
set(BUILD_OPENJPEG ON CACHE INTERNAL "")
|
||||
set(BUILD_PNG ON CACHE INTERNAL "")
|
||||
set(BUILD_TIFF ON CACHE INTERNAL "")
|
||||
|
||||
set(WITH_IMGCODEC_HDR ON CACHE INTERNAL "")
|
||||
set(WITH_IMGCODEC_PFM ON CACHE INTERNAL "")
|
||||
set(WITH_IMGCODEC_PXM ON CACHE INTERNAL "")
|
||||
set(WITH_IMGCODEC_SUNRASTER ON CACHE INTERNAL "")
|
||||
set(WITH_JPEG ON CACHE INTERNAL "")
|
||||
set(WITH_OPENJPEG ON CACHE INTERNAL "")
|
||||
set(WITH_PNG ON CACHE INTERNAL "")
|
||||
set(WITH_TIFF ON CACHE INTERNAL "")
|
||||
|
||||
set(BUILD_SHARED_LIBS OFF CACHE INTERNAL "")
|
||||
set(BUILD_DOCS OFF CACHE INTERNAL "")
|
||||
set(BUILD_EXAMPLES OFF CACHE INTERNAL "")
|
||||
set(BUILD_TESTS OFF CACHE INTERNAL "")
|
||||
endif()
|
||||
|
||||
set(BUILD_SHARED_LIBS OFF CACHE INTERNAL "")
|
||||
set(BUILD_DOCS OFF CACHE INTERNAL "")
|
||||
set(BUILD_EXAMPLES OFF CACHE INTERNAL "")
|
||||
set(BUILD_TESTS OFF CACHE INTERNAL "")
|
||||
|
||||
FetchContent_Declare(
|
||||
opencv
|
||||
GIT_REPOSITORY https://github.com/opencv/opencv.git
|
||||
|
@ -130,6 +124,7 @@ list(APPEND opencv_INCLUDE_DIRS
|
|||
|
||||
set(opencv_LIBS "")
|
||||
list(APPEND opencv_LIBS opencv_core opencv_imgproc)
|
||||
|
||||
if (OCOS_ENABLE_OPENCV_CODECS)
|
||||
list(APPEND opencv_INCLUDE_DIRS ${OPENCV_MODULE_opencv_imgcodecs_LOCATION}/include)
|
||||
list(APPEND opencv_LIBS opencv_imgcodecs)
|
||||
|
|
|
@ -3,19 +3,21 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <iterator>
|
||||
#include <vector>
|
||||
|
||||
#define ORT_API_MANUAL_INIT
|
||||
#include "onnxruntime_cxx_api.h"
|
||||
#undef ORT_API_MANUAL_INIT
|
||||
|
||||
|
||||
// A helper API to support test kernels.
|
||||
// Must be invoked before RegisterCustomOps.
|
||||
extern "C" bool ORT_API_CALL AddExternalCustomOp(const OrtCustomOp* c_op);
|
||||
|
||||
const char c_OpDomain[] = "ai.onnx.contrib";
|
||||
constexpr const char* c_OpDomain = "ai.onnx.contrib";
|
||||
constexpr const char* c_ComMsExtOpDomain = "com.microsoft.extensions";
|
||||
|
||||
struct BaseKernel {
|
||||
BaseKernel(const OrtApi& api) : api_(api), info_(nullptr), ort_(api_) {}
|
||||
|
@ -33,7 +35,7 @@ struct BaseKernel {
|
|||
return result;
|
||||
}
|
||||
|
||||
void SetOutput(OrtKernelContext* ctx, size_t output_idx, const std::vector<int64_t>& dim, const std::vector<int64_t>& data);
|
||||
void SetOutput(OrtKernelContext* ctx, size_t output_idx, const std::vector<int64_t>& dim, const std::vector<int64_t>& data);
|
||||
|
||||
protected:
|
||||
OrtErrorCode GetErrorCodeAndRelease(OrtStatusPtr status);
|
||||
|
@ -57,55 +59,42 @@ struct OrtTensorDimensions : std::vector<int64_t> {
|
|||
return s;
|
||||
}
|
||||
|
||||
bool IsScalar() const{
|
||||
bool IsScalar() const {
|
||||
return empty();
|
||||
}
|
||||
|
||||
bool IsVector() const{
|
||||
bool IsVector() const {
|
||||
return size() == 1;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename... Args>
|
||||
class CuopContainer {
|
||||
public:
|
||||
#if defined(_MSC_VER) && !defined(__clang__)
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable : 26409)
|
||||
#endif
|
||||
CuopContainer() : ocos_list_({[]() { return new Args; }()...}) {
|
||||
ocos_list_.push_back(nullptr);
|
||||
CuopContainer() : op_instances_({[]() { return std::make_shared<Args>(); }()...}) {
|
||||
ocos_list_.reserve(op_instances_.size());
|
||||
std::transform(op_instances_.begin(), op_instances_.end(), std::back_inserter(ocos_list_),
|
||||
[](const std::shared_ptr<OrtCustomOp>& custom_op) { return custom_op.get(); });
|
||||
}
|
||||
|
||||
~CuopContainer() {
|
||||
if (0 < ocos_list_.size()) {
|
||||
for (size_t i = 0; i < ocos_list_.size() - 1; i++) {
|
||||
delete ocos_list_[i];
|
||||
}
|
||||
}
|
||||
ocos_list_.clear();
|
||||
}
|
||||
#if defined(_MSC_VER) && !defined(__clang__)
|
||||
#pragma warning(pop)
|
||||
#endif
|
||||
const OrtCustomOp** GetList() {
|
||||
return &const_cast<const OrtCustomOp*&>(ocos_list_.front());
|
||||
const std::vector<const OrtCustomOp*>& GetCustomOps() const {
|
||||
return ocos_list_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<OrtCustomOp*> ocos_list_;
|
||||
std::vector<const OrtCustomOp*> ocos_list_;
|
||||
std::vector<std::shared_ptr<OrtCustomOp>> op_instances_; // use shared_ptr to capture type specific deleter
|
||||
};
|
||||
|
||||
struct CustomOpClassBegin{
|
||||
struct CustomOpClassBegin {
|
||||
};
|
||||
|
||||
typedef std::function<const OrtCustomOp**()> FxLoadCustomOpFactory;
|
||||
using FxLoadCustomOpFactory = std::function<const std::vector<const OrtCustomOp*>&()>;
|
||||
|
||||
template <typename _Begin_place_holder, typename... Args>
|
||||
const OrtCustomOp** LoadCustomOpClasses() {
|
||||
const std::vector<const OrtCustomOp*>& LoadCustomOpClasses() {
|
||||
static CuopContainer<Args...> ctr; // Let C++ runtime take cares of the MP initializing.
|
||||
return ctr.GetList();
|
||||
return ctr.GetCustomOps();
|
||||
}
|
||||
|
||||
#if defined(PYTHON_OP_SUPPORT)
|
||||
|
@ -119,12 +108,16 @@ extern FxLoadCustomOpFactory LoadCustomOpClasses_Math;
|
|||
|
||||
#ifdef ENABLE_TOKENIZER
|
||||
extern FxLoadCustomOpFactory LoadCustomOpClasses_Tokenizer;
|
||||
#endif // ENABLE_TOKENIZER
|
||||
#endif // ENABLE_TOKENIZER
|
||||
|
||||
#ifdef ENABLE_TF_STRING
|
||||
extern FxLoadCustomOpFactory LoadCustomOpClasses_Text;
|
||||
#endif // ENABLE_TF_STRING
|
||||
|
||||
#ifdef ENABLE_OPENCV
|
||||
extern FxLoadCustomOpFactory LoadCustomOpClasses_OpenCV;
|
||||
#ifdef ENABLE_CV2
|
||||
extern FxLoadCustomOpFactory LoadCustomOpClasses_CV2;
|
||||
#endif // ENABLE_OPENCV
|
||||
|
||||
#ifdef ENABLE_VISION
|
||||
extern FxLoadCustomOpFactory LoadCustomOpClasses_Vision;
|
||||
#endif
|
||||
|
|
|
@ -1,20 +1,20 @@
|
|||
#include "ocos.h"
|
||||
#include "gaussian_blur.hpp"
|
||||
#ifdef ENABLE_OPENCV_CODEC
|
||||
#ifdef ENABLE_OPENCV_CODECS
|
||||
#include "imread.hpp"
|
||||
#include "imdecode.hpp"
|
||||
#include "super_resolution_preprocess.hpp"
|
||||
#include "super_resolution_postprocess.hpp"
|
||||
#endif // ENABLE_OPENCV_CODEC
|
||||
#endif // ENABLE_OPENCV_CODECS
|
||||
|
||||
|
||||
FxLoadCustomOpFactory LoadCustomOpClasses_OpenCV =
|
||||
FxLoadCustomOpFactory LoadCustomOpClasses_CV2 =
|
||||
LoadCustomOpClasses<CustomOpClassBegin
|
||||
, CustomOpGaussianBlur
|
||||
#ifdef ENABLE_OPENCV_CODEC
|
||||
#ifdef ENABLE_OPENCV_CODECS
|
||||
, CustomOpImageReader
|
||||
, CustomOpImageDecoder
|
||||
, CustomOpSuperResolutionPreProcess
|
||||
, CustomOpSuperResolutionPostProcess
|
||||
#endif // ENABLE_OPENCV_CODEC
|
||||
#endif // ENABLE_OPENCV_CODECS
|
||||
>;
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "decode_image.hpp"
|
||||
|
||||
#include <opencv2/imgcodecs.hpp>
|
||||
|
||||
namespace ort_extensions {
|
||||
|
||||
void KernelDecodeImage::Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* const inputs = ort_.KernelContext_GetInput(context, 0ULL);
|
||||
OrtTensorDimensions dimensions(ort_, inputs);
|
||||
if (dimensions.size() != 1ULL) {
|
||||
ORT_CXX_API_THROW("[DecodeImage]: Raw image bytes with 1D shape expected.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
OrtTensorTypeAndShapeInfo* input_info = ort_.GetTensorTypeAndShape(inputs);
|
||||
const int64_t encoded_image_data_len = ort_.GetTensorShapeElementCount(input_info);
|
||||
ort_.ReleaseTensorTypeAndShapeInfo(input_info);
|
||||
|
||||
// Decode the image
|
||||
const std::vector<int32_t> encoded_image_sizes{1, static_cast<int32_t>(encoded_image_data_len)};
|
||||
const void* encoded_image_data = ort_.GetTensorData<uint8_t>(inputs); // uint8 data
|
||||
const cv::Mat encoded_image(encoded_image_sizes, CV_8UC1, const_cast<void*>(encoded_image_data));
|
||||
const cv::Mat decoded_image = cv::imdecode(encoded_image, cv::IMREAD_COLOR);
|
||||
|
||||
if (decoded_image.data == nullptr) {
|
||||
ORT_CXX_API_THROW("[DecodeImage] Invalid input. Failed to decode image.", ORT_INVALID_ARGUMENT);
|
||||
};
|
||||
|
||||
// Setup output & copy to destination
|
||||
const cv::Size decoded_image_size = decoded_image.size();
|
||||
const int64_t colors = decoded_image.elemSize(); // == 3 as it's BGR
|
||||
|
||||
const std::vector<int64_t> output_dims{decoded_image_size.height, decoded_image_size.width, colors};
|
||||
OrtValue* output_value = ort_.KernelContext_GetOutput(context, 0, output_dims.data(), output_dims.size());
|
||||
uint8_t* decoded_image_data = ort_.GetTensorMutableData<uint8_t>(output_value);
|
||||
memcpy(decoded_image_data, decoded_image.data, decoded_image_size.height * decoded_image_size.width * colors);
|
||||
}
|
||||
} // namespace ort_extensions
|
|
@ -0,0 +1,57 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ocos.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace ort_extensions {
|
||||
struct KernelDecodeImage : BaseKernel {
|
||||
KernelDecodeImage(const OrtApi& api) : BaseKernel(api) {}
|
||||
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
struct CustomOpDecodeImage : Ort::CustomOpBase<CustomOpDecodeImage, KernelDecodeImage> {
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return new KernelDecodeImage(api);
|
||||
}
|
||||
|
||||
void KernelDestroy(void* op_kernel) {
|
||||
delete static_cast<KernelDecodeImage*>(op_kernel);
|
||||
}
|
||||
|
||||
const char* GetName() const {
|
||||
return "DecodeImage";
|
||||
}
|
||||
|
||||
size_t GetInputTypeCount() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const {
|
||||
switch (index) {
|
||||
case 0:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
||||
default:
|
||||
ORT_CXX_API_THROW(MakeString("Invalid input index ", index), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
}
|
||||
|
||||
size_t GetOutputTypeCount() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const {
|
||||
switch (index) {
|
||||
case 0:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
||||
default:
|
||||
ORT_CXX_API_THROW(MakeString("Invalid output index ", index), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace ort_extensions
|
|
@ -0,0 +1,44 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "encode_image.hpp"
|
||||
|
||||
#include <opencv2/imgcodecs.hpp>
|
||||
|
||||
namespace ort_extensions {
|
||||
|
||||
void KernelEncodeImage ::Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input_bgr = ort_.KernelContext_GetInput(context, 0ULL);
|
||||
const OrtTensorDimensions dimensions_bgr(ort_, input_bgr);
|
||||
|
||||
if (dimensions_bgr.size() != 3 || dimensions_bgr[2] != 3) {
|
||||
// expect {H, W, C} as that's the inverse of what decode_image produces.
|
||||
// we have no way to check if it's BGR or RGB though
|
||||
ORT_CXX_API_THROW("[EncodeImage] requires rank 3 BGR input in channels last format.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
// Get data & the length
|
||||
std::vector<int32_t> height_x_width{static_cast<int32_t>(dimensions_bgr[0]), // H
|
||||
static_cast<int32_t>(dimensions_bgr[1])}; // W
|
||||
|
||||
// data is const uint8_t but opencv2 wants void*.
|
||||
const void* bgr_data = ort_.GetTensorData<uint8_t>(input_bgr);
|
||||
const cv::Mat bgr_image(height_x_width, CV_8UC3, const_cast<void*>(bgr_data));
|
||||
|
||||
// don't know output size ahead of time so need to encode and then copy to output
|
||||
std::vector<uint8_t> encoded_image;
|
||||
if (!cv::imencode(extension_, bgr_image, encoded_image)) {
|
||||
ORT_CXX_API_THROW("[EncodeImage] Image encoding failed.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
// Setup output & copy to destination
|
||||
std::vector<int64_t> output_dimensions{static_cast<int64_t>(encoded_image.size())};
|
||||
OrtValue* output_value = ort_.KernelContext_GetOutput(context, 0,
|
||||
output_dimensions.data(),
|
||||
output_dimensions.size());
|
||||
|
||||
uint8_t* data = ort_.GetTensorMutableData<uint8_t>(output_value);
|
||||
memcpy(data, encoded_image.data(), encoded_image.size());
|
||||
}
|
||||
} // namespace ort_extensions
|
|
@ -0,0 +1,71 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ocos.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace ort_extensions {
|
||||
struct KernelEncodeImage : BaseKernel {
|
||||
KernelEncodeImage(const OrtApi& api, const std::string& format)
|
||||
: BaseKernel{api},
|
||||
extension_{std::string(".") + format} {
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
private:
|
||||
const std::string extension_;
|
||||
};
|
||||
|
||||
/// <summary>
|
||||
/// EncodeImage
|
||||
///
|
||||
/// Converts rank 3 BGR input with channels last ordering to the requested file type.
|
||||
/// Default is 'jpg'
|
||||
/// </summary>
|
||||
struct CustomOpEncodeImage : Ort::CustomOpBase<CustomOpEncodeImage, KernelEncodeImage> {
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
Ort::CustomOpApi op_api{api};
|
||||
std::string format = op_api.KernelInfoGetAttribute<std::string>(info, "format");
|
||||
if (format != "jpg" && format != "png") {
|
||||
ORT_CXX_API_THROW("[EncodeImage] 'format' attribute value must be 'jpg' or 'png'.", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
|
||||
return new KernelEncodeImage(api, format);
|
||||
}
|
||||
|
||||
const char* GetName() const {
|
||||
return "EncodeImage";
|
||||
}
|
||||
|
||||
size_t GetInputTypeCount() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const {
|
||||
switch (index) {
|
||||
case 0:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
||||
default:
|
||||
ORT_CXX_API_THROW(MakeString("Invalid input index ", index), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
}
|
||||
|
||||
size_t GetOutputTypeCount() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const {
|
||||
switch (index) {
|
||||
case 0:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
||||
default:
|
||||
ORT_CXX_API_THROW(MakeString("Invalid output index ", index), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace ort_extensions
|
|
@ -0,0 +1,11 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "ocos.h"
|
||||
#include "decode_image.hpp"
|
||||
#include "encode_image.hpp"
|
||||
|
||||
FxLoadCustomOpFactory LoadCustomOpClasses_Vision =
|
||||
LoadCustomOpClasses<CustomOpClassBegin,
|
||||
ort_extensions::CustomOpDecodeImage,
|
||||
ort_extensions::CustomOpEncodeImage>;
|
|
@ -1,3 +1,6 @@
|
|||
[build-system]
|
||||
# Minimum requirements for the build system to execute.
|
||||
requires = ["setuptools", "wheel", "numpy>=1.18.5"] # PEP 508 specifications.
|
||||
|
||||
[tool.black]
|
||||
line-length = 120
|
||||
|
|
|
@ -72,7 +72,7 @@ extern "C" ORTX_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptio
|
|||
if (status = RegisterPythonDomainAndOps(options, ortApi); status) {
|
||||
return status;
|
||||
}
|
||||
#endif // PYTHON_OP_SUPPORT
|
||||
#endif // PYTHON_OP_SUPPORT
|
||||
|
||||
if (status = ortApi->CreateCustomOpDomain(c_OpDomain, &domain); status) {
|
||||
return status;
|
||||
|
@ -97,28 +97,31 @@ extern "C" ORTX_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptio
|
|||
static std::vector<FxLoadCustomOpFactory> c_factories = {
|
||||
LoadCustomOpClasses<CustomOpClassBegin>
|
||||
#if defined(ENABLE_TF_STRING)
|
||||
, LoadCustomOpClasses_Text
|
||||
#endif // ENABLE_TF_STRING
|
||||
,
|
||||
LoadCustomOpClasses_Text
|
||||
#endif // ENABLE_TF_STRING
|
||||
#if defined(ENABLE_MATH)
|
||||
, LoadCustomOpClasses_Math
|
||||
,
|
||||
LoadCustomOpClasses_Math
|
||||
#endif
|
||||
#if defined(ENABLE_TOKENIZER)
|
||||
, LoadCustomOpClasses_Tokenizer
|
||||
,
|
||||
LoadCustomOpClasses_Tokenizer
|
||||
#endif
|
||||
#if defined(ENABLE_OPENCV)
|
||||
, LoadCustomOpClasses_OpenCV
|
||||
#if defined(ENABLE_CV2)
|
||||
,
|
||||
LoadCustomOpClasses_CV2
|
||||
#endif
|
||||
};
|
||||
|
||||
for (auto fx : c_factories) {
|
||||
auto ops = fx();
|
||||
while (*ops != nullptr) {
|
||||
if (pyop_nameset.find((*ops)->GetName(*ops)) == pyop_nameset.end()) {
|
||||
if (status = ortApi->CustomOpDomain_Add(domain, *ops); status) {
|
||||
for (const auto& fx : c_factories) {
|
||||
const auto& ops = fx();
|
||||
for (const OrtCustomOp* op : ops) {
|
||||
if (pyop_nameset.find(op->GetName(op)) == pyop_nameset.end()) {
|
||||
if (status = ortApi->CustomOpDomain_Add(domain, op); status) {
|
||||
return status;
|
||||
}
|
||||
}
|
||||
++ops;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -133,5 +136,33 @@ extern "C" ORTX_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptio
|
|||
}
|
||||
}
|
||||
|
||||
if (status = ortApi->AddCustomOpDomain(options, domain); status) {
|
||||
return status;
|
||||
}
|
||||
|
||||
// Create domain for ops using the new domain name.
|
||||
if (status = ortApi->CreateCustomOpDomain(c_ComMsExtOpDomain, &domain); status) {
|
||||
return status;
|
||||
}
|
||||
|
||||
AddOrtCustomOpDomainToContainer(domain, ortApi);
|
||||
|
||||
static std::vector<FxLoadCustomOpFactory> new_domain_factories = {
|
||||
LoadCustomOpClasses<CustomOpClassBegin>
|
||||
#if defined(ENABLE_VISION)
|
||||
,
|
||||
LoadCustomOpClasses_Vision
|
||||
#endif
|
||||
};
|
||||
|
||||
for (const auto& fx : new_domain_factories) {
|
||||
const auto& ops = fx();
|
||||
for (const OrtCustomOp* op : ops) {
|
||||
if (status = ortApi->CustomOpDomain_Add(domain, op); status) {
|
||||
return status;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ortApi->AddCustomOpDomain(options, domain);
|
||||
}
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,68 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import enum
|
||||
import onnx
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# add tools dir where pre_post_processing folder is to sys path
|
||||
script_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
ort_ext_root = os.path.abspath(os.path.join(script_dir, "..", "..", ".."))
|
||||
tools_dir = os.path.join(ort_ext_root, "tools")
|
||||
sys.path.append(tools_dir)
|
||||
|
||||
from pre_post_processing import PrePostProcessor
|
||||
from pre_post_processing.Steps import *
|
||||
from pre_post_processing.utils import create_named_value, PRE_POST_PROCESSING_ONNX_OPSET
|
||||
|
||||
|
||||
def create_model(output_file: Path):
|
||||
"""
|
||||
Create unit test model. If input is bytes from a jpg we do the following
|
||||
- DecodeImage: jpg to BGR
|
||||
- EncodeImage: BGR to png (output format is set in the node)
|
||||
- DecodeImage: png to BGR
|
||||
|
||||
This is slightly easier to test as we can set the expected output by decoding the original image in the unit test.
|
||||
"""
|
||||
inputs = [create_named_value("image", onnx.TensorProto.UINT8, ["num_bytes"])]
|
||||
|
||||
pipeline = PrePostProcessor(inputs)
|
||||
pipeline.add_pre_processing(
|
||||
[
|
||||
ConvertImageToBGR(), # jpg/png image to BGR in HWC layout
|
||||
]
|
||||
)
|
||||
|
||||
pipeline.add_post_processing(
|
||||
[
|
||||
ConvertBGRToImage(image_format="png"), # jpg or png are supported
|
||||
ConvertImageToBGR(), # png to BGR in HWC layout
|
||||
]
|
||||
)
|
||||
|
||||
g = onnx.helper.make_graph(
|
||||
[
|
||||
onnx.helper.make_node("Identity", ["bgr_data_in"], ["bgr_data_out"])
|
||||
],
|
||||
"empty",
|
||||
[
|
||||
onnx.helper.make_tensor_value_info("bgr_data_in", onnx.TensorProto.UINT8, ['h', 'w', 3])
|
||||
],
|
||||
[
|
||||
onnx.helper.make_tensor_value_info("bgr_data_out", onnx.TensorProto.UINT8, ['h', 'w', 3])
|
||||
]
|
||||
)
|
||||
|
||||
onnx_import = onnx.helper.make_operatorsetid('', PRE_POST_PROCESSING_ONNX_OPSET)
|
||||
model = onnx.helper.make_model(g, opset_imports=[onnx_import])
|
||||
new_model = pipeline.run(model)
|
||||
new_model.doc_string = "Model for testing DecodeImage and EncodeImage."
|
||||
new_model.graph.doc_string = "" # clear out all the messages from graph merges
|
||||
onnx.save_model(new_model, str(output_file.resolve()))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
create_model(Path('decode_encode_decode_test.onnx'))
|
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
|
@ -0,0 +1,6 @@
|
|||
|
||||
Model sources:
|
||||
|
||||
PyTorch Mobilenet v2: https://pytorch.org/hub/pytorch_vision_mobilenet_v2/
|
||||
PyTorch Super Resolution: https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html
|
||||
Tensorflow Lite Mobilenet v2: https://tfhub.dev/iree/lite-model/mobilenet_v2_100_224/fp32/1
|
Двоичный файл не отображается.
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 130 KiB |
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 51 KiB |
|
@ -8,14 +8,35 @@
|
|||
const char* GetLibraryPath();
|
||||
|
||||
struct TestValue {
|
||||
TestValue(const char* name_in, const std::vector<float>& values_in, const std::vector<int64_t>& dims_in)
|
||||
: name{name_in}, element_type{ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT}, values_float{values_in}, dims{dims_in} {}
|
||||
|
||||
TestValue(const char* name_in, const std::vector<uint8_t>& values_in, const std::vector<int64_t>& dims_in)
|
||||
: name{name_in}, element_type{ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8}, values_uint8{values_in}, dims{dims_in} {}
|
||||
|
||||
TestValue(const char* name_in, const std::vector<int32_t>& values_in, const std::vector<int64_t>& dims_in)
|
||||
: name{name_in}, element_type{ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32}, values_int32{values_in}, dims{dims_in} {}
|
||||
|
||||
TestValue(const char* name_in, const std::vector<int64_t>& values_in, const std::vector<int64_t>& dims_in)
|
||||
: name{name_in}, element_type{ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64}, values_int64{values_in}, dims{dims_in} {}
|
||||
|
||||
TestValue(const char* name_in, const std::vector<std::string>& values_in, const std::vector<int64_t>& dims_in)
|
||||
: name{name_in}, element_type{ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING}, values_string{values_in}, dims{dims_in} {}
|
||||
|
||||
TestValue(const char* name_in, const std::vector<bool>& values_in, const std::vector<int64_t>& dims_in)
|
||||
: name{name_in}, element_type{ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL}, values_bool{values_in}, dims{dims_in} {}
|
||||
|
||||
TestValue() = default;
|
||||
|
||||
const char* name = nullptr;
|
||||
ONNXTensorElementDataType element_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
|
||||
std::vector<int64_t> dims;
|
||||
std::vector<float> values_float;
|
||||
std::vector<uint8_t> values_uint8;
|
||||
std::vector<int32_t> values_int32;
|
||||
std::vector<int64_t> values_int64;
|
||||
std::vector<std::string> values_string;
|
||||
std::vector<bool> value_bool;
|
||||
std::vector<bool> values_bool;
|
||||
};
|
||||
|
||||
void RunSession(Ort::Session& session_object,
|
||||
|
|
|
@ -10,7 +10,6 @@
|
|||
#include "string_tensor.h"
|
||||
#include "test_kernel.hpp"
|
||||
|
||||
|
||||
const char* GetLibraryPath() {
|
||||
#if defined(_WIN32)
|
||||
return "ortextensions.dll";
|
||||
|
@ -58,13 +57,13 @@ struct CustomOpOne : Ort::CustomOpBase<CustomOpOne, KernelOne> {
|
|||
size_t GetInputTypeCount() const {
|
||||
return 2;
|
||||
};
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const {
|
||||
ONNXTensorElementDataType GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
};
|
||||
size_t GetOutputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const {
|
||||
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
};
|
||||
};
|
||||
|
@ -101,13 +100,13 @@ struct CustomOpTwo : Ort::CustomOpBase<CustomOpTwo, KernelTwo> {
|
|||
size_t GetInputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const {
|
||||
ONNXTensorElementDataType GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
};
|
||||
size_t GetOutputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const {
|
||||
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
|
||||
};
|
||||
};
|
||||
|
@ -134,6 +133,7 @@ struct KernelThree : BaseKernel {
|
|||
out[i] = input_strs[i].find(substr_);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::string substr_;
|
||||
};
|
||||
|
@ -141,20 +141,20 @@ struct KernelThree : BaseKernel {
|
|||
struct CustomOpThree : Ort::CustomOpBase<CustomOpThree, KernelThree> {
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return CreateKernelImpl(api, info);
|
||||
};
|
||||
};
|
||||
const char* GetName() const {
|
||||
return "CustomOpThree";
|
||||
};
|
||||
size_t GetInputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const {
|
||||
ONNXTensorElementDataType GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
||||
size_t GetOutputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const {
|
||||
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
};
|
||||
};
|
||||
|
@ -175,7 +175,7 @@ void _emplace_back(Ort::MemoryInfo& memory_info, std::vector<Ort::Value>& ort_in
|
|||
}
|
||||
|
||||
ort_inputs.emplace_back(Ort::Value::CreateTensor<>(
|
||||
memory_info, reinterpret_cast<bool*>(convertor.data()) , values.size(), dims.data(), dims.size()));
|
||||
memory_info, reinterpret_cast<bool*>(convertor.data()), values.size(), dims.data(), dims.size()));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -219,11 +219,14 @@ void RunSession(Ort::Session& session_object,
|
|||
input_names.emplace_back(inputs[i].name);
|
||||
switch (inputs[i].element_type) {
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
|
||||
_emplace_back(memory_info, ort_inputs, inputs[i].value_bool, inputs[i].dims);
|
||||
_emplace_back(memory_info, ort_inputs, inputs[i].values_bool, inputs[i].dims);
|
||||
break;
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
|
||||
_emplace_back(memory_info, ort_inputs, inputs[i].values_float, inputs[i].dims);
|
||||
break;
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
|
||||
_emplace_back(memory_info, ort_inputs, inputs[i].values_uint8, inputs[i].dims);
|
||||
break;
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
|
||||
_emplace_back(memory_info, ort_inputs, inputs[i].values_int32, inputs[i].dims);
|
||||
break;
|
||||
|
@ -267,6 +270,9 @@ void RunSession(Ort::Session& session_object,
|
|||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
|
||||
_assert_eq(*output_tensor, expected.values_float, total_len);
|
||||
break;
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
|
||||
_assert_eq(*output_tensor, expected.values_uint8, total_len);
|
||||
break;
|
||||
case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
|
||||
_assert_eq(*output_tensor, expected.values_int32, total_len);
|
||||
break;
|
||||
|
@ -294,7 +300,8 @@ void TestInference(Ort::Env& env, const ORTCHAR_T* model_uri,
|
|||
Ort::SessionOptions session_options;
|
||||
void* handle = nullptr;
|
||||
if (custom_op_library_filename) {
|
||||
Ort::ThrowOnError(Ort::GetApi().RegisterCustomOpsLibrary((OrtSessionOptions*)session_options, custom_op_library_filename, &handle));
|
||||
Ort::ThrowOnError(Ort::GetApi().RegisterCustomOpsLibrary((OrtSessionOptions*)session_options,
|
||||
custom_op_library_filename, &handle));
|
||||
}
|
||||
|
||||
// if session creation passes, model loads fine
|
||||
|
@ -377,7 +384,7 @@ TEST(utils, test_get_str_attr) {
|
|||
}
|
||||
|
||||
TEST(ustring, tensor_operator) {
|
||||
OrtValue *tensor;
|
||||
OrtValue* tensor;
|
||||
OrtAllocator* allocator;
|
||||
|
||||
const auto* api_base = OrtGetApiBase();
|
||||
|
|
|
@ -13,11 +13,6 @@ TEST(hf_bert_tokenizer_opertor, test_default) {
|
|||
const std::filesystem::path model_path = "data/bert-large-uncased-whole-word-masking-finetuned-squad-tokenizer.onnx";
|
||||
|
||||
const std::vector<int64_t> input_dims{2};
|
||||
const std::vector<float> empty_float;
|
||||
const std::vector<int32_t> empty_int32;
|
||||
const std::vector<int64_t> empty_int64;
|
||||
const std::vector<std::string> empty_string;
|
||||
const std::vector<bool> empty_bool;
|
||||
const std::string context1 =
|
||||
"John is a 10 year old boy. "
|
||||
"He is the son of Robert Smith. "
|
||||
|
@ -32,124 +27,113 @@ TEST(hf_bert_tokenizer_opertor, test_default) {
|
|||
"He lives in Seattle, Washington.";
|
||||
|
||||
std::vector<std::vector<TestValue>> inputs = {
|
||||
{ TestValue{"text", ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, input_dims, empty_float, empty_int32, empty_int64,
|
||||
{"Who is John's sister?", context1}, empty_bool} },
|
||||
{ TestValue{"text", ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, input_dims, empty_float, empty_int32, empty_int64,
|
||||
{"Where does sophia study?", context1}, empty_bool} },
|
||||
{ TestValue{"text", ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, input_dims, empty_float, empty_int32, empty_int64,
|
||||
{"Who is John's mom?", context1}, empty_bool} },
|
||||
{ TestValue{"text", ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, input_dims, empty_float, empty_int32, empty_int64,
|
||||
{"Where does John's father's wife teach?", context1}, empty_bool} },
|
||||
{ TestValue{"text", ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, input_dims, empty_float, empty_int32, empty_int64,
|
||||
{"Who is John's father's wife's daughter's brother?", context1}, empty_bool} },
|
||||
{ TestValue{"text", ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, input_dims, empty_float, empty_int32, empty_int64,
|
||||
{"Who is John's friend?", context2}, empty_bool} },
|
||||
{ TestValue{"text", ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, input_dims, empty_float, empty_int32, empty_int64,
|
||||
{"Where does John's friend live?", context2}, empty_bool} },
|
||||
{ TestValue{"text", ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, input_dims, empty_float, empty_int32, empty_int64,
|
||||
{"Which state does John's friend live?", context2}, empty_bool} }
|
||||
};
|
||||
{TestValue{"text", {"Who is John's sister?", context1}, input_dims}},
|
||||
{TestValue{"text", {"Where does sophia study?", context1}, input_dims}},
|
||||
{TestValue{"text", {"Who is John's mom?", context1}, input_dims}},
|
||||
{TestValue{"text", {"Where does John's father's wife teach?", context1}, input_dims}},
|
||||
{TestValue{"text", {"Who is John's father's wife's daughter's brother?", context1}, input_dims}},
|
||||
{TestValue{"text", {"Who is John's friend?", context2}, input_dims}},
|
||||
{TestValue{"text", {"Where does John's friend live?", context2}, input_dims}},
|
||||
{TestValue{"text", {"Which state does John's friend live?", context2}, input_dims}}};
|
||||
|
||||
std::vector<std::vector<TestValue>> outputs = {
|
||||
{
|
||||
TestValue{"input_ids", ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, {1, 53}, empty_float, empty_int32,
|
||||
{101, 2040, 2003, 2198, 1005, 1055, 2905, 1029, 102, 2198, 2003, 1037, 2184, 2095, 2214, 2879, 1012, 2002, 2003, 1996, 2365, 1997, 2728, 3044,
|
||||
1012, 3870, 4482, 2003, 2728, 1005, 1055, 2564, 1012, 2016, 12011, 2012, 15384, 8256, 1012, 9665, 3044, 2003, 3870, 1005, 1055, 2684, 1012,
|
||||
2016, 2913, 2012, 15384, 4482, 102},
|
||||
empty_string, empty_bool},
|
||||
TestValue{"attention_mask", ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, {1, 53}, empty_float, empty_int32,
|
||||
std::vector<int64_t>(53, 1LL), empty_string, empty_bool},
|
||||
TestValue{"token_type_ids", ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, {1, 53}, empty_float, empty_int32,
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1},
|
||||
empty_string, empty_bool}
|
||||
},
|
||||
{
|
||||
TestValue{"input_ids", ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, {1, 51}, empty_float, empty_int32,
|
||||
{101, 2073, 2515, 9665, 2817, 1029, 102, 2198, 2003, 1037, 2184, 2095, 2214, 2879, 1012, 2002, 2003, 1996, 2365, 1997, 2728,
|
||||
3044, 1012, 3870, 4482, 2003, 2728, 1005, 1055, 2564, 1012, 2016, 12011, 2012, 15384, 8256, 1012, 9665, 3044, 2003, 3870,
|
||||
1005, 1055, 2684, 1012, 2016, 2913, 2012, 15384, 4482, 102},
|
||||
empty_string, empty_bool},
|
||||
TestValue{"attention_mask", ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, {1, 51}, empty_float, empty_int32,
|
||||
std::vector<int64_t>(51, 1LL), empty_string, empty_bool},
|
||||
TestValue{"token_type_ids", ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, {1, 51}, empty_float, empty_int32,
|
||||
{0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1},
|
||||
empty_string, empty_bool}
|
||||
},
|
||||
{
|
||||
TestValue{"input_ids", ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, {1, 53}, empty_float, empty_int32,
|
||||
{101, 2040, 2003, 2198, 1005, 1055, 3566, 1029, 102, 2198, 2003, 1037, 2184, 2095, 2214, 2879, 1012, 2002, 2003, 1996, 2365, 1997, 2728, 3044,
|
||||
1012, 3870, 4482, 2003, 2728, 1005, 1055, 2564, 1012, 2016, 12011, 2012, 15384, 8256, 1012, 9665, 3044, 2003, 3870, 1005, 1055, 2684, 1012,
|
||||
2016, 2913, 2012, 15384, 4482, 102},
|
||||
empty_string, empty_bool},
|
||||
TestValue{"attention_mask", ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, {1, 53}, empty_float, empty_int32,
|
||||
std::vector<int64_t>(53, 1LL), empty_string, empty_bool},
|
||||
TestValue{"token_type_ids", ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, {1, 53}, empty_float, empty_int32,
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1},
|
||||
empty_string, empty_bool}
|
||||
},
|
||||
{
|
||||
TestValue{"input_ids", ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, {1, 57}, empty_float, empty_int32,
|
||||
{101, 2073, 2515, 2198, 1005, 1055, 2269, 1005, 1055, 2564, 6570, 1029, 102, 2198, 2003, 1037, 2184, 2095, 2214, 2879, 1012, 2002, 2003, 1996,
|
||||
2365, 1997, 2728, 3044, 1012, 3870, 4482, 2003, 2728, 1005, 1055, 2564, 1012, 2016, 12011, 2012, 15384, 8256, 1012, 9665, 3044, 2003, 3870,
|
||||
1005, 1055, 2684, 1012, 2016, 2913, 2012, 15384, 4482, 102},
|
||||
empty_string, empty_bool},
|
||||
TestValue{"attention_mask", ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, {1, 57}, empty_float, empty_int32,
|
||||
std::vector<int64_t>(57, 1LL), empty_string, empty_bool},
|
||||
TestValue{"token_type_ids", ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, {1, 57}, empty_float, empty_int32,
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
empty_string, empty_bool}
|
||||
},
|
||||
{
|
||||
TestValue{"input_ids", ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, {1, 62}, empty_float, empty_int32,
|
||||
{101, 2040, 2003, 2198, 1005, 1055, 2269, 1005, 1055, 2564, 1005, 1055, 2684, 1005, 1055, 2567, 1029, 102, 2198, 2003, 1037, 2184, 2095, 2214,
|
||||
2879, 1012, 2002, 2003, 1996, 2365, 1997, 2728, 3044, 1012, 3870, 4482, 2003, 2728, 1005, 1055, 2564, 1012, 2016, 12011, 2012, 15384, 8256,
|
||||
1012, 9665, 3044, 2003, 3870, 1005, 1055, 2684, 1012, 2016, 2913, 2012, 15384, 4482, 102},
|
||||
empty_string, empty_bool},
|
||||
TestValue{"attention_mask", ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, {1, 62}, empty_float, empty_int32,
|
||||
std::vector<int64_t>(62, 1LL), empty_string, empty_bool},
|
||||
TestValue{"token_type_ids", ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, {1, 62}, empty_float, empty_int32,
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
empty_string, empty_bool}
|
||||
},
|
||||
{
|
||||
TestValue{"input_ids", ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, {1, 35}, empty_float, empty_int32,
|
||||
{101, 2040, 2003, 2198, 1005, 1055, 2767, 1029, 102, 2026, 2171, 2003, 2198, 1012, 1045, 2444, 1999, 2624, 4560, 1010, 2662, 1012, 6487, 2003,
|
||||
2026, 2767, 1012, 2002, 3268, 1999, 5862, 1010, 2899, 1012, 102},
|
||||
empty_string, empty_bool},
|
||||
TestValue{"attention_mask", ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, {1, 35}, empty_float, empty_int32,
|
||||
std::vector<int64_t>(35, 1LL), empty_string, empty_bool},
|
||||
TestValue{"token_type_ids", ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, {1, 35}, empty_float, empty_int32,
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
empty_string, empty_bool}
|
||||
},
|
||||
{
|
||||
TestValue{"input_ids", ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, {1, 36}, empty_float, empty_int32,
|
||||
{101, 2073, 2515, 2198, 1005, 1055, 2767, 2444, 1029, 102, 2026, 2171, 2003, 2198, 1012, 1045, 2444, 1999, 2624, 4560, 1010, 2662, 1012, 6487,
|
||||
2003, 2026, 2767, 1012, 2002, 3268, 1999, 5862, 1010, 2899, 1012, 102},
|
||||
empty_string, empty_bool},
|
||||
TestValue{"attention_mask", ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, {1, 36}, empty_float, empty_int32,
|
||||
std::vector<int64_t>(36, 1LL), empty_string, empty_bool},
|
||||
TestValue{"token_type_ids", ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, {1, 36}, empty_float, empty_int32,
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
empty_string, empty_bool}
|
||||
},
|
||||
{
|
||||
TestValue{"input_ids", ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, {1, 37}, empty_float, empty_int32,
|
||||
{101, 2029, 2110, 2515, 2198, 1005, 1055, 2767, 2444, 1029, 102, 2026, 2171, 2003, 2198, 1012, 1045, 2444, 1999, 2624, 4560, 1010, 2662, 1012,
|
||||
6487, 2003, 2026, 2767, 1012, 2002, 3268, 1999, 5862, 1010, 2899, 1012, 102},
|
||||
empty_string, empty_bool},
|
||||
TestValue{"attention_mask", ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, {1, 37}, empty_float, empty_int32,
|
||||
std::vector<int64_t>(37, 1LL), empty_string, empty_bool},
|
||||
TestValue{"token_type_ids", ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, {1, 37}, empty_float, empty_int32,
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
empty_string, empty_bool}
|
||||
}
|
||||
};
|
||||
{TestValue{
|
||||
"input_ids",
|
||||
std::vector<int64_t>{101, 2040, 2003, 2198, 1005, 1055, 2905, 1029, 102, 2198, 2003, 1037, 2184, 2095, 2214,
|
||||
2879, 1012, 2002, 2003, 1996, 2365, 1997, 2728, 3044, 1012, 3870, 4482, 2003, 2728,
|
||||
1005, 1055, 2564, 1012, 2016, 12011, 2012, 15384, 8256, 1012, 9665, 3044, 2003, 3870,
|
||||
1005, 1055, 2684, 1012, 2016, 2913, 2012, 15384, 4482, 102},
|
||||
{1, 53}},
|
||||
TestValue{"attention_mask", std::vector<int64_t>(53, 1), {1, 53}},
|
||||
TestValue{"token_type_ids",
|
||||
std::vector<int64_t>{0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
{1, 53}}},
|
||||
|
||||
{TestValue{"input_ids",
|
||||
std::vector<int64_t>{101, 2073, 2515, 9665, 2817, 1029, 102, 2198, 2003, 1037, 2184, 2095, 2214, 2879,
|
||||
1012, 2002, 2003, 1996, 2365, 1997, 2728, 3044, 1012, 3870, 4482, 2003, 2728,
|
||||
1005, 1055, 2564, 1012, 2016, 12011, 2012, 15384, 8256, 1012, 9665, 3044, 2003,
|
||||
3870, 1005, 1055, 2684, 1012, 2016, 2913, 2012, 15384, 4482, 102},
|
||||
{1, 51}},
|
||||
TestValue{"attention_mask", std::vector<int64_t>(51, 1), {1, 51}},
|
||||
TestValue{"token_type_ids",
|
||||
std::vector<int64_t>{0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
{1, 51}}},
|
||||
|
||||
{TestValue{"input_ids",
|
||||
std::vector<int64_t>{101, 2040, 2003, 2198, 1005, 1055, 3566, 1029, 102, 2198, 2003, 1037, 2184, 2095,
|
||||
2214, 2879, 1012, 2002, 2003, 1996, 2365, 1997, 2728, 3044, 1012, 3870, 4482,
|
||||
2003, 2728, 1005, 1055, 2564, 1012, 2016, 12011, 2012, 15384, 8256, 1012, 9665,
|
||||
3044, 2003, 3870, 1005, 1055, 2684, 1012, 2016, 2913, 2012, 15384, 4482, 102},
|
||||
{1, 53}},
|
||||
TestValue{"attention_mask", std::vector<int64_t>(53, 1), {1, 53}},
|
||||
TestValue{"token_type_ids",
|
||||
std::vector<int64_t>{0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
{1, 53}}},
|
||||
|
||||
{TestValue{"input_ids",
|
||||
std::vector<int64_t>{101, 2073, 2515, 2198, 1005, 1055, 2269, 1005, 1055, 2564, 6570, 1029, 102, 2198,
|
||||
2003, 1037, 2184, 2095, 2214, 2879, 1012, 2002, 2003, 1996, 2365, 1997, 2728,
|
||||
3044, 1012, 3870, 4482, 2003, 2728, 1005, 1055, 2564, 1012, 2016, 12011, 2012,
|
||||
15384, 8256, 1012, 9665, 3044, 2003, 3870, 1005, 1055, 2684, 1012, 2016, 2913,
|
||||
2012, 15384, 4482, 102},
|
||||
{1, 57}},
|
||||
TestValue{"attention_mask", std::vector<int64_t>(57, 1), {1, 57}},
|
||||
TestValue{"token_type_ids",
|
||||
std::vector<int64_t>{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1},
|
||||
{1, 57}}},
|
||||
|
||||
{TestValue{"input_ids",
|
||||
std::vector<int64_t>{101, 2040, 2003, 2198, 1005, 1055, 2269, 1005, 1055, 2564, 1005, 1055, 2684,
|
||||
1005, 1055, 2567, 1029, 102, 2198, 2003, 1037, 2184, 2095, 2214, 2879, 1012,
|
||||
2002, 2003, 1996, 2365, 1997, 2728, 3044, 1012, 3870, 4482, 2003, 2728, 1005,
|
||||
1055, 2564, 1012, 2016, 12011, 2012, 15384, 8256, 1012, 9665, 3044, 2003, 3870,
|
||||
1005, 1055, 2684, 1012, 2016, 2913, 2012, 15384, 4482, 102},
|
||||
{1, 62}},
|
||||
TestValue{"attention_mask", std::vector<int64_t>(62, 1), {1, 62}},
|
||||
TestValue{"token_type_ids",
|
||||
std::vector<int64_t>{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1},
|
||||
{1, 62}}},
|
||||
|
||||
{TestValue{"input_ids",
|
||||
std::vector<int64_t>{101, 2040, 2003, 2198, 1005, 1055, 2767, 1029, 102, 2026, 2171, 2003, 2198, 1012,
|
||||
1045, 2444, 1999, 2624, 4560, 1010, 2662, 1012, 6487, 2003, 2026, 2767, 1012,
|
||||
2002, 3268, 1999, 5862, 1010, 2899, 1012, 102},
|
||||
{1, 35}},
|
||||
TestValue{"attention_mask", std::vector<int64_t>(35, 1), {1, 35}},
|
||||
TestValue{"token_type_ids",
|
||||
std::vector<int64_t>{0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1},
|
||||
{1, 35}}},
|
||||
|
||||
{TestValue{"input_ids",
|
||||
std::vector<int64_t>{101, 2073, 2515, 2198, 1005, 1055, 2767, 2444, 1029, 102, 2026, 2171, 2003, 2198,
|
||||
1012, 1045, 2444, 1999, 2624, 4560, 1010, 2662, 1012, 6487, 2003, 2026, 2767,
|
||||
1012, 2002, 3268, 1999, 5862, 1010, 2899, 1012, 102},
|
||||
{1, 36}},
|
||||
TestValue{"attention_mask", std::vector<int64_t>(36, 1), {1, 36}},
|
||||
TestValue{"token_type_ids",
|
||||
std::vector<int64_t>{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
{1, 36}}},
|
||||
|
||||
{TestValue{"input_ids",
|
||||
std::vector<int64_t>{101, 2029, 2110, 2515, 2198, 1005, 1055, 2767, 2444, 1029, 102, 2026, 2171, 2003,
|
||||
2198, 1012, 1045, 2444, 1999, 2624, 4560, 1010, 2662, 1012, 6487, 2003, 2026,
|
||||
2767, 1012, 2002, 3268, 1999, 5862, 1010, 2899, 1012, 102},
|
||||
{1, 37}},
|
||||
TestValue{"attention_mask", std::vector<int64_t>(37, 1LL), {1, 37}},
|
||||
TestValue{"token_type_ids",
|
||||
std::vector<int64_t>{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
{1, 37}}}};
|
||||
|
||||
ASSERT_EQ(inputs.size(), outputs.size());
|
||||
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
#include "test_kernel.hpp"
|
||||
#include "text/string_lower.hpp"
|
||||
|
||||
|
||||
TEST(string_operator, test_string_lower) {
|
||||
auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");
|
||||
|
||||
|
@ -28,7 +27,6 @@ TEST(string_operator, test_string_lower) {
|
|||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
}
|
||||
|
||||
|
||||
TEST(string_operator, test_regex_split_with_offsets) {
|
||||
auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");
|
||||
|
||||
|
@ -64,7 +62,6 @@ TEST(string_operator, test_regex_split_with_offsets) {
|
|||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
}
|
||||
|
||||
|
||||
TEST(string_operator, test_string_ecmaregex_replace) {
|
||||
auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");
|
||||
|
||||
|
@ -84,14 +81,12 @@ TEST(string_operator, test_string_ecmaregex_replace) {
|
|||
inputs[2].dims = {1};
|
||||
inputs[2].values_string = {"$010"};
|
||||
|
||||
|
||||
std::vector<TestValue> outputs(1);
|
||||
outputs[0].name = "output";
|
||||
outputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
outputs[0].dims = {1};
|
||||
outputs[0].values_string = {"a Test 10 20 30 ♠♣"};
|
||||
|
||||
|
||||
std::filesystem::path model_path = "data";
|
||||
model_path /= "test_string_ecmaregex_replace.onnx";
|
||||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
|
@ -111,14 +106,12 @@ TEST(string_operator, test_string_ecmaregex_replace) {
|
|||
inputs[2].dims = {1};
|
||||
inputs[2].values_string = {"$010"};
|
||||
|
||||
|
||||
outputs[0].name = "output";
|
||||
outputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
outputs[0].dims = {1};
|
||||
outputs[0].values_string = {"a Test 1000 2000 3000 ♠♣"};
|
||||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
|
||||
|
||||
inputs[0].name = "input";
|
||||
inputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
inputs[0].dims = {1};
|
||||
|
@ -134,7 +127,6 @@ TEST(string_operator, test_string_ecmaregex_replace) {
|
|||
inputs[2].dims = {1};
|
||||
inputs[2].values_string = {"$010"};
|
||||
|
||||
|
||||
outputs[0].name = "output";
|
||||
outputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
outputs[0].dims = {1};
|
||||
|
@ -156,7 +148,6 @@ TEST(string_operator, test_string_ecmaregex_replace) {
|
|||
inputs[2].dims = {1};
|
||||
inputs[2].values_string = {"$1+"};
|
||||
|
||||
|
||||
outputs[0].name = "output";
|
||||
outputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
outputs[0].dims = {1};
|
||||
|
@ -205,7 +196,6 @@ TEST(string_operator, test_string_ecmaregex_replace) {
|
|||
outputs[0].values_string = {"Test 10 20 30 ", "Test 40 50 60 ", " Test 70 80 90 "};
|
||||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
|
||||
|
||||
inputs[0].name = "input";
|
||||
inputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
inputs[0].dims = {1};
|
||||
|
@ -221,7 +211,6 @@ TEST(string_operator, test_string_ecmaregex_replace) {
|
|||
inputs[2].dims = {1};
|
||||
inputs[2].values_string = {"aa"};
|
||||
|
||||
|
||||
outputs[0].name = "output";
|
||||
outputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
outputs[0].dims = {1};
|
||||
|
@ -249,7 +238,6 @@ TEST(string_operator, test_string_ecmaregex_replace) {
|
|||
inputs[2].dims = {1};
|
||||
inputs[2].values_string = {"$1+"};
|
||||
|
||||
|
||||
outputs[0].name = "output";
|
||||
outputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
outputs[0].dims = {3};
|
||||
|
@ -257,15 +245,14 @@ TEST(string_operator, test_string_ecmaregex_replace) {
|
|||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
}
|
||||
|
||||
|
||||
TEST(utils, test_string_join) {
|
||||
auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");
|
||||
|
||||
std::vector<TestValue> inputs(3);
|
||||
inputs[0].name = "text";
|
||||
inputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
inputs[0].dims = {1,3};
|
||||
inputs[0].values_string = {"abc","zzz","efg"};
|
||||
inputs[0].dims = {1, 3};
|
||||
inputs[0].values_string = {"abc", "zzz", "efg"};
|
||||
|
||||
inputs[1].name = "sep";
|
||||
inputs[1].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
|
@ -459,10 +446,9 @@ TEST(string_operator, test_string_to_vector) {
|
|||
std::vector<TestValue> outputs(1);
|
||||
outputs[0].name = "token_ids";
|
||||
outputs[0].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
outputs[0].dims = {5,3};
|
||||
outputs[0].dims = {5, 3};
|
||||
outputs[0].values_int64 = {0, 1, 2, 2, 3, 4, 0, 1, 2, 3, 4, 4, -1, -1, -1};
|
||||
|
||||
|
||||
std::filesystem::path model_path = "data";
|
||||
model_path /= "test_string_to_vector.onnx";
|
||||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
|
@ -509,7 +495,6 @@ TEST(string_operator, test_string_mapping) {
|
|||
outputs[0].dims = {5};
|
||||
outputs[0].values_string = {"Maybe", "也不知道可不可以", "No color", "OK", "Not OK"};
|
||||
|
||||
|
||||
std::filesystem::path model_path = "data";
|
||||
model_path /= "test_string_mapping.onnx";
|
||||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
|
@ -541,7 +526,7 @@ TEST(string_operator, test_masked_fill) {
|
|||
inputs[1].name = "mask";
|
||||
inputs[1].element_type = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
|
||||
inputs[1].dims = {5};
|
||||
inputs[1].value_bool = {true, false, true, false, true};
|
||||
inputs[1].values_bool = {true, false, true, false, true};
|
||||
|
||||
std::vector<TestValue> outputs(1);
|
||||
outputs[0].name = "output";
|
||||
|
@ -549,7 +534,6 @@ TEST(string_operator, test_masked_fill) {
|
|||
outputs[0].dims = {3};
|
||||
outputs[0].values_string = {"Orange and Yellow", "No color", "white"};
|
||||
|
||||
|
||||
std::filesystem::path model_path = "data";
|
||||
model_path /= "test_masked_fill.onnx";
|
||||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
|
@ -558,7 +542,7 @@ TEST(string_operator, test_masked_fill) {
|
|||
inputs[0].values_string = {"Orange and Yellow", "不知道啥颜色", "No color", "black", "white"};
|
||||
|
||||
inputs[1].dims = {5};
|
||||
inputs[1].value_bool = {false, false, false, false, false};
|
||||
inputs[1].values_bool = {false, false, false, false, false};
|
||||
|
||||
outputs[0].dims = {0};
|
||||
outputs[0].values_string = {};
|
||||
|
@ -568,7 +552,7 @@ TEST(string_operator, test_masked_fill) {
|
|||
inputs[0].values_string = {"Orange and Yellow", "不知道啥颜色", "No color", "black", "white"};
|
||||
|
||||
inputs[1].dims = {5};
|
||||
inputs[1].value_bool = {true, true, true, true, true};
|
||||
inputs[1].values_bool = {true, true, true, true, true};
|
||||
|
||||
outputs[0].dims = {5};
|
||||
outputs[0].values_string = {"Orange and Yellow", "不知道啥颜色", "No color", "black", "white"};
|
||||
|
@ -578,7 +562,7 @@ TEST(string_operator, test_masked_fill) {
|
|||
inputs[0].values_string = {"a"};
|
||||
|
||||
inputs[1].dims = {1};
|
||||
inputs[1].value_bool = {false};
|
||||
inputs[1].values_bool = {false};
|
||||
|
||||
outputs[0].dims = {0};
|
||||
outputs[0].values_string = {};
|
||||
|
@ -588,7 +572,7 @@ TEST(string_operator, test_masked_fill) {
|
|||
inputs[0].values_string = {"a"};
|
||||
|
||||
inputs[1].dims = {1};
|
||||
inputs[1].value_bool = {true};
|
||||
inputs[1].values_bool = {true};
|
||||
|
||||
outputs[0].dims = {1};
|
||||
outputs[0].values_string = {"a"};
|
||||
|
@ -598,9 +582,9 @@ TEST(string_operator, test_masked_fill) {
|
|||
inputs[0].values_string = {};
|
||||
|
||||
inputs[1].dims = {0};
|
||||
inputs[1].value_bool = {};
|
||||
inputs[1].values_bool = {};
|
||||
|
||||
outputs[0].dims = {0};
|
||||
outputs[0].values_string = {};
|
||||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <vector>
|
||||
#include "gtest/gtest.h"
|
||||
#include "ocos.h"
|
||||
#include "test_kernel.hpp"
|
||||
|
||||
#include <opencv2/imgcodecs.hpp>
|
||||
|
||||
namespace {
|
||||
std::vector<uint8_t> LoadBytesFromFile(const std::filesystem::path& filename) {
|
||||
using namespace std;
|
||||
ifstream ifs(filename, ios::binary | ios::ate);
|
||||
ifstream::pos_type pos = ifs.tellg();
|
||||
|
||||
std::vector<uint8_t> input_bytes(pos);
|
||||
ifs.seekg(0, ios::beg);
|
||||
// we want uint8_t values so reinterpret_cast so we don't have to read chars and copy to uint8_t after.
|
||||
ifs.read(reinterpret_cast<char*>(input_bytes.data()), pos);
|
||||
|
||||
return input_bytes;
|
||||
}
|
||||
|
||||
void FixCurrentDir() {
|
||||
// adjust for the Google Test Adapter in Visual Studio not setting the current path to $(ProjectDir),
|
||||
// which results in us being 2 levels below where the `data` folder is copied to and where the extensions
|
||||
// library is
|
||||
auto cur = std::filesystem::current_path();
|
||||
|
||||
do {
|
||||
auto data_dir = cur / "data";
|
||||
|
||||
if (std::filesystem::exists(data_dir) && std::filesystem::is_directory(data_dir)) {
|
||||
break;
|
||||
}
|
||||
|
||||
cur = cur.parent_path();
|
||||
ASSERT_NE(cur, cur.root_path()) << "Reached root directory without finding 'data' directory.";
|
||||
} while (true);
|
||||
|
||||
// set current path as the extensions library is also loaded from that directory by TestInference
|
||||
std::filesystem::current_path(cur);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Test DecodeImage and EncodeImage by providing a jpg image. Model will decode to BGR, encode to PNG and decode
|
||||
// again to BGR. We validate that the BGR output from that matches the original image.
|
||||
TEST(VisionOps, image_decode_encode) {
|
||||
FixCurrentDir();
|
||||
|
||||
std::string ort_version{OrtGetApiBase()->GetVersionString()};
|
||||
|
||||
// the test model requires ONNX opset 16, which requires ORT version 1.11 or later.
|
||||
// skip test if the CI doesn't have that ORT version.
|
||||
// the CI only has a few ORT versions so we don't worry about versions <= 1.2
|
||||
if (ort_version.compare(0, 3, "1.1") != 0 || // earlier than 1.10
|
||||
ort_version.compare(0, 4, "1.10") == 0) { // earlier than 1.11
|
||||
return;
|
||||
}
|
||||
|
||||
auto data_dir = std::filesystem::current_path() / "data";
|
||||
auto model_path = data_dir / "ppp_vision" / "decode_encode_decode_test.onnx";
|
||||
auto image_path = data_dir / "test_colors.jpg";
|
||||
|
||||
auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");
|
||||
std::vector<uint8_t> image_data = LoadBytesFromFile(image_path);
|
||||
|
||||
// decode image to get expected output
|
||||
const std::vector<int32_t> encoded_image_sizes{1, static_cast<int32_t>(image_data.size())};
|
||||
const cv::Mat encoded_image(encoded_image_sizes, CV_8UC1, static_cast<void*>(image_data.data()));
|
||||
const cv::Mat decoded_image = cv::imdecode(encoded_image, cv::IMREAD_COLOR);
|
||||
ASSERT_NE(decoded_image.data, nullptr) << "imdecode failed";
|
||||
|
||||
const cv::Size decoded_image_size = decoded_image.size();
|
||||
const int64_t colors = 3;
|
||||
const std::vector<int64_t> output_dimensions{decoded_image_size.height, decoded_image_size.width, colors};
|
||||
// decoded_image.total() is num pixels. elemSize is 3 (BGR value per pixel)
|
||||
const auto num_output_bytes = decoded_image.total() * decoded_image.elemSize();
|
||||
std::vector<uint8_t> expected_output(num_output_bytes, 0);
|
||||
memcpy(expected_output.data(), decoded_image.data, num_output_bytes);
|
||||
|
||||
std::vector<TestValue> inputs{TestValue("image", image_data, {static_cast<int64_t>(image_data.size())})};
|
||||
std::vector<TestValue> outputs{TestValue("bgr_data", expected_output, output_dimensions)};
|
||||
|
||||
TestInference(*ort_env, model_path.c_str(), inputs, outputs, GetLibraryPath());
|
||||
}
|
|
@ -0,0 +1,175 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import unittest
|
||||
|
||||
import io
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import os
|
||||
import sys
|
||||
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
from onnxruntime_extensions import get_library_path
|
||||
|
||||
# add tools dir where pre_post_processing folder is to sys path
|
||||
# TODO: Move this script to test folder so this is needed
|
||||
script_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
ort_ext_root = os.path.abspath(os.path.join(script_dir, ".."))
|
||||
tools_dir = os.path.join(ort_ext_root, "tools")
|
||||
test_data_dir = os.path.join(ort_ext_root, "test", "data", "ppp_vision")
|
||||
sys.path.append(tools_dir)
|
||||
|
||||
import add_pre_post_processing_to_model as add_ppp
|
||||
|
||||
|
||||
# Function to read the mobilenet labels and adjust for PT vs TF training if needed
|
||||
# def _get_labels(is_pytorch: bool = True):
|
||||
# labels_file = os.path.join(test_data_dir, "TF.ImageNetLabels.txt")
|
||||
# labels = []
|
||||
# with open(labels_file, 'r') as infile:
|
||||
# # skip first 'background' entry if pytorch as that model was not trained with it
|
||||
# if is_pytorch:
|
||||
# _ = infile.readline()
|
||||
#
|
||||
# for line in infile:
|
||||
# labels.append(line.strip())
|
||||
#
|
||||
# assert(len(labels) == 1000 if is_pytorch else 1001)
|
||||
# return labels
|
||||
|
||||
|
||||
class TestToolsAddPrePostProcessingToModel(unittest.TestCase):
|
||||
def test_pytorch_mobilenet(self):
|
||||
input_model = os.path.join(test_data_dir, "pytorch_mobilenet_v2.onnx")
|
||||
output_model = os.path.join(test_data_dir, "pytorch_mobilenet_v2.updated.onnx")
|
||||
input_image_path = os.path.join(test_data_dir, "wolves.jpg")
|
||||
|
||||
add_ppp.mobilenet(Path(input_model), Path(output_model))
|
||||
|
||||
def orig_output():
|
||||
from torchvision import transforms
|
||||
input_image = Image.open(input_image_path)
|
||||
preprocess = transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
input_tensor = preprocess(input_image)
|
||||
input_batch = input_tensor.unsqueeze(
|
||||
0).detach().cpu().numpy() # create a mini-batch as expected by the model
|
||||
|
||||
s = ort.InferenceSession(input_model)
|
||||
scores = s.run(None, {'x': np.array(input_batch)})
|
||||
scores = np.squeeze(scores)
|
||||
|
||||
def softmax(x):
|
||||
e_x = np.exp(x - np.max(x))
|
||||
return e_x / e_x.sum()
|
||||
|
||||
probabilities = softmax(scores)
|
||||
return probabilities
|
||||
|
||||
def new_output():
|
||||
input_bytes = np.fromfile(input_image_path, dtype=np.uint8)
|
||||
so = ort.SessionOptions()
|
||||
so.register_custom_ops_library(get_library_path())
|
||||
|
||||
s = ort.InferenceSession(output_model, so)
|
||||
probabilities = s.run(None, {'image': np.array(input_bytes)})[0]
|
||||
probabilities = np.squeeze(probabilities) # remove batch dim
|
||||
return probabilities
|
||||
|
||||
orig_results = orig_output()
|
||||
new_results = new_output()
|
||||
|
||||
orig_idx = np.argmax(orig_results)
|
||||
new_idx = np.argmax(new_results)
|
||||
self.assertEqual(orig_idx, new_idx)
|
||||
# check within 1%. probability values are in range 0..1
|
||||
self.assertTrue(abs(orig_results[orig_idx] - new_results[new_idx]) < 0.01)
|
||||
|
||||
def test_tflite_mobilenet(self):
|
||||
input_model = os.path.join(test_data_dir, "tflite_mobilenet_v2.onnx")
|
||||
output_model = os.path.join(test_data_dir, "tflite_mobilenet_v2.updated.onnx")
|
||||
input_image_path = os.path.join(test_data_dir, "wolves.jpg")
|
||||
|
||||
add_ppp.mobilenet(Path(input_model), Path(output_model), add_ppp.ModelSource.TENSORFLOW)
|
||||
|
||||
def orig_output():
|
||||
# can still use PT pre-processing as it's using PIL for images.
|
||||
# Update the Normalize values to match TF requirements.
|
||||
from torchvision import transforms
|
||||
input_image = Image.open(input_image_path)
|
||||
preprocess = transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||
])
|
||||
input_tensor = preprocess(input_image)
|
||||
input_batch = input_tensor.unsqueeze(
|
||||
0).detach().cpu().numpy() # create a mini-batch as expected by the model
|
||||
input_batch = np.transpose(input_batch, (0, 2, 3, 1)) # to NHWC format for TF input
|
||||
|
||||
s = ort.InferenceSession(input_model)
|
||||
probabilities = s.run(None, {'input': np.array(input_batch)})[0]
|
||||
return np.squeeze(probabilities)
|
||||
|
||||
def new_output():
|
||||
# TODO: Should we get the ortextensions library path from an env var and if provided run the model?
|
||||
input_bytes = np.fromfile(input_image_path, dtype=np.uint8)
|
||||
|
||||
so = ort.SessionOptions()
|
||||
so.register_custom_ops_library(get_library_path())
|
||||
|
||||
s = ort.InferenceSession(output_model, so)
|
||||
probabilities = s.run(None, {'image': np.array(input_bytes)})[0]
|
||||
return np.squeeze(probabilities) # remove batch dim
|
||||
|
||||
orig_results = orig_output()
|
||||
new_results = new_output()
|
||||
|
||||
orig_idx = np.argmax(orig_results)
|
||||
new_idx = np.argmax(new_results)
|
||||
self.assertEqual(orig_idx, new_idx)
|
||||
# check within 1%. probability values are in range 0..1
|
||||
self.assertTrue(abs(orig_results[orig_idx] - new_results[new_idx]) < 0.01)
|
||||
|
||||
def test_pytorch_superresolution(self):
|
||||
input_model = os.path.join(test_data_dir, "pytorch_super_resolution.onnx")
|
||||
output_model = os.path.join(test_data_dir, "pytorch_super_resolution.updated.onnx")
|
||||
input_image_path = os.path.join(test_data_dir, "..", "test_supres.jpg")
|
||||
|
||||
# expected output is manually inspected result of running the model.
|
||||
# there are still some diffs in the resized Cb and Cr values that get merged in during post-processing due to
|
||||
# the ONNX Resize not supporting anti-aliasing. That _should_ be added in the next ORT release as the ONNX spec
|
||||
# has added anti-aliasing.
|
||||
expected_output_image_path = os.path.join(test_data_dir, "..", "test_supres_expected.jpg")
|
||||
|
||||
add_ppp.superresolution(Path(input_model), Path(output_model))
|
||||
|
||||
input_bytes = np.fromfile(input_image_path, dtype=np.uint8)
|
||||
|
||||
so = ort.SessionOptions()
|
||||
so.register_custom_ops_library(get_library_path())
|
||||
s = ort.InferenceSession(output_model, so)
|
||||
|
||||
result_bytes = s.run(None, {'image': np.array(input_bytes)})[0]
|
||||
|
||||
# convert from jpg to RGB to remove any jpg encoding diffs
|
||||
result = np.array(Image.open(io.BytesIO(result_bytes)).convert('RGB'))
|
||||
expected = np.array(Image.open(expected_output_image_path).convert('RGB'))
|
||||
|
||||
# check all pixel values are within 1.
|
||||
#
|
||||
# we expect some variance from the floating point operations involved during Resize and conversion of the
|
||||
# original image to/from YCbCr. the different instructions used on different hardware can cause diffs, such as
|
||||
# whether avx512 is used or not.
|
||||
self.assertTrue(np.allclose(expected, result, atol=1, rtol=0))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,216 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import argparse
|
||||
import enum
|
||||
import os
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from pre_post_processing import PrePostProcessor
|
||||
from pre_post_processing.steps import *
|
||||
from pre_post_processing.utils import create_named_value, IoMapEntry
|
||||
|
||||
|
||||
class ModelSource(enum.Enum):
|
||||
PYTORCH = 0
|
||||
TENSORFLOW = 1
|
||||
OTHER = 2
|
||||
|
||||
|
||||
def imagenet_preprocessing(model_source: ModelSource = ModelSource.PYTORCH):
|
||||
"""
|
||||
Common pre-processing for an imagenet trained model.
|
||||
|
||||
- Resize so smallest side is 256
|
||||
- Centered crop to 224 x 224
|
||||
- Convert image bytes to floating point values in range 0..1
|
||||
- [Channels last to channels first (convert to ONNX layout) if model came from pytorch and has NCHW layout]
|
||||
- Normalize
|
||||
- (value - mean) / stddev
|
||||
- for a pytorch model, this applies per-channel normalization parameters
|
||||
- for a tensorflow model this simply moves the image bytes into the range -1..1
|
||||
- adds a batch dimension with a value of 1
|
||||
"""
|
||||
|
||||
# These utils cover both cases of typical pytorch/tensorflow pre-processing for an imagenet trained model
|
||||
# https://github.com/keras-team/keras/blob/b80dd12da9c0bc3f569eca3455e77762cf2ee8ef/keras/applications/imagenet_utils.py#L177
|
||||
|
||||
steps = [
|
||||
Resize(256),
|
||||
CenterCrop(224, 224),
|
||||
ImageBytesToFloat()
|
||||
]
|
||||
|
||||
if model_source == ModelSource.PYTORCH:
|
||||
# pytorch model has NCHW layout
|
||||
steps.extend([
|
||||
ChannelsLastToChannelsFirst(),
|
||||
Normalize([(0.485, 0.229), (0.456, 0.224), (0.406, 0.225)], layout="CHW")
|
||||
])
|
||||
else:
|
||||
# TF processing involves moving the data into the range -1..1 instead of 0..1.
|
||||
# ImageBytesToFloat converts to range 0..1, so we use 0.5 for the mean to move into the range -0.5..0.5
|
||||
# and 0.5 for the stddev to expand to -1..1
|
||||
steps.append(Normalize([(0.5, 0.5)], layout="HWC"))
|
||||
|
||||
steps.append(Unsqueeze([0])) # add batch dim
|
||||
|
||||
return steps
|
||||
|
||||
|
||||
def mobilenet(model_file: Path, output_file: Path, model_source: ModelSource = ModelSource.PYTORCH):
|
||||
model = onnx.load(str(model_file.resolve(strict=True)))
|
||||
inputs = [create_named_value("image", onnx.TensorProto.UINT8, ["num_bytes"])]
|
||||
|
||||
pipeline = PrePostProcessor(inputs)
|
||||
|
||||
# support user providing encoded image bytes
|
||||
preprocessing = [
|
||||
ConvertImageToBGR(), # custom op to convert jpg/png to BGR (output is HWC)
|
||||
ReverseAxis(axis=2, dim_value=3, name="BGR_to_RGB"),
|
||||
] # Normalization params are for RGB ordering
|
||||
# plug in default imagenet pre-processing
|
||||
preprocessing.extend(imagenet_preprocessing(model_source))
|
||||
|
||||
pipeline.add_pre_processing(preprocessing)
|
||||
|
||||
# for mobilenet we convert the score to probabilities with softmax if necessary. the TF model includes Softmax
|
||||
if model.graph.node[-1].op_type != "Softmax":
|
||||
pipeline.add_post_processing([Softmax()])
|
||||
|
||||
new_model = pipeline.run(model)
|
||||
|
||||
onnx.save_model(new_model, str(output_file.resolve()))
|
||||
|
||||
|
||||
def superresolution(model_file: Path, output_file: Path):
|
||||
# TODO: There seems to be a split with some super resolution models processing RGB input and some processing
|
||||
# the Y channel after converting to YCbCr.
|
||||
# For the sake of this example implementation we do the trickier YCbCr processing as that involves joining the
|
||||
# Cb and Cr channels with the model output to create the resized image.
|
||||
# Model is from https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html
|
||||
model = onnx.load(str(model_file.resolve(strict=True)))
|
||||
inputs = [create_named_value("image", onnx.TensorProto.UINT8, ["num_bytes"])]
|
||||
|
||||
# assuming input is *CHW, infer the input sizes from the model.
|
||||
# requires the model input and output has a fixed size for the input and output height and width.
|
||||
model_input_shape = model.graph.input[0].type.tensor_type.shape
|
||||
model_output_shape = model.graph.output[0].type.tensor_type.shape
|
||||
assert model_input_shape.dim[-1].HasField("dim_value")
|
||||
assert model_input_shape.dim[-2].HasField("dim_value")
|
||||
assert model_output_shape.dim[-1].HasField("dim_value")
|
||||
assert model_output_shape.dim[-2].HasField("dim_value")
|
||||
|
||||
w_in = model_input_shape.dim[-1].dim_value
|
||||
h_in = model_input_shape.dim[-2].dim_value
|
||||
h_out = model_output_shape.dim[-2].dim_value
|
||||
w_out = model_output_shape.dim[-1].dim_value
|
||||
|
||||
# pre/post processing for https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html
|
||||
pipeline = PrePostProcessor(inputs)
|
||||
pipeline.add_pre_processing(
|
||||
[
|
||||
ConvertImageToBGR(), # jpg/png image to BGR in HWC layout
|
||||
Resize((h_in, w_in)),
|
||||
CenterCrop(h_in, w_in),
|
||||
# this produces Y, Cb and Cr outputs. each has shape {h_in, w_in}. only Y is input to model
|
||||
PixelsToYCbCr(layout="BGR"),
|
||||
# if you inserted this Debug step here the 3 outputs from PixelsToYCbCr would also be model outputs
|
||||
# Debug(num_inputs=3),
|
||||
ImageBytesToFloat(), # Convert Y to float in range 0..1
|
||||
Unsqueeze([0, 1]), # add batch and channels dim to Y so shape is {1, 1, h_in, w_in}
|
||||
]
|
||||
)
|
||||
|
||||
# Post-processing is complicated here. resize the Cb and Cr outputs from the pre-processing to match
|
||||
# the model output size, merge those with the Y` model output, and convert back to RGB.
|
||||
|
||||
# create the Steps we need to use in the manual connections
|
||||
pipeline.add_post_processing(
|
||||
[
|
||||
Squeeze([0, 1]), # remove batch and channels dims from Y'
|
||||
FloatToImageBytes(name="Yout_to_bytes"), # convert Y' to uint8 in range 0..255
|
||||
# Resize the Cb values (output 1 from PixelsToYCbCr)
|
||||
(
|
||||
Resize((h_out, w_out), "HW"),
|
||||
[IoMapEntry(producer="PixelsToYCbCr", producer_idx=1, consumer_idx=0)],
|
||||
),
|
||||
# the Cb and Cr values are already in the range 0..255 so multiplier is 1. we're using the step to round
|
||||
# for accuracy (a direct Cast would just truncate) and clip (to ensure range 0..255) the values post-Resize
|
||||
FloatToImageBytes(multiplier=1.0, name="Resized_Cb"),
|
||||
(Resize((h_out, w_out), "HW"), [IoMapEntry("PixelsToYCbCr", 2, 0)]),
|
||||
FloatToImageBytes(multiplier=1.0, name="Resized_Cr"),
|
||||
# as we're selecting outputs from multiple previous steps we need to map them to the inputs using step names
|
||||
(
|
||||
YCbCrToPixels(layout="BGR"),
|
||||
[
|
||||
IoMapEntry("Yout_to_bytes", 0, 0), # uint8 Y' with shape {h, w}
|
||||
IoMapEntry("Resized_Cb", 0, 1), # uint8 Cb'
|
||||
IoMapEntry("Resized_Cr", 0, 2), # uint8 Cr'
|
||||
],
|
||||
),
|
||||
ConvertBGRToImage(image_format="jpg"), # jpg or png are supported
|
||||
]
|
||||
)
|
||||
|
||||
new_model = pipeline.run(model)
|
||||
onnx.save_model(new_model, str(output_file.resolve()))
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
os.path.basename(__file__),
|
||||
description="""Add pre and post processing to a model.
|
||||
|
||||
Currently supports updating:
|
||||
- super resolution with YCbCr input
|
||||
- imagenet trained mobilenet
|
||||
|
||||
To customize, the logic in the `mobilenet` and `superresolution` functions can be used as a guide.
|
||||
Create a pipeline and add the required pre/post processing 'Steps' in the order required. Configure
|
||||
individual steps as needed.
|
||||
|
||||
The updated model will be written in the same location as the original model, with '.onnx' updated to
|
||||
'.with_pre_post_processing.onnx'
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=["superresolution", "mobilenet"],
|
||||
help="Model type.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model_source",
|
||||
type=str,
|
||||
required=False,
|
||||
choices=["pytorch", "tensorflow"],
|
||||
default="pytorch",
|
||||
help="""
|
||||
Framework that model came from. In some cases there are known differences that can be taken into account when
|
||||
adding the pre/post processing to the model. Currently this equates to choosing different normalization
|
||||
behavior for mobilenet models.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument("model", type=Path, help="Provide path to ONNX model to update.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
model_path = args.model.resolve(strict=True)
|
||||
new_model_path = model_path.with_suffix(".with_pre_post_processing.onnx")
|
||||
|
||||
if args.model_type == "mobilenet":
|
||||
source = ModelSource.PYTORCH if args.model_source == "pytorch" else ModelSource.TENSORFLOW
|
||||
mobilenet(model_path, new_model_path, source)
|
||||
else:
|
||||
superresolution(model_path, new_model_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -26,19 +26,22 @@ OPMAP_TO_CMAKE_FLAGS = {
|
|||
'BertTokenizer': 'OCOS_ENABLE_BERT_TOKENIZER',
|
||||
'BasicTokenizer': 'OCOS_ENABLE_BERT_TOKENIZER',
|
||||
'BertTokenizerDecoder': 'OCOS_ENABLE_BERT_TOKENIZER',
|
||||
'SentencepieceTokenizer': 'OCOS_ENABLE_SPM_TOKENIZER'
|
||||
'SentencepieceTokenizer': 'OCOS_ENABLE_SPM_TOKENIZER',
|
||||
'ImageDecode': 'OCOS_ENABLE_VISION',
|
||||
'ImageEncode': 'OCOS_ENABLE_VISION',
|
||||
}
|
||||
|
||||
|
||||
def gen_cmake_oplist(opconfig_file, oplist_cmake_file='_selectedoplist.cmake'):
|
||||
ext_domain = "ai.onnx.contrib" # default_opset_domain()
|
||||
new_ext_domain = "com.microsoft.extensions"
|
||||
ext_domain_cnt = 0
|
||||
cmake_options = set()
|
||||
with open(oplist_cmake_file, 'w') as f:
|
||||
print("# Auto-Generated File, please do not edit!!!", file=f)
|
||||
with open(opconfig_file, 'r') as opfile:
|
||||
for _ln in opfile:
|
||||
if _ln.startswith(ext_domain):
|
||||
if _ln.startswith(ext_domain) or _ln.startswith(new_ext_domain):
|
||||
ext_domain_cnt += 1
|
||||
items = _ln.strip().split(';')
|
||||
if len(items) < 3:
|
||||
|
@ -55,8 +58,8 @@ def gen_cmake_oplist(opconfig_file, oplist_cmake_file='_selectedoplist.cmake'):
|
|||
print("# End of Building the Operator CMake variables", file=f)
|
||||
|
||||
if ext_domain_cnt == 0:
|
||||
print('[onnxruntime-extensions] warning: lines starting with extension domain (ai.onnx.contrib) in operators'
|
||||
' config file is 0')
|
||||
print('[onnxruntime-extensions] warning: lines starting with extension domains of ai.onnx.contrib or '
|
||||
'com.microsoft.extensions in operators config file is 0')
|
||||
|
||||
print('[onnxruntime-extensions] The cmake tool file has been generated successfully.')
|
||||
|
||||
|
|
|
@ -0,0 +1,205 @@
|
|||
# Example usage of the PrePostProcessor
|
||||
|
||||
The PrePostProcessor can be used to add pre and post processing operations to an existing model.
|
||||
|
||||
Currently the easiest way to use it is to download this folder and import PrePostProcessor and the Steps into your python script.
|
||||
We will provide a python package that includes it in the next release.
|
||||
|
||||
|
||||
## Initial imports
|
||||
|
||||
Import the PrePostProcessor, the steps and a utility to simplify creating new model inputs.
|
||||
|
||||
```py
|
||||
import onnx
|
||||
from pre_post_processing import PrePostProcessor
|
||||
from pre_post_processing.steps import *
|
||||
from pre_post_processing.utils import create_named_value
|
||||
```
|
||||
|
||||
## Example of creating the pre and post processing pipelines
|
||||
|
||||
The following is an example pre-processing pipeline to update a model to take bytes from an jpg or png image as input.
|
||||
The original model input was pre-processed float data with shape {1, channels, 244, 244}, requiring the user to
|
||||
manually convert their input image to this format.
|
||||
|
||||
### Create new input/s for the model
|
||||
|
||||
First, if you're adding pre-processing you need to create new inputs to the model that the pre-processing will use.
|
||||
|
||||
In our example we'll create a new input called 'image' containing uint8 data of length 'num_bytes'.
|
||||
|
||||
```py
|
||||
new_input = create_named_value('image', onnx.TensorProto.UINT8, ['num_bytes'])
|
||||
```
|
||||
|
||||
### Create PrePostProcessor
|
||||
|
||||
Create our PrePostProcessor instance with the new input/s.
|
||||
|
||||
```py
|
||||
pipeline = PrePostProcessor([new_input])
|
||||
```
|
||||
|
||||
### Add pre-processing steps
|
||||
|
||||
Add the preprocessing steps to the PrePostProcessor in the desired order.
|
||||
You can pick-and-choose from the predefined steps in the pre_post_processing.Steps module or create your own custom steps.
|
||||
If there's some common pre or post processing functionality that is missing please reach out and we'll look at adding
|
||||
the necessary Step implementations for it.
|
||||
|
||||
Configure the steps as needed.
|
||||
|
||||
```py
|
||||
pipeline.add_pre_processing(
|
||||
[
|
||||
ConvertImageToBGR(), # jpg/png image to BGR in HWC layout. output shape is {h_in, w_in, channels}
|
||||
Resize(256), # resize so smallest side is 256.
|
||||
CenterCrop(224, 224),
|
||||
ChannelsLastToChannelsFirst(), # ONNX models are typically channels first. output shape is {channels, 244, 244}
|
||||
ImageBytesToFloat(), # Convert uint8 values in range 0..255 to float values in range 0..1
|
||||
Unsqueeze(axes=[0]), # add batch dim so shape is {1, channels, 244, 244}. we now match the original model input
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
Outputs from the previous step will be automatically connected to the next step (or model in the case of the last step),
|
||||
in the same order.
|
||||
i.e. the first output of the previous step is connected to the first input of the next step, etc. etc.
|
||||
until we run out of outputs or inputs (whichever happens first).
|
||||
|
||||
It is also possible to manually specify connections. See [IoMapEntry](#iomapentry_usage)
|
||||
|
||||
|
||||
### Add post-processing steps
|
||||
|
||||
Similarly the post-processing is assembled the same way. Let's say it's simply a case of applying Softmax to the
|
||||
first model output:
|
||||
|
||||
``` py
|
||||
pipeline.add_pre_processing(
|
||||
[
|
||||
Softmax()
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
Neither pre-processing or post-processing is required. Simply add what you need for your model.
|
||||
|
||||
### Execute pipeline
|
||||
|
||||
Once we have assembled our pipeline we simply run it with the original model, and save the output.
|
||||
|
||||
The last pre-processing step is automatically connected to the original model inputs,
|
||||
and the first post-processing step is automatically connected to the original model outputs.
|
||||
|
||||
```py
|
||||
model = onnx.load('my_model.onnx')
|
||||
new_model = pipeline.run(model)
|
||||
onnx.save_model(new_model, 'my_model.with_pre_post_processing.onnx')
|
||||
```
|
||||
|
||||
|
||||
## Helper to create new named model inputs.
|
||||
|
||||
The `create_named_value` helper from [pre_post_processing.utils](./docs/pre_post_processing/utils.md#) can be used
|
||||
to create model inputs.
|
||||
|
||||
- The `name` value must be unique for the model.
|
||||
- The `data_type` should be an onnx.TensorProto value like onnx.TensorProto.UINT8 or onnx.TensorProto.FLOAT from the
|
||||
list defined [here](https://github.com/onnx/onnx/blob/759907808db622938082c6eeaa8f685dee3dc868/onnx/onnx.proto#L483).
|
||||
- The `shape` specifies the input shape. Use int for dimensions with known values and strings for symbolic dimensions.
|
||||
e.g. ['batch_size', 1024] would be a rank 2 tensor with a symbolic first dimension named 'batch_size'.
|
||||
|
||||
|
||||
## IoMapEntry usage
|
||||
|
||||
When the automatic connection of outputs from the previous step to inputs of the current step is insufficient,
|
||||
an IoMapEntry can be used to explicitly specify connections.
|
||||
|
||||
As an example, let's look at a subset of the operations in the pre and post processing for a super resolution model.
|
||||
In the pre-processing we convert the input from RGB to YCbCr using `PixelsToYCbCr`.
|
||||
That step produces 3 separate outputs - `Y`, `Cb` and `Cr`. The model has one input and is automatically connected
|
||||
to the `Y` output when PixelsToYCbCr is the last pre-processing step.
|
||||
We want to consume the `Cr` and `Cr` outputs in the post-processing by joining that with new `Y'` model output.
|
||||
|
||||
|
||||
```py
|
||||
pipeline = PrePostProcessor(inputs)
|
||||
pipeline.add_pre_processing(
|
||||
[
|
||||
...
|
||||
# this produces Y, Cb and Cr outputs. each has shape {h_in, w_in}. only Y is input to model
|
||||
PixelsToYCbCr(layout="BGR"),
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
In order to do that, the post-processing entry can be specified as a tuple of the Step and a list of IoMapEntries.
|
||||
Each IoMapEntry has a simple structure of `IoMapEntry(producer, producer_idx, consumer_idx)`. The `producer` is the
|
||||
name of the Step that produces the output. The `producer_idx` is the index of the output from that step. The `consumer_idx`
|
||||
is the input number of the Step that we want to connect to.
|
||||
|
||||
|
||||
```py
|
||||
pipeline.add_post_processing(
|
||||
[
|
||||
# as we're selecting outputs from multiple previous steps we need to map them to the inputs using step names
|
||||
(
|
||||
YCbCrToPixels(layout="BGR"),
|
||||
[
|
||||
# the first model output is automatically joined to consumer_idx=0
|
||||
IoMapEntry("PixelsToYCbCr", producer_idx=1, consumer_idx=1), # Cb value
|
||||
IoMapEntry("PixelsToYCbCr", producer_idx=2, consumer_idx=2) # Cr value
|
||||
],
|
||||
),
|
||||
ConvertBGRToImage(image_format="png")
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
By default the name for the each Step is the class name. When instantiating a step you can override the `name` property
|
||||
to provide a more descriptive name or resolve ambiguity (e.g. if there are multiple steps of the same type).
|
||||
|
||||
In our example, if we used `PixelsToYCbCr(layout="BGR", name="ImageConverter")` in the pre-processing step,
|
||||
we would use `IoMapEntry("ImageConverter", producer_idx=1, consumer_idx=1)` in the post-processing step to match that
|
||||
name.
|
||||
|
||||
Note that the automatic connection between steps will still occur. The list of IoMapEntry values is used to override the
|
||||
automatic connections, so you only need to provide an IoMapEntry for connections that need customization. In our
|
||||
example the model output is automatically connected to the first input of the `YCbCrToPixels` step so it wasn't
|
||||
necessary to provide an IoMapEntry for consumer_idx=0.
|
||||
|
||||
|
||||
## Debug step usage
|
||||
|
||||
If you are creating your own pipeline if can sometimes be necessary to inspect the output of a pre or post processing
|
||||
step if the final results are unexpected. The easiest way to do this is to insert a `Debug` step into the pipeline.
|
||||
|
||||
The Debug step will create graph outputs for the outputs from the previous step. That means they will be available
|
||||
as outputs when running the updated model, and can be inspected.
|
||||
|
||||
The Debug step will also pass through its inputs to the next step, so no other changes to the pipeline are required.
|
||||
|
||||
Considering our pre-processing example, if we wanted to inspect the result of the conversion from an input image
|
||||
we can insert a Debug step like below. The existing steps remain unchanged.
|
||||
|
||||
```py
|
||||
pipeline.add_pre_processing(
|
||||
[
|
||||
ConvertImageToBGR(), # jpg/png image to BGR in HWC layout. output shape is {h_in, w_in, channels}
|
||||
Debug(),
|
||||
Resize(256), # resize so smallest side is 256.
|
||||
```
|
||||
|
||||
The model will now have an additional output called 'bgr_data' (the default output name of the ConvertImageToBGR step).
|
||||
|
||||
Note that if the previous step produces multiple outputs the Debug step must be configured with this information.
|
||||
|
||||
e.g.
|
||||
|
||||
```py
|
||||
PixelsToYCbCr(layout="BGR"),
|
||||
Debug(num_inputs=3),
|
||||
...
|
||||
```
|
|
@ -0,0 +1,4 @@
|
|||
from .pre_post_processor import PrePostProcessor
|
||||
from .step import Step
|
||||
from .utils import *
|
||||
from .steps import *
|
|
@ -0,0 +1,9 @@
|
|||
Module pre_post_processing
|
||||
==========================
|
||||
|
||||
Sub-modules
|
||||
-----------
|
||||
* pre_post_processing.pre_post_processor
|
||||
* pre_post_processing.step
|
||||
* pre_post_processing.steps
|
||||
* pre_post_processing.utils
|
|
@ -0,0 +1,39 @@
|
|||
Module pre_post_processing.pre_post_processor
|
||||
=============================================
|
||||
|
||||
Classes
|
||||
-------
|
||||
|
||||
`PrePostProcessor(inputs: List[onnx.onnx_ml_pb2.ValueInfoProto] = None, outputs: List[onnx.onnx_ml_pb2.ValueInfoProto] = None)`
|
||||
: Class to handle running all the pre/post processing steps and updating the model.
|
||||
|
||||
### Methods
|
||||
|
||||
`add_post_processing(self, items: List[Union[pre_post_processing.step.Step, Tuple[pre_post_processing.step.Step, List[pre_post_processing.utils.IoMapEntry]]]])`
|
||||
: Add the post-processing steps. The first step is automatically joined to the original model outputs.
|
||||
|
||||
Options are:
|
||||
Add Step with default connection of outputs from the previous step (if available) to inputs of this step.
|
||||
Add tuple of Step and list of IoMapEntry instances for connections to previous steps. This will be
|
||||
used to override any automatic connections.
|
||||
If IoMapEntry.producer is None it is inferred to be the immediately previous Step.
|
||||
If IoMapEntry.producer is a step name it must match the name of a previous step.
|
||||
|
||||
`add_pre_processing(self, items: List[Union[pre_post_processing.step.Step, Tuple[pre_post_processing.step.Step, List[pre_post_processing.utils.IoMapEntry]]]])`
|
||||
: Add the pre-processing steps. The last step is automatically joined to the original model inputs.
|
||||
|
||||
Options are:
|
||||
Add Step with default connection of outputs from the previous step (if available) to inputs of this step.
|
||||
Add tuple of Step and list of IoMapEntry instances for manual connections to previous steps. This will be
|
||||
used to override any automatic connections.
|
||||
If IoMapEntry.producer is None it is inferred to be the immediately previous Step.
|
||||
If IoMapEntry.producer is a step name it must match the name of a previous step.
|
||||
|
||||
`run(self, model: onnx.onnx_ml_pb2.ModelProto)`
|
||||
: Update the model with the graph from each step in the pre and post processing pipelines.
|
||||
|
||||
Args:
|
||||
model: model to add pre/post processing to.
|
||||
|
||||
Returns:
|
||||
model with pre/post processing in it.
|
|
@ -0,0 +1,65 @@
|
|||
Module pre_post_processing.step
|
||||
===============================
|
||||
|
||||
Classes
|
||||
-------
|
||||
|
||||
`Debug(num_inputs: int = 1, name: Optional[str] = None)`
|
||||
: Step that can be arbitrarily inserted in the pre or post processing pipeline.
|
||||
It will make the outputs of the previous Step also become graph outputs so their value can be more easily debugged.
|
||||
|
||||
NOTE: Depending on when the previous Step's outputs are consumed in the pipeline the graph output for it
|
||||
may or may not have '_debug' as a suffix.
|
||||
TODO: PrePostProcessor __cleanup_graph_output_names could also hide the _debug by inserting an Identity node
|
||||
to rename so it's more consistent.
|
||||
|
||||
Initialize Debug step
|
||||
Args:
|
||||
num_inputs: Number of inputs from previous Step to make graph outputs.
|
||||
name: Optional name for Step. Defaults to 'Debug'
|
||||
|
||||
### Ancestors (in MRO)
|
||||
|
||||
* pre_post_processing.step.Step
|
||||
|
||||
`Step(inputs: List[str], outputs: List[str], name: Optional[str] = None)`
|
||||
: Base class for a pre or post processing step.
|
||||
|
||||
Initialize the step.
|
||||
|
||||
Args:
|
||||
inputs: List of default input names.
|
||||
outputs: List of default output names.
|
||||
name: Step name. Defaults to the derived class name.
|
||||
|
||||
### Descendants
|
||||
|
||||
* pre_post_processing.step.Debug
|
||||
* pre_post_processing.steps.general.ReverseAxis
|
||||
* pre_post_processing.steps.general.Softmax
|
||||
* pre_post_processing.steps.general.Squeeze
|
||||
* pre_post_processing.steps.general.Transpose
|
||||
* pre_post_processing.steps.general.Unsqueeze
|
||||
* pre_post_processing.steps.vision.CenterCrop
|
||||
* pre_post_processing.steps.vision.ConvertBGRToImage
|
||||
* pre_post_processing.steps.vision.ConvertImageToBGR
|
||||
* pre_post_processing.steps.vision.FloatToImageBytes
|
||||
* pre_post_processing.steps.vision.ImageBytesToFloat
|
||||
* pre_post_processing.steps.vision.Normalize
|
||||
* pre_post_processing.steps.vision.PixelsToYCbCr
|
||||
* pre_post_processing.steps.vision.Resize
|
||||
* pre_post_processing.steps.vision.YCbCrToPixels
|
||||
|
||||
### Class variables
|
||||
|
||||
`prefix`
|
||||
:
|
||||
|
||||
### Methods
|
||||
|
||||
`apply(self, graph: onnx.onnx_ml_pb2.GraphProto)`
|
||||
: Append the nodes that implement this step to the provided graph.
|
||||
|
||||
`connect(self, entry: pre_post_processing.utils.IoMapEntry)`
|
||||
: Connect the value name from a previous step to an input of this step so they match.
|
||||
This makes joining the GraphProto created by each step trivial.
|
|
@ -0,0 +1,70 @@
|
|||
Module pre_post_processing.steps.general
|
||||
========================================
|
||||
|
||||
Classes
|
||||
-------
|
||||
|
||||
`ReverseAxis(axis: int = -1, dim_value: int = -1, name: Optional[str] = None)`
|
||||
: Reverses the data in an axis by splitting and concatenating in reverse order.
|
||||
e.g. convert RGB ordered data to BGR.
|
||||
Output data type and shape is the same as the input.
|
||||
|
||||
Args:
|
||||
axis: Axis to reverse. Default is last axis.
|
||||
dim_value: Explicit value for size of dimension being reversed.
|
||||
This can be provided if the axis being reversed currently has a symbolic value.
|
||||
Note that this will fail during graph execution if the actual value at runtime does not match.
|
||||
If not provided, the size of the dimension to reverse is inferred from the input shape.
|
||||
name: Optional Step name. Defaults to 'ReverseAxis'
|
||||
|
||||
### Ancestors (in MRO)
|
||||
|
||||
* pre_post_processing.step.Step
|
||||
|
||||
`Softmax(name: Optional[str] = None)`
|
||||
: ONNX Softmax
|
||||
|
||||
Args:
|
||||
name: Optional Step name. Defaults to 'Softmax'
|
||||
|
||||
### Ancestors (in MRO)
|
||||
|
||||
* pre_post_processing.step.Step
|
||||
|
||||
`Squeeze(axes: Optional[List[int]] = None, name: Optional[str] = None)`
|
||||
: ONNX Squeeze
|
||||
|
||||
Args:
|
||||
axes: Axes to remove.
|
||||
If None, remove all axes with size of 1. Requires all dimensions to have explicit values.
|
||||
name: Optional Step name. Defaults to 'Squeeze'
|
||||
|
||||
### Ancestors (in MRO)
|
||||
|
||||
* pre_post_processing.step.Step
|
||||
|
||||
`Transpose(perms: List[int], name: Optional[str] = None)`
|
||||
: ONNX Transpose.
|
||||
|
||||
Args:
|
||||
perms: List of integers with permutations to apply.
|
||||
name: Optional Step name. Defaults to 'Transpose'
|
||||
|
||||
### Ancestors (in MRO)
|
||||
|
||||
* pre_post_processing.step.Step
|
||||
|
||||
### Descendants
|
||||
|
||||
* pre_post_processing.steps.vision.ChannelsLastToChannelsFirst
|
||||
|
||||
`Unsqueeze(axes: List[int], name: Optional[str] = None)`
|
||||
: ONNX Unsqueeze
|
||||
|
||||
Args:
|
||||
axes: List of integers indicating the dimensions to be inserted.
|
||||
name: Optional Step name. Defaults to 'Unsqueeze'
|
||||
|
||||
### Ancestors (in MRO)
|
||||
|
||||
* pre_post_processing.step.Step
|
|
@ -0,0 +1,7 @@
|
|||
Module pre_post_processing.steps
|
||||
================================
|
||||
|
||||
Sub-modules
|
||||
-----------
|
||||
* pre_post_processing.steps.general
|
||||
* pre_post_processing.steps.vision
|
|
@ -0,0 +1,143 @@
|
|||
Module pre_post_processing.steps.vision
|
||||
=======================================
|
||||
|
||||
Classes
|
||||
-------
|
||||
|
||||
`CenterCrop(height: int, width: int, name: Optional[str] = None)`
|
||||
: Crop the input to the requested dimensions, with the crop being centered.
|
||||
|
||||
Args:
|
||||
height: Height of area to crop.
|
||||
width: Width of area to crop.
|
||||
name: Optional step name. Defaults to 'CenterCrop'
|
||||
|
||||
### Ancestors (in MRO)
|
||||
|
||||
* pre_post_processing.step.Step
|
||||
|
||||
`ChannelsLastToChannelsFirst(has_batch_dim: bool = False, name: Optional[str] = None)`
|
||||
: Convert channels last data to channels first.
|
||||
Input can be NHWC or HWC.
|
||||
|
||||
Args:
|
||||
has_batch_dim: Set to True if the input has a batch dimension (i.e. is NHWC)
|
||||
name: Optional step name. Defaults to 'ChannelsLastToChannelsFirst'
|
||||
|
||||
### Ancestors (in MRO)
|
||||
|
||||
* pre_post_processing.steps.general.Transpose
|
||||
* pre_post_processing.step.Step
|
||||
|
||||
`ConvertBGRToImage(image_format: str = 'jpg', name: Optional[str] = None)`
|
||||
: Convert BGR ordered uint8 data into an encoded image.
|
||||
Supported output input formats: jpg, png
|
||||
Input shape: {input_image_height, input_image_width, 3}
|
||||
Output shape: {num_encoded_bytes}
|
||||
|
||||
Args:
|
||||
image_format: Format to encode to. jpg and png are supported.
|
||||
name: Optional step name. Defaults to 'ConvertBGRToImage'
|
||||
|
||||
### Ancestors (in MRO)
|
||||
|
||||
* pre_post_processing.step.Step
|
||||
|
||||
`ConvertImageToBGR(name: Optional[str] = None)`
|
||||
: Convert the bytes of an image by decoding to BGR ordered uint8 values.
|
||||
Supported input formats: jpg, png
|
||||
Input shape: {num_encoded_bytes}
|
||||
Output shape: {input_image_height, input_image_width, 3}
|
||||
|
||||
Args:
|
||||
name: Optional name of step. Defaults to 'ConvertImageToBGR'
|
||||
|
||||
NOTE: Input image format is inferred and does not need to be specified.
|
||||
|
||||
### Ancestors (in MRO)
|
||||
|
||||
* pre_post_processing.step.Step
|
||||
|
||||
`FloatToImageBytes(multiplier: float = 255.0, name: Optional[str] = None)`
|
||||
: Converting floating point values to uint8 values in range 0..255.
|
||||
Typically this reverses ImageBytesToFloat by converting input data in the range 0..1, but an optional multiplier
|
||||
can be specified if the input data has a different range.
|
||||
Values will be rounded prior to clipping and conversion to uint8.
|
||||
|
||||
Args:
|
||||
multiplier: Optional multiplier. Currently, the expected values are 255 (input data is in range 0..1), or
|
||||
1 (input data is in range 0..255).
|
||||
name: Optional step name. Defaults to 'FloatToImageBytes'
|
||||
|
||||
### Ancestors (in MRO)
|
||||
|
||||
* pre_post_processing.step.Step
|
||||
|
||||
`ImageBytesToFloat(name: Optional[str] = None)`
|
||||
: Convert uint8 or float values in range 0..255 to floating point values in range 0..1
|
||||
|
||||
Args:
|
||||
name: Optional step name. Defaults to 'ImageBytesToFloat'
|
||||
|
||||
### Ancestors (in MRO)
|
||||
|
||||
* pre_post_processing.step.Step
|
||||
|
||||
`Normalize(normalization_values: List[Tuple[float, float]], layout: str = 'CHW', name: Optional[str] = None)`
|
||||
: Normalize input data on a per-channel basis.
|
||||
`x -> (x - mean) / stddev`
|
||||
Output is float with same shape as input.
|
||||
|
||||
Args:
|
||||
normalization_values: Tuple with (mean, stddev). One entry per channel.
|
||||
If single entry is provided it will be used for all channels.
|
||||
layout: Input layout. Can be 'CHW' or 'HWC'
|
||||
name: Optional step name. Defaults to 'Normalize'
|
||||
|
||||
### Ancestors (in MRO)
|
||||
|
||||
* pre_post_processing.step.Step
|
||||
|
||||
`PixelsToYCbCr(layout: str = 'BGR', name: Optional[str] = None)`
|
||||
: Convert RGB or BGR pixel data to YCbCr format.
|
||||
Input shape: {height, width, 3}
|
||||
Output shape is the same.
|
||||
Output data is float, but rounded and clipped to the range 0..255 as per the spec for YCbCr conversion.
|
||||
|
||||
Args:
|
||||
layout: Input data layout. Can be 'BGR' or 'RGB'
|
||||
name: Optional step name. Defaults to 'PixelsToYCbCr'
|
||||
|
||||
### Ancestors (in MRO)
|
||||
|
||||
* pre_post_processing.step.Step
|
||||
|
||||
`Resize(resize_to: Union[int, Tuple[int, int]], layout: str = 'HWC', name: Optional[str] = None)`
|
||||
: Resize input data. Aspect ratio is maintained.
|
||||
e.g. if image is 1200 x 600 and 300 x 300 is requested the result will be 600 x 300
|
||||
|
||||
Args:
|
||||
resize_to: Target size. Can be a single value or a tuple with (target_height, target_width).
|
||||
The aspect ratio will be maintained and neither height or width in the result will be smaller
|
||||
than the requested value.
|
||||
layout: Input layout. 'CHW', 'HWC' and 'HW' are supported.
|
||||
name: Optional name. Defaults to 'Resize'
|
||||
|
||||
### Ancestors (in MRO)
|
||||
|
||||
* pre_post_processing.step.Step
|
||||
|
||||
`YCbCrToPixels(layout: str = 'BGR', name: Optional[str] = None)`
|
||||
: Convert YCbCr input to RGB or BGR.
|
||||
|
||||
Input data can be uint8 or float but all inputs must use the same type.
|
||||
Input shape: {height, width, 3}
|
||||
Output shape is the same.
|
||||
|
||||
Args:
|
||||
layout: Output layout. Can be 'BGR' or 'RGB'
|
||||
name: Optional step name. Defaults to 'YCbCrToPixels'
|
||||
|
||||
### Ancestors (in MRO)
|
||||
|
||||
* pre_post_processing.step.Step
|
|
@ -0,0 +1,60 @@
|
|||
Module pre_post_processing.utils
|
||||
================================
|
||||
|
||||
Functions
|
||||
---------
|
||||
|
||||
|
||||
`create_custom_op_checker_context()`
|
||||
: Create an ONNX checker context that includes the ort-extensions custom op domains so that custom ops don't
|
||||
cause failure when running onnx.checker.check_graph.
|
||||
Returns:
|
||||
|
||||
|
||||
`create_named_value(name: str, data_type: int, shape: List[Union[str, int]])`
|
||||
: Helper to create a new model input.
|
||||
|
||||
Args:
|
||||
name: Name for input. Must not already be in use in the model being updated.
|
||||
data_type: onnx.TensorProto data type. e.g. onnx.TensorProto.FLOAT, onnx.TensorProto.UINT8
|
||||
shape: Input shape. Use int for dimensions with known values and strings for symbolic dimensions.
|
||||
e.g. ['batch_size', 256, 256] would be a rank 3 tensor with a symbolic first dimension named 'batch_size'
|
||||
|
||||
|
||||
Returns:
|
||||
An onnx.ValueInfoProto that can be used as a new model input.
|
||||
|
||||
|
||||
`get_opset_imports()`
|
||||
: Get the opset imports for a model updated by the PrePostProcessor.
|
||||
|
||||
|
||||
`sanitize_output_names(graph: onnx.onnx_ml_pb2.GraphProto)`
|
||||
: Convert any usage of invalid characters like '/' and ';' in value names to '_'
|
||||
This is common in models exported from TensorFlow [Lite].
|
||||
|
||||
ONNX parse_graph does not allow for that in a value name, and technically it's a violation of the ONNX spec as per
|
||||
https://github.com/onnx/onnx/blob/main/docs/IR.md#names-within-a-graph
|
||||
|
||||
We do this for the original graph outputs only. The invalid naming has not been seen in model inputs, and we can
|
||||
leave the internals of the graph intact to minimize changes.
|
||||
|
||||
Args:
|
||||
graph: Graph to check and update any invalid names
|
||||
|
||||
Classes
|
||||
-------
|
||||
|
||||
`IoMapEntry(producer: Union[ForwardRef('Step'), str] = None, producer_idx: int = 0, consumer_idx: int = 0)`
|
||||
: Entry to map the output index from a producer step to the input index of a consumer step.
|
||||
|
||||
### Class variables
|
||||
|
||||
`consumer_idx: int`
|
||||
:
|
||||
|
||||
`producer: Union[Step, str]`
|
||||
:
|
||||
|
||||
`producer_idx: int`
|
||||
:
|
|
@ -0,0 +1,334 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import onnx
|
||||
|
||||
from onnx import version_converter
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
from .utils import (
|
||||
IoMapEntry,
|
||||
get_opset_imports,
|
||||
sanitize_output_names,
|
||||
PRE_POST_PROCESSING_ONNX_OPSET,
|
||||
TENSOR_TYPE_TO_ONNX_TYPE,
|
||||
)
|
||||
from .step import Step
|
||||
|
||||
|
||||
class PrePostProcessor:
|
||||
"""
|
||||
Class to handle running all the pre/post processing steps and updating the model.
|
||||
"""
|
||||
|
||||
def __init__(self, inputs: List[onnx.ValueInfoProto] = None, outputs: List[onnx.ValueInfoProto] = None):
|
||||
self.pre_processors = []
|
||||
self.post_processors = []
|
||||
|
||||
# Connections for each pre/post processor. 1:1 mapping with entries in pre_processors/post_processors
|
||||
self._pre_processor_connections = [] # type: List[List[IoMapEntry]]
|
||||
self._post_processor_connections = [] # type: List[List[IoMapEntry]]
|
||||
|
||||
# explicitly join outputs from Steps in pre_processors to inputs of the original model
|
||||
# format is Step or step name, step_idx, name of graph input/output
|
||||
#
|
||||
# Pre-processing we connect Step output to original model:
|
||||
# - step_idx is for Step.output_names, and name is in graph.input
|
||||
#
|
||||
# Post-processing we connect the original model output to the Step input
|
||||
# - step_idx is for Step.input_names, and name is in graph.output
|
||||
self._pre_processing_joins = None # type: Union[None,List[Tuple[Union[Step, str], int, str]]]
|
||||
self._post_processing_joins = None # type: Union[None,List[Tuple[Union[Step, str], int, str]]]
|
||||
|
||||
self._inputs = inputs if inputs else []
|
||||
self._outputs = outputs if outputs else []
|
||||
|
||||
def add_pre_processing(self, items: List[Union[Step, Tuple[Step, List[IoMapEntry]]]]):
|
||||
"""
|
||||
Add the pre-processing steps. The last step is automatically joined to the original model inputs.
|
||||
|
||||
Options are:
|
||||
Add Step with default connection of outputs from the previous step (if available) to inputs of this step.
|
||||
Add tuple of Step and list of IoMapEntry instances for manual connections to previous steps. This will be
|
||||
used to override any automatic connections.
|
||||
If IoMapEntry.producer is None it is inferred to be the immediately previous Step.
|
||||
If IoMapEntry.producer is a step name it must match the name of a previous step.
|
||||
"""
|
||||
self.__add_processing(self.pre_processors, self._pre_processor_connections, items)
|
||||
|
||||
def add_post_processing(self, items: List[Union[Step, Tuple[Step, List[IoMapEntry]]]]):
|
||||
"""
|
||||
Add the post-processing steps. The first step is automatically joined to the original model outputs.
|
||||
|
||||
Options are:
|
||||
Add Step with default connection of outputs from the previous step (if available) to inputs of this step.
|
||||
Add tuple of Step and list of IoMapEntry instances for connections to previous steps. This will be
|
||||
used to override any automatic connections.
|
||||
If IoMapEntry.producer is None it is inferred to be the immediately previous Step.
|
||||
If IoMapEntry.producer is a step name it must match the name of a previous step.
|
||||
"""
|
||||
self.__add_processing(self.post_processors, self._post_processor_connections, items)
|
||||
|
||||
def _add_connection(self, consumer: Step, entry: IoMapEntry):
|
||||
producer = self.__producer_from_step_or_str(entry.producer)
|
||||
|
||||
# Black does annoying things with the multi-line 'if' conditions making the code far less readable
|
||||
# fmt: off
|
||||
if not ((producer in self.pre_processors or producer in self.post_processors) and
|
||||
(consumer in self.pre_processors or consumer in self.post_processors)):
|
||||
raise ValueError("Producer and Consumer processors must both be registered")
|
||||
|
||||
if producer in self.pre_processors:
|
||||
if (consumer in self.pre_processors and
|
||||
self.pre_processors.index(producer) > self.pre_processors.index(consumer)):
|
||||
raise ValueError("Producer was registered after consumer and cannot be connected")
|
||||
elif producer in self.post_processors:
|
||||
if consumer not in self.post_processors:
|
||||
raise ValueError("Cannot connect pre-processor consumer with post-processor producer")
|
||||
elif self.post_processors.index(producer) > self.post_processors.index(consumer):
|
||||
raise ValueError("Producer was registered after consumer and cannot be connected")
|
||||
# fmt: on
|
||||
|
||||
assert isinstance(producer, Step)
|
||||
consumer.connect(entry)
|
||||
|
||||
def run(self, model: onnx.ModelProto):
|
||||
"""
|
||||
Update the model with the graph from each step in the pre and post processing pipelines.
|
||||
|
||||
Args:
|
||||
model: model to add pre/post processing to.
|
||||
|
||||
Returns:
|
||||
model with pre/post processing in it.
|
||||
"""
|
||||
|
||||
# update the input model to the ONNX opset we're using. this is required as we implement the steps based on
|
||||
# the operator specs for this opset.
|
||||
model_opset = [
|
||||
entry.version for entry in model.opset_import if entry.domain == "" or entry.domain == "ai.onnx"
|
||||
][0]
|
||||
|
||||
if model_opset > PRE_POST_PROCESSING_ONNX_OPSET:
|
||||
# It will probably work if the user updates PRE_POST_PROCESSING_ONNX_OPSET to match the model
|
||||
# but there are no guarantees.
|
||||
# Would only break if ONNX operators used in the pre/post processing graphs have had spec changes.
|
||||
raise ValueError(f"Model opset is {model_opset} which is newer than the opset used by this script.")
|
||||
elif model_opset < PRE_POST_PROCESSING_ONNX_OPSET:
|
||||
model = onnx.version_converter.convert_version(model, PRE_POST_PROCESSING_ONNX_OPSET)
|
||||
|
||||
def name_nodes(new_graph: onnx.GraphProto, prefix: str):
|
||||
# simple helper so all nodes are named. this makes it far easier to debug any issues.
|
||||
idx = 0
|
||||
for n in new_graph.node:
|
||||
if not n.name:
|
||||
n.name = prefix + str(idx)
|
||||
idx += 1
|
||||
|
||||
def connect_and_run(graph: onnx.GraphProto, processor: Step, connections: List[IoMapEntry]):
|
||||
for connection in connections:
|
||||
assert connection.producer
|
||||
self._add_connection(processor, connection)
|
||||
|
||||
return processor.apply(graph)
|
||||
|
||||
# fix any invalid output names now if we're adding post-processing as the onnx parse_graph can't handle them
|
||||
if self.post_processors:
|
||||
sanitize_output_names(model.graph)
|
||||
|
||||
graph = model.graph
|
||||
# add pre-processing
|
||||
if self.pre_processors:
|
||||
# create empty graph with pass through of the requested input name
|
||||
pre_process_graph = onnx.GraphProto()
|
||||
for i in self._inputs:
|
||||
pre_process_graph.input.append(i)
|
||||
pre_process_graph.output.append(i)
|
||||
|
||||
for idx, step in enumerate(self.pre_processors):
|
||||
pre_process_graph = connect_and_run(pre_process_graph, step, self._pre_processor_connections[idx])
|
||||
|
||||
# name all the nodes for easier debugging
|
||||
name_nodes(pre_process_graph, "pre_process_")
|
||||
|
||||
if not self._pre_processing_joins:
|
||||
# default to 1:1 between outputs of last step with inputs of original model
|
||||
last_step = self.pre_processors[-1]
|
||||
num_entries = min(len(last_step.output_names), len(graph.input))
|
||||
self._pre_processing_joins = [(last_step, i, graph.input[i].name) for i in range(0, num_entries)]
|
||||
|
||||
# map the pre-processing outputs to graph inputs
|
||||
io_map = [] # type: List[Tuple[str, str]]
|
||||
for step, step_idx, graph_input in self._pre_processing_joins:
|
||||
io_map.append((step.output_names[step_idx], graph_input))
|
||||
|
||||
graph = onnx.compose.merge_graphs(pre_process_graph, graph, io_map)
|
||||
|
||||
# add post-processing
|
||||
if self.post_processors:
|
||||
orig_model_outputs = [o.name for o in model.graph.output]
|
||||
graph_outputs = [o.name for o in graph.output] # this may have additional outputs from pre-processing
|
||||
|
||||
# create default joins if needed
|
||||
if not self._post_processing_joins:
|
||||
# default to 1:1 between outputs of original model with inputs of first post-processing step
|
||||
first_step = self.post_processors[0]
|
||||
num_entries = min(len(first_step.input_names), len(orig_model_outputs))
|
||||
self._post_processing_joins = [(first_step, i, orig_model_outputs[i]) for i in range(0, num_entries)]
|
||||
|
||||
# update the input names for the steps to match the values produced by the model
|
||||
for step, step_idx, graph_output in self._post_processing_joins:
|
||||
assert graph_output in graph_outputs
|
||||
step.input_names[step_idx] = graph_output
|
||||
|
||||
# create empty graph with the values that will be available to the post-processing
|
||||
post_process_graph = onnx.GraphProto()
|
||||
for o in graph.output:
|
||||
post_process_graph.input.append(o)
|
||||
post_process_graph.output.append(o)
|
||||
|
||||
for idx, step in enumerate(self.post_processors):
|
||||
post_process_graph = connect_and_run(post_process_graph, step, self._post_processor_connections[idx])
|
||||
|
||||
name_nodes(post_process_graph, "post_process_")
|
||||
|
||||
# io_map should be 1:1 with the post-processing graph given we updated the step input names to match
|
||||
io_map = [(o, o) for o in graph_outputs]
|
||||
graph = onnx.compose.merge_graphs(graph, post_process_graph, io_map)
|
||||
|
||||
# Make the output names nicer by removing prefixing from naming that occurred when applying the steps
|
||||
graph = PrePostProcessor.__cleanup_graph_output_names(graph)
|
||||
|
||||
opset_imports = [onnx.helper.make_operatorsetid(domain, opset) for domain, opset in get_opset_imports().items()]
|
||||
new_model = onnx.helper.make_model(graph, opset_imports=opset_imports)
|
||||
|
||||
onnx.checker.check_model(new_model)
|
||||
|
||||
return new_model
|
||||
|
||||
def __add_processing(
|
||||
self,
|
||||
processors: List[Step],
|
||||
processor_connections: List[List[IoMapEntry]],
|
||||
items: List[Union[Step, Tuple[Step, List[IoMapEntry]]]],
|
||||
):
|
||||
"""
|
||||
Add the pre/post processing steps and join with existing steps.
|
||||
|
||||
Args:
|
||||
processors: List of processors to add items to.
|
||||
processor_connections: Populated with connections between each step. 1:1 with entries in processors.
|
||||
items: Items to add to processors.
|
||||
Can be:
|
||||
A Step instance. This will be implicitly joined to the immediately previous Step if one exists.
|
||||
A tuple of (Step instance, list of IoMapEntry)
|
||||
The IoMapEntry values are used to manually join an output from a producer Step to an input
|
||||
of the current Step.
|
||||
In each IoMapEntry, if a step name is provided the producer Step will be searched for in all
|
||||
predecessor steps. It is valid for a post-processor step to consume output from a
|
||||
pre-processor step.
|
||||
"""
|
||||
|
||||
for item in items:
|
||||
step = None
|
||||
explicit_io_map_entries = None
|
||||
|
||||
if isinstance(item, Step):
|
||||
step = item
|
||||
elif isinstance(item, tuple):
|
||||
step, explicit_io_map_entries = item
|
||||
else:
|
||||
raise ValueError("Unexpected type " + str(type(item)))
|
||||
|
||||
# start with implicit joins and replace with explicitly provided ones
|
||||
# this allows the user to specify the minimum number of manual joins.
|
||||
io_map_entries = [None] * len(step.input_names) # type: List[Union[None,IoMapEntry]]
|
||||
prev_step = None if len(processors) == 0 else processors[-1]
|
||||
if prev_step:
|
||||
# default is connecting as many outputs from the previous step as possible
|
||||
for i in range(0, min(len(prev_step.output_names), len(step.input_names))):
|
||||
io_map_entries[i] = IoMapEntry(prev_step, i, i)
|
||||
|
||||
# add explicit connections
|
||||
if explicit_io_map_entries:
|
||||
for entry in explicit_io_map_entries:
|
||||
if not entry.producer:
|
||||
producer = prev_step
|
||||
else:
|
||||
producer = self.__producer_from_step_or_str(entry.producer) # throws if not found
|
||||
|
||||
io_map_entries[entry.consumer_idx] = IoMapEntry(producer, entry.producer_idx, entry.consumer_idx)
|
||||
|
||||
processors.append(step)
|
||||
processor_connections.append([entry for entry in io_map_entries if entry is not None])
|
||||
|
||||
def __producer_from_step_or_str(self, entry: Union[Step, str]):
|
||||
if isinstance(entry, Step):
|
||||
return entry
|
||||
if isinstance(entry, str):
|
||||
match = (next((s for s in self.pre_processors if s.name == entry), None) or
|
||||
next((s for s in self.post_processors if s.name == entry), None)) # fmt: skip
|
||||
|
||||
if not match:
|
||||
raise ValueError(f"Step named {entry} was not found")
|
||||
|
||||
return match
|
||||
|
||||
@staticmethod
|
||||
def __cleanup_graph_output_names(graph: onnx.GraphProto):
|
||||
"""
|
||||
Hide the prefixing of names that happens when we merge the graphs from the pre/post processing steps.
|
||||
Not essential but makes the graph outputs look far nicer.
|
||||
"""
|
||||
|
||||
# for each output create identity node to remove prefixing
|
||||
io_map = []
|
||||
fixes = onnx.GraphProto()
|
||||
fixes.input.extend(graph.output)
|
||||
|
||||
# manually handle naming clashes
|
||||
input_names = set([i.name for i in graph.input])
|
||||
used_names = set(input_names)
|
||||
conflicts = 0
|
||||
|
||||
for o in graph.output:
|
||||
if not o.name.startswith(Step.prefix):
|
||||
continue
|
||||
|
||||
# we will create a small graph to do the renames so the output of the original graph will be an input
|
||||
# to that 'fixer' graph
|
||||
io_map.append((o.name, o.name))
|
||||
clean_name = o.name
|
||||
while clean_name.startswith(Step.prefix):
|
||||
# output from last step will have one prefixing stage that adds Step._prefix + '_'
|
||||
# e.g. '_ppp8_<orig_name>'
|
||||
next_underscore = clean_name.find("_", 1)
|
||||
if next_underscore > 0:
|
||||
# this check shouldn't be necessary as we always add the trailing '_' when prefixing...
|
||||
if len(clean_name) > next_underscore + 1:
|
||||
next_underscore += 1
|
||||
clean_name = clean_name[next_underscore:]
|
||||
|
||||
# handle things like super resolution where there's an 'image' input and 'image' output
|
||||
if clean_name in input_names:
|
||||
clean_name += "_out"
|
||||
|
||||
orig_clean_name = clean_name
|
||||
while clean_name in used_names:
|
||||
conflicts += 1
|
||||
clean_name = f"{orig_clean_name}{conflicts}"
|
||||
|
||||
used_names.add(clean_name)
|
||||
|
||||
renamer = onnx.helper.make_node("Identity", [o.name], [clean_name], f"Rename {o.name}")
|
||||
fixes.node.append(renamer)
|
||||
|
||||
new_output = fixes.output.add()
|
||||
new_output.name = clean_name
|
||||
new_output.type.CopyFrom(o.type)
|
||||
|
||||
# merge if we have any renaming to do
|
||||
if io_map:
|
||||
graph = onnx.compose.merge_graphs(graph, fixes, io_map)
|
||||
|
||||
return graph
|
|
@ -0,0 +1,11 @@
|
|||
Documentation was generated with pdoc3 (`pip install pdoc3`).
|
||||
From the parent directory:
|
||||
`python -m pdoc pdoc pre_post_processing -o ./pre_post_processing/docs --filter pre_post_processing`
|
||||
|
||||
This was just a quick way to get some initial docs.
|
||||
There are probably better python doc generation tools in the CI that can be used.
|
||||
|
||||
It's not ideal in that there are no links between the different md files
|
||||
e.g. the doc for the base Step class mentions the derived classes but doesn't provide links to read their doc.
|
||||
|
||||
However it does seem to document each class fairly well.
|
|
@ -0,0 +1,205 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import abc
|
||||
import onnx
|
||||
|
||||
from onnx import parser
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from .utils import (
|
||||
IoMapEntry,
|
||||
create_custom_op_checker_context,
|
||||
TENSOR_TYPE_TO_ONNX_TYPE,
|
||||
)
|
||||
|
||||
|
||||
class Step(object):
|
||||
"""Base class for a pre or post processing step."""
|
||||
|
||||
prefix = "_ppp"
|
||||
_step_num = 0 # unique step number so we can prefix the naming in the graph created for the step
|
||||
_custom_op_checker_context = create_custom_op_checker_context()
|
||||
|
||||
def __init__(self, inputs: List[str], outputs: List[str], name: Optional[str] = None):
|
||||
"""
|
||||
Initialize the step.
|
||||
|
||||
Args:
|
||||
inputs: List of default input names.
|
||||
outputs: List of default output names.
|
||||
name: Step name. Defaults to the derived class name.
|
||||
"""
|
||||
self.step_num = Step._step_num
|
||||
self.input_names = inputs
|
||||
self.output_names = outputs
|
||||
self.name = name if name else f"{self.__class__.__name__}"
|
||||
self._prefix = f"{Step.prefix}{self.step_num}_"
|
||||
|
||||
Step._step_num += 1
|
||||
|
||||
def connect(self, entry: IoMapEntry):
|
||||
"""
|
||||
Connect the value name from a previous step to an input of this step so they match.
|
||||
This makes joining the GraphProto created by each step trivial.
|
||||
"""
|
||||
assert len(entry.producer.output_names) >= entry.producer_idx
|
||||
assert len(self.input_names) >= entry.consumer_idx
|
||||
assert isinstance(entry.producer, Step)
|
||||
|
||||
self.input_names[entry.consumer_idx] = entry.producer.output_names[entry.producer_idx]
|
||||
|
||||
def apply(self, graph: onnx.GraphProto):
|
||||
"""Append the nodes that implement this step to the provided graph."""
|
||||
|
||||
graph_for_step = self._create_graph_for_step(graph)
|
||||
onnx.checker.check_graph(graph_for_step, Step._custom_op_checker_context)
|
||||
|
||||
# prefix the graph for this step to guarantee no clashes of value names with the existing graph
|
||||
onnx.compose.add_prefix_graph(graph_for_step, self._prefix, inplace=True)
|
||||
result = self.__merge(graph, graph_for_step)
|
||||
|
||||
# update self.output_names to the prefixed names so that when we connect later Steps the values match
|
||||
new_outputs = [self._prefix + o for o in self.output_names]
|
||||
result_outputs = [o.name for o in result.output]
|
||||
|
||||
# sanity check that all of our outputs are in the merged graph
|
||||
for o in new_outputs:
|
||||
assert o in result_outputs
|
||||
|
||||
self.output_names = new_outputs
|
||||
|
||||
return result
|
||||
|
||||
@abc.abstractmethod
|
||||
def _create_graph_for_step(self, graph: onnx.GraphProto):
|
||||
"""Derived class should implement this and return the GraphProto containing the nodes required to
|
||||
implement the step."""
|
||||
pass
|
||||
|
||||
def __merge(self, first: onnx.GraphProto, second: onnx.GraphProto):
|
||||
# We prefixed all the value names in `second`, so allow for that when connecting the two graphs
|
||||
io_map = []
|
||||
for o in first.output:
|
||||
# apply the same prefix to the output from the previous step to match the prefixed graph from this step
|
||||
prefixed_output = self._prefix + o.name
|
||||
for i in second.input:
|
||||
if i.name == prefixed_output:
|
||||
io_map.append((o.name, i.name))
|
||||
|
||||
outputs_to_preserve = None
|
||||
|
||||
# special handling of Debug class.
|
||||
if isinstance(self, Debug):
|
||||
# preserve outputs of the first graph so they're available downstream. otherwise they are consumed by
|
||||
# the Debug node and disappear during the ONNX graph_merge as it considers consumed values to be
|
||||
# internal - which is entirely reasonable when merging graphs.
|
||||
# the issue we have is that we don't know what future steps might want things to remain as outputs.
|
||||
# the current approach is to insert a Debug step which simply duplicates the values so that they are
|
||||
# guaranteed not be consumed (only one of the two copies will be used).
|
||||
# doesn't change the number of outputs from the previous step, so it can be transparently inserted in the
|
||||
# pre/post processing pipeline.
|
||||
# need to also list the second graph's outputs when manually specifying outputs.
|
||||
outputs_to_preserve = [o.name for o in first.output] + [o.name for o in second.output]
|
||||
|
||||
# merge with existing graph
|
||||
merged_graph = onnx.compose.merge_graphs(first, second, io_map, outputs=outputs_to_preserve)
|
||||
|
||||
return merged_graph
|
||||
|
||||
@staticmethod
|
||||
def _elem_type_str(elem_type: int):
|
||||
return TENSOR_TYPE_TO_ONNX_TYPE[elem_type]
|
||||
|
||||
@staticmethod
|
||||
def _shape_to_str(shape: onnx.TensorShapeProto):
|
||||
"""Returns the values from the shape as a comma separated string."""
|
||||
|
||||
def dim_to_str(dim):
|
||||
if dim.HasField("dim_value"):
|
||||
return str(dim.dim_value)
|
||||
elif dim.HasField("dim_param"):
|
||||
return dim.dim_param
|
||||
else:
|
||||
return ""
|
||||
|
||||
shape_str = ",".join([dim_to_str(dim) for dim in shape.dim])
|
||||
return shape_str
|
||||
|
||||
def _input_tensor_type(self, graph: onnx.GraphProto, input_num: int) -> onnx.TensorProto:
|
||||
"""Get the onnx.TensorProto for the input from the outputs of the graph we're appending to."""
|
||||
|
||||
input_type = None
|
||||
for o in graph.output:
|
||||
if o.name == self.input_names[input_num]:
|
||||
input_type = o.type.tensor_type
|
||||
break
|
||||
|
||||
if not input_type:
|
||||
raise ValueError(f"Input {self.input_names[input_num]} was not found in outputs of graph.")
|
||||
|
||||
return input_type
|
||||
|
||||
def _get_input_type_and_shape_strs(self, graph: onnx.GraphProto, input_num: int) -> Tuple[str, str]:
|
||||
input_type = self._input_tensor_type(graph, input_num)
|
||||
return Step._elem_type_str(input_type.elem_type), Step._shape_to_str(input_type.shape)
|
||||
|
||||
|
||||
# special case. we include the helper Debug step here as logic in the base class is conditional on it.
|
||||
class Debug(Step):
|
||||
"""
|
||||
Step that can be arbitrarily inserted in the pre or post processing pipeline.
|
||||
It will make the outputs of the previous Step also become graph outputs so their value can be more easily debugged.
|
||||
|
||||
NOTE: Depending on when the previous Step's outputs are consumed in the pipeline the graph output for it
|
||||
may or may not have '_debug' as a suffix.
|
||||
TODO: PrePostProcessor __cleanup_graph_output_names could also hide the _debug by inserting an Identity node
|
||||
to rename so it's more consistent.
|
||||
"""
|
||||
|
||||
def __init__(self, num_inputs: int = 1, name: Optional[str] = None):
|
||||
"""
|
||||
Initialize Debug step
|
||||
Args:
|
||||
num_inputs: Number of inputs from previous Step to make graph outputs.
|
||||
name: Optional name for Step. Defaults to 'Debug'
|
||||
"""
|
||||
self._num_inputs = num_inputs
|
||||
input_names = [f"input{i}" for i in range(0, num_inputs)]
|
||||
output_names = [f"debug{i}" for i in range(0, num_inputs)]
|
||||
|
||||
super().__init__(input_names, output_names, name)
|
||||
|
||||
def _create_graph_for_step(self, graph: onnx.GraphProto):
|
||||
input_str = ""
|
||||
output_str = ""
|
||||
output_debug_str = ""
|
||||
nodes_str = ""
|
||||
|
||||
# update output names so we preserve info from the latest input names
|
||||
self.output_names = [f"{name}_debug" for name in self.input_names]
|
||||
|
||||
for i in range(0, self._num_inputs):
|
||||
input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, i)
|
||||
if i > 0:
|
||||
input_str += ", "
|
||||
output_str += ", "
|
||||
output_debug_str += ", "
|
||||
nodes_str += "\n"
|
||||
|
||||
input_str += f"{input_type_str}[{input_shape_str}] {self.input_names[i]}"
|
||||
output_str += f"{input_type_str}[{input_shape_str}] {self.output_names[i]}"
|
||||
nodes_str += f"{self.output_names[i]} = Identity({self.input_names[i]})\n"
|
||||
|
||||
debug_graph = onnx.parser.parse_graph(
|
||||
f"""\
|
||||
debug ({input_str})
|
||||
=> ({output_str})
|
||||
{{
|
||||
{nodes_str}
|
||||
}}
|
||||
"""
|
||||
)
|
||||
|
||||
onnx.checker.check_graph(debug_graph)
|
||||
return debug_graph
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from .general import *
|
||||
from .vision import *
|
|
@ -0,0 +1,207 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import onnx
|
||||
from typing import List, Optional
|
||||
from ..step import Step
|
||||
|
||||
|
||||
class ReverseAxis(Step):
|
||||
"""
|
||||
Reverses the data in an axis by splitting and concatenating in reverse order.
|
||||
e.g. convert RGB ordered data to BGR.
|
||||
Output data type and shape is the same as the input.
|
||||
"""
|
||||
|
||||
def __init__(self, axis: int = -1, dim_value: int = -1, name: Optional[str] = None):
|
||||
"""
|
||||
Args:
|
||||
axis: Axis to reverse. Default is last axis.
|
||||
dim_value: Explicit value for size of dimension being reversed.
|
||||
This can be provided if the axis being reversed currently has a symbolic value.
|
||||
Note that this will fail during graph execution if the actual value at runtime does not match.
|
||||
If not provided, the size of the dimension to reverse is inferred from the input shape.
|
||||
name: Optional Step name. Defaults to 'ReverseAxis'
|
||||
"""
|
||||
super().__init__(["data"], ["data_with_reversed_axis"], name)
|
||||
self._axis = axis
|
||||
self._dim_value = dim_value
|
||||
|
||||
def _create_graph_for_step(self, graph: onnx.GraphProto):
|
||||
input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, 0)
|
||||
input_dims = input_shape_str.split(",")
|
||||
split_dim = input_dims[self._axis]
|
||||
|
||||
if split_dim.isdigit():
|
||||
dim_value = int(split_dim)
|
||||
if self._dim_value != -1:
|
||||
# TODO: Technically we don't require a match here. For now expect it to match.
|
||||
assert dim_value == self._dim_value
|
||||
else:
|
||||
self._dim_value = dim_value
|
||||
|
||||
split_outs = []
|
||||
for i in range(0, self._dim_value):
|
||||
split_outs.append(f"split_out_{i}")
|
||||
|
||||
reverse_graph = onnx.parser.parse_graph(
|
||||
f"""\
|
||||
reverse_axis ({input_type_str}[{input_shape_str}] {self.input_names[0]})
|
||||
=> ({input_type_str}[{input_shape_str}] {self.output_names[0]})
|
||||
{{
|
||||
{','.join(split_outs)} = Split <axis = {self._axis}> ({self.input_names[0]})
|
||||
{self.output_names[0]} = Concat <axis = {self._axis}> ({','.join(reversed(split_outs))})
|
||||
}}
|
||||
"""
|
||||
)
|
||||
|
||||
return reverse_graph
|
||||
|
||||
|
||||
class Squeeze(Step):
|
||||
"""
|
||||
ONNX Squeeze
|
||||
"""
|
||||
|
||||
def __init__(self, axes: Optional[List[int]] = None, name: Optional[str] = None):
|
||||
"""
|
||||
Args:
|
||||
axes: Axes to remove.
|
||||
If None, remove all axes with size of 1. Requires all dimensions to have explicit values.
|
||||
name: Optional Step name. Defaults to 'Squeeze'
|
||||
"""
|
||||
super().__init__(["data"], ["squeezed"], name)
|
||||
self._axes = axes
|
||||
|
||||
def _create_graph_for_step(self, graph: onnx.GraphProto):
|
||||
input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, 0)
|
||||
dims = input_shape_str.split(",")
|
||||
|
||||
axes = self._axes
|
||||
if not axes:
|
||||
axes = []
|
||||
for idx, dim in enumerate(dims):
|
||||
if not dim.isnumeric():
|
||||
# we can't infer the output shape if there are symbolic dims
|
||||
raise ValueError("Axes must be specified if there are symbolic dimensions.")
|
||||
|
||||
if dim == '1':
|
||||
axes.append(int(idx))
|
||||
|
||||
output_dims = [dim for idx, dim in enumerate(dims) if idx not in axes]
|
||||
output_shape_str = ",".join(output_dims)
|
||||
|
||||
axes_strs = [str(axis) for axis in axes]
|
||||
|
||||
squeeze_graph = onnx.parser.parse_graph(
|
||||
f"""\
|
||||
squeeze ({input_type_str}[{input_shape_str}] {self.input_names[0]})
|
||||
=> ({input_type_str}[{output_shape_str}] {self.output_names[0]})
|
||||
{{
|
||||
axes = Constant <value = int64[{len(axes)}] {{{','.join(axes_strs)}}}> ()
|
||||
{self.output_names[0]} = Squeeze({self.input_names[0]}, axes)
|
||||
}}
|
||||
"""
|
||||
)
|
||||
|
||||
return squeeze_graph
|
||||
|
||||
|
||||
class Transpose(Step):
|
||||
"""
|
||||
ONNX Transpose.
|
||||
"""
|
||||
|
||||
def __init__(self, perms: List[int], name: Optional[str] = None):
|
||||
"""
|
||||
Args:
|
||||
perms: List of integers with permutations to apply.
|
||||
name: Optional Step name. Defaults to 'Transpose'
|
||||
"""
|
||||
super().__init__(["X"], ["transposed"], name)
|
||||
self.perms = perms
|
||||
|
||||
def _create_graph_for_step(self, graph: onnx.GraphProto):
|
||||
input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, 0)
|
||||
perms_str = ",".join([str(idx) for idx in self.perms])
|
||||
dims = input_shape_str.split(",")
|
||||
output_dims = [dims[axis] for axis in self.perms]
|
||||
output_shape_str = ",".join(output_dims)
|
||||
|
||||
transpose_graph = onnx.parser.parse_graph(
|
||||
f"""\
|
||||
transpose ({input_type_str}[{input_shape_str}] {self.input_names[0]})
|
||||
=> ({input_type_str}[{output_shape_str}] {self.output_names[0]})
|
||||
{{
|
||||
{self.output_names[0]} = Transpose <perm = [{perms_str}]> ({self.input_names[0]})
|
||||
}}
|
||||
"""
|
||||
)
|
||||
|
||||
return transpose_graph
|
||||
|
||||
|
||||
class Softmax(Step):
|
||||
"""
|
||||
ONNX Softmax
|
||||
"""
|
||||
|
||||
def __init__(self, name: Optional[str] = None):
|
||||
"""
|
||||
Args:
|
||||
name: Optional Step name. Defaults to 'Softmax'
|
||||
"""
|
||||
super().__init__(["data"], ["probabilities"], name)
|
||||
|
||||
def _create_graph_for_step(self, graph: onnx.GraphProto):
|
||||
input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, 0)
|
||||
|
||||
softmax_graph = onnx.parser.parse_graph(
|
||||
f"""\
|
||||
softmax ({input_type_str}[{input_shape_str}] {self.input_names[0]})
|
||||
=> ({input_type_str}[{input_shape_str}] {self.output_names[0]})
|
||||
{{
|
||||
{self.output_names[0]} = Softmax ({self.input_names[0]})
|
||||
}}
|
||||
"""
|
||||
)
|
||||
|
||||
return softmax_graph
|
||||
|
||||
|
||||
class Unsqueeze(Step):
|
||||
"""
|
||||
ONNX Unsqueeze
|
||||
"""
|
||||
|
||||
def __init__(self, axes: List[int], name: Optional[str] = None):
|
||||
"""
|
||||
Args:
|
||||
axes: List of integers indicating the dimensions to be inserted.
|
||||
name: Optional Step name. Defaults to 'Unsqueeze'
|
||||
"""
|
||||
super().__init__(["data"], ["expanded"], name)
|
||||
self._axes = axes
|
||||
|
||||
def _create_graph_for_step(self, graph: onnx.GraphProto):
|
||||
input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, 0)
|
||||
dims = input_shape_str.split(",")
|
||||
|
||||
for idx in self._axes:
|
||||
dims.insert(idx, "1")
|
||||
|
||||
output_shape_str = ",".join(dims)
|
||||
axes_strs = [str(axis) for axis in self._axes]
|
||||
|
||||
unsqueeze_graph = onnx.parser.parse_graph(
|
||||
f"""\
|
||||
unsqueeze ({input_type_str}[{input_shape_str}] {self.input_names[0]})
|
||||
=> ({input_type_str}[{output_shape_str}] {self.output_names[0]})
|
||||
{{
|
||||
axes = Constant <value = int64[{len(self._axes)}] {{{','.join(axes_strs)}}}> ()
|
||||
{self.output_names[0]} = Unsqueeze ({self.input_names[0]}, axes)
|
||||
}}
|
||||
"""
|
||||
)
|
||||
|
||||
return unsqueeze_graph
|
|
@ -0,0 +1,549 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import onnx
|
||||
import numpy as np
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from ..step import Step
|
||||
from .general import Transpose
|
||||
|
||||
|
||||
#
|
||||
# Image conversion
|
||||
#
|
||||
class ConvertImageToBGR(Step):
|
||||
"""
|
||||
Convert the bytes of an image by decoding to BGR ordered uint8 values.
|
||||
Supported input formats: jpg, png
|
||||
Input shape: {num_encoded_bytes}
|
||||
Output shape: {input_image_height, input_image_width, 3}
|
||||
"""
|
||||
|
||||
def __init__(self, name: Optional[str] = None):
|
||||
"""
|
||||
Args:
|
||||
name: Optional name of step. Defaults to 'ConvertImageToBGR'
|
||||
|
||||
NOTE: Input image format is inferred and does not need to be specified.
|
||||
"""
|
||||
super().__init__(["image"], ["bgr_data"], name)
|
||||
|
||||
def _create_graph_for_step(self, graph: onnx.GraphProto):
|
||||
input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, 0)
|
||||
assert input_type_str == "uint8"
|
||||
output_shape_str = f"to_bgr_ppp_{self.step_num}_h, to_bgr_ppp_{self.step_num}_w, 3"
|
||||
|
||||
converter_graph = onnx.parser.parse_graph(
|
||||
f"""\
|
||||
image_to_bgr (uint8[{input_shape_str}] {self.input_names[0]})
|
||||
=> (uint8[{output_shape_str}] {self.output_names[0]})
|
||||
{{
|
||||
{self.output_names[0]} = com.microsoft.extensions.DecodeImage({self.input_names[0]})
|
||||
}}
|
||||
"""
|
||||
)
|
||||
|
||||
return converter_graph
|
||||
|
||||
|
||||
class ConvertBGRToImage(Step):
|
||||
"""
|
||||
Convert BGR ordered uint8 data into an encoded image.
|
||||
Supported output input formats: jpg, png
|
||||
Input shape: {input_image_height, input_image_width, 3}
|
||||
Output shape: {num_encoded_bytes}
|
||||
"""
|
||||
|
||||
def __init__(self, image_format: str = "jpg", name: Optional[str] = None):
|
||||
"""
|
||||
Args:
|
||||
image_format: Format to encode to. jpg and png are supported.
|
||||
name: Optional step name. Defaults to 'ConvertBGRToImage'
|
||||
"""
|
||||
super().__init__(["bgr_data"], ["image"], name)
|
||||
assert image_format == "jpg" or image_format == "png"
|
||||
self._format = image_format
|
||||
|
||||
def _create_graph_for_step(self, graph: onnx.GraphProto):
|
||||
input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, 0)
|
||||
assert input_type_str == "uint8"
|
||||
output_shape_str = f"to_image_ppp_{self.step_num}_num_bytes"
|
||||
|
||||
converter_graph = onnx.parser.parse_graph(
|
||||
f"""\
|
||||
bgr_to_image (uint8[{input_shape_str}] {self.input_names[0]})
|
||||
=> (uint8[{output_shape_str}] {self.output_names[0]})
|
||||
{{
|
||||
{self.output_names[0]} = com.microsoft.extensions.EncodeImage ({self.input_names[0]})
|
||||
}}
|
||||
"""
|
||||
)
|
||||
|
||||
# as this is a custom op we have to add the attribute for `format` directly to the node.
|
||||
# parse_graph doesn't have a schema for the operator and fails attempting to validate the attribute.
|
||||
format_attr = converter_graph.node[0].attribute.add()
|
||||
format_attr.name = "format"
|
||||
format_attr.type = onnx.AttributeProto.AttributeType.STRING
|
||||
format_attr.s = bytes(self._format, "utf-8")
|
||||
|
||||
return converter_graph
|
||||
|
||||
|
||||
class PixelsToYCbCr(Step):
|
||||
"""
|
||||
Convert RGB or BGR pixel data to YCbCr format.
|
||||
Input shape: {height, width, 3}
|
||||
Output shape is the same.
|
||||
Output data is float, but rounded and clipped to the range 0..255 as per the spec for YCbCr conversion.
|
||||
"""
|
||||
|
||||
def __init__(self, layout: str = "BGR", name: Optional[str] = None):
|
||||
"""
|
||||
Args:
|
||||
layout: Input data layout. Can be 'BGR' or 'RGB'
|
||||
name: Optional step name. Defaults to 'PixelsToYCbCr'
|
||||
"""
|
||||
super().__init__(["pixels"], ["Y", "Cb", "Cr"], name)
|
||||
assert layout == "RGB" or layout == "BGR"
|
||||
self._layout = layout
|
||||
|
||||
def _create_graph_for_step(self, graph: onnx.GraphProto):
|
||||
input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, 0)
|
||||
# input should be uint8 data HWC
|
||||
input_dims = input_shape_str.split(",")
|
||||
assert input_type_str == "uint8" and len(input_dims) == 3 and input_dims[2] == "3"
|
||||
rgb_weights = np.array([[0.299, 0.587, 0.114],
|
||||
[-0.299 / 1.772, -0.587 / 1.772, 0.500],
|
||||
[0.500, -0.587 / 1.402, -0.114 / 1.402]],
|
||||
dtype=np.float32) # fmt: skip
|
||||
|
||||
bias = [0.0, 128.0, 128.0]
|
||||
|
||||
if self._layout == "RGB":
|
||||
weights = rgb_weights
|
||||
else:
|
||||
weights = rgb_weights[:, ::-1] # reverse the order of the last dim to match
|
||||
|
||||
# Weights are transposed for usage in matmul.
|
||||
weights_shape = "3, 3"
|
||||
weights = ",".join([str(w) for w in weights.T.flatten()])
|
||||
|
||||
bias_shape = "3"
|
||||
bias = ",".join([str(b) for b in bias])
|
||||
|
||||
# each output is {h, w}. TBD if input is CHW or HWC though. Once we figure that out we could copy values from
|
||||
# the input shape
|
||||
output_shape_str = f"YCbCr_ppp_{self.step_num}_h, YCbCr_ppp_{self.step_num}_w"
|
||||
assert input_type_str == "uint8"
|
||||
|
||||
# convert to float for MatMul
|
||||
# apply weights and bias
|
||||
# round and clip so it's in the range 0..255
|
||||
# convert back to uint8
|
||||
# split into channels. shape will be {h, w, 1}
|
||||
# remove the trailing '1' so output is {h, w}
|
||||
converter_graph = onnx.parser.parse_graph(
|
||||
f"""\
|
||||
pixels_to_YCbCr (uint8[{input_shape_str}] {self.input_names[0]})
|
||||
=> (float[{output_shape_str}] {self.output_names[0]},
|
||||
float[{output_shape_str}] {self.output_names[1]},
|
||||
float[{output_shape_str}] {self.output_names[2]})
|
||||
{{
|
||||
kWeights = Constant <value = float[{weights_shape}] {{{weights}}}> ()
|
||||
kBias = Constant <value = float[{bias_shape}] {{{bias}}}> ()
|
||||
i64_neg1 = Constant <value = int64[1] {{-1}}> ()
|
||||
f_0 = Constant <value = float[1] {{0.0}}> ()
|
||||
f_255 = Constant <value = float[1] {{255.0}}> ()
|
||||
|
||||
f_pixels = Cast <to = 1> ({self.input_names[0]})
|
||||
f_weighted = MatMul(f_pixels, kWeights)
|
||||
f_biased = Add(f_weighted, kBias)
|
||||
f_rounded = Round(f_biased)
|
||||
f_clipped = Clip (f_rounded, f_0, f_255)
|
||||
split_Y, split_Cb, split_Cr = Split <axis = -1>(f_clipped)
|
||||
{self.output_names[0]} = Squeeze (split_Y, i64_neg1)
|
||||
{self.output_names[1]} = Squeeze (split_Cb, i64_neg1)
|
||||
{self.output_names[2]} = Squeeze (split_Cr, i64_neg1)
|
||||
}}
|
||||
"""
|
||||
)
|
||||
|
||||
return converter_graph
|
||||
|
||||
|
||||
class YCbCrToPixels(Step):
|
||||
"""
|
||||
Convert YCbCr input to RGB or BGR.
|
||||
|
||||
Input data can be uint8 or float but all inputs must use the same type.
|
||||
Input shape: {height, width, 3}
|
||||
Output shape is the same.
|
||||
"""
|
||||
|
||||
def __init__(self, layout: str = "BGR", name: Optional[str] = None):
|
||||
"""
|
||||
Args:
|
||||
layout: Output layout. Can be 'BGR' or 'RGB'
|
||||
name: Optional step name. Defaults to 'YCbCrToPixels'
|
||||
"""
|
||||
super().__init__(["Y", "Cb", "Cr"], ["bgr_data"], name)
|
||||
assert layout == "RGB" or layout == "BGR"
|
||||
self._layout = layout
|
||||
|
||||
def _create_graph_for_step(self, graph: onnx.GraphProto):
|
||||
input_type_str0, input_shape_str0 = self._get_input_type_and_shape_strs(graph, 0)
|
||||
input_type_str1, input_shape_str1 = self._get_input_type_and_shape_strs(graph, 1)
|
||||
input_type_str2, input_shape_str2 = self._get_input_type_and_shape_strs(graph, 2)
|
||||
assert (input_type_str0 == "uint8" and input_type_str1 == "uint8" and input_type_str2 == "uint8") or (
|
||||
input_type_str0 == "float" and input_type_str1 == "float" and input_type_str2 == "float"
|
||||
)
|
||||
|
||||
assert (
|
||||
len(input_shape_str0.split(",")) == 2
|
||||
and len(input_shape_str1.split(",")) == 2
|
||||
and len(input_shape_str2.split(",")) == 2
|
||||
)
|
||||
|
||||
output_shape_str = f"{input_shape_str0}, 3"
|
||||
|
||||
# fmt: off
|
||||
# https://en.wikipedia.org/wiki/YCbCr
|
||||
# exact weights from https://www.itu.int/rec/T-REC-T.871-201105-I/en
|
||||
ycbcr_to_rgb_weights = np.array([[1, 0, 1.402],
|
||||
[1, -0.114*1.772/0.587, -0.299*1.402/0.587],
|
||||
[1, 1.772, 0]],
|
||||
dtype=np.float32)
|
||||
|
||||
# reverse 2nd and 3rd entry in each row (YCbCr to YCrCb so blue and red are flipped)
|
||||
ycbcr_to_bgr_weights = np.array([[1, 1.402, 0],
|
||||
[1, -0.299*1.402/0.587, -0.114*1.772/0.587],
|
||||
[1, 0, 1.772]],
|
||||
dtype=np.float32)
|
||||
# fmt: on
|
||||
|
||||
weights = ycbcr_to_bgr_weights if self._layout == "BGR" else ycbcr_to_rgb_weights
|
||||
bias = [0.0, 128.0, 128.0]
|
||||
|
||||
weights_shape = "3, 3"
|
||||
# transpose weights for use in matmul
|
||||
weights = ",".join([str(w) for w in weights.T.flatten()])
|
||||
|
||||
bias_shape = "3"
|
||||
bias = ",".join([str(b) for b in bias])
|
||||
|
||||
# unsqueeze the {h, w} inputs to add channels dim. new shape is {h, w, 1}
|
||||
# merge Y, Cb, Cr data on the new channel axis
|
||||
# convert to float to apply weights etc.
|
||||
# remove bias
|
||||
# apply weights
|
||||
# round and clip to 0..255
|
||||
# convert to uint8.
|
||||
converter_graph = onnx.parser.parse_graph(
|
||||
f"""\
|
||||
YCbCr_to_RGB ({input_type_str0}[{input_shape_str0}] {self.input_names[0]},
|
||||
{input_type_str1}[{input_shape_str1}] {self.input_names[1]},
|
||||
{input_type_str2}[{input_shape_str2}] {self.input_names[2]})
|
||||
=> (uint8[{output_shape_str}] {self.output_names[0]})
|
||||
{{
|
||||
kWeights = Constant <value = float[{weights_shape}] {{{weights}}}> ()
|
||||
kBias = Constant <value = float[{bias_shape}] {{{bias}}}> ()
|
||||
f_0 = Constant <value = float[1] {{0.0}}> ()
|
||||
f_255 = Constant <value = float[1] {{255.0}}> ()
|
||||
i64_neg1 = Constant <value = int64[1] {{-1}}> ()
|
||||
|
||||
Y1 = Unsqueeze({self.input_names[0]}, i64_neg1)
|
||||
Cb1 = Unsqueeze({self.input_names[1]}, i64_neg1)
|
||||
Cr1 = Unsqueeze({self.input_names[2]}, i64_neg1)
|
||||
YCbCr = Concat <axis = -1> (Y1, Cb1, Cr1)
|
||||
f_YCbCr = Cast <to = 1> (YCbCr)
|
||||
f_unbiased = Sub (f_YCbCr, kBias)
|
||||
f_pixels = MatMul (f_unbiased, kWeights)
|
||||
f_rounded = Round (f_pixels)
|
||||
clipped = Clip (f_rounded, f_0, f_255)
|
||||
{self.output_names[0]} = Cast <to = {onnx.TensorProto.UINT8}> (clipped)
|
||||
}}
|
||||
"""
|
||||
)
|
||||
|
||||
return converter_graph
|
||||
|
||||
|
||||
#
|
||||
# Pre-processing
|
||||
#
|
||||
class Resize(Step):
|
||||
"""
|
||||
Resize input data. Aspect ratio is maintained.
|
||||
e.g. if image is 1200 x 600 and 300 x 300 is requested the result will be 600 x 300
|
||||
"""
|
||||
|
||||
def __init__(self, resize_to: Union[int, Tuple[int, int]], layout: str = "HWC", name: Optional[str] = None):
|
||||
"""
|
||||
Args:
|
||||
resize_to: Target size. Can be a single value or a tuple with (target_height, target_width).
|
||||
The aspect ratio will be maintained and neither height or width in the result will be smaller
|
||||
than the requested value.
|
||||
layout: Input layout. 'CHW', 'HWC' and 'HW' are supported.
|
||||
name: Optional name. Defaults to 'Resize'
|
||||
"""
|
||||
super().__init__(["image"], ["resized_image"], name)
|
||||
if isinstance(resize_to, int):
|
||||
self._height = self._width = resize_to
|
||||
else:
|
||||
assert isinstance(resize_to, tuple)
|
||||
self._height, self._width = resize_to
|
||||
|
||||
self._layout = layout
|
||||
|
||||
def _create_graph_for_step(self, graph: onnx.GraphProto):
|
||||
input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, 0)
|
||||
dims = input_shape_str.split(",")
|
||||
|
||||
# adjust for layout
|
||||
# resize will use the largest ratio so both sides won't necessarily match the requested height and width.
|
||||
# use symbolic names for the output dims as we have to provide values. prefix the names to try and
|
||||
# avoid any clashes
|
||||
scales_constant_str = "f_1 = Constant <value = float[1] {1.0}> ()"
|
||||
if self._layout == "HWC":
|
||||
assert len(dims) == 3
|
||||
split_str = "h, w, c"
|
||||
scales_str = "ratio_resize, ratio_resize, f_1"
|
||||
output_shape_str = f"resize_ppp_{self.step_num}_h, resize_ppp_{self.step_num}_w, {dims[-1]}"
|
||||
elif self._layout == "CHW":
|
||||
assert len(dims) == 3
|
||||
split_str = "c, h, w"
|
||||
scales_str = "f_1, ratio_resize, ratio_resize"
|
||||
output_shape_str = f"{dims[0]}, resize_ppp_{self.step_num}_h, resize_ppp_{self.step_num}_w"
|
||||
elif self._layout == "HW":
|
||||
assert len(dims) == 2
|
||||
split_str = "h, w"
|
||||
scales_str = "ratio_resize, ratio_resize"
|
||||
scales_constant_str = ""
|
||||
output_shape_str = f"resize_ppp_{self.step_num}_h, resize_ppp_{self.step_num}_w"
|
||||
else:
|
||||
raise ValueError(f"Unsupported layout of {self._layout}")
|
||||
|
||||
# TODO: Make this configurable. Matching PIL resize for now
|
||||
resize_attributes = 'mode = "linear", nearest_mode = "floor"'
|
||||
|
||||
resize_graph = onnx.parser.parse_graph(
|
||||
f"""\
|
||||
resize ({input_type_str}[{input_shape_str}] {self.input_names[0]}) =>
|
||||
({input_type_str}[{output_shape_str}] {self.output_names[0]})
|
||||
{{
|
||||
target_size = Constant <value=float[2] {{{float(self._height)}, {float(self._width)}}}> ()
|
||||
image_shape = Shape ({self.input_names[0]})
|
||||
{split_str} = Split <axis=0> (image_shape)
|
||||
hw = Concat <axis = 0> (h, w)
|
||||
f_hw = Cast <to = 1> (hw)
|
||||
ratios = Div (target_size, f_hw)
|
||||
ratio_resize = ReduceMax (ratios)
|
||||
|
||||
{scales_constant_str}
|
||||
scales_resize = Concat <axis = 0> ({scales_str})
|
||||
{self.output_names[0]} = Resize <{resize_attributes}> ({self.input_names[0]}, , scales_resize)
|
||||
}}
|
||||
"""
|
||||
)
|
||||
|
||||
return resize_graph
|
||||
|
||||
|
||||
class CenterCrop(Step):
|
||||
"""
|
||||
Crop the input to the requested dimensions, with the crop being centered.
|
||||
"""
|
||||
|
||||
def __init__(self, height: int, width: int, name: Optional[str] = None):
|
||||
"""
|
||||
Args:
|
||||
height: Height of area to crop.
|
||||
width: Width of area to crop.
|
||||
name: Optional step name. Defaults to 'CenterCrop'
|
||||
"""
|
||||
super().__init__(["image"], ["cropped_image"], name)
|
||||
self._height = height
|
||||
self._width = width
|
||||
|
||||
def _create_graph_for_step(self, graph: onnx.GraphProto):
|
||||
input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, 0)
|
||||
dims = input_shape_str.split(",")
|
||||
output_shape_str = f"{self._height}, {self._width}, {dims[-1]}"
|
||||
|
||||
crop_graph = onnx.parser.parse_graph(
|
||||
f"""\
|
||||
crop ({input_type_str}[{input_shape_str}] {self.input_names[0]})
|
||||
=> ({input_type_str}[{output_shape_str}] {self.output_names[0]})
|
||||
{{
|
||||
target_crop = Constant <value = int64[2] {{{self._height}, {self._width}}}> ()
|
||||
i64_2 = Constant <value = int64[1] {{2}}> ()
|
||||
axes = Constant <value = int64[2] {{0, 1}}> ()
|
||||
x_shape = Shape ({self.input_names[0]})
|
||||
h, w, c = Split <axis = 0> (x_shape)
|
||||
hw = Concat <axis = 0> (h, w)
|
||||
hw_diff = Sub (hw, target_crop)
|
||||
start_xy = Div (hw_diff, i64_2)
|
||||
end_xy = Add (start_xy, target_crop)
|
||||
{self.output_names[0]} = Slice ({self.input_names[0]}, start_xy, end_xy, axes)
|
||||
}}
|
||||
"""
|
||||
)
|
||||
|
||||
return crop_graph
|
||||
|
||||
|
||||
class Normalize(Step):
|
||||
"""
|
||||
Normalize input data on a per-channel basis.
|
||||
`x -> (x - mean) / stddev`
|
||||
Output is float with same shape as input.
|
||||
"""
|
||||
|
||||
def __init__(self, normalization_values: List[Tuple[float, float]], layout: str = "CHW", name: Optional[str] = None):
|
||||
"""
|
||||
Args:
|
||||
normalization_values: Tuple with (mean, stddev). One entry per channel.
|
||||
If single entry is provided it will be used for all channels.
|
||||
layout: Input layout. Can be 'CHW' or 'HWC'
|
||||
name: Optional step name. Defaults to 'Normalize'
|
||||
"""
|
||||
super().__init__(["data"], ["normalized_data"], name)
|
||||
|
||||
# duplicate for each channel if needed
|
||||
if len(normalization_values) == 1:
|
||||
normalization_values *= 3
|
||||
|
||||
assert len(normalization_values) == 3
|
||||
self._normalization_values = normalization_values
|
||||
assert layout == "HWC" or layout == "CHW"
|
||||
self._hwc_layout = True if layout == "HWC" else False
|
||||
|
||||
def _create_graph_for_step(self, graph: onnx.GraphProto):
|
||||
mean0 = self._normalization_values[0][0]
|
||||
mean1 = self._normalization_values[1][0]
|
||||
mean2 = self._normalization_values[2][0]
|
||||
stddev0 = self._normalization_values[0][1]
|
||||
stddev1 = self._normalization_values[1][1]
|
||||
stddev2 = self._normalization_values[2][1]
|
||||
|
||||
input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, 0)
|
||||
values_shape = "3" if self._hwc_layout else "3, 1, 1"
|
||||
|
||||
normalize_graph = onnx.parser.parse_graph(
|
||||
f"""\
|
||||
normalize ({input_type_str}[{input_shape_str}] {self.input_names[0]})
|
||||
=> (float[{input_shape_str}] {self.output_names[0]})
|
||||
{{
|
||||
kMean = Constant <value = float[{values_shape}] {{{mean0}, {mean1}, {mean2}}}> ()
|
||||
kStddev = Constant <value = float[{values_shape}] {{{stddev0}, {stddev1}, {stddev2}}}> ()
|
||||
f_input = Cast <to = 1> ({self.input_names[0]})
|
||||
f_sub_mean = Sub (f_input, kMean)
|
||||
{self.output_names[0]} = Div (f_sub_mean, kStddev)
|
||||
}}
|
||||
"""
|
||||
)
|
||||
|
||||
onnx.checker.check_graph(normalize_graph)
|
||||
return normalize_graph
|
||||
|
||||
|
||||
#
|
||||
# Utilities
|
||||
#
|
||||
class ImageBytesToFloat(Step):
|
||||
"""
|
||||
Convert uint8 or float values in range 0..255 to floating point values in range 0..1
|
||||
"""
|
||||
|
||||
def __init__(self, name: Optional[str] = None):
|
||||
"""
|
||||
Args:
|
||||
name: Optional step name. Defaults to 'ImageBytesToFloat'
|
||||
"""
|
||||
super().__init__(["data"], ["float_data"], name)
|
||||
|
||||
def _create_graph_for_step(self, graph: onnx.GraphProto):
|
||||
input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, 0)
|
||||
if input_type_str == "uint8":
|
||||
optional_cast = f"""\
|
||||
input_f = Cast <to = 1> ({self.input_names[0]})
|
||||
"""
|
||||
else:
|
||||
# no-op that optimizer will remove
|
||||
optional_cast = f"input_f = Identity ({self.input_names[0]})"
|
||||
|
||||
byte_to_float_graph = onnx.parser.parse_graph(
|
||||
f"""\
|
||||
byte_to_float ({input_type_str}[{input_shape_str}] {self.input_names[0]})
|
||||
=> (float[{input_shape_str}] {self.output_names[0]})
|
||||
{{
|
||||
f_255 = Constant <value = float[1] {{255.0}}>()
|
||||
|
||||
{optional_cast}
|
||||
{self.output_names[0]} = Div(input_f, f_255)
|
||||
}}
|
||||
"""
|
||||
)
|
||||
|
||||
onnx.checker.check_graph(byte_to_float_graph)
|
||||
return byte_to_float_graph
|
||||
|
||||
|
||||
class FloatToImageBytes(Step):
|
||||
"""
|
||||
Converting floating point values to uint8 values in range 0..255.
|
||||
Typically this reverses ImageBytesToFloat by converting input data in the range 0..1, but an optional multiplier
|
||||
can be specified if the input data has a different range.
|
||||
Values will be rounded prior to clipping and conversion to uint8.
|
||||
"""
|
||||
|
||||
def __init__(self, multiplier: float = 255.0, name: Optional[str] = None):
|
||||
"""
|
||||
Args:
|
||||
multiplier: Optional multiplier. Currently, the expected values are 255 (input data is in range 0..1), or
|
||||
1 (input data is in range 0..255).
|
||||
name: Optional step name. Defaults to 'FloatToImageBytes'
|
||||
"""
|
||||
super().__init__(["float_data"], ["pixel_data"], name)
|
||||
self._multiplier = multiplier
|
||||
|
||||
def _create_graph_for_step(self, graph: onnx.GraphProto):
|
||||
input_type_str, input_shape_str = self._get_input_type_and_shape_strs(graph, 0)
|
||||
assert input_type_str == "float"
|
||||
|
||||
float_to_byte_graphs = onnx.parser.parse_graph(
|
||||
f"""\
|
||||
float_to_type (float[{input_shape_str}] {self.input_names[0]})
|
||||
=> (uint8[{input_shape_str}] {self.output_names[0]})
|
||||
{{
|
||||
f_0 = Constant <value = float[1] {{0.0}}> ()
|
||||
f_255 = Constant <value = float[1] {{255.0}}>()
|
||||
f_multiplier = Constant <value = float[1] {{{self._multiplier}}}> ()
|
||||
|
||||
scaled_input = Mul ({self.input_names[0]}, f_multiplier)
|
||||
rounded = Round (scaled_input)
|
||||
clipped = Clip (rounded, f_0, f_255)
|
||||
{self.output_names[0]} = Cast <to = {onnx.TensorProto.UINT8}> (clipped)
|
||||
}}
|
||||
"""
|
||||
)
|
||||
|
||||
onnx.checker.check_graph(float_to_byte_graphs)
|
||||
return float_to_byte_graphs
|
||||
|
||||
|
||||
class ChannelsLastToChannelsFirst(Transpose):
|
||||
"""
|
||||
Convert channels last data to channels first.
|
||||
Input can be NHWC or HWC.
|
||||
"""
|
||||
|
||||
def __init__(self, has_batch_dim: bool = False, name: Optional[str] = None):
|
||||
"""
|
||||
Args:
|
||||
has_batch_dim: Set to True if the input has a batch dimension (i.e. is NHWC)
|
||||
name: Optional step name. Defaults to 'ChannelsLastToChannelsFirst'
|
||||
"""
|
||||
perms = [0, 3, 1, 2] if has_batch_dim else [2, 0, 1]
|
||||
super().__init__(perms, name)
|
|
@ -0,0 +1,124 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import onnx
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Union
|
||||
|
||||
|
||||
def create_named_value(name: str, data_type: int, shape: List[Union[str, int]]):
|
||||
"""
|
||||
Helper to create a new model input.
|
||||
|
||||
Args:
|
||||
name: Name for input. Must not already be in use in the model being updated.
|
||||
data_type: onnx.TensorProto data type. e.g. onnx.TensorProto.FLOAT, onnx.TensorProto.UINT8
|
||||
shape: Input shape. Use int for dimensions with known values and strings for symbolic dimensions.
|
||||
e.g. ['batch_size', 256, 256] would be a rank 3 tensor with a symbolic first dimension named 'batch_size'
|
||||
|
||||
|
||||
Returns:
|
||||
An onnx.ValueInfoProto that can be used as a new model input.
|
||||
"""
|
||||
tensor_type = onnx.helper.make_tensor_type_proto(elem_type=data_type, shape=shape)
|
||||
return onnx.helper.make_value_info(name, tensor_type)
|
||||
|
||||
|
||||
# We need to use an opset that's valid for the pre/post processing operators we add.
|
||||
# Could alternatively use onnx.defs.onnx_opset_version to match the onnx version installed, but that's not deterministic
|
||||
# For now it's an arbitrary default of ONNX v16.
|
||||
# NOTE: If we update this value we need to make sure the operators used in all steps are also updated if their spec
|
||||
# has changed.
|
||||
PRE_POST_PROCESSING_ONNX_OPSET = 16
|
||||
|
||||
|
||||
def get_opset_imports():
|
||||
"""Get the opset imports for a model updated by the PrePostProcessor."""
|
||||
return {
|
||||
"": PRE_POST_PROCESSING_ONNX_OPSET,
|
||||
"com.microsoft.extensions": 1
|
||||
} # fmt: skip
|
||||
|
||||
|
||||
# Create an onnx checker context that includes the ort-ext domain so that custom ops don't cause failure
|
||||
def create_custom_op_checker_context():
|
||||
"""
|
||||
Create an ONNX checker context that includes the ort-extensions custom op domains so that custom ops don't
|
||||
cause failure when running onnx.checker.check_graph.
|
||||
Returns:
|
||||
|
||||
"""
|
||||
context = onnx.checker.C.CheckerContext()
|
||||
context.ir_version = onnx.checker.DEFAULT_CONTEXT.ir_version
|
||||
context.opset_imports = get_opset_imports()
|
||||
|
||||
return context
|
||||
|
||||
|
||||
# The ONNX graph parser has it's own map of names just to be special
|
||||
# https://github.com/onnx/onnx/blob/604af9cb28f63a6b9924237dcb91530649233db9/onnx/defs/parser.h#L72
|
||||
TENSOR_TYPE_TO_ONNX_TYPE = {
|
||||
int(onnx.TensorProto.FLOAT): "float",
|
||||
int(onnx.TensorProto.UINT8): "uint8",
|
||||
int(onnx.TensorProto.INT8): "int8",
|
||||
int(onnx.TensorProto.UINT16): "uint16",
|
||||
int(onnx.TensorProto.INT16): "int16",
|
||||
int(onnx.TensorProto.INT32): "int32",
|
||||
int(onnx.TensorProto.INT64): "int64",
|
||||
int(onnx.TensorProto.STRING): "string",
|
||||
int(onnx.TensorProto.BOOL): "bool",
|
||||
int(onnx.TensorProto.FLOAT16): "float16",
|
||||
int(onnx.TensorProto.DOUBLE): "double",
|
||||
int(onnx.TensorProto.UINT32): "uint32",
|
||||
int(onnx.TensorProto.UINT64): "uint64",
|
||||
int(onnx.TensorProto.COMPLEX64): "complex64",
|
||||
int(onnx.TensorProto.COMPLEX128): "complex128",
|
||||
int(onnx.TensorProto.BFLOAT16): "bfloat16",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class IoMapEntry:
|
||||
"""Entry to map the output index from a producer step to the input index of a consumer step."""
|
||||
|
||||
# optional producer
|
||||
# Uses Step if provided.
|
||||
# If a str with a previous Step name is provided the PrePostProcessor will find the relevant Step
|
||||
# If neither are provided the producer is inferred to be the immediately previous Step in the pipeline
|
||||
producer: Union["Step", str] = None
|
||||
# output index from the producer step
|
||||
producer_idx: int = 0
|
||||
# input index of the consumer step
|
||||
consumer_idx: int = 0
|
||||
|
||||
|
||||
def sanitize_output_names(graph: onnx.GraphProto):
|
||||
"""
|
||||
Convert any usage of invalid characters like '/' and ';' in value names to '_'
|
||||
This is common in models exported from TensorFlow [Lite].
|
||||
|
||||
ONNX parse_graph does not allow for that in a value name, and technically it's a violation of the ONNX spec as per
|
||||
https://github.com/onnx/onnx/blob/main/docs/IR.md#names-within-a-graph
|
||||
|
||||
We do this for the original graph outputs only. The invalid naming has not been seen in model inputs, and we can
|
||||
leave the internals of the graph intact to minimize changes.
|
||||
|
||||
Args:
|
||||
graph: Graph to check and update any invalid names
|
||||
"""
|
||||
|
||||
bad_output_names = [o.name for o in graph.output if "/" in o.name or ";" in o.name]
|
||||
if not bad_output_names:
|
||||
return graph
|
||||
|
||||
renames = {}
|
||||
for n in bad_output_names:
|
||||
renames[n] = n.replace("/", "_").replace(";", "_")
|
||||
|
||||
for o in graph.output:
|
||||
if o.name in bad_output_names:
|
||||
# Add Identity node to rename the output, and update the name in graph.output
|
||||
rename = onnx.helper.make_node("Identity", [o.name], [renames[o.name]], f"Rename {o.name}")
|
||||
graph.node.append(rename)
|
||||
o.name = renames[o.name]
|
Загрузка…
Ссылка в новой задаче