зеркало из https://github.com/microsoft/torchgeo.git
Add basic Sentinel dataset
This commit is contained in:
Родитель
e6b4031665
Коммит
fd3e3a1ad2
|
@ -3,15 +3,12 @@ from typing import Any, Dict
|
|||
import pytest
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
from torchgeo.datasets import GeoDataset, VisionDataset, ZipDataset
|
||||
from torchgeo.datasets import BoundingBox, GeoDataset, VisionDataset, ZipDataset
|
||||
|
||||
|
||||
class CustomGeoDataset(GeoDataset):
|
||||
def __getitem__(self, index: int) -> Dict[str, Any]:
|
||||
return {"index": index}
|
||||
|
||||
def __len__(self) -> int:
|
||||
return 2
|
||||
def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
|
||||
return {"index": query}
|
||||
|
||||
|
||||
class CustomVisionDataset(VisionDataset):
|
||||
|
@ -28,17 +25,14 @@ class TestGeoDataset:
|
|||
return CustomGeoDataset()
|
||||
|
||||
def test_getitem(self, dataset: GeoDataset) -> None:
|
||||
assert dataset[0] == {"index": 0}
|
||||
|
||||
def test_len(self, dataset: GeoDataset) -> None:
|
||||
assert len(dataset) == 2
|
||||
query = BoundingBox(0, 0, 0, 0, 0, 0)
|
||||
assert dataset[query] == {"index": query}
|
||||
|
||||
def test_add_two(self) -> None:
|
||||
ds1 = CustomGeoDataset()
|
||||
ds2 = CustomGeoDataset()
|
||||
dataset = ds1 + ds2
|
||||
assert isinstance(dataset, ZipDataset)
|
||||
assert len(dataset) == 2
|
||||
|
||||
def test_add_three(self) -> None:
|
||||
ds1 = CustomGeoDataset()
|
||||
|
@ -46,7 +40,6 @@ class TestGeoDataset:
|
|||
ds3 = CustomGeoDataset()
|
||||
dataset = ds1 + ds2 + ds3
|
||||
assert isinstance(dataset, ZipDataset)
|
||||
assert len(dataset) == 2
|
||||
|
||||
def test_add_four(self) -> None:
|
||||
ds1 = CustomGeoDataset()
|
||||
|
@ -55,15 +48,9 @@ class TestGeoDataset:
|
|||
ds4 = CustomGeoDataset()
|
||||
dataset = (ds1 + ds2) + (ds3 + ds4)
|
||||
assert isinstance(dataset, ZipDataset)
|
||||
assert len(dataset) == 2
|
||||
|
||||
def test_str(self, dataset: GeoDataset) -> None:
|
||||
assert "type: GeoDataset" in str(dataset)
|
||||
assert "size: 2" in str(dataset)
|
||||
|
||||
def test_abstract(self) -> None:
|
||||
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
|
||||
GeoDataset() # type: ignore[abstract]
|
||||
|
||||
def test_add_vision(self, dataset: GeoDataset) -> None:
|
||||
ds2 = CustomVisionDataset()
|
||||
|
@ -125,14 +112,11 @@ class TestZipDataset:
|
|||
return ZipDataset([ds1, ds2])
|
||||
|
||||
def test_getitem(self, dataset: ZipDataset) -> None:
|
||||
assert dataset[0] == {"index": 0}
|
||||
|
||||
def test_len(self, dataset: ZipDataset) -> None:
|
||||
assert len(dataset) == 2
|
||||
query = BoundingBox(0, 0, 0, 0, 0, 0)
|
||||
assert dataset[query] == {"index": query}
|
||||
|
||||
def test_str(self, dataset: ZipDataset) -> None:
|
||||
assert "type: ZipDataset" in str(dataset)
|
||||
assert "size: 2" in str(dataset)
|
||||
|
||||
def test_invalid_dataset(self) -> None:
|
||||
ds1 = CustomVisionDataset()
|
||||
|
|
|
@ -2,14 +2,16 @@ from .benin_cashews import BeninSmallHolderCashews
|
|||
from .cowc import COWC, COWCCounting, COWCDetection
|
||||
from .cv4a_kenya_crop_type import CV4AKenyaCropType
|
||||
from .cyclone import TropicalCycloneWindEstimation
|
||||
from .geo import GeoDataset, VisionDataset, ZipDataset
|
||||
from .geo import BoundingBox, GeoDataset, VisionDataset, ZipDataset
|
||||
from .landcoverai import LandCoverAI
|
||||
from .nwpu import VHR10
|
||||
from .sen12ms import SEN12MS
|
||||
from .sentinel import Sentinel, Sentinel2
|
||||
from .so2sat import So2Sat
|
||||
|
||||
__all__ = (
|
||||
"BeninSmallHolderCashews",
|
||||
"BoundingBox",
|
||||
"COWC",
|
||||
"COWCCounting",
|
||||
"COWCDetection",
|
||||
|
@ -17,6 +19,8 @@ __all__ = (
|
|||
"GeoDataset",
|
||||
"LandCoverAI",
|
||||
"SEN12MS",
|
||||
"Sentinel",
|
||||
"Sentinel2",
|
||||
"So2Sat",
|
||||
"TropicalCycloneWindEstimation",
|
||||
"VHR10",
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
import abc
|
||||
from typing import Any, Dict, Iterable
|
||||
from typing import Any, Dict, Iterable, NamedTuple, Union
|
||||
|
||||
import rasterio
|
||||
import torch
|
||||
from rtree.index import Index, Property
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
|
@ -8,42 +11,58 @@ from torch.utils.data import Dataset
|
|||
Dataset.__module__ = "torch.utils.data"
|
||||
|
||||
|
||||
class GeoDataset(Dataset[Dict[str, Any]], abc.ABC):
|
||||
"""Abstract base class for datasets containing geospatial information.
|
||||
class BoundingBox(NamedTuple):
|
||||
"""Named tuple for indexing spatiotemporal data."""
|
||||
|
||||
minx: Union[int, float]
|
||||
maxx: Union[int, float]
|
||||
miny: Union[int, float]
|
||||
maxy: Union[int, float]
|
||||
mint: Union[int, float]
|
||||
maxt: Union[int, float]
|
||||
|
||||
|
||||
class GeoDataset(Dataset[Dict[str, Any]]):
|
||||
"""Base class for datasets containing geospatial information.
|
||||
|
||||
Geospatial information includes things like:
|
||||
|
||||
* latitude, longitude
|
||||
* time
|
||||
* coordinate reference systems (CRS)
|
||||
* :term:`coordinate reference system (CRS)`
|
||||
|
||||
These kind of datasets are special because they can be combined. For example:
|
||||
|
||||
* Combine Landsat8 and CDL to train a model for crop classification
|
||||
* Combine Sentinel2 and Chesapeake to train a model for land cover mapping
|
||||
|
||||
This isn't true for VisionDataset, where the lack of geospatial information
|
||||
This isn't true for :class:`VisionDataset`, where the lack of geospatial information
|
||||
prohibits swapping image sources or target labels.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def __getitem__(self, index: int) -> Dict[str, Any]:
|
||||
"""Return an index within the dataset.
|
||||
#: R-tree to index geospatial data. Subclasses must insert data into this index in
|
||||
#: order for the sampler to index it properly.
|
||||
index = Index(properties=Property(dimension=3, interleaved=False))
|
||||
|
||||
def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
|
||||
"""Retrieve image and metadata indexed by query.
|
||||
|
||||
Parameters:
|
||||
index: index to return
|
||||
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
|
||||
|
||||
Returns:
|
||||
data and labels at that index
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def __len__(self) -> int:
|
||||
"""Return the length of the dataset.
|
||||
|
||||
Returns:
|
||||
length of the dataset
|
||||
sample of data/labels and metadata at that index
|
||||
"""
|
||||
bounds = rasterio.coords.BoundingBox(
|
||||
query.minx, query.miny, query.maxx, query.maxy
|
||||
)
|
||||
hits = self.index.intersection(query, objects=True)
|
||||
datasets = [hit.obj for hit in hits]
|
||||
dest, out_transform = rasterio.merge.merge(datasets, bounds)
|
||||
return {
|
||||
"image": torch.tensor(dest), # type: ignore[attr-defined]
|
||||
"transform": out_transform,
|
||||
}
|
||||
|
||||
def __add__(self, other: "GeoDataset") -> "ZipDataset": # type: ignore[override]
|
||||
"""Merge two GeoDatasets.
|
||||
|
@ -64,8 +83,7 @@ class GeoDataset(Dataset[Dict[str, Any]], abc.ABC):
|
|||
"""
|
||||
return f"""\
|
||||
{self.__class__.__name__} Dataset
|
||||
type: GeoDataset
|
||||
size: {len(self)}"""
|
||||
type: GeoDataset"""
|
||||
|
||||
|
||||
class VisionDataset(Dataset[Dict[str, Any]], abc.ABC):
|
||||
|
@ -123,28 +141,20 @@ class ZipDataset(GeoDataset):
|
|||
|
||||
self.datasets = datasets
|
||||
|
||||
def __getitem__(self, index: int) -> Dict[str, Any]:
|
||||
"""Return an index within the dataset.
|
||||
def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
|
||||
"""Retrieve image and metadata indexed by query.
|
||||
|
||||
Parameters:
|
||||
index: index to return
|
||||
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
|
||||
|
||||
Returns:
|
||||
data and labels at that index
|
||||
sample of data/labels and metadata at that index
|
||||
"""
|
||||
sample = {}
|
||||
for ds in self.datasets:
|
||||
sample.update(ds[index])
|
||||
sample.update(ds[query])
|
||||
return sample
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the length of the dataset.
|
||||
|
||||
Returns:
|
||||
length of the dataset
|
||||
"""
|
||||
return min(map(len, self.datasets))
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return the informal string representation of the object.
|
||||
|
||||
|
@ -153,5 +163,4 @@ class ZipDataset(GeoDataset):
|
|||
"""
|
||||
return f"""\
|
||||
{self.__class__.__name__} Dataset
|
||||
type: ZipDataset
|
||||
size: {len(self)}"""
|
||||
type: ZipDataset"""
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
import abc
|
||||
import glob
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, Optional, Sequence
|
||||
|
||||
import rasterio
|
||||
|
||||
from .geo import GeoDataset
|
||||
|
||||
|
||||
class Sentinel(GeoDataset, abc.ABC):
|
||||
"""`Sentinel <https://sentinel.esa.int/web/sentinel/home>`_ is a family of
|
||||
satellites launched by the `European Space Agency (ESA) <https://www.esa.int/>`_
|
||||
under the `Copernicus Programme <https://www.copernicus.eu/en>`_.
|
||||
|
||||
If you use this dataset in your research, please cite it using the following format:
|
||||
|
||||
* https://asf.alaska.edu/data-sets/sar-data-sets/sentinel-1/sentinel-1-how-to-cite/
|
||||
"""
|
||||
|
||||
# TODO: is this ABC actually needed?
|
||||
# Do these datasets actually share anything in common?
|
||||
# Could still keep it just to document what Sentinel is and how to cite it...
|
||||
|
||||
|
||||
class Sentinel2(Sentinel):
|
||||
"""The `Copernicus Sentinel-2 mission
|
||||
<https://sentinel.esa.int/web/sentinel/missions/sentinel-2>`_ comprises a
|
||||
constellation of two polar-orbiting satellites placed in the same sun-synchronous
|
||||
orbit, phased at 180° to each other. It aims at monitoring variability in land
|
||||
surface conditions, and its wide swath width (290 km) and high revisit time (10 days
|
||||
at the equator with one satellite, and 5 days with 2 satellites under cloud-free
|
||||
conditions which results in 2-3 days at mid-latitudes) will support monitoring of
|
||||
Earth's surface changes.
|
||||
"""
|
||||
|
||||
base_folder = "sentinel"
|
||||
band_names = [
|
||||
"B01",
|
||||
"B02",
|
||||
"B03",
|
||||
"B04",
|
||||
"B05",
|
||||
"B06",
|
||||
"B07",
|
||||
"B08",
|
||||
"B8A",
|
||||
"B09",
|
||||
"B10",
|
||||
"B11",
|
||||
"B12",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str = "data",
|
||||
bands: Sequence[str] = band_names,
|
||||
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
|
||||
) -> None:
|
||||
"""Initialize a new Sentinel-2 Dataset.
|
||||
|
||||
Parameters:
|
||||
root: root directory where dataset can be found
|
||||
bands: bands to return
|
||||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version
|
||||
"""
|
||||
self.root = root
|
||||
self.bands = bands
|
||||
self.transforms = transforms
|
||||
|
||||
fileglob = os.path.join(root, self.base_folder, f"**_{bands[0]}_*.tif")
|
||||
for filename in glob.iglob(fileglob):
|
||||
# https://sentinel.esa.int/web/sentinel/user-guides/sentinel-2-msi/naming-convention
|
||||
time = datetime.strptime(
|
||||
os.path.basename(filename).split("_")[1], "%Y%m%dT%H%M%S"
|
||||
)
|
||||
timestamp = time.timestamp()
|
||||
with rasterio.open(filename) as f:
|
||||
minx, miny, maxx, maxy = f.bounds
|
||||
coords = (minx, maxx, miny, maxy, timestamp, timestamp)
|
||||
self.index.insert(0, coords, filename)
|
|
@ -12,13 +12,12 @@ Sampler.__module__ = "torch.utils.data"
|
|||
|
||||
|
||||
class GeoSampler(Sampler[Tuple[Any, ...]], abc.ABC):
|
||||
"""Abstract base class for sampling from :class:`GeoDataset
|
||||
<torchgeo.datasets.GeoDataset>`.
|
||||
"""Abstract base class for sampling from :class:`~torchgeo.datasets.GeoDataset`.
|
||||
|
||||
Unlike PyTorch's :class:`Sampler <torch.utils.data.Sampler>`, :class:`GeoSampler`
|
||||
returns enough geospatial information to uniquely index any :class:`GeoDataset
|
||||
<torchgeo.datasets.GeoDataset>`. This includes things like latitude, longitude,
|
||||
height, width, projection, coordinate system, and time.
|
||||
Unlike PyTorch's :class:`~torch.utils.data.Sampler`, :class:`GeoSampler`
|
||||
returns enough geospatial information to uniquely index any
|
||||
:class:`~torchgeo.datasets.GeoDataset`. This includes things like latitude,
|
||||
longitude, height, width, projection, coordinate system, and time.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
|
|
Загрузка…
Ссылка в новой задаче