Simplify vision ops (#465)
* simplify vision ops * remove commented --------- Co-authored-by: Randy Shuai <rashuai@microsoft.com>
This commit is contained in:
Родитель
6aaf2920bf
Коммит
5cb3153485
|
@ -8,21 +8,19 @@
|
|||
|
||||
namespace ort_extensions {
|
||||
|
||||
void KernelDecodeImage::Compute(OrtKernelContext* context) {
|
||||
void KernelDecodeImage::Compute(const ortc::Tensor<uint8_t>& input,
|
||||
ortc::Tensor<uint8_t>& output) {
|
||||
// Setup inputs
|
||||
const OrtValue* const inputs = ort_.KernelContext_GetInput(context, 0ULL);
|
||||
OrtTensorDimensions dimensions(ort_, inputs);
|
||||
const auto& dimensions = input.Shape();
|
||||
if (dimensions.size() != 1ULL) {
|
||||
ORTX_CXX_API_THROW("[DecodeImage]: Raw image bytes with 1D shape expected.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
OrtTensorTypeAndShapeInfo* input_info = ort_.GetTensorTypeAndShape(inputs);
|
||||
const int64_t encoded_image_data_len = ort_.GetTensorShapeElementCount(input_info);
|
||||
ort_.ReleaseTensorTypeAndShapeInfo(input_info);
|
||||
const int64_t encoded_image_data_len = input.NumberOfElement();
|
||||
|
||||
// Decode the image
|
||||
const std::vector<int32_t> encoded_image_sizes{1, static_cast<int32_t>(encoded_image_data_len)};
|
||||
const void* encoded_image_data = ort_.GetTensorData<uint8_t>(inputs); // uint8 data
|
||||
const void* encoded_image_data = input.Data();
|
||||
const cv::Mat encoded_image(encoded_image_sizes, CV_8UC1, const_cast<void*>(encoded_image_data));
|
||||
const cv::Mat decoded_image = cv::imdecode(encoded_image, cv::IMREAD_COLOR);
|
||||
|
||||
|
@ -37,8 +35,7 @@ void KernelDecodeImage::Compute(OrtKernelContext* context) {
|
|||
const int64_t colors = decoded_image.elemSize(); // == 3 as it's BGR
|
||||
|
||||
const std::vector<int64_t> output_dims{height, width, colors};
|
||||
OrtValue* output_value = ort_.KernelContext_GetOutput(context, 0, output_dims.data(), output_dims.size());
|
||||
uint8_t* decoded_image_data = ort_.GetTensorMutableData<uint8_t>(output_value);
|
||||
uint8_t* decoded_image_data = output.Allocate(output_dims);
|
||||
memcpy(decoded_image_data, decoded_image.data, narrow<size_t>(height * width * colors));
|
||||
}
|
||||
} // namespace ort_extensions
|
||||
|
|
|
@ -15,44 +15,7 @@ void decode_image(const ortc::Tensor<uint8_t>& input,
|
|||
|
||||
struct KernelDecodeImage : BaseKernel {
|
||||
KernelDecodeImage(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {}
|
||||
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
struct CustomOpDecodeImage : OrtW::CustomOpBase<CustomOpDecodeImage, KernelDecodeImage> {
|
||||
void KernelDestroy(void* op_kernel) {
|
||||
delete static_cast<KernelDecodeImage*>(op_kernel);
|
||||
}
|
||||
|
||||
const char* GetName() const {
|
||||
return "DecodeImage";
|
||||
}
|
||||
|
||||
size_t GetInputTypeCount() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const {
|
||||
switch (index) {
|
||||
case 0:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
||||
default:
|
||||
ORTX_CXX_API_THROW(MakeString("Invalid input index ", index), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
}
|
||||
|
||||
size_t GetOutputTypeCount() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const {
|
||||
switch (index) {
|
||||
case 0:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
||||
default:
|
||||
ORTX_CXX_API_THROW(MakeString("Invalid output index ", index), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
}
|
||||
void Compute(const ortc::Tensor<uint8_t>& input, ortc::Tensor<uint8_t>& output);
|
||||
};
|
||||
|
||||
} // namespace ort_extensions
|
||||
|
|
|
@ -222,10 +222,11 @@ void DrawBoxesByScore(ImageView& image, const BoxArray& boxes, int64_t thickness
|
|||
|
||||
} // namespace
|
||||
|
||||
void DrawBoundingBoxes::Compute(OrtKernelContext* context) {
|
||||
void DrawBoundingBoxes::Compute(const ortc::Tensor<uint8_t>& input_bgr,
|
||||
const ortc::Tensor<float>& input_box,
|
||||
ortc::Tensor<uint8_t>& output) {
|
||||
// Setup inputs
|
||||
const OrtValue* input_bgr = ort_.KernelContext_GetInput(context, 0ULL);
|
||||
const OrtTensorDimensions dimensions_bgr(ort_, input_bgr);
|
||||
const auto& dimensions_bgr = input_bgr.Shape();
|
||||
|
||||
if (dimensions_bgr.size() != 3 || dimensions_bgr[2] != 3) {
|
||||
// expect {H, W, C} as that's the inverse of what decode_image produces.
|
||||
|
@ -233,26 +234,21 @@ void DrawBoundingBoxes::Compute(OrtKernelContext* context) {
|
|||
ORTX_CXX_API_THROW("[DrawBoundingBoxes] requires rank 3 BGR input in channels last format.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
const OrtValue* input_box = ort_.KernelContext_GetInput(context, 1ULL);
|
||||
const OrtTensorDimensions dimensions_box(ort_, input_box);
|
||||
const auto& dimensions_box = input_box.Shape();
|
||||
// x,y, x/w y/h, score, class
|
||||
if (dimensions_box.size() != 2 || dimensions_box[1] != 6) {
|
||||
ORTX_CXX_API_THROW("[DrawBoundingBoxes] requires rank 2 input and the last dim should be 6.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
auto box_span = gsl::make_span(ort_.GetTensorData<float>(input_box), dimensions_box[0] * dimensions_box[1]);
|
||||
auto box_span = gsl::make_span(input_box.Data(), dimensions_box[0] * dimensions_box[1]);
|
||||
BoxArray boxes(dimensions_box, box_span, bbox_mode_);
|
||||
int64_t image_size = dimensions_bgr[0] * dimensions_bgr[1] * dimensions_bgr[2];
|
||||
|
||||
// Setup output & copy to destination
|
||||
// can we reuse the input buffer?
|
||||
const std::vector<int64_t>& output_dims = dimensions_bgr;
|
||||
OrtValue* output_value = ort_.KernelContext_GetOutput(context, 0,
|
||||
output_dims.data(),
|
||||
output_dims.size());
|
||||
|
||||
auto* output_data = ort_.GetTensorMutableData<uint8_t>(output_value);
|
||||
const auto* input_data = ort_.GetTensorData<uint8_t>(input_bgr);
|
||||
auto* output_data = output.Allocate(output_dims);
|
||||
const auto* input_data = input_bgr.Data();
|
||||
|
||||
std::copy(input_data, input_data + image_size, output_data);
|
||||
auto data_span = gsl::make_span(output_data, image_size);
|
||||
|
|
|
@ -37,7 +37,9 @@ struct DrawBoundingBoxes : BaseKernel {
|
|||
}
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext* context);
|
||||
void Compute(const ortc::Tensor<uint8_t>& input_bgr,
|
||||
const ortc::Tensor<float>& input_box,
|
||||
ortc::Tensor<uint8_t>& output);
|
||||
|
||||
private:
|
||||
int64_t thickness_;
|
||||
|
@ -46,41 +48,4 @@ struct DrawBoundingBoxes : BaseKernel {
|
|||
BoundingBoxFormat bbox_mode_;
|
||||
};
|
||||
|
||||
struct CustomOpDrawBoundingBoxes : OrtW::CustomOpBase<CustomOpDrawBoundingBoxes, DrawBoundingBoxes> {
|
||||
void KernelDestroy(void* op_kernel) {
|
||||
delete static_cast<DrawBoundingBoxes*>(op_kernel);
|
||||
}
|
||||
|
||||
const char* GetName() const {
|
||||
return "DrawBoundingBoxes";
|
||||
}
|
||||
|
||||
size_t GetInputTypeCount() const {
|
||||
return 2;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const {
|
||||
switch (index) {
|
||||
case 0:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
||||
case 1:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
default:
|
||||
ORTX_CXX_API_THROW(MakeString("Invalid input index ", index), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
}
|
||||
|
||||
size_t GetOutputTypeCount() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const {
|
||||
switch (index) {
|
||||
case 0:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
||||
default:
|
||||
ORTX_CXX_API_THROW(MakeString("Invalid output index ", index), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace ort_extensions
|
||||
|
|
|
@ -7,10 +7,10 @@
|
|||
|
||||
namespace ort_extensions {
|
||||
|
||||
void KernelEncodeImage ::Compute(OrtKernelContext* context) {
|
||||
void KernelEncodeImage::Compute(const ortc::Tensor<uint8_t>& input,
|
||||
ortc::Tensor<uint8_t>& output) {
|
||||
// Setup inputs
|
||||
const OrtValue* input_bgr = ort_.KernelContext_GetInput(context, 0ULL);
|
||||
const OrtTensorDimensions dimensions_bgr(ort_, input_bgr);
|
||||
const auto dimensions_bgr = input.Shape();
|
||||
|
||||
if (dimensions_bgr.size() != 3 || dimensions_bgr[2] != 3) {
|
||||
// expect {H, W, C} as that's the inverse of what decode_image produces.
|
||||
|
@ -23,7 +23,7 @@ void KernelEncodeImage ::Compute(OrtKernelContext* context) {
|
|||
static_cast<int32_t>(dimensions_bgr[1])}; // W
|
||||
|
||||
// data is const uint8_t but opencv2 wants void*.
|
||||
const void* bgr_data = ort_.GetTensorData<uint8_t>(input_bgr);
|
||||
const void* bgr_data = input.Data();
|
||||
const cv::Mat bgr_image(height_x_width, CV_8UC3, const_cast<void*>(bgr_data));
|
||||
|
||||
// don't know output size ahead of time so need to encode and then copy to output
|
||||
|
@ -34,11 +34,8 @@ void KernelEncodeImage ::Compute(OrtKernelContext* context) {
|
|||
|
||||
// Setup output & copy to destination
|
||||
std::vector<int64_t> output_dimensions{static_cast<int64_t>(encoded_image.size())};
|
||||
OrtValue* output_value = ort_.KernelContext_GetOutput(context, 0,
|
||||
output_dimensions.data(),
|
||||
output_dimensions.size());
|
||||
|
||||
uint8_t* data = ort_.GetTensorMutableData<uint8_t>(output_value);
|
||||
uint8_t* data = output.Allocate(output_dimensions);
|
||||
memcpy(data, encoded_image.data(), encoded_image.size());
|
||||
}
|
||||
|
||||
} // namespace ort_extensions
|
||||
|
|
|
@ -20,47 +20,11 @@ struct KernelEncodeImage : BaseKernel {
|
|||
extension_ = std::string(".") + format;
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext* context);
|
||||
void Compute(const ortc::Tensor<uint8_t>& input_bgr,
|
||||
ortc::Tensor<uint8_t>& output);
|
||||
|
||||
private:
|
||||
std::string extension_;
|
||||
};
|
||||
|
||||
/// <summary>
|
||||
/// EncodeImage
|
||||
///
|
||||
/// Converts rank 3 BGR input with channels last ordering to the requested file type.
|
||||
/// Default is 'jpg'
|
||||
/// </summary>
|
||||
struct CustomOpEncodeImage : OrtW::CustomOpBase<CustomOpEncodeImage, KernelEncodeImage> {
|
||||
const char* GetName() const {
|
||||
return "EncodeImage";
|
||||
}
|
||||
|
||||
size_t GetInputTypeCount() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const {
|
||||
switch (index) {
|
||||
case 0:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
||||
default:
|
||||
ORTX_CXX_API_THROW(MakeString("Invalid input index ", index), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
}
|
||||
|
||||
size_t GetOutputTypeCount() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const {
|
||||
switch (index) {
|
||||
case 0:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
||||
default:
|
||||
ORTX_CXX_API_THROW(MakeString("Invalid output index ", index), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace ort_extensions
|
||||
|
|
|
@ -6,8 +6,11 @@
|
|||
#include "encode_image.hpp"
|
||||
#include "draw_bounding_box.hpp"
|
||||
|
||||
FxLoadCustomOpFactory LoadCustomOpClasses_Vision =
|
||||
LoadCustomOpClasses<CustomOpClassBegin,
|
||||
ort_extensions::CustomOpDecodeImage,
|
||||
ort_extensions::CustomOpEncodeImage,
|
||||
ort_extensions::CustomOpDrawBoundingBoxes>;
|
||||
const std::vector<const OrtCustomOp*>& VisionLoader() {
|
||||
static OrtOpLoader op_loader(CustomCpuStruct("EncodeImage", ort_extensions::KernelEncodeImage),
|
||||
CustomCpuStruct("DecodeImage", ort_extensions::KernelDecodeImage),
|
||||
CustomCpuStruct("DrawBoundingBoxes", ort_extensions::DrawBoundingBoxes));
|
||||
return op_loader.GetCustomOps();
|
||||
}
|
||||
|
||||
FxLoadCustomOpFactory LoadCustomOpClasses_Vision = VisionLoader;
|
Загрузка…
Ссылка в новой задаче