Overhaul BoundingBox and ZipDataset classes (#144)

* Adding a UnionDataset

* Adding contains method to BoundingBox

* Finishing UnionDataset

* Add __contains__ method

* Overhaul BoundingBox, add set arithmetic

* mypy fixes

* pydocstyle fixes

* Ignore erroneous pydocstyle warnings

* rtree only supports tuples, not BoundingBoxes

* mypy fixes

* Use custom collate function to handle BoundingBoxes

* Add back support for Python 3.6

* Add tests for all new BoundingBox features

* Rename ZipDataset to IntersectionDataset

* Merge indices of IntersectionDataset, auto-convert CRS/res

* Get tests to pass

* Fix more tests

* Test more of RasterDataset/VectorDataset directly

* Increase UnionDataset test coverage

* IntersectionDataset stacks tensors, UnionDataset merges tensors

* Support collating dicts with differing keys, add tests

* Style fixes

* Samplers: compute intersection between index and ROI

* Update README with example usage

* GeoDataset addition is deprecated

* Add note about CRS/res

* More documentation for Intersection/UnionDatasets

* Use collate function in tutorial

* Don't use multiple workers

* Fix typo

* Drop support for adding GeoDatasets

* Remove unused import

* Add comment explaining coverage config settings

* Collation function needed for benchmark script

* Add more explanation to README

* Correct Landsat 8 bands

* Print warning when changing CRS/res

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Caleb Robinson 2021-12-03 14:40:50 -08:00 коммит произвёл GitHub
Родитель 6f23fa42c5
Коммит 5d407b76b5
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
38 изменённых файлов: 1269 добавлений и 335 удалений

Просмотреть файл

@ -7,7 +7,7 @@ The goal of this library is to make it simple:
1. for machine learning experts to use geospatial data in their workflows, and
2. for remote sensing experts to use their data in machine learning workflows.
See our [installation instructions](#installation-instructions), [documentation](#documentation), and [examples](#example-usage) to learn how to use torchgeo.
See our [installation instructions](#installation), [documentation](#documentation), and [examples](#example-usage) to learn how to use TorchGeo.
External links:
[![docs](https://readthedocs.org/projects/torchgeo/badge/?version=latest)](https://torchgeo.readthedocs.io/en/latest/?badge=latest)
@ -21,7 +21,7 @@ Tests:
[![style](https://github.com/microsoft/torchgeo/actions/workflows/style.yaml/badge.svg)](https://github.com/microsoft/torchgeo/actions/workflows/style.yaml)
[![tests](https://github.com/microsoft/torchgeo/actions/workflows/tests.yaml/badge.svg)](https://github.com/microsoft/torchgeo/actions/workflows/tests.yaml)
## Installation instructions
## Installation
The recommended way to install TorchGeo is with [pip](https://pip.pypa.io/):
@ -33,13 +33,83 @@ For [conda](https://docs.conda.io/) and [spack](https://spack.io/) installation
## Documentation
You can find the documentation for torchgeo on [ReadTheDocs](https://torchgeo.readthedocs.io).
You can find the documentation for TorchGeo on [ReadTheDocs](https://torchgeo.readthedocs.io).
## Example usage
## Example Usage
The following sections give basic examples of what you can do with torchgeo. For more examples, check out our [tutorials](https://torchgeo.readthedocs.io/en/latest/tutorials/getting_started.html).
The following sections give basic examples of what you can do with TorchGeo. For more examples, check out our [tutorials](https://torchgeo.readthedocs.io/en/latest/tutorials/getting_started.html).
### Train and test models using our PyTorch Lightning based training script
First we'll import various classes and functions used in the following sections:
```python
from torch.utils.data import DataLoader
from torchgeo.datasets import CDL, COWCDetection, Landsat7, Landsat8, stack_samples
from torchgeo.samplers import RandomGeoSampler
```
### Benchmark datasets
TorchGeo includes a number of [*benchmark*](https://torchgeo.readthedocs.io/en/latest/api/datasets.html#non-geospatial-datasets) datasets, datasets that include both input images and target labels. This includes datasets for tasks like image classification, regression, semantic segmentation, object detection, instance segmentation, change detection, and more.
If you've used [torchvision](https://pytorch.org/vision) before, these datasets should seem very familiar. In this example, we'll create a dataset for the Cars Overhead With Context (COWC) car detection dataset. This dataset can be automatically downloaded, checksummed, and extracted, just like with torchvision.
```python
dataset = COWCDetection(root="...", split="train", download=True, checksum=True)
```
This dataset can then be passed to a PyTorch data loader.
```python
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4)
```
The only difference between a benchmark dataset in TorchGeo and a similar dataset in torchvision is that each dataset returns a dictionary with keys for each PyTorch Tensor.
```python
for batch in dataloader:
image = batch["image"]
label = batch["label"]
# train a model, or make predictions using a pre-trained model
```
### Geospatial datasets
Many remote sensing applications involve working with [*generic*](https://torchgeo.readthedocs.io/en/latest/api/datasets.html#geospatial-datasets) geospatial data. This data can be challenging to work with due to the sheer variety of data. Geospatial imagery is often multispectral with a different number of spectral bands and spatial resolution for every satellite. In addition, each file may be in a different coordinate reference system (CRS), requiring the data to be reprojected into a matching CRS.
In this example, we show how easy it is to work with geospatial data and to sample small image patches from a combination of Landsat and Cropland Data Layer (CDL) data using TorchGeo. First, we assume that the user has Landsat 7 and 8 imagery downloaded. Since Landsat 8 has more spectral bands than Landsat 7, we'll only use the bands that both satellites have in common. We'll create a single dataset including all images from both Landsat 7 and 8 data by taking the union between these two datasets.
```python
landsat7 = Landsat7(root="...")
landsat8 = Landsat8(root="...", bands=["B2", "B3", "B4", "B5", "B6", "B7", "B8", "B9"])
landsat = landsat7 | landsat8
```
Next, we take the intersection between this dataset and the Cropland Data Layer (CDL) dataset. We want to take the intersection instead of the union to ensure that we only sample from regions that have both Landsat and CDL data. Note that we can automatically download and checksum CDL data. Also note that each of these datasets may contain files in different coordinate reference systems (CRS) or resolutions, but TorchGeo automatically ensures that a matching CRS and resolution is used.
```python
cdl = CDL(root="...", download=True, checksum=True)
dataset = landsat & cdl
```
This dataset can now be used with a PyTorch data loader. Unlike benchmark datasets, geospatial datasets often include very large images. For example, the CDL dataset consists of a single image covering the entire continental United States. In order to sample from these datasets using geospatial coordinates, TorchGeo defines a number of [*samplers*](https://torchgeo.readthedocs.io/en/latest/api/samplers.html). In this example, we'll use a random sampler that returns 256x256 pixel images and an epoch length of 10,000 images. We also use a custom collation function to combine each sample dictionary into a mini-batch of samples.
```python
sampler = RandomGeoSampler(dataset, size=256, length=10000)
dataloader = DataLoader(dataset, batch_size=128, sampler=sampler, collate_fn=stack_samples)
```
This data loader can now be used in your normal training/evaluation pipeline.
```python
for batch in dataloader:
image = batch["image"]
mask = batch["mask"]
# train a model, or make predictions using a pre-trained model
```
### Train and test models using our PyTorch Lightning-based training script
We provide a script, `train.py` for training models using a subset of the datasets. We do this with the PyTorch Lightning `LightningModule`s and `LightningDataModule`s implemented under the `torchgeo.trainers` namespace.
The `train.py` script is configurable via the command line and/or via YAML configuration files. See the [conf/](conf/) directory for example configuration files that can be customized for different training runs.
@ -48,20 +118,6 @@ The `train.py` script is configurable via the command line and/or via YAML confi
$ python train.py config_file=conf/landcoverai.yaml
```
### Download and use the Tropical Cyclone Wind Estimation Competition dataset
This dataset is from a competition hosted by [Driven Data](https://www.drivendata.org/) in collaboration with [Radiant Earth](https://www.radiant.earth/). See [here](https://www.drivendata.org/competitions/72/predict-wind-speeds/) for more information.
Using this dataset in torchgeo is as simple as importing and instantiating the appropriate class.
```python
import torchgeo.datasets
dataset = torchgeo.datasets.TropicalCycloneWindEstimation(split="train", download=True)
print(dataset[0]["image"].shape)
print(dataset[0]["label"])
```
## Citation
If you use this software in your work, please cite [our paper](https://arxiv.org/abs/2111.08872):

Просмотреть файл

@ -14,7 +14,7 @@ import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.models import resnet34
from torchgeo.datasets import CDL, Landsat8
from torchgeo.datasets import CDL, Landsat8, stack_samples
from torchgeo.samplers import GridGeoSampler, RandomBatchGeoSampler, RandomGeoSampler
@ -129,7 +129,7 @@ def main(args: argparse.Namespace) -> None:
landsat = Landsat8(
args.landsat_root, crs=cdl.crs, res=cdl.res, cache=args.cache, bands=bands
)
dataset = landsat + cdl
dataset = landsat & cdl
# Initialize samplers
if args.epoch_size:
@ -158,7 +158,10 @@ def main(args: argparse.Namespace) -> None:
if isinstance(sampler, RandomBatchGeoSampler):
dataloader = DataLoader(
dataset, batch_sampler=sampler, num_workers=args.num_workers
dataset,
batch_sampler=sampler,
num_workers=args.num_workers,
collate_fn=stack_samples,
)
else:
dataloader = DataLoader(
@ -166,6 +169,7 @@ def main(args: argparse.Namespace) -> None:
batch_size=args.batch_size,
sampler=sampler, # type: ignore[arg-type]
num_workers=args.num_workers,
collate_fn=stack_samples,
)
tic = time.time()

Просмотреть файл

@ -10,7 +10,7 @@ In :mod:`torchgeo`, we define two types of datasets: :ref:`Geospatial Datasets`
Geospatial Datasets
-------------------
:class:`GeoDataset` is designed for datasets that contain geospatial information, like latitude, longitude, coordinate system, and projection. Datasets containing this kind of information can be combined using :class:`ZipDataset`.
:class:`GeoDataset` is designed for datasets that contain geospatial information, like latitude, longitude, coordinate system, and projection. Datasets containing this kind of information can be combined using :class:`IntersectionDataset` and :class:`UnionDataset`.
Canadian Building Footprints
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@ -245,13 +245,24 @@ VisionClassificationDataset
.. autoclass:: VisionClassificationDataset
ZipDataset
^^^^^^^^^^
IntersectionDataset
^^^^^^^^^^^^^^^^^^^
.. autoclass:: ZipDataset
.. autoclass:: IntersectionDataset
UnionDataset
^^^^^^^^^^^^
.. autoclass:: UnionDataset
Utilities
---------
.. autoclass:: BoundingBox
.. autofunction:: collate_dict
Collation Functions
^^^^^^^^^^^^^^^^^^^
.. autofunction:: stack_samples
.. autofunction:: concat_samples
.. autofunction:: merge_samples

Просмотреть файл

@ -72,9 +72,8 @@
"\n",
"from torch.utils.data import DataLoader\n",
"\n",
"from torchgeo.datasets import NAIP, ChesapeakeDE\n",
"from torchgeo.datasets import NAIP, ChesapeakeDE, stack_samples\n",
"from torchgeo.datasets.utils import download_url\n",
"from torchgeo.models import FCN\n",
"from torchgeo.samplers import RandomGeoSampler"
]
},
@ -293,7 +292,7 @@
"id": "OWUhlfpD22IX"
},
"source": [
"Finally, we create a ZipDataset so that we can automatically sample from both GeoDatasets simultaneously."
"Finally, we create an IntersectionDataset so that we can automatically sample from both GeoDatasets simultaneously."
]
},
{
@ -305,7 +304,7 @@
},
"outputs": [],
"source": [
"dataset = naip + chesapeake"
"dataset = naip & chesapeake"
]
},
{
@ -353,7 +352,7 @@
},
"outputs": [],
"source": [
"dataloader = DataLoader(dataset, sampler=sampler)"
"dataloader = DataLoader(dataset, sampler=sampler, collate_fn=stack_samples)"
]
},
{
@ -393,7 +392,7 @@
"timeout": 1200
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
@ -406,7 +405,8 @@
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"version": "3.9.7"
"pygments_lexer": "ipython3",
"version": "3.8.12"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {

Просмотреть файл

@ -35,6 +35,14 @@ exclude = '''
)/
'''
[tool.coverage.report]
# Ignore warnings for overloads
# https://github.com/nedbat/coveragepy/issues/970#issuecomment-612602180
exclude_lines = [
"pragma: no cover",
"@overload",
]
[tool.isort]
profile = "black"
known_first_party = ["docs", "tests", "torchgeo"]

Просмотреть файл

@ -28,6 +28,8 @@ setup_requires =
# setuptools 42+ required for metadata.license_files support in setup.cfg
setuptools>=42
install_requires =
# dataclasses was added in Python 3.7
dataclasses;python_version<'3.7'
einops
# fiona 1.5+ required for fiona.transform module
fiona>=1.5

Просмотреть файл

@ -0,0 +1,7 @@
{
"type": "FeatureCollection",
"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } },
"features": [
{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 0.0, 0.0 ], [ 0.0, 1.0 ], [ 1.0, 1.0 ], [ 1.0, 0.0 ], [ 0.0, 0.0 ] ] ] } }
]
}

Просмотреть файл

@ -0,0 +1,7 @@
{
"type": "FeatureCollection",
"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } },
"features": [
{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 0.0, 0.0 ], [ 0.0, 1.0 ], [ 1.0, 1.0 ], [ 1.0, 0.0 ], [ 0.0, 0.0 ] ] ] } }
]
}

Просмотреть файл

@ -0,0 +1,7 @@
{
"type": "FeatureCollection",
"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } },
"features": [
{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 0.0, 0.0 ], [ 0.0, 1.0 ], [ 1.0, 1.0 ], [ 1.0, 0.0 ], [ 0.0, 0.0 ] ] ] } }
]
}

Просмотреть файл

@ -0,0 +1,7 @@
{
"type": "FeatureCollection",
"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } },
"features": [
{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 0.0, 0.0 ], [ 0.0, 1.0 ], [ 1.0, 1.0 ], [ 1.0, 0.0 ], [ 0.0, 0.0 ] ] ] } }
]
}

Просмотреть файл

@ -0,0 +1,7 @@
{
"type": "FeatureCollection",
"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } },
"features": [
{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 0.0, 0.0 ], [ 0.0, 1.0 ], [ 1.0, 1.0 ], [ 1.0, 0.0 ], [ 0.0, 0.0 ] ] ] } }
]
}

Просмотреть файл

@ -0,0 +1,7 @@
{
"type": "FeatureCollection",
"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } },
"features": [
{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 0.0, 0.0 ], [ 0.0, 1.0 ], [ 1.0, 1.0 ], [ 1.0, 0.0 ], [ 0.0, 0.0 ] ] ] } }
]
}

Просмотреть файл

@ -0,0 +1,7 @@
{
"type": "FeatureCollection",
"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } },
"features": [
{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 0.0, 0.0 ], [ 0.0, 1.0 ], [ 1.0, 1.0 ], [ 1.0, 0.0 ], [ 0.0, 0.0 ] ] ] } }
]
}

Просмотреть файл

@ -0,0 +1,7 @@
{
"type": "FeatureCollection",
"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } },
"features": [
{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 0.0, 0.0 ], [ 0.0, 1.0 ], [ 1.0, 1.0 ], [ 1.0, 0.0 ], [ 0.0, 0.0 ] ] ] } }
]
}

Просмотреть файл

@ -0,0 +1,7 @@
{
"type": "FeatureCollection",
"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } },
"features": [
{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 0.0, 0.0 ], [ 0.0, 1.0 ], [ 1.0, 1.0 ], [ 1.0, 0.0 ], [ 0.0, 0.0 ] ] ] } }
]
}

Просмотреть файл

@ -0,0 +1,7 @@
{
"type": "FeatureCollection",
"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } },
"features": [
{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 0.0, 0.0 ], [ 0.0, 1.0 ], [ 1.0, 1.0 ], [ 1.0, 0.0 ], [ 0.0, 0.0 ] ] ] } }
]
}

Просмотреть файл

@ -0,0 +1,7 @@
{
"type": "FeatureCollection",
"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } },
"features": [
{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 0.0, 0.0 ], [ 0.0, 1.0 ], [ 1.0, 1.0 ], [ 1.0, 0.0 ], [ 0.0, 0.0 ] ] ] } }
]
}

Просмотреть файл

@ -0,0 +1,7 @@
{
"type": "FeatureCollection",
"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } },
"features": [
{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 0.0, 0.0 ], [ 0.0, 1.0 ], [ 1.0, 1.0 ], [ 1.0, 0.0 ], [ 0.0, 0.0 ] ] ] } }
]
}

Просмотреть файл

@ -0,0 +1,7 @@
{
"type": "FeatureCollection",
"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } },
"features": [
{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 0.0, 0.0 ], [ 0.0, 1.0 ], [ 1.0, 1.0 ], [ 1.0, 0.0 ], [ 0.0, 0.0 ] ] ] } }
]
}

Двоичные данные
tests/data/naip/m_3807511_ne_18_060_20190605.tif Normal file

Двоичный файл не отображается.

Просмотреть файл

Просмотреть файл

@ -14,7 +14,12 @@ from _pytest.monkeypatch import MonkeyPatch
from rasterio.crs import CRS
import torchgeo.datasets.utils
from torchgeo.datasets import BoundingBox, CanadianBuildingFootprints, ZipDataset
from torchgeo.datasets import (
BoundingBox,
CanadianBuildingFootprints,
IntersectionDataset,
UnionDataset,
)
def download_url(url: str, root: str, *args: str) -> None:
@ -66,9 +71,13 @@ class TestCanadianBuildingFootprints:
assert isinstance(x["crs"], CRS)
assert isinstance(x["mask"], torch.Tensor)
def test_add(self, dataset: CanadianBuildingFootprints) -> None:
ds = dataset + dataset
assert isinstance(ds, ZipDataset)
def test_and(self, dataset: CanadianBuildingFootprints) -> None:
ds = dataset & dataset
assert isinstance(ds, IntersectionDataset)
def test_or(self, dataset: CanadianBuildingFootprints) -> None:
ds = dataset | dataset
assert isinstance(ds, UnionDataset)
def test_already_downloaded(self, dataset: CanadianBuildingFootprints) -> None:
CanadianBuildingFootprints(root=dataset.root, download=True)

Просмотреть файл

@ -16,7 +16,7 @@ from _pytest.monkeypatch import MonkeyPatch
from rasterio.crs import CRS
import torchgeo.datasets.utils
from torchgeo.datasets import CDL, BoundingBox, ZipDataset
from torchgeo.datasets import CDL, BoundingBox, IntersectionDataset, UnionDataset
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@ -51,15 +51,19 @@ class TestCDL:
assert isinstance(x["crs"], CRS)
assert isinstance(x["mask"], torch.Tensor)
def test_add(self, dataset: CDL) -> None:
ds = dataset + dataset
assert isinstance(ds, ZipDataset)
def test_and(self, dataset: CDL) -> None:
ds = dataset & dataset
assert isinstance(ds, IntersectionDataset)
def test_or(self, dataset: CDL) -> None:
ds = dataset | dataset
assert isinstance(ds, UnionDataset)
def test_full_year(self, dataset: CDL) -> None:
bbox = dataset.bounds
time = datetime(2021, 6, 1).timestamp()
query = BoundingBox(bbox.minx, bbox.maxx, bbox.miny, bbox.maxy, time, time)
next(dataset.index.intersection(query))
next(dataset.index.intersection(tuple(query)))
def test_already_extracted(self, dataset: CDL) -> None:
CDL(root=dataset.root, download=True)

Просмотреть файл

@ -20,7 +20,8 @@ from torchgeo.datasets import (
Chesapeake13,
ChesapeakeCVPR,
ChesapeakeCVPRDataModule,
ZipDataset,
IntersectionDataset,
UnionDataset,
)
@ -55,9 +56,13 @@ class TestChesapeake13:
assert isinstance(x["crs"], CRS)
assert isinstance(x["mask"], torch.Tensor)
def test_add(self, dataset: Chesapeake13) -> None:
ds = dataset + dataset
assert isinstance(ds, ZipDataset)
def test_and(self, dataset: Chesapeake13) -> None:
ds = dataset & dataset
assert isinstance(ds, IntersectionDataset)
def test_or(self, dataset: Chesapeake13) -> None:
ds = dataset | dataset
assert isinstance(ds, UnionDataset)
def test_already_downloaded(self, dataset: Chesapeake13) -> None:
Chesapeake13(root=dataset.root, download=True)
@ -128,9 +133,13 @@ class TestChesapeakeCVPR:
assert isinstance(x["crs"], CRS)
assert isinstance(x["mask"], torch.Tensor)
def test_add(self, dataset: ChesapeakeCVPR) -> None:
ds = dataset + dataset
assert isinstance(ds, ZipDataset)
def test_and(self, dataset: ChesapeakeCVPR) -> None:
ds = dataset & dataset
assert isinstance(ds, IntersectionDataset)
def test_or(self, dataset: ChesapeakeCVPR) -> None:
ds = dataset | dataset
assert isinstance(ds, UnionDataset)
def test_already_extracted(self, dataset: ChesapeakeCVPR) -> None:
ChesapeakeCVPR(root=dataset.root, download=True)

Просмотреть файл

@ -13,14 +13,17 @@ from rasterio.crs import CRS
from torch.utils.data import ConcatDataset
from torchgeo.datasets import (
NAIP,
BoundingBox,
CanadianBuildingFootprints,
GeoDataset,
Landsat8,
IntersectionDataset,
RasterDataset,
Sentinel2,
UnionDataset,
VectorDataset,
VisionClassificationDataset,
VisionDataset,
ZipDataset,
)
@ -32,8 +35,8 @@ class CustomGeoDataset(GeoDataset):
res: float = 1,
) -> None:
super().__init__()
self.index.insert(0, bounds)
self.crs = crs
self.index.insert(0, tuple(bounds))
self._crs = crs
self.res = res
def __getitem__(self, query: BoundingBox) -> Dict[str, BoundingBox]:
@ -60,28 +63,56 @@ class TestGeoDataset:
def test_len(self, dataset: GeoDataset) -> None:
assert len(dataset) == 1
def test_add_two(self) -> None:
@pytest.mark.parametrize("crs", [CRS.from_epsg(3005), CRS.from_epsg(32616)])
def test_crs(self, dataset: GeoDataset, crs: CRS) -> None:
dataset.crs = crs
def test_and_two(self) -> None:
ds1 = CustomGeoDataset()
ds2 = CustomGeoDataset()
dataset = ds1 + ds2
assert isinstance(dataset, ZipDataset)
assert len(dataset) == 2
dataset = ds1 & ds2
assert isinstance(dataset, IntersectionDataset)
assert len(dataset) == 1
def test_add_three(self) -> None:
def test_and_three(self) -> None:
ds1 = CustomGeoDataset()
ds2 = CustomGeoDataset()
ds3 = CustomGeoDataset()
dataset = ds1 + ds2 + ds3
assert isinstance(dataset, ZipDataset)
assert len(dataset) == 3
dataset = ds1 & ds2 & ds3
assert isinstance(dataset, IntersectionDataset)
assert len(dataset) == 1
def test_add_four(self) -> None:
def test_and_four(self) -> None:
ds1 = CustomGeoDataset()
ds2 = CustomGeoDataset()
ds3 = CustomGeoDataset()
ds4 = CustomGeoDataset()
dataset = (ds1 + ds2) + (ds3 + ds4)
assert isinstance(dataset, ZipDataset)
dataset = (ds1 & ds2) & (ds3 & ds4)
assert isinstance(dataset, IntersectionDataset)
assert len(dataset) == 1
def test_or_two(self) -> None:
ds1 = CustomGeoDataset()
ds2 = CustomGeoDataset()
dataset = ds1 | ds2
assert isinstance(dataset, UnionDataset)
assert len(dataset) == 2
def test_or_three(self) -> None:
ds1 = CustomGeoDataset()
ds2 = CustomGeoDataset()
ds3 = CustomGeoDataset()
dataset = ds1 | ds2 | ds3
assert isinstance(dataset, UnionDataset)
assert len(dataset) == 3
def test_or_four(self) -> None:
ds1 = CustomGeoDataset()
ds2 = CustomGeoDataset()
ds3 = CustomGeoDataset()
ds4 = CustomGeoDataset()
dataset = (ds1 | ds2) | (ds3 | ds4)
assert isinstance(dataset, UnionDataset)
assert len(dataset) == 4
def test_str(self, dataset: GeoDataset) -> None:
@ -94,34 +125,75 @@ class TestGeoDataset:
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
GeoDataset() # type: ignore[abstract]
def test_add_vision(self, dataset: GeoDataset) -> None:
def test_and_vision(self, dataset: GeoDataset) -> None:
ds2 = CustomVisionDataset()
with pytest.raises(ValueError, match="ZipDataset only supports GeoDatasets"):
dataset + ds2 # type: ignore[operator]
with pytest.raises(
ValueError, match="IntersectionDataset only supports GeoDatasets"
):
dataset & ds2 # type: ignore[operator]
class TestRasterDataset:
@pytest.fixture(params=[True, False])
def dataset(self, request: SubRequest) -> Landsat8:
root = os.path.join("tests", "data", "landsat8")
bands = ["B1", "B2", "B3", "B4", "B5", "B6", "B7"]
def naip(self, request: SubRequest) -> NAIP:
root = os.path.join("tests", "data", "naip")
crs = CRS.from_epsg(3005)
transforms = nn.Identity() # type: ignore[attr-defined]
cache = request.param
return Landsat8(root, bands=bands, crs=crs, transforms=transforms, cache=cache)
return NAIP(root, crs=crs, transforms=transforms, cache=cache)
def test_getitem(self, dataset: Landsat8) -> None:
x = dataset[dataset.bounds]
@pytest.fixture(params=[True, False])
def sentinel(self, request: SubRequest) -> Sentinel2:
root = os.path.join("tests", "data", "sentinel2")
bands = ["B01", "B02", "B03", "B04", "B05", "B06", "B07", "B08", "B09", "B11"]
transforms = nn.Identity() # type: ignore[attr-defined]
cache = request.param
return Sentinel2(root, bands=bands, transforms=transforms, cache=cache)
def test_getitem_single_file(self, naip: NAIP) -> None:
x = naip[naip.bounds]
assert isinstance(x, dict)
assert isinstance(x["crs"], CRS)
assert isinstance(x["image"], torch.Tensor)
def test_getitem_separate_files(self, sentinel: Sentinel2) -> None:
x = sentinel[sentinel.bounds]
assert isinstance(x, dict)
assert isinstance(x["crs"], CRS)
assert isinstance(x["image"], torch.Tensor)
def test_invalid_query(self, sentinel: Sentinel2) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
with pytest.raises(
IndexError, match="query: .* not found in index with bounds: .*"
):
sentinel[query]
def test_no_data(self, tmp_path: Path) -> None:
with pytest.raises(FileNotFoundError, match="No RasterDataset data was found"):
RasterDataset(str(tmp_path))
class TestVectorDataset:
@pytest.fixture
def dataset(self) -> CanadianBuildingFootprints:
root = os.path.join("tests", "data", "cbf")
transforms = nn.Identity() # type: ignore[attr-defined]
return CanadianBuildingFootprints(root, res=0.1, transforms=transforms)
def test_getitem(self, dataset: CanadianBuildingFootprints) -> None:
x = dataset[dataset.bounds]
assert isinstance(x, dict)
assert isinstance(x["crs"], CRS)
assert isinstance(x["mask"], torch.Tensor)
def test_invalid_query(self, dataset: CanadianBuildingFootprints) -> None:
query = BoundingBox(2, 2, 2, 2, 2, 2)
with pytest.raises(
IndexError, match="query: .* not found in index with bounds:"
):
dataset[query]
def test_no_data(self, tmp_path: Path) -> None:
with pytest.raises(FileNotFoundError, match="No VectorDataset data was found"):
VectorDataset(str(tmp_path))
@ -174,7 +246,8 @@ class TestVisionDataset:
class TestVisionClassificationDataset:
@pytest.fixture(scope="class")
def dataset(self, root: str) -> VisionClassificationDataset:
return VisionClassificationDataset(root)
transforms = nn.Identity() # type: ignore[attr-defined]
return VisionClassificationDataset(root, transforms=transforms)
@pytest.fixture(scope="class")
def root(self) -> str:
@ -220,51 +293,99 @@ class TestVisionClassificationDataset:
assert "size: 2" in str(dataset)
class TestZipDataset:
class TestIntersectionDataset:
@pytest.fixture(scope="class")
def dataset(self) -> ZipDataset:
def dataset(self) -> IntersectionDataset:
ds1 = CustomGeoDataset()
ds2 = CustomGeoDataset()
return ZipDataset([ds1, ds2])
return ds1 & ds2
def test_getitem(self, dataset: ZipDataset) -> None:
def test_getitem(self, dataset: IntersectionDataset) -> None:
query = BoundingBox(0, 1, 2, 3, 4, 5)
assert dataset[query] == {"index": query}
def test_len(self, dataset: ZipDataset) -> None:
def test_len(self, dataset: IntersectionDataset) -> None:
assert len(dataset) == 1
def test_str(self, dataset: IntersectionDataset) -> None:
out = str(dataset)
assert "type: IntersectionDataset" in out
assert "bbox: BoundingBox" in out
assert "size: 1" in out
def test_vision_dataset(self) -> None:
ds1 = CustomVisionDataset()
ds2 = CustomVisionDataset()
with pytest.raises(
ValueError, match="IntersectionDataset only supports GeoDatasets"
):
IntersectionDataset(ds1, ds2) # type: ignore[arg-type]
def test_different_crs(self) -> None:
ds1 = CustomGeoDataset(crs=CRS.from_epsg(3005))
ds2 = CustomGeoDataset(crs=CRS.from_epsg(32616))
IntersectionDataset(ds1, ds2)
def test_different_res(self) -> None:
ds1 = CustomGeoDataset(res=1)
ds2 = CustomGeoDataset(res=2)
IntersectionDataset(ds1, ds2)
def test_no_overlap(self) -> None:
ds1 = CustomGeoDataset(BoundingBox(0, 1, 2, 3, 4, 5))
ds2 = CustomGeoDataset(BoundingBox(6, 7, 8, 9, 10, 11))
IntersectionDataset(ds1, ds2)
def test_invalid_query(self, dataset: IntersectionDataset) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
with pytest.raises(
IndexError, match="query: .* not found in index with bounds:"
):
dataset[query]
class TestUnionDataset:
@pytest.fixture(scope="class")
def dataset(self) -> UnionDataset:
ds1 = CustomGeoDataset()
ds2 = CustomGeoDataset()
return ds1 | ds2
def test_getitem(self, dataset: UnionDataset) -> None:
query = BoundingBox(0, 1, 2, 3, 4, 5)
assert dataset[query] == {"index": query}
def test_len(self, dataset: UnionDataset) -> None:
assert len(dataset) == 2
def test_str(self, dataset: ZipDataset) -> None:
def test_str(self, dataset: UnionDataset) -> None:
out = str(dataset)
assert "type: ZipDataset" in out
assert "type: UnionDataset" in out
assert "bbox: BoundingBox" in out
assert "size: 2" in out
def test_vision_dataset(self) -> None:
ds1 = CustomVisionDataset()
ds2 = CustomVisionDataset()
with pytest.raises(ValueError, match="ZipDataset only supports GeoDatasets"):
ZipDataset([ds1, ds2]) # type: ignore[list-item]
with pytest.raises(ValueError, match="UnionDataset only supports GeoDatasets"):
UnionDataset(ds1, ds2) # type: ignore[arg-type]
def test_different_crs(self) -> None:
ds1 = CustomGeoDataset(crs=CRS.from_epsg(3005))
ds2 = CustomGeoDataset(crs=CRS.from_epsg(32616))
with pytest.raises(ValueError, match="Datasets must be in the same CRS"):
ZipDataset([ds1, ds2])
UnionDataset(ds1, ds2)
def test_different_res(self) -> None:
ds1 = CustomGeoDataset(res=1)
ds2 = CustomGeoDataset(res=2)
with pytest.raises(ValueError, match="Datasets must have the same resolution"):
ZipDataset([ds1, ds2])
UnionDataset(ds1, ds2)
def test_no_overlap(self) -> None:
ds1 = CustomGeoDataset(BoundingBox(0, 1, 2, 3, 4, 5))
ds2 = CustomGeoDataset(BoundingBox(6, 7, 8, 9, 10, 11))
with pytest.raises(ValueError, match="Datasets have no overlap"):
ZipDataset([ds1, ds2])
UnionDataset(ds1, ds2)
def test_invalid_query(self, dataset: ZipDataset) -> None:
def test_invalid_query(self, dataset: UnionDataset) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
with pytest.raises(
IndexError, match="query: .* not found in index with bounds:"

Просмотреть файл

@ -12,7 +12,7 @@ import torch.nn as nn
from _pytest.monkeypatch import MonkeyPatch
from rasterio.crs import CRS
from torchgeo.datasets import BoundingBox, Landsat8, ZipDataset
from torchgeo.datasets import BoundingBox, IntersectionDataset, Landsat8, UnionDataset
class TestLandsat8:
@ -35,9 +35,13 @@ class TestLandsat8:
assert isinstance(x["crs"], CRS)
assert isinstance(x["image"], torch.Tensor)
def test_add(self, dataset: Landsat8) -> None:
ds = dataset + dataset
assert isinstance(ds, ZipDataset)
def test_and(self, dataset: Landsat8) -> None:
ds = dataset & dataset
assert isinstance(ds, IntersectionDataset)
def test_or(self, dataset: Landsat8) -> None:
ds = dataset | dataset
assert isinstance(ds, UnionDataset)
def test_plot(self, dataset: Landsat8) -> None:
query = dataset.bounds

Просмотреть файл

@ -12,7 +12,13 @@ import torch.nn as nn
from _pytest.monkeypatch import MonkeyPatch
from rasterio.crs import CRS
from torchgeo.datasets import NAIP, BoundingBox, NAIPChesapeakeDataModule, ZipDataset
from torchgeo.datasets import (
NAIP,
BoundingBox,
IntersectionDataset,
NAIPChesapeakeDataModule,
UnionDataset,
)
class TestNAIP:
@ -31,9 +37,13 @@ class TestNAIP:
assert isinstance(x["crs"], CRS)
assert isinstance(x["image"], torch.Tensor)
def test_add(self, dataset: NAIP) -> None:
ds = dataset + dataset
assert isinstance(ds, ZipDataset)
def test_and(self, dataset: NAIP) -> None:
ds = dataset & dataset
assert isinstance(ds, IntersectionDataset)
def test_or(self, dataset: NAIP) -> None:
ds = dataset | dataset
assert isinstance(ds, UnionDataset)
def test_plot(self, dataset: NAIP) -> None:
query = dataset.bounds

Просмотреть файл

@ -9,7 +9,7 @@ import torch
import torch.nn as nn
from rasterio.crs import CRS
from torchgeo.datasets import BoundingBox, Sentinel2, ZipDataset
from torchgeo.datasets import BoundingBox, IntersectionDataset, Sentinel2, UnionDataset
class TestSentinel2:
@ -29,9 +29,13 @@ class TestSentinel2:
assert isinstance(x["crs"], CRS)
assert isinstance(x["image"], torch.Tensor)
def test_add(self, dataset: Sentinel2) -> None:
ds = dataset + dataset
assert isinstance(ds, ZipDataset)
def test_and(self, dataset: Sentinel2) -> None:
ds = dataset & dataset
assert isinstance(ds, IntersectionDataset)
def test_or(self, dataset: Sentinel2) -> None:
ds = dataset | dataset
assert isinstance(ds, UnionDataset)
def test_no_data(self, tmp_path: Path) -> None:
with pytest.raises(FileNotFoundError, match="No Sentinel2 data was found in "):

Просмотреть файл

@ -6,11 +6,12 @@ import glob
import math
import os
import pickle
import re
import shutil
import sys
from datetime import datetime
from pathlib import Path
from typing import Any, Generator, Tuple
from typing import Any, Dict, Generator, List, Tuple
import numpy as np
import pytest
@ -22,14 +23,16 @@ from torch.utils.data import TensorDataset
import torchgeo.datasets.utils
from torchgeo.datasets.utils import (
BoundingBox,
collate_dict,
concat_samples,
dataset_split,
disambiguate_timestamp,
download_and_extract_archive,
download_radiant_mlhub_collection,
download_radiant_mlhub_dataset,
extract_archive,
merge_samples,
percentile_normalization,
stack_samples,
working_dir,
)
@ -164,7 +167,13 @@ def test_missing_radiant_mlhub(mock_missing_module: None) -> None:
class TestBoundingBox:
def test_new_init(self) -> None:
def test_repr_str(self) -> None:
bbox = BoundingBox(0, 1, 2.0, 3.0, -5, -4)
expected = "BoundingBox(minx=0, maxx=1, miny=2.0, maxy=3.0, mint=-5, maxt=-4)"
assert repr(bbox) == expected
assert str(bbox) == expected
def test_getitem(self) -> None:
bbox = BoundingBox(0, 1, 2, 3, 4, 5)
assert bbox.minx == 0
@ -176,13 +185,143 @@ class TestBoundingBox:
assert bbox[0] == 0
assert bbox[-1] == 5
assert bbox[1:3] == (1, 2)
assert bbox[1:3] == [1, 2]
def test_repr_str(self) -> None:
bbox = BoundingBox(0, 1, 2.0, 3.0, -5, -4)
expected = "BoundingBox(minx=0, maxx=1, miny=2.0, maxy=3.0, mint=-5, maxt=-4)"
assert repr(bbox) == expected
assert str(bbox) == expected
def test_iter(self) -> None:
bbox = BoundingBox(0, 1, 2, 3, 4, 5)
assert tuple(bbox) == (0, 1, 2, 3, 4, 5)
i = 0
for _ in bbox:
i += 1
assert i == 6
@pytest.mark.parametrize(
"test_input,expected",
[
# Same box
((0, 1, 0, 1, 0, 1), True),
((0.0, 1.0, 0.0, 1.0, 0.0, 1.0), True),
# bbox1 strictly within bbox2
((-1, 2, -1, 2, -1, 2), True),
# bbox2 strictly within bbox1
((0.25, 0.75, 0.25, 0.75, 0.25, 0.75), False),
# One corner of bbox1 within bbox2
((0.5, 1.5, 0.5, 1.5, 0.5, 1.5), False),
((0.5, 1.5, -0.5, 0.5, 0.5, 1.5), False),
((0.5, 1.5, 0.5, 1.5, -0.5, 0.5), False),
((0.5, 1.5, -0.5, 0.5, -0.5, 0.5), False),
((-0.5, 0.5, 0.5, 1.5, 0.5, 1.5), False),
((-0.5, 0.5, -0.5, 0.5, 0.5, 1.5), False),
((-0.5, 0.5, 0.5, 1.5, -0.5, 0.5), False),
((-0.5, 0.5, -0.5, 0.5, -0.5, 0.5), False),
# No overlap
((0.5, 1.5, 0.5, 1.5, 2, 3), False),
((0.5, 1.5, 2, 3, 0.5, 1.5), False),
((2, 3, 0.5, 1.5, 0.5, 1.5), False),
((2, 3, 2, 3, 2, 3), False),
],
)
def test_contains(
self,
test_input: Tuple[float, float, float, float, float, float],
expected: bool,
) -> None:
bbox1 = BoundingBox(0, 1, 0, 1, 0, 1)
bbox2 = BoundingBox(*test_input)
assert (bbox1 in bbox2) == expected
@pytest.mark.parametrize(
"test_input,expected",
[
# Same box
((0, 1, 0, 1, 0, 1), (0, 1, 0, 1, 0, 1)),
((0.0, 1.0, 0.0, 1.0, 0.0, 1.0), (0, 1, 0, 1, 0, 1)),
# bbox1 strictly within bbox2
((-1, 2, -1, 2, -1, 2), (-1, 2, -1, 2, -1, 2)),
# bbox2 strictly within bbox1
((0.25, 0.75, 0.25, 0.75, 0.25, 0.75), (0, 1, 0, 1, 0, 1)),
# One corner of bbox1 within bbox2
((0.5, 1.5, 0.5, 1.5, 0.5, 1.5), (0, 1.5, 0, 1.5, 0, 1.5)),
((0.5, 1.5, -0.5, 0.5, 0.5, 1.5), (0, 1.5, -0.5, 1, 0, 1.5)),
((0.5, 1.5, 0.5, 1.5, -0.5, 0.5), (0, 1.5, 0, 1.5, -0.5, 1)),
((0.5, 1.5, -0.5, 0.5, -0.5, 0.5), (0, 1.5, -0.5, 1, -0.5, 1)),
((-0.5, 0.5, 0.5, 1.5, 0.5, 1.5), (-0.5, 1, 0, 1.5, 0, 1.5)),
((-0.5, 0.5, -0.5, 0.5, 0.5, 1.5), (-0.5, 1, -0.5, 1, 0, 1.5)),
((-0.5, 0.5, 0.5, 1.5, -0.5, 0.5), (-0.5, 1, 0, 1.5, -0.5, 1)),
((-0.5, 0.5, -0.5, 0.5, -0.5, 0.5), (-0.5, 1, -0.5, 1, -0.5, 1)),
# No overlap
((0.5, 1.5, 0.5, 1.5, 2, 3), (0, 1.5, 0, 1.5, 0, 3)),
((0.5, 1.5, 2, 3, 0.5, 1.5), (0, 1.5, 0, 3, 0, 1.5)),
((2, 3, 0.5, 1.5, 0.5, 1.5), (0, 3, 0, 1.5, 0, 1.5)),
((2, 3, 2, 3, 2, 3), (0, 3, 0, 3, 0, 3)),
],
)
def test_or(
self,
test_input: Tuple[float, float, float, float, float, float],
expected: Tuple[float, float, float, float, float, float],
) -> None:
bbox1 = BoundingBox(0, 1, 0, 1, 0, 1)
bbox2 = BoundingBox(*test_input)
bbox3 = BoundingBox(*expected)
assert (bbox1 | bbox2) == bbox3
@pytest.mark.parametrize(
"test_input,expected",
[
# Same box
((0, 1, 0, 1, 0, 1), (0, 1, 0, 1, 0, 1)),
((0.0, 1.0, 0.0, 1.0, 0.0, 1.0), (0, 1, 0, 1, 0, 1)),
# bbox1 strictly within bbox2
((-1, 2, -1, 2, -1, 2), (0, 1, 0, 1, 0, 1)),
# bbox2 strictly within bbox1
(
(0.25, 0.75, 0.25, 0.75, 0.25, 0.75),
(0.25, 0.75, 0.25, 0.75, 0.25, 0.75),
),
# One corner of bbox1 within bbox2
((0.5, 1.5, 0.5, 1.5, 0.5, 1.5), (0.5, 1, 0.5, 1, 0.5, 1)),
((0.5, 1.5, -0.5, 0.5, 0.5, 1.5), (0.5, 1, 0, 0.5, 0.5, 1)),
((0.5, 1.5, 0.5, 1.5, -0.5, 0.5), (0.5, 1, 0.5, 1, 0, 0.5)),
((0.5, 1.5, -0.5, 0.5, -0.5, 0.5), (0.5, 1, 0, 0.5, 0, 0.5)),
((-0.5, 0.5, 0.5, 1.5, 0.5, 1.5), (0, 0.5, 0.5, 1, 0.5, 1)),
((-0.5, 0.5, -0.5, 0.5, 0.5, 1.5), (0, 0.5, 0, 0.5, 0.5, 1)),
((-0.5, 0.5, 0.5, 1.5, -0.5, 0.5), (0, 0.5, 0.5, 1, 0, 0.5)),
((-0.5, 0.5, -0.5, 0.5, -0.5, 0.5), (0, 0.5, 0, 0.5, 0, 0.5)),
],
)
def test_and_intersection(
self,
test_input: Tuple[float, float, float, float, float, float],
expected: Tuple[float, float, float, float, float, float],
) -> None:
bbox1 = BoundingBox(0, 1, 0, 1, 0, 1)
bbox2 = BoundingBox(*test_input)
bbox3 = BoundingBox(*expected)
assert (bbox1 & bbox2) == bbox3
@pytest.mark.parametrize(
"test_input",
[
# No overlap
(0.5, 1.5, 0.5, 1.5, 2, 3),
(0.5, 1.5, 2, 3, 0.5, 1.5),
(2, 3, 0.5, 1.5, 0.5, 1.5),
(2, 3, 2, 3, 2, 3),
],
)
def test_and_no_intersection(
self, test_input: Tuple[float, float, float, float, float, float]
) -> None:
bbox1 = BoundingBox(0, 1, 0, 1, 0, 1)
bbox2 = BoundingBox(*test_input)
with pytest.raises(
ValueError,
match=re.escape(f"Bounding boxes {bbox1} and {bbox2} do not overlap"),
):
bbox1 & bbox2
@pytest.mark.parametrize(
"test_input,expected",
@ -306,26 +445,103 @@ def test_disambiguate_timestamp(
assert math.isclose(maxt, max_datetime)
def test_collate_dict() -> None:
samples = [
{
"foo": torch.tensor(1), # type: ignore[attr-defined]
"bar": torch.tensor(2), # type: ignore[attr-defined]
"crs": CRS.from_epsg(3005),
},
{
"foo": torch.tensor(3), # type: ignore[attr-defined]
"bar": torch.tensor(4), # type: ignore[attr-defined]
"crs": CRS.from_epsg(3005),
},
]
sample = collate_dict(samples)
assert torch.allclose( # type: ignore[attr-defined]
sample["foo"], torch.tensor([1, 3]) # type: ignore[attr-defined]
)
assert torch.allclose( # type: ignore[attr-defined]
sample["bar"], torch.tensor([2, 4]) # type: ignore[attr-defined]
)
class TestCollateFunctionsMatchingKeys:
@pytest.fixture(scope="class")
def samples(self) -> List[Dict[str, Any]]:
return [
{
"image": torch.tensor([1, 2, 0]), # type: ignore[attr-defined]
"crs": CRS.from_epsg(2000),
},
{
"image": torch.tensor([0, 0, 3]), # type: ignore[attr-defined]
"crs": CRS.from_epsg(2001),
},
]
def test_stack_samples(self, samples: List[Dict[str, Any]]) -> None:
sample = stack_samples(samples)
assert sample["image"].size() == torch.Size( # type: ignore[attr-defined]
[2, 3]
)
assert torch.allclose( # type: ignore[attr-defined]
sample["image"],
torch.tensor([[1, 2, 0], [0, 0, 3]]), # type: ignore[attr-defined]
)
assert sample["crs"] == [CRS.from_epsg(2000), CRS.from_epsg(2001)]
def test_concat_samples(self, samples: List[Dict[str, Any]]) -> None:
sample = concat_samples(samples)
assert sample["image"].size() == torch.Size([6]) # type: ignore[attr-defined]
assert torch.allclose( # type: ignore[attr-defined]
sample["image"],
torch.tensor([1, 2, 0, 0, 0, 3]), # type: ignore[attr-defined]
)
assert sample["crs"] == CRS.from_epsg(2000)
def test_merge_samples(self, samples: List[Dict[str, Any]]) -> None:
sample = merge_samples(samples)
assert sample["image"].size() == torch.Size([3]) # type: ignore[attr-defined]
assert torch.allclose( # type: ignore[attr-defined]
sample["image"], torch.tensor([1, 2, 3]) # type: ignore[attr-defined]
)
assert sample["crs"] == CRS.from_epsg(2001)
class TestCollateFunctionsDifferingKeys:
@pytest.fixture(scope="class")
def samples(self) -> List[Dict[str, Any]]:
return [
{
"image": torch.tensor([1, 2, 0]), # type: ignore[attr-defined]
"crs1": CRS.from_epsg(2000),
},
{
"mask": torch.tensor([0, 0, 3]), # type: ignore[attr-defined]
"crs2": CRS.from_epsg(2001),
},
]
def test_stack_samples(self, samples: List[Dict[str, Any]]) -> None:
sample = stack_samples(samples)
assert sample["image"].size() == torch.Size( # type: ignore[attr-defined]
[1, 3]
)
assert sample["mask"].size() == torch.Size([1, 3]) # type: ignore[attr-defined]
assert torch.allclose( # type: ignore[attr-defined]
sample["image"], torch.tensor([[1, 2, 0]]) # type: ignore[attr-defined]
)
assert torch.allclose( # type: ignore[attr-defined]
sample["mask"], torch.tensor([[0, 0, 3]]) # type: ignore[attr-defined]
)
assert sample["crs1"] == [CRS.from_epsg(2000)]
assert sample["crs2"] == [CRS.from_epsg(2001)]
def test_concat_samples(self, samples: List[Dict[str, Any]]) -> None:
sample = concat_samples(samples)
assert sample["image"].size() == torch.Size([3]) # type: ignore[attr-defined]
assert sample["mask"].size() == torch.Size([3]) # type: ignore[attr-defined]
assert torch.allclose( # type: ignore[attr-defined]
sample["image"], torch.tensor([1, 2, 0]) # type: ignore[attr-defined]
)
assert torch.allclose( # type: ignore[attr-defined]
sample["mask"], torch.tensor([0, 0, 3]) # type: ignore[attr-defined]
)
assert sample["crs1"] == CRS.from_epsg(2000)
assert sample["crs2"] == CRS.from_epsg(2001)
def test_merge_samples(self, samples: List[Dict[str, Any]]) -> None:
sample = merge_samples(samples)
assert sample["image"].size() == torch.Size([3]) # type: ignore[attr-defined]
assert sample["mask"].size() == torch.Size([3]) # type: ignore[attr-defined]
assert torch.allclose( # type: ignore[attr-defined]
sample["image"], torch.tensor([1, 2, 0]) # type: ignore[attr-defined]
)
assert torch.allclose( # type: ignore[attr-defined]
sample["mask"], torch.tensor([0, 0, 3]) # type: ignore[attr-defined]
)
assert sample["crs1"] == CRS.from_epsg(2000)
assert sample["crs2"] == CRS.from_epsg(2001)
def test_existing_directory(tmp_path: Path) -> None:

Просмотреть файл

@ -28,7 +28,7 @@ class CustomBatchGeoSampler(BatchGeoSampler):
class CustomGeoDataset(GeoDataset):
def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 1) -> None:
super().__init__()
self.crs = crs
self._crs = crs
self.res = res
def __getitem__(self, query: BoundingBox) -> Dict[str, BoundingBox]:
@ -56,8 +56,9 @@ class TestBatchGeoSampler:
continue
def test_abstract(self) -> None:
ds = CustomGeoDataset()
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
BatchGeoSampler(None) # type: ignore[abstract]
BatchGeoSampler(ds) # type: ignore[abstract]
class TestRandomBatchGeoSampler:
@ -85,6 +86,16 @@ class TestRandomBatchGeoSampler:
def test_len(self, sampler: RandomBatchGeoSampler) -> None:
assert len(sampler) == sampler.length // sampler.batch_size
def test_roi(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
ds.index.insert(1, (5, 15, 5, 15, 5, 15))
roi = BoundingBox(0, 10, 0, 10, 0, 10)
sampler = RandomBatchGeoSampler(ds, 2, 2, 10, roi=roi)
for batch in sampler:
for query in batch:
assert query in roi
@pytest.mark.slow
@pytest.mark.parametrize("num_workers", [0, 1, 2])
def test_dataloader(self, sampler: RandomBatchGeoSampler, num_workers: int) -> None:

Просмотреть файл

@ -28,7 +28,7 @@ class CustomGeoSampler(GeoSampler):
class CustomGeoDataset(GeoDataset):
def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 1) -> None:
super().__init__()
self.crs = crs
self._crs = crs
self.res = res
def __getitem__(self, query: BoundingBox) -> Dict[str, BoundingBox]:
@ -46,6 +46,11 @@ class TestGeoSampler:
def test_len(self, sampler: CustomGeoSampler) -> None:
assert len(sampler) == 2
def test_abstract(self) -> None:
ds = CustomGeoDataset()
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
GeoSampler(ds) # type: ignore[abstract]
@pytest.mark.slow
@pytest.mark.parametrize("num_workers", [0, 1, 2])
def test_dataloader(self, sampler: CustomGeoSampler, num_workers: int) -> None:
@ -54,10 +59,6 @@ class TestGeoSampler:
for _ in dl:
continue
def test_abstract(self) -> None:
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
GeoSampler(None) # type: ignore[abstract]
class TestRandomGeoSampler:
@pytest.fixture(scope="function", params=[3, 4.5, (2, 2), (3, 4.5), (4.5, 3)])
@ -83,6 +84,15 @@ class TestRandomGeoSampler:
def test_len(self, sampler: RandomGeoSampler) -> None:
assert len(sampler) == sampler.length
def test_roi(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
ds.index.insert(1, (5, 15, 5, 15, 5, 15))
roi = BoundingBox(0, 10, 0, 10, 0, 10)
sampler = RandomGeoSampler(ds, 2, 10, roi=roi)
for query in sampler:
assert query in roi
@pytest.mark.slow
@pytest.mark.parametrize("num_workers", [0, 1, 2])
def test_dataloader(self, sampler: RandomGeoSampler, num_workers: int) -> None:
@ -116,6 +126,21 @@ class TestGridGeoSampler:
query.maxt - query.mint, sampler.roi.maxt - sampler.roi.mint
)
def test_len(self, sampler: GridGeoSampler) -> None:
rows = int((10 - sampler.size[0]) // sampler.stride[0]) + 1
cols = int((20 - sampler.size[1]) // sampler.stride[1]) + 1
length = rows * cols * 2
assert len(sampler) == length
def test_roi(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
ds.index.insert(1, (5, 15, 5, 15, 5, 15))
roi = BoundingBox(0, 10, 0, 10, 0, 10)
sampler = GridGeoSampler(ds, 2, 1, roi=roi)
for query in sampler:
assert query in roi
@pytest.mark.slow
@pytest.mark.parametrize("num_workers", [0, 1, 2])
def test_dataloader(self, sampler: GridGeoSampler, num_workers: int) -> None:
@ -123,9 +148,3 @@ class TestGridGeoSampler:
dl = DataLoader(ds, sampler=sampler, num_workers=num_workers)
for _ in dl:
continue
def test_len(self, sampler: GridGeoSampler) -> None:
rows = int((10 - sampler.size[0]) // sampler.stride[0]) + 1
cols = int((20 - sampler.size[1]) // sampler.stride[1]) + 1
length = rows * cols * 2
assert len(sampler) == length

Просмотреть файл

@ -29,11 +29,12 @@ from .etci2021 import ETCI2021, ETCI2021DataModule
from .eurosat import EuroSAT, EuroSATDataModule
from .geo import (
GeoDataset,
IntersectionDataset,
RasterDataset,
UnionDataset,
VectorDataset,
VisionClassificationDataset,
VisionDataset,
ZipDataset,
)
from .gid15 import GID15
from .landcoverai import LandCoverAI, LandCoverAIDataModule
@ -63,7 +64,7 @@ from .sentinel import Sentinel, Sentinel2
from .so2sat import So2Sat, So2SatDataModule
from .spacenet import SpaceNet, SpaceNet1, SpaceNet2, SpaceNet4, SpaceNet5, SpaceNet7
from .ucmerced import UCMerced, UCMercedDataModule
from .utils import BoundingBox, collate_dict
from .utils import BoundingBox, concat_samples, merge_samples, stack_samples
from .vaihingen import Vaihingen2D, Vaihingen2DDataModule
from .xview import XView2, XView2DataModule
from .zuericrop import ZueriCrop
@ -147,14 +148,17 @@ __all__ = (
"ZueriCrop",
# Base classes
"GeoDataset",
"IntersectionDataset",
"RasterDataset",
"UnionDataset",
"VectorDataset",
"VisionDataset",
"VisionClassificationDataset",
"ZipDataset",
# Utilities
"BoundingBox",
"collate_dict",
"concat_samples",
"merge_samples",
"stack_samples",
)
# https://stackoverflow.com/questions/40018681

Просмотреть файл

@ -32,6 +32,7 @@ from .utils import (
download_and_extract_archive,
download_url,
extract_archive,
stack_samples,
)
# https://github.com/pytorch/pytorch/issues/60979
@ -434,7 +435,7 @@ class ChesapeakeCVPR(GeoDataset):
Raises:
IndexError: if query is not found in the index
"""
hits = self.index.intersection(query, objects=True)
hits = self.index.intersection(tuple(query), objects=True)
filepaths = [hit.object for hit in hits]
sample = {"image": [], "mask": [], "crs": self.crs, "bbox": query}
@ -783,7 +784,10 @@ class ChesapeakeCVPRDataModule(LightningDataModule):
length=self.patches_per_tile * len(self.train_dataset),
)
return DataLoader(
self.train_dataset, batch_sampler=sampler, num_workers=self.num_workers
self.train_dataset,
batch_sampler=sampler,
num_workers=self.num_workers,
collate_fn=stack_samples,
)
def val_dataloader(self) -> DataLoader[Any]:
@ -802,6 +806,7 @@ class ChesapeakeCVPRDataModule(LightningDataModule):
batch_size=self.batch_size,
sampler=sampler,
num_workers=self.num_workers,
collate_fn=stack_samples,
)
def test_dataloader(self) -> DataLoader[Any]:
@ -820,4 +825,5 @@ class ChesapeakeCVPRDataModule(LightningDataModule):
batch_size=self.batch_size,
sampler=sampler,
num_workers=self.num_workers,
collate_fn=stack_samples,
)

Просмотреть файл

@ -6,7 +6,6 @@
import abc
import functools
import glob
import math
import os
import re
import sys
@ -16,8 +15,10 @@ import fiona
import fiona.transform
import matplotlib.pyplot as plt
import numpy as np
import pyproj
import rasterio
import rasterio.merge
import shapely
import torch
from rasterio.crs import CRS
from rasterio.io import DatasetReader
@ -29,7 +30,7 @@ from torch.utils.data import Dataset
from torchvision.datasets import ImageFolder
from torchvision.datasets.folder import default_loader as pil_loader
from .utils import BoundingBox, disambiguate_timestamp
from .utils import BoundingBox, concat_samples, disambiguate_timestamp, merge_samples
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
@ -46,20 +47,53 @@ class GeoDataset(Dataset[Dict[str, Any]], abc.ABC):
* :term:`coordinate reference system (CRS)`
* resolution
These kind of datasets are special because they can be combined. For example:
:class:`GeoDataset` is a special class of datasets. Unlike :class:`VisionDataset`,
the presence of geospatial information allows two or more datasets to be combined
based on latitude/longitude. This allows users to do things like:
* Combine Landsat8 and CDL to train a model for crop classification
* Combine NAIP and Chesapeake to train a model for land cover mapping
* Combine image and target labels and sample from both simultaneously
(e.g. Landsat and CDL)
* Combine datasets for multiple image sources for multimodal learning or data fusion
(e.g. Landsat and Sentinel)
This isn't true for :class:`VisionDataset`, where the lack of geospatial information
prohibits swapping image sources or target labels.
These combinations require that all queries are present in *both* datasets,
and can be combined using an :class:`IntersectionDataset`:
.. code-block:: python
dataset = landsat & cdl
Users may also want to:
* Combine datasets for multiple image sources and treat them as equivalent
(e.g. Landsat 7 and Landsat 8)
* Combine datasets for disparate geospatial locations
(e.g. Chesapeake NY and PA)
These combinations require that all queries are present in *at least one* dataset,
and can be combined using a :class:`UnionDataset`:
.. code-block:: python
dataset = landsat7 | landsat8
"""
#: :term:`coordinate reference system (CRS)` for the dataset.
crs: CRS
#: Resolution of the dataset in units of CRS.
res: float
_crs: CRS
# NOTE: according to the Python docs:
#
# * https://docs.python.org/3/library/exceptions.html#NotImplementedError
#
# the correct way to handle __add__ not being supported is to set it to None,
# not to return NotImplemented or raise NotImplementedError. The downside of
# this is that we have no way to explain to a user why they get an error and
# what they should do instead (use __and__ or __or__).
#: :class:`GeoDataset` addition can be ambiguous and is no longer supported.
#: Users should instead use the intersection or union operator.
__add__ = None # type: ignore[assignment]
def __init__(
self, transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None
@ -89,8 +123,8 @@ class GeoDataset(Dataset[Dict[str, Any]], abc.ABC):
IndexError: if query is not found in the index
"""
def __add__(self, other: "GeoDataset") -> "ZipDataset": # type: ignore[override]
"""Merge two GeoDatasets.
def __and__(self, other: "GeoDataset") -> "IntersectionDataset":
"""Take the intersection of two :class:`GeoDataset`.
Args:
other: another dataset
@ -99,11 +133,23 @@ class GeoDataset(Dataset[Dict[str, Any]], abc.ABC):
a single dataset
Raises:
ValueError: if other is not a GeoDataset, or if datasets do not overlap,
or if datasets do not have the same
:term:`coordinate reference system (CRS)`
ValueError: if other is not a :class:`GeoDataset`
"""
return ZipDataset([self, other])
return IntersectionDataset(self, other)
def __or__(self, other: "GeoDataset") -> "UnionDataset":
"""Take the union of two GeoDatasets.
Args:
other: another dataset
Returns:
a single dataset
Raises:
ValueError: if other is not a :class:`GeoDataset`
"""
return UnionDataset(self, other)
def __len__(self) -> int:
"""Return the number of files in the dataset.
@ -135,6 +181,43 @@ class GeoDataset(Dataset[Dict[str, Any]], abc.ABC):
"""
return BoundingBox(*self.index.bounds)
@property
def crs(self) -> CRS:
""":term:`coordinate reference system (CRS)` for the dataset.
Returns:
the :term:`coordinate reference system (CRS)`
"""
return self._crs
@crs.setter
def crs(self, new_crs: CRS) -> None:
"""Change the :term:`coordinate reference system (CRS)` of a GeoDataset.
If ``new_crs == self.crs``, does nothing, otherwise updates the R-tree index.
Args:
new_crs: new :term:`coordinate reference system (CRS)`
"""
if new_crs == self._crs:
return
new_index = Index(interleaved=False, properties=Property(dimension=3))
project = pyproj.Transformer.from_crs(
pyproj.CRS(str(self._crs)), pyproj.CRS(str(new_crs)), always_xy=True
).transform
for hit in self.index.intersection(self.index.bounds, objects=True):
old_minx, old_maxx, old_miny, old_maxy, mint, maxt = hit.bounds
old_box = shapely.geometry.box(old_minx, old_miny, old_maxx, old_maxy)
new_box = shapely.ops.transform(project, old_box)
new_minx, new_miny, new_maxx, new_maxy = new_box.bounds
new_bounds = (new_minx, new_maxx, new_miny, new_maxy, mint, maxt)
new_index.insert(hit.id, new_bounds, hit.object)
self._crs = new_crs
self.index = new_index
class RasterDataset(GeoDataset):
"""Abstract base class for :class:`GeoDataset` stored as raster files."""
@ -253,7 +336,7 @@ class RasterDataset(GeoDataset):
f"No {self.__class__.__name__} data was found in '{root}'"
)
self.crs = cast(CRS, crs)
self._crs = cast(CRS, crs)
self.res = cast(float, res)
def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
@ -268,7 +351,7 @@ class RasterDataset(GeoDataset):
Raises:
IndexError: if query is not found in the index
"""
hits = self.index.intersection(query, objects=True)
hits = self.index.intersection(tuple(query), objects=True)
filepaths = [hit.object for hit in hits]
if not filepaths:
@ -478,7 +561,7 @@ class VectorDataset(GeoDataset):
f"No {self.__class__.__name__} data was found in '{root}'"
)
self.crs = crs
self._crs = crs
def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
"""Retrieve image/mask and metadata indexed by query.
@ -492,7 +575,7 @@ class VectorDataset(GeoDataset):
Raises:
IndexError: if query is not found in the index
"""
hits = self.index.intersection(query, objects=True)
hits = self.index.intersection(tuple(query), objects=True)
filepaths = [hit.object for hit in hits]
if not filepaths:
@ -676,46 +759,80 @@ class VisionClassificationDataset(VisionDataset, ImageFolder): # type: ignore[m
return tensor, label
class ZipDataset(GeoDataset):
"""Dataset for merging two or more GeoDatasets.
class IntersectionDataset(GeoDataset):
"""Dataset representing the intersection of two GeoDatasets.
For example, this allows you to combine an image source like Landsat8 with a target
label like CDL.
This allows users to do things like:
* Combine image and target labels and sample from both simultaneously
(e.g. Landsat and CDL)
* Combine datasets for multiple image sources for multimodal learning or data fusion
(e.g. Landsat and Sentinel)
These combinations require that all queries are present in *both* datasets,
and can be combined using an :class:`IntersectionDataset`:
.. code-block:: python
dataset = landsat & cdl
"""
def __init__(self, datasets: Sequence[GeoDataset]) -> None:
def __init__(
self,
dataset1: GeoDataset,
dataset2: GeoDataset,
collate_fn: Callable[
[Sequence[Dict[str, Any]]], Dict[str, Any]
] = concat_samples,
) -> None:
"""Initialize a new Dataset instance.
Args:
datasets: list of datasets to merge
dataset1: the first dataset
dataset2: the second dataset
collate_fn: function used to collate samples
Raises:
ValueError: if datasets contains non-GeoDatasets, do not overlap, are not in
the same :term:`coordinate reference system (CRS)`, or do not have the
same resolution
ValueError: if either dataset is not a :class:`GeoDataset`
"""
for ds in datasets:
super().__init__()
self.datasets = [dataset1, dataset2]
self.collate_fn = collate_fn
for ds in self.datasets:
if not isinstance(ds, GeoDataset):
raise ValueError("ZipDataset only supports GeoDatasets")
raise ValueError("IntersectionDataset only supports GeoDatasets")
crs = datasets[0].crs
res = datasets[0].res
for ds in datasets:
if ds.crs != crs:
raise ValueError("Datasets must be in the same CRS")
if not math.isclose(ds.res, res):
# TODO: relax this constraint someday
raise ValueError("Datasets must have the same resolution")
self._crs = dataset1.crs
self.res = dataset1.res
self.datasets = datasets
self.crs = crs
self.res = res
# Force dataset2 to have the same CRS/res as dataset1
if dataset1.crs != dataset2.crs:
print(
f"Converting {dataset2.__class__.__name__} CRS from "
f"{dataset2.crs} to {dataset1.crs}"
)
dataset2.crs = dataset1.crs
if dataset1.res != dataset2.res:
print(
f"Converting {dataset2.__class__.__name__} resolution from "
f"{dataset2.res} to {dataset1.res}"
)
dataset2.res = dataset1.res
# Make sure datasets have overlap
try:
self.bounds
except ValueError:
raise ValueError("Datasets have no overlap")
# Merge dataset indices into a single index
self._merge_dataset_indices()
def _merge_dataset_indices(self) -> None:
"""Create a new R-tree out of the individual indices from two datasets."""
i = 0
ds1, ds2 = self.datasets
for hit1 in ds1.index.intersection(ds1.index.bounds, objects=True):
for hit2 in ds2.index.intersection(hit1.bounds, objects=True):
box1 = BoundingBox(*hit1.bounds)
box2 = BoundingBox(*hit2.bounds)
self.index.insert(i, tuple(box1 & box2))
i += 1
def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
"""Retrieve image and metadata indexed by query.
@ -734,21 +851,10 @@ class ZipDataset(GeoDataset):
f"query: {query} not found in index with bounds: {self.bounds}"
)
# TODO: use collate_dict here to concatenate instead of replace.
# For example, if using Landsat + Sentinel + CDL, don't want to remove Landsat
# images and replace with Sentinel images.
sample = {}
for ds in self.datasets:
sample.update(ds[query])
return sample
# All datasets are guaranteed to have a valid query
samples = [ds[query] for ds in self.datasets]
def __len__(self) -> int:
"""Return the number of files in the dataset.
Returns:
length of the dataset
"""
return sum(map(len, self.datasets))
return self.collate_fn(samples)
def __str__(self) -> str:
"""Return the informal string representation of the object.
@ -758,23 +864,117 @@ class ZipDataset(GeoDataset):
"""
return f"""\
{self.__class__.__name__} Dataset
type: ZipDataset
type: IntersectionDataset
bbox: {self.bounds}
size: {len(self)}"""
@property
def bounds(self) -> BoundingBox:
"""Bounds of the index.
class UnionDataset(GeoDataset):
"""Dataset representing the union of two GeoDatasets.
This allows users to do things like:
* Combine datasets for multiple image sources and treat them as equivalent
(e.g. Landsat 7 and Landsat 8)
* Combine datasets for disparate geospatial locations
(e.g. Chesapeake NY and PA)
These combinations require that all queries are present in *at least one* dataset,
and can be combined using a :class:`UnionDataset`:
.. code-block:: python
dataset = landsat7 | landsat8
"""
def __init__(
self,
dataset1: GeoDataset,
dataset2: GeoDataset,
collate_fn: Callable[
[Sequence[Dict[str, Any]]], Dict[str, Any]
] = merge_samples,
) -> None:
"""Initialize a new Dataset instance.
Args:
dataset1: the first dataset
dataset2: the second dataset
collate_fn: function used to collate samples
Raises:
ValueError: if either dataset is not a :class:`GeoDataset`
"""
super().__init__()
self.datasets = [dataset1, dataset2]
self.collate_fn = collate_fn
for ds in self.datasets:
if not isinstance(ds, GeoDataset):
raise ValueError("UnionDataset only supports GeoDatasets")
self._crs = dataset1.crs
self.res = dataset1.res
# Force dataset2 to have the same CRS/res as dataset1
if dataset1.crs != dataset2.crs:
print(
f"Converting {dataset2.__class__.__name__} CRS from "
f"{dataset2.crs} to {dataset1.crs}"
)
dataset2.crs = dataset1.crs
if dataset1.res != dataset2.res:
print(
f"Converting {dataset2.__class__.__name__} resolution from "
f"{dataset2.res} to {dataset1.res}"
)
dataset2.res = dataset1.res
# Merge dataset indices into a single index
self._merge_dataset_indices()
def _merge_dataset_indices(self) -> None:
"""Create a new R-tree out of the individual indices from two datasets."""
i = 0
for ds in self.datasets:
hits = ds.index.intersection(ds.index.bounds, objects=True)
for hit in hits:
self.index.insert(i, hit.bounds)
i += 1
def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
"""Retrieve image and metadata indexed by query.
Args:
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
Returns:
(minx, maxx, miny, maxy, mint, maxt) of the dataset
"""
# We want to compute the intersection of all dataset bounds, not the union
minx = max([ds.bounds[0] for ds in self.datasets])
maxx = min([ds.bounds[1] for ds in self.datasets])
miny = max([ds.bounds[2] for ds in self.datasets])
maxy = min([ds.bounds[3] for ds in self.datasets])
mint = max([ds.bounds[4] for ds in self.datasets])
maxt = min([ds.bounds[5] for ds in self.datasets])
sample of data/labels and metadata at that index
return BoundingBox(minx, maxx, miny, maxy, mint, maxt)
Raises:
IndexError: if query is not within bounds of the index
"""
if not query.intersects(self.bounds):
raise IndexError(
f"query: {query} not found in index with bounds: {self.bounds}"
)
# Not all datasets are guaranteed to have a valid query
samples = []
for ds in self.datasets:
if ds.index.intersection(tuple(query)):
samples.append(ds[query])
return self.collate_fn(samples)
def __str__(self) -> str:
"""Return the informal string representation of the object.
Returns:
informal string representation
"""
return f"""\
{self.__class__.__name__} Dataset
type: UnionDataset
bbox: {self.bounds}
size: {len(self)}"""

Просмотреть файл

@ -12,7 +12,7 @@ from ..samplers.batch import RandomBatchGeoSampler
from ..samplers.single import GridGeoSampler
from .chesapeake import Chesapeake13
from .geo import RasterDataset
from .utils import BoundingBox
from .utils import BoundingBox, stack_samples
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
@ -143,7 +143,7 @@ class NAIPChesapeakeDataModule(pl.LightningDataModule):
chesapeake.res,
transforms=self.naip_transform,
)
self.dataset = chesapeake + naip
self.dataset = chesapeake & naip
# TODO: figure out better train/val/test split
roi = self.dataset.bounds
@ -166,7 +166,10 @@ class NAIPChesapeakeDataModule(pl.LightningDataModule):
training data loader
"""
return DataLoader(
self.dataset, batch_sampler=self.train_sampler, num_workers=self.num_workers
self.dataset,
batch_sampler=self.train_sampler,
num_workers=self.num_workers,
collate_fn=stack_samples,
)
def val_dataloader(self) -> DataLoader[Any]:
@ -180,6 +183,7 @@ class NAIPChesapeakeDataModule(pl.LightningDataModule):
batch_size=self.batch_size,
sampler=self.val_sampler,
num_workers=self.num_workers,
collate_fn=stack_samples,
)
def test_dataloader(self) -> DataLoader[Any]:
@ -193,4 +197,5 @@ class NAIPChesapeakeDataModule(pl.LightningDataModule):
batch_size=self.batch_size,
sampler=self.test_sampler,
num_workers=self.num_workers,
collate_fn=stack_samples,
)

Просмотреть файл

@ -4,6 +4,7 @@
"""Common dataset utilities."""
import bz2
import collections
import contextlib
import gzip
import lzma
@ -11,8 +12,21 @@ import os
import sys
import tarfile
import zipfile
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union, cast
from typing import (
Any,
Dict,
Iterable,
Iterator,
List,
Optional,
Sequence,
Tuple,
Union,
cast,
overload,
)
import numpy as np
import rasterio
@ -30,7 +44,9 @@ __all__ = (
"BoundingBox",
"disambiguate_timestamp",
"working_dir",
"collate_dict",
"stack_samples",
"concat_samples",
"merge_samples",
"rasterio_loader",
"dataset_split",
"sort_sentinel2_bands",
@ -184,97 +200,134 @@ def download_radiant_mlhub_collection(
collection.download(output_dir=download_root, api_key=api_key)
class BoundingBox(Tuple[float, float, float, float, float, float]):
"""Data class for indexing spatiotemporal data.
@dataclass(frozen=True)
class BoundingBox:
"""Data class for indexing spatiotemporal data."""
Attributes:
minx (float): western boundary
maxx (float): eastern boundary
miny (float): southern boundary
maxy (float): northern boundary
mint (float): earliest boundary
maxt (float): latest boundary
"""
#: western boundary
minx: float
#: eastern boundary
maxx: float
#: southern boundary
miny: float
#: northern boundary
maxy: float
#: earliest boundary
mint: float
#: latest boundary
maxt: float
def __new__(
cls,
minx: float,
maxx: float,
miny: float,
maxy: float,
mint: float,
maxt: float,
) -> "BoundingBox":
"""Create a new instance of BoundingBox.
Args:
minx: western boundary
maxx: eastern boundary
miny: southern boundary
maxy: northern boundary
mint: earliest boundary
maxt: latest boundary
def __post_init__(self) -> None:
"""Validate the arguments passed to :meth:`__init__`.
Raises:
ValueError: if bounding box is invalid
(minx > maxx, miny > maxy, or mint > maxt)
"""
if minx > maxx:
raise ValueError(f"Bounding box is invalid: 'minx={minx}' > 'maxx={maxx}'")
if miny > maxy:
raise ValueError(f"Bounding box is invalid: 'miny={miny}' > 'maxy={maxy}'")
if mint > maxt:
raise ValueError(f"Bounding box is invalid: 'mint={mint}' > 'maxt={maxt}'")
if self.minx > self.maxx:
raise ValueError(
f"Bounding box is invalid: 'minx={self.minx}' > 'maxx={self.maxx}'"
)
if self.miny > self.maxy:
raise ValueError(
f"Bounding box is invalid: 'miny={self.miny}' > 'maxy={self.maxy}'"
)
if self.mint > self.maxt:
raise ValueError(
f"Bounding box is invalid: 'mint={self.mint}' > 'maxt={self.maxt}'"
)
# Using super() doesn't work with mypy, see:
# https://stackoverflow.com/q/60611012/5828163
return tuple.__new__(cls, [minx, maxx, miny, maxy, mint, maxt])
# https://github.com/PyCQA/pydocstyle/issues/525
@overload
def __getitem__(self, key: int) -> float: # noqa: D105
pass
def __init__(
self,
minx: float,
maxx: float,
miny: float,
maxy: float,
mint: float,
maxt: float,
) -> None:
"""Initialize a new instance of BoundingBox.
@overload
def __getitem__(self, key: slice) -> List[float]: # noqa: D105
pass
def __getitem__(self, key: Union[int, slice]) -> Union[float, List[float]]:
"""Index the (minx, maxx, miny, maxy, mint, maxt) tuple.
Args:
minx: western boundary
maxx: eastern boundary
miny: southern boundary
maxy: northern boundary
mint: earliest boundary
maxt: latest boundary
"""
self.minx = minx
self.maxx = maxx
self.miny = miny
self.maxy = maxy
self.mint = mint
self.maxt = maxt
def __getnewargs__(self) -> Tuple[float, float, float, float, float, float]:
"""Values passed to the ``__new__()`` method upon unpickling.
key: integer or slice object
Returns:
tuple of bounds
"""
return self.minx, self.maxx, self.miny, self.maxy, self.mint, self.maxt
the value(s) at that index
def __repr__(self) -> str:
"""Return the formal string representation of the object.
Raises:
IndexError: if key is out of bounds
"""
return [self.minx, self.maxx, self.miny, self.maxy, self.mint, self.maxt][key]
def __iter__(self) -> Iterator[float]:
"""Container iterator.
Returns:
formal string representation
iterator object that iterates over all objects in the container
"""
yield from [self.minx, self.maxx, self.miny, self.maxy, self.mint, self.maxt]
def __contains__(self, other: "BoundingBox") -> bool:
"""Whether or not other is within the bounds of this bounding box.
Args:
other: another bounding box
Returns:
True if other is within this bounding box, else False
"""
return (
f"{self.__class__.__name__}(minx={self.minx}, maxx={self.maxx}, "
f"miny={self.miny}, maxy={self.maxy}, mint={self.mint}, maxt={self.maxt})"
(self.minx <= other.minx <= self.maxx)
and (self.minx <= other.maxx <= self.maxx)
and (self.miny <= other.miny <= self.maxy)
and (self.miny <= other.maxy <= self.maxy)
and (self.mint <= other.mint <= self.maxt)
and (self.mint <= other.maxt <= self.maxt)
)
def __or__(self, other: "BoundingBox") -> "BoundingBox":
"""The union operator.
Args:
other: another bounding box
Returns:
the minimum bounding box that contains both self and other
"""
return BoundingBox(
min(self.minx, other.minx),
max(self.maxx, other.maxx),
min(self.miny, other.miny),
max(self.maxy, other.maxy),
min(self.mint, other.mint),
max(self.maxt, other.maxt),
)
def __and__(self, other: "BoundingBox") -> "BoundingBox":
"""The intersection operator.
Args:
other: another bounding box
Returns:
the intersection of self and other
Raises:
ValueError: if self and other do not intersect
"""
try:
return BoundingBox(
max(self.minx, other.minx),
min(self.maxx, other.maxx),
max(self.miny, other.miny),
min(self.maxy, other.maxy),
max(self.mint, other.mint),
min(self.maxt, other.maxt),
)
except ValueError:
raise ValueError(f"Bounding boxes {self} and {other} do not overlap")
def intersects(self, other: "BoundingBox") -> bool:
"""Whether or not two bounding boxes intersect.
@ -370,8 +423,27 @@ def working_dir(dirname: str, create: bool = False) -> Iterator[None]:
os.chdir(cwd)
def collate_dict(samples: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Merge a list of samples to form a mini-batch of Tensors.
def _list_dict_to_dict_list(samples: Iterable[Dict[Any, Any]]) -> Dict[Any, List[Any]]:
"""Convert a list of dictionaries to a dictionary of lists.
Args:
samples: a list of dictionaries
Returns:
a dictionary of lists
"""
collated = collections.defaultdict(list)
for sample in samples:
for key, value in sample.items():
collated[key].append(value)
return collated
def stack_samples(samples: Iterable[Dict[Any, Any]]) -> Dict[Any, Any]:
"""Stack a list of samples along a new axis.
Useful for forming a mini-batch of samples to pass to
:class:`torch.utils.data.DataLoader`.
Args:
samples: list of samples
@ -379,14 +451,55 @@ def collate_dict(samples: List[Dict[str, Any]]) -> Dict[str, Any]:
Returns:
a single sample
"""
collated = {}
for key, value in samples[0].items():
if isinstance(value, Tensor):
collated[key] = torch.stack([sample[key] for sample in samples])
collated: Dict[Any, Any] = _list_dict_to_dict_list(samples)
for key, value in collated.items():
if isinstance(value[0], Tensor):
collated[key] = torch.stack(value)
return collated
def concat_samples(samples: Iterable[Dict[Any, Any]]) -> Dict[Any, Any]:
"""Concatenate a list of samples along an existing axis.
Useful for joining samples in a :class:`torchgeo.datasets.IntersectionDataset`.
Args:
samples: list of samples
Returns:
a single sample
"""
collated: Dict[Any, Any] = _list_dict_to_dict_list(samples)
for key, value in collated.items():
if isinstance(value[0], Tensor):
collated[key] = torch.cat(value) # type: ignore[attr-defined]
else:
collated[key] = [
sample[key] for sample in samples
] # type: ignore[assignment]
collated[key] = value[0]
return collated
def merge_samples(samples: Iterable[Dict[Any, Any]]) -> Dict[Any, Any]:
"""Merge a list of samples.
Useful for joining samples in a :class:`torchgeo.datasets.UnionDataset`.
Args:
samples: list of samples
Returns:
a single sample
"""
collated: Dict[Any, Any] = {}
for sample in samples:
for key, value in sample.items():
if key in collated and isinstance(value, Tensor):
# Take the maximum so that nodata values (zeros) get replaced
# by data values whenever possible
collated[key] = torch.maximum( # type: ignore[attr-defined]
collated[key], value
)
else:
collated[key] = value
return collated

Просмотреть файл

@ -7,6 +7,7 @@ import abc
import random
from typing import Iterator, List, Optional, Tuple, Union
from rtree.index import Index, Property
from torch.utils.data import Sampler
from torchgeo.datasets.geo import GeoDataset
@ -28,6 +29,27 @@ class BatchGeoSampler(Sampler[List[BoundingBox]], abc.ABC):
longitude, height, width, projection, coordinate system, and time.
"""
def __init__(self, dataset: GeoDataset, roi: Optional[BoundingBox] = None) -> None:
"""Initialize a new Sampler instance.
Args:
dataset: dataset to index from
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
(defaults to the bounds of ``dataset.index``)
"""
if roi is None:
self.index = dataset.index
roi = BoundingBox(*self.index.bounds)
else:
self.index = Index(interleaved=False, properties=Property(dimension=3))
hits = dataset.index.intersection(tuple(roi), objects=True)
for hit in hits:
bbox = BoundingBox(*hit.bounds) & roi
self.index.insert(hit.id, tuple(bbox), hit.object)
self.res = dataset.res
self.roi = roi
@abc.abstractmethod
def __iter__(self) -> Iterator[List[BoundingBox]]:
"""Return a batch of indices of a dataset.
@ -42,9 +64,6 @@ class RandomBatchGeoSampler(BatchGeoSampler):
This is particularly useful during training when you want to maximize the size of
the dataset and return as many random :term:`chips <chip>` as possible.
When sampling from :class:`~torchgeo.datasets.ZipDataset`, the ``dataset`` should be
a tile-based dataset if possible.
"""
def __init__(
@ -72,15 +91,11 @@ class RandomBatchGeoSampler(BatchGeoSampler):
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
(defaults to the bounds of ``dataset.index``)
"""
self.index = dataset.index
self.res = dataset.res
super().__init__(dataset, roi)
self.size = _to_tuple(size)
self.batch_size = batch_size
self.length = length
if roi is None:
roi = BoundingBox(*self.index.bounds)
self.roi = roi
self.hits = list(self.index.intersection(roi, objects=True))
self.hits = list(self.index.intersection(tuple(self.roi), objects=True))
def __iter__(self) -> Iterator[List[BoundingBox]]:
"""Return the indices of a dataset.

Просмотреть файл

@ -7,6 +7,7 @@ import abc
import random
from typing import Iterator, Optional, Tuple, Union
from rtree.index import Index, Property
from torch.utils.data import Sampler
from torchgeo.datasets.geo import GeoDataset
@ -28,6 +29,27 @@ class GeoSampler(Sampler[BoundingBox], abc.ABC):
longitude, height, width, projection, coordinate system, and time.
"""
def __init__(self, dataset: GeoDataset, roi: Optional[BoundingBox] = None) -> None:
"""Initialize a new Sampler instance.
Args:
dataset: dataset to index from
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
(defaults to the bounds of ``dataset.index``)
"""
if roi is None:
self.index = dataset.index
roi = BoundingBox(*self.index.bounds)
else:
self.index = Index(interleaved=False, properties=Property(dimension=3))
hits = dataset.index.intersection(tuple(roi), objects=True)
for hit in hits:
bbox = BoundingBox(*hit.bounds) & roi
self.index.insert(hit.id, tuple(bbox), hit.object)
self.res = dataset.res
self.roi = roi
@abc.abstractmethod
def __iter__(self) -> Iterator[BoundingBox]:
"""Return the index of a dataset.
@ -70,14 +92,10 @@ class RandomGeoSampler(GeoSampler):
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
(defaults to the bounds of ``dataset.index``)
"""
self.index = dataset.index
self.res = dataset.res
super().__init__(dataset, roi)
self.size = _to_tuple(size)
self.length = length
if roi is None:
roi = BoundingBox(*self.index.bounds)
self.roi = roi
self.hits = list(self.index.intersection(roi, objects=True))
self.hits = list(self.index.intersection(tuple(self.roi), objects=True))
def __iter__(self) -> Iterator[BoundingBox]:
"""Return the index of a dataset.
@ -117,9 +135,6 @@ class GridGeoSampler(GeoSampler):
The overlap between each chip (``chip_size - stride``) should be approximately equal
to the `receptive field <https://distill.pub/2019/computing-receptive-fields/>`_ of
the CNN.
When sampling from :class:`~torchgeo.datasets.ZipDataset`, the ``dataset`` should be
a non-tile-based dataset if possible.
"""
def __init__(
@ -145,13 +160,10 @@ class GridGeoSampler(GeoSampler):
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
(defaults to the bounds of ``dataset.index``)
"""
self.index = dataset.index
super().__init__(dataset, roi)
self.size = _to_tuple(size)
self.stride = _to_tuple(stride)
if roi is None:
roi = BoundingBox(*self.index.bounds)
self.roi = roi
self.hits = list(self.index.intersection(roi, objects=True))
self.hits = list(self.index.intersection(tuple(self.roi), objects=True))
self.length: int = 0
for hit in self.hits: