OpenCV Image Decoder & SuperResolution CustomOps

This commit is contained in:
shaahji 2022-09-23 11:16:44 -07:00 коммит произвёл shaahji
Родитель d29f6d0f42
Коммит 78d8dd5705
16 изменённых файлов: 800 добавлений и 102 удалений

Просмотреть файл

@ -35,7 +35,7 @@ option(OCOS_ENABLE_BLINGFIRE "Enable operators depending on the Blingfire librar
option(OCOS_ENABLE_MATH "Enable math tensor operators building" ON)
option(OCOS_ENABLE_DLIB "Enable operators like Inverse depending on DLIB" ON)
option(OCOS_ENABLE_OPENCV "Enable operators depending on opencv" ON)
option(OCOS_ENABLE_OPENCV_CODECS "Enable operators depending on opencv imgcodecs" OFF)
option(OCOS_ENABLE_OPENCV_CODECS "Enable operators depending on opencv imgcodecs" ON)
option(OCOS_ENABLE_STATIC_LIB "Enable generating static library" OFF)
option(OCOS_ENABLE_SELECTED_OPLIST "Enable including the selected_ops tool file" OFF)
@ -265,7 +265,7 @@ if (OCOS_ENABLE_OPENCV)
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_OPENCV ENABLE_OPENCV_CODEC)
endif()
list(APPEND ocos_libraries ${opencv_LIBS})
target_include_directories(ocos_operators PRIVATE ${opencv_INCLUDE_DIRS})
target_include_directories(ocos_operators PUBLIC ${opencv_INCLUDE_DIRS})
endif()
if (OCOS_ENABLE_GPT2_TOKENIZER)

216
cmake/externals/opencv.cmake поставляемый
Просмотреть файл

@ -1,108 +1,124 @@
set(BUILD_TIFF OFF CACHE INTERNAL "")
set(BUILD_OPENJPEG OFF CACHE INTERNAL "")
set(BUILD_JASPER OFF CACHE INTERNAL "")
set(BUILD_JPEG OFF CACHE INTERNAL "")
set(BUILD_PNG OFF CACHE INTERNAL "")
set(BUILD_OPENEXR OFF CACHE INTERNAL "")
set(BUILD_WEBP OFF CACHE INTERNAL "")
set(BUILD_TBB OFF CACHE INTERNAL "")
set(BUILD_IPP_IW OFF CACHE INTERNAL "")
set(BUILD_ITT OFF CACHE INTERNAL "")
set(WITH_AVFOUNDATION OFF CACHE INTERNAL "")
set(WITH_CAP_IOS OFF CACHE INTERNAL "")
set(WITH_CAROTENE OFF CACHE INTERNAL "")
set(WITH_CPUFEATURES OFF CACHE INTERNAL "")
set(WITH_EIGEN OFF CACHE INTERNAL "")
set(WITH_FFMPEG OFF CACHE INTERNAL "")
set(WITH_GSTREAMER OFF CACHE INTERNAL "")
set(WITH_GTK OFF CACHE INTERNAL "")
set(WITH_IPP OFF CACHE INTERNAL "")
set(WITH_HALIDE OFF CACHE INTERNAL "")
set(WITH_VULKAN OFF CACHE INTERNAL "")
set(WITH_INF_ENGINE OFF CACHE INTERNAL "")
set(WITH_NGRAPH OFF CACHE INTERNAL "")
set(WITH_JASPER OFF CACHE INTERNAL "")
set(WITH_OPENJPEG OFF CACHE INTERNAL "")
set(WITH_JPEG OFF CACHE INTERNAL "")
set(WITH_WEBP OFF CACHE INTERNAL "")
set(WITH_OPENEXR OFF CACHE INTERNAL "")
set(WITH_PNG OFF CACHE INTERNAL "")
set(WITH_TIFF OFF CACHE INTERNAL "")
set(WITH_OPENVX OFF CACHE INTERNAL "")
set(WITH_GDCM OFF CACHE INTERNAL "")
set(WITH_TBB OFF CACHE INTERNAL "")
set(WITH_HPX OFF CACHE INTERNAL "")
set(WITH_OPENMP OFF CACHE INTERNAL "")
set(WITH_PTHREADS_PF OFF CACHE INTERNAL "")
set(WITH_V4L OFF CACHE INTERNAL "")
set(WITH_CLP OFF CACHE INTERNAL "")
set(WITH_OPENCL OFF CACHE INTERNAL "")
set(WITH_OPENCL_SVM OFF CACHE INTERNAL "")
set(WITH_ITT OFF CACHE INTERNAL "")
set(WITH_PROTOBUF OFF CACHE INTERNAL "")
set(WITH_IMGCODEC_HDR OFF CACHE INTERNAL "")
set(WITH_IMGCODEC_SUNRASTER OFF CACHE INTERNAL "")
set(WITH_IMGCODEC_PXM OFF CACHE INTERNAL "")
set(WITH_IMGCODEC_PFM OFF CACHE INTERNAL "")
set(WITH_QUIRC OFF CACHE INTERNAL "")
set(WITH_ANDROID_MEDIANDK OFF CACHE INTERNAL "")
set(WITH_TENGINE OFF CACHE INTERNAL "")
set(WITH_ONNX OFF CACHE INTERNAL "")
set(BUILD_opencv_apps OFF CACHE INTERNAL "")
set(BUILD_ANDROID_PROJECTS OFF CACHE INTERNAL "")
set(BUILD_ANDROID_EXAMPLES OFF CACHE INTERNAL "")
set(BUILD_DOCS OFF CACHE INTERNAL "")
set(BUILD_WITH_STATIC_CRT OFF CACHE INTERNAL "")
set(BUILD_FAT_JAVA_LIB OFF CACHE INTERNAL "")
set(BUILD_ANDROID_SERVICE OFF CACHE INTERNAL "")
set(BUILD_JAVA OFF CACHE INTERNAL "")
set(BUILD_OBJC OFF CACHE INTERNAL "")
set(ENABLE_PRECOMPILED_HEADERS OFF CACHE INTERNAL "")
set(ENABLE_FAST_MATH OFF CACHE INTERNAL "")
set(BUILD_opencv_java OFF CACHE INTERNAL "")
set(BUILD_opencv_gapi OFF CACHE INTERNAL "")
set(BUILD_opencv_objc OFF CACHE INTERNAL "")
set(BUILD_opencv_js OFF CACHE INTERNAL "")
set(BUILD_opencv_ts OFF CACHE INTERNAL "")
set(BUILD_opencv_features2d OFF CACHE INTERNAL "")
set(BUILD_opencv_photo OFF CACHE INTERNAL "")
set(BUILD_opencv_video OFF CACHE INTERNAL "")
set(BUILD_opencv_python2 OFF CACHE INTERNAL "")
set(BUILD_opencv_python3 OFF CACHE INTERNAL "")
set(BUILD_opencv_dnn OFF CACHE INTERNAL "")
set(BUILD_opencv_imgcodecs OFF CACHE INTERNAL "")
set(BUILD_opencv_videoio OFF CACHE INTERNAL "")
set(BUILD_opencv_calib3d OFF CACHE INTERNAL "")
set(BUILD_opencv_highgui OFF CACHE INTERNAL "")
set(BUILD_opencv_flann OFF CACHE INTERNAL "")
set(BUILD_opencv_objdetect OFF CACHE INTERNAL "")
set(BUILD_opencv_stitching OFF CACHE INTERNAL "")
set(BUILD_opencv_ml OFF CACHE INTERNAL "")
set(BUILD_opencv_world OFF CACHE INTERNAL "")
set(BUILD_ANDROID_EXAMPLES OFF CACHE INTERNAL "")
set(BUILD_ANDROID_PROJECTS OFF CACHE INTERNAL "")
set(BUILD_ANDROID_SERVICE OFF CACHE INTERNAL "")
set(BUILD_DOCS OFF CACHE INTERNAL "")
set(BUILD_FAT_JAVA_LIB OFF CACHE INTERNAL "")
set(BUILD_IPP_IW OFF CACHE INTERNAL "")
set(BUILD_ITT OFF CACHE INTERNAL "")
set(BUILD_JASPER OFF CACHE INTERNAL "")
set(BUILD_JAVA OFF CACHE INTERNAL "")
# set(BUILD_JPEG OFF CACHE INTERNAL "")
set(BUILD_OBJC OFF CACHE INTERNAL "")
# set(BUILD_OPENJPEG OFF CACHE INTERNAL "")
# set(BUILD_PNG OFF CACHE INTERNAL "")
set(BUILD_opencv_apps OFF CACHE INTERNAL "")
set(BUILD_opencv_calib3d OFF CACHE INTERNAL "")
set(BUILD_opencv_dnn OFF CACHE INTERNAL "")
set(BUILD_opencv_features2d OFF CACHE INTERNAL "")
set(BUILD_opencv_flann OFF CACHE INTERNAL "")
set(BUILD_opencv_gapi OFF CACHE INTERNAL "")
set(BUILD_opencv_highgui OFF CACHE INTERNAL "")
# set(BUILD_opencv_imgcodecs OFF CACHE INTERNAL "")
set(BUILD_opencv_java OFF CACHE INTERNAL "")
set(BUILD_opencv_js OFF CACHE INTERNAL "")
set(BUILD_opencv_ml OFF CACHE INTERNAL "")
set(BUILD_opencv_objc OFF CACHE INTERNAL "")
set(BUILD_opencv_objdetect OFF CACHE INTERNAL "")
set(BUILD_opencv_photo OFF CACHE INTERNAL "")
set(BUILD_opencv_python2 OFF CACHE INTERNAL "")
set(BUILD_opencv_python3 OFF CACHE INTERNAL "")
set(BUILD_opencv_stitching OFF CACHE INTERNAL "")
set(BUILD_opencv_ts OFF CACHE INTERNAL "")
set(BUILD_opencv_video OFF CACHE INTERNAL "")
set(BUILD_opencv_videoio OFF CACHE INTERNAL "")
set(BUILD_opencv_world OFF CACHE INTERNAL "")
set(BUILD_OPENEXR OFF CACHE INTERNAL "")
set(BUILD_TBB OFF CACHE INTERNAL "")
set(BUILD_TIFF OFF CACHE INTERNAL "")
set(BUILD_WEBP OFF CACHE INTERNAL "")
set(BUILD_WITH_STATIC_CRT OFF CACHE INTERNAL "")
set(BUILD_ZLIB OFF CACHE INTERNAL "")
set(ENABLE_FAST_MATH OFF CACHE INTERNAL "")
set(ENABLE_PRECOMPILED_HEADERS OFF CACHE INTERNAL "")
set(WITH_ANDROID_MEDIANDK OFF CACHE INTERNAL "")
set(WITH_AVFOUNDATION OFF CACHE INTERNAL "")
set(WITH_CAP_IOS OFF CACHE INTERNAL "")
set(WITH_CAROTENE OFF CACHE INTERNAL "")
set(WITH_CLP OFF CACHE INTERNAL "")
set(WITH_CPUFEATURES OFF CACHE INTERNAL "")
set(WITH_DSHOW OFF CACHE INTERNAL "")
set(WITH_EIGEN OFF CACHE INTERNAL "")
set(WITH_FFMPEG OFF CACHE INTERNAL "")
set(WITH_GDCM OFF CACHE INTERNAL "")
set(WITH_GSTREAMER OFF CACHE INTERNAL "")
set(WITH_GTK OFF CACHE INTERNAL "")
set(WITH_HALIDE OFF CACHE INTERNAL "")
set(WITH_HPX OFF CACHE INTERNAL "")
# set(WITH_IMGCODEC_HDR OFF CACHE INTERNAL "")
# set(WITH_IMGCODEC_PFM OFF CACHE INTERNAL "")
# set(WITH_IMGCODEC_PXM OFF CACHE INTERNAL "")
# set(WITH_IMGCODEC_SUNRASTER OFF CACHE INTERNAL "")
set(WITH_INF_ENGINE OFF CACHE INTERNAL "")
set(WITH_IPP OFF CACHE INTERNAL "")
set(WITH_ITT OFF CACHE INTERNAL "")
set(WITH_JASPER OFF CACHE INTERNAL "")
# set(WITH_JPEG OFF CACHE INTERNAL "")
set(WITH_MSMF OFF CACHE INTERNAL "")
set(WITH_NGRAPH OFF CACHE INTERNAL "")
set(WITH_ONNX OFF CACHE INTERNAL "")
set(WITH_OPENCL OFF CACHE INTERNAL "")
set(WITH_OPENCL_SVM OFF CACHE INTERNAL "")
set(WITH_OPENEXR OFF CACHE INTERNAL "")
# set(WITH_OPENJPEG OFF CACHE INTERNAL "")
set(WITH_OPENMP OFF CACHE INTERNAL "")
set(WITH_OPENVX OFF CACHE INTERNAL "")
# set(WITH_PNG OFF CACHE INTERNAL "")
set(WITH_PROTOBUF OFF CACHE INTERNAL "")
set(WITH_PTHREADS_PF OFF CACHE INTERNAL "")
set(WITH_QUIRC OFF CACHE INTERNAL "")
set(WITH_TBB OFF CACHE INTERNAL "")
set(WITH_TENGINE OFF CACHE INTERNAL "")
# set(WITH_TIFF OFF CACHE INTERNAL "")
set(WITH_V4L OFF CACHE INTERNAL "")
set(WITH_VULKAN OFF CACHE INTERNAL "")
set(WITH_WEBP OFF CACHE INTERNAL "")
set(WITH_WIN32UI OFF CACHE INTERNAL "")
if (OCOS_ENABLE_OPENCV_CODECS)
set(BUILD_OPENJPEG ON CACHE INTERNAL "")
set(BUILD_JPEG ON CACHE INTERNAL "")
set(BUILD_PNG ON CACHE INTERNAL "")
set(WITH_OPENJPEG ON CACHE INTERNAL "")
set(WITH_JPEG ON CACHE INTERNAL "")
set(WITH_PNG ON CACHE INTERNAL "")
set(BUILD_opencv_imgcodecs ON CACHE INTERNAL "")
set(BUILD_opencv_imgcodecs ON CACHE INTERNAL "")
set(BUILD_JPEG ON CACHE INTERNAL "")
set(BUILD_OPENJPEG ON CACHE INTERNAL "")
set(BUILD_PNG ON CACHE INTERNAL "")
set(BUILD_TIFF ON CACHE INTERNAL "")
set(WITH_IMGCODEC_HDR ON CACHE INTERNAL "")
set(WITH_IMGCODEC_PFM ON CACHE INTERNAL "")
set(WITH_IMGCODEC_PXM ON CACHE INTERNAL "")
set(WITH_IMGCODEC_SUNRASTER ON CACHE INTERNAL "")
set(WITH_JPEG ON CACHE INTERNAL "")
set(WITH_OPENJPEG ON CACHE INTERNAL "")
set(WITH_PNG ON CACHE INTERNAL "")
set(WITH_TIFF ON CACHE INTERNAL "")
set(BUILD_SHARED_LIBS OFF CACHE INTERNAL "")
set(BUILD_DOCS OFF CACHE INTERNAL "")
set(BUILD_EXAMPLES OFF CACHE INTERNAL "")
set(BUILD_TESTS OFF CACHE INTERNAL "")
endif()
FetchContent_Declare(
opencv
GIT_REPOSITORY https://github.com/opencv/opencv.git
GIT_TAG 4.5.4
GIT_SHALLOW TRUE
-DBUILD_DOCS:BOOL=FALSE
-DBUILD_EXAMPLES:BOOL=FALSE
-DBUILD_TESTS:BOOL=FALSE
-DBUILD_SHARED_LIBS:BOOL=FALSE
-DCMAKE_INSTALL_PREFIX:PATH=${CMAKE_CURRENT_BINARY_DIR}/opencv
-DCV_TRACE:BOOL=FALSE
PATCH_COMMAND git checkout . && git apply --whitespace=fix --ignore-space-change --ignore-whitespace ${CMAKE_CURRENT_SOURCE_DIR}/cmake/externals/opencv-no-rtti.patch
opencv
GIT_REPOSITORY https://github.com/opencv/opencv.git
GIT_TAG 4.5.4
GIT_SHALLOW TRUE
-DBUILD_DOCS:BOOL=FALSE
-DBUILD_EXAMPLES:BOOL=FALSE
-DBUILD_TESTS:BOOL=FALSE
-DBUILD_SHARED_LIBS:BOOL=FALSE
-DCMAKE_INSTALL_PREFIX:PATH=${CMAKE_CURRENT_BINARY_DIR}/opencv
-DCV_TRACE:BOOL=FALSE
PATCH_COMMAND git checkout . && git apply --whitespace=fix --ignore-space-change --ignore-whitespace ${CMAKE_CURRENT_SOURCE_DIR}/cmake/externals/opencv-no-rtti.patch
)
FetchContent_MakeAvailable(opencv)

Просмотреть файл

@ -316,6 +316,21 @@ class GaussianBlur(CustomOp):
]
class ImageDecoder(CustomOp):
@classmethod
def get_inputs(cls):
return [
cls.io_def('raw_input_image', onnx_proto.TensorProto.UINT8, [])
]
@classmethod
def get_outputs(cls):
return [
cls.io_def('decoded_image', onnx_proto.TensorProto.UINT8, [None, None, 3])
]
class SingleOpGraph:
@classmethod

Просмотреть файл

@ -0,0 +1,85 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <opencv2/core.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/imgcodecs.hpp>
#include "ocos.h"
#include "string_utils.h"
#include <cstdint>
struct KernelImageDecoder : BaseKernel {
KernelImageDecoder(const OrtApi& api) : BaseKernel(api) {}
void Compute(OrtKernelContext* context) {
// Setup inputs
const OrtValue* const inputs = ort_.KernelContext_GetInput(context, 0ULL);
OrtTensorDimensions dimensions(ort_, inputs);
if (dimensions.size() != 1ULL) {
ORT_CXX_API_THROW("[ImageDecoder]: Only raw image formats are supported.", ORT_INVALID_ARGUMENT);
}
// Get data & the length
const uint8_t* const encoded_image_data = ort_.GetTensorData<uint8_t>(inputs);
OrtTensorTypeAndShapeInfo* const input_info = ort_.GetTensorTypeAndShape(inputs);
const int64_t encoded_image_data_len = ort_.GetTensorShapeElementCount(input_info);
ort_.ReleaseTensorTypeAndShapeInfo(input_info);
// Decode the image
const std::vector<int32_t> encoded_image_sizes{1, static_cast<int32_t>(encoded_image_data_len)};
const cv::Mat encoded_image(encoded_image_sizes, CV_8UC1,
const_cast<void*>(static_cast<const void*>(encoded_image_data)));
const cv::Mat decoded_image = cv::imdecode(encoded_image, cv::IMREAD_COLOR);
// Setup output & copy to destination
const cv::Size decoded_image_size = decoded_image.size();
const int64_t colors = 3;
const std::vector<int64_t> output_dimensions{decoded_image_size.height, decoded_image_size.width, colors};
OrtValue *const output_value = ort_.KernelContext_GetOutput(
context, 0, output_dimensions.data(), output_dimensions.size());
uint8_t* const decoded_image_data = ort_.GetTensorMutableData<uint8_t>(output_value);
memcpy(decoded_image_data, decoded_image.data, decoded_image.total() * decoded_image.elemSize());
}
};
struct CustomOpImageDecoder : Ort::CustomOpBase<CustomOpImageDecoder, KernelImageDecoder> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelImageDecoder(api);
}
const char* GetName() const {
return "ImageDecoder";
}
size_t GetInputTypeCount() const {
return 1;
}
ONNXTensorElementDataType GetInputType(size_t index) const {
switch (index) {
case 0:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
default:
ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
}
}
size_t GetOutputTypeCount() const {
return 1;
}
ONNXTensorElementDataType GetOutputType(size_t index) const {
switch (index) {
case 0:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
default:
ORT_CXX_API_THROW(MakeString("Unexpected output index ", index), ORT_INVALID_ARGUMENT);
}
}
};

Просмотреть файл

@ -2,6 +2,9 @@
#include "gaussian_blur.hpp"
#ifdef ENABLE_OPENCV_CODEC
#include "imread.hpp"
#include "imdecode.hpp"
#include "super_resolution_preprocess.hpp"
#include "super_resolution_postprocess.hpp"
#endif // ENABLE_OPENCV_CODEC
@ -10,5 +13,8 @@ FxLoadCustomOpFactory LoadCustomOpClasses_OpenCV =
, CustomOpGaussianBlur
#ifdef ENABLE_OPENCV_CODEC
, CustomOpImageReader
, CustomOpImageDecoder
, CustomOpSuperResolutionPreProcess
, CustomOpSuperResolutionPostProcess
#endif // ENABLE_OPENCV_CODEC
>;

Просмотреть файл

@ -0,0 +1,99 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "super_resolution_postprocess.hpp"
#include "string_utils.h"
#include <opencv2/core.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/imgcodecs.hpp>
#include <cstdint>
KernelSuperResolutionPostProcess::KernelSuperResolutionPostProcess(const OrtApi& api) : BaseKernel(api) {}
void KernelSuperResolutionPostProcess::Compute(OrtKernelContext* context) {
// Setup inputs
const OrtValue* const input_y = ort_.KernelContext_GetInput(context, 0ULL);
const OrtValue* const input_cr = ort_.KernelContext_GetInput(context, 1ULL);
const OrtValue* const input_cb = ort_.KernelContext_GetInput(context, 2ULL);
const OrtTensorDimensions dimensions_y(ort_, input_y);
const OrtTensorDimensions dimensions_cr(ort_, input_cr);
const OrtTensorDimensions dimensions_cb(ort_, input_cb);
if ((dimensions_y.size() != 4ULL) || (dimensions_cr.size() != 4ULL) || (dimensions_cb.size() != 4ULL)) {
throw std::runtime_error("Expecting 3 channels y, cr, and cb.");
}
// Get data & the length
const float* const channel_y_data = ort_.GetTensorData<float>(input_y);
const float* const channel_cr_data = ort_.GetTensorData<float>(input_cr);
const float* const channel_cb_data = ort_.GetTensorData<float>(input_cb);
cv::Mat y(
std::vector<int32_t>{static_cast<int32_t>(dimensions_y[2]), static_cast<int32_t>(dimensions_y[3])},
CV_32F, const_cast<void*>(static_cast<const void*>(channel_y_data)));
cv::Mat cr(
std::vector<int32_t>{static_cast<int32_t>(dimensions_cr[2]), static_cast<int32_t>(dimensions_cr[3])},
CV_32F, const_cast<void*>(static_cast<const void*>(channel_cr_data)));
cv::Mat cb(
std::vector<int32_t>{static_cast<int32_t>(dimensions_cb[2]), static_cast<int32_t>(dimensions_cb[3])},
CV_32F, const_cast<void*>(static_cast<const void*>(channel_cb_data)));
// Scale the individual channels
y *= 255.0;
cv::resize(cr, cr, y.size(), 0, 0, cv::INTER_CUBIC);
cv::resize(cb, cb, y.size(), 0, 0, cv::INTER_CUBIC);
// Merge the channels
const cv::Mat channels[] = {y, cr, cb};
cv::Mat ycrcb_image;
cv::merge(channels, 3, ycrcb_image);
// Convert it back to BGR format
cv::Mat bgr_image;
cv::cvtColor(ycrcb_image, bgr_image, cv::COLOR_YCrCb2BGR);
// Encode it as jpg
std::vector<uchar> encoded_image;
cv::imencode(".jpg", bgr_image, encoded_image);
// Setup output & copy to destination
const std::vector<int64_t> output_dimensions{1LL, static_cast<int64_t>(encoded_image.size())};
OrtValue* const output_value = ort_.KernelContext_GetOutput(
context, 0, output_dimensions.data(), output_dimensions.size());
float* const data = ort_.GetTensorMutableData<float>(output_value);
memcpy(data, encoded_image.data(), encoded_image.size());
}
const char* CustomOpSuperResolutionPostProcess::GetName() const {
return "SuperResolutionPostProcess";
}
size_t CustomOpSuperResolutionPostProcess::GetInputTypeCount() const {
return 3;
}
ONNXTensorElementDataType CustomOpSuperResolutionPostProcess::GetInputType(size_t index) const {
switch (index) {
case 0:
case 1:
case 2:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
default:
ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
}
}
size_t CustomOpSuperResolutionPostProcess::GetOutputTypeCount() const {
return 1;
}
ONNXTensorElementDataType CustomOpSuperResolutionPostProcess::GetOutputType(size_t index) const {
switch (index) {
case 0:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
default:
ORT_CXX_API_THROW(MakeString("Unexpected output index ", index), ORT_INVALID_ARGUMENT);
}
}

Просмотреть файл

@ -0,0 +1,19 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "ocos.h"
struct KernelSuperResolutionPostProcess : BaseKernel {
KernelSuperResolutionPostProcess(const OrtApi& api);
void Compute(OrtKernelContext* context);
};
struct CustomOpSuperResolutionPostProcess : Ort::CustomOpBase<CustomOpSuperResolutionPostProcess, KernelSuperResolutionPostProcess> {
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
};

Просмотреть файл

@ -0,0 +1,91 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "super_resolution_preprocess.hpp"
#include "string_utils.h"
#include <opencv2/core.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/imgcodecs.hpp>
#include <cstdint>
KernelSuperResolutionPreProcess::KernelSuperResolutionPreProcess(const OrtApi& api) : BaseKernel(api) {}
void KernelSuperResolutionPreProcess::Compute(OrtKernelContext* context) {
// Setup inputs
const OrtValue* const inputs = ort_.KernelContext_GetInput(context, 0ULL);
OrtTensorDimensions dimensions(ort_, inputs);
if (dimensions.size() != 1ULL) {
throw std::runtime_error("Only raw image formats are supported.");
}
// Get data & the length
const uint8_t* const encoded_bgr_image_data = ort_.GetTensorData<uint8_t>(inputs);
OrtTensorTypeAndShapeInfo* const input_info = ort_.GetTensorTypeAndShape(inputs);
const int64_t encoded_bgr_image_data_len = ort_.GetTensorShapeElementCount(input_info);
ort_.ReleaseTensorTypeAndShapeInfo(input_info);
// Decode the image
const std::vector<int32_t> encoded_bgr_image_sizes{1, static_cast<int32_t>(encoded_bgr_image_data_len)};
const cv::Mat encoded_bgr_image(encoded_bgr_image_sizes, CV_8UC1,
const_cast<void*>(static_cast<const void*>(encoded_bgr_image_data)));
// OpenCV decodes images in BGR format.
// Ref: https://stackoverflow.com/a/44359400
const cv::Mat decoded_bgr_image = cv::imdecode(encoded_bgr_image, cv::IMREAD_COLOR);
cv::Mat normalized_bgr_image;
decoded_bgr_image.convertTo(normalized_bgr_image, CV_32F);
cv::Mat ycrcb_image;
cv::cvtColor(normalized_bgr_image, ycrcb_image, cv::COLOR_BGR2YCrCb);
cv::Mat channels[3];
cv::split(ycrcb_image, channels);
channels[0] /= 255.0;
// Setup output & copy to destination
for (int32_t i = 0; i < 3; ++i) {
const cv::Mat& channel = channels[i];
const cv::Size size = channel.size();
const std::vector<int64_t> output_dimensions{1LL, 1LL, size.height, size.width};
OrtValue* const output_value = ort_.KernelContext_GetOutput(
context, i, output_dimensions.data(), output_dimensions.size());
float* const data = ort_.GetTensorMutableData<float>(output_value);
memcpy(data, channel.data, channel.total() * channel.elemSize());
}
}
const char* CustomOpSuperResolutionPreProcess::GetName() const {
return "SuperResolutionPreProcess";
}
size_t CustomOpSuperResolutionPreProcess::GetInputTypeCount() const {
return 1;
}
ONNXTensorElementDataType CustomOpSuperResolutionPreProcess::GetInputType(size_t index) const {
switch (index) {
case 0:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
default:
ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
}
}
size_t CustomOpSuperResolutionPreProcess::GetOutputTypeCount() const {
return 3;
}
ONNXTensorElementDataType CustomOpSuperResolutionPreProcess::GetOutputType(size_t index) const {
switch (index) {
case 0:
case 1:
case 2:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
default:
ORT_CXX_API_THROW(MakeString("Unexpected output index ", index), ORT_INVALID_ARGUMENT);
}
}

Просмотреть файл

@ -0,0 +1,19 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "ocos.h"
struct KernelSuperResolutionPreProcess : BaseKernel {
KernelSuperResolutionPreProcess(const OrtApi& api);
void Compute(OrtKernelContext* context);
};
struct CustomOpSuperResolutionPreProcess : Ort::CustomOpBase<CustomOpSuperResolutionPreProcess, KernelSuperResolutionPreProcess> {
const char* GetName() const;
size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const;
size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const;
};

Двоичные данные
test/data/supres.onnx Normal file

Двоичный файл не отображается.

Двоичные данные
test/data/test_colors.jpg Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 111 KiB

Двоичные данные
test/data/test_supres.jpg Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 8.5 KiB

Просмотреть файл

@ -44,6 +44,31 @@ class TestOpenCV(unittest.TestCase):
# convimg.save('temp_pineapple.jpg')
self.assertFalse(np.allclose(np.asarray(img), np.asarray(convimg)))
def test_image_decoder(self):
input_image_file = util.get_test_data_file("data", "test_colors.jpg")
model = OrtPyFunction.from_customop("ImageDecoder")
input_data = open(input_image_file, 'rb').read()
raw_input_image = np.frombuffer(input_data, dtype=np.uint8)
actual = model(raw_input_image)
actual = np.asarray(actual, dtype=np.uint8)
self.assertEqual(actual.shape[2], 3)
expected = Image.open(input_image_file).convert('RGB')
expected = np.asarray(expected, dtype=np.uint8).copy()
# Convert the image to BGR format since cv2 is default BGR format.
red = expected[:,:,0].copy()
expected[:,:,0] = expected[:,:,2].copy()
expected[:,:,2] = red
self.assertEqual(actual.shape[0], expected.shape[0])
self.assertEqual(actual.shape[1], expected.shape[1])
self.assertEqual(actual.shape[2], expected.shape[2])
self.assertTrue(np.allclose(actual, expected, atol=1))
if __name__ == "__main__":
unittest.main()

124
test/test_supres.py Normal file
Просмотреть файл

@ -0,0 +1,124 @@
import io
import numpy as np
import os
import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import unittest
from PIL import Image
from onnxruntime_extensions import OrtPyFunction, util
_input_image_filepath = util.get_test_data_file("data", "test_supres.jpg")
_onnx_model_filepath = util.get_test_data_file("data", "supres.onnx")
_torch_model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'
class SuperResolutionNet(nn.Module):
def __init__(self, upscale_factor, inplace=False):
super(SuperResolutionNet, self).__init__()
self.relu = nn.ReLU(inplace=inplace)
self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
self._initialize_weights()
def forward(self, x):
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.relu(self.conv3(x))
return self.pixel_shuffle(self.conv4(x))
def _initialize_weights(self):
nn.init.orthogonal_(self.conv1.weight, nn.init.calculate_gain('relu'))
nn.init.orthogonal_(self.conv2.weight, nn.init.calculate_gain('relu'))
nn.init.orthogonal_(self.conv3.weight, nn.init.calculate_gain('relu'))
nn.init.orthogonal_(self.conv4.weight)
def _run_torch_inferencing():
# Create the super-resolution model by using the above model definition.
torch_model = SuperResolutionNet(upscale_factor=3)
# Initialize & load model with the pretrained weights
map_location = lambda storage, loc: storage
if torch.cuda.is_available():
map_location = None
torch_model.load_state_dict(model_zoo.load_url(_torch_model_url, map_location=map_location))
# set the model to inferencing mode
torch_model.eval()
input_image_ycbcr = Image.open(_input_image_filepath).convert('YCbCr')
input_image_y, input_image_cb, input_image_cr = input_image_ycbcr.split()
input_image_y = torch.from_numpy(np.asarray(input_image_y, dtype=np.uint8)).float()
input_image_y /= 255.0
input_image_y = input_image_y.view(1, -1, input_image_y.shape[1], input_image_y.shape[0])
output_image_y = torch_model(input_image_y)
output_image_y = output_image_y.detach().cpu().numpy()
output_image_y = Image.fromarray(np.uint8((output_image_y[0] * 255.0).clip(0, 255)[0]), mode='L')
# get the output image follow post-processing step from PyTorch implementation
output_image = Image.merge(
"YCbCr", [
output_image_y,
input_image_cb.resize(output_image_y.size, Image.BICUBIC),
input_image_cr.resize(output_image_y.size, Image.BICUBIC),
]).convert("RGB")
# Uncomment to create a local file
#
# output_image_filepath = util.get_test_data_file("data", "test_supres_torch.jpg")
# output_image.save(output_image_filepath)
output_image = np.asarray(output_image, dtype=np.uint8)
return output_image
def _run_onnx_inferencing():
encoded_input_image = open(_input_image_filepath, 'rb').read()
encoded_input_image = np.frombuffer(encoded_input_image, dtype=np.uint8)
onnx_model = OrtPyFunction.from_model(_onnx_model_filepath)
encoded_output_image = onnx_model(encoded_input_image)
encoded_output_image = encoded_output_image.tobytes()
# Uncomment to create a local file
#
# output_image_filepath = util.get_test_data_file("data", "test_supres_onnx.jpg")
# with open(output_image_filepath, 'wb') as strm:
# strm.write(encoded_output_image)
# strm.flush()
with io.BytesIO(encoded_output_image) as strm:
decoded_output_image = Image.open(strm).convert('RGB')
decoded_output_image = np.asarray(decoded_output_image, dtype=np.uint8)
return decoded_output_image
class TestSupres(unittest.TestCase):
@classmethod
def setUpClass(cls):
pass
def test_e2e(self):
actual = _run_onnx_inferencing()
expected = _run_torch_inferencing()
self.assertEqual(actual.shape[0], expected.shape[0])
self.assertEqual(actual.shape[1], expected.shape[1])
self.assertEqual(actual.shape[2], expected.shape[2])
self.assertTrue(np.allclose(actual, expected, atol=20))
if __name__ == "__main__":
unittest.main()

Двоичные данные
tutorials/data/cat_224x224.jpg Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 10 KiB

199
tutorials/supres_e2e.py Normal file
Просмотреть файл

@ -0,0 +1,199 @@
import io
import numpy as np
import onnx
import onnxruntime
import os
import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
# ref: https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html
_this_dirpath = os.path.dirname(os.path.abspath(__file__))
_data_dirpath = os.path.join(_this_dirpath, 'data')
_onnx_model_filepath = os.path.join(_data_dirpath, 'supres.onnx')
_domain = 'ai.onnx.contrib'
_opset_version = 11
class SuperResolutionNet(nn.Module):
def __init__(self, upscale_factor, inplace=False):
super(SuperResolutionNet, self).__init__()
self.relu = nn.ReLU(inplace=inplace)
self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
self._initialize_weights()
def forward(self, x):
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.relu(self.conv3(x))
return self.pixel_shuffle(self.conv4(x))
def _initialize_weights(self):
nn.init.orthogonal_(self.conv1.weight, nn.init.calculate_gain('relu'))
nn.init.orthogonal_(self.conv2.weight, nn.init.calculate_gain('relu'))
nn.init.orthogonal_(self.conv3.weight, nn.init.calculate_gain('relu'))
nn.init.orthogonal_(self.conv4.weight)
# Create the super-resolution model by using the above model definition.
torch_model = SuperResolutionNet(upscale_factor=3)
# Load pretrained model weights
model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'
batch_size = 1 # just a random number
# Initialize model with the pretrained weights
map_location = lambda storage, loc: storage
if torch.cuda.is_available():
map_location = None
torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location))
# set the model to inference mode
torch_model.eval()
# Input to the model
input = torch.randn(batch_size, 1, 224, 224, requires_grad=True)
# Export the model
with io.BytesIO() as strm:
torch.onnx.export(torch_model, # model being run
input, # model input (or a tuple for multiple inputs)
strm, # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=10, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names = ['input'], # the model's input names
output_names = ['output'], # the model's output names
dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes
'output' : {0 : 'batch_size'}})
onnx_model = onnx.load_model_from_string(strm.getvalue())
onnx.checker.check_model(onnx_model)
ort_session = onnxruntime.InferenceSession(strm.getvalue())
torch_out = torch_model(input)
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(input)}
ort_outputs = ort_session.run(None, ort_inputs)
# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(to_numpy(torch_out), ort_outputs[0], rtol=1e-03, atol=1e-05)
# Generate an model pipeline with pre/post nodes
mkv = onnx.helper.make_tensor_value_info
onnx_opsetids = [
onnx.helper.make_opsetid('', _opset_version),
onnx.helper.make_opsetid(_domain, _opset_version)
]
# Create custom op node for pre-processing
preprocess_node = onnx.helper.make_node(
'SuperResolutionPreProcess',
inputs=['raw_input_image'],
outputs=['input', 'cr', 'cb'],
name='Preprocess',
doc_string='Preprocessing node',
domain=_domain)
process_model = onnx_model
process_model.opset_import.pop()
process_model.opset_import.extend(onnx_opsetids)
onnx.checker.check_model(process_model)
process_graph = process_model.graph
# Create custom op node for post-processing
postprocess_node = onnx.helper.make_node(
'SuperResolutionPostProcess',
inputs=['output', 'cr', 'cb'],
outputs=['raw_output_image'],
name='Postprocess',
doc_string='Postprocessing node',
domain=_domain)
inputs = [mkv('raw_input_image', onnx.onnx_pb.TensorProto.UINT8, [])]
outputs = [mkv('raw_output_image', onnx.onnx_pb.TensorProto.UINT8, [])]
nodes = [preprocess_node] + list(process_graph.node) + [postprocess_node]
graph = onnx.helper.make_graph(
nodes, 'supres_graph', inputs, outputs,
initializer=list(process_graph.initializer))
onnx_model = onnx.helper.make_model(graph)
onnx_model.opset_import.pop()
onnx_model.opset_import.extend(onnx_opsetids)
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
onnx_model_as_string = str(onnx_model)
onnx_model_as_text = onnx.helper.printable_graph(onnx_model.graph)
if 'op_type: "SuperResolutionPreProcess"' not in onnx_model_as_string:
raise "Failed to add pre-process to onnx graph"
if 'op_type: "SuperResolutionPostProcess"' not in onnx_model_as_string:
raise "Failed to add post-process to onnx graph"
if 'SuperResolutionPreProcess(%raw_input_image)' not in onnx_model_as_text:
raise "Failed to add pre-process to onnx graph"
if 'SuperResolutionPostProcess(%output, %cr, %cb)' not in onnx_model_as_text:
raise "Failed to add post-process to onnx graph"
onnx.checker.check_model(onnx_model)
onnx.save(onnx_model, _onnx_model_filepath)
# Test with a inferencing session
import numpy as np
from onnxruntime_extensions import OrtPyFunction
_input_image_filepath = os.path.join(_data_dirpath, 'cat_224x224.jpg')
_output_image_filepath = os.path.join(_data_dirpath, 'cat_672x672.jpg')
encoded_input_image = open(_input_image_filepath, 'rb').read()
raw_input_image = np.frombuffer(encoded_input_image, dtype=np.uint8)
model_func = OrtPyFunction.from_model(_onnx_model_filepath)
raw_output_image = model_func(raw_input_image)
encoded_output_image = raw_output_image.tobytes()
with open(_output_image_filepath, 'wb') as strm:
strm.write(encoded_output_image)
strm.flush()
'''
Steps to integrate the generated model in Android app:
Assuming the app is using extensions as a separate
binary rather than embedding extensions into the ORT build.
1. Drop the generated extensions binary (libortcustomops.so) into your app's
resources folder (usually app/src/main/jniLibs/armeabi-v7a)
2. When creating an OrtSession, add the following statement to register extensions
val options = OrtSession.SessionOptions()
options.registerCustomOpLibrary("libortcustomops.so")
val session = ortEnv?.createSession(model, options)
3. Call OrtSession.run to generate the output.
val rawImageData = // raw image data in bytes
val shape = longArrayOf(rawImageData.size.toLong())
val tensor = OnnxTensor.createTensor(env, ByteBuffer.wrap(rawImageData), shape, OnnxJavaType.UINT8)
val output = session?.run(Collections.singletonMap("input", tensor))
"output" is the jpeg encoded high resolution image.
'''