* add astergdem dataset

* add astergdem dataset

* add plot method

* typo

* fix docs

* requested changes

* Update docs/api/datasets.rst

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>

* Update torchgeo/datasets/astergdem.py

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>

* split regex

* split regex

* split regex

* regex

Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Nils Lehmann 2022-02-24 22:49:40 +01:00 коммит произвёл GitHub
Родитель 26b6917306
Коммит d48e10ec15
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
8 изменённых файлов: 270 добавлений и 0 удалений

Просмотреть файл

@ -12,6 +12,11 @@ Geospatial Datasets
:class:`GeoDataset` is designed for datasets that contain geospatial information, like latitude, longitude, coordinate system, and projection. Datasets containing this kind of information can be combined using :class:`IntersectionDataset` and :class:`UnionDataset`.
Aster Global Digital Evaluation Model
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: AsterGDEM
Canadian Building Footprints
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Двоичные данные
tests/data/astergdem/ASTGTMV003_N000000_dem.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/astergdem/ASTGTMV003_N000010_dem.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/astergdem/astergdem.zip Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,63 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import hashlib
import os
import random
import zipfile
import numpy as np
import rasterio
np.random.seed(0)
random.seed(0)
SIZE = 64
files = [
{"image": "ASTGTMV003_N000000_dem.tif"},
{"image": "ASTGTMV003_N000010_dem.tif"},
]
def create_file(path: str, dtype: str, num_channels: int) -> None:
profile = {}
profile["driver"] = "GTiff"
profile["dtype"] = dtype
profile["count"] = num_channels
profile["crs"] = "epsg:4326"
profile["transform"] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1)
profile["height"] = SIZE
profile["width"] = SIZE
profile["compress"] = "lzw"
profile["predictor"] = 2
Z = np.random.randint(
np.iinfo(profile["dtype"]).max, size=(1, SIZE, SIZE), dtype=profile["dtype"]
)
src = rasterio.open(path, "w", **profile)
src.write(Z)
if __name__ == "__main__":
zipfilename = "astergdem.zip"
files_to_zip = []
for file_dict in files:
path = file_dict["image"]
# remove old data
if os.path.exists(path):
os.remove(path)
# Create mask file
create_file(path, dtype="int32", num_channels=1)
files_to_zip.append(path)
# Compress data
with zipfile.ZipFile(zipfilename, "w") as zip:
for file in files_to_zip:
zip.write(file, arcname=file)
# Compute checksums
with open(zipfilename, "rb") as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f"{zipfilename}: {md5}")

Просмотреть файл

@ -0,0 +1,61 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import shutil
from pathlib import Path
import pytest
import torch
import torch.nn as nn
from rasterio.crs import CRS
from torchgeo.datasets import AsterGDEM, BoundingBox, IntersectionDataset, UnionDataset
class TestAsterGDEM:
@pytest.fixture
def dataset(self, tmp_path: Path) -> AsterGDEM:
zipfile = os.path.join("tests", "data", "astergdem", "astergdem.zip")
shutil.unpack_archive(zipfile, tmp_path, "zip")
root = str(tmp_path)
transforms = nn.Identity() # type: ignore[attr-defined]
return AsterGDEM(root, transforms=transforms)
def test_datasetmissing(self, tmp_path: Path) -> None:
shutil.rmtree(tmp_path)
os.makedirs(tmp_path)
with pytest.raises(RuntimeError, match="Dataset not found in"):
AsterGDEM(root=str(tmp_path))
def test_getitem(self, dataset: AsterGDEM) -> None:
x = dataset[dataset.bounds]
assert isinstance(x, dict)
assert isinstance(x["crs"], CRS)
assert isinstance(x["mask"], torch.Tensor)
def test_and(self, dataset: AsterGDEM) -> None:
ds = dataset & dataset
assert isinstance(ds, IntersectionDataset)
def test_or(self, dataset: AsterGDEM) -> None:
ds = dataset | dataset
assert isinstance(ds, UnionDataset)
def test_plot(self, dataset: AsterGDEM) -> None:
query = dataset.bounds
x = dataset[query]
dataset.plot(x, suptitle="Test")
def test_plot_prediction(self, dataset: AsterGDEM) -> None:
query = dataset.bounds
x = dataset[query]
x["prediction"] = x["mask"].clone()
dataset.plot(x, suptitle="Prediction")
def test_invalid_query(self, dataset: AsterGDEM) -> None:
query = BoundingBox(100, 100, 100, 100, 0, 0)
with pytest.raises(
IndexError, match="query: .* not found in index with bounds:"
):
dataset[query]

Просмотреть файл

@ -4,6 +4,7 @@
"""TorchGeo datasets."""
from .advance import ADVANCE
from .astergdem import AsterGDEM
from .benin_cashews import BeninSmallHolderCashews
from .bigearthnet import BigEarthNet
from .cbf import CanadianBuildingFootprints
@ -85,6 +86,7 @@ from .zuericrop import ZueriCrop
__all__ = (
# GeoDataset
"AsterGDEM",
"CanadianBuildingFootprints",
"CDL",
"Chesapeake",

Просмотреть файл

@ -0,0 +1,139 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""Aster Global Digital Evaluation Model dataset."""
import glob
import os
from typing import Any, Callable, Dict, Optional
import matplotlib.pyplot as plt
from rasterio.crs import CRS
from torch import Tensor
from .geo import RasterDataset
class AsterGDEM(RasterDataset):
"""Aster Global Digital Evaluation Model Dataset.
The `Aster Global Digital Evaluation Model
<https://lpdaac.usgs.gov/products/astgtmv003/>`_
dataset is a Digital Elevation Model (DEM) on a global scale.
The dataset can be downloaded from the
`Earth Data website <https://search.earthdata.nasa.gov/search/>`_
after making an account.
Dataset features:
* DEMs at 30 m per pixel spatial resolution (3601x3601 px)
* data collected from the `Aster
<https://terra.nasa.gov/about/terra-instruments/aster>`_ instrument
Dataset format:
* DEMs are single-channel tif files
.. versionadded:: 0.3
"""
is_image = False
filename_glob = "ASTGTMV003_*_dem*"
filename_regex = r"""
(?P<name>[ASTGTMV003]{10})
_(?P<id>[A-Z0-9]{7})
_(?P<data>[a-z]{3})*
"""
def __init__(
self,
root: str = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
cache: bool = True,
) -> None:
"""Initialize a new Dataset instance.
Args:
root: root directory where dataset can be found, here the collection of
individual zip files for each tile should be found
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
(defaults to the resolution of the first file found)
transforms: a function/transform that takes an input sample
and returns a transformed version
cache: if True, cache file handle to speed up repeated sampling
Raises:
FileNotFoundError: if no files are found in ``root``
RuntimeError: if dataset is missing
"""
self.root = root
self._verify()
super().__init__(root, crs, res, transforms, cache)
def _verify(self) -> None:
"""Verify the integrity of the dataset.
Raises:
RuntimeError: if dataset is missing
"""
# Check if the extracted files already exists
pathname = os.path.join(self.root, self.filename_glob)
if glob.glob(pathname):
return
raise RuntimeError(
f"Dataset not found in `root={self.root}` "
"either specify a different `root` directory or make sure you "
"have manually downloaded dataset tiles as suggested in the documentation."
)
def plot( # type: ignore[override]
self,
sample: Dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
"""Plot a sample from the dataset.
Args:
sample: a sample returned by :meth:`RasterDataset.__getitem__`
show_titles: flag indicating whether to show titles above each panel
suptitle: optional string to use as a suptitle
Returns:
a matplotlib Figure with the rendered sample
"""
mask = sample["mask"].squeeze()
ncols = 1
showing_predictions = "prediction" in sample
if showing_predictions:
prediction = sample["prediction"].squeeze()
ncols = 2
fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(4 * ncols, 4))
if showing_predictions:
axs[0].imshow(mask)
axs[0].axis("off")
axs[1].imshow(prediction)
axs[1].axis("off")
if show_titles:
axs[0].set_title("Mask")
axs[1].set_title("Prediction")
else:
axs.imshow(mask)
axs.axis("off")
if show_titles:
axs.set_title("Mask")
if suptitle is not None:
plt.suptitle(suptitle)
return fig