reimplement resize cpu kernel for image processing (#768)
* reimplement resize cpu kernel for image processing * accuracy fixing and code refinement * fix the build issues * fix Linux build issue * more fixings * Fix the pipeline issue * fix the ci script * try to fix CUDA machine pool
This commit is contained in:
Родитель
d79299e733
Коммит
620050fbe0
|
@ -412,10 +412,6 @@ stages:
|
|||
name: 'onnxruntime-extensions-Windows-CPU'
|
||||
|
||||
steps:
|
||||
# the vcpkg build requires a cmake python module
|
||||
- script: python -m pip install cmake
|
||||
displayName: Install cmake python module
|
||||
|
||||
- script: |
|
||||
call .\build.bat -DOCOS_ENABLE_CTEST=ON -DCMAKE_MSVC_RUNTIME_LIBRARY=MultiThreaded
|
||||
cd out/Windows
|
||||
|
|
|
@ -726,7 +726,7 @@ target_link_libraries(ocos_operators PRIVATE ${ocos_libraries})
|
|||
|
||||
set (file_patterns "shared/lib/*.cc")
|
||||
if (OCOS_ENABLE_C_API)
|
||||
list(APPEND file_patterns "shared/api/*.h*" "shared/api/*.cc")
|
||||
list(APPEND file_patterns "shared/api/*.h*" "shared/api/*.c" "shared/api/*.cc")
|
||||
endif()
|
||||
|
||||
file(GLOB shared_TARGET_LIB_SRC ${file_patterns})
|
||||
|
|
|
@ -10,7 +10,10 @@ if (NOT Python3_FOUND)
|
|||
message(FATAL_ERROR "Python3 not found!")
|
||||
endif()
|
||||
|
||||
file(GLOB TARGET_SRC_PYOPS "pyop/*.cc" "pyop/*.h" "shared/*.cc")
|
||||
file(GLOB TARGET_SRC_PYOPS "pyop/pyfunc.cc" "pyop/*.h" "shared/*.cc")
|
||||
if (OCOS_ENABLE_C_API)
|
||||
list(APPEND TARGET_SRC_PYOPS "pyop/py_c_api.cc")
|
||||
endif()
|
||||
if (WIN32)
|
||||
list(APPEND TARGET_SRC_PYOPS "pyop/extensions_pydll.def")
|
||||
endif()
|
||||
|
|
|
@ -1,18 +1,17 @@
|
|||
diff --git a/3rdparty/libjpeg-turbo/CMakeLists.txt b/3rdparty/libjpeg-turbo/CMakeLists.txt
|
||||
index 3c7f29b08e..066ea4e545 100644
|
||||
--- a/3rdparty/libjpeg-turbo/CMakeLists.txt
|
||||
+++ b/3rdparty/libjpeg-turbo/CMakeLists.txt
|
||||
@@ -67,7 +67,7 @@ set(JPEG_LIB_VERSION "${VERSION}-${JPEG_LIB_VERSION}" PARENT_SCOPE)
|
||||
set(THREAD_LOCAL "") # WITH_TURBOJPEG is not used
|
||||
diff --git a/3rdparty/libjpeg/CMakeLists.txt b/3rdparty/libjpeg/CMakeLists.txt
|
||||
index c0524cc38..69a71e416 100644
|
||||
--- a/3rdparty/libjpeg/CMakeLists.txt
|
||||
+++ b/3rdparty/libjpeg/CMakeLists.txt
|
||||
@@ -27,7 +27,6 @@ endif()
|
||||
|
||||
if(MSVC)
|
||||
- add_definitions(-W3 -wd4996 -wd4018)
|
||||
+ add_definitions(-W3)
|
||||
endif()
|
||||
ocv_warnings_disable(CMAKE_C_FLAGS -Wcast-align -Wshadow -Wunused -Wshift-negative-value -Wimplicit-fallthrough)
|
||||
ocv_warnings_disable(CMAKE_C_FLAGS -Wunused-parameter) # clang
|
||||
-ocv_warnings_disable(CMAKE_C_FLAGS /wd4013 /wd4244 /wd4267) # vs2005
|
||||
|
||||
if(WIN32)
|
||||
set_target_properties(${JPEG_LIBRARY}
|
||||
PROPERTIES OUTPUT_NAME ${JPEG_LIBRARY}
|
||||
diff --git a/3rdparty/zlib/CMakeLists.txt b/3rdparty/zlib/CMakeLists.txt
|
||||
index 9758861a6b..9e654ba922 100644
|
||||
index 9758861a6..9e654ba92 100644
|
||||
--- a/3rdparty/zlib/CMakeLists.txt
|
||||
+++ b/3rdparty/zlib/CMakeLists.txt
|
||||
@@ -81,7 +81,6 @@ set_target_properties(${ZLIB_LIBRARY} PROPERTIES DEFINE_SYMBOL ZLIB_DLL)
|
||||
|
@ -24,7 +23,7 @@ index 9758861a6b..9e654ba922 100644
|
|||
)
|
||||
|
||||
diff --git a/CMakeLists.txt b/CMakeLists.txt
|
||||
index d95e5db163..db185453df 100644
|
||||
index d95e5db16..db185453d 100644
|
||||
--- a/CMakeLists.txt
|
||||
+++ b/CMakeLists.txt
|
||||
@@ -617,11 +617,6 @@ endif()
|
||||
|
@ -40,7 +39,7 @@ index d95e5db163..db185453df 100644
|
|||
|
||||
ocv_cmake_hook(POST_COMPILER_OPTIONS)
|
||||
diff --git a/cmake/OpenCVDetectCXXCompiler.cmake b/cmake/OpenCVDetectCXXCompiler.cmake
|
||||
index 7f229cde96..92e204a5b9 100644
|
||||
index 7f229cde9..92e204a5b 100644
|
||||
--- a/cmake/OpenCVDetectCXXCompiler.cmake
|
||||
+++ b/cmake/OpenCVDetectCXXCompiler.cmake
|
||||
@@ -171,7 +171,7 @@ elseif(MSVC)
|
||||
|
@ -53,7 +52,7 @@ index 7f229cde96..92e204a5b9 100644
|
|||
else()
|
||||
message(WARNING "OpenCV does not recognize MSVC_VERSION \"${MSVC_VERSION}\". Cannot set OpenCV_RUNTIME")
|
||||
diff --git a/modules/core/include/opencv2/core/ocl.hpp b/modules/core/include/opencv2/core/ocl.hpp
|
||||
index 4503fa00dd..642b0508d0 100644
|
||||
index 4503fa00d..642b0508d 100644
|
||||
--- a/modules/core/include/opencv2/core/ocl.hpp
|
||||
+++ b/modules/core/include/opencv2/core/ocl.hpp
|
||||
@@ -302,21 +302,6 @@ public:
|
||||
|
@ -79,7 +78,7 @@ index 4503fa00dd..642b0508d0 100644
|
|||
inline Impl* getImpl() const { return (Impl*)p; }
|
||||
inline bool empty() const { return !p; }
|
||||
diff --git a/modules/core/src/ocl_disabled.impl.hpp b/modules/core/src/ocl_disabled.impl.hpp
|
||||
index a217979a1e..0ba30d024c 100644
|
||||
index a217979a1..0ba30d024 100644
|
||||
--- a/modules/core/src/ocl_disabled.impl.hpp
|
||||
+++ b/modules/core/src/ocl_disabled.impl.hpp
|
||||
@@ -177,11 +177,6 @@ void* Context::getOpenCLContextProperty(int /*propertyId*/) const { OCL_NOT_AVAI
|
||||
|
@ -94,3 +93,25 @@ index a217979a1e..0ba30d024c 100644
|
|||
/* static */ Context Context::fromHandle(void* context) { OCL_NOT_AVAILABLE(); }
|
||||
/* static */ Context Context::fromDevice(const ocl::Device& device) { OCL_NOT_AVAILABLE(); }
|
||||
/* static */ Context Context::create(const std::string& configuration) { OCL_NOT_AVAILABLE(); }
|
||||
diff --git a/samples/dnn/dnn_model_runner/dnn_conversion/requirements.txt b/samples/dnn/dnn_model_runner/dnn_conversion/requirements.txt
|
||||
deleted file mode 100644
|
||||
index 6887c2ab2..000000000
|
||||
--- a/samples/dnn/dnn_model_runner/dnn_conversion/requirements.txt
|
||||
+++ /dev/null
|
||||
@@ -1,15 +0,0 @@
|
||||
-# Python 3.7.5
|
||||
-onnx>=1.7.0
|
||||
-numpy>=1.19.1
|
||||
-
|
||||
-torch>=1.5.1
|
||||
-torchvision>=0.6.1
|
||||
-
|
||||
-tensorflow>=2.1.0
|
||||
-tensorflow-gpu>=2.1.0
|
||||
-
|
||||
-paddlepaddle>=2.0.0
|
||||
-paddlepaddle-gpu>=2.0.0
|
||||
-paddlehub>=2.1.0
|
||||
-paddle2onnx>=0.5.1
|
||||
-paddleseg>=2.0.0
|
||||
\ No newline at end of file
|
||||
|
|
|
@ -16,4 +16,4 @@ Most APIs accept raw data inputs such as audio, image compressed binary formats,
|
|||
|
||||
**Image processing:** `OrtxCreateProcessor` can create an image processor object from a pre-defined workflow in JSON format to process image files into a tensor-like data type. An example code snippet can be found [here](../test/pp_api_test/test_processor.cc#L75).
|
||||
|
||||
**Audio feature extraction:** `OrtxCreateSpeechFeatureExtractor` creates a speech feature extractor to obtain log mel spectrum data as input for the Whisper model. An example code snippet can be found [here](../test/pp_api_test/test_feature_extractor.cc#L16).
|
||||
**Audio feature extraction:** `OrtxCreateSpeechFeatureExtractor` creates a speech feature extractor to obtain log mel spectrum data as input for the Whisper model. An example code snippet can be found [here](../test/pp_api_test/test_feature_extraction.cc#L16).
|
||||
|
|
|
@ -33,3 +33,23 @@ typedef enum {
|
|||
kOrtxErrorInternal = 9,
|
||||
kOrtxErrorUnknown = 1000
|
||||
} extError_t;
|
||||
|
||||
typedef enum {
|
||||
kOrtxUnknownType = 0,
|
||||
kOrtxFloat = 1,
|
||||
kOrtxDouble = 2,
|
||||
kOrtxString = 3,
|
||||
kOrtxBool = 4,
|
||||
kOrtxComplex64 = 5,
|
||||
kOrtxComplex128 = 6,
|
||||
kOrtxBFloat16 = 7,
|
||||
kOrtxUint8 = 8,
|
||||
kOrtxInt8 = 9,
|
||||
kOrtxUint16 = 10,
|
||||
kOrtxInt16 = 11,
|
||||
kOrtxInt32 = 12,
|
||||
kOrtxUint32 = 13,
|
||||
kOrtxInt64 = 14,
|
||||
kOrtxUint64 = 15,
|
||||
kOrtxFloat16 = 16
|
||||
} extDataType_t;
|
||||
|
|
|
@ -92,6 +92,19 @@ extError_t ORTX_API_CALL OrtxDisposeOnly(OrtxObject* object);
|
|||
*/
|
||||
extError_t ORTX_API_CALL OrtxTensorResultGetAt(OrtxTensorResult* result, size_t index, OrtxTensor** tensor);
|
||||
|
||||
/**
|
||||
* @brief Retrieves the data type of the given tensor.
|
||||
*
|
||||
* This function returns the data type of the specified tensor. The data type is
|
||||
* stored in the `type` parameter.
|
||||
*
|
||||
* @param tensor The tensor for which to retrieve the data type.
|
||||
* @param type A pointer to a variable that will hold the retrieved data type.
|
||||
*
|
||||
* @return An `extError_t` value indicating the success or failure of the operation.
|
||||
*/
|
||||
extError_t ORTX_API_CALL OrtxGetTensorType(OrtxTensor* tensor, extDataType_t* type);
|
||||
|
||||
/** \brief Get the data from the tensor
|
||||
*
|
||||
* \param tensor The tensor object
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
###############################################################################
|
||||
|
||||
from . import _extensions_pydll as _C
|
||||
if not hasattr(_C, "create_processor"):
|
||||
raise ImportError("onnxruntime_extensions is not built with pre-processing API")
|
||||
|
||||
create_processor = _C.create_processor
|
||||
load_images = _C.load_images
|
||||
image_pre_process = _C.image_pre_process
|
||||
tensor_result_get_at = _C.tensor_result_get_at
|
|
@ -0,0 +1,114 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <pybind11/iostream.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/numpy.h>
|
||||
#include <thread>
|
||||
|
||||
#include "ortx_utils.h"
|
||||
#include "ortx_processor.h"
|
||||
#include "pykernel.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
template <typename T>
|
||||
int64_t NumOfElement(const T& sp) {
|
||||
size_t c = 1;
|
||||
for (auto v : sp) {
|
||||
c *= v;
|
||||
}
|
||||
return c;
|
||||
}
|
||||
|
||||
void AddGlobalMethodsCApi(pybind11::module& m) {
|
||||
m.def(
|
||||
"create_processor",
|
||||
[](const char* processor_def_json) {
|
||||
OrtxProcessor* processor = nullptr;
|
||||
auto err = OrtxCreateProcessor(&processor, processor_def_json);
|
||||
if (err != kOrtxOK) {
|
||||
throw std::runtime_error(std::string("Failed to create processor") + OrtxGetLastErrorMessage());
|
||||
}
|
||||
return reinterpret_cast<std::uintptr_t>(processor);
|
||||
},
|
||||
"Create a processor.");
|
||||
|
||||
m.def(
|
||||
"load_images",
|
||||
[](const std::vector<std::string>& image_paths) {
|
||||
OrtxRawImages* images = nullptr;
|
||||
size_t num_images = image_paths.size();
|
||||
auto image_ptrs = std::make_unique<const char*[]>(num_images);
|
||||
for (size_t i = 0; i < num_images; ++i) {
|
||||
image_ptrs[i] = image_paths[i].c_str();
|
||||
}
|
||||
|
||||
auto err = OrtxLoadImages(&images, image_ptrs.get(), num_images, nullptr);
|
||||
if (err != kOrtxOK) {
|
||||
throw std::runtime_error(std::string("Failed to load images") + OrtxGetLastErrorMessage());
|
||||
}
|
||||
return reinterpret_cast<std::uintptr_t>(images);
|
||||
},
|
||||
"Load images.");
|
||||
|
||||
m.def(
|
||||
"image_pre_process",
|
||||
[](std::uintptr_t processor_h, std::uintptr_t images_h) {
|
||||
OrtxProcessor* processor = reinterpret_cast<OrtxProcessor*>(processor_h);
|
||||
OrtxRawImages* images = reinterpret_cast<OrtxRawImages*>(images_h);
|
||||
OrtxTensorResult* result{};
|
||||
auto err = OrtxImagePreProcess(processor, images, &result);
|
||||
if (err != kOrtxOK) {
|
||||
throw std::runtime_error(std::string("Failed to preprocess images") + OrtxGetLastErrorMessage());
|
||||
}
|
||||
return reinterpret_cast<std::uintptr_t>(result);
|
||||
},
|
||||
"Preprocess images.");
|
||||
|
||||
m.def("tensor_result_get_at", [](std::uintptr_t result_h, size_t index) {
|
||||
OrtxTensorResult* result = reinterpret_cast<OrtxTensorResult*>(result_h);
|
||||
OrtxTensor* tensor{};
|
||||
auto err = OrtxTensorResultGetAt(result, index, &tensor);
|
||||
if (err != kOrtxOK) {
|
||||
throw std::runtime_error(std::string("Failed to get tensor") + OrtxGetLastErrorMessage());
|
||||
}
|
||||
|
||||
extDataType_t tensor_type;
|
||||
|
||||
OrtxGetTensorType(tensor, &tensor_type);
|
||||
const int64_t* shape{};
|
||||
size_t num_dims;
|
||||
const void* data{};
|
||||
size_t elem_size = 0;
|
||||
if (tensor_type == extDataType_t::kOrtxInt64 || tensor_type == extDataType_t::kOrtxFloat) {
|
||||
OrtxGetTensorData(tensor, reinterpret_cast<const void**>(&data), &shape, &num_dims);
|
||||
elem_size = 4;
|
||||
if (tensor_type == extDataType_t::kOrtxInt64) {
|
||||
elem_size = 8;
|
||||
}
|
||||
} else if (tensor_type == extDataType_t::kOrtxUnknownType) {
|
||||
throw std::runtime_error("Failed to get tensor type");
|
||||
} else if (tensor_type == extDataType_t::kOrtxUnknownType) {
|
||||
throw std::runtime_error("unsupported tensor type");
|
||||
}
|
||||
|
||||
std::vector<std::size_t> npy_dims;
|
||||
for (auto n = num_dims - num_dims; n < num_dims; ++n) {
|
||||
npy_dims.push_back(shape[n]);
|
||||
}
|
||||
py::array obj{};
|
||||
|
||||
if (tensor_type == extDataType_t::kOrtxFloat) {
|
||||
obj = py::array_t<float>(npy_dims);
|
||||
} else if (tensor_type == extDataType_t::kOrtxInt64) {
|
||||
obj = py::array_t<int64_t>(npy_dims);
|
||||
}
|
||||
|
||||
void* out_ptr = obj.mutable_data();
|
||||
memcpy(out_ptr, data, NumOfElement(npy_dims) * elem_size);
|
||||
return obj;
|
||||
}, "Get tensor at index.");
|
||||
}
|
|
@ -482,6 +482,9 @@ PYBIND11_MODULE(_extensions_pydll, m) {
|
|||
m.doc() = "pybind11 stateful interface to ONNXRuntime-Extensions";
|
||||
|
||||
AddGlobalMethods(m);
|
||||
#if defined(ENABLE_C_API)
|
||||
AddGlobalMethodsCApi(m);
|
||||
#endif
|
||||
AddObjectMethods(m);
|
||||
auto atexit = py::module_::import("atexit");
|
||||
atexit.attr("register")(py::cpp_function([]() {
|
||||
|
|
|
@ -127,3 +127,7 @@ struct PyCustomOpFactory : public OrtCustomOp {
|
|||
};
|
||||
|
||||
bool EnablePyCustomOps(bool enable = true);
|
||||
|
||||
#if defined(ENABLE_C_API)
|
||||
void AddGlobalMethodsCApi(pybind11::module& m);
|
||||
#endif
|
||||
|
|
|
@ -101,12 +101,29 @@ extError_t ORTX_API_CALL OrtxTensorResultGetAt(OrtxTensorResult* result, size_t
|
|||
return kOrtxErrorInvalidArgument;
|
||||
}
|
||||
|
||||
auto tensor_ptr = std::make_unique<OrtxObjectWrapper<ortc::TensorBase, kOrtxKindTensor>>();
|
||||
tensor_ptr->SetObject(ts);
|
||||
auto tensor_ptr = std::make_unique<TensorObject>();
|
||||
tensor_ptr->SetTensor(ts);
|
||||
tensor_ptr->SetTensorType(result_ptr->GetTensorType(index));
|
||||
*tensor = static_cast<OrtxTensor*>(tensor_ptr.release());
|
||||
return extError_t();
|
||||
}
|
||||
|
||||
extError_t ORTX_API_CALL OrtxGetTensorType(OrtxTensor* tensor, extDataType_t* type) {
|
||||
if (tensor == nullptr || type == nullptr) {
|
||||
ReturnableStatus::last_error_message_ = "Invalid argument";
|
||||
return kOrtxErrorInvalidArgument;
|
||||
}
|
||||
|
||||
auto tensor_impl = static_cast<TensorObject*>(tensor);
|
||||
if (tensor_impl->ortx_kind() != extObjectKind_t::kOrtxKindTensor) {
|
||||
ReturnableStatus::last_error_message_ = "Invalid argument";
|
||||
return kOrtxErrorInvalidArgument;
|
||||
}
|
||||
|
||||
*type = tensor_impl->GetTensorType();
|
||||
return extError_t();
|
||||
}
|
||||
|
||||
extError_t ORTX_API_CALL OrtxGetTensorData(OrtxTensor* tensor, const void** data, const int64_t** shape,
|
||||
size_t* num_dims) {
|
||||
if (tensor == nullptr) {
|
||||
|
@ -114,15 +131,15 @@ extError_t ORTX_API_CALL OrtxGetTensorData(OrtxTensor* tensor, const void** data
|
|||
return kOrtxErrorInvalidArgument;
|
||||
}
|
||||
|
||||
auto tensor_impl = static_cast<OrtxObjectWrapper<ortc::TensorBase, kOrtxKindTensor>*>(tensor);
|
||||
auto tensor_impl = static_cast<TensorObject*>(tensor);
|
||||
if (tensor_impl->ortx_kind() != extObjectKind_t::kOrtxKindTensor) {
|
||||
ReturnableStatus::last_error_message_ = "Invalid argument";
|
||||
return kOrtxErrorInvalidArgument;
|
||||
}
|
||||
|
||||
*data = tensor_impl->GetObject()->DataRaw();
|
||||
*shape = tensor_impl->GetObject()->Shape().data();
|
||||
*num_dims = tensor_impl->GetObject()->Shape().size();
|
||||
*data = tensor_impl->GetTensor()->DataRaw();
|
||||
*shape = tensor_impl->GetTensor()->Shape().data();
|
||||
*num_dims = tensor_impl->GetTensor()->Shape().size();
|
||||
return extError_t();
|
||||
}
|
||||
|
||||
|
|
|
@ -93,12 +93,33 @@ class StringArray : public OrtxObjectImpl {
|
|||
std::vector<std::string> strings_;
|
||||
};
|
||||
|
||||
class TensorObject : public OrtxObjectImpl {
|
||||
public:
|
||||
TensorObject() : OrtxObjectImpl(extObjectKind_t::kOrtxKindTensor) {}
|
||||
~TensorObject() override = default;
|
||||
|
||||
void SetTensor(ortc::TensorBase* tensor) { tensor_ = tensor; }
|
||||
void SetTensorType(extDataType_t type) { tensor_type_ = type; }
|
||||
|
||||
[[nodiscard]] extDataType_t GetTensorType() const { return tensor_type_; }
|
||||
|
||||
[[nodiscard]] ortc::TensorBase* GetTensor() const { return tensor_; }
|
||||
|
||||
private:
|
||||
ortc::TensorBase* tensor_{};
|
||||
extDataType_t tensor_type_{extDataType_t::kOrtxUnknownType};
|
||||
};
|
||||
|
||||
class TensorResult : public OrtxObjectImpl {
|
||||
public:
|
||||
TensorResult() : OrtxObjectImpl(extObjectKind_t::kOrtxKindTensorResult) {}
|
||||
~TensorResult() override = default;
|
||||
|
||||
void SetTensors(std::vector<std::unique_ptr<ortc::TensorBase>>&& tensors) { tensors_ = std::move(tensors); }
|
||||
void SetTensorTypes(const std::vector<extDataType_t>& types) { tensor_types_ = types; }
|
||||
[[nodiscard]] size_t NumTensors() const { return tensors_.size(); }
|
||||
|
||||
[[nodiscard]] const std::vector<extDataType_t>& tensor_types() const { return tensor_types_; }
|
||||
|
||||
[[nodiscard]] const std::vector<std::unique_ptr<ortc::TensorBase>>& tensors() const { return tensors_; }
|
||||
|
||||
|
@ -118,8 +139,16 @@ class TensorResult : public OrtxObjectImpl {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
extDataType_t GetTensorType(size_t i) const {
|
||||
if (i < tensor_types_.size()) {
|
||||
return tensor_types_[i];
|
||||
}
|
||||
return extDataType_t::kOrtxUnknownType;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::unique_ptr<ortc::TensorBase>> tensors_;
|
||||
std::vector<extDataType_t> tensor_types_;
|
||||
};
|
||||
|
||||
struct ReturnableStatus {
|
||||
|
|
|
@ -143,7 +143,7 @@ std::tuple<OrtxStatus, ProcessorResult> ImageProcessor::PreProcess(ort_extension
|
|||
|
||||
*pixel_values = r.pixel_values = StackTensor<float>(outputs, 0, allocator_);
|
||||
*image_sizes = r.image_sizes = StackTensor<int64_t>(outputs, 1, allocator_);
|
||||
*num_img_takens = r.num_img_takens = StackTensor<int64_t>(outputs, 2, allocator_);
|
||||
*num_img_takens = r.num_img_tokens = StackTensor<int64_t>(outputs, 2, allocator_);
|
||||
|
||||
return {status, std::move(r)};
|
||||
}
|
||||
|
@ -179,6 +179,7 @@ OrtxStatus ImageProcessor::PreProcess(ort_extensions::span<ImageRawData> image_d
|
|||
operations_.back()->ResetTensors(allocator_);
|
||||
if (status.IsOk()) {
|
||||
r.SetTensors(std::move(img_result));
|
||||
r.SetTensorTypes({kOrtxFloat, kOrtxInt64, kOrtxInt64});
|
||||
}
|
||||
|
||||
return status;
|
||||
|
@ -195,8 +196,8 @@ void ImageProcessor::ClearOutputs(ProcessorResult* r) {
|
|||
r->image_sizes = nullptr;
|
||||
}
|
||||
|
||||
if (r->num_img_takens) {
|
||||
std::unique_ptr<ortc::TensorBase>(r->num_img_takens).reset();
|
||||
r->num_img_takens = nullptr;
|
||||
if (r->num_img_tokens) {
|
||||
std::unique_ptr<ortc::TensorBase>(r->num_img_tokens).reset();
|
||||
r->num_img_tokens = nullptr;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -24,7 +24,7 @@ class ProcessorResult : public OrtxObjectImpl {
|
|||
ProcessorResult() : OrtxObjectImpl(kOrtxKindProcessorResult) {}
|
||||
ortc::Tensor<float>* pixel_values{};
|
||||
ortc::Tensor<int64_t>* image_sizes{};
|
||||
ortc::Tensor<int64_t>* num_img_takens{};
|
||||
ortc::Tensor<int64_t>* num_img_tokens{};
|
||||
};
|
||||
class ImageProcessor : public OrtxObjectImpl {
|
||||
public:
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,57 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#define IMAGING_MODE_LENGTH 6 + 1 /* Band names ("1", "L", "P", "RGB", "RGBA", "CMYK", "YCbCr", "BGR;xy") */
|
||||
|
||||
/* standard filters */
|
||||
#define IMAGING_TRANSFORM_NEAREST 0
|
||||
#define IMAGING_TRANSFORM_BOX 4
|
||||
#define IMAGING_TRANSFORM_BILINEAR 2
|
||||
#define IMAGING_TRANSFORM_HAMMING 5
|
||||
#define IMAGING_TRANSFORM_BICUBIC 3
|
||||
#define IMAGING_TRANSFORM_LANCZOS 1
|
||||
|
||||
typedef struct ImagingMemoryInstance* Imaging;
|
||||
typedef struct {
|
||||
char* ptr;
|
||||
int size;
|
||||
} ImagingMemoryBlock;
|
||||
|
||||
struct ImagingMemoryInstance {
|
||||
/* Format */
|
||||
char mode[IMAGING_MODE_LENGTH]; /* Band names ("1", "L", "P", "RGB", "RGBA", "CMYK",
|
||||
"YCbCr", "BGR;xy") */
|
||||
int type; /* Data type (IMAGING_TYPE_*) */
|
||||
int bands; /* Number of bands (1, 2, 3, or 4) */
|
||||
int xsize; /* Image dimension. */
|
||||
int ysize;
|
||||
|
||||
/* Data pointers */
|
||||
uint8_t** image8; /* Set for 8-bit images (pixelsize=1). */
|
||||
int32_t** image32; /* Set for 32-bit images (pixelsize=4). */
|
||||
|
||||
/* Internals */
|
||||
char** image; /* Actual raster data. */
|
||||
char* block; /* Set if data is allocated in a single block. */
|
||||
ImagingMemoryBlock* blocks; /* Memory blocks for pixel storage */
|
||||
|
||||
int pixelsize; /* Size of a pixel, in bytes (1, 2 or 4) */
|
||||
int linesize; /* Size of a line, in bytes (xsize * pixelsize) */
|
||||
|
||||
/* Virtual methods */
|
||||
void (*destroy)(Imaging im);
|
||||
};
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
Imaging ImagingNew(const char* mode, int xsize, int ysize);
|
||||
Imaging ImagingResample(Imaging imIn, int xsize, int ysize, int filter, float box[4]);
|
||||
void ImagingDelete(Imaging im);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
|
@ -4,6 +4,7 @@
|
|||
#pragma once
|
||||
|
||||
#include "ocos.h"
|
||||
#include "image_resample.h"
|
||||
|
||||
inline OrtxStatus convert_to_rgb(const ortc::Tensor<uint8_t>& input, ortc::Tensor<uint8_t>& output) {
|
||||
auto& dimensions = input.Shape();
|
||||
|
@ -60,24 +61,45 @@ struct Resize {
|
|||
int w = static_cast<int>(dimensions[1]);
|
||||
int c = static_cast<int>(dimensions[2]);
|
||||
|
||||
cv::Mat image(h, w, CV_8UC3, const_cast<uint8_t*>(input_data));
|
||||
cv::Mat output_image;
|
||||
cv::InterpolationFlags interp{};
|
||||
Imaging rgb_image = ImagingNew("RGB", w, h);
|
||||
for (int32_t i = 0; i < h; ++i) {
|
||||
for (int32_t j = 0; j < w; ++j) {
|
||||
uint8_t* pixel = reinterpret_cast<uint8_t*>(rgb_image->image[i] + j * 4);
|
||||
pixel[0] = input_data[(i * w + j) * 3];
|
||||
pixel[1] = input_data[(i * w + j) * 3 + 1];
|
||||
pixel[2] = input_data[(i * w + j) * 3 + 2];
|
||||
pixel[3] = 0; // unused
|
||||
}
|
||||
}
|
||||
|
||||
int interp = IMAGING_TRANSFORM_NEAREST;
|
||||
if (interpolation_ == "NEAREST") {
|
||||
interp = cv::INTER_NEAREST;
|
||||
interp = IMAGING_TRANSFORM_NEAREST;
|
||||
} else if (interpolation_ == "LINEAR") {
|
||||
interp = cv::INTER_LINEAR;
|
||||
interp = IMAGING_TRANSFORM_BILINEAR;
|
||||
} else if (interpolation_ == "CUBIC") {
|
||||
interp = cv::INTER_CUBIC;
|
||||
interp = IMAGING_TRANSFORM_BICUBIC;
|
||||
} else if (interpolation_ == "LANCZOS") {
|
||||
interp = IMAGING_TRANSFORM_LANCZOS;
|
||||
} else {
|
||||
return {kOrtxErrorInvalidArgument, "[Resize]: Invalid interpolation method"};
|
||||
}
|
||||
|
||||
cv::resize(image, output_image, {static_cast<int32_t>(width_), static_cast<int32_t>(height_)}, 0.0, 0.0, interp);
|
||||
float box[4] = {0.0f, 0.0f, static_cast<float>(width_), static_cast<float>(height_)};
|
||||
|
||||
auto output_image = ImagingResample(rgb_image, static_cast<int>(width_), static_cast<int>(height_), interp, box);
|
||||
// cv::resize(image, output_image, {static_cast<int32_t>(width_), static_cast<int32_t>(height_)}, 0.0, 0.0, interp);
|
||||
ImagingDelete(rgb_image);
|
||||
|
||||
auto* p_output_image = output.Allocate({height_, width_, c});
|
||||
std::memcpy(p_output_image, output_image.data, height_ * width_ * c);
|
||||
for (auto i = height_ - height_; i < height_; ++i) {
|
||||
for (auto j = width_ - width_; j < width_; ++j) {
|
||||
auto c0_index = i * width_ * c + j * c;
|
||||
std::memcpy(p_output_image + c0_index, output_image->image[i] + j * 4, c);
|
||||
}
|
||||
}
|
||||
|
||||
ImagingDelete(output_image);
|
||||
return {};
|
||||
}
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
#pragma once
|
||||
|
||||
#include "ocos.h"
|
||||
#include "image_resample.h"
|
||||
|
||||
constexpr int max_crops = 16;
|
||||
constexpr int num_img_tokens = 144;
|
||||
|
@ -13,7 +14,7 @@ constexpr int image_resized_height = 336;
|
|||
constexpr float OPENAI_CLIP_MEAN[] = {0.48145466f, 0.4578275f, 0.40821073f};
|
||||
constexpr float OPENAI_CLIP_STD[] = {0.26862954f, 0.26130258f, 0.27577711f};
|
||||
|
||||
inline cv::Mat padding_336(const cv::Mat& image) {
|
||||
inline Imaging padding_336_h(Imaging image) {
|
||||
// def padding_336(b):
|
||||
// width, height = b.size
|
||||
// tar = int(np.ceil(height / 336) * 336)
|
||||
|
@ -21,29 +22,100 @@ inline cv::Mat padding_336(const cv::Mat& image) {
|
|||
// bottom_padding = tar - height - top_padding
|
||||
// left_padding = 0
|
||||
// right_padding = 0
|
||||
// b = torchvision.transforms.functional.pad(b, [left_padding, top_padding, right_padding, bottom_padding], fill=[255,255,255])
|
||||
// b = torchvision.transforms.functional.pad(b, [left_padding, top_padding, right_padding, bottom_padding],
|
||||
// fill=[255,255,255])
|
||||
|
||||
// return b
|
||||
float height = static_cast<float>(image.rows);
|
||||
float height = static_cast<float>(image->ysize);
|
||||
int32_t tar = static_cast<int32_t>(std::ceil(height / image_resized_height) * image_resized_height);
|
||||
if (tar == image->ysize) {
|
||||
return image;
|
||||
}
|
||||
int32_t top_padding = static_cast<int32_t>((tar - height) / 2);
|
||||
int32_t bottom_padding = tar - image.rows - top_padding;
|
||||
int32_t bottom_padding = tar - image->ysize - top_padding;
|
||||
|
||||
cv::Mat output;
|
||||
cv::copyMakeBorder(image, output, top_padding, bottom_padding, 0, 0, cv::BORDER_CONSTANT, {255, 255, 255});
|
||||
Imaging output = ImagingNew("RGB", image->xsize, tar);
|
||||
if (output == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
for (int32_t i = 0; i < top_padding; ++i) {
|
||||
for (int32_t j = 0; j < image->xsize; ++j) {
|
||||
output->image[i][j * 4 + 0] = char(255);
|
||||
output->image[i][j * 4 + 1] = char(255);
|
||||
output->image[i][j * 4 + 2] = char(255);
|
||||
output->image[i][j * 4 + 3] = 0; // unused
|
||||
}
|
||||
}
|
||||
for (int32_t i = top_padding; i < top_padding + image->ysize; ++i) {
|
||||
for (int32_t j = 0; j < image->xsize; ++j) {
|
||||
output->image[i][j * 4 + 0] = image->image[i - top_padding][j * 4];
|
||||
output->image[i][j * 4 + 1] = image->image[i - top_padding][j * 4 + 1];
|
||||
output->image[i][j * 4 + 2] = image->image[i - top_padding][j * 4 + 2];
|
||||
output->image[i][j * 4 + 3] = 0; // unused
|
||||
}
|
||||
}
|
||||
for (int32_t i = top_padding + image->ysize; i < tar; ++i) {
|
||||
for (int32_t j = 0; j < image->xsize; ++j) {
|
||||
output->image[i][j * 4 + 0] = char(255);
|
||||
output->image[i][j * 4 + 1] = char(255);
|
||||
output->image[i][j * 4 + 2] = char(255);
|
||||
output->image[i][j * 4 + 3] = 0; // unused
|
||||
}
|
||||
}
|
||||
|
||||
ImagingDelete(image);
|
||||
return output;
|
||||
}
|
||||
|
||||
inline cv::Mat hd_transform(const cv::Mat& image, int hd_num) {
|
||||
inline Imaging padding_336_w(Imaging image) {
|
||||
float width = static_cast<float>(image->xsize);
|
||||
int32_t tar = static_cast<int32_t>(std::ceil(width / image_resized_width) * image_resized_width);
|
||||
if (tar == image->xsize) {
|
||||
return image;
|
||||
}
|
||||
|
||||
int32_t left_padding = static_cast<int32_t>((tar - width) / 2);
|
||||
int32_t right_padding = tar - image->xsize - left_padding;
|
||||
|
||||
Imaging output = ImagingNew("RGB", tar, image->ysize);
|
||||
if (output == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
for (int32_t i = 0; i < image->ysize; ++i) {
|
||||
for (int32_t j = 0; j < left_padding; ++j) {
|
||||
output->image[i][j * 4 + 0] = char(255);
|
||||
output->image[i][j * 4 + 1] = char(255);
|
||||
output->image[i][j * 4 + 2] = char(255);
|
||||
output->image[i][j * 4 + 3] = 0; // unused
|
||||
}
|
||||
for (int32_t j = left_padding; j < left_padding + image->xsize; ++j) {
|
||||
output->image[i][j * 4 + 0] = image->image[i][(j - left_padding) * 4 + 0];
|
||||
output->image[i][j * 4 + 1] = image->image[i][(j - left_padding) * 4 + 1];
|
||||
output->image[i][j * 4 + 2] = image->image[i][(j - left_padding) * 4 + 2];
|
||||
output->image[i][j * 4 + 3] = 0; // unused
|
||||
}
|
||||
for (int32_t j = left_padding + image->xsize; j < tar; ++j) {
|
||||
output->image[i][j * 4 + 0] = char(255);
|
||||
output->image[i][j * 4 + 1] = char(255);
|
||||
output->image[i][j * 4 + 2] = char(255);
|
||||
output->image[i][j * 4 + 3] = 0; // unused
|
||||
}
|
||||
}
|
||||
|
||||
ImagingDelete(image);
|
||||
return output;
|
||||
}
|
||||
|
||||
inline Imaging hd_transform(Imaging image, int hd_num) {
|
||||
// width, height = img.size
|
||||
auto [width, height] = std::make_tuple(image.cols, image.rows);
|
||||
auto [width, height] = std::make_tuple(image->xsize, image->ysize);
|
||||
|
||||
// ratio = width / height if width >= height else height / width
|
||||
float ratio = 1.0f * width;
|
||||
double ratio = 1.0 * width;
|
||||
if (width >= height) {
|
||||
ratio /= height;
|
||||
} else {
|
||||
ratio = 1.0f * height / width;
|
||||
ratio = height / ratio;
|
||||
}
|
||||
|
||||
// scale = 1
|
||||
|
@ -68,15 +140,16 @@ inline cv::Mat hd_transform(const cv::Mat& image, int hd_num) {
|
|||
}
|
||||
|
||||
// img = torchvision.transforms.functional.resize(img, [new_h, new_w])
|
||||
std::vector<int32_t> height_x_width{static_cast<int32_t>(new_h), // H
|
||||
static_cast<int32_t>(new_w)}; // W
|
||||
float box[4] = {0.0f, 0.0f, static_cast<float>(image->xsize), static_cast<float>(image->ysize)};
|
||||
auto output_image =
|
||||
ImagingResample(image, static_cast<int>(new_w), static_cast<int>(new_h), IMAGING_TRANSFORM_BILINEAR, box);
|
||||
ImagingDelete(image);
|
||||
if (output_image == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
cv::Mat output_image;
|
||||
cv::resize(image, output_image,
|
||||
{static_cast<int32_t>(new_w), static_cast<int32_t>(new_h)}, 0.0, 0.0,
|
||||
cv::INTER_LINEAR);
|
||||
// img = padding_336(img)
|
||||
return padding_336(output_image);
|
||||
return width < height ? padding_336_w(output_image) : padding_336_h(output_image);
|
||||
}
|
||||
|
||||
// Function to calculate 1D index from 3D indices
|
||||
|
@ -97,10 +170,8 @@ inline void Permute3DArray(const float* array, float* permutedArray, size_t X, s
|
|||
}
|
||||
}
|
||||
|
||||
inline OrtxStatus phi3_hd_transform(const ortc::Tensor<uint8_t>& input,
|
||||
ortc::Tensor<float>& pixel_values,
|
||||
ortc::Tensor<int64_t>& image_sizes,
|
||||
ortc::Tensor<int64_t>& num_img_takens) {
|
||||
inline OrtxStatus phi3_hd_transform(const ortc::Tensor<uint8_t>& input, ortc::Tensor<float>& pixel_values,
|
||||
ortc::Tensor<int64_t>& image_sizes, ortc::Tensor<int64_t>& num_img_tokens) {
|
||||
auto& dimensions = input.Shape();
|
||||
if (dimensions.size() != 3ULL) {
|
||||
return {kOrtxErrorInvalidArgument, "[hd_transform]: Only raw image formats"};
|
||||
|
@ -111,89 +182,121 @@ inline OrtxStatus phi3_hd_transform(const ortc::Tensor<uint8_t>& input,
|
|||
int32_t h = static_cast<int32_t>(dimensions[0]);
|
||||
int32_t w = static_cast<int32_t>(dimensions[1]);
|
||||
int32_t c = static_cast<int32_t>(dimensions[2]);
|
||||
std::vector<int32_t> height_x_width{static_cast<int32_t>(h), // H
|
||||
static_cast<int32_t>(w)}; // W
|
||||
// std::vector<int32_t> height_x_width{static_cast<int32_t>(h), // H
|
||||
// static_cast<int32_t>(w)}; // W
|
||||
|
||||
cv::Mat rgb_image(height_x_width, CV_8UC3, const_cast<uint8_t*>(input_data));
|
||||
// elems = [HD_transform(im, hd_num = self.num_crops) for im in images]
|
||||
auto elem = hd_transform(rgb_image, max_crops);
|
||||
// # tensor transform and normalize
|
||||
// hd_images = [img_processor(im) for im in elems]
|
||||
std::tie(w, h) = std::make_tuple(elem.cols, elem.rows);
|
||||
auto elem_image = elem.data;
|
||||
auto rgb_image_ptr = std::make_unique<float[]>(h * w * c);
|
||||
auto p_pixel_values = rgb_image_ptr.get();
|
||||
for (int64_t j = 0; j < h; ++j) {
|
||||
for (int64_t k = 0; k < w; ++k) {
|
||||
auto c0_index = j * w * c + k * c;
|
||||
p_pixel_values[c0_index] = (static_cast<float>(elem_image[c0_index]) / 255.f - OPENAI_CLIP_MEAN[0]) / OPENAI_CLIP_STD[0];
|
||||
p_pixel_values[c0_index + 1] = (static_cast<float>(elem_image[c0_index + 1]) / 255.f - OPENAI_CLIP_MEAN[1]) / OPENAI_CLIP_STD[1];
|
||||
p_pixel_values[c0_index + 2] = (static_cast<float>(elem_image[c0_index + 2]) / 255.f - OPENAI_CLIP_MEAN[2]) / OPENAI_CLIP_STD[2];
|
||||
// cv::Mat rgb_image(static_cast<int>(h), static_cast<int>(w), CV_8UC3, const_cast<uint8_t*>(input_data));
|
||||
Imaging rgb_image = ImagingNew("RGB", w, h);
|
||||
if (rgb_image == nullptr) {
|
||||
return {kOrtxErrorOutOfMemory, "[hd_transform]: Failed to allocate memory for RGB image"};
|
||||
}
|
||||
for (int32_t i = 0; i < h; ++i) {
|
||||
for (int32_t j = 0; j < w; ++j) {
|
||||
uint8_t* pixel = reinterpret_cast<uint8_t*>(rgb_image->image[i] + j * 4);
|
||||
pixel[0] = input_data[(i * w + j) * 3];
|
||||
pixel[1] = input_data[(i * w + j) * 3 + 1];
|
||||
pixel[2] = input_data[(i * w + j) * 3 + 2];
|
||||
pixel[3] = 0; // unused
|
||||
}
|
||||
}
|
||||
|
||||
// Debug code to check the image parity
|
||||
// auto rgb_image_ptr_debug = std::make_unique<float[]>(h * w * c);
|
||||
// Permute3DArray(p_pixel_values, rgb_image_ptr_debug.get(), h, w, c);
|
||||
|
||||
cv::Mat hd_image(h, w, CV_32FC3, p_pixel_values);
|
||||
// # create global image
|
||||
// global_image = [torch.nn.functional.interpolate(im.unsqueeze(0).float(), size=(336, 336), mode='bicubic',).to(im.dtype) for im in hd_images]
|
||||
cv::Mat global_image;
|
||||
cv::resize(hd_image, global_image, {image_resized_height, image_resized_width}, 0.0, 0.0, cv::INTER_CUBIC);
|
||||
|
||||
int64_t shape[2];
|
||||
// # [(3, h, w)], where h, w is multiple of 336
|
||||
// shapes = [[im.size(1), im.size(2)] for im in hd_images]
|
||||
{
|
||||
auto shapes = image_sizes.Allocate({2});
|
||||
shapes[0] = shape[0] = hd_image.rows;
|
||||
shapes[1] = shape[1] = hd_image.cols;
|
||||
// cv::Mat rgb_image(h, w, CV_8UC3, const_cast<uint8_t*>(input_data));
|
||||
// elems = [HD_transform(im, hd_num = self.num_crops) for im in images]
|
||||
auto elem = hd_transform(rgb_image, max_crops);
|
||||
// # tensor transform and normalize
|
||||
if (elem == nullptr) {
|
||||
return {kOrtxErrorOutOfMemory, "[hd_transform]: Failed to allocate memory for elem"};
|
||||
}
|
||||
// num_img_tokens = [int((h//336*w//336+1)*144 + 1 + (h//336+1)*12) for h, w in shapes]
|
||||
|
||||
std::tie(w, h) = std::make_tuple(elem->xsize, elem->ysize);
|
||||
uint8_t** elem_image = reinterpret_cast<uint8_t**>(elem->image);
|
||||
auto rgb_image_ptr = std::make_unique<float[]>(c * h * w); // channel first
|
||||
auto p_pixel_values = rgb_image_ptr.get();
|
||||
for (int32_t k = 0; k < c; ++k) {
|
||||
for (int32_t i = 0; i < h; ++i) {
|
||||
for (int32_t j = 0; j < w; ++j) {
|
||||
p_pixel_values[k * h * w + i * w + j] =
|
||||
(static_cast<float>(elem_image[i][j * 4 + k]) / 255.f - OPENAI_CLIP_MEAN[k]) / OPENAI_CLIP_STD[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
ImagingDelete(elem);
|
||||
|
||||
auto shape = image_sizes.Allocate({2});
|
||||
// shapes = [[im.size(1), im.size(2)] for im in hd_images]
|
||||
shape[0] = h;
|
||||
shape[1] = w;
|
||||
|
||||
auto image_size_1c = h * w;
|
||||
std::vector<Imaging> global_image(c); // resample the image per channel
|
||||
for (int32_t k = 0; k < c; ++k) {
|
||||
// # create global image
|
||||
auto image_1c = ImagingNew("F", w, h);
|
||||
for (int32_t y = 0; y < h; ++y) {
|
||||
for (int32_t x = 0; x < w; ++x) {
|
||||
float* pixel = reinterpret_cast<float*>(image_1c->image[y]);
|
||||
*(pixel + x) = p_pixel_values[k * image_size_1c + y * w + x];
|
||||
}
|
||||
}
|
||||
// global_image = [torch.nn.functional.interpolate(im.unsqueeze(0).float(), size=(336, 336),
|
||||
// mode='bicubic',).to(im.dtype) for im in hd_images]
|
||||
float box[]{0.0f, 0.0f, static_cast<float>(image_1c->xsize), static_cast<float>(image_1c->ysize)};
|
||||
global_image[k] =
|
||||
ImagingResample(image_1c, image_resized_width, image_resized_height, IMAGING_TRANSFORM_BICUBIC, box);
|
||||
if (global_image[k] == nullptr) {
|
||||
return {kOrtxErrorOutOfMemory, "[hd_transform]: Failed to allocate memory for global_image"};
|
||||
}
|
||||
ImagingDelete(image_1c);
|
||||
}
|
||||
|
||||
{
|
||||
auto n_tokens = num_img_takens.Allocate({1});
|
||||
// num_img_tokens = [int((h//336*w//336+1)*144 + 1 + (h//336+1)*12) for h, w in shapes]
|
||||
auto n_tokens = num_img_tokens.Allocate({1});
|
||||
auto [h_t, w_t] = std::make_tuple(image_sizes.Data()[0], image_sizes.Data()[1]);
|
||||
auto num_t = (static_cast<int32_t>(
|
||||
static_cast<int32_t>(h_t / image_resized_height) * w_t / image_resized_width) +
|
||||
1) *
|
||||
144 +
|
||||
1 + static_cast<int32_t>(h_t / image_resized_height + 1) * 12;
|
||||
auto num_t =
|
||||
(static_cast<int32_t>(static_cast<int32_t>(h_t / image_resized_height) * w_t / image_resized_width) + 1) * 144 +
|
||||
1 + static_cast<int32_t>(h_t / image_resized_height + 1) * 12;
|
||||
*n_tokens = static_cast<int64_t>(num_t);
|
||||
}
|
||||
// # reshape to channel dimension -> (num_images, num_crops, 3, 336, 336)
|
||||
// # (1, 3, h//336, 336, w//336, 336) -> (1, h//336, w//336, 3, 336, 336) -> (h//336*w//336, 3, 336, 336)
|
||||
// hd_images_reshape = [im.reshape(1, 3, h//336, 336, w//336, 336).permute(0,2,4,1,3,5).reshape(-1, 3, 336, 336).contiguous() for im, (h, w) in zip(hd_images, shapes)]
|
||||
// # concat global image and local image
|
||||
// hd_images_reshape = [torch.cat([_global_image] + [_im], dim=0) for _global_image, _im in zip(global_image, hd_images_reshape)]
|
||||
// # pad to max_num_crops
|
||||
// image_transformed = [pad_to_max_num_crops_tensor(im, self.num_crops+1) for im in hd_images_reshape]
|
||||
// image_transformed = torch.stack(image_transformed, dim=0)
|
||||
// padded_images = image_transformed
|
||||
// hd_images_reshape = [im.reshape(1, 3, h//336, 336, w//336, 336).permute(0,2,4,1,3,5).reshape(-1, 3, 336,
|
||||
// 336).contiguous() for im, (h, w) in zip(hd_images, shapes)] # concat global image and local image hd_images_reshape
|
||||
// = [torch.cat([_global_image] + [_im], dim=0) for _global_image, _im in zip(global_image, hd_images_reshape)] # pad
|
||||
// to max_num_crops image_transformed = [pad_to_max_num_crops_tensor(im, self.num_crops+1) for im in
|
||||
// hd_images_reshape] image_transformed = torch.stack(image_transformed, dim=0) padded_images = image_transformed
|
||||
std::vector<int64_t> padded_image_shape = {max_crops + 1, 3, image_resized_height, image_resized_width};
|
||||
float* output_pixel = pixel_values.Allocate(padded_image_shape);
|
||||
// Copy the image pixel value from the global image
|
||||
const int image_c_size = image_resized_height * image_resized_width * 3;
|
||||
Permute3DArray(reinterpret_cast<float*>(global_image.data), output_pixel, image_resized_height, image_resized_width, 3);
|
||||
auto num_crops = static_cast<int>((shape[0] / image_resized_height) * (shape[1] / image_resized_width));
|
||||
float* image_transformed = reinterpret_cast<float*>(hd_image.data);
|
||||
// for (int i = 0; i < num_crops; ++i) {
|
||||
// Permute3DArray(image_transformed + i * image_c_size, output_pixel + (i + 1) * image_c_size, image_resized_height, image_resized_width, 3);
|
||||
// }
|
||||
const int image_1c_size = image_resized_height * image_resized_width;
|
||||
for (auto i = c - c; i < c; ++i) {
|
||||
for (int y = 0; y < image_resized_height; ++y) {
|
||||
auto image_transformed = reinterpret_cast<float*>(global_image[i]->image[y]);
|
||||
memcpy(output_pixel + i * image_1c_size + y * image_resized_width, image_transformed,
|
||||
image_resized_width * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
float* output_pixel_n_1 = output_pixel + image_c_size;
|
||||
for (auto img : global_image) {
|
||||
ImagingDelete(img);
|
||||
}
|
||||
|
||||
auto num_crops = static_cast<int>((shape[0] / image_resized_height) * (shape[1] / image_resized_width));
|
||||
// chop the image into crops
|
||||
float* output_pixel_n_1 = output_pixel + image_1c_size * c;
|
||||
int m = static_cast<int>(shape[0] / image_resized_height);
|
||||
int n = static_cast<int>(shape[1] / image_resized_width);
|
||||
h = image_resized_height;
|
||||
w = image_resized_width;
|
||||
assert(m * n == num_crops);
|
||||
for (int i = 0; i < m; ++i) {
|
||||
for (int j = 0; j < n; ++j) {
|
||||
int sub_index = (i * n + j) * image_c_size;
|
||||
for (int x = 0; x < image_resized_height; ++x) {
|
||||
for (int y = 0; y < image_resized_width; ++y) {
|
||||
for (int k = 0; k < 3; ++k) { // Loop over channels
|
||||
output_pixel_n_1[sub_index + k * h * w + x * w + y] = image_transformed[((i * h + x) * shape[1] + (j * w + y)) * 3 + k];
|
||||
for (int32_t k = 0; k < c; ++k) {
|
||||
// channel first
|
||||
int sub_index = (i * n + j) * image_1c_size * c + k * image_1c_size;
|
||||
for (int y = 0; y < image_resized_height; ++y) {
|
||||
for (int x = 0; x < image_resized_width; ++x) {
|
||||
output_pixel_n_1[sub_index + y * image_resized_width + x] =
|
||||
p_pixel_values[k * shape[0] * shape[1] + (i * image_resized_height + y) * shape[1] +
|
||||
(j * image_resized_width + x)];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -202,7 +305,8 @@ inline OrtxStatus phi3_hd_transform(const ortc::Tensor<uint8_t>& input,
|
|||
|
||||
// padding the rest of the crops
|
||||
// pad = torch.zeros(max_crops - B, 3, H, W, dtype=images.dtype, device=images.device)
|
||||
memset(output_pixel_n_1 + num_crops * image_c_size, 0, image_c_size * (max_crops - num_crops) * sizeof(float));
|
||||
memset(output_pixel_n_1 + num_crops * image_1c_size * c, 0,
|
||||
image_1c_size * c * (max_crops - num_crops) * sizeof(float));
|
||||
|
||||
// image_sizes = shapes
|
||||
return {};
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
import os
|
||||
import tempfile
|
||||
from PIL import Image
|
||||
from transformers import AutoProcessor
|
||||
from onnxruntime_extensions.pp_api import create_processor, load_images, image_pre_process, tensor_result_get_at
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def regen_image(arr):
|
||||
mean = np.array([0.48145466, 0.4578275, 0.40821073])
|
||||
std = np.array([0.26862954, 0.26130258, 0.27577711])
|
||||
|
||||
# Reverse normalization
|
||||
array = arr * std + mean
|
||||
|
||||
# Clip the values to [0, 1] range
|
||||
array = np.clip(array, 0, 1)
|
||||
|
||||
# Convert to [0, 255] range and uint8 type
|
||||
array = (array * 255).astype(np.uint8)
|
||||
|
||||
# Convert NumPy array to PIL Image
|
||||
image = Image.fromarray(array)
|
||||
return image
|
||||
|
||||
|
||||
test_image = "test/data/processor/passport.png"
|
||||
# test_image = "/temp/passport_s.png"
|
||||
# test_image = "/temp/passport_s2.png"
|
||||
model_id = "microsoft/Phi-3-vision-128k-instruct"
|
||||
|
||||
processor = create_processor("test/data/processor/phi_3_image.json")
|
||||
images = load_images([test_image])
|
||||
c_out = image_pre_process(processor, images)
|
||||
# print(tensor_result_get_at(c_out, 0))
|
||||
# print(tensor_result_get_at(c_out, 1))
|
||||
|
||||
image = Image.open(test_image)
|
||||
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
|
||||
messages = [
|
||||
{"role": "user", "content": "<|image_1|>\nWhat is shown in this image?"},
|
||||
{"role": "assistant", "content": "The chart displays the percentage of respondents who agree with various statements about their preparedness for meetings. It shows five categories: 'Having clear and pre-defined goals for meetings', 'Knowing where to find the information I need for a meeting', 'Understanding my exact role and responsibilities when I'm invited', 'Having tools to manage admin tasks like note-taking or summarization', and 'Having more focus time to sufficiently prepare for meetings'. Each category has an associated bar indicating the level of agreement, measured on a scale from 0% to 100%."},
|
||||
{"role": "user", "content": "Provide insightful questions to spark discussion."}
|
||||
]
|
||||
prompt = processor.tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True)
|
||||
|
||||
|
||||
inputs = processor(prompt, [image], return_tensors="pt")
|
||||
# print(inputs["pixel_values"].numpy())
|
||||
# print(inputs["image_sizes"])
|
||||
|
||||
np.testing.assert_allclose(
|
||||
inputs["image_sizes"].numpy(), tensor_result_get_at(c_out, 1))
|
||||
# np.testing.assert_allclose(inputs["pixel_values"].numpy(), tensor_result_get_at(c_out, 0), rtol=1e-1)
|
||||
|
||||
if os.path.exists("/temp"):
|
||||
temp_dir = "/temp"
|
||||
else:
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
print(f"Created temp dir: {temp_dir}")
|
||||
|
||||
for i in range(17):
|
||||
expected = inputs["pixel_values"].numpy()[0, i]
|
||||
actual = tensor_result_get_at(c_out, 0)[0, i]
|
||||
e_image = regen_image(expected.transpose(1, 2, 0))
|
||||
a_image = regen_image(actual.transpose(1, 2, 0))
|
||||
e_image.save(f"{temp_dir}/e_{i}.png")
|
||||
a_image.save(f"{temp_dir}/a_{i}.png")
|
||||
|
||||
try:
|
||||
np.testing.assert_allclose(inputs["pixel_values"].numpy(
|
||||
)[0, i], tensor_result_get_at(c_out, 0)[0, i], rtol=1e-2)
|
||||
except AssertionError as e:
|
||||
print(str(e))
|
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 824 KiB |
|
@ -4,10 +4,10 @@ $latest_valid_vs = ""
|
|||
function choose_latter_vs {
|
||||
param([string]$path)
|
||||
if ($global:latest_valid_vs -lt $path) {
|
||||
$cmake_path = $path + "\Common7\IDE\CommonExtensions\Microsoft\CMake\CMake\bin\cmake.exe"
|
||||
if (Test-Path -Path $cmake_path) {
|
||||
$global:latest_valid_vs = $path
|
||||
}
|
||||
# $cmake_path = $path + "\Common7\IDE\CommonExtensions\Microsoft\CMake\CMake\bin\cmake.exe"
|
||||
# if (Test-Path -Path $cmake_path) {
|
||||
$global:latest_valid_vs = $path
|
||||
# }
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче