* load pretrained weights

* change name millionaid

* restructure and additional weights

* rename sentinel1 weights

* add vit small weights

* forgot to add vit.py

* struggling with test

* wrong name failing test

* feedback on tests

* increase test coverage

* fix failing test

* fix failing test

* fix failing test and add vit tests

* fix failing vit test

* torchgeo.models.utils

* forgot utils file

* typo num channels

* nitpick docs, version torchvision

* another try min dependencies

* add documentation table

* expand pytests to test pretrained weights on tasks

* reverse changes to byol task

* add tests to init pretrained weights from config

* forgot to add the conf files

* change path

* increase test coverage

* vit tests all pass locally including slow

* now remote

* fix tests another one

* add a draft tutorial

* run black on tutorial notebook

* Tutorial typo fixes

* Lower min torch/vision versions

* Fix bad rebase

* Remove dead code

* Flake8 fixes

* Consistent in_chans

* Black fixes

* bison > yacs

* Remove one more reference

* Download modified weights from hugging face

* Add entrypoints

* Add torch.hub support

* progress arg is required

* Fix model loading for resnet18

* Add transforms, update tests

* VIT -> ViT

* add seco weights

* Fix type hints

* Link to timm docs

* Fix pydocstyle

* Try to fix timm docs link

* Fix tests

* Nuke ignores

* Ignore timm links

* Add model API methods

* Add to __init__ and document

* Test model API functions

* fix tests

* Use correct documentation link for intersphinx

* Typos

* Fix Windows tests

* meth -> func

* Explicit function scope

* weight-specific filename

* Support enums in classification trainer

* Update other trainers too

* Fix regression tests

* Fix classification tests

* Fix byol tests

* Fix types

* progress_bar is required arg

* Test weight enums

* Fix pickling

* Fix regression tests

* Improve coverage of classification tests

* Improve coverage of BYOL tests

* Update resnet table

* Update ViT table

* Update get_state_dict usage

* Remove unused YAML files

* Update table widths

* Documentation improvements

* Tweak tables

* Try to fix Windows tests

* Revert "Try to fix Windows tests"

This reverts commit 1325b13ff7.

* Monkeypatch everything

* Revert "Monkeypatch everything"

This reverts commit e3e8d7d042.

* Revert "Revert "Monkeypatch everything""

This reverts commit 9b27bd705b.

* Patch things not at the source

* Fix missing import

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Nils Lehmann 2023-01-22 23:25:49 +01:00 коммит произвёл GitHub
Родитель 0d04e06791
Коммит 60eb61b5fa
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
46 изменённых файлов: 1101 добавлений и 396 удалений

Просмотреть файл

@ -10,7 +10,7 @@ experiment:
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
weights: null
in_channels: 14
num_classes: 19
datamodule:

Просмотреть файл

@ -6,9 +6,11 @@ experiment:
name: cowc_counting_test
module:
model: resnet18
weights: null
num_outputs: 1
in_channels: 3
learning_rate: 1e-3
learning_rate_schedule_patience: 2
pretrained: True
datamodule:
root: "data/cowc_counting"
seed: 0

Просмотреть файл

@ -6,9 +6,11 @@ experiment:
name: "cyclone_test"
module:
model: "resnet18"
weights: null
num_outputs: 1
in_channels: 3
learning_rate: 1e-3
learning_rate_schedule_patience: 2
pretrained: True
datamodule:
root: "data/cyclone"
seed: 0

Просмотреть файл

@ -5,7 +5,7 @@ experiment:
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
weights: null
in_channels: 13
num_classes: 10
datamodule:

Просмотреть файл

@ -15,7 +15,6 @@ experiment:
module:
model: "faster-rcnn"
backbone: "resnet50"
pretrained: True
num_classes: 2
learning_rate: 1.2e-4
learning_rate_schedule_patience: 6

Просмотреть файл

@ -10,7 +10,7 @@ experiment:
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
weights: null
in_channels: 3
num_classes: 45
datamodule:

Просмотреть файл

@ -10,7 +10,7 @@ experiment:
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
weights: null
in_channels: 3
num_classes: 17
datamodule:

Просмотреть файл

@ -21,7 +21,7 @@ Fully-convolutional Network
.. autoclass:: FCN
FC Siamese Networks
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^
.. autoclass:: FCSiamConc
.. autoclass:: FCSiamDiff
@ -34,4 +34,33 @@ RCF Extractor
ResNet
^^^^^^
.. autofunction:: resnet18
.. autofunction:: resnet50
.. autoclass:: ResNet18_Weights
.. autoclass:: ResNet50_Weights
.. csv-table::
:widths: 45 10 10 10 15 10 10 10
:header-rows: 1
:align: center
:file: resnet_pretrained_weights.csv
Vision Transformer
^^^^^^^^^^^^^^^^^^
.. autofunction:: vit_small_patch16_224
.. autoclass:: ViTSmall16_Weights
.. csv-table::
:widths: 45 10 10 10 15 10 10 10
:header-rows: 1
:align: center
:file: vit_pretrained_weights.csv
Utility Functions
^^^^^^^^^^^^^^^^^
.. autofunction:: get_model
.. autofunction:: get_model_weights
.. autofunction:: get_weight
.. autofunction:: list_models

Просмотреть файл

@ -0,0 +1,9 @@
Weight,Channels,Source,Citation,BigEarthNet,EuroSAT,So2Sat,OSCD
ResNet18_Weights.SENTINEL2_ALL_MOCO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,,,,
ResNet18_Weights.SENTINEL2_RGB_MOCO, 3,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,,,,
ResNet18_Weights.SENTINEL2_RGB_SECO, 3,`link <https://github.com/ServiceNow/seasonal-contrast>`__,`link <https://arxiv.org/abs/2103.16607>`__,87.27,93.14,,46.94
ResNet50_Weights.SENTINEL1_ALL_MOCO, 2,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,,,,
ResNet50_Weights.SENTINEL2_ALL_MOCO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,91.8,99.1,60.9,
ResNet50_Weights.SENTINEL2_RGB_MOCO, 3,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,,,,
ResNet50_Weights.SENTINEL2_ALL_DINO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,90.7,99.1,63.6,
ResNet50_Weights.SENTINEL2_RGB_SECO, 3,`link <https://github.com/ServiceNow/seasonal-contrast>`__,`link <https://arxiv.org/abs/2103.16607>`__,87.81,,,
1 Weight Channels Source Citation BigEarthNet EuroSAT So2Sat OSCD
2 ResNet18_Weights.SENTINEL2_ALL_MOCO 13 `link <https://github.com/zhu-xlab/SSL4EO-S12>`__ `link <https://arxiv.org/abs/2211.07044>`__
3 ResNet18_Weights.SENTINEL2_RGB_MOCO 3 `link <https://github.com/zhu-xlab/SSL4EO-S12>`__ `link <https://arxiv.org/abs/2211.07044>`__
4 ResNet18_Weights.SENTINEL2_RGB_SECO 3 `link <https://github.com/ServiceNow/seasonal-contrast>`__ `link <https://arxiv.org/abs/2103.16607>`__ 87.27 93.14 46.94
5 ResNet50_Weights.SENTINEL1_ALL_MOCO 2 `link <https://github.com/zhu-xlab/SSL4EO-S12>`__ `link <https://arxiv.org/abs/2211.07044>`__
6 ResNet50_Weights.SENTINEL2_ALL_MOCO 13 `link <https://github.com/zhu-xlab/SSL4EO-S12>`__ `link <https://arxiv.org/abs/2211.07044>`__ 91.8 99.1 60.9
7 ResNet50_Weights.SENTINEL2_RGB_MOCO 3 `link <https://github.com/zhu-xlab/SSL4EO-S12>`__ `link <https://arxiv.org/abs/2211.07044>`__
8 ResNet50_Weights.SENTINEL2_ALL_DINO 13 `link <https://github.com/zhu-xlab/SSL4EO-S12>`__ `link <https://arxiv.org/abs/2211.07044>`__ 90.7 99.1 63.6
9 ResNet50_Weights.SENTINEL2_RGB_SECO 3 `link <https://github.com/ServiceNow/seasonal-contrast>`__ `link <https://arxiv.org/abs/2103.16607>`__ 87.81

Просмотреть файл

@ -0,0 +1,3 @@
Weight,Channels,Source,Citation,BigEarthNet,EuroSAT,So2Sat,OSCD
VITSmall16_Weights.SENTINEL2_ALL_MOCO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,89.9,98.6,61.6,
VITSmall16_Weights.SENTINEL2_ALL_DINO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,90.5,99.0,62.2,
1 Weight Channels Source Citation BigEarthNet EuroSAT So2Sat OSCD
2 VITSmall16_Weights.SENTINEL2_ALL_MOCO 13 `link <https://github.com/zhu-xlab/SSL4EO-S12>`__ `link <https://arxiv.org/abs/2211.07044>`__ 89.9 98.6 61.6
3 VITSmall16_Weights.SENTINEL2_ALL_DINO 13 `link <https://github.com/zhu-xlab/SSL4EO-S12>`__ `link <https://arxiv.org/abs/2211.07044>`__ 90.5 99.0 62.2

Просмотреть файл

@ -56,14 +56,12 @@ needs_sphinx = "4.0"
nitpicky = True
nitpick_ignore = [
# https://github.com/sphinx-doc/sphinx/issues/8127
("py:class", ".."),
# TODO: can't figure out why this isn't found
("py:class", "LightningDataModule"),
("py:class", "pytorch_lightning.core.module.LightningModule"),
# Undocumented class
("py:class", "torchvision.models.resnet.ResNet"),
# Undocumented classes
("py:class", "segmentation_models_pytorch.base.model.SegmentationModel"),
("py:class", "timm.models.resnet.ResNet"),
("py:class", "timm.models.vision_transformer.VisionTransformer"),
("py:class", "torchvision.models._api.WeightsEnum"),
("py:class", "torchvision.models.resnet.ResNet"),
]
@ -114,6 +112,7 @@ intersphinx_mapping = {
"rasterio": ("https://rasterio.readthedocs.io/en/stable/", None),
"rtree": ("https://rtree.readthedocs.io/en/stable/", None),
"segmentation_models_pytorch": ("https://smp.readthedocs.io/en/stable/", None),
"timm": ("https://huggingface.co/docs/timm/main/en/", None),
"torch": ("https://pytorch.org/docs/stable", None),
"torchvision": ("https://pytorch.org/vision/stable", None),
}

Просмотреть файл

@ -33,6 +33,7 @@ torchgeo
tutorials/indices
tutorials/trainers
tutorials/benchmarking
tutorials/pretrained_weights
.. toctree::
:maxdepth: 1

Просмотреть файл

@ -0,0 +1,254 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Copyright (c) Microsoft Corporation. All rights reserved.\n",
"\n",
"Licensed under the MIT License."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Pretrained Weights\n",
"\n",
"In this tutorial, we demonstrate some available pretrained weights in TorchGeo. The implementation follows torchvisions' recently introduced [Multi-Weight API](https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/). We will use the [EuroSAT](https://torchgeo.readthedocs.io/en/stable/api/datasets.html#eurosat) dataset throughout this tutorial.\n",
"\n",
"It's recommended to run this notebook on Google Colab if you don't have your own GPU. Click the \"Open in Colab\" button above to get started."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup\n",
"\n",
"First, we install TorchGeo."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install torchgeo"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Imports\n",
"\n",
"Next, we import TorchGeo and any other libraries we need."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"\n",
"import os\n",
"import csv\n",
"import tempfile\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pytorch_lightning as pl\n",
"import timm\n",
"from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint\n",
"from pytorch_lightning.loggers import CSVLogger\n",
"\n",
"from torchgeo.datamodules import EuroSATDataModule\n",
"from torchgeo.trainers import ClassificationTask\n",
"from torchgeo.models import ResNet50_Weights, VITSmall16_Weights"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# we set a flag to check to see whether the notebook is currently being run by PyTest, if this is the case then we'll\n",
"# skip the expensive training.\n",
"in_tests = \"PYTEST_CURRENT_TEST\" in os.environ"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Datamodule\n",
"\n",
"We will utilize TorchGeo's datamodules from [PyTorch Lightning](https://pytorch-lightning.readthedocs.io/) to organize the dataloader setup."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"root = os.path.join(tempfile.gettempdir(), \"eurosat\")\n",
"\n",
"datamodule = EuroSATDataModule(root=root, batch_size=64, num_workers=4)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Weights\n",
"\n",
"Available pretrained weights are listed on the model documentation [page](https://torchgeo.readthedocs.io/en/stable/api/models.html). While some weights only accept RGB channel input, some weights have been pretrained on Sentinel 2 imagery with 13 input channels and can hence prove useful for transfer learning tasks involving Sentinel 2 data.\n",
"\n",
"To access these weights you can do the following:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"weights = ResNet50_Weights.SENTINEL2_ALL_MOCO"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"This set of weights is a torchvision `WeightEnum` and holds information such as the download url link or additional meta data. TorchGeo takes care of the downloading and initialization of models with a desired set of weights. Given that EuroSAT is a classification dataset, we can use a `ClassificationTask` object that holds the model and optimizer object as well as the training logic."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"task = ClassificationTask(\n",
" model=\"resnet50\",\n",
" loss=\"ce\",\n",
" weights=weights,\n",
" in_channels=13,\n",
" num_classes=10,\n",
" learning_rate=0.001,\n",
" learning_rate_schedule_patience=5,\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"If you do not want to utilize the `ClassificationTask` functionality for your experiments, you can also just create a [timm](https://github.com/rwightman/pytorch-image-models) model with pretrained weights from TorchGeo as follows:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"in_chans = weights.meta[\"in_chans\"]\n",
"model = timm.create_model(\"resnet50\", in_chans=in_chans, num_classes=10)\n",
"model.load_state_dict(weights.get_state_dict(), strict=False)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Training\n",
"\n",
"To train our pretrained model on the EuroSAT dataset we will make use of PyTorch Lightning's [Trainer](https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html). For a more elaborate explanation of how TorchGeo uses PyTorch Lightning, check out [this tutorial](https://torchgeo.readthedocs.io/en/stable/tutorials/trainers.html)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"experiment_dir = os.path.join(tempfile.gettempdir(), \"eurosat_results\")\n",
"\n",
"checkpoint_callback = ModelCheckpoint(\n",
" monitor=\"val_loss\", dirpath=experiment_dir, save_top_k=1, save_last=True\n",
")\n",
"\n",
"early_stopping_callback = EarlyStopping(monitor=\"val_loss\", min_delta=0.00, patience=10)\n",
"\n",
"csv_logger = CSVLogger(save_dir=experiment_dir, name=\"pretrained_weights_logs\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trainer = pl.Trainer(\n",
" callbacks=[checkpoint_callback, early_stopping_callback],\n",
" logger=[csv_logger],\n",
" default_root_dir=experiment_dir,\n",
" min_epochs=1,\n",
" max_epochs=10,\n",
" fast_dev_run=in_tests,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trainer.fit(model=task, datamodule=datamodule)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "torchEnv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "b058dd71d0e7047e70e62f655d92ec955f772479bbe5e5addd202027292e8f60"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

Просмотреть файл

@ -11,12 +11,12 @@ dependencies:
- pycocotools>=2
- pyproj>=2.2
- python>=3.7
- pytorch>=1.9
- pytorch>=1.12
- pyvista>=0.20
- rarfile>=3
- rasterio>=1.0.20
- shapely>=1.3
- torchvision>=0.10
- torchvision>=0.13
- pip:
- black[jupyter]>=21.8
- flake8>=3.8

14
hubconf.py Normal file
Просмотреть файл

@ -0,0 +1,14 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""TorchGeo pre-trained model repository configuration file.
* https://pytorch.org/hub/
* https://pytorch.org/docs/stable/hub.html
"""
from torchgeo.models import resnet18, resnet50, vit_small_patch16_224
__all__ = ("resnet18", "resnet50", "vit_small_patch16_224")
dependencies = ["timm"]

Просмотреть файл

@ -18,9 +18,9 @@ scikit-learn==0.21.0
segmentation-models-pytorch==0.2.0
shapely==1.3.0
timm==0.4.12
torch==1.9.0
torch==1.12.0
torchmetrics==0.10.0
torchvision==0.10.0
torchvision==0.13.0
# datasets
h5py==2.6.0

Просмотреть файл

@ -57,12 +57,12 @@ install_requires =
shapely>=1.3,<3
# timm 0.4.12 required by segmentation-models-pytorch
timm>=0.4.12,<0.7
# torch 1.9+ required by torchvision
torch>=1.9,<2
# torch 1.12+ required by torchvision
torch>=1.12,<2
# torchmetrics 0.10+ required for binary/multiclass/multilabel classification metrics
torchmetrics>=0.10,<0.12
# torchvision 0.10+ required for torchvision.utils.draw_segmentation_masks
torchvision>=0.10,<0.15
# torchvision 0.13+ required for torchvision.models._api.WeightsEnum
torchvision>=0.13,<0.15
python_requires = ~= 3.7
packages = find:

Просмотреть файл

@ -5,7 +5,7 @@ experiment:
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
weights: null
in_channels: 14
num_classes: 19
datamodule:

Просмотреть файл

@ -5,7 +5,7 @@ experiment:
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
weights: null
in_channels: 2
num_classes: 19
datamodule:

Просмотреть файл

@ -5,7 +5,7 @@ experiment:
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
weights: null
in_channels: 12
num_classes: 19
datamodule:

Просмотреть файл

@ -1,21 +0,0 @@
experiment:
task: "ssl"
name: "test_byol"
module:
model: "byol"
backbone: "resnet18"
input_channels: 4
weights: imagenet
learning_rate: 1e-3
learning_rate_schedule_patience: 6
datamodule:
root: "tests/data/chesapeake/cvpr"
download: true
train_splits:
- "de-test"
val_splits:
- "de-test"
test_splits:
- "de-test"
batch_size: 1
num_workers: 0

Просмотреть файл

@ -10,7 +10,7 @@ experiment:
num_classes: 7
num_filters: 1
ignore_index: null
weights: imagenet
weights: null
datamodule:
root: "tests/data/chesapeake/cvpr"
download: true

Просмотреть файл

@ -10,7 +10,7 @@ experiment:
num_classes: 5
num_filters: 1
ignore_index: null
weights: imagenet
weights: null
datamodule:
root: "tests/data/chesapeake/cvpr"
download: true

Просмотреть файл

@ -2,12 +2,11 @@ experiment:
task: cowc_counting
module:
model: resnet18
weights: "random"
weights: null
num_outputs: 1
in_channels: 3
learning_rate: 1e-3
learning_rate_schedule_patience: 2
pretrained: True
datamodule:
root: "tests/data/cowc_counting"
download: true

Просмотреть файл

@ -2,12 +2,11 @@ experiment:
task: "cyclone"
module:
model: "resnet18"
weights: "random"
weights: null
num_outputs: 1
in_channels: 3
learning_rate: 1e-3
learning_rate_schedule_patience: 2
pretrained: False
datamodule:
root: "tests/data/cyclone"
download: true

Просмотреть файл

@ -5,7 +5,7 @@ experiment:
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
weights: null
in_channels: 13
num_classes: 2
datamodule:

Просмотреть файл

@ -5,7 +5,7 @@ experiment:
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
weights: null
in_channels: 3
num_classes: 3
datamodule:

Просмотреть файл

@ -5,7 +5,7 @@ experiment:
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
weights: null
in_channels: 3
num_classes: 17
datamodule:

Просмотреть файл

@ -5,7 +5,7 @@ experiment:
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
weights: null
in_channels: 3
num_classes: 17
datamodule:

Просмотреть файл

@ -3,7 +3,7 @@ experiment:
module:
loss: "ce"
model: "resnet18"
weights: "random"
weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 3

Двоичные данные
tests/data/models/resnet50-sentinel2-2.pt.zip

Двоичный файл не отображается.

50
tests/models/test_api.py Normal file
Просмотреть файл

@ -0,0 +1,50 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import enum
from typing import Callable
import pytest
import torch.nn as nn
from torchvision.models._api import WeightsEnum
from torchgeo.models import (
ResNet18_Weights,
ResNet50_Weights,
ViTSmall16_Weights,
get_model,
get_model_weights,
get_weight,
list_models,
resnet18,
resnet50,
vit_small_patch16_224,
)
builders = [resnet18, resnet50, vit_small_patch16_224]
enums = [ResNet18_Weights, ResNet50_Weights, ViTSmall16_Weights]
@pytest.mark.parametrize("builder", builders)
def test_get_model(builder: Callable[..., nn.Module]) -> None:
model = get_model(builder.__name__)
assert isinstance(model, nn.Module)
@pytest.mark.parametrize("builder", builders)
def test_get_model_weights(builder: Callable[..., nn.Module]) -> None:
weights = get_model_weights(builder)
assert isinstance(weights, enum.EnumMeta)
weights = get_model_weights(builder.__name__)
assert isinstance(weights, enum.EnumMeta)
@pytest.mark.parametrize("enum", enums)
def test_get_weight(enum: WeightsEnum) -> None:
for weight in enum:
assert weight == get_weight(str(weight))
def test_list_models() -> None:
models = [builder.__name__ for builder in builders]
assert set(models) == set(list_models())

Просмотреть файл

@ -1,61 +1,74 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
from pathlib import Path
from typing import Any, Optional
from typing import Any, Dict
import pytest
import timm
import torch
import torchvision
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from torch.nn.modules import Module
from torchvision.models._api import WeightsEnum
import torchgeo.models.resnet
from torchgeo.datasets.utils import extract_archive
from torchgeo.models import resnet50
from torchgeo.models import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50
def load_state_dict_from_file(
file: str,
model_dir: Optional[str] = None,
map_location: Optional[Any] = None,
progress: Optional[bool] = True,
check_hash: Optional[bool] = False,
file_name: Optional[str] = None,
) -> Any:
"""Mockup of ``torch.hub.load_state_dict_from_url``."""
return torch.load(file)
def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]:
state_dict: Dict[str, Any] = torch.load(url)
return state_dict
@pytest.mark.parametrize(
"model_class,sensor,bands,in_channels,num_classes",
[(resnet50, "sentinel2", "all", 10, 17)],
)
def test_resnet(
monkeypatch: MonkeyPatch,
tmp_path: Path,
model_class: Module,
sensor: str,
bands: str,
in_channels: int,
num_classes: int,
) -> None:
extract_archive(
os.path.join("tests", "data", "models", "resnet50-sentinel2-2.pt.zip"),
str(tmp_path),
)
class TestResNet18:
@pytest.fixture(params=[*ResNet18_Weights])
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param
new_model_urls = {
"sentinel2": {"all": {"resnet50": str(tmp_path / "resnet50-sentinel2-2.pt")}}
}
@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
) -> WeightsEnum:
path = tmp_path / f"{weights}.pth"
model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"])
torch.save(model.state_dict(), path)
monkeypatch.setattr(weights, "url", str(path))
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
return weights
monkeypatch.setattr(torchgeo.models.resnet, "MODEL_URLS", new_model_urls)
monkeypatch.setattr(
torchgeo.models.resnet, "load_state_dict_from_url", load_state_dict_from_file
)
def test_resnet(self) -> None:
resnet18()
model = model_class(sensor, bands, pretrained=True)
x = torch.zeros(1, in_channels, 256, 256)
y = model(x)
assert isinstance(y, torch.Tensor)
assert y.size() == torch.Size([1, 17])
def test_resnet_weights(self, mocked_weights: WeightsEnum) -> None:
resnet18(weights=mocked_weights)
@pytest.mark.slow
def test_resnet_download(self, weights: WeightsEnum) -> None:
resnet18(weights=weights)
class TestResNet50:
@pytest.fixture(params=[*ResNet50_Weights])
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param
@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
) -> WeightsEnum:
path = tmp_path / f"{weights}.pth"
model = timm.create_model("resnet50", in_chans=weights.meta["in_chans"])
torch.save(model.state_dict(), path)
monkeypatch.setattr(weights, "url", str(path))
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
return weights
def test_resnet(self) -> None:
resnet50()
def test_resnet_weights(self, mocked_weights: WeightsEnum) -> None:
resnet50(weights=mocked_weights)
@pytest.mark.slow
def test_resnet_download(self, weights: WeightsEnum) -> None:
resnet50(weights=weights)

49
tests/models/test_vit.py Normal file
Просмотреть файл

@ -0,0 +1,49 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
from pathlib import Path
from typing import Any, Dict
import pytest
import timm
import torch
import torchvision
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from torchvision.models._api import WeightsEnum
from torchgeo.models import ViTSmall16_Weights, vit_small_patch16_224
def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]:
state_dict: Dict[str, Any] = torch.load(url)
return state_dict
class TestViTSmall16:
@pytest.fixture(params=[*ViTSmall16_Weights])
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param
@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
) -> WeightsEnum:
path = tmp_path / f"{weights}.pth"
model = timm.create_model(
weights.meta["model"], in_chans=weights.meta["in_chans"]
)
torch.save(model.state_dict(), path)
monkeypatch.setattr(weights, "url", str(path))
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
return weights
def test_vit(self) -> None:
vit_small_patch16_224()
def test_vit_weights(self, mocked_weights: WeightsEnum) -> None:
vit_small_patch16_224(weights=mocked_weights)
@pytest.mark.slow
def test_vit_download(self, weights: WeightsEnum) -> None:
vit_small_patch16_224(weights=weights)

Просмотреть файл

@ -2,21 +2,33 @@
# Licensed under the MIT License.
import os
from pathlib import Path
from typing import Any, Dict, Type, cast
import pytest
import timm
import torch
import torch.nn as nn
import torchvision
from _pytest.monkeypatch import MonkeyPatch
from omegaconf import OmegaConf
from pytorch_lightning import LightningDataModule, Trainer
from torchvision.models import resnet18
from torchvision.models._api import WeightsEnum
from torchgeo.datamodules import ChesapeakeCVPRDataModule
from torchgeo.models import ResNet18_Weights
from torchgeo.trainers import BYOLTask
from torchgeo.trainers.byol import BYOL, SimCLRAugmentation
from .test_utils import SegmentationTestModel
def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]:
state_dict: Dict[str, Any] = torch.load(url)
return state_dict
class TestBYOL:
def test_custom_augment_fn(self) -> None:
backbone = resnet18()
@ -45,7 +57,7 @@ class TestBYOLTask:
def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None:
conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))
conf_dict = OmegaConf.to_object(conf.experiment)
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)
conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict)
# Instantiate datamodule
datamodule_kwargs = conf_dict["datamodule"]
@ -64,30 +76,31 @@ class TestBYOLTask:
trainer.predict(model=model, dataloaders=datamodule.val_dataloader())
@pytest.fixture
def model_kwargs(self) -> Dict[Any, Any]:
return {"backbone": "resnet18", "weights": "random", "in_channels": 3}
def model_kwargs(self) -> Dict[str, Any]:
return {"backbone": "resnet18", "weights": None, "in_channels": 3}
def test_invalid_pretrained(
self, model_kwargs: Dict[Any, Any], checkpoint: str
@pytest.fixture
def mocked_weights(self, tmp_path: Path, monkeypatch: MonkeyPatch) -> WeightsEnum:
weights = ResNet18_Weights.SENTINEL2_RGB_MOCO
path = tmp_path / f"{weights}.pth"
model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"])
torch.save(model.state_dict(), path)
monkeypatch.setattr(weights, "url", str(path))
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
return weights
def test_weight_file(self, model_kwargs: Dict[str, Any], checkpoint: str) -> None:
model_kwargs["weights"] = checkpoint
BYOLTask(**model_kwargs)
def test_weight_enum(
self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum
) -> None:
model_kwargs["weights"] = checkpoint
model_kwargs["backbone"] = "resnet50"
match = "Trying to load resnet18 weights into a resnet50"
with pytest.raises(ValueError, match=match):
model_kwargs["weights"] = mocked_weights
BYOLTask(**model_kwargs)
def test_pretrained(self, model_kwargs: Dict[Any, Any], checkpoint: str) -> None:
model_kwargs["weights"] = checkpoint
BYOLTask(**model_kwargs)
def test_invalid_backbone(self, model_kwargs: Dict[Any, Any]) -> None:
model_kwargs["backbone"] = "invalid_backbone"
match = "Model type 'invalid_backbone' is not a valid timm model."
with pytest.raises(ValueError, match=match):
BYOLTask(**model_kwargs)
def test_invalid_weights(self, model_kwargs: Dict[Any, Any]) -> None:
model_kwargs["weights"] = "invalid_weights"
match = "Weight type 'invalid_weights' is not valid."
with pytest.raises(ValueError, match=match):
def test_weight_str(
self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum
) -> None:
model_kwargs["weights"] = str(mocked_weights)
BYOLTask(**model_kwargs)

Просмотреть файл

@ -2,14 +2,18 @@
# Licensed under the MIT License.
import os
from pathlib import Path
from typing import Any, Dict, Type, cast
import pytest
import timm
import torch
import torchvision
from _pytest.monkeypatch import MonkeyPatch
from omegaconf import OmegaConf
from pytorch_lightning import LightningDataModule, Trainer
from torch.nn.modules import Module
from torchvision.models._api import WeightsEnum
from torchgeo.datamodules import (
BigEarthNetDataModule,
@ -18,6 +22,7 @@ from torchgeo.datamodules import (
So2SatDataModule,
UCMercedDataModule,
)
from torchgeo.models import ResNet18_Weights
from torchgeo.trainers import ClassificationTask, MultiLabelClassificationTask
from .test_utils import ClassificationTestModel
@ -27,6 +32,11 @@ def create_model(*args: Any, **kwargs: Any) -> Module:
return ClassificationTestModel(**kwargs)
def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]:
state_dict: Dict[str, Any] = torch.load(url)
return state_dict
class TestClassificationTask:
@pytest.mark.parametrize(
"name,classname",
@ -46,7 +56,7 @@ class TestClassificationTask:
conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))
conf_dict = OmegaConf.to_object(conf.experiment)
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)
conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict)
# Instantiate datamodule
datamodule_kwargs = conf_dict["datamodule"]
@ -66,7 +76,7 @@ class TestClassificationTask:
def test_no_logger(self) -> None:
conf = OmegaConf.load(os.path.join("tests", "conf", "ucmerced.yaml"))
conf_dict = OmegaConf.to_object(conf.experiment)
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)
conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict)
# Instantiate datamodule
datamodule_kwargs = conf_dict["datamodule"]
@ -83,49 +93,52 @@ class TestClassificationTask:
trainer.fit(model=model, datamodule=datamodule)
@pytest.fixture
def model_kwargs(self) -> Dict[Any, Any]:
def model_kwargs(self) -> Dict[str, Any]:
return {
"model": "resnet18",
"in_channels": 13,
"loss": "ce",
"num_classes": 10,
"weights": "random",
"weights": None,
}
def test_pretrained(self, model_kwargs: Dict[Any, Any], checkpoint: str) -> None:
@pytest.fixture
def mocked_weights(self, tmp_path: Path, monkeypatch: MonkeyPatch) -> WeightsEnum:
weights = ResNet18_Weights.SENTINEL2_ALL_MOCO
path = tmp_path / f"{weights}.pth"
model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"])
torch.save(model.state_dict(), path)
monkeypatch.setattr(weights, "url", str(path))
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
return weights
def test_weight_file(self, model_kwargs: Dict[str, Any], checkpoint: str) -> None:
model_kwargs["weights"] = checkpoint
with pytest.warns(UserWarning):
ClassificationTask(**model_kwargs)
def test_invalid_pretrained(
self, model_kwargs: Dict[Any, Any], checkpoint: str
def test_weight_enum(
self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum
) -> None:
model_kwargs["weights"] = checkpoint
model_kwargs["model"] = "resnet50"
match = "Trying to load resnet18 weights into a resnet50"
with pytest.raises(ValueError, match=match):
model_kwargs["weights"] = mocked_weights
with pytest.warns(UserWarning):
ClassificationTask(**model_kwargs)
def test_invalid_loss(self, model_kwargs: Dict[Any, Any]) -> None:
def test_weight_str(
self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum
) -> None:
model_kwargs["weights"] = str(mocked_weights)
with pytest.warns(UserWarning):
ClassificationTask(**model_kwargs)
def test_invalid_loss(self, model_kwargs: Dict[str, Any]) -> None:
model_kwargs["loss"] = "invalid_loss"
match = "Loss type 'invalid_loss' is not valid."
with pytest.raises(ValueError, match=match):
ClassificationTask(**model_kwargs)
def test_invalid_model(self, model_kwargs: Dict[Any, Any]) -> None:
model_kwargs["model"] = "invalid_model"
match = "Model type 'invalid_model' is not a valid timm model."
with pytest.raises(ValueError, match=match):
ClassificationTask(**model_kwargs)
def test_invalid_weights(self, model_kwargs: Dict[Any, Any]) -> None:
model_kwargs["weights"] = "invalid_weights"
match = "Weight type 'invalid_weights' is not valid."
with pytest.raises(ValueError, match=match):
ClassificationTask(**model_kwargs)
def test_missing_attributes(
self, model_kwargs: Dict[Any, Any], monkeypatch: MonkeyPatch
self, model_kwargs: Dict[str, Any], monkeypatch: MonkeyPatch
) -> None:
monkeypatch.delattr(EuroSATDataModule, "plot")
datamodule = EuroSATDataModule(
@ -150,7 +163,7 @@ class TestMultiLabelClassificationTask:
) -> None:
conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))
conf_dict = OmegaConf.to_object(conf.experiment)
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)
conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict)
# Instantiate datamodule
datamodule_kwargs = conf_dict["datamodule"]
@ -170,7 +183,7 @@ class TestMultiLabelClassificationTask:
def test_no_logger(self) -> None:
conf = OmegaConf.load(os.path.join("tests", "conf", "bigearthnet_s1.yaml"))
conf_dict = OmegaConf.to_object(conf.experiment)
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)
conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict)
# Instantiate datamodule
datamodule_kwargs = conf_dict["datamodule"]
@ -187,23 +200,23 @@ class TestMultiLabelClassificationTask:
trainer.fit(model=model, datamodule=datamodule)
@pytest.fixture
def model_kwargs(self) -> Dict[Any, Any]:
def model_kwargs(self) -> Dict[str, Any]:
return {
"model": "resnet18",
"in_channels": 14,
"loss": "bce",
"num_classes": 19,
"weights": "random",
"weights": None,
}
def test_invalid_loss(self, model_kwargs: Dict[Any, Any]) -> None:
def test_invalid_loss(self, model_kwargs: Dict[str, Any]) -> None:
model_kwargs["loss"] = "invalid_loss"
match = "Loss type 'invalid_loss' is not valid."
with pytest.raises(ValueError, match=match):
MultiLabelClassificationTask(**model_kwargs)
def test_missing_attributes(
self, model_kwargs: Dict[Any, Any], monkeypatch: MonkeyPatch
self, model_kwargs: Dict[str, Any], monkeypatch: MonkeyPatch
) -> None:
monkeypatch.delattr(BigEarthNetDataModule, "plot")
datamodule = BigEarthNetDataModule(

Просмотреть файл

@ -2,18 +2,30 @@
# Licensed under the MIT License.
import os
from pathlib import Path
from typing import Any, Dict, Type, cast
import pytest
import timm
import torch
import torchvision
from _pytest.monkeypatch import MonkeyPatch
from omegaconf import OmegaConf
from pytorch_lightning import LightningDataModule, Trainer
from torchvision.models._api import WeightsEnum
from torchgeo.datamodules import COWCCountingDataModule, TropicalCycloneDataModule
from torchgeo.models import ResNet18_Weights
from torchgeo.trainers import RegressionTask
from .test_utils import RegressionTestModel
def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]:
state_dict: Dict[str, Any] = torch.load(url)
return state_dict
class TestRegressionTask:
@pytest.mark.parametrize(
"name,classname",
@ -25,7 +37,7 @@ class TestRegressionTask:
def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None:
conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))
conf_dict = OmegaConf.to_object(conf.experiment)
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)
conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict)
# Instantiate datamodule
datamodule_kwargs = conf_dict["datamodule"]
@ -46,7 +58,7 @@ class TestRegressionTask:
def test_no_logger(self) -> None:
conf = OmegaConf.load(os.path.join("tests", "conf", "cyclone.yaml"))
conf_dict = OmegaConf.to_object(conf.experiment)
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)
conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict)
# Instantiate datamodule
datamodule_kwargs = conf_dict["datamodule"]
@ -63,36 +75,39 @@ class TestRegressionTask:
trainer.fit(model=model, datamodule=datamodule)
@pytest.fixture
def model_kwargs(self) -> Dict[Any, Any]:
def model_kwargs(self) -> Dict[str, Any]:
return {
"model": "resnet18",
"weights": "random",
"weights": None,
"num_outputs": 1,
"in_channels": 3,
}
def test_invalid_pretrained(
self, model_kwargs: Dict[Any, Any], checkpoint: str
) -> None:
model_kwargs["weights"] = checkpoint
model_kwargs["model"] = "resnet50"
match = "Trying to load resnet18 weights into a resnet50"
with pytest.raises(ValueError, match=match):
RegressionTask(**model_kwargs)
@pytest.fixture
def mocked_weights(self, tmp_path: Path, monkeypatch: MonkeyPatch) -> WeightsEnum:
weights = ResNet18_Weights.SENTINEL2_RGB_MOCO
path = tmp_path / f"{weights}.pth"
model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"])
torch.save(model.state_dict(), path)
monkeypatch.setattr(weights, "url", str(path))
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
return weights
def test_pretrained(self, model_kwargs: Dict[Any, Any], checkpoint: str) -> None:
def test_weight_file(self, model_kwargs: Dict[str, Any], checkpoint: str) -> None:
model_kwargs["weights"] = checkpoint
with pytest.warns(UserWarning):
RegressionTask(**model_kwargs)
def test_invalid_model(self, model_kwargs: Dict[Any, Any]) -> None:
model_kwargs["model"] = "invalid_model"
match = "Model type 'invalid_model' is not a valid timm model."
with pytest.raises(ValueError, match=match):
def test_weight_enum(
self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum
) -> None:
model_kwargs["weights"] = mocked_weights
with pytest.warns(UserWarning):
RegressionTask(**model_kwargs)
def test_invalid_weights(self, model_kwargs: Dict[Any, Any]) -> None:
model_kwargs["weights"] = "invalid_weights"
match = "Weight type 'invalid_weights' is not valid."
with pytest.raises(ValueError, match=match):
def test_weight_str(
self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum
) -> None:
model_kwargs["weights"] = str(mocked_weights)
with pytest.warns(UserWarning):
RegressionTask(**model_kwargs)

Просмотреть файл

@ -3,14 +3,17 @@
"""TorchGeo models."""
from .api import get_model, get_model_weights, get_weight, list_models
from .changestar import ChangeMixin, ChangeStar, ChangeStarFarSeg
from .farseg import FarSeg
from .fcn import FCN
from .fcsiam import FCSiamConc, FCSiamDiff
from .rcf import RCF
from .resnet import resnet50
from .resnet import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50
from .vit import ViTSmall16_Weights, vit_small_patch16_224
__all__ = (
# models
"ChangeMixin",
"ChangeStar",
"ChangeStarFarSeg",
@ -19,5 +22,16 @@ __all__ = (
"FCSiamConc",
"FCSiamDiff",
"RCF",
"resnet18",
"resnet50",
"vit_small_patch16_224",
# weights
"ResNet50_Weights",
"ResNet18_Weights",
"ViTSmall16_Weights",
# utilities
"get_model",
"get_model_weights",
"get_weight",
"list_models",
)

90
torchgeo/models/api.py Normal file
Просмотреть файл

@ -0,0 +1,90 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""APIs for querying and loading pre-trained model weights.
See the following references for design details:
* https://pytorch.org/blog/easily-list-and-initialize-models-with-new-apis-in-torchvision/
* https://pytorch.org/vision/stable/models.html
* https://github.com/pytorch/vision/blob/main/torchvision/models/_api.py
""" # noqa: E501
from typing import Any, Callable, List, Union
import torch.nn as nn
from torchvision.models._api import WeightsEnum
from .resnet import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50
from .vit import ViTSmall16_Weights, vit_small_patch16_224
_model = {
"resnet18": resnet18,
"resnet50": resnet50,
"vit_small_patch16_224": vit_small_patch16_224,
}
_model_weights = {
resnet18: ResNet18_Weights,
resnet50: ResNet50_Weights,
vit_small_patch16_224: ViTSmall16_Weights,
"resnet18": ResNet18_Weights,
"resnet50": ResNet50_Weights,
"vit_small_patch16_224": ViTSmall16_Weights,
}
def get_model(name: str, *args: Any, **kwargs: Any) -> nn.Module:
"""Get an instantiated model from its name.
.. versionadded:: 0.4
Args:
name: Name of the model.
*args: Additional arguments passed to the model builder method.
**kwargs: Additional keyword arguments passed to the model builder method.
Returns:
An instantiated model.
"""
model: nn.Module = _model[name](*args, **kwargs)
return model
def get_model_weights(name: Union[Callable[..., nn.Module], str]) -> WeightsEnum:
"""Get the weights enum class associated with a given model.
.. versionadded:: 0.4
Args:
name: Model builder function or the name under which it is registered.
Returns:
The weights enum class associated with the model.
"""
return _model_weights[name]
def get_weight(name: str) -> WeightsEnum:
"""Get the weights enum value by its full name.
.. versionadded:: 0.4
Args:
name: Name of the weight enum entry.
Returns:
The requested weight enum.
"""
return eval(name)
def list_models() -> List[str]:
"""List the registered models.
.. versionadded:: 0.4
Returns:
A list of registered models.
"""
return list(_model.keys())

Просмотреть файл

@ -9,7 +9,6 @@ from typing import List, cast
import torch.nn.functional as F
import torchvision
from packaging.version import parse
from torch import Tensor
from torch.nn.modules import (
BatchNorm2d,
@ -62,7 +61,6 @@ class FarSeg(Module):
else:
raise ValueError(f"unknown backbone: {backbone}.")
kwargs = {}
if parse(torchvision.__version__) >= parse("0.13"):
if backbone_pretrained:
kwargs = {
"weights": getattr(
@ -71,8 +69,6 @@ class FarSeg(Module):
}
else:
kwargs = {"weights": None}
else:
kwargs = {"pretrained": backbone_pretrained}
self.backbone = getattr(resnet, backbone)(**kwargs)

Просмотреть файл

@ -3,83 +3,208 @@
"""Pre-trained ResNet models."""
from typing import Any, List, Type, Union
from typing import Any, Optional
import kornia.augmentation as K
import timm
import torch.nn as nn
from torch.hub import load_state_dict_from_url
from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet
from timm.models import ResNet
from torchvision.models._api import Weights, WeightsEnum
MODEL_URLS = {
"sentinel2": {
"all": {
"resnet50": "https://zenodo.org/record/5610000/files/resnet50-sentinel2.pt"
}
}
}
from ..transforms import AugmentationSequential
__all__ = ["ResNet50_Weights", "ResNet18_Weights"]
_zhu_xlab_transforms = AugmentationSequential(
K.Resize(256), K.CenterCrop(224), data_keys=["image"]
)
# https://github.com/pytorch/vision/pull/6883
# https://github.com/pytorch/vision/pull/7107
# Can be removed once torchvision>=0.15 is required
Weights.__deepcopy__ = lambda *args, **kwargs: args[0]
IN_CHANNELS = {"sentinel2": {"all": 10}}
class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
"""ResNet18 weights.
NUM_CLASSES = {"sentinel2": 17}
For `timm <https://github.com/rwightman/pytorch-image-models>`_
*resnet18* implementation.
.. versionadded:: 0.4
"""
SENTINEL2_ALL_MOCO = Weights(
url=(
"https://huggingface.co/torchgeo/resnet18_sentinel2_all_moco/"
"resolve/main/resnet18_sentinel2_all_moco.pth"
),
transforms=_zhu_xlab_transforms,
meta={
"dataset": "SSL4EO-S12",
"in_chans": 13,
"model": "resnet18",
"publication": "https://arxiv.org/abs/2211.07044",
"repo": "https://github.com/zhu-xlab/SSL4EO-S12",
"ssl_method": "moco",
},
)
SENTINEL2_RGB_MOCO = Weights(
url=(
"https://huggingface.co/torchgeo/resnet18_sentinel2_rgb_moco/"
"resolve/main/resnet18_sentinel2_rgb_moco.pth"
),
transforms=_zhu_xlab_transforms,
meta={
"dataset": "SSL4EO-S12",
"in_chans": 3,
"model": "resnet18",
"publication": "https://arxiv.org/abs/2211.07044",
"repo": "https://github.com/zhu-xlab/SSL4EO-S12",
"ssl_method": "moco",
},
)
SENTINEL2_RGB_SECO = Weights(
url=(
"https://huggingface.co/torchgeo/resnet18_sentinel2_rgb_seco/"
"resolve/main/resnet18_sentinel2_rgb_seco.ckpt"
),
transforms=nn.Identity(),
meta={
"dataset": "SeCo Dataset",
"in_chans": 3,
"model": "resnet18",
"publication": "https://arxiv.org/abs/2103.16607",
"repo": "https://github.com/ServiceNow/seasonal-contrast",
"ssl_method": "seco",
},
)
def _resnet(
sensor: str,
bands: str,
arch: str,
block: Type[Union[BasicBlock, Bottleneck]],
layers: List[int],
pretrained: bool,
progress: bool,
**kwargs: Any,
class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
"""ResNet50 weights.
For `timm <https://github.com/rwightman/pytorch-image-models>`_
*resnet50* implementation.
.. versionadded:: 0.4
"""
SENTINEL1_ALL_MOCO = Weights(
url=(
"https://huggingface.co/torchgeo/resnet50_sentinel1_all_moco/"
"resolve/main/resnet50_sentinel1_all_moco.pth"
),
transforms=_zhu_xlab_transforms,
meta={
"dataset": "SSL4EO-S12",
"in_chans": 2,
"model": "resnet50",
"publication": "https://arxiv.org/abs/2211.07044",
"repo": "https://github.com/zhu-xlab/SSL4EO-S12",
"ssl_method": "moco",
},
)
SENTINEL2_ALL_MOCO = Weights(
url=(
"https://huggingface.co/torchgeo/resnet50_sentinel2_all_moco/"
"resolve/main/resnet50_sentinel2_all_moco.pth"
),
transforms=_zhu_xlab_transforms,
meta={
"dataset": "SSL4EO-S12",
"in_chans": 13,
"model": "resnet50",
"publication": "https://arxiv.org/abs/2211.07044",
"repo": "https://github.com/zhu-xlab/SSL4EO-S12",
"ssl_method": "moco",
},
)
SENTINEL2_RGB_MOCO = Weights(
url=(
"https://huggingface.co/torchgeo/resnet50_sentinel2_rgb_moco/"
"resolve/main/resnet50_sentinel2_rgb_moco.pth"
),
transforms=_zhu_xlab_transforms,
meta={
"dataset": "SSL4EO-S12",
"in_chans": 3,
"model": "resnet50",
"publication": "https://arxiv.org/abs/2211.07044",
"repo": "https://github.com/zhu-xlab/SSL4EO-S12",
"ssl_method": "moco",
},
)
SENTINEL2_ALL_DINO = Weights(
url=(
"https://huggingface.co/torchgeo/resnet50_sentinel2_all_dino/"
"resolve/main/resnet50_sentinel2_all_dino.pth"
),
transforms=_zhu_xlab_transforms,
meta={
"dataset": "SSL4EO-S12",
"in_chans": 13,
"model": "resnet50",
"publication": "https://arxiv.org/abs/2211.07044",
"repo": "https://github.com/zhu-xlab/SSL4EO-S12",
"ssl_method": "dino",
},
)
SENTINEL2_RGB_SECO = Weights(
url=(
"https://huggingface.co/torchgeo/resnet50_sentinel2_rgb_seco/"
"resolve/main/resnet50_sentinel2_rgb_seco.ckpt"
),
transforms=nn.Identity(),
meta={
"dataset": "SeCo Dataset",
"in_chans": 3,
"model": "resnet50",
"publication": "https://arxiv.org/abs/2103.16607",
"repo": "https://github.com/ServiceNow/seasonal-contrast",
"ssl_method": "seco",
},
)
def resnet18(
weights: Optional[ResNet18_Weights] = None, *args: Any, **kwargs: Any
) -> ResNet:
"""Resnet model.
"""ResNet-18 model.
If you use this model in your research, please cite the following paper:
* https://arxiv.org/pdf/1512.03385.pdf
.. versionadded:: 0.4
Args:
sensor: imagery source which determines number of input channels
bands: which spectral bands to consider: "all", "rgb", etc.
arch: ResNet version specifying number of layers
block: type of network block
layers: number of layers per block
pretrained: if True, returns a model pre-trained on ``sensor`` imagery
progress: if True, displays a progress bar of the download to stderr
weights: Pre-trained model weights to use.
*args: Additional arguments to pass to :func:`timm.create_model`
**kwargs: Additional keywork arguments to pass to :func:`timm.create_model`
Returns:
A ResNet-50 model
A ResNet-18 model.
"""
# Initialize a new model
model = ResNet(block, layers, NUM_CLASSES[sensor], **kwargs)
if weights:
kwargs["in_chans"] = weights.meta["in_chans"]
# Replace the first layer with the correct number of input channels
model.conv1 = nn.Conv2d(
IN_CHANNELS[sensor][bands],
out_channels=64,
kernel_size=7,
stride=1,
padding=2,
bias=False,
)
model: ResNet = timm.create_model("resnet18", *args, **kwargs)
# Load pretrained weights
if pretrained:
state_dict = load_state_dict_from_url(
MODEL_URLS[sensor][bands][arch], progress=progress
)
model.load_state_dict(state_dict)
if weights:
model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
return model
def resnet50(
sensor: str,
bands: str,
pretrained: bool = False,
progress: bool = True,
**kwargs: Any,
weights: Optional[ResNet50_Weights] = None, *args: Any, **kwargs: Any
) -> ResNet:
"""ResNet-50 model.
@ -87,22 +212,23 @@ def resnet50(
* https://arxiv.org/pdf/1512.03385.pdf
.. versionchanged:: 0.4
Switched to multi-weight support API.
Args:
sensor: imagery source which determines number of input channels
bands: which spectral bands to consider: "all", "rgb", etc.
pretrained: if True, returns a model pre-trained on ``sensor`` imagery
progress: if True, displays a progress bar of the download to stderr
weights: Pre-trained model weights to use.
*args: Additional arguments to pass to :func:`timm.create_model`.
**kwargs: Additional keywork arguments to pass to :func:`timm.create_model`.
Returns:
A ResNet-50 model
A ResNet-50 model.
"""
return _resnet(
sensor,
bands,
"resnet50",
Bottleneck,
[3, 4, 6, 3],
pretrained,
progress,
**kwargs,
)
if weights:
kwargs["in_chans"] = weights.meta["in_chans"]
model: ResNet = timm.create_model("resnet50", *args, **kwargs)
if weights:
model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
return model

98
torchgeo/models/vit.py Normal file
Просмотреть файл

@ -0,0 +1,98 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""Pre-trained Vision Transformer models."""
from typing import Any, Optional
import kornia.augmentation as K
import timm
from timm.models.vision_transformer import VisionTransformer
from torchvision.models._api import Weights, WeightsEnum
from ..transforms import AugmentationSequential
__all__ = ["ViTSmall16_Weights"]
_zhu_xlab_transforms = AugmentationSequential(
K.Resize(256), K.CenterCrop(224), data_keys=["image"]
)
# https://github.com/pytorch/vision/pull/6883
# https://github.com/pytorch/vision/pull/7107
# Can be removed once torchvision>=0.15 is required
Weights.__deepcopy__ = lambda *args, **kwargs: args[0]
class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
"""Vision Transformer Samll Patch Size 16 weights.
For `timm <https://github.com/rwightman/pytorch-image-models>`_
*vit_small_patch16_224* implementation.
.. versionadded:: 0.4
"""
SENTINEL2_ALL_MOCO = Weights(
url=(
"https://huggingface.co/torchgeo/vit_small_patch16_224_sentinel2_all_moco/"
"resolve/main/vit_small_patch16_224_sentinel2_all_moco.pth"
),
transforms=_zhu_xlab_transforms,
meta={
"dataset": "SSL4EO-S12",
"in_chans": 13,
"model": "vit_small_patch16_224",
"publication": "https://arxiv.org/abs/2211.07044",
"repo": "https://github.com/zhu-xlab/SSL4EO-S12",
"ssl_method": "moco",
},
)
SENTINEL2_ALL_DINO = Weights(
url=(
"https://huggingface.co/torchgeo/vit_small_patch16_224_sentinel2_all_dino/"
"resolve/main/vit_small_patch16_224_sentinel2_all_dino.pth"
),
transforms=_zhu_xlab_transforms,
meta={
"dataset": "SSL4EO-S12",
"in_chans": 13,
"model": "vit_small_patch16_224",
"publication": "https://arxiv.org/abs/2211.07044",
"repo": "https://github.com/zhu-xlab/SSL4EO-S12",
"ssl_method": "dino",
},
)
def vit_small_patch16_224(
weights: Optional[ViTSmall16_Weights] = None, *args: Any, **kwargs: Any
) -> VisionTransformer:
"""Vision Transform (ViT) small patch size 16 model.
If you use this model in your research, please cite the following paper:
* https://arxiv.org/abs/2010.11929
.. versionadded:: 0.4
Args:
weights: Pre-trained model weights to use.
*args: Additional arguments to pass to :func:`timm.create_model`.
**kwargs: Additional keywork arguments to pass to :func:`timm.create_model`.
Returns:
A ViT small 16 model.
"""
if weights:
kwargs["in_chans"] = weights.meta["in_chans"]
model: VisionTransformer = timm.create_model(
"vit_small_patch16_224", *args, **kwargs
)
if weights:
model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
return model

Просмотреть файл

@ -17,7 +17,9 @@ from kornia.geometry import transform as KorniaTransform
from torch import Tensor, optim
from torch.nn.modules import BatchNorm1d, Linear, Module, ReLU, Sequential
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision.models._api import WeightsEnum
from ..models import get_weight
from . import utils
@ -323,41 +325,23 @@ class BYOLTask(pl.LightningModule):
def config_task(self) -> None:
"""Configures the task based on kwargs parameters passed to the constructor."""
# Create model
in_channels = self.hyperparams["in_channels"]
backbone_name = self.hyperparams["backbone"]
imagenet_pretrained = False
custom_pretrained = False
if self.hyperparams["weights"] and not os.path.exists(
self.hyperparams["weights"]
):
if self.hyperparams["weights"] not in ["imagenet", "random"]:
raise ValueError(
f"Weight type '{self.hyperparams['weights']}' is not valid."
)
else:
imagenet_pretrained = self.hyperparams["weights"] == "imagenet"
custom_pretrained = False
else:
custom_pretrained = True
# Create the model
valid_models = timm.list_models(pretrained=imagenet_pretrained)
if backbone_name in valid_models:
weights = self.hyperparams["weights"]
backbone = timm.create_model(
backbone_name, in_chans=in_channels, pretrained=imagenet_pretrained
self.hyperparams["backbone"],
in_chans=in_channels,
pretrained=weights is True,
)
# Load weights
if weights and weights is not True:
if isinstance(weights, WeightsEnum):
state_dict = weights.get_state_dict(progress=True)
elif os.path.exists(weights):
_, state_dict = utils.extract_backbone(weights)
else:
raise ValueError(f"Model type '{backbone_name}' is not a valid timm model.")
if custom_pretrained:
name, state_dict = utils.extract_backbone(self.hyperparams["weights"])
if self.hyperparams["backbone"] != name:
raise ValueError(
f"Trying to load {name} weights into a "
f"{self.hyperparams['backbone']}"
)
state_dict = get_weight(weights).get_state_dict(progress=True)
backbone = utils.load_state_dict(backbone, state_dict)
self.model = BYOL(backbone, in_channels=in_channels, image_size=(256, 256))
@ -368,7 +352,9 @@ class BYOLTask(pl.LightningModule):
Keyword Args:
in_channels: Number of input channels to model
backbone: Name of the timm model to use
weights: Either "random" or "imagenet"
weights: Either a weight enum, the string representation of a weight enum,
True for ImageNet weights, False or None for random weights,
or the path to a saved model state dict.
learning_rate: Learning rate for optimizer
learning_rate_schedule_patience: Patience for learning rate scheduler

Просмотреть файл

@ -22,8 +22,10 @@ from torchmetrics.classification import (
MultilabelAccuracy,
MultilabelFBetaScore,
)
from torchvision.models._api import WeightsEnum
from ..datasets.utils import unbind_samples
from ..datasets import unbind_samples
from ..models import get_weight
from . import utils
@ -43,44 +45,23 @@ class ClassificationTask(pl.LightningModule):
def config_model(self) -> None:
"""Configures the model based on kwargs parameters passed to the constructor."""
in_channels = self.hyperparams["in_channels"]
model = self.hyperparams["model"]
imagenet_pretrained = False
custom_pretrained = False
if self.hyperparams["weights"] and not os.path.exists(
self.hyperparams["weights"]
):
if self.hyperparams["weights"] not in ["imagenet", "random"]:
raise ValueError(
f"Weight type '{self.hyperparams['weights']}' is not valid."
)
else:
imagenet_pretrained = self.hyperparams["weights"] == "imagenet"
custom_pretrained = False
else:
custom_pretrained = True
# Create the model
valid_models = timm.list_models(pretrained=imagenet_pretrained)
if model in valid_models:
# Create model
weights = self.hyperparams["weights"]
self.model = timm.create_model(
model,
self.hyperparams["model"],
num_classes=self.hyperparams["num_classes"],
in_chans=in_channels,
pretrained=imagenet_pretrained,
in_chans=self.hyperparams["in_channels"],
pretrained=weights is True,
)
# Load weights
if weights and weights is not True:
if isinstance(weights, WeightsEnum):
state_dict = weights.get_state_dict(progress=True)
elif os.path.exists(weights):
_, state_dict = utils.extract_backbone(weights)
else:
raise ValueError(f"Model type '{model}' is not a valid timm model.")
if custom_pretrained:
name, state_dict = utils.extract_backbone(self.hyperparams["weights"])
if self.hyperparams["model"] != name:
raise ValueError(
f"Trying to load {name} weights into a "
f"{self.hyperparams['model']}"
)
state_dict = get_weight(weights).get_state_dict(progress=True)
self.model = utils.load_state_dict(self.model, state_dict)
def config_task(self) -> None:
@ -102,7 +83,9 @@ class ClassificationTask(pl.LightningModule):
Keyword Args:
model: Name of the classification model use
loss: Name of the loss function, accepts 'ce', 'jaccard', or 'focal'
weights: Either "random" or "imagenet"
weights: Either a weight enum, the string representation of a weight enum,
True for ImageNet weights, False or None for random weights,
or the path to a saved model state dict.
num_classes: Number of prediction classes
in_channels: Number of input channels to model
learning_rate: Learning rate for optimizer
@ -321,12 +304,6 @@ class MultiLabelClassificationTask(ClassificationTask):
"""
super().__init__(**kwargs)
# Creates `self.hparams` from kwargs
self.save_hyperparameters() # type: ignore[operator]
self.hyperparams = cast(Dict[str, Any], self.hparams)
self.config_task()
self.train_metrics = MetricCollection(
{
"OverallAccuracy": MultilabelAccuracy(

Просмотреть файл

@ -8,18 +8,16 @@ from typing import Any, Dict, List, cast
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch
import torchvision
from packaging.version import parse
from torch import Tensor
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from torchvision.models import resnet as R
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.ops import MultiScaleRoIAlign
if parse(torchvision.__version__) >= parse("0.13"):
from torchvision.models import resnet as R
from ..datasets.utils import unbind_samples
BACKBONE_WEIGHT_MAP = {
"resnet18": R.ResNet18_Weights.DEFAULT,
@ -33,8 +31,6 @@ if parse(torchvision.__version__) >= parse("0.13"):
"wide_resnet101_2": R.Wide_ResNet101_2_Weights.DEFAULT,
}
from ..datasets.utils import unbind_samples
class ObjectDetectionTask(pl.LightningModule):
"""LightningModule for object detection of images.
@ -62,15 +58,12 @@ class ObjectDetectionTask(pl.LightningModule):
"backbone_name": self.hyperparams["backbone"],
"trainable_layers": self.hyperparams.get("trainable_layers", 3),
}
if parse(torchvision.__version__) >= parse("0.13"):
if backbone_pretrained:
kwargs["weights"] = BACKBONE_WEIGHT_MAP[
self.hyperparams["backbone"]
]
else:
kwargs["weights"] = None
else:
kwargs["pretrained"] = backbone_pretrained
backbone = resnet_fpn_backbone(**kwargs)
else:

Просмотреть файл

@ -14,8 +14,10 @@ import torch.nn.functional as F
from torch import Tensor
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection
from torchvision.models._api import WeightsEnum
from ..datasets.utils import unbind_samples
from ..datasets import unbind_samples
from ..models import get_weight
from . import utils
@ -35,44 +37,23 @@ class RegressionTask(pl.LightningModule):
def config_task(self) -> None:
"""Configures the task based on kwargs parameters."""
in_channels = self.hyperparams["in_channels"]
model = self.hyperparams["model"]
imagenet_pretrained = False
custom_pretrained = False
if self.hyperparams["weights"] and not os.path.exists(
self.hyperparams["weights"]
):
if self.hyperparams["weights"] not in ["imagenet", "random"]:
raise ValueError(
f"Weight type '{self.hyperparams['weights']}' is not valid."
)
else:
imagenet_pretrained = self.hyperparams["weights"] == "imagenet"
custom_pretrained = False
else:
custom_pretrained = True
# Create the model
valid_models = timm.list_models(pretrained=imagenet_pretrained)
if model in valid_models:
# Create model
weights = self.hyperparams["weights"]
self.model = timm.create_model(
model,
self.hyperparams["model"],
num_classes=self.hyperparams["num_outputs"],
in_chans=in_channels,
pretrained=imagenet_pretrained,
in_chans=self.hyperparams["in_channels"],
pretrained=weights is True,
)
# Load weights
if weights and weights is not True:
if isinstance(weights, WeightsEnum):
state_dict = weights.get_state_dict(progress=True)
elif os.path.exists(weights):
_, state_dict = utils.extract_backbone(weights)
else:
raise ValueError(f"Model type '{model}' is not a valid timm model.")
if custom_pretrained:
name, state_dict = utils.extract_backbone(self.hyperparams["weights"])
if self.hyperparams["model"] != name:
raise ValueError(
f"Trying to load {name} weights into a "
f"{self.hyperparams['model']}"
)
state_dict = get_weight(weights).get_state_dict(progress=True)
self.model = utils.load_state_dict(self.model, state_dict)
def __init__(self, **kwargs: Any) -> None:
@ -80,7 +61,9 @@ class RegressionTask(pl.LightningModule):
Keyword Args:
model: Name of the timm model to use
weights: Either "random" or "imagenet"
weights: Either a weight enum, the string representation of a weight enum,
True for ImageNet weights, False or None for random weights,
or the path to a saved model state dict.
num_outputs: Number of prediction outputs
in_channels: Number of input channels to model
learning_rate: Learning rate for optimizer