зеркало из https://github.com/microsoft/torchgeo.git
Add AsterGDEM dataset (#404)
* add astergdem dataset * add astergdem dataset * add plot method * typo * fix docs * requested changes * Update docs/api/datasets.rst Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Update torchgeo/datasets/astergdem.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * split regex * split regex * split regex * regex Co-authored-by: Caleb Robinson <calebrob6@gmail.com> Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Родитель
26b6917306
Коммит
d48e10ec15
|
@ -12,6 +12,11 @@ Geospatial Datasets
|
|||
|
||||
:class:`GeoDataset` is designed for datasets that contain geospatial information, like latitude, longitude, coordinate system, and projection. Datasets containing this kind of information can be combined using :class:`IntersectionDataset` and :class:`UnionDataset`.
|
||||
|
||||
Aster Global Digital Evaluation Model
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: AsterGDEM
|
||||
|
||||
Canadian Building Footprints
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
|
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
|
@ -0,0 +1,63 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import random
|
||||
import zipfile
|
||||
|
||||
import numpy as np
|
||||
import rasterio
|
||||
|
||||
np.random.seed(0)
|
||||
random.seed(0)
|
||||
|
||||
SIZE = 64
|
||||
|
||||
files = [
|
||||
{"image": "ASTGTMV003_N000000_dem.tif"},
|
||||
{"image": "ASTGTMV003_N000010_dem.tif"},
|
||||
]
|
||||
|
||||
|
||||
def create_file(path: str, dtype: str, num_channels: int) -> None:
|
||||
profile = {}
|
||||
profile["driver"] = "GTiff"
|
||||
profile["dtype"] = dtype
|
||||
profile["count"] = num_channels
|
||||
profile["crs"] = "epsg:4326"
|
||||
profile["transform"] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1)
|
||||
profile["height"] = SIZE
|
||||
profile["width"] = SIZE
|
||||
profile["compress"] = "lzw"
|
||||
profile["predictor"] = 2
|
||||
|
||||
Z = np.random.randint(
|
||||
np.iinfo(profile["dtype"]).max, size=(1, SIZE, SIZE), dtype=profile["dtype"]
|
||||
)
|
||||
src = rasterio.open(path, "w", **profile)
|
||||
src.write(Z)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
zipfilename = "astergdem.zip"
|
||||
files_to_zip = []
|
||||
|
||||
for file_dict in files:
|
||||
path = file_dict["image"]
|
||||
# remove old data
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
# Create mask file
|
||||
create_file(path, dtype="int32", num_channels=1)
|
||||
files_to_zip.append(path)
|
||||
|
||||
# Compress data
|
||||
with zipfile.ZipFile(zipfilename, "w") as zip:
|
||||
for file in files_to_zip:
|
||||
zip.write(file, arcname=file)
|
||||
|
||||
# Compute checksums
|
||||
with open(zipfilename, "rb") as f:
|
||||
md5 = hashlib.md5(f.read()).hexdigest()
|
||||
print(f"{zipfilename}: {md5}")
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from rasterio.crs import CRS
|
||||
|
||||
from torchgeo.datasets import AsterGDEM, BoundingBox, IntersectionDataset, UnionDataset
|
||||
|
||||
|
||||
class TestAsterGDEM:
|
||||
@pytest.fixture
|
||||
def dataset(self, tmp_path: Path) -> AsterGDEM:
|
||||
zipfile = os.path.join("tests", "data", "astergdem", "astergdem.zip")
|
||||
shutil.unpack_archive(zipfile, tmp_path, "zip")
|
||||
root = str(tmp_path)
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return AsterGDEM(root, transforms=transforms)
|
||||
|
||||
def test_datasetmissing(self, tmp_path: Path) -> None:
|
||||
shutil.rmtree(tmp_path)
|
||||
os.makedirs(tmp_path)
|
||||
with pytest.raises(RuntimeError, match="Dataset not found in"):
|
||||
AsterGDEM(root=str(tmp_path))
|
||||
|
||||
def test_getitem(self, dataset: AsterGDEM) -> None:
|
||||
x = dataset[dataset.bounds]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x["crs"], CRS)
|
||||
assert isinstance(x["mask"], torch.Tensor)
|
||||
|
||||
def test_and(self, dataset: AsterGDEM) -> None:
|
||||
ds = dataset & dataset
|
||||
assert isinstance(ds, IntersectionDataset)
|
||||
|
||||
def test_or(self, dataset: AsterGDEM) -> None:
|
||||
ds = dataset | dataset
|
||||
assert isinstance(ds, UnionDataset)
|
||||
|
||||
def test_plot(self, dataset: AsterGDEM) -> None:
|
||||
query = dataset.bounds
|
||||
x = dataset[query]
|
||||
dataset.plot(x, suptitle="Test")
|
||||
|
||||
def test_plot_prediction(self, dataset: AsterGDEM) -> None:
|
||||
query = dataset.bounds
|
||||
x = dataset[query]
|
||||
x["prediction"] = x["mask"].clone()
|
||||
dataset.plot(x, suptitle="Prediction")
|
||||
|
||||
def test_invalid_query(self, dataset: AsterGDEM) -> None:
|
||||
query = BoundingBox(100, 100, 100, 100, 0, 0)
|
||||
with pytest.raises(
|
||||
IndexError, match="query: .* not found in index with bounds:"
|
||||
):
|
||||
dataset[query]
|
|
@ -4,6 +4,7 @@
|
|||
"""TorchGeo datasets."""
|
||||
|
||||
from .advance import ADVANCE
|
||||
from .astergdem import AsterGDEM
|
||||
from .benin_cashews import BeninSmallHolderCashews
|
||||
from .bigearthnet import BigEarthNet
|
||||
from .cbf import CanadianBuildingFootprints
|
||||
|
@ -85,6 +86,7 @@ from .zuericrop import ZueriCrop
|
|||
|
||||
__all__ = (
|
||||
# GeoDataset
|
||||
"AsterGDEM",
|
||||
"CanadianBuildingFootprints",
|
||||
"CDL",
|
||||
"Chesapeake",
|
||||
|
|
|
@ -0,0 +1,139 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Aster Global Digital Evaluation Model dataset."""
|
||||
|
||||
import glob
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from rasterio.crs import CRS
|
||||
from torch import Tensor
|
||||
|
||||
from .geo import RasterDataset
|
||||
|
||||
|
||||
class AsterGDEM(RasterDataset):
|
||||
"""Aster Global Digital Evaluation Model Dataset.
|
||||
|
||||
The `Aster Global Digital Evaluation Model
|
||||
<https://lpdaac.usgs.gov/products/astgtmv003/>`_
|
||||
dataset is a Digital Elevation Model (DEM) on a global scale.
|
||||
The dataset can be downloaded from the
|
||||
`Earth Data website <https://search.earthdata.nasa.gov/search/>`_
|
||||
after making an account.
|
||||
|
||||
Dataset features:
|
||||
|
||||
* DEMs at 30 m per pixel spatial resolution (3601x3601 px)
|
||||
* data collected from the `Aster
|
||||
<https://terra.nasa.gov/about/terra-instruments/aster>`_ instrument
|
||||
|
||||
Dataset format:
|
||||
|
||||
* DEMs are single-channel tif files
|
||||
|
||||
.. versionadded:: 0.3
|
||||
"""
|
||||
|
||||
is_image = False
|
||||
filename_glob = "ASTGTMV003_*_dem*"
|
||||
filename_regex = r"""
|
||||
(?P<name>[ASTGTMV003]{10})
|
||||
_(?P<id>[A-Z0-9]{7})
|
||||
_(?P<data>[a-z]{3})*
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str = "data",
|
||||
crs: Optional[CRS] = None,
|
||||
res: Optional[float] = None,
|
||||
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
|
||||
cache: bool = True,
|
||||
) -> None:
|
||||
"""Initialize a new Dataset instance.
|
||||
|
||||
Args:
|
||||
root: root directory where dataset can be found, here the collection of
|
||||
individual zip files for each tile should be found
|
||||
crs: :term:`coordinate reference system (CRS)` to warp to
|
||||
(defaults to the CRS of the first file found)
|
||||
res: resolution of the dataset in units of CRS
|
||||
(defaults to the resolution of the first file found)
|
||||
transforms: a function/transform that takes an input sample
|
||||
and returns a transformed version
|
||||
cache: if True, cache file handle to speed up repeated sampling
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: if no files are found in ``root``
|
||||
RuntimeError: if dataset is missing
|
||||
"""
|
||||
self.root = root
|
||||
|
||||
self._verify()
|
||||
|
||||
super().__init__(root, crs, res, transforms, cache)
|
||||
|
||||
def _verify(self) -> None:
|
||||
"""Verify the integrity of the dataset.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if dataset is missing
|
||||
"""
|
||||
# Check if the extracted files already exists
|
||||
pathname = os.path.join(self.root, self.filename_glob)
|
||||
if glob.glob(pathname):
|
||||
return
|
||||
|
||||
raise RuntimeError(
|
||||
f"Dataset not found in `root={self.root}` "
|
||||
"either specify a different `root` directory or make sure you "
|
||||
"have manually downloaded dataset tiles as suggested in the documentation."
|
||||
)
|
||||
|
||||
def plot( # type: ignore[override]
|
||||
self,
|
||||
sample: Dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
sample: a sample returned by :meth:`RasterDataset.__getitem__`
|
||||
show_titles: flag indicating whether to show titles above each panel
|
||||
suptitle: optional string to use as a suptitle
|
||||
|
||||
Returns:
|
||||
a matplotlib Figure with the rendered sample
|
||||
"""
|
||||
mask = sample["mask"].squeeze()
|
||||
ncols = 1
|
||||
|
||||
showing_predictions = "prediction" in sample
|
||||
if showing_predictions:
|
||||
prediction = sample["prediction"].squeeze()
|
||||
ncols = 2
|
||||
|
||||
fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(4 * ncols, 4))
|
||||
|
||||
if showing_predictions:
|
||||
axs[0].imshow(mask)
|
||||
axs[0].axis("off")
|
||||
axs[1].imshow(prediction)
|
||||
axs[1].axis("off")
|
||||
if show_titles:
|
||||
axs[0].set_title("Mask")
|
||||
axs[1].set_title("Prediction")
|
||||
else:
|
||||
axs.imshow(mask)
|
||||
axs.axis("off")
|
||||
if show_titles:
|
||||
axs.set_title("Mask")
|
||||
|
||||
if suptitle is not None:
|
||||
plt.suptitle(suptitle)
|
||||
|
||||
return fig
|
Загрузка…
Ссылка в новой задаче