TorchGeo: datasets, samplers, transforms, and pre-trained models for geospatial data
Перейти к файлу
Nils Lehmann 60eb61b5fa
Add Multi-Weight Support API (#917)
* load pretrained weights

* change name millionaid

* restructure and additional weights

* rename sentinel1 weights

* add vit small weights

* forgot to add vit.py

* struggling with test

* wrong name failing test

* feedback on tests

* increase test coverage

* fix failing test

* fix failing test

* fix failing test and add vit tests

* fix failing vit test

* torchgeo.models.utils

* forgot utils file

* typo num channels

* nitpick docs, version torchvision

* another try min dependencies

* add documentation table

* expand pytests to test pretrained weights on tasks

* reverse changes to byol task

* add tests to init pretrained weights from config

* forgot to add the conf files

* change path

* increase test coverage

* vit tests all pass locally including slow

* now remote

* fix tests another one

* add a draft tutorial

* run black on tutorial notebook

* Tutorial typo fixes

* Lower min torch/vision versions

* Fix bad rebase

* Remove dead code

* Flake8 fixes

* Consistent in_chans

* Black fixes

* bison > yacs

* Remove one more reference

* Download modified weights from hugging face

* Add entrypoints

* Add torch.hub support

* progress arg is required

* Fix model loading for resnet18

* Add transforms, update tests

* VIT -> ViT

* add seco weights

* Fix type hints

* Link to timm docs

* Fix pydocstyle

* Try to fix timm docs link

* Fix tests

* Nuke ignores

* Ignore timm links

* Add model API methods

* Add to __init__ and document

* Test model API functions

* fix tests

* Use correct documentation link for intersphinx

* Typos

* Fix Windows tests

* meth -> func

* Explicit function scope

* weight-specific filename

* Support enums in classification trainer

* Update other trainers too

* Fix regression tests

* Fix classification tests

* Fix byol tests

* Fix types

* progress_bar is required arg

* Test weight enums

* Fix pickling

* Fix regression tests

* Improve coverage of classification tests

* Improve coverage of BYOL tests

* Update resnet table

* Update ViT table

* Update get_state_dict usage

* Remove unused YAML files

* Update table widths

* Documentation improvements

* Tweak tables

* Try to fix Windows tests

* Revert "Try to fix Windows tests"

This reverts commit 1325b13ff7.

* Monkeypatch everything

* Revert "Monkeypatch everything"

This reverts commit e3e8d7d042.

* Revert "Revert "Monkeypatch everything""

This reverts commit 9b27bd705b.

* Patch things not at the source

* Fix missing import

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
2023-01-22 14:25:49 -08:00
.github Bump actions/setup-python from 4.4.0 to 4.5.0 (#1016) 2023-01-16 14:44:33 -06:00
conf Add Multi-Weight Support API (#917) 2023-01-22 14:25:49 -08:00
docs Add Multi-Weight Support API (#917) 2023-01-22 14:25:49 -08:00
experiments Change segmentation model argument names (#919) 2022-12-12 21:43:25 -06:00
images Better README (#626) 2022-06-29 22:25:59 -07:00
logo Add favicon to ReadTheDocs 2021-09-08 16:08:04 -05:00
requirements Add Multi-Weight Support API (#917) 2023-01-22 14:25:49 -08:00
tests Add Multi-Weight Support API (#917) 2023-01-22 14:25:49 -08:00
torchgeo Add Multi-Weight Support API (#917) 2023-01-22 14:25:49 -08:00
.codecov.yml Remove Codecov annotations from PRs 2021-09-19 11:07:39 -05:00
.gitattributes gitattributes: allow diff of test data (#470) 2022-03-19 10:31:03 -05:00
.gitignore Adding results directory to gitignore (#1028) 2023-01-22 12:55:11 -06:00
.pre-commit-config.yaml Update pre-commit hooks (#987) 2022-12-29 08:34:21 -06:00
.readthedocs.yaml Move requirements-min.txt, split deps in multiple files (#605) 2022-06-19 09:07:42 -07:00
CITATION.cff Citation: arXiv -> SIGSPATIAL (#927) 2022-12-01 06:20:06 -08:00
CODE_OF_CONDUCT.md Add Microsoft open-source template 2021-05-21 11:35:58 -05:00
LICENSE Add Microsoft open-source template 2021-05-21 11:35:58 -05:00
README.md Citation: arXiv -> SIGSPATIAL (#927) 2022-12-01 06:20:06 -08:00
SECURITY.md Add Microsoft open-source template 2021-05-21 11:35:58 -05:00
SUPPORT.md Add Microsoft open-source template 2021-05-21 11:35:58 -05:00
benchmark.py Remove type ignores for PyTorch (#460) 2022-03-14 20:35:37 +00:00
environment.yml Add Multi-Weight Support API (#917) 2023-01-22 14:25:49 -08:00
evaluate.py Change segmentation model argument names (#919) 2022-12-12 21:43:25 -06:00
hubconf.py Add Multi-Weight Support API (#917) 2023-01-22 14:25:49 -08:00
pyproject.toml Bump setuptools from 65.7.0 to 66.0.0 in /requirements (#1017) 2023-01-16 22:18:53 +00:00
setup.cfg Add Multi-Weight Support API (#917) 2023-01-22 14:25:49 -08:00
train.py Add datamodule for GID-15 dataset (#928) 2022-12-30 11:31:00 -06:00

README.md

TorchGeo logo

TorchGeo is a PyTorch domain library, similar to torchvision, providing datasets, samplers, transforms, and pre-trained models specific to geospatial data.

The goal of this library is to make it simple:

  1. for machine learning experts to work with geospatial data, and
  2. for remote sensing experts to explore machine learning solutions.

Testing: docs style tests codecov

Packaging: pypi conda spack

Installation

The recommended way to install TorchGeo is with pip:

$ pip install torchgeo

For conda and spack installation instructions, see the documentation.

Documentation

You can find the documentation for TorchGeo on ReadTheDocs. This includes API documentation, contributing instructions, and several tutorials. For more details, check out our paper and blog.

Example Usage

The following sections give basic examples of what you can do with TorchGeo.

First we'll import various classes and functions used in the following sections:

from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from torchgeo.datamodules import InriaAerialImageLabelingDataModule
from torchgeo.datasets import CDL, Landsat7, Landsat8, VHR10, stack_samples
from torchgeo.samplers import RandomGeoSampler
from torchgeo.trainers import SemanticSegmentationTask

Geospatial datasets and samplers

Many remote sensing applications involve working with geospatial datasets—datasets with geographic metadata. These datasets can be challenging to work with due to the sheer variety of data. Geospatial imagery is often multispectral with a different number of spectral bands and spatial resolution for every satellite. In addition, each file may be in a different coordinate reference system (CRS), requiring the data to be reprojected into a matching CRS.

Example application in which we combine Landsat and CDL and sample from both

In this example, we show how easy it is to work with geospatial data and to sample small image patches from a combination of Landsat and Cropland Data Layer (CDL) data using TorchGeo. First, we assume that the user has Landsat 7 and 8 imagery downloaded. Since Landsat 8 has more spectral bands than Landsat 7, we'll only use the bands that both satellites have in common. We'll create a single dataset including all images from both Landsat 7 and 8 data by taking the union between these two datasets.

landsat7 = Landsat7(root="...")
landsat8 = Landsat8(root="...", bands=Landsat8.all_bands[1:-2])
landsat = landsat7 | landsat8

Next, we take the intersection between this dataset and the CDL dataset. We want to take the intersection instead of the union to ensure that we only sample from regions that have both Landsat and CDL data. Note that we can automatically download and checksum CDL data. Also note that each of these datasets may contain files in different coordinate reference systems (CRS) or resolutions, but TorchGeo automatically ensures that a matching CRS and resolution is used.

cdl = CDL(root="...", download=True, checksum=True)
dataset = landsat & cdl

This dataset can now be used with a PyTorch data loader. Unlike benchmark datasets, geospatial datasets often include very large images. For example, the CDL dataset consists of a single image covering the entire continental United States. In order to sample from these datasets using geospatial coordinates, TorchGeo defines a number of samplers. In this example, we'll use a random sampler that returns 256 x 256 pixel images and 10,000 samples per epoch. We also use a custom collation function to combine each sample dictionary into a mini-batch of samples.

sampler = RandomGeoSampler(dataset, size=256, length=10000)
dataloader = DataLoader(dataset, batch_size=128, sampler=sampler, collate_fn=stack_samples)

This data loader can now be used in your normal training/evaluation pipeline.

for batch in dataloader:
    image = batch["image"]
    mask = batch["mask"]

    # train a model, or make predictions using a pre-trained model

Many applications involve intelligently composing datasets based on geospatial metadata like this. For example, users may want to:

  • Combine datasets for multiple image sources and treat them as equivalent (e.g., Landsat 7 and 8)
  • Combine datasets for disparate geospatial locations (e.g., Chesapeake NY and PA)

These combinations require that all queries are present in at least one dataset, and can be created using a UnionDataset. Similarly, users may want to:

  • Combine image and target labels and sample from both simultaneously (e.g., Landsat and CDL)
  • Combine datasets for multiple image sources for multimodal learning or data fusion (e.g., Landsat and Sentinel)

These combinations require that all queries are present in both datasets, and can be created using an IntersectionDataset. TorchGeo automatically composes these datasets for you when you use the intersection (&) and union (|) operators.

Benchmark datasets

TorchGeo includes a number of benchmark datasets—datasets that include both input images and target labels. This includes datasets for tasks like image classification, regression, semantic segmentation, object detection, instance segmentation, change detection, and more.

If you've used torchvision before, these datasets should seem very familiar. In this example, we'll create a dataset for the Northwestern Polytechnical University (NWPU) very-high-resolution ten-class (VHR-10) geospatial object detection dataset. This dataset can be automatically downloaded, checksummed, and extracted, just like with torchvision.

dataset = VHR10(root="...", download=True, checksum=True)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4)

for batch in dataloader:
    image = batch["image"]
    label = batch["label"]

    # train a model, or make predictions using a pre-trained model
Example predictions from a Mask R-CNN model trained on the NWPU VHR-10 dataset

All TorchGeo datasets are compatible with PyTorch data loaders, making them easy to integrate into existing training workflows. The only difference between a benchmark dataset in TorchGeo and a similar dataset in torchvision is that each dataset returns a dictionary with keys for each PyTorch Tensor.

Reproducibility with PyTorch Lightning

In order to facilitate direct comparisons between results published in the literature and further reduce the boilerplate code needed to run experiments with datasets in TorchGeo, we have created PyTorch Lightning datamodules with well-defined train-val-test splits and trainers for various tasks like classification, regression, and semantic segmentation. These datamodules show how to incorporate augmentations from the kornia library, include preprocessing transforms (with pre-calculated channel statistics), and let users easily experiment with hyperparameters related to the data itself (as opposed to the modeling process). Training a semantic segmentation model on the Inria Aerial Image Labeling dataset is as easy as a few imports and four lines of code.

datamodule = InriaAerialImageLabelingDataModule(root="...", batch_size=64, num_workers=6)
task = SemanticSegmentationTask(segmentation_model="unet", encoder_weights="imagenet", learning_rate=0.1)
trainer = Trainer(gpus=1, default_root_dir="...")

trainer.fit(model=task, datamodule=datamodule)
Building segmentations produced by a U-Net model trained on the Inria Aerial Image Labeling dataset

In our GitHub repo, we provide train.py and evaluate.py scripts to train and evaluate the performance of models using these datamodules and trainers. These scripts are configurable via the command line and/or via YAML configuration files. See the conf directory for example configuration files that can be customized for different training runs.

$ python train.py config_file=conf/landcoverai.yaml

Citation

If you use this software in your work, please cite our paper:

@inproceedings{Stewart_TorchGeo_Deep_Learning_2022,
    address = {Seattle, Washington},
    author = {Stewart, Adam J. and Robinson, Caleb and Corley, Isaac A. and Ortiz, Anthony and Lavista Ferres, Juan M. and Banerjee, Arindam},
    booktitle = {Proceedings of the 30th International Conference on Advances in Geographic Information Systems},
    doi = {10.1145/3557915.3560953},
    month = {11},
    pages = {1--12},
    publisher = {Association for Computing Machinery},
    series = {SIGSPATIAL '22},
    title = {{TorchGeo}: Deep Learning With Geospatial Data},
    url = {https://dl.acm.org/doi/10.1145/3557915.3560953},
    year = {2022}
}

Contributing

This project welcomes contributions and suggestions. If you would like to submit a pull request, see our Contribution Guide for more information.

This project has adopted the Microsoft Open Source Code of Conduct. For more information see the Code of Conduct FAQ or contact opencode@microsoft.com with any additional questions or comments.