* Add dataset

* Add dataset to docs

* Tests for enviroatlas

* Test coverage

* Added numpy type

* Added plotting

* Code review changes

* Propagating code review comments to Chesapeake
This commit is contained in:
Caleb Robinson 2022-01-28 03:16:45 +00:00 коммит произвёл GitHub
Родитель 57981c823f
Коммит d4c8a4bd7b
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
7 изменённых файлов: 986 добавлений и 2 удалений

Просмотреть файл

@ -37,6 +37,11 @@ Cropland Data Layer (CDL)
.. autoclass:: CDL
EnviroAtlas
^^^^^^^^^^^
.. autoclass:: EnviroAtlas
Landsat
^^^^^^^

Просмотреть файл

@ -0,0 +1,305 @@
#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import shutil
from typing import Any, Dict
import fiona
import fiona.transform
import numpy as np
import rasterio
import shapely.geometry
from rasterio.crs import CRS
from rasterio.transform import Affine
from torchvision.datasets.utils import calculate_md5
suffix_to_key_map = {
"a_naip": "naip",
"b_nlcd": "nlcd",
"c_roads": "roads",
"d_water": "water",
"d1_waterways": "waterways",
"d2_waterbodies": "waterbodies",
"e_buildings": "buildings",
"h_highres_labels": "lc",
"prior_from_cooccurrences_101_31": "prior",
"prior_from_cooccurrences_101_31_no_osm_no_buildings": "prior_no_osm_no_buildings",
}
layer_data_profiles: Dict[str, Dict[Any, Any]] = {
"a_naip": {
"profile": {
"driver": "GTiff",
"dtype": "uint8",
"nodata": None,
"count": 4,
"crs": CRS.from_epsg(26914),
"blockxsize": 512,
"blockysize": 512,
"tiled": True,
"compress": "deflate",
"interleave": "pixel",
},
"data_type": "continuous",
"vals": (4, 255),
},
"b_nlcd": {
"profile": {
"driver": "GTiff",
"dtype": "uint8",
"nodata": None,
"count": 1,
"crs": CRS.from_epsg(26914),
"blockxsize": 512,
"blockysize": 512,
"tiled": True,
"compress": "deflate",
"interleave": "band",
},
"data_type": "categorical",
"vals": [1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15],
},
"c_roads": {
"profile": {
"driver": "GTiff",
"dtype": "uint8",
"nodata": None,
"count": 1,
"crs": CRS.from_epsg(26914),
"blockxsize": 512,
"blockysize": 512,
"tiled": True,
"compress": "deflate",
"interleave": "band",
},
"data_type": "categorical",
"vals": [0, 1],
},
"d1_waterways": {
"profile": {
"driver": "GTiff",
"dtype": "uint8",
"nodata": None,
"count": 1,
"crs": CRS.from_epsg(26914),
"blockxsize": 512,
"blockysize": 512,
"tiled": True,
"compress": "deflate",
"interleave": "band",
},
"data_type": "categorical",
"vals": [0, 1],
},
"d2_waterbodies": {
"profile": {
"driver": "GTiff",
"dtype": "uint8",
"nodata": None,
"count": 1,
"crs": CRS.from_epsg(26914),
"blockxsize": 512,
"blockysize": 512,
"tiled": True,
"compress": "deflate",
"interleave": "band",
},
"data_type": "categorical",
"vals": [0, 1],
},
"d_water": {
"profile": {
"driver": "GTiff",
"dtype": "uint8",
"nodata": None,
"count": 1,
"crs": CRS.from_epsg(26914),
"blockxsize": 512,
"blockysize": 512,
"tiled": True,
"compress": "deflate",
"interleave": "band",
},
"data_type": "categorical",
"vals": [0, 1],
},
"e_buildings": {
"profile": {
"driver": "GTiff",
"dtype": "uint8",
"nodata": None,
"count": 1,
"crs": CRS.from_epsg(26914),
"blockxsize": 512,
"blockysize": 512,
"tiled": True,
"compress": "deflate",
"interleave": "band",
},
"data_type": "categorical",
"vals": [0, 1],
},
"h_highres_labels": {
"profile": {
"driver": "GTiff",
"dtype": "uint8",
"nodata": None,
"count": 1,
"crs": CRS.from_epsg(26914),
"blockxsize": 512,
"blockysize": 512,
"tiled": True,
"compress": "deflate",
"interleave": "band",
},
"data_type": "categorical",
"vals": [10, 20, 30, 40, 70],
},
"prior_from_cooccurrences_101_31": {
"profile": {
"driver": "GTiff",
"dtype": "uint8",
"nodata": None,
"count": 5,
"crs": CRS.from_epsg(26914),
"blockxsize": 512,
"blockysize": 512,
"tiled": True,
"compress": "deflate",
"interleave": "band",
},
"data_type": "continuous",
"vals": (0, 225),
},
"prior_from_cooccurrences_101_31_no_osm_no_buildings": {
"profile": {
"driver": "GTiff",
"dtype": "uint8",
"nodata": None,
"count": 5,
"crs": CRS.from_epsg(26914),
"blockxsize": 512,
"blockysize": 512,
"tiled": True,
"compress": "deflate",
"interleave": "band",
},
"data_type": "continuous",
"vals": (0, 220),
},
}
tile_list = [
"pittsburgh_pa-2010_1m-train_tiles-debuffered/4007925_se",
"austin_tx-2012_1m-test_tiles-debuffered/3009726_sw",
]
def write_data(path: str, profile: Dict[Any, Any], data_type: Any, vals: Any) -> None:
assert all(key in profile for key in ("count", "height", "width", "dtype"))
with rasterio.open(path, "w", **profile) as dst:
size = (profile["count"], profile["height"], profile["width"])
dtype = np.dtype(profile["dtype"])
if data_type == "continuous":
data = np.random.randint(vals[0], vals[1] + 1, size=size, dtype=dtype)
elif data_type == "categorical":
data = np.random.choice(vals, size=size).astype(dtype)
else:
raise ValueError(f"{data_type} is not recognized")
dst.write(data)
def generate_test_data(root: str) -> str:
"""Creates test data archive for the EnviroAtlas dataset and returns its md5 hash.
Args:
root (str): Path to store test data
Returns:
str: md5 hash of created archive
"""
size = (64, 64)
folder_path = os.path.join(root, "enviroatlas_lotp")
if not os.path.exists(folder_path):
os.makedirs(folder_path)
for prefix in tile_list:
for suffix, data_profile in layer_data_profiles.items():
img_path = os.path.join(folder_path, f"{prefix}_{suffix}.tif")
img_dir = os.path.dirname(img_path)
if not os.path.exists(img_dir):
os.makedirs(img_dir)
data_profile["profile"]["height"] = size[0]
data_profile["profile"]["width"] = size[1]
data_profile["profile"]["transform"] = Affine(
1.0, 0.0, 608170.0, 0.0, -1.0, 3381430.0
)
write_data(
img_path,
data_profile["profile"],
data_profile["data_type"],
data_profile["vals"],
)
# build the spatial index
schema = {
"geometry": "Polygon",
"properties": {
"split": "str",
"naip": "str",
"nlcd": "str",
"roads": "str",
"water": "str",
"waterways": "str",
"waterbodies": "str",
"buildings": "str",
"lc": "str",
"prior_no_osm_no_buildings": "str",
"prior": "str",
},
}
with fiona.open(
os.path.join(folder_path, "spatial_index.geojson"),
"w",
driver="GeoJSON",
crs="EPSG:3857",
schema=schema,
) as dst:
for prefix in tile_list:
img_path = os.path.join(folder_path, f"{prefix}_a_naip.tif")
with rasterio.open(img_path) as f:
geom = shapely.geometry.mapping(shapely.geometry.box(*f.bounds))
geom = fiona.transform.transform_geom(
f.crs.to_string(), "EPSG:3857", geom
)
row = {
"geometry": geom,
"properties": {
"split": prefix.split("/")[0].replace("_tiles-debuffered", "")
},
}
for suffix, data_profile in layer_data_profiles.items():
key = suffix_to_key_map[suffix]
row["properties"][key] = f"{prefix}_{suffix}.tif"
dst.write(row)
# Create archive
archive_path = os.path.join(root, "enviroatlas_lotp")
shutil.make_archive(archive_path, "zip", root_dir=root, base_dir="enviroatlas_lotp")
shutil.rmtree(folder_path)
md5: str = calculate_md5(archive_path + ".zip")
return md5
if __name__ == "__main__":
md5_hash = generate_test_data(os.getcwd())
print(md5_hash)

Двоичные данные
tests/data/enviroatlas/enviroatlas_lotp.zip Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,134 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import shutil
from pathlib import Path
from typing import Generator
import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from rasterio.crs import CRS
import torchgeo.datasets.utils
from torchgeo.datasets import (
BoundingBox,
EnviroAtlas,
IntersectionDataset,
UnionDataset,
)
from torchgeo.samplers import RandomGeoSampler
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
class TestEnviroAtlas:
@pytest.fixture(
params=[
(("naip", "prior", "lc"), False),
(("naip", "prior", "buildings", "lc"), True),
(("naip", "prior"), False),
]
)
def dataset(
self,
request: SubRequest,
monkeypatch: Generator[MonkeyPatch, None, None],
tmp_path: Path,
) -> EnviroAtlas:
monkeypatch.setattr( # type: ignore[attr-defined]
torchgeo.datasets.enviroatlas, "download_url", download_url
)
monkeypatch.setattr( # type: ignore[attr-defined]
EnviroAtlas, "md5", "071ec65c611e1d4915a5247bffb5ad87"
)
monkeypatch.setattr( # type: ignore[attr-defined]
EnviroAtlas,
"url",
os.path.join("tests", "data", "enviroatlas", "enviroatlas_lotp.zip"),
)
monkeypatch.setattr( # type: ignore[attr-defined]
EnviroAtlas,
"files",
["pittsburgh_pa-2010_1m-train_tiles-debuffered", "spatial_index.geojson"],
)
root = str(tmp_path)
transforms = nn.Identity() # type: ignore[attr-defined]
return EnviroAtlas(
root,
layers=request.param[0],
transforms=transforms,
prior_as_input=request.param[1],
download=True,
checksum=True,
)
def test_getitem(self, dataset: EnviroAtlas) -> None:
sampler = RandomGeoSampler(dataset, size=16, length=32)
bb = next(iter(sampler))
x = dataset[bb]
assert isinstance(x, dict)
assert isinstance(x["crs"], CRS)
assert isinstance(x["mask"], torch.Tensor)
def test_and(self, dataset: EnviroAtlas) -> None:
ds = dataset & dataset
assert isinstance(ds, IntersectionDataset)
def test_or(self, dataset: EnviroAtlas) -> None:
ds = dataset | dataset
assert isinstance(ds, UnionDataset)
def test_already_extracted(self, dataset: EnviroAtlas) -> None:
EnviroAtlas(root=dataset.root, download=True)
def test_already_downloaded(self, tmp_path: Path) -> None:
root = str(tmp_path)
shutil.copy(
os.path.join("tests", "data", "enviroatlas", "enviroatlas_lotp.zip"), root
)
EnviroAtlas(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
EnviroAtlas(str(tmp_path), checksum=True)
def test_out_of_bounds_query(self, dataset: EnviroAtlas) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
with pytest.raises(
IndexError, match="query: .* not found in index with bounds:"
):
dataset[query]
def test_multiple_hits_query(self, dataset: EnviroAtlas) -> None:
ds = EnviroAtlas(
root=dataset.root,
splits=["pittsburgh_pa-2010_1m-train", "austin_tx-2012_1m-test"],
layers=dataset.layers,
)
with pytest.raises(
IndexError, match="query: .* spans multiple tiles which is not valid"
):
ds[dataset.bounds]
def test_plot(self, dataset: EnviroAtlas) -> None:
sampler = RandomGeoSampler(dataset, size=16, length=1)
bb = next(iter(sampler))
x = dataset[bb]
if "naip" not in dataset.layers or "lc" not in dataset.layers:
with pytest.raises(ValueError, match="The 'naip' and"):
dataset.plot(x)
else:
dataset.plot(x, suptitle="Test")
plt.close()
dataset.plot(x, show_titles=False)
plt.close()
x["prediction"] = x["mask"][0].clone()
dataset.plot(x)
plt.close()

Просмотреть файл

@ -25,6 +25,7 @@ from .cowc import COWC, COWCCounting, COWCDetection
from .cv4a_kenya_crop_type import CV4AKenyaCropType
from .cyclone import TropicalCycloneWindEstimation
from .dfc2022 import DFC2022
from .enviroatlas import EnviroAtlas
from .etci2021 import ETCI2021
from .eurosat import EuroSAT
from .fair1m import FAIR1M
@ -118,6 +119,7 @@ __all__ = (
"COWCDetection",
"CV4AKenyaCropType",
"DFC2022",
"EnviroAtlas",
"ETCI2021",
"EuroSAT",
"FAIR1M",

Просмотреть файл

@ -6,7 +6,7 @@
import abc
import os
import sys
from typing import Any, Callable, Dict, List, Optional, Sequence
from typing import Any, Callable, Dict, Optional, Sequence
import fiona
import numpy as np
@ -402,7 +402,7 @@ class ChesapeakeCVPR(GeoDataset):
self,
root: str = "data",
splits: Sequence[str] = ["de-train"],
layers: List[str] = ["naip-new", "lc"],
layers: Sequence[str] = ["naip-new", "lc"],
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
cache: bool = True,
download: bool = False,
@ -427,6 +427,7 @@ class ChesapeakeCVPR(GeoDataset):
Raises:
FileNotFoundError: if no files are found in ``root``
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
AssertionError: if ``splits`` or ``layers`` are not valid
"""
for split in splits:
assert split in self.splits

Просмотреть файл

@ -0,0 +1,537 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""EnviroAtlas High-Resolution Land Cover datasets."""
import os
import sys
from typing import Any, Callable, Dict, Optional, Sequence
import fiona
import matplotlib.pyplot as plt
import numpy as np
import pyproj
import rasterio
import rasterio.mask
import shapely.geometry
import shapely.ops
import torch
from matplotlib.colors import ListedColormap
from rasterio.crs import CRS
from torch import Tensor
from .geo import GeoDataset
from .utils import BoundingBox, download_url, extract_archive
class EnviroAtlas(GeoDataset):
"""EnviroAtlas dataset covering four cities with prior and weak input data layers.
The `EnviroAtlas
<https://doi.org/10.5281/zenodo.5778192>`_ dataset contains NAIP aerial imagery,
NLCD land cover labels, OpenStreetMap roads, water, waterways, and waterbodies,
Microsoft building footprint labels, high-resolution land cover labels from the
EPA EnviroAtlas dataset, and high-resolution land cover prior layers.
This dataset was organized to accompany the 2022 paper, `"Resolving label
uncertainty with implicit generative models"
<https://openreview.net/forum?id=AEa_UepnMDX>`_. More details can be found at
https://github.com/estherrolf/qr_for_landcover.
If you use this dataset in your research, please cite the following paper:
* https://openreview.net/forum?id=AEa_UepnMDX
.. versionadded:: 0.3
"""
url = "https://zenodo.org/record/5778193/files/enviroatlas_lotp.zip?download=1"
filename = "enviroatlas_lotp.zip"
md5 = "6142f8d1ebfc7f8ad888337f0683dc7a"
crs = CRS.from_epsg(3857)
res = 1
valid_prior_layers = ["prior", "prior_no_osm_no_buildings"]
valid_layers = [
"naip",
"nlcd",
"roads",
"water",
"waterways",
"waterbodies",
"buildings",
"lc",
] + valid_prior_layers
cities = [
"pittsburgh_pa-2010_1m",
"durham_nc-2012_1m",
"austin_tx-2012_1m",
"phoenix_az-2010_1m",
]
splits = (
[f"{state}-train" for state in cities[:1]]
+ [f"{state}-val" for state in cities[:1]]
+ [f"{state}-test" for state in cities]
+ [f"{state}-val5" for state in cities]
)
# these are used to check the integrity of the dataset
files = [
"austin_tx-2012_1m-test_tiles-debuffered",
"austin_tx-2012_1m-val5_tiles-debuffered",
"durham_nc-2012_1m-test_tiles-debuffered",
"durham_nc-2012_1m-val5_tiles-debuffered",
"phoenix_az-2010_1m-test_tiles-debuffered",
"phoenix_az-2010_1m-val5_tiles-debuffered",
"pittsburgh_pa-2010_1m-test_tiles-debuffered",
"pittsburgh_pa-2010_1m-train_tiles-debuffered",
"pittsburgh_pa-2010_1m-val5_tiles-debuffered",
"pittsburgh_pa-2010_1m-val_tiles-debuffered",
"austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_a_naip.tif",
"austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_b_nlcd.tif",
"austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_c_roads.tif",
"austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_d1_waterways.tif",
"austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_d2_waterbodies.tif",
"austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_d_water.tif",
"austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_e_buildings.tif",
"austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_h_highres_labels.tif",
"austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_prior_from_cooccurrences_101_31.tif", # noqa: E501
"austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_prior_from_cooccurrences_101_31_no_osm_no_buildings.tif", # noqa: E501
"spatial_index.geojson",
]
p_src_crs = pyproj.CRS("epsg:3857")
p_transformers = {
"epsg:26917": pyproj.Transformer.from_crs(
p_src_crs, pyproj.CRS("epsg:26917"), always_xy=True
).transform,
"epsg:26918": pyproj.Transformer.from_crs(
p_src_crs, pyproj.CRS("epsg:26918"), always_xy=True
).transform,
"epsg:26914": pyproj.Transformer.from_crs(
p_src_crs, pyproj.CRS("epsg:26914"), always_xy=True
).transform,
"epsg:26912": pyproj.Transformer.from_crs(
p_src_crs, pyproj.CRS("epsg:26912"), always_xy=True
).transform,
}
# used to convert the 10 high-res classes labeled as [0, 10, 20, 30, 40, 52, 70, 80,
# 82, 91, 92] to sequential labels [0, ..., 10]
raw_enviroatlas_to_idx_map: "np.typing.NDArray[np.uint8]" = np.array(
[
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
1,
0,
0,
0,
0,
0,
0,
0,
0,
0,
2,
0,
0,
0,
0,
0,
0,
0,
0,
0,
3,
0,
0,
0,
0,
0,
0,
0,
0,
0,
4,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
5,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
6,
0,
0,
0,
0,
0,
0,
0,
0,
0,
7,
0,
8,
0,
0,
0,
0,
0,
0,
0,
0,
9,
10,
],
dtype=np.uint8,
)
highres_classes = [
"Unclassified",
"Water",
"Impervious Surface",
"Soil and Barren",
"Trees and Forest",
"Shrubs",
"Grass and Herbaceous",
"Agriculture",
"Orchards",
"Woody Wetlands",
"Emergent Wetlands",
]
highres_cmap = ListedColormap(
[
[1.00000000, 1.00000000, 1.00000000],
[0.00000000, 0.77254902, 1.00000000],
[0.61176471, 0.61176471, 0.61176471],
[1.00000000, 0.66666667, 0.00000000],
[0.14901961, 0.45098039, 0.00000000],
[0.80000000, 0.72156863, 0.47450980],
[0.63921569, 1.00000000, 0.45098039],
[0.86274510, 0.85098039, 0.22352941],
[0.67058824, 0.42352941, 0.15686275],
[0.72156863, 0.85098039, 0.92156863],
[0.42352941, 0.62352941, 0.72156863],
]
)
def __init__(
self,
root: str = "data",
splits: Sequence[str] = ["pittsburgh_pa-2010_1m-train"],
layers: Sequence[str] = ["naip", "prior"],
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
prior_as_input: bool = False,
cache: bool = True,
download: bool = False,
checksum: bool = False,
) -> None:
"""Initialize a new Dataset instance.
Args:
root: root directory where dataset can be found
splits: a list of strings in the format "{state}-{train,val,test}"
indicating the subset of data to use, for example "ny-train"
layers: a list containing a subset of ``valid_layers`` indicating which
layers to load
transforms: a function/transform that takes an input sample
and returns a transformed version
prior_as_input: bool describing whether the prior is used as an input (True)
or as supervision (False)
cache: if True, cache file handle to speed up repeated sampling
download: if True, download dataset and store it in the root directory
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
FileNotFoundError: if no files are found in ``root``
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
AssertionError: if ``splits`` or ``layers`` are not valid
"""
for split in splits:
assert split in self.splits
assert all([layer in self.valid_layers for layer in layers])
self.root = root
self.layers = layers
self.cache = cache
self.download = download
self.checksum = checksum
self.prior_as_input = prior_as_input
self._verify()
super().__init__(transforms)
# Add all tiles into the index in epsg:3857 based on the included geojson
mint: float = 0
maxt: float = sys.maxsize
with fiona.open(
os.path.join(root, "enviroatlas_lotp", "spatial_index.geojson"), "r"
) as f:
for i, row in enumerate(f):
if row["properties"]["split"] in splits:
box = shapely.geometry.shape(row["geometry"])
minx, miny, maxx, maxy = box.bounds
coords = (minx, maxx, miny, maxy, mint, maxt)
self.index.insert(
i,
coords,
{
"naip": row["properties"]["naip"],
"nlcd": row["properties"]["nlcd"],
"roads": row["properties"]["roads"],
"water": row["properties"]["water"],
"waterways": row["properties"]["waterways"],
"waterbodies": row["properties"]["waterbodies"],
"buildings": row["properties"]["buildings"],
"lc": row["properties"]["lc"],
"prior_no_osm_no_buildings": row["properties"][
"naip"
].replace(
"a_naip",
"prior_from_cooccurrences_101_31_no_osm_no_buildings",
),
"prior": row["properties"]["naip"].replace(
"a_naip", "prior_from_cooccurrences_101_31"
),
},
)
def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
"""Retrieve image/mask and metadata indexed by query.
Args:
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
Returns:
sample of image/mask and metadata at that index
Raises:
IndexError: if query is not found in the index
"""
hits = self.index.intersection(tuple(query), objects=True)
filepaths = [hit.object for hit in hits]
sample = {"image": [], "mask": [], "crs": self.crs, "bbox": query}
if len(filepaths) == 0:
raise IndexError(
f"query: {query} not found in index with bounds: {self.bounds}"
)
elif len(filepaths) == 1:
filenames = filepaths[0]
query_geom_transformed = None # is set by the first layer
minx, maxx, miny, maxy, mint, maxt = query
query_box = shapely.geometry.box(minx, miny, maxx, maxy)
for layer in self.layers:
fn = filenames[layer]
with rasterio.open(
os.path.join(self.root, "enviroatlas_lotp", fn)
) as f:
dst_crs = f.crs.to_string().lower()
if query_geom_transformed is None:
query_box_transformed = shapely.ops.transform(
self.p_transformers[dst_crs], query_box
).envelope
query_geom_transformed = shapely.geometry.mapping(
query_box_transformed
)
data, _ = rasterio.mask.mask(
f, [query_geom_transformed], crop=True, all_touched=True
)
if layer in [
"naip",
"buildings",
"roads",
"waterways",
"waterbodies",
"water",
]:
sample["image"].append(data)
elif layer in ["prior", "prior_no_osm_no_buildings"]:
if self.prior_as_input:
sample["image"].append(data)
else:
sample["mask"].append(data)
elif layer in ["lc"]:
data = self.raw_enviroatlas_to_idx_map[data]
sample["mask"].append(data)
else:
raise IndexError(f"query: {query} spans multiple tiles which is not valid")
sample["image"] = np.concatenate(sample["image"], axis=0)
sample["mask"] = np.concatenate(sample["mask"], axis=0)
sample["image"] = torch.from_numpy( # type: ignore[attr-defined]
sample["image"]
)
sample["mask"] = torch.from_numpy(sample["mask"]) # type: ignore[attr-defined]
if self.transforms is not None:
sample = self.transforms(sample)
return sample
def _verify(self) -> None:
"""Verify the integrity of the dataset.
Raises:
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
"""
# Check if the extracted files already exist
def exists(filename: str) -> bool:
return os.path.exists(os.path.join(self.root, "enviroatlas_lotp", filename))
if all(map(exists, self.files)):
return
# Check if the zip files have already been downloaded
if os.path.exists(os.path.join(self.root, self.filename)):
self._extract()
return
# Check if the user requested to download the dataset
if not self.download:
raise RuntimeError(
f"Dataset not found in `root={self.root}` and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automaticaly download the dataset."
)
# Download the dataset
self._download()
self._extract()
def _download(self) -> None:
"""Download the dataset."""
download_url(self.url, self.root, filename=self.filename, md5=self.md5)
def _extract(self) -> None:
"""Extract the dataset."""
extract_archive(os.path.join(self.root, self.filename))
def plot(
self,
sample: Dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
"""Plot a sample from the dataset.
Note: only plots the "naip" and "lc" layers.
Args:
sample: a sample returned by :meth:`__getitem__`
show_titles: flag indicating whether to show titles above each panel
suptitle: optional string to use as a suptitle
Returns:
a matplotlib Figure with the rendered sample
Raises:
ValueError: if the NAIP layer isn't included in ``self.layers``
"""
if "naip" not in self.layers or "lc" not in self.layers:
raise ValueError("The 'naip' and 'lc' layers must be included for plotting")
image_layers = []
mask_layers = []
for layer in self.layers:
if layer in [
"naip",
"buildings",
"roads",
"waterways",
"waterbodies",
"water",
]:
image_layers.append(layer)
elif layer in ["prior", "prior_no_osm_no_buildings"]:
if self.prior_as_input:
image_layers.append(layer)
else:
mask_layers.append(layer)
elif layer in ["lc"]:
mask_layers.append(layer)
naip_index = image_layers.index("naip")
lc_index = mask_layers.index("lc")
image = np.rollaxis(
sample["image"][naip_index : naip_index + 3, :, :].numpy(), 0, 3
)
mask = sample["mask"][lc_index].numpy()
num_panels = 2
showing_predictions = "prediction" in sample
if showing_predictions:
predictions = sample["prediction"].numpy()
num_panels += 1
fig, axs = plt.subplots(1, num_panels, figsize=(num_panels * 4, 5))
axs[0].imshow(image)
axs[0].axis("off")
axs[1].imshow(
mask, vmin=0, vmax=10, cmap=self.highres_cmap, interpolation="none"
)
axs[1].axis("off")
if show_titles:
axs[0].set_title("Image")
axs[1].set_title("Mask")
if showing_predictions:
axs[2].imshow(
predictions,
vmin=0,
vmax=10,
cmap=self.highres_cmap,
interpolation="none",
)
axs[2].axis("off")
if show_titles:
axs[2].set_title("Predictions")
if suptitle is not None:
plt.suptitle(suptitle)
return fig