* Added EuroSat dataset

* Cleaning up

* Removing unzipped data

* Added to docs

* EuroSat --> EuroSAT to match paper

* Changing class listing to use bullets

* Update torchgeo/datasets/eurosat.py

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>

* Addressing review

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Caleb Robinson 2021-09-27 18:13:17 -07:00 коммит произвёл GitHub
Родитель 5d4ad430d2
Коммит 489ffdc2bd
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 281 добавлений и 0 удалений

Просмотреть файл

@ -97,6 +97,11 @@ ETCI2021 Flood Detection
.. autoclass:: ETCI2021
EuroSAT
^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: EuroSAT
GID-15 (Gaofen Image Dataset)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Двоичные данные
tests/data/eurosat/EuroSATallBands.zip Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,78 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import shutil
from pathlib import Path
from typing import Generator
import pytest
import torch
from _pytest.monkeypatch import MonkeyPatch
from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
from torchgeo.datasets import EuroSAT
from torchgeo.transforms import Identity
def download_url(url: str, root: str, *args: str) -> None:
shutil.copy(url, root)
class TestEuroSAT:
@pytest.fixture()
def dataset(
self,
monkeypatch: Generator[MonkeyPatch, None, None],
tmp_path: Path,
) -> EuroSAT:
monkeypatch.setattr( # type: ignore[attr-defined]
torchgeo.datasets.utils, "download_url", download_url
)
md5 = "aa051207b0547daba0ac6af57808d68e"
monkeypatch.setattr(EuroSAT, "md5", md5) # type: ignore[attr-defined]
url = os.path.join("tests", "data", "eurosat", "EuroSATallBands.zip")
monkeypatch.setattr(EuroSAT, "url", url) # type: ignore[attr-defined]
monkeypatch.setattr( # type: ignore[attr-defined]
EuroSAT,
"class_counts",
{
"AnnualCrop": 1,
"Forest": 1,
"HerbaceousVegetation": 0,
"Highway": 0,
"Industrial": 0,
"Pasture": 0,
"PermanentCrop": 0,
"Residential": 0,
"River": 0,
"SeaLake": 0,
},
)
root = str(tmp_path)
transforms = Identity()
return EuroSAT(root, transforms, download=True, checksum=True)
def test_getitem(self, dataset: EuroSAT) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["label"], torch.Tensor)
def test_len(self, dataset: EuroSAT) -> None:
assert len(dataset) == 2
def test_add(self, dataset: EuroSAT) -> None:
ds = dataset + dataset
assert isinstance(ds, ConcatDataset)
assert len(ds) == 4
def test_already_downloaded(self, dataset: EuroSAT) -> None:
EuroSAT(root=dataset.root, download=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
EuroSAT(str(tmp_path))

Просмотреть файл

@ -24,6 +24,7 @@ from .cowc import COWC, COWCCounting, COWCDetection
from .cv4a_kenya_crop_type import CV4AKenyaCropType
from .cyclone import TropicalCycloneWindEstimation
from .etci2021 import ETCI2021
from .eurosat import EuroSAT
from .geo import GeoDataset, RasterDataset, VectorDataset, VisionDataset, ZipDataset
from .gid15 import GID15
from .landcoverai import LandCoverAI
@ -89,6 +90,7 @@ __all__ = (
"COWCDetection",
"CV4AKenyaCropType",
"ETCI2021",
"EuroSAT",
"GID15",
"LandCoverAI",
"LEVIRCDPlus",

Просмотреть файл

@ -0,0 +1,196 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""EuroSAT dataset."""
import os
from typing import Callable, Dict, Optional
import numpy as np
import rasterio
import torch
from torch import Tensor
from .geo import VisionDataset
from .utils import check_integrity, download_and_extract_archive
class EuroSAT(VisionDataset):
"""EuroSAT dataset.
The `EuroSAT <https://github.com/phelber/EuroSAT>`_ dataset is based on Sentinel-2
satellite images covering 13 spectral bands and consists of 10 target classes with
a total of 27,000 labeled and geo-referenced images.
Dataset format:
* rasters are 13-channel GeoTiffs
* labels are values in the range [0,9]
Dataset classes:
* Industrial Buildings
* Residential Buildings
* Annual Crop
* Permanent Crop
* River
* Sea and Lake
* Herbaceous Vegetation
* Highway
* Pasture
* Forest
If you use this dataset in your research, please cite the following papers:
* https://ieeexplore.ieee.org/document/8736785
* https://ieeexplore.ieee.org/document/8519248
"""
url = "http://madm.dfki.de/files/sentinel/EuroSATallBands.zip" # 2.0 GB download
filename = "EuroSATallBands.zip"
md5 = "5ac12b3b2557aa56e1826e981e8e200e"
# For some reason the class directories are actually nested in this directory
base_dir = os.path.join(
"ds", "images", "remote_sensing", "otherDatasets", "sentinel_2", "tif"
)
class_counts = {
"AnnualCrop": 3000,
"Forest": 3000,
"HerbaceousVegetation": 3000,
"Highway": 2500,
"Industrial": 2500,
"Pasture": 2000,
"PermanentCrop": 2500,
"Residential": 3000,
"River": 2500,
"SeaLake": 3000,
}
class_name_to_label_idx = {
"AnnualCrop": 0,
"Forest": 1,
"HerbaceousVegetation": 2,
"Highway": 3,
"Industrial": 4,
"Pasture": 5,
"PermanentCrop": 6,
"Residential": 7,
"River": 8,
"SeaLake": 9,
}
def __init__(
self,
root: str = "data",
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
download: bool = False,
checksum: bool = False,
) -> None:
"""Initialize a new EuroSAT dataset instance.
Args:
root: root directory where dataset can be found
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
RuntimeError: if ``download=False`` and data is not found, or checksums
don't match
"""
self.root = root
self.transforms = transforms
self.checksum = checksum
if download:
self._download()
if not self._check_integrity():
raise RuntimeError(
"Dataset not found or corrupted. "
+ "You can use download=True to download it"
)
self.fns = []
self.labels = []
for class_name, class_count in self.class_counts.items():
for i in range(1, class_count + 1):
self.fns.append(
os.path.join(
self.root, self.base_dir, class_name, f"{class_name}_{i}.tif"
)
)
self.labels.append(self.class_name_to_label_idx[class_name])
def __getitem__(self, index: int) -> Dict[str, Tensor]:
"""Return an index within the dataset.
Args:
index: index to return
Returns:
data and label at that index
"""
sample: Dict[str, Tensor] = {
"image": self._load_image(index),
"label": torch.tensor(self.labels[index]), # type: ignore[attr-defined]
}
if self.transforms is not None:
sample = self.transforms(sample)
return sample
def __len__(self) -> int:
"""Return the number of data points in the dataset.
Returns:
length of the dataset
"""
return len(self.labels)
def _load_image(self, index: int) -> Tensor:
"""Load a single image.
Args:
id_: unique ID of the image
Returns:
the image
"""
filename = self.fns[index]
with rasterio.open(filename) as f:
array = f.read().astype(np.int32)
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
return tensor
def _check_integrity(self) -> bool:
"""Check integrity of dataset.
Returns:
True if dataset files are found and/or MD5s match, else False
"""
integrity: bool = check_integrity(
os.path.join(self.root, self.filename),
self.md5 if self.checksum else None,
)
return integrity
def _download(self) -> None:
"""Download the dataset and extract it.
Raises:
RuntimeError: if download doesn't work correctly or checksums don't match
"""
if self._check_integrity():
print("Files already downloaded and verified")
return
download_and_extract_archive(
self.url,
self.root,
filename=self.filename,
md5=self.md5 if self.checksum else None,
)