VectorDataset: fix issue with empty query (#467)

* VectorDataset: fix issue with empty query

* isort
This commit is contained in:
Adam J. Stewart 2022-03-19 10:30:02 -05:00 коммит произвёл GitHub
Родитель 76603f3343
Коммит 3d1a1e9b08
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 85 добавлений и 11 удалений

60
tests/data/vector/data.py Executable file
Просмотреть файл

@ -0,0 +1,60 @@
#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import json
# Create an L shape:
#
# +--+
# | |
# +--+--+
# | | |
# +--+--+
#
# This allows us to test queries:
#
# * within the L
# * within the dataset bounding box but with no features
# * outside the dataset bounding box
geojson = {
"type": "FeatureCollection",
"crs": {"type": "name", "properties": {"name": "urn:ogc:def:crs:OGC:1.3:CRS84"}},
"features": [
{
"type": "Feature",
"properties": {},
"geometry": {
"type": "Polygon",
"coordinates": [
[[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]]
],
},
},
{
"type": "Feature",
"properties": {},
"geometry": {
"type": "Polygon",
"coordinates": [
[[1.0, 0.0], [1.0, 1.0], [2.0, 1.0], [2.0, 0.0], [1.0, 0.0]]
],
},
},
{
"type": "Feature",
"properties": {},
"geometry": {
"type": "Polygon",
"coordinates": [
[[0.0, 1.0], [0.0, 2.0], [1.0, 2.0], [1.0, 1.0], [0.0, 1.0]]
],
},
},
],
}
with open("vector.geojson", "w") as f:
json.dump(geojson, f)

Просмотреть файл

@ -0,0 +1 @@
{"type": "FeatureCollection", "crs": {"type": "name", "properties": {"name": "urn:ogc:def:crs:OGC:1.3:CRS84"}}, "features": [{"type": "Feature", "properties": {}, "geometry": {"type": "Polygon", "coordinates": [[[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]]]}}, {"type": "Feature", "properties": {}, "geometry": {"type": "Polygon", "coordinates": [[[1.0, 0.0], [1.0, 1.0], [2.0, 1.0], [2.0, 0.0], [1.0, 0.0]]]}}, {"type": "Feature", "properties": {}, "geometry": {"type": "Polygon", "coordinates": [[[0.0, 1.0], [0.0, 2.0], [1.0, 2.0], [1.0, 1.0], [0.0, 1.0]]]}}]}

Просмотреть файл

@ -16,7 +16,6 @@ from torch.utils.data import ConcatDataset
from torchgeo.datasets import ( from torchgeo.datasets import (
NAIP, NAIP,
BoundingBox, BoundingBox,
CanadianBuildingFootprints,
GeoDataset, GeoDataset,
IntersectionDataset, IntersectionDataset,
RasterDataset, RasterDataset,
@ -44,6 +43,10 @@ class CustomGeoDataset(GeoDataset):
return {"index": query} return {"index": query}
class CustomVectorDataset(VectorDataset):
filename_glob = "*.geojson"
class CustomVisionDataset(VisionDataset): class CustomVisionDataset(VisionDataset):
def __getitem__(self, index: int) -> Dict[str, int]: def __getitem__(self, index: int) -> Dict[str, int]:
return {"index": index} return {"index": index}
@ -201,20 +204,25 @@ class TestRasterDataset:
class TestVectorDataset: class TestVectorDataset:
@pytest.fixture @pytest.fixture(scope="class")
def dataset(self) -> CanadianBuildingFootprints: def dataset(self) -> CustomVectorDataset:
root = os.path.join("tests", "data", "cbf") root = os.path.join("tests", "data", "vector")
transforms = nn.Identity() # type: ignore[no-untyped-call] transforms = nn.Identity() # type: ignore[no-untyped-call]
return CanadianBuildingFootprints(root, res=0.1, transforms=transforms) return CustomVectorDataset(root, res=0.1, transforms=transforms)
def test_getitem(self, dataset: CanadianBuildingFootprints) -> None: def test_getitem(self, dataset: CustomVectorDataset) -> None:
x = dataset[dataset.bounds] x = dataset[dataset.bounds]
assert isinstance(x, dict) assert isinstance(x, dict)
assert isinstance(x["crs"], CRS) assert isinstance(x["crs"], CRS)
assert isinstance(x["mask"], torch.Tensor) assert isinstance(x["mask"], torch.Tensor)
def test_invalid_query(self, dataset: CanadianBuildingFootprints) -> None: def test_empty_shapes(self, dataset: CustomVectorDataset) -> None:
query = BoundingBox(2, 2, 2, 2, 2, 2) query = BoundingBox(1.1, 1.9, 1.1, 1.9, 0, 0)
x = dataset[query]
assert torch.equal(x["mask"], torch.zeros(7, 7, dtype=torch.uint8))
def test_invalid_query(self, dataset: CustomVectorDataset) -> None:
query = BoundingBox(3, 3, 3, 3, 0, 0)
with pytest.raises( with pytest.raises(
IndexError, match="query: .* not found in index with bounds:" IndexError, match="query: .* not found in index with bounds:"
): ):

Просмотреть файл

@ -661,9 +661,14 @@ class VectorDataset(GeoDataset):
transform = rasterio.transform.from_bounds( transform = rasterio.transform.from_bounds(
query.minx, query.miny, query.maxx, query.maxy, width, height query.minx, query.miny, query.maxx, query.maxy, width, height
) )
masks = rasterio.features.rasterize( if shapes:
shapes, out_shape=(int(height), int(width)), transform=transform masks = rasterio.features.rasterize(
) shapes, out_shape=(int(height), int(width)), transform=transform
)
else:
# If no features are found in this query, return an empty mask
# with the default fill value and dtype used by rasterize
masks = np.zeros((int(height), int(width)), dtype=np.uint8)
sample = {"mask": torch.tensor(masks), "crs": self.crs, "bbox": query} sample = {"mask": torch.tensor(masks), "crs": self.crs, "bbox": query}