* Add tests to catch reprojection issues

* More specific tests

* Fix boundless reprojection

* black/flake8

* Add getters and setters for all crs/res attributes

* Actually fix reprojection bug

* Fix tests

* Get all tests passing

* More specific tests

* Remove aux files

* Ignore *.aux.xml files

* Fix mypy

* Increase coverage

* Use dtype properly

* Increase coverage

* Add newline
This commit is contained in:
Adam J. Stewart 2023-05-18 23:33:47 -05:00 коммит произвёл GitHub
Родитель f562378bcb
Коммит bbf6108516
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
18 изменённых файлов: 348 добавлений и 135 удалений

1
.gitignore поставляемый
Просмотреть файл

@ -4,6 +4,7 @@
/output/
*.pdf
/results/
*.aux.xml
# Spack
.spack-env/

Просмотреть файл

@ -2,41 +2,99 @@
# Licensed under the MIT License.
import os
from typing import Optional
import numpy as np
import rasterio
import rasterio.transform
from torchvision.datasets.utils import calculate_md5
import rasterio as rio
from rasterio.transform import from_bounds
from rasterio.warp import calculate_default_transform, reproject
RES = [2, 4, 8]
EPSG = [4087, 4326, 32631]
SIZE = 16
def generate_test_data(fn: str) -> str:
"""Creates test data with uint32 datatype.
def write_raster(
res: int = RES[0],
epsg: int = EPSG[0],
dtype: str = "uint8",
path: Optional[str] = None,
) -> None:
"""Write a raster file.
Args:
fn (str): Filename to write
Returns:
str: md5 hash of created archive
res: Resolution.
epsg: EPSG of file.
dtype: Data type.
path: File path.
"""
size = SIZE // res
profile = {
"driver": "GTiff",
"dtype": "uint32",
"dtype": dtype,
"count": 1,
"crs": "epsg:4326",
"transform": rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1),
"height": 4,
"width": 4,
"compress": "lzw",
"predictor": 2,
"crs": f"epsg:{epsg}",
"transform": from_bounds(0, 0, SIZE, SIZE, size, size),
"height": size,
"width": size,
"nodata": 0,
}
with rasterio.open(fn, "w", **profile) as f:
f.write(np.random.randint(0, 256, size=(1, 4, 4)))
if path is None:
name = f"res_{res}_epsg_{epsg}"
path = os.path.join(name, f"{name}.tif")
md5: str = calculate_md5(fn)
return md5
directory = os.path.dirname(path)
os.makedirs(directory, exist_ok=True)
with rio.open(path, "w", **profile) as f:
x = np.ones((1, size, size))
f.write(x)
def reproject_raster(res: int, src_epsg: int, dst_epsg: int) -> None:
"""Reproject a raster file.
Args:
res: Resolution.
src_epsg: EPSG of source file.
dst_epsg: EPSG of destination file.
"""
src_name = f"res_{res}_epsg_{src_epsg}"
src_path = os.path.join(src_name, f"{src_name}.tif")
with rio.open(src_path) as src:
dst_crs = f"epsg:{dst_epsg}"
transform, width, height = calculate_default_transform(
src.crs, dst_crs, src.width, src.height, *src.bounds
)
profile = src.profile.copy()
profile.update(
{"crs": dst_crs, "transform": transform, "width": width, "height": height}
)
dst_name = f"res_{res}_epsg_{dst_epsg}"
os.makedirs(dst_name, exist_ok=True)
dst_path = os.path.join(dst_name, f"{dst_name}.tif")
with rio.open(dst_path, "w", **profile) as dst:
reproject(
source=rio.band(src, 1),
destination=rio.band(dst, 1),
src_transform=src.transform,
src_crs=src.crs,
dst_transform=dst.transform,
dst_crs=dst.crs,
)
if __name__ == "__main__":
md5_hash = generate_test_data(os.path.join(os.getcwd(), "test0.tif"))
print(md5_hash)
for res in RES:
src_epsg = EPSG[0]
write_raster(res, src_epsg)
for dst_epsg in EPSG[1:]:
reproject_raster(res, src_epsg, dst_epsg)
for dtype in ["uint16", "uint32"]:
path = os.path.join(dtype, f"{dtype}.tif")
write_raster(dtype=dtype, path=path)
with open(os.path.join(dtype, "corrupted.tif"), "w") as f:
f.write("not a tif file\n")

Двоичные данные
tests/data/raster/res_2_epsg_32631/res_2_epsg_32631.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/raster/res_2_epsg_4087/res_2_epsg_4087.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/raster/res_2_epsg_4326/res_2_epsg_4326.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/raster/res_4_epsg_32631/res_4_epsg_32631.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/raster/res_4_epsg_4087/res_4_epsg_4087.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/raster/res_4_epsg_4326/res_4_epsg_4326.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/raster/res_8_epsg_32631/res_8_epsg_32631.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/raster/res_8_epsg_4087/res_8_epsg_4087.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/raster/res_8_epsg_4326/res_8_epsg_4326.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/raster/test0.tif

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1 @@
not a tif file

Двоичные данные
tests/data/raster/uint16/uint16.tif Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1 @@
not a tif file

Двоичные данные
tests/data/raster/uint32/uint32.tif Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -30,7 +30,7 @@ class CustomGeoDataset(GeoDataset):
def __init__(
self,
bounds: BoundingBox = BoundingBox(0, 1, 2, 3, 4, 5),
crs: CRS = CRS.from_epsg(3005),
crs: CRS = CRS.from_epsg(4087),
res: float = 1,
) -> None:
super().__init__()
@ -74,7 +74,7 @@ class TestGeoDataset:
def test_len(self, dataset: GeoDataset) -> None:
assert len(dataset) == 1
@pytest.mark.parametrize("crs", [CRS.from_epsg(3005), CRS.from_epsg(32616)])
@pytest.mark.parametrize("crs", [CRS.from_epsg(4087), CRS.from_epsg(32631)])
def test_crs(self, dataset: GeoDataset, crs: CRS) -> None:
dataset.crs = crs
@ -157,7 +157,7 @@ class TestRasterDataset:
def naip(self, request: SubRequest) -> NAIP:
root = os.path.join("tests", "data", "naip")
bands = request.param[0]
crs = CRS.from_epsg(3005)
crs = CRS.from_epsg(4087)
transforms = nn.Identity()
cache = request.param[1]
return NAIP(root, crs=crs, bands=bands, transforms=transforms, cache=cache)
@ -178,11 +178,6 @@ class TestRasterDataset:
cache = request.param[1]
return Sentinel2(root, bands=bands, transforms=transforms, cache=cache)
@pytest.fixture()
def custom_dtype_ds(self) -> RasterDataset:
root = os.path.join("tests", "data", "raster")
return RasterDataset(root)
def test_getitem_single_file(self, naip: NAIP) -> None:
x = naip[naip.bounds]
assert isinstance(x, dict)
@ -197,8 +192,11 @@ class TestRasterDataset:
assert isinstance(x["image"], torch.Tensor)
assert len(sentinel.bands) == x["image"].shape[0]
def test_getitem_uint_dtype(self, custom_dtype_ds: RasterDataset) -> None:
x = custom_dtype_ds[custom_dtype_ds.bounds]
@pytest.mark.parametrize("dtype", ["uint16", "uint32"])
def test_getitem_uint_dtype(self, dtype: str) -> None:
root = os.path.join("tests", "data", "raster", dtype)
ds = RasterDataset(root)
x = ds[ds.bounds]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert x["image"].dtype == torch.float32
@ -377,14 +375,15 @@ class TestNonGeoClassificationDataset:
class TestIntersectionDataset:
@pytest.fixture(scope="class")
def dataset(self) -> IntersectionDataset:
ds1 = CustomGeoDataset()
ds2 = CustomGeoDataset()
ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087"))
ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4326"))
transforms = nn.Identity()
return IntersectionDataset(ds1, ds2, transforms=transforms)
def test_getitem(self, dataset: IntersectionDataset) -> None:
query = BoundingBox(0, 1, 2, 3, 4, 5)
assert dataset[query] == {"index": query}
query = dataset.bounds
sample = dataset[query]
assert isinstance(sample["image"], torch.Tensor)
def test_len(self, dataset: IntersectionDataset) -> None:
assert len(dataset) == 1
@ -403,27 +402,69 @@ class TestIntersectionDataset:
):
IntersectionDataset(ds1, ds2) # type: ignore[arg-type]
def test_different_crs(self) -> None:
ds1 = CustomGeoDataset(BoundingBox(0, 1, 0, 1, 0, 1), crs=CRS.from_epsg(3005))
ds2 = CustomGeoDataset(
BoundingBox(
-3547229.913123814,
6360089.518213182,
-3547229.913123814,
6360089.518213182,
-3547229.913123814,
6360089.518213182,
),
crs=CRS.from_epsg(32616),
)
def test_different_crs_12(self) -> None:
ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087"))
ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4326"))
ds = IntersectionDataset(ds1, ds2)
assert len(ds) == 1
sample = ds[ds.bounds]
assert ds1.crs == ds2.crs == ds.crs == CRS.from_epsg(4087)
assert ds1.res == ds2.res == ds.res == 2
assert len(ds1) == len(ds2) == len(ds) == 1
assert isinstance(sample["image"], torch.Tensor)
def test_different_res(self) -> None:
ds1 = CustomGeoDataset(res=1)
ds2 = CustomGeoDataset(res=2)
def test_different_crs_12_3(self) -> None:
ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087"))
ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4326"))
ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_32631"))
ds = (ds1 & ds2) & ds3
sample = ds[ds.bounds]
assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087)
assert ds1.res == ds2.res == ds3.res == ds.res == 2
assert len(ds1) == len(ds2) == len(ds3) == len(ds) == 1
assert isinstance(sample["image"], torch.Tensor)
def test_different_crs_1_23(self) -> None:
ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087"))
ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4326"))
ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_32631"))
ds = ds1 & (ds2 & ds3)
sample = ds[ds.bounds]
assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087)
assert ds1.res == ds2.res == ds3.res == ds.res == 2
assert len(ds1) == len(ds2) == len(ds3) == len(ds) == 1
assert isinstance(sample["image"], torch.Tensor)
def test_different_res_12(self) -> None:
ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087"))
ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4087"))
ds = IntersectionDataset(ds1, ds2)
assert len(ds) == 1
sample = ds[ds.bounds]
assert ds1.crs == ds2.crs == ds.crs == CRS.from_epsg(4087)
assert ds1.res == ds2.res == ds.res == 2
assert len(ds1) == len(ds2) == len(ds) == 1
assert isinstance(sample["image"], torch.Tensor)
def test_different_res_12_3(self) -> None:
ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087"))
ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4087"))
ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_8_epsg_4087"))
ds = (ds1 & ds2) & ds3
sample = ds[ds.bounds]
assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087)
assert ds1.res == ds2.res == ds3.res == ds.res == 2
assert len(ds1) == len(ds2) == len(ds3) == len(ds) == 1
assert isinstance(sample["image"], torch.Tensor)
def test_different_res_1_23(self) -> None:
ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087"))
ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4087"))
ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_8_epsg_4087"))
ds = ds1 & (ds2 & ds3)
sample = ds[ds.bounds]
assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087)
assert ds1.res == ds2.res == ds3.res == ds.res == 2
assert len(ds1) == len(ds2) == len(ds3) == len(ds) == 1
assert isinstance(sample["image"], torch.Tensor)
def test_no_overlap(self) -> None:
ds1 = CustomGeoDataset(BoundingBox(0, 1, 2, 3, 4, 5))
@ -433,7 +474,7 @@ class TestIntersectionDataset:
IntersectionDataset(ds1, ds2)
def test_invalid_query(self, dataset: IntersectionDataset) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
query = BoundingBox(-1, -1, -1, -1, -1, -1)
with pytest.raises(
IndexError, match="query: .* not found in index with bounds:"
):
@ -443,14 +484,15 @@ class TestIntersectionDataset:
class TestUnionDataset:
@pytest.fixture(scope="class")
def dataset(self) -> UnionDataset:
ds1 = CustomGeoDataset(bounds=BoundingBox(0, 1, 0, 1, 0, 1))
ds2 = CustomGeoDataset(bounds=BoundingBox(2, 3, 2, 3, 2, 3))
ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087"))
ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4326"))
transforms = nn.Identity()
return UnionDataset(ds1, ds2, transforms=transforms)
def test_getitem(self, dataset: UnionDataset) -> None:
query = BoundingBox(0, 1, 0, 1, 0, 1)
assert dataset[query] == {"index": query}
query = dataset.bounds
sample = dataset[query]
assert isinstance(sample["image"], torch.Tensor)
def test_len(self, dataset: UnionDataset) -> None:
assert len(dataset) == 2
@ -461,6 +503,76 @@ class TestUnionDataset:
assert "bbox: BoundingBox" in out
assert "size: 2" in out
def test_different_crs_12(self) -> None:
ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087"))
ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4326"))
ds = UnionDataset(ds1, ds2)
sample = ds[ds.bounds]
assert ds1.crs == ds2.crs == ds.crs == CRS.from_epsg(4087)
assert ds1.res == ds2.res == ds.res == 2
assert len(ds1) == len(ds2) == 1
assert len(ds) == 2
assert isinstance(sample["image"], torch.Tensor)
def test_different_crs_12_3(self) -> None:
ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087"))
ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4326"))
ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_32631"))
ds = (ds1 | ds2) | ds3
sample = ds[ds.bounds]
assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087)
assert ds1.res == ds2.res == ds3.res == ds.res == 2
assert len(ds1) == len(ds2) == len(ds3) == 1
assert len(ds) == 3
assert isinstance(sample["image"], torch.Tensor)
def test_different_crs_1_23(self) -> None:
ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087"))
ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4326"))
ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_32631"))
ds = ds1 | (ds2 | ds3)
sample = ds[ds.bounds]
assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087)
assert ds1.res == ds2.res == ds3.res == ds.res == 2
assert len(ds1) == len(ds2) == len(ds3) == 1
assert len(ds) == 3
assert isinstance(sample["image"], torch.Tensor)
def test_different_res_12(self) -> None:
ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087"))
ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4087"))
ds = UnionDataset(ds1, ds2)
sample = ds[ds.bounds]
assert ds1.crs == ds2.crs == ds.crs == CRS.from_epsg(4087)
assert ds1.res == ds2.res == ds.res == 2
assert len(ds1) == len(ds2) == 1
assert len(ds) == 2
assert isinstance(sample["image"], torch.Tensor)
def test_different_res_12_3(self) -> None:
ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087"))
ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4087"))
ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_8_epsg_4087"))
ds = (ds1 | ds2) | ds3
sample = ds[ds.bounds]
assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087)
assert ds1.res == ds2.res == ds3.res == ds.res == 2
assert len(ds1) == len(ds2) == len(ds3) == 1
assert len(ds) == 3
assert isinstance(sample["image"], torch.Tensor)
def test_different_res_1_23(self) -> None:
ds1 = RasterDataset(os.path.join("tests", "data", "raster", "res_2_epsg_4087"))
ds2 = RasterDataset(os.path.join("tests", "data", "raster", "res_4_epsg_4087"))
ds3 = RasterDataset(os.path.join("tests", "data", "raster", "res_8_epsg_4087"))
ds = ds1 | (ds2 | ds3)
sample = ds[ds.bounds]
assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087)
assert ds1.res == ds2.res == ds3.res == ds.res == 2
assert len(ds1) == len(ds2) == len(ds3) == 1
assert len(ds) == 3
assert isinstance(sample["image"], torch.Tensor)
def test_nongeo_dataset(self) -> None:
ds1 = CustomNonGeoDataset()
ds2 = CustomNonGeoDataset()
@ -473,22 +585,8 @@ class TestUnionDataset:
with pytest.raises(ValueError, match=msg):
UnionDataset(ds3, ds1) # type: ignore[arg-type]
def test_different_crs(self) -> None:
ds1 = CustomGeoDataset(crs=CRS.from_epsg(3005))
ds2 = CustomGeoDataset(crs=CRS.from_epsg(32616))
ds = UnionDataset(ds1, ds2)
assert ds.crs == ds1.crs
assert len(ds) == 2
def test_different_res(self) -> None:
ds1 = CustomGeoDataset(res=1)
ds2 = CustomGeoDataset(res=2)
ds = UnionDataset(ds1, ds2)
assert ds.res == ds1.res
assert len(ds) == 2
def test_invalid_query(self, dataset: UnionDataset) -> None:
query = BoundingBox(4, 5, 4, 5, 4, 5)
query = BoundingBox(-1, -1, -1, -1, -1, -1)
with pytest.raises(
IndexError, match="query: .* not found in index with bounds:"
):

Просмотреть файл

@ -23,7 +23,6 @@ import torch
from rasterio.crs import CRS
from rasterio.io import DatasetReader
from rasterio.vrt import WarpedVRT
from rasterio.windows import from_bounds
from rtree.index import Index, Property
from torch import Tensor
from torch.utils.data import Dataset
@ -73,9 +72,8 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC):
dataset = landsat7 | landsat8
"""
#: Resolution of the dataset in units of CRS.
res: float
_crs: CRS
_crs = CRS.from_epsg(4326)
_res = 0.0
# NOTE: according to the Python docs:
#
@ -213,12 +211,10 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC):
@property
def crs(self) -> CRS:
""":term:`coordinate reference system (CRS)` for the dataset.
""":term:`coordinate reference system (CRS)` of the dataset.
Returns:
the :term:`coordinate reference system (CRS)`
.. versionadded:: 0.2
The :term:`coordinate reference system (CRS)`.
"""
return self._crs
@ -229,17 +225,16 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC):
If ``new_crs == self.crs``, does nothing, otherwise updates the R-tree index.
Args:
new_crs: new :term:`coordinate reference system (CRS)`
.. versionadded:: 0.2
new_crs: New :term:`coordinate reference system (CRS)`.
"""
if new_crs == self._crs:
if new_crs == self.crs:
return
print(f"Converting {self.__class__.__name__} CRS from {self.crs} to {new_crs}")
new_index = Index(interleaved=False, properties=Property(dimension=3))
project = pyproj.Transformer.from_crs(
pyproj.CRS(str(self._crs)), pyproj.CRS(str(new_crs)), always_xy=True
pyproj.CRS(str(self.crs)), pyproj.CRS(str(new_crs)), always_xy=True
).transform
for hit in self.index.intersection(self.index.bounds, objects=True):
old_minx, old_maxx, old_miny, old_maxy, mint, maxt = hit.bounds
@ -252,6 +247,28 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC):
self._crs = new_crs
self.index = new_index
@property
def res(self) -> float:
"""Resolution of the dataset in units of CRS.
Returns:
The resolution of the dataset.
"""
return self._res
@res.setter
def res(self, new_res: float) -> None:
"""Change the resolution of a GeoDataset.
Args:
new_res: New resolution.
"""
if new_res == self.res:
return
print(f"Converting {self.__class__.__name__} res from {self.res} to {new_res}")
self._res = new_res
class RasterDataset(GeoDataset):
"""Abstract base class for :class:`GeoDataset` stored as raster files."""
@ -399,7 +416,7 @@ class RasterDataset(GeoDataset):
raise AssertionError(msg)
self._crs = cast(CRS, crs)
self.res = cast(float, res)
self._res = cast(float, res)
def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
"""Retrieve image/mask and metadata indexed by query.
@ -477,22 +494,7 @@ class RasterDataset(GeoDataset):
vrt_fhs = [self._load_warp_file(fp) for fp in filepaths]
bounds = (query.minx, query.miny, query.maxx, query.maxy)
if len(vrt_fhs) == 1:
src = vrt_fhs[0]
out_width = round((query.maxx - query.minx) / self.res)
out_height = round((query.maxy - query.miny) / self.res)
count = len(band_indexes) if band_indexes else src.count
out_shape = (count, out_height, out_width)
dest = src.read(
indexes=band_indexes,
out_shape=out_shape,
window=from_bounds(*bounds, src.transform),
boundless=True,
)
else:
dest, _ = rasterio.merge.merge(
vrt_fhs, bounds, self.res, indexes=band_indexes
)
dest, _ = rasterio.merge.merge(vrt_fhs, bounds, self.res, indexes=band_indexes)
# fix numpy dtypes which are not supported by pytorch tensors
if dest.dtype == np.uint16:
@ -574,7 +576,6 @@ class VectorDataset(GeoDataset):
super().__init__(transforms)
self.root = root
self.res = res
self.label_name = label_name
# Populate the dataset index
@ -605,6 +606,7 @@ class VectorDataset(GeoDataset):
raise FileNotFoundError(msg)
self._crs = crs
self._res = res
def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
"""Retrieve image/mask and metadata indexed by query.
@ -844,23 +846,9 @@ class IntersectionDataset(GeoDataset):
if not isinstance(ds, GeoDataset):
raise ValueError("IntersectionDataset only supports GeoDatasets")
self._crs = dataset1.crs
self.crs = dataset1.crs
self.res = dataset1.res
# Force dataset2 to have the same CRS/res as dataset1
if dataset1.crs != dataset2.crs:
print(
f"Converting {dataset2.__class__.__name__} CRS from "
f"{dataset2.crs} to {dataset1.crs}"
)
dataset2.crs = dataset1.crs
if dataset1.res != dataset2.res:
print(
f"Converting {dataset2.__class__.__name__} resolution from "
f"{dataset2.res} to {dataset1.res}"
)
dataset2.res = dataset1.res
# Merge dataset indices into a single index
self._merge_dataset_indices()
@ -917,6 +905,46 @@ class IntersectionDataset(GeoDataset):
bbox: {self.bounds}
size: {len(self)}"""
@property
def crs(self) -> CRS:
""":term:`coordinate reference system (CRS)` of both datasets.
Returns:
The :term:`coordinate reference system (CRS)`.
"""
return self._crs
@crs.setter
def crs(self, new_crs: CRS) -> None:
"""Change the :term:`coordinate reference system (CRS)` of both datasets.
Args:
new_crs: New :term:`coordinate reference system (CRS)`.
"""
self._crs = new_crs
self.datasets[0].crs = new_crs
self.datasets[1].crs = new_crs
@property
def res(self) -> float:
"""Resolution of both datasets in units of CRS.
Returns:
Resolution of both datasets.
"""
return self._res
@res.setter
def res(self, new_res: float) -> None:
"""Change the resolution of both datasets.
Args:
new_res: New resolution.
"""
self._res = new_res
self.datasets[0].res = new_res
self.datasets[1].res = new_res
class UnionDataset(GeoDataset):
"""Dataset representing the union of two GeoDatasets.
@ -970,23 +998,9 @@ class UnionDataset(GeoDataset):
if not isinstance(ds, GeoDataset):
raise ValueError("UnionDataset only supports GeoDatasets")
self._crs = dataset1.crs
self.crs = dataset1.crs
self.res = dataset1.res
# Force dataset2 to have the same CRS/res as dataset1
if dataset1.crs != dataset2.crs:
print(
f"Converting {dataset2.__class__.__name__} CRS from "
f"{dataset2.crs} to {dataset1.crs}"
)
dataset2.crs = dataset1.crs
if dataset1.res != dataset2.res:
print(
f"Converting {dataset2.__class__.__name__} resolution from "
f"{dataset2.res} to {dataset1.res}"
)
dataset2.res = dataset1.res
# Merge dataset indices into a single index
self._merge_dataset_indices()
@ -1040,3 +1054,43 @@ class UnionDataset(GeoDataset):
type: UnionDataset
bbox: {self.bounds}
size: {len(self)}"""
@property
def crs(self) -> CRS:
""":term:`coordinate reference system (CRS)` of both datasets.
Returns:
The :term:`coordinate reference system (CRS)`.
"""
return self._crs
@crs.setter
def crs(self, new_crs: CRS) -> None:
"""Change the :term:`coordinate reference system (CRS)` of both datasets.
Args:
new_crs: New :term:`coordinate reference system (CRS)`.
"""
self._crs = new_crs
self.datasets[0].crs = new_crs
self.datasets[1].crs = new_crs
@property
def res(self) -> float:
"""Resolution of both datasets in units of CRS.
Returns:
The resolution of both datasets.
"""
return self._res
@res.setter
def res(self, new_res: float) -> None:
"""Change the resolution of both datasets.
Args:
new_res: New resolution.
"""
self._res = new_res
self.datasets[0].res = new_res
self.datasets[1].res = new_res