add(tutorials): exporting yolo world model (#803)

* add(tutorials): exporting yolo world model

This allows us to export yolo world onnx model which can be later used in mobile inference.

* add(tutorial): make classes optional

---------

Co-authored-by: Scott McKay <skottmckay@gmail.com>
This commit is contained in:
Stalin Sabu Thomas 2024-10-03 10:12:35 +05:30 коммит произвёл GitHub
Родитель 12a9e8beb4
Коммит f47bed4596
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
1 изменённых файлов: 432 добавлений и 0 удалений

Просмотреть файл

@ -0,0 +1,432 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
import re
from functools import partial
from pathlib import Path
from typing import Literal, Optional, TypeAlias
from warnings import warn
import onnx.shape_inference
import onnxruntime_extensions
from onnxruntime_extensions.tools.pre_post_processing import *
from PIL import Image, ImageDraw
ModelSize: TypeAlias = Optional[Literal["s", "m", "l", "x"]]
BASE_PATH = Path(__file__).parent
def _get_yolov8_world_model(onnx_model_path: Path, classes: list[str], size: ModelSize = None):
# install yolov8
from pip._internal import main as pipmain
try:
import ultralytics
except ImportError:
pipmain(["install", "ultralytics"])
import ultralytics
if size is None:
regex = r"yolov8(.{1})"
matches = re.search(regex, onnx_model_path.name)
if matches:
size = matches.group(1)
else:
size = "m"
warn(f"YOLO size not set and not able to determine yolo world size, defaulting to {size}.")
# NOTE: only v2 models are exportable
pt_model = Path(f"yolov8{size}-worldv2.pt")
# load yolo world pretrained model
model = ultralytics.YOLOWorld(str(pt_model))
# set to the right classes
model.set_classes(classes)
# export the vocabulary to be used as YOLO model
pt_model_path = onnx_model_path.with_suffix(".pt")
model.save(pt_model_path)
model = ultralytics.YOLO(pt_model_path)
# export the model to ONNX format
success = model.export(format="onnx", optimize=True, simplify=True)
assert success, f"Failed to export {pt_model_path.name} to onnx"
assert onnx_model_path.exists(), "Falied to export "
def _get_model_and_info(
input_model_path: Path, classes: list[str], model_size: ModelSize = None
) -> tuple[onnx.ModelProto, int, int]:
if not input_model_path.is_file():
print(f"Fetching the model... {str(input_model_path)}")
_get_yolov8_world_model(input_model_path, classes, size=model_size)
print("Adding pre/post processing to the model...")
model = onnx.load(str(input_model_path.resolve(strict=True)))
model_with_shape_info = onnx.shape_inference.infer_shapes(model)
model_input_shape = model_with_shape_info.graph.input[0].type.tensor_type.shape
model_output_shape = model_with_shape_info.graph.output[0].type.tensor_type.shape
num_classes = len(classes)
# infer the input sizes from the model.
w_in = model_input_shape.dim[-1].dim_value
h_in = model_input_shape.dim[-2].dim_value
assert w_in == 640 and h_in == 640 # expected values
# output should be [1, num_classes+4(bbox coords), 8400].
classes_bbox = model_output_shape.dim[1].dim_value
boxes_out = model_output_shape.dim[2].dim_value
assert classes_bbox == 4 + num_classes
assert boxes_out == 8400
return model, w_in, h_in
def _update_model(model: onnx.ModelProto, output_model_path: Path, pipeline: PrePostProcessor):
"""
Update the model by running the pre/post processing pipeline
@param model: ONNX model to update
@param output_model_path: Filename to write the updated model to.
@param pipeline: Pre/Post processing pipeline to run.
"""
new_model = pipeline.run(model)
print("Pre/post proceessing added.")
# run shape inferencing to validate the new model. shape inferencing will fail if any of the new node
# types or shapes are incorrect. infer_shapes returns a copy of the model with ValueInfo populated,
# but we ignore that and save new_model as it is smaller due to not containing the inferred shape information.
_ = onnx.shape_inference.infer_shapes(new_model, strict_mode=True)
onnx.save_model(new_model, str(output_model_path.resolve()))
print("Updated model saved.")
def _add_pre_post_processing(
classes: list[str],
input_type: Literal["rgb", "image"],
input_model_path: Path,
output_model_path: Path,
output_image_format: Optional[Literal["jpg", "png"]] = None,
input_shape: Optional[List[Union[int, str]]] = None,
model_size: ModelSize = None,
):
"""
Add pre and post processing with model input of jpg or png image bytes or just RGB data.
Pre-processing will convert the input to the correct height, width and data type for the model.
Post-processing will select the best bounding boxes using NonMaxSuppression, and scale the selected bounding
boxes to the original image size.
The post-processing can alternatively return the original image with the bounding boxes drawn on it
instead of the scaled bounding box and key point data.
@param classes: Classes that will be sent as prompt to yolo world model.
@param input_type: Is the input an image or raw rgb.
@param input_model_path: Path to ONNX model.
@param output_model_path: Path to write updated model to.
@param output_image_format: Optional. Specify 'jpg' or 'png' for the post-processing to return image bytes in that
format with the bounding boxes drawn on it.
Otherwise the model will return the scaled bounding boxes.
Can only be used if input_type is 'image'.
@param input_shape: Optional. Input shape of RGB data. Must be 3D.
First or last value must be 3 (channels first or last).
This is required if input_type is 'raw'.
@param model_size: Size of the yolo model. Valid values ["s", "m", "l", "x"].
If None, we automatically detect based on filename.
"""
num_classes = len(classes)
model, w_in, h_in = _get_model_and_info(input_model_path, classes, model_size)
pre_processing_steps = []
if input_type == "rgb":
if output_image_format is not None:
raise ValueError("Model cannot output to image when input_type is 'raw'.")
if input_shape is None:
raise ValueError("For input_type 'raw', provide input_shape.")
elif input_shape[0] == 3:
layout = "CHW"
elif input_shape[2] == 3:
layout = "HWC"
else:
raise ValueError("Invalid input shape. Either first or last dimension must be 3.")
inputs = [create_named_value("rgb_data", onnx.TensorProto.UINT8, input_shape)]
if layout == "CHW":
# use Identity so we have an output named RGBImageCHW
# for ScaleNMSBoundingBoxesAndKeyPoints in the post-processing steps
pre_processing_steps += [Identity(name="RGBImageCHW")]
else:
pre_processing_steps += [ChannelsLastToChannelsFirst(name="RGBImageCHW")] # HWC to CHW
else:
inputs = [create_named_value("image_bytes", onnx.TensorProto.UINT8, ["num_bytes"])]
pre_processing_steps += [
ConvertImageToBGR(name="BGRImageHWC"), # jpg/png image to BGR in HWC layout
ChannelsLastToChannelsFirst(name="BGRImageCHW"), # HWC to CHW
]
onnx_opset = 18
pipeline = PrePostProcessor(inputs, onnx_opset)
pre_processing_steps += [
# Resize to match model input. Uses not_larger as we use LetterBox to pad as needed.
Resize((h_in, w_in), policy="not_larger", layout="CHW"),
LetterBox(target_shape=(h_in, w_in), layout="CHW"), # padding or cropping the image to (h_in, w_in)
ImageBytesToFloat(), # Convert to float in range 0..1
Unsqueeze([0]), # add batch, CHW --> 1CHW
]
pipeline.add_pre_processing(pre_processing_steps)
# NonMaxSuppression and drawing boxes
post_processing_steps = [
Squeeze([0]), # Squeeze to remove batch dimension from [batch, num_classes+4, 8400] output
Transpose([1, 0]), # reverse so result info is inner dim
# split the 56 elements into the box, score for the 1 class, and mask info (17 locations x 3 values)
Split(num_outputs=2, axis=1, splits=[4, num_classes]),
# Apply NMS to select best boxes. iou and score values match
# https://github.com/ultralytics/ultralytics/blob/e7bd159a44cf7426c0f33ed9b413ef4439505a03/ultralytics/models/yolo/pose/predict.py#L34-L35
# thresholds are arbitrarily chosen. adjust as needed
SelectBestBoundingBoxesByNMS(iou_threshold=0.7, score_threshold=0.25),
# Scale boxes and key point coords back to original image. Mask data has 17 key points per box.
(
ScaleNMSBoundingBoxesAndKeyPoints(num_key_points=17, layout="CHW"),
[
# A default connection from SelectBestBoundingBoxesByNMS for input 0
# A connection from original image to input 1
# A connection from the resized image to input 2
# A connection from the LetterBoxed image to input 3
# We use the three images to calculate the scale factor and offset.
# With scale and offset, we can scale the bounding box and key points back to the original image.
utils.IoMapEntry(
"RGBImageCHW" if input_type == "rgb" else "BGRImageCHW", producer_idx=0, consumer_idx=1
),
utils.IoMapEntry("Resize", producer_idx=0, consumer_idx=2),
utils.IoMapEntry("LetterBox", producer_idx=0, consumer_idx=3),
],
),
]
if output_image_format:
post_processing_steps += [
# DrawBoundingBoxes on the original image
# Model imported from pytorch has CENTER_XYWH format
# two mode for how to color box,
# 1. colour_by_classes=True, (colour_by_classes), 2. colour_by_classes=False,(colour_by_confidence)
(
DrawBoundingBoxes(mode="CENTER_XYWH", num_classes=num_classes, colour_by_classes=True),
[
utils.IoMapEntry("ConvertImageToBGR", producer_idx=0, consumer_idx=0),
utils.IoMapEntry("ScaleBoundingBoxes", producer_idx=0, consumer_idx=1),
],
),
# Encode to jpg/png
ConvertBGRToImage(image_format=output_image_format),
]
pipeline.add_post_processing(post_processing_steps)
print("Updating model ...")
_update_model(model, output_model_path, pipeline)
def _run_inference(
onnx_model_path: Path,
model_input: str,
model_outputs_image: bool,
test_image: Path,
classes: list[str],
rgb_layout: Optional[str],
):
import numpy as np
import onnxruntime as ort
print(f"Running the model to validate output using {str(test_image)}.")
providers = ["CPUExecutionProvider"]
session_options = ort.SessionOptions()
session_options.register_custom_ops_library(onnxruntime_extensions.get_library_path())
session = ort.InferenceSession(str(onnx_model_path), providers=providers, sess_options=session_options)
input_name = [i.name for i in session.get_inputs()]
if model_input == "image":
image_bytes = np.frombuffer(open(test_image, "rb").read(), dtype=np.uint8)
model_input = {input_name[0]: image_bytes}
else:
rgb_image = np.array(Image.open(test_image).convert("RGB"))
if rgb_layout == "CHW":
rgb_image = rgb_image.transpose((2, 0, 1)) # Channels first
model_input = {input_name[0]: rgb_image}
model_output = ["image"] if model_outputs_image else ["nms_output_with_scaled_boxes_and_keypoints"]
outputs = session.run(model_output, model_input)
if model_outputs_image:
# jpg or png with bounding boxes draw
image_out = outputs[0]
from io import BytesIO
s = BytesIO(image_out)
Image.open(s).show()
else:
# open original image so we can draw on it
input_image = Image.open(test_image).convert("RGB")
input_image_draw = ImageDraw.Draw(input_image)
scaled_nms_output = outputs[0]
for result in scaled_nms_output:
# split the 4 box coords, 1 score, 1 class
(box, score, class_id) = np.split(result, (4, 5))
class_id = int(class_id)
score = float(score * 100)
# convert box from centered XYWH to co-ords and draw rectangle
# NOTE: The pytorch model seems to output XYXY co-ords. Not sure why that's different.
half_w = box[2] / 2
half_h = box[3] / 2
x0 = box[0] - half_w
y0 = box[1] - half_h
x1 = box[0] + half_w
y1 = box[1] + half_h
input_image_draw.rectangle(((x0, y0), (x1, y1)), outline="red", width=4)
input_image_draw.text((x0, y0), f"{classes[class_id]}-{score:.2f}%")
print("Displaying original image with bounding boxes.")
input_image.show()
def load_classes(args) -> list[str]:
if args.classes:
return args.classes
elif args.classes_file:
classes_file = Path(args.classes_file)
with classes_file.open() as fp:
if classes_file.suffix == ".json":
import json
classes = json.load(fp)
elif classes_file.suffix == ".txt":
classes = fp.read().splitlines()
else:
raise ValueError(f"Invalid file type {classes_file}")
return classes
else:
# default value for data/stormtroopers.jpg
return ["person", "helmet"]
# python tutorials/yolov8_world_e2e.py yolov8m-worldv2.onnx --infer --input=rgb --input_shape H,W,3
if __name__ == "__main__":
parser = argparse.ArgumentParser(
"""Add pre and post processing to the YOLOv8 World model. The model can be updated to take either
jpg/png bytes as input (--input image), or RGB data (--input rgb).
NOTE: Use only YOLO WOrld v2 model as that is the only one with export capability.
By default the post processing will scale the bounding boxes and key points to the original image.
"""
)
parser.add_argument("model", type=Path, help="The ONNX YOLOv8 World model.")
classes_group = parser.add_mutually_exclusive_group()
classes_group.add_argument(
"--classes",
type=lambda x: re.split(r"\s*,\s*", x.strip()),
default=["person", "helmet"],
help="List of class names that will be passed to yolo world prompt."
"Default values 'person,helmet' for data/stormtroopers.jpg",
)
classes_group.add_argument(
"--classes-file", type=Path, help="JSON file containing list of classes that will be set as yolo world prompt."
)
parser.add_argument(
"--size",
choices=("s", "m", "l", "x"),
default=None,
help="Size of yolo world model.",
)
parser.add_argument(
"--updated_onnx_model_path",
type=Path,
required=False,
help="Filename to save the updated ONNX model to. If not provided default to the filename "
"from --onnx_model_path with '.with_pre_post_processing' before the '.onnx' "
"e.g. yolov8m-worldv2.onnx -> yolov8m-worldv2.with_pre_post_processing.onnx",
)
parser.add_argument(
"--input",
choices=("image", "rgb"),
default="image",
help="Desired model input format. Image bytes from jpg/png or RGB data.",
)
parser.add_argument(
"--input_shape",
type=lambda x: [int(dim) if dim.isnumeric() else dim for dim in x.split(",")],
default=["H", "W", 3],
required=False,
help="Shape of RGB input if input is 'rgb'. Provide a comma separated list of 3 dimensions. "
"Symbolic dimensions are allowed. Either the first or last dimension must be 3 to infer "
"if layout is HWC or CHW. "
"examples: channels first with symbolic dims for height and width: --input_shape 3,H,W "
"or channels last with fixed input shape: --input_shape 384,512,3",
)
parser.add_argument(
"--output_as_image",
choices=("jpg", "png"),
required=False,
help="OPTIONAL. If the input is an image, instead of outputting the scaled bounding boxes "
"the model will draw the bounding boxes on the original image, convert to the "
"specified format, and output the updated image bytes.",
)
parser.add_argument("--infer", action="store_true", help="Run inference on the model to validate output.")
parser.add_argument(
"--test_image",
type=Path,
default=BASE_PATH / "data/stormtroopers.jpg",
help="JPG or PNG image to run model with.",
)
args = parser.parse_args()
classes = load_classes(args)
num_classes = len(classes)
assert num_classes > 0, "Requires prompt for YOLO world model."
if args.output_as_image and args.input == "rgb":
raise argparse.ArgumentError(
args.output_as_image, "output_as_image argument can only be used if input is 'image'"
)
if args.input_shape and len(args.input_shape) != 3:
raise argparse.ArgumentError(args.input_shape, "Shape of RGB input must have 3 dimensions.")
updated_model_path = (
args.updated_onnx_model_path
if args.updated_onnx_model_path
else args.model.with_suffix(suffix=".with_pre_post_processing.onnx")
)
# default output is the scaled non-max suppression data which matches the original model.
# each result has bounding box (4), score (1), class (num_classes), = num_classes+5 elements
# bounding box is centered XYWH format.
# alternative is to output the original image with the bounding boxes.
add_pre_post_processing = partial(
_add_pre_post_processing,
classes=classes,
input_type=args.input,
input_model_path=args.model,
output_model_path=updated_model_path,
model_size=args.size,
)
if args.input == "rgb":
print("Updating model with RGB data as input.")
add_pre_post_processing(input_shape=args.input_shape)
rgb_layout = "CHW" if args.input_shape[0] == 3 else "HWC"
elif args.input == "image":
print("Updating model with jpg/png image bytes as input.")
add_pre_post_processing(output_image_format=args.output_as_image)
rgb_layout = None
if args.infer:
_run_inference(
updated_model_path, args.input, args.output_as_image is not None, args.test_image, classes, rgb_layout
)