зеркало из https://github.com/microsoft/torchgeo.git
VectorDataset: fix issue with empty query (#467)
* VectorDataset: fix issue with empty query * isort
This commit is contained in:
Родитель
76603f3343
Коммит
3d1a1e9b08
|
@ -0,0 +1,60 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
# Create an L shape:
|
||||||
|
#
|
||||||
|
# +--+
|
||||||
|
# | |
|
||||||
|
# +--+--+
|
||||||
|
# | | |
|
||||||
|
# +--+--+
|
||||||
|
#
|
||||||
|
# This allows us to test queries:
|
||||||
|
#
|
||||||
|
# * within the L
|
||||||
|
# * within the dataset bounding box but with no features
|
||||||
|
# * outside the dataset bounding box
|
||||||
|
|
||||||
|
geojson = {
|
||||||
|
"type": "FeatureCollection",
|
||||||
|
"crs": {"type": "name", "properties": {"name": "urn:ogc:def:crs:OGC:1.3:CRS84"}},
|
||||||
|
"features": [
|
||||||
|
{
|
||||||
|
"type": "Feature",
|
||||||
|
"properties": {},
|
||||||
|
"geometry": {
|
||||||
|
"type": "Polygon",
|
||||||
|
"coordinates": [
|
||||||
|
[[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]]
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "Feature",
|
||||||
|
"properties": {},
|
||||||
|
"geometry": {
|
||||||
|
"type": "Polygon",
|
||||||
|
"coordinates": [
|
||||||
|
[[1.0, 0.0], [1.0, 1.0], [2.0, 1.0], [2.0, 0.0], [1.0, 0.0]]
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "Feature",
|
||||||
|
"properties": {},
|
||||||
|
"geometry": {
|
||||||
|
"type": "Polygon",
|
||||||
|
"coordinates": [
|
||||||
|
[[0.0, 1.0], [0.0, 2.0], [1.0, 2.0], [1.0, 1.0], [0.0, 1.0]]
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
with open("vector.geojson", "w") as f:
|
||||||
|
json.dump(geojson, f)
|
|
@ -0,0 +1 @@
|
||||||
|
{"type": "FeatureCollection", "crs": {"type": "name", "properties": {"name": "urn:ogc:def:crs:OGC:1.3:CRS84"}}, "features": [{"type": "Feature", "properties": {}, "geometry": {"type": "Polygon", "coordinates": [[[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]]]}}, {"type": "Feature", "properties": {}, "geometry": {"type": "Polygon", "coordinates": [[[1.0, 0.0], [1.0, 1.0], [2.0, 1.0], [2.0, 0.0], [1.0, 0.0]]]}}, {"type": "Feature", "properties": {}, "geometry": {"type": "Polygon", "coordinates": [[[0.0, 1.0], [0.0, 2.0], [1.0, 2.0], [1.0, 1.0], [0.0, 1.0]]]}}]}
|
|
@ -16,7 +16,6 @@ from torch.utils.data import ConcatDataset
|
||||||
from torchgeo.datasets import (
|
from torchgeo.datasets import (
|
||||||
NAIP,
|
NAIP,
|
||||||
BoundingBox,
|
BoundingBox,
|
||||||
CanadianBuildingFootprints,
|
|
||||||
GeoDataset,
|
GeoDataset,
|
||||||
IntersectionDataset,
|
IntersectionDataset,
|
||||||
RasterDataset,
|
RasterDataset,
|
||||||
|
@ -44,6 +43,10 @@ class CustomGeoDataset(GeoDataset):
|
||||||
return {"index": query}
|
return {"index": query}
|
||||||
|
|
||||||
|
|
||||||
|
class CustomVectorDataset(VectorDataset):
|
||||||
|
filename_glob = "*.geojson"
|
||||||
|
|
||||||
|
|
||||||
class CustomVisionDataset(VisionDataset):
|
class CustomVisionDataset(VisionDataset):
|
||||||
def __getitem__(self, index: int) -> Dict[str, int]:
|
def __getitem__(self, index: int) -> Dict[str, int]:
|
||||||
return {"index": index}
|
return {"index": index}
|
||||||
|
@ -201,20 +204,25 @@ class TestRasterDataset:
|
||||||
|
|
||||||
|
|
||||||
class TestVectorDataset:
|
class TestVectorDataset:
|
||||||
@pytest.fixture
|
@pytest.fixture(scope="class")
|
||||||
def dataset(self) -> CanadianBuildingFootprints:
|
def dataset(self) -> CustomVectorDataset:
|
||||||
root = os.path.join("tests", "data", "cbf")
|
root = os.path.join("tests", "data", "vector")
|
||||||
transforms = nn.Identity() # type: ignore[no-untyped-call]
|
transforms = nn.Identity() # type: ignore[no-untyped-call]
|
||||||
return CanadianBuildingFootprints(root, res=0.1, transforms=transforms)
|
return CustomVectorDataset(root, res=0.1, transforms=transforms)
|
||||||
|
|
||||||
def test_getitem(self, dataset: CanadianBuildingFootprints) -> None:
|
def test_getitem(self, dataset: CustomVectorDataset) -> None:
|
||||||
x = dataset[dataset.bounds]
|
x = dataset[dataset.bounds]
|
||||||
assert isinstance(x, dict)
|
assert isinstance(x, dict)
|
||||||
assert isinstance(x["crs"], CRS)
|
assert isinstance(x["crs"], CRS)
|
||||||
assert isinstance(x["mask"], torch.Tensor)
|
assert isinstance(x["mask"], torch.Tensor)
|
||||||
|
|
||||||
def test_invalid_query(self, dataset: CanadianBuildingFootprints) -> None:
|
def test_empty_shapes(self, dataset: CustomVectorDataset) -> None:
|
||||||
query = BoundingBox(2, 2, 2, 2, 2, 2)
|
query = BoundingBox(1.1, 1.9, 1.1, 1.9, 0, 0)
|
||||||
|
x = dataset[query]
|
||||||
|
assert torch.equal(x["mask"], torch.zeros(7, 7, dtype=torch.uint8))
|
||||||
|
|
||||||
|
def test_invalid_query(self, dataset: CustomVectorDataset) -> None:
|
||||||
|
query = BoundingBox(3, 3, 3, 3, 0, 0)
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
IndexError, match="query: .* not found in index with bounds:"
|
IndexError, match="query: .* not found in index with bounds:"
|
||||||
):
|
):
|
||||||
|
|
|
@ -661,9 +661,14 @@ class VectorDataset(GeoDataset):
|
||||||
transform = rasterio.transform.from_bounds(
|
transform = rasterio.transform.from_bounds(
|
||||||
query.minx, query.miny, query.maxx, query.maxy, width, height
|
query.minx, query.miny, query.maxx, query.maxy, width, height
|
||||||
)
|
)
|
||||||
masks = rasterio.features.rasterize(
|
if shapes:
|
||||||
shapes, out_shape=(int(height), int(width)), transform=transform
|
masks = rasterio.features.rasterize(
|
||||||
)
|
shapes, out_shape=(int(height), int(width)), transform=transform
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# If no features are found in this query, return an empty mask
|
||||||
|
# with the default fill value and dtype used by rasterize
|
||||||
|
masks = np.zeros((int(height), int(width)), dtype=np.uint8)
|
||||||
|
|
||||||
sample = {"mask": torch.tensor(masks), "crs": self.crs, "bbox": query}
|
sample = {"mask": torch.tensor(masks), "crs": self.crs, "bbox": query}
|
||||||
|
|
||||||
|
|
Загрузка…
Ссылка в новой задаче