зеркало из https://github.com/microsoft/torchgeo.git
Refactor GeoDataset
This commit is contained in:
Родитель
ae8049ac19
Коммит
9af70d3c40
|
@ -117,6 +117,16 @@ GeoDataset
|
|||
|
||||
.. autoclass:: GeoDataset
|
||||
|
||||
RasterDataset
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: RasterDataset
|
||||
|
||||
VectorDataset
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: VectorDataset
|
||||
|
||||
VisionDataset
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ from .chesapeake import (
|
|||
from .cowc import COWC, COWCCounting, COWCDetection
|
||||
from .cv4a_kenya_crop_type import CV4AKenyaCropType
|
||||
from .cyclone import TropicalCycloneWindEstimation
|
||||
from .geo import GeoDataset, VisionDataset, ZipDataset
|
||||
from .geo import GeoDataset, RasterDataset, VectorDataset, VisionDataset, ZipDataset
|
||||
from .landcoverai import LandCoverAI
|
||||
from .landsat import (
|
||||
Landsat,
|
||||
|
@ -81,6 +81,8 @@ __all__ = (
|
|||
"VHR10",
|
||||
# Base classes
|
||||
"GeoDataset",
|
||||
"RasterDataset",
|
||||
"VectorDataset",
|
||||
"VisionDataset",
|
||||
"ZipDataset",
|
||||
# Utilities
|
||||
|
|
|
@ -1,50 +1,15 @@
|
|||
"""CDL dataset."""
|
||||
|
||||
import glob
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import rasterio
|
||||
import torch
|
||||
from rasterio.crs import CRS
|
||||
from rasterio.vrt import WarpedVRT
|
||||
from rtree.index import Index, Property
|
||||
from torch import Tensor
|
||||
|
||||
from .geo import GeoDataset
|
||||
from .utils import BoundingBox, check_integrity, download_and_extract_archive
|
||||
|
||||
_crs = CRS.from_wkt(
|
||||
"""
|
||||
PROJCS["Albers Conical Equal Area",
|
||||
GEOGCS["NAD83",
|
||||
DATUM["North_American_Datum_1983",
|
||||
SPHEROID["GRS 1980",6378137,298.257222101,
|
||||
AUTHORITY["EPSG","7019"]],
|
||||
AUTHORITY["EPSG","6269"]],
|
||||
PRIMEM["Greenwich",0,
|
||||
AUTHORITY["EPSG","8901"]],
|
||||
UNIT["degree",0.0174532925199433,
|
||||
AUTHORITY["EPSG","9122"]],
|
||||
AUTHORITY["EPSG","4269"]],
|
||||
PROJECTION["Albers_Conic_Equal_Area"],
|
||||
PARAMETER["latitude_of_center",23],
|
||||
PARAMETER["longitude_of_center",-96],
|
||||
PARAMETER["standard_parallel_1",29.5],
|
||||
PARAMETER["standard_parallel_2",45.5],
|
||||
PARAMETER["false_easting",0],
|
||||
PARAMETER["false_northing",0],
|
||||
UNIT["meters",1],
|
||||
AXIS["Easting",EAST],
|
||||
AXIS["Northing",NORTH]]
|
||||
"""
|
||||
)
|
||||
from .geo import RasterDataset
|
||||
from .utils import check_integrity, download_and_extract_archive
|
||||
|
||||
|
||||
class CDL(GeoDataset):
|
||||
class CDL(RasterDataset):
|
||||
"""Cropland Data Layer (CDL) dataset.
|
||||
|
||||
The `Cropland Data Layer
|
||||
|
@ -62,6 +27,14 @@ class CDL(GeoDataset):
|
|||
* https://www.nass.usda.gov/Research_and_Science/Cropland/sarsfaqs2.php#Section1_14.0
|
||||
""" # noqa: E501
|
||||
|
||||
filename_glob = "*_30m_cdls.*"
|
||||
filename_regex = r"""
|
||||
^(?P<date>\d+)
|
||||
_30m_cdls\..*$
|
||||
"""
|
||||
date_format = "%Y"
|
||||
is_image = False
|
||||
|
||||
url = "https://www.nass.usda.gov/Research_and_Science/Cropland/Release/datasets/{}_30m_cdls.zip" # noqa: E501
|
||||
md5s = [
|
||||
(2020, "97b3b5fd62177c9ed857010bca146f36"),
|
||||
|
@ -82,24 +55,27 @@ class CDL(GeoDataset):
|
|||
def __init__(
|
||||
self,
|
||||
root: str = "data",
|
||||
crs: CRS = _crs,
|
||||
crs: Optional[CRS] = None,
|
||||
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new CDL Dataset.
|
||||
"""Initialize a new Dataset instance.
|
||||
|
||||
Args:
|
||||
root: root directory where dataset can be found
|
||||
crs: :term:`coordinate reference system (CRS)` to project to
|
||||
crs: :term:`coordinate reference system (CRS)` to project to. Uses the CRS
|
||||
of the files by default
|
||||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version
|
||||
download: if True, download dataset and store it in the root directory
|
||||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: if no files are found in ``root``
|
||||
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
|
||||
"""
|
||||
self.root = root
|
||||
self.crs = crs
|
||||
self.transforms = transforms
|
||||
self.checksum = checksum
|
||||
|
||||
if download:
|
||||
|
@ -111,61 +87,7 @@ class CDL(GeoDataset):
|
|||
+ "You can use download=True to download it"
|
||||
)
|
||||
|
||||
# Create an R-tree to index the dataset
|
||||
self.index = Index(interleaved=False, properties=Property(dimension=3))
|
||||
fileglob = os.path.join(root, "**_30m_cdls.img")
|
||||
for i, filename in enumerate(glob.iglob(fileglob, recursive=True)):
|
||||
year = int(os.path.basename(filename).split("_")[0])
|
||||
mint = datetime(year, 1, 1, 0, 0, 0).timestamp()
|
||||
maxt = datetime(year, 12, 31, 23, 59, 59).timestamp()
|
||||
with rasterio.open(filename) as src:
|
||||
cmap = src.colormap(1)
|
||||
with WarpedVRT(src, crs=self.crs) as vrt:
|
||||
minx, miny, maxx, maxy = vrt.bounds
|
||||
coords = (minx, maxx, miny, maxy, mint, maxt)
|
||||
self.index.insert(i, coords, filename)
|
||||
self.cmap = np.array([cmap[i] for i in range(256)])
|
||||
|
||||
def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
|
||||
"""Retrieve image and metadata indexed by query.
|
||||
|
||||
Args:
|
||||
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
|
||||
|
||||
Returns:
|
||||
sample of labels and metadata at that index
|
||||
|
||||
Raises:
|
||||
IndexError: if query is not within bounds of the index
|
||||
"""
|
||||
if not query.intersects(self.bounds):
|
||||
raise IndexError(
|
||||
f"query: {query} is not within bounds of the index: {self.bounds}"
|
||||
)
|
||||
|
||||
hits = self.index.intersection(query, objects=True)
|
||||
filename = next(hits).object # TODO: this assumes there is only a single hit
|
||||
with rasterio.open(filename) as src:
|
||||
with WarpedVRT(src, crs=self.crs) as vrt:
|
||||
window = rasterio.windows.from_bounds(
|
||||
query.minx,
|
||||
query.miny,
|
||||
query.maxx,
|
||||
query.maxy,
|
||||
transform=vrt.transform,
|
||||
)
|
||||
masks = vrt.read(window=window)
|
||||
masks = masks.astype(np.int32)
|
||||
sample = {
|
||||
"masks": torch.tensor(masks), # type: ignore[attr-defined]
|
||||
"crs": self.crs,
|
||||
"bbox": query,
|
||||
}
|
||||
|
||||
if self.transforms is not None:
|
||||
sample = self.transforms(sample)
|
||||
|
||||
return sample
|
||||
super().__init__(root, crs, transforms)
|
||||
|
||||
def _check_integrity(self) -> bool:
|
||||
"""Check integrity of dataset.
|
||||
|
@ -191,20 +113,3 @@ class CDL(GeoDataset):
|
|||
self.root,
|
||||
md5=md5 if self.checksum else None,
|
||||
)
|
||||
|
||||
def plot(self, image: Tensor) -> None:
|
||||
"""Plot an image on a map.
|
||||
|
||||
Args:
|
||||
image: the image to plot
|
||||
"""
|
||||
# Convert from class labels to RGBA values
|
||||
array = image.squeeze().numpy()
|
||||
array = self.cmap[array]
|
||||
|
||||
# Plot the image
|
||||
ax = plt.axes()
|
||||
ax.imshow(array)
|
||||
ax.axis("off")
|
||||
plt.show()
|
||||
plt.close()
|
||||
|
|
|
@ -2,49 +2,15 @@
|
|||
|
||||
import abc
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import rasterio
|
||||
import torch
|
||||
from rasterio.crs import CRS
|
||||
from rasterio.vrt import WarpedVRT
|
||||
from rtree.index import Index, Property
|
||||
from torch import Tensor
|
||||
|
||||
from .geo import GeoDataset
|
||||
from .utils import BoundingBox, check_integrity, download_and_extract_archive
|
||||
|
||||
_crs = CRS.from_wkt(
|
||||
"""
|
||||
PROJCS["USA_Contiguous_Albers_Equal_Area_Conic_USGS_version",
|
||||
GEOGCS["NAD83",
|
||||
DATUM["North_American_Datum_1983",
|
||||
SPHEROID["GRS 1980",6378137,298.257222101004,
|
||||
AUTHORITY["EPSG","7019"]],
|
||||
AUTHORITY["EPSG","6269"]],
|
||||
PRIMEM["Greenwich",0],
|
||||
UNIT["degree",0.0174532925199433,
|
||||
AUTHORITY["EPSG","9122"]],
|
||||
AUTHORITY["EPSG","4269"]],
|
||||
PROJECTION["Albers_Conic_Equal_Area"],
|
||||
PARAMETER["latitude_of_center",23],
|
||||
PARAMETER["longitude_of_center",-96],
|
||||
PARAMETER["standard_parallel_1",29.5],
|
||||
PARAMETER["standard_parallel_2",45.5],
|
||||
PARAMETER["false_easting",0],
|
||||
PARAMETER["false_northing",0],
|
||||
UNIT["metre",1,
|
||||
AUTHORITY["EPSG","9001"]],
|
||||
AXIS["Easting",EAST],
|
||||
AXIS["Northing",NORTH]]
|
||||
"""
|
||||
)
|
||||
from .geo import RasterDataset
|
||||
from .utils import check_integrity, download_and_extract_archive
|
||||
|
||||
|
||||
class Chesapeake(GeoDataset, abc.ABC):
|
||||
class Chesapeake(RasterDataset, abc.ABC):
|
||||
"""Abstract base class for all Chesapeake datasets.
|
||||
|
||||
`Chesapeake Bay High-Resolution Land Cover Project
|
||||
|
@ -70,6 +36,8 @@ class Chesapeake(GeoDataset, abc.ABC):
|
|||
* https://doi.org/10.1109/cvpr.2019.01301
|
||||
"""
|
||||
|
||||
is_image = False
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def base_folder(self) -> str:
|
||||
|
@ -97,46 +65,29 @@ class Chesapeake(GeoDataset, abc.ABC):
|
|||
url += f"/{self.base_folder}/{self.zipfile}"
|
||||
return url
|
||||
|
||||
cmap = {
|
||||
0: (0, 0, 0, 0),
|
||||
1: (0, 197, 255, 255),
|
||||
2: (0, 168, 132, 255),
|
||||
3: (38, 115, 0, 255),
|
||||
4: (76, 230, 0, 255),
|
||||
5: (163, 255, 115, 255),
|
||||
6: (255, 170, 0, 255),
|
||||
7: (255, 0, 0, 255),
|
||||
8: (156, 156, 156, 255),
|
||||
9: (0, 0, 0, 255),
|
||||
10: (115, 115, 0, 255),
|
||||
11: (230, 230, 0, 255),
|
||||
12: (255, 255, 115, 255),
|
||||
13: (197, 0, 255, 255),
|
||||
14: (0, 0, 0, 0),
|
||||
15: (0, 0, 0, 0),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str,
|
||||
crs: CRS = _crs,
|
||||
transforms: Optional[Callable[[Any], Any]] = None,
|
||||
root: str = "data",
|
||||
crs: Optional[CRS] = None,
|
||||
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new Chesapeake dataset instance.
|
||||
"""Initialize a new Dataset instance.
|
||||
|
||||
Args:
|
||||
root: root directory where dataset can be found
|
||||
crs: :term:`coordinate reference system (CRS)` to project to
|
||||
crs: :term:`coordinate reference system (CRS)` to project to. Uses the CRS
|
||||
of the files by default
|
||||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version
|
||||
download: if True, download dataset and store it in the root directory
|
||||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
|
||||
Raises:
|
||||
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
|
||||
"""
|
||||
self.root = root
|
||||
self.crs = crs
|
||||
self.transforms = transforms
|
||||
self.checksum = checksum
|
||||
|
||||
if download:
|
||||
|
@ -148,57 +99,7 @@ class Chesapeake(GeoDataset, abc.ABC):
|
|||
+ "You can use download=True to download it"
|
||||
)
|
||||
|
||||
# Create an R-tree to index the dataset
|
||||
self.index = Index(interleaved=False, properties=Property(dimension=3))
|
||||
filename = os.path.join(self.root, self.filename)
|
||||
with rasterio.open(filename) as src:
|
||||
with rasterio.open(filename) as src:
|
||||
with WarpedVRT(src, crs=self.crs) as vrt:
|
||||
minx, miny, maxx, maxy = vrt.bounds
|
||||
mint = 0
|
||||
maxt = sys.maxsize
|
||||
coords = (minx, maxx, miny, maxy, mint, maxt)
|
||||
self.index.insert(0, coords, filename)
|
||||
|
||||
def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
|
||||
"""Retrieve labels and metadata indexed by query.
|
||||
|
||||
Args:
|
||||
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
|
||||
|
||||
Returns:
|
||||
sample of labels and metadata at that index
|
||||
|
||||
Raises:
|
||||
IndexError: if query is not within bounds of the index
|
||||
"""
|
||||
if not query.intersects(self.bounds):
|
||||
raise IndexError(
|
||||
f"query: {query} is not within bounds of the index: {self.bounds}"
|
||||
)
|
||||
|
||||
hits = self.index.intersection(query, objects=True)
|
||||
filename = next(hits).object
|
||||
with rasterio.open(filename) as src:
|
||||
with WarpedVRT(src, crs=self.crs) as vrt:
|
||||
window = rasterio.windows.from_bounds(
|
||||
query.minx,
|
||||
query.miny,
|
||||
query.maxx,
|
||||
query.maxy,
|
||||
transform=vrt.transform,
|
||||
)
|
||||
masks = vrt.read(window=window)
|
||||
sample = {
|
||||
"masks": torch.tensor(masks), # type: ignore[attr-defined]
|
||||
"crs": self.crs,
|
||||
"bbox": query,
|
||||
}
|
||||
|
||||
if self.transforms is not None:
|
||||
sample = self.transforms(sample)
|
||||
|
||||
return sample
|
||||
super().__init__(root, crs, transforms)
|
||||
|
||||
def _check_integrity(self) -> bool:
|
||||
"""Check integrity of dataset.
|
||||
|
@ -225,24 +126,6 @@ class Chesapeake(GeoDataset, abc.ABC):
|
|||
md5=self.md5,
|
||||
)
|
||||
|
||||
def plot(self, image: Tensor) -> None:
|
||||
"""Plot an image on a map.
|
||||
|
||||
Args:
|
||||
image: the image to plot
|
||||
"""
|
||||
# Convert from class labels to RGBA values
|
||||
cmap = np.array([self.cmap[i] for i in range(len(self.cmap))])
|
||||
array = image.squeeze().numpy()
|
||||
array = cmap[array]
|
||||
|
||||
# Plot the image
|
||||
ax = plt.axes()
|
||||
ax.imshow(array)
|
||||
ax.axis("off")
|
||||
plt.show()
|
||||
plt.close()
|
||||
|
||||
|
||||
class Chesapeake7(Chesapeake):
|
||||
"""Complete 7-class dataset.
|
||||
|
|
|
@ -1,10 +1,21 @@
|
|||
"""Base classes for all :mod:`torchgeo` datasets."""
|
||||
|
||||
import abc
|
||||
from typing import Any, Dict, Sequence
|
||||
import glob
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import rasterio
|
||||
import torch
|
||||
from rasterio.crs import CRS
|
||||
from rtree.index import Index
|
||||
from rasterio.vrt import WarpedVRT
|
||||
from rtree.index import Index, Property
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from .utils import BoundingBox
|
||||
|
@ -90,6 +101,234 @@ class GeoDataset(Dataset[Dict[str, Any]], abc.ABC):
|
|||
return BoundingBox(*self.index.bounds)
|
||||
|
||||
|
||||
class RasterDataset(GeoDataset):
|
||||
"""Abstract base class for :class:`GeoDataset`s stored as raster files."""
|
||||
|
||||
#: Glob expression used to search for files.
|
||||
#:
|
||||
#: This expression should be specific enough that it will not pick up files from
|
||||
#: other datasets. It should not include a file extension, as the dataset may be in
|
||||
#: a different file format than what it was originally downloaded as.
|
||||
filename_glob = "*"
|
||||
|
||||
#: Regular expression used to extract date from filename.
|
||||
#:
|
||||
#: The expression should use named groups. The expression may contain any number of
|
||||
#: groups. The following groups are specifically searched for by the base class:
|
||||
#:
|
||||
#: * ``date``: used to calculate ``mint`` and ``maxt`` for ``index`` insertion
|
||||
#: * ``band``: used when :attr:`separate_files` is True
|
||||
filename_regex = ".*"
|
||||
|
||||
#: Date format string used to parse date from filename.
|
||||
#:
|
||||
#: Not used if :attr:`filename_regex` does not contain a ``date`` group.
|
||||
date_format = "%Y%m%d"
|
||||
|
||||
#: True if dataset contains imagery, False if dataset contains mask
|
||||
is_image = True
|
||||
|
||||
#: True if data is stored in a separate file for each band, else False.
|
||||
separate_files = False
|
||||
|
||||
#: Names of all available bands in the dataset
|
||||
all_bands: List[str] = []
|
||||
|
||||
#: Names of RGB bands in the dataset, used for plotting
|
||||
rgb_bands: List[str] = []
|
||||
|
||||
#: If True, stretch the image from the 2nd percentile to the 98th percentile,
|
||||
#: used for plotting
|
||||
stretch = False
|
||||
|
||||
#: Color map for the dataset, used for plotting
|
||||
cmap: Dict[int, Tuple[int, int, int, int]] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str,
|
||||
crs: Optional[CRS] = None,
|
||||
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
|
||||
) -> None:
|
||||
"""Initialize a new Dataset instance.
|
||||
|
||||
Args:
|
||||
root: root directory where dataset can be found
|
||||
crs: :term:`coordinate reference system (CRS)` to project to. Uses the CRS
|
||||
of the files by default
|
||||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: if no files are found in ``root``
|
||||
"""
|
||||
self.root = root
|
||||
self.crs = crs
|
||||
self.transforms = transforms
|
||||
|
||||
# Create an R-tree to index the dataset
|
||||
self.index = Index(interleaved=False, properties=Property(dimension=3))
|
||||
i = 0
|
||||
pathname = os.path.join(root, "**", self.filename_glob)
|
||||
filename_regex = re.compile(self.filename_regex, re.VERBOSE)
|
||||
for filepath in glob.iglob(pathname, recursive=True):
|
||||
match = re.match(filename_regex, os.path.basename(filepath))
|
||||
if match is not None:
|
||||
try:
|
||||
with rasterio.open(filepath) as src:
|
||||
# See if file has a color map
|
||||
try:
|
||||
self.cmap = src.colormap(1)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if self.crs is None:
|
||||
self.crs = src.crs
|
||||
|
||||
with WarpedVRT(src, crs=self.crs) as vrt:
|
||||
minx, miny, maxx, maxy = vrt.bounds
|
||||
except rasterio.errors.RasterioIOError:
|
||||
# Skip files that rasterio is unable to read
|
||||
continue
|
||||
else:
|
||||
mint: float = 0
|
||||
maxt: float = sys.maxsize
|
||||
if "date" in match.groupdict():
|
||||
date = match.group("date")
|
||||
time = datetime.strptime(date, self.date_format)
|
||||
mint = maxt = time.timestamp()
|
||||
|
||||
coords = (minx, maxx, miny, maxy, mint, maxt)
|
||||
self.index.insert(i, coords, filepath)
|
||||
i += 1
|
||||
|
||||
if i == 0:
|
||||
raise FileNotFoundError(
|
||||
f"No {self.__class__.__name__} data was found in '{root}'"
|
||||
)
|
||||
|
||||
def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
|
||||
"""Retrieve image/mask and metadata indexed by query.
|
||||
|
||||
Args:
|
||||
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
|
||||
|
||||
Returns:
|
||||
sample of image/mask and metadata at that index
|
||||
|
||||
Raises:
|
||||
IndexError: if query is not found in the index
|
||||
"""
|
||||
hits = self.index.intersection(query, objects=True)
|
||||
|
||||
try:
|
||||
hit = next(hits) # TODO: this assumes there is only a single hit
|
||||
except StopIteration:
|
||||
raise IndexError(
|
||||
f"query: {query} is not within bounds of the index: {self.bounds}"
|
||||
)
|
||||
|
||||
filepath = hit.object
|
||||
if self.separate_files:
|
||||
data_list = []
|
||||
filename_regex = re.compile(self.filename_regex, re.VERBOSE)
|
||||
for band in getattr(self, "bands", self.all_bands):
|
||||
filename = os.path.basename(filepath)
|
||||
directory = os.path.dirname(filepath)
|
||||
match = re.match(filename_regex, filename)
|
||||
if match:
|
||||
start, end = match.start("band"), match.end("band")
|
||||
filename = filename[:start] + band + filename[end:]
|
||||
data_list.append(
|
||||
self._load_file(os.path.join(directory, filename), query)
|
||||
)
|
||||
data = torch.stack(data_list)
|
||||
else:
|
||||
data = self._load_file(filepath, query)
|
||||
|
||||
key = "image" if self.is_image else "masks"
|
||||
sample = {
|
||||
key: data,
|
||||
"crs": self.crs,
|
||||
"bbox": query,
|
||||
}
|
||||
|
||||
if self.transforms is not None:
|
||||
sample = self.transforms(sample)
|
||||
|
||||
return sample
|
||||
|
||||
def _load_file(self, filepath: str, query: BoundingBox) -> Tensor:
|
||||
"""Load a single raster file.
|
||||
|
||||
Args:
|
||||
filepath: path to file to open
|
||||
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
|
||||
|
||||
Returns:
|
||||
image/mask at that index
|
||||
"""
|
||||
with rasterio.open(filepath) as src:
|
||||
with WarpedVRT(src, crs=self.crs, nodata=0) as vrt:
|
||||
window = rasterio.windows.from_bounds(
|
||||
query.minx,
|
||||
query.miny,
|
||||
query.maxx,
|
||||
query.maxy,
|
||||
transform=vrt.transform,
|
||||
)
|
||||
array = vrt.read(window=window).astype(np.int32)
|
||||
tensor: Tensor = torch.tensor(array) # type: ignore[attr-defined]
|
||||
return tensor
|
||||
|
||||
def plot(self, data: Tensor) -> None:
|
||||
"""Plot a data sample.
|
||||
|
||||
Args:
|
||||
data: the data to plot
|
||||
|
||||
Raises:
|
||||
AssertionError: if ``is_image`` is True and ``data`` has a different number
|
||||
of channels than expected
|
||||
"""
|
||||
array = data.squeeze().numpy()
|
||||
|
||||
if self.is_image:
|
||||
bands = getattr(self, "bands", self.all_bands)
|
||||
assert array.shape[0] == len(bands)
|
||||
|
||||
# Only plot RGB bands
|
||||
if bands and self.rgb_bands:
|
||||
indices = np.array([bands.index(band) for band in self.rgb_bands])
|
||||
array = array[indices]
|
||||
|
||||
# Convert from CxHxW to HxWxC
|
||||
array = np.rollaxis(array, 0, 3)
|
||||
|
||||
if self.cmap:
|
||||
# Convert from class labels to RGBA values
|
||||
cmap = np.array([self.cmap[i] for i in range(len(self.cmap))])
|
||||
array = cmap[array]
|
||||
|
||||
if self.stretch:
|
||||
# Stretch to the range of 2nd to 98th percentile
|
||||
per02 = np.percentile(array, 2) # type: ignore[no-untyped-call]
|
||||
per98 = np.percentile(array, 98) # type: ignore[no-untyped-call]
|
||||
array = (array - per02) / (per98 - per02)
|
||||
array = np.clip(array, 0, 1)
|
||||
|
||||
# Plot the data
|
||||
ax = plt.axes()
|
||||
ax.imshow(array)
|
||||
ax.axis("off")
|
||||
plt.show()
|
||||
plt.close()
|
||||
|
||||
|
||||
class VectorDataset(GeoDataset):
|
||||
pass
|
||||
|
||||
|
||||
class VisionDataset(Dataset[Dict[str, Any]], abc.ABC):
|
||||
"""Abstract base class for datasets lacking geospatial information.
|
||||
|
||||
|
|
|
@ -1,27 +1,14 @@
|
|||
"""Landsat datasets."""
|
||||
|
||||
import abc
|
||||
import glob
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, Optional, Sequence
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import rasterio
|
||||
import torch
|
||||
from rasterio.crs import CRS
|
||||
from rasterio.vrt import WarpedVRT
|
||||
from rtree.index import Index, Property
|
||||
from torch import Tensor
|
||||
|
||||
from .geo import GeoDataset
|
||||
from .utils import BoundingBox
|
||||
|
||||
_crs = CRS.from_epsg(32616)
|
||||
from .geo import RasterDataset
|
||||
|
||||
|
||||
class Landsat(GeoDataset, abc.ABC):
|
||||
class Landsat(RasterDataset, abc.ABC):
|
||||
"""Abstract base class for all Landsat datasets.
|
||||
|
||||
`Landsat <https://landsat.gsfc.nasa.gov/>`_ is a joint NASA/USGS program,
|
||||
|
@ -32,27 +19,45 @@ class Landsat(GeoDataset, abc.ABC):
|
|||
* https://www.usgs.gov/centers/eros/data-citation
|
||||
"""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def band_names(self) -> Sequence[str]:
|
||||
"""Spectral bands provided by a satellite.
|
||||
# https://www.usgs.gov/faqs/what-naming-convention-landsat-collections-level-1-scenes
|
||||
# https://www.usgs.gov/faqs/what-naming-convention-landsat-collection-2-level-1-and-level-2-scenes
|
||||
filename_glob = ""
|
||||
filename_regex = r"""
|
||||
^L
|
||||
(?P<sensor>[COTEM])
|
||||
(?P<satellite>\d{2})
|
||||
_(?P<processing_correction_level>[A-Z0-9]{4})
|
||||
_(?P<wrs_path>\d{3})
|
||||
(?P<wrs_row>\d{3})
|
||||
_(?P<date>\d{8})
|
||||
_(?P<processing_date>\d{8})
|
||||
_(?P<collection_number>\d{2})
|
||||
_(?P<collection_category>[A-Z0-9]{2})
|
||||
_SR
|
||||
_(?P<band>[A-Z0-9]{2})
|
||||
\..*$
|
||||
"""
|
||||
|
||||
See https://www.usgs.gov/faqs/what-are-band-designations-landsat-satellites
|
||||
for more details.
|
||||
"""
|
||||
# https://www.usgs.gov/faqs/what-are-band-designations-landsat-satellites
|
||||
all_bands: List[str] = []
|
||||
rgb_bands: List[str] = []
|
||||
|
||||
separate_files = True
|
||||
stretch = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str = "data",
|
||||
crs: CRS = _crs,
|
||||
crs: Optional[CRS] = None,
|
||||
bands: Sequence[str] = [],
|
||||
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
|
||||
) -> None:
|
||||
"""Initialize a new Landsat Dataset.
|
||||
"""Initialize a new Dataset instance.
|
||||
|
||||
Args:
|
||||
root: root directory where dataset can be found
|
||||
crs: :term:`coordinate reference system (CRS)` to project to
|
||||
crs: :term:`coordinate reference system (CRS)` to project to. Uses the CRS
|
||||
of the files by default
|
||||
bands: bands to return (defaults to all bands)
|
||||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version
|
||||
|
@ -60,106 +65,104 @@ class Landsat(GeoDataset, abc.ABC):
|
|||
Raises:
|
||||
FileNotFoundError: if no files are found in ``root``
|
||||
"""
|
||||
self.root = root
|
||||
self.crs = crs
|
||||
self.bands = bands if bands else self.band_names
|
||||
self.transforms = transforms
|
||||
self.bands = bands if bands else self.all_bands
|
||||
|
||||
# Create an R-tree to index the dataset
|
||||
self.index = Index(interleaved=False, properties=Property(dimension=3))
|
||||
fileglob = os.path.join(root, f"**_{self.bands[0]}.TIF")
|
||||
for i, filename in enumerate(glob.iglob(fileglob, recursive=True)):
|
||||
# https://www.usgs.gov/faqs/what-naming-convention-landsat-collections-level-1-scenes
|
||||
# https://www.usgs.gov/faqs/what-naming-convention-landsat-collection-2-level-1-and-level-2-scenes
|
||||
time = datetime.strptime(os.path.basename(filename).split("_")[3], "%Y%m%d")
|
||||
timestamp = time.timestamp()
|
||||
with rasterio.open(filename) as src:
|
||||
with WarpedVRT(src, crs=self.crs) as vrt:
|
||||
minx, miny, maxx, maxy = vrt.bounds
|
||||
coords = (minx, maxx, miny, maxy, timestamp, timestamp)
|
||||
self.index.insert(i, coords, filename)
|
||||
super().__init__(root, crs, transforms)
|
||||
|
||||
if "filename" not in locals():
|
||||
raise FileNotFoundError(
|
||||
f"No {self.__class__.__name__} data was found in '{root}'"
|
||||
)
|
||||
|
||||
def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
|
||||
"""Retrieve image and metadata indexed by query.
|
||||
class Landsat1(Landsat):
|
||||
"""Landsat 1 Multispectral Scanner (MSS)."""
|
||||
|
||||
Args:
|
||||
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
|
||||
filename_glob = "LM01_*.*"
|
||||
|
||||
Returns:
|
||||
sample of data and metadata at that index
|
||||
all_bands = [
|
||||
"B4",
|
||||
"B5",
|
||||
"B6",
|
||||
"B7",
|
||||
]
|
||||
rgb_bands = ["B6", "B5", "B4"]
|
||||
|
||||
Raises:
|
||||
IndexError: if query is not within bounds of the index
|
||||
"""
|
||||
if not query.intersects(self.bounds):
|
||||
raise IndexError(
|
||||
f"query: {query} is not within bounds of the index: {self.bounds}"
|
||||
)
|
||||
|
||||
hits = self.index.intersection(query, objects=True)
|
||||
filename = next(hits).object # TODO: this assumes there is only a single hit
|
||||
data_list = []
|
||||
for band in self.bands:
|
||||
tokens = filename.split("_")
|
||||
tokens[-1] = band + ".TIF"
|
||||
filename = "_".join(tokens)
|
||||
with rasterio.open(filename) as src:
|
||||
with WarpedVRT(src, crs=self.crs) as vrt:
|
||||
window = rasterio.windows.from_bounds(
|
||||
query.minx,
|
||||
query.miny,
|
||||
query.maxx,
|
||||
query.maxy,
|
||||
transform=vrt.transform,
|
||||
)
|
||||
image = vrt.read(window=window)
|
||||
data_list.append(image)
|
||||
image = np.concatenate(data_list) # type: ignore[no-untyped-call]
|
||||
image = image.astype(np.int32)
|
||||
sample = {
|
||||
"image": torch.tensor(image), # type: ignore[attr-defined]
|
||||
"crs": self.crs,
|
||||
"bbox": query,
|
||||
}
|
||||
class Landsat2(Landsat1):
|
||||
"""Landsat 2 Multispectral Scanner (MSS)."""
|
||||
|
||||
if self.transforms is not None:
|
||||
sample = self.transforms(sample)
|
||||
filename_glob = "LM02_*.*"
|
||||
|
||||
return sample
|
||||
|
||||
def plot(self, image: Tensor) -> None:
|
||||
"""Plot an image on a map.
|
||||
class Landsat3(Landsat1):
|
||||
"""Landsat 3 Multispectral Scanner (MSS)."""
|
||||
|
||||
Args:
|
||||
image: the image to plot
|
||||
"""
|
||||
# Convert from CxHxW to HxWxC
|
||||
image = image.permute((1, 2, 0))
|
||||
array = image.numpy()
|
||||
filename_glob = "LM03_*.*"
|
||||
|
||||
# Stretch to the range of 2nd to 98th percentile
|
||||
per98 = np.percentile(array, 98) # type: ignore[no-untyped-call]
|
||||
per02 = np.percentile(array, 2) # type: ignore[no-untyped-call]
|
||||
array = (array - per02) / (per98 - per02)
|
||||
array = np.clip(array, 0, 1)
|
||||
|
||||
# Plot the image
|
||||
ax = plt.axes()
|
||||
ax.imshow(array)
|
||||
ax.axis("off")
|
||||
plt.show()
|
||||
plt.close()
|
||||
class Landsat4MSS(Landsat):
|
||||
"""Landsat 4 Multispectral Scanner (MSS)."""
|
||||
|
||||
filename_glob = "LM04_*.*"
|
||||
|
||||
all_bands = [
|
||||
"B1",
|
||||
"B2",
|
||||
"B3",
|
||||
"B4",
|
||||
]
|
||||
rgb_bands = ["B3", "B2", "B1"]
|
||||
|
||||
|
||||
class Landsat4TM(Landsat):
|
||||
"""Landsat 4 Thematic Mapper (TM)."""
|
||||
|
||||
filename_glob = "LT04_*.*"
|
||||
|
||||
all_bands = [
|
||||
"B1",
|
||||
"B2",
|
||||
"B3",
|
||||
"B4",
|
||||
"B5",
|
||||
"B6",
|
||||
"B7",
|
||||
]
|
||||
rgb_bands = ["B3", "B2", "B1"]
|
||||
|
||||
|
||||
class Landsat5MSS(Landsat4MSS):
|
||||
"""Landsat 4 Multispectral Scanner (MSS)."""
|
||||
|
||||
filename_glob = "LM04_*.*"
|
||||
|
||||
|
||||
class Landsat5TM(Landsat4TM):
|
||||
"""Landsat 5 Thematic Mapper (TM)."""
|
||||
|
||||
filename_glob = "LT05_*.*"
|
||||
|
||||
|
||||
class Landsat7(Landsat):
|
||||
"""Landsat 7 Enhanced Thematic Mapper Plus (ETM+)."""
|
||||
|
||||
filename_glob = "LE07_*.*"
|
||||
|
||||
all_bands = [
|
||||
"B1",
|
||||
"B2",
|
||||
"B3",
|
||||
"B4",
|
||||
"B5",
|
||||
"B6",
|
||||
"B7",
|
||||
"B8",
|
||||
]
|
||||
rgb_bands = ["B3", "B2", "B1"]
|
||||
|
||||
|
||||
class Landsat8(Landsat):
|
||||
"""Landsat 8-9 Operational Land Imager (OLI) and Thermal Infrared Sensor (TIRS)."""
|
||||
"""Landsat 8 Operational Land Imager (OLI) and Thermal Infrared Sensor (TIRS)."""
|
||||
|
||||
band_names = [
|
||||
filename_glob = "LC08_*.*"
|
||||
|
||||
all_bands = [
|
||||
"B1",
|
||||
"B2",
|
||||
"B3",
|
||||
|
@ -172,67 +175,10 @@ class Landsat8(Landsat):
|
|||
"B10",
|
||||
"B11",
|
||||
]
|
||||
rgb_bands = ["B4", "B3", "B2"]
|
||||
|
||||
|
||||
Landsat9 = Landsat8
|
||||
class Landsat9(Landsat8):
|
||||
"""Landsat 9 Operational Land Imager (OLI) and Thermal Infrared Sensor (TIRS)."""
|
||||
|
||||
|
||||
class Landsat7(Landsat):
|
||||
"""Landsat 7 Enhanced Thematic Mapper Plus (ETM+)."""
|
||||
|
||||
band_names = [
|
||||
"B1",
|
||||
"B2",
|
||||
"B3",
|
||||
"B4",
|
||||
"B5",
|
||||
"B6",
|
||||
"B7",
|
||||
"B8",
|
||||
]
|
||||
|
||||
|
||||
class Landsat4TM(Landsat):
|
||||
"""Landsat 4-5 Thematic Mapper (TM)."""
|
||||
|
||||
band_names = [
|
||||
"B1",
|
||||
"B2",
|
||||
"B3",
|
||||
"B4",
|
||||
"B5",
|
||||
"B6",
|
||||
"B7",
|
||||
]
|
||||
|
||||
|
||||
Landsat5TM = Landsat4TM
|
||||
|
||||
|
||||
class Landsat4MSS(Landsat):
|
||||
"""Landsat 4-5 Multispectral Scanner (MSS)."""
|
||||
|
||||
band_names = [
|
||||
"B1",
|
||||
"B2",
|
||||
"B3",
|
||||
"B4",
|
||||
]
|
||||
|
||||
|
||||
Landsat5MSS = Landsat4MSS
|
||||
|
||||
|
||||
class Landsat1(Landsat):
|
||||
"""Landsat 1-3 Multispectral Scanner (MSS)."""
|
||||
|
||||
band_names = [
|
||||
"B4",
|
||||
"B5",
|
||||
"B6",
|
||||
"B7",
|
||||
]
|
||||
|
||||
|
||||
Landsat2 = Landsat1
|
||||
Landsat3 = Landsat1
|
||||
filename_glob = "LC09_*.*"
|
||||
|
|
|
@ -1,26 +1,9 @@
|
|||
"""National Agriculture Imagery Program (NAIP) dataset."""
|
||||
|
||||
import glob
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import rasterio
|
||||
import torch
|
||||
from rasterio.crs import CRS
|
||||
from rasterio.vrt import WarpedVRT
|
||||
from rtree.index import Index, Property
|
||||
from torch import Tensor
|
||||
|
||||
from .geo import GeoDataset
|
||||
from .utils import BoundingBox
|
||||
|
||||
_crs = CRS.from_epsg(26918)
|
||||
from .geo import RasterDataset
|
||||
|
||||
|
||||
class NAIP(GeoDataset):
|
||||
class NAIP(RasterDataset):
|
||||
"""National Agriculture Imagery Program (NAIP) dataset.
|
||||
|
||||
The `National Agriculture Imagery Program (NAIP)
|
||||
|
@ -41,118 +24,18 @@ class NAIP(GeoDataset):
|
|||
|
||||
# https://www.nrcs.usda.gov/Internet/FSE_DOCUMENTS/nrcs141p2_015644.pdf
|
||||
# https://planetarycomputer.microsoft.com/dataset/naip#Storage-Documentation
|
||||
filename_glob = "m_*.tif"
|
||||
filename_regex = re.compile(
|
||||
r"""
|
||||
filename_glob = "m_*.*"
|
||||
filename_regex = r"""
|
||||
^m
|
||||
_(?P<quadrangle>\d+)
|
||||
_(?P<quarter_quad>[a-z]+)
|
||||
_(?P<utm_zone>\d+)
|
||||
_(?P<resolution>\d+)
|
||||
_(?P<acquisition_date>\d+)
|
||||
_(?P<date>\d+)
|
||||
(?:_(?P<processing_date>\d+))?
|
||||
.tif$
|
||||
""",
|
||||
re.VERBOSE,
|
||||
)
|
||||
date_format = "%Y%m%d"
|
||||
\..*$
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str,
|
||||
crs: CRS = _crs,
|
||||
transforms: Optional[Callable[[Any], Any]] = None,
|
||||
) -> None:
|
||||
"""Initialize a new NAIP dataset instance.
|
||||
|
||||
Args:
|
||||
root: root directory where dataset can be found
|
||||
crs: :term:`coordinate reference system (CRS)` to project to
|
||||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: if no files are found in ``root``
|
||||
"""
|
||||
self.root = root
|
||||
self.crs = crs
|
||||
self.transforms = transforms
|
||||
|
||||
# Create an R-tree to index the dataset
|
||||
self.index = Index(interleaved=False, properties=Property(dimension=3))
|
||||
fileglob = os.path.join(root, "**", self.filename_glob)
|
||||
for i, filename in enumerate(glob.iglob(fileglob, recursive=True)):
|
||||
match = re.match(self.filename_regex, os.path.basename(filename))
|
||||
if match is not None:
|
||||
with rasterio.open(filename) as src:
|
||||
with WarpedVRT(src, crs=self.crs) as vrt:
|
||||
minx, miny, maxx, maxy = vrt.bounds
|
||||
date = match.group("acquisition_date")
|
||||
time = datetime.strptime(date, self.date_format)
|
||||
timestamp = time.timestamp()
|
||||
coords = (minx, maxx, miny, maxy, timestamp, timestamp)
|
||||
self.index.insert(i, coords, filename)
|
||||
|
||||
if "filename" not in locals():
|
||||
raise FileNotFoundError(f"No NAIP data was found in '{root}'")
|
||||
|
||||
def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
|
||||
"""Retrieve image and metadata indexed by query.
|
||||
|
||||
Args:
|
||||
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
|
||||
|
||||
Returns:
|
||||
sample of image and metadata at that index
|
||||
|
||||
Raises:
|
||||
IndexError: if query is not within bounds of the index
|
||||
"""
|
||||
if not query.intersects(self.bounds):
|
||||
raise IndexError(
|
||||
f"query: {query} is not within bounds of the index: {self.bounds}"
|
||||
)
|
||||
|
||||
hits = self.index.intersection(query, objects=True)
|
||||
filename = next(hits).object
|
||||
with rasterio.open(filename) as src:
|
||||
with WarpedVRT(src, crs=self.crs, nodata=0) as vrt:
|
||||
window = rasterio.windows.from_bounds(
|
||||
query.minx,
|
||||
query.miny,
|
||||
query.maxx,
|
||||
query.maxy,
|
||||
transform=vrt.transform,
|
||||
)
|
||||
image = vrt.read(window=window)
|
||||
|
||||
sample = {
|
||||
"image": torch.tensor(image), # type: ignore[attr-defined]
|
||||
"crs": self.crs,
|
||||
"bbox": query,
|
||||
}
|
||||
|
||||
if self.transforms is not None:
|
||||
sample = self.transforms(sample)
|
||||
|
||||
return sample
|
||||
|
||||
def plot(self, image: Tensor) -> None:
|
||||
"""Plot an image on a map.
|
||||
|
||||
Args:
|
||||
image: the image to plot
|
||||
"""
|
||||
# Drop NIR channel
|
||||
image = image[:3]
|
||||
|
||||
# Convert from CxHxW to HxWxC
|
||||
image = image.permute((1, 2, 0))
|
||||
array = image.numpy()
|
||||
|
||||
# Plot the image
|
||||
ax = plt.axes()
|
||||
ax.imshow(array)
|
||||
ax.axis("off")
|
||||
plt.show()
|
||||
plt.close()
|
||||
# Plotting
|
||||
all_bands = ["R", "G", "B", "NIR"]
|
||||
rgb_bands = ["R", "G", "B"]
|
||||
|
|
|
@ -1,23 +1,13 @@
|
|||
"""Sentinel datasets."""
|
||||
|
||||
import abc
|
||||
import glob
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, Optional, Sequence
|
||||
|
||||
import numpy as np
|
||||
import rasterio
|
||||
import torch
|
||||
from rasterio.crs import CRS as RCRS
|
||||
from rasterio.vrt import WarpedVRT
|
||||
from rtree.index import Index, Property
|
||||
from rasterio.crs import CRS
|
||||
|
||||
from .geo import GeoDataset
|
||||
from .utils import BoundingBox
|
||||
from .geo import RasterDataset
|
||||
|
||||
|
||||
class Sentinel(GeoDataset, abc.ABC):
|
||||
class Sentinel(RasterDataset):
|
||||
"""Abstract base class for all Sentinel datasets.
|
||||
|
||||
`Sentinel <https://sentinel.esa.int/web/sentinel/home>`_ is a family of
|
||||
|
@ -43,7 +33,22 @@ class Sentinel2(Sentinel):
|
|||
Earth's surface changes.
|
||||
"""
|
||||
|
||||
band_names = [
|
||||
# TODO: files downloaded from USGS Earth Explorer seem to have a different
|
||||
# filename format than the official documentation
|
||||
# https://sentinels.copernicus.eu/web/sentinel/user-guides/sentinel-2-msi/naming-convention
|
||||
# https://sentinel.esa.int/documents/247904/685211/Sentinel-2-MSI-L2A-Product-Format-Specifications.pdf
|
||||
filename_glob = "T*_*_B*_*m.*"
|
||||
filename_regex = r"""
|
||||
^T(?P<tile>\d{2}[A-Z]{3})
|
||||
_(?P<date>\d{8}T\d{6})
|
||||
_(?P<band>B\d{2})
|
||||
_(?P<resolution>\d{2}m)
|
||||
\..*$
|
||||
"""
|
||||
date_format = "%Y%m%dT%H%M%S"
|
||||
|
||||
# https://gisgeography.com/sentinel-2-bands-combinations/
|
||||
all_bands = [
|
||||
"B01",
|
||||
"B02",
|
||||
"B03",
|
||||
|
@ -58,94 +63,30 @@ class Sentinel2(Sentinel):
|
|||
"B11",
|
||||
"B12",
|
||||
]
|
||||
rgb_bands = ["B04", "B03", "B02"]
|
||||
|
||||
separate_files = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str = "data",
|
||||
crs: RCRS = RCRS.from_epsg(32641),
|
||||
bands: Sequence[str] = band_names,
|
||||
crs: Optional[CRS] = None,
|
||||
bands: Sequence[str] = [],
|
||||
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
|
||||
) -> None:
|
||||
"""Initialize a new Sentinel-2 Dataset.
|
||||
"""Initialize a new Dataset instance.
|
||||
|
||||
Args:
|
||||
root: root directory where dataset can be found
|
||||
crs: :term:`coordinate reference system (CRS)` to project to
|
||||
bands: bands to return
|
||||
crs: :term:`coordinate reference system (CRS)` to project to. Uses the CRS
|
||||
of the files by default
|
||||
bands: bands to return (defaults to all bands)
|
||||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: if no files are found in ``root``
|
||||
"""
|
||||
self.root = root
|
||||
self.crs = crs
|
||||
self.bands = bands
|
||||
self.transforms = transforms
|
||||
self.bands = bands if bands else self.all_bands
|
||||
|
||||
# Create an R-tree to index the dataset
|
||||
self.index = Index(interleaved=False, properties=Property(dimension=3))
|
||||
fileglob = os.path.join(root, f"**_{bands[0]}_*.tif")
|
||||
for i, filename in enumerate(glob.iglob(fileglob, recursive=True)):
|
||||
# https://sentinel.esa.int/web/sentinel/user-guides/sentinel-2-msi/naming-convention
|
||||
time = datetime.strptime(
|
||||
os.path.basename(filename).split("_")[1], "%Y%m%dT%H%M%S"
|
||||
)
|
||||
timestamp = time.timestamp()
|
||||
with rasterio.open(filename) as src:
|
||||
with WarpedVRT(src, crs=self.crs) as vrt:
|
||||
minx, miny, maxx, maxy = vrt.bounds
|
||||
coords = (minx, maxx, miny, maxy, timestamp, timestamp)
|
||||
self.index.insert(i, coords, filename)
|
||||
|
||||
if "filename" not in locals():
|
||||
raise FileNotFoundError(f"No Sentinel2 data was found in '{root}'")
|
||||
|
||||
def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
|
||||
"""Retrieve image and metadata indexed by query.
|
||||
|
||||
Args:
|
||||
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
|
||||
|
||||
Returns:
|
||||
sample of data and metadata at that index
|
||||
|
||||
Raises:
|
||||
IndexError: if query is not within bounds of the index
|
||||
"""
|
||||
if not query.intersects(self.bounds):
|
||||
raise IndexError(
|
||||
f"query: {query} is not within bounds of the index: {self.bounds}"
|
||||
)
|
||||
|
||||
hits = self.index.intersection(query, objects=True)
|
||||
filename = next(hits).object # TODO: this assumes there is only a single hit
|
||||
data_list = []
|
||||
for band in self.bands:
|
||||
tokens = filename.split("_")
|
||||
tokens[2] = band
|
||||
filename = "_".join(tokens)
|
||||
with rasterio.open(filename) as src:
|
||||
with WarpedVRT(src, crs=self.crs) as vrt:
|
||||
window = rasterio.windows.from_bounds(
|
||||
query.minx,
|
||||
query.miny,
|
||||
query.maxx,
|
||||
query.maxy,
|
||||
transform=vrt.transform,
|
||||
)
|
||||
image = vrt.read(window=window)
|
||||
data_list.append(image)
|
||||
# FIXME: different bands have different resolution, won't be able to concatenate
|
||||
image = np.concatenate(data_list) # type: ignore[no-untyped-call]
|
||||
image = image.astype(np.int32)
|
||||
sample = {
|
||||
"image": torch.tensor(image), # type: ignore[attr-defined]
|
||||
"crs": self.crs,
|
||||
"bbox": query,
|
||||
}
|
||||
|
||||
if self.transforms is not None:
|
||||
sample = self.transforms(sample)
|
||||
|
||||
return sample
|
||||
super().__init__(root, crs, transforms)
|
||||
|
|
Загрузка…
Ссылка в новой задаче