зеркало из https://github.com/microsoft/torchgeo.git
L8 Biome: convert to IntersectionDataset (#2058)
This commit is contained in:
Родитель
171cb919e7
Коммит
5976bd15bf
|
@ -5,6 +5,7 @@ import glob
|
|||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pytest
|
||||
|
@ -65,7 +66,9 @@ class TestL8Biome:
|
|||
plt.close()
|
||||
|
||||
def test_already_extracted(self, dataset: L8Biome) -> None:
|
||||
L8Biome(dataset.paths, download=True)
|
||||
paths = cast(str, dataset.paths)
|
||||
L8Biome(paths, download=True)
|
||||
L8Biome([paths], download=True)
|
||||
|
||||
def test_already_downloaded(self, tmp_path: Path) -> None:
|
||||
pathname = os.path.join('tests', 'data', 'l8biome', '*.tar.gz')
|
||||
|
|
|
@ -6,19 +6,79 @@
|
|||
import glob
|
||||
import os
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from rasterio.crs import CRS
|
||||
from torch import Tensor
|
||||
|
||||
from .errors import DatasetNotFoundError, RGBBandsMissingError
|
||||
from .geo import RasterDataset
|
||||
from .geo import IntersectionDataset, RasterDataset
|
||||
from .utils import BoundingBox, download_url, extract_archive
|
||||
|
||||
|
||||
class L8Biome(RasterDataset):
|
||||
class L8BiomeImage(RasterDataset):
|
||||
"""Images from the L8 Biome dataset."""
|
||||
|
||||
# https://gisgeography.com/landsat-file-naming-convention/
|
||||
filename_glob = 'LC8*.TIF'
|
||||
filename_regex = r"""
|
||||
^LC8
|
||||
(?P<wrs_path>\d{3})
|
||||
(?P<wrs_row>\d{3})
|
||||
(?P<date>\d{7})
|
||||
(?P<gsi>[A-Z]{3})
|
||||
(?P<version>\d{2})
|
||||
\.TIF$
|
||||
"""
|
||||
date_format = '%Y%j'
|
||||
is_image = True
|
||||
rgb_bands = ['B4', 'B3', 'B2']
|
||||
all_bands = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B9', 'B10', 'B11']
|
||||
|
||||
|
||||
class L8BiomeMask(RasterDataset):
|
||||
"""Masks from the L8 Biome dataset."""
|
||||
|
||||
# https://gisgeography.com/landsat-file-naming-convention/
|
||||
filename_glob = 'LC8*_fixedmask.TIF'
|
||||
filename_regex = r"""
|
||||
^LC8
|
||||
(?P<wrs_path>\d{3})
|
||||
(?P<wrs_row>\d{3})
|
||||
(?P<date>\d{7})
|
||||
(?P<gsi>[A-Z]{3})
|
||||
(?P<version>\d{2})
|
||||
_fixedmask
|
||||
\.TIF$
|
||||
"""
|
||||
date_format = '%Y%j'
|
||||
is_image = False
|
||||
classes = ['Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud']
|
||||
ordinal_map = torch.zeros(256, dtype=torch.long)
|
||||
ordinal_map[64] = 1
|
||||
ordinal_map[128] = 2
|
||||
ordinal_map[192] = 3
|
||||
ordinal_map[255] = 4
|
||||
|
||||
def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
|
||||
"""Retrieve image/mask and metadata indexed by query.
|
||||
|
||||
Args:
|
||||
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
|
||||
Returns:
|
||||
sample of image, mask and metadata at that index
|
||||
Raises:
|
||||
IndexError: if query is not found in the index
|
||||
"""
|
||||
sample = super().__getitem__(query)
|
||||
sample['mask'] = self.ordinal_map[sample['mask']]
|
||||
return sample
|
||||
|
||||
|
||||
class L8Biome(IntersectionDataset):
|
||||
"""L8 Biome dataset.
|
||||
|
||||
The `L8 Biome <https://landsat.usgs.gov/landsat-8-cloud-cover-assessment-validation-data>`__
|
||||
|
@ -70,31 +130,12 @@ class L8Biome(RasterDataset):
|
|||
'wetlands': '1f86cc354631ca9a50ce54b7cab3f557',
|
||||
}
|
||||
|
||||
classes = ['Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud']
|
||||
|
||||
# https://gisgeography.com/landsat-file-naming-convention/
|
||||
filename_glob = 'LC8*.TIF'
|
||||
filename_regex = r"""
|
||||
^LC8
|
||||
(?P<wrs_path>\d{3})
|
||||
(?P<wrs_row>\d{3})
|
||||
(?P<date>\d{7})
|
||||
(?P<gsi>[A-Z]{3})
|
||||
(?P<version>\d{2})
|
||||
\.TIF$
|
||||
"""
|
||||
date_format = '%Y%j'
|
||||
|
||||
separate_files = False
|
||||
rgb_bands = ['B4', 'B3', 'B2']
|
||||
all_bands = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B9', 'B10', 'B11']
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
paths: str | Iterable[str],
|
||||
crs: CRS | None = CRS.from_epsg(3857),
|
||||
res: float | None = None,
|
||||
bands: Sequence[str] = all_bands,
|
||||
bands: Sequence[str] = L8BiomeImage.all_bands,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
cache: bool = True,
|
||||
download: bool = False,
|
||||
|
@ -124,18 +165,25 @@ class L8Biome(RasterDataset):
|
|||
|
||||
self._verify()
|
||||
|
||||
super().__init__(
|
||||
paths, crs=crs, res=res, bands=bands, transforms=transforms, cache=cache
|
||||
)
|
||||
self.image = L8BiomeImage(paths, crs, res, bands, transforms, cache)
|
||||
self.mask = L8BiomeMask(paths, crs, res, None, transforms, cache)
|
||||
|
||||
super().__init__(self.image, self.mask)
|
||||
|
||||
def _verify(self) -> None:
|
||||
"""Verify the integrity of the dataset."""
|
||||
# Check if the extracted files already exist
|
||||
if self.files:
|
||||
if not isinstance(self.paths, str):
|
||||
return
|
||||
|
||||
for classname in [L8BiomeImage, L8BiomeMask]:
|
||||
pathname = os.path.join(self.paths, '**', classname.filename_glob)
|
||||
if not glob.glob(pathname, recursive=True):
|
||||
break
|
||||
else:
|
||||
return
|
||||
|
||||
# Check if the tar.gz files have already been downloaded
|
||||
assert isinstance(self.paths, str)
|
||||
pathname = os.path.join(self.paths, '*.tar.gz')
|
||||
if glob.glob(pathname):
|
||||
self._extract()
|
||||
|
@ -163,51 +211,6 @@ class L8Biome(RasterDataset):
|
|||
for tarfile in glob.iglob(pathname):
|
||||
extract_archive(tarfile)
|
||||
|
||||
def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
|
||||
"""Retrieve image/mask and metadata indexed by query.
|
||||
|
||||
Args:
|
||||
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
|
||||
|
||||
Returns:
|
||||
sample of image, mask and metadata at that index
|
||||
|
||||
Raises:
|
||||
IndexError: if query is not found in the index
|
||||
"""
|
||||
hits = self.index.intersection(tuple(query), objects=True)
|
||||
filepaths = cast(list[str], [hit.object for hit in hits])
|
||||
|
||||
if not filepaths:
|
||||
raise IndexError(
|
||||
f'query: {query} not found in index with bounds: {self.bounds}'
|
||||
)
|
||||
|
||||
image = self._merge_files(filepaths, query, self.band_indexes)
|
||||
|
||||
mask_filepaths = []
|
||||
for filepath in filepaths:
|
||||
mask_filepath = filepath.replace('.TIF', '_fixedmask.TIF')
|
||||
mask_filepaths.append(mask_filepath)
|
||||
|
||||
mask = self._merge_files(mask_filepaths, query)
|
||||
mask_mapping = {64: 1, 128: 2, 192: 3, 255: 4}
|
||||
|
||||
for k, v in mask_mapping.items():
|
||||
mask[mask == k] = v
|
||||
|
||||
sample = {
|
||||
'crs': self.crs,
|
||||
'bbox': query,
|
||||
'image': image.float(),
|
||||
'mask': mask.long(),
|
||||
}
|
||||
|
||||
if self.transforms is not None:
|
||||
sample = self.transforms(sample)
|
||||
|
||||
return sample
|
||||
|
||||
def plot(
|
||||
self,
|
||||
sample: dict[str, Tensor],
|
||||
|
@ -217,7 +220,7 @@ class L8Biome(RasterDataset):
|
|||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
sample: a sample returned by :meth:`__getitem__`
|
||||
sample: a sample returned by :meth:`RasterDataset.__getitem__`
|
||||
show_titles: flag indicating whether to show titles above each panel
|
||||
suptitle: optional string to use as a suptitle
|
||||
|
||||
|
@ -228,9 +231,9 @@ class L8Biome(RasterDataset):
|
|||
RGBBandsMissingError: If *bands* does not include all RGB bands.
|
||||
"""
|
||||
rgb_indices = []
|
||||
for band in self.rgb_bands:
|
||||
if band in self.bands:
|
||||
rgb_indices.append(self.bands.index(band))
|
||||
for band in self.image.rgb_bands:
|
||||
if band in self.image.bands:
|
||||
rgb_indices.append(self.image.bands.index(band))
|
||||
else:
|
||||
raise RGBBandsMissingError()
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче