зеркало из https://github.com/microsoft/torchgeo.git
Add the EnviroAtlas dataset (#364)
* Add dataset * Add dataset to docs * Tests for enviroatlas * Test coverage * Added numpy type * Added plotting * Code review changes * Propagating code review comments to Chesapeake
This commit is contained in:
Родитель
57981c823f
Коммит
d4c8a4bd7b
|
@ -37,6 +37,11 @@ Cropland Data Layer (CDL)
|
|||
|
||||
.. autoclass:: CDL
|
||||
|
||||
EnviroAtlas
|
||||
^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: EnviroAtlas
|
||||
|
||||
Landsat
|
||||
^^^^^^^
|
||||
|
||||
|
|
|
@ -0,0 +1,305 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from typing import Any, Dict
|
||||
|
||||
import fiona
|
||||
import fiona.transform
|
||||
import numpy as np
|
||||
import rasterio
|
||||
import shapely.geometry
|
||||
from rasterio.crs import CRS
|
||||
from rasterio.transform import Affine
|
||||
from torchvision.datasets.utils import calculate_md5
|
||||
|
||||
suffix_to_key_map = {
|
||||
"a_naip": "naip",
|
||||
"b_nlcd": "nlcd",
|
||||
"c_roads": "roads",
|
||||
"d_water": "water",
|
||||
"d1_waterways": "waterways",
|
||||
"d2_waterbodies": "waterbodies",
|
||||
"e_buildings": "buildings",
|
||||
"h_highres_labels": "lc",
|
||||
"prior_from_cooccurrences_101_31": "prior",
|
||||
"prior_from_cooccurrences_101_31_no_osm_no_buildings": "prior_no_osm_no_buildings",
|
||||
}
|
||||
|
||||
layer_data_profiles: Dict[str, Dict[Any, Any]] = {
|
||||
"a_naip": {
|
||||
"profile": {
|
||||
"driver": "GTiff",
|
||||
"dtype": "uint8",
|
||||
"nodata": None,
|
||||
"count": 4,
|
||||
"crs": CRS.from_epsg(26914),
|
||||
"blockxsize": 512,
|
||||
"blockysize": 512,
|
||||
"tiled": True,
|
||||
"compress": "deflate",
|
||||
"interleave": "pixel",
|
||||
},
|
||||
"data_type": "continuous",
|
||||
"vals": (4, 255),
|
||||
},
|
||||
"b_nlcd": {
|
||||
"profile": {
|
||||
"driver": "GTiff",
|
||||
"dtype": "uint8",
|
||||
"nodata": None,
|
||||
"count": 1,
|
||||
"crs": CRS.from_epsg(26914),
|
||||
"blockxsize": 512,
|
||||
"blockysize": 512,
|
||||
"tiled": True,
|
||||
"compress": "deflate",
|
||||
"interleave": "band",
|
||||
},
|
||||
"data_type": "categorical",
|
||||
"vals": [1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15],
|
||||
},
|
||||
"c_roads": {
|
||||
"profile": {
|
||||
"driver": "GTiff",
|
||||
"dtype": "uint8",
|
||||
"nodata": None,
|
||||
"count": 1,
|
||||
"crs": CRS.from_epsg(26914),
|
||||
"blockxsize": 512,
|
||||
"blockysize": 512,
|
||||
"tiled": True,
|
||||
"compress": "deflate",
|
||||
"interleave": "band",
|
||||
},
|
||||
"data_type": "categorical",
|
||||
"vals": [0, 1],
|
||||
},
|
||||
"d1_waterways": {
|
||||
"profile": {
|
||||
"driver": "GTiff",
|
||||
"dtype": "uint8",
|
||||
"nodata": None,
|
||||
"count": 1,
|
||||
"crs": CRS.from_epsg(26914),
|
||||
"blockxsize": 512,
|
||||
"blockysize": 512,
|
||||
"tiled": True,
|
||||
"compress": "deflate",
|
||||
"interleave": "band",
|
||||
},
|
||||
"data_type": "categorical",
|
||||
"vals": [0, 1],
|
||||
},
|
||||
"d2_waterbodies": {
|
||||
"profile": {
|
||||
"driver": "GTiff",
|
||||
"dtype": "uint8",
|
||||
"nodata": None,
|
||||
"count": 1,
|
||||
"crs": CRS.from_epsg(26914),
|
||||
"blockxsize": 512,
|
||||
"blockysize": 512,
|
||||
"tiled": True,
|
||||
"compress": "deflate",
|
||||
"interleave": "band",
|
||||
},
|
||||
"data_type": "categorical",
|
||||
"vals": [0, 1],
|
||||
},
|
||||
"d_water": {
|
||||
"profile": {
|
||||
"driver": "GTiff",
|
||||
"dtype": "uint8",
|
||||
"nodata": None,
|
||||
"count": 1,
|
||||
"crs": CRS.from_epsg(26914),
|
||||
"blockxsize": 512,
|
||||
"blockysize": 512,
|
||||
"tiled": True,
|
||||
"compress": "deflate",
|
||||
"interleave": "band",
|
||||
},
|
||||
"data_type": "categorical",
|
||||
"vals": [0, 1],
|
||||
},
|
||||
"e_buildings": {
|
||||
"profile": {
|
||||
"driver": "GTiff",
|
||||
"dtype": "uint8",
|
||||
"nodata": None,
|
||||
"count": 1,
|
||||
"crs": CRS.from_epsg(26914),
|
||||
"blockxsize": 512,
|
||||
"blockysize": 512,
|
||||
"tiled": True,
|
||||
"compress": "deflate",
|
||||
"interleave": "band",
|
||||
},
|
||||
"data_type": "categorical",
|
||||
"vals": [0, 1],
|
||||
},
|
||||
"h_highres_labels": {
|
||||
"profile": {
|
||||
"driver": "GTiff",
|
||||
"dtype": "uint8",
|
||||
"nodata": None,
|
||||
"count": 1,
|
||||
"crs": CRS.from_epsg(26914),
|
||||
"blockxsize": 512,
|
||||
"blockysize": 512,
|
||||
"tiled": True,
|
||||
"compress": "deflate",
|
||||
"interleave": "band",
|
||||
},
|
||||
"data_type": "categorical",
|
||||
"vals": [10, 20, 30, 40, 70],
|
||||
},
|
||||
"prior_from_cooccurrences_101_31": {
|
||||
"profile": {
|
||||
"driver": "GTiff",
|
||||
"dtype": "uint8",
|
||||
"nodata": None,
|
||||
"count": 5,
|
||||
"crs": CRS.from_epsg(26914),
|
||||
"blockxsize": 512,
|
||||
"blockysize": 512,
|
||||
"tiled": True,
|
||||
"compress": "deflate",
|
||||
"interleave": "band",
|
||||
},
|
||||
"data_type": "continuous",
|
||||
"vals": (0, 225),
|
||||
},
|
||||
"prior_from_cooccurrences_101_31_no_osm_no_buildings": {
|
||||
"profile": {
|
||||
"driver": "GTiff",
|
||||
"dtype": "uint8",
|
||||
"nodata": None,
|
||||
"count": 5,
|
||||
"crs": CRS.from_epsg(26914),
|
||||
"blockxsize": 512,
|
||||
"blockysize": 512,
|
||||
"tiled": True,
|
||||
"compress": "deflate",
|
||||
"interleave": "band",
|
||||
},
|
||||
"data_type": "continuous",
|
||||
"vals": (0, 220),
|
||||
},
|
||||
}
|
||||
|
||||
tile_list = [
|
||||
"pittsburgh_pa-2010_1m-train_tiles-debuffered/4007925_se",
|
||||
"austin_tx-2012_1m-test_tiles-debuffered/3009726_sw",
|
||||
]
|
||||
|
||||
|
||||
def write_data(path: str, profile: Dict[Any, Any], data_type: Any, vals: Any) -> None:
|
||||
assert all(key in profile for key in ("count", "height", "width", "dtype"))
|
||||
with rasterio.open(path, "w", **profile) as dst:
|
||||
size = (profile["count"], profile["height"], profile["width"])
|
||||
dtype = np.dtype(profile["dtype"])
|
||||
if data_type == "continuous":
|
||||
data = np.random.randint(vals[0], vals[1] + 1, size=size, dtype=dtype)
|
||||
elif data_type == "categorical":
|
||||
data = np.random.choice(vals, size=size).astype(dtype)
|
||||
else:
|
||||
raise ValueError(f"{data_type} is not recognized")
|
||||
dst.write(data)
|
||||
|
||||
|
||||
def generate_test_data(root: str) -> str:
|
||||
"""Creates test data archive for the EnviroAtlas dataset and returns its md5 hash.
|
||||
|
||||
Args:
|
||||
root (str): Path to store test data
|
||||
|
||||
Returns:
|
||||
str: md5 hash of created archive
|
||||
"""
|
||||
size = (64, 64)
|
||||
folder_path = os.path.join(root, "enviroatlas_lotp")
|
||||
|
||||
if not os.path.exists(folder_path):
|
||||
os.makedirs(folder_path)
|
||||
|
||||
for prefix in tile_list:
|
||||
for suffix, data_profile in layer_data_profiles.items():
|
||||
|
||||
img_path = os.path.join(folder_path, f"{prefix}_{suffix}.tif")
|
||||
img_dir = os.path.dirname(img_path)
|
||||
if not os.path.exists(img_dir):
|
||||
os.makedirs(img_dir)
|
||||
|
||||
data_profile["profile"]["height"] = size[0]
|
||||
data_profile["profile"]["width"] = size[1]
|
||||
data_profile["profile"]["transform"] = Affine(
|
||||
1.0, 0.0, 608170.0, 0.0, -1.0, 3381430.0
|
||||
)
|
||||
|
||||
write_data(
|
||||
img_path,
|
||||
data_profile["profile"],
|
||||
data_profile["data_type"],
|
||||
data_profile["vals"],
|
||||
)
|
||||
|
||||
# build the spatial index
|
||||
schema = {
|
||||
"geometry": "Polygon",
|
||||
"properties": {
|
||||
"split": "str",
|
||||
"naip": "str",
|
||||
"nlcd": "str",
|
||||
"roads": "str",
|
||||
"water": "str",
|
||||
"waterways": "str",
|
||||
"waterbodies": "str",
|
||||
"buildings": "str",
|
||||
"lc": "str",
|
||||
"prior_no_osm_no_buildings": "str",
|
||||
"prior": "str",
|
||||
},
|
||||
}
|
||||
with fiona.open(
|
||||
os.path.join(folder_path, "spatial_index.geojson"),
|
||||
"w",
|
||||
driver="GeoJSON",
|
||||
crs="EPSG:3857",
|
||||
schema=schema,
|
||||
) as dst:
|
||||
for prefix in tile_list:
|
||||
|
||||
img_path = os.path.join(folder_path, f"{prefix}_a_naip.tif")
|
||||
with rasterio.open(img_path) as f:
|
||||
geom = shapely.geometry.mapping(shapely.geometry.box(*f.bounds))
|
||||
geom = fiona.transform.transform_geom(
|
||||
f.crs.to_string(), "EPSG:3857", geom
|
||||
)
|
||||
|
||||
row = {
|
||||
"geometry": geom,
|
||||
"properties": {
|
||||
"split": prefix.split("/")[0].replace("_tiles-debuffered", "")
|
||||
},
|
||||
}
|
||||
for suffix, data_profile in layer_data_profiles.items():
|
||||
key = suffix_to_key_map[suffix]
|
||||
row["properties"][key] = f"{prefix}_{suffix}.tif"
|
||||
dst.write(row)
|
||||
|
||||
# Create archive
|
||||
archive_path = os.path.join(root, "enviroatlas_lotp")
|
||||
shutil.make_archive(archive_path, "zip", root_dir=root, base_dir="enviroatlas_lotp")
|
||||
shutil.rmtree(folder_path)
|
||||
md5: str = calculate_md5(archive_path + ".zip")
|
||||
return md5
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
md5_hash = generate_test_data(os.getcwd())
|
||||
print(md5_hash)
|
Двоичный файл не отображается.
|
@ -0,0 +1,134 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from rasterio.crs import CRS
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import (
|
||||
BoundingBox,
|
||||
EnviroAtlas,
|
||||
IntersectionDataset,
|
||||
UnionDataset,
|
||||
)
|
||||
from torchgeo.samplers import RandomGeoSampler
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
class TestEnviroAtlas:
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
(("naip", "prior", "lc"), False),
|
||||
(("naip", "prior", "buildings", "lc"), True),
|
||||
(("naip", "prior"), False),
|
||||
]
|
||||
)
|
||||
def dataset(
|
||||
self,
|
||||
request: SubRequest,
|
||||
monkeypatch: Generator[MonkeyPatch, None, None],
|
||||
tmp_path: Path,
|
||||
) -> EnviroAtlas:
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
torchgeo.datasets.enviroatlas, "download_url", download_url
|
||||
)
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
EnviroAtlas, "md5", "071ec65c611e1d4915a5247bffb5ad87"
|
||||
)
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
EnviroAtlas,
|
||||
"url",
|
||||
os.path.join("tests", "data", "enviroatlas", "enviroatlas_lotp.zip"),
|
||||
)
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
EnviroAtlas,
|
||||
"files",
|
||||
["pittsburgh_pa-2010_1m-train_tiles-debuffered", "spatial_index.geojson"],
|
||||
)
|
||||
root = str(tmp_path)
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return EnviroAtlas(
|
||||
root,
|
||||
layers=request.param[0],
|
||||
transforms=transforms,
|
||||
prior_as_input=request.param[1],
|
||||
download=True,
|
||||
checksum=True,
|
||||
)
|
||||
|
||||
def test_getitem(self, dataset: EnviroAtlas) -> None:
|
||||
sampler = RandomGeoSampler(dataset, size=16, length=32)
|
||||
bb = next(iter(sampler))
|
||||
x = dataset[bb]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x["crs"], CRS)
|
||||
assert isinstance(x["mask"], torch.Tensor)
|
||||
|
||||
def test_and(self, dataset: EnviroAtlas) -> None:
|
||||
ds = dataset & dataset
|
||||
assert isinstance(ds, IntersectionDataset)
|
||||
|
||||
def test_or(self, dataset: EnviroAtlas) -> None:
|
||||
ds = dataset | dataset
|
||||
assert isinstance(ds, UnionDataset)
|
||||
|
||||
def test_already_extracted(self, dataset: EnviroAtlas) -> None:
|
||||
EnviroAtlas(root=dataset.root, download=True)
|
||||
|
||||
def test_already_downloaded(self, tmp_path: Path) -> None:
|
||||
root = str(tmp_path)
|
||||
shutil.copy(
|
||||
os.path.join("tests", "data", "enviroatlas", "enviroatlas_lotp.zip"), root
|
||||
)
|
||||
EnviroAtlas(root)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found"):
|
||||
EnviroAtlas(str(tmp_path), checksum=True)
|
||||
|
||||
def test_out_of_bounds_query(self, dataset: EnviroAtlas) -> None:
|
||||
query = BoundingBox(0, 0, 0, 0, 0, 0)
|
||||
with pytest.raises(
|
||||
IndexError, match="query: .* not found in index with bounds:"
|
||||
):
|
||||
dataset[query]
|
||||
|
||||
def test_multiple_hits_query(self, dataset: EnviroAtlas) -> None:
|
||||
ds = EnviroAtlas(
|
||||
root=dataset.root,
|
||||
splits=["pittsburgh_pa-2010_1m-train", "austin_tx-2012_1m-test"],
|
||||
layers=dataset.layers,
|
||||
)
|
||||
with pytest.raises(
|
||||
IndexError, match="query: .* spans multiple tiles which is not valid"
|
||||
):
|
||||
ds[dataset.bounds]
|
||||
|
||||
def test_plot(self, dataset: EnviroAtlas) -> None:
|
||||
sampler = RandomGeoSampler(dataset, size=16, length=1)
|
||||
bb = next(iter(sampler))
|
||||
x = dataset[bb]
|
||||
if "naip" not in dataset.layers or "lc" not in dataset.layers:
|
||||
with pytest.raises(ValueError, match="The 'naip' and"):
|
||||
dataset.plot(x)
|
||||
else:
|
||||
dataset.plot(x, suptitle="Test")
|
||||
plt.close()
|
||||
dataset.plot(x, show_titles=False)
|
||||
plt.close()
|
||||
x["prediction"] = x["mask"][0].clone()
|
||||
dataset.plot(x)
|
||||
plt.close()
|
|
@ -25,6 +25,7 @@ from .cowc import COWC, COWCCounting, COWCDetection
|
|||
from .cv4a_kenya_crop_type import CV4AKenyaCropType
|
||||
from .cyclone import TropicalCycloneWindEstimation
|
||||
from .dfc2022 import DFC2022
|
||||
from .enviroatlas import EnviroAtlas
|
||||
from .etci2021 import ETCI2021
|
||||
from .eurosat import EuroSAT
|
||||
from .fair1m import FAIR1M
|
||||
|
@ -118,6 +119,7 @@ __all__ = (
|
|||
"COWCDetection",
|
||||
"CV4AKenyaCropType",
|
||||
"DFC2022",
|
||||
"EnviroAtlas",
|
||||
"ETCI2021",
|
||||
"EuroSAT",
|
||||
"FAIR1M",
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
import abc
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence
|
||||
from typing import Any, Callable, Dict, Optional, Sequence
|
||||
|
||||
import fiona
|
||||
import numpy as np
|
||||
|
@ -402,7 +402,7 @@ class ChesapeakeCVPR(GeoDataset):
|
|||
self,
|
||||
root: str = "data",
|
||||
splits: Sequence[str] = ["de-train"],
|
||||
layers: List[str] = ["naip-new", "lc"],
|
||||
layers: Sequence[str] = ["naip-new", "lc"],
|
||||
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
|
||||
cache: bool = True,
|
||||
download: bool = False,
|
||||
|
@ -427,6 +427,7 @@ class ChesapeakeCVPR(GeoDataset):
|
|||
Raises:
|
||||
FileNotFoundError: if no files are found in ``root``
|
||||
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
|
||||
AssertionError: if ``splits`` or ``layers`` are not valid
|
||||
"""
|
||||
for split in splits:
|
||||
assert split in self.splits
|
||||
|
|
|
@ -0,0 +1,537 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""EnviroAtlas High-Resolution Land Cover datasets."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Callable, Dict, Optional, Sequence
|
||||
|
||||
import fiona
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pyproj
|
||||
import rasterio
|
||||
import rasterio.mask
|
||||
import shapely.geometry
|
||||
import shapely.ops
|
||||
import torch
|
||||
from matplotlib.colors import ListedColormap
|
||||
from rasterio.crs import CRS
|
||||
from torch import Tensor
|
||||
|
||||
from .geo import GeoDataset
|
||||
from .utils import BoundingBox, download_url, extract_archive
|
||||
|
||||
|
||||
class EnviroAtlas(GeoDataset):
|
||||
"""EnviroAtlas dataset covering four cities with prior and weak input data layers.
|
||||
|
||||
The `EnviroAtlas
|
||||
<https://doi.org/10.5281/zenodo.5778192>`_ dataset contains NAIP aerial imagery,
|
||||
NLCD land cover labels, OpenStreetMap roads, water, waterways, and waterbodies,
|
||||
Microsoft building footprint labels, high-resolution land cover labels from the
|
||||
EPA EnviroAtlas dataset, and high-resolution land cover prior layers.
|
||||
|
||||
This dataset was organized to accompany the 2022 paper, `"Resolving label
|
||||
uncertainty with implicit generative models"
|
||||
<https://openreview.net/forum?id=AEa_UepnMDX>`_. More details can be found at
|
||||
https://github.com/estherrolf/qr_for_landcover.
|
||||
|
||||
If you use this dataset in your research, please cite the following paper:
|
||||
|
||||
* https://openreview.net/forum?id=AEa_UepnMDX
|
||||
|
||||
.. versionadded:: 0.3
|
||||
"""
|
||||
|
||||
url = "https://zenodo.org/record/5778193/files/enviroatlas_lotp.zip?download=1"
|
||||
filename = "enviroatlas_lotp.zip"
|
||||
md5 = "6142f8d1ebfc7f8ad888337f0683dc7a"
|
||||
|
||||
crs = CRS.from_epsg(3857)
|
||||
res = 1
|
||||
|
||||
valid_prior_layers = ["prior", "prior_no_osm_no_buildings"]
|
||||
|
||||
valid_layers = [
|
||||
"naip",
|
||||
"nlcd",
|
||||
"roads",
|
||||
"water",
|
||||
"waterways",
|
||||
"waterbodies",
|
||||
"buildings",
|
||||
"lc",
|
||||
] + valid_prior_layers
|
||||
|
||||
cities = [
|
||||
"pittsburgh_pa-2010_1m",
|
||||
"durham_nc-2012_1m",
|
||||
"austin_tx-2012_1m",
|
||||
"phoenix_az-2010_1m",
|
||||
]
|
||||
splits = (
|
||||
[f"{state}-train" for state in cities[:1]]
|
||||
+ [f"{state}-val" for state in cities[:1]]
|
||||
+ [f"{state}-test" for state in cities]
|
||||
+ [f"{state}-val5" for state in cities]
|
||||
)
|
||||
|
||||
# these are used to check the integrity of the dataset
|
||||
files = [
|
||||
"austin_tx-2012_1m-test_tiles-debuffered",
|
||||
"austin_tx-2012_1m-val5_tiles-debuffered",
|
||||
"durham_nc-2012_1m-test_tiles-debuffered",
|
||||
"durham_nc-2012_1m-val5_tiles-debuffered",
|
||||
"phoenix_az-2010_1m-test_tiles-debuffered",
|
||||
"phoenix_az-2010_1m-val5_tiles-debuffered",
|
||||
"pittsburgh_pa-2010_1m-test_tiles-debuffered",
|
||||
"pittsburgh_pa-2010_1m-train_tiles-debuffered",
|
||||
"pittsburgh_pa-2010_1m-val5_tiles-debuffered",
|
||||
"pittsburgh_pa-2010_1m-val_tiles-debuffered",
|
||||
"austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_a_naip.tif",
|
||||
"austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_b_nlcd.tif",
|
||||
"austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_c_roads.tif",
|
||||
"austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_d1_waterways.tif",
|
||||
"austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_d2_waterbodies.tif",
|
||||
"austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_d_water.tif",
|
||||
"austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_e_buildings.tif",
|
||||
"austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_h_highres_labels.tif",
|
||||
"austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_prior_from_cooccurrences_101_31.tif", # noqa: E501
|
||||
"austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_prior_from_cooccurrences_101_31_no_osm_no_buildings.tif", # noqa: E501
|
||||
"spatial_index.geojson",
|
||||
]
|
||||
|
||||
p_src_crs = pyproj.CRS("epsg:3857")
|
||||
p_transformers = {
|
||||
"epsg:26917": pyproj.Transformer.from_crs(
|
||||
p_src_crs, pyproj.CRS("epsg:26917"), always_xy=True
|
||||
).transform,
|
||||
"epsg:26918": pyproj.Transformer.from_crs(
|
||||
p_src_crs, pyproj.CRS("epsg:26918"), always_xy=True
|
||||
).transform,
|
||||
"epsg:26914": pyproj.Transformer.from_crs(
|
||||
p_src_crs, pyproj.CRS("epsg:26914"), always_xy=True
|
||||
).transform,
|
||||
"epsg:26912": pyproj.Transformer.from_crs(
|
||||
p_src_crs, pyproj.CRS("epsg:26912"), always_xy=True
|
||||
).transform,
|
||||
}
|
||||
|
||||
# used to convert the 10 high-res classes labeled as [0, 10, 20, 30, 40, 52, 70, 80,
|
||||
# 82, 91, 92] to sequential labels [0, ..., 10]
|
||||
raw_enviroatlas_to_idx_map: "np.typing.NDArray[np.uint8]" = np.array(
|
||||
[
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
2,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
3,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
4,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
5,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
6,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
7,
|
||||
0,
|
||||
8,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
9,
|
||||
10,
|
||||
],
|
||||
dtype=np.uint8,
|
||||
)
|
||||
|
||||
highres_classes = [
|
||||
"Unclassified",
|
||||
"Water",
|
||||
"Impervious Surface",
|
||||
"Soil and Barren",
|
||||
"Trees and Forest",
|
||||
"Shrubs",
|
||||
"Grass and Herbaceous",
|
||||
"Agriculture",
|
||||
"Orchards",
|
||||
"Woody Wetlands",
|
||||
"Emergent Wetlands",
|
||||
]
|
||||
highres_cmap = ListedColormap(
|
||||
[
|
||||
[1.00000000, 1.00000000, 1.00000000],
|
||||
[0.00000000, 0.77254902, 1.00000000],
|
||||
[0.61176471, 0.61176471, 0.61176471],
|
||||
[1.00000000, 0.66666667, 0.00000000],
|
||||
[0.14901961, 0.45098039, 0.00000000],
|
||||
[0.80000000, 0.72156863, 0.47450980],
|
||||
[0.63921569, 1.00000000, 0.45098039],
|
||||
[0.86274510, 0.85098039, 0.22352941],
|
||||
[0.67058824, 0.42352941, 0.15686275],
|
||||
[0.72156863, 0.85098039, 0.92156863],
|
||||
[0.42352941, 0.62352941, 0.72156863],
|
||||
]
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str = "data",
|
||||
splits: Sequence[str] = ["pittsburgh_pa-2010_1m-train"],
|
||||
layers: Sequence[str] = ["naip", "prior"],
|
||||
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
|
||||
prior_as_input: bool = False,
|
||||
cache: bool = True,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new Dataset instance.
|
||||
|
||||
Args:
|
||||
root: root directory where dataset can be found
|
||||
splits: a list of strings in the format "{state}-{train,val,test}"
|
||||
indicating the subset of data to use, for example "ny-train"
|
||||
layers: a list containing a subset of ``valid_layers`` indicating which
|
||||
layers to load
|
||||
transforms: a function/transform that takes an input sample
|
||||
and returns a transformed version
|
||||
prior_as_input: bool describing whether the prior is used as an input (True)
|
||||
or as supervision (False)
|
||||
cache: if True, cache file handle to speed up repeated sampling
|
||||
download: if True, download dataset and store it in the root directory
|
||||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: if no files are found in ``root``
|
||||
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
|
||||
AssertionError: if ``splits`` or ``layers`` are not valid
|
||||
"""
|
||||
for split in splits:
|
||||
assert split in self.splits
|
||||
assert all([layer in self.valid_layers for layer in layers])
|
||||
self.root = root
|
||||
self.layers = layers
|
||||
self.cache = cache
|
||||
self.download = download
|
||||
self.checksum = checksum
|
||||
self.prior_as_input = prior_as_input
|
||||
|
||||
self._verify()
|
||||
|
||||
super().__init__(transforms)
|
||||
|
||||
# Add all tiles into the index in epsg:3857 based on the included geojson
|
||||
mint: float = 0
|
||||
maxt: float = sys.maxsize
|
||||
with fiona.open(
|
||||
os.path.join(root, "enviroatlas_lotp", "spatial_index.geojson"), "r"
|
||||
) as f:
|
||||
for i, row in enumerate(f):
|
||||
if row["properties"]["split"] in splits:
|
||||
box = shapely.geometry.shape(row["geometry"])
|
||||
minx, miny, maxx, maxy = box.bounds
|
||||
coords = (minx, maxx, miny, maxy, mint, maxt)
|
||||
|
||||
self.index.insert(
|
||||
i,
|
||||
coords,
|
||||
{
|
||||
"naip": row["properties"]["naip"],
|
||||
"nlcd": row["properties"]["nlcd"],
|
||||
"roads": row["properties"]["roads"],
|
||||
"water": row["properties"]["water"],
|
||||
"waterways": row["properties"]["waterways"],
|
||||
"waterbodies": row["properties"]["waterbodies"],
|
||||
"buildings": row["properties"]["buildings"],
|
||||
"lc": row["properties"]["lc"],
|
||||
"prior_no_osm_no_buildings": row["properties"][
|
||||
"naip"
|
||||
].replace(
|
||||
"a_naip",
|
||||
"prior_from_cooccurrences_101_31_no_osm_no_buildings",
|
||||
),
|
||||
"prior": row["properties"]["naip"].replace(
|
||||
"a_naip", "prior_from_cooccurrences_101_31"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
|
||||
"""Retrieve image/mask and metadata indexed by query.
|
||||
|
||||
Args:
|
||||
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
|
||||
|
||||
Returns:
|
||||
sample of image/mask and metadata at that index
|
||||
|
||||
Raises:
|
||||
IndexError: if query is not found in the index
|
||||
"""
|
||||
hits = self.index.intersection(tuple(query), objects=True)
|
||||
filepaths = [hit.object for hit in hits]
|
||||
|
||||
sample = {"image": [], "mask": [], "crs": self.crs, "bbox": query}
|
||||
|
||||
if len(filepaths) == 0:
|
||||
raise IndexError(
|
||||
f"query: {query} not found in index with bounds: {self.bounds}"
|
||||
)
|
||||
elif len(filepaths) == 1:
|
||||
filenames = filepaths[0]
|
||||
query_geom_transformed = None # is set by the first layer
|
||||
|
||||
minx, maxx, miny, maxy, mint, maxt = query
|
||||
query_box = shapely.geometry.box(minx, miny, maxx, maxy)
|
||||
|
||||
for layer in self.layers:
|
||||
|
||||
fn = filenames[layer]
|
||||
|
||||
with rasterio.open(
|
||||
os.path.join(self.root, "enviroatlas_lotp", fn)
|
||||
) as f:
|
||||
dst_crs = f.crs.to_string().lower()
|
||||
|
||||
if query_geom_transformed is None:
|
||||
query_box_transformed = shapely.ops.transform(
|
||||
self.p_transformers[dst_crs], query_box
|
||||
).envelope
|
||||
query_geom_transformed = shapely.geometry.mapping(
|
||||
query_box_transformed
|
||||
)
|
||||
|
||||
data, _ = rasterio.mask.mask(
|
||||
f, [query_geom_transformed], crop=True, all_touched=True
|
||||
)
|
||||
|
||||
if layer in [
|
||||
"naip",
|
||||
"buildings",
|
||||
"roads",
|
||||
"waterways",
|
||||
"waterbodies",
|
||||
"water",
|
||||
]:
|
||||
sample["image"].append(data)
|
||||
elif layer in ["prior", "prior_no_osm_no_buildings"]:
|
||||
if self.prior_as_input:
|
||||
sample["image"].append(data)
|
||||
else:
|
||||
sample["mask"].append(data)
|
||||
elif layer in ["lc"]:
|
||||
data = self.raw_enviroatlas_to_idx_map[data]
|
||||
sample["mask"].append(data)
|
||||
else:
|
||||
raise IndexError(f"query: {query} spans multiple tiles which is not valid")
|
||||
|
||||
sample["image"] = np.concatenate(sample["image"], axis=0)
|
||||
sample["mask"] = np.concatenate(sample["mask"], axis=0)
|
||||
|
||||
sample["image"] = torch.from_numpy( # type: ignore[attr-defined]
|
||||
sample["image"]
|
||||
)
|
||||
sample["mask"] = torch.from_numpy(sample["mask"]) # type: ignore[attr-defined]
|
||||
|
||||
if self.transforms is not None:
|
||||
sample = self.transforms(sample)
|
||||
|
||||
return sample
|
||||
|
||||
def _verify(self) -> None:
|
||||
"""Verify the integrity of the dataset.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
|
||||
"""
|
||||
# Check if the extracted files already exist
|
||||
def exists(filename: str) -> bool:
|
||||
return os.path.exists(os.path.join(self.root, "enviroatlas_lotp", filename))
|
||||
|
||||
if all(map(exists, self.files)):
|
||||
return
|
||||
|
||||
# Check if the zip files have already been downloaded
|
||||
if os.path.exists(os.path.join(self.root, self.filename)):
|
||||
self._extract()
|
||||
return
|
||||
|
||||
# Check if the user requested to download the dataset
|
||||
if not self.download:
|
||||
raise RuntimeError(
|
||||
f"Dataset not found in `root={self.root}` and `download=False`, "
|
||||
"either specify a different `root` directory or use `download=True` "
|
||||
"to automaticaly download the dataset."
|
||||
)
|
||||
|
||||
# Download the dataset
|
||||
self._download()
|
||||
self._extract()
|
||||
|
||||
def _download(self) -> None:
|
||||
"""Download the dataset."""
|
||||
download_url(self.url, self.root, filename=self.filename, md5=self.md5)
|
||||
|
||||
def _extract(self) -> None:
|
||||
"""Extract the dataset."""
|
||||
extract_archive(os.path.join(self.root, self.filename))
|
||||
|
||||
def plot(
|
||||
self,
|
||||
sample: Dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Note: only plots the "naip" and "lc" layers.
|
||||
|
||||
Args:
|
||||
sample: a sample returned by :meth:`__getitem__`
|
||||
show_titles: flag indicating whether to show titles above each panel
|
||||
suptitle: optional string to use as a suptitle
|
||||
|
||||
Returns:
|
||||
a matplotlib Figure with the rendered sample
|
||||
|
||||
Raises:
|
||||
ValueError: if the NAIP layer isn't included in ``self.layers``
|
||||
"""
|
||||
if "naip" not in self.layers or "lc" not in self.layers:
|
||||
raise ValueError("The 'naip' and 'lc' layers must be included for plotting")
|
||||
|
||||
image_layers = []
|
||||
mask_layers = []
|
||||
for layer in self.layers:
|
||||
if layer in [
|
||||
"naip",
|
||||
"buildings",
|
||||
"roads",
|
||||
"waterways",
|
||||
"waterbodies",
|
||||
"water",
|
||||
]:
|
||||
image_layers.append(layer)
|
||||
elif layer in ["prior", "prior_no_osm_no_buildings"]:
|
||||
if self.prior_as_input:
|
||||
image_layers.append(layer)
|
||||
else:
|
||||
mask_layers.append(layer)
|
||||
elif layer in ["lc"]:
|
||||
mask_layers.append(layer)
|
||||
|
||||
naip_index = image_layers.index("naip")
|
||||
lc_index = mask_layers.index("lc")
|
||||
|
||||
image = np.rollaxis(
|
||||
sample["image"][naip_index : naip_index + 3, :, :].numpy(), 0, 3
|
||||
)
|
||||
mask = sample["mask"][lc_index].numpy()
|
||||
|
||||
num_panels = 2
|
||||
showing_predictions = "prediction" in sample
|
||||
if showing_predictions:
|
||||
predictions = sample["prediction"].numpy()
|
||||
num_panels += 1
|
||||
|
||||
fig, axs = plt.subplots(1, num_panels, figsize=(num_panels * 4, 5))
|
||||
axs[0].imshow(image)
|
||||
axs[0].axis("off")
|
||||
axs[1].imshow(
|
||||
mask, vmin=0, vmax=10, cmap=self.highres_cmap, interpolation="none"
|
||||
)
|
||||
axs[1].axis("off")
|
||||
if show_titles:
|
||||
axs[0].set_title("Image")
|
||||
axs[1].set_title("Mask")
|
||||
|
||||
if showing_predictions:
|
||||
axs[2].imshow(
|
||||
predictions,
|
||||
vmin=0,
|
||||
vmax=10,
|
||||
cmap=self.highres_cmap,
|
||||
interpolation="none",
|
||||
)
|
||||
axs[2].axis("off")
|
||||
if show_titles:
|
||||
axs[2].set_title("Predictions")
|
||||
|
||||
if suptitle is not None:
|
||||
plt.suptitle(suptitle)
|
||||
return fig
|
Загрузка…
Ссылка в новой задаче