Add keypoint detection with tuned model (#454)
* Add keypoint detetion with tuned model * Add tests * Minor revision * Update tests * Fix bugs in tests * Use GPU device if available * Update tests * Fix bug: 'not idx' will be 'True' if 'idx=0' * Fix bugs * Move toy keypoint meta into notebook * Fix bugs * Fix bugs * Fix bugs in notebook * Add descriptions for keypoint meta data * Raise exception when RandomHorizontalFlip is used without specifying hflip_inds
This commit is contained in:
Родитель
1b35fdd655
Коммит
2bcf13aee0
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -25,10 +25,10 @@ from utils_cv.classification.data import Urls as ic_urls
|
|||
from utils_cv.detection.data import Urls as od_urls
|
||||
from utils_cv.detection.bbox import DetectionBbox, AnnotationBbox
|
||||
from utils_cv.detection.dataset import DetectionDataset
|
||||
from utils_cv.detection.keypoint import Keypoints
|
||||
from utils_cv.detection.model import (
|
||||
get_pretrained_fasterrcnn,
|
||||
get_pretrained_maskrcnn,
|
||||
get_pretrained_keypointrcnn,
|
||||
DetectionLearner,
|
||||
_extract_od_results,
|
||||
_apply_threshold,
|
||||
|
@ -432,19 +432,6 @@ def od_mask_rects() -> Tuple:
|
|||
return binary_masks, mask, rects, im
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def od_keypoints_for_plot() -> Tuple:
|
||||
# a completely black image
|
||||
im = Image.fromarray(np.zeros((500, 600, 3), dtype=np.uint8))
|
||||
|
||||
# dummy keypoints
|
||||
keypoints = Keypoints(
|
||||
np.array([[[100, 200, 2], [200, 200, 2]]]),
|
||||
{"num_keypoints": 2, "skeleton": [[0, 1]]},
|
||||
)
|
||||
return im, keypoints
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tiny_od_data_path(tmp_session) -> str:
|
||||
""" Returns the path to the fridge object detection dataset. """
|
||||
|
@ -467,6 +454,17 @@ def tiny_od_mask_data_path(tmp_session) -> str:
|
|||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tiny_od_keypoint_data_path(tmp_session) -> str:
|
||||
""" Returns the path to the fridge object detection keypoint dataset. """
|
||||
return unzip_url(
|
||||
od_urls.fridge_objects_keypoint_milk_bottle_tiny_path,
|
||||
fpath=tmp_session,
|
||||
dest=tmp_session,
|
||||
exist_ok=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def od_sample_im_anno(tiny_od_data_path) -> Tuple[Path, ...]:
|
||||
""" Returns an annotation and image path from the tiny_od_data_path fixture.
|
||||
|
@ -546,6 +544,21 @@ def od_sample_detection(od_sample_raw_preds, od_detection_mask_dataset):
|
|||
return detections
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def od_sample_keypoint_detection(
|
||||
od_sample_raw_preds, tiny_od_detection_keypoint_dataset
|
||||
):
|
||||
labels = ["one", "two", "three", "four"]
|
||||
detections = _extract_od_results(
|
||||
_apply_threshold(od_sample_raw_preds[0], threshold=0.9),
|
||||
labels,
|
||||
tiny_od_detection_keypoint_dataset.im_paths[0],
|
||||
)
|
||||
detections["idx"] = 0
|
||||
del detections["masks"]
|
||||
return detections
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def od_detection_dataset(tiny_od_data_path):
|
||||
""" returns a basic detection dataset. """
|
||||
|
@ -560,6 +573,34 @@ def od_detection_mask_dataset(tiny_od_mask_data_path):
|
|||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tiny_od_detection_keypoint_dataset(tiny_od_keypoint_data_path):
|
||||
""" returns a basic detection keypoint dataset. """
|
||||
return DetectionDataset(
|
||||
tiny_od_keypoint_data_path,
|
||||
keypoint_meta={
|
||||
"labels": [
|
||||
"lid_left_top",
|
||||
"lid_right_top",
|
||||
"lid_left_bottom",
|
||||
"lid_right_bottom",
|
||||
"left_bottom",
|
||||
"right_bottom",
|
||||
],
|
||||
"skeleton": [
|
||||
[0, 1],
|
||||
[0, 2],
|
||||
[1, 3],
|
||||
[2, 3],
|
||||
[2, 4],
|
||||
[3, 5],
|
||||
[4, 5],
|
||||
],
|
||||
"hflip_inds": [1, 0, 3, 2, 5, 4],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
@pytest.fixture(scope="session")
|
||||
def od_detection_learner(od_detection_dataset):
|
||||
|
@ -596,6 +637,27 @@ def od_detection_mask_learner(od_detection_mask_dataset):
|
|||
return learner
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
@pytest.fixture(scope="session")
|
||||
def od_detection_keypoint_learner(tiny_od_detection_keypoint_dataset):
|
||||
""" returns a keypoint detection learner that has been trained for one epoch. """
|
||||
model = get_pretrained_keypointrcnn(
|
||||
num_classes=len(tiny_od_detection_keypoint_dataset.labels) + 1,
|
||||
num_keypoints=len(
|
||||
tiny_od_detection_keypoint_dataset.keypoint_meta["labels"]
|
||||
),
|
||||
min_size=100,
|
||||
max_size=200,
|
||||
rpn_pre_nms_top_n_train=500,
|
||||
rpn_pre_nms_top_n_test=250,
|
||||
rpn_post_nms_top_n_train=500,
|
||||
rpn_post_nms_top_n_test=250,
|
||||
)
|
||||
learner = DetectionLearner(tiny_od_detection_keypoint_dataset, model=model)
|
||||
learner.fit(1, skip_evaluation=True)
|
||||
return learner
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
@pytest.fixture(scope="session")
|
||||
def od_detection_eval(od_detection_learner):
|
||||
|
|
|
@ -4,7 +4,12 @@
|
|||
import pytest
|
||||
from typing import List, Optional
|
||||
|
||||
from utils_cv.detection.bbox import DetectionBbox, AnnotationBbox, _Bbox, bboxes_iou
|
||||
from utils_cv.detection.bbox import (
|
||||
DetectionBbox,
|
||||
AnnotationBbox,
|
||||
_Bbox,
|
||||
bboxes_iou,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
|
@ -24,10 +29,7 @@ def det_bbox() -> "DetectionBbox":
|
|||
)
|
||||
|
||||
|
||||
def validate_bbox(
|
||||
bbox: _Bbox,
|
||||
rect: Optional[List[int]] = None
|
||||
) -> None:
|
||||
def validate_bbox(bbox: _Bbox, rect: Optional[List[int]] = None) -> None:
|
||||
if rect is None:
|
||||
rect = [0, 10, 100, 1000]
|
||||
assert [bbox.left, bbox.top, bbox.right, bbox.bottom] == rect
|
||||
|
@ -38,7 +40,7 @@ def validate_anno_bbox(
|
|||
label_idx: int,
|
||||
rect: Optional[List[int]] = None,
|
||||
im_path: Optional[str] = None,
|
||||
label_name: Optional[str] = None
|
||||
label_name: Optional[str] = None,
|
||||
):
|
||||
validate_bbox(bbox, rect)
|
||||
assert type(bbox) == AnnotationBbox
|
||||
|
@ -131,7 +133,7 @@ def test_detection_bbox_from_array(det_bbox):
|
|||
assert type(bbox_from_array) == DetectionBbox
|
||||
|
||||
|
||||
def test_bboxes_iou():
|
||||
def test_bboxes_iou():
|
||||
# test bboxes which do not overlap
|
||||
basic_bbox = _Bbox(left=0, top=10, right=100, bottom=1000)
|
||||
non_overlapping_bbox = _Bbox(left=200, top=10, right=300, bottom=1000)
|
||||
|
@ -139,5 +141,6 @@ def test_bboxes_iou():
|
|||
|
||||
# test bboxes which overlap
|
||||
overlapping_bbox = _Bbox(left=10, top=500, right=300, bottom=2000)
|
||||
assert bboxes_iou(basic_bbox, overlapping_bbox) == pytest.approx(0.092, rel=1e-2)
|
||||
|
||||
assert bboxes_iou(basic_bbox, overlapping_bbox) == pytest.approx(
|
||||
0.092, rel=1e-2
|
||||
)
|
||||
|
|
|
@ -84,11 +84,7 @@ def labelbox_export_data(tmp_session):
|
|||
# Dict version of the combination of keypoint_json and anno_xml
|
||||
keypoint_truth_dict = {
|
||||
"folder": "images",
|
||||
"size": {
|
||||
"width": "500",
|
||||
"height": "500",
|
||||
"depth": "3",
|
||||
},
|
||||
"size": {"width": "500", "height": "500", "depth": "3"},
|
||||
"object": {
|
||||
"milk_bottle": {
|
||||
"bndbox": {
|
||||
|
@ -176,16 +172,16 @@ def test_coco_labels():
|
|||
assert len(labels) == 91
|
||||
|
||||
|
||||
def test_extract_keypoints_from_labelbox_json(labelbox_export_data, tmp_session):
|
||||
def test_extract_keypoints_from_labelbox_json(
|
||||
labelbox_export_data, tmp_session
|
||||
):
|
||||
data_dir, _, keypoint_json_path, keypoint_truth_dict = labelbox_export_data
|
||||
keypoint_data_dir = Path(tmp_session) / "labelbox_test_keypoint_data"
|
||||
keypoint_data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# run extract_keypoints_from_labelbox_json()
|
||||
extract_keypoints_from_labelbox_json(
|
||||
keypoint_json_path,
|
||||
data_dir,
|
||||
keypoint_data_dir,
|
||||
keypoint_json_path, data_dir, keypoint_data_dir
|
||||
)
|
||||
|
||||
# verify keypoint data directory structure
|
||||
|
@ -249,11 +245,7 @@ def test_extract_masks_from_labelbox_json(labelbox_export_data, tmp_session):
|
|||
mask_data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# run masks_from_labelbox_json()
|
||||
extract_masks_from_labelbox_json(
|
||||
mask_json_path,
|
||||
data_dir,
|
||||
mask_data_dir,
|
||||
)
|
||||
extract_masks_from_labelbox_json(mask_json_path, data_dir, mask_data_dir)
|
||||
|
||||
# verify mask data directory structure
|
||||
# only 1.jpg, 1.xml and 1.png are included
|
||||
|
|
|
@ -68,16 +68,35 @@ def test_get_transform(basic_im):
|
|||
assert type(tfms_im) == Tensor
|
||||
|
||||
|
||||
def test_parse_pascal_voc(od_sample_im_anno, od_sample_bboxes):
|
||||
def test_parse_pascal_voc(
|
||||
od_sample_im_anno, od_sample_bboxes, tiny_od_keypoint_data_path
|
||||
):
|
||||
""" test that 'parse_pascal_voc' can parse the 'od_sample_im_anno' correctly. """
|
||||
anno_path, im_path = od_sample_im_anno
|
||||
anno_bboxes, im_path = parse_pascal_voc_anno(anno_path)
|
||||
anno_bboxes, im_path, _ = parse_pascal_voc_anno(anno_path)
|
||||
assert type(anno_bboxes[0]) == AnnotationBbox
|
||||
assert anno_bboxes[0].left == od_sample_bboxes[0].left
|
||||
assert anno_bboxes[0].right == od_sample_bboxes[0].right
|
||||
assert anno_bboxes[0].top == od_sample_bboxes[0].top
|
||||
assert anno_bboxes[0].bottom == od_sample_bboxes[0].bottom
|
||||
|
||||
# test keypoints
|
||||
anno_path = Path(tiny_od_keypoint_data_path) / "annotations" / "9.xml"
|
||||
keypoints_truth = np.array(
|
||||
[
|
||||
[
|
||||
[328, 227, 2],
|
||||
[382, 228, 2],
|
||||
[326, 247, 2],
|
||||
[382, 249, 2],
|
||||
[302, 440, 2],
|
||||
[379, 446, 2],
|
||||
]
|
||||
]
|
||||
)
|
||||
_, _, keypoints_pred = parse_pascal_voc_anno(anno_path)
|
||||
np.all(keypoints_pred == keypoints_truth)
|
||||
|
||||
|
||||
def validate_detection_dataset(data: DetectionDataset, labels: List[str]):
|
||||
assert len(data) == 39 if data.mask_paths is None else 31
|
||||
|
@ -90,10 +109,18 @@ def validate_detection_dataset(data: DetectionDataset, labels: List[str]):
|
|||
assert len(data.mask_paths) == len(data.im_paths)
|
||||
|
||||
|
||||
def validate_milkbottle_keypoint_tiny_dataset(data: DetectionDataset):
|
||||
assert len(data) == 31
|
||||
assert type(data) == DetectionDataset
|
||||
assert len(data.labels) == 1
|
||||
assert len(data.keypoints) == len(data.im_paths)
|
||||
|
||||
|
||||
def test_detection_dataset_init_basic(
|
||||
tiny_od_data_path,
|
||||
od_data_path_labels,
|
||||
tiny_od_mask_data_path
|
||||
tiny_od_mask_data_path,
|
||||
tiny_od_keypoint_data_path,
|
||||
):
|
||||
""" Tests that initialization of the Detection Dataset works. """
|
||||
data = DetectionDataset(tiny_od_data_path)
|
||||
|
@ -109,16 +136,46 @@ def test_detection_dataset_init_basic(
|
|||
|
||||
# test mask data
|
||||
data = DetectionDataset(
|
||||
tiny_od_mask_data_path,
|
||||
mask_dir="segmentation-masks"
|
||||
tiny_od_mask_data_path, mask_dir="segmentation-masks"
|
||||
)
|
||||
validate_detection_dataset(data, od_data_path_labels)
|
||||
assert len(data.test_ds) == 15
|
||||
assert len(data.train_ds) == 16
|
||||
|
||||
# test keypoint data
|
||||
data = DetectionDataset(
|
||||
tiny_od_keypoint_data_path,
|
||||
keypoint_meta={
|
||||
"labels": [
|
||||
"lid_left_top",
|
||||
"lid_right_top",
|
||||
"lid_left_bottom",
|
||||
"lid_right_bottom",
|
||||
"left_bottom",
|
||||
"right_bottom",
|
||||
],
|
||||
"skeleton": [
|
||||
[0, 1],
|
||||
[0, 2],
|
||||
[1, 3],
|
||||
[2, 3],
|
||||
[2, 4],
|
||||
[3, 5],
|
||||
[4, 5],
|
||||
],
|
||||
"hflip_inds": [1, 0, 3, 2, 5, 4],
|
||||
},
|
||||
)
|
||||
validate_milkbottle_keypoint_tiny_dataset(data)
|
||||
assert len(data.test_ds) == 15
|
||||
assert len(data.train_ds) == 16
|
||||
|
||||
|
||||
def test_detection_dataset_init_train_pct(
|
||||
tiny_od_data_path, od_data_path_labels, tiny_od_mask_data_path
|
||||
tiny_od_data_path,
|
||||
od_data_path_labels,
|
||||
tiny_od_mask_data_path,
|
||||
tiny_od_keypoint_data_path,
|
||||
):
|
||||
""" Tests that initialization with train_pct."""
|
||||
data = DetectionDataset(tiny_od_data_path, train_pct=0.75)
|
||||
|
@ -128,31 +185,62 @@ def test_detection_dataset_init_train_pct(
|
|||
|
||||
# test mask data
|
||||
data = DetectionDataset(
|
||||
tiny_od_mask_data_path,
|
||||
train_pct=0.75,
|
||||
mask_dir="segmentation-masks"
|
||||
tiny_od_mask_data_path, train_pct=0.75, mask_dir="segmentation-masks"
|
||||
)
|
||||
validate_detection_dataset(data, od_data_path_labels)
|
||||
assert len(data.test_ds) == 7
|
||||
assert len(data.train_ds) == 24
|
||||
|
||||
# test keypoint data
|
||||
data = DetectionDataset(
|
||||
tiny_od_keypoint_data_path,
|
||||
train_pct=0.75,
|
||||
keypoint_meta={
|
||||
"labels": [
|
||||
"lid_left_top",
|
||||
"lid_right_top",
|
||||
"lid_left_bottom",
|
||||
"lid_right_bottom",
|
||||
"left_bottom",
|
||||
"right_bottom",
|
||||
],
|
||||
"skeleton": [
|
||||
[0, 1],
|
||||
[0, 2],
|
||||
[1, 3],
|
||||
[2, 3],
|
||||
[2, 4],
|
||||
[3, 5],
|
||||
[4, 5],
|
||||
],
|
||||
"hflip_inds": [1, 0, 3, 2, 5, 4],
|
||||
},
|
||||
)
|
||||
validate_milkbottle_keypoint_tiny_dataset(data)
|
||||
assert len(data.test_ds) == 7
|
||||
assert len(data.train_ds) == 24
|
||||
|
||||
|
||||
def test_detection_dataset_show_ims(
|
||||
basic_detection_dataset,
|
||||
od_detection_mask_dataset
|
||||
od_detection_mask_dataset,
|
||||
tiny_od_detection_keypoint_dataset,
|
||||
):
|
||||
# simply test that this is error free for now
|
||||
basic_detection_dataset.show_ims()
|
||||
od_detection_mask_dataset.show_ims()
|
||||
tiny_od_detection_keypoint_dataset.show_ims()
|
||||
|
||||
|
||||
def test_detection_dataset_show_im_transformations(
|
||||
basic_detection_dataset,
|
||||
od_detection_mask_dataset
|
||||
od_detection_mask_dataset,
|
||||
tiny_od_detection_keypoint_dataset,
|
||||
):
|
||||
# simply test that this is error free for now
|
||||
basic_detection_dataset.show_im_transformations()
|
||||
od_detection_mask_dataset.show_im_transformations()
|
||||
tiny_od_detection_keypoint_dataset.show_im_transformations()
|
||||
|
||||
|
||||
def test_detection_dataset_init_anno_im_dirs(
|
||||
|
|
|
@ -1,39 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from utils_cv.detection.keypoint import Keypoints
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def od_sample_keypoint_with_meta():
|
||||
keypoints = np.array(
|
||||
[
|
||||
[[10.0, 20.0, 2], [20.0, 20.0, 2]],
|
||||
[[20.0, 10.0, 2], [0, 0, 0]],
|
||||
[[30.0, 30.0, 2], [40.0, 40.0, 2]],
|
||||
[[40.0, 10.0, 2], [50.0, 50.0, 2]],
|
||||
]
|
||||
)
|
||||
keypoint_meta = {"num_keypoints": 2, "skeleton": [[0, 1]]}
|
||||
lines = [
|
||||
[10.0, 20.0, 20.0, 20.0],
|
||||
[30.0, 30.0, 40.0, 40.0],
|
||||
[40.0, 10.0, 50.0, 50.0],
|
||||
]
|
||||
|
||||
return keypoints, keypoint_meta, lines
|
||||
|
||||
|
||||
def test_cocokeypoints(od_sample_keypoint_with_meta):
|
||||
keypoints, keypoint_meta, lines = od_sample_keypoint_with_meta
|
||||
|
||||
# test init
|
||||
k = Keypoints(keypoints, keypoint_meta)
|
||||
assert np.all(k.keypoints == keypoints)
|
||||
assert k.meta == keypoint_meta
|
||||
|
||||
# test get_lints
|
||||
assert k.get_lines() == lines
|
|
@ -139,6 +139,14 @@ def test_detection_mask_learner_train_one_epoch(od_detection_mask_learner,):
|
|||
od_detection_mask_learner.fit(epochs=1)
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
def test_detection_keypoint_learner_train_one_epoch(
|
||||
od_detection_keypoint_learner,
|
||||
):
|
||||
""" Simply test that a small training loop works for keypoint learner. """
|
||||
od_detection_keypoint_learner.fit(epochs=1, skip_evaluation=True)
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
def test_detection_learner_plot_precision_loss_curves(od_detection_learner,):
|
||||
""" Simply test that `plot_precision_loss_curves` works. """
|
||||
|
@ -185,6 +193,19 @@ def test_detection_mask_learner_predict(
|
|||
assert len(bboxes) == len(masks)
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
def test_detection_keypoint_learner_predict(
|
||||
od_detection_keypoint_learner, od_cup_path
|
||||
):
|
||||
""" Simply test that `predict` works for keypoint learner. """
|
||||
pred = od_detection_keypoint_learner.predict(od_cup_path, threshold=0.1)
|
||||
bboxes = pred["det_bboxes"]
|
||||
keypoints = pred["keypoints"]
|
||||
assert type(bboxes) == list
|
||||
assert type(keypoints) == np.ndarray
|
||||
assert len(bboxes) == len(keypoints)
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
def test_detection_learner_predict_threshold(
|
||||
od_detection_learner, od_cup_path
|
||||
|
@ -215,6 +236,22 @@ def test_detection_mask_learner_predict_threshold(
|
|||
assert len(bboxes) == 0
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
def test_detection_keypoint_learner_predict_threshold(
|
||||
od_detection_keypoint_learner, od_cup_path
|
||||
):
|
||||
""" Simply test that `predict` works for keypoint learner with a threshold by
|
||||
setting a really high threshold.
|
||||
"""
|
||||
pred = od_detection_keypoint_learner.predict(od_cup_path, threshold=0.9999)
|
||||
bboxes = pred["det_bboxes"]
|
||||
keypoints = pred["keypoints"]
|
||||
assert type(bboxes) == list
|
||||
assert type(keypoints) == np.ndarray
|
||||
assert len(bboxes) == len(keypoints)
|
||||
assert len(bboxes) == 0
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
def test_detection_learner_predict_batch(
|
||||
od_detection_learner, od_detection_dataset
|
||||
|
@ -237,6 +274,17 @@ def test_detection_mask_learner_predict_batch(
|
|||
assert isinstance(generator, Iterable)
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
def test_detection_keypoint_learner_predict_batch(
|
||||
od_detection_keypoint_learner, tiny_od_detection_keypoint_dataset
|
||||
):
|
||||
""" Simply test that `predict_batch` works for keypoint learner. """
|
||||
generator = od_detection_keypoint_learner.predict_batch(
|
||||
tiny_od_detection_keypoint_dataset.test_dl
|
||||
)
|
||||
assert isinstance(generator, Iterable)
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
def test_detection_learner_predict_batch_threshold(
|
||||
od_detection_learner, od_detection_dataset
|
||||
|
@ -263,6 +311,19 @@ def test_detection_mask_learner_predict_batch_threshold(
|
|||
assert isinstance(generator, Iterable)
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
def test_detection_keypoint_learner_predict_batch_threshold(
|
||||
od_detection_keypoint_learner, tiny_od_detection_keypoint_dataset
|
||||
):
|
||||
""" Simply test that `predict_batch` works for keypoint learner with a
|
||||
threshold by setting it really high.
|
||||
"""
|
||||
generator = od_detection_keypoint_learner.predict_batch(
|
||||
tiny_od_detection_keypoint_dataset.test_dl, threshold=0.9999
|
||||
)
|
||||
assert isinstance(generator, Iterable)
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
def test_detection_dataset_predict_dl(
|
||||
od_detection_learner, od_detection_dataset
|
||||
|
@ -279,6 +340,16 @@ def test_detection_mask_dataset_predict_dl(
|
|||
od_detection_mask_learner.predict_dl(od_detection_mask_dataset.test_dl)
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
def test_detection_keypoint_dataset_predict_dl(
|
||||
od_detection_keypoint_learner, tiny_od_detection_keypoint_dataset
|
||||
):
|
||||
""" Simply test that `predict_dl` works for mask learner. """
|
||||
od_detection_keypoint_learner.predict_dl(
|
||||
tiny_od_detection_keypoint_dataset.test_dl
|
||||
)
|
||||
|
||||
|
||||
def validate_saved_model(name: str, path: str) -> None:
|
||||
""" Tests that saved model is there """
|
||||
assert (Path(path)).exists()
|
||||
|
|
|
@ -80,12 +80,18 @@ def test_02_notebook_run(detection_notebooks, tiny_od_mask_data_path):
|
|||
|
||||
@pytest.mark.gpu
|
||||
@pytest.mark.notebooks
|
||||
def test_03_notebook_run(detection_notebooks):
|
||||
def test_03_notebook_run(detection_notebooks, tiny_od_keypoint_data_path):
|
||||
notebook_path = detection_notebooks["03"]
|
||||
pm.execute_notebook(
|
||||
notebook_path,
|
||||
OUTPUT_NOTEBOOK,
|
||||
parameters=dict(PM_VERSION=pm.__version__, IM_SIZE=100),
|
||||
parameters=dict(
|
||||
PM_VERSION=pm.__version__,
|
||||
IM_SIZE=100,
|
||||
EPOCHS=1,
|
||||
DATA_PATH=tiny_od_keypoint_data_path,
|
||||
THRESHOLD=0.01,
|
||||
),
|
||||
kernel_name=KERNEL_NAME,
|
||||
)
|
||||
nb_output = sb.read_notebook(OUTPUT_NOTEBOOK)
|
||||
|
|
|
@ -77,23 +77,57 @@ def test_plot_masks(od_mask_rects):
|
|||
assert background_uniques[0] == ch_uniques[0]
|
||||
|
||||
|
||||
def test_plot_keypoints(od_keypoints_for_plot, basic_plot_settings):
|
||||
im, keypoints = od_keypoints_for_plot
|
||||
def test_plot_keypoints(basic_plot_settings):
|
||||
# a completely black image
|
||||
im = Image.fromarray(np.zeros((500, 600, 3), dtype=np.uint8))
|
||||
|
||||
# dummy keypoints
|
||||
keypoints = np.array([[[100, 200, 2], [200, 200, 2]]])
|
||||
keypoint_meta = {"skeleton": [[0, 1]]}
|
||||
|
||||
# basic case
|
||||
plot_keypoints(im, keypoints)
|
||||
plot_keypoints(im, keypoints, keypoint_meta)
|
||||
|
||||
# with update plot_settings
|
||||
plot_keypoints(im, keypoints, plot_settings=basic_plot_settings)
|
||||
plot_keypoints(
|
||||
im, keypoints, keypoint_meta, plot_settings=basic_plot_settings
|
||||
)
|
||||
|
||||
|
||||
def test_plot_detections(od_sample_detection, od_detection_mask_dataset):
|
||||
def test_plot_detections(
|
||||
od_sample_detection,
|
||||
od_detection_mask_dataset,
|
||||
od_sample_keypoint_detection,
|
||||
tiny_od_detection_keypoint_dataset,
|
||||
):
|
||||
plot_detections(od_sample_detection)
|
||||
plot_detections(od_sample_detection, od_detection_mask_dataset)
|
||||
plot_detections(od_sample_detection, od_detection_mask_dataset, 0)
|
||||
|
||||
# plot keypoints
|
||||
plot_detections(
|
||||
od_sample_keypoint_detection,
|
||||
keypoint_meta=tiny_od_detection_keypoint_dataset.keypoint_meta,
|
||||
)
|
||||
plot_detections(
|
||||
od_sample_keypoint_detection,
|
||||
tiny_od_detection_keypoint_dataset,
|
||||
keypoint_meta=tiny_od_detection_keypoint_dataset.keypoint_meta,
|
||||
)
|
||||
plot_detections(
|
||||
od_sample_keypoint_detection,
|
||||
tiny_od_detection_keypoint_dataset,
|
||||
0,
|
||||
keypoint_meta=tiny_od_detection_keypoint_dataset.keypoint_meta,
|
||||
)
|
||||
|
||||
def test_plot_grid(od_sample_detection, od_detection_mask_dataset):
|
||||
|
||||
def test_plot_grid(
|
||||
od_sample_detection,
|
||||
od_detection_mask_dataset,
|
||||
od_sample_keypoint_detection,
|
||||
tiny_od_detection_keypoint_dataset,
|
||||
):
|
||||
""" Test that `plot_grid` works. """
|
||||
|
||||
# test callable args
|
||||
|
@ -107,6 +141,16 @@ def test_plot_grid(od_sample_detection, od_detection_mask_dataset):
|
|||
|
||||
plot_grid(plot_detections, callable_args, rows=1)
|
||||
|
||||
def callable_args():
|
||||
return (
|
||||
od_sample_keypoint_detection,
|
||||
tiny_od_detection_keypoint_dataset,
|
||||
None,
|
||||
tiny_od_detection_keypoint_dataset.keypoint_meta,
|
||||
)
|
||||
|
||||
plot_grid(plot_detections, callable_args, rows=1)
|
||||
|
||||
# test iterable args
|
||||
def iterator_args():
|
||||
for detection in [od_sample_detection, od_sample_detection]:
|
||||
|
@ -120,6 +164,20 @@ def test_plot_grid(od_sample_detection, od_detection_mask_dataset):
|
|||
|
||||
plot_grid(plot_detections, iterator_args(), rows=1, cols=2)
|
||||
|
||||
def iterator_args():
|
||||
for detection in [
|
||||
od_sample_keypoint_detection,
|
||||
od_sample_keypoint_detection,
|
||||
]:
|
||||
yield (
|
||||
detection,
|
||||
tiny_od_detection_keypoint_dataset,
|
||||
None,
|
||||
tiny_od_detection_keypoint_dataset.keypoint_meta,
|
||||
)
|
||||
|
||||
plot_grid(plot_detections, iterator_args(), rows=1, cols=2)
|
||||
|
||||
|
||||
def test__setup_pr_axes(basic_ax):
|
||||
""" Test that `_setup_pr_axes` works. """
|
||||
|
|
|
@ -185,8 +185,8 @@ class DetectionBbox(AnnotationBbox):
|
|||
""" Create a Bbox object from an array [left, top, right, bottom]
|
||||
This function must take in a score.
|
||||
"""
|
||||
score = kwargs['score']
|
||||
del kwargs['score']
|
||||
score = kwargs["score"]
|
||||
del kwargs["score"]
|
||||
bbox = super().from_array(arr, **kwargs)
|
||||
bbox.__class__ = DetectionBbox
|
||||
bbox.score = score
|
||||
|
|
|
@ -31,7 +31,17 @@ class Urls:
|
|||
|
||||
# mask datasets
|
||||
fridge_objects_mask_path = urljoin(base, "odFridgeObjectsMask.zip")
|
||||
fridge_objects_mask_tiny_path = urljoin(base, "odFridgeObjectsMaskTiny.zip")
|
||||
fridge_objects_mask_tiny_path = urljoin(
|
||||
base, "odFridgeObjectsMaskTiny.zip"
|
||||
)
|
||||
|
||||
# keypoint datasets
|
||||
fridge_objects_keypoint_milk_bottle_path = urljoin(
|
||||
base, "odFridgeObjectsMilkbottleKeypoint.zip"
|
||||
)
|
||||
fridge_objects_keypoint_milk_bottle_tiny_path = urljoin(
|
||||
base, "odFridgeObjectsMilkbottleKeypointTiny.zip"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def all(cls) -> List[str]:
|
||||
|
@ -260,10 +270,13 @@ def extract_masks_from_labelbox_json(
|
|||
# read mask images
|
||||
mask_urls = [obj["instanceURI"] for obj in anno["Label"]["objects"]]
|
||||
labels = [obj["value"] for obj in anno["Label"]["objects"]]
|
||||
binary_masks = np.array([
|
||||
np.array(Image.open(urllib.request.urlopen(url)))[..., 0] == 255
|
||||
for url in mask_urls
|
||||
])
|
||||
binary_masks = np.array(
|
||||
[
|
||||
np.array(Image.open(urllib.request.urlopen(url)))[..., 0]
|
||||
== 255
|
||||
for url in mask_urls
|
||||
]
|
||||
)
|
||||
|
||||
# rearrange masks with regard to annotation
|
||||
tree = ET.parse(dst_anno_path)
|
||||
|
@ -286,17 +299,19 @@ def extract_masks_from_labelbox_json(
|
|||
min_overlap = binary_masks.shape[1] * binary_masks.shape[2]
|
||||
for i, bmask in enumerate(binary_masks):
|
||||
bmask_out = bmask.copy()
|
||||
bmask_out[top:(bottom + 1), left:(right + 1)] = False
|
||||
bmask_out[top : (bottom + 1), left : (right + 1)] = False
|
||||
non_overlap = np.sum(bmask_out)
|
||||
if non_overlap < min_overlap:
|
||||
match = i
|
||||
min_overlap = non_overlap
|
||||
assert label == labels[match], \
|
||||
"{}: {}".format(label, labels[match])
|
||||
assert label == labels[match], "{}: {}".format(
|
||||
label, labels[match]
|
||||
)
|
||||
matches.append(match)
|
||||
|
||||
assert len(set(matches)) == len(matches), \
|
||||
"{}: {}".format(len(set(matches)), len(matches))
|
||||
assert len(set(matches)) == len(matches), "{}: {}".format(
|
||||
len(set(matches)), len(matches)
|
||||
)
|
||||
|
||||
binary_masks = binary_masks[matches]
|
||||
|
||||
|
@ -407,7 +422,7 @@ def extract_keypoints_from_labelbox_json(
|
|||
# process one image keypoints annotation per iteration
|
||||
for anno in annos:
|
||||
# get related file paths
|
||||
im_name = anno["External ID"] # image file name
|
||||
im_name = anno["External ID"] # image file name
|
||||
anno_name = im_name[:-4] + ".xml" # annotation file name
|
||||
|
||||
print("Processing image: {}".format(im_name))
|
||||
|
@ -422,20 +437,20 @@ def extract_keypoints_from_labelbox_json(
|
|||
shutil.copy(src_im_path, dst_im_path)
|
||||
|
||||
# add keypoints annotation into PASCAL VOC XML file
|
||||
kps_annos = anno["Label"]
|
||||
keypoints_annos = anno["Label"]
|
||||
tree = ET.parse(src_anno_path)
|
||||
root = tree.getroot()
|
||||
for obj in root.findall("object"):
|
||||
prefix = obj.find("name").text + "_"
|
||||
# add "keypoints" node for current object
|
||||
kps = ET.SubElement(obj, "keypoints")
|
||||
for k in kps_annos.keys():
|
||||
keypoints = ET.SubElement(obj, "keypoints")
|
||||
for k in keypoints_annos.keys():
|
||||
if k.startswith(prefix):
|
||||
# add keypoint into "keypoints" node
|
||||
pt = ET.SubElement(kps, k[len(prefix):])
|
||||
pt = ET.SubElement(keypoints, k[len(prefix) :])
|
||||
x = ET.SubElement(pt, "x") # add x coordinate
|
||||
y = ET.SubElement(pt, "y") # add y coordinate
|
||||
geo = kps_annos[k][0]["geometry"]
|
||||
geo = keypoints_annos[k][0]["geometry"]
|
||||
x.text = str(geo["x"])
|
||||
y.text = str(geo["y"])
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ import math
|
|||
import numpy as np
|
||||
from pathlib import Path
|
||||
import random
|
||||
from typing import Callable, List, Tuple, Union
|
||||
from typing import Callable, Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset, Subset, DataLoader
|
||||
|
@ -20,12 +20,51 @@ from .plot import plot_detections, plot_grid
|
|||
from .bbox import AnnotationBbox
|
||||
from .mask import binarise_mask
|
||||
from .references.utils import collate_fn
|
||||
from .references.transforms import RandomHorizontalFlip, Compose, ToTensor
|
||||
from .references.transforms import Compose, ToTensor
|
||||
from ..common.gpu import db_num_workers
|
||||
|
||||
Trans = Callable[[object, dict], Tuple[object, dict]]
|
||||
|
||||
|
||||
def _flip_keypoints(keypoints, width, hflip_inds):
|
||||
""" Variation of `references.transforms._flip_coco_person_keypoints` with additional
|
||||
hflip_inds. """
|
||||
flipped_keypoints = keypoints[:, hflip_inds]
|
||||
flipped_keypoints[..., 0] = width - flipped_keypoints[..., 0]
|
||||
# Maintain COCO convention that if visibility == 0, then x, y = 0
|
||||
inds = flipped_keypoints[..., 2] == 0
|
||||
flipped_keypoints[inds] = 0
|
||||
return flipped_keypoints
|
||||
|
||||
|
||||
class RandomHorizontalFlip(object):
|
||||
""" Variation of `references.transforms.RandomHorizontalFlip` to make sure flipping
|
||||
works on custom keypoints. """
|
||||
|
||||
def __init__(self, prob):
|
||||
self.prob = prob
|
||||
|
||||
def __call__(self, im, target):
|
||||
if random.random() < self.prob:
|
||||
height, width = im.shape[-2:]
|
||||
im = im.flip(-1)
|
||||
bbox = target["boxes"]
|
||||
bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
|
||||
target["boxes"] = bbox
|
||||
if "masks" in target:
|
||||
target["masks"] = target["masks"].flip(-1)
|
||||
if "keypoints" in target:
|
||||
assert (
|
||||
"hflip_inds" in target
|
||||
), "To use random horizontal flipping, 'hflip_inds' needs to be specified"
|
||||
keypoints = target["keypoints"]
|
||||
keypoints = _flip_keypoints(
|
||||
keypoints, width, target["hflip_inds"]
|
||||
)
|
||||
target["keypoints"] = keypoints
|
||||
return im, target
|
||||
|
||||
|
||||
class ColorJitterTransform(object):
|
||||
""" Wrapper for torchvision's ColorJitter to make sure 'target
|
||||
object is passed along """
|
||||
|
@ -80,19 +119,26 @@ def get_transform(train: bool) -> Trans:
|
|||
|
||||
|
||||
def parse_pascal_voc_anno(
|
||||
anno_path: str, labels: List[str] = None
|
||||
) -> Tuple[List[AnnotationBbox], Union[str, Path]]:
|
||||
anno_path: str, labels: List[str] = None, keypoint_meta: Dict = None
|
||||
) -> Tuple[List[AnnotationBbox], Union[str, Path], np.ndarray]:
|
||||
""" Extract the annotations and image path from labelling in Pascal VOC format.
|
||||
|
||||
Args:
|
||||
anno_path: the path to the annotation xml file
|
||||
labels: list of all possible labels, used to compute label index for each label name
|
||||
keypoint_meta: meta data of keypoints which should include at least
|
||||
"labels".
|
||||
|
||||
Return
|
||||
A tuple of annotations and the image path
|
||||
A tuple of annotations, the image path and keypoints. Keypoints is a
|
||||
numpy array of shape (N, K, 3), where N is the number of objects of the
|
||||
category that defined the keypoints, and K is the number of keypoints
|
||||
defined in the category. `len(keypoints)` would be 0 if no keypoints
|
||||
found.
|
||||
"""
|
||||
|
||||
anno_bboxes = []
|
||||
keypoints = []
|
||||
tree = ET.parse(anno_path)
|
||||
root = tree.getroot()
|
||||
|
||||
|
@ -107,10 +153,44 @@ def parse_pascal_voc_anno(
|
|||
os.path.join(anno_dir, root.find("filename").text)
|
||||
)
|
||||
|
||||
# extract bounding boxes and classification
|
||||
# extract bounding boxes, classification and keypoints
|
||||
objs = root.findall("object")
|
||||
for obj in objs:
|
||||
label = obj.find("name").text
|
||||
# Get keypoints if any.
|
||||
# For keypoint detection, currently only one category (except
|
||||
# background) is allowed. We assume all annotated objects are of that
|
||||
# category.
|
||||
if keypoint_meta is not None:
|
||||
kps = []
|
||||
kps_labels = keypoint_meta["labels"]
|
||||
|
||||
# Assume keypoints are available
|
||||
kps_annos = obj.find("keypoints")
|
||||
if kps_annos is None:
|
||||
raise Exception(f"No keypoints found in {anno_path}")
|
||||
assert set([kp.tag for kp in kps_annos]).issubset(
|
||||
kps_labels
|
||||
), "Incompatible keypoint labels"
|
||||
|
||||
# Read keypoint coordinates: [x, y, visibility]
|
||||
# Visibility 0 means invisible, non-zero means visible
|
||||
for name in kps_labels:
|
||||
kp_anno = kps_annos.find(name)
|
||||
if kp_anno is None:
|
||||
# return 0 for invisible keypoints
|
||||
kps.append([0, 0, 0])
|
||||
else:
|
||||
kps.append(
|
||||
[
|
||||
int(float(kp_anno.find("x").text)),
|
||||
int(float(kp_anno.find("y").text)),
|
||||
1,
|
||||
]
|
||||
)
|
||||
keypoints.append(kps)
|
||||
|
||||
# get bounding box
|
||||
bnd_box = obj.find("bndbox")
|
||||
left = int(bnd_box.find("xmin").text)
|
||||
top = int(bnd_box.find("ymin").text)
|
||||
|
@ -132,7 +212,7 @@ def parse_pascal_voc_anno(
|
|||
assert anno_bbox.is_valid()
|
||||
anno_bboxes.append(anno_bbox)
|
||||
|
||||
return anno_bboxes, im_path
|
||||
return anno_bboxes, im_path, np.array(keypoints)
|
||||
|
||||
|
||||
class DetectionDataset:
|
||||
|
@ -152,6 +232,7 @@ class DetectionDataset:
|
|||
anno_dir: str = "annotations",
|
||||
im_dir: str = "images",
|
||||
mask_dir: str = None,
|
||||
keypoint_meta: Dict = None,
|
||||
seed: int = None,
|
||||
allow_negatives: bool = False,
|
||||
):
|
||||
|
@ -173,6 +254,8 @@ class DetectionDataset:
|
|||
im_dir: the name of the image subfolder under the root directory. If set to 'None' then infers image location from annotation .xml files
|
||||
allow_negatives: is false (default) then will throw an error if no annotation .xml file can be found for a given image. Otherwise use image as negative, ie assume that the image does not contain any of the objects of interest.
|
||||
mask_dir: the name of the mask subfolder under the root directory if the dataset is used for instance segmentation
|
||||
keypoint_meta: meta data of keypoints which should include
|
||||
"labels", "skeleton" and "hflip_inds".
|
||||
seed: random seed for splitting dataset to training and testing data
|
||||
"""
|
||||
|
||||
|
@ -186,6 +269,7 @@ class DetectionDataset:
|
|||
self.train_pct = train_pct
|
||||
self.allow_negatives = allow_negatives
|
||||
self.seed = seed
|
||||
self.keypoint_meta = keypoint_meta
|
||||
|
||||
# read annotations
|
||||
self._read_annos()
|
||||
|
@ -223,12 +307,19 @@ class DetectionDataset:
|
|||
self.anno_paths = []
|
||||
self.anno_bboxes = []
|
||||
self.mask_paths = []
|
||||
self.keypoints = []
|
||||
for anno_idx, anno_filename in enumerate(anno_filenames):
|
||||
anno_path = self.root / self.anno_dir / str(anno_filename)
|
||||
|
||||
# Parse annotation file if present
|
||||
if os.path.exists(anno_path):
|
||||
anno_bboxes, im_path = parse_pascal_voc_anno(anno_path)
|
||||
anno_bboxes, im_path, keypoints = parse_pascal_voc_anno(
|
||||
anno_path, keypoint_meta=self.keypoint_meta
|
||||
)
|
||||
# When meta provided, we assume this is keypoint
|
||||
# detection.
|
||||
if self.keypoint_meta is not None:
|
||||
self.keypoints.append(keypoints)
|
||||
else:
|
||||
if not self.allow_negatives:
|
||||
raise FileNotFoundError(anno_path)
|
||||
|
@ -338,9 +429,10 @@ class DetectionDataset:
|
|||
def add_images(
|
||||
self,
|
||||
im_paths: List[str],
|
||||
anno_bboxes: List[AnnotationBbox],
|
||||
anno_bboxes: List[List[AnnotationBbox]],
|
||||
target: str = "train",
|
||||
mask_paths: List[str] = None,
|
||||
keypoints: List[np.ndarray] = None,
|
||||
):
|
||||
""" Add new images to either the training or test set.
|
||||
|
||||
|
@ -349,6 +441,9 @@ class DetectionDataset:
|
|||
anno_bboxes: ground truth boxes for each image.
|
||||
target: specify if images are to be added to the training or test set. Valid options: "train" or "test".
|
||||
mask_paths: path to the masks.
|
||||
keypoints: list of numpy array of shape (N, K, 3), where N is the
|
||||
number of objects of the category that defined the keypoints,
|
||||
and K is the number of keypoints defined in the category.
|
||||
|
||||
Raises:
|
||||
Exception if `target` variable is neither 'train' nor 'test'
|
||||
|
@ -361,6 +456,9 @@ class DetectionDataset:
|
|||
if mask_paths is not None:
|
||||
self.mask_paths.append(mask_paths[i])
|
||||
|
||||
if keypoints is not None:
|
||||
self.keypoints.append(keypoints[i])
|
||||
|
||||
if target.lower() == "train":
|
||||
self.train_ds.dataset.im_paths.append(im_path)
|
||||
self.train_ds.dataset.anno_bboxes.append(anno_bbox)
|
||||
|
@ -368,6 +466,9 @@ class DetectionDataset:
|
|||
if mask_paths is not None:
|
||||
self.train_ds.dataset.mask_paths.append(mask_paths[i])
|
||||
|
||||
if keypoints is not None:
|
||||
self.train_ds.dataset.keypoints.append(keypoints[i])
|
||||
|
||||
self.train_ds.indices.append(len(self.im_paths) - 1)
|
||||
elif target.lower() == "test":
|
||||
self.test_ds.dataset.im_paths.append(im_path)
|
||||
|
@ -376,6 +477,9 @@ class DetectionDataset:
|
|||
if mask_paths is not None:
|
||||
self.test_ds.dataset.mask_paths.append(mask_paths[i])
|
||||
|
||||
if keypoints is not None:
|
||||
self.test_ds.dataset.keypoints.append(keypoints[i])
|
||||
|
||||
self.test_ds.indices.append(len(self.im_paths) - 1)
|
||||
else:
|
||||
raise Exception(f"Target {target} unknown.")
|
||||
|
@ -504,6 +608,16 @@ class DetectionDataset:
|
|||
if binary_masks is not None:
|
||||
target["masks"] = torch.as_tensor(binary_masks, dtype=torch.uint8)
|
||||
|
||||
# get keypoints
|
||||
if self.keypoints:
|
||||
target["keypoints"] = torch.as_tensor(
|
||||
self.keypoints[idx], dtype=torch.float32
|
||||
)
|
||||
if "hflip_inds" in self.keypoint_meta:
|
||||
target["hflip_inds"] = torch.as_tensor(
|
||||
self.keypoint_meta["hflip_inds"], dtype=torch.int64
|
||||
)
|
||||
|
||||
# get image
|
||||
im = Image.open(im_path).convert("RGB")
|
||||
|
||||
|
|
|
@ -1,48 +1,81 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import Dict, List
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Keypoints:
|
||||
""" Keypoints class for keypoint detection. """
|
||||
|
||||
def __init__(self, keypoints: np.ndarray, meta: Dict):
|
||||
"""
|
||||
Args:
|
||||
keypoints: keypoints array of shape (N, num_keypoints, 3), where N
|
||||
is the number of objects. 3 means x, y and visibility. 0 for
|
||||
visibility means invisible.
|
||||
meta: a dict includes "num_keypoints" and "skeleton".
|
||||
"num_keypoints" is an integer indicating the number of
|
||||
predefined keypoints. "skeleton" is a list of connections
|
||||
between each predefined keypoints.
|
||||
"""
|
||||
self.meta = meta
|
||||
self.keypoints = keypoints
|
||||
|
||||
# convert self.keypoints into correct type
|
||||
self.keypoints = np.asarray(self.keypoints, dtype=np.float)
|
||||
if self.keypoints.ndim != 3:
|
||||
# shape must be (N, len(self.meta["labels"]), 3)
|
||||
self.keypoints = self.keypoints.reshape(
|
||||
(-1, self.meta["num_keypoints"], 3)
|
||||
)
|
||||
|
||||
# skeleton indexes should not be out of the range of labels
|
||||
assert (
|
||||
np.max(np.array(self.meta["skeleton"]))
|
||||
< self.meta["num_keypoints"]
|
||||
)
|
||||
|
||||
# make sure invisible points' x, y = 0
|
||||
self.keypoints[self.keypoints[..., 2] == 0] = 0
|
||||
|
||||
def get_lines(self) -> List[List[float]]:
|
||||
""" Return connected lines represented by list of [x1, y1, x2, y2]. """
|
||||
joints = self.keypoints[:, self.meta["skeleton"]]
|
||||
visibles = (joints[..., 2] != 0).all(axis=2)
|
||||
bones = joints[visibles][..., :2]
|
||||
lines = bones.reshape((-1, 4)).tolist()
|
||||
return lines
|
||||
COCO_keypoint_meta = {
|
||||
# `labels` gives the names of keypoints
|
||||
"labels": [
|
||||
"nose", # 0
|
||||
"left_eye", # 1
|
||||
"right_eye", # 2
|
||||
"left_ear", # 3
|
||||
"right_ear", # 4
|
||||
"left_shoulder", # 5
|
||||
"right_shoulder", # 6
|
||||
"left_elbow", # 7
|
||||
"right_elbow", # 8
|
||||
"left_wrist", # 9
|
||||
"right_wrist", # 10
|
||||
"left_hip", # 11
|
||||
"right_hip", # 12
|
||||
"left_knee", # 13
|
||||
"right_knee", # 14
|
||||
"left_ankle", # 15
|
||||
"right_ankle", # 16
|
||||
],
|
||||
# `skeleton` is used to specify how keypoints are connected with each
|
||||
# other when drawing. For example, `[15, 13]` means left_ankle (15) will
|
||||
# connect to left_knee when plotting on the image, and `[13, 11]` means
|
||||
# left_knee (13) will connect to left_hip (11).
|
||||
"skeleton": [
|
||||
[15, 13], # left_ankle -- left_knee
|
||||
[13, 11], # left_knee -- left_hip
|
||||
[16, 14], # right_ankle -- right_knee
|
||||
[14, 12], # right_knee -- right_hip
|
||||
[11, 12], # left_hip -- right_hip
|
||||
[5, 11], # left_shoulder -- left_hip
|
||||
[6, 12], # right_shoulder -- right_hip
|
||||
[5, 6], # left_shoulder -- right_shoulder
|
||||
[5, 7], # left_shoulder -- left_elbow
|
||||
[6, 8], # right_shoulder -- right_elbow
|
||||
[7, 9], # left_elbow -- left_wrist
|
||||
[8, 10], # right_elbow -- right_wrist
|
||||
[1, 2], # left_eye -- right_eye
|
||||
[0, 1], # nose -- left_eye
|
||||
[0, 2], # nose -- right_eye
|
||||
[1, 3], # left_eye -- left_ear
|
||||
[2, 4], # right_eye -- right_ear
|
||||
[3, 5], # left_ear -- left_shoulder
|
||||
[4, 6], # right_ear -- right_shoulder
|
||||
],
|
||||
# When an image is flipped horizontally, some keypoints related to the
|
||||
# concept of left and right would change to its opposite meaning. For
|
||||
# example, left eye would become right eye.
|
||||
# `hflip_inds` is used in the horizontal flip transformation for data
|
||||
# augmentation during training to specify what the keypoint will become.
|
||||
# In other words, `COCO_keypoint_meta["hflip_inds"][0]` specify what nose
|
||||
# (`COCO_keypoint_meta["labels"][0]`) will become when an image is
|
||||
# flipped horizontally. Because nose would still be nose even after
|
||||
# flipping, its value is still 0. Left eye
|
||||
# (`COCO_keypoint_meta["labels"][1]`) will be right eye
|
||||
# (`COCO_keypoint_meta["labels"][2]`), so the value of
|
||||
# `COCO_keypoint_meta["hflip_inds"][1]` should be 2.
|
||||
"hflip_inds": [
|
||||
0, # nose
|
||||
2, # left_eye -> right_eye
|
||||
1, # right_eye -> left_eye
|
||||
4, # left_ear -> right_ear
|
||||
3, # right_ear -> left_ear
|
||||
6, # left_shoulder -> right_shoulder
|
||||
5, # right_shoulder -> left_shoulder
|
||||
8, # left_elbow -> right_elbow
|
||||
7, # right_elbow -> left_elbow
|
||||
10, # left_wrist -> right_wrist
|
||||
9, # right_wrist -> left_wrist
|
||||
12, # left_hip -> right_hip
|
||||
11, # right_hip -> left_hip
|
||||
14, # left_knee -> right_knee
|
||||
13, # right_knee -> left_knee
|
||||
16, # left_ankle -> right_ankle
|
||||
15, # right_ankle -> left_ankle
|
||||
],
|
||||
}
|
||||
|
|
|
@ -24,8 +24,9 @@ def binarise_mask(mask: Union[np.ndarray, str, Path]) -> np.ndarray:
|
|||
|
||||
# if all values are False or True, consider it's already binarised
|
||||
if mask.ndim == 3:
|
||||
assert all(i in [False, True] for i in np.unique(mask).tolist()), \
|
||||
"'mask' should be grayscale."
|
||||
assert all(
|
||||
i in [False, True] for i in np.unique(mask).tolist()
|
||||
), "'mask' should be grayscale."
|
||||
return mask
|
||||
|
||||
assert mask.ndim == 2, "'mask' should have at least 2 channels."
|
||||
|
@ -37,8 +38,7 @@ def binarise_mask(mask: Union[np.ndarray, str, Path]) -> np.ndarray:
|
|||
|
||||
|
||||
def colorise_binary_mask(
|
||||
binary_mask: np.ndarray,
|
||||
color: Tuple[int, int, int] = (2, 166, 101),
|
||||
binary_mask: np.ndarray, color: Tuple[int, int, int] = (2, 166, 101)
|
||||
) -> np.ndarray:
|
||||
""" Set the color for the instance in the mask. """
|
||||
# create empty RGB channels
|
||||
|
@ -53,16 +53,16 @@ def colorise_binary_mask(
|
|||
|
||||
|
||||
def transparentise_mask(
|
||||
colored_mask: np.ndarray,
|
||||
alpha: float = 0.5,
|
||||
colored_mask: np.ndarray, alpha: float = 0.5
|
||||
) -> np.ndarray:
|
||||
""" Return a mask with fully transparent background and alpha-transparent
|
||||
instances.
|
||||
|
||||
Assume channel is the third dimension of mask, and no alpha channel.
|
||||
"""
|
||||
assert colored_mask.shape[2] == 3, \
|
||||
"'colored_mask' should be of 3-channels RGB."
|
||||
assert (
|
||||
colored_mask.shape[2] == 3
|
||||
), "'colored_mask' should be of 3-channels RGB."
|
||||
# convert (0, 0, 0) to (0, 0, 0, 0) and
|
||||
# all other (x, y, z) to (x, y, z, alpha*255)
|
||||
binary_mask = (colored_mask != 0).any(axis=2)
|
||||
|
|
|
@ -17,7 +17,6 @@ from .bbox import _Bbox
|
|||
from .model import ims_eval_detections
|
||||
from .references.coco_eval import CocoEvaluator
|
||||
from ..common.misc import get_font
|
||||
from .keypoint import Keypoints
|
||||
from .mask import binarise_mask, colorise_binary_mask, transparentise_mask
|
||||
|
||||
|
||||
|
@ -101,9 +100,12 @@ def plot_masks(
|
|||
) -> PIL.Image.Image:
|
||||
""" Put mask onto image.
|
||||
|
||||
Assume the mask is already binary masks of [N, Height, Width], or
|
||||
grayscale mask of [Height, Width] with different values
|
||||
representing different objects, 0 as background.
|
||||
Args:
|
||||
im: the image to plot masks on
|
||||
mask: it should be binary masks of [N, Height, Width], or grayscale
|
||||
mask of [Height, Width] with different values representing
|
||||
different objects, 0 as background
|
||||
plot_settings: the parameter to plot the masks
|
||||
"""
|
||||
if isinstance(im, (str, Path)):
|
||||
im = Image.open(im)
|
||||
|
@ -128,16 +130,41 @@ def plot_masks(
|
|||
|
||||
def plot_keypoints(
|
||||
im: Union[str, Path, PIL.Image.Image],
|
||||
keypoints: Keypoints,
|
||||
keypoints: np.ndarray,
|
||||
keypoint_meta: Dict,
|
||||
plot_settings: PlotSettings = PlotSettings(),
|
||||
) -> PIL.Image.Image:
|
||||
""" Plot connected keypoints on Image and return the Image. """
|
||||
""" Plot connected keypoints on Image and return the Image.
|
||||
|
||||
Args:
|
||||
im: the image to plot keypoints on
|
||||
keypoints: the keypoints to plot, of shape (N, num_keypoints, 3),
|
||||
where N is the number of objects. 3 means x, y and visibility.
|
||||
0 for visibility means invisible
|
||||
keypoint_meta: meta data of keypoints which should include at least
|
||||
"skeleton"
|
||||
plot_settings: the parameter to plot the keypoints
|
||||
"""
|
||||
if isinstance(im, (str, Path)):
|
||||
im = Image.open(im)
|
||||
|
||||
if keypoints is not None:
|
||||
assert (
|
||||
keypoints.ndim == 3 and keypoints.shape[2] == 3
|
||||
), "Malformed keypoints array"
|
||||
assert (
|
||||
np.max(np.array(keypoint_meta["skeleton"])) < keypoints.shape[1]
|
||||
), "Skeleton index out of range"
|
||||
|
||||
draw = ImageDraw.Draw(im)
|
||||
for line in keypoints.get_lines():
|
||||
|
||||
# get connected skeleton lines of the keypoints
|
||||
joints = keypoints[:, keypoint_meta["skeleton"]]
|
||||
visibles = (joints[..., 2] != 0).all(axis=2)
|
||||
bones = joints[visibles][..., :2]
|
||||
|
||||
# draw skeleton lines
|
||||
for line in bones.reshape((-1, 4)).tolist():
|
||||
draw.line(
|
||||
line,
|
||||
fill=plot_settings.keypoint_color,
|
||||
|
@ -160,8 +187,8 @@ def plot_detections(
|
|||
detection: output running model prediction.
|
||||
data: dataset with ground truth information.
|
||||
idx: index into the data object to find the ground truth which corresponds to the detection.
|
||||
keypoint_meta: meta data of keypoints which should include
|
||||
"num_keypoints" and "skeleton".
|
||||
keypoint_meta: meta data of keypoints which should include at least
|
||||
"skeleton".
|
||||
ax: an optional ax to specify where you wish the figure to be drawn on
|
||||
"""
|
||||
# Open image
|
||||
|
@ -169,7 +196,7 @@ def plot_detections(
|
|||
im = Image.open(detection["im_path"])
|
||||
|
||||
# Get id of ground truth image/annotation
|
||||
if data and not idx:
|
||||
if data and idx is None:
|
||||
idx = detection["idx"]
|
||||
|
||||
# Loop over all images
|
||||
|
@ -190,11 +217,21 @@ def plot_detections(
|
|||
mask = detection["masks"]
|
||||
im = plot_masks(im, mask, PlotSettings(mask_color=(128, 165, 0)))
|
||||
|
||||
# Plot ground truth keypoints
|
||||
if data and data.keypoints and data.keypoint_meta:
|
||||
im = plot_keypoints(
|
||||
im,
|
||||
data.keypoints[idx],
|
||||
data.keypoint_meta,
|
||||
PlotSettings(keypoint_color=(0, 192, 0)),
|
||||
)
|
||||
|
||||
# Plot predicted keypoints
|
||||
if "keypoints" in detection:
|
||||
im = plot_keypoints(
|
||||
im,
|
||||
Keypoints(detection["keypoints"], keypoint_meta),
|
||||
detection["keypoints"],
|
||||
keypoint_meta,
|
||||
PlotSettings(keypoint_color=(192, 165, 0)),
|
||||
)
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче