Update Yolov8 tutorial with arguments (#658)
This commit is contained in:
Родитель
0c93c20761
Коммит
5c53aaad62
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 24 KiB |
|
@ -1,13 +1,15 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import argparse
|
||||
import onnx.shape_inference
|
||||
import onnxruntime_extensions
|
||||
from onnxruntime_extensions.tools.pre_post_processing import *
|
||||
from pathlib import Path
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
def get_yolov8_pose_model(onnx_model_name: str):
|
||||
|
||||
def _get_yolov8_pose_model(onnx_model_path: Path):
|
||||
# install yolov8
|
||||
from pip._internal import main as pipmain
|
||||
try:
|
||||
|
@ -15,38 +17,23 @@ def get_yolov8_pose_model(onnx_model_name: str):
|
|||
except ImportError:
|
||||
pipmain(['install', 'ultralytics'])
|
||||
import ultralytics
|
||||
|
||||
pt_model = Path("yolov8n-pose.pt")
|
||||
model = ultralytics.YOLO(str(pt_model)) # load a pretrained model
|
||||
success = model.export(format="onnx") # export the model to ONNX format
|
||||
assert success, "Failed to export yolov8n-pose.pt to onnx"
|
||||
import shutil
|
||||
shutil.move(pt_model.with_suffix('.onnx'), onnx_model_name)
|
||||
shutil.move(pt_model.with_suffix('.onnx'), str(onnx_model_path))
|
||||
|
||||
|
||||
def add_pre_post_processing_to_yolo(input_model_file: Path, output_model_file: Path,
|
||||
output_image: bool = False,
|
||||
decode_input: bool = True,
|
||||
input_shape: Optional[List[Union[int, str]]] = None):
|
||||
"""Construct the pipeline for an end2end model with pre and post processing.
|
||||
The final model can take raw image binary as inputs and output the result in raw image file.
|
||||
|
||||
Args:
|
||||
input_model_file (Path): The onnx yolo model.
|
||||
output_model_file (Path): where to save the final onnx model.
|
||||
output_image (bool): Model will draw bounding boxes on the original image and output that. It will NOT draw
|
||||
the keypoints as there's no custom operator to handle that currently.
|
||||
If false, the output will have the same shape as the original model, with all the co-ordinates updated
|
||||
to match the original input image.
|
||||
decode_input: Input is jpg/png to decode. Alternative is to provide RGB data
|
||||
input_shape: Input shape if RGB data is being provided. Can use symbolic dimensions. Either the first or last
|
||||
dimension must be 3 to determine if layout is HWC or CHW.
|
||||
"""
|
||||
if not Path(input_model_file).is_file():
|
||||
print("Fetching the model...")
|
||||
get_yolov8_pose_model(str(input_model_file))
|
||||
def _get_model_and_info(input_model_path: Path):
|
||||
if not input_model_path.is_file():
|
||||
print(f"Fetching the model... {str(input_model_path)}")
|
||||
_get_yolov8_pose_model(input_model_path)
|
||||
|
||||
print("Adding pre/post processing to the model...")
|
||||
model = onnx.load(str(input_model_file.resolve(strict=True)))
|
||||
model = onnx.load(str(input_model_path.resolve(strict=True)))
|
||||
|
||||
model_with_shape_info = onnx.shape_inference.infer_shapes(model)
|
||||
|
||||
model_input_shape = model_with_shape_info.graph.input[0].type.tensor_type.shape
|
||||
|
@ -57,63 +44,85 @@ def add_pre_post_processing_to_yolo(input_model_file: Path, output_model_file: P
|
|||
h_in = model_input_shape.dim[-2].dim_value
|
||||
assert w_in == 640 and h_in == 640 # expected values
|
||||
|
||||
# output is [1, 56, 8400]
|
||||
# there are
|
||||
# output should be [1, 56, 8400].
|
||||
classes_masks_out = model_output_shape.dim[1].dim_value
|
||||
boxes_out = model_output_shape.dim[2].dim_value
|
||||
assert classes_masks_out == 56
|
||||
assert boxes_out == 8400
|
||||
|
||||
# layout of image prior to Resize and LetterBox being run. post-processing needs to know this to determine where
|
||||
# to get the original H and W from
|
||||
if decode_input:
|
||||
inputs = [create_named_value("image", onnx.TensorProto.UINT8, ["num_bytes"])]
|
||||
# ConvertImageToBGR produces HWC output
|
||||
decoded_image_layout = "HWC"
|
||||
else:
|
||||
assert input_shape and len(input_shape) == 3, "3D input shape is required if decode_input is false."
|
||||
if input_shape[0] == 3:
|
||||
decoded_image_layout = "CHW"
|
||||
elif input_shape[2] == 3:
|
||||
decoded_image_layout = "HWC"
|
||||
else:
|
||||
raise ValueError("Invalid input shape. Either first or last dimension must be 3.")
|
||||
return (model, w_in, h_in)
|
||||
|
||||
inputs = [create_named_value("decoded_image", onnx.TensorProto.UINT8, input_shape)]
|
||||
|
||||
def _update_model(model: onnx.ModelProto, output_model_path: Path, pipeline: PrePostProcessor):
|
||||
"""
|
||||
Update the model by running the pre/post processing pipeline
|
||||
@param model: ONNX model to update
|
||||
@param output_model_path: Filename to write the updated model to.
|
||||
@param pipeline: Pre/Post processing pipeline to run.
|
||||
"""
|
||||
|
||||
new_model = pipeline.run(model)
|
||||
print("Pre/post proceessing added.")
|
||||
|
||||
# run shape inferencing to validate the new model. shape inferencing will fail if any of the new node
|
||||
# types or shapes are incorrect. infer_shapes returns a copy of the model with ValueInfo populated,
|
||||
# but we ignore that and save new_model as it is smaller due to not containing the inferred shape information.
|
||||
_ = onnx.shape_inference.infer_shapes(new_model, strict_mode=True)
|
||||
onnx.save_model(new_model, str(output_model_path.resolve()))
|
||||
print("Updated model saved.")
|
||||
|
||||
|
||||
def _add_pre_post_processing_to_rgb_input(input_model_path: Path,
|
||||
output_model_path: Path,
|
||||
input_shape: List[Union[int, str]]):
|
||||
"""
|
||||
Add pre and post processing with model input of RGB data.
|
||||
Pre-processing will convert the input to the correct height, width and data type for the model.
|
||||
Post-processing will select the best bounding boxes using NonMaxSuppression, and scale the selected bounding
|
||||
boxes and key-points to the original image size.
|
||||
|
||||
@param input_model_path: Path to ONNX model.
|
||||
@param output_model_path: Path to write updated model to.
|
||||
@param input_shape: Input shape of RGB data. Must be 3D. First or last value must be 3 (channels first or last).
|
||||
"""
|
||||
model, w_in, h_in = _get_model_and_info(input_model_path)
|
||||
|
||||
if input_shape[0] == 3:
|
||||
layout = "CHW"
|
||||
elif input_shape[2] == 3:
|
||||
layout = "HWC"
|
||||
else:
|
||||
raise ValueError("Invalid input shape. Either first or last dimension must be 3.")
|
||||
|
||||
onnx_opset = 18
|
||||
inputs = [create_named_value("rgb_data", onnx.TensorProto.UINT8, input_shape)]
|
||||
pipeline = PrePostProcessor(inputs, onnx_opset)
|
||||
|
||||
pre_processing_steps = []
|
||||
if decode_input:
|
||||
pre_processing_steps.append(ConvertImageToBGR(name="ImageHWC")) # jpg/png image to BGR in HWC layout
|
||||
if layout == "CHW":
|
||||
# use Identity so we have an output named RGBImageCHW
|
||||
# for ScaleNMSBoundingBoxesAndKeyPoints in the post-processing steps
|
||||
pre_processing_steps = [Identity(name="RGBImageCHW")]
|
||||
else:
|
||||
# use Identity if we don't need to call ChannelsLastToChannelsFirst as the next step
|
||||
if decoded_image_layout == "CHW":
|
||||
pre_processing_steps.append(Identity(name="DecodedImageCHW"))
|
||||
|
||||
if decoded_image_layout == "HWC":
|
||||
pre_processing_steps.append(ChannelsLastToChannelsFirst(name="DecodedImageCHW")) # HWC to CHW
|
||||
pre_processing_steps = [ChannelsLastToChannelsFirst(name="RGBImageCHW")] # HWC to CHW
|
||||
|
||||
pre_processing_steps += [
|
||||
# Resize an arbitrary sized image to a fixed size in not_larger policy
|
||||
# Resize to match model input. Uses not_larger as we use LetterBox to pad as needed.
|
||||
Resize((h_in, w_in), policy='not_larger', layout='CHW'),
|
||||
# padding or cropping the image to (h_in, w_in)
|
||||
LetterBox(target_shape=(h_in, w_in), layout='CHW'),
|
||||
LetterBox(target_shape=(h_in, w_in), layout='CHW'), # padding or cropping the image to (h_in, w_in)
|
||||
ImageBytesToFloat(), # Convert to float in range 0..1
|
||||
Unsqueeze([0]), # add batch, CHW --> 1CHW
|
||||
]
|
||||
|
||||
pipeline.add_pre_processing(pre_processing_steps)
|
||||
|
||||
# NMS and drawing boxes
|
||||
post_processing_steps = [
|
||||
Squeeze([0]), # - Squeeze to remove batch dimension from [batch, 56, 8200] output
|
||||
Transpose([1, 0]), # reverse so result info is inner dim
|
||||
# split the 56 elements into the box, score for the 1 class, and mask info (17 locations x 3 values)
|
||||
# split the 56 elements into 4 for the box, score for the 1 class, and mask info (17 locations x 3 values)
|
||||
Split(num_outputs=3, axis=1, splits=[4, 1, 51]),
|
||||
# Apply NMS to select best boxes. iou and score values match
|
||||
# https://github.com/ultralytics/ultralytics/blob/e7bd159a44cf7426c0f33ed9b413ef4439505a03/ultralytics/models/yolo/pose/predict.py#L34-L35
|
||||
# thresholds are arbitrarily chosen. adjust as needed.
|
||||
SelectBestBoundingBoxesByNMS(iou_threshold=0.7, score_threshold=0.25, has_mask_data=True),
|
||||
# Scale boxes and key point coords back to original image. Mask data has 17 key points per box.
|
||||
(ScaleNMSBoundingBoxesAndKeyPoints(num_key_points=17, layout='CHW'),
|
||||
|
@ -124,64 +133,127 @@ def add_pre_post_processing_to_yolo(input_model_file: Path, output_model_file: P
|
|||
# A connection from the LetterBoxed image to input 3
|
||||
# We use the three images to calculate the scale factor and offset.
|
||||
# With scale and offset, we can scale the bounding box and key points back to the original image.
|
||||
utils.IoMapEntry("DecodedImageCHW", producer_idx=0, consumer_idx=1),
|
||||
utils.IoMapEntry("RGBImageCHW", producer_idx=0, consumer_idx=1),
|
||||
utils.IoMapEntry("Resize", producer_idx=0, consumer_idx=2),
|
||||
utils.IoMapEntry("LetterBox", producer_idx=0, consumer_idx=3),
|
||||
]),
|
||||
]
|
||||
|
||||
if output_image:
|
||||
# separate out the bounding boxes from the keypoint data to use the existing steps/custom op to draw the
|
||||
# bounding boxes.
|
||||
pipeline.add_post_processing(post_processing_steps)
|
||||
|
||||
_update_model(model, output_model_path, pipeline)
|
||||
|
||||
|
||||
def _add_pre_post_processing_to_image_input(input_model_path: Path,
|
||||
output_model_path: Path,
|
||||
output_image_format: Optional[str]):
|
||||
"""
|
||||
Add pre and post processing with model input of jpg or png image bytes.
|
||||
Pre-processing will convert the input to the correct height, width and data type for the model.
|
||||
Post-processing will select the best bounding boxes using NonMaxSuppression, and scale the selected bounding
|
||||
boxes and key-points to the original image size.
|
||||
|
||||
The post-processing can alternatively return the original image with the bounding boxes drawn on it
|
||||
instead of the scaled bounding box and key point data.
|
||||
|
||||
@param input_model_path: Path to ONNX model.
|
||||
@param output_model_path: Path to write updated model to.
|
||||
@param output_image_format: Optional. Specify 'jpg' or 'png' for the post-processing to return image bytes in that
|
||||
format with the bounding boxes drawn on it.
|
||||
Otherwise the model will return the scaled bounding boxes and key points.
|
||||
"""
|
||||
model, w_in, h_in = _get_model_and_info(input_model_path)
|
||||
|
||||
onnx_opset = 18
|
||||
inputs = [create_named_value("image_bytes", onnx.TensorProto.UINT8, ["num_bytes"])]
|
||||
pipeline = PrePostProcessor(inputs, onnx_opset)
|
||||
|
||||
pre_processing_steps = [
|
||||
ConvertImageToBGR(name="BGRImageHWC"), # jpg/png image to BGR in HWC layout
|
||||
ChannelsLastToChannelsFirst(name="BGRImageCHW"), # HWC to CHW
|
||||
# Resize to match model input. Uses not_larger as we use LetterBox to pad as needed.
|
||||
Resize((h_in, w_in), policy='not_larger', layout='CHW'),
|
||||
LetterBox(target_shape=(h_in, w_in), layout='CHW'), # padding or cropping the image to (h_in, w_in)
|
||||
ImageBytesToFloat(), # Convert to float in range 0..1
|
||||
Unsqueeze([0]), # add batch, CHW --> 1CHW
|
||||
]
|
||||
|
||||
pipeline.add_pre_processing(pre_processing_steps)
|
||||
|
||||
# NonMaxSuppression and drawing boxes
|
||||
post_processing_steps = [
|
||||
Squeeze([0]), # Squeeze to remove batch dimension from [batch, 56, 8200] output
|
||||
Transpose([1, 0]), # reverse so result info is inner dim
|
||||
# split the 56 elements into the box, score for the 1 class, and mask info (17 locations x 3 values)
|
||||
Split(num_outputs=3, axis=1, splits=[4, 1, 51]),
|
||||
# Apply NMS to select best boxes. iou and score values match
|
||||
# https://github.com/ultralytics/ultralytics/blob/e7bd159a44cf7426c0f33ed9b413ef4439505a03/ultralytics/models/yolo/pose/predict.py#L34-L35
|
||||
# thresholds are arbitrarily chosen. adjust as needed
|
||||
SelectBestBoundingBoxesByNMS(iou_threshold=0.7, score_threshold=0.25, has_mask_data=True),
|
||||
# Scale boxes and key point coords back to original image. Mask data has 17 key points per box.
|
||||
(ScaleNMSBoundingBoxesAndKeyPoints(num_key_points=17, layout='CHW'),
|
||||
[
|
||||
# A default connection from SelectBestBoundingBoxesByNMS for input 0
|
||||
# A connection from original image to input 1
|
||||
# A connection from the resized image to input 2
|
||||
# A connection from the LetterBoxed image to input 3
|
||||
# We use the three images to calculate the scale factor and offset.
|
||||
# With scale and offset, we can scale the bounding box and key points back to the original image.
|
||||
utils.IoMapEntry("BGRImageCHW", producer_idx=0, consumer_idx=1),
|
||||
utils.IoMapEntry("Resize", producer_idx=0, consumer_idx=2),
|
||||
utils.IoMapEntry("LetterBox", producer_idx=0, consumer_idx=3),
|
||||
]),
|
||||
]
|
||||
|
||||
if output_image_format:
|
||||
post_processing_steps += [
|
||||
# split out bounding box from keypoint data
|
||||
Split(num_outputs=2, axis=-1, splits=[6, 51], name="SplitScaledBoxesAndKeypoints"),
|
||||
# separate out the bounding boxes from the keypoint data to use the existing steps/custom op to draw the
|
||||
# bounding boxes.
|
||||
(DrawBoundingBoxes(mode='CENTER_XYWH', num_classes=1, colour_by_classes=True),
|
||||
[
|
||||
utils.IoMapEntry("OriginalRGBImage", producer_idx=0, consumer_idx=0),
|
||||
utils.IoMapEntry("BGRImageHWC", producer_idx=0, consumer_idx=0),
|
||||
utils.IoMapEntry("SplitScaledBoxesAndKeypoints", producer_idx=0, consumer_idx=1),
|
||||
]),
|
||||
# Encode to jpg/png
|
||||
ConvertBGRToImage(image_format="png"),
|
||||
ConvertBGRToImage(image_format=output_image_format),
|
||||
]
|
||||
|
||||
pipeline.add_post_processing(post_processing_steps)
|
||||
|
||||
new_model = pipeline.run(model)
|
||||
print("Updating model ...")
|
||||
|
||||
print("Pre/post proceessing added.")
|
||||
# run shape inferencing to validate the new model. shape inferencing will fail if any of the new node
|
||||
# types or shapes are incorrect. infer_shapes returns a copy of the model with ValueInfo populated,
|
||||
# but we ignore that and save new_model as it is smaller due to not containing the inferred shape information.
|
||||
_ = onnx.shape_inference.infer_shapes(new_model, strict_mode=True)
|
||||
onnx.save_model(new_model, str(output_model_file.resolve()))
|
||||
print("Updated model saved.")
|
||||
_update_model(model, output_model_path, pipeline)
|
||||
|
||||
|
||||
def run_inference(onnx_model_file: Path, output_image: bool = False, model_decodes_image: bool = True):
|
||||
def _run_inference(onnx_model_path: Path, model_input: str, model_outputs_image: bool, test_image: Path,
|
||||
rgb_layout: Optional[str]):
|
||||
import onnxruntime as ort
|
||||
import numpy as np
|
||||
|
||||
print("Running the model to validate output.")
|
||||
print(f"Running the model to validate output using {str(test_image)}.")
|
||||
|
||||
providers = ['CPUExecutionProvider']
|
||||
session_options = ort.SessionOptions()
|
||||
session_options.register_custom_ops_library(onnxruntime_extensions.get_library_path())
|
||||
session = ort.InferenceSession(str(onnx_model_file), providers=providers, sess_options=session_options)
|
||||
session = ort.InferenceSession(str(onnx_model_path), providers=providers, sess_options=session_options)
|
||||
|
||||
input_image_path = './data/bus.jpg'
|
||||
input_name = [i.name for i in session.get_inputs()]
|
||||
if model_decodes_image:
|
||||
image_bytes = np.frombuffer(open(input_image_path, 'rb').read(), dtype=np.uint8)
|
||||
if model_input == "image":
|
||||
image_bytes = np.frombuffer(open(test_image, 'rb').read(), dtype=np.uint8)
|
||||
model_input = {input_name[0]: image_bytes}
|
||||
else:
|
||||
rgb_image = np.array(Image.open(input_image_path).convert('RGB'))
|
||||
rgb_image = rgb_image.transpose((2, 0, 1)) # Channels first
|
||||
rgb_image = np.array(Image.open(test_image).convert('RGB'))
|
||||
if rgb_layout == "CHW":
|
||||
rgb_image = rgb_image.transpose((2, 0, 1)) # Channels first
|
||||
model_input = {input_name[0]: rgb_image}
|
||||
|
||||
model_output = ['image_out'] if output_image else ['nms_output_with_scaled_boxes_and_keypoints']
|
||||
model_output = ['image'] if model_outputs_image else ['nms_output_with_scaled_boxes_and_keypoints']
|
||||
outputs = session.run(model_output, model_input)
|
||||
|
||||
if output_image:
|
||||
if model_outputs_image:
|
||||
# jpg or png with bounding boxes draw
|
||||
image_out = outputs[0]
|
||||
from io import BytesIO
|
||||
s = BytesIO(image_out)
|
||||
|
@ -192,7 +264,7 @@ def run_inference(onnx_model_file: Path, output_image: bool = False, model_decod
|
|||
[8, 10], [9, 11], [2, 3], [1, 2], [1, 3], [2, 4], [3, 5], [4, 6], [5, 7]]
|
||||
|
||||
# open original image so we can draw on it
|
||||
input_image = Image.open(input_image_path).convert('RGB')
|
||||
input_image = Image.open(test_image).convert('RGB')
|
||||
input_image_draw = ImageDraw.Draw(input_image)
|
||||
|
||||
scaled_nms_output = outputs[0]
|
||||
|
@ -236,26 +308,63 @@ def run_inference(onnx_model_file: Path, output_image: bool = False, model_decod
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
onnx_model_name = Path("./data/yolov8n-pose.onnx")
|
||||
onnx_e2e_model_name = onnx_model_name.with_suffix(suffix=".with_pre_post_processing.onnx")
|
||||
parser = argparse.ArgumentParser(
|
||||
"""Add pre and post processing to the YOLOv8 POSE model. The model can be updated to take either
|
||||
jpg/png bytes as input (--input image), or RGB data (--input rgb).
|
||||
By default the post processing will scale the bounding boxes and key points to the original image.
|
||||
""")
|
||||
parser.add_argument("--onnx_model_path", type=Path, default="yolov8n-pose.onnx",
|
||||
help="The ONNX YOLOv8 POSE model.")
|
||||
parser.add_argument("--updated_onnx_model_path", type=Path, required=False,
|
||||
help="Filename to save the updated ONNX model to. If not provided default to the filename "
|
||||
"from --onnx_model_path with '.with_pre_post_processing' before the '.onnx' "
|
||||
"e.g. yolov8n-pose.onnx -> yolov8n-pose.with_pre_post_processing.onnx")
|
||||
parser.add_argument("--input", choices=("image", "rgb"), default="image",
|
||||
help="Desired model input format. Image bytes from jpg/png or RGB data.")
|
||||
parser.add_argument("--input_shape",
|
||||
type=lambda x: [int(dim) if dim.isnumeric() else dim for dim in x.split(",")],
|
||||
required=False,
|
||||
help="Shape of RGB input if input is 'rgb'. Provide a comma separated list of 3 dimensions. "
|
||||
"Symbolic dimensions are allowed. Either the first or last dimension must be 3 to infer "
|
||||
"if layout is HWC or CHW. "
|
||||
"examples: channels first with symbolic dims for height and width: --input_shape 3,H,W "
|
||||
"or channels last with fixed input shape: --input_shape 384,512,3")
|
||||
parser.add_argument("--output_image", choices=("jpg", "png"), required=False,
|
||||
help="OPTIONAL. If the input is an image, instead of outputting the scaled bounding boxes and "
|
||||
"key points the model will draw the bounding boxes on the original image, convert to the "
|
||||
"specified format, and output the updated image bytes. The scaled key points for each "
|
||||
"selected bounding box will also be a model output."
|
||||
"NOTE: it will NOT draw the key points as there's no custom operator to handle that.")
|
||||
parser.add_argument("--run_model", action='store_true',
|
||||
help="Run inference on the model to validate output.")
|
||||
parser.add_argument("--test_image", type=Path, default="data/stormtroopers.jpg",
|
||||
help="JPG or PNG image to run model with.")
|
||||
|
||||
# default output is the scaled non-max suppresion data which matches the original model.
|
||||
# each result has bounding box (4), score (1), class (1), keypoints(17 x 3) = 57 elements
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.output_image and args.input == "rgb":
|
||||
raise argparse.ArgumentError(args.output_image, "output_image argument can only be used if input is 'image'")
|
||||
|
||||
if args.input_shape and len(args.input_shape) != 3:
|
||||
raise argparse.ArgumentError(args.input_shape, "Shape of RGB input must have 3 dimensions.")
|
||||
|
||||
updated_model_path = (args.updated_onnx_model_path
|
||||
if args.updated_onnx_model_path
|
||||
else args.onnx_model_path.with_suffix(suffix=".with_pre_post_processing.onnx"))
|
||||
|
||||
# default output is the scaled non-max suppression data which matches the original model.
|
||||
# each result has bounding box (4), score (1), class (1), key points (17 x 3) = 57 elements
|
||||
# bounding box is centered XYWH format.
|
||||
# alternative is to output the original image with the bounding boxes but no key points drawn.
|
||||
output_image_with_bounding_boxes = False
|
||||
if args.input == "rgb":
|
||||
print("Updating model with RGB data as input.")
|
||||
_add_pre_post_processing_to_rgb_input(args.onnx_model_path, updated_model_path, args.input_shape)
|
||||
rgb_layout = "CHW" if args.input_shape[0] == 3 else "HWC"
|
||||
else:
|
||||
assert(args.input == "image")
|
||||
print("Updating model with jpg/png image bytes as input.")
|
||||
_add_pre_post_processing_to_image_input(args.onnx_model_path, updated_model_path, args.output_image)
|
||||
rgb_layout = None
|
||||
|
||||
for model_decodes_image in [True, False]:
|
||||
if model_decodes_image:
|
||||
print("Running with model taking jpg/png as input.")
|
||||
else:
|
||||
print("Running with model taking RGB data as input.")
|
||||
|
||||
input_shape = None
|
||||
if not model_decodes_image:
|
||||
# NOTE: This uses CHW just for the sake of testing both layouts
|
||||
input_shape = [3, "h_in", "w_in"]
|
||||
|
||||
add_pre_post_processing_to_yolo(onnx_model_name, onnx_e2e_model_name, output_image_with_bounding_boxes,
|
||||
model_decodes_image, input_shape)
|
||||
run_inference(onnx_e2e_model_name, output_image_with_bounding_boxes, model_decodes_image)
|
||||
if args.run_model:
|
||||
_run_inference(updated_model_path, args.input, args.output_image is not None, args.test_image, rgb_layout)
|
||||
|
|
Загрузка…
Ссылка в новой задаче