diff --git a/tests/data/vector/data.py b/tests/data/vector/data.py new file mode 100755 index 000000000..97bdb902e --- /dev/null +++ b/tests/data/vector/data.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import json + +# Create an L shape: +# +# +--+ +# | | +# +--+--+ +# | | | +# +--+--+ +# +# This allows us to test queries: +# +# * within the L +# * within the dataset bounding box but with no features +# * outside the dataset bounding box + +geojson = { + "type": "FeatureCollection", + "crs": {"type": "name", "properties": {"name": "urn:ogc:def:crs:OGC:1.3:CRS84"}}, + "features": [ + { + "type": "Feature", + "properties": {}, + "geometry": { + "type": "Polygon", + "coordinates": [ + [[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]] + ], + }, + }, + { + "type": "Feature", + "properties": {}, + "geometry": { + "type": "Polygon", + "coordinates": [ + [[1.0, 0.0], [1.0, 1.0], [2.0, 1.0], [2.0, 0.0], [1.0, 0.0]] + ], + }, + }, + { + "type": "Feature", + "properties": {}, + "geometry": { + "type": "Polygon", + "coordinates": [ + [[0.0, 1.0], [0.0, 2.0], [1.0, 2.0], [1.0, 1.0], [0.0, 1.0]] + ], + }, + }, + ], +} + +with open("vector.geojson", "w") as f: + json.dump(geojson, f) diff --git a/tests/data/vector/vector.geojson b/tests/data/vector/vector.geojson new file mode 100644 index 000000000..26a9a7bde --- /dev/null +++ b/tests/data/vector/vector.geojson @@ -0,0 +1 @@ +{"type": "FeatureCollection", "crs": {"type": "name", "properties": {"name": "urn:ogc:def:crs:OGC:1.3:CRS84"}}, "features": [{"type": "Feature", "properties": {}, "geometry": {"type": "Polygon", "coordinates": [[[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]]]}}, {"type": "Feature", "properties": {}, "geometry": {"type": "Polygon", "coordinates": [[[1.0, 0.0], [1.0, 1.0], [2.0, 1.0], [2.0, 0.0], [1.0, 0.0]]]}}, {"type": "Feature", "properties": {}, "geometry": {"type": "Polygon", "coordinates": [[[0.0, 1.0], [0.0, 2.0], [1.0, 2.0], [1.0, 1.0], [0.0, 1.0]]]}}]} \ No newline at end of file diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py index f71f8a332..4d790de9c 100644 --- a/tests/datasets/test_geo.py +++ b/tests/datasets/test_geo.py @@ -16,7 +16,6 @@ from torch.utils.data import ConcatDataset from torchgeo.datasets import ( NAIP, BoundingBox, - CanadianBuildingFootprints, GeoDataset, IntersectionDataset, RasterDataset, @@ -44,6 +43,10 @@ class CustomGeoDataset(GeoDataset): return {"index": query} +class CustomVectorDataset(VectorDataset): + filename_glob = "*.geojson" + + class CustomVisionDataset(VisionDataset): def __getitem__(self, index: int) -> Dict[str, int]: return {"index": index} @@ -201,20 +204,25 @@ class TestRasterDataset: class TestVectorDataset: - @pytest.fixture - def dataset(self) -> CanadianBuildingFootprints: - root = os.path.join("tests", "data", "cbf") + @pytest.fixture(scope="class") + def dataset(self) -> CustomVectorDataset: + root = os.path.join("tests", "data", "vector") transforms = nn.Identity() # type: ignore[no-untyped-call] - return CanadianBuildingFootprints(root, res=0.1, transforms=transforms) + return CustomVectorDataset(root, res=0.1, transforms=transforms) - def test_getitem(self, dataset: CanadianBuildingFootprints) -> None: + def test_getitem(self, dataset: CustomVectorDataset) -> None: x = dataset[dataset.bounds] assert isinstance(x, dict) assert isinstance(x["crs"], CRS) assert isinstance(x["mask"], torch.Tensor) - def test_invalid_query(self, dataset: CanadianBuildingFootprints) -> None: - query = BoundingBox(2, 2, 2, 2, 2, 2) + def test_empty_shapes(self, dataset: CustomVectorDataset) -> None: + query = BoundingBox(1.1, 1.9, 1.1, 1.9, 0, 0) + x = dataset[query] + assert torch.equal(x["mask"], torch.zeros(7, 7, dtype=torch.uint8)) + + def test_invalid_query(self, dataset: CustomVectorDataset) -> None: + query = BoundingBox(3, 3, 3, 3, 0, 0) with pytest.raises( IndexError, match="query: .* not found in index with bounds:" ): diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 725aeb4c9..55a76c4f1 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -661,9 +661,14 @@ class VectorDataset(GeoDataset): transform = rasterio.transform.from_bounds( query.minx, query.miny, query.maxx, query.maxy, width, height ) - masks = rasterio.features.rasterize( - shapes, out_shape=(int(height), int(width)), transform=transform - ) + if shapes: + masks = rasterio.features.rasterize( + shapes, out_shape=(int(height), int(width)), transform=transform + ) + else: + # If no features are found in this query, return an empty mask + # with the default fill value and dtype used by rasterize + masks = np.zeros((int(height), int(width)), dtype=np.uint8) sample = {"mask": torch.tensor(masks), "crs": self.crs, "bbox": query}