Update FAIR1M dataset and datamodule (#1275)

This commit is contained in:
Isaac Corley 2023-04-26 07:00:54 -05:00 коммит произвёл GitHub
Родитель 698d2b5c97
Коммит 28615a19df
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
43 изменённых файлов: 500 добавлений и 82 удалений

Двоичные данные
tests/data/fair1m/images.zip

Двоичный файл не отображается.

Двоичные данные
tests/data/fair1m/labelXmls.zip

Двоичный файл не отображается.

Просмотреть файл

Просмотреть файл

Просмотреть файл

Просмотреть файл

Двоичные данные
tests/data/fair1m/test/images0.zip Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/fair1m/test/images1.zip Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/fair1m/test/images2.zip Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/fair1m/train/part1/images.zip Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/fair1m/train/part1/images/0.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/fair1m/train/part1/images/1.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/fair1m/train/part1/images/2.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/fair1m/train/part1/images/3.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/fair1m/train/part1/labelXml.zip Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/fair1m/train/part2/images.zip Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/fair1m/train/part2/images/0.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/fair1m/train/part2/images/1.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/fair1m/train/part2/images/2.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/fair1m/train/part2/images/3.tif Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,28 @@
<?xml version="1.0" encoding="utf-8"?>
<annotation>
<source>
<filename>0.tif</filename>
</source>
<size>
<width>2</width>
<height>2</height>
<depth>3</depth>
</size>
<objects>
<object>
<coordinate>pixel</coordinate>
<type>rectangle</type>
<description>None</description>
<possibleresult>
<name>Liquid Cargo Ship</name>
</possibleresult>
<points>
<point>0.000000,0.000000</point>
<point>1.000000,0.000000</point>
<point>1.000000,1.000000</point>
<point>0.000000,1.000000</point>
<point>0.000000,0.000000</point>
</points>
</object>
</objects>
</annotation>

Просмотреть файл

@ -0,0 +1,28 @@
<?xml version="1.0" encoding="utf-8"?>
<annotation>
<source>
<filename>1.tif</filename>
</source>
<size>
<width>2</width>
<height>2</height>
<depth>3</depth>
</size>
<objects>
<object>
<coordinate>pixel</coordinate>
<type>rectangle</type>
<description>None</description>
<possibleresult>
<name>Cargo Truck</name>
</possibleresult>
<points>
<point>0.000000,0.000000</point>
<point>1.000000,0.000000</point>
<point>1.000000,1.000000</point>
<point>0.000000,1.000000</point>
<point>0.000000,0.000000</point>
</points>
</object>
</objects>
</annotation>

Просмотреть файл

@ -0,0 +1,28 @@
<?xml version="1.0" encoding="utf-8"?>
<annotation>
<source>
<filename>2.tif</filename>
</source>
<size>
<width>2</width>
<height>2</height>
<depth>3</depth>
</size>
<objects>
<object>
<coordinate>pixel</coordinate>
<type>rectangle</type>
<description>None</description>
<possibleresult>
<name>Boeing737</name>
</possibleresult>
<points>
<point>0.000000,0.000000</point>
<point>1.000000,0.000000</point>
<point>1.000000,1.000000</point>
<point>0.000000,1.000000</point>
<point>0.000000,0.000000</point>
</points>
</object>
</objects>
</annotation>

Просмотреть файл

@ -0,0 +1,28 @@
<?xml version="1.0" encoding="utf-8"?>
<annotation>
<source>
<filename>3.tif</filename>
</source>
<size>
<width>2</width>
<height>2</height>
<depth>3</depth>
</size>
<objects>
<object>
<coordinate>pixel</coordinate>
<type>rectangle</type>
<description>None</description>
<possibleresult>
<name>A220</name>
</possibleresult>
<points>
<point>0.000000,0.000000</point>
<point>1.000000,0.000000</point>
<point>1.000000,1.000000</point>
<point>0.000000,1.000000</point>
<point>0.000000,0.000000</point>
</points>
</object>
</objects>
</annotation>

Двоичные данные
tests/data/fair1m/train/part2/labelXmls.zip Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/fair1m/validation/images.zip Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/fair1m/validation/images/0.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/fair1m/validation/images/1.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/fair1m/validation/images/2.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/fair1m/validation/images/3.tif Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,28 @@
<?xml version="1.0" encoding="utf-8"?>
<annotation>
<source>
<filename>0.tif</filename>
</source>
<size>
<width>2</width>
<height>2</height>
<depth>3</depth>
</size>
<objects>
<object>
<coordinate>pixel</coordinate>
<type>rectangle</type>
<description>None</description>
<possibleresult>
<name>Liquid Cargo Ship</name>
</possibleresult>
<points>
<point>0.000000,0.000000</point>
<point>1.000000,0.000000</point>
<point>1.000000,1.000000</point>
<point>0.000000,1.000000</point>
<point>0.000000,0.000000</point>
</points>
</object>
</objects>
</annotation>

Просмотреть файл

@ -0,0 +1,28 @@
<?xml version="1.0" encoding="utf-8"?>
<annotation>
<source>
<filename>1.tif</filename>
</source>
<size>
<width>2</width>
<height>2</height>
<depth>3</depth>
</size>
<objects>
<object>
<coordinate>pixel</coordinate>
<type>rectangle</type>
<description>None</description>
<possibleresult>
<name>Cargo Truck</name>
</possibleresult>
<points>
<point>0.000000,0.000000</point>
<point>1.000000,0.000000</point>
<point>1.000000,1.000000</point>
<point>0.000000,1.000000</point>
<point>0.000000,0.000000</point>
</points>
</object>
</objects>
</annotation>

Просмотреть файл

@ -0,0 +1,28 @@
<?xml version="1.0" encoding="utf-8"?>
<annotation>
<source>
<filename>2.tif</filename>
</source>
<size>
<width>2</width>
<height>2</height>
<depth>3</depth>
</size>
<objects>
<object>
<coordinate>pixel</coordinate>
<type>rectangle</type>
<description>None</description>
<possibleresult>
<name>Boeing737</name>
</possibleresult>
<points>
<point>0.000000,0.000000</point>
<point>1.000000,0.000000</point>
<point>1.000000,1.000000</point>
<point>0.000000,1.000000</point>
<point>0.000000,0.000000</point>
</points>
</object>
</objects>
</annotation>

Просмотреть файл

@ -0,0 +1,28 @@
<?xml version="1.0" encoding="utf-8"?>
<annotation>
<source>
<filename>3.tif</filename>
</source>
<size>
<width>2</width>
<height>2</height>
<depth>3</depth>
</size>
<objects>
<object>
<coordinate>pixel</coordinate>
<type>rectangle</type>
<description>None</description>
<possibleresult>
<name>A220</name>
</possibleresult>
<points>
<point>0.000000,0.000000</point>
<point>1.000000,0.000000</point>
<point>1.000000,1.000000</point>
<point>0.000000,1.000000</point>
<point>0.000000,0.000000</point>
</points>
</object>
</objects>
</annotation>

Двоичные данные
tests/data/fair1m/validation/labelXmls.zip Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -7,7 +7,6 @@ import matplotlib.pyplot as plt
import pytest
from torchgeo.datamodules import FAIR1MDataModule
from torchgeo.datasets import unbind_samples
class TestFAIR1MDataModule:
@ -16,13 +15,7 @@ class TestFAIR1MDataModule:
root = os.path.join("tests", "data", "fair1m")
batch_size = 2
num_workers = 0
dm = FAIR1MDataModule(
root=root,
batch_size=batch_size,
num_workers=num_workers,
val_split_pct=0.33,
test_split_pct=0.33,
)
dm = FAIR1MDataModule(root=root, batch_size=batch_size, num_workers=num_workers)
return dm
def test_train_dataloader(self, datamodule: FAIR1MDataModule) -> None:
@ -33,13 +26,17 @@ class TestFAIR1MDataModule:
datamodule.setup("validate")
next(iter(datamodule.val_dataloader()))
def test_test_dataloader(self, datamodule: FAIR1MDataModule) -> None:
datamodule.setup("test")
next(iter(datamodule.test_dataloader()))
def test_predict_dataloader(self, datamodule: FAIR1MDataModule) -> None:
datamodule.setup("predict")
next(iter(datamodule.predict_dataloader()))
def test_plot(self, datamodule: FAIR1MDataModule) -> None:
datamodule.setup("validate")
batch = next(iter(datamodule.val_dataloader()))
sample = unbind_samples(batch)[0]
sample = {
"image": batch["image"][0],
"boxes": batch["boxes"][0],
"label": batch["label"][0],
}
datamodule.plot(sample)
plt.close()

Просмотреть файл

@ -9,59 +9,119 @@ import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import FAIR1M
def download_url(url: str, root: str, filename: str, *args: str, **kwargs: str) -> None:
os.makedirs(root, exist_ok=True)
shutil.copy(url, os.path.join(root, filename))
class TestFAIR1M:
@pytest.fixture
def dataset(self, monkeypatch: MonkeyPatch) -> FAIR1M:
md5s = ["f278aba757de9079225db42107e09e30", "ecef7bd264fcbc533bec5e9e1cacaff1"]
test_root = os.path.join("tests", "data", "fair1m")
@pytest.fixture(params=["train", "val", "test"])
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> FAIR1M:
monkeypatch.setattr(torchgeo.datasets.fair1m, "download_url", download_url)
urls = {
"train": (
os.path.join(self.test_root, "train", "part1", "images.zip"),
os.path.join(self.test_root, "train", "part1", "labelXml.zip"),
os.path.join(self.test_root, "train", "part2", "images.zip"),
os.path.join(self.test_root, "train", "part2", "labelXmls.zip"),
),
"val": (
os.path.join(self.test_root, "validation", "images.zip"),
os.path.join(self.test_root, "validation", "labelXmls.zip"),
),
"test": (
os.path.join(self.test_root, "test", "images0.zip"),
os.path.join(self.test_root, "test", "images1.zip"),
os.path.join(self.test_root, "test", "images2.zip"),
),
}
md5s = {
"train": (
"ffbe9329e51ae83161ce24b5b46dc934",
"2db6fbe64be6ebb0a03656da6c6effe7",
"401b0f1d75d9d23f2e088bfeaf274cfa",
"d62b18eae8c3201f6112c2e9db84d605",
),
"val": (
"83d2f06574fc7158ded0eb1fb256c8fe",
"316490b200503c54cf43835a341b6dbe",
),
"test": (
"3c02845752667b96a5749c90c7fdc994",
"9359107f1d0abac6a5b98725f4064bc0",
"d7bc2985c625ffd47d86cdabb2a9d2bc",
),
}
monkeypatch.setattr(FAIR1M, "urls", urls)
monkeypatch.setattr(FAIR1M, "md5s", md5s)
root = os.path.join("tests", "data", "fair1m")
root = str(tmp_path)
split = request.param
transforms = nn.Identity()
return FAIR1M(root, transforms)
return FAIR1M(root, split, transforms, download=True, checksum=True)
def test_getitem(self, dataset: FAIR1M) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["boxes"], torch.Tensor)
assert isinstance(x["label"], torch.Tensor)
assert x["image"].shape[0] == 3
assert x["boxes"].shape[-2:] == (5, 2)
assert x["label"].ndim == 1
if dataset.split != "test":
assert isinstance(x["boxes"], torch.Tensor)
assert isinstance(x["label"], torch.Tensor)
assert x["boxes"].shape[-2:] == (5, 2)
assert x["label"].ndim == 1
def test_len(self, dataset: FAIR1M) -> None:
assert len(dataset) == 4
if dataset.split == "train":
assert len(dataset) == 8
else:
assert len(dataset) == 4
def test_already_downloaded(self, dataset: FAIR1M, tmp_path: Path) -> None:
shutil.rmtree(str(tmp_path))
shutil.copytree(dataset.root, str(tmp_path))
FAIR1M(root=str(tmp_path))
FAIR1M(root=str(tmp_path), split=dataset.split, download=True)
def test_already_downloaded_not_extracted(
self, dataset: FAIR1M, tmp_path: Path
) -> None:
for filename in dataset.filenames:
filepath = os.path.join("tests", "data", "fair1m", filename)
shutil.copy(filepath, str(tmp_path))
FAIR1M(root=str(tmp_path), checksum=True)
shutil.rmtree(dataset.root)
for filepath, url in zip(
dataset.paths[dataset.split], dataset.urls[dataset.split]
):
output = os.path.join(str(tmp_path), filepath)
os.makedirs(os.path.dirname(output), exist_ok=True)
download_url(url, root=os.path.dirname(output), filename=output)
FAIR1M(root=str(tmp_path), split=dataset.split, checksum=True)
def test_corrupted(self, tmp_path: Path, dataset: FAIR1M) -> None:
md5s = tuple(["randomhash"] * len(FAIR1M.md5s[dataset.split]))
FAIR1M.md5s[dataset.split] = md5s
shutil.rmtree(dataset.root)
for filepath, url in zip(
dataset.paths[dataset.split], dataset.urls[dataset.split]
):
output = os.path.join(str(tmp_path), filepath)
os.makedirs(os.path.dirname(output), exist_ok=True)
download_url(url, root=os.path.dirname(output), filename=output)
def test_corrupted(self, tmp_path: Path) -> None:
filenames = ["images.zip", "labelXmls.zip"]
for filename in filenames:
with open(os.path.join(tmp_path, filename), "w") as f:
f.write("bad")
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
FAIR1M(root=str(tmp_path), checksum=True)
FAIR1M(root=str(tmp_path), split=dataset.split, checksum=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
err = "Dataset not found in `root` directory, "
"specify a different `root` directory."
with pytest.raises(RuntimeError, match=err):
FAIR1M(str(tmp_path))
def test_not_downloaded(self, tmp_path: Path, dataset: FAIR1M) -> None:
shutil.rmtree(str(tmp_path))
with pytest.raises(RuntimeError, match="Dataset not found in"):
FAIR1M(root=str(tmp_path), split=dataset.split)
def test_plot(self, dataset: FAIR1M) -> None:
x = dataset[0].copy()
@ -69,6 +129,8 @@ class TestFAIR1M:
plt.close()
dataset.plot(x, show_titles=False)
plt.close()
x["prediction_boxes"] = x["boxes"].clone()
dataset.plot(x)
plt.close()
if dataset.split != "test":
x["prediction_boxes"] = x["boxes"].clone()
dataset.plot(x)
plt.close()

Просмотреть файл

@ -5,9 +5,33 @@
from typing import Any
import torch
from torch import Tensor
from ..datasets import FAIR1M
from .geo import NonGeoDataModule
from .utils import dataset_split
def collate_fn(batch: list[dict[str, Tensor]]) -> dict[str, Any]:
"""Custom object detection collate fn to handle variable boxes.
Args:
batch: list of sample dicts return by dataset
Returns:
batch dict output
.. versionadded:: 0.5
"""
output: dict[str, Any] = {}
output["image"] = torch.stack([sample["image"] for sample in batch])
if "boxes" in batch[0]:
output["boxes"] = [sample["boxes"] for sample in batch]
if "label" in batch[0]:
output["label"] = [sample["label"] for sample in batch]
return output
class FAIR1MDataModule(NonGeoDataModule):
@ -17,27 +41,21 @@ class FAIR1MDataModule(NonGeoDataModule):
"""
def __init__(
self,
batch_size: int = 64,
num_workers: int = 0,
val_split_pct: float = 0.2,
test_split_pct: float = 0.2,
**kwargs: Any,
self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
) -> None:
"""Initialize a new FAIR1MDataModule instance.
Args:
batch_size: Size of each mini-batch.
num_workers: Number of workers for parallel data loading.
val_split_pct: Percentage of the dataset to use as a validation set.
test_split_pct: Percentage of the dataset to use as a test set.
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.FAIR1M`.
.. versionchanged:: 0.5
Removed *val_split_pct* and *test_split_pct* parameters.
"""
super().__init__(FAIR1M, batch_size, num_workers, **kwargs)
self.val_split_pct = val_split_pct
self.test_split_pct = test_split_pct
self.collate_fn = collate_fn
def setup(self, stage: str) -> None:
"""Set up datasets.
@ -45,7 +63,10 @@ class FAIR1MDataModule(NonGeoDataModule):
Args:
stage: Either 'fit', 'validate', 'test', or 'predict'.
"""
self.dataset = FAIR1M(**self.kwargs)
self.train_dataset, self.val_dataset, self.test_dataset = dataset_split(
self.dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct
)
if stage in ["fit"]:
self.train_dataset = FAIR1M(split="train", **self.kwargs)
if stage in ["fit", "validate"]:
self.val_dataset = FAIR1M(split="val", **self.kwargs)
if stage in ["predict"]:
# Test set labels are not publicly available
self.predict_dataset = FAIR1M(split="test", **self.kwargs)

Просмотреть файл

@ -16,7 +16,7 @@ from PIL import Image
from torch import Tensor
from .geo import NonGeoDataset
from .utils import check_integrity, extract_archive
from .utils import check_integrity, download_url, extract_archive
def parse_pascal_voc(path: str) -> dict[str, Any]:
@ -111,7 +111,7 @@ class FAIR1M(NonGeoDataset):
If you use this dataset in your research, please cite the following paper:
* https://arxiv.org/abs/2103.05569
* https://doi.org/10.1016/j.isprsjprs.2021.12.004
.. versionadded:: 0.2
"""
@ -156,15 +156,81 @@ class FAIR1M(NonGeoDataset):
"Bridge": {"id": 36, "category": "Road"},
}
filename_glob = {
"train": os.path.join("train", "**", "images", "*.tif"),
"val": os.path.join("validation", "images", "*.tif"),
"test": os.path.join("test", "images", "*.tif"),
}
directories = {
"train": (
os.path.join("train", "part1", "images"),
os.path.join("train", "part1", "labelXml"),
os.path.join("train", "part2", "images"),
os.path.join("train", "part2", "labelXml"),
),
"val": (
os.path.join("validation", "images"),
os.path.join("validation", "labelXml"),
),
"test": (os.path.join("test", "images")),
}
paths = {
"train": (
os.path.join("train", "part1", "images.zip"),
os.path.join("train", "part1", "labelXml.zip"),
os.path.join("train", "part2", "images.zip"),
os.path.join("train", "part2", "labelXmls.zip"),
),
"val": (
os.path.join("validation", "images.zip"),
os.path.join("validation", "labelXmls.zip"),
),
"test": (
os.path.join("test", "images0.zip"),
os.path.join("test", "images1.zip"),
os.path.join("test", "images2.zip"),
),
}
urls = {
"train": (
"https://drive.google.com/file/d/1LWT_ybL-s88Lzg9A9wHpj0h2rJHrqrVf",
"https://drive.google.com/file/d/1CnOuS8oX6T9JMqQnfFsbmf7U38G6Vc8u",
"https://drive.google.com/file/d/1cx4MRfpmh68SnGAYetNlDy68w0NgKucJ",
"https://drive.google.com/file/d/1RFVjadTHA_bsB7BJwSZoQbiyM7KIDEUI",
),
"val": (
"https://drive.google.com/file/d/1lSSHOD02B6_sUmr2b-R1iqhgWRQRw-S9",
"https://drive.google.com/file/d/1sTTna1C5n3Senpfo-73PdiNilnja1AV4",
),
"test": (
"https://drive.google.com/file/d/1HtOOVfK9qetDBjE7MM0dK_u5u7n4gdw3",
"https://drive.google.com/file/d/1iXKCPmmJtRYcyuWCQC35bk97NmyAsasq",
"https://drive.google.com/file/d/1oUc25FVf8Zcp4pzJ31A1j1sOLNHu63P0",
),
}
md5s = {
"train": (
"a460fe6b1b5b276bf856ce9ac72d6568",
"80f833ff355f91445c92a0c0c1fa7414",
"ad237e61dba304fcef23cd14aa6c4280",
"5c5948e68cd0f991a0d73f10956a3b05",
),
"val": ("dce782be65405aa381821b5f4d9eac94", "700b516a21edc9eae66ca315b72a09a1"),
"test": (
"fb8ccb274f3075d50ac9f7803fbafd3d",
"dc9bbbdee000e97f02276aa61b03e585",
"700b516a21edc9eae66ca315b72a09a1",
),
}
image_root: str = "images"
labels_root: str = "labelXml"
filenames = ["images.zip", "labelXmls.zip"]
md5s = ["a460fe6b1b5b276bf856ce9ac72d6568", "80f833ff355f91445c92a0c0c1fa7414"]
label_root: str = "labelXml"
def __init__(
self,
root: str = "data",
split: str = "train",
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
download: bool = False,
checksum: bool = False,
) -> None:
"""Initialize a new FAIR1M dataset instance.
@ -174,13 +240,24 @@ class FAIR1M(NonGeoDataset):
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
AssertionError: if ``split`` argument is invalid
RuntimeError: if ``download=False`` and data is not found, or checksums
don't match
.. versionchanged:: 0.5
Added *split* and *download* parameters.
"""
assert split in self.directories
self.root = root
self.split = split
self.transforms = transforms
self.download = download
self.checksum = checksum
self._verify()
self.files = sorted(
glob.glob(os.path.join(self.root, self.labels_root, "*.xml"))
glob.glob(os.path.join(self.root, self.filename_glob[split]))
)
def __getitem__(self, index: int) -> dict[str, Tensor]:
@ -193,10 +270,16 @@ class FAIR1M(NonGeoDataset):
data and label at that index
"""
path = self.files[index]
parsed = parse_pascal_voc(path)
image = self._load_image(parsed["filename"])
boxes, labels = self._load_target(parsed["points"], parsed["labels"])
sample = {"image": image, "boxes": boxes, "label": labels}
image = self._load_image(path)
sample = {"image": image}
if self.split != "test":
label_path = path.replace(self.image_root, self.label_root)
label_path = label_path.replace(".tif", ".xml")
voc = parse_pascal_voc(label_path)
boxes, labels = self._load_target(voc["points"], voc["labels"])
sample = {"image": image, "boxes": boxes, "label": labels}
if self.transforms is not None:
sample = self.transforms(sample)
@ -220,7 +303,6 @@ class FAIR1M(NonGeoDataset):
Returns:
the image
"""
path = os.path.join(self.root, self.image_root, path)
with Image.open(path) as img:
array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB"))
tensor = torch.from_numpy(array)
@ -251,17 +333,19 @@ class FAIR1M(NonGeoDataset):
Raises:
RuntimeError: if checksum fails or the dataset is not found
"""
# Check if the files already exist
# Check if the directories already exist
exists = []
for directory in [self.image_root, self.labels_root]:
for directory in self.directories[self.split]:
exists.append(os.path.exists(os.path.join(self.root, directory)))
if all(exists):
return
# Check if .zip files already exists (if so extract)
exists = []
for filename, md5 in zip(self.filenames, self.md5s):
filepath = os.path.join(self.root, filename)
paths = self.paths[self.split]
md5s = self.md5s[self.split]
for path, md5 in zip(paths, md5s):
filepath = os.path.join(self.root, path)
if os.path.isfile(filepath):
if self.checksum and not check_integrity(filepath, md5):
raise RuntimeError("Dataset found, but corrupted.")
@ -273,11 +357,39 @@ class FAIR1M(NonGeoDataset):
if all(exists):
return
if self.download:
self._download()
return
raise RuntimeError(
"Dataset not found in `root` directory, "
"specify a different `root` directory."
f"Dataset not found in `root={self.root}` and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automatically download the dataset."
)
def _download(self) -> None:
"""Download the dataset and extract it.
Raises:
RuntimeError: if download doesn't work correctly or checksums don't match
"""
paths = self.paths[self.split]
urls = self.urls[self.split]
md5s = self.md5s[self.split]
for directory in self.directories[self.split]:
os.makedirs(os.path.join(self.root, directory), exist_ok=True)
for path, url, md5 in zip(paths, urls, md5s):
filepath = os.path.join(self.root, path)
if not os.path.exists(filepath):
download_url(
url=url,
root=os.path.dirname(filepath),
filename=os.path.basename(filepath),
md5=md5 if self.checksum else None,
)
extract_archive(filepath)
def plot(
self,
sample: dict[str, Tensor],
@ -306,12 +418,14 @@ class FAIR1M(NonGeoDataset):
axs[0].imshow(image)
axs[0].axis("off")
polygons = [
patches.Polygon(points, color="r", fill=False)
for points in sample["boxes"].numpy()
]
for polygon in polygons:
axs[0].add_patch(polygon)
if "boxes" in sample:
polygons = [
patches.Polygon(points, color="r", fill=False)
for points in sample["boxes"].numpy()
]
for polygon in polygons:
axs[0].add_patch(polygon)
if show_titles:
axs[0].set_title("Ground Truth")