Add train/val/test splits to UCMerced (#216)

* Added splits to the UCMerced dataset

* adding test files

* Removed random splits from UCMerced trainer
This commit is contained in:
Caleb Robinson 2021-11-01 08:09:36 -07:00 коммит произвёл GitHub
Родитель 8f0243beb5
Коммит 69598528e6
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
8 изменённых файлов: 105 добавлений и 57 удалений

Просмотреть файл

@ -0,0 +1,4 @@
agricultural00.tif
agricultural01.tif
agricultural02.tif
airplane00.tif

Просмотреть файл

@ -0,0 +1,4 @@
agricultural00.tif
agricultural01.tif
agricultural02.tif
airplane00.tif

Просмотреть файл

@ -0,0 +1,4 @@
agricultural00.tif
agricultural01.tif
agricultural02.tif
airplane00.tif

Просмотреть файл

@ -9,6 +9,7 @@ from typing import Generator
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from torch.utils.data import ConcatDataset
@ -21,9 +22,12 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
class TestUCMerced:
@pytest.fixture()
@pytest.fixture(params=["train", "val", "test"])
def dataset(
self, monkeypatch: Generator[MonkeyPatch, None, None], tmp_path: Path
self,
monkeypatch: Generator[MonkeyPatch, None, None],
tmp_path: Path,
request: SubRequest,
) -> UCMerced:
monkeypatch.setattr( # type: ignore[attr-defined]
torchgeo.datasets.ucmerced, "download_url", download_url
@ -32,9 +36,30 @@ class TestUCMerced:
monkeypatch.setattr(UCMerced, "md5", md5) # type: ignore[attr-defined]
url = os.path.join("tests", "data", "ucmerced", "UCMerced_LandUse.zip")
monkeypatch.setattr(UCMerced, "url", url) # type: ignore[attr-defined]
monkeypatch.setattr( # type: ignore[attr-defined]
UCMerced,
"split_urls",
{
"train": os.path.join(
"tests", "data", "ucmerced", "uc_merced-train.txt"
),
"val": os.path.join("tests", "data", "ucmerced", "uc_merced-val.txt"),
"test": os.path.join("tests", "data", "ucmerced", "uc_merced-test.txt"),
},
)
monkeypatch.setattr( # type: ignore[attr-defined]
UCMerced,
"split_md5s",
{
"train": "a01fa9f13333bb176fc1bfe26ff4c711",
"val": "a01fa9f13333bb176fc1bfe26ff4c711",
"test": "a01fa9f13333bb176fc1bfe26ff4c711",
},
)
root = str(tmp_path)
split = request.param
transforms = nn.Identity() # type: ignore[attr-defined]
return UCMerced(root, transforms, download=True, checksum=True)
return UCMerced(root, split, transforms, download=True, checksum=True)
def test_getitem(self, dataset: UCMerced) -> None:
x = dataset[0]

Просмотреть файл

@ -4,25 +4,16 @@
import os
import pytest
from _pytest.fixtures import SubRequest
from torchgeo.trainers import UCMercedDataModule
@pytest.fixture(scope="module", params=[True, False])
def datamodule(request: SubRequest) -> UCMercedDataModule:
@pytest.fixture(scope="module")
def datamodule() -> UCMercedDataModule:
root = os.path.join("tests", "data", "ucmerced")
batch_size = 2
num_workers = 0
unsupervised_mode = request.param
dm = UCMercedDataModule(
root,
batch_size,
num_workers,
val_split_pct=0.33,
test_split_pct=0.33,
unsupervised_mode=unsupervised_mode,
)
dm = UCMercedDataModule(root, batch_size, num_workers)
dm.prepare_data()
dm.setup()
return dm

Просмотреть файл

@ -607,6 +607,7 @@ class VisionClassificationDataset(VisionDataset, ImageFolder): # type: ignore[m
root: str,
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
loader: Optional[Callable[[str], Any]] = pil_loader,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> None:
"""Initialize a new VisionClassificationDataset instance.
@ -616,11 +617,17 @@ class VisionClassificationDataset(VisionDataset, ImageFolder): # type: ignore[m
entry and returns a transformed version
loader: a callable function which takes as input a path to an image and
returns a PIL Image or numpy array
is_valid_file: A function that takes the path of an Image file and checks if
the file is a valid file
"""
# When transform & target_transform are None, ImageFolder.__getitem__(index)
# returns a PIL.Image and int for image and label, respectively
super().__init__(
root=root, transform=None, target_transform=None, loader=loader
root=root,
transform=None,
target_transform=None,
loader=loader,
is_valid_file=is_valid_file,
)
# Must be set after calling super().__init__()

Просмотреть файл

@ -50,6 +50,11 @@ class UCMerced(VisionClassificationDataset):
* storagetanks
* tenniscourt
This dataset uses the train/val/test splits defined in the "In-domain representation
learning for remote sensing" paper:
* https://arxiv.org/abs/1911.06721.
If you use this dataset in your research, please cite the following paper:
* https://dl.acm.org/doi/10.1145/1869790.1869829
@ -85,9 +90,22 @@ class UCMerced(VisionClassificationDataset):
]
class_counts = {class_name: 100 for class_name in classes}
splits = ["train", "val", "test"]
split_urls = {
"train": "https://storage.googleapis.com/remote_sensing_representations/uc_merced-train.txt", # noqa: E501
"val": "https://storage.googleapis.com/remote_sensing_representations/uc_merced-val.txt", # noqa: E501
"test": "https://storage.googleapis.com/remote_sensing_representations/uc_merced-test.txt", # noqa: E501
}
split_md5s = {
"train": "f2fb12eb2210cfb53f93f063a35ff374",
"val": "11ecabfc52782e5ea6a9c7c0d263aca0",
"test": "046aff88472d8fc07c4678d03749e28d",
}
def __init__(
self,
root: str = "data",
split: str = "train",
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
download: bool = False,
checksum: bool = False,
@ -96,6 +114,7 @@ class UCMerced(VisionClassificationDataset):
Args:
root: root directory where dataset can be found
split: one of "train", "val", or "test"
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
@ -105,12 +124,24 @@ class UCMerced(VisionClassificationDataset):
RuntimeError: if ``download=False`` and data is not found, or checksums
don't match
"""
assert split in self.splits
self.root = root
self.transforms = transforms
self.download = download
self.checksum = checksum
self._verify()
super().__init__(root=os.path.join(root, self.base_dir), transforms=transforms)
valid_fns = set()
with open(os.path.join(self.root, f"uc_merced-{split}.txt"), "r") as f:
for fn in f:
valid_fns.add(fn.strip())
is_in_split: Callable[[str], bool] = lambda x: os.path.basename(x) in valid_fns
super().__init__(
root=os.path.join(root, self.base_dir),
transforms=transforms,
is_valid_file=is_in_split,
)
def _check_integrity(self) -> bool:
"""Check integrity of dataset.
@ -159,6 +190,13 @@ class UCMerced(VisionClassificationDataset):
filename=self.filename,
md5=self.md5 if self.checksum else None,
)
for split in self.splits:
download_url(
self.split_urls[split],
self.root,
filename=f"uc_merced-{split}.txt",
md5=self.split_md5s[split] if self.checksum else None,
)
def _extract(self) -> None:
"""Extract the dataset."""

Просмотреть файл

@ -13,7 +13,6 @@ from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Normalize
from ..datasets import UCMerced
from ..datasets.utils import dataset_split
from .tasks import ClassificationTask
# https://github.com/pytorch/pytorch/issues/60979
@ -44,9 +43,6 @@ class UCMercedDataModule(pl.LightningDataModule):
root_dir: str,
batch_size: int = 64,
num_workers: int = 4,
unsupervised_mode: bool = False,
val_split_pct: float = 0.2,
test_split_pct: float = 0.2,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for UCMerced based DataLoaders.
@ -55,19 +51,11 @@ class UCMercedDataModule(pl.LightningDataModule):
root_dir: The ``root`` arugment to pass to the UCMerced Dataset classes
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
unsupervised_mode: Makes the train dataloader return imagery from the train,
val, and test sets
val_split_pct: What percentage of the dataset to use as a validation set
test_split_pct: What percentage of the dataset to use as a test set
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.unsupervised_mode = unsupervised_mode
self.val_split_pct = val_split_pct
self.test_split_pct = test_split_pct
self.norm = Normalize(self.band_means, self.band_stds)
@ -107,16 +95,9 @@ class UCMercedDataModule(pl.LightningDataModule):
"""
transforms = Compose([self.preprocess])
if not self.unsupervised_mode:
dataset = UCMerced(self.root_dir, transforms=transforms)
self.train_dataset, self.val_dataset, self.test_dataset = dataset_split(
dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct
)
else:
self.train_dataset = UCMerced(self.root_dir, transforms=transforms)
self.val_dataset, self.test_dataset = None, None # type: ignore[assignment]
self.train_dataset = UCMerced(self.root_dir, "train", transforms=transforms)
self.val_dataset = UCMerced(self.root_dir, "val", transforms=transforms)
self.test_dataset = UCMerced(self.root_dir, "test", transforms=transforms)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
@ -137,15 +118,12 @@ class UCMercedDataModule(pl.LightningDataModule):
Returns:
validation data loader
"""
if self.unsupervised_mode or self.val_split_pct == 0:
return self.train_dataloader()
else:
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
@ -153,12 +131,9 @@ class UCMercedDataModule(pl.LightningDataModule):
Returns:
testing data loader
"""
if self.unsupervised_mode or self.test_split_pct == 0:
return self.train_dataloader()
else:
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)