зеркало из https://github.com/microsoft/torchgeo.git
Update FAIR1M dataset and datamodule (#1275)
This commit is contained in:
Родитель
698d2b5c97
Коммит
28615a19df
Двоичные данные
tests/data/fair1m/images.zip
Двоичные данные
tests/data/fair1m/images.zip
Двоичный файл не отображается.
Двоичные данные
tests/data/fair1m/labelXmls.zip
Двоичные данные
tests/data/fair1m/labelXmls.zip
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
|
@ -0,0 +1,28 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<annotation>
|
||||
<source>
|
||||
<filename>0.tif</filename>
|
||||
</source>
|
||||
<size>
|
||||
<width>2</width>
|
||||
<height>2</height>
|
||||
<depth>3</depth>
|
||||
</size>
|
||||
<objects>
|
||||
<object>
|
||||
<coordinate>pixel</coordinate>
|
||||
<type>rectangle</type>
|
||||
<description>None</description>
|
||||
<possibleresult>
|
||||
<name>Liquid Cargo Ship</name>
|
||||
</possibleresult>
|
||||
<points>
|
||||
<point>0.000000,0.000000</point>
|
||||
<point>1.000000,0.000000</point>
|
||||
<point>1.000000,1.000000</point>
|
||||
<point>0.000000,1.000000</point>
|
||||
<point>0.000000,0.000000</point>
|
||||
</points>
|
||||
</object>
|
||||
</objects>
|
||||
</annotation>
|
|
@ -0,0 +1,28 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<annotation>
|
||||
<source>
|
||||
<filename>1.tif</filename>
|
||||
</source>
|
||||
<size>
|
||||
<width>2</width>
|
||||
<height>2</height>
|
||||
<depth>3</depth>
|
||||
</size>
|
||||
<objects>
|
||||
<object>
|
||||
<coordinate>pixel</coordinate>
|
||||
<type>rectangle</type>
|
||||
<description>None</description>
|
||||
<possibleresult>
|
||||
<name>Cargo Truck</name>
|
||||
</possibleresult>
|
||||
<points>
|
||||
<point>0.000000,0.000000</point>
|
||||
<point>1.000000,0.000000</point>
|
||||
<point>1.000000,1.000000</point>
|
||||
<point>0.000000,1.000000</point>
|
||||
<point>0.000000,0.000000</point>
|
||||
</points>
|
||||
</object>
|
||||
</objects>
|
||||
</annotation>
|
|
@ -0,0 +1,28 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<annotation>
|
||||
<source>
|
||||
<filename>2.tif</filename>
|
||||
</source>
|
||||
<size>
|
||||
<width>2</width>
|
||||
<height>2</height>
|
||||
<depth>3</depth>
|
||||
</size>
|
||||
<objects>
|
||||
<object>
|
||||
<coordinate>pixel</coordinate>
|
||||
<type>rectangle</type>
|
||||
<description>None</description>
|
||||
<possibleresult>
|
||||
<name>Boeing737</name>
|
||||
</possibleresult>
|
||||
<points>
|
||||
<point>0.000000,0.000000</point>
|
||||
<point>1.000000,0.000000</point>
|
||||
<point>1.000000,1.000000</point>
|
||||
<point>0.000000,1.000000</point>
|
||||
<point>0.000000,0.000000</point>
|
||||
</points>
|
||||
</object>
|
||||
</objects>
|
||||
</annotation>
|
|
@ -0,0 +1,28 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<annotation>
|
||||
<source>
|
||||
<filename>3.tif</filename>
|
||||
</source>
|
||||
<size>
|
||||
<width>2</width>
|
||||
<height>2</height>
|
||||
<depth>3</depth>
|
||||
</size>
|
||||
<objects>
|
||||
<object>
|
||||
<coordinate>pixel</coordinate>
|
||||
<type>rectangle</type>
|
||||
<description>None</description>
|
||||
<possibleresult>
|
||||
<name>A220</name>
|
||||
</possibleresult>
|
||||
<points>
|
||||
<point>0.000000,0.000000</point>
|
||||
<point>1.000000,0.000000</point>
|
||||
<point>1.000000,1.000000</point>
|
||||
<point>0.000000,1.000000</point>
|
||||
<point>0.000000,0.000000</point>
|
||||
</points>
|
||||
</object>
|
||||
</objects>
|
||||
</annotation>
|
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
|
@ -0,0 +1,28 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<annotation>
|
||||
<source>
|
||||
<filename>0.tif</filename>
|
||||
</source>
|
||||
<size>
|
||||
<width>2</width>
|
||||
<height>2</height>
|
||||
<depth>3</depth>
|
||||
</size>
|
||||
<objects>
|
||||
<object>
|
||||
<coordinate>pixel</coordinate>
|
||||
<type>rectangle</type>
|
||||
<description>None</description>
|
||||
<possibleresult>
|
||||
<name>Liquid Cargo Ship</name>
|
||||
</possibleresult>
|
||||
<points>
|
||||
<point>0.000000,0.000000</point>
|
||||
<point>1.000000,0.000000</point>
|
||||
<point>1.000000,1.000000</point>
|
||||
<point>0.000000,1.000000</point>
|
||||
<point>0.000000,0.000000</point>
|
||||
</points>
|
||||
</object>
|
||||
</objects>
|
||||
</annotation>
|
|
@ -0,0 +1,28 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<annotation>
|
||||
<source>
|
||||
<filename>1.tif</filename>
|
||||
</source>
|
||||
<size>
|
||||
<width>2</width>
|
||||
<height>2</height>
|
||||
<depth>3</depth>
|
||||
</size>
|
||||
<objects>
|
||||
<object>
|
||||
<coordinate>pixel</coordinate>
|
||||
<type>rectangle</type>
|
||||
<description>None</description>
|
||||
<possibleresult>
|
||||
<name>Cargo Truck</name>
|
||||
</possibleresult>
|
||||
<points>
|
||||
<point>0.000000,0.000000</point>
|
||||
<point>1.000000,0.000000</point>
|
||||
<point>1.000000,1.000000</point>
|
||||
<point>0.000000,1.000000</point>
|
||||
<point>0.000000,0.000000</point>
|
||||
</points>
|
||||
</object>
|
||||
</objects>
|
||||
</annotation>
|
|
@ -0,0 +1,28 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<annotation>
|
||||
<source>
|
||||
<filename>2.tif</filename>
|
||||
</source>
|
||||
<size>
|
||||
<width>2</width>
|
||||
<height>2</height>
|
||||
<depth>3</depth>
|
||||
</size>
|
||||
<objects>
|
||||
<object>
|
||||
<coordinate>pixel</coordinate>
|
||||
<type>rectangle</type>
|
||||
<description>None</description>
|
||||
<possibleresult>
|
||||
<name>Boeing737</name>
|
||||
</possibleresult>
|
||||
<points>
|
||||
<point>0.000000,0.000000</point>
|
||||
<point>1.000000,0.000000</point>
|
||||
<point>1.000000,1.000000</point>
|
||||
<point>0.000000,1.000000</point>
|
||||
<point>0.000000,0.000000</point>
|
||||
</points>
|
||||
</object>
|
||||
</objects>
|
||||
</annotation>
|
|
@ -0,0 +1,28 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<annotation>
|
||||
<source>
|
||||
<filename>3.tif</filename>
|
||||
</source>
|
||||
<size>
|
||||
<width>2</width>
|
||||
<height>2</height>
|
||||
<depth>3</depth>
|
||||
</size>
|
||||
<objects>
|
||||
<object>
|
||||
<coordinate>pixel</coordinate>
|
||||
<type>rectangle</type>
|
||||
<description>None</description>
|
||||
<possibleresult>
|
||||
<name>A220</name>
|
||||
</possibleresult>
|
||||
<points>
|
||||
<point>0.000000,0.000000</point>
|
||||
<point>1.000000,0.000000</point>
|
||||
<point>1.000000,1.000000</point>
|
||||
<point>0.000000,1.000000</point>
|
||||
<point>0.000000,0.000000</point>
|
||||
</points>
|
||||
</object>
|
||||
</objects>
|
||||
</annotation>
|
Двоичный файл не отображается.
|
@ -7,7 +7,6 @@ import matplotlib.pyplot as plt
|
|||
import pytest
|
||||
|
||||
from torchgeo.datamodules import FAIR1MDataModule
|
||||
from torchgeo.datasets import unbind_samples
|
||||
|
||||
|
||||
class TestFAIR1MDataModule:
|
||||
|
@ -16,13 +15,7 @@ class TestFAIR1MDataModule:
|
|||
root = os.path.join("tests", "data", "fair1m")
|
||||
batch_size = 2
|
||||
num_workers = 0
|
||||
dm = FAIR1MDataModule(
|
||||
root=root,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
val_split_pct=0.33,
|
||||
test_split_pct=0.33,
|
||||
)
|
||||
dm = FAIR1MDataModule(root=root, batch_size=batch_size, num_workers=num_workers)
|
||||
return dm
|
||||
|
||||
def test_train_dataloader(self, datamodule: FAIR1MDataModule) -> None:
|
||||
|
@ -33,13 +26,17 @@ class TestFAIR1MDataModule:
|
|||
datamodule.setup("validate")
|
||||
next(iter(datamodule.val_dataloader()))
|
||||
|
||||
def test_test_dataloader(self, datamodule: FAIR1MDataModule) -> None:
|
||||
datamodule.setup("test")
|
||||
next(iter(datamodule.test_dataloader()))
|
||||
def test_predict_dataloader(self, datamodule: FAIR1MDataModule) -> None:
|
||||
datamodule.setup("predict")
|
||||
next(iter(datamodule.predict_dataloader()))
|
||||
|
||||
def test_plot(self, datamodule: FAIR1MDataModule) -> None:
|
||||
datamodule.setup("validate")
|
||||
batch = next(iter(datamodule.val_dataloader()))
|
||||
sample = unbind_samples(batch)[0]
|
||||
sample = {
|
||||
"image": batch["image"][0],
|
||||
"boxes": batch["boxes"][0],
|
||||
"label": batch["label"][0],
|
||||
}
|
||||
datamodule.plot(sample)
|
||||
plt.close()
|
||||
|
|
|
@ -9,59 +9,119 @@ import matplotlib.pyplot as plt
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import FAIR1M
|
||||
|
||||
|
||||
def download_url(url: str, root: str, filename: str, *args: str, **kwargs: str) -> None:
|
||||
os.makedirs(root, exist_ok=True)
|
||||
shutil.copy(url, os.path.join(root, filename))
|
||||
|
||||
|
||||
class TestFAIR1M:
|
||||
@pytest.fixture
|
||||
def dataset(self, monkeypatch: MonkeyPatch) -> FAIR1M:
|
||||
md5s = ["f278aba757de9079225db42107e09e30", "ecef7bd264fcbc533bec5e9e1cacaff1"]
|
||||
test_root = os.path.join("tests", "data", "fair1m")
|
||||
|
||||
@pytest.fixture(params=["train", "val", "test"])
|
||||
def dataset(
|
||||
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
|
||||
) -> FAIR1M:
|
||||
monkeypatch.setattr(torchgeo.datasets.fair1m, "download_url", download_url)
|
||||
urls = {
|
||||
"train": (
|
||||
os.path.join(self.test_root, "train", "part1", "images.zip"),
|
||||
os.path.join(self.test_root, "train", "part1", "labelXml.zip"),
|
||||
os.path.join(self.test_root, "train", "part2", "images.zip"),
|
||||
os.path.join(self.test_root, "train", "part2", "labelXmls.zip"),
|
||||
),
|
||||
"val": (
|
||||
os.path.join(self.test_root, "validation", "images.zip"),
|
||||
os.path.join(self.test_root, "validation", "labelXmls.zip"),
|
||||
),
|
||||
"test": (
|
||||
os.path.join(self.test_root, "test", "images0.zip"),
|
||||
os.path.join(self.test_root, "test", "images1.zip"),
|
||||
os.path.join(self.test_root, "test", "images2.zip"),
|
||||
),
|
||||
}
|
||||
md5s = {
|
||||
"train": (
|
||||
"ffbe9329e51ae83161ce24b5b46dc934",
|
||||
"2db6fbe64be6ebb0a03656da6c6effe7",
|
||||
"401b0f1d75d9d23f2e088bfeaf274cfa",
|
||||
"d62b18eae8c3201f6112c2e9db84d605",
|
||||
),
|
||||
"val": (
|
||||
"83d2f06574fc7158ded0eb1fb256c8fe",
|
||||
"316490b200503c54cf43835a341b6dbe",
|
||||
),
|
||||
"test": (
|
||||
"3c02845752667b96a5749c90c7fdc994",
|
||||
"9359107f1d0abac6a5b98725f4064bc0",
|
||||
"d7bc2985c625ffd47d86cdabb2a9d2bc",
|
||||
),
|
||||
}
|
||||
monkeypatch.setattr(FAIR1M, "urls", urls)
|
||||
monkeypatch.setattr(FAIR1M, "md5s", md5s)
|
||||
root = os.path.join("tests", "data", "fair1m")
|
||||
root = str(tmp_path)
|
||||
split = request.param
|
||||
transforms = nn.Identity()
|
||||
return FAIR1M(root, transforms)
|
||||
return FAIR1M(root, split, transforms, download=True, checksum=True)
|
||||
|
||||
def test_getitem(self, dataset: FAIR1M) -> None:
|
||||
x = dataset[0]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x["image"], torch.Tensor)
|
||||
assert isinstance(x["boxes"], torch.Tensor)
|
||||
assert isinstance(x["label"], torch.Tensor)
|
||||
assert x["image"].shape[0] == 3
|
||||
assert x["boxes"].shape[-2:] == (5, 2)
|
||||
assert x["label"].ndim == 1
|
||||
|
||||
if dataset.split != "test":
|
||||
assert isinstance(x["boxes"], torch.Tensor)
|
||||
assert isinstance(x["label"], torch.Tensor)
|
||||
assert x["boxes"].shape[-2:] == (5, 2)
|
||||
assert x["label"].ndim == 1
|
||||
|
||||
def test_len(self, dataset: FAIR1M) -> None:
|
||||
assert len(dataset) == 4
|
||||
if dataset.split == "train":
|
||||
assert len(dataset) == 8
|
||||
else:
|
||||
assert len(dataset) == 4
|
||||
|
||||
def test_already_downloaded(self, dataset: FAIR1M, tmp_path: Path) -> None:
|
||||
shutil.rmtree(str(tmp_path))
|
||||
shutil.copytree(dataset.root, str(tmp_path))
|
||||
FAIR1M(root=str(tmp_path))
|
||||
FAIR1M(root=str(tmp_path), split=dataset.split, download=True)
|
||||
|
||||
def test_already_downloaded_not_extracted(
|
||||
self, dataset: FAIR1M, tmp_path: Path
|
||||
) -> None:
|
||||
for filename in dataset.filenames:
|
||||
filepath = os.path.join("tests", "data", "fair1m", filename)
|
||||
shutil.copy(filepath, str(tmp_path))
|
||||
FAIR1M(root=str(tmp_path), checksum=True)
|
||||
shutil.rmtree(dataset.root)
|
||||
for filepath, url in zip(
|
||||
dataset.paths[dataset.split], dataset.urls[dataset.split]
|
||||
):
|
||||
output = os.path.join(str(tmp_path), filepath)
|
||||
os.makedirs(os.path.dirname(output), exist_ok=True)
|
||||
download_url(url, root=os.path.dirname(output), filename=output)
|
||||
|
||||
FAIR1M(root=str(tmp_path), split=dataset.split, checksum=True)
|
||||
|
||||
def test_corrupted(self, tmp_path: Path, dataset: FAIR1M) -> None:
|
||||
md5s = tuple(["randomhash"] * len(FAIR1M.md5s[dataset.split]))
|
||||
FAIR1M.md5s[dataset.split] = md5s
|
||||
shutil.rmtree(dataset.root)
|
||||
for filepath, url in zip(
|
||||
dataset.paths[dataset.split], dataset.urls[dataset.split]
|
||||
):
|
||||
output = os.path.join(str(tmp_path), filepath)
|
||||
os.makedirs(os.path.dirname(output), exist_ok=True)
|
||||
download_url(url, root=os.path.dirname(output), filename=output)
|
||||
|
||||
def test_corrupted(self, tmp_path: Path) -> None:
|
||||
filenames = ["images.zip", "labelXmls.zip"]
|
||||
for filename in filenames:
|
||||
with open(os.path.join(tmp_path, filename), "w") as f:
|
||||
f.write("bad")
|
||||
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
|
||||
FAIR1M(root=str(tmp_path), checksum=True)
|
||||
FAIR1M(root=str(tmp_path), split=dataset.split, checksum=True)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
err = "Dataset not found in `root` directory, "
|
||||
"specify a different `root` directory."
|
||||
with pytest.raises(RuntimeError, match=err):
|
||||
FAIR1M(str(tmp_path))
|
||||
def test_not_downloaded(self, tmp_path: Path, dataset: FAIR1M) -> None:
|
||||
shutil.rmtree(str(tmp_path))
|
||||
with pytest.raises(RuntimeError, match="Dataset not found in"):
|
||||
FAIR1M(root=str(tmp_path), split=dataset.split)
|
||||
|
||||
def test_plot(self, dataset: FAIR1M) -> None:
|
||||
x = dataset[0].copy()
|
||||
|
@ -69,6 +129,8 @@ class TestFAIR1M:
|
|||
plt.close()
|
||||
dataset.plot(x, show_titles=False)
|
||||
plt.close()
|
||||
x["prediction_boxes"] = x["boxes"].clone()
|
||||
dataset.plot(x)
|
||||
plt.close()
|
||||
|
||||
if dataset.split != "test":
|
||||
x["prediction_boxes"] = x["boxes"].clone()
|
||||
dataset.plot(x)
|
||||
plt.close()
|
||||
|
|
|
@ -5,9 +5,33 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from ..datasets import FAIR1M
|
||||
from .geo import NonGeoDataModule
|
||||
from .utils import dataset_split
|
||||
|
||||
|
||||
def collate_fn(batch: list[dict[str, Tensor]]) -> dict[str, Any]:
|
||||
"""Custom object detection collate fn to handle variable boxes.
|
||||
|
||||
Args:
|
||||
batch: list of sample dicts return by dataset
|
||||
|
||||
Returns:
|
||||
batch dict output
|
||||
|
||||
.. versionadded:: 0.5
|
||||
"""
|
||||
output: dict[str, Any] = {}
|
||||
output["image"] = torch.stack([sample["image"] for sample in batch])
|
||||
|
||||
if "boxes" in batch[0]:
|
||||
output["boxes"] = [sample["boxes"] for sample in batch]
|
||||
if "label" in batch[0]:
|
||||
output["label"] = [sample["label"] for sample in batch]
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class FAIR1MDataModule(NonGeoDataModule):
|
||||
|
@ -17,27 +41,21 @@ class FAIR1MDataModule(NonGeoDataModule):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 0,
|
||||
val_split_pct: float = 0.2,
|
||||
test_split_pct: float = 0.2,
|
||||
**kwargs: Any,
|
||||
self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
|
||||
) -> None:
|
||||
"""Initialize a new FAIR1MDataModule instance.
|
||||
|
||||
Args:
|
||||
batch_size: Size of each mini-batch.
|
||||
num_workers: Number of workers for parallel data loading.
|
||||
val_split_pct: Percentage of the dataset to use as a validation set.
|
||||
test_split_pct: Percentage of the dataset to use as a test set.
|
||||
**kwargs: Additional keyword arguments passed to
|
||||
:class:`~torchgeo.datasets.FAIR1M`.
|
||||
|
||||
.. versionchanged:: 0.5
|
||||
Removed *val_split_pct* and *test_split_pct* parameters.
|
||||
"""
|
||||
super().__init__(FAIR1M, batch_size, num_workers, **kwargs)
|
||||
|
||||
self.val_split_pct = val_split_pct
|
||||
self.test_split_pct = test_split_pct
|
||||
self.collate_fn = collate_fn
|
||||
|
||||
def setup(self, stage: str) -> None:
|
||||
"""Set up datasets.
|
||||
|
@ -45,7 +63,10 @@ class FAIR1MDataModule(NonGeoDataModule):
|
|||
Args:
|
||||
stage: Either 'fit', 'validate', 'test', or 'predict'.
|
||||
"""
|
||||
self.dataset = FAIR1M(**self.kwargs)
|
||||
self.train_dataset, self.val_dataset, self.test_dataset = dataset_split(
|
||||
self.dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct
|
||||
)
|
||||
if stage in ["fit"]:
|
||||
self.train_dataset = FAIR1M(split="train", **self.kwargs)
|
||||
if stage in ["fit", "validate"]:
|
||||
self.val_dataset = FAIR1M(split="val", **self.kwargs)
|
||||
if stage in ["predict"]:
|
||||
# Test set labels are not publicly available
|
||||
self.predict_dataset = FAIR1M(split="test", **self.kwargs)
|
||||
|
|
|
@ -16,7 +16,7 @@ from PIL import Image
|
|||
from torch import Tensor
|
||||
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import check_integrity, extract_archive
|
||||
from .utils import check_integrity, download_url, extract_archive
|
||||
|
||||
|
||||
def parse_pascal_voc(path: str) -> dict[str, Any]:
|
||||
|
@ -111,7 +111,7 @@ class FAIR1M(NonGeoDataset):
|
|||
|
||||
If you use this dataset in your research, please cite the following paper:
|
||||
|
||||
* https://arxiv.org/abs/2103.05569
|
||||
* https://doi.org/10.1016/j.isprsjprs.2021.12.004
|
||||
|
||||
.. versionadded:: 0.2
|
||||
"""
|
||||
|
@ -156,15 +156,81 @@ class FAIR1M(NonGeoDataset):
|
|||
"Bridge": {"id": 36, "category": "Road"},
|
||||
}
|
||||
|
||||
filename_glob = {
|
||||
"train": os.path.join("train", "**", "images", "*.tif"),
|
||||
"val": os.path.join("validation", "images", "*.tif"),
|
||||
"test": os.path.join("test", "images", "*.tif"),
|
||||
}
|
||||
directories = {
|
||||
"train": (
|
||||
os.path.join("train", "part1", "images"),
|
||||
os.path.join("train", "part1", "labelXml"),
|
||||
os.path.join("train", "part2", "images"),
|
||||
os.path.join("train", "part2", "labelXml"),
|
||||
),
|
||||
"val": (
|
||||
os.path.join("validation", "images"),
|
||||
os.path.join("validation", "labelXml"),
|
||||
),
|
||||
"test": (os.path.join("test", "images")),
|
||||
}
|
||||
paths = {
|
||||
"train": (
|
||||
os.path.join("train", "part1", "images.zip"),
|
||||
os.path.join("train", "part1", "labelXml.zip"),
|
||||
os.path.join("train", "part2", "images.zip"),
|
||||
os.path.join("train", "part2", "labelXmls.zip"),
|
||||
),
|
||||
"val": (
|
||||
os.path.join("validation", "images.zip"),
|
||||
os.path.join("validation", "labelXmls.zip"),
|
||||
),
|
||||
"test": (
|
||||
os.path.join("test", "images0.zip"),
|
||||
os.path.join("test", "images1.zip"),
|
||||
os.path.join("test", "images2.zip"),
|
||||
),
|
||||
}
|
||||
urls = {
|
||||
"train": (
|
||||
"https://drive.google.com/file/d/1LWT_ybL-s88Lzg9A9wHpj0h2rJHrqrVf",
|
||||
"https://drive.google.com/file/d/1CnOuS8oX6T9JMqQnfFsbmf7U38G6Vc8u",
|
||||
"https://drive.google.com/file/d/1cx4MRfpmh68SnGAYetNlDy68w0NgKucJ",
|
||||
"https://drive.google.com/file/d/1RFVjadTHA_bsB7BJwSZoQbiyM7KIDEUI",
|
||||
),
|
||||
"val": (
|
||||
"https://drive.google.com/file/d/1lSSHOD02B6_sUmr2b-R1iqhgWRQRw-S9",
|
||||
"https://drive.google.com/file/d/1sTTna1C5n3Senpfo-73PdiNilnja1AV4",
|
||||
),
|
||||
"test": (
|
||||
"https://drive.google.com/file/d/1HtOOVfK9qetDBjE7MM0dK_u5u7n4gdw3",
|
||||
"https://drive.google.com/file/d/1iXKCPmmJtRYcyuWCQC35bk97NmyAsasq",
|
||||
"https://drive.google.com/file/d/1oUc25FVf8Zcp4pzJ31A1j1sOLNHu63P0",
|
||||
),
|
||||
}
|
||||
md5s = {
|
||||
"train": (
|
||||
"a460fe6b1b5b276bf856ce9ac72d6568",
|
||||
"80f833ff355f91445c92a0c0c1fa7414",
|
||||
"ad237e61dba304fcef23cd14aa6c4280",
|
||||
"5c5948e68cd0f991a0d73f10956a3b05",
|
||||
),
|
||||
"val": ("dce782be65405aa381821b5f4d9eac94", "700b516a21edc9eae66ca315b72a09a1"),
|
||||
"test": (
|
||||
"fb8ccb274f3075d50ac9f7803fbafd3d",
|
||||
"dc9bbbdee000e97f02276aa61b03e585",
|
||||
"700b516a21edc9eae66ca315b72a09a1",
|
||||
),
|
||||
}
|
||||
image_root: str = "images"
|
||||
labels_root: str = "labelXml"
|
||||
filenames = ["images.zip", "labelXmls.zip"]
|
||||
md5s = ["a460fe6b1b5b276bf856ce9ac72d6568", "80f833ff355f91445c92a0c0c1fa7414"]
|
||||
label_root: str = "labelXml"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str = "data",
|
||||
split: str = "train",
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new FAIR1M dataset instance.
|
||||
|
@ -174,13 +240,24 @@ class FAIR1M(NonGeoDataset):
|
|||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version
|
||||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
|
||||
Raises:
|
||||
AssertionError: if ``split`` argument is invalid
|
||||
RuntimeError: if ``download=False`` and data is not found, or checksums
|
||||
don't match
|
||||
|
||||
.. versionchanged:: 0.5
|
||||
Added *split* and *download* parameters.
|
||||
"""
|
||||
assert split in self.directories
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.transforms = transforms
|
||||
self.download = download
|
||||
self.checksum = checksum
|
||||
self._verify()
|
||||
self.files = sorted(
|
||||
glob.glob(os.path.join(self.root, self.labels_root, "*.xml"))
|
||||
glob.glob(os.path.join(self.root, self.filename_glob[split]))
|
||||
)
|
||||
|
||||
def __getitem__(self, index: int) -> dict[str, Tensor]:
|
||||
|
@ -193,10 +270,16 @@ class FAIR1M(NonGeoDataset):
|
|||
data and label at that index
|
||||
"""
|
||||
path = self.files[index]
|
||||
parsed = parse_pascal_voc(path)
|
||||
image = self._load_image(parsed["filename"])
|
||||
boxes, labels = self._load_target(parsed["points"], parsed["labels"])
|
||||
sample = {"image": image, "boxes": boxes, "label": labels}
|
||||
|
||||
image = self._load_image(path)
|
||||
sample = {"image": image}
|
||||
|
||||
if self.split != "test":
|
||||
label_path = path.replace(self.image_root, self.label_root)
|
||||
label_path = label_path.replace(".tif", ".xml")
|
||||
voc = parse_pascal_voc(label_path)
|
||||
boxes, labels = self._load_target(voc["points"], voc["labels"])
|
||||
sample = {"image": image, "boxes": boxes, "label": labels}
|
||||
|
||||
if self.transforms is not None:
|
||||
sample = self.transforms(sample)
|
||||
|
@ -220,7 +303,6 @@ class FAIR1M(NonGeoDataset):
|
|||
Returns:
|
||||
the image
|
||||
"""
|
||||
path = os.path.join(self.root, self.image_root, path)
|
||||
with Image.open(path) as img:
|
||||
array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB"))
|
||||
tensor = torch.from_numpy(array)
|
||||
|
@ -251,17 +333,19 @@ class FAIR1M(NonGeoDataset):
|
|||
Raises:
|
||||
RuntimeError: if checksum fails or the dataset is not found
|
||||
"""
|
||||
# Check if the files already exist
|
||||
# Check if the directories already exist
|
||||
exists = []
|
||||
for directory in [self.image_root, self.labels_root]:
|
||||
for directory in self.directories[self.split]:
|
||||
exists.append(os.path.exists(os.path.join(self.root, directory)))
|
||||
if all(exists):
|
||||
return
|
||||
|
||||
# Check if .zip files already exists (if so extract)
|
||||
exists = []
|
||||
for filename, md5 in zip(self.filenames, self.md5s):
|
||||
filepath = os.path.join(self.root, filename)
|
||||
paths = self.paths[self.split]
|
||||
md5s = self.md5s[self.split]
|
||||
for path, md5 in zip(paths, md5s):
|
||||
filepath = os.path.join(self.root, path)
|
||||
if os.path.isfile(filepath):
|
||||
if self.checksum and not check_integrity(filepath, md5):
|
||||
raise RuntimeError("Dataset found, but corrupted.")
|
||||
|
@ -273,11 +357,39 @@ class FAIR1M(NonGeoDataset):
|
|||
if all(exists):
|
||||
return
|
||||
|
||||
if self.download:
|
||||
self._download()
|
||||
return
|
||||
|
||||
raise RuntimeError(
|
||||
"Dataset not found in `root` directory, "
|
||||
"specify a different `root` directory."
|
||||
f"Dataset not found in `root={self.root}` and `download=False`, "
|
||||
"either specify a different `root` directory or use `download=True` "
|
||||
"to automatically download the dataset."
|
||||
)
|
||||
|
||||
def _download(self) -> None:
|
||||
"""Download the dataset and extract it.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if download doesn't work correctly or checksums don't match
|
||||
"""
|
||||
paths = self.paths[self.split]
|
||||
urls = self.urls[self.split]
|
||||
md5s = self.md5s[self.split]
|
||||
for directory in self.directories[self.split]:
|
||||
os.makedirs(os.path.join(self.root, directory), exist_ok=True)
|
||||
|
||||
for path, url, md5 in zip(paths, urls, md5s):
|
||||
filepath = os.path.join(self.root, path)
|
||||
if not os.path.exists(filepath):
|
||||
download_url(
|
||||
url=url,
|
||||
root=os.path.dirname(filepath),
|
||||
filename=os.path.basename(filepath),
|
||||
md5=md5 if self.checksum else None,
|
||||
)
|
||||
extract_archive(filepath)
|
||||
|
||||
def plot(
|
||||
self,
|
||||
sample: dict[str, Tensor],
|
||||
|
@ -306,12 +418,14 @@ class FAIR1M(NonGeoDataset):
|
|||
|
||||
axs[0].imshow(image)
|
||||
axs[0].axis("off")
|
||||
polygons = [
|
||||
patches.Polygon(points, color="r", fill=False)
|
||||
for points in sample["boxes"].numpy()
|
||||
]
|
||||
for polygon in polygons:
|
||||
axs[0].add_patch(polygon)
|
||||
|
||||
if "boxes" in sample:
|
||||
polygons = [
|
||||
patches.Polygon(points, color="r", fill=False)
|
||||
for points in sample["boxes"].numpy()
|
||||
]
|
||||
for polygon in polygons:
|
||||
axs[0].add_patch(polygon)
|
||||
|
||||
if show_titles:
|
||||
axs[0].set_title("Ground Truth")
|
||||
|
|
Загрузка…
Ссылка в новой задаче