зеркало из https://github.com/microsoft/torchgeo.git
Adding the RwandaFieldBoundary dataset (#1574)
* Initial commit * Add to docs * Pyupgrade * Pyupgrade * Added tests * who actually cares about lines that are 91 characters long * Using Figure from matplotlib.figure instead of matplotlib.pyplot to make mypy happy even though they are the same thing * Documentation updates * Update docs link --------- Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Родитель
2e01f99655
Коммит
b6d78b74c3
|
@ -297,6 +297,11 @@ RESISC45
|
|||
|
||||
.. autoclass:: RESISC45
|
||||
|
||||
Rwanda Field Boundary
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: RwandaFieldBoundary
|
||||
|
||||
Seasonal Contrast
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
|
|
@ -26,6 +26,7 @@ Dataset,Task,Source,# Samples,# Classes,Size (px),Resolution (m),Bands
|
|||
`Potsdam`_,S,Aerial,38,6,"6,000x6,000",0.05,MSI
|
||||
`ReforesTree`_,"OD, R",Aerial,100,6,"4,000x4,000",0.02,RGB
|
||||
`RESISC45`_,C,Google Earth,"31,500",45,256x256,0.2--30,RGB
|
||||
`Rwanda Field Boundary`_,S,Planetscope,70,2,256x256,4.7,RGB + NIR
|
||||
`Seasonal Contrast`_,T,Sentinel-2,100K--1M,-,264x264,10,MSI
|
||||
`SeasoNet`_,S,Sentinel-2,"1,759,830",33,120x120,10,MSI
|
||||
`SEN12MS`_,S,"Sentinel-1/2, MODIS","180,662",33,256x256,10,"SAR, MSI"
|
||||
|
|
|
|
@ -0,0 +1,101 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import numpy as np
|
||||
import rasterio
|
||||
|
||||
dates = ("2021_03", "2021_04", "2021_08", "2021_10", "2021_11", "2021_12")
|
||||
all_bands = ("B01", "B02", "B03", "B04")
|
||||
|
||||
SIZE = 32
|
||||
NUM_SAMPLES = 5
|
||||
np.random.seed(0)
|
||||
|
||||
|
||||
def create_mask(fn: str) -> None:
|
||||
profile = {
|
||||
"driver": "GTiff",
|
||||
"dtype": "uint8",
|
||||
"nodata": 0.0,
|
||||
"width": SIZE,
|
||||
"height": SIZE,
|
||||
"count": 1,
|
||||
"crs": "epsg:3857",
|
||||
"compress": "lzw",
|
||||
"predictor": 2,
|
||||
"transform": rasterio.Affine(10.0, 0.0, 0.0, 0.0, -10.0, 0.0),
|
||||
"blockysize": 32,
|
||||
"tiled": False,
|
||||
"interleave": "band",
|
||||
}
|
||||
with rasterio.open(fn, "w", **profile) as f:
|
||||
f.write(np.random.randint(0, 2, size=(SIZE, SIZE), dtype=np.uint8), 1)
|
||||
|
||||
|
||||
def create_img(fn: str) -> None:
|
||||
profile = {
|
||||
"driver": "GTiff",
|
||||
"dtype": "uint16",
|
||||
"nodata": 0.0,
|
||||
"width": SIZE,
|
||||
"height": SIZE,
|
||||
"count": 1,
|
||||
"crs": "epsg:3857",
|
||||
"compress": "lzw",
|
||||
"predictor": 2,
|
||||
"blockysize": 16,
|
||||
"transform": rasterio.Affine(10.0, 0.0, 0.0, 0.0, -10.0, 0.0),
|
||||
"tiled": False,
|
||||
"interleave": "band",
|
||||
}
|
||||
with rasterio.open(fn, "w", **profile) as f:
|
||||
f.write(np.random.randint(0, 2, size=(SIZE, SIZE), dtype=np.uint16), 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Train and test images
|
||||
for split in ("train", "test"):
|
||||
for i in range(NUM_SAMPLES):
|
||||
for date in dates:
|
||||
directory = os.path.join(
|
||||
f"nasa_rwanda_field_boundary_competition_source_{split}",
|
||||
f"nasa_rwanda_field_boundary_competition_source_{split}_{i:02d}_{date}", # noqa: E501
|
||||
)
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
for band in all_bands:
|
||||
create_img(os.path.join(directory, f"{band}.tif"))
|
||||
|
||||
# Create collections.json, this isn't used by the dataset but is checked to
|
||||
# exist
|
||||
with open(
|
||||
f"nasa_rwanda_field_boundary_competition_source_{split}/collections.json",
|
||||
"w",
|
||||
) as f:
|
||||
f.write("Not used")
|
||||
|
||||
# Train labels
|
||||
for i in range(NUM_SAMPLES):
|
||||
directory = os.path.join(
|
||||
"nasa_rwanda_field_boundary_competition_labels_train",
|
||||
f"nasa_rwanda_field_boundary_competition_labels_train_{i:02d}",
|
||||
)
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
create_mask(os.path.join(directory, "raster_labels.tif"))
|
||||
|
||||
# Create directories and compute checksums
|
||||
for filename in [
|
||||
"nasa_rwanda_field_boundary_competition_source_train",
|
||||
"nasa_rwanda_field_boundary_competition_source_test",
|
||||
"nasa_rwanda_field_boundary_competition_labels_train",
|
||||
]:
|
||||
shutil.make_archive(filename, "gztar", ".", filename)
|
||||
# Compute checksums
|
||||
with open(f"{filename}.tar.gz", "rb") as f:
|
||||
md5 = hashlib.md5(f.read()).hexdigest()
|
||||
print(f"{filename}: {md5}")
|
Двоичные данные
tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_labels_train.tar.gz
Normal file
Двоичные данные
tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_labels_train.tar.gz
Normal file
Двоичный файл не отображается.
Двоичные данные
tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_test.tar.gz
Normal file
Двоичные данные
tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_test.tar.gz
Normal file
Двоичный файл не отображается.
Двоичные данные
tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_train.tar.gz
Normal file
Двоичные данные
tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_train.tar.gz
Normal file
Двоичный файл не отображается.
|
@ -0,0 +1,140 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.fixtures import SubRequest
|
||||
from pytest import MonkeyPatch
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
from torchgeo.datasets import RwandaFieldBoundary
|
||||
|
||||
|
||||
class Collection:
|
||||
def download(self, output_dir: str, **kwargs: str) -> None:
|
||||
glob_path = os.path.join("tests", "data", "rwanda_field_boundary", "*.tar.gz")
|
||||
for tarball in glob.iglob(glob_path):
|
||||
shutil.copy(tarball, output_dir)
|
||||
|
||||
|
||||
def fetch(dataset_id: str, **kwargs: str) -> Collection:
|
||||
return Collection()
|
||||
|
||||
|
||||
class TestRwandaFieldBoundary:
|
||||
@pytest.fixture(params=["train", "test"])
|
||||
def dataset(
|
||||
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
|
||||
) -> RwandaFieldBoundary:
|
||||
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.3")
|
||||
monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch)
|
||||
monkeypatch.setattr(
|
||||
RwandaFieldBoundary, "number_of_patches_per_split", {"train": 5, "test": 5}
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
RwandaFieldBoundary,
|
||||
"md5s",
|
||||
{
|
||||
"train_images": "af9395e2e49deefebb35fa65fa378ba3",
|
||||
"test_images": "d104bb82323a39e7c3b3b7dd0156f550",
|
||||
"train_labels": "6cceaf16a141cf73179253a783e7d51b",
|
||||
},
|
||||
)
|
||||
|
||||
root = str(tmp_path)
|
||||
split = request.param
|
||||
transforms = nn.Identity()
|
||||
return RwandaFieldBoundary(
|
||||
root, split, transforms=transforms, api_key="", download=True, checksum=True
|
||||
)
|
||||
|
||||
def test_getitem(self, dataset: RwandaFieldBoundary) -> None:
|
||||
x = dataset[0]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x["image"], torch.Tensor)
|
||||
if dataset.split == "train":
|
||||
assert isinstance(x["mask"], torch.Tensor)
|
||||
else:
|
||||
assert "mask" not in x
|
||||
|
||||
def test_len(self, dataset: RwandaFieldBoundary) -> None:
|
||||
assert len(dataset) == 5
|
||||
|
||||
def test_add(self, dataset: RwandaFieldBoundary) -> None:
|
||||
ds = dataset + dataset
|
||||
assert isinstance(ds, ConcatDataset)
|
||||
assert len(ds) == 10
|
||||
|
||||
def test_needs_extraction(self, tmp_path: Path) -> None:
|
||||
root = str(tmp_path)
|
||||
for fn in [
|
||||
"nasa_rwanda_field_boundary_competition_source_train.tar.gz",
|
||||
"nasa_rwanda_field_boundary_competition_source_test.tar.gz",
|
||||
"nasa_rwanda_field_boundary_competition_labels_train.tar.gz",
|
||||
]:
|
||||
url = os.path.join("tests", "data", "rwanda_field_boundary", fn)
|
||||
shutil.copy(url, root)
|
||||
RwandaFieldBoundary(root, checksum=False)
|
||||
|
||||
def test_already_downloaded(self, dataset: RwandaFieldBoundary) -> None:
|
||||
RwandaFieldBoundary(root=dataset.root)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found in"):
|
||||
RwandaFieldBoundary(str(tmp_path))
|
||||
|
||||
def test_corrupted(self, tmp_path: Path) -> None:
|
||||
for fn in [
|
||||
"nasa_rwanda_field_boundary_competition_source_train.tar.gz",
|
||||
"nasa_rwanda_field_boundary_competition_source_test.tar.gz",
|
||||
"nasa_rwanda_field_boundary_competition_labels_train.tar.gz",
|
||||
]:
|
||||
with open(os.path.join(tmp_path, fn), "w") as f:
|
||||
f.write("bad")
|
||||
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
|
||||
RwandaFieldBoundary(root=str(tmp_path), checksum=True)
|
||||
|
||||
def test_failed_download(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> None:
|
||||
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.3")
|
||||
monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch)
|
||||
monkeypatch.setattr(
|
||||
RwandaFieldBoundary,
|
||||
"md5s",
|
||||
{"train_images": "bad", "test_images": "bad", "train_labels": "bad"},
|
||||
)
|
||||
root = str(tmp_path)
|
||||
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
|
||||
RwandaFieldBoundary(root, "train", api_key="", download=True, checksum=True)
|
||||
|
||||
def test_no_api_key(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Must provide an API key to download"):
|
||||
RwandaFieldBoundary(str(tmp_path), api_key=None, download=True)
|
||||
|
||||
def test_invalid_bands(self) -> None:
|
||||
with pytest.raises(ValueError, match="is an invalid band name."):
|
||||
RwandaFieldBoundary(bands=("foo", "bar"))
|
||||
|
||||
def test_plot(self, dataset: RwandaFieldBoundary) -> None:
|
||||
x = dataset[0].copy()
|
||||
dataset.plot(x, suptitle="Test")
|
||||
plt.close()
|
||||
dataset.plot(x, show_titles=False)
|
||||
plt.close()
|
||||
|
||||
if dataset.split == "train":
|
||||
x["prediction"] = x["mask"].clone()
|
||||
dataset.plot(x)
|
||||
plt.close()
|
||||
|
||||
def test_failed_plot(self, dataset: RwandaFieldBoundary) -> None:
|
||||
single_band_dataset = RwandaFieldBoundary(root=dataset.root, bands=("B01",))
|
||||
with pytest.raises(ValueError, match="Dataset doesn't contain"):
|
||||
x = single_band_dataset[0].copy()
|
||||
single_band_dataset.plot(x, suptitle="Test")
|
|
@ -83,6 +83,7 @@ from .patternnet import PatternNet
|
|||
from .potsdam import Potsdam2D
|
||||
from .reforestree import ReforesTree
|
||||
from .resisc45 import RESISC45
|
||||
from .rwanda_field_boundary import RwandaFieldBoundary
|
||||
from .seasonet import SeasoNet
|
||||
from .seco import SeasonalContrastS2
|
||||
from .sen12ms import SEN12MS
|
||||
|
@ -201,6 +202,7 @@ __all__ = (
|
|||
"Potsdam2D",
|
||||
"RESISC45",
|
||||
"ReforesTree",
|
||||
"RwandaFieldBoundary",
|
||||
"SeasonalContrastS2",
|
||||
"SeasoNet",
|
||||
"SEN12MS",
|
||||
|
|
|
@ -0,0 +1,328 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Rwanda Field Boundary Competition dataset."""
|
||||
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from typing import Callable, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import rasterio
|
||||
import rasterio.features
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from torch import Tensor
|
||||
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive
|
||||
|
||||
|
||||
class RwandaFieldBoundary(NonGeoDataset):
|
||||
r"""Rwanda Field Boundary Competition dataset.
|
||||
|
||||
This dataset contains field boundaries for smallholder farms in eastern Rwanda.
|
||||
The Nasa Harvest program funded a team of annotators from TaQadam to label Planet
|
||||
imagery for the 2021 growing season for the purpose of conducting the Rwanda Field
|
||||
boundary detection Challenge. The dataset includes rasterized labeled field
|
||||
boundaries and time series satellite imagery from Planet's NICFI program.
|
||||
Planet's basemap imagery is provided for six months (March, April, August, October,
|
||||
November and December). Note: only fields that were big enough to be differentiated
|
||||
on the Planetscope imagery were labeled, only fields that were fully contained
|
||||
within the chips were labeled. The paired dataset is provided in 256x256 chips for a
|
||||
total of 70 tiles covering 1532 individual fields.
|
||||
|
||||
The labels are provided as binary semantic segmentation labels:
|
||||
|
||||
0. No field-boundary
|
||||
1. Field-boundary
|
||||
|
||||
If you use this dataset in your research, please cite the following:
|
||||
|
||||
* https://doi.org/10.34911/RDNT.G580WW
|
||||
|
||||
.. note::
|
||||
|
||||
This dataset requires the following additional library to be installed:
|
||||
|
||||
* `radiant-mlhub <https://pypi.org/project/radiant-mlhub/>`_ to download the
|
||||
imagery and labels from the Radiant Earth MLHub
|
||||
|
||||
.. versionadded:: 0.5
|
||||
"""
|
||||
|
||||
dataset_id = "nasa_rwanda_field_boundary_competition"
|
||||
collection_ids = [
|
||||
"nasa_rwanda_field_boundary_competition_source_train",
|
||||
"nasa_rwanda_field_boundary_competition_labels_train",
|
||||
"nasa_rwanda_field_boundary_competition_source_test",
|
||||
]
|
||||
number_of_patches_per_split = {"train": 57, "test": 13}
|
||||
|
||||
filenames = {
|
||||
"train_images": "nasa_rwanda_field_boundary_competition_source_train.tar.gz",
|
||||
"test_images": "nasa_rwanda_field_boundary_competition_source_test.tar.gz",
|
||||
"train_labels": "nasa_rwanda_field_boundary_competition_labels_train.tar.gz",
|
||||
}
|
||||
md5s = {
|
||||
"train_images": "1f9ec08038218e67e11f82a86849b333",
|
||||
"test_images": "17bb0e56eedde2e7a43c57aa908dc125",
|
||||
"train_labels": "10e4eb761523c57b6d3bdf9394004f5f",
|
||||
}
|
||||
|
||||
dates = ("2021_03", "2021_04", "2021_08", "2021_10", "2021_11", "2021_12")
|
||||
|
||||
all_bands = ("B01", "B02", "B03", "B04")
|
||||
rgb_bands = ("B03", "B02", "B01")
|
||||
|
||||
classes = ["No field-boundary", "Field-boundary"]
|
||||
|
||||
splits = ["train", "test"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str = "data",
|
||||
split: str = "train",
|
||||
bands: Sequence[str] = all_bands,
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
download: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new RwandaFieldBoundary instance.
|
||||
|
||||
Args:
|
||||
root: root directory where dataset can be found
|
||||
split: one of "train" or "test"
|
||||
bands: the subset of bands to load
|
||||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version
|
||||
download: if True, download dataset and store it in the root directory
|
||||
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
|
||||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
|
||||
Raises:
|
||||
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
|
||||
or if ``download=True`` and ``api_key=None``
|
||||
"""
|
||||
self._validate_bands(bands)
|
||||
assert split in self.splits
|
||||
if download and api_key is None:
|
||||
raise RuntimeError("Must provide an API key to download the dataset")
|
||||
self.root = os.path.expanduser(root)
|
||||
self.bands = bands
|
||||
self.transforms = transforms
|
||||
self.split = split
|
||||
self.download = download
|
||||
self.api_key = api_key
|
||||
self.checksum = checksum
|
||||
self._verify()
|
||||
|
||||
self.image_filenames: list[list[list[str]]] = []
|
||||
self.mask_filenames: list[str] = []
|
||||
for i in range(self.number_of_patches_per_split[split]):
|
||||
dates = []
|
||||
for date in self.dates:
|
||||
patch = []
|
||||
for band in self.bands:
|
||||
fn = os.path.join(
|
||||
self.root,
|
||||
f"nasa_rwanda_field_boundary_competition_source_{split}",
|
||||
f"nasa_rwanda_field_boundary_competition_source_{split}_{i:02d}_{date}", # noqa: E501
|
||||
f"{band}.tif",
|
||||
)
|
||||
patch.append(fn)
|
||||
dates.append(patch)
|
||||
self.image_filenames.append(dates)
|
||||
self.mask_filenames.append(
|
||||
os.path.join(
|
||||
self.root,
|
||||
f"nasa_rwanda_field_boundary_competition_labels_{split}",
|
||||
f"nasa_rwanda_field_boundary_competition_labels_{split}_{i:02d}",
|
||||
"raster_labels.tif",
|
||||
)
|
||||
)
|
||||
|
||||
def __getitem__(self, index: int) -> dict[str, Tensor]:
|
||||
"""Return an index within the dataset.
|
||||
|
||||
Args:
|
||||
index: index to return
|
||||
|
||||
Returns:
|
||||
a dict containing image, mask, transform, crs, and metadata at index.
|
||||
"""
|
||||
img_fns = self.image_filenames[index]
|
||||
mask_fn = self.mask_filenames[index]
|
||||
|
||||
imgs = []
|
||||
for date_fns in img_fns:
|
||||
bands = []
|
||||
for band_fn in date_fns:
|
||||
with rasterio.open(band_fn) as f:
|
||||
bands.append(f.read(1).astype(np.int32))
|
||||
imgs.append(bands)
|
||||
img = torch.from_numpy(np.array(imgs))
|
||||
|
||||
sample = {"image": img}
|
||||
|
||||
if self.split == "train":
|
||||
with rasterio.open(mask_fn) as f:
|
||||
mask = f.read(1)
|
||||
mask = torch.from_numpy(mask)
|
||||
sample["mask"] = mask
|
||||
|
||||
if self.transforms is not None:
|
||||
sample = self.transforms(sample)
|
||||
|
||||
return sample
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of chips in the dataset.
|
||||
|
||||
Returns:
|
||||
length of the dataset
|
||||
"""
|
||||
return len(self.image_filenames)
|
||||
|
||||
def _validate_bands(self, bands: Sequence[str]) -> None:
|
||||
"""Validate list of bands.
|
||||
|
||||
Args:
|
||||
bands: user-provided sequence of bands to load
|
||||
|
||||
Raises:
|
||||
ValueError: if an invalid band name is provided
|
||||
"""
|
||||
for band in bands:
|
||||
if band not in self.all_bands:
|
||||
raise ValueError(f"'{band}' is an invalid band name.")
|
||||
|
||||
def _verify(self) -> None:
|
||||
"""Verify the integrity of the dataset.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
|
||||
"""
|
||||
# Check if the subdirectories already exist and have the correct number of files
|
||||
checks = []
|
||||
for split, num_patches in self.number_of_patches_per_split.items():
|
||||
path = os.path.join(
|
||||
self.root, f"nasa_rwanda_field_boundary_competition_source_{split}"
|
||||
)
|
||||
if os.path.exists(path):
|
||||
num_files = len(os.listdir(path))
|
||||
# 6 dates + 1 collection.json file
|
||||
checks.append(num_files == (num_patches * 6) + 1)
|
||||
else:
|
||||
checks.append(False)
|
||||
|
||||
if all(checks):
|
||||
return
|
||||
|
||||
# Check if tar file already exists (if so then extract)
|
||||
have_all_files = True
|
||||
for group in ["train_images", "train_labels", "test_images"]:
|
||||
filepath = os.path.join(self.root, self.filenames[group])
|
||||
if os.path.exists(filepath):
|
||||
if self.checksum and not check_integrity(filepath, self.md5s[group]):
|
||||
raise RuntimeError("Dataset found, but corrupted.")
|
||||
extract_archive(filepath)
|
||||
else:
|
||||
have_all_files = False
|
||||
if have_all_files:
|
||||
return
|
||||
|
||||
# Check if the user requested to download the dataset
|
||||
if not self.download:
|
||||
raise RuntimeError(
|
||||
f"Dataset not found in `root={self.root}` and `download=False`, "
|
||||
"either specify a different `root` directory or use `download=True` "
|
||||
"to automatically download the dataset."
|
||||
)
|
||||
|
||||
# Download and extract the dataset
|
||||
self._download()
|
||||
|
||||
def _download(self) -> None:
|
||||
"""Download the dataset and extract it.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if download doesn't work correctly or checksums don't match
|
||||
"""
|
||||
for collection_id in self.collection_ids:
|
||||
download_radiant_mlhub_collection(collection_id, self.root, self.api_key)
|
||||
|
||||
for group in ["train_images", "train_labels", "test_images"]:
|
||||
filepath = os.path.join(self.root, self.filenames[group])
|
||||
if self.checksum and not check_integrity(filepath, self.md5s[group]):
|
||||
raise RuntimeError("Dataset not found or corrupted.")
|
||||
extract_archive(filepath, self.root)
|
||||
|
||||
def plot(
|
||||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
time_step: int = 0,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
sample: a sample returned by :meth:`__getitem__`
|
||||
show_titles: flag indicating whether to show titles above each panel
|
||||
time_step: time step at which to access image, beginning with 0
|
||||
suptitle: optional string to use as a suptitle
|
||||
|
||||
Returns:
|
||||
a matplotlib Figure with the rendered sample
|
||||
|
||||
Raises:
|
||||
ValueError: if the RGB bands are not included in ``self.bands``
|
||||
"""
|
||||
rgb_indices = []
|
||||
for band in self.rgb_bands:
|
||||
if band in self.bands:
|
||||
rgb_indices.append(self.bands.index(band))
|
||||
else:
|
||||
raise ValueError("Dataset doesn't contain some of the RGB bands")
|
||||
|
||||
num_time_points = sample["image"].shape[0]
|
||||
assert time_step < num_time_points
|
||||
|
||||
image = np.rollaxis(sample["image"][time_step, rgb_indices].numpy(), 0, 3)
|
||||
image = np.clip(image / 2000, 0, 1)
|
||||
|
||||
if "mask" in sample:
|
||||
mask = sample["mask"].numpy()
|
||||
else:
|
||||
mask = np.zeros_like(image)
|
||||
|
||||
num_panels = 2
|
||||
showing_predictions = "prediction" in sample
|
||||
if showing_predictions:
|
||||
predictions = sample["prediction"].numpy()
|
||||
num_panels += 1
|
||||
|
||||
fig, axs = plt.subplots(ncols=num_panels, figsize=(4 * num_panels, 4))
|
||||
|
||||
axs[0].imshow(image)
|
||||
axs[0].axis("off")
|
||||
if show_titles:
|
||||
axs[0].set_title(f"t={time_step}")
|
||||
|
||||
axs[1].imshow(mask, vmin=0, vmax=1, interpolation="none")
|
||||
axs[1].axis("off")
|
||||
if show_titles:
|
||||
axs[1].set_title("Mask")
|
||||
|
||||
if showing_predictions:
|
||||
axs[2].imshow(predictions, vmin=0, vmax=1, interpolation="none")
|
||||
axs[2].axis("off")
|
||||
if show_titles:
|
||||
axs[2].set_title("Predictions")
|
||||
|
||||
if suptitle is not None:
|
||||
plt.suptitle(suptitle)
|
||||
return fig
|
Загрузка…
Ссылка в новой задаче