diff --git a/conf/bigearthnet.yaml b/conf/bigearthnet.yaml index 81d0e8389..2ebef57f6 100644 --- a/conf/bigearthnet.yaml +++ b/conf/bigearthnet.yaml @@ -10,7 +10,7 @@ experiment: model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 - weights: "random" + weights: null in_channels: 14 num_classes: 19 datamodule: diff --git a/conf/cowc_counting.yaml b/conf/cowc_counting.yaml index 6817577d7..91f0d9921 100644 --- a/conf/cowc_counting.yaml +++ b/conf/cowc_counting.yaml @@ -6,9 +6,11 @@ experiment: name: cowc_counting_test module: model: resnet18 + weights: null + num_outputs: 1 + in_channels: 3 learning_rate: 1e-3 learning_rate_schedule_patience: 2 - pretrained: True datamodule: root: "data/cowc_counting" seed: 0 diff --git a/conf/cyclone.yaml b/conf/cyclone.yaml index 5ddc36fee..c0e038b38 100644 --- a/conf/cyclone.yaml +++ b/conf/cyclone.yaml @@ -6,9 +6,11 @@ experiment: name: "cyclone_test" module: model: "resnet18" + weights: null + num_outputs: 1 + in_channels: 3 learning_rate: 1e-3 learning_rate_schedule_patience: 2 - pretrained: True datamodule: root: "data/cyclone" seed: 0 diff --git a/conf/eurosat.yaml b/conf/eurosat.yaml index 5abde6d45..89dddfd19 100644 --- a/conf/eurosat.yaml +++ b/conf/eurosat.yaml @@ -5,7 +5,7 @@ experiment: model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 - weights: "random" + weights: null in_channels: 13 num_classes: 10 datamodule: diff --git a/conf/nasa_marine_debris.yaml b/conf/nasa_marine_debris.yaml index b0b101ad0..3b9458265 100644 --- a/conf/nasa_marine_debris.yaml +++ b/conf/nasa_marine_debris.yaml @@ -15,7 +15,6 @@ experiment: module: model: "faster-rcnn" backbone: "resnet50" - pretrained: True num_classes: 2 learning_rate: 1.2e-4 learning_rate_schedule_patience: 6 diff --git a/conf/resisc45.yaml b/conf/resisc45.yaml index 4dc34b13c..435df7b94 100644 --- a/conf/resisc45.yaml +++ b/conf/resisc45.yaml @@ -10,7 +10,7 @@ experiment: model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 - weights: "random" + weights: null in_channels: 3 num_classes: 45 datamodule: diff --git a/conf/so2sat.yaml b/conf/so2sat.yaml index 4caf2a01b..e8157c5c1 100644 --- a/conf/so2sat.yaml +++ b/conf/so2sat.yaml @@ -10,7 +10,7 @@ experiment: model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 - weights: "random" + weights: null in_channels: 3 num_classes: 17 datamodule: diff --git a/docs/api/models.rst b/docs/api/models.rst index 253a1ecf1..4e9889c1f 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -21,7 +21,7 @@ Fully-convolutional Network .. autoclass:: FCN FC Siamese Networks -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +^^^^^^^^^^^^^^^^^^^ .. autoclass:: FCSiamConc .. autoclass:: FCSiamDiff @@ -34,4 +34,33 @@ RCF Extractor ResNet ^^^^^^ +.. autofunction:: resnet18 .. autofunction:: resnet50 +.. autoclass:: ResNet18_Weights +.. autoclass:: ResNet50_Weights + +.. csv-table:: + :widths: 45 10 10 10 15 10 10 10 + :header-rows: 1 + :align: center + :file: resnet_pretrained_weights.csv + +Vision Transformer +^^^^^^^^^^^^^^^^^^ + +.. autofunction:: vit_small_patch16_224 +.. autoclass:: ViTSmall16_Weights + +.. csv-table:: + :widths: 45 10 10 10 15 10 10 10 + :header-rows: 1 + :align: center + :file: vit_pretrained_weights.csv + +Utility Functions +^^^^^^^^^^^^^^^^^ + +.. autofunction:: get_model +.. autofunction:: get_model_weights +.. autofunction:: get_weight +.. autofunction:: list_models diff --git a/docs/api/resnet_pretrained_weights.csv b/docs/api/resnet_pretrained_weights.csv new file mode 100644 index 000000000..d1fa7c413 --- /dev/null +++ b/docs/api/resnet_pretrained_weights.csv @@ -0,0 +1,9 @@ +Weight,Channels,Source,Citation,BigEarthNet,EuroSAT,So2Sat,OSCD +ResNet18_Weights.SENTINEL2_ALL_MOCO,13,`link `__,`link `__,,,, +ResNet18_Weights.SENTINEL2_RGB_MOCO, 3,`link `__,`link `__,,,, +ResNet18_Weights.SENTINEL2_RGB_SECO, 3,`link `__,`link `__,87.27,93.14,,46.94 +ResNet50_Weights.SENTINEL1_ALL_MOCO, 2,`link `__,`link `__,,,, +ResNet50_Weights.SENTINEL2_ALL_MOCO,13,`link `__,`link `__,91.8,99.1,60.9, +ResNet50_Weights.SENTINEL2_RGB_MOCO, 3,`link `__,`link `__,,,, +ResNet50_Weights.SENTINEL2_ALL_DINO,13,`link `__,`link `__,90.7,99.1,63.6, +ResNet50_Weights.SENTINEL2_RGB_SECO, 3,`link `__,`link `__,87.81,,, diff --git a/docs/api/vit_pretrained_weights.csv b/docs/api/vit_pretrained_weights.csv new file mode 100644 index 000000000..1ac4e9452 --- /dev/null +++ b/docs/api/vit_pretrained_weights.csv @@ -0,0 +1,3 @@ +Weight,Channels,Source,Citation,BigEarthNet,EuroSAT,So2Sat,OSCD +VITSmall16_Weights.SENTINEL2_ALL_MOCO,13,`link `__,`link `__,89.9,98.6,61.6, +VITSmall16_Weights.SENTINEL2_ALL_DINO,13,`link `__,`link `__,90.5,99.0,62.2, diff --git a/docs/conf.py b/docs/conf.py index 6b7ce0b4a..06d6c9da2 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -56,14 +56,12 @@ needs_sphinx = "4.0" nitpicky = True nitpick_ignore = [ - # https://github.com/sphinx-doc/sphinx/issues/8127 - ("py:class", ".."), - # TODO: can't figure out why this isn't found - ("py:class", "LightningDataModule"), - ("py:class", "pytorch_lightning.core.module.LightningModule"), - # Undocumented class - ("py:class", "torchvision.models.resnet.ResNet"), + # Undocumented classes ("py:class", "segmentation_models_pytorch.base.model.SegmentationModel"), + ("py:class", "timm.models.resnet.ResNet"), + ("py:class", "timm.models.vision_transformer.VisionTransformer"), + ("py:class", "torchvision.models._api.WeightsEnum"), + ("py:class", "torchvision.models.resnet.ResNet"), ] @@ -114,6 +112,7 @@ intersphinx_mapping = { "rasterio": ("https://rasterio.readthedocs.io/en/stable/", None), "rtree": ("https://rtree.readthedocs.io/en/stable/", None), "segmentation_models_pytorch": ("https://smp.readthedocs.io/en/stable/", None), + "timm": ("https://huggingface.co/docs/timm/main/en/", None), "torch": ("https://pytorch.org/docs/stable", None), "torchvision": ("https://pytorch.org/vision/stable", None), } diff --git a/docs/index.rst b/docs/index.rst index 6228db380..44f637070 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -33,6 +33,7 @@ torchgeo tutorials/indices tutorials/trainers tutorials/benchmarking + tutorials/pretrained_weights .. toctree:: :maxdepth: 1 diff --git a/docs/tutorials/pretrained_weights.ipynb b/docs/tutorials/pretrained_weights.ipynb new file mode 100644 index 000000000..9e2d4969a --- /dev/null +++ b/docs/tutorials/pretrained_weights.ipynb @@ -0,0 +1,254 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Copyright (c) Microsoft Corporation. All rights reserved.\n", + "\n", + "Licensed under the MIT License." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Pretrained Weights\n", + "\n", + "In this tutorial, we demonstrate some available pretrained weights in TorchGeo. The implementation follows torchvisions' recently introduced [Multi-Weight API](https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/). We will use the [EuroSAT](https://torchgeo.readthedocs.io/en/stable/api/datasets.html#eurosat) dataset throughout this tutorial.\n", + "\n", + "It's recommended to run this notebook on Google Colab if you don't have your own GPU. Click the \"Open in Colab\" button above to get started." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "First, we install TorchGeo." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install torchgeo" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Imports\n", + "\n", + "Next, we import TorchGeo and any other libraries we need." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "\n", + "import os\n", + "import csv\n", + "import tempfile\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pytorch_lightning as pl\n", + "import timm\n", + "from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint\n", + "from pytorch_lightning.loggers import CSVLogger\n", + "\n", + "from torchgeo.datamodules import EuroSATDataModule\n", + "from torchgeo.trainers import ClassificationTask\n", + "from torchgeo.models import ResNet50_Weights, VITSmall16_Weights" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# we set a flag to check to see whether the notebook is currently being run by PyTest, if this is the case then we'll\n", + "# skip the expensive training.\n", + "in_tests = \"PYTEST_CURRENT_TEST\" in os.environ" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Datamodule\n", + "\n", + "We will utilize TorchGeo's datamodules from [PyTorch Lightning](https://pytorch-lightning.readthedocs.io/) to organize the dataloader setup." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "root = os.path.join(tempfile.gettempdir(), \"eurosat\")\n", + "\n", + "datamodule = EuroSATDataModule(root=root, batch_size=64, num_workers=4)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Weights\n", + "\n", + "Available pretrained weights are listed on the model documentation [page](https://torchgeo.readthedocs.io/en/stable/api/models.html). While some weights only accept RGB channel input, some weights have been pretrained on Sentinel 2 imagery with 13 input channels and can hence prove useful for transfer learning tasks involving Sentinel 2 data.\n", + "\n", + "To access these weights you can do the following:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "weights = ResNet50_Weights.SENTINEL2_ALL_MOCO" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This set of weights is a torchvision `WeightEnum` and holds information such as the download url link or additional meta data. TorchGeo takes care of the downloading and initialization of models with a desired set of weights. Given that EuroSAT is a classification dataset, we can use a `ClassificationTask` object that holds the model and optimizer object as well as the training logic." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "task = ClassificationTask(\n", + " model=\"resnet50\",\n", + " loss=\"ce\",\n", + " weights=weights,\n", + " in_channels=13,\n", + " num_classes=10,\n", + " learning_rate=0.001,\n", + " learning_rate_schedule_patience=5,\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you do not want to utilize the `ClassificationTask` functionality for your experiments, you can also just create a [timm](https://github.com/rwightman/pytorch-image-models) model with pretrained weights from TorchGeo as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "in_chans = weights.meta[\"in_chans\"]\n", + "model = timm.create_model(\"resnet50\", in_chans=in_chans, num_classes=10)\n", + "model.load_state_dict(weights.get_state_dict(), strict=False)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Training\n", + "\n", + "To train our pretrained model on the EuroSAT dataset we will make use of PyTorch Lightning's [Trainer](https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html). For a more elaborate explanation of how TorchGeo uses PyTorch Lightning, check out [this tutorial](https://torchgeo.readthedocs.io/en/stable/tutorials/trainers.html)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "experiment_dir = os.path.join(tempfile.gettempdir(), \"eurosat_results\")\n", + "\n", + "checkpoint_callback = ModelCheckpoint(\n", + " monitor=\"val_loss\", dirpath=experiment_dir, save_top_k=1, save_last=True\n", + ")\n", + "\n", + "early_stopping_callback = EarlyStopping(monitor=\"val_loss\", min_delta=0.00, patience=10)\n", + "\n", + "csv_logger = CSVLogger(save_dir=experiment_dir, name=\"pretrained_weights_logs\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer = pl.Trainer(\n", + " callbacks=[checkpoint_callback, early_stopping_callback],\n", + " logger=[csv_logger],\n", + " default_root_dir=experiment_dir,\n", + " min_epochs=1,\n", + " max_epochs=10,\n", + " fast_dev_run=in_tests,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer.fit(model=task, datamodule=datamodule)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "torchEnv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "b058dd71d0e7047e70e62f655d92ec955f772479bbe5e5addd202027292e8f60" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/environment.yml b/environment.yml index 4f8c98f7d..7406b4401 100644 --- a/environment.yml +++ b/environment.yml @@ -11,12 +11,12 @@ dependencies: - pycocotools>=2 - pyproj>=2.2 - python>=3.7 - - pytorch>=1.9 + - pytorch>=1.12 - pyvista>=0.20 - rarfile>=3 - rasterio>=1.0.20 - shapely>=1.3 - - torchvision>=0.10 + - torchvision>=0.13 - pip: - black[jupyter]>=21.8 - flake8>=3.8 diff --git a/hubconf.py b/hubconf.py new file mode 100644 index 000000000..3e17c4ad7 --- /dev/null +++ b/hubconf.py @@ -0,0 +1,14 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""TorchGeo pre-trained model repository configuration file. + +* https://pytorch.org/hub/ +* https://pytorch.org/docs/stable/hub.html +""" + +from torchgeo.models import resnet18, resnet50, vit_small_patch16_224 + +__all__ = ("resnet18", "resnet50", "vit_small_patch16_224") + +dependencies = ["timm"] diff --git a/requirements/min.old b/requirements/min.old index e35c12be2..6d22b601b 100644 --- a/requirements/min.old +++ b/requirements/min.old @@ -18,9 +18,9 @@ scikit-learn==0.21.0 segmentation-models-pytorch==0.2.0 shapely==1.3.0 timm==0.4.12 -torch==1.9.0 +torch==1.12.0 torchmetrics==0.10.0 -torchvision==0.10.0 +torchvision==0.13.0 # datasets h5py==2.6.0 diff --git a/setup.cfg b/setup.cfg index ebd22e0f1..ce7b2446b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -57,12 +57,12 @@ install_requires = shapely>=1.3,<3 # timm 0.4.12 required by segmentation-models-pytorch timm>=0.4.12,<0.7 - # torch 1.9+ required by torchvision - torch>=1.9,<2 + # torch 1.12+ required by torchvision + torch>=1.12,<2 # torchmetrics 0.10+ required for binary/multiclass/multilabel classification metrics torchmetrics>=0.10,<0.12 - # torchvision 0.10+ required for torchvision.utils.draw_segmentation_masks - torchvision>=0.10,<0.15 + # torchvision 0.13+ required for torchvision.models._api.WeightsEnum + torchvision>=0.13,<0.15 python_requires = ~= 3.7 packages = find: diff --git a/tests/conf/bigearthnet_all.yaml b/tests/conf/bigearthnet_all.yaml index 2ac68ca4e..e885c9db4 100644 --- a/tests/conf/bigearthnet_all.yaml +++ b/tests/conf/bigearthnet_all.yaml @@ -5,7 +5,7 @@ experiment: model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 - weights: "random" + weights: null in_channels: 14 num_classes: 19 datamodule: diff --git a/tests/conf/bigearthnet_s1.yaml b/tests/conf/bigearthnet_s1.yaml index a5427bff5..09b71cbd8 100644 --- a/tests/conf/bigearthnet_s1.yaml +++ b/tests/conf/bigearthnet_s1.yaml @@ -5,7 +5,7 @@ experiment: model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 - weights: "random" + weights: null in_channels: 2 num_classes: 19 datamodule: diff --git a/tests/conf/bigearthnet_s2.yaml b/tests/conf/bigearthnet_s2.yaml index 49ea9c723..487b14338 100644 --- a/tests/conf/bigearthnet_s2.yaml +++ b/tests/conf/bigearthnet_s2.yaml @@ -5,7 +5,7 @@ experiment: model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 - weights: "random" + weights: null in_channels: 12 num_classes: 19 datamodule: diff --git a/tests/conf/byol.yaml b/tests/conf/byol.yaml deleted file mode 100644 index 982fcd888..000000000 --- a/tests/conf/byol.yaml +++ /dev/null @@ -1,21 +0,0 @@ -experiment: - task: "ssl" - name: "test_byol" - module: - model: "byol" - backbone: "resnet18" - input_channels: 4 - weights: imagenet - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - datamodule: - root: "tests/data/chesapeake/cvpr" - download: true - train_splits: - - "de-test" - val_splits: - - "de-test" - test_splits: - - "de-test" - batch_size: 1 - num_workers: 0 diff --git a/tests/conf/chesapeake_cvpr_7.yaml b/tests/conf/chesapeake_cvpr_7.yaml index 5597890e9..9a34e401d 100644 --- a/tests/conf/chesapeake_cvpr_7.yaml +++ b/tests/conf/chesapeake_cvpr_7.yaml @@ -10,7 +10,7 @@ experiment: num_classes: 7 num_filters: 1 ignore_index: null - weights: imagenet + weights: null datamodule: root: "tests/data/chesapeake/cvpr" download: true diff --git a/tests/conf/chesapeake_cvpr_prior.yaml b/tests/conf/chesapeake_cvpr_prior.yaml index 3f4d440e2..907b17ac7 100644 --- a/tests/conf/chesapeake_cvpr_prior.yaml +++ b/tests/conf/chesapeake_cvpr_prior.yaml @@ -10,7 +10,7 @@ experiment: num_classes: 5 num_filters: 1 ignore_index: null - weights: imagenet + weights: null datamodule: root: "tests/data/chesapeake/cvpr" download: true diff --git a/tests/conf/cowc_counting.yaml b/tests/conf/cowc_counting.yaml index fe12f68e7..a4f256981 100644 --- a/tests/conf/cowc_counting.yaml +++ b/tests/conf/cowc_counting.yaml @@ -2,12 +2,11 @@ experiment: task: cowc_counting module: model: resnet18 - weights: "random" + weights: null num_outputs: 1 in_channels: 3 learning_rate: 1e-3 learning_rate_schedule_patience: 2 - pretrained: True datamodule: root: "tests/data/cowc_counting" download: true diff --git a/tests/conf/cyclone.yaml b/tests/conf/cyclone.yaml index 72b5c5e0b..fd0fd42b4 100644 --- a/tests/conf/cyclone.yaml +++ b/tests/conf/cyclone.yaml @@ -2,12 +2,11 @@ experiment: task: "cyclone" module: model: "resnet18" - weights: "random" + weights: null num_outputs: 1 in_channels: 3 learning_rate: 1e-3 learning_rate_schedule_patience: 2 - pretrained: False datamodule: root: "tests/data/cyclone" download: true diff --git a/tests/conf/eurosat.yaml b/tests/conf/eurosat.yaml index e674d2f09..a4cbc9eb5 100644 --- a/tests/conf/eurosat.yaml +++ b/tests/conf/eurosat.yaml @@ -5,7 +5,7 @@ experiment: model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 - weights: "random" + weights: null in_channels: 13 num_classes: 2 datamodule: diff --git a/tests/conf/resisc45.yaml b/tests/conf/resisc45.yaml index 0b545cc18..fd354ad09 100644 --- a/tests/conf/resisc45.yaml +++ b/tests/conf/resisc45.yaml @@ -5,7 +5,7 @@ experiment: model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 - weights: "random" + weights: null in_channels: 3 num_classes: 3 datamodule: diff --git a/tests/conf/so2sat_supervised.yaml b/tests/conf/so2sat_supervised.yaml index 476644ffe..0cbe484d6 100644 --- a/tests/conf/so2sat_supervised.yaml +++ b/tests/conf/so2sat_supervised.yaml @@ -5,7 +5,7 @@ experiment: model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 - weights: "random" + weights: null in_channels: 3 num_classes: 17 datamodule: diff --git a/tests/conf/so2sat_unsupervised.yaml b/tests/conf/so2sat_unsupervised.yaml index e7aeda254..02c1e6a32 100644 --- a/tests/conf/so2sat_unsupervised.yaml +++ b/tests/conf/so2sat_unsupervised.yaml @@ -5,7 +5,7 @@ experiment: model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 - weights: "random" + weights: null in_channels: 3 num_classes: 17 datamodule: diff --git a/tests/conf/ucmerced.yaml b/tests/conf/ucmerced.yaml index 7c2995b47..1a2ddf8ad 100644 --- a/tests/conf/ucmerced.yaml +++ b/tests/conf/ucmerced.yaml @@ -3,7 +3,7 @@ experiment: module: loss: "ce" model: "resnet18" - weights: "random" + weights: null learning_rate: 1e-3 learning_rate_schedule_patience: 6 in_channels: 3 diff --git a/tests/data/models/resnet50-sentinel2-2.pt.zip b/tests/data/models/resnet50-sentinel2-2.pt.zip deleted file mode 100644 index 4727d4cf0..000000000 Binary files a/tests/data/models/resnet50-sentinel2-2.pt.zip and /dev/null differ diff --git a/tests/models/test_api.py b/tests/models/test_api.py new file mode 100644 index 000000000..3aa126435 --- /dev/null +++ b/tests/models/test_api.py @@ -0,0 +1,50 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import enum +from typing import Callable + +import pytest +import torch.nn as nn +from torchvision.models._api import WeightsEnum + +from torchgeo.models import ( + ResNet18_Weights, + ResNet50_Weights, + ViTSmall16_Weights, + get_model, + get_model_weights, + get_weight, + list_models, + resnet18, + resnet50, + vit_small_patch16_224, +) + +builders = [resnet18, resnet50, vit_small_patch16_224] +enums = [ResNet18_Weights, ResNet50_Weights, ViTSmall16_Weights] + + +@pytest.mark.parametrize("builder", builders) +def test_get_model(builder: Callable[..., nn.Module]) -> None: + model = get_model(builder.__name__) + assert isinstance(model, nn.Module) + + +@pytest.mark.parametrize("builder", builders) +def test_get_model_weights(builder: Callable[..., nn.Module]) -> None: + weights = get_model_weights(builder) + assert isinstance(weights, enum.EnumMeta) + weights = get_model_weights(builder.__name__) + assert isinstance(weights, enum.EnumMeta) + + +@pytest.mark.parametrize("enum", enums) +def test_get_weight(enum: WeightsEnum) -> None: + for weight in enum: + assert weight == get_weight(str(weight)) + + +def test_list_models() -> None: + models = [builder.__name__ for builder in builders] + assert set(models) == set(list_models()) diff --git a/tests/models/test_resnet.py b/tests/models/test_resnet.py index a46c4771b..bb72a2e00 100644 --- a/tests/models/test_resnet.py +++ b/tests/models/test_resnet.py @@ -1,61 +1,74 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import os from pathlib import Path -from typing import Any, Optional +from typing import Any, Dict import pytest +import timm import torch +import torchvision +from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch -from torch.nn.modules import Module +from torchvision.models._api import WeightsEnum -import torchgeo.models.resnet -from torchgeo.datasets.utils import extract_archive -from torchgeo.models import resnet50 +from torchgeo.models import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50 -def load_state_dict_from_file( - file: str, - model_dir: Optional[str] = None, - map_location: Optional[Any] = None, - progress: Optional[bool] = True, - check_hash: Optional[bool] = False, - file_name: Optional[str] = None, -) -> Any: - """Mockup of ``torch.hub.load_state_dict_from_url``.""" - return torch.load(file) +def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]: + state_dict: Dict[str, Any] = torch.load(url) + return state_dict -@pytest.mark.parametrize( - "model_class,sensor,bands,in_channels,num_classes", - [(resnet50, "sentinel2", "all", 10, 17)], -) -def test_resnet( - monkeypatch: MonkeyPatch, - tmp_path: Path, - model_class: Module, - sensor: str, - bands: str, - in_channels: int, - num_classes: int, -) -> None: - extract_archive( - os.path.join("tests", "data", "models", "resnet50-sentinel2-2.pt.zip"), - str(tmp_path), - ) +class TestResNet18: + @pytest.fixture(params=[*ResNet18_Weights]) + def weights(self, request: SubRequest) -> WeightsEnum: + return request.param - new_model_urls = { - "sentinel2": {"all": {"resnet50": str(tmp_path / "resnet50-sentinel2-2.pt")}} - } + @pytest.fixture + def mocked_weights( + self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum + ) -> WeightsEnum: + path = tmp_path / f"{weights}.pth" + model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"]) + torch.save(model.state_dict(), path) + monkeypatch.setattr(weights, "url", str(path)) + monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) + return weights - monkeypatch.setattr(torchgeo.models.resnet, "MODEL_URLS", new_model_urls) - monkeypatch.setattr( - torchgeo.models.resnet, "load_state_dict_from_url", load_state_dict_from_file - ) + def test_resnet(self) -> None: + resnet18() - model = model_class(sensor, bands, pretrained=True) - x = torch.zeros(1, in_channels, 256, 256) - y = model(x) - assert isinstance(y, torch.Tensor) - assert y.size() == torch.Size([1, 17]) + def test_resnet_weights(self, mocked_weights: WeightsEnum) -> None: + resnet18(weights=mocked_weights) + + @pytest.mark.slow + def test_resnet_download(self, weights: WeightsEnum) -> None: + resnet18(weights=weights) + + +class TestResNet50: + @pytest.fixture(params=[*ResNet50_Weights]) + def weights(self, request: SubRequest) -> WeightsEnum: + return request.param + + @pytest.fixture + def mocked_weights( + self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum + ) -> WeightsEnum: + path = tmp_path / f"{weights}.pth" + model = timm.create_model("resnet50", in_chans=weights.meta["in_chans"]) + torch.save(model.state_dict(), path) + monkeypatch.setattr(weights, "url", str(path)) + monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) + return weights + + def test_resnet(self) -> None: + resnet50() + + def test_resnet_weights(self, mocked_weights: WeightsEnum) -> None: + resnet50(weights=mocked_weights) + + @pytest.mark.slow + def test_resnet_download(self, weights: WeightsEnum) -> None: + resnet50(weights=weights) diff --git a/tests/models/test_vit.py b/tests/models/test_vit.py new file mode 100644 index 000000000..1cd3ee1cd --- /dev/null +++ b/tests/models/test_vit.py @@ -0,0 +1,49 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +from pathlib import Path +from typing import Any, Dict + +import pytest +import timm +import torch +import torchvision +from _pytest.fixtures import SubRequest +from _pytest.monkeypatch import MonkeyPatch +from torchvision.models._api import WeightsEnum + +from torchgeo.models import ViTSmall16_Weights, vit_small_patch16_224 + + +def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]: + state_dict: Dict[str, Any] = torch.load(url) + return state_dict + + +class TestViTSmall16: + @pytest.fixture(params=[*ViTSmall16_Weights]) + def weights(self, request: SubRequest) -> WeightsEnum: + return request.param + + @pytest.fixture + def mocked_weights( + self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum + ) -> WeightsEnum: + path = tmp_path / f"{weights}.pth" + model = timm.create_model( + weights.meta["model"], in_chans=weights.meta["in_chans"] + ) + torch.save(model.state_dict(), path) + monkeypatch.setattr(weights, "url", str(path)) + monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) + return weights + + def test_vit(self) -> None: + vit_small_patch16_224() + + def test_vit_weights(self, mocked_weights: WeightsEnum) -> None: + vit_small_patch16_224(weights=mocked_weights) + + @pytest.mark.slow + def test_vit_download(self, weights: WeightsEnum) -> None: + vit_small_patch16_224(weights=weights) diff --git a/tests/trainers/test_byol.py b/tests/trainers/test_byol.py index bc131ddc6..492ca87ed 100644 --- a/tests/trainers/test_byol.py +++ b/tests/trainers/test_byol.py @@ -2,21 +2,33 @@ # Licensed under the MIT License. import os +from pathlib import Path from typing import Any, Dict, Type, cast import pytest +import timm +import torch import torch.nn as nn +import torchvision +from _pytest.monkeypatch import MonkeyPatch from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer from torchvision.models import resnet18 +from torchvision.models._api import WeightsEnum from torchgeo.datamodules import ChesapeakeCVPRDataModule +from torchgeo.models import ResNet18_Weights from torchgeo.trainers import BYOLTask from torchgeo.trainers.byol import BYOL, SimCLRAugmentation from .test_utils import SegmentationTestModel +def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]: + state_dict: Dict[str, Any] = torch.load(url) + return state_dict + + class TestBYOL: def test_custom_augment_fn(self) -> None: backbone = resnet18() @@ -45,7 +57,7 @@ class TestBYOLTask: def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) conf_dict = OmegaConf.to_object(conf.experiment) - conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict) + conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict) # Instantiate datamodule datamodule_kwargs = conf_dict["datamodule"] @@ -64,30 +76,31 @@ class TestBYOLTask: trainer.predict(model=model, dataloaders=datamodule.val_dataloader()) @pytest.fixture - def model_kwargs(self) -> Dict[Any, Any]: - return {"backbone": "resnet18", "weights": "random", "in_channels": 3} + def model_kwargs(self) -> Dict[str, Any]: + return {"backbone": "resnet18", "weights": None, "in_channels": 3} - def test_invalid_pretrained( - self, model_kwargs: Dict[Any, Any], checkpoint: str - ) -> None: - model_kwargs["weights"] = checkpoint - model_kwargs["backbone"] = "resnet50" - match = "Trying to load resnet18 weights into a resnet50" - with pytest.raises(ValueError, match=match): - BYOLTask(**model_kwargs) + @pytest.fixture + def mocked_weights(self, tmp_path: Path, monkeypatch: MonkeyPatch) -> WeightsEnum: + weights = ResNet18_Weights.SENTINEL2_RGB_MOCO + path = tmp_path / f"{weights}.pth" + model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"]) + torch.save(model.state_dict(), path) + monkeypatch.setattr(weights, "url", str(path)) + monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) + return weights - def test_pretrained(self, model_kwargs: Dict[Any, Any], checkpoint: str) -> None: + def test_weight_file(self, model_kwargs: Dict[str, Any], checkpoint: str) -> None: model_kwargs["weights"] = checkpoint BYOLTask(**model_kwargs) - def test_invalid_backbone(self, model_kwargs: Dict[Any, Any]) -> None: - model_kwargs["backbone"] = "invalid_backbone" - match = "Model type 'invalid_backbone' is not a valid timm model." - with pytest.raises(ValueError, match=match): - BYOLTask(**model_kwargs) + def test_weight_enum( + self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum + ) -> None: + model_kwargs["weights"] = mocked_weights + BYOLTask(**model_kwargs) - def test_invalid_weights(self, model_kwargs: Dict[Any, Any]) -> None: - model_kwargs["weights"] = "invalid_weights" - match = "Weight type 'invalid_weights' is not valid." - with pytest.raises(ValueError, match=match): - BYOLTask(**model_kwargs) + def test_weight_str( + self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum + ) -> None: + model_kwargs["weights"] = str(mocked_weights) + BYOLTask(**model_kwargs) diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index 7059e59cd..922e171d4 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -2,14 +2,18 @@ # Licensed under the MIT License. import os +from pathlib import Path from typing import Any, Dict, Type, cast import pytest import timm +import torch +import torchvision from _pytest.monkeypatch import MonkeyPatch from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer from torch.nn.modules import Module +from torchvision.models._api import WeightsEnum from torchgeo.datamodules import ( BigEarthNetDataModule, @@ -18,6 +22,7 @@ from torchgeo.datamodules import ( So2SatDataModule, UCMercedDataModule, ) +from torchgeo.models import ResNet18_Weights from torchgeo.trainers import ClassificationTask, MultiLabelClassificationTask from .test_utils import ClassificationTestModel @@ -27,6 +32,11 @@ def create_model(*args: Any, **kwargs: Any) -> Module: return ClassificationTestModel(**kwargs) +def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]: + state_dict: Dict[str, Any] = torch.load(url) + return state_dict + + class TestClassificationTask: @pytest.mark.parametrize( "name,classname", @@ -46,7 +56,7 @@ class TestClassificationTask: conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) conf_dict = OmegaConf.to_object(conf.experiment) - conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict) + conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict) # Instantiate datamodule datamodule_kwargs = conf_dict["datamodule"] @@ -66,7 +76,7 @@ class TestClassificationTask: def test_no_logger(self) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", "ucmerced.yaml")) conf_dict = OmegaConf.to_object(conf.experiment) - conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict) + conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict) # Instantiate datamodule datamodule_kwargs = conf_dict["datamodule"] @@ -83,49 +93,52 @@ class TestClassificationTask: trainer.fit(model=model, datamodule=datamodule) @pytest.fixture - def model_kwargs(self) -> Dict[Any, Any]: + def model_kwargs(self) -> Dict[str, Any]: return { "model": "resnet18", "in_channels": 13, "loss": "ce", "num_classes": 10, - "weights": "random", + "weights": None, } - def test_pretrained(self, model_kwargs: Dict[Any, Any], checkpoint: str) -> None: + @pytest.fixture + def mocked_weights(self, tmp_path: Path, monkeypatch: MonkeyPatch) -> WeightsEnum: + weights = ResNet18_Weights.SENTINEL2_ALL_MOCO + path = tmp_path / f"{weights}.pth" + model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"]) + torch.save(model.state_dict(), path) + monkeypatch.setattr(weights, "url", str(path)) + monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) + return weights + + def test_weight_file(self, model_kwargs: Dict[str, Any], checkpoint: str) -> None: model_kwargs["weights"] = checkpoint with pytest.warns(UserWarning): ClassificationTask(**model_kwargs) - def test_invalid_pretrained( - self, model_kwargs: Dict[Any, Any], checkpoint: str + def test_weight_enum( + self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum ) -> None: - model_kwargs["weights"] = checkpoint - model_kwargs["model"] = "resnet50" - match = "Trying to load resnet18 weights into a resnet50" - with pytest.raises(ValueError, match=match): + model_kwargs["weights"] = mocked_weights + with pytest.warns(UserWarning): ClassificationTask(**model_kwargs) - def test_invalid_loss(self, model_kwargs: Dict[Any, Any]) -> None: + def test_weight_str( + self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum + ) -> None: + model_kwargs["weights"] = str(mocked_weights) + with pytest.warns(UserWarning): + ClassificationTask(**model_kwargs) + + def test_invalid_loss(self, model_kwargs: Dict[str, Any]) -> None: model_kwargs["loss"] = "invalid_loss" match = "Loss type 'invalid_loss' is not valid." with pytest.raises(ValueError, match=match): ClassificationTask(**model_kwargs) - def test_invalid_model(self, model_kwargs: Dict[Any, Any]) -> None: - model_kwargs["model"] = "invalid_model" - match = "Model type 'invalid_model' is not a valid timm model." - with pytest.raises(ValueError, match=match): - ClassificationTask(**model_kwargs) - - def test_invalid_weights(self, model_kwargs: Dict[Any, Any]) -> None: - model_kwargs["weights"] = "invalid_weights" - match = "Weight type 'invalid_weights' is not valid." - with pytest.raises(ValueError, match=match): - ClassificationTask(**model_kwargs) - def test_missing_attributes( - self, model_kwargs: Dict[Any, Any], monkeypatch: MonkeyPatch + self, model_kwargs: Dict[str, Any], monkeypatch: MonkeyPatch ) -> None: monkeypatch.delattr(EuroSATDataModule, "plot") datamodule = EuroSATDataModule( @@ -150,7 +163,7 @@ class TestMultiLabelClassificationTask: ) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) conf_dict = OmegaConf.to_object(conf.experiment) - conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict) + conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict) # Instantiate datamodule datamodule_kwargs = conf_dict["datamodule"] @@ -170,7 +183,7 @@ class TestMultiLabelClassificationTask: def test_no_logger(self) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", "bigearthnet_s1.yaml")) conf_dict = OmegaConf.to_object(conf.experiment) - conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict) + conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict) # Instantiate datamodule datamodule_kwargs = conf_dict["datamodule"] @@ -187,23 +200,23 @@ class TestMultiLabelClassificationTask: trainer.fit(model=model, datamodule=datamodule) @pytest.fixture - def model_kwargs(self) -> Dict[Any, Any]: + def model_kwargs(self) -> Dict[str, Any]: return { "model": "resnet18", "in_channels": 14, "loss": "bce", "num_classes": 19, - "weights": "random", + "weights": None, } - def test_invalid_loss(self, model_kwargs: Dict[Any, Any]) -> None: + def test_invalid_loss(self, model_kwargs: Dict[str, Any]) -> None: model_kwargs["loss"] = "invalid_loss" match = "Loss type 'invalid_loss' is not valid." with pytest.raises(ValueError, match=match): MultiLabelClassificationTask(**model_kwargs) def test_missing_attributes( - self, model_kwargs: Dict[Any, Any], monkeypatch: MonkeyPatch + self, model_kwargs: Dict[str, Any], monkeypatch: MonkeyPatch ) -> None: monkeypatch.delattr(BigEarthNetDataModule, "plot") datamodule = BigEarthNetDataModule( diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index 67d70334d..22524b0a8 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -2,18 +2,30 @@ # Licensed under the MIT License. import os +from pathlib import Path from typing import Any, Dict, Type, cast import pytest +import timm +import torch +import torchvision +from _pytest.monkeypatch import MonkeyPatch from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer +from torchvision.models._api import WeightsEnum from torchgeo.datamodules import COWCCountingDataModule, TropicalCycloneDataModule +from torchgeo.models import ResNet18_Weights from torchgeo.trainers import RegressionTask from .test_utils import RegressionTestModel +def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]: + state_dict: Dict[str, Any] = torch.load(url) + return state_dict + + class TestRegressionTask: @pytest.mark.parametrize( "name,classname", @@ -25,7 +37,7 @@ class TestRegressionTask: def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) conf_dict = OmegaConf.to_object(conf.experiment) - conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict) + conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict) # Instantiate datamodule datamodule_kwargs = conf_dict["datamodule"] @@ -46,7 +58,7 @@ class TestRegressionTask: def test_no_logger(self) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", "cyclone.yaml")) conf_dict = OmegaConf.to_object(conf.experiment) - conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict) + conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict) # Instantiate datamodule datamodule_kwargs = conf_dict["datamodule"] @@ -63,36 +75,39 @@ class TestRegressionTask: trainer.fit(model=model, datamodule=datamodule) @pytest.fixture - def model_kwargs(self) -> Dict[Any, Any]: + def model_kwargs(self) -> Dict[str, Any]: return { "model": "resnet18", - "weights": "random", + "weights": None, "num_outputs": 1, "in_channels": 3, } - def test_invalid_pretrained( - self, model_kwargs: Dict[Any, Any], checkpoint: str - ) -> None: - model_kwargs["weights"] = checkpoint - model_kwargs["model"] = "resnet50" - match = "Trying to load resnet18 weights into a resnet50" - with pytest.raises(ValueError, match=match): - RegressionTask(**model_kwargs) + @pytest.fixture + def mocked_weights(self, tmp_path: Path, monkeypatch: MonkeyPatch) -> WeightsEnum: + weights = ResNet18_Weights.SENTINEL2_RGB_MOCO + path = tmp_path / f"{weights}.pth" + model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"]) + torch.save(model.state_dict(), path) + monkeypatch.setattr(weights, "url", str(path)) + monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) + return weights - def test_pretrained(self, model_kwargs: Dict[Any, Any], checkpoint: str) -> None: + def test_weight_file(self, model_kwargs: Dict[str, Any], checkpoint: str) -> None: model_kwargs["weights"] = checkpoint with pytest.warns(UserWarning): RegressionTask(**model_kwargs) - def test_invalid_model(self, model_kwargs: Dict[Any, Any]) -> None: - model_kwargs["model"] = "invalid_model" - match = "Model type 'invalid_model' is not a valid timm model." - with pytest.raises(ValueError, match=match): + def test_weight_enum( + self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum + ) -> None: + model_kwargs["weights"] = mocked_weights + with pytest.warns(UserWarning): RegressionTask(**model_kwargs) - def test_invalid_weights(self, model_kwargs: Dict[Any, Any]) -> None: - model_kwargs["weights"] = "invalid_weights" - match = "Weight type 'invalid_weights' is not valid." - with pytest.raises(ValueError, match=match): + def test_weight_str( + self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum + ) -> None: + model_kwargs["weights"] = str(mocked_weights) + with pytest.warns(UserWarning): RegressionTask(**model_kwargs) diff --git a/torchgeo/models/__init__.py b/torchgeo/models/__init__.py index 1f4afdf3d..e1b20236e 100644 --- a/torchgeo/models/__init__.py +++ b/torchgeo/models/__init__.py @@ -3,14 +3,17 @@ """TorchGeo models.""" +from .api import get_model, get_model_weights, get_weight, list_models from .changestar import ChangeMixin, ChangeStar, ChangeStarFarSeg from .farseg import FarSeg from .fcn import FCN from .fcsiam import FCSiamConc, FCSiamDiff from .rcf import RCF -from .resnet import resnet50 +from .resnet import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50 +from .vit import ViTSmall16_Weights, vit_small_patch16_224 __all__ = ( + # models "ChangeMixin", "ChangeStar", "ChangeStarFarSeg", @@ -19,5 +22,16 @@ __all__ = ( "FCSiamConc", "FCSiamDiff", "RCF", + "resnet18", "resnet50", + "vit_small_patch16_224", + # weights + "ResNet50_Weights", + "ResNet18_Weights", + "ViTSmall16_Weights", + # utilities + "get_model", + "get_model_weights", + "get_weight", + "list_models", ) diff --git a/torchgeo/models/api.py b/torchgeo/models/api.py new file mode 100644 index 000000000..95cf5b5c8 --- /dev/null +++ b/torchgeo/models/api.py @@ -0,0 +1,90 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""APIs for querying and loading pre-trained model weights. + +See the following references for design details: + +* https://pytorch.org/blog/easily-list-and-initialize-models-with-new-apis-in-torchvision/ +* https://pytorch.org/vision/stable/models.html +* https://github.com/pytorch/vision/blob/main/torchvision/models/_api.py +""" # noqa: E501 + +from typing import Any, Callable, List, Union + +import torch.nn as nn +from torchvision.models._api import WeightsEnum + +from .resnet import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50 +from .vit import ViTSmall16_Weights, vit_small_patch16_224 + +_model = { + "resnet18": resnet18, + "resnet50": resnet50, + "vit_small_patch16_224": vit_small_patch16_224, +} + +_model_weights = { + resnet18: ResNet18_Weights, + resnet50: ResNet50_Weights, + vit_small_patch16_224: ViTSmall16_Weights, + "resnet18": ResNet18_Weights, + "resnet50": ResNet50_Weights, + "vit_small_patch16_224": ViTSmall16_Weights, +} + + +def get_model(name: str, *args: Any, **kwargs: Any) -> nn.Module: + """Get an instantiated model from its name. + + .. versionadded:: 0.4 + + Args: + name: Name of the model. + *args: Additional arguments passed to the model builder method. + **kwargs: Additional keyword arguments passed to the model builder method. + + Returns: + An instantiated model. + """ + model: nn.Module = _model[name](*args, **kwargs) + return model + + +def get_model_weights(name: Union[Callable[..., nn.Module], str]) -> WeightsEnum: + """Get the weights enum class associated with a given model. + + .. versionadded:: 0.4 + + Args: + name: Model builder function or the name under which it is registered. + + Returns: + The weights enum class associated with the model. + """ + return _model_weights[name] + + +def get_weight(name: str) -> WeightsEnum: + """Get the weights enum value by its full name. + + .. versionadded:: 0.4 + + Args: + name: Name of the weight enum entry. + + Returns: + The requested weight enum. + """ + return eval(name) + + +def list_models() -> List[str]: + """List the registered models. + + .. versionadded:: 0.4 + + Returns: + A list of registered models. + """ + return list(_model.keys()) diff --git a/torchgeo/models/farseg.py b/torchgeo/models/farseg.py index 50db33bab..13f6601f9 100644 --- a/torchgeo/models/farseg.py +++ b/torchgeo/models/farseg.py @@ -9,7 +9,6 @@ from typing import List, cast import torch.nn.functional as F import torchvision -from packaging.version import parse from torch import Tensor from torch.nn.modules import ( BatchNorm2d, @@ -62,17 +61,14 @@ class FarSeg(Module): else: raise ValueError(f"unknown backbone: {backbone}.") kwargs = {} - if parse(torchvision.__version__) >= parse("0.13"): - if backbone_pretrained: - kwargs = { - "weights": getattr( - torchvision.models, f"ResNet{backbone[6:]}_Weights" - ).DEFAULT - } - else: - kwargs = {"weights": None} + if backbone_pretrained: + kwargs = { + "weights": getattr( + torchvision.models, f"ResNet{backbone[6:]}_Weights" + ).DEFAULT + } else: - kwargs = {"pretrained": backbone_pretrained} + kwargs = {"weights": None} self.backbone = getattr(resnet, backbone)(**kwargs) diff --git a/torchgeo/models/resnet.py b/torchgeo/models/resnet.py index 13126d922..18812156e 100644 --- a/torchgeo/models/resnet.py +++ b/torchgeo/models/resnet.py @@ -3,83 +3,208 @@ """Pre-trained ResNet models.""" -from typing import Any, List, Type, Union +from typing import Any, Optional +import kornia.augmentation as K +import timm import torch.nn as nn -from torch.hub import load_state_dict_from_url -from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet +from timm.models import ResNet +from torchvision.models._api import Weights, WeightsEnum -MODEL_URLS = { - "sentinel2": { - "all": { - "resnet50": "https://zenodo.org/record/5610000/files/resnet50-sentinel2.pt" - } - } -} +from ..transforms import AugmentationSequential + +__all__ = ["ResNet50_Weights", "ResNet18_Weights"] + +_zhu_xlab_transforms = AugmentationSequential( + K.Resize(256), K.CenterCrop(224), data_keys=["image"] +) + +# https://github.com/pytorch/vision/pull/6883 +# https://github.com/pytorch/vision/pull/7107 +# Can be removed once torchvision>=0.15 is required +Weights.__deepcopy__ = lambda *args, **kwargs: args[0] -IN_CHANNELS = {"sentinel2": {"all": 10}} +class ResNet18_Weights(WeightsEnum): # type: ignore[misc] + """ResNet18 weights. -NUM_CLASSES = {"sentinel2": 17} + For `timm `_ + *resnet18* implementation. + + .. versionadded:: 0.4 + """ + + SENTINEL2_ALL_MOCO = Weights( + url=( + "https://huggingface.co/torchgeo/resnet18_sentinel2_all_moco/" + "resolve/main/resnet18_sentinel2_all_moco.pth" + ), + transforms=_zhu_xlab_transforms, + meta={ + "dataset": "SSL4EO-S12", + "in_chans": 13, + "model": "resnet18", + "publication": "https://arxiv.org/abs/2211.07044", + "repo": "https://github.com/zhu-xlab/SSL4EO-S12", + "ssl_method": "moco", + }, + ) + + SENTINEL2_RGB_MOCO = Weights( + url=( + "https://huggingface.co/torchgeo/resnet18_sentinel2_rgb_moco/" + "resolve/main/resnet18_sentinel2_rgb_moco.pth" + ), + transforms=_zhu_xlab_transforms, + meta={ + "dataset": "SSL4EO-S12", + "in_chans": 3, + "model": "resnet18", + "publication": "https://arxiv.org/abs/2211.07044", + "repo": "https://github.com/zhu-xlab/SSL4EO-S12", + "ssl_method": "moco", + }, + ) + + SENTINEL2_RGB_SECO = Weights( + url=( + "https://huggingface.co/torchgeo/resnet18_sentinel2_rgb_seco/" + "resolve/main/resnet18_sentinel2_rgb_seco.ckpt" + ), + transforms=nn.Identity(), + meta={ + "dataset": "SeCo Dataset", + "in_chans": 3, + "model": "resnet18", + "publication": "https://arxiv.org/abs/2103.16607", + "repo": "https://github.com/ServiceNow/seasonal-contrast", + "ssl_method": "seco", + }, + ) -def _resnet( - sensor: str, - bands: str, - arch: str, - block: Type[Union[BasicBlock, Bottleneck]], - layers: List[int], - pretrained: bool, - progress: bool, - **kwargs: Any, +class ResNet50_Weights(WeightsEnum): # type: ignore[misc] + """ResNet50 weights. + + For `timm `_ + *resnet50* implementation. + + .. versionadded:: 0.4 + """ + + SENTINEL1_ALL_MOCO = Weights( + url=( + "https://huggingface.co/torchgeo/resnet50_sentinel1_all_moco/" + "resolve/main/resnet50_sentinel1_all_moco.pth" + ), + transforms=_zhu_xlab_transforms, + meta={ + "dataset": "SSL4EO-S12", + "in_chans": 2, + "model": "resnet50", + "publication": "https://arxiv.org/abs/2211.07044", + "repo": "https://github.com/zhu-xlab/SSL4EO-S12", + "ssl_method": "moco", + }, + ) + + SENTINEL2_ALL_MOCO = Weights( + url=( + "https://huggingface.co/torchgeo/resnet50_sentinel2_all_moco/" + "resolve/main/resnet50_sentinel2_all_moco.pth" + ), + transforms=_zhu_xlab_transforms, + meta={ + "dataset": "SSL4EO-S12", + "in_chans": 13, + "model": "resnet50", + "publication": "https://arxiv.org/abs/2211.07044", + "repo": "https://github.com/zhu-xlab/SSL4EO-S12", + "ssl_method": "moco", + }, + ) + + SENTINEL2_RGB_MOCO = Weights( + url=( + "https://huggingface.co/torchgeo/resnet50_sentinel2_rgb_moco/" + "resolve/main/resnet50_sentinel2_rgb_moco.pth" + ), + transforms=_zhu_xlab_transforms, + meta={ + "dataset": "SSL4EO-S12", + "in_chans": 3, + "model": "resnet50", + "publication": "https://arxiv.org/abs/2211.07044", + "repo": "https://github.com/zhu-xlab/SSL4EO-S12", + "ssl_method": "moco", + }, + ) + + SENTINEL2_ALL_DINO = Weights( + url=( + "https://huggingface.co/torchgeo/resnet50_sentinel2_all_dino/" + "resolve/main/resnet50_sentinel2_all_dino.pth" + ), + transforms=_zhu_xlab_transforms, + meta={ + "dataset": "SSL4EO-S12", + "in_chans": 13, + "model": "resnet50", + "publication": "https://arxiv.org/abs/2211.07044", + "repo": "https://github.com/zhu-xlab/SSL4EO-S12", + "ssl_method": "dino", + }, + ) + + SENTINEL2_RGB_SECO = Weights( + url=( + "https://huggingface.co/torchgeo/resnet50_sentinel2_rgb_seco/" + "resolve/main/resnet50_sentinel2_rgb_seco.ckpt" + ), + transforms=nn.Identity(), + meta={ + "dataset": "SeCo Dataset", + "in_chans": 3, + "model": "resnet50", + "publication": "https://arxiv.org/abs/2103.16607", + "repo": "https://github.com/ServiceNow/seasonal-contrast", + "ssl_method": "seco", + }, + ) + + +def resnet18( + weights: Optional[ResNet18_Weights] = None, *args: Any, **kwargs: Any ) -> ResNet: - """Resnet model. + """ResNet-18 model. If you use this model in your research, please cite the following paper: * https://arxiv.org/pdf/1512.03385.pdf + .. versionadded:: 0.4 + Args: - sensor: imagery source which determines number of input channels - bands: which spectral bands to consider: "all", "rgb", etc. - arch: ResNet version specifying number of layers - block: type of network block - layers: number of layers per block - pretrained: if True, returns a model pre-trained on ``sensor`` imagery - progress: if True, displays a progress bar of the download to stderr + weights: Pre-trained model weights to use. + *args: Additional arguments to pass to :func:`timm.create_model` + **kwargs: Additional keywork arguments to pass to :func:`timm.create_model` Returns: - A ResNet-50 model + A ResNet-18 model. """ - # Initialize a new model - model = ResNet(block, layers, NUM_CLASSES[sensor], **kwargs) + if weights: + kwargs["in_chans"] = weights.meta["in_chans"] - # Replace the first layer with the correct number of input channels - model.conv1 = nn.Conv2d( - IN_CHANNELS[sensor][bands], - out_channels=64, - kernel_size=7, - stride=1, - padding=2, - bias=False, - ) + model: ResNet = timm.create_model("resnet18", *args, **kwargs) - # Load pretrained weights - if pretrained: - state_dict = load_state_dict_from_url( - MODEL_URLS[sensor][bands][arch], progress=progress - ) - model.load_state_dict(state_dict) + if weights: + model.load_state_dict(weights.get_state_dict(progress=True), strict=False) return model def resnet50( - sensor: str, - bands: str, - pretrained: bool = False, - progress: bool = True, - **kwargs: Any, + weights: Optional[ResNet50_Weights] = None, *args: Any, **kwargs: Any ) -> ResNet: """ResNet-50 model. @@ -87,22 +212,23 @@ def resnet50( * https://arxiv.org/pdf/1512.03385.pdf + .. versionchanged:: 0.4 + Switched to multi-weight support API. + Args: - sensor: imagery source which determines number of input channels - bands: which spectral bands to consider: "all", "rgb", etc. - pretrained: if True, returns a model pre-trained on ``sensor`` imagery - progress: if True, displays a progress bar of the download to stderr + weights: Pre-trained model weights to use. + *args: Additional arguments to pass to :func:`timm.create_model`. + **kwargs: Additional keywork arguments to pass to :func:`timm.create_model`. Returns: - A ResNet-50 model + A ResNet-50 model. """ - return _resnet( - sensor, - bands, - "resnet50", - Bottleneck, - [3, 4, 6, 3], - pretrained, - progress, - **kwargs, - ) + if weights: + kwargs["in_chans"] = weights.meta["in_chans"] + + model: ResNet = timm.create_model("resnet50", *args, **kwargs) + + if weights: + model.load_state_dict(weights.get_state_dict(progress=True), strict=False) + + return model diff --git a/torchgeo/models/vit.py b/torchgeo/models/vit.py new file mode 100644 index 000000000..b8d0874d4 --- /dev/null +++ b/torchgeo/models/vit.py @@ -0,0 +1,98 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Pre-trained Vision Transformer models.""" + +from typing import Any, Optional + +import kornia.augmentation as K +import timm +from timm.models.vision_transformer import VisionTransformer +from torchvision.models._api import Weights, WeightsEnum + +from ..transforms import AugmentationSequential + +__all__ = ["ViTSmall16_Weights"] + +_zhu_xlab_transforms = AugmentationSequential( + K.Resize(256), K.CenterCrop(224), data_keys=["image"] +) + +# https://github.com/pytorch/vision/pull/6883 +# https://github.com/pytorch/vision/pull/7107 +# Can be removed once torchvision>=0.15 is required +Weights.__deepcopy__ = lambda *args, **kwargs: args[0] + + +class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc] + """Vision Transformer Samll Patch Size 16 weights. + + For `timm `_ + *vit_small_patch16_224* implementation. + + .. versionadded:: 0.4 + """ + + SENTINEL2_ALL_MOCO = Weights( + url=( + "https://huggingface.co/torchgeo/vit_small_patch16_224_sentinel2_all_moco/" + "resolve/main/vit_small_patch16_224_sentinel2_all_moco.pth" + ), + transforms=_zhu_xlab_transforms, + meta={ + "dataset": "SSL4EO-S12", + "in_chans": 13, + "model": "vit_small_patch16_224", + "publication": "https://arxiv.org/abs/2211.07044", + "repo": "https://github.com/zhu-xlab/SSL4EO-S12", + "ssl_method": "moco", + }, + ) + + SENTINEL2_ALL_DINO = Weights( + url=( + "https://huggingface.co/torchgeo/vit_small_patch16_224_sentinel2_all_dino/" + "resolve/main/vit_small_patch16_224_sentinel2_all_dino.pth" + ), + transforms=_zhu_xlab_transforms, + meta={ + "dataset": "SSL4EO-S12", + "in_chans": 13, + "model": "vit_small_patch16_224", + "publication": "https://arxiv.org/abs/2211.07044", + "repo": "https://github.com/zhu-xlab/SSL4EO-S12", + "ssl_method": "dino", + }, + ) + + +def vit_small_patch16_224( + weights: Optional[ViTSmall16_Weights] = None, *args: Any, **kwargs: Any +) -> VisionTransformer: + """Vision Transform (ViT) small patch size 16 model. + + If you use this model in your research, please cite the following paper: + + * https://arxiv.org/abs/2010.11929 + + .. versionadded:: 0.4 + + Args: + weights: Pre-trained model weights to use. + *args: Additional arguments to pass to :func:`timm.create_model`. + **kwargs: Additional keywork arguments to pass to :func:`timm.create_model`. + + Returns: + A ViT small 16 model. + """ + if weights: + kwargs["in_chans"] = weights.meta["in_chans"] + + model: VisionTransformer = timm.create_model( + "vit_small_patch16_224", *args, **kwargs + ) + + if weights: + model.load_state_dict(weights.get_state_dict(progress=True), strict=False) + + return model diff --git a/torchgeo/trainers/byol.py b/torchgeo/trainers/byol.py index fbbd894b3..d733df142 100644 --- a/torchgeo/trainers/byol.py +++ b/torchgeo/trainers/byol.py @@ -17,7 +17,9 @@ from kornia.geometry import transform as KorniaTransform from torch import Tensor, optim from torch.nn.modules import BatchNorm1d, Linear, Module, ReLU, Sequential from torch.optim.lr_scheduler import ReduceLROnPlateau +from torchvision.models._api import WeightsEnum +from ..models import get_weight from . import utils @@ -323,41 +325,23 @@ class BYOLTask(pl.LightningModule): def config_task(self) -> None: """Configures the task based on kwargs parameters passed to the constructor.""" + # Create model in_channels = self.hyperparams["in_channels"] - backbone_name = self.hyperparams["backbone"] + weights = self.hyperparams["weights"] + backbone = timm.create_model( + self.hyperparams["backbone"], + in_chans=in_channels, + pretrained=weights is True, + ) - imagenet_pretrained = False - custom_pretrained = False - if self.hyperparams["weights"] and not os.path.exists( - self.hyperparams["weights"] - ): - if self.hyperparams["weights"] not in ["imagenet", "random"]: - raise ValueError( - f"Weight type '{self.hyperparams['weights']}' is not valid." - ) + # Load weights + if weights and weights is not True: + if isinstance(weights, WeightsEnum): + state_dict = weights.get_state_dict(progress=True) + elif os.path.exists(weights): + _, state_dict = utils.extract_backbone(weights) else: - imagenet_pretrained = self.hyperparams["weights"] == "imagenet" - custom_pretrained = False - else: - custom_pretrained = True - - # Create the model - valid_models = timm.list_models(pretrained=imagenet_pretrained) - if backbone_name in valid_models: - backbone = timm.create_model( - backbone_name, in_chans=in_channels, pretrained=imagenet_pretrained - ) - else: - raise ValueError(f"Model type '{backbone_name}' is not a valid timm model.") - - if custom_pretrained: - name, state_dict = utils.extract_backbone(self.hyperparams["weights"]) - - if self.hyperparams["backbone"] != name: - raise ValueError( - f"Trying to load {name} weights into a " - f"{self.hyperparams['backbone']}" - ) + state_dict = get_weight(weights).get_state_dict(progress=True) backbone = utils.load_state_dict(backbone, state_dict) self.model = BYOL(backbone, in_channels=in_channels, image_size=(256, 256)) @@ -368,7 +352,9 @@ class BYOLTask(pl.LightningModule): Keyword Args: in_channels: Number of input channels to model backbone: Name of the timm model to use - weights: Either "random" or "imagenet" + weights: Either a weight enum, the string representation of a weight enum, + True for ImageNet weights, False or None for random weights, + or the path to a saved model state dict. learning_rate: Learning rate for optimizer learning_rate_schedule_patience: Patience for learning rate scheduler diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index ca4c8dcf6..41d7c922e 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -22,8 +22,10 @@ from torchmetrics.classification import ( MultilabelAccuracy, MultilabelFBetaScore, ) +from torchvision.models._api import WeightsEnum -from ..datasets.utils import unbind_samples +from ..datasets import unbind_samples +from ..models import get_weight from . import utils @@ -43,44 +45,23 @@ class ClassificationTask(pl.LightningModule): def config_model(self) -> None: """Configures the model based on kwargs parameters passed to the constructor.""" - in_channels = self.hyperparams["in_channels"] - model = self.hyperparams["model"] + # Create model + weights = self.hyperparams["weights"] + self.model = timm.create_model( + self.hyperparams["model"], + num_classes=self.hyperparams["num_classes"], + in_chans=self.hyperparams["in_channels"], + pretrained=weights is True, + ) - imagenet_pretrained = False - custom_pretrained = False - if self.hyperparams["weights"] and not os.path.exists( - self.hyperparams["weights"] - ): - if self.hyperparams["weights"] not in ["imagenet", "random"]: - raise ValueError( - f"Weight type '{self.hyperparams['weights']}' is not valid." - ) + # Load weights + if weights and weights is not True: + if isinstance(weights, WeightsEnum): + state_dict = weights.get_state_dict(progress=True) + elif os.path.exists(weights): + _, state_dict = utils.extract_backbone(weights) else: - imagenet_pretrained = self.hyperparams["weights"] == "imagenet" - custom_pretrained = False - else: - custom_pretrained = True - - # Create the model - valid_models = timm.list_models(pretrained=imagenet_pretrained) - if model in valid_models: - self.model = timm.create_model( - model, - num_classes=self.hyperparams["num_classes"], - in_chans=in_channels, - pretrained=imagenet_pretrained, - ) - else: - raise ValueError(f"Model type '{model}' is not a valid timm model.") - - if custom_pretrained: - name, state_dict = utils.extract_backbone(self.hyperparams["weights"]) - - if self.hyperparams["model"] != name: - raise ValueError( - f"Trying to load {name} weights into a " - f"{self.hyperparams['model']}" - ) + state_dict = get_weight(weights).get_state_dict(progress=True) self.model = utils.load_state_dict(self.model, state_dict) def config_task(self) -> None: @@ -102,7 +83,9 @@ class ClassificationTask(pl.LightningModule): Keyword Args: model: Name of the classification model use loss: Name of the loss function, accepts 'ce', 'jaccard', or 'focal' - weights: Either "random" or "imagenet" + weights: Either a weight enum, the string representation of a weight enum, + True for ImageNet weights, False or None for random weights, + or the path to a saved model state dict. num_classes: Number of prediction classes in_channels: Number of input channels to model learning_rate: Learning rate for optimizer @@ -321,12 +304,6 @@ class MultiLabelClassificationTask(ClassificationTask): """ super().__init__(**kwargs) - # Creates `self.hparams` from kwargs - self.save_hyperparameters() # type: ignore[operator] - self.hyperparams = cast(Dict[str, Any], self.hparams) - - self.config_task() - self.train_metrics = MetricCollection( { "OverallAccuracy": MultilabelAccuracy( diff --git a/torchgeo/trainers/detection.py b/torchgeo/trainers/detection.py index e4982ff1f..b50d273d4 100644 --- a/torchgeo/trainers/detection.py +++ b/torchgeo/trainers/detection.py @@ -8,33 +8,29 @@ from typing import Any, Dict, List, cast import matplotlib.pyplot as plt import pytorch_lightning as pl import torch -import torchvision -from packaging.version import parse from torch import Tensor from torch.optim.lr_scheduler import ReduceLROnPlateau from torchmetrics.detection.mean_ap import MeanAveragePrecision +from torchvision.models import resnet as R from torchvision.models.detection import FasterRCNN from torchvision.models.detection.backbone_utils import resnet_fpn_backbone from torchvision.models.detection.rpn import AnchorGenerator from torchvision.ops import MultiScaleRoIAlign -if parse(torchvision.__version__) >= parse("0.13"): - from torchvision.models import resnet as R - - BACKBONE_WEIGHT_MAP = { - "resnet18": R.ResNet18_Weights.DEFAULT, - "resnet34": R.ResNet34_Weights.DEFAULT, - "resnet50": R.ResNet50_Weights.DEFAULT, - "resnet101": R.ResNet101_Weights.DEFAULT, - "resnet152": R.ResNet152_Weights.DEFAULT, - "resnext50_32x4d": R.ResNeXt50_32X4D_Weights.DEFAULT, - "resnext101_32x8d": R.ResNeXt101_32X8D_Weights.DEFAULT, - "wide_resnet50_2": R.Wide_ResNet50_2_Weights.DEFAULT, - "wide_resnet101_2": R.Wide_ResNet101_2_Weights.DEFAULT, - } - from ..datasets.utils import unbind_samples +BACKBONE_WEIGHT_MAP = { + "resnet18": R.ResNet18_Weights.DEFAULT, + "resnet34": R.ResNet34_Weights.DEFAULT, + "resnet50": R.ResNet50_Weights.DEFAULT, + "resnet101": R.ResNet101_Weights.DEFAULT, + "resnet152": R.ResNet152_Weights.DEFAULT, + "resnext50_32x4d": R.ResNeXt50_32X4D_Weights.DEFAULT, + "resnext101_32x8d": R.ResNeXt101_32X8D_Weights.DEFAULT, + "wide_resnet50_2": R.Wide_ResNet50_2_Weights.DEFAULT, + "wide_resnet101_2": R.Wide_ResNet101_2_Weights.DEFAULT, +} + class ObjectDetectionTask(pl.LightningModule): """LightningModule for object detection of images. @@ -62,15 +58,12 @@ class ObjectDetectionTask(pl.LightningModule): "backbone_name": self.hyperparams["backbone"], "trainable_layers": self.hyperparams.get("trainable_layers", 3), } - if parse(torchvision.__version__) >= parse("0.13"): - if backbone_pretrained: - kwargs["weights"] = BACKBONE_WEIGHT_MAP[ - self.hyperparams["backbone"] - ] - else: - kwargs["weights"] = None + if backbone_pretrained: + kwargs["weights"] = BACKBONE_WEIGHT_MAP[ + self.hyperparams["backbone"] + ] else: - kwargs["pretrained"] = backbone_pretrained + kwargs["weights"] = None backbone = resnet_fpn_backbone(**kwargs) else: diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index 8a7056f12..73bffd593 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -14,8 +14,10 @@ import torch.nn.functional as F from torch import Tensor from torch.optim.lr_scheduler import ReduceLROnPlateau from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection +from torchvision.models._api import WeightsEnum -from ..datasets.utils import unbind_samples +from ..datasets import unbind_samples +from ..models import get_weight from . import utils @@ -35,44 +37,23 @@ class RegressionTask(pl.LightningModule): def config_task(self) -> None: """Configures the task based on kwargs parameters.""" - in_channels = self.hyperparams["in_channels"] - model = self.hyperparams["model"] + # Create model + weights = self.hyperparams["weights"] + self.model = timm.create_model( + self.hyperparams["model"], + num_classes=self.hyperparams["num_outputs"], + in_chans=self.hyperparams["in_channels"], + pretrained=weights is True, + ) - imagenet_pretrained = False - custom_pretrained = False - if self.hyperparams["weights"] and not os.path.exists( - self.hyperparams["weights"] - ): - if self.hyperparams["weights"] not in ["imagenet", "random"]: - raise ValueError( - f"Weight type '{self.hyperparams['weights']}' is not valid." - ) + # Load weights + if weights and weights is not True: + if isinstance(weights, WeightsEnum): + state_dict = weights.get_state_dict(progress=True) + elif os.path.exists(weights): + _, state_dict = utils.extract_backbone(weights) else: - imagenet_pretrained = self.hyperparams["weights"] == "imagenet" - custom_pretrained = False - else: - custom_pretrained = True - - # Create the model - valid_models = timm.list_models(pretrained=imagenet_pretrained) - if model in valid_models: - self.model = timm.create_model( - model, - num_classes=self.hyperparams["num_outputs"], - in_chans=in_channels, - pretrained=imagenet_pretrained, - ) - else: - raise ValueError(f"Model type '{model}' is not a valid timm model.") - - if custom_pretrained: - name, state_dict = utils.extract_backbone(self.hyperparams["weights"]) - - if self.hyperparams["model"] != name: - raise ValueError( - f"Trying to load {name} weights into a " - f"{self.hyperparams['model']}" - ) + state_dict = get_weight(weights).get_state_dict(progress=True) self.model = utils.load_state_dict(self.model, state_dict) def __init__(self, **kwargs: Any) -> None: @@ -80,7 +61,9 @@ class RegressionTask(pl.LightningModule): Keyword Args: model: Name of the timm model to use - weights: Either "random" or "imagenet" + weights: Either a weight enum, the string representation of a weight enum, + True for ImageNet weights, False or None for random weights, + or the path to a saved model state dict. num_outputs: Number of prediction outputs in_channels: Number of input channels to model learning_rate: Learning rate for optimizer