Add random crop logic to DeepGlobeLandCover Datamodule (#876)

* crop logic

* typo

* change train_batch_size logic

* fix failing test

* typos and naming

* return argument train dataloader

* typo

* fix failing test

* suggestions except about test file

* remove test_deepglobe and add test to trainer

* forgot new conf file

* reanme collate function

* move cropping logic to transform and utils

* remove comment

* simplify

* move pad_segmentation to transforms

* another one

* naming and versionadded

* another transforms approach

* typo

* fix read the docs

* some checks for Ncrop

* add unit tests new transforms

* Remove cruft

* More simplification

* Add config file

* Implemented ExtractTensorPatches

* Remove tests

* Remove unnecessary attrs

* Apply to both input and mask

* Implement RandomNCrop

* Fix dimensions

* mypy fixes

* Fix docs

* Ensure that image and mask get the same transformation

* Bump min kornia version

* ignore still needed?

* Remove unneeded hacks

* Fix pydocstyle

* Fix dimensions

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Nils Lehmann 2022-12-30 01:08:49 +01:00 коммит произвёл GitHub
Родитель 876b06aaac
Коммит c62d8321fb
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
13 изменённых файлов: 256 добавлений и 90 удалений

Просмотреть файл

@ -169,8 +169,8 @@ Kenya Crop Type
.. autoclass:: CV4AKenyaCropType
Deep Globe Land Cover
^^^^^^^^^^^^^^^^^^^^^
DeepGlobe Land Cover
^^^^^^^^^^^^^^^^^^^^
.. autoclass:: DeepGlobeLandCover

Просмотреть файл

@ -5,7 +5,7 @@ Dataset,Task,Source,# Samples,# Classes,Size (px),Resolution (m),Bands
`Cloud Cover Detection`_,S,Sentinel-2,"22,728",2,512x512,10,MSI
`COWC`_,"C, R","CSUAV AFRL, ISPRS, LINZ, AGRC","388,435",2,256x256,0.15,RGB
`Kenya Crop Type`_,S,Sentinel-2,"4,688",7,"3,035x2,016",10,MSI
`Deep Globe Land Cover`_,S,DigitalGlobe +Vivid,803,7,"2,448x2,448",0.5,RGB
`DeepGlobe Land Cover`_,S,DigitalGlobe +Vivid,803,7,"2,448x2,448",0.5,RGB
`DFC2022`_,S,Aerial,"3,981",15,"2,000x2,000",0.5,RGB
`ETCI2021 Flood Detection`_,S,Sentinel-1,"66,810",2,256x256,5--20,SAR
`EuroSAT`_,C,Sentinel-2,"27,000",10,64x64,10,MSI
@ -34,4 +34,4 @@ Dataset,Task,Source,# Samples,# Classes,Size (px),Resolution (m),Bands
`Vaihingen`_,S,Aerial,33,6,"1,281--3,816",0.09,RGB
`NWPU VHR-10`_,I,"Google Earth, Vaihingen",800,10,"358--1,728",0.08--2,RGB
`xView2`_,CD,Maxar,"3,732",4,"1,024x1,024",0.8,RGB
`ZueriCrop`_,"I, T",Sentinel-2,116K,48,24x24,10,MSI
`ZueriCrop`_,"I, T",Sentinel-2,116K,48,24x24,10,MSI

1 Dataset Task Source # Samples # Classes Size (px) Resolution (m) Bands
5 `Cloud Cover Detection`_ S Sentinel-2 22,728 2 512x512 10 MSI
6 `COWC`_ C, R CSUAV AFRL, ISPRS, LINZ, AGRC 388,435 2 256x256 0.15 RGB
7 `Kenya Crop Type`_ S Sentinel-2 4,688 7 3,035x2,016 10 MSI
8 `Deep Globe Land Cover`_ `DeepGlobe Land Cover`_ S DigitalGlobe +Vivid 803 7 2,448x2,448 0.5 RGB
9 `DFC2022`_ S Aerial 3,981 15 2,000x2,000 0.5 RGB
10 `ETCI2021 Flood Detection`_ S Sentinel-1 66,810 2 256x256 5--20 SAR
11 `EuroSAT`_ C Sentinel-2 27,000 10 64x64 10 MSI
34 `Vaihingen`_ S Aerial 33 6 1,281--3,816 0.09 RGB
35 `NWPU VHR-10`_ I Google Earth, Vaihingen 800 10 358--1,728 0.08--2 RGB
36 `xView2`_ CD Maxar 3,732 4 1,024x1,024 0.8 RGB
37 `ZueriCrop`_ I, T Sentinel-2 116K 48 24x24 10 MSI

Просмотреть файл

@ -22,7 +22,7 @@ dependencies:
- flake8>=3.8
- ipywidgets>=7
- isort[colors]>=5.8
- kornia>=0.6.4
- kornia>=0.6.5
- laspy>=2
- mypy>=0.900
- nbmake>=0.1

Просмотреть файл

@ -4,7 +4,7 @@ setuptools==42.0.0
# install
einops==0.3.0
fiona==1.8.0
kornia==0.6.4
kornia==0.6.5
matplotlib==3.3.0
numpy==1.17.2
omegaconf==2.1.0

Просмотреть файл

@ -29,8 +29,8 @@ install_requires =
einops>=0.3,<0.7
# fiona 1.8+ required for reading empty files
fiona>=1.8,<2
# kornia 0.6.4+ required for kornia.contrib.compute_padding
kornia>=0.6.4,<0.7
# kornia 0.6.5+ required due to change in kornia.augmentation API
kornia>=0.6.5,<0.7
# matplotlib 3.3+ required for (H, W, 1) image support in plt.imshow
matplotlib>=3.3,<4
# numpy 1.17.2+ required by pytorch-lightning

Просмотреть файл

@ -14,6 +14,8 @@ experiment:
ignore_index: null
datamodule:
root: "tests/data/deepglobelandcover"
num_tiles_per_batch: 1
num_patches_per_tile: 1
patch_size: 2
val_split_pct: 0.5
batch_size: 1
num_workers: 0

Просмотреть файл

@ -1,19 +0,0 @@
experiment:
task: "deepglobelandcover"
module:
loss: "ce"
model: "unet"
backbone: "resnet18"
weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 6
verbose: false
in_channels: 3
num_classes: 7
num_filters: 1
ignore_index: null
datamodule:
root: "tests/data/deepglobelandcover"
val_split_pct: 0.0
batch_size: 1
num_workers: 0

Просмотреть файл

@ -36,8 +36,7 @@ class TestSemanticSegmentationTask:
"name,classname",
[
("chesapeake_cvpr_5", ChesapeakeCVPRDataModule),
("deepglobelandcover_0", DeepGlobeLandCoverDataModule),
("deepglobelandcover_5", DeepGlobeLandCoverDataModule),
("deepglobelandcover", DeepGlobeLandCoverDataModule),
("etci2021", ETCI2021DataModule),
("inria_train", InriaAerialImageLabelingDataModule),
("inria_val", InriaAerialImageLabelingDataModule),

Просмотреть файл

@ -3,14 +3,17 @@
"""DeepGlobe Land Cover Classification Challenge datamodule."""
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Tuple, Union
import matplotlib.pyplot as plt
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose
from kornia.augmentation import Normalize
from torch.utils.data import DataLoader
from ..datasets import DeepGlobeLandCover
from ..samplers.utils import _to_tuple
from ..transforms import AugmentationSequential
from ..transforms.transforms import _ExtractTensorPatches, _RandomNCrop
from .utils import dataset_split
@ -18,72 +21,74 @@ class DeepGlobeLandCoverDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the DeepGlobe Land Cover dataset.
Uses the train/test splits from the dataset.
"""
def __init__(
self,
batch_size: int = 64,
num_workers: int = 0,
num_tiles_per_batch: int = 16,
num_patches_per_tile: int = 16,
patch_size: Union[Tuple[int, int], int] = 64,
val_split_pct: float = 0.2,
num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for DeepGlobe Land Cover based DataLoaders.
"""Initialize a new LightningDataModule instance.
The DeepGlobe Land Cover dataset contains images that are too large to pass
directly through a model. Instead, we randomly sample patches from image tiles
during training and chop up image tiles into patch grids during evaluation.
During training, the effective batch size is equal to
``num_tiles_per_batch`` x ``num_patches_per_tile``.
Args:
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
val_split_pct: What percentage of the dataset to use as a validation set
num_tiles_per_batch: The number of image tiles to sample from during
training
num_patches_per_tile: The number of patches to randomly sample from each
image tile during training
patch_size: The size of each patch, either ``size`` or ``(height, width)``.
Should be a multiple of 32 for most segmentation architectures
val_split_pct: The percentage of the dataset to use as a validation set
num_workers: The number of workers to use for parallel data loading
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.DeepGlobeLandCover`
.. versionchanged:: 0.4
*batch_size* was replaced by *num_tile_per_batch*, *num_patches_per_tile*,
and *patch_size*.
"""
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers
self.num_tiles_per_batch = num_tiles_per_batch
self.num_patches_per_tile = num_patches_per_tile
self.patch_size = _to_tuple(patch_size)
self.val_split_pct = val_split_pct
self.num_workers = num_workers
self.kwargs = kwargs
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Args:
sample: input image dictionary
Returns:
preprocessed sample
"""
sample["image"] = sample["image"].float()
sample["image"] /= 255.0
return sample
self.train_transform = AugmentationSequential(
Normalize(mean=0.0, std=255.0),
_RandomNCrop(self.patch_size, self.num_patches_per_tile),
data_keys=["image", "mask"],
)
self.test_transform = AugmentationSequential(
Normalize(mean=0.0, std=255.0),
_ExtractTensorPatches(self.patch_size),
data_keys=["image", "mask"],
)
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
"""Initialize the main Dataset objects.
This method is called once per GPU per run.
Args:
stage: stage to set up
"""
transforms = Compose([self.preprocess])
dataset = DeepGlobeLandCover(
split="train", transforms=transforms, **self.kwargs
)
self.train_dataset: Dataset[Any]
self.val_dataset: Dataset[Any]
if self.val_split_pct > 0.0:
self.train_dataset, self.val_dataset, _ = dataset_split(
dataset, val_pct=self.val_split_pct, test_pct=0.0
)
else:
self.train_dataset = dataset
self.val_dataset = dataset
self.test_dataset = DeepGlobeLandCover(
split="test", transforms=transforms, **self.kwargs
train_dataset = DeepGlobeLandCover(split="train", **self.kwargs)
self.train_dataset, self.val_dataset = dataset_split(
train_dataset, self.val_split_pct
)
self.test_dataset = DeepGlobeLandCover(split="test", **self.kwargs)
def train_dataloader(self) -> DataLoader[Dict[str, Any]]:
"""Return a DataLoader for training.
@ -93,7 +98,7 @@ class DeepGlobeLandCoverDataModule(pl.LightningDataModule):
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
batch_size=self.num_tiles_per_batch,
num_workers=self.num_workers,
shuffle=True,
)
@ -105,10 +110,7 @@ class DeepGlobeLandCoverDataModule(pl.LightningDataModule):
validation data loader
"""
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
self.val_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False
)
def test_dataloader(self) -> DataLoader[Dict[str, Any]]:
@ -118,12 +120,35 @@ class DeepGlobeLandCoverDataModule(pl.LightningDataModule):
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
self.test_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False
)
def on_after_batch_transfer(
self, batch: Dict[str, Any], dataloader_idx: int
) -> Dict[str, Any]:
"""Apply augmentations to batch after transferring to GPU.
Args:
batch: A batch of data that needs to be altered or augmented
dataloader_idx: The index of the dataloader to which the batch belongs
Returns:
A batch of data
"""
# Kornia requires masks to have a channel dimension
batch["mask"] = batch["mask"].unsqueeze(1)
if self.trainer:
if self.trainer.training:
batch = self.train_transform(batch)
elif self.trainer.validating or self.trainer.testing:
batch = self.test_transform(batch)
# Torchmetrics does not support masks with a channel dimension
batch["mask"] = batch["mask"].squeeze(1)
return batch
def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
"""Run :meth:`torchgeo.datasets.DeepGlobeLandCover.plot`.

Просмотреть файл

@ -3,7 +3,7 @@
"""InriaAerialImageLabeling datamodule."""
from typing import Any, Dict, List, Optional, Tuple, Union, cast
from typing import Any, Dict, List, Optional, Tuple, Union
import kornia.augmentation as K
import matplotlib.pyplot as plt
@ -69,7 +69,7 @@ class InriaAerialImageLabelingDataModule(pl.LightningDataModule):
self.num_workers = num_workers
self.val_split_pct = val_split_pct
self.test_split_pct = test_split_pct
self.patch_size = cast(Tuple[int, int], _to_tuple(patch_size))
self.patch_size = _to_tuple(patch_size)
self.num_patches_per_tile = num_patches_per_tile
self.kwargs = kwargs

Просмотреть файл

@ -173,7 +173,7 @@ class DeepGlobeLandCover(NonGeoDataset):
array: "np.typing.NDArray[np.int_]" = np.array(img)
tensor = torch.from_numpy(array)
# Convert from HxWxC to CxHxW
tensor = tensor.permute((2, 0, 1))
tensor = tensor.permute((2, 0, 1)).to(torch.float32)
return tensor
def _load_target(self, index: int) -> Tensor:

Просмотреть файл

@ -4,13 +4,23 @@
"""Common sampler utilities."""
import math
from typing import Optional, Tuple, Union
from typing import Optional, Tuple, Union, overload
import torch
from ..datasets import BoundingBox
@overload
def _to_tuple(value: Union[Tuple[int, int], int]) -> Tuple[int, int]:
...
@overload
def _to_tuple(value: Union[Tuple[float, float], float]) -> Tuple[float, float]:
...
def _to_tuple(value: Union[Tuple[float, float], float]) -> Tuple[float, float]:
"""Convert value to a tuple if it is not already a tuple.

Просмотреть файл

@ -3,14 +3,19 @@
"""TorchGeo transforms."""
from typing import Dict, List, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import kornia.augmentation as K
import kornia
import torch
from kornia.augmentation import GeometricAugmentationBase2D
from kornia.augmentation.random_generator import CropGenerator
from kornia.contrib import compute_padding, extract_tensor_patches
from kornia.geometry import crop_by_indices
from torch import Tensor
from torch.nn.modules import Module
# TODO: contribute these to Kornia and delete this file
class AugmentationSequential(Module):
"""Wrapper around kornia AugmentationSequential to handle input dicts."""
@ -33,7 +38,7 @@ class AugmentationSequential(Module):
else:
keys.append(key)
self.augs = K.AugmentationSequential(*args, data_keys=keys)
self.augs = kornia.augmentation.AugmentationSequential(*args, data_keys=keys)
def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""Perform augmentations and update data dict.
@ -69,3 +74,147 @@ class AugmentationSequential(Module):
sample["boxes"] = sample["boxes"].to(boxes_dtype)
return sample
class _ExtractTensorPatches(GeometricAugmentationBase2D):
"""Chop up a tensor into a grid."""
def __init__(self, window_size: Union[int, Tuple[int, int]]) -> None:
"""Initialize a new _ExtractTensorPatches instance.
Args:
window_size: the size of each patch
"""
super().__init__(p=1)
self.flags = {"window_size": window_size}
def compute_transformation(
self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any]
) -> Tensor:
"""Compute the transformation.
Args:
input: the input tensor
params: generated parameters
flags: static parameters
Returns:
the transformation
"""
out: Tensor = self.identity_matrix(input)
return out
def apply_transform(
self,
input: Tensor,
params: Dict[str, Tensor],
flags: Dict[str, Any],
transform: Optional[Tensor] = None,
) -> Tensor:
"""Apply the transform.
Args:
input: the input tensor
params: generated parameters
flags: static parameters
transform: the geometric transformation tensor
Returns:
the augmented input
"""
size = flags["window_size"]
h, w = input.shape[-2:]
padding = compute_padding((h, w), size)
input = extract_tensor_patches(input, size, size, padding)
input = torch.flatten(input, 0, 1) # [B, N, C?, H, W] -> [B*N, C?, H, W]
return input
class _RandomNCrop(GeometricAugmentationBase2D):
"""Take N random crops of a tensor."""
def __init__(self, size: Tuple[int, int], num: int) -> None:
"""Initialize a new _RandomNCrop instance.
Args:
size: desired output size (out_h, out_w) of the crop
num: number of crops to take
"""
super().__init__(p=1)
self._param_generator: _NCropGenerator = _NCropGenerator(size, num)
self.flags = {"size": size, "num": num}
def compute_transformation(
self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any]
) -> Tensor:
"""Compute the transformation.
Args:
input: the input tensor
params: generated parameters
flags: static parameters
Returns:
the transformation
"""
out: Tensor = self.identity_matrix(input)
return out
def apply_transform(
self,
input: Tensor,
params: Dict[str, Tensor],
flags: Dict[str, Any],
transform: Optional[Tensor] = None,
) -> Tensor:
"""Apply the transform.
Args:
input: the input tensor
params: generated parameters
flags: static parameters
transform: the geometric transformation tensor
Returns:
the augmented input
"""
out = []
for i in range(flags["num"]):
out.append(crop_by_indices(input, params["src"][i], flags["size"]))
return torch.cat(out)
class _NCropGenerator(CropGenerator):
"""Generate N random crops."""
def __init__(self, size: Union[Tuple[int, int], Tensor], num: int) -> None:
"""Initialize a new _NCropGenerator instance.
Args:
size: desired output size (out_h, out_w) of the crop
num: number of crops to generate
"""
super().__init__(size)
self.num = num
def forward(
self, batch_shape: torch.Size, same_on_batch: bool = False
) -> Dict[str, Tensor]:
"""Generate the crops.
Args:
batch_shape: input size (b, c?, in_h, in_w)
same_on_batch: apply the same transformation across the batch
Returns:
the randomly generated parameters
"""
out = []
for _ in range(self.num):
out.append(super().forward(batch_shape, same_on_batch))
return {
"src": torch.stack([x["src"] for x in out]),
"dst": torch.stack([x["dst"] for x in out]),
"input_size": out[0]["input_size"],
"output_size": out[0]["output_size"],
}