* simplify vision ops

* remove commented

---------

Co-authored-by: Randy Shuai <rashuai@microsoft.com>
This commit is contained in:
RandySheriffH 2023-06-02 20:47:48 -07:00 коммит произвёл GitHub
Родитель 6aaf2920bf
Коммит 5cb3153485
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
7 изменённых файлов: 34 добавлений и 149 удалений

Просмотреть файл

@ -8,21 +8,19 @@
namespace ort_extensions {
void KernelDecodeImage::Compute(OrtKernelContext* context) {
void KernelDecodeImage::Compute(const ortc::Tensor<uint8_t>& input,
ortc::Tensor<uint8_t>& output) {
// Setup inputs
const OrtValue* const inputs = ort_.KernelContext_GetInput(context, 0ULL);
OrtTensorDimensions dimensions(ort_, inputs);
const auto& dimensions = input.Shape();
if (dimensions.size() != 1ULL) {
ORTX_CXX_API_THROW("[DecodeImage]: Raw image bytes with 1D shape expected.", ORT_INVALID_ARGUMENT);
}
OrtTensorTypeAndShapeInfo* input_info = ort_.GetTensorTypeAndShape(inputs);
const int64_t encoded_image_data_len = ort_.GetTensorShapeElementCount(input_info);
ort_.ReleaseTensorTypeAndShapeInfo(input_info);
const int64_t encoded_image_data_len = input.NumberOfElement();
// Decode the image
const std::vector<int32_t> encoded_image_sizes{1, static_cast<int32_t>(encoded_image_data_len)};
const void* encoded_image_data = ort_.GetTensorData<uint8_t>(inputs); // uint8 data
const void* encoded_image_data = input.Data();
const cv::Mat encoded_image(encoded_image_sizes, CV_8UC1, const_cast<void*>(encoded_image_data));
const cv::Mat decoded_image = cv::imdecode(encoded_image, cv::IMREAD_COLOR);
@ -37,8 +35,7 @@ void KernelDecodeImage::Compute(OrtKernelContext* context) {
const int64_t colors = decoded_image.elemSize(); // == 3 as it's BGR
const std::vector<int64_t> output_dims{height, width, colors};
OrtValue* output_value = ort_.KernelContext_GetOutput(context, 0, output_dims.data(), output_dims.size());
uint8_t* decoded_image_data = ort_.GetTensorMutableData<uint8_t>(output_value);
uint8_t* decoded_image_data = output.Allocate(output_dims);
memcpy(decoded_image_data, decoded_image.data, narrow<size_t>(height * width * colors));
}
} // namespace ort_extensions

Просмотреть файл

@ -15,44 +15,7 @@ void decode_image(const ortc::Tensor<uint8_t>& input,
struct KernelDecodeImage : BaseKernel {
KernelDecodeImage(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {}
void Compute(OrtKernelContext* context);
};
struct CustomOpDecodeImage : OrtW::CustomOpBase<CustomOpDecodeImage, KernelDecodeImage> {
void KernelDestroy(void* op_kernel) {
delete static_cast<KernelDecodeImage*>(op_kernel);
}
const char* GetName() const {
return "DecodeImage";
}
size_t GetInputTypeCount() const {
return 1;
}
ONNXTensorElementDataType GetInputType(size_t index) const {
switch (index) {
case 0:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
default:
ORTX_CXX_API_THROW(MakeString("Invalid input index ", index), ORT_INVALID_ARGUMENT);
}
}
size_t GetOutputTypeCount() const {
return 1;
}
ONNXTensorElementDataType GetOutputType(size_t index) const {
switch (index) {
case 0:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
default:
ORTX_CXX_API_THROW(MakeString("Invalid output index ", index), ORT_INVALID_ARGUMENT);
}
}
void Compute(const ortc::Tensor<uint8_t>& input, ortc::Tensor<uint8_t>& output);
};
} // namespace ort_extensions

Просмотреть файл

@ -222,10 +222,11 @@ void DrawBoxesByScore(ImageView& image, const BoxArray& boxes, int64_t thickness
} // namespace
void DrawBoundingBoxes::Compute(OrtKernelContext* context) {
void DrawBoundingBoxes::Compute(const ortc::Tensor<uint8_t>& input_bgr,
const ortc::Tensor<float>& input_box,
ortc::Tensor<uint8_t>& output) {
// Setup inputs
const OrtValue* input_bgr = ort_.KernelContext_GetInput(context, 0ULL);
const OrtTensorDimensions dimensions_bgr(ort_, input_bgr);
const auto& dimensions_bgr = input_bgr.Shape();
if (dimensions_bgr.size() != 3 || dimensions_bgr[2] != 3) {
// expect {H, W, C} as that's the inverse of what decode_image produces.
@ -233,26 +234,21 @@ void DrawBoundingBoxes::Compute(OrtKernelContext* context) {
ORTX_CXX_API_THROW("[DrawBoundingBoxes] requires rank 3 BGR input in channels last format.", ORT_INVALID_ARGUMENT);
}
const OrtValue* input_box = ort_.KernelContext_GetInput(context, 1ULL);
const OrtTensorDimensions dimensions_box(ort_, input_box);
const auto& dimensions_box = input_box.Shape();
// x,y, x/w y/h, score, class
if (dimensions_box.size() != 2 || dimensions_box[1] != 6) {
ORTX_CXX_API_THROW("[DrawBoundingBoxes] requires rank 2 input and the last dim should be 6.", ORT_INVALID_ARGUMENT);
}
auto box_span = gsl::make_span(ort_.GetTensorData<float>(input_box), dimensions_box[0] * dimensions_box[1]);
auto box_span = gsl::make_span(input_box.Data(), dimensions_box[0] * dimensions_box[1]);
BoxArray boxes(dimensions_box, box_span, bbox_mode_);
int64_t image_size = dimensions_bgr[0] * dimensions_bgr[1] * dimensions_bgr[2];
// Setup output & copy to destination
// can we reuse the input buffer?
const std::vector<int64_t>& output_dims = dimensions_bgr;
OrtValue* output_value = ort_.KernelContext_GetOutput(context, 0,
output_dims.data(),
output_dims.size());
auto* output_data = ort_.GetTensorMutableData<uint8_t>(output_value);
const auto* input_data = ort_.GetTensorData<uint8_t>(input_bgr);
auto* output_data = output.Allocate(output_dims);
const auto* input_data = input_bgr.Data();
std::copy(input_data, input_data + image_size, output_data);
auto data_span = gsl::make_span(output_data, image_size);

Просмотреть файл

@ -37,7 +37,9 @@ struct DrawBoundingBoxes : BaseKernel {
}
}
void Compute(OrtKernelContext* context);
void Compute(const ortc::Tensor<uint8_t>& input_bgr,
const ortc::Tensor<float>& input_box,
ortc::Tensor<uint8_t>& output);
private:
int64_t thickness_;
@ -46,41 +48,4 @@ struct DrawBoundingBoxes : BaseKernel {
BoundingBoxFormat bbox_mode_;
};
struct CustomOpDrawBoundingBoxes : OrtW::CustomOpBase<CustomOpDrawBoundingBoxes, DrawBoundingBoxes> {
void KernelDestroy(void* op_kernel) {
delete static_cast<DrawBoundingBoxes*>(op_kernel);
}
const char* GetName() const {
return "DrawBoundingBoxes";
}
size_t GetInputTypeCount() const {
return 2;
}
ONNXTensorElementDataType GetInputType(size_t index) const {
switch (index) {
case 0:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
case 1:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
default:
ORTX_CXX_API_THROW(MakeString("Invalid input index ", index), ORT_INVALID_ARGUMENT);
}
}
size_t GetOutputTypeCount() const {
return 1;
}
ONNXTensorElementDataType GetOutputType(size_t index) const {
switch (index) {
case 0:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
default:
ORTX_CXX_API_THROW(MakeString("Invalid output index ", index), ORT_INVALID_ARGUMENT);
}
}
};
} // namespace ort_extensions

Просмотреть файл

@ -7,10 +7,10 @@
namespace ort_extensions {
void KernelEncodeImage ::Compute(OrtKernelContext* context) {
void KernelEncodeImage::Compute(const ortc::Tensor<uint8_t>& input,
ortc::Tensor<uint8_t>& output) {
// Setup inputs
const OrtValue* input_bgr = ort_.KernelContext_GetInput(context, 0ULL);
const OrtTensorDimensions dimensions_bgr(ort_, input_bgr);
const auto dimensions_bgr = input.Shape();
if (dimensions_bgr.size() != 3 || dimensions_bgr[2] != 3) {
// expect {H, W, C} as that's the inverse of what decode_image produces.
@ -23,7 +23,7 @@ void KernelEncodeImage ::Compute(OrtKernelContext* context) {
static_cast<int32_t>(dimensions_bgr[1])}; // W
// data is const uint8_t but opencv2 wants void*.
const void* bgr_data = ort_.GetTensorData<uint8_t>(input_bgr);
const void* bgr_data = input.Data();
const cv::Mat bgr_image(height_x_width, CV_8UC3, const_cast<void*>(bgr_data));
// don't know output size ahead of time so need to encode and then copy to output
@ -34,11 +34,8 @@ void KernelEncodeImage ::Compute(OrtKernelContext* context) {
// Setup output & copy to destination
std::vector<int64_t> output_dimensions{static_cast<int64_t>(encoded_image.size())};
OrtValue* output_value = ort_.KernelContext_GetOutput(context, 0,
output_dimensions.data(),
output_dimensions.size());
uint8_t* data = ort_.GetTensorMutableData<uint8_t>(output_value);
uint8_t* data = output.Allocate(output_dimensions);
memcpy(data, encoded_image.data(), encoded_image.size());
}
} // namespace ort_extensions

Просмотреть файл

@ -20,47 +20,11 @@ struct KernelEncodeImage : BaseKernel {
extension_ = std::string(".") + format;
}
void Compute(OrtKernelContext* context);
void Compute(const ortc::Tensor<uint8_t>& input_bgr,
ortc::Tensor<uint8_t>& output);
private:
std::string extension_;
};
/// <summary>
/// EncodeImage
///
/// Converts rank 3 BGR input with channels last ordering to the requested file type.
/// Default is 'jpg'
/// </summary>
struct CustomOpEncodeImage : OrtW::CustomOpBase<CustomOpEncodeImage, KernelEncodeImage> {
const char* GetName() const {
return "EncodeImage";
}
size_t GetInputTypeCount() const {
return 1;
}
ONNXTensorElementDataType GetInputType(size_t index) const {
switch (index) {
case 0:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
default:
ORTX_CXX_API_THROW(MakeString("Invalid input index ", index), ORT_INVALID_ARGUMENT);
}
}
size_t GetOutputTypeCount() const {
return 1;
}
ONNXTensorElementDataType GetOutputType(size_t index) const {
switch (index) {
case 0:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
default:
ORTX_CXX_API_THROW(MakeString("Invalid output index ", index), ORT_INVALID_ARGUMENT);
}
}
};
} // namespace ort_extensions

Просмотреть файл

@ -6,8 +6,11 @@
#include "encode_image.hpp"
#include "draw_bounding_box.hpp"
FxLoadCustomOpFactory LoadCustomOpClasses_Vision =
LoadCustomOpClasses<CustomOpClassBegin,
ort_extensions::CustomOpDecodeImage,
ort_extensions::CustomOpEncodeImage,
ort_extensions::CustomOpDrawBoundingBoxes>;
const std::vector<const OrtCustomOp*>& VisionLoader() {
static OrtOpLoader op_loader(CustomCpuStruct("EncodeImage", ort_extensions::KernelEncodeImage),
CustomCpuStruct("DecodeImage", ort_extensions::KernelDecodeImage),
CustomCpuStruct("DrawBoundingBoxes", ort_extensions::DrawBoundingBoxes));
return op_loader.GetCustomOps();
}
FxLoadCustomOpFactory LoadCustomOpClasses_Vision = VisionLoader;