This commit is contained in:
Adam J. Stewart 2021-08-05 21:16:30 +00:00
Родитель ae8049ac19
Коммит 9af70d3c40
8 изменённых файлов: 445 добавлений и 636 удалений

Просмотреть файл

@ -117,6 +117,16 @@ GeoDataset
.. autoclass:: GeoDataset
RasterDataset
^^^^^^^^^^^^^
.. autoclass:: RasterDataset
VectorDataset
^^^^^^^^^^^^^
.. autoclass:: VectorDataset
VisionDataset
^^^^^^^^^^^^^

Просмотреть файл

@ -18,7 +18,7 @@ from .chesapeake import (
from .cowc import COWC, COWCCounting, COWCDetection
from .cv4a_kenya_crop_type import CV4AKenyaCropType
from .cyclone import TropicalCycloneWindEstimation
from .geo import GeoDataset, VisionDataset, ZipDataset
from .geo import GeoDataset, RasterDataset, VectorDataset, VisionDataset, ZipDataset
from .landcoverai import LandCoverAI
from .landsat import (
Landsat,
@ -81,6 +81,8 @@ __all__ = (
"VHR10",
# Base classes
"GeoDataset",
"RasterDataset",
"VectorDataset",
"VisionDataset",
"ZipDataset",
# Utilities

Просмотреть файл

@ -1,50 +1,15 @@
"""CDL dataset."""
import glob
import os
from datetime import datetime
from typing import Any, Callable, Dict, Optional
import matplotlib.pyplot as plt
import numpy as np
import rasterio
import torch
from rasterio.crs import CRS
from rasterio.vrt import WarpedVRT
from rtree.index import Index, Property
from torch import Tensor
from .geo import GeoDataset
from .utils import BoundingBox, check_integrity, download_and_extract_archive
_crs = CRS.from_wkt(
"""
PROJCS["Albers Conical Equal Area",
GEOGCS["NAD83",
DATUM["North_American_Datum_1983",
SPHEROID["GRS 1980",6378137,298.257222101,
AUTHORITY["EPSG","7019"]],
AUTHORITY["EPSG","6269"]],
PRIMEM["Greenwich",0,
AUTHORITY["EPSG","8901"]],
UNIT["degree",0.0174532925199433,
AUTHORITY["EPSG","9122"]],
AUTHORITY["EPSG","4269"]],
PROJECTION["Albers_Conic_Equal_Area"],
PARAMETER["latitude_of_center",23],
PARAMETER["longitude_of_center",-96],
PARAMETER["standard_parallel_1",29.5],
PARAMETER["standard_parallel_2",45.5],
PARAMETER["false_easting",0],
PARAMETER["false_northing",0],
UNIT["meters",1],
AXIS["Easting",EAST],
AXIS["Northing",NORTH]]
"""
)
from .geo import RasterDataset
from .utils import check_integrity, download_and_extract_archive
class CDL(GeoDataset):
class CDL(RasterDataset):
"""Cropland Data Layer (CDL) dataset.
The `Cropland Data Layer
@ -62,6 +27,14 @@ class CDL(GeoDataset):
* https://www.nass.usda.gov/Research_and_Science/Cropland/sarsfaqs2.php#Section1_14.0
""" # noqa: E501
filename_glob = "*_30m_cdls.*"
filename_regex = r"""
^(?P<date>\d+)
_30m_cdls\..*$
"""
date_format = "%Y"
is_image = False
url = "https://www.nass.usda.gov/Research_and_Science/Cropland/Release/datasets/{}_30m_cdls.zip" # noqa: E501
md5s = [
(2020, "97b3b5fd62177c9ed857010bca146f36"),
@ -82,24 +55,27 @@ class CDL(GeoDataset):
def __init__(
self,
root: str = "data",
crs: CRS = _crs,
crs: Optional[CRS] = None,
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
download: bool = False,
checksum: bool = False,
) -> None:
"""Initialize a new CDL Dataset.
"""Initialize a new Dataset instance.
Args:
root: root directory where dataset can be found
crs: :term:`coordinate reference system (CRS)` to project to
crs: :term:`coordinate reference system (CRS)` to project to. Uses the CRS
of the files by default
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
FileNotFoundError: if no files are found in ``root``
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
"""
self.root = root
self.crs = crs
self.transforms = transforms
self.checksum = checksum
if download:
@ -111,61 +87,7 @@ class CDL(GeoDataset):
+ "You can use download=True to download it"
)
# Create an R-tree to index the dataset
self.index = Index(interleaved=False, properties=Property(dimension=3))
fileglob = os.path.join(root, "**_30m_cdls.img")
for i, filename in enumerate(glob.iglob(fileglob, recursive=True)):
year = int(os.path.basename(filename).split("_")[0])
mint = datetime(year, 1, 1, 0, 0, 0).timestamp()
maxt = datetime(year, 12, 31, 23, 59, 59).timestamp()
with rasterio.open(filename) as src:
cmap = src.colormap(1)
with WarpedVRT(src, crs=self.crs) as vrt:
minx, miny, maxx, maxy = vrt.bounds
coords = (minx, maxx, miny, maxy, mint, maxt)
self.index.insert(i, coords, filename)
self.cmap = np.array([cmap[i] for i in range(256)])
def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
"""Retrieve image and metadata indexed by query.
Args:
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
Returns:
sample of labels and metadata at that index
Raises:
IndexError: if query is not within bounds of the index
"""
if not query.intersects(self.bounds):
raise IndexError(
f"query: {query} is not within bounds of the index: {self.bounds}"
)
hits = self.index.intersection(query, objects=True)
filename = next(hits).object # TODO: this assumes there is only a single hit
with rasterio.open(filename) as src:
with WarpedVRT(src, crs=self.crs) as vrt:
window = rasterio.windows.from_bounds(
query.minx,
query.miny,
query.maxx,
query.maxy,
transform=vrt.transform,
)
masks = vrt.read(window=window)
masks = masks.astype(np.int32)
sample = {
"masks": torch.tensor(masks), # type: ignore[attr-defined]
"crs": self.crs,
"bbox": query,
}
if self.transforms is not None:
sample = self.transforms(sample)
return sample
super().__init__(root, crs, transforms)
def _check_integrity(self) -> bool:
"""Check integrity of dataset.
@ -191,20 +113,3 @@ class CDL(GeoDataset):
self.root,
md5=md5 if self.checksum else None,
)
def plot(self, image: Tensor) -> None:
"""Plot an image on a map.
Args:
image: the image to plot
"""
# Convert from class labels to RGBA values
array = image.squeeze().numpy()
array = self.cmap[array]
# Plot the image
ax = plt.axes()
ax.imshow(array)
ax.axis("off")
plt.show()
plt.close()

Просмотреть файл

@ -2,49 +2,15 @@
import abc
import os
import sys
from typing import Any, Callable, Dict, Optional
import matplotlib.pyplot as plt
import numpy as np
import rasterio
import torch
from rasterio.crs import CRS
from rasterio.vrt import WarpedVRT
from rtree.index import Index, Property
from torch import Tensor
from .geo import GeoDataset
from .utils import BoundingBox, check_integrity, download_and_extract_archive
_crs = CRS.from_wkt(
"""
PROJCS["USA_Contiguous_Albers_Equal_Area_Conic_USGS_version",
GEOGCS["NAD83",
DATUM["North_American_Datum_1983",
SPHEROID["GRS 1980",6378137,298.257222101004,
AUTHORITY["EPSG","7019"]],
AUTHORITY["EPSG","6269"]],
PRIMEM["Greenwich",0],
UNIT["degree",0.0174532925199433,
AUTHORITY["EPSG","9122"]],
AUTHORITY["EPSG","4269"]],
PROJECTION["Albers_Conic_Equal_Area"],
PARAMETER["latitude_of_center",23],
PARAMETER["longitude_of_center",-96],
PARAMETER["standard_parallel_1",29.5],
PARAMETER["standard_parallel_2",45.5],
PARAMETER["false_easting",0],
PARAMETER["false_northing",0],
UNIT["metre",1,
AUTHORITY["EPSG","9001"]],
AXIS["Easting",EAST],
AXIS["Northing",NORTH]]
"""
)
from .geo import RasterDataset
from .utils import check_integrity, download_and_extract_archive
class Chesapeake(GeoDataset, abc.ABC):
class Chesapeake(RasterDataset, abc.ABC):
"""Abstract base class for all Chesapeake datasets.
`Chesapeake Bay High-Resolution Land Cover Project
@ -70,6 +36,8 @@ class Chesapeake(GeoDataset, abc.ABC):
* https://doi.org/10.1109/cvpr.2019.01301
"""
is_image = False
@property
@abc.abstractmethod
def base_folder(self) -> str:
@ -97,46 +65,29 @@ class Chesapeake(GeoDataset, abc.ABC):
url += f"/{self.base_folder}/{self.zipfile}"
return url
cmap = {
0: (0, 0, 0, 0),
1: (0, 197, 255, 255),
2: (0, 168, 132, 255),
3: (38, 115, 0, 255),
4: (76, 230, 0, 255),
5: (163, 255, 115, 255),
6: (255, 170, 0, 255),
7: (255, 0, 0, 255),
8: (156, 156, 156, 255),
9: (0, 0, 0, 255),
10: (115, 115, 0, 255),
11: (230, 230, 0, 255),
12: (255, 255, 115, 255),
13: (197, 0, 255, 255),
14: (0, 0, 0, 0),
15: (0, 0, 0, 0),
}
def __init__(
self,
root: str,
crs: CRS = _crs,
transforms: Optional[Callable[[Any], Any]] = None,
root: str = "data",
crs: Optional[CRS] = None,
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
download: bool = False,
checksum: bool = False,
) -> None:
"""Initialize a new Chesapeake dataset instance.
"""Initialize a new Dataset instance.
Args:
root: root directory where dataset can be found
crs: :term:`coordinate reference system (CRS)` to project to
crs: :term:`coordinate reference system (CRS)` to project to. Uses the CRS
of the files by default
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
"""
self.root = root
self.crs = crs
self.transforms = transforms
self.checksum = checksum
if download:
@ -148,57 +99,7 @@ class Chesapeake(GeoDataset, abc.ABC):
+ "You can use download=True to download it"
)
# Create an R-tree to index the dataset
self.index = Index(interleaved=False, properties=Property(dimension=3))
filename = os.path.join(self.root, self.filename)
with rasterio.open(filename) as src:
with rasterio.open(filename) as src:
with WarpedVRT(src, crs=self.crs) as vrt:
minx, miny, maxx, maxy = vrt.bounds
mint = 0
maxt = sys.maxsize
coords = (minx, maxx, miny, maxy, mint, maxt)
self.index.insert(0, coords, filename)
def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
"""Retrieve labels and metadata indexed by query.
Args:
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
Returns:
sample of labels and metadata at that index
Raises:
IndexError: if query is not within bounds of the index
"""
if not query.intersects(self.bounds):
raise IndexError(
f"query: {query} is not within bounds of the index: {self.bounds}"
)
hits = self.index.intersection(query, objects=True)
filename = next(hits).object
with rasterio.open(filename) as src:
with WarpedVRT(src, crs=self.crs) as vrt:
window = rasterio.windows.from_bounds(
query.minx,
query.miny,
query.maxx,
query.maxy,
transform=vrt.transform,
)
masks = vrt.read(window=window)
sample = {
"masks": torch.tensor(masks), # type: ignore[attr-defined]
"crs": self.crs,
"bbox": query,
}
if self.transforms is not None:
sample = self.transforms(sample)
return sample
super().__init__(root, crs, transforms)
def _check_integrity(self) -> bool:
"""Check integrity of dataset.
@ -225,24 +126,6 @@ class Chesapeake(GeoDataset, abc.ABC):
md5=self.md5,
)
def plot(self, image: Tensor) -> None:
"""Plot an image on a map.
Args:
image: the image to plot
"""
# Convert from class labels to RGBA values
cmap = np.array([self.cmap[i] for i in range(len(self.cmap))])
array = image.squeeze().numpy()
array = cmap[array]
# Plot the image
ax = plt.axes()
ax.imshow(array)
ax.axis("off")
plt.show()
plt.close()
class Chesapeake7(Chesapeake):
"""Complete 7-class dataset.

Просмотреть файл

@ -1,10 +1,21 @@
"""Base classes for all :mod:`torchgeo` datasets."""
import abc
from typing import Any, Dict, Sequence
import glob
import os
import re
import sys
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
import matplotlib.pyplot as plt
import numpy as np
import rasterio
import torch
from rasterio.crs import CRS
from rtree.index import Index
from rasterio.vrt import WarpedVRT
from rtree.index import Index, Property
from torch import Tensor
from torch.utils.data import Dataset
from .utils import BoundingBox
@ -90,6 +101,234 @@ class GeoDataset(Dataset[Dict[str, Any]], abc.ABC):
return BoundingBox(*self.index.bounds)
class RasterDataset(GeoDataset):
"""Abstract base class for :class:`GeoDataset`s stored as raster files."""
#: Glob expression used to search for files.
#:
#: This expression should be specific enough that it will not pick up files from
#: other datasets. It should not include a file extension, as the dataset may be in
#: a different file format than what it was originally downloaded as.
filename_glob = "*"
#: Regular expression used to extract date from filename.
#:
#: The expression should use named groups. The expression may contain any number of
#: groups. The following groups are specifically searched for by the base class:
#:
#: * ``date``: used to calculate ``mint`` and ``maxt`` for ``index`` insertion
#: * ``band``: used when :attr:`separate_files` is True
filename_regex = ".*"
#: Date format string used to parse date from filename.
#:
#: Not used if :attr:`filename_regex` does not contain a ``date`` group.
date_format = "%Y%m%d"
#: True if dataset contains imagery, False if dataset contains mask
is_image = True
#: True if data is stored in a separate file for each band, else False.
separate_files = False
#: Names of all available bands in the dataset
all_bands: List[str] = []
#: Names of RGB bands in the dataset, used for plotting
rgb_bands: List[str] = []
#: If True, stretch the image from the 2nd percentile to the 98th percentile,
#: used for plotting
stretch = False
#: Color map for the dataset, used for plotting
cmap: Dict[int, Tuple[int, int, int, int]] = {}
def __init__(
self,
root: str,
crs: Optional[CRS] = None,
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
) -> None:
"""Initialize a new Dataset instance.
Args:
root: root directory where dataset can be found
crs: :term:`coordinate reference system (CRS)` to project to. Uses the CRS
of the files by default
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
Raises:
FileNotFoundError: if no files are found in ``root``
"""
self.root = root
self.crs = crs
self.transforms = transforms
# Create an R-tree to index the dataset
self.index = Index(interleaved=False, properties=Property(dimension=3))
i = 0
pathname = os.path.join(root, "**", self.filename_glob)
filename_regex = re.compile(self.filename_regex, re.VERBOSE)
for filepath in glob.iglob(pathname, recursive=True):
match = re.match(filename_regex, os.path.basename(filepath))
if match is not None:
try:
with rasterio.open(filepath) as src:
# See if file has a color map
try:
self.cmap = src.colormap(1)
except ValueError:
pass
if self.crs is None:
self.crs = src.crs
with WarpedVRT(src, crs=self.crs) as vrt:
minx, miny, maxx, maxy = vrt.bounds
except rasterio.errors.RasterioIOError:
# Skip files that rasterio is unable to read
continue
else:
mint: float = 0
maxt: float = sys.maxsize
if "date" in match.groupdict():
date = match.group("date")
time = datetime.strptime(date, self.date_format)
mint = maxt = time.timestamp()
coords = (minx, maxx, miny, maxy, mint, maxt)
self.index.insert(i, coords, filepath)
i += 1
if i == 0:
raise FileNotFoundError(
f"No {self.__class__.__name__} data was found in '{root}'"
)
def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
"""Retrieve image/mask and metadata indexed by query.
Args:
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
Returns:
sample of image/mask and metadata at that index
Raises:
IndexError: if query is not found in the index
"""
hits = self.index.intersection(query, objects=True)
try:
hit = next(hits) # TODO: this assumes there is only a single hit
except StopIteration:
raise IndexError(
f"query: {query} is not within bounds of the index: {self.bounds}"
)
filepath = hit.object
if self.separate_files:
data_list = []
filename_regex = re.compile(self.filename_regex, re.VERBOSE)
for band in getattr(self, "bands", self.all_bands):
filename = os.path.basename(filepath)
directory = os.path.dirname(filepath)
match = re.match(filename_regex, filename)
if match:
start, end = match.start("band"), match.end("band")
filename = filename[:start] + band + filename[end:]
data_list.append(
self._load_file(os.path.join(directory, filename), query)
)
data = torch.stack(data_list)
else:
data = self._load_file(filepath, query)
key = "image" if self.is_image else "masks"
sample = {
key: data,
"crs": self.crs,
"bbox": query,
}
if self.transforms is not None:
sample = self.transforms(sample)
return sample
def _load_file(self, filepath: str, query: BoundingBox) -> Tensor:
"""Load a single raster file.
Args:
filepath: path to file to open
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
Returns:
image/mask at that index
"""
with rasterio.open(filepath) as src:
with WarpedVRT(src, crs=self.crs, nodata=0) as vrt:
window = rasterio.windows.from_bounds(
query.minx,
query.miny,
query.maxx,
query.maxy,
transform=vrt.transform,
)
array = vrt.read(window=window).astype(np.int32)
tensor: Tensor = torch.tensor(array) # type: ignore[attr-defined]
return tensor
def plot(self, data: Tensor) -> None:
"""Plot a data sample.
Args:
data: the data to plot
Raises:
AssertionError: if ``is_image`` is True and ``data`` has a different number
of channels than expected
"""
array = data.squeeze().numpy()
if self.is_image:
bands = getattr(self, "bands", self.all_bands)
assert array.shape[0] == len(bands)
# Only plot RGB bands
if bands and self.rgb_bands:
indices = np.array([bands.index(band) for band in self.rgb_bands])
array = array[indices]
# Convert from CxHxW to HxWxC
array = np.rollaxis(array, 0, 3)
if self.cmap:
# Convert from class labels to RGBA values
cmap = np.array([self.cmap[i] for i in range(len(self.cmap))])
array = cmap[array]
if self.stretch:
# Stretch to the range of 2nd to 98th percentile
per02 = np.percentile(array, 2) # type: ignore[no-untyped-call]
per98 = np.percentile(array, 98) # type: ignore[no-untyped-call]
array = (array - per02) / (per98 - per02)
array = np.clip(array, 0, 1)
# Plot the data
ax = plt.axes()
ax.imshow(array)
ax.axis("off")
plt.show()
plt.close()
class VectorDataset(GeoDataset):
pass
class VisionDataset(Dataset[Dict[str, Any]], abc.ABC):
"""Abstract base class for datasets lacking geospatial information.

Просмотреть файл

@ -1,27 +1,14 @@
"""Landsat datasets."""
import abc
import glob
import os
from datetime import datetime
from typing import Any, Callable, Dict, Optional, Sequence
from typing import Any, Callable, Dict, List, Optional, Sequence
import matplotlib.pyplot as plt
import numpy as np
import rasterio
import torch
from rasterio.crs import CRS
from rasterio.vrt import WarpedVRT
from rtree.index import Index, Property
from torch import Tensor
from .geo import GeoDataset
from .utils import BoundingBox
_crs = CRS.from_epsg(32616)
from .geo import RasterDataset
class Landsat(GeoDataset, abc.ABC):
class Landsat(RasterDataset, abc.ABC):
"""Abstract base class for all Landsat datasets.
`Landsat <https://landsat.gsfc.nasa.gov/>`_ is a joint NASA/USGS program,
@ -32,27 +19,45 @@ class Landsat(GeoDataset, abc.ABC):
* https://www.usgs.gov/centers/eros/data-citation
"""
@property
@abc.abstractmethod
def band_names(self) -> Sequence[str]:
"""Spectral bands provided by a satellite.
# https://www.usgs.gov/faqs/what-naming-convention-landsat-collections-level-1-scenes
# https://www.usgs.gov/faqs/what-naming-convention-landsat-collection-2-level-1-and-level-2-scenes
filename_glob = ""
filename_regex = r"""
^L
(?P<sensor>[COTEM])
(?P<satellite>\d{2})
_(?P<processing_correction_level>[A-Z0-9]{4})
_(?P<wrs_path>\d{3})
(?P<wrs_row>\d{3})
_(?P<date>\d{8})
_(?P<processing_date>\d{8})
_(?P<collection_number>\d{2})
_(?P<collection_category>[A-Z0-9]{2})
_SR
_(?P<band>[A-Z0-9]{2})
\..*$
"""
See https://www.usgs.gov/faqs/what-are-band-designations-landsat-satellites
for more details.
"""
# https://www.usgs.gov/faqs/what-are-band-designations-landsat-satellites
all_bands: List[str] = []
rgb_bands: List[str] = []
separate_files = True
stretch = True
def __init__(
self,
root: str = "data",
crs: CRS = _crs,
crs: Optional[CRS] = None,
bands: Sequence[str] = [],
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
) -> None:
"""Initialize a new Landsat Dataset.
"""Initialize a new Dataset instance.
Args:
root: root directory where dataset can be found
crs: :term:`coordinate reference system (CRS)` to project to
crs: :term:`coordinate reference system (CRS)` to project to. Uses the CRS
of the files by default
bands: bands to return (defaults to all bands)
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
@ -60,106 +65,104 @@ class Landsat(GeoDataset, abc.ABC):
Raises:
FileNotFoundError: if no files are found in ``root``
"""
self.root = root
self.crs = crs
self.bands = bands if bands else self.band_names
self.transforms = transforms
self.bands = bands if bands else self.all_bands
# Create an R-tree to index the dataset
self.index = Index(interleaved=False, properties=Property(dimension=3))
fileglob = os.path.join(root, f"**_{self.bands[0]}.TIF")
for i, filename in enumerate(glob.iglob(fileglob, recursive=True)):
# https://www.usgs.gov/faqs/what-naming-convention-landsat-collections-level-1-scenes
# https://www.usgs.gov/faqs/what-naming-convention-landsat-collection-2-level-1-and-level-2-scenes
time = datetime.strptime(os.path.basename(filename).split("_")[3], "%Y%m%d")
timestamp = time.timestamp()
with rasterio.open(filename) as src:
with WarpedVRT(src, crs=self.crs) as vrt:
minx, miny, maxx, maxy = vrt.bounds
coords = (minx, maxx, miny, maxy, timestamp, timestamp)
self.index.insert(i, coords, filename)
super().__init__(root, crs, transforms)
if "filename" not in locals():
raise FileNotFoundError(
f"No {self.__class__.__name__} data was found in '{root}'"
)
def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
"""Retrieve image and metadata indexed by query.
class Landsat1(Landsat):
"""Landsat 1 Multispectral Scanner (MSS)."""
Args:
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
filename_glob = "LM01_*.*"
Returns:
sample of data and metadata at that index
all_bands = [
"B4",
"B5",
"B6",
"B7",
]
rgb_bands = ["B6", "B5", "B4"]
Raises:
IndexError: if query is not within bounds of the index
"""
if not query.intersects(self.bounds):
raise IndexError(
f"query: {query} is not within bounds of the index: {self.bounds}"
)
hits = self.index.intersection(query, objects=True)
filename = next(hits).object # TODO: this assumes there is only a single hit
data_list = []
for band in self.bands:
tokens = filename.split("_")
tokens[-1] = band + ".TIF"
filename = "_".join(tokens)
with rasterio.open(filename) as src:
with WarpedVRT(src, crs=self.crs) as vrt:
window = rasterio.windows.from_bounds(
query.minx,
query.miny,
query.maxx,
query.maxy,
transform=vrt.transform,
)
image = vrt.read(window=window)
data_list.append(image)
image = np.concatenate(data_list) # type: ignore[no-untyped-call]
image = image.astype(np.int32)
sample = {
"image": torch.tensor(image), # type: ignore[attr-defined]
"crs": self.crs,
"bbox": query,
}
class Landsat2(Landsat1):
"""Landsat 2 Multispectral Scanner (MSS)."""
if self.transforms is not None:
sample = self.transforms(sample)
filename_glob = "LM02_*.*"
return sample
def plot(self, image: Tensor) -> None:
"""Plot an image on a map.
class Landsat3(Landsat1):
"""Landsat 3 Multispectral Scanner (MSS)."""
Args:
image: the image to plot
"""
# Convert from CxHxW to HxWxC
image = image.permute((1, 2, 0))
array = image.numpy()
filename_glob = "LM03_*.*"
# Stretch to the range of 2nd to 98th percentile
per98 = np.percentile(array, 98) # type: ignore[no-untyped-call]
per02 = np.percentile(array, 2) # type: ignore[no-untyped-call]
array = (array - per02) / (per98 - per02)
array = np.clip(array, 0, 1)
# Plot the image
ax = plt.axes()
ax.imshow(array)
ax.axis("off")
plt.show()
plt.close()
class Landsat4MSS(Landsat):
"""Landsat 4 Multispectral Scanner (MSS)."""
filename_glob = "LM04_*.*"
all_bands = [
"B1",
"B2",
"B3",
"B4",
]
rgb_bands = ["B3", "B2", "B1"]
class Landsat4TM(Landsat):
"""Landsat 4 Thematic Mapper (TM)."""
filename_glob = "LT04_*.*"
all_bands = [
"B1",
"B2",
"B3",
"B4",
"B5",
"B6",
"B7",
]
rgb_bands = ["B3", "B2", "B1"]
class Landsat5MSS(Landsat4MSS):
"""Landsat 4 Multispectral Scanner (MSS)."""
filename_glob = "LM04_*.*"
class Landsat5TM(Landsat4TM):
"""Landsat 5 Thematic Mapper (TM)."""
filename_glob = "LT05_*.*"
class Landsat7(Landsat):
"""Landsat 7 Enhanced Thematic Mapper Plus (ETM+)."""
filename_glob = "LE07_*.*"
all_bands = [
"B1",
"B2",
"B3",
"B4",
"B5",
"B6",
"B7",
"B8",
]
rgb_bands = ["B3", "B2", "B1"]
class Landsat8(Landsat):
"""Landsat 8-9 Operational Land Imager (OLI) and Thermal Infrared Sensor (TIRS)."""
"""Landsat 8 Operational Land Imager (OLI) and Thermal Infrared Sensor (TIRS)."""
band_names = [
filename_glob = "LC08_*.*"
all_bands = [
"B1",
"B2",
"B3",
@ -172,67 +175,10 @@ class Landsat8(Landsat):
"B10",
"B11",
]
rgb_bands = ["B4", "B3", "B2"]
Landsat9 = Landsat8
class Landsat9(Landsat8):
"""Landsat 9 Operational Land Imager (OLI) and Thermal Infrared Sensor (TIRS)."""
class Landsat7(Landsat):
"""Landsat 7 Enhanced Thematic Mapper Plus (ETM+)."""
band_names = [
"B1",
"B2",
"B3",
"B4",
"B5",
"B6",
"B7",
"B8",
]
class Landsat4TM(Landsat):
"""Landsat 4-5 Thematic Mapper (TM)."""
band_names = [
"B1",
"B2",
"B3",
"B4",
"B5",
"B6",
"B7",
]
Landsat5TM = Landsat4TM
class Landsat4MSS(Landsat):
"""Landsat 4-5 Multispectral Scanner (MSS)."""
band_names = [
"B1",
"B2",
"B3",
"B4",
]
Landsat5MSS = Landsat4MSS
class Landsat1(Landsat):
"""Landsat 1-3 Multispectral Scanner (MSS)."""
band_names = [
"B4",
"B5",
"B6",
"B7",
]
Landsat2 = Landsat1
Landsat3 = Landsat1
filename_glob = "LC09_*.*"

Просмотреть файл

@ -1,26 +1,9 @@
"""National Agriculture Imagery Program (NAIP) dataset."""
import glob
import os
import re
from datetime import datetime
from typing import Any, Callable, Dict, Optional
import matplotlib.pyplot as plt
import rasterio
import torch
from rasterio.crs import CRS
from rasterio.vrt import WarpedVRT
from rtree.index import Index, Property
from torch import Tensor
from .geo import GeoDataset
from .utils import BoundingBox
_crs = CRS.from_epsg(26918)
from .geo import RasterDataset
class NAIP(GeoDataset):
class NAIP(RasterDataset):
"""National Agriculture Imagery Program (NAIP) dataset.
The `National Agriculture Imagery Program (NAIP)
@ -41,118 +24,18 @@ class NAIP(GeoDataset):
# https://www.nrcs.usda.gov/Internet/FSE_DOCUMENTS/nrcs141p2_015644.pdf
# https://planetarycomputer.microsoft.com/dataset/naip#Storage-Documentation
filename_glob = "m_*.tif"
filename_regex = re.compile(
r"""
filename_glob = "m_*.*"
filename_regex = r"""
^m
_(?P<quadrangle>\d+)
_(?P<quarter_quad>[a-z]+)
_(?P<utm_zone>\d+)
_(?P<resolution>\d+)
_(?P<acquisition_date>\d+)
_(?P<date>\d+)
(?:_(?P<processing_date>\d+))?
.tif$
""",
re.VERBOSE,
)
date_format = "%Y%m%d"
\..*$
"""
def __init__(
self,
root: str,
crs: CRS = _crs,
transforms: Optional[Callable[[Any], Any]] = None,
) -> None:
"""Initialize a new NAIP dataset instance.
Args:
root: root directory where dataset can be found
crs: :term:`coordinate reference system (CRS)` to project to
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
Raises:
FileNotFoundError: if no files are found in ``root``
"""
self.root = root
self.crs = crs
self.transforms = transforms
# Create an R-tree to index the dataset
self.index = Index(interleaved=False, properties=Property(dimension=3))
fileglob = os.path.join(root, "**", self.filename_glob)
for i, filename in enumerate(glob.iglob(fileglob, recursive=True)):
match = re.match(self.filename_regex, os.path.basename(filename))
if match is not None:
with rasterio.open(filename) as src:
with WarpedVRT(src, crs=self.crs) as vrt:
minx, miny, maxx, maxy = vrt.bounds
date = match.group("acquisition_date")
time = datetime.strptime(date, self.date_format)
timestamp = time.timestamp()
coords = (minx, maxx, miny, maxy, timestamp, timestamp)
self.index.insert(i, coords, filename)
if "filename" not in locals():
raise FileNotFoundError(f"No NAIP data was found in '{root}'")
def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
"""Retrieve image and metadata indexed by query.
Args:
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
Returns:
sample of image and metadata at that index
Raises:
IndexError: if query is not within bounds of the index
"""
if not query.intersects(self.bounds):
raise IndexError(
f"query: {query} is not within bounds of the index: {self.bounds}"
)
hits = self.index.intersection(query, objects=True)
filename = next(hits).object
with rasterio.open(filename) as src:
with WarpedVRT(src, crs=self.crs, nodata=0) as vrt:
window = rasterio.windows.from_bounds(
query.minx,
query.miny,
query.maxx,
query.maxy,
transform=vrt.transform,
)
image = vrt.read(window=window)
sample = {
"image": torch.tensor(image), # type: ignore[attr-defined]
"crs": self.crs,
"bbox": query,
}
if self.transforms is not None:
sample = self.transforms(sample)
return sample
def plot(self, image: Tensor) -> None:
"""Plot an image on a map.
Args:
image: the image to plot
"""
# Drop NIR channel
image = image[:3]
# Convert from CxHxW to HxWxC
image = image.permute((1, 2, 0))
array = image.numpy()
# Plot the image
ax = plt.axes()
ax.imshow(array)
ax.axis("off")
plt.show()
plt.close()
# Plotting
all_bands = ["R", "G", "B", "NIR"]
rgb_bands = ["R", "G", "B"]

Просмотреть файл

@ -1,23 +1,13 @@
"""Sentinel datasets."""
import abc
import glob
import os
from datetime import datetime
from typing import Any, Callable, Dict, Optional, Sequence
import numpy as np
import rasterio
import torch
from rasterio.crs import CRS as RCRS
from rasterio.vrt import WarpedVRT
from rtree.index import Index, Property
from rasterio.crs import CRS
from .geo import GeoDataset
from .utils import BoundingBox
from .geo import RasterDataset
class Sentinel(GeoDataset, abc.ABC):
class Sentinel(RasterDataset):
"""Abstract base class for all Sentinel datasets.
`Sentinel <https://sentinel.esa.int/web/sentinel/home>`_ is a family of
@ -43,7 +33,22 @@ class Sentinel2(Sentinel):
Earth's surface changes.
"""
band_names = [
# TODO: files downloaded from USGS Earth Explorer seem to have a different
# filename format than the official documentation
# https://sentinels.copernicus.eu/web/sentinel/user-guides/sentinel-2-msi/naming-convention
# https://sentinel.esa.int/documents/247904/685211/Sentinel-2-MSI-L2A-Product-Format-Specifications.pdf
filename_glob = "T*_*_B*_*m.*"
filename_regex = r"""
^T(?P<tile>\d{2}[A-Z]{3})
_(?P<date>\d{8}T\d{6})
_(?P<band>B\d{2})
_(?P<resolution>\d{2}m)
\..*$
"""
date_format = "%Y%m%dT%H%M%S"
# https://gisgeography.com/sentinel-2-bands-combinations/
all_bands = [
"B01",
"B02",
"B03",
@ -58,94 +63,30 @@ class Sentinel2(Sentinel):
"B11",
"B12",
]
rgb_bands = ["B04", "B03", "B02"]
separate_files = True
def __init__(
self,
root: str = "data",
crs: RCRS = RCRS.from_epsg(32641),
bands: Sequence[str] = band_names,
crs: Optional[CRS] = None,
bands: Sequence[str] = [],
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
) -> None:
"""Initialize a new Sentinel-2 Dataset.
"""Initialize a new Dataset instance.
Args:
root: root directory where dataset can be found
crs: :term:`coordinate reference system (CRS)` to project to
bands: bands to return
crs: :term:`coordinate reference system (CRS)` to project to. Uses the CRS
of the files by default
bands: bands to return (defaults to all bands)
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
Raises:
FileNotFoundError: if no files are found in ``root``
"""
self.root = root
self.crs = crs
self.bands = bands
self.transforms = transforms
self.bands = bands if bands else self.all_bands
# Create an R-tree to index the dataset
self.index = Index(interleaved=False, properties=Property(dimension=3))
fileglob = os.path.join(root, f"**_{bands[0]}_*.tif")
for i, filename in enumerate(glob.iglob(fileglob, recursive=True)):
# https://sentinel.esa.int/web/sentinel/user-guides/sentinel-2-msi/naming-convention
time = datetime.strptime(
os.path.basename(filename).split("_")[1], "%Y%m%dT%H%M%S"
)
timestamp = time.timestamp()
with rasterio.open(filename) as src:
with WarpedVRT(src, crs=self.crs) as vrt:
minx, miny, maxx, maxy = vrt.bounds
coords = (minx, maxx, miny, maxy, timestamp, timestamp)
self.index.insert(i, coords, filename)
if "filename" not in locals():
raise FileNotFoundError(f"No Sentinel2 data was found in '{root}'")
def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
"""Retrieve image and metadata indexed by query.
Args:
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
Returns:
sample of data and metadata at that index
Raises:
IndexError: if query is not within bounds of the index
"""
if not query.intersects(self.bounds):
raise IndexError(
f"query: {query} is not within bounds of the index: {self.bounds}"
)
hits = self.index.intersection(query, objects=True)
filename = next(hits).object # TODO: this assumes there is only a single hit
data_list = []
for band in self.bands:
tokens = filename.split("_")
tokens[2] = band
filename = "_".join(tokens)
with rasterio.open(filename) as src:
with WarpedVRT(src, crs=self.crs) as vrt:
window = rasterio.windows.from_bounds(
query.minx,
query.miny,
query.maxx,
query.maxy,
transform=vrt.transform,
)
image = vrt.read(window=window)
data_list.append(image)
# FIXME: different bands have different resolution, won't be able to concatenate
image = np.concatenate(data_list) # type: ignore[no-untyped-call]
image = image.astype(np.int32)
sample = {
"image": torch.tensor(image), # type: ignore[attr-defined]
"crs": self.crs,
"bbox": query,
}
if self.transforms is not None:
sample = self.transforms(sample)
return sample
super().__init__(root, crs, transforms)