зеркало из https://github.com/microsoft/torchgeo.git
Added functionality for validation split (#1540)
* added functionality for validation split * Changed "valid" to "val" * updated docstring & removed redundant lists * Fixed format with Linters * Add more testing files * Simplify regex * Update datamodule to use new val split --------- Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Родитель
76f92f86df
Коммит
6ae0d78448
|
@ -12,7 +12,5 @@ data:
|
|||
init_args:
|
||||
batch_size: 1
|
||||
patch_size: 2
|
||||
val_split_pct: 0.2
|
||||
test_split_pct: 0.2
|
||||
dict_kwargs:
|
||||
root: "tests/data/inria"
|
||||
|
|
|
@ -10,7 +10,5 @@ data:
|
|||
init_args:
|
||||
batch_size: 1
|
||||
patch_size: 2
|
||||
val_split_pct: 0.2
|
||||
test_split_pct: 0.2
|
||||
dict_kwargs:
|
||||
root: "tests/data/inria"
|
||||
|
|
|
@ -10,7 +10,5 @@ data:
|
|||
init_args:
|
||||
batch_size: 1
|
||||
patch_size: 2
|
||||
val_split_pct: 0.2
|
||||
test_split_pct: 0.2
|
||||
dict_kwargs:
|
||||
root: "tests/data/inria"
|
||||
|
|
|
@ -10,7 +10,5 @@ data:
|
|||
init_args:
|
||||
batch_size: 1
|
||||
patch_size: 2
|
||||
val_split_pct: 0.2
|
||||
test_split_pct: 0.2
|
||||
dict_kwargs:
|
||||
root: "tests/data/inria"
|
||||
|
|
Двоичные данные
tests/data/inria/AerialImageDataset/test/images/austin10.tif
Двоичные данные
tests/data/inria/AerialImageDataset/test/images/austin10.tif
Двоичный файл не отображается.
Двоичные данные
tests/data/inria/AerialImageDataset/test/images/austin11.tif
Двоичные данные
tests/data/inria/AerialImageDataset/test/images/austin11.tif
Двоичный файл не отображается.
Двоичные данные
tests/data/inria/AerialImageDataset/test/images/austin12.tif
Двоичные данные
tests/data/inria/AerialImageDataset/test/images/austin12.tif
Двоичный файл не отображается.
Двоичные данные
tests/data/inria/AerialImageDataset/test/images/austin13.tif
Двоичные данные
tests/data/inria/AerialImageDataset/test/images/austin13.tif
Двоичный файл не отображается.
Двоичные данные
tests/data/inria/AerialImageDataset/test/images/austin14.tif
Двоичные данные
tests/data/inria/AerialImageDataset/test/images/austin14.tif
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичные данные
tests/data/inria/AerialImageDataset/train/gt/austin1.tif
Двоичные данные
tests/data/inria/AerialImageDataset/train/gt/austin1.tif
Двоичный файл не отображается.
Двоичные данные
tests/data/inria/AerialImageDataset/train/gt/austin2.tif
Двоичные данные
tests/data/inria/AerialImageDataset/train/gt/austin2.tif
Двоичный файл не отображается.
Двоичные данные
tests/data/inria/AerialImageDataset/train/gt/austin3.tif
Двоичные данные
tests/data/inria/AerialImageDataset/train/gt/austin3.tif
Двоичный файл не отображается.
Двоичные данные
tests/data/inria/AerialImageDataset/train/gt/austin4.tif
Двоичные данные
tests/data/inria/AerialImageDataset/train/gt/austin4.tif
Двоичный файл не отображается.
Двоичные данные
tests/data/inria/AerialImageDataset/train/gt/austin5.tif
Двоичные данные
tests/data/inria/AerialImageDataset/train/gt/austin5.tif
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичные данные
tests/data/inria/AerialImageDataset/train/images/austin1.tif
Двоичные данные
tests/data/inria/AerialImageDataset/train/images/austin1.tif
Двоичный файл не отображается.
Двоичные данные
tests/data/inria/AerialImageDataset/train/images/austin2.tif
Двоичные данные
tests/data/inria/AerialImageDataset/train/images/austin2.tif
Двоичный файл не отображается.
Двоичные данные
tests/data/inria/AerialImageDataset/train/images/austin3.tif
Двоичные данные
tests/data/inria/AerialImageDataset/train/images/austin3.tif
Двоичный файл не отображается.
Двоичные данные
tests/data/inria/AerialImageDataset/train/images/austin4.tif
Двоичные данные
tests/data/inria/AerialImageDataset/train/images/austin4.tif
Двоичный файл не отображается.
Двоичные данные
tests/data/inria/AerialImageDataset/train/images/austin5.tif
Двоичные данные
tests/data/inria/AerialImageDataset/train/images/austin5.tif
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичные данные
tests/data/inria/NEW2-AerialImageDataset.zip
Двоичные данные
tests/data/inria/NEW2-AerialImageDataset.zip
Двоичный файл не отображается.
|
@ -81,10 +81,9 @@ def generate_test_data(root: str, n_samples: int = 2) -> str:
|
|||
shutil.make_archive(
|
||||
archive_path, "zip", root_dir=root, base_dir="AerialImageDataset"
|
||||
)
|
||||
shutil.rmtree(folder_path)
|
||||
return calculate_md5(f"{archive_path}.zip")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
md5_hash = generate_test_data(os.getcwd(), 5)
|
||||
md5_hash = generate_test_data(os.getcwd(), 7)
|
||||
print(md5_hash)
|
||||
|
|
|
@ -15,12 +15,12 @@ from torchgeo.datasets import InriaAerialImageLabeling
|
|||
|
||||
|
||||
class TestInriaAerialImageLabeling:
|
||||
@pytest.fixture(params=["train", "test"])
|
||||
@pytest.fixture(params=["train", "val", "test"])
|
||||
def dataset(
|
||||
self, request: SubRequest, monkeypatch: MonkeyPatch
|
||||
) -> InriaAerialImageLabeling:
|
||||
root = os.path.join("tests", "data", "inria")
|
||||
test_md5 = "478688944e4797c097d9387fd0b3f038"
|
||||
test_md5 = "3ecbe95eb84aea064e455c4321546be1"
|
||||
monkeypatch.setattr(InriaAerialImageLabeling, "md5", test_md5)
|
||||
transforms = nn.Identity()
|
||||
return InriaAerialImageLabeling(
|
||||
|
@ -38,7 +38,12 @@ class TestInriaAerialImageLabeling:
|
|||
assert x["image"].ndim == 3
|
||||
|
||||
def test_len(self, dataset: InriaAerialImageLabeling) -> None:
|
||||
assert len(dataset) == 5
|
||||
if dataset.split == "train":
|
||||
assert len(dataset) == 2
|
||||
elif dataset.split == "val":
|
||||
assert len(dataset) == 5
|
||||
elif dataset.split == "test":
|
||||
assert len(dataset) == 7
|
||||
|
||||
def test_already_downloaded(self, dataset: InriaAerialImageLabeling) -> None:
|
||||
InriaAerialImageLabeling(root=dataset.root)
|
||||
|
|
|
@ -12,7 +12,6 @@ from ..samplers.utils import _to_tuple
|
|||
from ..transforms import AugmentationSequential
|
||||
from ..transforms.transforms import _RandomNCrop
|
||||
from .geo import NonGeoDataModule
|
||||
from .utils import dataset_split
|
||||
|
||||
|
||||
class InriaAerialImageLabelingDataModule(NonGeoDataModule):
|
||||
|
@ -29,8 +28,6 @@ class InriaAerialImageLabelingDataModule(NonGeoDataModule):
|
|||
batch_size: int = 64,
|
||||
patch_size: Union[tuple[int, int], int] = 64,
|
||||
num_workers: int = 0,
|
||||
val_split_pct: float = 0.1,
|
||||
test_split_pct: float = 0.1,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a new InriaAerialImageLabelingDataModule instance.
|
||||
|
@ -40,16 +37,12 @@ class InriaAerialImageLabelingDataModule(NonGeoDataModule):
|
|||
patch_size: Size of each patch, either ``size`` or ``(height, width)``.
|
||||
Should be a multiple of 32 for most segmentation architectures.
|
||||
num_workers: Number of workers for parallel data loading.
|
||||
val_split_pct: Percentage of the dataset to use as a validation set.
|
||||
test_split_pct: Percentage of the dataset to use as a test set.
|
||||
**kwargs: Additional keyword arguments passed to
|
||||
:class:`~torchgeo.datasets.InriaAerialImageLabeling`.
|
||||
"""
|
||||
super().__init__(InriaAerialImageLabeling, 1, num_workers, **kwargs)
|
||||
|
||||
self.patch_size = _to_tuple(patch_size)
|
||||
self.val_split_pct = val_split_pct
|
||||
self.test_split_pct = test_split_pct
|
||||
|
||||
self.train_aug = AugmentationSequential(
|
||||
K.Normalize(mean=self.mean, std=self.std),
|
||||
|
@ -75,11 +68,10 @@ class InriaAerialImageLabelingDataModule(NonGeoDataModule):
|
|||
Args:
|
||||
stage: Either 'fit', 'validate', 'test', or 'predict'.
|
||||
"""
|
||||
if stage in ["fit", "validate", "test"]:
|
||||
self.dataset = InriaAerialImageLabeling(split="train", **self.kwargs)
|
||||
self.train_dataset, self.val_dataset, self.test_dataset = dataset_split(
|
||||
self.dataset, self.val_split_pct, self.test_split_pct
|
||||
)
|
||||
if stage in ["fit"]:
|
||||
self.train_dataset = InriaAerialImageLabeling(split="train", **self.kwargs)
|
||||
if stage in ["fit", "validate"]:
|
||||
self.val_dataset = InriaAerialImageLabeling(split="val", **self.kwargs)
|
||||
if stage in ["predict"]:
|
||||
# Test set masks are not public, use for prediction instead
|
||||
self.predict_dataset = InriaAerialImageLabeling(split="test", **self.kwargs)
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
|
||||
import glob
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
@ -45,6 +46,9 @@ class InriaAerialImageLabeling(NonGeoDataset):
|
|||
* https://doi.org/10.1109/IGARSS.2017.8127684
|
||||
|
||||
.. versionadded:: 0.3
|
||||
|
||||
.. versionchanged:: 0.5
|
||||
Added support for a *val* split.
|
||||
"""
|
||||
|
||||
directory = "AerialImageDataset"
|
||||
|
@ -62,7 +66,7 @@ class InriaAerialImageLabeling(NonGeoDataset):
|
|||
|
||||
Args:
|
||||
root: root directory where dataset can be found
|
||||
split: train/test split
|
||||
split: train/val/test split
|
||||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version.
|
||||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
|
@ -72,7 +76,7 @@ class InriaAerialImageLabeling(NonGeoDataset):
|
|||
RuntimeError: if dataset is missing
|
||||
"""
|
||||
self.root = root
|
||||
assert split in {"train", "test"}
|
||||
assert split in {"train", "val", "test"}
|
||||
self.split = split
|
||||
self.transforms = transforms
|
||||
self.checksum = checksum
|
||||
|
@ -90,15 +94,25 @@ class InriaAerialImageLabeling(NonGeoDataset):
|
|||
list of dicts containing paths for each pair of image and label
|
||||
"""
|
||||
files = []
|
||||
root_dir = os.path.join(root, self.directory, self.split)
|
||||
split = "train" if self.split in ["train", "val"] else "test"
|
||||
root_dir = os.path.join(root, self.directory, split)
|
||||
pattern = re.compile(r"([A-Za-z]+)(\d+)")
|
||||
|
||||
images = glob.glob(os.path.join(root_dir, "images", "*.tif"))
|
||||
images = sorted(images)
|
||||
if self.split == "train":
|
||||
|
||||
if split == "train":
|
||||
labels = glob.glob(os.path.join(root_dir, "gt", "*.tif"))
|
||||
labels = sorted(labels)
|
||||
|
||||
for img, lbl in zip(images, labels):
|
||||
files.append({"image": img, "label": lbl})
|
||||
if match := pattern.search(img):
|
||||
idx = int(match.group(2))
|
||||
# For validation, use the first 5 images of every location
|
||||
if self.split == "train" and idx > 5:
|
||||
files.append({"image": img, "label": lbl})
|
||||
elif self.split == "val" and idx < 6:
|
||||
files.append({"image": img, "label": lbl})
|
||||
else:
|
||||
for img in images:
|
||||
files.append({"image": img})
|
||||
|
|
Загрузка…
Ссылка в новой задаче