Added functionality for validation split (#1540)

* added functionality for validation split

* Changed "valid" to "val"

* updated docstring & removed redundant lists

* Fixed format with Linters

* Add more testing files

* Simplify regex

* Update datamodule to use new val split

---------

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Tarandeep Singh 2023-09-29 21:32:19 +05:30 коммит произвёл GitHub
Родитель 76f92f86df
Коммит 6ae0d78448
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
30 изменённых файлов: 32 добавлений и 30 удалений

Просмотреть файл

@ -12,7 +12,5 @@ data:
init_args:
batch_size: 1
patch_size: 2
val_split_pct: 0.2
test_split_pct: 0.2
dict_kwargs:
root: "tests/data/inria"

Просмотреть файл

@ -10,7 +10,5 @@ data:
init_args:
batch_size: 1
patch_size: 2
val_split_pct: 0.2
test_split_pct: 0.2
dict_kwargs:
root: "tests/data/inria"

Просмотреть файл

@ -10,7 +10,5 @@ data:
init_args:
batch_size: 1
patch_size: 2
val_split_pct: 0.2
test_split_pct: 0.2
dict_kwargs:
root: "tests/data/inria"

Просмотреть файл

@ -10,7 +10,5 @@ data:
init_args:
batch_size: 1
patch_size: 2
val_split_pct: 0.2
test_split_pct: 0.2
dict_kwargs:
root: "tests/data/inria"

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичные данные
tests/data/inria/AerialImageDataset/test/images/austin15.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/inria/AerialImageDataset/test/images/austin16.tif Normal file

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичные данные
tests/data/inria/AerialImageDataset/train/gt/austin6.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/inria/AerialImageDataset/train/gt/austin7.tif Normal file

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичные данные
tests/data/inria/AerialImageDataset/train/images/austin6.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/inria/AerialImageDataset/train/images/austin7.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/inria/NEW2-AerialImageDataset.zip

Двоичный файл не отображается.

Просмотреть файл

@ -81,10 +81,9 @@ def generate_test_data(root: str, n_samples: int = 2) -> str:
shutil.make_archive(
archive_path, "zip", root_dir=root, base_dir="AerialImageDataset"
)
shutil.rmtree(folder_path)
return calculate_md5(f"{archive_path}.zip")
if __name__ == "__main__":
md5_hash = generate_test_data(os.getcwd(), 5)
md5_hash = generate_test_data(os.getcwd(), 7)
print(md5_hash)

Просмотреть файл

@ -15,12 +15,12 @@ from torchgeo.datasets import InriaAerialImageLabeling
class TestInriaAerialImageLabeling:
@pytest.fixture(params=["train", "test"])
@pytest.fixture(params=["train", "val", "test"])
def dataset(
self, request: SubRequest, monkeypatch: MonkeyPatch
) -> InriaAerialImageLabeling:
root = os.path.join("tests", "data", "inria")
test_md5 = "478688944e4797c097d9387fd0b3f038"
test_md5 = "3ecbe95eb84aea064e455c4321546be1"
monkeypatch.setattr(InriaAerialImageLabeling, "md5", test_md5)
transforms = nn.Identity()
return InriaAerialImageLabeling(
@ -38,7 +38,12 @@ class TestInriaAerialImageLabeling:
assert x["image"].ndim == 3
def test_len(self, dataset: InriaAerialImageLabeling) -> None:
assert len(dataset) == 5
if dataset.split == "train":
assert len(dataset) == 2
elif dataset.split == "val":
assert len(dataset) == 5
elif dataset.split == "test":
assert len(dataset) == 7
def test_already_downloaded(self, dataset: InriaAerialImageLabeling) -> None:
InriaAerialImageLabeling(root=dataset.root)

Просмотреть файл

@ -12,7 +12,6 @@ from ..samplers.utils import _to_tuple
from ..transforms import AugmentationSequential
from ..transforms.transforms import _RandomNCrop
from .geo import NonGeoDataModule
from .utils import dataset_split
class InriaAerialImageLabelingDataModule(NonGeoDataModule):
@ -29,8 +28,6 @@ class InriaAerialImageLabelingDataModule(NonGeoDataModule):
batch_size: int = 64,
patch_size: Union[tuple[int, int], int] = 64,
num_workers: int = 0,
val_split_pct: float = 0.1,
test_split_pct: float = 0.1,
**kwargs: Any,
) -> None:
"""Initialize a new InriaAerialImageLabelingDataModule instance.
@ -40,16 +37,12 @@ class InriaAerialImageLabelingDataModule(NonGeoDataModule):
patch_size: Size of each patch, either ``size`` or ``(height, width)``.
Should be a multiple of 32 for most segmentation architectures.
num_workers: Number of workers for parallel data loading.
val_split_pct: Percentage of the dataset to use as a validation set.
test_split_pct: Percentage of the dataset to use as a test set.
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.InriaAerialImageLabeling`.
"""
super().__init__(InriaAerialImageLabeling, 1, num_workers, **kwargs)
self.patch_size = _to_tuple(patch_size)
self.val_split_pct = val_split_pct
self.test_split_pct = test_split_pct
self.train_aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
@ -75,11 +68,10 @@ class InriaAerialImageLabelingDataModule(NonGeoDataModule):
Args:
stage: Either 'fit', 'validate', 'test', or 'predict'.
"""
if stage in ["fit", "validate", "test"]:
self.dataset = InriaAerialImageLabeling(split="train", **self.kwargs)
self.train_dataset, self.val_dataset, self.test_dataset = dataset_split(
self.dataset, self.val_split_pct, self.test_split_pct
)
if stage in ["fit"]:
self.train_dataset = InriaAerialImageLabeling(split="train", **self.kwargs)
if stage in ["fit", "validate"]:
self.val_dataset = InriaAerialImageLabeling(split="val", **self.kwargs)
if stage in ["predict"]:
# Test set masks are not public, use for prediction instead
self.predict_dataset = InriaAerialImageLabeling(split="test", **self.kwargs)

Просмотреть файл

@ -5,6 +5,7 @@
import glob
import os
import re
from typing import Any, Callable, Optional
import matplotlib.pyplot as plt
@ -45,6 +46,9 @@ class InriaAerialImageLabeling(NonGeoDataset):
* https://doi.org/10.1109/IGARSS.2017.8127684
.. versionadded:: 0.3
.. versionchanged:: 0.5
Added support for a *val* split.
"""
directory = "AerialImageDataset"
@ -62,7 +66,7 @@ class InriaAerialImageLabeling(NonGeoDataset):
Args:
root: root directory where dataset can be found
split: train/test split
split: train/val/test split
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version.
checksum: if True, check the MD5 of the downloaded files (may be slow)
@ -72,7 +76,7 @@ class InriaAerialImageLabeling(NonGeoDataset):
RuntimeError: if dataset is missing
"""
self.root = root
assert split in {"train", "test"}
assert split in {"train", "val", "test"}
self.split = split
self.transforms = transforms
self.checksum = checksum
@ -90,15 +94,25 @@ class InriaAerialImageLabeling(NonGeoDataset):
list of dicts containing paths for each pair of image and label
"""
files = []
root_dir = os.path.join(root, self.directory, self.split)
split = "train" if self.split in ["train", "val"] else "test"
root_dir = os.path.join(root, self.directory, split)
pattern = re.compile(r"([A-Za-z]+)(\d+)")
images = glob.glob(os.path.join(root_dir, "images", "*.tif"))
images = sorted(images)
if self.split == "train":
if split == "train":
labels = glob.glob(os.path.join(root_dir, "gt", "*.tif"))
labels = sorted(labels)
for img, lbl in zip(images, labels):
files.append({"image": img, "label": lbl})
if match := pattern.search(img):
idx = int(match.group(2))
# For validation, use the first 5 images of every location
if self.split == "train" and idx > 5:
files.append({"image": img, "label": lbl})
elif self.split == "val" and idx < 6:
files.append({"image": img, "label": lbl})
else:
for img in images:
files.append({"image": img})