simplified plotting functions
This commit is contained in:
Родитель
f913cbbaf0
Коммит
8c3cf1a634
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -15,7 +15,7 @@ from torchvision.transforms import ColorJitter
|
|||
import xml.etree.ElementTree as ET
|
||||
from PIL import Image
|
||||
|
||||
from .plot import display_bboxes_mask, plot_grid
|
||||
from .plot import plot_detections, plot_grid
|
||||
from .bbox import AnnotationBbox
|
||||
from .mask import binarise_mask
|
||||
from .references.utils import collate_fn
|
||||
|
@ -152,7 +152,7 @@ class DetectionDataset:
|
|||
im_dir: str = "images",
|
||||
mask_dir: str = None,
|
||||
seed: int = None,
|
||||
allow_negatives: bool = False
|
||||
allow_negatives: bool = False,
|
||||
):
|
||||
""" initialize dataset
|
||||
|
||||
|
@ -255,7 +255,7 @@ class DetectionDataset:
|
|||
# Assume mask image name matches image name but has .png
|
||||
# extension
|
||||
mask_name = os.path.basename(self.im_paths[-1])
|
||||
mask_name = mask_name[:mask_name.rindex('.')] + ".png"
|
||||
mask_name = mask_name[: mask_name.rindex(".")] + ".png"
|
||||
mask_path = self.root / self.mask_dir / mask_name
|
||||
# For mask prediction, if no mask provided and negatives not
|
||||
# allowed (), ignore the image
|
||||
|
@ -389,7 +389,15 @@ class DetectionDataset:
|
|||
if seed or self.seed:
|
||||
random.seed(seed or self.seed)
|
||||
|
||||
plot_grid(display_bboxes_mask, self._get_random_anno, rows=rows, cols=cols)
|
||||
def helper(im_paths):
|
||||
idx = random.randrange(len(im_paths))
|
||||
detection = {}
|
||||
detection["idx"] = idx
|
||||
detection["im_path"] = im_paths[idx]
|
||||
detection["det_bboxes"] = []
|
||||
return detection, self, None
|
||||
|
||||
plot_grid(plot_detections, helper(self.im_paths), rows=2)
|
||||
|
||||
def show_im_transformations(
|
||||
self, idx: int = None, rows: int = 1, cols: int = 3
|
||||
|
@ -438,21 +446,12 @@ class DetectionDataset:
|
|||
# for the tiny bounding box in _read_annos(), make the mask to
|
||||
# be the whole box
|
||||
mask = np.zeros(
|
||||
Image.open(self.im_paths[idx]).size[::-1],
|
||||
dtype=np.uint8
|
||||
Image.open(self.im_paths[idx]).size[::-1], dtype=np.uint8
|
||||
)
|
||||
binary_masks = binarise_mask(mask)
|
||||
|
||||
return binary_masks
|
||||
|
||||
def _get_random_anno(self) -> Tuple:
|
||||
""" Get random annotation and corresponding image
|
||||
|
||||
Returns a list of annotations and the image path
|
||||
"""
|
||||
idx = random.randrange(len(self.im_paths))
|
||||
return self.anno_bboxes[idx], self.im_paths[idx], self._get_binary_mask(idx)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
""" Make iterable. """
|
||||
# get box/labels from annotations
|
||||
|
|
|
@ -4,15 +4,7 @@
|
|||
import os
|
||||
import itertools
|
||||
import json
|
||||
from typing import (
|
||||
Callable,
|
||||
List,
|
||||
Tuple,
|
||||
Union,
|
||||
Generator,
|
||||
Optional,
|
||||
Dict,
|
||||
)
|
||||
from typing import Callable, List, Tuple, Union, Generator, Optional, Dict
|
||||
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
|
@ -53,9 +45,9 @@ def _get_det_bboxes_and_mask(
|
|||
Return:
|
||||
a dict of DetectionBboxes and masks
|
||||
"""
|
||||
pred_labels = pred['labels'].tolist()
|
||||
pred_boxes = pred['boxes'].tolist()
|
||||
pred_scores = pred['scores'].tolist()
|
||||
pred_labels = pred["labels"].tolist()
|
||||
pred_boxes = pred["boxes"].tolist()
|
||||
pred_scores = pred["scores"].tolist()
|
||||
|
||||
det_bboxes = []
|
||||
for label, box, score in zip(pred_labels, pred_boxes, pred_scores):
|
||||
|
@ -69,7 +61,7 @@ def _get_det_bboxes_and_mask(
|
|||
)
|
||||
det_bboxes.append(det_bbox)
|
||||
|
||||
res = {"det_bboxes": det_bboxes}
|
||||
res = {"det_bboxes": det_bboxes, "im_path": im_path}
|
||||
|
||||
if "masks" in pred:
|
||||
res["masks"] = pred["masks"].squeeze(1)
|
||||
|
@ -77,8 +69,7 @@ def _get_det_bboxes_and_mask(
|
|||
|
||||
|
||||
def _apply_threshold(
|
||||
pred: Dict[str, np.ndarray],
|
||||
threshold: Optional[float] = 0.5,
|
||||
pred: Dict[str, np.ndarray], threshold: Optional[float] = 0.5
|
||||
) -> Dict:
|
||||
""" Return prediction results that are above the threshold if any.
|
||||
|
||||
|
@ -90,7 +81,7 @@ def _apply_threshold(
|
|||
"""
|
||||
# apply score threshold
|
||||
if threshold:
|
||||
selected = pred['scores'] > threshold
|
||||
selected = pred["scores"] > threshold
|
||||
pred = {k: v[selected] for k, v in pred.items()}
|
||||
# apply mask threshold
|
||||
if "masks" in pred:
|
||||
|
@ -157,10 +148,7 @@ def _tune_box_predictor(model: nn.Module, num_classes: int) -> nn.Module:
|
|||
return model
|
||||
|
||||
|
||||
def get_pretrained_fasterrcnn(
|
||||
num_classes: int = None,
|
||||
**kwargs,
|
||||
) -> nn.Module:
|
||||
def get_pretrained_fasterrcnn(num_classes: int = None, **kwargs) -> nn.Module:
|
||||
""" Gets a pretrained FasterRCNN model
|
||||
|
||||
Args:
|
||||
|
@ -177,10 +165,7 @@ def get_pretrained_fasterrcnn(
|
|||
# intuitive.
|
||||
|
||||
# load a model pre-trained on COCO
|
||||
model = _get_pretrained_rcnn(
|
||||
fasterrcnn_resnet50_fpn,
|
||||
**kwargs,
|
||||
)
|
||||
model = _get_pretrained_rcnn(fasterrcnn_resnet50_fpn, **kwargs)
|
||||
|
||||
# if num_classes is specified, then create new final bounding box
|
||||
# prediction layers, otherwise use pre-trained layers
|
||||
|
@ -190,10 +175,7 @@ def get_pretrained_fasterrcnn(
|
|||
return model
|
||||
|
||||
|
||||
def get_pretrained_maskrcnn(
|
||||
num_classes: int = None,
|
||||
**kwargs,
|
||||
) -> nn.Module:
|
||||
def get_pretrained_maskrcnn(num_classes: int = None, **kwargs) -> nn.Module:
|
||||
""" Gets a pretrained Mask R-CNN model
|
||||
|
||||
Args:
|
||||
|
@ -208,10 +190,7 @@ def get_pretrained_maskrcnn(
|
|||
|
||||
"""
|
||||
# load a model pre-trained on COCO
|
||||
model = _get_pretrained_rcnn(
|
||||
maskrcnn_resnet50_fpn,
|
||||
**kwargs,
|
||||
)
|
||||
model = _get_pretrained_rcnn(maskrcnn_resnet50_fpn, **kwargs)
|
||||
|
||||
# if num_classes is specified, then create new final bounding box
|
||||
# and mask prediction layers, otherwise use pre-trained layers
|
||||
|
@ -224,9 +203,7 @@ def get_pretrained_maskrcnn(
|
|||
in_features = model.roi_heads.mask_predictor.conv5_mask.in_channels
|
||||
# replace the mask predictor with a new one
|
||||
model.roi_heads.mask_predictor = MaskRCNNPredictor(
|
||||
in_features,
|
||||
256,
|
||||
num_classes
|
||||
in_features, 256, num_classes
|
||||
)
|
||||
|
||||
return model
|
||||
|
@ -477,6 +454,7 @@ class DetectionLearner:
|
|||
print_freq: int = 10,
|
||||
step_size: int = None,
|
||||
gamma: float = 0.1,
|
||||
skip_evaluation: bool = False,
|
||||
) -> None:
|
||||
""" The main training loop. """
|
||||
|
||||
|
@ -522,9 +500,12 @@ class DetectionLearner:
|
|||
self.lr_scheduler.step()
|
||||
|
||||
# evaluate
|
||||
e = self.evaluate(dl=self.dataset.test_dl)
|
||||
self.ap.append(_calculate_ap(e))
|
||||
self.ap_iou_point_5.append(_calculate_ap(e, iou_threshold_idx=0))
|
||||
if not skip_evaluation:
|
||||
e = self.evaluate(dl=self.dataset.test_dl)
|
||||
self.ap.append(_calculate_ap(e))
|
||||
self.ap_iou_point_5.append(
|
||||
_calculate_ap(e, iou_threshold_idx=0)
|
||||
)
|
||||
|
||||
def plot_precision_loss_curves(
|
||||
self, figsize: Tuple[int, int] = (10, 5)
|
||||
|
@ -536,7 +517,7 @@ class DetectionLearner:
|
|||
|
||||
for i, (k, v) in enumerate(ap.items()):
|
||||
|
||||
ax1 = fig.add_subplot(1, len(ap), i+1)
|
||||
ax1 = fig.add_subplot(1, len(ap), i + 1)
|
||||
|
||||
ax1.set_xlim([0, self.epochs - 1])
|
||||
ax1.set_xticks(range(0, self.epochs))
|
||||
|
@ -600,15 +581,11 @@ class DetectionLearner:
|
|||
# detach prediction results to cpu
|
||||
pred = {k: v.detach().cpu().numpy() for k, v in pred.items()}
|
||||
return _get_det_bboxes_and_mask(
|
||||
_apply_threshold(pred, threshold=threshold),
|
||||
self.labels,
|
||||
im_path
|
||||
_apply_threshold(pred, threshold=threshold), self.labels, im_path
|
||||
)
|
||||
|
||||
def predict_dl(
|
||||
self,
|
||||
dl: DataLoader,
|
||||
threshold: Optional[float] = 0.5,
|
||||
self, dl: DataLoader, threshold: Optional[float] = 0.5
|
||||
) -> List[DetectionBbox]:
|
||||
""" Predict all images in a dataloader object.
|
||||
|
||||
|
@ -623,9 +600,7 @@ class DetectionLearner:
|
|||
return [pred for preds in pred_generator for pred in preds]
|
||||
|
||||
def predict_batch(
|
||||
self,
|
||||
dl: DataLoader,
|
||||
threshold: Optional[float] = 0.5,
|
||||
self, dl: DataLoader, threshold: Optional[float] = 0.5
|
||||
) -> Generator[List[DetectionBbox], None, None]:
|
||||
""" Batch predict
|
||||
|
||||
|
@ -654,7 +629,7 @@ class DetectionLearner:
|
|||
bboxes_masks = _get_det_bboxes_and_mask(
|
||||
_apply_threshold(pred, threshold=threshold),
|
||||
self.labels,
|
||||
dl.dataset.dataset.im_paths[im_id]
|
||||
dl.dataset.dataset.im_paths[im_id],
|
||||
)
|
||||
results.append({"idx": im_id, **bboxes_masks})
|
||||
|
||||
|
|
|
@ -4,10 +4,8 @@
|
|||
"""
|
||||
Helper module for visualizations
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Tuple, Callable, Any, Iterator, Optional
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
|
@ -15,7 +13,7 @@ from PIL import Image, ImageDraw
|
|||
from torch.utils.data import Subset
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from .bbox import _Bbox, AnnotationBbox, DetectionBbox
|
||||
from .bbox import _Bbox, DetectionBbox
|
||||
from .model import ims_eval_detections
|
||||
from .references.coco_eval import CocoEvaluator
|
||||
from ..common.misc import get_font
|
||||
|
@ -28,11 +26,11 @@ class PlotSettings:
|
|||
def __init__(
|
||||
self,
|
||||
rect_th: int = 4,
|
||||
rect_color: Tuple[int, int, int] = (255, 0, 0),
|
||||
rect_color: Tuple[int, int, int] = (0, 0, 255),
|
||||
text_size: int = 25,
|
||||
text_color: Tuple[int, int, int] = (255, 255, 255),
|
||||
mask_color: Tuple[int, int, int] = (2, 166, 101),
|
||||
mask_alpha: float = 0.5,
|
||||
text_color: Tuple[int, int, int] = (0, 0, 255),
|
||||
mask_color: Tuple[int, int, int] = (0, 0, 128),
|
||||
mask_alpha: float = 0.8,
|
||||
):
|
||||
self.rect_th = rect_th
|
||||
self.rect_color = rect_color
|
||||
|
@ -65,7 +63,7 @@ def plot_boxes(
|
|||
|
||||
for bbox in bboxes:
|
||||
# do not draw background bounding boxes
|
||||
if bbox.label_idx == 0:
|
||||
if hasattr(bbox, "label_idx") and bbox.label_idx == 0:
|
||||
continue
|
||||
|
||||
box = [(bbox.left, bbox.top), (bbox.right, bbox.bottom)]
|
||||
|
@ -91,7 +89,7 @@ def plot_boxes(
|
|||
return im
|
||||
|
||||
|
||||
def plot_mask(
|
||||
def plot_masks(
|
||||
im: Union[str, Path, PIL.Image.Image],
|
||||
mask: Union[str, Path, np.ndarray],
|
||||
plot_settings: PlotSettings = PlotSettings(),
|
||||
|
@ -103,15 +101,16 @@ def plot_mask(
|
|||
representing different objects, 0 as background.
|
||||
"""
|
||||
if isinstance(im, (str, Path)):
|
||||
SEE_IF_I_CAN_REMOVE_THIS_IF_CLAUSE
|
||||
im = Image.open(im)
|
||||
|
||||
# convert to RGBA for transparentising
|
||||
im = im.convert('RGBA')
|
||||
im = im.convert("RGBA")
|
||||
# colorise masks
|
||||
binary_masks = binarise_mask(mask)
|
||||
colored_masks = [
|
||||
colorise_binary_mask(bmask, plot_settings.mask_color) for bmask in
|
||||
binary_masks
|
||||
colorise_binary_mask(bmask, plot_settings.mask_color)
|
||||
for bmask in binary_masks
|
||||
]
|
||||
# merge masks into img one by one
|
||||
for cmask in colored_masks:
|
||||
|
@ -123,52 +122,60 @@ def plot_mask(
|
|||
return im
|
||||
|
||||
|
||||
def display_bboxes_mask(
|
||||
bboxes: List[_Bbox],
|
||||
im_path: Union[Path, str],
|
||||
mask_path: Union[Path, str] = None,
|
||||
ax: Optional[plt.axes] = None,
|
||||
plot_settings: PlotSettings = PlotSettings(),
|
||||
figsize: Tuple[int, int] = (12, 12),
|
||||
) -> None:
|
||||
""" Draw image with bounding boxes and mask.
|
||||
def plot_detections(detection, data=None, idx=None, ax: plt.axes = None):
|
||||
# Open image
|
||||
assert detection["im_path"], 'Detection["im_path"] should not be None.'
|
||||
im = Image.open(detection["im_path"])
|
||||
|
||||
Args:
|
||||
bboxes: A list of _Bbox, could be DetectionBbox or AnnotationBbox
|
||||
im_path: the location of image path to draw
|
||||
mask_path: the location of mask path to draw
|
||||
ax: an optional ax to specify where you wish the figure to be drawn on
|
||||
plot_settings: plotting parameters
|
||||
figsize: figure size
|
||||
# Get id of ground truth image/annotation
|
||||
if data and not idx:
|
||||
idx = detection["idx"]
|
||||
|
||||
Returns nothing, but plots the image with bounding boxes, labels and masks
|
||||
if any.
|
||||
"""
|
||||
# Read image
|
||||
im = Image.open(im_path)
|
||||
# Loop over all images
|
||||
det_bboxes = detection["det_bboxes"]
|
||||
|
||||
# set an image title
|
||||
title = os.path.basename(im_path)
|
||||
# Plot ground truth mask
|
||||
if data and data.mask_paths:
|
||||
mask_path = data.mask_paths[idx]
|
||||
if mask_path:
|
||||
im = plot_masks(
|
||||
im,
|
||||
mask_path,
|
||||
plot_settings=PlotSettings(mask_color=(0, 128, 0)),
|
||||
)
|
||||
|
||||
if mask_path is not None:
|
||||
# plot masks on im
|
||||
im = plot_mask(im_path, mask_path)
|
||||
# Plot predicted masks
|
||||
if "masks" in detection:
|
||||
mask = detection["masks"]
|
||||
im = plot_masks(im, mask, PlotSettings(mask_color=(128, 165, 0)))
|
||||
|
||||
if bboxes is not None:
|
||||
# plot boxes on im
|
||||
im = plot_boxes(im, bboxes, title=title, plot_settings=plot_settings)
|
||||
# Plot the detections
|
||||
plot_boxes(
|
||||
im,
|
||||
det_bboxes,
|
||||
plot_settings=PlotSettings(
|
||||
rect_color=(255, 165, 0), text_color=(255, 165, 0), rect_th=2
|
||||
),
|
||||
)
|
||||
|
||||
# display the image
|
||||
if ax is not None:
|
||||
# Plot the ground truth annotations
|
||||
if data:
|
||||
anno_bboxes = data.anno_bboxes[idx]
|
||||
plot_boxes(
|
||||
im,
|
||||
anno_bboxes,
|
||||
plot_settings=PlotSettings(
|
||||
rect_color=(0, 255, 0), text_color=(0, 255, 0)
|
||||
),
|
||||
)
|
||||
|
||||
# show image
|
||||
if ax:
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
ax.imshow(im)
|
||||
else:
|
||||
plt.figure(figsize=figsize)
|
||||
plt.imshow(im)
|
||||
plt.xticks([])
|
||||
plt.yticks([])
|
||||
plt.show()
|
||||
return im
|
||||
|
||||
|
||||
def plot_grid(
|
||||
|
@ -218,48 +225,6 @@ def plot_grid(
|
|||
plt.subplots_adjust(top=0.8, bottom=0.2, hspace=0.1, wspace=0.2)
|
||||
|
||||
|
||||
def plot_detection_vs_ground_truth(
|
||||
im_path: str,
|
||||
det_bboxes: List[DetectionBbox],
|
||||
anno_bboxes: List[AnnotationBbox],
|
||||
ax: plt.axes,
|
||||
) -> None:
|
||||
""" Plots bounding boxes of ground_truths and detections.
|
||||
|
||||
Args:
|
||||
im_path: the image to plot
|
||||
det_bboxes: a list of detected annotations
|
||||
anno_bboxes: a list of ground_truth detections
|
||||
ax: the axis to plot on
|
||||
|
||||
Returns nothing, but displays a graph
|
||||
"""
|
||||
im = Image.open(im_path).convert("RGB")
|
||||
|
||||
# plot detections
|
||||
det_params = PlotSettings(rect_color=(255, 0, 0), text_size=1)
|
||||
im = plot_boxes(
|
||||
im,
|
||||
det_bboxes,
|
||||
title=os.path.basename(im_path),
|
||||
plot_settings=det_params,
|
||||
)
|
||||
|
||||
# plot ground truth boxes
|
||||
anno_params = PlotSettings(rect_color=(0, 255, 0), text_size=1)
|
||||
im = plot_boxes(
|
||||
im,
|
||||
anno_bboxes,
|
||||
title=os.path.basename(im_path),
|
||||
plot_settings=anno_params,
|
||||
)
|
||||
|
||||
# show image
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
ax.imshow(im)
|
||||
|
||||
|
||||
# ===== Precision - Recall curve =====
|
||||
|
||||
|
||||
|
@ -311,9 +276,7 @@ def _get_precision_recall_settings(
|
|||
|
||||
|
||||
def _plot_pr_curve_iou_range(
|
||||
ax: plt.axes,
|
||||
coco_eval: CocoEvaluator,
|
||||
iou_type: Optional[str] = None,
|
||||
ax: plt.axes, coco_eval: CocoEvaluator, iou_type: Optional[str] = None
|
||||
) -> None:
|
||||
""" Plots the PR curve over varying iou thresholds averaging over [K]
|
||||
categories. """
|
||||
|
@ -339,9 +302,7 @@ def _plot_pr_curve_iou_range(
|
|||
|
||||
|
||||
def _plot_pr_curve_iou_mean(
|
||||
ax: plt.axes,
|
||||
coco_eval: CocoEvaluator,
|
||||
iou_type: Optional[str] = None,
|
||||
ax: plt.axes, coco_eval: CocoEvaluator, iou_type: Optional[str] = None
|
||||
) -> None:
|
||||
""" Plots the PR curve, averaging over iou thresholds and [K] labels. """
|
||||
x = np.arange(0.0, 1.01, 0.01)
|
||||
|
@ -401,8 +362,6 @@ def plot_pr_curves(
|
|||
plt.show()
|
||||
|
||||
|
||||
|
||||
|
||||
# ===== Correct/missing detection counts curve =====
|
||||
|
||||
|
||||
|
@ -476,7 +435,6 @@ def _plot_counts_curves_obj(
|
|||
label="Total number of missed ground truths",
|
||||
)
|
||||
|
||||
|
||||
ax.legend()
|
||||
ax.set_xlabel("Score threshold")
|
||||
ax.set_ylabel("Frequency")
|
||||
|
|
Загрузка…
Ссылка в новой задаче