зеркало из https://github.com/microsoft/torchgeo.git
Add train/val/test splits to UCMerced (#216)
* Added splits to the UCMerced dataset * adding test files * Removed random splits from UCMerced trainer
This commit is contained in:
Родитель
8f0243beb5
Коммит
69598528e6
|
@ -0,0 +1,4 @@
|
|||
agricultural00.tif
|
||||
agricultural01.tif
|
||||
agricultural02.tif
|
||||
airplane00.tif
|
|
@ -0,0 +1,4 @@
|
|||
agricultural00.tif
|
||||
agricultural01.tif
|
||||
agricultural02.tif
|
||||
airplane00.tif
|
|
@ -0,0 +1,4 @@
|
|||
agricultural00.tif
|
||||
agricultural01.tif
|
||||
agricultural02.tif
|
||||
airplane00.tif
|
|
@ -9,6 +9,7 @@ from typing import Generator
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
|
@ -21,9 +22,12 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
|||
|
||||
|
||||
class TestUCMerced:
|
||||
@pytest.fixture()
|
||||
@pytest.fixture(params=["train", "val", "test"])
|
||||
def dataset(
|
||||
self, monkeypatch: Generator[MonkeyPatch, None, None], tmp_path: Path
|
||||
self,
|
||||
monkeypatch: Generator[MonkeyPatch, None, None],
|
||||
tmp_path: Path,
|
||||
request: SubRequest,
|
||||
) -> UCMerced:
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
torchgeo.datasets.ucmerced, "download_url", download_url
|
||||
|
@ -32,9 +36,30 @@ class TestUCMerced:
|
|||
monkeypatch.setattr(UCMerced, "md5", md5) # type: ignore[attr-defined]
|
||||
url = os.path.join("tests", "data", "ucmerced", "UCMerced_LandUse.zip")
|
||||
monkeypatch.setattr(UCMerced, "url", url) # type: ignore[attr-defined]
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
UCMerced,
|
||||
"split_urls",
|
||||
{
|
||||
"train": os.path.join(
|
||||
"tests", "data", "ucmerced", "uc_merced-train.txt"
|
||||
),
|
||||
"val": os.path.join("tests", "data", "ucmerced", "uc_merced-val.txt"),
|
||||
"test": os.path.join("tests", "data", "ucmerced", "uc_merced-test.txt"),
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
UCMerced,
|
||||
"split_md5s",
|
||||
{
|
||||
"train": "a01fa9f13333bb176fc1bfe26ff4c711",
|
||||
"val": "a01fa9f13333bb176fc1bfe26ff4c711",
|
||||
"test": "a01fa9f13333bb176fc1bfe26ff4c711",
|
||||
},
|
||||
)
|
||||
root = str(tmp_path)
|
||||
split = request.param
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return UCMerced(root, transforms, download=True, checksum=True)
|
||||
return UCMerced(root, split, transforms, download=True, checksum=True)
|
||||
|
||||
def test_getitem(self, dataset: UCMerced) -> None:
|
||||
x = dataset[0]
|
||||
|
|
|
@ -4,25 +4,16 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
from _pytest.fixtures import SubRequest
|
||||
|
||||
from torchgeo.trainers import UCMercedDataModule
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=[True, False])
|
||||
def datamodule(request: SubRequest) -> UCMercedDataModule:
|
||||
@pytest.fixture(scope="module")
|
||||
def datamodule() -> UCMercedDataModule:
|
||||
root = os.path.join("tests", "data", "ucmerced")
|
||||
batch_size = 2
|
||||
num_workers = 0
|
||||
unsupervised_mode = request.param
|
||||
dm = UCMercedDataModule(
|
||||
root,
|
||||
batch_size,
|
||||
num_workers,
|
||||
val_split_pct=0.33,
|
||||
test_split_pct=0.33,
|
||||
unsupervised_mode=unsupervised_mode,
|
||||
)
|
||||
dm = UCMercedDataModule(root, batch_size, num_workers)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
|
|
@ -607,6 +607,7 @@ class VisionClassificationDataset(VisionDataset, ImageFolder): # type: ignore[m
|
|||
root: str,
|
||||
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
|
||||
loader: Optional[Callable[[str], Any]] = pil_loader,
|
||||
is_valid_file: Optional[Callable[[str], bool]] = None,
|
||||
) -> None:
|
||||
"""Initialize a new VisionClassificationDataset instance.
|
||||
|
||||
|
@ -616,11 +617,17 @@ class VisionClassificationDataset(VisionDataset, ImageFolder): # type: ignore[m
|
|||
entry and returns a transformed version
|
||||
loader: a callable function which takes as input a path to an image and
|
||||
returns a PIL Image or numpy array
|
||||
is_valid_file: A function that takes the path of an Image file and checks if
|
||||
the file is a valid file
|
||||
"""
|
||||
# When transform & target_transform are None, ImageFolder.__getitem__(index)
|
||||
# returns a PIL.Image and int for image and label, respectively
|
||||
super().__init__(
|
||||
root=root, transform=None, target_transform=None, loader=loader
|
||||
root=root,
|
||||
transform=None,
|
||||
target_transform=None,
|
||||
loader=loader,
|
||||
is_valid_file=is_valid_file,
|
||||
)
|
||||
|
||||
# Must be set after calling super().__init__()
|
||||
|
|
|
@ -50,6 +50,11 @@ class UCMerced(VisionClassificationDataset):
|
|||
* storagetanks
|
||||
* tenniscourt
|
||||
|
||||
This dataset uses the train/val/test splits defined in the "In-domain representation
|
||||
learning for remote sensing" paper:
|
||||
|
||||
* https://arxiv.org/abs/1911.06721.
|
||||
|
||||
If you use this dataset in your research, please cite the following paper:
|
||||
|
||||
* https://dl.acm.org/doi/10.1145/1869790.1869829
|
||||
|
@ -85,9 +90,22 @@ class UCMerced(VisionClassificationDataset):
|
|||
]
|
||||
class_counts = {class_name: 100 for class_name in classes}
|
||||
|
||||
splits = ["train", "val", "test"]
|
||||
split_urls = {
|
||||
"train": "https://storage.googleapis.com/remote_sensing_representations/uc_merced-train.txt", # noqa: E501
|
||||
"val": "https://storage.googleapis.com/remote_sensing_representations/uc_merced-val.txt", # noqa: E501
|
||||
"test": "https://storage.googleapis.com/remote_sensing_representations/uc_merced-test.txt", # noqa: E501
|
||||
}
|
||||
split_md5s = {
|
||||
"train": "f2fb12eb2210cfb53f93f063a35ff374",
|
||||
"val": "11ecabfc52782e5ea6a9c7c0d263aca0",
|
||||
"test": "046aff88472d8fc07c4678d03749e28d",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str = "data",
|
||||
split: str = "train",
|
||||
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
|
@ -96,6 +114,7 @@ class UCMerced(VisionClassificationDataset):
|
|||
|
||||
Args:
|
||||
root: root directory where dataset can be found
|
||||
split: one of "train", "val", or "test"
|
||||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version
|
||||
download: if True, download dataset and store it in the root directory
|
||||
|
@ -105,12 +124,24 @@ class UCMerced(VisionClassificationDataset):
|
|||
RuntimeError: if ``download=False`` and data is not found, or checksums
|
||||
don't match
|
||||
"""
|
||||
assert split in self.splits
|
||||
self.root = root
|
||||
self.transforms = transforms
|
||||
self.download = download
|
||||
self.checksum = checksum
|
||||
self._verify()
|
||||
super().__init__(root=os.path.join(root, self.base_dir), transforms=transforms)
|
||||
|
||||
valid_fns = set()
|
||||
with open(os.path.join(self.root, f"uc_merced-{split}.txt"), "r") as f:
|
||||
for fn in f:
|
||||
valid_fns.add(fn.strip())
|
||||
is_in_split: Callable[[str], bool] = lambda x: os.path.basename(x) in valid_fns
|
||||
|
||||
super().__init__(
|
||||
root=os.path.join(root, self.base_dir),
|
||||
transforms=transforms,
|
||||
is_valid_file=is_in_split,
|
||||
)
|
||||
|
||||
def _check_integrity(self) -> bool:
|
||||
"""Check integrity of dataset.
|
||||
|
@ -159,6 +190,13 @@ class UCMerced(VisionClassificationDataset):
|
|||
filename=self.filename,
|
||||
md5=self.md5 if self.checksum else None,
|
||||
)
|
||||
for split in self.splits:
|
||||
download_url(
|
||||
self.split_urls[split],
|
||||
self.root,
|
||||
filename=f"uc_merced-{split}.txt",
|
||||
md5=self.split_md5s[split] if self.checksum else None,
|
||||
)
|
||||
|
||||
def _extract(self) -> None:
|
||||
"""Extract the dataset."""
|
||||
|
|
|
@ -13,7 +13,6 @@ from torch.utils.data import DataLoader
|
|||
from torchvision.transforms import Compose, Normalize
|
||||
|
||||
from ..datasets import UCMerced
|
||||
from ..datasets.utils import dataset_split
|
||||
from .tasks import ClassificationTask
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
|
@ -44,9 +43,6 @@ class UCMercedDataModule(pl.LightningDataModule):
|
|||
root_dir: str,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 4,
|
||||
unsupervised_mode: bool = False,
|
||||
val_split_pct: float = 0.2,
|
||||
test_split_pct: float = 0.2,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a LightningDataModule for UCMerced based DataLoaders.
|
||||
|
@ -55,19 +51,11 @@ class UCMercedDataModule(pl.LightningDataModule):
|
|||
root_dir: The ``root`` arugment to pass to the UCMerced Dataset classes
|
||||
batch_size: The batch size to use in all created DataLoaders
|
||||
num_workers: The number of workers to use in all created DataLoaders
|
||||
unsupervised_mode: Makes the train dataloader return imagery from the train,
|
||||
val, and test sets
|
||||
val_split_pct: What percentage of the dataset to use as a validation set
|
||||
test_split_pct: What percentage of the dataset to use as a test set
|
||||
"""
|
||||
super().__init__() # type: ignore[no-untyped-call]
|
||||
self.root_dir = root_dir
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.unsupervised_mode = unsupervised_mode
|
||||
|
||||
self.val_split_pct = val_split_pct
|
||||
self.test_split_pct = test_split_pct
|
||||
|
||||
self.norm = Normalize(self.band_means, self.band_stds)
|
||||
|
||||
|
@ -107,16 +95,9 @@ class UCMercedDataModule(pl.LightningDataModule):
|
|||
"""
|
||||
transforms = Compose([self.preprocess])
|
||||
|
||||
if not self.unsupervised_mode:
|
||||
|
||||
dataset = UCMerced(self.root_dir, transforms=transforms)
|
||||
self.train_dataset, self.val_dataset, self.test_dataset = dataset_split(
|
||||
dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct
|
||||
)
|
||||
else:
|
||||
|
||||
self.train_dataset = UCMerced(self.root_dir, transforms=transforms)
|
||||
self.val_dataset, self.test_dataset = None, None # type: ignore[assignment]
|
||||
self.train_dataset = UCMerced(self.root_dir, "train", transforms=transforms)
|
||||
self.val_dataset = UCMerced(self.root_dir, "val", transforms=transforms)
|
||||
self.test_dataset = UCMerced(self.root_dir, "test", transforms=transforms)
|
||||
|
||||
def train_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for training.
|
||||
|
@ -137,15 +118,12 @@ class UCMercedDataModule(pl.LightningDataModule):
|
|||
Returns:
|
||||
validation data loader
|
||||
"""
|
||||
if self.unsupervised_mode or self.val_split_pct == 0:
|
||||
return self.train_dataloader()
|
||||
else:
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
return DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
"""Return a DataLoader for testing.
|
||||
|
@ -153,12 +131,9 @@ class UCMercedDataModule(pl.LightningDataModule):
|
|||
Returns:
|
||||
testing data loader
|
||||
"""
|
||||
if self.unsupervised_mode or self.test_split_pct == 0:
|
||||
return self.train_dataloader()
|
||||
else:
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
return DataLoader(
|
||||
self.test_dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
|
|
Загрузка…
Ссылка в новой задаче