OpenCV Image Decoder & SuperResolution CustomOps
This commit is contained in:
Родитель
d29f6d0f42
Коммит
78d8dd5705
|
@ -35,7 +35,7 @@ option(OCOS_ENABLE_BLINGFIRE "Enable operators depending on the Blingfire librar
|
|||
option(OCOS_ENABLE_MATH "Enable math tensor operators building" ON)
|
||||
option(OCOS_ENABLE_DLIB "Enable operators like Inverse depending on DLIB" ON)
|
||||
option(OCOS_ENABLE_OPENCV "Enable operators depending on opencv" ON)
|
||||
option(OCOS_ENABLE_OPENCV_CODECS "Enable operators depending on opencv imgcodecs" OFF)
|
||||
option(OCOS_ENABLE_OPENCV_CODECS "Enable operators depending on opencv imgcodecs" ON)
|
||||
option(OCOS_ENABLE_STATIC_LIB "Enable generating static library" OFF)
|
||||
option(OCOS_ENABLE_SELECTED_OPLIST "Enable including the selected_ops tool file" OFF)
|
||||
|
||||
|
@ -265,7 +265,7 @@ if (OCOS_ENABLE_OPENCV)
|
|||
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_OPENCV ENABLE_OPENCV_CODEC)
|
||||
endif()
|
||||
list(APPEND ocos_libraries ${opencv_LIBS})
|
||||
target_include_directories(ocos_operators PRIVATE ${opencv_INCLUDE_DIRS})
|
||||
target_include_directories(ocos_operators PUBLIC ${opencv_INCLUDE_DIRS})
|
||||
endif()
|
||||
|
||||
if (OCOS_ENABLE_GPT2_TOKENIZER)
|
||||
|
|
|
@ -1,108 +1,124 @@
|
|||
set(BUILD_TIFF OFF CACHE INTERNAL "")
|
||||
set(BUILD_OPENJPEG OFF CACHE INTERNAL "")
|
||||
set(BUILD_JASPER OFF CACHE INTERNAL "")
|
||||
set(BUILD_JPEG OFF CACHE INTERNAL "")
|
||||
set(BUILD_PNG OFF CACHE INTERNAL "")
|
||||
set(BUILD_OPENEXR OFF CACHE INTERNAL "")
|
||||
set(BUILD_WEBP OFF CACHE INTERNAL "")
|
||||
set(BUILD_TBB OFF CACHE INTERNAL "")
|
||||
set(BUILD_IPP_IW OFF CACHE INTERNAL "")
|
||||
set(BUILD_ITT OFF CACHE INTERNAL "")
|
||||
set(WITH_AVFOUNDATION OFF CACHE INTERNAL "")
|
||||
set(WITH_CAP_IOS OFF CACHE INTERNAL "")
|
||||
set(WITH_CAROTENE OFF CACHE INTERNAL "")
|
||||
set(WITH_CPUFEATURES OFF CACHE INTERNAL "")
|
||||
set(WITH_EIGEN OFF CACHE INTERNAL "")
|
||||
set(WITH_FFMPEG OFF CACHE INTERNAL "")
|
||||
set(WITH_GSTREAMER OFF CACHE INTERNAL "")
|
||||
set(WITH_GTK OFF CACHE INTERNAL "")
|
||||
set(WITH_IPP OFF CACHE INTERNAL "")
|
||||
set(WITH_HALIDE OFF CACHE INTERNAL "")
|
||||
set(WITH_VULKAN OFF CACHE INTERNAL "")
|
||||
set(WITH_INF_ENGINE OFF CACHE INTERNAL "")
|
||||
set(WITH_NGRAPH OFF CACHE INTERNAL "")
|
||||
set(WITH_JASPER OFF CACHE INTERNAL "")
|
||||
set(WITH_OPENJPEG OFF CACHE INTERNAL "")
|
||||
set(WITH_JPEG OFF CACHE INTERNAL "")
|
||||
set(WITH_WEBP OFF CACHE INTERNAL "")
|
||||
set(WITH_OPENEXR OFF CACHE INTERNAL "")
|
||||
set(WITH_PNG OFF CACHE INTERNAL "")
|
||||
set(WITH_TIFF OFF CACHE INTERNAL "")
|
||||
set(WITH_OPENVX OFF CACHE INTERNAL "")
|
||||
set(WITH_GDCM OFF CACHE INTERNAL "")
|
||||
set(WITH_TBB OFF CACHE INTERNAL "")
|
||||
set(WITH_HPX OFF CACHE INTERNAL "")
|
||||
set(WITH_OPENMP OFF CACHE INTERNAL "")
|
||||
set(WITH_PTHREADS_PF OFF CACHE INTERNAL "")
|
||||
set(WITH_V4L OFF CACHE INTERNAL "")
|
||||
set(WITH_CLP OFF CACHE INTERNAL "")
|
||||
set(WITH_OPENCL OFF CACHE INTERNAL "")
|
||||
set(WITH_OPENCL_SVM OFF CACHE INTERNAL "")
|
||||
set(WITH_ITT OFF CACHE INTERNAL "")
|
||||
set(WITH_PROTOBUF OFF CACHE INTERNAL "")
|
||||
set(WITH_IMGCODEC_HDR OFF CACHE INTERNAL "")
|
||||
set(WITH_IMGCODEC_SUNRASTER OFF CACHE INTERNAL "")
|
||||
set(WITH_IMGCODEC_PXM OFF CACHE INTERNAL "")
|
||||
set(WITH_IMGCODEC_PFM OFF CACHE INTERNAL "")
|
||||
set(WITH_QUIRC OFF CACHE INTERNAL "")
|
||||
set(WITH_ANDROID_MEDIANDK OFF CACHE INTERNAL "")
|
||||
set(WITH_TENGINE OFF CACHE INTERNAL "")
|
||||
set(WITH_ONNX OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_apps OFF CACHE INTERNAL "")
|
||||
set(BUILD_ANDROID_PROJECTS OFF CACHE INTERNAL "")
|
||||
set(BUILD_ANDROID_EXAMPLES OFF CACHE INTERNAL "")
|
||||
set(BUILD_DOCS OFF CACHE INTERNAL "")
|
||||
set(BUILD_WITH_STATIC_CRT OFF CACHE INTERNAL "")
|
||||
set(BUILD_FAT_JAVA_LIB OFF CACHE INTERNAL "")
|
||||
set(BUILD_ANDROID_SERVICE OFF CACHE INTERNAL "")
|
||||
set(BUILD_JAVA OFF CACHE INTERNAL "")
|
||||
set(BUILD_OBJC OFF CACHE INTERNAL "")
|
||||
set(ENABLE_PRECOMPILED_HEADERS OFF CACHE INTERNAL "")
|
||||
set(ENABLE_FAST_MATH OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_java OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_gapi OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_objc OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_js OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_ts OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_features2d OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_photo OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_video OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_python2 OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_python3 OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_dnn OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_imgcodecs OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_videoio OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_calib3d OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_highgui OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_flann OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_objdetect OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_stitching OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_ml OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_world OFF CACHE INTERNAL "")
|
||||
set(BUILD_ANDROID_EXAMPLES OFF CACHE INTERNAL "")
|
||||
set(BUILD_ANDROID_PROJECTS OFF CACHE INTERNAL "")
|
||||
set(BUILD_ANDROID_SERVICE OFF CACHE INTERNAL "")
|
||||
set(BUILD_DOCS OFF CACHE INTERNAL "")
|
||||
set(BUILD_FAT_JAVA_LIB OFF CACHE INTERNAL "")
|
||||
set(BUILD_IPP_IW OFF CACHE INTERNAL "")
|
||||
set(BUILD_ITT OFF CACHE INTERNAL "")
|
||||
set(BUILD_JASPER OFF CACHE INTERNAL "")
|
||||
set(BUILD_JAVA OFF CACHE INTERNAL "")
|
||||
# set(BUILD_JPEG OFF CACHE INTERNAL "")
|
||||
set(BUILD_OBJC OFF CACHE INTERNAL "")
|
||||
# set(BUILD_OPENJPEG OFF CACHE INTERNAL "")
|
||||
# set(BUILD_PNG OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_apps OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_calib3d OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_dnn OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_features2d OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_flann OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_gapi OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_highgui OFF CACHE INTERNAL "")
|
||||
# set(BUILD_opencv_imgcodecs OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_java OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_js OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_ml OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_objc OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_objdetect OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_photo OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_python2 OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_python3 OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_stitching OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_ts OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_video OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_videoio OFF CACHE INTERNAL "")
|
||||
set(BUILD_opencv_world OFF CACHE INTERNAL "")
|
||||
set(BUILD_OPENEXR OFF CACHE INTERNAL "")
|
||||
set(BUILD_TBB OFF CACHE INTERNAL "")
|
||||
set(BUILD_TIFF OFF CACHE INTERNAL "")
|
||||
set(BUILD_WEBP OFF CACHE INTERNAL "")
|
||||
set(BUILD_WITH_STATIC_CRT OFF CACHE INTERNAL "")
|
||||
set(BUILD_ZLIB OFF CACHE INTERNAL "")
|
||||
set(ENABLE_FAST_MATH OFF CACHE INTERNAL "")
|
||||
set(ENABLE_PRECOMPILED_HEADERS OFF CACHE INTERNAL "")
|
||||
set(WITH_ANDROID_MEDIANDK OFF CACHE INTERNAL "")
|
||||
set(WITH_AVFOUNDATION OFF CACHE INTERNAL "")
|
||||
set(WITH_CAP_IOS OFF CACHE INTERNAL "")
|
||||
set(WITH_CAROTENE OFF CACHE INTERNAL "")
|
||||
set(WITH_CLP OFF CACHE INTERNAL "")
|
||||
set(WITH_CPUFEATURES OFF CACHE INTERNAL "")
|
||||
set(WITH_DSHOW OFF CACHE INTERNAL "")
|
||||
set(WITH_EIGEN OFF CACHE INTERNAL "")
|
||||
set(WITH_FFMPEG OFF CACHE INTERNAL "")
|
||||
set(WITH_GDCM OFF CACHE INTERNAL "")
|
||||
set(WITH_GSTREAMER OFF CACHE INTERNAL "")
|
||||
set(WITH_GTK OFF CACHE INTERNAL "")
|
||||
set(WITH_HALIDE OFF CACHE INTERNAL "")
|
||||
set(WITH_HPX OFF CACHE INTERNAL "")
|
||||
# set(WITH_IMGCODEC_HDR OFF CACHE INTERNAL "")
|
||||
# set(WITH_IMGCODEC_PFM OFF CACHE INTERNAL "")
|
||||
# set(WITH_IMGCODEC_PXM OFF CACHE INTERNAL "")
|
||||
# set(WITH_IMGCODEC_SUNRASTER OFF CACHE INTERNAL "")
|
||||
set(WITH_INF_ENGINE OFF CACHE INTERNAL "")
|
||||
set(WITH_IPP OFF CACHE INTERNAL "")
|
||||
set(WITH_ITT OFF CACHE INTERNAL "")
|
||||
set(WITH_JASPER OFF CACHE INTERNAL "")
|
||||
# set(WITH_JPEG OFF CACHE INTERNAL "")
|
||||
set(WITH_MSMF OFF CACHE INTERNAL "")
|
||||
set(WITH_NGRAPH OFF CACHE INTERNAL "")
|
||||
set(WITH_ONNX OFF CACHE INTERNAL "")
|
||||
set(WITH_OPENCL OFF CACHE INTERNAL "")
|
||||
set(WITH_OPENCL_SVM OFF CACHE INTERNAL "")
|
||||
set(WITH_OPENEXR OFF CACHE INTERNAL "")
|
||||
# set(WITH_OPENJPEG OFF CACHE INTERNAL "")
|
||||
set(WITH_OPENMP OFF CACHE INTERNAL "")
|
||||
set(WITH_OPENVX OFF CACHE INTERNAL "")
|
||||
# set(WITH_PNG OFF CACHE INTERNAL "")
|
||||
set(WITH_PROTOBUF OFF CACHE INTERNAL "")
|
||||
set(WITH_PTHREADS_PF OFF CACHE INTERNAL "")
|
||||
set(WITH_QUIRC OFF CACHE INTERNAL "")
|
||||
set(WITH_TBB OFF CACHE INTERNAL "")
|
||||
set(WITH_TENGINE OFF CACHE INTERNAL "")
|
||||
# set(WITH_TIFF OFF CACHE INTERNAL "")
|
||||
set(WITH_V4L OFF CACHE INTERNAL "")
|
||||
set(WITH_VULKAN OFF CACHE INTERNAL "")
|
||||
set(WITH_WEBP OFF CACHE INTERNAL "")
|
||||
set(WITH_WIN32UI OFF CACHE INTERNAL "")
|
||||
|
||||
if (OCOS_ENABLE_OPENCV_CODECS)
|
||||
set(BUILD_OPENJPEG ON CACHE INTERNAL "")
|
||||
set(BUILD_JPEG ON CACHE INTERNAL "")
|
||||
set(BUILD_PNG ON CACHE INTERNAL "")
|
||||
set(WITH_OPENJPEG ON CACHE INTERNAL "")
|
||||
set(WITH_JPEG ON CACHE INTERNAL "")
|
||||
set(WITH_PNG ON CACHE INTERNAL "")
|
||||
set(BUILD_opencv_imgcodecs ON CACHE INTERNAL "")
|
||||
set(BUILD_opencv_imgcodecs ON CACHE INTERNAL "")
|
||||
|
||||
set(BUILD_JPEG ON CACHE INTERNAL "")
|
||||
set(BUILD_OPENJPEG ON CACHE INTERNAL "")
|
||||
set(BUILD_PNG ON CACHE INTERNAL "")
|
||||
set(BUILD_TIFF ON CACHE INTERNAL "")
|
||||
|
||||
set(WITH_IMGCODEC_HDR ON CACHE INTERNAL "")
|
||||
set(WITH_IMGCODEC_PFM ON CACHE INTERNAL "")
|
||||
set(WITH_IMGCODEC_PXM ON CACHE INTERNAL "")
|
||||
set(WITH_IMGCODEC_SUNRASTER ON CACHE INTERNAL "")
|
||||
set(WITH_JPEG ON CACHE INTERNAL "")
|
||||
set(WITH_OPENJPEG ON CACHE INTERNAL "")
|
||||
set(WITH_PNG ON CACHE INTERNAL "")
|
||||
set(WITH_TIFF ON CACHE INTERNAL "")
|
||||
|
||||
set(BUILD_SHARED_LIBS OFF CACHE INTERNAL "")
|
||||
set(BUILD_DOCS OFF CACHE INTERNAL "")
|
||||
set(BUILD_EXAMPLES OFF CACHE INTERNAL "")
|
||||
set(BUILD_TESTS OFF CACHE INTERNAL "")
|
||||
endif()
|
||||
|
||||
|
||||
FetchContent_Declare(
|
||||
opencv
|
||||
GIT_REPOSITORY https://github.com/opencv/opencv.git
|
||||
GIT_TAG 4.5.4
|
||||
GIT_SHALLOW TRUE
|
||||
-DBUILD_DOCS:BOOL=FALSE
|
||||
-DBUILD_EXAMPLES:BOOL=FALSE
|
||||
-DBUILD_TESTS:BOOL=FALSE
|
||||
-DBUILD_SHARED_LIBS:BOOL=FALSE
|
||||
-DCMAKE_INSTALL_PREFIX:PATH=${CMAKE_CURRENT_BINARY_DIR}/opencv
|
||||
-DCV_TRACE:BOOL=FALSE
|
||||
PATCH_COMMAND git checkout . && git apply --whitespace=fix --ignore-space-change --ignore-whitespace ${CMAKE_CURRENT_SOURCE_DIR}/cmake/externals/opencv-no-rtti.patch
|
||||
opencv
|
||||
GIT_REPOSITORY https://github.com/opencv/opencv.git
|
||||
GIT_TAG 4.5.4
|
||||
GIT_SHALLOW TRUE
|
||||
-DBUILD_DOCS:BOOL=FALSE
|
||||
-DBUILD_EXAMPLES:BOOL=FALSE
|
||||
-DBUILD_TESTS:BOOL=FALSE
|
||||
-DBUILD_SHARED_LIBS:BOOL=FALSE
|
||||
-DCMAKE_INSTALL_PREFIX:PATH=${CMAKE_CURRENT_BINARY_DIR}/opencv
|
||||
-DCV_TRACE:BOOL=FALSE
|
||||
PATCH_COMMAND git checkout . && git apply --whitespace=fix --ignore-space-change --ignore-whitespace ${CMAKE_CURRENT_SOURCE_DIR}/cmake/externals/opencv-no-rtti.patch
|
||||
)
|
||||
|
||||
FetchContent_MakeAvailable(opencv)
|
||||
|
|
|
@ -316,6 +316,21 @@ class GaussianBlur(CustomOp):
|
|||
]
|
||||
|
||||
|
||||
class ImageDecoder(CustomOp):
|
||||
|
||||
@classmethod
|
||||
def get_inputs(cls):
|
||||
return [
|
||||
cls.io_def('raw_input_image', onnx_proto.TensorProto.UINT8, [])
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_outputs(cls):
|
||||
return [
|
||||
cls.io_def('decoded_image', onnx_proto.TensorProto.UINT8, [None, None, 3])
|
||||
]
|
||||
|
||||
|
||||
class SingleOpGraph:
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -0,0 +1,85 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <opencv2/core.hpp>
|
||||
#include <opencv2/imgproc.hpp>
|
||||
#include <opencv2/imgcodecs.hpp>
|
||||
|
||||
#include "ocos.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
struct KernelImageDecoder : BaseKernel {
|
||||
KernelImageDecoder(const OrtApi& api) : BaseKernel(api) {}
|
||||
|
||||
void Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* const inputs = ort_.KernelContext_GetInput(context, 0ULL);
|
||||
OrtTensorDimensions dimensions(ort_, inputs);
|
||||
if (dimensions.size() != 1ULL) {
|
||||
ORT_CXX_API_THROW("[ImageDecoder]: Only raw image formats are supported.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
// Get data & the length
|
||||
const uint8_t* const encoded_image_data = ort_.GetTensorData<uint8_t>(inputs);
|
||||
|
||||
OrtTensorTypeAndShapeInfo* const input_info = ort_.GetTensorTypeAndShape(inputs);
|
||||
const int64_t encoded_image_data_len = ort_.GetTensorShapeElementCount(input_info);
|
||||
ort_.ReleaseTensorTypeAndShapeInfo(input_info);
|
||||
|
||||
// Decode the image
|
||||
const std::vector<int32_t> encoded_image_sizes{1, static_cast<int32_t>(encoded_image_data_len)};
|
||||
const cv::Mat encoded_image(encoded_image_sizes, CV_8UC1,
|
||||
const_cast<void*>(static_cast<const void*>(encoded_image_data)));
|
||||
const cv::Mat decoded_image = cv::imdecode(encoded_image, cv::IMREAD_COLOR);
|
||||
|
||||
// Setup output & copy to destination
|
||||
const cv::Size decoded_image_size = decoded_image.size();
|
||||
const int64_t colors = 3;
|
||||
|
||||
const std::vector<int64_t> output_dimensions{decoded_image_size.height, decoded_image_size.width, colors};
|
||||
OrtValue *const output_value = ort_.KernelContext_GetOutput(
|
||||
context, 0, output_dimensions.data(), output_dimensions.size());
|
||||
uint8_t* const decoded_image_data = ort_.GetTensorMutableData<uint8_t>(output_value);
|
||||
memcpy(decoded_image_data, decoded_image.data, decoded_image.total() * decoded_image.elemSize());
|
||||
}
|
||||
};
|
||||
|
||||
struct CustomOpImageDecoder : Ort::CustomOpBase<CustomOpImageDecoder, KernelImageDecoder> {
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return new KernelImageDecoder(api);
|
||||
}
|
||||
|
||||
const char* GetName() const {
|
||||
return "ImageDecoder";
|
||||
}
|
||||
|
||||
size_t GetInputTypeCount() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const {
|
||||
switch (index) {
|
||||
case 0:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
||||
default:
|
||||
ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
}
|
||||
|
||||
size_t GetOutputTypeCount() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const {
|
||||
switch (index) {
|
||||
case 0:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
||||
default:
|
||||
ORT_CXX_API_THROW(MakeString("Unexpected output index ", index), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
}
|
||||
};
|
|
@ -2,6 +2,9 @@
|
|||
#include "gaussian_blur.hpp"
|
||||
#ifdef ENABLE_OPENCV_CODEC
|
||||
#include "imread.hpp"
|
||||
#include "imdecode.hpp"
|
||||
#include "super_resolution_preprocess.hpp"
|
||||
#include "super_resolution_postprocess.hpp"
|
||||
#endif // ENABLE_OPENCV_CODEC
|
||||
|
||||
|
||||
|
@ -10,5 +13,8 @@ FxLoadCustomOpFactory LoadCustomOpClasses_OpenCV =
|
|||
, CustomOpGaussianBlur
|
||||
#ifdef ENABLE_OPENCV_CODEC
|
||||
, CustomOpImageReader
|
||||
, CustomOpImageDecoder
|
||||
, CustomOpSuperResolutionPreProcess
|
||||
, CustomOpSuperResolutionPostProcess
|
||||
#endif // ENABLE_OPENCV_CODEC
|
||||
>;
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "super_resolution_postprocess.hpp"
|
||||
#include "string_utils.h"
|
||||
|
||||
#include <opencv2/core.hpp>
|
||||
#include <opencv2/imgproc.hpp>
|
||||
#include <opencv2/imgcodecs.hpp>
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
KernelSuperResolutionPostProcess::KernelSuperResolutionPostProcess(const OrtApi& api) : BaseKernel(api) {}
|
||||
|
||||
void KernelSuperResolutionPostProcess::Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* const input_y = ort_.KernelContext_GetInput(context, 0ULL);
|
||||
const OrtValue* const input_cr = ort_.KernelContext_GetInput(context, 1ULL);
|
||||
const OrtValue* const input_cb = ort_.KernelContext_GetInput(context, 2ULL);
|
||||
|
||||
const OrtTensorDimensions dimensions_y(ort_, input_y);
|
||||
const OrtTensorDimensions dimensions_cr(ort_, input_cr);
|
||||
const OrtTensorDimensions dimensions_cb(ort_, input_cb);
|
||||
if ((dimensions_y.size() != 4ULL) || (dimensions_cr.size() != 4ULL) || (dimensions_cb.size() != 4ULL)) {
|
||||
throw std::runtime_error("Expecting 3 channels y, cr, and cb.");
|
||||
}
|
||||
|
||||
// Get data & the length
|
||||
const float* const channel_y_data = ort_.GetTensorData<float>(input_y);
|
||||
const float* const channel_cr_data = ort_.GetTensorData<float>(input_cr);
|
||||
const float* const channel_cb_data = ort_.GetTensorData<float>(input_cb);
|
||||
|
||||
cv::Mat y(
|
||||
std::vector<int32_t>{static_cast<int32_t>(dimensions_y[2]), static_cast<int32_t>(dimensions_y[3])},
|
||||
CV_32F, const_cast<void*>(static_cast<const void*>(channel_y_data)));
|
||||
cv::Mat cr(
|
||||
std::vector<int32_t>{static_cast<int32_t>(dimensions_cr[2]), static_cast<int32_t>(dimensions_cr[3])},
|
||||
CV_32F, const_cast<void*>(static_cast<const void*>(channel_cr_data)));
|
||||
cv::Mat cb(
|
||||
std::vector<int32_t>{static_cast<int32_t>(dimensions_cb[2]), static_cast<int32_t>(dimensions_cb[3])},
|
||||
CV_32F, const_cast<void*>(static_cast<const void*>(channel_cb_data)));
|
||||
|
||||
// Scale the individual channels
|
||||
y *= 255.0;
|
||||
cv::resize(cr, cr, y.size(), 0, 0, cv::INTER_CUBIC);
|
||||
cv::resize(cb, cb, y.size(), 0, 0, cv::INTER_CUBIC);
|
||||
|
||||
// Merge the channels
|
||||
const cv::Mat channels[] = {y, cr, cb};
|
||||
cv::Mat ycrcb_image;
|
||||
cv::merge(channels, 3, ycrcb_image);
|
||||
|
||||
// Convert it back to BGR format
|
||||
cv::Mat bgr_image;
|
||||
cv::cvtColor(ycrcb_image, bgr_image, cv::COLOR_YCrCb2BGR);
|
||||
|
||||
// Encode it as jpg
|
||||
std::vector<uchar> encoded_image;
|
||||
cv::imencode(".jpg", bgr_image, encoded_image);
|
||||
|
||||
// Setup output & copy to destination
|
||||
const std::vector<int64_t> output_dimensions{1LL, static_cast<int64_t>(encoded_image.size())};
|
||||
OrtValue* const output_value = ort_.KernelContext_GetOutput(
|
||||
context, 0, output_dimensions.data(), output_dimensions.size());
|
||||
float* const data = ort_.GetTensorMutableData<float>(output_value);
|
||||
memcpy(data, encoded_image.data(), encoded_image.size());
|
||||
}
|
||||
|
||||
const char* CustomOpSuperResolutionPostProcess::GetName() const {
|
||||
return "SuperResolutionPostProcess";
|
||||
}
|
||||
|
||||
size_t CustomOpSuperResolutionPostProcess::GetInputTypeCount() const {
|
||||
return 3;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType CustomOpSuperResolutionPostProcess::GetInputType(size_t index) const {
|
||||
switch (index) {
|
||||
case 0:
|
||||
case 1:
|
||||
case 2:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
default:
|
||||
ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
}
|
||||
|
||||
size_t CustomOpSuperResolutionPostProcess::GetOutputTypeCount() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType CustomOpSuperResolutionPostProcess::GetOutputType(size_t index) const {
|
||||
switch (index) {
|
||||
case 0:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
||||
default:
|
||||
ORT_CXX_API_THROW(MakeString("Unexpected output index ", index), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ocos.h"
|
||||
|
||||
struct KernelSuperResolutionPostProcess : BaseKernel {
|
||||
KernelSuperResolutionPostProcess(const OrtApi& api);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
struct CustomOpSuperResolutionPostProcess : Ort::CustomOpBase<CustomOpSuperResolutionPostProcess, KernelSuperResolutionPostProcess> {
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
};
|
|
@ -0,0 +1,91 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "super_resolution_preprocess.hpp"
|
||||
#include "string_utils.h"
|
||||
|
||||
#include <opencv2/core.hpp>
|
||||
#include <opencv2/imgproc.hpp>
|
||||
#include <opencv2/imgcodecs.hpp>
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
KernelSuperResolutionPreProcess::KernelSuperResolutionPreProcess(const OrtApi& api) : BaseKernel(api) {}
|
||||
|
||||
void KernelSuperResolutionPreProcess::Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* const inputs = ort_.KernelContext_GetInput(context, 0ULL);
|
||||
OrtTensorDimensions dimensions(ort_, inputs);
|
||||
if (dimensions.size() != 1ULL) {
|
||||
throw std::runtime_error("Only raw image formats are supported.");
|
||||
}
|
||||
|
||||
// Get data & the length
|
||||
const uint8_t* const encoded_bgr_image_data = ort_.GetTensorData<uint8_t>(inputs);
|
||||
|
||||
OrtTensorTypeAndShapeInfo* const input_info = ort_.GetTensorTypeAndShape(inputs);
|
||||
const int64_t encoded_bgr_image_data_len = ort_.GetTensorShapeElementCount(input_info);
|
||||
ort_.ReleaseTensorTypeAndShapeInfo(input_info);
|
||||
|
||||
// Decode the image
|
||||
const std::vector<int32_t> encoded_bgr_image_sizes{1, static_cast<int32_t>(encoded_bgr_image_data_len)};
|
||||
const cv::Mat encoded_bgr_image(encoded_bgr_image_sizes, CV_8UC1,
|
||||
const_cast<void*>(static_cast<const void*>(encoded_bgr_image_data)));
|
||||
// OpenCV decodes images in BGR format.
|
||||
// Ref: https://stackoverflow.com/a/44359400
|
||||
const cv::Mat decoded_bgr_image = cv::imdecode(encoded_bgr_image, cv::IMREAD_COLOR);
|
||||
|
||||
cv::Mat normalized_bgr_image;
|
||||
decoded_bgr_image.convertTo(normalized_bgr_image, CV_32F);
|
||||
|
||||
cv::Mat ycrcb_image;
|
||||
cv::cvtColor(normalized_bgr_image, ycrcb_image, cv::COLOR_BGR2YCrCb);
|
||||
|
||||
cv::Mat channels[3];
|
||||
cv::split(ycrcb_image, channels);
|
||||
channels[0] /= 255.0;
|
||||
|
||||
// Setup output & copy to destination
|
||||
for (int32_t i = 0; i < 3; ++i) {
|
||||
const cv::Mat& channel = channels[i];
|
||||
const cv::Size size = channel.size();
|
||||
|
||||
const std::vector<int64_t> output_dimensions{1LL, 1LL, size.height, size.width};
|
||||
OrtValue* const output_value = ort_.KernelContext_GetOutput(
|
||||
context, i, output_dimensions.data(), output_dimensions.size());
|
||||
float* const data = ort_.GetTensorMutableData<float>(output_value);
|
||||
memcpy(data, channel.data, channel.total() * channel.elemSize());
|
||||
}
|
||||
}
|
||||
|
||||
const char* CustomOpSuperResolutionPreProcess::GetName() const {
|
||||
return "SuperResolutionPreProcess";
|
||||
}
|
||||
|
||||
size_t CustomOpSuperResolutionPreProcess::GetInputTypeCount() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType CustomOpSuperResolutionPreProcess::GetInputType(size_t index) const {
|
||||
switch (index) {
|
||||
case 0:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
||||
default:
|
||||
ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
}
|
||||
|
||||
size_t CustomOpSuperResolutionPreProcess::GetOutputTypeCount() const {
|
||||
return 3;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType CustomOpSuperResolutionPreProcess::GetOutputType(size_t index) const {
|
||||
switch (index) {
|
||||
case 0:
|
||||
case 1:
|
||||
case 2:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
default:
|
||||
ORT_CXX_API_THROW(MakeString("Unexpected output index ", index), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ocos.h"
|
||||
|
||||
struct KernelSuperResolutionPreProcess : BaseKernel {
|
||||
KernelSuperResolutionPreProcess(const OrtApi& api);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
struct CustomOpSuperResolutionPreProcess : Ort::CustomOpBase<CustomOpSuperResolutionPreProcess, KernelSuperResolutionPreProcess> {
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
};
|
Двоичный файл не отображается.
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 111 KiB |
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 8.5 KiB |
|
@ -44,6 +44,31 @@ class TestOpenCV(unittest.TestCase):
|
|||
# convimg.save('temp_pineapple.jpg')
|
||||
self.assertFalse(np.allclose(np.asarray(img), np.asarray(convimg)))
|
||||
|
||||
def test_image_decoder(self):
|
||||
input_image_file = util.get_test_data_file("data", "test_colors.jpg")
|
||||
|
||||
model = OrtPyFunction.from_customop("ImageDecoder")
|
||||
input_data = open(input_image_file, 'rb').read()
|
||||
raw_input_image = np.frombuffer(input_data, dtype=np.uint8)
|
||||
|
||||
actual = model(raw_input_image)
|
||||
actual = np.asarray(actual, dtype=np.uint8)
|
||||
self.assertEqual(actual.shape[2], 3)
|
||||
|
||||
expected = Image.open(input_image_file).convert('RGB')
|
||||
expected = np.asarray(expected, dtype=np.uint8).copy()
|
||||
|
||||
# Convert the image to BGR format since cv2 is default BGR format.
|
||||
red = expected[:,:,0].copy()
|
||||
expected[:,:,0] = expected[:,:,2].copy()
|
||||
expected[:,:,2] = red
|
||||
|
||||
self.assertEqual(actual.shape[0], expected.shape[0])
|
||||
self.assertEqual(actual.shape[1], expected.shape[1])
|
||||
self.assertEqual(actual.shape[2], expected.shape[2])
|
||||
|
||||
self.assertTrue(np.allclose(actual, expected, atol=1))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -0,0 +1,124 @@
|
|||
import io
|
||||
import numpy as np
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
import unittest
|
||||
from PIL import Image
|
||||
from onnxruntime_extensions import OrtPyFunction, util
|
||||
|
||||
|
||||
_input_image_filepath = util.get_test_data_file("data", "test_supres.jpg")
|
||||
_onnx_model_filepath = util.get_test_data_file("data", "supres.onnx")
|
||||
_torch_model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'
|
||||
|
||||
|
||||
class SuperResolutionNet(nn.Module):
|
||||
def __init__(self, upscale_factor, inplace=False):
|
||||
super(SuperResolutionNet, self).__init__()
|
||||
|
||||
self.relu = nn.ReLU(inplace=inplace)
|
||||
self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
|
||||
self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
|
||||
self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
|
||||
self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
|
||||
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
||||
|
||||
self._initialize_weights()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu(self.conv1(x))
|
||||
x = self.relu(self.conv2(x))
|
||||
x = self.relu(self.conv3(x))
|
||||
|
||||
return self.pixel_shuffle(self.conv4(x))
|
||||
|
||||
def _initialize_weights(self):
|
||||
nn.init.orthogonal_(self.conv1.weight, nn.init.calculate_gain('relu'))
|
||||
nn.init.orthogonal_(self.conv2.weight, nn.init.calculate_gain('relu'))
|
||||
nn.init.orthogonal_(self.conv3.weight, nn.init.calculate_gain('relu'))
|
||||
nn.init.orthogonal_(self.conv4.weight)
|
||||
|
||||
|
||||
def _run_torch_inferencing():
|
||||
# Create the super-resolution model by using the above model definition.
|
||||
torch_model = SuperResolutionNet(upscale_factor=3)
|
||||
|
||||
# Initialize & load model with the pretrained weights
|
||||
map_location = lambda storage, loc: storage
|
||||
if torch.cuda.is_available():
|
||||
map_location = None
|
||||
torch_model.load_state_dict(model_zoo.load_url(_torch_model_url, map_location=map_location))
|
||||
|
||||
# set the model to inferencing mode
|
||||
torch_model.eval()
|
||||
|
||||
input_image_ycbcr = Image.open(_input_image_filepath).convert('YCbCr')
|
||||
input_image_y, input_image_cb, input_image_cr = input_image_ycbcr.split()
|
||||
input_image_y = torch.from_numpy(np.asarray(input_image_y, dtype=np.uint8)).float()
|
||||
input_image_y /= 255.0
|
||||
input_image_y = input_image_y.view(1, -1, input_image_y.shape[1], input_image_y.shape[0])
|
||||
|
||||
output_image_y = torch_model(input_image_y)
|
||||
output_image_y = output_image_y.detach().cpu().numpy()
|
||||
output_image_y = Image.fromarray(np.uint8((output_image_y[0] * 255.0).clip(0, 255)[0]), mode='L')
|
||||
|
||||
# get the output image follow post-processing step from PyTorch implementation
|
||||
output_image = Image.merge(
|
||||
"YCbCr", [
|
||||
output_image_y,
|
||||
input_image_cb.resize(output_image_y.size, Image.BICUBIC),
|
||||
input_image_cr.resize(output_image_y.size, Image.BICUBIC),
|
||||
]).convert("RGB")
|
||||
|
||||
# Uncomment to create a local file
|
||||
#
|
||||
# output_image_filepath = util.get_test_data_file("data", "test_supres_torch.jpg")
|
||||
# output_image.save(output_image_filepath)
|
||||
|
||||
output_image = np.asarray(output_image, dtype=np.uint8)
|
||||
return output_image
|
||||
|
||||
|
||||
def _run_onnx_inferencing():
|
||||
encoded_input_image = open(_input_image_filepath, 'rb').read()
|
||||
encoded_input_image = np.frombuffer(encoded_input_image, dtype=np.uint8)
|
||||
|
||||
onnx_model = OrtPyFunction.from_model(_onnx_model_filepath)
|
||||
encoded_output_image = onnx_model(encoded_input_image)
|
||||
|
||||
encoded_output_image = encoded_output_image.tobytes()
|
||||
|
||||
# Uncomment to create a local file
|
||||
#
|
||||
# output_image_filepath = util.get_test_data_file("data", "test_supres_onnx.jpg")
|
||||
# with open(output_image_filepath, 'wb') as strm:
|
||||
# strm.write(encoded_output_image)
|
||||
# strm.flush()
|
||||
|
||||
with io.BytesIO(encoded_output_image) as strm:
|
||||
decoded_output_image = Image.open(strm).convert('RGB')
|
||||
|
||||
decoded_output_image = np.asarray(decoded_output_image, dtype=np.uint8)
|
||||
return decoded_output_image
|
||||
|
||||
|
||||
class TestSupres(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
pass
|
||||
|
||||
def test_e2e(self):
|
||||
actual = _run_onnx_inferencing()
|
||||
expected = _run_torch_inferencing()
|
||||
|
||||
self.assertEqual(actual.shape[0], expected.shape[0])
|
||||
self.assertEqual(actual.shape[1], expected.shape[1])
|
||||
self.assertEqual(actual.shape[2], expected.shape[2])
|
||||
|
||||
self.assertTrue(np.allclose(actual, expected, atol=20))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 10 KiB |
|
@ -0,0 +1,199 @@
|
|||
import io
|
||||
import numpy as np
|
||||
import onnx
|
||||
import onnxruntime
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
|
||||
# ref: https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html
|
||||
|
||||
|
||||
_this_dirpath = os.path.dirname(os.path.abspath(__file__))
|
||||
_data_dirpath = os.path.join(_this_dirpath, 'data')
|
||||
_onnx_model_filepath = os.path.join(_data_dirpath, 'supres.onnx')
|
||||
|
||||
_domain = 'ai.onnx.contrib'
|
||||
_opset_version = 11
|
||||
|
||||
|
||||
class SuperResolutionNet(nn.Module):
|
||||
def __init__(self, upscale_factor, inplace=False):
|
||||
super(SuperResolutionNet, self).__init__()
|
||||
|
||||
self.relu = nn.ReLU(inplace=inplace)
|
||||
self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
|
||||
self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
|
||||
self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
|
||||
self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
|
||||
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
||||
|
||||
self._initialize_weights()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu(self.conv1(x))
|
||||
x = self.relu(self.conv2(x))
|
||||
x = self.relu(self.conv3(x))
|
||||
|
||||
return self.pixel_shuffle(self.conv4(x))
|
||||
|
||||
def _initialize_weights(self):
|
||||
nn.init.orthogonal_(self.conv1.weight, nn.init.calculate_gain('relu'))
|
||||
nn.init.orthogonal_(self.conv2.weight, nn.init.calculate_gain('relu'))
|
||||
nn.init.orthogonal_(self.conv3.weight, nn.init.calculate_gain('relu'))
|
||||
nn.init.orthogonal_(self.conv4.weight)
|
||||
|
||||
|
||||
# Create the super-resolution model by using the above model definition.
|
||||
torch_model = SuperResolutionNet(upscale_factor=3)
|
||||
|
||||
# Load pretrained model weights
|
||||
model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'
|
||||
batch_size = 1 # just a random number
|
||||
|
||||
# Initialize model with the pretrained weights
|
||||
map_location = lambda storage, loc: storage
|
||||
if torch.cuda.is_available():
|
||||
map_location = None
|
||||
torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location))
|
||||
|
||||
# set the model to inference mode
|
||||
torch_model.eval()
|
||||
|
||||
# Input to the model
|
||||
input = torch.randn(batch_size, 1, 224, 224, requires_grad=True)
|
||||
|
||||
# Export the model
|
||||
with io.BytesIO() as strm:
|
||||
torch.onnx.export(torch_model, # model being run
|
||||
input, # model input (or a tuple for multiple inputs)
|
||||
strm, # where to save the model (can be a file or file-like object)
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=10, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names = ['input'], # the model's input names
|
||||
output_names = ['output'], # the model's output names
|
||||
dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes
|
||||
'output' : {0 : 'batch_size'}})
|
||||
|
||||
onnx_model = onnx.load_model_from_string(strm.getvalue())
|
||||
onnx.checker.check_model(onnx_model)
|
||||
|
||||
ort_session = onnxruntime.InferenceSession(strm.getvalue())
|
||||
|
||||
torch_out = torch_model(input)
|
||||
|
||||
def to_numpy(tensor):
|
||||
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
|
||||
|
||||
# compute ONNX Runtime output prediction
|
||||
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(input)}
|
||||
ort_outputs = ort_session.run(None, ort_inputs)
|
||||
|
||||
# compare ONNX Runtime and PyTorch results
|
||||
np.testing.assert_allclose(to_numpy(torch_out), ort_outputs[0], rtol=1e-03, atol=1e-05)
|
||||
|
||||
|
||||
# Generate an model pipeline with pre/post nodes
|
||||
mkv = onnx.helper.make_tensor_value_info
|
||||
onnx_opsetids = [
|
||||
onnx.helper.make_opsetid('', _opset_version),
|
||||
onnx.helper.make_opsetid(_domain, _opset_version)
|
||||
]
|
||||
|
||||
# Create custom op node for pre-processing
|
||||
preprocess_node = onnx.helper.make_node(
|
||||
'SuperResolutionPreProcess',
|
||||
inputs=['raw_input_image'],
|
||||
outputs=['input', 'cr', 'cb'],
|
||||
name='Preprocess',
|
||||
doc_string='Preprocessing node',
|
||||
domain=_domain)
|
||||
|
||||
process_model = onnx_model
|
||||
process_model.opset_import.pop()
|
||||
process_model.opset_import.extend(onnx_opsetids)
|
||||
onnx.checker.check_model(process_model)
|
||||
process_graph = process_model.graph
|
||||
|
||||
# Create custom op node for post-processing
|
||||
postprocess_node = onnx.helper.make_node(
|
||||
'SuperResolutionPostProcess',
|
||||
inputs=['output', 'cr', 'cb'],
|
||||
outputs=['raw_output_image'],
|
||||
name='Postprocess',
|
||||
doc_string='Postprocessing node',
|
||||
domain=_domain)
|
||||
|
||||
inputs = [mkv('raw_input_image', onnx.onnx_pb.TensorProto.UINT8, [])]
|
||||
outputs = [mkv('raw_output_image', onnx.onnx_pb.TensorProto.UINT8, [])]
|
||||
nodes = [preprocess_node] + list(process_graph.node) + [postprocess_node]
|
||||
graph = onnx.helper.make_graph(
|
||||
nodes, 'supres_graph', inputs, outputs,
|
||||
initializer=list(process_graph.initializer))
|
||||
onnx_model = onnx.helper.make_model(graph)
|
||||
onnx_model.opset_import.pop()
|
||||
onnx_model.opset_import.extend(onnx_opsetids)
|
||||
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
|
||||
|
||||
onnx_model_as_string = str(onnx_model)
|
||||
onnx_model_as_text = onnx.helper.printable_graph(onnx_model.graph)
|
||||
|
||||
if 'op_type: "SuperResolutionPreProcess"' not in onnx_model_as_string:
|
||||
raise "Failed to add pre-process to onnx graph"
|
||||
|
||||
if 'op_type: "SuperResolutionPostProcess"' not in onnx_model_as_string:
|
||||
raise "Failed to add post-process to onnx graph"
|
||||
|
||||
if 'SuperResolutionPreProcess(%raw_input_image)' not in onnx_model_as_text:
|
||||
raise "Failed to add pre-process to onnx graph"
|
||||
|
||||
if 'SuperResolutionPostProcess(%output, %cr, %cb)' not in onnx_model_as_text:
|
||||
raise "Failed to add post-process to onnx graph"
|
||||
|
||||
onnx.checker.check_model(onnx_model)
|
||||
onnx.save(onnx_model, _onnx_model_filepath)
|
||||
|
||||
# Test with a inferencing session
|
||||
import numpy as np
|
||||
from onnxruntime_extensions import OrtPyFunction
|
||||
|
||||
_input_image_filepath = os.path.join(_data_dirpath, 'cat_224x224.jpg')
|
||||
_output_image_filepath = os.path.join(_data_dirpath, 'cat_672x672.jpg')
|
||||
|
||||
encoded_input_image = open(_input_image_filepath, 'rb').read()
|
||||
raw_input_image = np.frombuffer(encoded_input_image, dtype=np.uint8)
|
||||
|
||||
model_func = OrtPyFunction.from_model(_onnx_model_filepath)
|
||||
raw_output_image = model_func(raw_input_image)
|
||||
|
||||
encoded_output_image = raw_output_image.tobytes()
|
||||
with open(_output_image_filepath, 'wb') as strm:
|
||||
strm.write(encoded_output_image)
|
||||
strm.flush()
|
||||
|
||||
'''
|
||||
Steps to integrate the generated model in Android app:
|
||||
|
||||
Assuming the app is using extensions as a separate
|
||||
binary rather than embedding extensions into the ORT build.
|
||||
|
||||
1. Drop the generated extensions binary (libortcustomops.so) into your app's
|
||||
resources folder (usually app/src/main/jniLibs/armeabi-v7a)
|
||||
|
||||
2. When creating an OrtSession, add the following statement to register extensions
|
||||
|
||||
val options = OrtSession.SessionOptions()
|
||||
options.registerCustomOpLibrary("libortcustomops.so")
|
||||
val session = ortEnv?.createSession(model, options)
|
||||
|
||||
3. Call OrtSession.run to generate the output.
|
||||
|
||||
val rawImageData = // raw image data in bytes
|
||||
val shape = longArrayOf(rawImageData.size.toLong())
|
||||
val tensor = OnnxTensor.createTensor(env, ByteBuffer.wrap(rawImageData), shape, OnnxJavaType.UINT8)
|
||||
val output = session?.run(Collections.singletonMap("input", tensor))
|
||||
|
||||
"output" is the jpeg encoded high resolution image.
|
||||
'''
|
Загрузка…
Ссылка в новой задаче