зеркало из https://github.com/microsoft/torchgeo.git
Fix reprojection issues (#1344)
* Add tests to catch reprojection issues * More specific tests * Fix boundless reprojection * black/flake8 * Add getters and setters for all crs/res attributes * Actually fix reprojection bug * Fix tests * Get all tests passing * More specific tests * Remove aux files * Ignore *.aux.xml files * Fix mypy * Increase coverage * Use dtype properly * Increase coverage * Add newline
This commit is contained in:
Родитель
f562378bcb
Коммит
bbf6108516
|
@ -4,6 +4,7 @@
|
|||
/output/
|
||||
*.pdf
|
||||
/results/
|
||||
*.aux.xml
|
||||
|
||||
# Spack
|
||||
.spack-env/
|
||||
|
|
|
@ -2,41 +2,99 @@
|
|||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import rasterio
|
||||
import rasterio.transform
|
||||
from torchvision.datasets.utils import calculate_md5
|
||||
import rasterio as rio
|
||||
from rasterio.transform import from_bounds
|
||||
from rasterio.warp import calculate_default_transform, reproject
|
||||
|
||||
RES = [2, 4, 8]
|
||||
EPSG = [4087, 4326, 32631]
|
||||
SIZE = 16
|
||||
|
||||
|
||||
def generate_test_data(fn: str) -> str:
|
||||
"""Creates test data with uint32 datatype.
|
||||
def write_raster(
|
||||
res: int = RES[0],
|
||||
epsg: int = EPSG[0],
|
||||
dtype: str = "uint8",
|
||||
path: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Write a raster file.
|
||||
|
||||
Args:
|
||||
fn (str): Filename to write
|
||||
|
||||
Returns:
|
||||
str: md5 hash of created archive
|
||||
res: Resolution.
|
||||
epsg: EPSG of file.
|
||||
dtype: Data type.
|
||||
path: File path.
|
||||
"""
|
||||
size = SIZE // res
|
||||
profile = {
|
||||
"driver": "GTiff",
|
||||
"dtype": "uint32",
|
||||
"dtype": dtype,
|
||||
"count": 1,
|
||||
"crs": "epsg:4326",
|
||||
"transform": rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1),
|
||||
"height": 4,
|
||||
"width": 4,
|
||||
"compress": "lzw",
|
||||
"predictor": 2,
|
||||
"crs": f"epsg:{epsg}",
|
||||
"transform": from_bounds(0, 0, SIZE, SIZE, size, size),
|
||||
"height": size,
|
||||
"width": size,
|
||||
"nodata": 0,
|
||||
}
|
||||
|
||||
with rasterio.open(fn, "w", **profile) as f:
|
||||
f.write(np.random.randint(0, 256, size=(1, 4, 4)))
|
||||
if path is None:
|
||||
name = f"res_{res}_epsg_{epsg}"
|
||||
path = os.path.join(name, f"{name}.tif")
|
||||
|
||||
md5: str = calculate_md5(fn)
|
||||
return md5
|
||||
directory = os.path.dirname(path)
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
|
||||
with rio.open(path, "w", **profile) as f:
|
||||
x = np.ones((1, size, size))
|
||||
f.write(x)
|
||||
|
||||
|
||||
def reproject_raster(res: int, src_epsg: int, dst_epsg: int) -> None:
|
||||
"""Reproject a raster file.
|
||||
|
||||
Args:
|
||||
res: Resolution.
|
||||
src_epsg: EPSG of source file.
|
||||
dst_epsg: EPSG of destination file.
|
||||
"""
|
||||
src_name = f"res_{res}_epsg_{src_epsg}"
|
||||
src_path = os.path.join(src_name, f"{src_name}.tif")
|
||||
with rio.open(src_path) as src:
|
||||
dst_crs = f"epsg:{dst_epsg}"
|
||||
transform, width, height = calculate_default_transform(
|
||||
src.crs, dst_crs, src.width, src.height, *src.bounds
|
||||
)
|
||||
profile = src.profile.copy()
|
||||
profile.update(
|
||||
{"crs": dst_crs, "transform": transform, "width": width, "height": height}
|
||||
)
|
||||
dst_name = f"res_{res}_epsg_{dst_epsg}"
|
||||
os.makedirs(dst_name, exist_ok=True)
|
||||
dst_path = os.path.join(dst_name, f"{dst_name}.tif")
|
||||
with rio.open(dst_path, "w", **profile) as dst:
|
||||
reproject(
|
||||
source=rio.band(src, 1),
|
||||
destination=rio.band(dst, 1),
|
||||
src_transform=src.transform,
|
||||
src_crs=src.crs,
|
||||
dst_transform=dst.transform,
|
||||
dst_crs=dst.crs,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
md5_hash = generate_test_data(os.path.join(os.getcwd(), "test0.tif"))
|
||||
print(md5_hash)
|
||||
for res in RES:
|
||||
src_epsg = EPSG[0]
|
||||
write_raster(res, src_epsg)
|
||||
|
||||
for dst_epsg in EPSG[1:]:
|
||||
reproject_raster(res, src_epsg, dst_epsg)
|
||||
|
||||
for dtype in ["uint16", "uint32"]:
|
||||
path = os.path.join(dtype, f"{dtype}.tif")
|
||||
write_raster(dtype=dtype, path=path)
|
||||
with open(os.path.join(dtype, "corrupted.tif"), "w") as f:
|
||||
f.write("not a tif file\n")
|
||||
|
|
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичные данные
tests/data/raster/test0.tif
Двоичные данные
tests/data/raster/test0.tif
Двоичный файл не отображается.
|
@ -0,0 +1 @@
|
|||
not a tif file
|
Двоичный файл не отображается.
|
@ -0,0 +1 @@
|
|||
not a tif file
|
Двоичный файл не отображается.
|
@ -30,7 +30,7 @@ class CustomGeoDataset(GeoDataset):
|
|||
def __init__(
|
||||
self,
|
||||
bounds: BoundingBox = BoundingBox(0, 1, 2, 3, 4, 5),
|
||||
crs: CRS = CRS.from_epsg(3005),
|
||||
crs: CRS = CRS.from_epsg(4087),
|
||||
res: float = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
@ -74,7 +74,7 @@ class TestGeoDataset:
|
|||
def test_len(self, dataset: GeoDataset) -> None:
|
||||
assert len(dataset) == 1
|
||||
|
||||
@pytest.mark.parametrize("crs", [CRS.from_epsg(3005), CRS.from_epsg(32616)])
|
||||
@pytest.mark.parametrize("crs", [CRS.from_epsg(4087), CRS.from_epsg(32631)])
|
||||
def test_crs(self, dataset: GeoDataset, crs: CRS) -> None:
|
||||
dataset.crs = crs
|
||||
|
||||
|
@ -157,7 +157,7 @@ class TestRasterDataset:
|
|||
def naip(self, request: SubRequest) -> NAIP:
|
||||
root = os.path.join("tests", "data", "naip")
|
||||
bands = request.param[0]
|
||||
crs = CRS.from_epsg(3005)
|
||||
crs = CRS.from_epsg(4087)
|
||||
transforms = nn.Identity()
|
||||
cache = request.param[1]
|
||||
return NAIP(root, crs=crs, bands=bands, transforms=transforms, cache=cache)
|
||||
|
@ -178,11 +178,6 @@ class TestRasterDataset:
|
|||
cache = request.param[1]
|
||||
return Sentinel2(root, bands=bands, transforms=transforms, cache=cache)
|
||||
|
||||
@pytest.fixture()
|
||||
def custom_dtype_ds(self) -> RasterDataset:
|
||||
root = os.path.join("tests", "data", "raster")
|
||||
return RasterDataset(root)
|
||||
|
||||
def test_getitem_single_file(self, naip: NAIP) -> None:
|
||||
x = naip[naip.bounds]
|
||||
assert isinstance(x, dict)
|
||||
|
@ -197,8 +192,11 @@ class TestRasterDataset:
|
|||
assert isinstance(x["image"], torch.Tensor)
|
||||
assert len(sentinel.bands) == x["image"].shape[0]
|
||||
|
||||
def test_getitem_uint_dtype(self, custom_dtype_ds: RasterDataset) -> None:
|
||||
x = custom_dtype_ds[custom_dtype_ds.bounds]
|
||||
@pytest.mark.parametrize("dtype", ["uint16", "uint32"])
|
||||
def test_getitem_uint_dtype(self, dtype: str) -> None:
|
||||
root = os.path.join("tests", "data", "raster", dtype)
|
||||
ds = RasterDataset(root)
|
||||
x = ds[ds.bounds]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x["image"], torch.Tensor)
|
||||
assert x["image"].dtype == torch.float32
|
||||
|
@ -377,14 +375,15 @@ class TestNonGeoClassificationDataset:
|
|||
class TestIntersectionDataset:
|
||||
@pytest.fixture(scope="class")
|
||||
def dataset(self) -> IntersectionDataset:
|
||||
ds1 = CustomGeoDataset()
|
||||
ds2 = CustomGeoDataset()
|
||||
ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087"))
|
||||
ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4326"))
|
||||
transforms = nn.Identity()
|
||||
return IntersectionDataset(ds1, ds2, transforms=transforms)
|
||||
|
||||
def test_getitem(self, dataset: IntersectionDataset) -> None:
|
||||
query = BoundingBox(0, 1, 2, 3, 4, 5)
|
||||
assert dataset[query] == {"index": query}
|
||||
query = dataset.bounds
|
||||
sample = dataset[query]
|
||||
assert isinstance(sample["image"], torch.Tensor)
|
||||
|
||||
def test_len(self, dataset: IntersectionDataset) -> None:
|
||||
assert len(dataset) == 1
|
||||
|
@ -403,27 +402,69 @@ class TestIntersectionDataset:
|
|||
):
|
||||
IntersectionDataset(ds1, ds2) # type: ignore[arg-type]
|
||||
|
||||
def test_different_crs(self) -> None:
|
||||
ds1 = CustomGeoDataset(BoundingBox(0, 1, 0, 1, 0, 1), crs=CRS.from_epsg(3005))
|
||||
ds2 = CustomGeoDataset(
|
||||
BoundingBox(
|
||||
-3547229.913123814,
|
||||
6360089.518213182,
|
||||
-3547229.913123814,
|
||||
6360089.518213182,
|
||||
-3547229.913123814,
|
||||
6360089.518213182,
|
||||
),
|
||||
crs=CRS.from_epsg(32616),
|
||||
)
|
||||
def test_different_crs_12(self) -> None:
|
||||
ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087"))
|
||||
ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4326"))
|
||||
ds = IntersectionDataset(ds1, ds2)
|
||||
assert len(ds) == 1
|
||||
sample = ds[ds.bounds]
|
||||
assert ds1.crs == ds2.crs == ds.crs == CRS.from_epsg(4087)
|
||||
assert ds1.res == ds2.res == ds.res == 2
|
||||
assert len(ds1) == len(ds2) == len(ds) == 1
|
||||
assert isinstance(sample["image"], torch.Tensor)
|
||||
|
||||
def test_different_res(self) -> None:
|
||||
ds1 = CustomGeoDataset(res=1)
|
||||
ds2 = CustomGeoDataset(res=2)
|
||||
def test_different_crs_12_3(self) -> None:
|
||||
ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087"))
|
||||
ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4326"))
|
||||
ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_32631"))
|
||||
ds = (ds1 & ds2) & ds3
|
||||
sample = ds[ds.bounds]
|
||||
assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087)
|
||||
assert ds1.res == ds2.res == ds3.res == ds.res == 2
|
||||
assert len(ds1) == len(ds2) == len(ds3) == len(ds) == 1
|
||||
assert isinstance(sample["image"], torch.Tensor)
|
||||
|
||||
def test_different_crs_1_23(self) -> None:
|
||||
ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087"))
|
||||
ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4326"))
|
||||
ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_32631"))
|
||||
ds = ds1 & (ds2 & ds3)
|
||||
sample = ds[ds.bounds]
|
||||
assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087)
|
||||
assert ds1.res == ds2.res == ds3.res == ds.res == 2
|
||||
assert len(ds1) == len(ds2) == len(ds3) == len(ds) == 1
|
||||
assert isinstance(sample["image"], torch.Tensor)
|
||||
|
||||
def test_different_res_12(self) -> None:
|
||||
ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087"))
|
||||
ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4087"))
|
||||
ds = IntersectionDataset(ds1, ds2)
|
||||
assert len(ds) == 1
|
||||
sample = ds[ds.bounds]
|
||||
assert ds1.crs == ds2.crs == ds.crs == CRS.from_epsg(4087)
|
||||
assert ds1.res == ds2.res == ds.res == 2
|
||||
assert len(ds1) == len(ds2) == len(ds) == 1
|
||||
assert isinstance(sample["image"], torch.Tensor)
|
||||
|
||||
def test_different_res_12_3(self) -> None:
|
||||
ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087"))
|
||||
ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4087"))
|
||||
ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_8_epsg_4087"))
|
||||
ds = (ds1 & ds2) & ds3
|
||||
sample = ds[ds.bounds]
|
||||
assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087)
|
||||
assert ds1.res == ds2.res == ds3.res == ds.res == 2
|
||||
assert len(ds1) == len(ds2) == len(ds3) == len(ds) == 1
|
||||
assert isinstance(sample["image"], torch.Tensor)
|
||||
|
||||
def test_different_res_1_23(self) -> None:
|
||||
ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087"))
|
||||
ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4087"))
|
||||
ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_8_epsg_4087"))
|
||||
ds = ds1 & (ds2 & ds3)
|
||||
sample = ds[ds.bounds]
|
||||
assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087)
|
||||
assert ds1.res == ds2.res == ds3.res == ds.res == 2
|
||||
assert len(ds1) == len(ds2) == len(ds3) == len(ds) == 1
|
||||
assert isinstance(sample["image"], torch.Tensor)
|
||||
|
||||
def test_no_overlap(self) -> None:
|
||||
ds1 = CustomGeoDataset(BoundingBox(0, 1, 2, 3, 4, 5))
|
||||
|
@ -433,7 +474,7 @@ class TestIntersectionDataset:
|
|||
IntersectionDataset(ds1, ds2)
|
||||
|
||||
def test_invalid_query(self, dataset: IntersectionDataset) -> None:
|
||||
query = BoundingBox(0, 0, 0, 0, 0, 0)
|
||||
query = BoundingBox(-1, -1, -1, -1, -1, -1)
|
||||
with pytest.raises(
|
||||
IndexError, match="query: .* not found in index with bounds:"
|
||||
):
|
||||
|
@ -443,14 +484,15 @@ class TestIntersectionDataset:
|
|||
class TestUnionDataset:
|
||||
@pytest.fixture(scope="class")
|
||||
def dataset(self) -> UnionDataset:
|
||||
ds1 = CustomGeoDataset(bounds=BoundingBox(0, 1, 0, 1, 0, 1))
|
||||
ds2 = CustomGeoDataset(bounds=BoundingBox(2, 3, 2, 3, 2, 3))
|
||||
ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087"))
|
||||
ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4326"))
|
||||
transforms = nn.Identity()
|
||||
return UnionDataset(ds1, ds2, transforms=transforms)
|
||||
|
||||
def test_getitem(self, dataset: UnionDataset) -> None:
|
||||
query = BoundingBox(0, 1, 0, 1, 0, 1)
|
||||
assert dataset[query] == {"index": query}
|
||||
query = dataset.bounds
|
||||
sample = dataset[query]
|
||||
assert isinstance(sample["image"], torch.Tensor)
|
||||
|
||||
def test_len(self, dataset: UnionDataset) -> None:
|
||||
assert len(dataset) == 2
|
||||
|
@ -461,6 +503,76 @@ class TestUnionDataset:
|
|||
assert "bbox: BoundingBox" in out
|
||||
assert "size: 2" in out
|
||||
|
||||
def test_different_crs_12(self) -> None:
|
||||
ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087"))
|
||||
ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4326"))
|
||||
ds = UnionDataset(ds1, ds2)
|
||||
sample = ds[ds.bounds]
|
||||
assert ds1.crs == ds2.crs == ds.crs == CRS.from_epsg(4087)
|
||||
assert ds1.res == ds2.res == ds.res == 2
|
||||
assert len(ds1) == len(ds2) == 1
|
||||
assert len(ds) == 2
|
||||
assert isinstance(sample["image"], torch.Tensor)
|
||||
|
||||
def test_different_crs_12_3(self) -> None:
|
||||
ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087"))
|
||||
ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4326"))
|
||||
ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_32631"))
|
||||
ds = (ds1 | ds2) | ds3
|
||||
sample = ds[ds.bounds]
|
||||
assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087)
|
||||
assert ds1.res == ds2.res == ds3.res == ds.res == 2
|
||||
assert len(ds1) == len(ds2) == len(ds3) == 1
|
||||
assert len(ds) == 3
|
||||
assert isinstance(sample["image"], torch.Tensor)
|
||||
|
||||
def test_different_crs_1_23(self) -> None:
|
||||
ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087"))
|
||||
ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4326"))
|
||||
ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_32631"))
|
||||
ds = ds1 | (ds2 | ds3)
|
||||
sample = ds[ds.bounds]
|
||||
assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087)
|
||||
assert ds1.res == ds2.res == ds3.res == ds.res == 2
|
||||
assert len(ds1) == len(ds2) == len(ds3) == 1
|
||||
assert len(ds) == 3
|
||||
assert isinstance(sample["image"], torch.Tensor)
|
||||
|
||||
def test_different_res_12(self) -> None:
|
||||
ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087"))
|
||||
ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4087"))
|
||||
ds = UnionDataset(ds1, ds2)
|
||||
sample = ds[ds.bounds]
|
||||
assert ds1.crs == ds2.crs == ds.crs == CRS.from_epsg(4087)
|
||||
assert ds1.res == ds2.res == ds.res == 2
|
||||
assert len(ds1) == len(ds2) == 1
|
||||
assert len(ds) == 2
|
||||
assert isinstance(sample["image"], torch.Tensor)
|
||||
|
||||
def test_different_res_12_3(self) -> None:
|
||||
ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087"))
|
||||
ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4087"))
|
||||
ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_8_epsg_4087"))
|
||||
ds = (ds1 | ds2) | ds3
|
||||
sample = ds[ds.bounds]
|
||||
assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087)
|
||||
assert ds1.res == ds2.res == ds3.res == ds.res == 2
|
||||
assert len(ds1) == len(ds2) == len(ds3) == 1
|
||||
assert len(ds) == 3
|
||||
assert isinstance(sample["image"], torch.Tensor)
|
||||
|
||||
def test_different_res_1_23(self) -> None:
|
||||
ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087"))
|
||||
ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4087"))
|
||||
ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_8_epsg_4087"))
|
||||
ds = ds1 | (ds2 | ds3)
|
||||
sample = ds[ds.bounds]
|
||||
assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087)
|
||||
assert ds1.res == ds2.res == ds3.res == ds.res == 2
|
||||
assert len(ds1) == len(ds2) == len(ds3) == 1
|
||||
assert len(ds) == 3
|
||||
assert isinstance(sample["image"], torch.Tensor)
|
||||
|
||||
def test_nongeo_dataset(self) -> None:
|
||||
ds1 = CustomNonGeoDataset()
|
||||
ds2 = CustomNonGeoDataset()
|
||||
|
@ -473,22 +585,8 @@ class TestUnionDataset:
|
|||
with pytest.raises(ValueError, match=msg):
|
||||
UnionDataset(ds3, ds1) # type: ignore[arg-type]
|
||||
|
||||
def test_different_crs(self) -> None:
|
||||
ds1 = CustomGeoDataset(crs=CRS.from_epsg(3005))
|
||||
ds2 = CustomGeoDataset(crs=CRS.from_epsg(32616))
|
||||
ds = UnionDataset(ds1, ds2)
|
||||
assert ds.crs == ds1.crs
|
||||
assert len(ds) == 2
|
||||
|
||||
def test_different_res(self) -> None:
|
||||
ds1 = CustomGeoDataset(res=1)
|
||||
ds2 = CustomGeoDataset(res=2)
|
||||
ds = UnionDataset(ds1, ds2)
|
||||
assert ds.res == ds1.res
|
||||
assert len(ds) == 2
|
||||
|
||||
def test_invalid_query(self, dataset: UnionDataset) -> None:
|
||||
query = BoundingBox(4, 5, 4, 5, 4, 5)
|
||||
query = BoundingBox(-1, -1, -1, -1, -1, -1)
|
||||
with pytest.raises(
|
||||
IndexError, match="query: .* not found in index with bounds:"
|
||||
):
|
||||
|
|
|
@ -23,7 +23,6 @@ import torch
|
|||
from rasterio.crs import CRS
|
||||
from rasterio.io import DatasetReader
|
||||
from rasterio.vrt import WarpedVRT
|
||||
from rasterio.windows import from_bounds
|
||||
from rtree.index import Index, Property
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset
|
||||
|
@ -73,9 +72,8 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC):
|
|||
dataset = landsat7 | landsat8
|
||||
"""
|
||||
|
||||
#: Resolution of the dataset in units of CRS.
|
||||
res: float
|
||||
_crs: CRS
|
||||
_crs = CRS.from_epsg(4326)
|
||||
_res = 0.0
|
||||
|
||||
# NOTE: according to the Python docs:
|
||||
#
|
||||
|
@ -213,12 +211,10 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC):
|
|||
|
||||
@property
|
||||
def crs(self) -> CRS:
|
||||
""":term:`coordinate reference system (CRS)` for the dataset.
|
||||
""":term:`coordinate reference system (CRS)` of the dataset.
|
||||
|
||||
Returns:
|
||||
the :term:`coordinate reference system (CRS)`
|
||||
|
||||
.. versionadded:: 0.2
|
||||
The :term:`coordinate reference system (CRS)`.
|
||||
"""
|
||||
return self._crs
|
||||
|
||||
|
@ -229,17 +225,16 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC):
|
|||
If ``new_crs == self.crs``, does nothing, otherwise updates the R-tree index.
|
||||
|
||||
Args:
|
||||
new_crs: new :term:`coordinate reference system (CRS)`
|
||||
|
||||
.. versionadded:: 0.2
|
||||
new_crs: New :term:`coordinate reference system (CRS)`.
|
||||
"""
|
||||
if new_crs == self._crs:
|
||||
if new_crs == self.crs:
|
||||
return
|
||||
|
||||
print(f"Converting {self.__class__.__name__} CRS from {self.crs} to {new_crs}")
|
||||
new_index = Index(interleaved=False, properties=Property(dimension=3))
|
||||
|
||||
project = pyproj.Transformer.from_crs(
|
||||
pyproj.CRS(str(self._crs)), pyproj.CRS(str(new_crs)), always_xy=True
|
||||
pyproj.CRS(str(self.crs)), pyproj.CRS(str(new_crs)), always_xy=True
|
||||
).transform
|
||||
for hit in self.index.intersection(self.index.bounds, objects=True):
|
||||
old_minx, old_maxx, old_miny, old_maxy, mint, maxt = hit.bounds
|
||||
|
@ -252,6 +247,28 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC):
|
|||
self._crs = new_crs
|
||||
self.index = new_index
|
||||
|
||||
@property
|
||||
def res(self) -> float:
|
||||
"""Resolution of the dataset in units of CRS.
|
||||
|
||||
Returns:
|
||||
The resolution of the dataset.
|
||||
"""
|
||||
return self._res
|
||||
|
||||
@res.setter
|
||||
def res(self, new_res: float) -> None:
|
||||
"""Change the resolution of a GeoDataset.
|
||||
|
||||
Args:
|
||||
new_res: New resolution.
|
||||
"""
|
||||
if new_res == self.res:
|
||||
return
|
||||
|
||||
print(f"Converting {self.__class__.__name__} res from {self.res} to {new_res}")
|
||||
self._res = new_res
|
||||
|
||||
|
||||
class RasterDataset(GeoDataset):
|
||||
"""Abstract base class for :class:`GeoDataset` stored as raster files."""
|
||||
|
@ -399,7 +416,7 @@ class RasterDataset(GeoDataset):
|
|||
raise AssertionError(msg)
|
||||
|
||||
self._crs = cast(CRS, crs)
|
||||
self.res = cast(float, res)
|
||||
self._res = cast(float, res)
|
||||
|
||||
def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
|
||||
"""Retrieve image/mask and metadata indexed by query.
|
||||
|
@ -477,22 +494,7 @@ class RasterDataset(GeoDataset):
|
|||
vrt_fhs = [self._load_warp_file(fp) for fp in filepaths]
|
||||
|
||||
bounds = (query.minx, query.miny, query.maxx, query.maxy)
|
||||
if len(vrt_fhs) == 1:
|
||||
src = vrt_fhs[0]
|
||||
out_width = round((query.maxx - query.minx) / self.res)
|
||||
out_height = round((query.maxy - query.miny) / self.res)
|
||||
count = len(band_indexes) if band_indexes else src.count
|
||||
out_shape = (count, out_height, out_width)
|
||||
dest = src.read(
|
||||
indexes=band_indexes,
|
||||
out_shape=out_shape,
|
||||
window=from_bounds(*bounds, src.transform),
|
||||
boundless=True,
|
||||
)
|
||||
else:
|
||||
dest, _ = rasterio.merge.merge(
|
||||
vrt_fhs, bounds, self.res, indexes=band_indexes
|
||||
)
|
||||
dest, _ = rasterio.merge.merge(vrt_fhs, bounds, self.res, indexes=band_indexes)
|
||||
|
||||
# fix numpy dtypes which are not supported by pytorch tensors
|
||||
if dest.dtype == np.uint16:
|
||||
|
@ -574,7 +576,6 @@ class VectorDataset(GeoDataset):
|
|||
super().__init__(transforms)
|
||||
|
||||
self.root = root
|
||||
self.res = res
|
||||
self.label_name = label_name
|
||||
|
||||
# Populate the dataset index
|
||||
|
@ -605,6 +606,7 @@ class VectorDataset(GeoDataset):
|
|||
raise FileNotFoundError(msg)
|
||||
|
||||
self._crs = crs
|
||||
self._res = res
|
||||
|
||||
def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
|
||||
"""Retrieve image/mask and metadata indexed by query.
|
||||
|
@ -844,23 +846,9 @@ class IntersectionDataset(GeoDataset):
|
|||
if not isinstance(ds, GeoDataset):
|
||||
raise ValueError("IntersectionDataset only supports GeoDatasets")
|
||||
|
||||
self._crs = dataset1.crs
|
||||
self.crs = dataset1.crs
|
||||
self.res = dataset1.res
|
||||
|
||||
# Force dataset2 to have the same CRS/res as dataset1
|
||||
if dataset1.crs != dataset2.crs:
|
||||
print(
|
||||
f"Converting {dataset2.__class__.__name__} CRS from "
|
||||
f"{dataset2.crs} to {dataset1.crs}"
|
||||
)
|
||||
dataset2.crs = dataset1.crs
|
||||
if dataset1.res != dataset2.res:
|
||||
print(
|
||||
f"Converting {dataset2.__class__.__name__} resolution from "
|
||||
f"{dataset2.res} to {dataset1.res}"
|
||||
)
|
||||
dataset2.res = dataset1.res
|
||||
|
||||
# Merge dataset indices into a single index
|
||||
self._merge_dataset_indices()
|
||||
|
||||
|
@ -917,6 +905,46 @@ class IntersectionDataset(GeoDataset):
|
|||
bbox: {self.bounds}
|
||||
size: {len(self)}"""
|
||||
|
||||
@property
|
||||
def crs(self) -> CRS:
|
||||
""":term:`coordinate reference system (CRS)` of both datasets.
|
||||
|
||||
Returns:
|
||||
The :term:`coordinate reference system (CRS)`.
|
||||
"""
|
||||
return self._crs
|
||||
|
||||
@crs.setter
|
||||
def crs(self, new_crs: CRS) -> None:
|
||||
"""Change the :term:`coordinate reference system (CRS)` of both datasets.
|
||||
|
||||
Args:
|
||||
new_crs: New :term:`coordinate reference system (CRS)`.
|
||||
"""
|
||||
self._crs = new_crs
|
||||
self.datasets[0].crs = new_crs
|
||||
self.datasets[1].crs = new_crs
|
||||
|
||||
@property
|
||||
def res(self) -> float:
|
||||
"""Resolution of both datasets in units of CRS.
|
||||
|
||||
Returns:
|
||||
Resolution of both datasets.
|
||||
"""
|
||||
return self._res
|
||||
|
||||
@res.setter
|
||||
def res(self, new_res: float) -> None:
|
||||
"""Change the resolution of both datasets.
|
||||
|
||||
Args:
|
||||
new_res: New resolution.
|
||||
"""
|
||||
self._res = new_res
|
||||
self.datasets[0].res = new_res
|
||||
self.datasets[1].res = new_res
|
||||
|
||||
|
||||
class UnionDataset(GeoDataset):
|
||||
"""Dataset representing the union of two GeoDatasets.
|
||||
|
@ -970,23 +998,9 @@ class UnionDataset(GeoDataset):
|
|||
if not isinstance(ds, GeoDataset):
|
||||
raise ValueError("UnionDataset only supports GeoDatasets")
|
||||
|
||||
self._crs = dataset1.crs
|
||||
self.crs = dataset1.crs
|
||||
self.res = dataset1.res
|
||||
|
||||
# Force dataset2 to have the same CRS/res as dataset1
|
||||
if dataset1.crs != dataset2.crs:
|
||||
print(
|
||||
f"Converting {dataset2.__class__.__name__} CRS from "
|
||||
f"{dataset2.crs} to {dataset1.crs}"
|
||||
)
|
||||
dataset2.crs = dataset1.crs
|
||||
if dataset1.res != dataset2.res:
|
||||
print(
|
||||
f"Converting {dataset2.__class__.__name__} resolution from "
|
||||
f"{dataset2.res} to {dataset1.res}"
|
||||
)
|
||||
dataset2.res = dataset1.res
|
||||
|
||||
# Merge dataset indices into a single index
|
||||
self._merge_dataset_indices()
|
||||
|
||||
|
@ -1040,3 +1054,43 @@ class UnionDataset(GeoDataset):
|
|||
type: UnionDataset
|
||||
bbox: {self.bounds}
|
||||
size: {len(self)}"""
|
||||
|
||||
@property
|
||||
def crs(self) -> CRS:
|
||||
""":term:`coordinate reference system (CRS)` of both datasets.
|
||||
|
||||
Returns:
|
||||
The :term:`coordinate reference system (CRS)`.
|
||||
"""
|
||||
return self._crs
|
||||
|
||||
@crs.setter
|
||||
def crs(self, new_crs: CRS) -> None:
|
||||
"""Change the :term:`coordinate reference system (CRS)` of both datasets.
|
||||
|
||||
Args:
|
||||
new_crs: New :term:`coordinate reference system (CRS)`.
|
||||
"""
|
||||
self._crs = new_crs
|
||||
self.datasets[0].crs = new_crs
|
||||
self.datasets[1].crs = new_crs
|
||||
|
||||
@property
|
||||
def res(self) -> float:
|
||||
"""Resolution of both datasets in units of CRS.
|
||||
|
||||
Returns:
|
||||
The resolution of both datasets.
|
||||
"""
|
||||
return self._res
|
||||
|
||||
@res.setter
|
||||
def res(self, new_res: float) -> None:
|
||||
"""Change the resolution of both datasets.
|
||||
|
||||
Args:
|
||||
new_res: New resolution.
|
||||
"""
|
||||
self._res = new_res
|
||||
self.datasets[0].res = new_res
|
||||
self.datasets[1].res = new_res
|
||||
|
|
Загрузка…
Ссылка в новой задаче