зеркало из https://github.com/microsoft/torchgeo.git
Add VHR10 datamodule (#1082)
* Add VHR10 datamodule * Add newline * patch_size accepts int and tuple of ints * Update conf * VHR10 Datamodule v2 * Remove auto_lr_find * Remove preprocess * Update config * Remove setting of matplotlib backend * Remove import * Typing update * Key fix * Coverage fix * Update conf * Update conf * Dowload=True * Use weights * Empty commit * Switch to ndim * Remove conf, tight_layout and spacing * Set constrained layout via rcParams * Revert and bump min matplotlib version * Switch back to dataset_split * Separate out AugPipe * Increase figsize & revert matplotlib * Common collate_fn * Class var std * Undo std change in BaseDataModule * Undo req changes * Remove unused line * Add version strings * mypy fix --------- Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Родитель
b0f5184bf5
Коммит
0f8b0ac3ea
|
@ -0,0 +1,17 @@
|
|||
model:
|
||||
class_path: ObjectDetectionTask
|
||||
init_args:
|
||||
model: "faster-rcnn"
|
||||
backbone: "resnet50"
|
||||
num_classes: 11
|
||||
lr: 2.5e-5
|
||||
patience: 10
|
||||
data:
|
||||
class_path: VHR10DataModule
|
||||
init_args:
|
||||
batch_size: 1
|
||||
num_workers: 0
|
||||
patch_size: 4
|
||||
dict_kwargs:
|
||||
root: "tests/data/vhr10"
|
||||
download: true
|
Двоичные данные
tests/data/vhr10/NWPU VHR-10 dataset.rar
Двоичные данные
tests/data/vhr10/NWPU VHR-10 dataset.rar
Двоичный файл не отображается.
|
@ -1 +1 @@
|
|||
{"images": [{"file_name": "001.jpg", "height": 8, "width": 8, "id": 0}, {"file_name": "002.jpg", "height": 8, "width": 8, "id": 1}, {"file_name": "003.jpg", "height": 8, "width": 8, "id": 2}, {"file_name": "004.jpg", "height": 8, "width": 8, "id": 3}, {"file_name": "005.jpg", "height": 8, "width": 8, "id": 4}], "annotations": [{"id": 0, "image_id": 0, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "iscrowd": 0}, {"id": 1, "image_id": 1, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}, {"id": 2, "image_id": 2, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}, {"id": 3, "image_id": 3, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}, {"id": 4, "image_id": 4, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}]}
|
||||
{"images": [{"file_name": "001.jpg", "height": 8, "width": 8, "id": 0}, {"file_name": "002.jpg", "height": 8, "width": 8, "id": 1}, {"file_name": "003.jpg", "height": 8, "width": 8, "id": 2}, {"file_name": "004.jpg", "height": 8, "width": 8, "id": 3}, {"file_name": "005.jpg", "height": 8, "width": 8, "id": 4}], "annotations": [{"id": 0, "image_id": 0, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}, {"id": 1, "image_id": 1, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}, {"id": 2, "image_id": 2, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}, {"id": 3, "image_id": 3, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}, {"id": 4, "image_id": 4, "category_id": 1, "area": 4.0, "bbox": [4, 4, 2, 2], "segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]], "iscrowd": 0}]}
|
|
@ -5,7 +5,6 @@ import json
|
|||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
@ -47,7 +46,7 @@ def generate_test_data(root: str, n_imgs: int = 3) -> str:
|
|||
)
|
||||
|
||||
ann = 0
|
||||
for i, img in enumerate(ANNOTATION_FILE["images"]):
|
||||
for _, img in enumerate(ANNOTATION_FILE["images"]):
|
||||
annot = {
|
||||
"id": ann,
|
||||
"image_id": img["id"],
|
||||
|
@ -57,12 +56,7 @@ def generate_test_data(root: str, n_imgs: int = 3) -> str:
|
|||
"segmentation": [[1, 1, 2, 2, 3, 3, 4, 5, 5]],
|
||||
"iscrowd": 0,
|
||||
}
|
||||
if i != 0:
|
||||
ANNOTATION_FILE["annotations"].append(annot)
|
||||
else:
|
||||
noseg_annot = deepcopy(annot)
|
||||
del noseg_annot["segmentation"]
|
||||
ANNOTATION_FILE["annotations"].append(noseg_annot)
|
||||
ANNOTATION_FILE["annotations"].append(annot)
|
||||
ann += 1
|
||||
|
||||
with open(ann_file, "w") as j:
|
||||
|
|
|
@ -35,11 +35,11 @@ class TestVHR10:
|
|||
monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url)
|
||||
url = os.path.join("tests", "data", "vhr10", "NWPU VHR-10 dataset.rar")
|
||||
monkeypatch.setitem(VHR10.image_meta, "url", url)
|
||||
md5 = "5fddb0dfd56a80638831df9f90cbf37a"
|
||||
md5 = "92769845cae6a4e8c74bfa1a0d1d4a80"
|
||||
monkeypatch.setitem(VHR10.image_meta, "md5", md5)
|
||||
url = os.path.join("tests", "data", "vhr10", "annotations.json")
|
||||
monkeypatch.setitem(VHR10.target_meta, "url", url)
|
||||
md5 = "833899cce369168e0d4ee420dac326dc"
|
||||
md5 = "567c4cd8c12624864ff04865de504c58"
|
||||
monkeypatch.setitem(VHR10.target_meta, "md5", md5)
|
||||
root = str(tmp_path)
|
||||
split = request.param
|
||||
|
|
|
@ -67,7 +67,7 @@ def plot(*args: Any, **kwargs: Any) -> None:
|
|||
|
||||
|
||||
class TestObjectDetectionTask:
|
||||
@pytest.mark.parametrize("name", ["nasa_marine_debris"])
|
||||
@pytest.mark.parametrize("name", ["nasa_marine_debris", "vhr10"])
|
||||
@pytest.mark.parametrize("model_name", ["faster-rcnn", "fcos", "retinanet"])
|
||||
def test_trainer(
|
||||
self, monkeypatch: MonkeyPatch, name: str, model_name: str, fast_dev_run: bool
|
||||
|
|
|
@ -23,7 +23,7 @@ def batch_gray() -> dict[str, Tensor]:
|
|||
return {
|
||||
"image": torch.tensor([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.float),
|
||||
"mask": torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long),
|
||||
"boxes": torch.tensor([[[0, 1], [1, 1], [1, 0], [0, 0]]], dtype=torch.float),
|
||||
"boxes": torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float),
|
||||
"labels": torch.tensor([[0, 1]]),
|
||||
}
|
||||
|
||||
|
@ -42,7 +42,7 @@ def batch_rgb() -> dict[str, Tensor]:
|
|||
dtype=torch.float,
|
||||
),
|
||||
"mask": torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long),
|
||||
"boxes": torch.tensor([[[0, 1], [1, 1], [1, 0], [0, 0]]], dtype=torch.float),
|
||||
"boxes": torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float),
|
||||
"labels": torch.tensor([[0, 1]]),
|
||||
}
|
||||
|
||||
|
@ -63,7 +63,7 @@ def batch_multispectral() -> dict[str, Tensor]:
|
|||
dtype=torch.float,
|
||||
),
|
||||
"mask": torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long),
|
||||
"boxes": torch.tensor([[[0, 1], [1, 1], [1, 0], [0, 0]]], dtype=torch.float),
|
||||
"boxes": torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float),
|
||||
"labels": torch.tensor([[0, 1]]),
|
||||
}
|
||||
|
||||
|
@ -79,7 +79,7 @@ def test_augmentation_sequential_gray(batch_gray: dict[str, Tensor]) -> None:
|
|||
expected = {
|
||||
"image": torch.tensor([[[[3, 2, 1], [6, 5, 4], [9, 8, 7]]]], dtype=torch.float),
|
||||
"mask": torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long),
|
||||
"boxes": torch.tensor([[[1, 0], [2, 0], [2, 1], [1, 1]]], dtype=torch.float),
|
||||
"boxes": torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float),
|
||||
"labels": torch.tensor([[0, 1]]),
|
||||
}
|
||||
augs = transforms.AugmentationSequential(
|
||||
|
@ -102,7 +102,7 @@ def test_augmentation_sequential_rgb(batch_rgb: dict[str, Tensor]) -> None:
|
|||
dtype=torch.float,
|
||||
),
|
||||
"mask": torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long),
|
||||
"boxes": torch.tensor([[[1, 0], [2, 0], [2, 1], [1, 1]]], dtype=torch.float),
|
||||
"boxes": torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float),
|
||||
"labels": torch.tensor([[0, 1]]),
|
||||
}
|
||||
augs = transforms.AugmentationSequential(
|
||||
|
@ -129,7 +129,7 @@ def test_augmentation_sequential_multispectral(
|
|||
dtype=torch.float,
|
||||
),
|
||||
"mask": torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long),
|
||||
"boxes": torch.tensor([[[1, 0], [2, 0], [2, 1], [1, 1]]], dtype=torch.float),
|
||||
"boxes": torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float),
|
||||
"labels": torch.tensor([[0, 1]]),
|
||||
}
|
||||
augs = transforms.AugmentationSequential(
|
||||
|
@ -156,7 +156,7 @@ def test_augmentation_sequential_image_only(
|
|||
dtype=torch.float,
|
||||
),
|
||||
"mask": torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long),
|
||||
"boxes": torch.tensor([[[0, 1], [1, 1], [1, 0], [0, 0]]], dtype=torch.float),
|
||||
"boxes": torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float),
|
||||
"labels": torch.tensor([[0, 1]]),
|
||||
}
|
||||
augs = transforms.AugmentationSequential(
|
||||
|
@ -188,7 +188,7 @@ def test_sequential_transforms_augmentations(
|
|||
dtype=torch.float,
|
||||
),
|
||||
"mask": torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long),
|
||||
"boxes": torch.tensor([[[1, 0], [2, 0], [2, 1], [1, 1]]], dtype=torch.float),
|
||||
"boxes": torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float),
|
||||
"labels": torch.tensor([[0, 1]]),
|
||||
}
|
||||
train_transforms = transforms.AugmentationSequential(
|
||||
|
|
|
@ -38,6 +38,7 @@ from .ucmerced import UCMercedDataModule
|
|||
from .usavars import USAVarsDataModule
|
||||
from .utils import MisconfigurationException
|
||||
from .vaihingen import Vaihingen2DDataModule
|
||||
from .vhr10 import VHR10DataModule
|
||||
from .xview import XView2DataModule
|
||||
|
||||
__all__ = (
|
||||
|
@ -79,6 +80,7 @@ __all__ = (
|
|||
"UCMercedDataModule",
|
||||
"USAVarsDataModule",
|
||||
"Vaihingen2DDataModule",
|
||||
"VHR10DataModule",
|
||||
"XView2DataModule",
|
||||
# Base classes
|
||||
"BaseDataModule",
|
||||
|
|
|
@ -5,28 +5,13 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
import kornia.augmentation as K
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from ..datasets import NASAMarineDebris
|
||||
from ..transforms import AugmentationSequential
|
||||
from .geo import NonGeoDataModule
|
||||
from .utils import dataset_split
|
||||
|
||||
|
||||
def collate_fn(batch: list[dict[str, Tensor]]) -> dict[str, Any]:
|
||||
"""Custom object detection collate fn to handle variable boxes.
|
||||
|
||||
Args:
|
||||
batch: list of sample dicts return by dataset
|
||||
|
||||
Returns:
|
||||
batch dict output
|
||||
"""
|
||||
output: dict[str, Any] = {}
|
||||
output["image"] = torch.stack([sample["image"] for sample in batch])
|
||||
output["boxes"] = [sample["boxes"] for sample in batch]
|
||||
output["labels"] = [torch.tensor([1] * len(sample["boxes"])) for sample in batch]
|
||||
return output
|
||||
from .utils import AugPipe, collate_fn_detection, dataset_split
|
||||
|
||||
|
||||
class NASAMarineDebrisDataModule(NonGeoDataModule):
|
||||
|
@ -35,6 +20,8 @@ class NASAMarineDebrisDataModule(NonGeoDataModule):
|
|||
.. versionadded:: 0.2
|
||||
"""
|
||||
|
||||
std = torch.tensor(255)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = 64,
|
||||
|
@ -58,7 +45,14 @@ class NASAMarineDebrisDataModule(NonGeoDataModule):
|
|||
self.val_split_pct = val_split_pct
|
||||
self.test_split_pct = test_split_pct
|
||||
|
||||
self.collate_fn = collate_fn
|
||||
self.aug = AugPipe(
|
||||
AugmentationSequential(
|
||||
K.Normalize(mean=self.mean, std=self.std), data_keys=["image", "boxes"]
|
||||
),
|
||||
batch_size,
|
||||
)
|
||||
|
||||
self.collate_fn = collate_fn_detection
|
||||
|
||||
def setup(self, stage: str) -> None:
|
||||
"""Set up datasets.
|
||||
|
|
|
@ -5,10 +5,13 @@
|
|||
|
||||
import math
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from torch import Generator
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Generator, Tensor
|
||||
from torch.nn import Module
|
||||
from torch.utils.data import Subset, TensorDataset, random_split
|
||||
|
||||
from ..datasets import NonGeoDataset
|
||||
|
@ -19,6 +22,86 @@ class MisconfigurationException(Exception):
|
|||
"""Exception used to inform users of misuse with Lightning."""
|
||||
|
||||
|
||||
class AugPipe(Module):
|
||||
"""Pipeline for applying augmentations sequentially on select data keys.
|
||||
|
||||
.. versionadded:: 0.6
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, augs: Callable[[dict[str, Any]], dict[str, Any]], batch_size: int
|
||||
) -> None:
|
||||
"""Initialize a new AugPipe instance.
|
||||
|
||||
Args:
|
||||
augs: Augmentations to apply.
|
||||
batch_size: Batch size
|
||||
"""
|
||||
super().__init__()
|
||||
self.augs = augs
|
||||
self.batch_size = batch_size
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Apply the augmentation.
|
||||
|
||||
Args:
|
||||
batch: Input batch.
|
||||
|
||||
Returns:
|
||||
Augmented batch.
|
||||
"""
|
||||
batch_len = len(batch["image"])
|
||||
for bs in range(batch_len):
|
||||
batch_dict = {
|
||||
"image": batch["image"][bs],
|
||||
"labels": batch["labels"][bs],
|
||||
"boxes": batch["boxes"][bs],
|
||||
}
|
||||
|
||||
if "masks" in batch:
|
||||
batch_dict["masks"] = batch["masks"][bs]
|
||||
|
||||
batch_dict = self.augs(batch_dict)
|
||||
|
||||
batch["image"][bs] = batch_dict["image"]
|
||||
batch["labels"][bs] = batch_dict["labels"]
|
||||
batch["boxes"][bs] = batch_dict["boxes"]
|
||||
|
||||
if "masks" in batch:
|
||||
batch["masks"][bs] = batch_dict["masks"]
|
||||
|
||||
# Stack images
|
||||
batch["image"] = rearrange(batch["image"], "b () c h w -> b c h w")
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def collate_fn_detection(batch: list[dict[str, Tensor]]) -> dict[str, Any]:
|
||||
"""Custom collate fn for object detection and instance segmentation.
|
||||
|
||||
Args:
|
||||
batch: list of sample dicts return by dataset
|
||||
|
||||
Returns:
|
||||
batch dict output
|
||||
|
||||
.. versionadded:: 0.6
|
||||
"""
|
||||
output: dict[str, Any] = {}
|
||||
output["image"] = [sample["image"] for sample in batch]
|
||||
output["boxes"] = [sample["boxes"].float() for sample in batch]
|
||||
if "labels" in batch[0]:
|
||||
output["labels"] = [sample["labels"] for sample in batch]
|
||||
else:
|
||||
output["labels"] = [
|
||||
torch.tensor([1] * len(sample["boxes"])) for sample in batch
|
||||
]
|
||||
|
||||
if "masks" in batch[0]:
|
||||
output["masks"] = [sample["masks"] for sample in batch]
|
||||
return output
|
||||
|
||||
|
||||
def dataset_split(
|
||||
dataset: Union[TensorDataset, NonGeoDataset],
|
||||
val_pct: float,
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""NWPU VHR-10 datamodule."""
|
||||
|
||||
from typing import Any, Union
|
||||
|
||||
import kornia.augmentation as K
|
||||
import torch
|
||||
|
||||
from ..datasets import VHR10
|
||||
from ..samplers.utils import _to_tuple
|
||||
from ..transforms import AugmentationSequential
|
||||
from .geo import NonGeoDataModule
|
||||
from .utils import AugPipe, collate_fn_detection, dataset_split
|
||||
|
||||
|
||||
class VHR10DataModule(NonGeoDataModule):
|
||||
"""LightningDataModule implementation for the VHR10 dataset.
|
||||
|
||||
.. versionadded:: 0.6
|
||||
"""
|
||||
|
||||
std = torch.tensor(255)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = 64,
|
||||
patch_size: Union[tuple[int, int], int] = 512,
|
||||
num_workers: int = 0,
|
||||
val_split_pct: float = 0.2,
|
||||
test_split_pct: float = 0.2,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a new VHR10DataModule instance.
|
||||
|
||||
Args:
|
||||
batch_size: Size of each mini-batch.
|
||||
patch_size: Size of each patch, either ``size`` or ``(height, width)``.
|
||||
num_workers: Number of workers for parallel data loading.
|
||||
val_split_pct: Percentage of the dataset to use as a validation set.
|
||||
test_split_pct: Percentage of the dataset to use as a test set.
|
||||
**kwargs: Additional keyword arguments passed to
|
||||
:class:`~torchgeo.datasets.VHR10`.
|
||||
"""
|
||||
super().__init__(VHR10, batch_size, num_workers, **kwargs)
|
||||
|
||||
self.val_split_pct = val_split_pct
|
||||
self.test_split_pct = test_split_pct
|
||||
self.patch_size = _to_tuple(patch_size)
|
||||
|
||||
self.collate_fn = collate_fn_detection
|
||||
|
||||
self.train_aug = AugPipe(
|
||||
AugmentationSequential(
|
||||
K.Normalize(mean=self.mean, std=self.std),
|
||||
K.Resize(self.patch_size),
|
||||
K.RandomHorizontalFlip(),
|
||||
K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=0.7),
|
||||
K.RandomVerticalFlip(),
|
||||
data_keys=["image", "boxes", "masks"],
|
||||
),
|
||||
batch_size,
|
||||
)
|
||||
self.aug = AugPipe(
|
||||
AugmentationSequential(
|
||||
K.Normalize(mean=self.mean, std=self.std),
|
||||
K.Resize(self.patch_size),
|
||||
data_keys=["image", "boxes", "masks"],
|
||||
),
|
||||
batch_size,
|
||||
)
|
||||
|
||||
def setup(self, stage: str) -> None:
|
||||
"""Set up datasets.
|
||||
|
||||
Args:
|
||||
stage: Either 'fit', 'validate', 'test', or 'predict'.
|
||||
"""
|
||||
self.dataset = VHR10(**self.kwargs)
|
||||
self.train_dataset, self.val_dataset, self.test_dataset = dataset_split(
|
||||
self.dataset, self.val_split_pct, self.test_split_pct
|
||||
)
|
|
@ -45,10 +45,7 @@ def convert_coco_poly_to_mask(
|
|||
mask = torch.as_tensor(mask, dtype=torch.uint8)
|
||||
mask = mask.any(dim=2)
|
||||
masks.append(mask)
|
||||
if masks:
|
||||
masks_tensor = torch.stack(masks, dim=0)
|
||||
else:
|
||||
masks_tensor = torch.zeros((0, height, width), dtype=torch.uint8)
|
||||
masks_tensor = torch.stack(masks, dim=0)
|
||||
return masks_tensor
|
||||
|
||||
|
||||
|
@ -89,10 +86,8 @@ class ConvertCocoAnnotations:
|
|||
categories = [obj["category_id"] for obj in anno]
|
||||
classes = torch.tensor(categories, dtype=torch.int64)
|
||||
|
||||
if "segmentation" in anno[0]:
|
||||
segmentations = [obj["segmentation"] for obj in anno]
|
||||
else:
|
||||
segmentations = []
|
||||
segmentations = [obj["segmentation"] for obj in anno]
|
||||
|
||||
masks = convert_coco_poly_to_mask(segmentations, h, w)
|
||||
|
||||
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
|
||||
|
@ -258,8 +253,7 @@ class VHR10(NonGeoDataset):
|
|||
sample = self.coco_convert(sample)
|
||||
sample["labels"] = sample["label"]["labels"]
|
||||
sample["boxes"] = sample["label"]["boxes"]
|
||||
if "masks" in sample["label"]:
|
||||
sample["masks"] = sample["label"]["masks"]
|
||||
sample["masks"] = sample["label"]["masks"]
|
||||
del sample["label"]
|
||||
|
||||
if self.transforms is not None:
|
||||
|
@ -296,6 +290,7 @@ class VHR10(NonGeoDataset):
|
|||
with Image.open(filename) as img:
|
||||
array: "np.typing.NDArray[np.int_]" = np.array(img)
|
||||
tensor = torch.from_numpy(array)
|
||||
tensor = tensor.float()
|
||||
# Convert from HxWxC to CxHxW
|
||||
tensor = tensor.permute((2, 0, 1))
|
||||
return tensor
|
||||
|
@ -439,7 +434,7 @@ class VHR10(NonGeoDataset):
|
|||
ncols += 1
|
||||
|
||||
# Display image
|
||||
fig, axs = plt.subplots(ncols=ncols, squeeze=False, figsize=(ncols * 10, 10))
|
||||
fig, axs = plt.subplots(ncols=ncols, squeeze=False, figsize=(ncols * 10, 13))
|
||||
axs[0, 0].imshow(image)
|
||||
axs[0, 0].axis("off")
|
||||
|
||||
|
@ -536,9 +531,9 @@ class VHR10(NonGeoDataset):
|
|||
if show_titles:
|
||||
axs[0, 1].set_title("Prediction")
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if suptitle is not None:
|
||||
plt.suptitle(suptitle)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
return fig
|
||||
|
|
|
@ -10,6 +10,7 @@ import torch
|
|||
from einops import rearrange
|
||||
from kornia.contrib import extract_tensor_patches
|
||||
from kornia.geometry import crop_by_indices
|
||||
from kornia.geometry.boxes import Boxes
|
||||
from torch import Tensor
|
||||
from torch.nn.modules import Module
|
||||
|
||||
|
@ -47,6 +48,8 @@ class AugmentationSequential(Module):
|
|||
keys.append("input")
|
||||
elif key == "boxes":
|
||||
keys.append("bbox")
|
||||
elif key == "masks":
|
||||
keys.append("mask")
|
||||
else:
|
||||
keys.append(key)
|
||||
|
||||
|
@ -67,10 +70,19 @@ class AugmentationSequential(Module):
|
|||
dtype[key] = batch[key].dtype
|
||||
batch[key] = batch[key].float()
|
||||
|
||||
# Convert shape of boxes from [N, 4] to [N, 4, 2]
|
||||
if "boxes" in batch and (
|
||||
isinstance(batch["boxes"], list) or batch["boxes"].ndim == 2
|
||||
):
|
||||
batch["boxes"] = Boxes.from_tensor(batch["boxes"]).data
|
||||
|
||||
# Kornia requires masks to have a channel dimension
|
||||
if "mask" in batch and len(batch["mask"].shape) == 3:
|
||||
if "mask" in batch and batch["mask"].ndim == 3:
|
||||
batch["mask"] = rearrange(batch["mask"], "b h w -> b () h w")
|
||||
|
||||
if "masks" in batch and batch["masks"].ndim == 3:
|
||||
batch["masks"] = rearrange(batch["masks"], "c h w -> () c h w")
|
||||
|
||||
inputs = [batch[k] for k in self.data_keys]
|
||||
outputs_list: Union[Tensor, list[Tensor]] = self.augs(*inputs)
|
||||
outputs_list = (
|
||||
|
@ -85,9 +97,17 @@ class AugmentationSequential(Module):
|
|||
for key in self.data_keys:
|
||||
batch[key] = batch[key].to(dtype[key])
|
||||
|
||||
# Convert boxes to default [N, 4]
|
||||
if "boxes" in batch:
|
||||
batch["boxes"] = Boxes(batch["boxes"]).to_tensor(
|
||||
mode="xyxy"
|
||||
) # type:ignore[assignment]
|
||||
|
||||
# Torchmetrics does not support masks with a channel dimension
|
||||
if "mask" in batch and batch["mask"].shape[1] == 1:
|
||||
batch["mask"] = rearrange(batch["mask"], "b () h w -> b h w")
|
||||
if "masks" in batch and batch["masks"].ndim == 4:
|
||||
batch["masks"] = rearrange(batch["masks"], "() c h w -> c h w")
|
||||
|
||||
return batch
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче