Adding layer support to CVPR dataset

This commit is contained in:
Caleb Robinson 2021-09-02 03:06:38 +00:00 коммит произвёл Adam J. Stewart
Родитель 50e6b3e1d3
Коммит 082dcac8d7
1 изменённых файлов: 38 добавлений и 25 удалений

Просмотреть файл

@ -5,7 +5,7 @@
import os
import sys
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, Optional, List
import fiona
import pyproj
@ -42,6 +42,15 @@ class CVPRChesapeake(GeoDataset):
crs = CRS.from_epsg(3857)
res = 1
valid_layers = [
"naip-new",
"naip-old",
"landsat-leaf-on",
"landsat-leaf-off",
"nlcd",
"lc",
"buildings"
]
states = ["de", "md", "va", "wv", "pa", "ny"]
splits = (
[f"{state}-train" for state in states]
@ -63,6 +72,7 @@ class CVPRChesapeake(GeoDataset):
self,
root: str = "data",
split: str = "de-train",
layers: List[str] = ["naip-new", "lc"],
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
cache: bool = True,
download: bool = False,
@ -85,8 +95,10 @@ class CVPRChesapeake(GeoDataset):
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
"""
assert split in self.splits
assert all([layer in self.valid_layers for layer in layers])
super().__init__(transforms) # creates self.index and self.transform
self.root = root
self.layers = layers
self.cache = cache
self.checksum = checksum
@ -137,45 +149,46 @@ class CVPRChesapeake(GeoDataset):
hits = self.index.intersection(query, objects=True)
filepaths = [hit.object for hit in hits]
sample = {
"crs": self.crs,
"bbox": query,
}
if len(filepaths) == 0:
raise IndexError(
f"query: {query} not found in index with bounds: {self.bounds}"
)
elif len(filepaths) == 1:
filenames = filepaths[0]
naip_fn = filenames["naip-new"]
lc_fn = filenames["lc"]
query_geom_transformed = None # is set by the first layer
minx, maxx, miny, maxy, mint, maxt = query
query_box = shapely.geometry.box(minx, miny, maxx, maxy)
with rasterio.open(os.path.join(self.root, naip_fn)) as f:
dst_crs = f.crs.to_string().lower()
query_box_transformed = shapely.ops.transform(
self.p_transformers[dst_crs], query_box
).envelope
query_geom_transformed = shapely.geometry.mapping(query_box_transformed)
naip_data, _ = rasterio.mask.mask(
f, [query_geom_transformed], crop=True, all_touched=True
)
for layer in self.layers:
with rasterio.open(os.path.join(self.root, lc_fn)) as f:
lc_data, _ = rasterio.mask.mask(
f, [query_geom_transformed], crop=True, all_touched=True
)
fn = filenames[layer]
with rasterio.open(os.path.join(self.root, fn)) as f:
dst_crs = f.crs.to_string().lower()
if query_geom_transformed is None:
query_box_transformed = shapely.ops.transform(
self.p_transformers[dst_crs], query_box
).envelope
query_geom_transformed = shapely.geometry.mapping(
query_box_transformed
)
data, _ = rasterio.mask.mask(
f, [query_geom_transformed], crop=True, all_touched=True
)
sample[layer] = data.squeeze()
else:
raise IndexError(f"query: {query} spans multiple tiles which is not valid")
sample = {
"image": naip_data,
"mask": lc_data,
"crs": self.crs,
"bbox": query,
}
print(naip_data.shape, query)
if self.transforms is not None:
sample = self.transforms(sample)