L8 Biome: convert to IntersectionDataset (#2058)

This commit is contained in:
Adam J. Stewart 2024-05-13 16:23:18 +02:00 коммит произвёл GitHub
Родитель 171cb919e7
Коммит 5976bd15bf
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
2 изменённых файлов: 84 добавлений и 78 удалений

Просмотреть файл

@ -5,6 +5,7 @@ import glob
import os
import shutil
from pathlib import Path
from typing import cast
import matplotlib.pyplot as plt
import pytest
@ -65,7 +66,9 @@ class TestL8Biome:
plt.close()
def test_already_extracted(self, dataset: L8Biome) -> None:
L8Biome(dataset.paths, download=True)
paths = cast(str, dataset.paths)
L8Biome(paths, download=True)
L8Biome([paths], download=True)
def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join('tests', 'data', 'l8biome', '*.tar.gz')

Просмотреть файл

@ -6,19 +6,79 @@
import glob
import os
from collections.abc import Callable, Iterable, Sequence
from typing import Any, cast
from typing import Any
import matplotlib.pyplot as plt
import torch
from matplotlib.figure import Figure
from rasterio.crs import CRS
from torch import Tensor
from .errors import DatasetNotFoundError, RGBBandsMissingError
from .geo import RasterDataset
from .geo import IntersectionDataset, RasterDataset
from .utils import BoundingBox, download_url, extract_archive
class L8Biome(RasterDataset):
class L8BiomeImage(RasterDataset):
"""Images from the L8 Biome dataset."""
# https://gisgeography.com/landsat-file-naming-convention/
filename_glob = 'LC8*.TIF'
filename_regex = r"""
^LC8
(?P<wrs_path>\d{3})
(?P<wrs_row>\d{3})
(?P<date>\d{7})
(?P<gsi>[A-Z]{3})
(?P<version>\d{2})
\.TIF$
"""
date_format = '%Y%j'
is_image = True
rgb_bands = ['B4', 'B3', 'B2']
all_bands = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B9', 'B10', 'B11']
class L8BiomeMask(RasterDataset):
"""Masks from the L8 Biome dataset."""
# https://gisgeography.com/landsat-file-naming-convention/
filename_glob = 'LC8*_fixedmask.TIF'
filename_regex = r"""
^LC8
(?P<wrs_path>\d{3})
(?P<wrs_row>\d{3})
(?P<date>\d{7})
(?P<gsi>[A-Z]{3})
(?P<version>\d{2})
_fixedmask
\.TIF$
"""
date_format = '%Y%j'
is_image = False
classes = ['Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud']
ordinal_map = torch.zeros(256, dtype=torch.long)
ordinal_map[64] = 1
ordinal_map[128] = 2
ordinal_map[192] = 3
ordinal_map[255] = 4
def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
"""Retrieve image/mask and metadata indexed by query.
Args:
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
Returns:
sample of image, mask and metadata at that index
Raises:
IndexError: if query is not found in the index
"""
sample = super().__getitem__(query)
sample['mask'] = self.ordinal_map[sample['mask']]
return sample
class L8Biome(IntersectionDataset):
"""L8 Biome dataset.
The `L8 Biome <https://landsat.usgs.gov/landsat-8-cloud-cover-assessment-validation-data>`__
@ -70,31 +130,12 @@ class L8Biome(RasterDataset):
'wetlands': '1f86cc354631ca9a50ce54b7cab3f557',
}
classes = ['Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud']
# https://gisgeography.com/landsat-file-naming-convention/
filename_glob = 'LC8*.TIF'
filename_regex = r"""
^LC8
(?P<wrs_path>\d{3})
(?P<wrs_row>\d{3})
(?P<date>\d{7})
(?P<gsi>[A-Z]{3})
(?P<version>\d{2})
\.TIF$
"""
date_format = '%Y%j'
separate_files = False
rgb_bands = ['B4', 'B3', 'B2']
all_bands = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B9', 'B10', 'B11']
def __init__(
self,
paths: str | Iterable[str],
crs: CRS | None = CRS.from_epsg(3857),
res: float | None = None,
bands: Sequence[str] = all_bands,
bands: Sequence[str] = L8BiomeImage.all_bands,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
cache: bool = True,
download: bool = False,
@ -124,18 +165,25 @@ class L8Biome(RasterDataset):
self._verify()
super().__init__(
paths, crs=crs, res=res, bands=bands, transforms=transforms, cache=cache
)
self.image = L8BiomeImage(paths, crs, res, bands, transforms, cache)
self.mask = L8BiomeMask(paths, crs, res, None, transforms, cache)
super().__init__(self.image, self.mask)
def _verify(self) -> None:
"""Verify the integrity of the dataset."""
# Check if the extracted files already exist
if self.files:
if not isinstance(self.paths, str):
return
for classname in [L8BiomeImage, L8BiomeMask]:
pathname = os.path.join(self.paths, '**', classname.filename_glob)
if not glob.glob(pathname, recursive=True):
break
else:
return
# Check if the tar.gz files have already been downloaded
assert isinstance(self.paths, str)
pathname = os.path.join(self.paths, '*.tar.gz')
if glob.glob(pathname):
self._extract()
@ -163,51 +211,6 @@ class L8Biome(RasterDataset):
for tarfile in glob.iglob(pathname):
extract_archive(tarfile)
def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
"""Retrieve image/mask and metadata indexed by query.
Args:
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
Returns:
sample of image, mask and metadata at that index
Raises:
IndexError: if query is not found in the index
"""
hits = self.index.intersection(tuple(query), objects=True)
filepaths = cast(list[str], [hit.object for hit in hits])
if not filepaths:
raise IndexError(
f'query: {query} not found in index with bounds: {self.bounds}'
)
image = self._merge_files(filepaths, query, self.band_indexes)
mask_filepaths = []
for filepath in filepaths:
mask_filepath = filepath.replace('.TIF', '_fixedmask.TIF')
mask_filepaths.append(mask_filepath)
mask = self._merge_files(mask_filepaths, query)
mask_mapping = {64: 1, 128: 2, 192: 3, 255: 4}
for k, v in mask_mapping.items():
mask[mask == k] = v
sample = {
'crs': self.crs,
'bbox': query,
'image': image.float(),
'mask': mask.long(),
}
if self.transforms is not None:
sample = self.transforms(sample)
return sample
def plot(
self,
sample: dict[str, Tensor],
@ -217,7 +220,7 @@ class L8Biome(RasterDataset):
"""Plot a sample from the dataset.
Args:
sample: a sample returned by :meth:`__getitem__`
sample: a sample returned by :meth:`RasterDataset.__getitem__`
show_titles: flag indicating whether to show titles above each panel
suptitle: optional string to use as a suptitle
@ -228,9 +231,9 @@ class L8Biome(RasterDataset):
RGBBandsMissingError: If *bands* does not include all RGB bands.
"""
rgb_indices = []
for band in self.rgb_bands:
if band in self.bands:
rgb_indices.append(self.bands.index(band))
for band in self.image.rgb_bands:
if band in self.image.bands:
rgb_indices.append(self.image.bands.index(band))
else:
raise RGBBandsMissingError()