This commit is contained in:
Adam J. Stewart 2021-07-12 09:42:44 -05:00
Родитель e6b4031665
Коммит fd3e3a1ad2
5 изменённых файлов: 144 добавлений и 65 удалений

Просмотреть файл

@ -3,15 +3,12 @@ from typing import Any, Dict
import pytest
from torch.utils.data import ConcatDataset
from torchgeo.datasets import GeoDataset, VisionDataset, ZipDataset
from torchgeo.datasets import BoundingBox, GeoDataset, VisionDataset, ZipDataset
class CustomGeoDataset(GeoDataset):
def __getitem__(self, index: int) -> Dict[str, Any]:
return {"index": index}
def __len__(self) -> int:
return 2
def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
return {"index": query}
class CustomVisionDataset(VisionDataset):
@ -28,17 +25,14 @@ class TestGeoDataset:
return CustomGeoDataset()
def test_getitem(self, dataset: GeoDataset) -> None:
assert dataset[0] == {"index": 0}
def test_len(self, dataset: GeoDataset) -> None:
assert len(dataset) == 2
query = BoundingBox(0, 0, 0, 0, 0, 0)
assert dataset[query] == {"index": query}
def test_add_two(self) -> None:
ds1 = CustomGeoDataset()
ds2 = CustomGeoDataset()
dataset = ds1 + ds2
assert isinstance(dataset, ZipDataset)
assert len(dataset) == 2
def test_add_three(self) -> None:
ds1 = CustomGeoDataset()
@ -46,7 +40,6 @@ class TestGeoDataset:
ds3 = CustomGeoDataset()
dataset = ds1 + ds2 + ds3
assert isinstance(dataset, ZipDataset)
assert len(dataset) == 2
def test_add_four(self) -> None:
ds1 = CustomGeoDataset()
@ -55,15 +48,9 @@ class TestGeoDataset:
ds4 = CustomGeoDataset()
dataset = (ds1 + ds2) + (ds3 + ds4)
assert isinstance(dataset, ZipDataset)
assert len(dataset) == 2
def test_str(self, dataset: GeoDataset) -> None:
assert "type: GeoDataset" in str(dataset)
assert "size: 2" in str(dataset)
def test_abstract(self) -> None:
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
GeoDataset() # type: ignore[abstract]
def test_add_vision(self, dataset: GeoDataset) -> None:
ds2 = CustomVisionDataset()
@ -125,14 +112,11 @@ class TestZipDataset:
return ZipDataset([ds1, ds2])
def test_getitem(self, dataset: ZipDataset) -> None:
assert dataset[0] == {"index": 0}
def test_len(self, dataset: ZipDataset) -> None:
assert len(dataset) == 2
query = BoundingBox(0, 0, 0, 0, 0, 0)
assert dataset[query] == {"index": query}
def test_str(self, dataset: ZipDataset) -> None:
assert "type: ZipDataset" in str(dataset)
assert "size: 2" in str(dataset)
def test_invalid_dataset(self) -> None:
ds1 = CustomVisionDataset()

Просмотреть файл

@ -2,14 +2,16 @@ from .benin_cashews import BeninSmallHolderCashews
from .cowc import COWC, COWCCounting, COWCDetection
from .cv4a_kenya_crop_type import CV4AKenyaCropType
from .cyclone import TropicalCycloneWindEstimation
from .geo import GeoDataset, VisionDataset, ZipDataset
from .geo import BoundingBox, GeoDataset, VisionDataset, ZipDataset
from .landcoverai import LandCoverAI
from .nwpu import VHR10
from .sen12ms import SEN12MS
from .sentinel import Sentinel, Sentinel2
from .so2sat import So2Sat
__all__ = (
"BeninSmallHolderCashews",
"BoundingBox",
"COWC",
"COWCCounting",
"COWCDetection",
@ -17,6 +19,8 @@ __all__ = (
"GeoDataset",
"LandCoverAI",
"SEN12MS",
"Sentinel",
"Sentinel2",
"So2Sat",
"TropicalCycloneWindEstimation",
"VHR10",

Просмотреть файл

@ -1,6 +1,9 @@
import abc
from typing import Any, Dict, Iterable
from typing import Any, Dict, Iterable, NamedTuple, Union
import rasterio
import torch
from rtree.index import Index, Property
from torch.utils.data import Dataset
# https://github.com/pytorch/pytorch/issues/60979
@ -8,42 +11,58 @@ from torch.utils.data import Dataset
Dataset.__module__ = "torch.utils.data"
class GeoDataset(Dataset[Dict[str, Any]], abc.ABC):
"""Abstract base class for datasets containing geospatial information.
class BoundingBox(NamedTuple):
"""Named tuple for indexing spatiotemporal data."""
minx: Union[int, float]
maxx: Union[int, float]
miny: Union[int, float]
maxy: Union[int, float]
mint: Union[int, float]
maxt: Union[int, float]
class GeoDataset(Dataset[Dict[str, Any]]):
"""Base class for datasets containing geospatial information.
Geospatial information includes things like:
* latitude, longitude
* time
* coordinate reference systems (CRS)
* :term:`coordinate reference system (CRS)`
These kind of datasets are special because they can be combined. For example:
* Combine Landsat8 and CDL to train a model for crop classification
* Combine Sentinel2 and Chesapeake to train a model for land cover mapping
This isn't true for VisionDataset, where the lack of geospatial information
This isn't true for :class:`VisionDataset`, where the lack of geospatial information
prohibits swapping image sources or target labels.
"""
@abc.abstractmethod
def __getitem__(self, index: int) -> Dict[str, Any]:
"""Return an index within the dataset.
#: R-tree to index geospatial data. Subclasses must insert data into this index in
#: order for the sampler to index it properly.
index = Index(properties=Property(dimension=3, interleaved=False))
def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
"""Retrieve image and metadata indexed by query.
Parameters:
index: index to return
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
Returns:
data and labels at that index
"""
@abc.abstractmethod
def __len__(self) -> int:
"""Return the length of the dataset.
Returns:
length of the dataset
sample of data/labels and metadata at that index
"""
bounds = rasterio.coords.BoundingBox(
query.minx, query.miny, query.maxx, query.maxy
)
hits = self.index.intersection(query, objects=True)
datasets = [hit.obj for hit in hits]
dest, out_transform = rasterio.merge.merge(datasets, bounds)
return {
"image": torch.tensor(dest), # type: ignore[attr-defined]
"transform": out_transform,
}
def __add__(self, other: "GeoDataset") -> "ZipDataset": # type: ignore[override]
"""Merge two GeoDatasets.
@ -64,8 +83,7 @@ class GeoDataset(Dataset[Dict[str, Any]], abc.ABC):
"""
return f"""\
{self.__class__.__name__} Dataset
type: GeoDataset
size: {len(self)}"""
type: GeoDataset"""
class VisionDataset(Dataset[Dict[str, Any]], abc.ABC):
@ -123,28 +141,20 @@ class ZipDataset(GeoDataset):
self.datasets = datasets
def __getitem__(self, index: int) -> Dict[str, Any]:
"""Return an index within the dataset.
def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
"""Retrieve image and metadata indexed by query.
Parameters:
index: index to return
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
Returns:
data and labels at that index
sample of data/labels and metadata at that index
"""
sample = {}
for ds in self.datasets:
sample.update(ds[index])
sample.update(ds[query])
return sample
def __len__(self) -> int:
"""Return the length of the dataset.
Returns:
length of the dataset
"""
return min(map(len, self.datasets))
def __str__(self) -> str:
"""Return the informal string representation of the object.
@ -153,5 +163,4 @@ class ZipDataset(GeoDataset):
"""
return f"""\
{self.__class__.__name__} Dataset
type: ZipDataset
size: {len(self)}"""
type: ZipDataset"""

Просмотреть файл

@ -0,0 +1,83 @@
import abc
import glob
import os
from datetime import datetime
from typing import Any, Callable, Dict, Optional, Sequence
import rasterio
from .geo import GeoDataset
class Sentinel(GeoDataset, abc.ABC):
"""`Sentinel <https://sentinel.esa.int/web/sentinel/home>`_ is a family of
satellites launched by the `European Space Agency (ESA) <https://www.esa.int/>`_
under the `Copernicus Programme <https://www.copernicus.eu/en>`_.
If you use this dataset in your research, please cite it using the following format:
* https://asf.alaska.edu/data-sets/sar-data-sets/sentinel-1/sentinel-1-how-to-cite/
"""
# TODO: is this ABC actually needed?
# Do these datasets actually share anything in common?
# Could still keep it just to document what Sentinel is and how to cite it...
class Sentinel2(Sentinel):
"""The `Copernicus Sentinel-2 mission
<https://sentinel.esa.int/web/sentinel/missions/sentinel-2>`_ comprises a
constellation of two polar-orbiting satellites placed in the same sun-synchronous
orbit, phased at 180° to each other. It aims at monitoring variability in land
surface conditions, and its wide swath width (290 km) and high revisit time (10 days
at the equator with one satellite, and 5 days with 2 satellites under cloud-free
conditions which results in 2-3 days at mid-latitudes) will support monitoring of
Earth's surface changes.
"""
base_folder = "sentinel"
band_names = [
"B01",
"B02",
"B03",
"B04",
"B05",
"B06",
"B07",
"B08",
"B8A",
"B09",
"B10",
"B11",
"B12",
]
def __init__(
self,
root: str = "data",
bands: Sequence[str] = band_names,
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
) -> None:
"""Initialize a new Sentinel-2 Dataset.
Parameters:
root: root directory where dataset can be found
bands: bands to return
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
"""
self.root = root
self.bands = bands
self.transforms = transforms
fileglob = os.path.join(root, self.base_folder, f"**_{bands[0]}_*.tif")
for filename in glob.iglob(fileglob):
# https://sentinel.esa.int/web/sentinel/user-guides/sentinel-2-msi/naming-convention
time = datetime.strptime(
os.path.basename(filename).split("_")[1], "%Y%m%dT%H%M%S"
)
timestamp = time.timestamp()
with rasterio.open(filename) as f:
minx, miny, maxx, maxy = f.bounds
coords = (minx, maxx, miny, maxy, timestamp, timestamp)
self.index.insert(0, coords, filename)

Просмотреть файл

@ -12,13 +12,12 @@ Sampler.__module__ = "torch.utils.data"
class GeoSampler(Sampler[Tuple[Any, ...]], abc.ABC):
"""Abstract base class for sampling from :class:`GeoDataset
<torchgeo.datasets.GeoDataset>`.
"""Abstract base class for sampling from :class:`~torchgeo.datasets.GeoDataset`.
Unlike PyTorch's :class:`Sampler <torch.utils.data.Sampler>`, :class:`GeoSampler`
returns enough geospatial information to uniquely index any :class:`GeoDataset
<torchgeo.datasets.GeoDataset>`. This includes things like latitude, longitude,
height, width, projection, coordinate system, and time.
Unlike PyTorch's :class:`~torch.utils.data.Sampler`, :class:`GeoSampler`
returns enough geospatial information to uniquely index any
:class:`~torchgeo.datasets.GeoDataset`. This includes things like latitude,
longitude, height, width, projection, coordinate system, and time.
"""
@abc.abstractmethod