This commit is contained in:
PatrickBue 2019-11-22 10:06:46 -05:00 коммит произвёл Young Park
Родитель f913cbbaf0
Коммит 8c3cf1a634
5 изменённых файлов: 663 добавлений и 1316 удалений

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -15,7 +15,7 @@ from torchvision.transforms import ColorJitter
import xml.etree.ElementTree as ET
from PIL import Image
from .plot import display_bboxes_mask, plot_grid
from .plot import plot_detections, plot_grid
from .bbox import AnnotationBbox
from .mask import binarise_mask
from .references.utils import collate_fn
@ -152,7 +152,7 @@ class DetectionDataset:
im_dir: str = "images",
mask_dir: str = None,
seed: int = None,
allow_negatives: bool = False
allow_negatives: bool = False,
):
""" initialize dataset
@ -255,7 +255,7 @@ class DetectionDataset:
# Assume mask image name matches image name but has .png
# extension
mask_name = os.path.basename(self.im_paths[-1])
mask_name = mask_name[:mask_name.rindex('.')] + ".png"
mask_name = mask_name[: mask_name.rindex(".")] + ".png"
mask_path = self.root / self.mask_dir / mask_name
# For mask prediction, if no mask provided and negatives not
# allowed (), ignore the image
@ -389,7 +389,15 @@ class DetectionDataset:
if seed or self.seed:
random.seed(seed or self.seed)
plot_grid(display_bboxes_mask, self._get_random_anno, rows=rows, cols=cols)
def helper(im_paths):
idx = random.randrange(len(im_paths))
detection = {}
detection["idx"] = idx
detection["im_path"] = im_paths[idx]
detection["det_bboxes"] = []
return detection, self, None
plot_grid(plot_detections, helper(self.im_paths), rows=2)
def show_im_transformations(
self, idx: int = None, rows: int = 1, cols: int = 3
@ -438,21 +446,12 @@ class DetectionDataset:
# for the tiny bounding box in _read_annos(), make the mask to
# be the whole box
mask = np.zeros(
Image.open(self.im_paths[idx]).size[::-1],
dtype=np.uint8
Image.open(self.im_paths[idx]).size[::-1], dtype=np.uint8
)
binary_masks = binarise_mask(mask)
return binary_masks
def _get_random_anno(self) -> Tuple:
""" Get random annotation and corresponding image
Returns a list of annotations and the image path
"""
idx = random.randrange(len(self.im_paths))
return self.anno_bboxes[idx], self.im_paths[idx], self._get_binary_mask(idx)
def __getitem__(self, idx):
""" Make iterable. """
# get box/labels from annotations

Просмотреть файл

@ -4,15 +4,7 @@
import os
import itertools
import json
from typing import (
Callable,
List,
Tuple,
Union,
Generator,
Optional,
Dict,
)
from typing import Callable, List, Tuple, Union, Generator, Optional, Dict
from pathlib import Path
import shutil
@ -53,9 +45,9 @@ def _get_det_bboxes_and_mask(
Return:
a dict of DetectionBboxes and masks
"""
pred_labels = pred['labels'].tolist()
pred_boxes = pred['boxes'].tolist()
pred_scores = pred['scores'].tolist()
pred_labels = pred["labels"].tolist()
pred_boxes = pred["boxes"].tolist()
pred_scores = pred["scores"].tolist()
det_bboxes = []
for label, box, score in zip(pred_labels, pred_boxes, pred_scores):
@ -69,7 +61,7 @@ def _get_det_bboxes_and_mask(
)
det_bboxes.append(det_bbox)
res = {"det_bboxes": det_bboxes}
res = {"det_bboxes": det_bboxes, "im_path": im_path}
if "masks" in pred:
res["masks"] = pred["masks"].squeeze(1)
@ -77,8 +69,7 @@ def _get_det_bboxes_and_mask(
def _apply_threshold(
pred: Dict[str, np.ndarray],
threshold: Optional[float] = 0.5,
pred: Dict[str, np.ndarray], threshold: Optional[float] = 0.5
) -> Dict:
""" Return prediction results that are above the threshold if any.
@ -90,7 +81,7 @@ def _apply_threshold(
"""
# apply score threshold
if threshold:
selected = pred['scores'] > threshold
selected = pred["scores"] > threshold
pred = {k: v[selected] for k, v in pred.items()}
# apply mask threshold
if "masks" in pred:
@ -157,10 +148,7 @@ def _tune_box_predictor(model: nn.Module, num_classes: int) -> nn.Module:
return model
def get_pretrained_fasterrcnn(
num_classes: int = None,
**kwargs,
) -> nn.Module:
def get_pretrained_fasterrcnn(num_classes: int = None, **kwargs) -> nn.Module:
""" Gets a pretrained FasterRCNN model
Args:
@ -177,10 +165,7 @@ def get_pretrained_fasterrcnn(
# intuitive.
# load a model pre-trained on COCO
model = _get_pretrained_rcnn(
fasterrcnn_resnet50_fpn,
**kwargs,
)
model = _get_pretrained_rcnn(fasterrcnn_resnet50_fpn, **kwargs)
# if num_classes is specified, then create new final bounding box
# prediction layers, otherwise use pre-trained layers
@ -190,10 +175,7 @@ def get_pretrained_fasterrcnn(
return model
def get_pretrained_maskrcnn(
num_classes: int = None,
**kwargs,
) -> nn.Module:
def get_pretrained_maskrcnn(num_classes: int = None, **kwargs) -> nn.Module:
""" Gets a pretrained Mask R-CNN model
Args:
@ -208,10 +190,7 @@ def get_pretrained_maskrcnn(
"""
# load a model pre-trained on COCO
model = _get_pretrained_rcnn(
maskrcnn_resnet50_fpn,
**kwargs,
)
model = _get_pretrained_rcnn(maskrcnn_resnet50_fpn, **kwargs)
# if num_classes is specified, then create new final bounding box
# and mask prediction layers, otherwise use pre-trained layers
@ -224,9 +203,7 @@ def get_pretrained_maskrcnn(
in_features = model.roi_heads.mask_predictor.conv5_mask.in_channels
# replace the mask predictor with a new one
model.roi_heads.mask_predictor = MaskRCNNPredictor(
in_features,
256,
num_classes
in_features, 256, num_classes
)
return model
@ -477,6 +454,7 @@ class DetectionLearner:
print_freq: int = 10,
step_size: int = None,
gamma: float = 0.1,
skip_evaluation: bool = False,
) -> None:
""" The main training loop. """
@ -522,9 +500,12 @@ class DetectionLearner:
self.lr_scheduler.step()
# evaluate
e = self.evaluate(dl=self.dataset.test_dl)
self.ap.append(_calculate_ap(e))
self.ap_iou_point_5.append(_calculate_ap(e, iou_threshold_idx=0))
if not skip_evaluation:
e = self.evaluate(dl=self.dataset.test_dl)
self.ap.append(_calculate_ap(e))
self.ap_iou_point_5.append(
_calculate_ap(e, iou_threshold_idx=0)
)
def plot_precision_loss_curves(
self, figsize: Tuple[int, int] = (10, 5)
@ -536,7 +517,7 @@ class DetectionLearner:
for i, (k, v) in enumerate(ap.items()):
ax1 = fig.add_subplot(1, len(ap), i+1)
ax1 = fig.add_subplot(1, len(ap), i + 1)
ax1.set_xlim([0, self.epochs - 1])
ax1.set_xticks(range(0, self.epochs))
@ -600,15 +581,11 @@ class DetectionLearner:
# detach prediction results to cpu
pred = {k: v.detach().cpu().numpy() for k, v in pred.items()}
return _get_det_bboxes_and_mask(
_apply_threshold(pred, threshold=threshold),
self.labels,
im_path
_apply_threshold(pred, threshold=threshold), self.labels, im_path
)
def predict_dl(
self,
dl: DataLoader,
threshold: Optional[float] = 0.5,
self, dl: DataLoader, threshold: Optional[float] = 0.5
) -> List[DetectionBbox]:
""" Predict all images in a dataloader object.
@ -623,9 +600,7 @@ class DetectionLearner:
return [pred for preds in pred_generator for pred in preds]
def predict_batch(
self,
dl: DataLoader,
threshold: Optional[float] = 0.5,
self, dl: DataLoader, threshold: Optional[float] = 0.5
) -> Generator[List[DetectionBbox], None, None]:
""" Batch predict
@ -654,7 +629,7 @@ class DetectionLearner:
bboxes_masks = _get_det_bboxes_and_mask(
_apply_threshold(pred, threshold=threshold),
self.labels,
dl.dataset.dataset.im_paths[im_id]
dl.dataset.dataset.im_paths[im_id],
)
results.append({"idx": im_id, **bboxes_masks})

Просмотреть файл

@ -4,10 +4,8 @@
"""
Helper module for visualizations
"""
import os
from pathlib import Path
from typing import List, Union, Tuple, Callable, Any, Iterator, Optional
from pathlib import Path
import numpy as np
import PIL
@ -15,7 +13,7 @@ from PIL import Image, ImageDraw
from torch.utils.data import Subset
import matplotlib.pyplot as plt
from .bbox import _Bbox, AnnotationBbox, DetectionBbox
from .bbox import _Bbox, DetectionBbox
from .model import ims_eval_detections
from .references.coco_eval import CocoEvaluator
from ..common.misc import get_font
@ -28,11 +26,11 @@ class PlotSettings:
def __init__(
self,
rect_th: int = 4,
rect_color: Tuple[int, int, int] = (255, 0, 0),
rect_color: Tuple[int, int, int] = (0, 0, 255),
text_size: int = 25,
text_color: Tuple[int, int, int] = (255, 255, 255),
mask_color: Tuple[int, int, int] = (2, 166, 101),
mask_alpha: float = 0.5,
text_color: Tuple[int, int, int] = (0, 0, 255),
mask_color: Tuple[int, int, int] = (0, 0, 128),
mask_alpha: float = 0.8,
):
self.rect_th = rect_th
self.rect_color = rect_color
@ -65,7 +63,7 @@ def plot_boxes(
for bbox in bboxes:
# do not draw background bounding boxes
if bbox.label_idx == 0:
if hasattr(bbox, "label_idx") and bbox.label_idx == 0:
continue
box = [(bbox.left, bbox.top), (bbox.right, bbox.bottom)]
@ -91,7 +89,7 @@ def plot_boxes(
return im
def plot_mask(
def plot_masks(
im: Union[str, Path, PIL.Image.Image],
mask: Union[str, Path, np.ndarray],
plot_settings: PlotSettings = PlotSettings(),
@ -103,15 +101,16 @@ def plot_mask(
representing different objects, 0 as background.
"""
if isinstance(im, (str, Path)):
SEE_IF_I_CAN_REMOVE_THIS_IF_CLAUSE
im = Image.open(im)
# convert to RGBA for transparentising
im = im.convert('RGBA')
im = im.convert("RGBA")
# colorise masks
binary_masks = binarise_mask(mask)
colored_masks = [
colorise_binary_mask(bmask, plot_settings.mask_color) for bmask in
binary_masks
colorise_binary_mask(bmask, plot_settings.mask_color)
for bmask in binary_masks
]
# merge masks into img one by one
for cmask in colored_masks:
@ -123,52 +122,60 @@ def plot_mask(
return im
def display_bboxes_mask(
bboxes: List[_Bbox],
im_path: Union[Path, str],
mask_path: Union[Path, str] = None,
ax: Optional[plt.axes] = None,
plot_settings: PlotSettings = PlotSettings(),
figsize: Tuple[int, int] = (12, 12),
) -> None:
""" Draw image with bounding boxes and mask.
def plot_detections(detection, data=None, idx=None, ax: plt.axes = None):
# Open image
assert detection["im_path"], 'Detection["im_path"] should not be None.'
im = Image.open(detection["im_path"])
Args:
bboxes: A list of _Bbox, could be DetectionBbox or AnnotationBbox
im_path: the location of image path to draw
mask_path: the location of mask path to draw
ax: an optional ax to specify where you wish the figure to be drawn on
plot_settings: plotting parameters
figsize: figure size
# Get id of ground truth image/annotation
if data and not idx:
idx = detection["idx"]
Returns nothing, but plots the image with bounding boxes, labels and masks
if any.
"""
# Read image
im = Image.open(im_path)
# Loop over all images
det_bboxes = detection["det_bboxes"]
# set an image title
title = os.path.basename(im_path)
# Plot ground truth mask
if data and data.mask_paths:
mask_path = data.mask_paths[idx]
if mask_path:
im = plot_masks(
im,
mask_path,
plot_settings=PlotSettings(mask_color=(0, 128, 0)),
)
if mask_path is not None:
# plot masks on im
im = plot_mask(im_path, mask_path)
# Plot predicted masks
if "masks" in detection:
mask = detection["masks"]
im = plot_masks(im, mask, PlotSettings(mask_color=(128, 165, 0)))
if bboxes is not None:
# plot boxes on im
im = plot_boxes(im, bboxes, title=title, plot_settings=plot_settings)
# Plot the detections
plot_boxes(
im,
det_bboxes,
plot_settings=PlotSettings(
rect_color=(255, 165, 0), text_color=(255, 165, 0), rect_th=2
),
)
# display the image
if ax is not None:
# Plot the ground truth annotations
if data:
anno_bboxes = data.anno_bboxes[idx]
plot_boxes(
im,
anno_bboxes,
plot_settings=PlotSettings(
rect_color=(0, 255, 0), text_color=(0, 255, 0)
),
)
# show image
if ax:
ax.set_xticks([])
ax.set_yticks([])
ax.imshow(im)
else:
plt.figure(figsize=figsize)
plt.imshow(im)
plt.xticks([])
plt.yticks([])
plt.show()
return im
def plot_grid(
@ -218,48 +225,6 @@ def plot_grid(
plt.subplots_adjust(top=0.8, bottom=0.2, hspace=0.1, wspace=0.2)
def plot_detection_vs_ground_truth(
im_path: str,
det_bboxes: List[DetectionBbox],
anno_bboxes: List[AnnotationBbox],
ax: plt.axes,
) -> None:
""" Plots bounding boxes of ground_truths and detections.
Args:
im_path: the image to plot
det_bboxes: a list of detected annotations
anno_bboxes: a list of ground_truth detections
ax: the axis to plot on
Returns nothing, but displays a graph
"""
im = Image.open(im_path).convert("RGB")
# plot detections
det_params = PlotSettings(rect_color=(255, 0, 0), text_size=1)
im = plot_boxes(
im,
det_bboxes,
title=os.path.basename(im_path),
plot_settings=det_params,
)
# plot ground truth boxes
anno_params = PlotSettings(rect_color=(0, 255, 0), text_size=1)
im = plot_boxes(
im,
anno_bboxes,
title=os.path.basename(im_path),
plot_settings=anno_params,
)
# show image
ax.set_xticks([])
ax.set_yticks([])
ax.imshow(im)
# ===== Precision - Recall curve =====
@ -311,9 +276,7 @@ def _get_precision_recall_settings(
def _plot_pr_curve_iou_range(
ax: plt.axes,
coco_eval: CocoEvaluator,
iou_type: Optional[str] = None,
ax: plt.axes, coco_eval: CocoEvaluator, iou_type: Optional[str] = None
) -> None:
""" Plots the PR curve over varying iou thresholds averaging over [K]
categories. """
@ -339,9 +302,7 @@ def _plot_pr_curve_iou_range(
def _plot_pr_curve_iou_mean(
ax: plt.axes,
coco_eval: CocoEvaluator,
iou_type: Optional[str] = None,
ax: plt.axes, coco_eval: CocoEvaluator, iou_type: Optional[str] = None
) -> None:
""" Plots the PR curve, averaging over iou thresholds and [K] labels. """
x = np.arange(0.0, 1.01, 0.01)
@ -401,8 +362,6 @@ def plot_pr_curves(
plt.show()
# ===== Correct/missing detection counts curve =====
@ -476,7 +435,6 @@ def _plot_counts_curves_obj(
label="Total number of missed ground truths",
)
ax.legend()
ax.set_xlabel("Score threshold")
ax.set_ylabel("Frequency")