Add the MLlama Imaging Processing Support (#823)
* initial checkins for mllama image process * fix some tests * some fixings * add more image * More test assertions * parity test passed * code clean up * code refinement
This commit is contained in:
Родитель
7ab9d24cb4
Коммит
aa2c82fa67
|
@ -557,7 +557,6 @@ class Tensor<std::string_view> : public TensorBase {
|
|||
std::unique_ptr<IStringTensorStorage<std::string_view>> storage_;
|
||||
};
|
||||
|
||||
|
||||
template<typename ...Args>
|
||||
class NamedArgumentDict{
|
||||
public:
|
||||
|
|
|
@ -72,8 +72,8 @@ class ImageProcessor:
|
|||
return image_pre_process(self.processor, images)
|
||||
|
||||
@staticmethod
|
||||
def to_numpy(result):
|
||||
return tensor_result_get_at(result, 0)
|
||||
def to_numpy(result, idx):
|
||||
return tensor_result_get_at(result, idx)
|
||||
|
||||
def __del__(self):
|
||||
if delete_object and self.processor:
|
||||
|
|
|
@ -95,5 +95,5 @@ extError_t ORTX_API_CALL OrtxImagePreProcess(OrtxProcessor* processor, OrtxRawIm
|
|||
*result = nullptr;
|
||||
}
|
||||
|
||||
return {};
|
||||
return status.Code();
|
||||
}
|
||||
|
|
|
@ -160,26 +160,3 @@ extError_t ORTX_API_CALL OrtxGetTensorData(OrtxTensor* tensor, const void** data
|
|||
return extError_t();
|
||||
}
|
||||
|
||||
extError_t ORTX_API_CALL OrtxGetTensorDataInt64(OrtxTensor* tensor, const int64_t** data, const int64_t** shape,
|
||||
size_t* num_dims) {
|
||||
const void* data_ptr{};
|
||||
auto err = OrtxGetTensorData(tensor, &data_ptr, shape, num_dims);
|
||||
*data = reinterpret_cast<const int64_t*>(data_ptr); // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast)
|
||||
return err;
|
||||
}
|
||||
|
||||
extError_t ORTX_API_CALL OrtxGetTensorDataFloat(OrtxTensor* tensor, const float** data, const int64_t** shape,
|
||||
size_t* num_dims) {
|
||||
const void* data_ptr{};
|
||||
auto err = OrtxGetTensorData(tensor, &data_ptr, shape, num_dims);
|
||||
*data = reinterpret_cast<const float*>(data_ptr); // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast)
|
||||
return err;
|
||||
}
|
||||
|
||||
extError_t ORTX_API_CALL OrtxGetTensorDataUint8(OrtxTensor* tensor, const uint8_t** data, const int64_t** shape,
|
||||
size_t* num_dims) {
|
||||
const void* data_ptr{};
|
||||
auto err = OrtxGetTensorData(tensor, &data_ptr, shape, num_dims);
|
||||
*data = reinterpret_cast<const uint8_t*>(data_ptr); // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast)
|
||||
return err;
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
#pragma once
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include <variant>
|
||||
|
||||
#include "ortx_utils.h"
|
||||
#include "file_sys.h"
|
||||
|
@ -258,4 +259,10 @@ std::tuple<std::unique_ptr<T[]>, size_t> LoadRawData(It begin, It end) {
|
|||
|
||||
return std::make_tuple(std::move(raw_data), n);
|
||||
}
|
||||
|
||||
using AttrType =
|
||||
std::variant<std::string, double, int64_t, std::vector<std::string>, std::vector<double>, std::vector<int64_t>>;
|
||||
using AttrDict = std::unordered_map<std::string, AttrType>;
|
||||
} // namespace ort_extensions
|
||||
|
||||
namespace ortx = ort_extensions;
|
||||
|
|
|
@ -17,19 +17,28 @@
|
|||
#include "image_decoder.hpp"
|
||||
#endif
|
||||
#else
|
||||
#include "image_decoder.hpp"
|
||||
#include "image_decoder.hpp"
|
||||
#endif
|
||||
|
||||
#include "image_transforms.hpp"
|
||||
#include "image_transforms_phi_3.hpp"
|
||||
#include "image_transforms_mllama.hpp"
|
||||
|
||||
namespace ort_extensions {
|
||||
std::tuple<std::unique_ptr<ImageRawData[]>, size_t>
|
||||
LoadRawImages(const std::initializer_list<const char*>& image_paths) {
|
||||
return ort_extensions::LoadRawData<const char* const*, ImageRawData>(image_paths.begin(), image_paths.end());
|
||||
}
|
||||
|
||||
std::tuple<std::unique_ptr<ImageRawData[]>, size_t>
|
||||
LoadRawImages(const char* image_paths[], size_t num_images) {
|
||||
return ort_extensions::LoadRawData<const char* const*, ImageRawData>(image_paths, image_paths + num_images);
|
||||
}
|
||||
} // namespace ort_extensions
|
||||
|
||||
using namespace ort_extensions;
|
||||
using json = nlohmann::json;
|
||||
|
||||
std::tuple<std::unique_ptr<ImageRawData[]>, size_t>
|
||||
ort_extensions::LoadRawImages(const std::initializer_list<const char*>& image_paths) {
|
||||
return ort_extensions::LoadRawData<const char* const*, ImageRawData>(image_paths.begin(), image_paths.end());
|
||||
}
|
||||
|
||||
Operation::KernelRegistry ImageProcessor::kernel_registry_ = {
|
||||
{"DecodeImage", []() { return CreateKernelInstance(&DecodeImage::Compute); }},
|
||||
{"Resize", []() { return CreateKernelInstance(&Resize::Compute); }},
|
||||
|
@ -37,7 +46,9 @@ Operation::KernelRegistry ImageProcessor::kernel_registry_ = {
|
|||
{"Normalize", []() { return CreateKernelInstance(&Normalize::Compute); }},
|
||||
{"CenterCrop", []() { return CreateKernelInstance(&CenterCrop::Compute); }},
|
||||
{"ConvertRGB", []() { return CreateKernelInstance(convert_to_rgb); }},
|
||||
{"Permute3D", []() { return CreateKernelInstance(&Permute3D::Compute); }},
|
||||
{"Phi3ImageTransform", []() { return CreateKernelInstance(phi3_hd_transform); }},
|
||||
{"Llama3ImageTransform", []() { return CreateKernelInstance(&Llama3ImageTransform::Compute); }},
|
||||
};
|
||||
|
||||
OrtxStatus ImageProcessor::Init(std::string_view processor_def) {
|
||||
|
@ -189,7 +200,6 @@ OrtxStatus ImageProcessor::PreProcess(ort_extensions::span<ImageRawData> image_d
|
|||
operations_.back()->ResetTensors(allocator_);
|
||||
if (status.IsOk()) {
|
||||
r.SetTensors(std::move(img_result));
|
||||
// r.SetTensorTypes({kOrtxFloat, kOrtxInt64, kOrtxInt64});
|
||||
}
|
||||
|
||||
return status;
|
||||
|
|
|
@ -17,7 +17,10 @@ namespace ort_extensions {
|
|||
using ImageRawData = std::vector<uint8_t>;
|
||||
|
||||
std::tuple<std::unique_ptr<ImageRawData[]>, size_t> LoadRawImages(
|
||||
const std::initializer_list<const char*>& image_paths);
|
||||
const std::initializer_list<const char*>& image_paths);
|
||||
|
||||
std::tuple<std::unique_ptr<ImageRawData[]>, size_t> LoadRawImages(
|
||||
const char* image_paths[], size_t num_images);
|
||||
|
||||
class ProcessorResult : public OrtxObjectImpl {
|
||||
public:
|
||||
|
@ -26,6 +29,7 @@ class ProcessorResult : public OrtxObjectImpl {
|
|||
ortc::Tensor<int64_t>* image_sizes{};
|
||||
ortc::Tensor<int64_t>* num_img_tokens{};
|
||||
};
|
||||
|
||||
class ImageProcessor : public OrtxObjectImpl {
|
||||
public:
|
||||
ImageProcessor();
|
||||
|
|
|
@ -3,9 +3,31 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "ocos.h"
|
||||
#include "ext_status.h"
|
||||
#include "op_def_struct.h"
|
||||
#include "image_resample.h"
|
||||
|
||||
template <typename T>
|
||||
void DumpTensorToFile(const ortc::Tensor<T>& tensor, const char* name) {
|
||||
#if WIN32
|
||||
auto tic = GetTickCount();
|
||||
std::string dtype;
|
||||
if constexpr (std::is_same_v<T, uint8_t> || std::is_same_v<T, std::byte>) {
|
||||
dtype = "_u_";
|
||||
} else {
|
||||
dtype = "_f_";
|
||||
}
|
||||
dtype += std::to_string(tensor.Shape()[1]);
|
||||
// use tic to be filename in a temp file name
|
||||
auto filename = std::string("\\temp\\") + name + std::to_string(tic) + dtype + ".bin";
|
||||
std::ofstream file(filename, std::ios::out | std::ios::binary);
|
||||
if (file.is_open()) {
|
||||
file.write(reinterpret_cast<const char*>(tensor.DataRaw()), tensor.SizeInBytes());
|
||||
file.close();
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
inline OrtxStatus convert_to_rgb(const ortc::Tensor<uint8_t>& input, ortc::Tensor<uint8_t>& output) {
|
||||
auto& dimensions = input.Shape();
|
||||
if (dimensions.size() != 3ULL || dimensions[2] != 3) {
|
||||
|
@ -31,23 +53,13 @@ inline OrtxStatus convert_to_rgb(const ortc::Tensor<uint8_t>& input, ortc::Tenso
|
|||
}
|
||||
|
||||
struct Resize {
|
||||
template <typename DictT>
|
||||
OrtxStatus Init(const DictT& attrs) {
|
||||
for (const auto& [key, value] : attrs) {
|
||||
if (key == "height") {
|
||||
height_ = std::get<int64_t>(value);
|
||||
} else if (key == "width") {
|
||||
width_ = std::get<int64_t>(value);
|
||||
} else if (key == "interpolation") {
|
||||
interpolation_ = std::get<std::string>(value);
|
||||
if (interpolation_ != "NEAREST" && interpolation_ != "LINEAR" && interpolation_ != "CUBIC") {
|
||||
return {kOrtxErrorInvalidArgument, "[Resize]: Invalid interpolation method"};
|
||||
}
|
||||
} else {
|
||||
return {kOrtxErrorInvalidArgument, "[Resize]: Invalid argument"};
|
||||
}
|
||||
}
|
||||
return {};
|
||||
static const std::unordered_map<std::string, int> InterpolationMethods() {
|
||||
return {
|
||||
{"NEAREST", IMAGING_TRANSFORM_NEAREST},
|
||||
{"LINEAR", IMAGING_TRANSFORM_BILINEAR},
|
||||
{"CUBIC", IMAGING_TRANSFORM_BICUBIC},
|
||||
{"LANCZOS", IMAGING_TRANSFORM_LANCZOS}
|
||||
};
|
||||
}
|
||||
|
||||
OrtxStatus Compute(const ortc::Tensor<uint8_t>& input, ortc::Tensor<uint8_t>& output) {
|
||||
|
@ -72,40 +84,57 @@ struct Resize {
|
|||
}
|
||||
}
|
||||
|
||||
int interp = IMAGING_TRANSFORM_NEAREST;
|
||||
if (interpolation_ == "NEAREST") {
|
||||
interp = IMAGING_TRANSFORM_NEAREST;
|
||||
} else if (interpolation_ == "LINEAR") {
|
||||
interp = IMAGING_TRANSFORM_BILINEAR;
|
||||
} else if (interpolation_ == "CUBIC") {
|
||||
interp = IMAGING_TRANSFORM_BICUBIC;
|
||||
} else if (interpolation_ == "LANCZOS") {
|
||||
interp = IMAGING_TRANSFORM_LANCZOS;
|
||||
} else {
|
||||
return {kOrtxErrorInvalidArgument, "[Resize]: Invalid interpolation method"};
|
||||
int interp = InterpolationMethods().at(interpolation_);
|
||||
float box[4] = {0.0f, 0.0f, static_cast<float>(w), static_cast<float>(h)};
|
||||
auto [height, width] = std::make_tuple(height_, width_);
|
||||
|
||||
if (keep_aspect_ratio_) {
|
||||
double scale = (std::max)(static_cast<double>(width) / w, static_cast<double>(height) / h);
|
||||
width = static_cast<int64_t>(w * scale);
|
||||
height = static_cast<int64_t>(h * scale);
|
||||
}
|
||||
|
||||
float box[4] = {0.0f, 0.0f, static_cast<float>(width_), static_cast<float>(height_)};
|
||||
|
||||
auto output_image = ImagingResample(rgb_image, static_cast<int>(width_), static_cast<int>(height_), interp, box);
|
||||
// cv::resize(image, output_image, {static_cast<int32_t>(width_), static_cast<int32_t>(height_)}, 0.0, 0.0, interp);
|
||||
auto output_image = ImagingResample(rgb_image, static_cast<int>(width), static_cast<int>(height), interp, box);
|
||||
ImagingDelete(rgb_image);
|
||||
|
||||
auto* p_output_image = output.Allocate({height_, width_, c});
|
||||
for (auto i = height_ - height_; i < height_; ++i) {
|
||||
for (auto j = width_ - width_; j < width_; ++j) {
|
||||
auto c0_index = i * width_ * c + j * c;
|
||||
std::memcpy(p_output_image + c0_index, output_image->image[i] + j * 4, c);
|
||||
auto* p_output_image = output.Allocate({height, width, c});
|
||||
for (auto i = height - height; i < height; ++i) {
|
||||
for (auto j = width - width; j < width; ++j) {
|
||||
auto c0_index = i * width * c + j * c;
|
||||
std::memcpy(p_output_image + c0_index, output_image->image[i] + j * 4, c);
|
||||
}
|
||||
}
|
||||
// DumpTensor(output);
|
||||
|
||||
ImagingDelete(output_image);
|
||||
return {};
|
||||
}
|
||||
|
||||
template <typename DictT>
|
||||
OrtxStatus Init(const DictT& attrs) {
|
||||
for (const auto& [key, value] : attrs) {
|
||||
if (key == "height") {
|
||||
height_ = std::get<int64_t>(value);
|
||||
} else if (key == "width") {
|
||||
width_ = std::get<int64_t>(value);
|
||||
} else if (key == "keep_aspect_ratio") {
|
||||
keep_aspect_ratio_ = std::get<int64_t>(value) != 0;
|
||||
} else if (key == "interpolation") {
|
||||
interpolation_ = std::get<std::string>(value);
|
||||
if (InterpolationMethods().find(interpolation_) == InterpolationMethods().end()) {
|
||||
return {kOrtxErrorInvalidArgument, "[Resize]: Invalid interpolation method"};
|
||||
}
|
||||
} else {
|
||||
return {kOrtxErrorInvalidArgument, "[Resize]: Invalid argument"};
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t height_{256};
|
||||
int64_t width_{256};
|
||||
bool keep_aspect_ratio_{true};
|
||||
std::string interpolation_{"CUBIC"}; // LINEAR, NEAREST, CUBIC
|
||||
};
|
||||
|
||||
|
@ -113,7 +142,7 @@ struct Rescale {
|
|||
template <typename DictT>
|
||||
OrtxStatus Init(const DictT& attrs) {
|
||||
for (const auto& [key, value] : attrs) {
|
||||
if (key == "scale") {
|
||||
if (key == "rescale_factor") {
|
||||
scale_ = static_cast<float>(std::get<double>(value));
|
||||
} else {
|
||||
return {kOrtxErrorInvalidArgument, "[Rescale]: Invalid argument"};
|
||||
|
@ -139,7 +168,7 @@ struct Rescale {
|
|||
for (int64_t k = 0; k < w; ++k) {
|
||||
auto c0_index = j * w * c + k * c;
|
||||
for (int64_t l = 0; l < c; ++l) {
|
||||
p_output_image[c0_index + l] = input_data[c0_index + l] * scale_;
|
||||
p_output_image[c0_index + l] = static_cast<float>(input_data[c0_index + l]) * scale_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -220,7 +249,6 @@ struct CenterCrop {
|
|||
// s_h = torch.div((img_h - height), 2, rounding_mode='trunc')
|
||||
// s_w = torch.div((img_w - width), 2, rounding_mode='trunc')
|
||||
// x = img[:, :, s_h:s_h + height, s_w:s_w + width]
|
||||
|
||||
OrtxStatus Compute(const ortc::Tensor<uint8_t>& input, ortc::Tensor<uint8_t>& output) {
|
||||
auto& dimensions = input.Shape();
|
||||
if (dimensions.size() != 3ULL) {
|
||||
|
@ -252,3 +280,47 @@ struct CenterCrop {
|
|||
int64_t target_h_{224};
|
||||
int64_t target_w_{224};
|
||||
};
|
||||
|
||||
struct Permute3D {
|
||||
|
||||
OrtxStatus Compute(const ortc::Tensor<float>& input, ortc::Tensor<float>& output) {
|
||||
auto& dimensions = input.Shape();
|
||||
if (dimensions.size() != 3ULL || dims_.size() != 3ULL) {
|
||||
return {kOrtxErrorInvalidArgument, "[Permute]: Only 3D tensors are supported"};
|
||||
}
|
||||
|
||||
auto* input_data = input.Data();
|
||||
std::vector<int64_t> output_shape = {dimensions[dims_[0]], dimensions[dims_[1]], dimensions[dims_[2]]};
|
||||
auto* p_output_image = output.Allocate(output_shape);
|
||||
|
||||
for (int64_t i = 0; i < dimensions[0]; ++i) {
|
||||
for (int64_t j = 0; j < dimensions[1]; ++j) {
|
||||
for (int64_t k = 0; k < dimensions[2]; ++k) {
|
||||
auto c0_index = i * dimensions[1] * dimensions[2] + j * dimensions[2] + k;
|
||||
auto c1_index = (dims_[0] == 0 ? i : (dims_[0] == 1 ? j : k)) * output_shape[1] * output_shape[2] +
|
||||
(dims_[1] == 0 ? i : (dims_[1] == 1 ? j : k)) * output_shape[2] +
|
||||
(dims_[2] == 0 ? i : (dims_[2] == 1 ? j : k));
|
||||
p_output_image[c1_index] = input_data[c0_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
template <typename DictT>
|
||||
OrtxStatus Init(const DictT& attrs) {
|
||||
for (const auto& [key, value] : attrs) {
|
||||
if (key == "dims") {
|
||||
dims_ = std::get<std::vector<int64_t>>(value);
|
||||
} else {
|
||||
return {kOrtxErrorInvalidArgument, "[Permute]: Invalid argument"};
|
||||
}
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int64_t> dims_{1, 2, 0};
|
||||
};
|
||||
|
|
|
@ -0,0 +1,390 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "ortx_processor.h"
|
||||
#include "c_api_utils.hpp"
|
||||
#include "image_resample.h"
|
||||
#include "image_transforms.hpp"
|
||||
|
||||
struct Llama3ImageTransform {
|
||||
static void SplitIntoTitles(const ortc::Tensor<float>& normalized_image, ortc::Tensor<float>& pixel_values,
|
||||
int64_t tile_height, int64_t tile_width) {
|
||||
auto& shape = normalized_image.Shape();
|
||||
int64_t image_height = shape[0];
|
||||
int64_t image_width = shape[1];
|
||||
int64_t num_channels = shape[2];
|
||||
|
||||
const int64_t image_1c_size = tile_height * tile_width;
|
||||
assert(image_height % tile_height == 0);
|
||||
int64_t num_tiles_height = static_cast<int64_t>(image_height / tile_height);
|
||||
assert(image_width % tile_width == 0);
|
||||
int64_t num_tiles_width = static_cast<int64_t>(image_width / tile_width);
|
||||
|
||||
auto p_normalized_image = normalized_image.Data();
|
||||
// shape (num_tiles_width * num_tiles_height, num_channels, tile_height, tile_width)
|
||||
float* output_pixel =
|
||||
pixel_values.Allocate({num_tiles_height * num_tiles_width, num_channels, tile_height, tile_width});
|
||||
|
||||
// From (image_height, image_width, num_channels)
|
||||
// Permute to (num_tiles_height, num_tiles_width, num_channels, tile_height, tile_width)
|
||||
for (int64_t i = 0; i < num_tiles_height; ++i) {
|
||||
for (int64_t j = 0; j < num_tiles_width; ++j) {
|
||||
// convert to be channel first
|
||||
for (int64_t k = 0; k < num_channels; ++k) {
|
||||
auto sub_index = image_1c_size * (i * num_tiles_width + j) * num_channels + image_1c_size * k;
|
||||
for (int64_t y = 0; y < tile_height; ++y) {
|
||||
for (int64_t x = 0; x < tile_width; ++x) {
|
||||
output_pixel[sub_index + y * tile_width + x] =
|
||||
p_normalized_image[(i * tile_height + y) * image_width * num_channels +
|
||||
(j * tile_width + x) * num_channels + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
OrtxStatus Compute(const ortc::Tensor<uint8_t>& image, ortc::Tensor<float>& pixel_values,
|
||||
ortc::Tensor<int64_t>& aspect_ratio_ids, ortc::Tensor<int64_t>& aspect_ratio_mask,
|
||||
ortc::Tensor<int64_t>& num_tiles) {
|
||||
auto& dimensions = image.Shape();
|
||||
if (dimensions.size() != 3ULL) {
|
||||
return {kOrtxErrorInvalidArgument, "[Llama3ImageTransform]: Only 3D decoded image tensors are supported"};
|
||||
}
|
||||
|
||||
std::pair<int64_t, int64_t> aspect_ratio;
|
||||
ortc::Tensor<uint8_t> resized_image(&ortx::CppAllocator::Instance());
|
||||
OrtxStatus status = DoResize(image, resized_image, aspect_ratio);
|
||||
if (!status.IsOk()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
ortc::Tensor<uint8_t> padded_image(&ortx::CppAllocator::Instance());
|
||||
status = DoPad(resized_image, aspect_ratio, padded_image);
|
||||
if (!status.IsOk()) {
|
||||
return status;
|
||||
}
|
||||
resized_image.Release();
|
||||
|
||||
ortc::Tensor<float> rescaled_image(&ortx::CppAllocator::Instance());
|
||||
status = rescale_.Compute(padded_image, rescaled_image);
|
||||
if (!status.IsOk()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
ortc::Tensor<float> normalized_image(&ortx::CppAllocator::Instance());
|
||||
status = normalize_.Compute(rescaled_image, normalized_image);
|
||||
if (!status.IsOk()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
// DumpTensorToFile(normalized_image, "normalized_image");
|
||||
|
||||
SplitIntoTitles(normalized_image, pixel_values, tile_size_.first, tile_size_.second);
|
||||
|
||||
std::vector<std::pair<int64_t, int64_t>> aspect_ratios = {aspect_ratio};
|
||||
auto v_aspect_ratio_ids = ConvertAspectRatiosToIds(aspect_ratios, max_image_tiles_);
|
||||
auto v_aspect_ratio_mask = BuildAspectRatioMask(aspect_ratios, max_image_tiles_);
|
||||
|
||||
auto p_ids = aspect_ratio_ids.Allocate({static_cast<int64_t>(v_aspect_ratio_ids.size())});
|
||||
std::copy(v_aspect_ratio_ids.begin(), v_aspect_ratio_ids.end(), p_ids);
|
||||
|
||||
auto p_mask = aspect_ratio_mask.Allocate({static_cast<int64_t>(v_aspect_ratio_mask[0].size())});
|
||||
std::copy(v_aspect_ratio_mask[0].begin(), v_aspect_ratio_mask[0].end(), p_mask);
|
||||
|
||||
auto p_num_tiles = num_tiles.Allocate({1});
|
||||
p_num_tiles[0] = aspect_ratios[0].first * aspect_ratios[0].second;
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
private:
|
||||
static std::vector<std::pair<int64_t, int64_t>> GetAllSupportedAspectRatios(int64_t max_image_tiles) {
|
||||
std::vector<std::pair<int64_t, int64_t>> aspect_ratios;
|
||||
|
||||
for (int64_t width = 1; width <= max_image_tiles; ++width) {
|
||||
for (int64_t height = 1; height <= max_image_tiles; ++height) {
|
||||
if (width * height <= max_image_tiles) {
|
||||
aspect_ratios.emplace_back(width, height);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return aspect_ratios;
|
||||
}
|
||||
|
||||
/*
|
||||
Calculates the new size of an image to fit within a canvas while maintaining aspect ratio.
|
||||
|
||||
This function calculates the optimal size for an image to fit within a canvas defined by
|
||||
canvas_height and canvas_width, while ensuring that the image dimensions are not smaller than
|
||||
tile_size. If the image is larger than the canvas, the returned size will fit within the canvas.
|
||||
If the image already fits within the canvas, the size remains unchanged.
|
||||
The aspect ratio of the original image is preserved.
|
||||
|
||||
Args:
|
||||
image_height (`int`):
|
||||
The height of the original image.
|
||||
image_width (`int`):
|
||||
The width of the original image.
|
||||
canvas_height (`int`):
|
||||
The height of the canvas.
|
||||
canvas_width (`int`):
|
||||
The width of the canvas.
|
||||
tile_size (`int`):
|
||||
The tile size.
|
||||
|
||||
Returns:
|
||||
`Tuple[int, int]`: A tuple containing the new height and width of the image.
|
||||
|
||||
*/
|
||||
static std::tuple<int64_t, int64_t> GetImageSizeFitToCanvas(int64_t image_height, int64_t image_width,
|
||||
int64_t canvas_height, int64_t canvas_width,
|
||||
int64_t tile_size) {
|
||||
// Set target image size in between `tile_size` and canvas_size
|
||||
int64_t target_width = std::clamp(image_width, tile_size, canvas_width);
|
||||
int64_t target_height = std::clamp(image_height, tile_size, canvas_height);
|
||||
|
||||
double scale_h = static_cast<double>(target_height) / image_height;
|
||||
double scale_w = static_cast<double>(target_width) / image_width;
|
||||
|
||||
int64_t new_width, new_height;
|
||||
|
||||
if (scale_w < scale_h) {
|
||||
new_width = target_width;
|
||||
new_height = static_cast<int64_t>(std::round(image_height * scale_w));
|
||||
} else {
|
||||
new_height = target_height;
|
||||
new_width = static_cast<int64_t>(std::round(image_width * scale_h));
|
||||
}
|
||||
|
||||
return std::make_tuple(new_height, new_width);
|
||||
}
|
||||
|
||||
static std::vector<std::vector<int64_t>> BuildAspectRatioMask(
|
||||
const std::vector<std::pair<int64_t, int64_t>>& aspect_ratios, int64_t max_image_tiles) {
|
||||
int64_t max_num_images = aspect_ratios.size();
|
||||
|
||||
// Initialize the 2D vector with zeros
|
||||
std::vector<std::vector<int64_t>> aspect_ratio_mask(max_num_images, std::vector<int64_t>(max_image_tiles, 0));
|
||||
|
||||
// Set the first tile to 1 for all aspect ratios
|
||||
for (int64_t j = 0; j < max_num_images; ++j) {
|
||||
aspect_ratio_mask[j][0] = 1;
|
||||
}
|
||||
|
||||
// Set the aspect ratio mask for the rest of the tiles
|
||||
for (size_t j = 0; j < aspect_ratios.size(); ++j) {
|
||||
int64_t num_tiles_w = aspect_ratios[j].first;
|
||||
int64_t num_tiles_h = aspect_ratios[j].second;
|
||||
int64_t num_tiles = num_tiles_w * num_tiles_h;
|
||||
for (int64_t k = 0; k < num_tiles && k < max_image_tiles; ++k) {
|
||||
aspect_ratio_mask[j][k] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
return aspect_ratio_mask;
|
||||
}
|
||||
|
||||
/*
|
||||
Determines the best canvas based on image and tile size and maximum number of tiles.
|
||||
|
||||
First, calculates possible resolutions based on the maximum number of tiles and tile size.
|
||||
For example for max_image_tiles=2, tile_size=100, possible tile arrangements are:
|
||||
[(1, 1), (1, 2), (2, 1)] and corresponding canvas sizes are:
|
||||
[(100, 100), (100, 200), (200, 100)]
|
||||
|
||||
For each possible resolution, calculates the scaling factors for
|
||||
width and height, and selects the smallest one, which is the limiting side.
|
||||
E.g. to match the canvas you can upscale height by 2x, and width by 1.5x,
|
||||
therefore, the maximum upscaling you can do is min(2, 1.5) = 1.5.
|
||||
|
||||
If upscaling is possible (any of the scaling factors is greater than 1),
|
||||
then picks the smallest upscaling factor > 1.
|
||||
|
||||
If upscaling is not possible, then picks the largest scaling factor <= 1, i.e.
|
||||
reduce downscaling as much as possible.
|
||||
|
||||
If there are multiple resolutions with the same max scale, we pick the one with the lowest area,
|
||||
to minimize padding. E.g., the same image can be upscaled to 224x224 and 224x448, but the latter
|
||||
has more padding.
|
||||
|
||||
Args:
|
||||
image_height (`int`):
|
||||
The height of the image.
|
||||
image_width (`int`):
|
||||
The width of the image.
|
||||
max_image_tiles (`int`):
|
||||
The maximum number of tiles any image can be split into.
|
||||
tile_size (`int`):
|
||||
The tile size.
|
||||
|
||||
Returns:
|
||||
`pair[int, int]`: The best canvas resolution [height, width] for the given image.
|
||||
*/
|
||||
static std::pair<int64_t, int64_t> GetOptimalTiledCanvas(int64_t image_height, int64_t image_width,
|
||||
int64_t max_image_tiles, int64_t tile_size) {
|
||||
{
|
||||
auto possible_tile_arrangements = GetAllSupportedAspectRatios(max_image_tiles);
|
||||
std::vector<std::pair<int, int>> possible_canvas_sizes;
|
||||
|
||||
for (const auto& arrangement : possible_tile_arrangements) {
|
||||
possible_canvas_sizes.emplace_back(arrangement.first * tile_size, arrangement.second * tile_size);
|
||||
}
|
||||
|
||||
std::vector<double> scales;
|
||||
for (const auto& size : possible_canvas_sizes) {
|
||||
double scale_h = static_cast<double>(size.first) / image_height;
|
||||
double scale_w = static_cast<double>(size.second) / image_width;
|
||||
scales.push_back(std::min(scale_h, scale_w));
|
||||
}
|
||||
|
||||
double selected_scale = 0;
|
||||
std::vector<double> upscaling_options;
|
||||
for (double scale : scales) {
|
||||
if (scale >= 1) {
|
||||
upscaling_options.push_back(scale);
|
||||
}
|
||||
}
|
||||
|
||||
if (!upscaling_options.empty()) {
|
||||
selected_scale = *std::min_element(upscaling_options.begin(), upscaling_options.end());
|
||||
} else {
|
||||
std::vector<double> downscaling_options;
|
||||
for (double scale : scales) {
|
||||
if (scale < 1) {
|
||||
downscaling_options.push_back(scale);
|
||||
}
|
||||
}
|
||||
selected_scale = *std::max_element(downscaling_options.begin(), downscaling_options.end());
|
||||
}
|
||||
|
||||
std::vector<std::pair<int, int>> chosen_canvas;
|
||||
for (size_t i = 0; i < scales.size(); ++i) {
|
||||
if (std::abs(scales[i] - selected_scale) < 1e-9) {
|
||||
chosen_canvas.push_back(possible_canvas_sizes[i]);
|
||||
}
|
||||
}
|
||||
|
||||
if (chosen_canvas.size() > 1) {
|
||||
auto optimal_canvas = std::min_element(chosen_canvas.begin(), chosen_canvas.end(),
|
||||
[](const std::pair<int, int>& a, const std::pair<int, int>& b) {
|
||||
return (a.first * a.second) < (b.first * b.second);
|
||||
});
|
||||
return *optimal_canvas;
|
||||
} else {
|
||||
return chosen_canvas[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static std::vector<int64_t> ConvertAspectRatiosToIds(const std::vector<std::pair<int64_t, int64_t>>& aspect_ratios,
|
||||
int64_t max_image_tiles) {
|
||||
int64_t max_num_images = aspect_ratios.size();
|
||||
|
||||
auto supported_aspect_ratios = GetAllSupportedAspectRatios(max_image_tiles);
|
||||
|
||||
// Initialize the 1D vector with zeros
|
||||
std::vector<int64_t> aspect_ratios_ids(max_num_images, 0);
|
||||
|
||||
for (size_t j = 0; j < aspect_ratios.size(); ++j) {
|
||||
const auto& ratio = aspect_ratios[j];
|
||||
auto it = std::find(supported_aspect_ratios.begin(), supported_aspect_ratios.end(), ratio);
|
||||
if (it != supported_aspect_ratios.end()) {
|
||||
aspect_ratios_ids[j] = std::distance(supported_aspect_ratios.begin(), it) + 1;
|
||||
}
|
||||
}
|
||||
|
||||
return aspect_ratios_ids;
|
||||
}
|
||||
|
||||
OrtxStatus DoPad(const ortc::Tensor<uint8_t>& image, const std::pair<int64_t, int64_t>& aspect_ratio,
|
||||
ortc::Tensor<uint8_t>& padded_image) const {
|
||||
auto& dimensions = image.Shape();
|
||||
auto [image_height, image_width] = std::make_tuple(dimensions[0], dimensions[1]);
|
||||
auto [num_tiles_height, num_tiles_width] = aspect_ratio;
|
||||
auto padded_height = num_tiles_height * tile_size_.first;
|
||||
auto padded_width = num_tiles_width * tile_size_.second;
|
||||
auto pad_size = std::make_pair(padded_height - image_height, padded_width - image_width);
|
||||
auto channels = dimensions[2];
|
||||
auto* padded_image_data = padded_image.Allocate({padded_height, padded_width, channels});
|
||||
std::memset(padded_image_data, 0, padded_height * padded_width * channels);
|
||||
auto* input_data = image.Data();
|
||||
for (int64_t j = 0; j < image_height; ++j) {
|
||||
std::memcpy(padded_image_data + j * padded_width * channels, input_data + j * image_width * channels,
|
||||
image_width * channels);
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
OrtxStatus DoResize(const ortc::Tensor<uint8_t>& image, ortc::Tensor<uint8_t>& resized_image,
|
||||
std::pair<int64_t, int64_t>& aspect_ratio) const {
|
||||
auto& dimensions = image.Shape();
|
||||
auto [image_height, image_width] = std::make_tuple(dimensions[0], dimensions[1]);
|
||||
auto tile_size = tile_size_.first;
|
||||
auto [canvas_height, canvas_width] = GetOptimalTiledCanvas(image_height, image_width, max_image_tiles_, tile_size);
|
||||
auto num_tiles_height = canvas_height / tile_size;
|
||||
auto num_tiles_width = canvas_width / tile_size;
|
||||
aspect_ratio = std::make_pair(num_tiles_height, num_tiles_width);
|
||||
auto [new_height, new_width] =
|
||||
GetImageSizeFitToCanvas(image_height, image_width, canvas_height, canvas_width, tile_size);
|
||||
|
||||
Resize resizer;
|
||||
std::unordered_map<std::string, ortx::AttrType> attrs = {{"height", new_height},
|
||||
{"width", new_width},
|
||||
{"interpolation", std::string("LINEAR")},
|
||||
{"keep_aspect_ratio", int64_t(0)}};
|
||||
OrtxStatus status = resizer.Init(attrs);
|
||||
if (!status.IsOk()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
return resizer.Compute(image, resized_image);
|
||||
}
|
||||
|
||||
public:
|
||||
template <typename DictT>
|
||||
OrtxStatus Init(const DictT& attrs) {
|
||||
DictT normalizer_attrs;
|
||||
DictT rescaler_attrs;
|
||||
for (const auto& [key, value] : attrs) {
|
||||
if (key.find("normalize/") == 0) {
|
||||
normalizer_attrs[key.substr(10)] = value;
|
||||
} else if (key.find("rescale/") == 0) {
|
||||
rescaler_attrs[key.substr(8)] = value;
|
||||
} else if (key == "max_image_tiles") {
|
||||
max_image_tiles_ = std::get<int64_t>(value);
|
||||
} else if (key == "size") {
|
||||
auto tile_size = std::get<std::vector<int64_t>>(value);
|
||||
if (tile_size.size() != 2) {
|
||||
return {kOrtxErrorInvalidArgument, "[Llama3ImageTransform]: Invalid tile size"};
|
||||
}
|
||||
tile_size_ = std::make_pair(tile_size[0], tile_size[1]);
|
||||
} else if (key == "interpolation") {
|
||||
interpolation_ = std::get<std::string>(value);
|
||||
} else {
|
||||
return {kOrtxErrorInvalidArgument, "[Llama3ImageTransform]: Invalid argument"};
|
||||
}
|
||||
}
|
||||
|
||||
OrtxStatus status = normalize_.Init(normalizer_attrs);
|
||||
if (!status.IsOk()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
return rescale_.Init(rescaler_attrs);
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t max_image_tiles_{};
|
||||
std::pair<int64_t, int64_t> tile_size_{};
|
||||
std::string interpolation_{};
|
||||
|
||||
Rescale rescale_;
|
||||
Normalize normalize_;
|
||||
};
|
|
@ -3,7 +3,8 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "ocos.h"
|
||||
#include "ext_status.h"
|
||||
#include "op_def_struct.h"
|
||||
#include "image_resample.h"
|
||||
|
||||
constexpr int max_crops = 16;
|
||||
|
|
|
@ -28,10 +28,6 @@ class KernelDef {
|
|||
virtual TensorArgs AllocateOutput(ortc::IAllocator* allocator) const = 0;
|
||||
virtual OrtxStatus Apply(TensorArgs& inputs, TensorArgs& output) const = 0;
|
||||
|
||||
using AttrType =
|
||||
std::variant<std::string, double, int64_t, std::vector<std::string>, std::vector<double>, std::vector<int64_t>>;
|
||||
using AttrDict = std::unordered_map<std::string, AttrType>;
|
||||
|
||||
template <typename... Args>
|
||||
using tuple_function_args = std::tuple<typename std::remove_reference<Args>::type*...>;
|
||||
|
||||
|
@ -85,7 +81,6 @@ class KernelDef {
|
|||
|
||||
template <typename T, typename... Args>
|
||||
static auto CastOutputAllType(TensorArgs::iterator tensor, T& arg, Args&... args) {
|
||||
// return std::make_tuple(static_cast<T&>(*tensor), CastOutputAllType(args...));
|
||||
return std::tuple_cat(CastOutputImpl<T>(tensor), CastOutputAllType(tensor + 1, args...));
|
||||
}
|
||||
|
||||
|
@ -278,8 +273,6 @@ class Operation {
|
|||
|
||||
private:
|
||||
std::vector<std::unique_ptr<ortc::TensorBase>> outputs_;
|
||||
|
||||
private:
|
||||
const KernelRegistry* kernel_registry_;
|
||||
|
||||
std::unique_ptr<KernelDef> kernel_;
|
||||
|
@ -294,9 +287,10 @@ class OrtxRunner {
|
|||
|
||||
template <typename IT, typename OT> // batch input/output container
|
||||
OrtxStatus Run(IT& input_seq, OT& output_seq) {
|
||||
for (size_t i = 0; i < input_seq.size(); ++i) {
|
||||
size_t i = 0;
|
||||
Operation* last_op = nullptr;
|
||||
for (; i < input_seq.size(); ++i) {
|
||||
auto& input = *(input_seq.begin() + i);
|
||||
Operation* last_op = nullptr;
|
||||
// sequentially apply the operations
|
||||
for (auto& op : ops_) {
|
||||
if (last_op != nullptr) {
|
||||
|
@ -305,7 +299,7 @@ class OrtxRunner {
|
|||
auto [status, ts_output] = op->Apply(allocator_, input);
|
||||
if (status.IsOk()) {
|
||||
if (op == ops_.back()) {
|
||||
output_seq.push_back(ts_output);
|
||||
output_seq.push_back(std::move(ts_output));
|
||||
} else {
|
||||
input = ts_output;
|
||||
}
|
||||
|
@ -317,9 +311,70 @@ class OrtxRunner {
|
|||
}
|
||||
}
|
||||
|
||||
if (last_op != nullptr) {
|
||||
last_op->ResetTensors(allocator_);
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
static bool IsGreaterShape(const std::vector<int64_t>& lhs, const std::vector<int64_t>& rhs) {
|
||||
if (lhs.size() != rhs.size()) {
|
||||
return lhs.size() > rhs.size();
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < lhs.size(); ++i) {
|
||||
if (lhs[i] != rhs[i]) {
|
||||
return lhs[i] > rhs[i];
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
static void CopyOrPadTensor(const std::vector<int64_t>::const_iterator dest_shape_begin,
|
||||
const std::vector<int64_t>::const_iterator dest_shape_end,
|
||||
const std::vector<int64_t>::const_iterator src_shape_begin,
|
||||
const std::vector<int64_t>::const_iterator src_shape_end,
|
||||
std::byte* dest, const std::byte* src, size_t element_size) {
|
||||
// no broadcasting here
|
||||
assert(dest_shape_begin != dest_shape_end && src_shape_begin != src_shape_end);
|
||||
assert(dest_shape_end - dest_shape_begin == src_shape_end - src_shape_begin);
|
||||
|
||||
if ((dest_shape_begin + 1) == dest_shape_end) {
|
||||
std::memcpy(dest, src, element_size * (*src_shape_begin));
|
||||
if (*dest_shape_begin > *src_shape_begin) {
|
||||
std::memset(dest + *src_shape_begin * element_size, 0, (*dest_shape_begin - *src_shape_begin) * element_size);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t dest_chunk_size = 1;
|
||||
int64_t src_chunk_size = 1;
|
||||
for (auto iter = dest_shape_begin + 1; iter != dest_shape_end; ++iter) {
|
||||
dest_chunk_size *= *iter;
|
||||
}
|
||||
|
||||
for (auto iter = src_shape_begin + 1; iter != src_shape_end; ++iter) {
|
||||
src_chunk_size *= *iter;
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < *dest_shape_begin; ++i) {
|
||||
if (i < *src_shape_begin) {
|
||||
if (dest_chunk_size == src_chunk_size) {
|
||||
std::memcpy(dest + i * dest_chunk_size * element_size, src + i * src_chunk_size * element_size,
|
||||
dest_chunk_size * element_size);
|
||||
} else {
|
||||
CopyOrPadTensor(dest_shape_begin + 1, dest_shape_end, src_shape_begin + 1, src_shape_end,
|
||||
dest + i * dest_chunk_size * element_size, src + i * src_chunk_size * element_size,
|
||||
element_size);
|
||||
}
|
||||
} else {
|
||||
std::memset(dest + i * dest_chunk_size * element_size, 0, dest_chunk_size * element_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static OrtxStatus StackTensors(const std::vector<TensorArgs>& arg_lists, std::vector<TensorPtr>& outputs,
|
||||
ortc::IAllocator* allocator) {
|
||||
if (arg_lists.empty()) {
|
||||
|
@ -332,9 +387,18 @@ class OrtxRunner {
|
|||
std::vector<ortc::TensorBase*> ts_ptrs;
|
||||
ts_ptrs.reserve(arg_lists.size());
|
||||
std::vector<int64_t> shape = arg_lists[0][axis]->Shape();
|
||||
size_t element_size = arg_lists[0][axis]->SizeInBytes() / arg_lists[0][axis]->NumberOfElement();
|
||||
bool is_same_shape = true;
|
||||
for (auto& ts : arg_lists) {
|
||||
if (shape != ts[axis]->Shape()) {
|
||||
return {kOrtxErrorInvalidArgument, "[StackTensors]: shapes of tensors to stack are not the same."};
|
||||
is_same_shape = false;
|
||||
auto dtype = ts[axis]->Type();
|
||||
if (dtype != ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 && dtype != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
|
||||
return {kOrtxErrorInvalidArgument, "[StackTensors]: shapes of tensors to stack are not the same."};
|
||||
}
|
||||
if (IsGreaterShape(ts[axis]->Shape(), shape)) {
|
||||
shape = ts[axis]->Shape();
|
||||
}
|
||||
}
|
||||
ts_ptrs.push_back(ts[axis]);
|
||||
}
|
||||
|
@ -342,11 +406,16 @@ class OrtxRunner {
|
|||
std::vector<int64_t> output_shape = shape;
|
||||
output_shape.insert(output_shape.begin(), batch_size);
|
||||
std::byte* tensor_buf = outputs[axis]->AllocateRaw(output_shape);
|
||||
auto ts_size = outputs[axis]->SizeInBytes() / batch_size;
|
||||
for (size_t i = 0; i < batch_size; ++i) {
|
||||
auto ts = ts_ptrs[i];
|
||||
const std::byte* ts_buff = reinterpret_cast<const std::byte*>(ts->DataRaw());
|
||||
auto ts_size = ts->SizeInBytes();
|
||||
std::memcpy(tensor_buf + i * ts_size, ts_buff, ts_size);
|
||||
if (is_same_shape /* || ts->Shape() == std::vector<int64_t>(output_shape.begin() + 1, output_shape.end()) */) {
|
||||
std::memcpy(tensor_buf + i * ts_size, ts_buff, ts_size);
|
||||
} else {
|
||||
CopyOrPadTensor(output_shape.begin() + 1, output_shape.end(), ts->Shape().begin(), ts->Shape().end(),
|
||||
tensor_buf + i * ts_size, reinterpret_cast<const std::byte*>(ts->DataRaw()), element_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -11,20 +11,14 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"operation": {
|
||||
"name": "convert_to_rgb",
|
||||
"type": "ConvertRGB"
|
||||
}
|
||||
},
|
||||
{
|
||||
"operation": {
|
||||
"name": "resize",
|
||||
"type": "Resize",
|
||||
"attrs": {
|
||||
"interpolation": "CUBIC",
|
||||
"width": 256,
|
||||
"height": 256
|
||||
"width": 224,
|
||||
"height": 224
|
||||
}
|
||||
}
|
||||
},
|
||||
|
@ -53,6 +47,15 @@
|
|||
"std": [0.229, 0.224, 0.225]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"operation": {
|
||||
"name": "to_channel_first",
|
||||
"type": "Permute3D",
|
||||
"attrs": {
|
||||
"dims": [2, 0, 1]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@ img_proc = ImageProcessor(R"""
|
|||
|
||||
img_name = "australia.jpg"
|
||||
result = img_proc.pre_process(os.path.dirname(__file__) + "/" + img_name)
|
||||
np_img = img_proc.to_numpy(result)
|
||||
np_img = img_proc.to_numpy(result, 0)
|
||||
print(np_img.shape, np_img.dtype)
|
||||
|
||||
# can save the image back to disk
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
import re
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from PIL import Image
|
||||
|
||||
dumping_file_path = "C:\\temp\\normalized_image826885234_f_560.bin"
|
||||
|
||||
|
||||
def regen_image(arr):
|
||||
mean = np.array([0.48145466, 0.4578275, 0.40821073])
|
||||
std = np.array([0.26862954, 0.26130258, 0.27577711])
|
||||
|
||||
# Reverse normalization
|
||||
array = arr * std + mean
|
||||
|
||||
# Clip the values to [0, 1] range
|
||||
array = np.clip(array, 0, 1)
|
||||
|
||||
# Convert to [0, 255] range and uint8 type
|
||||
array = (array * 255).astype(np.uint8)
|
||||
return array
|
||||
|
||||
|
||||
filename = os.path.basename(dumping_file_path)
|
||||
res = re.search(r".+(\d+)_([u|f])_(\d+)", filename)
|
||||
dtype = np.uint8 if res[2] == 'u' else np.float32
|
||||
# load the binary raw data from the file
|
||||
with open(dumping_file_path, 'rb') as file:
|
||||
raw_data = np.fromfile(file, dtype=dtype)
|
||||
|
||||
|
||||
image_width = int(res[3])
|
||||
image_height = int(raw_data.size / image_width) // 3
|
||||
raw_data = raw_data.reshape((image_height, image_width, 3))
|
||||
|
||||
# from bgr to rgb
|
||||
# raw_data = raw_data[:, :, ::-1]
|
||||
|
||||
# save the image to disk
|
||||
if dtype == np.float32:
|
||||
raw_data = regen_image(raw_data)
|
||||
|
||||
img = Image.fromarray(raw_data)
|
||||
img.save(dumping_file_path + ".png")
|
||||
img.show()
|
|
@ -0,0 +1,43 @@
|
|||
{
|
||||
"processor": {
|
||||
"name": "MllamaImageProcessor",
|
||||
"transforms": [
|
||||
{
|
||||
"operation": {
|
||||
"name": "decode_image",
|
||||
"domain": "com.microsoft.extensions",
|
||||
"type": "DecodeImage",
|
||||
"attrs": {
|
||||
"color_space": "BGR"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"operation": {
|
||||
"name": "llama3_image_transform",
|
||||
"domain": "com.microsoft.extensions",
|
||||
"type": "Llama3ImageTransform",
|
||||
"attrs": {
|
||||
"max_image_tiles": 4,
|
||||
"size": [
|
||||
560,
|
||||
560
|
||||
],
|
||||
"interpolation": "LINEAR",
|
||||
"rescale/rescale_factor": 0.00392156862745098,
|
||||
"normalize/mean": [
|
||||
0.48145466,
|
||||
0.4578275,
|
||||
0.40821073
|
||||
],
|
||||
"normalize/std": [
|
||||
0.26862954,
|
||||
0.26130258,
|
||||
0.27577711
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
|
@ -12,10 +12,12 @@
|
|||
|
||||
using namespace ort_extensions;
|
||||
|
||||
const char* test_image_paths[] = {"data/processor/standard_s.jpg", "data/processor/australia.jpg", "data/processor/exceltable.png"};
|
||||
const size_t test_image_count = sizeof(test_image_paths) / sizeof(test_image_paths[0]);
|
||||
|
||||
TEST(ProcessorTest, TestPhi3VImageProcessing) {
|
||||
auto [input_data, n_data] = ort_extensions::LoadRawImages(
|
||||
{"data/processor/standard_s.jpg", "data/processor/australia.jpg", "data/processor/exceltable.png"});
|
||||
auto [input_data, n_data] = ort_extensions::LoadRawImages(test_image_paths, test_image_count);
|
||||
// {"data/processor/standard_s.jpg", "data/processor/australia.jpg", "data/processor/exceltable.png"});
|
||||
|
||||
auto proc = OrtxObjectPtr<ImageProcessor>(OrtxCreateProcessor, "data/processor/phi_3_image.json");
|
||||
ortc::Tensor<float>* pixel_values;
|
||||
|
@ -43,11 +45,9 @@ TEST(ProcessorTest, TestPhi3VImageProcessing) {
|
|||
proc->ClearOutputs(&r);
|
||||
}
|
||||
|
||||
TEST(ProcessorTest, TestClipImageProcessing) {
|
||||
const char* images_path[] = {"data/processor/standard_s.jpg", "data/processor/australia.jpg",
|
||||
"data/processor/exceltable.png"};
|
||||
OrtxObjectPtr<OrtxRawImages> raw_images;
|
||||
extError_t err = OrtxLoadImages(ort_extensions::ptr(raw_images), images_path, 3, nullptr);
|
||||
TEST(ProcessorTest, TestCLIPImageProcessing) {
|
||||
OrtxObjectPtr<OrtxRawImages> raw_images{};
|
||||
extError_t err = OrtxLoadImages(ort_extensions::ptr(raw_images), test_image_paths, test_image_count, nullptr);
|
||||
ASSERT_EQ(err, kOrtxOK);
|
||||
|
||||
OrtxObjectPtr<OrtxProcessor> processor;
|
||||
|
@ -72,3 +72,53 @@ TEST(ProcessorTest, TestClipImageProcessing) {
|
|||
ASSERT_EQ(err, kOrtxOK);
|
||||
ASSERT_EQ(num_dims, 4);
|
||||
}
|
||||
|
||||
TEST(ProcessorTest, TestMLlamaImageProcessing) {
|
||||
OrtxObjectPtr<OrtxRawImages> raw_images{};
|
||||
extError_t err = OrtxLoadImages(ort_extensions::ptr(raw_images), test_image_paths, test_image_count, nullptr);
|
||||
ASSERT_EQ(err, kOrtxOK);
|
||||
|
||||
OrtxObjectPtr<OrtxProcessor> processor;
|
||||
err = OrtxCreateProcessor(ort_extensions::ptr(processor), "data/processor/mllama/llama_3_image.json");
|
||||
if (err != kOrtxOK) {
|
||||
std::cout << "Error: " << OrtxGetLastErrorMessage() << std::endl;
|
||||
}
|
||||
ASSERT_EQ(err, kOrtxOK);
|
||||
|
||||
OrtxObjectPtr<OrtxTensorResult> result;
|
||||
err = OrtxImagePreProcess(processor.get(), raw_images.get(), ort_extensions::ptr(result));
|
||||
ASSERT_EQ(err, kOrtxOK);
|
||||
|
||||
OrtxTensor* tensor;
|
||||
err = OrtxTensorResultGetAt(result.get(), 0, &tensor);
|
||||
ASSERT_EQ(err, kOrtxOK);
|
||||
|
||||
const float* data{};
|
||||
const int64_t* shape{};
|
||||
size_t num_dims;
|
||||
err = OrtxGetTensorData(tensor, reinterpret_cast<const void**>(&data), &shape, &num_dims);
|
||||
ASSERT_EQ(err, kOrtxOK);
|
||||
ASSERT_EQ(num_dims, 5);
|
||||
|
||||
err = OrtxTensorResultGetAt(result.get(), 1, &tensor);
|
||||
ASSERT_EQ(err, kOrtxOK);
|
||||
const int64_t* int_data{};
|
||||
err = OrtxGetTensorData(tensor, reinterpret_cast<const void**>(&int_data), &shape, &num_dims);
|
||||
ASSERT_EQ(err, kOrtxOK);
|
||||
ASSERT_EQ(num_dims, 2);
|
||||
ASSERT_EQ(std::vector<int64_t>(int_data, int_data + 3), std::vector<int64_t>({6, 6, 1}));
|
||||
|
||||
err = OrtxTensorResultGetAt(result.get(), 2, &tensor);
|
||||
ASSERT_EQ(err, kOrtxOK);
|
||||
err = OrtxGetTensorData(tensor, reinterpret_cast<const void**>(&int_data), &shape, &num_dims);
|
||||
ASSERT_EQ(err, kOrtxOK);
|
||||
ASSERT_EQ(num_dims, 2);
|
||||
ASSERT_EQ(std::vector<int64_t>(shape, shape + num_dims), std::vector<int64_t>({3, 4}));
|
||||
|
||||
err = OrtxTensorResultGetAt(result.get(), 3, &tensor);
|
||||
ASSERT_EQ(err, kOrtxOK);
|
||||
err = OrtxGetTensorData(tensor, reinterpret_cast<const void**>(&int_data), &shape, &num_dims);
|
||||
ASSERT_EQ(err, kOrtxOK);
|
||||
ASSERT_EQ(num_dims, 2);
|
||||
ASSERT_EQ(std::vector<int64_t>(int_data, int_data + 3), std::vector<int64_t>({4, 4, 1}));
|
||||
}
|
||||
|
|
|
@ -0,0 +1,119 @@
|
|||
import os
|
||||
import tempfile
|
||||
import requests
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
is_pp_api_available = False
|
||||
try:
|
||||
from transformers import AutoImageProcessor
|
||||
from onnxruntime_extensions import pp_api
|
||||
is_pp_api_available = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def regen_image(arr):
|
||||
mean = np.array([0.48145466, 0.4578275, 0.40821073])
|
||||
std = np.array([0.26862954, 0.26130258, 0.27577711])
|
||||
|
||||
# Reverse normalization
|
||||
array = arr * std + mean
|
||||
|
||||
# Clip the values to [0, 1] range
|
||||
array = np.clip(array, 0, 1)
|
||||
|
||||
# Convert to [0, 255] range and uint8 type
|
||||
array = (array * 255).astype(np.uint8)
|
||||
|
||||
# Convert NumPy array to PIL Image
|
||||
image = Image.fromarray(array)
|
||||
return image
|
||||
|
||||
|
||||
@unittest.skipIf(not is_pp_api_available, "pp_api is not available")
|
||||
class TestPPAPI(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
if os.path.exists("/temp"):
|
||||
cls.temp_dir = "/temp"
|
||||
elif os.path.exists("/tmp"):
|
||||
cls.temp_dir = "/tmp"
|
||||
else:
|
||||
cls.temp_dir = tempfile.mkdtemp()
|
||||
print(f"Created temp dir: {cls.temp_dir}")
|
||||
cls.token_id = os.environ.get("HF_TOKEN", None)
|
||||
|
||||
def test_CLIP_image_processing(self):
|
||||
model_id = "openai/clip-vit-large-patch14"
|
||||
image_list = ["test/data/processor/australia.jpg",
|
||||
"test/data/processor/passport.png",
|
||||
"test/data/processor/exceltable.png"]
|
||||
(image, image2, image3) = [Image.open(f) for f in image_list]
|
||||
|
||||
processor = AutoImageProcessor.from_pretrained(model_id)
|
||||
inputs = processor.preprocess(
|
||||
[image, image2, image3], return_tensors="np")
|
||||
print({k: v.shape if k == "pixel_values" else v for k, v in inputs.items()})
|
||||
|
||||
expected_images = inputs["pixel_values"]
|
||||
for i in range(len(expected_images)):
|
||||
expected = expected_images[i]
|
||||
e_image = regen_image(np.transpose(expected, (1, 2, 0)))
|
||||
e_image.save(f"{self.temp_dir}/CLIP_e_{i}.png")
|
||||
|
||||
ort_processor = pp_api.ImageProcessor(
|
||||
"test/data/processor/clip_image.json")
|
||||
inputs = ort_processor.pre_process(image_list)
|
||||
print(ort_processor.to_numpy(inputs, 0).shape)
|
||||
actual_images = ort_processor.to_numpy(inputs, 0)
|
||||
for i in range(len(actual_images)):
|
||||
actual = actual_images[i]
|
||||
a_image = regen_image(np.transpose(actual, (1, 2, 0)))
|
||||
a_image.save(f"{self.temp_dir}/CLIP_a_{i}.png")
|
||||
|
||||
def test_llama3_2_image_processing(self):
|
||||
model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
||||
|
||||
url = ("https://huggingface.co/datasets/huggingface/"
|
||||
"documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg")
|
||||
# save the image to a file in self.temp_dir
|
||||
with open(f"{self.temp_dir}/rabbit.jpg", "wb") as f:
|
||||
f.write(requests.get(url).content)
|
||||
|
||||
# image = Image.open(requests.get(url, stream=True).raw)
|
||||
image_list = [f"{self.temp_dir}/rabbit.jpg",
|
||||
"test/data/processor/passport.png",
|
||||
"test/data/processor/exceltable.png"]
|
||||
(image, image2, image3) = [Image.open(f) for f in image_list]
|
||||
|
||||
processor = AutoImageProcessor.from_pretrained(model_id, token=TestPPAPI.token_id)
|
||||
inputs = processor.preprocess(
|
||||
[image, image2, image3], return_tensors="np")
|
||||
print({k: v.shape if k == "pixel_values" else v for k, v in inputs.items()})
|
||||
|
||||
ort_processor = pp_api.ImageProcessor(
|
||||
"test/data/processor/mllama/llama_3_image.json")
|
||||
ort_inputs = ort_processor.to_numpy(ort_processor.pre_process(image_list), 0)
|
||||
print(ort_inputs.shape)
|
||||
|
||||
for idx in range(len(image_list)):
|
||||
expected_images = inputs["pixel_values"][0][idx]
|
||||
for i in range(len(expected_images)):
|
||||
expected = expected_images[i]
|
||||
e_image = regen_image(np.transpose(expected, (1, 2, 0)))
|
||||
e_image.save(f"{self.temp_dir}/e_{idx}_{i}.png")
|
||||
|
||||
actual_images = ort_inputs[idx]
|
||||
for i in range(len(actual_images)):
|
||||
actual = actual_images[i]
|
||||
a_image = regen_image(np.transpose(actual, (1, 2, 0)))
|
||||
a_image.save(f"{self.temp_dir}/a_{idx}_{i}.png")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Загрузка…
Ссылка в новой задаче