зеркало из https://github.com/microsoft/torchgeo.git
Overhaul BoundingBox and ZipDataset classes (#144)
* Adding a UnionDataset * Adding contains method to BoundingBox * Finishing UnionDataset * Add __contains__ method * Overhaul BoundingBox, add set arithmetic * mypy fixes * pydocstyle fixes * Ignore erroneous pydocstyle warnings * rtree only supports tuples, not BoundingBoxes * mypy fixes * Use custom collate function to handle BoundingBoxes * Add back support for Python 3.6 * Add tests for all new BoundingBox features * Rename ZipDataset to IntersectionDataset * Merge indices of IntersectionDataset, auto-convert CRS/res * Get tests to pass * Fix more tests * Test more of RasterDataset/VectorDataset directly * Increase UnionDataset test coverage * IntersectionDataset stacks tensors, UnionDataset merges tensors * Support collating dicts with differing keys, add tests * Style fixes * Samplers: compute intersection between index and ROI * Update README with example usage * GeoDataset addition is deprecated * Add note about CRS/res * More documentation for Intersection/UnionDatasets * Use collate function in tutorial * Don't use multiple workers * Fix typo * Drop support for adding GeoDatasets * Remove unused import * Add comment explaining coverage config settings * Collation function needed for benchmark script * Add more explanation to README * Correct Landsat 8 bands * Print warning when changing CRS/res Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Родитель
6f23fa42c5
Коммит
5d407b76b5
96
README.md
96
README.md
|
@ -7,7 +7,7 @@ The goal of this library is to make it simple:
|
|||
1. for machine learning experts to use geospatial data in their workflows, and
|
||||
2. for remote sensing experts to use their data in machine learning workflows.
|
||||
|
||||
See our [installation instructions](#installation-instructions), [documentation](#documentation), and [examples](#example-usage) to learn how to use torchgeo.
|
||||
See our [installation instructions](#installation), [documentation](#documentation), and [examples](#example-usage) to learn how to use TorchGeo.
|
||||
|
||||
External links:
|
||||
[![docs](https://readthedocs.org/projects/torchgeo/badge/?version=latest)](https://torchgeo.readthedocs.io/en/latest/?badge=latest)
|
||||
|
@ -21,7 +21,7 @@ Tests:
|
|||
[![style](https://github.com/microsoft/torchgeo/actions/workflows/style.yaml/badge.svg)](https://github.com/microsoft/torchgeo/actions/workflows/style.yaml)
|
||||
[![tests](https://github.com/microsoft/torchgeo/actions/workflows/tests.yaml/badge.svg)](https://github.com/microsoft/torchgeo/actions/workflows/tests.yaml)
|
||||
|
||||
## Installation instructions
|
||||
## Installation
|
||||
|
||||
The recommended way to install TorchGeo is with [pip](https://pip.pypa.io/):
|
||||
|
||||
|
@ -33,13 +33,83 @@ For [conda](https://docs.conda.io/) and [spack](https://spack.io/) installation
|
|||
|
||||
## Documentation
|
||||
|
||||
You can find the documentation for torchgeo on [ReadTheDocs](https://torchgeo.readthedocs.io).
|
||||
You can find the documentation for TorchGeo on [ReadTheDocs](https://torchgeo.readthedocs.io).
|
||||
|
||||
## Example usage
|
||||
## Example Usage
|
||||
|
||||
The following sections give basic examples of what you can do with torchgeo. For more examples, check out our [tutorials](https://torchgeo.readthedocs.io/en/latest/tutorials/getting_started.html).
|
||||
The following sections give basic examples of what you can do with TorchGeo. For more examples, check out our [tutorials](https://torchgeo.readthedocs.io/en/latest/tutorials/getting_started.html).
|
||||
|
||||
### Train and test models using our PyTorch Lightning based training script
|
||||
First we'll import various classes and functions used in the following sections:
|
||||
|
||||
```python
|
||||
from torch.utils.data import DataLoader
|
||||
from torchgeo.datasets import CDL, COWCDetection, Landsat7, Landsat8, stack_samples
|
||||
from torchgeo.samplers import RandomGeoSampler
|
||||
```
|
||||
|
||||
### Benchmark datasets
|
||||
|
||||
TorchGeo includes a number of [*benchmark*](https://torchgeo.readthedocs.io/en/latest/api/datasets.html#non-geospatial-datasets) datasets, datasets that include both input images and target labels. This includes datasets for tasks like image classification, regression, semantic segmentation, object detection, instance segmentation, change detection, and more.
|
||||
|
||||
If you've used [torchvision](https://pytorch.org/vision) before, these datasets should seem very familiar. In this example, we'll create a dataset for the Cars Overhead With Context (COWC) car detection dataset. This dataset can be automatically downloaded, checksummed, and extracted, just like with torchvision.
|
||||
|
||||
```python
|
||||
dataset = COWCDetection(root="...", split="train", download=True, checksum=True)
|
||||
```
|
||||
|
||||
This dataset can then be passed to a PyTorch data loader.
|
||||
|
||||
```python
|
||||
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4)
|
||||
```
|
||||
|
||||
The only difference between a benchmark dataset in TorchGeo and a similar dataset in torchvision is that each dataset returns a dictionary with keys for each PyTorch Tensor.
|
||||
|
||||
```python
|
||||
for batch in dataloader:
|
||||
image = batch["image"]
|
||||
label = batch["label"]
|
||||
|
||||
# train a model, or make predictions using a pre-trained model
|
||||
```
|
||||
|
||||
### Geospatial datasets
|
||||
|
||||
Many remote sensing applications involve working with [*generic*](https://torchgeo.readthedocs.io/en/latest/api/datasets.html#geospatial-datasets) geospatial data. This data can be challenging to work with due to the sheer variety of data. Geospatial imagery is often multispectral with a different number of spectral bands and spatial resolution for every satellite. In addition, each file may be in a different coordinate reference system (CRS), requiring the data to be reprojected into a matching CRS.
|
||||
|
||||
In this example, we show how easy it is to work with geospatial data and to sample small image patches from a combination of Landsat and Cropland Data Layer (CDL) data using TorchGeo. First, we assume that the user has Landsat 7 and 8 imagery downloaded. Since Landsat 8 has more spectral bands than Landsat 7, we'll only use the bands that both satellites have in common. We'll create a single dataset including all images from both Landsat 7 and 8 data by taking the union between these two datasets.
|
||||
|
||||
```python
|
||||
landsat7 = Landsat7(root="...")
|
||||
landsat8 = Landsat8(root="...", bands=["B2", "B3", "B4", "B5", "B6", "B7", "B8", "B9"])
|
||||
landsat = landsat7 | landsat8
|
||||
```
|
||||
|
||||
Next, we take the intersection between this dataset and the Cropland Data Layer (CDL) dataset. We want to take the intersection instead of the union to ensure that we only sample from regions that have both Landsat and CDL data. Note that we can automatically download and checksum CDL data. Also note that each of these datasets may contain files in different coordinate reference systems (CRS) or resolutions, but TorchGeo automatically ensures that a matching CRS and resolution is used.
|
||||
|
||||
```python
|
||||
cdl = CDL(root="...", download=True, checksum=True)
|
||||
dataset = landsat & cdl
|
||||
```
|
||||
|
||||
This dataset can now be used with a PyTorch data loader. Unlike benchmark datasets, geospatial datasets often include very large images. For example, the CDL dataset consists of a single image covering the entire continental United States. In order to sample from these datasets using geospatial coordinates, TorchGeo defines a number of [*samplers*](https://torchgeo.readthedocs.io/en/latest/api/samplers.html). In this example, we'll use a random sampler that returns 256x256 pixel images and an epoch length of 10,000 images. We also use a custom collation function to combine each sample dictionary into a mini-batch of samples.
|
||||
|
||||
```python
|
||||
sampler = RandomGeoSampler(dataset, size=256, length=10000)
|
||||
dataloader = DataLoader(dataset, batch_size=128, sampler=sampler, collate_fn=stack_samples)
|
||||
```
|
||||
|
||||
This data loader can now be used in your normal training/evaluation pipeline.
|
||||
|
||||
```python
|
||||
for batch in dataloader:
|
||||
image = batch["image"]
|
||||
mask = batch["mask"]
|
||||
|
||||
# train a model, or make predictions using a pre-trained model
|
||||
```
|
||||
|
||||
### Train and test models using our PyTorch Lightning-based training script
|
||||
|
||||
We provide a script, `train.py` for training models using a subset of the datasets. We do this with the PyTorch Lightning `LightningModule`s and `LightningDataModule`s implemented under the `torchgeo.trainers` namespace.
|
||||
The `train.py` script is configurable via the command line and/or via YAML configuration files. See the [conf/](conf/) directory for example configuration files that can be customized for different training runs.
|
||||
|
@ -48,20 +118,6 @@ The `train.py` script is configurable via the command line and/or via YAML confi
|
|||
$ python train.py config_file=conf/landcoverai.yaml
|
||||
```
|
||||
|
||||
### Download and use the Tropical Cyclone Wind Estimation Competition dataset
|
||||
|
||||
This dataset is from a competition hosted by [Driven Data](https://www.drivendata.org/) in collaboration with [Radiant Earth](https://www.radiant.earth/). See [here](https://www.drivendata.org/competitions/72/predict-wind-speeds/) for more information.
|
||||
|
||||
Using this dataset in torchgeo is as simple as importing and instantiating the appropriate class.
|
||||
|
||||
```python
|
||||
import torchgeo.datasets
|
||||
|
||||
dataset = torchgeo.datasets.TropicalCycloneWindEstimation(split="train", download=True)
|
||||
print(dataset[0]["image"].shape)
|
||||
print(dataset[0]["label"])
|
||||
```
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this software in your work, please cite [our paper](https://arxiv.org/abs/2111.08872):
|
||||
|
|
10
benchmark.py
10
benchmark.py
|
@ -14,7 +14,7 @@ import torch.optim as optim
|
|||
from torch.utils.data import DataLoader
|
||||
from torchvision.models import resnet34
|
||||
|
||||
from torchgeo.datasets import CDL, Landsat8
|
||||
from torchgeo.datasets import CDL, Landsat8, stack_samples
|
||||
from torchgeo.samplers import GridGeoSampler, RandomBatchGeoSampler, RandomGeoSampler
|
||||
|
||||
|
||||
|
@ -129,7 +129,7 @@ def main(args: argparse.Namespace) -> None:
|
|||
landsat = Landsat8(
|
||||
args.landsat_root, crs=cdl.crs, res=cdl.res, cache=args.cache, bands=bands
|
||||
)
|
||||
dataset = landsat + cdl
|
||||
dataset = landsat & cdl
|
||||
|
||||
# Initialize samplers
|
||||
if args.epoch_size:
|
||||
|
@ -158,7 +158,10 @@ def main(args: argparse.Namespace) -> None:
|
|||
|
||||
if isinstance(sampler, RandomBatchGeoSampler):
|
||||
dataloader = DataLoader(
|
||||
dataset, batch_sampler=sampler, num_workers=args.num_workers
|
||||
dataset,
|
||||
batch_sampler=sampler,
|
||||
num_workers=args.num_workers,
|
||||
collate_fn=stack_samples,
|
||||
)
|
||||
else:
|
||||
dataloader = DataLoader(
|
||||
|
@ -166,6 +169,7 @@ def main(args: argparse.Namespace) -> None:
|
|||
batch_size=args.batch_size,
|
||||
sampler=sampler, # type: ignore[arg-type]
|
||||
num_workers=args.num_workers,
|
||||
collate_fn=stack_samples,
|
||||
)
|
||||
|
||||
tic = time.time()
|
||||
|
|
|
@ -10,7 +10,7 @@ In :mod:`torchgeo`, we define two types of datasets: :ref:`Geospatial Datasets`
|
|||
Geospatial Datasets
|
||||
-------------------
|
||||
|
||||
:class:`GeoDataset` is designed for datasets that contain geospatial information, like latitude, longitude, coordinate system, and projection. Datasets containing this kind of information can be combined using :class:`ZipDataset`.
|
||||
:class:`GeoDataset` is designed for datasets that contain geospatial information, like latitude, longitude, coordinate system, and projection. Datasets containing this kind of information can be combined using :class:`IntersectionDataset` and :class:`UnionDataset`.
|
||||
|
||||
Canadian Building Footprints
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
@ -245,13 +245,24 @@ VisionClassificationDataset
|
|||
|
||||
.. autoclass:: VisionClassificationDataset
|
||||
|
||||
ZipDataset
|
||||
^^^^^^^^^^
|
||||
IntersectionDataset
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: ZipDataset
|
||||
.. autoclass:: IntersectionDataset
|
||||
|
||||
UnionDataset
|
||||
^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: UnionDataset
|
||||
|
||||
Utilities
|
||||
---------
|
||||
|
||||
.. autoclass:: BoundingBox
|
||||
.. autofunction:: collate_dict
|
||||
|
||||
Collation Functions
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: stack_samples
|
||||
.. autofunction:: concat_samples
|
||||
.. autofunction:: merge_samples
|
||||
|
|
|
@ -72,9 +72,8 @@
|
|||
"\n",
|
||||
"from torch.utils.data import DataLoader\n",
|
||||
"\n",
|
||||
"from torchgeo.datasets import NAIP, ChesapeakeDE\n",
|
||||
"from torchgeo.datasets import NAIP, ChesapeakeDE, stack_samples\n",
|
||||
"from torchgeo.datasets.utils import download_url\n",
|
||||
"from torchgeo.models import FCN\n",
|
||||
"from torchgeo.samplers import RandomGeoSampler"
|
||||
]
|
||||
},
|
||||
|
@ -293,7 +292,7 @@
|
|||
"id": "OWUhlfpD22IX"
|
||||
},
|
||||
"source": [
|
||||
"Finally, we create a ZipDataset so that we can automatically sample from both GeoDatasets simultaneously."
|
||||
"Finally, we create an IntersectionDataset so that we can automatically sample from both GeoDatasets simultaneously."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -305,7 +304,7 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dataset = naip + chesapeake"
|
||||
"dataset = naip & chesapeake"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -353,7 +352,7 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dataloader = DataLoader(dataset, sampler=sampler)"
|
||||
"dataloader = DataLoader(dataset, sampler=sampler, collate_fn=stack_samples)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -393,7 +392,7 @@
|
|||
"timeout": 1200
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
|
@ -406,7 +405,8 @@
|
|||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"version": "3.9.7"
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.12"
|
||||
},
|
||||
"widgets": {
|
||||
"application/vnd.jupyter.widget-state+json": {
|
||||
|
|
|
@ -35,6 +35,14 @@ exclude = '''
|
|||
)/
|
||||
'''
|
||||
|
||||
[tool.coverage.report]
|
||||
# Ignore warnings for overloads
|
||||
# https://github.com/nedbat/coveragepy/issues/970#issuecomment-612602180
|
||||
exclude_lines = [
|
||||
"pragma: no cover",
|
||||
"@overload",
|
||||
]
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
known_first_party = ["docs", "tests", "torchgeo"]
|
||||
|
|
|
@ -28,6 +28,8 @@ setup_requires =
|
|||
# setuptools 42+ required for metadata.license_files support in setup.cfg
|
||||
setuptools>=42
|
||||
install_requires =
|
||||
# dataclasses was added in Python 3.7
|
||||
dataclasses;python_version<'3.7'
|
||||
einops
|
||||
# fiona 1.5+ required for fiona.transform module
|
||||
fiona>=1.5
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
{
|
||||
"type": "FeatureCollection",
|
||||
"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } },
|
||||
"features": [
|
||||
{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 0.0, 0.0 ], [ 0.0, 1.0 ], [ 1.0, 1.0 ], [ 1.0, 0.0 ], [ 0.0, 0.0 ] ] ] } }
|
||||
]
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
{
|
||||
"type": "FeatureCollection",
|
||||
"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } },
|
||||
"features": [
|
||||
{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 0.0, 0.0 ], [ 0.0, 1.0 ], [ 1.0, 1.0 ], [ 1.0, 0.0 ], [ 0.0, 0.0 ] ] ] } }
|
||||
]
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
{
|
||||
"type": "FeatureCollection",
|
||||
"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } },
|
||||
"features": [
|
||||
{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 0.0, 0.0 ], [ 0.0, 1.0 ], [ 1.0, 1.0 ], [ 1.0, 0.0 ], [ 0.0, 0.0 ] ] ] } }
|
||||
]
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
{
|
||||
"type": "FeatureCollection",
|
||||
"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } },
|
||||
"features": [
|
||||
{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 0.0, 0.0 ], [ 0.0, 1.0 ], [ 1.0, 1.0 ], [ 1.0, 0.0 ], [ 0.0, 0.0 ] ] ] } }
|
||||
]
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
{
|
||||
"type": "FeatureCollection",
|
||||
"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } },
|
||||
"features": [
|
||||
{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 0.0, 0.0 ], [ 0.0, 1.0 ], [ 1.0, 1.0 ], [ 1.0, 0.0 ], [ 0.0, 0.0 ] ] ] } }
|
||||
]
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
{
|
||||
"type": "FeatureCollection",
|
||||
"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } },
|
||||
"features": [
|
||||
{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 0.0, 0.0 ], [ 0.0, 1.0 ], [ 1.0, 1.0 ], [ 1.0, 0.0 ], [ 0.0, 0.0 ] ] ] } }
|
||||
]
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
{
|
||||
"type": "FeatureCollection",
|
||||
"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } },
|
||||
"features": [
|
||||
{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 0.0, 0.0 ], [ 0.0, 1.0 ], [ 1.0, 1.0 ], [ 1.0, 0.0 ], [ 0.0, 0.0 ] ] ] } }
|
||||
]
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
{
|
||||
"type": "FeatureCollection",
|
||||
"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } },
|
||||
"features": [
|
||||
{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 0.0, 0.0 ], [ 0.0, 1.0 ], [ 1.0, 1.0 ], [ 1.0, 0.0 ], [ 0.0, 0.0 ] ] ] } }
|
||||
]
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
{
|
||||
"type": "FeatureCollection",
|
||||
"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } },
|
||||
"features": [
|
||||
{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 0.0, 0.0 ], [ 0.0, 1.0 ], [ 1.0, 1.0 ], [ 1.0, 0.0 ], [ 0.0, 0.0 ] ] ] } }
|
||||
]
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
{
|
||||
"type": "FeatureCollection",
|
||||
"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } },
|
||||
"features": [
|
||||
{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 0.0, 0.0 ], [ 0.0, 1.0 ], [ 1.0, 1.0 ], [ 1.0, 0.0 ], [ 0.0, 0.0 ] ] ] } }
|
||||
]
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
{
|
||||
"type": "FeatureCollection",
|
||||
"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } },
|
||||
"features": [
|
||||
{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 0.0, 0.0 ], [ 0.0, 1.0 ], [ 1.0, 1.0 ], [ 1.0, 0.0 ], [ 0.0, 0.0 ] ] ] } }
|
||||
]
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
{
|
||||
"type": "FeatureCollection",
|
||||
"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } },
|
||||
"features": [
|
||||
{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 0.0, 0.0 ], [ 0.0, 1.0 ], [ 1.0, 1.0 ], [ 1.0, 0.0 ], [ 0.0, 0.0 ] ] ] } }
|
||||
]
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
{
|
||||
"type": "FeatureCollection",
|
||||
"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } },
|
||||
"features": [
|
||||
{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 0.0, 0.0 ], [ 0.0, 1.0 ], [ 1.0, 1.0 ], [ 1.0, 0.0 ], [ 0.0, 0.0 ] ] ] } }
|
||||
]
|
||||
}
|
Двоичный файл не отображается.
|
@ -14,7 +14,12 @@ from _pytest.monkeypatch import MonkeyPatch
|
|||
from rasterio.crs import CRS
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import BoundingBox, CanadianBuildingFootprints, ZipDataset
|
||||
from torchgeo.datasets import (
|
||||
BoundingBox,
|
||||
CanadianBuildingFootprints,
|
||||
IntersectionDataset,
|
||||
UnionDataset,
|
||||
)
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
|
@ -66,9 +71,13 @@ class TestCanadianBuildingFootprints:
|
|||
assert isinstance(x["crs"], CRS)
|
||||
assert isinstance(x["mask"], torch.Tensor)
|
||||
|
||||
def test_add(self, dataset: CanadianBuildingFootprints) -> None:
|
||||
ds = dataset + dataset
|
||||
assert isinstance(ds, ZipDataset)
|
||||
def test_and(self, dataset: CanadianBuildingFootprints) -> None:
|
||||
ds = dataset & dataset
|
||||
assert isinstance(ds, IntersectionDataset)
|
||||
|
||||
def test_or(self, dataset: CanadianBuildingFootprints) -> None:
|
||||
ds = dataset | dataset
|
||||
assert isinstance(ds, UnionDataset)
|
||||
|
||||
def test_already_downloaded(self, dataset: CanadianBuildingFootprints) -> None:
|
||||
CanadianBuildingFootprints(root=dataset.root, download=True)
|
||||
|
|
|
@ -16,7 +16,7 @@ from _pytest.monkeypatch import MonkeyPatch
|
|||
from rasterio.crs import CRS
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import CDL, BoundingBox, ZipDataset
|
||||
from torchgeo.datasets import CDL, BoundingBox, IntersectionDataset, UnionDataset
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
|
@ -51,15 +51,19 @@ class TestCDL:
|
|||
assert isinstance(x["crs"], CRS)
|
||||
assert isinstance(x["mask"], torch.Tensor)
|
||||
|
||||
def test_add(self, dataset: CDL) -> None:
|
||||
ds = dataset + dataset
|
||||
assert isinstance(ds, ZipDataset)
|
||||
def test_and(self, dataset: CDL) -> None:
|
||||
ds = dataset & dataset
|
||||
assert isinstance(ds, IntersectionDataset)
|
||||
|
||||
def test_or(self, dataset: CDL) -> None:
|
||||
ds = dataset | dataset
|
||||
assert isinstance(ds, UnionDataset)
|
||||
|
||||
def test_full_year(self, dataset: CDL) -> None:
|
||||
bbox = dataset.bounds
|
||||
time = datetime(2021, 6, 1).timestamp()
|
||||
query = BoundingBox(bbox.minx, bbox.maxx, bbox.miny, bbox.maxy, time, time)
|
||||
next(dataset.index.intersection(query))
|
||||
next(dataset.index.intersection(tuple(query)))
|
||||
|
||||
def test_already_extracted(self, dataset: CDL) -> None:
|
||||
CDL(root=dataset.root, download=True)
|
||||
|
|
|
@ -20,7 +20,8 @@ from torchgeo.datasets import (
|
|||
Chesapeake13,
|
||||
ChesapeakeCVPR,
|
||||
ChesapeakeCVPRDataModule,
|
||||
ZipDataset,
|
||||
IntersectionDataset,
|
||||
UnionDataset,
|
||||
)
|
||||
|
||||
|
||||
|
@ -55,9 +56,13 @@ class TestChesapeake13:
|
|||
assert isinstance(x["crs"], CRS)
|
||||
assert isinstance(x["mask"], torch.Tensor)
|
||||
|
||||
def test_add(self, dataset: Chesapeake13) -> None:
|
||||
ds = dataset + dataset
|
||||
assert isinstance(ds, ZipDataset)
|
||||
def test_and(self, dataset: Chesapeake13) -> None:
|
||||
ds = dataset & dataset
|
||||
assert isinstance(ds, IntersectionDataset)
|
||||
|
||||
def test_or(self, dataset: Chesapeake13) -> None:
|
||||
ds = dataset | dataset
|
||||
assert isinstance(ds, UnionDataset)
|
||||
|
||||
def test_already_downloaded(self, dataset: Chesapeake13) -> None:
|
||||
Chesapeake13(root=dataset.root, download=True)
|
||||
|
@ -128,9 +133,13 @@ class TestChesapeakeCVPR:
|
|||
assert isinstance(x["crs"], CRS)
|
||||
assert isinstance(x["mask"], torch.Tensor)
|
||||
|
||||
def test_add(self, dataset: ChesapeakeCVPR) -> None:
|
||||
ds = dataset + dataset
|
||||
assert isinstance(ds, ZipDataset)
|
||||
def test_and(self, dataset: ChesapeakeCVPR) -> None:
|
||||
ds = dataset & dataset
|
||||
assert isinstance(ds, IntersectionDataset)
|
||||
|
||||
def test_or(self, dataset: ChesapeakeCVPR) -> None:
|
||||
ds = dataset | dataset
|
||||
assert isinstance(ds, UnionDataset)
|
||||
|
||||
def test_already_extracted(self, dataset: ChesapeakeCVPR) -> None:
|
||||
ChesapeakeCVPR(root=dataset.root, download=True)
|
||||
|
|
|
@ -13,14 +13,17 @@ from rasterio.crs import CRS
|
|||
from torch.utils.data import ConcatDataset
|
||||
|
||||
from torchgeo.datasets import (
|
||||
NAIP,
|
||||
BoundingBox,
|
||||
CanadianBuildingFootprints,
|
||||
GeoDataset,
|
||||
Landsat8,
|
||||
IntersectionDataset,
|
||||
RasterDataset,
|
||||
Sentinel2,
|
||||
UnionDataset,
|
||||
VectorDataset,
|
||||
VisionClassificationDataset,
|
||||
VisionDataset,
|
||||
ZipDataset,
|
||||
)
|
||||
|
||||
|
||||
|
@ -32,8 +35,8 @@ class CustomGeoDataset(GeoDataset):
|
|||
res: float = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.index.insert(0, bounds)
|
||||
self.crs = crs
|
||||
self.index.insert(0, tuple(bounds))
|
||||
self._crs = crs
|
||||
self.res = res
|
||||
|
||||
def __getitem__(self, query: BoundingBox) -> Dict[str, BoundingBox]:
|
||||
|
@ -60,28 +63,56 @@ class TestGeoDataset:
|
|||
def test_len(self, dataset: GeoDataset) -> None:
|
||||
assert len(dataset) == 1
|
||||
|
||||
def test_add_two(self) -> None:
|
||||
@pytest.mark.parametrize("crs", [CRS.from_epsg(3005), CRS.from_epsg(32616)])
|
||||
def test_crs(self, dataset: GeoDataset, crs: CRS) -> None:
|
||||
dataset.crs = crs
|
||||
|
||||
def test_and_two(self) -> None:
|
||||
ds1 = CustomGeoDataset()
|
||||
ds2 = CustomGeoDataset()
|
||||
dataset = ds1 + ds2
|
||||
assert isinstance(dataset, ZipDataset)
|
||||
assert len(dataset) == 2
|
||||
dataset = ds1 & ds2
|
||||
assert isinstance(dataset, IntersectionDataset)
|
||||
assert len(dataset) == 1
|
||||
|
||||
def test_add_three(self) -> None:
|
||||
def test_and_three(self) -> None:
|
||||
ds1 = CustomGeoDataset()
|
||||
ds2 = CustomGeoDataset()
|
||||
ds3 = CustomGeoDataset()
|
||||
dataset = ds1 + ds2 + ds3
|
||||
assert isinstance(dataset, ZipDataset)
|
||||
assert len(dataset) == 3
|
||||
dataset = ds1 & ds2 & ds3
|
||||
assert isinstance(dataset, IntersectionDataset)
|
||||
assert len(dataset) == 1
|
||||
|
||||
def test_add_four(self) -> None:
|
||||
def test_and_four(self) -> None:
|
||||
ds1 = CustomGeoDataset()
|
||||
ds2 = CustomGeoDataset()
|
||||
ds3 = CustomGeoDataset()
|
||||
ds4 = CustomGeoDataset()
|
||||
dataset = (ds1 + ds2) + (ds3 + ds4)
|
||||
assert isinstance(dataset, ZipDataset)
|
||||
dataset = (ds1 & ds2) & (ds3 & ds4)
|
||||
assert isinstance(dataset, IntersectionDataset)
|
||||
assert len(dataset) == 1
|
||||
|
||||
def test_or_two(self) -> None:
|
||||
ds1 = CustomGeoDataset()
|
||||
ds2 = CustomGeoDataset()
|
||||
dataset = ds1 | ds2
|
||||
assert isinstance(dataset, UnionDataset)
|
||||
assert len(dataset) == 2
|
||||
|
||||
def test_or_three(self) -> None:
|
||||
ds1 = CustomGeoDataset()
|
||||
ds2 = CustomGeoDataset()
|
||||
ds3 = CustomGeoDataset()
|
||||
dataset = ds1 | ds2 | ds3
|
||||
assert isinstance(dataset, UnionDataset)
|
||||
assert len(dataset) == 3
|
||||
|
||||
def test_or_four(self) -> None:
|
||||
ds1 = CustomGeoDataset()
|
||||
ds2 = CustomGeoDataset()
|
||||
ds3 = CustomGeoDataset()
|
||||
ds4 = CustomGeoDataset()
|
||||
dataset = (ds1 | ds2) | (ds3 | ds4)
|
||||
assert isinstance(dataset, UnionDataset)
|
||||
assert len(dataset) == 4
|
||||
|
||||
def test_str(self, dataset: GeoDataset) -> None:
|
||||
|
@ -94,34 +125,75 @@ class TestGeoDataset:
|
|||
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
|
||||
GeoDataset() # type: ignore[abstract]
|
||||
|
||||
def test_add_vision(self, dataset: GeoDataset) -> None:
|
||||
def test_and_vision(self, dataset: GeoDataset) -> None:
|
||||
ds2 = CustomVisionDataset()
|
||||
with pytest.raises(ValueError, match="ZipDataset only supports GeoDatasets"):
|
||||
dataset + ds2 # type: ignore[operator]
|
||||
with pytest.raises(
|
||||
ValueError, match="IntersectionDataset only supports GeoDatasets"
|
||||
):
|
||||
dataset & ds2 # type: ignore[operator]
|
||||
|
||||
|
||||
class TestRasterDataset:
|
||||
@pytest.fixture(params=[True, False])
|
||||
def dataset(self, request: SubRequest) -> Landsat8:
|
||||
root = os.path.join("tests", "data", "landsat8")
|
||||
bands = ["B1", "B2", "B3", "B4", "B5", "B6", "B7"]
|
||||
def naip(self, request: SubRequest) -> NAIP:
|
||||
root = os.path.join("tests", "data", "naip")
|
||||
crs = CRS.from_epsg(3005)
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
cache = request.param
|
||||
return Landsat8(root, bands=bands, crs=crs, transforms=transforms, cache=cache)
|
||||
return NAIP(root, crs=crs, transforms=transforms, cache=cache)
|
||||
|
||||
def test_getitem(self, dataset: Landsat8) -> None:
|
||||
x = dataset[dataset.bounds]
|
||||
@pytest.fixture(params=[True, False])
|
||||
def sentinel(self, request: SubRequest) -> Sentinel2:
|
||||
root = os.path.join("tests", "data", "sentinel2")
|
||||
bands = ["B01", "B02", "B03", "B04", "B05", "B06", "B07", "B08", "B09", "B11"]
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
cache = request.param
|
||||
return Sentinel2(root, bands=bands, transforms=transforms, cache=cache)
|
||||
|
||||
def test_getitem_single_file(self, naip: NAIP) -> None:
|
||||
x = naip[naip.bounds]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x["crs"], CRS)
|
||||
assert isinstance(x["image"], torch.Tensor)
|
||||
|
||||
def test_getitem_separate_files(self, sentinel: Sentinel2) -> None:
|
||||
x = sentinel[sentinel.bounds]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x["crs"], CRS)
|
||||
assert isinstance(x["image"], torch.Tensor)
|
||||
|
||||
def test_invalid_query(self, sentinel: Sentinel2) -> None:
|
||||
query = BoundingBox(0, 0, 0, 0, 0, 0)
|
||||
with pytest.raises(
|
||||
IndexError, match="query: .* not found in index with bounds: .*"
|
||||
):
|
||||
sentinel[query]
|
||||
|
||||
def test_no_data(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(FileNotFoundError, match="No RasterDataset data was found"):
|
||||
RasterDataset(str(tmp_path))
|
||||
|
||||
|
||||
class TestVectorDataset:
|
||||
@pytest.fixture
|
||||
def dataset(self) -> CanadianBuildingFootprints:
|
||||
root = os.path.join("tests", "data", "cbf")
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return CanadianBuildingFootprints(root, res=0.1, transforms=transforms)
|
||||
|
||||
def test_getitem(self, dataset: CanadianBuildingFootprints) -> None:
|
||||
x = dataset[dataset.bounds]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x["crs"], CRS)
|
||||
assert isinstance(x["mask"], torch.Tensor)
|
||||
|
||||
def test_invalid_query(self, dataset: CanadianBuildingFootprints) -> None:
|
||||
query = BoundingBox(2, 2, 2, 2, 2, 2)
|
||||
with pytest.raises(
|
||||
IndexError, match="query: .* not found in index with bounds:"
|
||||
):
|
||||
dataset[query]
|
||||
|
||||
def test_no_data(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(FileNotFoundError, match="No VectorDataset data was found"):
|
||||
VectorDataset(str(tmp_path))
|
||||
|
@ -174,7 +246,8 @@ class TestVisionDataset:
|
|||
class TestVisionClassificationDataset:
|
||||
@pytest.fixture(scope="class")
|
||||
def dataset(self, root: str) -> VisionClassificationDataset:
|
||||
return VisionClassificationDataset(root)
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return VisionClassificationDataset(root, transforms=transforms)
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def root(self) -> str:
|
||||
|
@ -220,51 +293,99 @@ class TestVisionClassificationDataset:
|
|||
assert "size: 2" in str(dataset)
|
||||
|
||||
|
||||
class TestZipDataset:
|
||||
class TestIntersectionDataset:
|
||||
@pytest.fixture(scope="class")
|
||||
def dataset(self) -> ZipDataset:
|
||||
def dataset(self) -> IntersectionDataset:
|
||||
ds1 = CustomGeoDataset()
|
||||
ds2 = CustomGeoDataset()
|
||||
return ZipDataset([ds1, ds2])
|
||||
return ds1 & ds2
|
||||
|
||||
def test_getitem(self, dataset: ZipDataset) -> None:
|
||||
def test_getitem(self, dataset: IntersectionDataset) -> None:
|
||||
query = BoundingBox(0, 1, 2, 3, 4, 5)
|
||||
assert dataset[query] == {"index": query}
|
||||
|
||||
def test_len(self, dataset: ZipDataset) -> None:
|
||||
def test_len(self, dataset: IntersectionDataset) -> None:
|
||||
assert len(dataset) == 1
|
||||
|
||||
def test_str(self, dataset: IntersectionDataset) -> None:
|
||||
out = str(dataset)
|
||||
assert "type: IntersectionDataset" in out
|
||||
assert "bbox: BoundingBox" in out
|
||||
assert "size: 1" in out
|
||||
|
||||
def test_vision_dataset(self) -> None:
|
||||
ds1 = CustomVisionDataset()
|
||||
ds2 = CustomVisionDataset()
|
||||
with pytest.raises(
|
||||
ValueError, match="IntersectionDataset only supports GeoDatasets"
|
||||
):
|
||||
IntersectionDataset(ds1, ds2) # type: ignore[arg-type]
|
||||
|
||||
def test_different_crs(self) -> None:
|
||||
ds1 = CustomGeoDataset(crs=CRS.from_epsg(3005))
|
||||
ds2 = CustomGeoDataset(crs=CRS.from_epsg(32616))
|
||||
IntersectionDataset(ds1, ds2)
|
||||
|
||||
def test_different_res(self) -> None:
|
||||
ds1 = CustomGeoDataset(res=1)
|
||||
ds2 = CustomGeoDataset(res=2)
|
||||
IntersectionDataset(ds1, ds2)
|
||||
|
||||
def test_no_overlap(self) -> None:
|
||||
ds1 = CustomGeoDataset(BoundingBox(0, 1, 2, 3, 4, 5))
|
||||
ds2 = CustomGeoDataset(BoundingBox(6, 7, 8, 9, 10, 11))
|
||||
IntersectionDataset(ds1, ds2)
|
||||
|
||||
def test_invalid_query(self, dataset: IntersectionDataset) -> None:
|
||||
query = BoundingBox(0, 0, 0, 0, 0, 0)
|
||||
with pytest.raises(
|
||||
IndexError, match="query: .* not found in index with bounds:"
|
||||
):
|
||||
dataset[query]
|
||||
|
||||
|
||||
class TestUnionDataset:
|
||||
@pytest.fixture(scope="class")
|
||||
def dataset(self) -> UnionDataset:
|
||||
ds1 = CustomGeoDataset()
|
||||
ds2 = CustomGeoDataset()
|
||||
return ds1 | ds2
|
||||
|
||||
def test_getitem(self, dataset: UnionDataset) -> None:
|
||||
query = BoundingBox(0, 1, 2, 3, 4, 5)
|
||||
assert dataset[query] == {"index": query}
|
||||
|
||||
def test_len(self, dataset: UnionDataset) -> None:
|
||||
assert len(dataset) == 2
|
||||
|
||||
def test_str(self, dataset: ZipDataset) -> None:
|
||||
def test_str(self, dataset: UnionDataset) -> None:
|
||||
out = str(dataset)
|
||||
assert "type: ZipDataset" in out
|
||||
assert "type: UnionDataset" in out
|
||||
assert "bbox: BoundingBox" in out
|
||||
assert "size: 2" in out
|
||||
|
||||
def test_vision_dataset(self) -> None:
|
||||
ds1 = CustomVisionDataset()
|
||||
ds2 = CustomVisionDataset()
|
||||
with pytest.raises(ValueError, match="ZipDataset only supports GeoDatasets"):
|
||||
ZipDataset([ds1, ds2]) # type: ignore[list-item]
|
||||
with pytest.raises(ValueError, match="UnionDataset only supports GeoDatasets"):
|
||||
UnionDataset(ds1, ds2) # type: ignore[arg-type]
|
||||
|
||||
def test_different_crs(self) -> None:
|
||||
ds1 = CustomGeoDataset(crs=CRS.from_epsg(3005))
|
||||
ds2 = CustomGeoDataset(crs=CRS.from_epsg(32616))
|
||||
with pytest.raises(ValueError, match="Datasets must be in the same CRS"):
|
||||
ZipDataset([ds1, ds2])
|
||||
UnionDataset(ds1, ds2)
|
||||
|
||||
def test_different_res(self) -> None:
|
||||
ds1 = CustomGeoDataset(res=1)
|
||||
ds2 = CustomGeoDataset(res=2)
|
||||
with pytest.raises(ValueError, match="Datasets must have the same resolution"):
|
||||
ZipDataset([ds1, ds2])
|
||||
UnionDataset(ds1, ds2)
|
||||
|
||||
def test_no_overlap(self) -> None:
|
||||
ds1 = CustomGeoDataset(BoundingBox(0, 1, 2, 3, 4, 5))
|
||||
ds2 = CustomGeoDataset(BoundingBox(6, 7, 8, 9, 10, 11))
|
||||
with pytest.raises(ValueError, match="Datasets have no overlap"):
|
||||
ZipDataset([ds1, ds2])
|
||||
UnionDataset(ds1, ds2)
|
||||
|
||||
def test_invalid_query(self, dataset: ZipDataset) -> None:
|
||||
def test_invalid_query(self, dataset: UnionDataset) -> None:
|
||||
query = BoundingBox(0, 0, 0, 0, 0, 0)
|
||||
with pytest.raises(
|
||||
IndexError, match="query: .* not found in index with bounds:"
|
||||
|
|
|
@ -12,7 +12,7 @@ import torch.nn as nn
|
|||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from rasterio.crs import CRS
|
||||
|
||||
from torchgeo.datasets import BoundingBox, Landsat8, ZipDataset
|
||||
from torchgeo.datasets import BoundingBox, IntersectionDataset, Landsat8, UnionDataset
|
||||
|
||||
|
||||
class TestLandsat8:
|
||||
|
@ -35,9 +35,13 @@ class TestLandsat8:
|
|||
assert isinstance(x["crs"], CRS)
|
||||
assert isinstance(x["image"], torch.Tensor)
|
||||
|
||||
def test_add(self, dataset: Landsat8) -> None:
|
||||
ds = dataset + dataset
|
||||
assert isinstance(ds, ZipDataset)
|
||||
def test_and(self, dataset: Landsat8) -> None:
|
||||
ds = dataset & dataset
|
||||
assert isinstance(ds, IntersectionDataset)
|
||||
|
||||
def test_or(self, dataset: Landsat8) -> None:
|
||||
ds = dataset | dataset
|
||||
assert isinstance(ds, UnionDataset)
|
||||
|
||||
def test_plot(self, dataset: Landsat8) -> None:
|
||||
query = dataset.bounds
|
||||
|
|
|
@ -12,7 +12,13 @@ import torch.nn as nn
|
|||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from rasterio.crs import CRS
|
||||
|
||||
from torchgeo.datasets import NAIP, BoundingBox, NAIPChesapeakeDataModule, ZipDataset
|
||||
from torchgeo.datasets import (
|
||||
NAIP,
|
||||
BoundingBox,
|
||||
IntersectionDataset,
|
||||
NAIPChesapeakeDataModule,
|
||||
UnionDataset,
|
||||
)
|
||||
|
||||
|
||||
class TestNAIP:
|
||||
|
@ -31,9 +37,13 @@ class TestNAIP:
|
|||
assert isinstance(x["crs"], CRS)
|
||||
assert isinstance(x["image"], torch.Tensor)
|
||||
|
||||
def test_add(self, dataset: NAIP) -> None:
|
||||
ds = dataset + dataset
|
||||
assert isinstance(ds, ZipDataset)
|
||||
def test_and(self, dataset: NAIP) -> None:
|
||||
ds = dataset & dataset
|
||||
assert isinstance(ds, IntersectionDataset)
|
||||
|
||||
def test_or(self, dataset: NAIP) -> None:
|
||||
ds = dataset | dataset
|
||||
assert isinstance(ds, UnionDataset)
|
||||
|
||||
def test_plot(self, dataset: NAIP) -> None:
|
||||
query = dataset.bounds
|
||||
|
|
|
@ -9,7 +9,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from rasterio.crs import CRS
|
||||
|
||||
from torchgeo.datasets import BoundingBox, Sentinel2, ZipDataset
|
||||
from torchgeo.datasets import BoundingBox, IntersectionDataset, Sentinel2, UnionDataset
|
||||
|
||||
|
||||
class TestSentinel2:
|
||||
|
@ -29,9 +29,13 @@ class TestSentinel2:
|
|||
assert isinstance(x["crs"], CRS)
|
||||
assert isinstance(x["image"], torch.Tensor)
|
||||
|
||||
def test_add(self, dataset: Sentinel2) -> None:
|
||||
ds = dataset + dataset
|
||||
assert isinstance(ds, ZipDataset)
|
||||
def test_and(self, dataset: Sentinel2) -> None:
|
||||
ds = dataset & dataset
|
||||
assert isinstance(ds, IntersectionDataset)
|
||||
|
||||
def test_or(self, dataset: Sentinel2) -> None:
|
||||
ds = dataset | dataset
|
||||
assert isinstance(ds, UnionDataset)
|
||||
|
||||
def test_no_data(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(FileNotFoundError, match="No Sentinel2 data was found in "):
|
||||
|
|
|
@ -6,11 +6,12 @@ import glob
|
|||
import math
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator, Tuple
|
||||
from typing import Any, Dict, Generator, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
@ -22,14 +23,16 @@ from torch.utils.data import TensorDataset
|
|||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets.utils import (
|
||||
BoundingBox,
|
||||
collate_dict,
|
||||
concat_samples,
|
||||
dataset_split,
|
||||
disambiguate_timestamp,
|
||||
download_and_extract_archive,
|
||||
download_radiant_mlhub_collection,
|
||||
download_radiant_mlhub_dataset,
|
||||
extract_archive,
|
||||
merge_samples,
|
||||
percentile_normalization,
|
||||
stack_samples,
|
||||
working_dir,
|
||||
)
|
||||
|
||||
|
@ -164,7 +167,13 @@ def test_missing_radiant_mlhub(mock_missing_module: None) -> None:
|
|||
|
||||
|
||||
class TestBoundingBox:
|
||||
def test_new_init(self) -> None:
|
||||
def test_repr_str(self) -> None:
|
||||
bbox = BoundingBox(0, 1, 2.0, 3.0, -5, -4)
|
||||
expected = "BoundingBox(minx=0, maxx=1, miny=2.0, maxy=3.0, mint=-5, maxt=-4)"
|
||||
assert repr(bbox) == expected
|
||||
assert str(bbox) == expected
|
||||
|
||||
def test_getitem(self) -> None:
|
||||
bbox = BoundingBox(0, 1, 2, 3, 4, 5)
|
||||
|
||||
assert bbox.minx == 0
|
||||
|
@ -176,13 +185,143 @@ class TestBoundingBox:
|
|||
|
||||
assert bbox[0] == 0
|
||||
assert bbox[-1] == 5
|
||||
assert bbox[1:3] == (1, 2)
|
||||
assert bbox[1:3] == [1, 2]
|
||||
|
||||
def test_repr_str(self) -> None:
|
||||
bbox = BoundingBox(0, 1, 2.0, 3.0, -5, -4)
|
||||
expected = "BoundingBox(minx=0, maxx=1, miny=2.0, maxy=3.0, mint=-5, maxt=-4)"
|
||||
assert repr(bbox) == expected
|
||||
assert str(bbox) == expected
|
||||
def test_iter(self) -> None:
|
||||
bbox = BoundingBox(0, 1, 2, 3, 4, 5)
|
||||
|
||||
assert tuple(bbox) == (0, 1, 2, 3, 4, 5)
|
||||
|
||||
i = 0
|
||||
for _ in bbox:
|
||||
i += 1
|
||||
assert i == 6
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_input,expected",
|
||||
[
|
||||
# Same box
|
||||
((0, 1, 0, 1, 0, 1), True),
|
||||
((0.0, 1.0, 0.0, 1.0, 0.0, 1.0), True),
|
||||
# bbox1 strictly within bbox2
|
||||
((-1, 2, -1, 2, -1, 2), True),
|
||||
# bbox2 strictly within bbox1
|
||||
((0.25, 0.75, 0.25, 0.75, 0.25, 0.75), False),
|
||||
# One corner of bbox1 within bbox2
|
||||
((0.5, 1.5, 0.5, 1.5, 0.5, 1.5), False),
|
||||
((0.5, 1.5, -0.5, 0.5, 0.5, 1.5), False),
|
||||
((0.5, 1.5, 0.5, 1.5, -0.5, 0.5), False),
|
||||
((0.5, 1.5, -0.5, 0.5, -0.5, 0.5), False),
|
||||
((-0.5, 0.5, 0.5, 1.5, 0.5, 1.5), False),
|
||||
((-0.5, 0.5, -0.5, 0.5, 0.5, 1.5), False),
|
||||
((-0.5, 0.5, 0.5, 1.5, -0.5, 0.5), False),
|
||||
((-0.5, 0.5, -0.5, 0.5, -0.5, 0.5), False),
|
||||
# No overlap
|
||||
((0.5, 1.5, 0.5, 1.5, 2, 3), False),
|
||||
((0.5, 1.5, 2, 3, 0.5, 1.5), False),
|
||||
((2, 3, 0.5, 1.5, 0.5, 1.5), False),
|
||||
((2, 3, 2, 3, 2, 3), False),
|
||||
],
|
||||
)
|
||||
def test_contains(
|
||||
self,
|
||||
test_input: Tuple[float, float, float, float, float, float],
|
||||
expected: bool,
|
||||
) -> None:
|
||||
bbox1 = BoundingBox(0, 1, 0, 1, 0, 1)
|
||||
bbox2 = BoundingBox(*test_input)
|
||||
assert (bbox1 in bbox2) == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_input,expected",
|
||||
[
|
||||
# Same box
|
||||
((0, 1, 0, 1, 0, 1), (0, 1, 0, 1, 0, 1)),
|
||||
((0.0, 1.0, 0.0, 1.0, 0.0, 1.0), (0, 1, 0, 1, 0, 1)),
|
||||
# bbox1 strictly within bbox2
|
||||
((-1, 2, -1, 2, -1, 2), (-1, 2, -1, 2, -1, 2)),
|
||||
# bbox2 strictly within bbox1
|
||||
((0.25, 0.75, 0.25, 0.75, 0.25, 0.75), (0, 1, 0, 1, 0, 1)),
|
||||
# One corner of bbox1 within bbox2
|
||||
((0.5, 1.5, 0.5, 1.5, 0.5, 1.5), (0, 1.5, 0, 1.5, 0, 1.5)),
|
||||
((0.5, 1.5, -0.5, 0.5, 0.5, 1.5), (0, 1.5, -0.5, 1, 0, 1.5)),
|
||||
((0.5, 1.5, 0.5, 1.5, -0.5, 0.5), (0, 1.5, 0, 1.5, -0.5, 1)),
|
||||
((0.5, 1.5, -0.5, 0.5, -0.5, 0.5), (0, 1.5, -0.5, 1, -0.5, 1)),
|
||||
((-0.5, 0.5, 0.5, 1.5, 0.5, 1.5), (-0.5, 1, 0, 1.5, 0, 1.5)),
|
||||
((-0.5, 0.5, -0.5, 0.5, 0.5, 1.5), (-0.5, 1, -0.5, 1, 0, 1.5)),
|
||||
((-0.5, 0.5, 0.5, 1.5, -0.5, 0.5), (-0.5, 1, 0, 1.5, -0.5, 1)),
|
||||
((-0.5, 0.5, -0.5, 0.5, -0.5, 0.5), (-0.5, 1, -0.5, 1, -0.5, 1)),
|
||||
# No overlap
|
||||
((0.5, 1.5, 0.5, 1.5, 2, 3), (0, 1.5, 0, 1.5, 0, 3)),
|
||||
((0.5, 1.5, 2, 3, 0.5, 1.5), (0, 1.5, 0, 3, 0, 1.5)),
|
||||
((2, 3, 0.5, 1.5, 0.5, 1.5), (0, 3, 0, 1.5, 0, 1.5)),
|
||||
((2, 3, 2, 3, 2, 3), (0, 3, 0, 3, 0, 3)),
|
||||
],
|
||||
)
|
||||
def test_or(
|
||||
self,
|
||||
test_input: Tuple[float, float, float, float, float, float],
|
||||
expected: Tuple[float, float, float, float, float, float],
|
||||
) -> None:
|
||||
bbox1 = BoundingBox(0, 1, 0, 1, 0, 1)
|
||||
bbox2 = BoundingBox(*test_input)
|
||||
bbox3 = BoundingBox(*expected)
|
||||
assert (bbox1 | bbox2) == bbox3
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_input,expected",
|
||||
[
|
||||
# Same box
|
||||
((0, 1, 0, 1, 0, 1), (0, 1, 0, 1, 0, 1)),
|
||||
((0.0, 1.0, 0.0, 1.0, 0.0, 1.0), (0, 1, 0, 1, 0, 1)),
|
||||
# bbox1 strictly within bbox2
|
||||
((-1, 2, -1, 2, -1, 2), (0, 1, 0, 1, 0, 1)),
|
||||
# bbox2 strictly within bbox1
|
||||
(
|
||||
(0.25, 0.75, 0.25, 0.75, 0.25, 0.75),
|
||||
(0.25, 0.75, 0.25, 0.75, 0.25, 0.75),
|
||||
),
|
||||
# One corner of bbox1 within bbox2
|
||||
((0.5, 1.5, 0.5, 1.5, 0.5, 1.5), (0.5, 1, 0.5, 1, 0.5, 1)),
|
||||
((0.5, 1.5, -0.5, 0.5, 0.5, 1.5), (0.5, 1, 0, 0.5, 0.5, 1)),
|
||||
((0.5, 1.5, 0.5, 1.5, -0.5, 0.5), (0.5, 1, 0.5, 1, 0, 0.5)),
|
||||
((0.5, 1.5, -0.5, 0.5, -0.5, 0.5), (0.5, 1, 0, 0.5, 0, 0.5)),
|
||||
((-0.5, 0.5, 0.5, 1.5, 0.5, 1.5), (0, 0.5, 0.5, 1, 0.5, 1)),
|
||||
((-0.5, 0.5, -0.5, 0.5, 0.5, 1.5), (0, 0.5, 0, 0.5, 0.5, 1)),
|
||||
((-0.5, 0.5, 0.5, 1.5, -0.5, 0.5), (0, 0.5, 0.5, 1, 0, 0.5)),
|
||||
((-0.5, 0.5, -0.5, 0.5, -0.5, 0.5), (0, 0.5, 0, 0.5, 0, 0.5)),
|
||||
],
|
||||
)
|
||||
def test_and_intersection(
|
||||
self,
|
||||
test_input: Tuple[float, float, float, float, float, float],
|
||||
expected: Tuple[float, float, float, float, float, float],
|
||||
) -> None:
|
||||
bbox1 = BoundingBox(0, 1, 0, 1, 0, 1)
|
||||
bbox2 = BoundingBox(*test_input)
|
||||
bbox3 = BoundingBox(*expected)
|
||||
assert (bbox1 & bbox2) == bbox3
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_input",
|
||||
[
|
||||
# No overlap
|
||||
(0.5, 1.5, 0.5, 1.5, 2, 3),
|
||||
(0.5, 1.5, 2, 3, 0.5, 1.5),
|
||||
(2, 3, 0.5, 1.5, 0.5, 1.5),
|
||||
(2, 3, 2, 3, 2, 3),
|
||||
],
|
||||
)
|
||||
def test_and_no_intersection(
|
||||
self, test_input: Tuple[float, float, float, float, float, float]
|
||||
) -> None:
|
||||
bbox1 = BoundingBox(0, 1, 0, 1, 0, 1)
|
||||
bbox2 = BoundingBox(*test_input)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape(f"Bounding boxes {bbox1} and {bbox2} do not overlap"),
|
||||
):
|
||||
bbox1 & bbox2
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_input,expected",
|
||||
|
@ -306,26 +445,103 @@ def test_disambiguate_timestamp(
|
|||
assert math.isclose(maxt, max_datetime)
|
||||
|
||||
|
||||
def test_collate_dict() -> None:
|
||||
samples = [
|
||||
{
|
||||
"foo": torch.tensor(1), # type: ignore[attr-defined]
|
||||
"bar": torch.tensor(2), # type: ignore[attr-defined]
|
||||
"crs": CRS.from_epsg(3005),
|
||||
},
|
||||
{
|
||||
"foo": torch.tensor(3), # type: ignore[attr-defined]
|
||||
"bar": torch.tensor(4), # type: ignore[attr-defined]
|
||||
"crs": CRS.from_epsg(3005),
|
||||
},
|
||||
]
|
||||
sample = collate_dict(samples)
|
||||
assert torch.allclose( # type: ignore[attr-defined]
|
||||
sample["foo"], torch.tensor([1, 3]) # type: ignore[attr-defined]
|
||||
)
|
||||
assert torch.allclose( # type: ignore[attr-defined]
|
||||
sample["bar"], torch.tensor([2, 4]) # type: ignore[attr-defined]
|
||||
)
|
||||
class TestCollateFunctionsMatchingKeys:
|
||||
@pytest.fixture(scope="class")
|
||||
def samples(self) -> List[Dict[str, Any]]:
|
||||
return [
|
||||
{
|
||||
"image": torch.tensor([1, 2, 0]), # type: ignore[attr-defined]
|
||||
"crs": CRS.from_epsg(2000),
|
||||
},
|
||||
{
|
||||
"image": torch.tensor([0, 0, 3]), # type: ignore[attr-defined]
|
||||
"crs": CRS.from_epsg(2001),
|
||||
},
|
||||
]
|
||||
|
||||
def test_stack_samples(self, samples: List[Dict[str, Any]]) -> None:
|
||||
sample = stack_samples(samples)
|
||||
assert sample["image"].size() == torch.Size( # type: ignore[attr-defined]
|
||||
[2, 3]
|
||||
)
|
||||
assert torch.allclose( # type: ignore[attr-defined]
|
||||
sample["image"],
|
||||
torch.tensor([[1, 2, 0], [0, 0, 3]]), # type: ignore[attr-defined]
|
||||
)
|
||||
assert sample["crs"] == [CRS.from_epsg(2000), CRS.from_epsg(2001)]
|
||||
|
||||
def test_concat_samples(self, samples: List[Dict[str, Any]]) -> None:
|
||||
sample = concat_samples(samples)
|
||||
assert sample["image"].size() == torch.Size([6]) # type: ignore[attr-defined]
|
||||
assert torch.allclose( # type: ignore[attr-defined]
|
||||
sample["image"],
|
||||
torch.tensor([1, 2, 0, 0, 0, 3]), # type: ignore[attr-defined]
|
||||
)
|
||||
assert sample["crs"] == CRS.from_epsg(2000)
|
||||
|
||||
def test_merge_samples(self, samples: List[Dict[str, Any]]) -> None:
|
||||
sample = merge_samples(samples)
|
||||
assert sample["image"].size() == torch.Size([3]) # type: ignore[attr-defined]
|
||||
assert torch.allclose( # type: ignore[attr-defined]
|
||||
sample["image"], torch.tensor([1, 2, 3]) # type: ignore[attr-defined]
|
||||
)
|
||||
assert sample["crs"] == CRS.from_epsg(2001)
|
||||
|
||||
|
||||
class TestCollateFunctionsDifferingKeys:
|
||||
@pytest.fixture(scope="class")
|
||||
def samples(self) -> List[Dict[str, Any]]:
|
||||
return [
|
||||
{
|
||||
"image": torch.tensor([1, 2, 0]), # type: ignore[attr-defined]
|
||||
"crs1": CRS.from_epsg(2000),
|
||||
},
|
||||
{
|
||||
"mask": torch.tensor([0, 0, 3]), # type: ignore[attr-defined]
|
||||
"crs2": CRS.from_epsg(2001),
|
||||
},
|
||||
]
|
||||
|
||||
def test_stack_samples(self, samples: List[Dict[str, Any]]) -> None:
|
||||
sample = stack_samples(samples)
|
||||
assert sample["image"].size() == torch.Size( # type: ignore[attr-defined]
|
||||
[1, 3]
|
||||
)
|
||||
assert sample["mask"].size() == torch.Size([1, 3]) # type: ignore[attr-defined]
|
||||
assert torch.allclose( # type: ignore[attr-defined]
|
||||
sample["image"], torch.tensor([[1, 2, 0]]) # type: ignore[attr-defined]
|
||||
)
|
||||
assert torch.allclose( # type: ignore[attr-defined]
|
||||
sample["mask"], torch.tensor([[0, 0, 3]]) # type: ignore[attr-defined]
|
||||
)
|
||||
assert sample["crs1"] == [CRS.from_epsg(2000)]
|
||||
assert sample["crs2"] == [CRS.from_epsg(2001)]
|
||||
|
||||
def test_concat_samples(self, samples: List[Dict[str, Any]]) -> None:
|
||||
sample = concat_samples(samples)
|
||||
assert sample["image"].size() == torch.Size([3]) # type: ignore[attr-defined]
|
||||
assert sample["mask"].size() == torch.Size([3]) # type: ignore[attr-defined]
|
||||
assert torch.allclose( # type: ignore[attr-defined]
|
||||
sample["image"], torch.tensor([1, 2, 0]) # type: ignore[attr-defined]
|
||||
)
|
||||
assert torch.allclose( # type: ignore[attr-defined]
|
||||
sample["mask"], torch.tensor([0, 0, 3]) # type: ignore[attr-defined]
|
||||
)
|
||||
assert sample["crs1"] == CRS.from_epsg(2000)
|
||||
assert sample["crs2"] == CRS.from_epsg(2001)
|
||||
|
||||
def test_merge_samples(self, samples: List[Dict[str, Any]]) -> None:
|
||||
sample = merge_samples(samples)
|
||||
assert sample["image"].size() == torch.Size([3]) # type: ignore[attr-defined]
|
||||
assert sample["mask"].size() == torch.Size([3]) # type: ignore[attr-defined]
|
||||
assert torch.allclose( # type: ignore[attr-defined]
|
||||
sample["image"], torch.tensor([1, 2, 0]) # type: ignore[attr-defined]
|
||||
)
|
||||
assert torch.allclose( # type: ignore[attr-defined]
|
||||
sample["mask"], torch.tensor([0, 0, 3]) # type: ignore[attr-defined]
|
||||
)
|
||||
assert sample["crs1"] == CRS.from_epsg(2000)
|
||||
assert sample["crs2"] == CRS.from_epsg(2001)
|
||||
|
||||
|
||||
def test_existing_directory(tmp_path: Path) -> None:
|
||||
|
|
|
@ -28,7 +28,7 @@ class CustomBatchGeoSampler(BatchGeoSampler):
|
|||
class CustomGeoDataset(GeoDataset):
|
||||
def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 1) -> None:
|
||||
super().__init__()
|
||||
self.crs = crs
|
||||
self._crs = crs
|
||||
self.res = res
|
||||
|
||||
def __getitem__(self, query: BoundingBox) -> Dict[str, BoundingBox]:
|
||||
|
@ -56,8 +56,9 @@ class TestBatchGeoSampler:
|
|||
continue
|
||||
|
||||
def test_abstract(self) -> None:
|
||||
ds = CustomGeoDataset()
|
||||
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
|
||||
BatchGeoSampler(None) # type: ignore[abstract]
|
||||
BatchGeoSampler(ds) # type: ignore[abstract]
|
||||
|
||||
|
||||
class TestRandomBatchGeoSampler:
|
||||
|
@ -85,6 +86,16 @@ class TestRandomBatchGeoSampler:
|
|||
def test_len(self, sampler: RandomBatchGeoSampler) -> None:
|
||||
assert len(sampler) == sampler.length // sampler.batch_size
|
||||
|
||||
def test_roi(self) -> None:
|
||||
ds = CustomGeoDataset()
|
||||
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
|
||||
ds.index.insert(1, (5, 15, 5, 15, 5, 15))
|
||||
roi = BoundingBox(0, 10, 0, 10, 0, 10)
|
||||
sampler = RandomBatchGeoSampler(ds, 2, 2, 10, roi=roi)
|
||||
for batch in sampler:
|
||||
for query in batch:
|
||||
assert query in roi
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("num_workers", [0, 1, 2])
|
||||
def test_dataloader(self, sampler: RandomBatchGeoSampler, num_workers: int) -> None:
|
||||
|
|
|
@ -28,7 +28,7 @@ class CustomGeoSampler(GeoSampler):
|
|||
class CustomGeoDataset(GeoDataset):
|
||||
def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 1) -> None:
|
||||
super().__init__()
|
||||
self.crs = crs
|
||||
self._crs = crs
|
||||
self.res = res
|
||||
|
||||
def __getitem__(self, query: BoundingBox) -> Dict[str, BoundingBox]:
|
||||
|
@ -46,6 +46,11 @@ class TestGeoSampler:
|
|||
def test_len(self, sampler: CustomGeoSampler) -> None:
|
||||
assert len(sampler) == 2
|
||||
|
||||
def test_abstract(self) -> None:
|
||||
ds = CustomGeoDataset()
|
||||
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
|
||||
GeoSampler(ds) # type: ignore[abstract]
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("num_workers", [0, 1, 2])
|
||||
def test_dataloader(self, sampler: CustomGeoSampler, num_workers: int) -> None:
|
||||
|
@ -54,10 +59,6 @@ class TestGeoSampler:
|
|||
for _ in dl:
|
||||
continue
|
||||
|
||||
def test_abstract(self) -> None:
|
||||
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
|
||||
GeoSampler(None) # type: ignore[abstract]
|
||||
|
||||
|
||||
class TestRandomGeoSampler:
|
||||
@pytest.fixture(scope="function", params=[3, 4.5, (2, 2), (3, 4.5), (4.5, 3)])
|
||||
|
@ -83,6 +84,15 @@ class TestRandomGeoSampler:
|
|||
def test_len(self, sampler: RandomGeoSampler) -> None:
|
||||
assert len(sampler) == sampler.length
|
||||
|
||||
def test_roi(self) -> None:
|
||||
ds = CustomGeoDataset()
|
||||
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
|
||||
ds.index.insert(1, (5, 15, 5, 15, 5, 15))
|
||||
roi = BoundingBox(0, 10, 0, 10, 0, 10)
|
||||
sampler = RandomGeoSampler(ds, 2, 10, roi=roi)
|
||||
for query in sampler:
|
||||
assert query in roi
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("num_workers", [0, 1, 2])
|
||||
def test_dataloader(self, sampler: RandomGeoSampler, num_workers: int) -> None:
|
||||
|
@ -116,6 +126,21 @@ class TestGridGeoSampler:
|
|||
query.maxt - query.mint, sampler.roi.maxt - sampler.roi.mint
|
||||
)
|
||||
|
||||
def test_len(self, sampler: GridGeoSampler) -> None:
|
||||
rows = int((10 - sampler.size[0]) // sampler.stride[0]) + 1
|
||||
cols = int((20 - sampler.size[1]) // sampler.stride[1]) + 1
|
||||
length = rows * cols * 2
|
||||
assert len(sampler) == length
|
||||
|
||||
def test_roi(self) -> None:
|
||||
ds = CustomGeoDataset()
|
||||
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
|
||||
ds.index.insert(1, (5, 15, 5, 15, 5, 15))
|
||||
roi = BoundingBox(0, 10, 0, 10, 0, 10)
|
||||
sampler = GridGeoSampler(ds, 2, 1, roi=roi)
|
||||
for query in sampler:
|
||||
assert query in roi
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("num_workers", [0, 1, 2])
|
||||
def test_dataloader(self, sampler: GridGeoSampler, num_workers: int) -> None:
|
||||
|
@ -123,9 +148,3 @@ class TestGridGeoSampler:
|
|||
dl = DataLoader(ds, sampler=sampler, num_workers=num_workers)
|
||||
for _ in dl:
|
||||
continue
|
||||
|
||||
def test_len(self, sampler: GridGeoSampler) -> None:
|
||||
rows = int((10 - sampler.size[0]) // sampler.stride[0]) + 1
|
||||
cols = int((20 - sampler.size[1]) // sampler.stride[1]) + 1
|
||||
length = rows * cols * 2
|
||||
assert len(sampler) == length
|
||||
|
|
|
@ -29,11 +29,12 @@ from .etci2021 import ETCI2021, ETCI2021DataModule
|
|||
from .eurosat import EuroSAT, EuroSATDataModule
|
||||
from .geo import (
|
||||
GeoDataset,
|
||||
IntersectionDataset,
|
||||
RasterDataset,
|
||||
UnionDataset,
|
||||
VectorDataset,
|
||||
VisionClassificationDataset,
|
||||
VisionDataset,
|
||||
ZipDataset,
|
||||
)
|
||||
from .gid15 import GID15
|
||||
from .landcoverai import LandCoverAI, LandCoverAIDataModule
|
||||
|
@ -63,7 +64,7 @@ from .sentinel import Sentinel, Sentinel2
|
|||
from .so2sat import So2Sat, So2SatDataModule
|
||||
from .spacenet import SpaceNet, SpaceNet1, SpaceNet2, SpaceNet4, SpaceNet5, SpaceNet7
|
||||
from .ucmerced import UCMerced, UCMercedDataModule
|
||||
from .utils import BoundingBox, collate_dict
|
||||
from .utils import BoundingBox, concat_samples, merge_samples, stack_samples
|
||||
from .vaihingen import Vaihingen2D, Vaihingen2DDataModule
|
||||
from .xview import XView2, XView2DataModule
|
||||
from .zuericrop import ZueriCrop
|
||||
|
@ -147,14 +148,17 @@ __all__ = (
|
|||
"ZueriCrop",
|
||||
# Base classes
|
||||
"GeoDataset",
|
||||
"IntersectionDataset",
|
||||
"RasterDataset",
|
||||
"UnionDataset",
|
||||
"VectorDataset",
|
||||
"VisionDataset",
|
||||
"VisionClassificationDataset",
|
||||
"ZipDataset",
|
||||
# Utilities
|
||||
"BoundingBox",
|
||||
"collate_dict",
|
||||
"concat_samples",
|
||||
"merge_samples",
|
||||
"stack_samples",
|
||||
)
|
||||
|
||||
# https://stackoverflow.com/questions/40018681
|
||||
|
|
|
@ -32,6 +32,7 @@ from .utils import (
|
|||
download_and_extract_archive,
|
||||
download_url,
|
||||
extract_archive,
|
||||
stack_samples,
|
||||
)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
|
@ -434,7 +435,7 @@ class ChesapeakeCVPR(GeoDataset):
|
|||
Raises:
|
||||
IndexError: if query is not found in the index
|
||||
"""
|
||||
hits = self.index.intersection(query, objects=True)
|
||||
hits = self.index.intersection(tuple(query), objects=True)
|
||||
filepaths = [hit.object for hit in hits]
|
||||
|
||||
sample = {"image": [], "mask": [], "crs": self.crs, "bbox": query}
|
||||
|
@ -783,7 +784,10 @@ class ChesapeakeCVPRDataModule(LightningDataModule):
|
|||
length=self.patches_per_tile * len(self.train_dataset),
|
||||
)
|
||||
return DataLoader(
|
||||
self.train_dataset, batch_sampler=sampler, num_workers=self.num_workers
|
||||
self.train_dataset,
|
||||
batch_sampler=sampler,
|
||||
num_workers=self.num_workers,
|
||||
collate_fn=stack_samples,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
|
@ -802,6 +806,7 @@ class ChesapeakeCVPRDataModule(LightningDataModule):
|
|||
batch_size=self.batch_size,
|
||||
sampler=sampler,
|
||||
num_workers=self.num_workers,
|
||||
collate_fn=stack_samples,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
|
@ -820,4 +825,5 @@ class ChesapeakeCVPRDataModule(LightningDataModule):
|
|||
batch_size=self.batch_size,
|
||||
sampler=sampler,
|
||||
num_workers=self.num_workers,
|
||||
collate_fn=stack_samples,
|
||||
)
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
import abc
|
||||
import functools
|
||||
import glob
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
@ -16,8 +15,10 @@ import fiona
|
|||
import fiona.transform
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pyproj
|
||||
import rasterio
|
||||
import rasterio.merge
|
||||
import shapely
|
||||
import torch
|
||||
from rasterio.crs import CRS
|
||||
from rasterio.io import DatasetReader
|
||||
|
@ -29,7 +30,7 @@ from torch.utils.data import Dataset
|
|||
from torchvision.datasets import ImageFolder
|
||||
from torchvision.datasets.folder import default_loader as pil_loader
|
||||
|
||||
from .utils import BoundingBox, disambiguate_timestamp
|
||||
from .utils import BoundingBox, concat_samples, disambiguate_timestamp, merge_samples
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
|
@ -46,20 +47,53 @@ class GeoDataset(Dataset[Dict[str, Any]], abc.ABC):
|
|||
* :term:`coordinate reference system (CRS)`
|
||||
* resolution
|
||||
|
||||
These kind of datasets are special because they can be combined. For example:
|
||||
:class:`GeoDataset` is a special class of datasets. Unlike :class:`VisionDataset`,
|
||||
the presence of geospatial information allows two or more datasets to be combined
|
||||
based on latitude/longitude. This allows users to do things like:
|
||||
|
||||
* Combine Landsat8 and CDL to train a model for crop classification
|
||||
* Combine NAIP and Chesapeake to train a model for land cover mapping
|
||||
* Combine image and target labels and sample from both simultaneously
|
||||
(e.g. Landsat and CDL)
|
||||
* Combine datasets for multiple image sources for multimodal learning or data fusion
|
||||
(e.g. Landsat and Sentinel)
|
||||
|
||||
This isn't true for :class:`VisionDataset`, where the lack of geospatial information
|
||||
prohibits swapping image sources or target labels.
|
||||
These combinations require that all queries are present in *both* datasets,
|
||||
and can be combined using an :class:`IntersectionDataset`:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
dataset = landsat & cdl
|
||||
|
||||
Users may also want to:
|
||||
|
||||
* Combine datasets for multiple image sources and treat them as equivalent
|
||||
(e.g. Landsat 7 and Landsat 8)
|
||||
* Combine datasets for disparate geospatial locations
|
||||
(e.g. Chesapeake NY and PA)
|
||||
|
||||
These combinations require that all queries are present in *at least one* dataset,
|
||||
and can be combined using a :class:`UnionDataset`:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
dataset = landsat7 | landsat8
|
||||
"""
|
||||
|
||||
#: :term:`coordinate reference system (CRS)` for the dataset.
|
||||
crs: CRS
|
||||
|
||||
#: Resolution of the dataset in units of CRS.
|
||||
res: float
|
||||
_crs: CRS
|
||||
|
||||
# NOTE: according to the Python docs:
|
||||
#
|
||||
# * https://docs.python.org/3/library/exceptions.html#NotImplementedError
|
||||
#
|
||||
# the correct way to handle __add__ not being supported is to set it to None,
|
||||
# not to return NotImplemented or raise NotImplementedError. The downside of
|
||||
# this is that we have no way to explain to a user why they get an error and
|
||||
# what they should do instead (use __and__ or __or__).
|
||||
|
||||
#: :class:`GeoDataset` addition can be ambiguous and is no longer supported.
|
||||
#: Users should instead use the intersection or union operator.
|
||||
__add__ = None # type: ignore[assignment]
|
||||
|
||||
def __init__(
|
||||
self, transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None
|
||||
|
@ -89,8 +123,8 @@ class GeoDataset(Dataset[Dict[str, Any]], abc.ABC):
|
|||
IndexError: if query is not found in the index
|
||||
"""
|
||||
|
||||
def __add__(self, other: "GeoDataset") -> "ZipDataset": # type: ignore[override]
|
||||
"""Merge two GeoDatasets.
|
||||
def __and__(self, other: "GeoDataset") -> "IntersectionDataset":
|
||||
"""Take the intersection of two :class:`GeoDataset`.
|
||||
|
||||
Args:
|
||||
other: another dataset
|
||||
|
@ -99,11 +133,23 @@ class GeoDataset(Dataset[Dict[str, Any]], abc.ABC):
|
|||
a single dataset
|
||||
|
||||
Raises:
|
||||
ValueError: if other is not a GeoDataset, or if datasets do not overlap,
|
||||
or if datasets do not have the same
|
||||
:term:`coordinate reference system (CRS)`
|
||||
ValueError: if other is not a :class:`GeoDataset`
|
||||
"""
|
||||
return ZipDataset([self, other])
|
||||
return IntersectionDataset(self, other)
|
||||
|
||||
def __or__(self, other: "GeoDataset") -> "UnionDataset":
|
||||
"""Take the union of two GeoDatasets.
|
||||
|
||||
Args:
|
||||
other: another dataset
|
||||
|
||||
Returns:
|
||||
a single dataset
|
||||
|
||||
Raises:
|
||||
ValueError: if other is not a :class:`GeoDataset`
|
||||
"""
|
||||
return UnionDataset(self, other)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of files in the dataset.
|
||||
|
@ -135,6 +181,43 @@ class GeoDataset(Dataset[Dict[str, Any]], abc.ABC):
|
|||
"""
|
||||
return BoundingBox(*self.index.bounds)
|
||||
|
||||
@property
|
||||
def crs(self) -> CRS:
|
||||
""":term:`coordinate reference system (CRS)` for the dataset.
|
||||
|
||||
Returns:
|
||||
the :term:`coordinate reference system (CRS)`
|
||||
"""
|
||||
return self._crs
|
||||
|
||||
@crs.setter
|
||||
def crs(self, new_crs: CRS) -> None:
|
||||
"""Change the :term:`coordinate reference system (CRS)` of a GeoDataset.
|
||||
|
||||
If ``new_crs == self.crs``, does nothing, otherwise updates the R-tree index.
|
||||
|
||||
Args:
|
||||
new_crs: new :term:`coordinate reference system (CRS)`
|
||||
"""
|
||||
if new_crs == self._crs:
|
||||
return
|
||||
|
||||
new_index = Index(interleaved=False, properties=Property(dimension=3))
|
||||
|
||||
project = pyproj.Transformer.from_crs(
|
||||
pyproj.CRS(str(self._crs)), pyproj.CRS(str(new_crs)), always_xy=True
|
||||
).transform
|
||||
for hit in self.index.intersection(self.index.bounds, objects=True):
|
||||
old_minx, old_maxx, old_miny, old_maxy, mint, maxt = hit.bounds
|
||||
old_box = shapely.geometry.box(old_minx, old_miny, old_maxx, old_maxy)
|
||||
new_box = shapely.ops.transform(project, old_box)
|
||||
new_minx, new_miny, new_maxx, new_maxy = new_box.bounds
|
||||
new_bounds = (new_minx, new_maxx, new_miny, new_maxy, mint, maxt)
|
||||
new_index.insert(hit.id, new_bounds, hit.object)
|
||||
|
||||
self._crs = new_crs
|
||||
self.index = new_index
|
||||
|
||||
|
||||
class RasterDataset(GeoDataset):
|
||||
"""Abstract base class for :class:`GeoDataset` stored as raster files."""
|
||||
|
@ -253,7 +336,7 @@ class RasterDataset(GeoDataset):
|
|||
f"No {self.__class__.__name__} data was found in '{root}'"
|
||||
)
|
||||
|
||||
self.crs = cast(CRS, crs)
|
||||
self._crs = cast(CRS, crs)
|
||||
self.res = cast(float, res)
|
||||
|
||||
def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
|
||||
|
@ -268,7 +351,7 @@ class RasterDataset(GeoDataset):
|
|||
Raises:
|
||||
IndexError: if query is not found in the index
|
||||
"""
|
||||
hits = self.index.intersection(query, objects=True)
|
||||
hits = self.index.intersection(tuple(query), objects=True)
|
||||
filepaths = [hit.object for hit in hits]
|
||||
|
||||
if not filepaths:
|
||||
|
@ -478,7 +561,7 @@ class VectorDataset(GeoDataset):
|
|||
f"No {self.__class__.__name__} data was found in '{root}'"
|
||||
)
|
||||
|
||||
self.crs = crs
|
||||
self._crs = crs
|
||||
|
||||
def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
|
||||
"""Retrieve image/mask and metadata indexed by query.
|
||||
|
@ -492,7 +575,7 @@ class VectorDataset(GeoDataset):
|
|||
Raises:
|
||||
IndexError: if query is not found in the index
|
||||
"""
|
||||
hits = self.index.intersection(query, objects=True)
|
||||
hits = self.index.intersection(tuple(query), objects=True)
|
||||
filepaths = [hit.object for hit in hits]
|
||||
|
||||
if not filepaths:
|
||||
|
@ -676,46 +759,80 @@ class VisionClassificationDataset(VisionDataset, ImageFolder): # type: ignore[m
|
|||
return tensor, label
|
||||
|
||||
|
||||
class ZipDataset(GeoDataset):
|
||||
"""Dataset for merging two or more GeoDatasets.
|
||||
class IntersectionDataset(GeoDataset):
|
||||
"""Dataset representing the intersection of two GeoDatasets.
|
||||
|
||||
For example, this allows you to combine an image source like Landsat8 with a target
|
||||
label like CDL.
|
||||
This allows users to do things like:
|
||||
|
||||
* Combine image and target labels and sample from both simultaneously
|
||||
(e.g. Landsat and CDL)
|
||||
* Combine datasets for multiple image sources for multimodal learning or data fusion
|
||||
(e.g. Landsat and Sentinel)
|
||||
|
||||
These combinations require that all queries are present in *both* datasets,
|
||||
and can be combined using an :class:`IntersectionDataset`:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
dataset = landsat & cdl
|
||||
"""
|
||||
|
||||
def __init__(self, datasets: Sequence[GeoDataset]) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
dataset1: GeoDataset,
|
||||
dataset2: GeoDataset,
|
||||
collate_fn: Callable[
|
||||
[Sequence[Dict[str, Any]]], Dict[str, Any]
|
||||
] = concat_samples,
|
||||
) -> None:
|
||||
"""Initialize a new Dataset instance.
|
||||
|
||||
Args:
|
||||
datasets: list of datasets to merge
|
||||
dataset1: the first dataset
|
||||
dataset2: the second dataset
|
||||
collate_fn: function used to collate samples
|
||||
|
||||
Raises:
|
||||
ValueError: if datasets contains non-GeoDatasets, do not overlap, are not in
|
||||
the same :term:`coordinate reference system (CRS)`, or do not have the
|
||||
same resolution
|
||||
ValueError: if either dataset is not a :class:`GeoDataset`
|
||||
"""
|
||||
for ds in datasets:
|
||||
super().__init__()
|
||||
self.datasets = [dataset1, dataset2]
|
||||
self.collate_fn = collate_fn
|
||||
|
||||
for ds in self.datasets:
|
||||
if not isinstance(ds, GeoDataset):
|
||||
raise ValueError("ZipDataset only supports GeoDatasets")
|
||||
raise ValueError("IntersectionDataset only supports GeoDatasets")
|
||||
|
||||
crs = datasets[0].crs
|
||||
res = datasets[0].res
|
||||
for ds in datasets:
|
||||
if ds.crs != crs:
|
||||
raise ValueError("Datasets must be in the same CRS")
|
||||
if not math.isclose(ds.res, res):
|
||||
# TODO: relax this constraint someday
|
||||
raise ValueError("Datasets must have the same resolution")
|
||||
self._crs = dataset1.crs
|
||||
self.res = dataset1.res
|
||||
|
||||
self.datasets = datasets
|
||||
self.crs = crs
|
||||
self.res = res
|
||||
# Force dataset2 to have the same CRS/res as dataset1
|
||||
if dataset1.crs != dataset2.crs:
|
||||
print(
|
||||
f"Converting {dataset2.__class__.__name__} CRS from "
|
||||
f"{dataset2.crs} to {dataset1.crs}"
|
||||
)
|
||||
dataset2.crs = dataset1.crs
|
||||
if dataset1.res != dataset2.res:
|
||||
print(
|
||||
f"Converting {dataset2.__class__.__name__} resolution from "
|
||||
f"{dataset2.res} to {dataset1.res}"
|
||||
)
|
||||
dataset2.res = dataset1.res
|
||||
|
||||
# Make sure datasets have overlap
|
||||
try:
|
||||
self.bounds
|
||||
except ValueError:
|
||||
raise ValueError("Datasets have no overlap")
|
||||
# Merge dataset indices into a single index
|
||||
self._merge_dataset_indices()
|
||||
|
||||
def _merge_dataset_indices(self) -> None:
|
||||
"""Create a new R-tree out of the individual indices from two datasets."""
|
||||
i = 0
|
||||
ds1, ds2 = self.datasets
|
||||
for hit1 in ds1.index.intersection(ds1.index.bounds, objects=True):
|
||||
for hit2 in ds2.index.intersection(hit1.bounds, objects=True):
|
||||
box1 = BoundingBox(*hit1.bounds)
|
||||
box2 = BoundingBox(*hit2.bounds)
|
||||
self.index.insert(i, tuple(box1 & box2))
|
||||
i += 1
|
||||
|
||||
def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
|
||||
"""Retrieve image and metadata indexed by query.
|
||||
|
@ -734,21 +851,10 @@ class ZipDataset(GeoDataset):
|
|||
f"query: {query} not found in index with bounds: {self.bounds}"
|
||||
)
|
||||
|
||||
# TODO: use collate_dict here to concatenate instead of replace.
|
||||
# For example, if using Landsat + Sentinel + CDL, don't want to remove Landsat
|
||||
# images and replace with Sentinel images.
|
||||
sample = {}
|
||||
for ds in self.datasets:
|
||||
sample.update(ds[query])
|
||||
return sample
|
||||
# All datasets are guaranteed to have a valid query
|
||||
samples = [ds[query] for ds in self.datasets]
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of files in the dataset.
|
||||
|
||||
Returns:
|
||||
length of the dataset
|
||||
"""
|
||||
return sum(map(len, self.datasets))
|
||||
return self.collate_fn(samples)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return the informal string representation of the object.
|
||||
|
@ -758,23 +864,117 @@ class ZipDataset(GeoDataset):
|
|||
"""
|
||||
return f"""\
|
||||
{self.__class__.__name__} Dataset
|
||||
type: ZipDataset
|
||||
type: IntersectionDataset
|
||||
bbox: {self.bounds}
|
||||
size: {len(self)}"""
|
||||
|
||||
@property
|
||||
def bounds(self) -> BoundingBox:
|
||||
"""Bounds of the index.
|
||||
|
||||
class UnionDataset(GeoDataset):
|
||||
"""Dataset representing the union of two GeoDatasets.
|
||||
|
||||
This allows users to do things like:
|
||||
|
||||
* Combine datasets for multiple image sources and treat them as equivalent
|
||||
(e.g. Landsat 7 and Landsat 8)
|
||||
* Combine datasets for disparate geospatial locations
|
||||
(e.g. Chesapeake NY and PA)
|
||||
|
||||
These combinations require that all queries are present in *at least one* dataset,
|
||||
and can be combined using a :class:`UnionDataset`:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
dataset = landsat7 | landsat8
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset1: GeoDataset,
|
||||
dataset2: GeoDataset,
|
||||
collate_fn: Callable[
|
||||
[Sequence[Dict[str, Any]]], Dict[str, Any]
|
||||
] = merge_samples,
|
||||
) -> None:
|
||||
"""Initialize a new Dataset instance.
|
||||
|
||||
Args:
|
||||
dataset1: the first dataset
|
||||
dataset2: the second dataset
|
||||
collate_fn: function used to collate samples
|
||||
|
||||
Raises:
|
||||
ValueError: if either dataset is not a :class:`GeoDataset`
|
||||
"""
|
||||
super().__init__()
|
||||
self.datasets = [dataset1, dataset2]
|
||||
self.collate_fn = collate_fn
|
||||
|
||||
for ds in self.datasets:
|
||||
if not isinstance(ds, GeoDataset):
|
||||
raise ValueError("UnionDataset only supports GeoDatasets")
|
||||
|
||||
self._crs = dataset1.crs
|
||||
self.res = dataset1.res
|
||||
|
||||
# Force dataset2 to have the same CRS/res as dataset1
|
||||
if dataset1.crs != dataset2.crs:
|
||||
print(
|
||||
f"Converting {dataset2.__class__.__name__} CRS from "
|
||||
f"{dataset2.crs} to {dataset1.crs}"
|
||||
)
|
||||
dataset2.crs = dataset1.crs
|
||||
if dataset1.res != dataset2.res:
|
||||
print(
|
||||
f"Converting {dataset2.__class__.__name__} resolution from "
|
||||
f"{dataset2.res} to {dataset1.res}"
|
||||
)
|
||||
dataset2.res = dataset1.res
|
||||
|
||||
# Merge dataset indices into a single index
|
||||
self._merge_dataset_indices()
|
||||
|
||||
def _merge_dataset_indices(self) -> None:
|
||||
"""Create a new R-tree out of the individual indices from two datasets."""
|
||||
i = 0
|
||||
for ds in self.datasets:
|
||||
hits = ds.index.intersection(ds.index.bounds, objects=True)
|
||||
for hit in hits:
|
||||
self.index.insert(i, hit.bounds)
|
||||
i += 1
|
||||
|
||||
def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
|
||||
"""Retrieve image and metadata indexed by query.
|
||||
|
||||
Args:
|
||||
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
|
||||
|
||||
Returns:
|
||||
(minx, maxx, miny, maxy, mint, maxt) of the dataset
|
||||
"""
|
||||
# We want to compute the intersection of all dataset bounds, not the union
|
||||
minx = max([ds.bounds[0] for ds in self.datasets])
|
||||
maxx = min([ds.bounds[1] for ds in self.datasets])
|
||||
miny = max([ds.bounds[2] for ds in self.datasets])
|
||||
maxy = min([ds.bounds[3] for ds in self.datasets])
|
||||
mint = max([ds.bounds[4] for ds in self.datasets])
|
||||
maxt = min([ds.bounds[5] for ds in self.datasets])
|
||||
sample of data/labels and metadata at that index
|
||||
|
||||
return BoundingBox(minx, maxx, miny, maxy, mint, maxt)
|
||||
Raises:
|
||||
IndexError: if query is not within bounds of the index
|
||||
"""
|
||||
if not query.intersects(self.bounds):
|
||||
raise IndexError(
|
||||
f"query: {query} not found in index with bounds: {self.bounds}"
|
||||
)
|
||||
|
||||
# Not all datasets are guaranteed to have a valid query
|
||||
samples = []
|
||||
for ds in self.datasets:
|
||||
if ds.index.intersection(tuple(query)):
|
||||
samples.append(ds[query])
|
||||
|
||||
return self.collate_fn(samples)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return the informal string representation of the object.
|
||||
|
||||
Returns:
|
||||
informal string representation
|
||||
"""
|
||||
return f"""\
|
||||
{self.__class__.__name__} Dataset
|
||||
type: UnionDataset
|
||||
bbox: {self.bounds}
|
||||
size: {len(self)}"""
|
||||
|
|
|
@ -12,7 +12,7 @@ from ..samplers.batch import RandomBatchGeoSampler
|
|||
from ..samplers.single import GridGeoSampler
|
||||
from .chesapeake import Chesapeake13
|
||||
from .geo import RasterDataset
|
||||
from .utils import BoundingBox
|
||||
from .utils import BoundingBox, stack_samples
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
|
@ -143,7 +143,7 @@ class NAIPChesapeakeDataModule(pl.LightningDataModule):
|
|||
chesapeake.res,
|
||||
transforms=self.naip_transform,
|
||||
)
|
||||
self.dataset = chesapeake + naip
|
||||
self.dataset = chesapeake & naip
|
||||
|
||||
# TODO: figure out better train/val/test split
|
||||
roi = self.dataset.bounds
|
||||
|
@ -166,7 +166,10 @@ class NAIPChesapeakeDataModule(pl.LightningDataModule):
|
|||
training data loader
|
||||
"""
|
||||
return DataLoader(
|
||||
self.dataset, batch_sampler=self.train_sampler, num_workers=self.num_workers
|
||||
self.dataset,
|
||||
batch_sampler=self.train_sampler,
|
||||
num_workers=self.num_workers,
|
||||
collate_fn=stack_samples,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader[Any]:
|
||||
|
@ -180,6 +183,7 @@ class NAIPChesapeakeDataModule(pl.LightningDataModule):
|
|||
batch_size=self.batch_size,
|
||||
sampler=self.val_sampler,
|
||||
num_workers=self.num_workers,
|
||||
collate_fn=stack_samples,
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> DataLoader[Any]:
|
||||
|
@ -193,4 +197,5 @@ class NAIPChesapeakeDataModule(pl.LightningDataModule):
|
|||
batch_size=self.batch_size,
|
||||
sampler=self.test_sampler,
|
||||
num_workers=self.num_workers,
|
||||
collate_fn=stack_samples,
|
||||
)
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
"""Common dataset utilities."""
|
||||
|
||||
import bz2
|
||||
import collections
|
||||
import contextlib
|
||||
import gzip
|
||||
import lzma
|
||||
|
@ -11,8 +12,21 @@ import os
|
|||
import sys
|
||||
import tarfile
|
||||
import zipfile
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union, cast
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import rasterio
|
||||
|
@ -30,7 +44,9 @@ __all__ = (
|
|||
"BoundingBox",
|
||||
"disambiguate_timestamp",
|
||||
"working_dir",
|
||||
"collate_dict",
|
||||
"stack_samples",
|
||||
"concat_samples",
|
||||
"merge_samples",
|
||||
"rasterio_loader",
|
||||
"dataset_split",
|
||||
"sort_sentinel2_bands",
|
||||
|
@ -184,97 +200,134 @@ def download_radiant_mlhub_collection(
|
|||
collection.download(output_dir=download_root, api_key=api_key)
|
||||
|
||||
|
||||
class BoundingBox(Tuple[float, float, float, float, float, float]):
|
||||
"""Data class for indexing spatiotemporal data.
|
||||
@dataclass(frozen=True)
|
||||
class BoundingBox:
|
||||
"""Data class for indexing spatiotemporal data."""
|
||||
|
||||
Attributes:
|
||||
minx (float): western boundary
|
||||
maxx (float): eastern boundary
|
||||
miny (float): southern boundary
|
||||
maxy (float): northern boundary
|
||||
mint (float): earliest boundary
|
||||
maxt (float): latest boundary
|
||||
"""
|
||||
#: western boundary
|
||||
minx: float
|
||||
#: eastern boundary
|
||||
maxx: float
|
||||
#: southern boundary
|
||||
miny: float
|
||||
#: northern boundary
|
||||
maxy: float
|
||||
#: earliest boundary
|
||||
mint: float
|
||||
#: latest boundary
|
||||
maxt: float
|
||||
|
||||
def __new__(
|
||||
cls,
|
||||
minx: float,
|
||||
maxx: float,
|
||||
miny: float,
|
||||
maxy: float,
|
||||
mint: float,
|
||||
maxt: float,
|
||||
) -> "BoundingBox":
|
||||
"""Create a new instance of BoundingBox.
|
||||
|
||||
Args:
|
||||
minx: western boundary
|
||||
maxx: eastern boundary
|
||||
miny: southern boundary
|
||||
maxy: northern boundary
|
||||
mint: earliest boundary
|
||||
maxt: latest boundary
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate the arguments passed to :meth:`__init__`.
|
||||
|
||||
Raises:
|
||||
ValueError: if bounding box is invalid
|
||||
(minx > maxx, miny > maxy, or mint > maxt)
|
||||
"""
|
||||
if minx > maxx:
|
||||
raise ValueError(f"Bounding box is invalid: 'minx={minx}' > 'maxx={maxx}'")
|
||||
if miny > maxy:
|
||||
raise ValueError(f"Bounding box is invalid: 'miny={miny}' > 'maxy={maxy}'")
|
||||
if mint > maxt:
|
||||
raise ValueError(f"Bounding box is invalid: 'mint={mint}' > 'maxt={maxt}'")
|
||||
if self.minx > self.maxx:
|
||||
raise ValueError(
|
||||
f"Bounding box is invalid: 'minx={self.minx}' > 'maxx={self.maxx}'"
|
||||
)
|
||||
if self.miny > self.maxy:
|
||||
raise ValueError(
|
||||
f"Bounding box is invalid: 'miny={self.miny}' > 'maxy={self.maxy}'"
|
||||
)
|
||||
if self.mint > self.maxt:
|
||||
raise ValueError(
|
||||
f"Bounding box is invalid: 'mint={self.mint}' > 'maxt={self.maxt}'"
|
||||
)
|
||||
|
||||
# Using super() doesn't work with mypy, see:
|
||||
# https://stackoverflow.com/q/60611012/5828163
|
||||
return tuple.__new__(cls, [minx, maxx, miny, maxy, mint, maxt])
|
||||
# https://github.com/PyCQA/pydocstyle/issues/525
|
||||
@overload
|
||||
def __getitem__(self, key: int) -> float: # noqa: D105
|
||||
pass
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
minx: float,
|
||||
maxx: float,
|
||||
miny: float,
|
||||
maxy: float,
|
||||
mint: float,
|
||||
maxt: float,
|
||||
) -> None:
|
||||
"""Initialize a new instance of BoundingBox.
|
||||
@overload
|
||||
def __getitem__(self, key: slice) -> List[float]: # noqa: D105
|
||||
pass
|
||||
|
||||
def __getitem__(self, key: Union[int, slice]) -> Union[float, List[float]]:
|
||||
"""Index the (minx, maxx, miny, maxy, mint, maxt) tuple.
|
||||
|
||||
Args:
|
||||
minx: western boundary
|
||||
maxx: eastern boundary
|
||||
miny: southern boundary
|
||||
maxy: northern boundary
|
||||
mint: earliest boundary
|
||||
maxt: latest boundary
|
||||
"""
|
||||
self.minx = minx
|
||||
self.maxx = maxx
|
||||
self.miny = miny
|
||||
self.maxy = maxy
|
||||
self.mint = mint
|
||||
self.maxt = maxt
|
||||
|
||||
def __getnewargs__(self) -> Tuple[float, float, float, float, float, float]:
|
||||
"""Values passed to the ``__new__()`` method upon unpickling.
|
||||
key: integer or slice object
|
||||
|
||||
Returns:
|
||||
tuple of bounds
|
||||
"""
|
||||
return self.minx, self.maxx, self.miny, self.maxy, self.mint, self.maxt
|
||||
the value(s) at that index
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return the formal string representation of the object.
|
||||
Raises:
|
||||
IndexError: if key is out of bounds
|
||||
"""
|
||||
return [self.minx, self.maxx, self.miny, self.maxy, self.mint, self.maxt][key]
|
||||
|
||||
def __iter__(self) -> Iterator[float]:
|
||||
"""Container iterator.
|
||||
|
||||
Returns:
|
||||
formal string representation
|
||||
iterator object that iterates over all objects in the container
|
||||
"""
|
||||
yield from [self.minx, self.maxx, self.miny, self.maxy, self.mint, self.maxt]
|
||||
|
||||
def __contains__(self, other: "BoundingBox") -> bool:
|
||||
"""Whether or not other is within the bounds of this bounding box.
|
||||
|
||||
Args:
|
||||
other: another bounding box
|
||||
|
||||
Returns:
|
||||
True if other is within this bounding box, else False
|
||||
"""
|
||||
return (
|
||||
f"{self.__class__.__name__}(minx={self.minx}, maxx={self.maxx}, "
|
||||
f"miny={self.miny}, maxy={self.maxy}, mint={self.mint}, maxt={self.maxt})"
|
||||
(self.minx <= other.minx <= self.maxx)
|
||||
and (self.minx <= other.maxx <= self.maxx)
|
||||
and (self.miny <= other.miny <= self.maxy)
|
||||
and (self.miny <= other.maxy <= self.maxy)
|
||||
and (self.mint <= other.mint <= self.maxt)
|
||||
and (self.mint <= other.maxt <= self.maxt)
|
||||
)
|
||||
|
||||
def __or__(self, other: "BoundingBox") -> "BoundingBox":
|
||||
"""The union operator.
|
||||
|
||||
Args:
|
||||
other: another bounding box
|
||||
|
||||
Returns:
|
||||
the minimum bounding box that contains both self and other
|
||||
"""
|
||||
return BoundingBox(
|
||||
min(self.minx, other.minx),
|
||||
max(self.maxx, other.maxx),
|
||||
min(self.miny, other.miny),
|
||||
max(self.maxy, other.maxy),
|
||||
min(self.mint, other.mint),
|
||||
max(self.maxt, other.maxt),
|
||||
)
|
||||
|
||||
def __and__(self, other: "BoundingBox") -> "BoundingBox":
|
||||
"""The intersection operator.
|
||||
|
||||
Args:
|
||||
other: another bounding box
|
||||
|
||||
Returns:
|
||||
the intersection of self and other
|
||||
|
||||
Raises:
|
||||
ValueError: if self and other do not intersect
|
||||
"""
|
||||
try:
|
||||
return BoundingBox(
|
||||
max(self.minx, other.minx),
|
||||
min(self.maxx, other.maxx),
|
||||
max(self.miny, other.miny),
|
||||
min(self.maxy, other.maxy),
|
||||
max(self.mint, other.mint),
|
||||
min(self.maxt, other.maxt),
|
||||
)
|
||||
except ValueError:
|
||||
raise ValueError(f"Bounding boxes {self} and {other} do not overlap")
|
||||
|
||||
def intersects(self, other: "BoundingBox") -> bool:
|
||||
"""Whether or not two bounding boxes intersect.
|
||||
|
||||
|
@ -370,8 +423,27 @@ def working_dir(dirname: str, create: bool = False) -> Iterator[None]:
|
|||
os.chdir(cwd)
|
||||
|
||||
|
||||
def collate_dict(samples: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""Merge a list of samples to form a mini-batch of Tensors.
|
||||
def _list_dict_to_dict_list(samples: Iterable[Dict[Any, Any]]) -> Dict[Any, List[Any]]:
|
||||
"""Convert a list of dictionaries to a dictionary of lists.
|
||||
|
||||
Args:
|
||||
samples: a list of dictionaries
|
||||
|
||||
Returns:
|
||||
a dictionary of lists
|
||||
"""
|
||||
collated = collections.defaultdict(list)
|
||||
for sample in samples:
|
||||
for key, value in sample.items():
|
||||
collated[key].append(value)
|
||||
return collated
|
||||
|
||||
|
||||
def stack_samples(samples: Iterable[Dict[Any, Any]]) -> Dict[Any, Any]:
|
||||
"""Stack a list of samples along a new axis.
|
||||
|
||||
Useful for forming a mini-batch of samples to pass to
|
||||
:class:`torch.utils.data.DataLoader`.
|
||||
|
||||
Args:
|
||||
samples: list of samples
|
||||
|
@ -379,14 +451,55 @@ def collate_dict(samples: List[Dict[str, Any]]) -> Dict[str, Any]:
|
|||
Returns:
|
||||
a single sample
|
||||
"""
|
||||
collated = {}
|
||||
for key, value in samples[0].items():
|
||||
if isinstance(value, Tensor):
|
||||
collated[key] = torch.stack([sample[key] for sample in samples])
|
||||
collated: Dict[Any, Any] = _list_dict_to_dict_list(samples)
|
||||
for key, value in collated.items():
|
||||
if isinstance(value[0], Tensor):
|
||||
collated[key] = torch.stack(value)
|
||||
return collated
|
||||
|
||||
|
||||
def concat_samples(samples: Iterable[Dict[Any, Any]]) -> Dict[Any, Any]:
|
||||
"""Concatenate a list of samples along an existing axis.
|
||||
|
||||
Useful for joining samples in a :class:`torchgeo.datasets.IntersectionDataset`.
|
||||
|
||||
Args:
|
||||
samples: list of samples
|
||||
|
||||
Returns:
|
||||
a single sample
|
||||
"""
|
||||
collated: Dict[Any, Any] = _list_dict_to_dict_list(samples)
|
||||
for key, value in collated.items():
|
||||
if isinstance(value[0], Tensor):
|
||||
collated[key] = torch.cat(value) # type: ignore[attr-defined]
|
||||
else:
|
||||
collated[key] = [
|
||||
sample[key] for sample in samples
|
||||
] # type: ignore[assignment]
|
||||
collated[key] = value[0]
|
||||
return collated
|
||||
|
||||
|
||||
def merge_samples(samples: Iterable[Dict[Any, Any]]) -> Dict[Any, Any]:
|
||||
"""Merge a list of samples.
|
||||
|
||||
Useful for joining samples in a :class:`torchgeo.datasets.UnionDataset`.
|
||||
|
||||
Args:
|
||||
samples: list of samples
|
||||
|
||||
Returns:
|
||||
a single sample
|
||||
"""
|
||||
collated: Dict[Any, Any] = {}
|
||||
for sample in samples:
|
||||
for key, value in sample.items():
|
||||
if key in collated and isinstance(value, Tensor):
|
||||
# Take the maximum so that nodata values (zeros) get replaced
|
||||
# by data values whenever possible
|
||||
collated[key] = torch.maximum( # type: ignore[attr-defined]
|
||||
collated[key], value
|
||||
)
|
||||
else:
|
||||
collated[key] = value
|
||||
return collated
|
||||
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ import abc
|
|||
import random
|
||||
from typing import Iterator, List, Optional, Tuple, Union
|
||||
|
||||
from rtree.index import Index, Property
|
||||
from torch.utils.data import Sampler
|
||||
|
||||
from torchgeo.datasets.geo import GeoDataset
|
||||
|
@ -28,6 +29,27 @@ class BatchGeoSampler(Sampler[List[BoundingBox]], abc.ABC):
|
|||
longitude, height, width, projection, coordinate system, and time.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset: GeoDataset, roi: Optional[BoundingBox] = None) -> None:
|
||||
"""Initialize a new Sampler instance.
|
||||
|
||||
Args:
|
||||
dataset: dataset to index from
|
||||
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
|
||||
(defaults to the bounds of ``dataset.index``)
|
||||
"""
|
||||
if roi is None:
|
||||
self.index = dataset.index
|
||||
roi = BoundingBox(*self.index.bounds)
|
||||
else:
|
||||
self.index = Index(interleaved=False, properties=Property(dimension=3))
|
||||
hits = dataset.index.intersection(tuple(roi), objects=True)
|
||||
for hit in hits:
|
||||
bbox = BoundingBox(*hit.bounds) & roi
|
||||
self.index.insert(hit.id, tuple(bbox), hit.object)
|
||||
|
||||
self.res = dataset.res
|
||||
self.roi = roi
|
||||
|
||||
@abc.abstractmethod
|
||||
def __iter__(self) -> Iterator[List[BoundingBox]]:
|
||||
"""Return a batch of indices of a dataset.
|
||||
|
@ -42,9 +64,6 @@ class RandomBatchGeoSampler(BatchGeoSampler):
|
|||
|
||||
This is particularly useful during training when you want to maximize the size of
|
||||
the dataset and return as many random :term:`chips <chip>` as possible.
|
||||
|
||||
When sampling from :class:`~torchgeo.datasets.ZipDataset`, the ``dataset`` should be
|
||||
a tile-based dataset if possible.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -72,15 +91,11 @@ class RandomBatchGeoSampler(BatchGeoSampler):
|
|||
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
|
||||
(defaults to the bounds of ``dataset.index``)
|
||||
"""
|
||||
self.index = dataset.index
|
||||
self.res = dataset.res
|
||||
super().__init__(dataset, roi)
|
||||
self.size = _to_tuple(size)
|
||||
self.batch_size = batch_size
|
||||
self.length = length
|
||||
if roi is None:
|
||||
roi = BoundingBox(*self.index.bounds)
|
||||
self.roi = roi
|
||||
self.hits = list(self.index.intersection(roi, objects=True))
|
||||
self.hits = list(self.index.intersection(tuple(self.roi), objects=True))
|
||||
|
||||
def __iter__(self) -> Iterator[List[BoundingBox]]:
|
||||
"""Return the indices of a dataset.
|
||||
|
|
|
@ -7,6 +7,7 @@ import abc
|
|||
import random
|
||||
from typing import Iterator, Optional, Tuple, Union
|
||||
|
||||
from rtree.index import Index, Property
|
||||
from torch.utils.data import Sampler
|
||||
|
||||
from torchgeo.datasets.geo import GeoDataset
|
||||
|
@ -28,6 +29,27 @@ class GeoSampler(Sampler[BoundingBox], abc.ABC):
|
|||
longitude, height, width, projection, coordinate system, and time.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset: GeoDataset, roi: Optional[BoundingBox] = None) -> None:
|
||||
"""Initialize a new Sampler instance.
|
||||
|
||||
Args:
|
||||
dataset: dataset to index from
|
||||
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
|
||||
(defaults to the bounds of ``dataset.index``)
|
||||
"""
|
||||
if roi is None:
|
||||
self.index = dataset.index
|
||||
roi = BoundingBox(*self.index.bounds)
|
||||
else:
|
||||
self.index = Index(interleaved=False, properties=Property(dimension=3))
|
||||
hits = dataset.index.intersection(tuple(roi), objects=True)
|
||||
for hit in hits:
|
||||
bbox = BoundingBox(*hit.bounds) & roi
|
||||
self.index.insert(hit.id, tuple(bbox), hit.object)
|
||||
|
||||
self.res = dataset.res
|
||||
self.roi = roi
|
||||
|
||||
@abc.abstractmethod
|
||||
def __iter__(self) -> Iterator[BoundingBox]:
|
||||
"""Return the index of a dataset.
|
||||
|
@ -70,14 +92,10 @@ class RandomGeoSampler(GeoSampler):
|
|||
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
|
||||
(defaults to the bounds of ``dataset.index``)
|
||||
"""
|
||||
self.index = dataset.index
|
||||
self.res = dataset.res
|
||||
super().__init__(dataset, roi)
|
||||
self.size = _to_tuple(size)
|
||||
self.length = length
|
||||
if roi is None:
|
||||
roi = BoundingBox(*self.index.bounds)
|
||||
self.roi = roi
|
||||
self.hits = list(self.index.intersection(roi, objects=True))
|
||||
self.hits = list(self.index.intersection(tuple(self.roi), objects=True))
|
||||
|
||||
def __iter__(self) -> Iterator[BoundingBox]:
|
||||
"""Return the index of a dataset.
|
||||
|
@ -117,9 +135,6 @@ class GridGeoSampler(GeoSampler):
|
|||
The overlap between each chip (``chip_size - stride``) should be approximately equal
|
||||
to the `receptive field <https://distill.pub/2019/computing-receptive-fields/>`_ of
|
||||
the CNN.
|
||||
|
||||
When sampling from :class:`~torchgeo.datasets.ZipDataset`, the ``dataset`` should be
|
||||
a non-tile-based dataset if possible.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -145,13 +160,10 @@ class GridGeoSampler(GeoSampler):
|
|||
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
|
||||
(defaults to the bounds of ``dataset.index``)
|
||||
"""
|
||||
self.index = dataset.index
|
||||
super().__init__(dataset, roi)
|
||||
self.size = _to_tuple(size)
|
||||
self.stride = _to_tuple(stride)
|
||||
if roi is None:
|
||||
roi = BoundingBox(*self.index.bounds)
|
||||
self.roi = roi
|
||||
self.hits = list(self.index.intersection(roi, objects=True))
|
||||
self.hits = list(self.index.intersection(tuple(self.roi), objects=True))
|
||||
|
||||
self.length: int = 0
|
||||
for hit in self.hits:
|
||||
|
|
Загрузка…
Ссылка в новой задаче