From 60eb61b5fafab9cf0d23a136961ab2c50c9951d8 Mon Sep 17 00:00:00 2001 From: Nils Lehmann <35272119+nilsleh@users.noreply.github.com> Date: Sun, 22 Jan 2023 23:25:49 +0100 Subject: [PATCH] Add Multi-Weight Support API (#917) * load pretrained weights * change name millionaid * restructure and additional weights * rename sentinel1 weights * add vit small weights * forgot to add vit.py * struggling with test * wrong name failing test * feedback on tests * increase test coverage * fix failing test * fix failing test * fix failing test and add vit tests * fix failing vit test * torchgeo.models.utils * forgot utils file * typo num channels * nitpick docs, version torchvision * another try min dependencies * add documentation table * expand pytests to test pretrained weights on tasks * reverse changes to byol task * add tests to init pretrained weights from config * forgot to add the conf files * change path * increase test coverage * vit tests all pass locally including slow * now remote * fix tests another one * add a draft tutorial * run black on tutorial notebook * Tutorial typo fixes * Lower min torch/vision versions * Fix bad rebase * Remove dead code * Flake8 fixes * Consistent in_chans * Black fixes * bison > yacs * Remove one more reference * Download modified weights from hugging face * Add entrypoints * Add torch.hub support * progress arg is required * Fix model loading for resnet18 * Add transforms, update tests * VIT -> ViT * add seco weights * Fix type hints * Link to timm docs * Fix pydocstyle * Try to fix timm docs link * Fix tests * Nuke ignores * Ignore timm links * Add model API methods * Add to __init__ and document * Test model API functions * fix tests * Use correct documentation link for intersphinx * Typos * Fix Windows tests * meth -> func * Explicit function scope * weight-specific filename * Support enums in classification trainer * Update other trainers too * Fix regression tests * Fix classification tests * Fix byol tests * Fix types * progress_bar is required arg * Test weight enums * Fix pickling * Fix regression tests * Improve coverage of classification tests * Improve coverage of BYOL tests * Update resnet table * Update ViT table * Update get_state_dict usage * Remove unused YAML files * Update table widths * Documentation improvements * Tweak tables * Try to fix Windows tests * Revert "Try to fix Windows tests" This reverts commit 1325b13ff779b28adcaca36725e098ae8352a1d6. * Monkeypatch everything * Revert "Monkeypatch everything" This reverts commit e3e8d7d04231f8c0a39b5accd8f3d977aa7cbab2. * Revert "Revert "Monkeypatch everything"" This reverts commit 9b27bd705b06a743c092301c36802ce6e9503898. * Patch things not at the source * Fix missing import Co-authored-by: Adam J. Stewart --- conf/bigearthnet.yaml | 2 +- conf/cowc_counting.yaml | 4 +- conf/cyclone.yaml | 4 +- conf/eurosat.yaml | 2 +- conf/nasa_marine_debris.yaml | 1 - conf/resisc45.yaml | 2 +- conf/so2sat.yaml | 2 +- docs/api/models.rst | 31 ++- docs/api/resnet_pretrained_weights.csv | 9 + docs/api/vit_pretrained_weights.csv | 3 + docs/conf.py | 13 +- docs/index.rst | 1 + docs/tutorials/pretrained_weights.ipynb | 254 +++++++++++++++++ environment.yml | 4 +- hubconf.py | 14 + requirements/min.old | 4 +- setup.cfg | 8 +- tests/conf/bigearthnet_all.yaml | 2 +- tests/conf/bigearthnet_s1.yaml | 2 +- tests/conf/bigearthnet_s2.yaml | 2 +- tests/conf/byol.yaml | 21 -- tests/conf/chesapeake_cvpr_7.yaml | 2 +- tests/conf/chesapeake_cvpr_prior.yaml | 2 +- tests/conf/cowc_counting.yaml | 3 +- tests/conf/cyclone.yaml | 3 +- tests/conf/eurosat.yaml | 2 +- tests/conf/resisc45.yaml | 2 +- tests/conf/so2sat_supervised.yaml | 2 +- tests/conf/so2sat_unsupervised.yaml | 2 +- tests/conf/ucmerced.yaml | 2 +- tests/data/models/resnet50-sentinel2-2.pt.zip | Bin 110708 -> 0 bytes tests/models/test_api.py | 50 ++++ tests/models/test_resnet.py | 103 ++++--- tests/models/test_vit.py | 49 ++++ tests/trainers/test_byol.py | 57 ++-- tests/trainers/test_classification.py | 75 ++--- tests/trainers/test_regression.py | 57 ++-- torchgeo/models/__init__.py | 16 +- torchgeo/models/api.py | 90 ++++++ torchgeo/models/farseg.py | 18 +- torchgeo/models/resnet.py | 260 +++++++++++++----- torchgeo/models/vit.py | 98 +++++++ torchgeo/trainers/byol.py | 52 ++-- torchgeo/trainers/classification.py | 65 ++--- torchgeo/trainers/detection.py | 43 ++- torchgeo/trainers/regression.py | 59 ++-- 46 files changed, 1101 insertions(+), 396 deletions(-) create mode 100644 docs/api/resnet_pretrained_weights.csv create mode 100644 docs/api/vit_pretrained_weights.csv create mode 100644 docs/tutorials/pretrained_weights.ipynb create mode 100644 hubconf.py delete mode 100644 tests/conf/byol.yaml delete mode 100644 tests/data/models/resnet50-sentinel2-2.pt.zip create mode 100644 tests/models/test_api.py create mode 100644 tests/models/test_vit.py create mode 100644 torchgeo/models/api.py create mode 100644 torchgeo/models/vit.py diff --git a/conf/bigearthnet.yaml b/conf/bigearthnet.yaml index 81d0e8389..2ebef57f6 100644 --- a/conf/bigearthnet.yaml +++ b/conf/bigearthnet.yaml @@ -10,7 +10,7 @@ experiment: model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 - weights: "random" + weights: null in_channels: 14 num_classes: 19 datamodule: diff --git a/conf/cowc_counting.yaml b/conf/cowc_counting.yaml index 6817577d7..91f0d9921 100644 --- a/conf/cowc_counting.yaml +++ b/conf/cowc_counting.yaml @@ -6,9 +6,11 @@ experiment: name: cowc_counting_test module: model: resnet18 + weights: null + num_outputs: 1 + in_channels: 3 learning_rate: 1e-3 learning_rate_schedule_patience: 2 - pretrained: True datamodule: root: "data/cowc_counting" seed: 0 diff --git a/conf/cyclone.yaml b/conf/cyclone.yaml index 5ddc36fee..c0e038b38 100644 --- a/conf/cyclone.yaml +++ b/conf/cyclone.yaml @@ -6,9 +6,11 @@ experiment: name: "cyclone_test" module: model: "resnet18" + weights: null + num_outputs: 1 + in_channels: 3 learning_rate: 1e-3 learning_rate_schedule_patience: 2 - pretrained: True datamodule: root: "data/cyclone" seed: 0 diff --git a/conf/eurosat.yaml b/conf/eurosat.yaml index 5abde6d45..89dddfd19 100644 --- a/conf/eurosat.yaml +++ b/conf/eurosat.yaml @@ -5,7 +5,7 @@ experiment: model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 - weights: "random" + weights: null in_channels: 13 num_classes: 10 datamodule: diff --git a/conf/nasa_marine_debris.yaml b/conf/nasa_marine_debris.yaml index b0b101ad0..3b9458265 100644 --- a/conf/nasa_marine_debris.yaml +++ b/conf/nasa_marine_debris.yaml @@ -15,7 +15,6 @@ experiment: module: model: "faster-rcnn" backbone: "resnet50" - pretrained: True num_classes: 2 learning_rate: 1.2e-4 learning_rate_schedule_patience: 6 diff --git a/conf/resisc45.yaml b/conf/resisc45.yaml index 4dc34b13c..435df7b94 100644 --- a/conf/resisc45.yaml +++ b/conf/resisc45.yaml @@ -10,7 +10,7 @@ experiment: model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 - weights: "random" + weights: null in_channels: 3 num_classes: 45 datamodule: diff --git a/conf/so2sat.yaml b/conf/so2sat.yaml index 4caf2a01b..e8157c5c1 100644 --- a/conf/so2sat.yaml +++ b/conf/so2sat.yaml @@ -10,7 +10,7 @@ experiment: model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 - weights: "random" + weights: null in_channels: 3 num_classes: 17 datamodule: diff --git a/docs/api/models.rst b/docs/api/models.rst index 253a1ecf1..4e9889c1f 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -21,7 +21,7 @@ Fully-convolutional Network .. autoclass:: FCN FC Siamese Networks -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +^^^^^^^^^^^^^^^^^^^ .. autoclass:: FCSiamConc .. autoclass:: FCSiamDiff @@ -34,4 +34,33 @@ RCF Extractor ResNet ^^^^^^ +.. autofunction:: resnet18 .. autofunction:: resnet50 +.. autoclass:: ResNet18_Weights +.. autoclass:: ResNet50_Weights + +.. csv-table:: + :widths: 45 10 10 10 15 10 10 10 + :header-rows: 1 + :align: center + :file: resnet_pretrained_weights.csv + +Vision Transformer +^^^^^^^^^^^^^^^^^^ + +.. autofunction:: vit_small_patch16_224 +.. autoclass:: ViTSmall16_Weights + +.. csv-table:: + :widths: 45 10 10 10 15 10 10 10 + :header-rows: 1 + :align: center + :file: vit_pretrained_weights.csv + +Utility Functions +^^^^^^^^^^^^^^^^^ + +.. autofunction:: get_model +.. autofunction:: get_model_weights +.. autofunction:: get_weight +.. autofunction:: list_models diff --git a/docs/api/resnet_pretrained_weights.csv b/docs/api/resnet_pretrained_weights.csv new file mode 100644 index 000000000..d1fa7c413 --- /dev/null +++ b/docs/api/resnet_pretrained_weights.csv @@ -0,0 +1,9 @@ +Weight,Channels,Source,Citation,BigEarthNet,EuroSAT,So2Sat,OSCD +ResNet18_Weights.SENTINEL2_ALL_MOCO,13,`link `__,`link `__,,,, +ResNet18_Weights.SENTINEL2_RGB_MOCO, 3,`link `__,`link `__,,,, +ResNet18_Weights.SENTINEL2_RGB_SECO, 3,`link `__,`link `__,87.27,93.14,,46.94 +ResNet50_Weights.SENTINEL1_ALL_MOCO, 2,`link `__,`link `__,,,, +ResNet50_Weights.SENTINEL2_ALL_MOCO,13,`link `__,`link `__,91.8,99.1,60.9, +ResNet50_Weights.SENTINEL2_RGB_MOCO, 3,`link `__,`link `__,,,, +ResNet50_Weights.SENTINEL2_ALL_DINO,13,`link `__,`link `__,90.7,99.1,63.6, +ResNet50_Weights.SENTINEL2_RGB_SECO, 3,`link `__,`link `__,87.81,,, diff --git a/docs/api/vit_pretrained_weights.csv b/docs/api/vit_pretrained_weights.csv new file mode 100644 index 000000000..1ac4e9452 --- /dev/null +++ b/docs/api/vit_pretrained_weights.csv @@ -0,0 +1,3 @@ +Weight,Channels,Source,Citation,BigEarthNet,EuroSAT,So2Sat,OSCD +VITSmall16_Weights.SENTINEL2_ALL_MOCO,13,`link `__,`link `__,89.9,98.6,61.6, +VITSmall16_Weights.SENTINEL2_ALL_DINO,13,`link `__,`link `__,90.5,99.0,62.2, diff --git a/docs/conf.py b/docs/conf.py index 6b7ce0b4a..06d6c9da2 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -56,14 +56,12 @@ needs_sphinx = "4.0" nitpicky = True nitpick_ignore = [ - # https://github.com/sphinx-doc/sphinx/issues/8127 - ("py:class", ".."), - # TODO: can't figure out why this isn't found - ("py:class", "LightningDataModule"), - ("py:class", "pytorch_lightning.core.module.LightningModule"), - # Undocumented class - ("py:class", "torchvision.models.resnet.ResNet"), + # Undocumented classes ("py:class", "segmentation_models_pytorch.base.model.SegmentationModel"), + ("py:class", "timm.models.resnet.ResNet"), + ("py:class", "timm.models.vision_transformer.VisionTransformer"), + ("py:class", "torchvision.models._api.WeightsEnum"), + ("py:class", "torchvision.models.resnet.ResNet"), ] @@ -114,6 +112,7 @@ intersphinx_mapping = { "rasterio": ("https://rasterio.readthedocs.io/en/stable/", None), "rtree": ("https://rtree.readthedocs.io/en/stable/", None), "segmentation_models_pytorch": ("https://smp.readthedocs.io/en/stable/", None), + "timm": ("https://huggingface.co/docs/timm/main/en/", None), "torch": ("https://pytorch.org/docs/stable", None), "torchvision": ("https://pytorch.org/vision/stable", None), } diff --git a/docs/index.rst b/docs/index.rst index 6228db380..44f637070 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -33,6 +33,7 @@ torchgeo tutorials/indices tutorials/trainers tutorials/benchmarking + tutorials/pretrained_weights .. toctree:: :maxdepth: 1 diff --git a/docs/tutorials/pretrained_weights.ipynb b/docs/tutorials/pretrained_weights.ipynb new file mode 100644 index 000000000..9e2d4969a --- /dev/null +++ b/docs/tutorials/pretrained_weights.ipynb @@ -0,0 +1,254 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Copyright (c) Microsoft Corporation. All rights reserved.\n", + "\n", + "Licensed under the MIT License." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Pretrained Weights\n", + "\n", + "In this tutorial, we demonstrate some available pretrained weights in TorchGeo. The implementation follows torchvisions' recently introduced [Multi-Weight API](https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/). We will use the [EuroSAT](https://torchgeo.readthedocs.io/en/stable/api/datasets.html#eurosat) dataset throughout this tutorial.\n", + "\n", + "It's recommended to run this notebook on Google Colab if you don't have your own GPU. Click the \"Open in Colab\" button above to get started." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "First, we install TorchGeo." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install torchgeo" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Imports\n", + "\n", + "Next, we import TorchGeo and any other libraries we need." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "\n", + "import os\n", + "import csv\n", + "import tempfile\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pytorch_lightning as pl\n", + "import timm\n", + "from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint\n", + "from pytorch_lightning.loggers import CSVLogger\n", + "\n", + "from torchgeo.datamodules import EuroSATDataModule\n", + "from torchgeo.trainers import ClassificationTask\n", + "from torchgeo.models import ResNet50_Weights, VITSmall16_Weights" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# we set a flag to check to see whether the notebook is currently being run by PyTest, if this is the case then we'll\n", + "# skip the expensive training.\n", + "in_tests = \"PYTEST_CURRENT_TEST\" in os.environ" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Datamodule\n", + "\n", + "We will utilize TorchGeo's datamodules from [PyTorch Lightning](https://pytorch-lightning.readthedocs.io/) to organize the dataloader setup." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "root = os.path.join(tempfile.gettempdir(), \"eurosat\")\n", + "\n", + "datamodule = EuroSATDataModule(root=root, batch_size=64, num_workers=4)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Weights\n", + "\n", + "Available pretrained weights are listed on the model documentation [page](https://torchgeo.readthedocs.io/en/stable/api/models.html). While some weights only accept RGB channel input, some weights have been pretrained on Sentinel 2 imagery with 13 input channels and can hence prove useful for transfer learning tasks involving Sentinel 2 data.\n", + "\n", + "To access these weights you can do the following:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "weights = ResNet50_Weights.SENTINEL2_ALL_MOCO" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This set of weights is a torchvision `WeightEnum` and holds information such as the download url link or additional meta data. TorchGeo takes care of the downloading and initialization of models with a desired set of weights. Given that EuroSAT is a classification dataset, we can use a `ClassificationTask` object that holds the model and optimizer object as well as the training logic." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "task = ClassificationTask(\n", + " model=\"resnet50\",\n", + " loss=\"ce\",\n", + " weights=weights,\n", + " in_channels=13,\n", + " num_classes=10,\n", + " learning_rate=0.001,\n", + " learning_rate_schedule_patience=5,\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you do not want to utilize the `ClassificationTask` functionality for your experiments, you can also just create a [timm](https://github.com/rwightman/pytorch-image-models) model with pretrained weights from TorchGeo as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "in_chans = weights.meta[\"in_chans\"]\n", + "model = timm.create_model(\"resnet50\", in_chans=in_chans, num_classes=10)\n", + "model.load_state_dict(weights.get_state_dict(), strict=False)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Training\n", + "\n", + "To train our pretrained model on the EuroSAT dataset we will make use of PyTorch Lightning's [Trainer](https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html). For a more elaborate explanation of how TorchGeo uses PyTorch Lightning, check out [this tutorial](https://torchgeo.readthedocs.io/en/stable/tutorials/trainers.html)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "experiment_dir = os.path.join(tempfile.gettempdir(), \"eurosat_results\")\n", + "\n", + "checkpoint_callback = ModelCheckpoint(\n", + " monitor=\"val_loss\", dirpath=experiment_dir, save_top_k=1, save_last=True\n", + ")\n", + "\n", + "early_stopping_callback = EarlyStopping(monitor=\"val_loss\", min_delta=0.00, patience=10)\n", + "\n", + "csv_logger = CSVLogger(save_dir=experiment_dir, name=\"pretrained_weights_logs\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer = pl.Trainer(\n", + " callbacks=[checkpoint_callback, early_stopping_callback],\n", + " logger=[csv_logger],\n", + " default_root_dir=experiment_dir,\n", + " min_epochs=1,\n", + " max_epochs=10,\n", + " fast_dev_run=in_tests,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer.fit(model=task, datamodule=datamodule)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "torchEnv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "b058dd71d0e7047e70e62f655d92ec955f772479bbe5e5addd202027292e8f60" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/environment.yml b/environment.yml index 4f8c98f7d..7406b4401 100644 --- a/environment.yml +++ b/environment.yml @@ -11,12 +11,12 @@ dependencies: - pycocotools>=2 - pyproj>=2.2 - python>=3.7 - - pytorch>=1.9 + - pytorch>=1.12 - pyvista>=0.20 - rarfile>=3 - rasterio>=1.0.20 - shapely>=1.3 - - torchvision>=0.10 + - torchvision>=0.13 - pip: - black[jupyter]>=21.8 - flake8>=3.8 diff --git a/hubconf.py b/hubconf.py new file mode 100644 index 000000000..3e17c4ad7 --- /dev/null +++ b/hubconf.py @@ -0,0 +1,14 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""TorchGeo pre-trained model repository configuration file. + +* https://pytorch.org/hub/ +* https://pytorch.org/docs/stable/hub.html +""" + +from torchgeo.models import resnet18, resnet50, vit_small_patch16_224 + +__all__ = ("resnet18", "resnet50", "vit_small_patch16_224") + +dependencies = ["timm"] diff --git a/requirements/min.old b/requirements/min.old index e35c12be2..6d22b601b 100644 --- a/requirements/min.old +++ b/requirements/min.old @@ -18,9 +18,9 @@ scikit-learn==0.21.0 segmentation-models-pytorch==0.2.0 shapely==1.3.0 timm==0.4.12 -torch==1.9.0 +torch==1.12.0 torchmetrics==0.10.0 -torchvision==0.10.0 +torchvision==0.13.0 # datasets h5py==2.6.0 diff --git a/setup.cfg b/setup.cfg index ebd22e0f1..ce7b2446b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -57,12 +57,12 @@ install_requires = shapely>=1.3,<3 # timm 0.4.12 required by segmentation-models-pytorch timm>=0.4.12,<0.7 - # torch 1.9+ required by torchvision - torch>=1.9,<2 + # torch 1.12+ required by torchvision + torch>=1.12,<2 # torchmetrics 0.10+ required for binary/multiclass/multilabel classification metrics torchmetrics>=0.10,<0.12 - # torchvision 0.10+ required for torchvision.utils.draw_segmentation_masks - torchvision>=0.10,<0.15 + # torchvision 0.13+ required for torchvision.models._api.WeightsEnum + torchvision>=0.13,<0.15 python_requires = ~= 3.7 packages = find: diff --git a/tests/conf/bigearthnet_all.yaml b/tests/conf/bigearthnet_all.yaml index 2ac68ca4e..e885c9db4 100644 --- a/tests/conf/bigearthnet_all.yaml +++ b/tests/conf/bigearthnet_all.yaml @@ -5,7 +5,7 @@ experiment: model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 - weights: "random" + weights: null in_channels: 14 num_classes: 19 datamodule: diff --git a/tests/conf/bigearthnet_s1.yaml b/tests/conf/bigearthnet_s1.yaml index a5427bff5..09b71cbd8 100644 --- a/tests/conf/bigearthnet_s1.yaml +++ b/tests/conf/bigearthnet_s1.yaml @@ -5,7 +5,7 @@ experiment: model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 - weights: "random" + weights: null in_channels: 2 num_classes: 19 datamodule: diff --git a/tests/conf/bigearthnet_s2.yaml b/tests/conf/bigearthnet_s2.yaml index 49ea9c723..487b14338 100644 --- a/tests/conf/bigearthnet_s2.yaml +++ b/tests/conf/bigearthnet_s2.yaml @@ -5,7 +5,7 @@ experiment: model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 - weights: "random" + weights: null in_channels: 12 num_classes: 19 datamodule: diff --git a/tests/conf/byol.yaml b/tests/conf/byol.yaml deleted file mode 100644 index 982fcd888..000000000 --- a/tests/conf/byol.yaml +++ /dev/null @@ -1,21 +0,0 @@ -experiment: - task: "ssl" - name: "test_byol" - module: - model: "byol" - backbone: "resnet18" - input_channels: 4 - weights: imagenet - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - datamodule: - root: "tests/data/chesapeake/cvpr" - download: true - train_splits: - - "de-test" - val_splits: - - "de-test" - test_splits: - - "de-test" - batch_size: 1 - num_workers: 0 diff --git a/tests/conf/chesapeake_cvpr_7.yaml b/tests/conf/chesapeake_cvpr_7.yaml index 5597890e9..9a34e401d 100644 --- a/tests/conf/chesapeake_cvpr_7.yaml +++ b/tests/conf/chesapeake_cvpr_7.yaml @@ -10,7 +10,7 @@ experiment: num_classes: 7 num_filters: 1 ignore_index: null - weights: imagenet + weights: null datamodule: root: "tests/data/chesapeake/cvpr" download: true diff --git a/tests/conf/chesapeake_cvpr_prior.yaml b/tests/conf/chesapeake_cvpr_prior.yaml index 3f4d440e2..907b17ac7 100644 --- a/tests/conf/chesapeake_cvpr_prior.yaml +++ b/tests/conf/chesapeake_cvpr_prior.yaml @@ -10,7 +10,7 @@ experiment: num_classes: 5 num_filters: 1 ignore_index: null - weights: imagenet + weights: null datamodule: root: "tests/data/chesapeake/cvpr" download: true diff --git a/tests/conf/cowc_counting.yaml b/tests/conf/cowc_counting.yaml index fe12f68e7..a4f256981 100644 --- a/tests/conf/cowc_counting.yaml +++ b/tests/conf/cowc_counting.yaml @@ -2,12 +2,11 @@ experiment: task: cowc_counting module: model: resnet18 - weights: "random" + weights: null num_outputs: 1 in_channels: 3 learning_rate: 1e-3 learning_rate_schedule_patience: 2 - pretrained: True datamodule: root: "tests/data/cowc_counting" download: true diff --git a/tests/conf/cyclone.yaml b/tests/conf/cyclone.yaml index 72b5c5e0b..fd0fd42b4 100644 --- a/tests/conf/cyclone.yaml +++ b/tests/conf/cyclone.yaml @@ -2,12 +2,11 @@ experiment: task: "cyclone" module: model: "resnet18" - weights: "random" + weights: null num_outputs: 1 in_channels: 3 learning_rate: 1e-3 learning_rate_schedule_patience: 2 - pretrained: False datamodule: root: "tests/data/cyclone" download: true diff --git a/tests/conf/eurosat.yaml b/tests/conf/eurosat.yaml index e674d2f09..a4cbc9eb5 100644 --- a/tests/conf/eurosat.yaml +++ b/tests/conf/eurosat.yaml @@ -5,7 +5,7 @@ experiment: model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 - weights: "random" + weights: null in_channels: 13 num_classes: 2 datamodule: diff --git a/tests/conf/resisc45.yaml b/tests/conf/resisc45.yaml index 0b545cc18..fd354ad09 100644 --- a/tests/conf/resisc45.yaml +++ b/tests/conf/resisc45.yaml @@ -5,7 +5,7 @@ experiment: model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 - weights: "random" + weights: null in_channels: 3 num_classes: 3 datamodule: diff --git a/tests/conf/so2sat_supervised.yaml b/tests/conf/so2sat_supervised.yaml index 476644ffe..0cbe484d6 100644 --- a/tests/conf/so2sat_supervised.yaml +++ b/tests/conf/so2sat_supervised.yaml @@ -5,7 +5,7 @@ experiment: model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 - weights: "random" + weights: null in_channels: 3 num_classes: 17 datamodule: diff --git a/tests/conf/so2sat_unsupervised.yaml b/tests/conf/so2sat_unsupervised.yaml index e7aeda254..02c1e6a32 100644 --- a/tests/conf/so2sat_unsupervised.yaml +++ b/tests/conf/so2sat_unsupervised.yaml @@ -5,7 +5,7 @@ experiment: model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 - weights: "random" + weights: null in_channels: 3 num_classes: 17 datamodule: diff --git a/tests/conf/ucmerced.yaml b/tests/conf/ucmerced.yaml index 7c2995b47..1a2ddf8ad 100644 --- a/tests/conf/ucmerced.yaml +++ b/tests/conf/ucmerced.yaml @@ -3,7 +3,7 @@ experiment: module: loss: "ce" model: "resnet18" - weights: "random" + weights: null learning_rate: 1e-3 learning_rate_schedule_patience: 6 in_channels: 3 diff --git a/tests/data/models/resnet50-sentinel2-2.pt.zip b/tests/data/models/resnet50-sentinel2-2.pt.zip deleted file mode 100644 index 4727d4cf0df3f1b7b7ece7a27df6edcc0a1f9486..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 110708 zcmeF)cU)6jzA$_g1(m8Oy$BXWdPiF5peP8^yR^`x_ZF4jK~O-DNbjA1bP(yiH|b3X zy|*N9;FSB!{meY?+?i+YIWzA3k!VQvUhBKp`mM5e)_yLFdFdv`MT{#LUAaR?_%3mJO zovdoBm%0*S%Q=@%VBO+G39CnI9j8uR7LNR>Ei^6d+1di~C+RPB=Bo4`Ffi(Bk6Aq~ z)c)?YupQXj0AK1ia?E0&9nq*GD)woUo#> z>@K9?R-YoVHH{)$MN-yC!Xnb>0*>)^F^EWC+UPX2(Dqg~x6O`g++c!6l&Tt6Nyg_4 z6$R$FL56teyR{GE5%;HzFITq5FFBNSuB53~KNa176U4ZjImp`HX7q-m%r=JmLAEm2 zfH7Y2%x8pa?m^rcm$q^zn#&5+ z$GOa-O!`zt$))-RiQC}fb>R!6b`^~uGGI;APS^Bg^*)(Nb=d)s$Sp3PiX_C-xoC7-? zztM|M0VGdkt6Gd3O-t2uXJssDpA4jvyp)aS7B(za)G&ZVOc^&5%TcZ{WLwnJ`qJv@}klis}bIV!BCi#U;q^3=AXKBi)$_Br5l*42iad7dPWj$5Z zv|Rf1Rm*y9p7Bq{ZdGtf_PiY(hfhQn(8{ipR0m!L=jxU5;3gP|Y>1 zxKGi>OKz%TwyMV$Q&sMMW0{Asv<+5_WI6%7l) zfxXxvFI$B1RI+sD9psg!=IrB{wO*-=h^}H#-8N32XSQtBeC4S5abjKDs{PBP(*rnF z;mD_T*QtBC2NOJ%^quRSCGT-qGBsYPYgg$TNt*6c)NmbSc5}9^FSR3IgGXtiJ|JjG zTt{!Se9&Ne+sLgHPBQs&BWUR;r|G^$nnQ>Y`<3oY$+=Y*on0TT zUXo8qtmTrrrK60d`(|k`LL}cyRHYHi7Yw^>@!B>{*$37x5JSd|H&R9`V?$BJUx-2u z2F4!m3+E$5^Y;DZ*22Dz(TuM~a-DGXCpV3y1-qvi-)$6j;@tL$inzmb!qHjt?!dcX zY_e7*A@i30)TXsNd>tVt$VNDIYn%0xRgF_SO0Srvf0f?l#g%p{BpBu#gP2|StKCLy z@`7)8I+-J07K8d%SzTnW1QDakG#I`jZ5R=7mA6$E$z8-=r>l@c>cb-4(dlkvtgpzv zybI)tIWx`MQ5pq2oKRb2*UQf;`;07ui;7h5u07``Ppt?YD7w!;ftQgks*oZ79P*oSox_AudvHR>jQ~o zY;qy0-u1O3=)AEVwrc+5`QSRQqoAwcn3 zD$@5GL;4NUX`duoQkkCH7w{H)2;chBap5M;_t_r-)%N^w>yGk_xG8|mF9 za-wJfgF3!fx380S6Gc-!zAhjC!-TCqfRdX#Mg}v^d&N*Go%DwEqEctc?aoe%M$^1y zi)FS?@>95q#z`DF?}KUW)?KiT&82QGUNOis?3inE=OF%ENEjmNJCM0_`@v8SDHE|u zfFbjY3WIuc66|I!hxmHb&h;pBCCU$1&B`T|t<5egyq8g?h4qCC&Xr${`A~HU<6geT zvxJau5AUkDy(3YmxSP}K`_3Ci=K?p$Z->dMg~d@hJRK3aBY;0T7cuc8Ku2OFq2RIj zi-K64GR<#rM@DG%D?nPqwZjs`Q`K zAuv-U#gDL8lcBd=zWYv2YWZfA+@3c}5^T7i>%qlzB75K0UGqKamp+B`Ji77;9-x)d zwC|e0Ys-$)mENNLS!!MEE{oW#zQKIPpmq}8-D}Z-`>wG=I;;mJ!;B_+s!6W5vQ#P1 ze%v3)yv$9v^N1pygrHRvEnlvO#Pxg`4d908@T9?t< zB<#C}-M=@j^KPA}525AyIt5$7^-8KAUcifS_#N^nZ|SS9PesD)eq4Qb)}=;rgx8{dn@uf`^i@$bz`|qO%)9hLdiykWDxuSiXFVj0~E*b(vfm zH*=?Cn3cQzZNB#&%+RAD*Q)gxKX;?|xISC}R*lS1P&j=293(ry8h`vgV=DZZDieg; z%2_rHs%{Nxo@RYAIc2tu6rI(*yX|hfO_wmmE?(0u5XE=szW$sCJoWA9P^T<*%)BnW zboetXZ&9*4zGnMU0nEd8m#!@zni%SP|1(B!QjSZBIq|E>3l1LZtAlrvg~{=HOnyXq zkwkd1;)S5AOUPprI`1?s6Dcvfn;M^c`skZk^bdWwOZrl7mfX(_2_Lh~qT^CorCgxpo5YG@0^AUbz8}Sj+fG6ah&~oohDQ1Ak6?kYj#fsfzv%-%el)hJp-XagSE#6R4qGWGKA_cEu=zw>c zMtI7yY``+SJKYL{46~QiyGai%>_wGyRe|)$S2kMQrLR|RNF}jd^Xc{Na3NnMeHK`+ z_O9uwsbF-V%*WS#5>fZbIMR7WTFGzvF-k6V{wp6tdvLvt$vME51;r&D_il^F z#^oAY;@_}~hyiy{a#tRWiM%U#4Q64Tav@o2hpFlE*ncTxwiR3QoO9QFDy|u9d9^|o zH}(!p>BS3crYS9tQ71d$NqTpI&%yF~j~LpYtA{@Ep>?Qg%-bYyB~-eW{Decr)+^$I ziBj@Q3fm>q!fY-^+-x71RMC*FIh8uRckK0$jqiLW^S*OPV4alct9Kcl@nd2GDk-J3 zu!(RXyP%1gPMyBDpQ5UX23FU^^m%W%Q6}X+B^AbbT|MTXRKQ9i3LYP4rIYF0f1N%` zox)E2mEM{ZS{H!YyOaS>ND+EN8qburd(-rZ{K5-Ya+~x!PFS{Jel0lpqu>Q~a$>*Y z%Z{PmpV|v|t_aW+;t*E$)bkJeGaOaPh!a^Cim*oX$UbrZAmmO7ST;??WU@4VB!H@Qd>XlK-KCWz0tbgP&iKA*e%tY}*Doz9KE zb}d(*6jcRRywqVKid1A`69xHAAKt63vLaS?Kd9bLO8yY?o;)4n0^254DSt3lCLychmqd3&krelAh@wRlsKrtl$-NM zc8WF7yU?lkr^xu5n*xeG8#2;IdmWk0Uo||hZqHSJBpkS(85(}=i#^w42Zby@MUPr) z{+;rIH@eq8G5RJZf}$&=JnuP#^qRbtdQr>ZU@u{wMsV+9C!$Yh_-p>ftmr4uLm8S-gI}S!XuqTpny2>yw zLm=i#*M}1^>4~ni#V={7f<$%S>_rw3Hh=Y&BrGNjBHvD@#y6~&$MqA`)6(oy}3&I^eSc2ed`D04Lzjmk3TPAhF(9`d&j0IH~pY?=#})zprwGX z6fg0D&4aK5iRVYb$udqa*xu-TBWetCz5)H(>ityI@FAV4JVq+Pqcl?qzDxZ9F|_`% z@iCGTSOKFl9b4bJHyN6>K3!`NOj^V-yO2(N19l}-KZV%sYdWMuAxZYSnwVl{c<4f( z;=8uz52a0G>0P-YU4>dfIwNnd1x{Y$c71Q6lvvnDDrDyQl;+9JI@Z1KX0>HYW1g_3 zyrlYEDQCQ|=6J1l96xn6=?QgCW3hWRZ;=TmgopTBh#3~sx>PM&_6G1-+|!S{U23O~ zeO(}e??QmK@4QRJy5*-Qb<|1w*UV_r_lX=ZGQDJ|IPS{7s`PVUxJv)d{$04-VjsiZ zPl0`=4?bDSB$TfuB+|291}koNDMmW+kS~3<$zyjJKy~%|-nd$Sw4LVkX3~*WX?$Ih z>r2q}noBdX=ydIKe8Ov3NV3tQ7;Mq<>VWyKZk4)y3YD04ee?Y`EwAnT zeKAX^d147m?v(C+N2`8z97~cBsvC~<&FgflmfzMzS9QdZ1Ep?IOUW6}mqNGI32foC zMbuL8>m}6OcYD4&()8aT4XB@ss-qvcBplH6!bEWG@#oai3GiP5? zSs~G??f?S+L&AWt@WT-zkgQcD^oP7Op~5VH5e>t=H;# z?|Lw%Y<$@jdJOu!d*=zrB?&w@;s_Z^0{hI61Wx9g?RtvOIdLJ`q{5D;pvTYYHlx2g zqK3xSK?$WNtavP}({)D|u=g+T&R>V0(mbOcdMk39*Q*|*b6Vko6PEd9TbWl9E4LL1 zdm@|1*L(S|`&38_5J`B*aotX$y%7>_LfI3!WLo*m{vp=00JHlUi}7ZnO0p>$w1fPg z>Cwmcl|2LgtH);aDF$B#5>oIVI7no^5Rpe8Uq^hCR#}l~5f&G1%qQCwiM>=IUtcqz z53%bH@yBA_uS=Hqu9>_la}qJF^f0mZxW4XZ-4ZUS-@qUDR^~Oi-z~zBJ8esBQco#k zFL68;-*+B=W-7yV32~W2L=4N;kdW5vI@#Nn=i~idgVgx?OZJ;>kU)56g?9??dnDrPjrG zUue#{T6|s8XT%Ks{>*kg`NWg28ZUDrbN0TEi4578P=yUi-SZlPd_#@kflg{5vO>{bVM5(_@hziCN6Nu6>)SKS>h`t|^B zO^WZYF8H8xmR2~(FW}Mj*vRllqv4{HUh=&WSK=Rr_urkj4~-XXiS>{7X&Lp!b*T55|9y*JVJ85;D({V1@- zt%3E)EjFqiK7gwe)`VyB-K66->|5Mf9a+@E!tBx%Y14Zxj9U#W$IfcE!(>{LDP4lY zo3FcEjG)IO+ufZ_w`aO6kF!d4)$M~p_w{{2eb(FiNwi*Ote%gHEyzZ&THj{MsI>$cFE}mt()coR@e&< zwm+N}rr(2zC7<4e)`Jgmg;Dunk?JEH1euz(M8`^WUyf%Zc+Y&hMs!1Sw+5$0m56C> z(0jvy`5NfTn)Kt70CP~}K8{&!(!QohO`INLX~KKTfg1~QWl#PwA}AXYc}$RnO8Q6= z#Y(S+a(F}w%1E@?ynJ;J`!e~J3KKS&gGXM@P)>jJ%l6tOILb~SQSN#BtBcp(mVc+F zQnJ+~DRlK8B*=10Uvd@0^iXF8Yh^N1 z>$yo{1A_@!3jV;u2mH^?g+IOWHw}6!U>JH8y%Brq#A@HWVm+d_sqUo^N%Ir+U5?1} zx5q*MY38Gt6BUklVx6muntbRUT@J6l2O*q7=|B#VCI*@a6p>MT*&1lwPRAIif}897 zd_usyCMUdyHs3U1l}?!L4E2013Wye5CY?ZKu%N*UD^F8=nI%+}<&kT`zw44$;> zrM^3OWOtpT@SF8YVj9`|m_iqd`9fp1T5i&E)`yFF`o?G6v5C1G-g<+>=gy0(kMD5c zMJ%JsX;(O>J<|ow%RI7QIVj_vjZl7Dq~k%C`90i6x@GSQMAUHZnO`ppY}E)>xvC+H zOU4DImEFTln^tR^R^;|FGKoqq?1&NNZ(|8u=j)R#KVNe#IX)7IlPA9P+Hy}GM7S$TDEQ!Mcj$IG8)`MEKy3% z@#B%WCAyI*#j)`7UYj4e>b)jr!K?3vi=OHVcXMF8cH#}wEyA~k-&XhQ;SU}{7244} zcU0QLZ1%gm9;f_9k2Qm|ka2vn^s^{UMWDI5K-2M4;$p|lN=Mx6l9&``^TN|4os4KQ4tkL_B4A$V^uF+g!RiFx$KM?!;n~CHg=xOE&&8vDS`diHAQmVNM({Ox-p~i( z6+*c?^iUe7J0;x8@yaky1@lkZrJiJ(wbycW{e~XlagJ0GL*u%p?YH9iyq5O5r>)S9 zaN^cU6IrnAEKOfzECW)lmH?Ugk{P*Er*RtX?7HN)t1EtNg~v(S0FiGfZUkiw&O>2? zcPSo%vWy;hDz!g9%=nPFslY70e3s-!Uo5)0tuMgJuP=@)>aw54I8@EHgwAX{v^?1 zp@-l4!KeB~*x^HQOisSZ@smd-23Yw#1H4|?K23w-2A$7QxcC;PU!+_))@G*@$8dG* z9%u*Ua}U%K>k1@HI*1)t(0WuSBGE;p-x|RssQz8=0rLKob;j{C;gqVtJM5g~90s(U z=u!aF#v?K{%ra4pXw@*x3Vzz$^3=b>w`J@NnZ?MkwF!9~i=|SdrB+Zbqk&)J$TP9k zG}1gGK`18^6Rc4fp0md*+({=O_*LC1+QstIL(XyAfyRPO=RKjmIkw zC#xNAZ7Aq8x;1rFi!Dyl+ChEEn1*L9$;nvPH6?il=qTD-^n)pH(H-DKJj3JUDw1S|KQpSNxDiKSAblt3X zSp6eNd|QIV=D+xoQK|*$#-poWv{GU8kUlG;L38uSRoeY+LV?1UXXQfmZti44Pu5BD{gp}_b%tiI3YQB^BI~~(WL4T*CYp<*BUx?xy!A=Bv#qo z)2(>sTjSV6vrgao1@5Ixk@anm$J9G|Fbi|z$=}?_^NgH&aq`U5%H`G+31_-b$BFnN zCMM@d-q0*TT7=_P;=}D3jrlu9^U3Vng@i}W{_gsESmS7dR&*6~&kFpc6Xa15>~X(I zm_i?lQ`k+=p(mgoR2yEJ7RaiS%Q~o`%c7|pkely2G%N;(qc>dB13L`1$4U>f`bKtY zjKcL~m?GAD&Gn=d5Z7aRmt zQr$|2WgixAKST-|cU_Cs_3jwO<5bIQtbIyHl-!t~)sk->Vb|7D&b}Q-_=r`)bNoDJ zZ1jz&ZJj*a_tRJnGg_5r-(Mm;k@f$C4!j8SE#3Yh=dH4=y#oz6jS(DL_#XDvu}5VD zf9wgkz=c_7lsvwEaOSUQ2+&lDio?M!?R+VxkflkudAt<`J0fEz0`yS+uq=!bchJyQb?j-1oS6v~x1`hcg=SSqz4Q3(b8-VB)L% zx#wX6o|YRs>#7Vth7IAh4dEHU!+XfD=JxuMS?B~c_n|9a(;TkcZnvHQ*L`sVGd(PN zbdg$~HylfDeCmM()?C)8DjlpPrXiQ#kfQaN3)kUg^NK_tfC4T-mG3FHAdyyMS)P#s z3(A6m#+Jk!yaci3Z^+_vhGsQWBDQ8fD9`(A%&Qm~3QJj1(Uo0|7{~9JHmo3?VhMS6 z6GwSrHI}jU zw%G8j8$G&6-y;VXjDx%m@Ivqs3kz+~&|hNEHPA*WyOek;AMG5cfj(-Ag4HK}yO?6f zZ_dI6!IOsmMz8Gb!|*X>cN3ReG{l^eErNoEJm{)jPHgq~4rk-g!)eVtuGJv4>eY=8m$Jf@23#D+4>8mo>#zEkF(2^OU#5|L#`U5uC5kT z6dK)Ea=|?30V4+;Fe)X`0V5rP1*IVG#tB*CJqsAR;RzkiTVK%uLlRw-ZP39#Xl}Xs zZi+5>AFFvxJ*luTy7tj4Ww*AQXSW8Rn`_CADh@-PTO9e8G4;00p3lflN%C{)^0Os% zbfqQLG|7b)midhch^y9< z*sKBT^rf*H6voQb3Ieeg5JNZ;Cey`^cf*nPQ@Euk2PnvWyl!IXt40(_kdK@~kpirvsS_VwE1|M$2H0I$(^TDJ?QvU9o5;cRoUTbLx)Rdfo zL8DXFh?9BY#{0sJ<`Duh3pam6ffd#`-U$j4RoWE9nLQL{o(ql2g*FX!i@}!?9aPZ! zcPc`mNLq3C9etX}PK{8w$cJ2>ROCzO24ydnb-A@9Xlga4-pdWelG8mMh%+=hrNibG zJyuAJVo8|X5I-iMMKqsumR+@~Q0F!vtNO#@eQiDY3;3+MTfTb_J3ak*R;g zYc6dFtqRO(X(xEwI)`09KjWphy(!A(a@LaDFv{!qOX@WGnimE|7s`GcxhHNhML5nX z4=#W(Pl8@`-ptM(x&RwWZ5v88fO8}|jP_J$C3K-%822N=*!t1gYe(o%qCd&&rFt!d z(XZ)^OFym(4Mt>q$F7OZU{!qmU^RqfQNWljEH00r?SxlbS14ifV|a~}3{8AGx~LC7 zMYoKww(J;I6pl@^a!1*%VTSx!2ZKddT02B|xhQ&9<-}8LA;BbmN)Ztnh9$wljV;dp zPrYR-Ut*!#0(9_Jc&c2utYMKB+}SiTMkv7b&2r#85{95FrWm~?T{{fsbPht2RT-I3UMDYWgI8{hlN6yK!b0hpvqJ(>Tc0OE1n^`v zme#V)lj!;J#$v52($*)|mz+jMuYZFbg)U!zO5&82xAMtcb&q91x!Qb{ip%&XvFUi<1~USFtKRvF{_Dd}^~6{#M$@Mnw)}BSg}!9Bgi8s z*h8>M__{t6w^079zs)(nY-(P0Xg_%TGbXs>+a*xkf~wxnj{d_<%B1U)>EGx%&5Ql?thajl zSwp3C_``de52{~_X2r+kKAL*CyI&_dkb;MMXRdD@YKAB!orsMvtu zR6~DcnTtb5D=ko~l(<}rtlT|Vi&i3!lunf#y(3=)N>$E)sPT>R>O~3}GW1{VQ7G;{bo=?mOuKlfLS?kV(L$TPB#%O4m;(J*oGBCpF@tWQf2`d=pSbYi zr9_8haWR9tTvF(g!V+Mx-1_v^TmuZ|tk0(%G~R%2L=6S2Gs6?`mt)UD<0yI*nUkA0 zn+VHxPTqbv$!9#|)A5qlF7iAw#-&BXCcTiv&zH1Rk+P(|V@GQnL|AUrfQz5gD4kL` zI;)~=QA1qmKvwxIc!KAfC5b)`del`*JSA+pR~o32D=?(t^ib0&G`G@wNLw7Sn+DJ)?#iALhCa-s zHM-;p4|wpJ2MtZXXgpg-2XgFh9hl;aqO?2H=gIX;L3dz#9&Q(Cq&&Do5Eju%TXavN zl&1sTr#kgVcV8*B8>!WUgtCH#_?xJS^uzE9n|~QW9->E(f{A8Ag+s22qpr3VLib8a zeTLZ2N01=yvr1@cg9V)gc{Tesaf;2wRyS=k3YdwX%EazoG7bbs({Gk6K(6X+mUYq; zxL<3&*{zLy$W)1b4FYeSK6e&ZOd18M$kcUKaJB7Jy?AM#x`~GykhdC(QCimGR7eYFYZGr zOQ^-F<$7pvmqT7$F3QOWJyr;9oqe;ZF0VbNihiSamlXW`?a$tDA4&y!^wDm^5qMv`~P*V;J@wN^CxK9kiXk>+vab&_t9qF>~P z7wuIx%41+41)v}~!L^*7qDdm@gtf69>G2(G4D0J4$WvD&1_sOQQW!%w4Mv^h=_z~l zQUb`exv<2eQ{RPDKqwUE)y&R09NQ*{`ia@mS^po%gO`4luREoxUPyUSbgNxAr^S&2 zasu`v=$IPOS}K^Jtlit2QyZ%>E+pwHwn(qVHpdP#u+QUs9}7W7?m#CYZnYTewX3ir z3SP)+<5Ap*-AuSQMAj;+pSs^gg?Ysk?te>ytx~W#P;IzM;m5 zkcrNTh{CPj@~#RCgyj~|KTiIi@J%e{jnqlbw^Hg}X&wF^guc=IJGj>Dry6xo`bl{H$;xoB7c^AxN;>uT9O>Uz?iV{nHB= zTHm^9BHh-Hl7DIKKll~)pEvd|&KmojpF7wUfSLYodDq|MJEnRy<@oPu7M?Z+S`05* zHJuicELjbm+9-5RAxCF0cX-){A6f7=TUEoS+QPm&ey5)0?4>^DC4r^UPA$q_z9A7;ZtBxL^ zRTD3&vY=@HYC{(=`g=q78r8KFMkFr)ugW?276BG;L5_d~+^3I{h0cNo4QQwmnS0vu zs(CGLt@yN2c$7s51(Oy+UBE!KF4g)jg|YuOj95Ry#Q5S51>N8Oz9YTC_n|ajor%C8 zy}8hFlWz=&?SEJ;{B!Jh6M9xSr);;9OAXZ%ewj>QMmZ&TF0JH5{^iUEjQ+RL6b7<# z1LPO)zEVEAt(|@_mU_}1bwW@Oy$(cnZuCeD zazv)d^1%ojTs{2tv`tz4!_u--UKhW5*+_w@DSWo2X?=0MWn_Lt3yi+rvh46`9pnJ9 z6doxr#H(2Y-_@zKN0B+9F6{Rqz+VyABeDn3(qB4b=-C`(|9?#f9@*qj35I5tcPT9* zvHz910=5kZ00MtXfmN!rS@*hP@MO4AcJZY7OV5QYruM%&g>3u#L?${X)2Q0X69Q~Q z>;#<}M1lDRsKXgO)qsIlKCx z0Uy*?SFym`*Kxq{{v+L(>6O+5TF$X9XO$3m00MvjAOHybHvuHHEh)0*1nEa}rzX&~ zEom@R#`MIHcRS-lS6P2 zf=n3LIB{jO%3YUVoL*R@ZceVE*(NwF+qOD6G4I(yTwY~>$juWJq@O2>et6&O+#Mo#)ZsIP0C7(7(~s^4j?;J+0#HbXLe~p@nV{ zB!*iXH1YJpeiVp=H5&9^qqu!Oe=DAS=iqwLmh(yNe*%{LL6F=rA{`0|I;2}+2vQNC zOI?ecXi3(PTMXJ`ic`sc2*(SPK2i>**t+;n><0KMAOHybdkToWFzS8mchCU$w;Xsu zoa-p@WSqV-LK+s|K>}Cj$yHGvqS%-EJwj|A4|SYEL-{gp0C$m?KFWOoOhahPA?#XP9OP)Q5@}% zf@sK_!cH$7?S$1pSVBF}*L$}?C)*d~tD9%nO;mK_XP1Rocn|o<77*c(e=60$mjMAl z01)`k7jQQoURNG(`O1a?slS$B%==`cEoG>C(uH{JDwL)h;i@rv6 zi^(~6hp@RKQXk~cz2w+mIe@@hfWS{6U?R;u+KpaKr0n{H%!_CH{MXe>J&PzL2UMeX z8lUw;2ZtR%H1Sc5l24sxFHgT(lHXWIGJMF4G|v95+w&eIB74O7;q2LN-}WeSUrmsb zXL**1XKjlXsWxh$Mx0KIdW~X7`C``A{sf>v76Ac301yBK00BVYXBK!poc?hYJmu-` zL9@lsrxaN&C-ET{x+-;jZ*algrhgZ_=P=Rj*YN8Y5PjMD^27wFk;9?S;ytMR^nykk z^v)4xNffAg8`}hZUAhhnq>nntui?ISeB+`}m+ArO`^E?JG(WR2kas}fCl%-kH7kOm z7ag&>QxTE&an*sY$O2)=8<>&LvJ@1T2FlKddkPCO2L;8ff5Sl&)xo2W{laWE5nx1t zHi~WNzvhGbfBEOh9&$$9U45BmyXdgW0Pz{}nl}MdU3rqOmL~V{)p57`XVfM1b`r|k zQMwBa=*yq}G8VP8>)!;H?fyy2B9QH$Ti`R*pdko-^NJ}c`#aCLN8u8NqsOh5Hx+cX zyGVGSK~!{y5rUwHzr>=t@vgHu>laKf=sQ)8?3f_0KWr~aATgL1x@FFvm(f=`jTvlN zkE8?@k5n}utx%uMPygh%#Q3?(*S}xx5B;brRbR~qf5mWr3p3$2cT+Vs56cC$6^^oFf}g)sPWB7=z$30*S$9Cuvs}-|NXKLq!Z#cXpJrYI1T;u-o_S#NI=c zV{NS~Nloo*%msN>BnCG6Vs-d}d@`!G26aBP{1XrWNdN=@f&V>$Y3sJ6Pw|4mmUR4o zE8x~< zLuC)2{r2wHSd=h6$LBdQmJ9_)|I>>@eSTA9LJvXipU_;dl@Kzr83>elHxc=%B}wBM zNqeW@CuR73XF7*9pxSAg3P$JnRs{XRdTaS#%nu@ZJv^EUEHkMSz6s+XHQT}ParuzVThh> zVm!h;r8%;r)Ggfxhl8OMzpR~*UjxBfs)3*p7p@K0z-G<$HW~(3_|d$JbrKulMg8>>QpQp zPJO#`RSfSc@0Xw;yYa?L+T!Gz$%(2zZhh2nj~y!gfw%iMGJtU9a(OLH{2DF|CEqD< zKJ9e@m7Xbl401yBK00BVY4-{aU`fPH)tnrv_mLrS%feQKc$J-uY z#-az{hSpvE?_}?njMB$97Dg4~hh{HfOhXpK;{29|Gj2%V2wDiMIlX|v^azr&jg3(- ze_BY7{x!dB7tKj+G$`v)nmol^H~V5oA-SLKqRb8x-Vyd;J)+FT33Y>}>=}%>9M#9X z0W5v>2OfD~*MI;Z00;mAfB+x>2mk_r03ZMe00Mx(|099d68@1-BCUufpa2Z`r8+UQMnGWOo?Vjt0Ct1k)!M$;V5|Hp%p&IIQHhNm& zm&d>=&FT-D-~Pvaf?aw3!|by+L^_e88wbbOgOvU0AvSw+cNe>DHXTAe6ZXVWzis)q zF9NTC03ZMe00MvjAn%N!`DhmuRRh;V(7Rp>HXd!@;+=r|6KlHYW28M@7gAo}HN7(J_q_ov4P$6j2`X2aTP zUd4wlpY3T~_+k>une~Z z-3d52PR>Q&^yTd@_kKaYWdHKNgGywJlSS9wphr_Xl~*$r!e$H)n0fe5Ej>WK00BS% z5C8-K0YCr{00aO5KmZT`1ONd*01yBK00BS%5C8-KfnNo_kH8FW(tkzY8$ws+cK?qi ziX68$-z0`5?pAh)F{Q;S_TfFgU5HD|TW||U!J_I7$tK;{D@HYDXADmL#H*LSU%jy9 zGR=zR854FC7Y)$?SOEb*01yBK00BS%5C8-K0YKo-DZp1FXnvoj_$1PQSRqd^o0smL zp%a*~@4<)2BYh!r`5@)q^?mGQd(ycl@faAV(4)DA^n6~0`=NLD=ffhujJ46k9oi_c zdj>2mk_rz~3Wa$=Lm|HK-kYK*y4- z<)UIGakq87OC?2tiGHWibtr*2J86PSkUX`4R26-n*yy}lT|3>Jvm;2^ercZ zm)-BbRr;)tnl42Cc4wF8g^-@%-`8V#xKDuP2d9a$W6J9`UI>aD;ccF9K3<}%Rv05d z_>DtRJJD7f60dfij~`6P&ziw2ourl}8|Q{WzisaC$24za8`h-#;$aX=! z`g+9ON&(FXslxaU61cjXMoyq(em~h?Wbv0wIOYuDSH5( zC;t8#1!j{0>lMVjqQ5ECy6}aN*8o<5`H4f}i9K=mi`q3_ceu9jzf~52!~z0<03ZMe z00MvjAOHve0)W8(V*!+%APHn)sEnI zA^KalNHAWbR>S&_T!WJM-W}$EDvmKR9+ZPze|b6n2mk_r03ZMe00MvjAOHve0)PM@@UIm3J_>`p_fFz_ zJqJ?QWQnMLvv<)<422a$F=9>9Y+af0IfSl8)J^;nbG0aIPh<(K6iY}%$0X}Y+l{1d z>@MqRTu_&KO~XHV?bE+9gTS@{0YCr{00aO5KmZT`1ONd*01yBK00BVYeoJXyFO^5+jg$5ZCsyo?vVkd4u~0-m%UJ%dr5YVikIomu}>jayR$Dvulhr zAgjv+7u)HzZiP%PC&RH3Ds%@_36Jj)w#ze!0DHQfQaWreo-uWd+gkZ}-vupQRo6G#~&700MvjAOHve0)PM@00;mA zfB+x>2>h7@rUhKpxgQDekNWck>hBfV-lKaw3C)otKeA^_=#UP1Dcnx6rcxQ_m{fa{ z^|y3Gbhl8`+s8-3D?!hMbjHe2q^<48_|x5!EUPlMWwHbNW>H^454!Te)pq#J z(`Lso(Tegk?zzdY%&?2iwxXAx4D`!jS|(D`=YA$_ynRXhUL(ybMO5OmC;Vfid8HN1 z&&V$-u)*%bLe_#Um`((j9fcj>9vNG$FM4!yb|1uc@3=eF78jlnq<#N=OeTEmJ5rMr#V}W!JO`={nVVW`-X3dFESniaDPkTNA!1Bu0{@vFuc5Akr+RDg ztRQz){sKmaFtZxS?Np6i_r>S@F@LEu6c>V6?inIQ)sKDMBC33$yCkZ6Ed-UbHPyQX z8T4YRV#?-zrSlkE=MIG{lLds~`&F{!g_?nyUc&@e|v-esIMDF~P zzs9R0Rze@ubjun@vnzo>pJUmFyxuhO*>a*^-npG5;`B{;2OX$|!LdH@b zAV<-Mf(nKbYX3 z0|dVPhYBP}pS5^jH;|i6Q94oom+A_S>}*2Xo0!L6*xC4)z2U*Owjwz4UTnvz?bHB! zX{zL?8NaKRY$uE=9B~&I_Jp9o3JuM7mMSo8CU_ zTh#9O{k6Wu(9*~Az6Bp9#^%@WkncyS^fl2H0-+jXhN!~yVu8^DRLm=$U|pwE;!x3qvsfq|6q@Zi%{p5xnM2>$!{n$5fd1Yu?O%b<$ z7CAD}kz*Bz9Dj77_~Z8fdg(DYNUd@sM>bC`^TT<~?>W+7gg%PGPaEoOywl$NL5iSE zMI*Vm|F@ct;rwzPw0_!0G^IJ^>Ag?Gxle;pXF3TPSVNx*?yo|{z)%<%ayv)^ONs-U zUm`%RlR>2C>f`+JkuW)4<48m|3USIx2`UuWbD-Y!=|X*vLD2tk27iJ*Zwh8CVw>NQ zUw08qeRDn@>!LXukBKkqnNevA7P(uP?pm9ta`W>}TCAdnV@+gisJ!yyN$mc-E-!5c zF4X-FkfSy7*+o4ee>`jf7LjV%04s3;?f*L5I{kAExBdkBKob65 zfoT^d^jL-Ebe+KJ>9S;|o6-&flVQ+ofzz+Aet^Hi+3;wj@5jv06Lb5#zfn+)JI@Mg zXDnp3bWd)qLi@7jH_#NS2pH#(8fo$JWI_AN1|j#Y8WUPJ-Jv*T4(Q z<(;lIHc5G&O!8p|le?1wf!2l}Q3?_GB|`415eYvK-npJdG9*NM<*H1h@R~{)`O7*7 z6RchR{cna5%XC6SCMOAW+5{yYRuKt!!ABmF6)_o0?ng6~eQ&yS9YRBh95xUxu1N<@ zBP&iiAN`syzF(*WMW;7ozF)wWS zB&itsA_9+tB>B%ilqG}x0a=~T(pUVZ>>yh=G98o1!C6oq+fg%{NRm%c;ouhf8$xO+ zn6ju%%|Ls0g$)7tyX=0~Idyb88Z$nj>cr*qp=(uZWpazl%ai@tR*Zd|Z%SEN`u0S| z{Q~XvmT5etBk-RNL`qhWmV`m(RyY-1jaI1D`*Z08iJeDtb-VlL-q6;j$ozjjTMJNA zXBI|NjV-&d&I)l>5_ZD2iXgEv2!)Gbr%-S~gXI;Tp{*1jKvGECgn*gs78#(@UM-;% z2}*$gQotZj31MJR1Tho{%EN?M1oMETK>{Wu+5e|IJ9p;ZOzt`7Ki}j0_s(QOd26n} zIgGM*4m1vWlHCETEt9y}MnQVc5IZK(eR4QBFhIPGZZPD=(K>dooRho70<7N4Peev!|(%%vuB)#FKCT`HZn9~Ad4X4cjSVrt#uES%!k=XD;B zfm! z5VBdGk?~iQzaoKz2G1K%X3`O$N$xqcG--q6q#BsyiXuXd4drWgOY5WnwRthHg-rZ3DW5&!IDv=E{^Y>uvQh%l(> z(y=MtX}L8Pb%_2={W}rg@L-keHOb5Ta3dOi6&2zix&KajF(Z{be`Wa^+rLfni?w5? znxdF_>u+)7qcE5g=S{(O{@LCV;?GbmtV#SR*#6HV#eXbY76R;iM}^>@z1z70D$PfV zd!T_6KnK&8tk)6zj}5D|WB;Qmdb(hzO-X#MMvSYEN`p&Y(L6x#Tc!UrSF!yWgU-$+ zjM|!%1V~3RG<#R;5g;g6IZDKD&BU5V0}RZ?#6+371V}6Y7={5HAk5Ys2Ks8#7a)N$ z^V^rXO2>6ZYE7(Tz83}i-X${_959y6a4h3Wfat9R_eq}X#Lc5qXrP+5{iOhiU~7W6 zpF&U2xo#9J|F*gk4rP}l8HzEOfR<(?%*_1JfHy2SMP7n~b{|^b$PVe1n;kDWe^ks| zJjFAH&6U-;ai*UxsVah|eh%?tR0sJ0qiw!@&+!0sZvos?Hl@{F&ny@Jwc~b&h2(Q$ zwhslnmX>KtKpk%0Rr?k*11F51LWfE2Kwp$FfBk*Tp--j#!J&2y^~q1$5A5oir!@;_ z6oTQDU#f%;Ih44r1FE8+c}Hi}qo8?i6!gFC{V?KE zu|zjT3rZslU?*()f^!?5Wk)B`%%N;luqGyYCB*{N|0q8rKh=%` z3Wb9iitW{Ec6`AW18D}OS?{W{ax;aj6xv7!uG zb+LX$aF~7~>p+l02Uod|hlaJ=6HiH|1rz^BW!2$zeS}Wes6;g)Ep_p%4Ar#jSvsXM z?j=Y~%0B~Bm3I-B_*tQW@qU17{_7&mlnEQ$UY76gNH=}3-275?NW>6t>h&1}u<>ut z!C3CQ_SoEH@hRy?u|^D=HAOj|etGm_S1deJvB& zog$pFao}>lq0{YiS!NKIA0f~r2z1_{rQUj>Fy8e<2-M4G zmmy5Cnp(XE}9KBegBgKFlXu?|VPV zJ2~}AbILZmV#P&<@7hhk9I)GMx6dv;_&-qu+QIy`)?IcyW6d6%-NFCABi`e|kJ)KC zS!qn4 None: + model = get_model(builder.__name__) + assert isinstance(model, nn.Module) + + +@pytest.mark.parametrize("builder", builders) +def test_get_model_weights(builder: Callable[..., nn.Module]) -> None: + weights = get_model_weights(builder) + assert isinstance(weights, enum.EnumMeta) + weights = get_model_weights(builder.__name__) + assert isinstance(weights, enum.EnumMeta) + + +@pytest.mark.parametrize("enum", enums) +def test_get_weight(enum: WeightsEnum) -> None: + for weight in enum: + assert weight == get_weight(str(weight)) + + +def test_list_models() -> None: + models = [builder.__name__ for builder in builders] + assert set(models) == set(list_models()) diff --git a/tests/models/test_resnet.py b/tests/models/test_resnet.py index a46c4771b..bb72a2e00 100644 --- a/tests/models/test_resnet.py +++ b/tests/models/test_resnet.py @@ -1,61 +1,74 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import os from pathlib import Path -from typing import Any, Optional +from typing import Any, Dict import pytest +import timm import torch +import torchvision +from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch -from torch.nn.modules import Module +from torchvision.models._api import WeightsEnum -import torchgeo.models.resnet -from torchgeo.datasets.utils import extract_archive -from torchgeo.models import resnet50 +from torchgeo.models import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50 -def load_state_dict_from_file( - file: str, - model_dir: Optional[str] = None, - map_location: Optional[Any] = None, - progress: Optional[bool] = True, - check_hash: Optional[bool] = False, - file_name: Optional[str] = None, -) -> Any: - """Mockup of ``torch.hub.load_state_dict_from_url``.""" - return torch.load(file) +def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]: + state_dict: Dict[str, Any] = torch.load(url) + return state_dict -@pytest.mark.parametrize( - "model_class,sensor,bands,in_channels,num_classes", - [(resnet50, "sentinel2", "all", 10, 17)], -) -def test_resnet( - monkeypatch: MonkeyPatch, - tmp_path: Path, - model_class: Module, - sensor: str, - bands: str, - in_channels: int, - num_classes: int, -) -> None: - extract_archive( - os.path.join("tests", "data", "models", "resnet50-sentinel2-2.pt.zip"), - str(tmp_path), - ) +class TestResNet18: + @pytest.fixture(params=[*ResNet18_Weights]) + def weights(self, request: SubRequest) -> WeightsEnum: + return request.param - new_model_urls = { - "sentinel2": {"all": {"resnet50": str(tmp_path / "resnet50-sentinel2-2.pt")}} - } + @pytest.fixture + def mocked_weights( + self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum + ) -> WeightsEnum: + path = tmp_path / f"{weights}.pth" + model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"]) + torch.save(model.state_dict(), path) + monkeypatch.setattr(weights, "url", str(path)) + monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) + return weights - monkeypatch.setattr(torchgeo.models.resnet, "MODEL_URLS", new_model_urls) - monkeypatch.setattr( - torchgeo.models.resnet, "load_state_dict_from_url", load_state_dict_from_file - ) + def test_resnet(self) -> None: + resnet18() - model = model_class(sensor, bands, pretrained=True) - x = torch.zeros(1, in_channels, 256, 256) - y = model(x) - assert isinstance(y, torch.Tensor) - assert y.size() == torch.Size([1, 17]) + def test_resnet_weights(self, mocked_weights: WeightsEnum) -> None: + resnet18(weights=mocked_weights) + + @pytest.mark.slow + def test_resnet_download(self, weights: WeightsEnum) -> None: + resnet18(weights=weights) + + +class TestResNet50: + @pytest.fixture(params=[*ResNet50_Weights]) + def weights(self, request: SubRequest) -> WeightsEnum: + return request.param + + @pytest.fixture + def mocked_weights( + self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum + ) -> WeightsEnum: + path = tmp_path / f"{weights}.pth" + model = timm.create_model("resnet50", in_chans=weights.meta["in_chans"]) + torch.save(model.state_dict(), path) + monkeypatch.setattr(weights, "url", str(path)) + monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) + return weights + + def test_resnet(self) -> None: + resnet50() + + def test_resnet_weights(self, mocked_weights: WeightsEnum) -> None: + resnet50(weights=mocked_weights) + + @pytest.mark.slow + def test_resnet_download(self, weights: WeightsEnum) -> None: + resnet50(weights=weights) diff --git a/tests/models/test_vit.py b/tests/models/test_vit.py new file mode 100644 index 000000000..1cd3ee1cd --- /dev/null +++ b/tests/models/test_vit.py @@ -0,0 +1,49 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +from pathlib import Path +from typing import Any, Dict + +import pytest +import timm +import torch +import torchvision +from _pytest.fixtures import SubRequest +from _pytest.monkeypatch import MonkeyPatch +from torchvision.models._api import WeightsEnum + +from torchgeo.models import ViTSmall16_Weights, vit_small_patch16_224 + + +def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]: + state_dict: Dict[str, Any] = torch.load(url) + return state_dict + + +class TestViTSmall16: + @pytest.fixture(params=[*ViTSmall16_Weights]) + def weights(self, request: SubRequest) -> WeightsEnum: + return request.param + + @pytest.fixture + def mocked_weights( + self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum + ) -> WeightsEnum: + path = tmp_path / f"{weights}.pth" + model = timm.create_model( + weights.meta["model"], in_chans=weights.meta["in_chans"] + ) + torch.save(model.state_dict(), path) + monkeypatch.setattr(weights, "url", str(path)) + monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) + return weights + + def test_vit(self) -> None: + vit_small_patch16_224() + + def test_vit_weights(self, mocked_weights: WeightsEnum) -> None: + vit_small_patch16_224(weights=mocked_weights) + + @pytest.mark.slow + def test_vit_download(self, weights: WeightsEnum) -> None: + vit_small_patch16_224(weights=weights) diff --git a/tests/trainers/test_byol.py b/tests/trainers/test_byol.py index bc131ddc6..492ca87ed 100644 --- a/tests/trainers/test_byol.py +++ b/tests/trainers/test_byol.py @@ -2,21 +2,33 @@ # Licensed under the MIT License. import os +from pathlib import Path from typing import Any, Dict, Type, cast import pytest +import timm +import torch import torch.nn as nn +import torchvision +from _pytest.monkeypatch import MonkeyPatch from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer from torchvision.models import resnet18 +from torchvision.models._api import WeightsEnum from torchgeo.datamodules import ChesapeakeCVPRDataModule +from torchgeo.models import ResNet18_Weights from torchgeo.trainers import BYOLTask from torchgeo.trainers.byol import BYOL, SimCLRAugmentation from .test_utils import SegmentationTestModel +def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]: + state_dict: Dict[str, Any] = torch.load(url) + return state_dict + + class TestBYOL: def test_custom_augment_fn(self) -> None: backbone = resnet18() @@ -45,7 +57,7 @@ class TestBYOLTask: def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) conf_dict = OmegaConf.to_object(conf.experiment) - conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict) + conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict) # Instantiate datamodule datamodule_kwargs = conf_dict["datamodule"] @@ -64,30 +76,31 @@ class TestBYOLTask: trainer.predict(model=model, dataloaders=datamodule.val_dataloader()) @pytest.fixture - def model_kwargs(self) -> Dict[Any, Any]: - return {"backbone": "resnet18", "weights": "random", "in_channels": 3} + def model_kwargs(self) -> Dict[str, Any]: + return {"backbone": "resnet18", "weights": None, "in_channels": 3} - def test_invalid_pretrained( - self, model_kwargs: Dict[Any, Any], checkpoint: str - ) -> None: - model_kwargs["weights"] = checkpoint - model_kwargs["backbone"] = "resnet50" - match = "Trying to load resnet18 weights into a resnet50" - with pytest.raises(ValueError, match=match): - BYOLTask(**model_kwargs) + @pytest.fixture + def mocked_weights(self, tmp_path: Path, monkeypatch: MonkeyPatch) -> WeightsEnum: + weights = ResNet18_Weights.SENTINEL2_RGB_MOCO + path = tmp_path / f"{weights}.pth" + model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"]) + torch.save(model.state_dict(), path) + monkeypatch.setattr(weights, "url", str(path)) + monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) + return weights - def test_pretrained(self, model_kwargs: Dict[Any, Any], checkpoint: str) -> None: + def test_weight_file(self, model_kwargs: Dict[str, Any], checkpoint: str) -> None: model_kwargs["weights"] = checkpoint BYOLTask(**model_kwargs) - def test_invalid_backbone(self, model_kwargs: Dict[Any, Any]) -> None: - model_kwargs["backbone"] = "invalid_backbone" - match = "Model type 'invalid_backbone' is not a valid timm model." - with pytest.raises(ValueError, match=match): - BYOLTask(**model_kwargs) + def test_weight_enum( + self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum + ) -> None: + model_kwargs["weights"] = mocked_weights + BYOLTask(**model_kwargs) - def test_invalid_weights(self, model_kwargs: Dict[Any, Any]) -> None: - model_kwargs["weights"] = "invalid_weights" - match = "Weight type 'invalid_weights' is not valid." - with pytest.raises(ValueError, match=match): - BYOLTask(**model_kwargs) + def test_weight_str( + self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum + ) -> None: + model_kwargs["weights"] = str(mocked_weights) + BYOLTask(**model_kwargs) diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index 7059e59cd..922e171d4 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -2,14 +2,18 @@ # Licensed under the MIT License. import os +from pathlib import Path from typing import Any, Dict, Type, cast import pytest import timm +import torch +import torchvision from _pytest.monkeypatch import MonkeyPatch from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer from torch.nn.modules import Module +from torchvision.models._api import WeightsEnum from torchgeo.datamodules import ( BigEarthNetDataModule, @@ -18,6 +22,7 @@ from torchgeo.datamodules import ( So2SatDataModule, UCMercedDataModule, ) +from torchgeo.models import ResNet18_Weights from torchgeo.trainers import ClassificationTask, MultiLabelClassificationTask from .test_utils import ClassificationTestModel @@ -27,6 +32,11 @@ def create_model(*args: Any, **kwargs: Any) -> Module: return ClassificationTestModel(**kwargs) +def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]: + state_dict: Dict[str, Any] = torch.load(url) + return state_dict + + class TestClassificationTask: @pytest.mark.parametrize( "name,classname", @@ -46,7 +56,7 @@ class TestClassificationTask: conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) conf_dict = OmegaConf.to_object(conf.experiment) - conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict) + conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict) # Instantiate datamodule datamodule_kwargs = conf_dict["datamodule"] @@ -66,7 +76,7 @@ class TestClassificationTask: def test_no_logger(self) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", "ucmerced.yaml")) conf_dict = OmegaConf.to_object(conf.experiment) - conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict) + conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict) # Instantiate datamodule datamodule_kwargs = conf_dict["datamodule"] @@ -83,49 +93,52 @@ class TestClassificationTask: trainer.fit(model=model, datamodule=datamodule) @pytest.fixture - def model_kwargs(self) -> Dict[Any, Any]: + def model_kwargs(self) -> Dict[str, Any]: return { "model": "resnet18", "in_channels": 13, "loss": "ce", "num_classes": 10, - "weights": "random", + "weights": None, } - def test_pretrained(self, model_kwargs: Dict[Any, Any], checkpoint: str) -> None: + @pytest.fixture + def mocked_weights(self, tmp_path: Path, monkeypatch: MonkeyPatch) -> WeightsEnum: + weights = ResNet18_Weights.SENTINEL2_ALL_MOCO + path = tmp_path / f"{weights}.pth" + model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"]) + torch.save(model.state_dict(), path) + monkeypatch.setattr(weights, "url", str(path)) + monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) + return weights + + def test_weight_file(self, model_kwargs: Dict[str, Any], checkpoint: str) -> None: model_kwargs["weights"] = checkpoint with pytest.warns(UserWarning): ClassificationTask(**model_kwargs) - def test_invalid_pretrained( - self, model_kwargs: Dict[Any, Any], checkpoint: str + def test_weight_enum( + self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum ) -> None: - model_kwargs["weights"] = checkpoint - model_kwargs["model"] = "resnet50" - match = "Trying to load resnet18 weights into a resnet50" - with pytest.raises(ValueError, match=match): + model_kwargs["weights"] = mocked_weights + with pytest.warns(UserWarning): ClassificationTask(**model_kwargs) - def test_invalid_loss(self, model_kwargs: Dict[Any, Any]) -> None: + def test_weight_str( + self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum + ) -> None: + model_kwargs["weights"] = str(mocked_weights) + with pytest.warns(UserWarning): + ClassificationTask(**model_kwargs) + + def test_invalid_loss(self, model_kwargs: Dict[str, Any]) -> None: model_kwargs["loss"] = "invalid_loss" match = "Loss type 'invalid_loss' is not valid." with pytest.raises(ValueError, match=match): ClassificationTask(**model_kwargs) - def test_invalid_model(self, model_kwargs: Dict[Any, Any]) -> None: - model_kwargs["model"] = "invalid_model" - match = "Model type 'invalid_model' is not a valid timm model." - with pytest.raises(ValueError, match=match): - ClassificationTask(**model_kwargs) - - def test_invalid_weights(self, model_kwargs: Dict[Any, Any]) -> None: - model_kwargs["weights"] = "invalid_weights" - match = "Weight type 'invalid_weights' is not valid." - with pytest.raises(ValueError, match=match): - ClassificationTask(**model_kwargs) - def test_missing_attributes( - self, model_kwargs: Dict[Any, Any], monkeypatch: MonkeyPatch + self, model_kwargs: Dict[str, Any], monkeypatch: MonkeyPatch ) -> None: monkeypatch.delattr(EuroSATDataModule, "plot") datamodule = EuroSATDataModule( @@ -150,7 +163,7 @@ class TestMultiLabelClassificationTask: ) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) conf_dict = OmegaConf.to_object(conf.experiment) - conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict) + conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict) # Instantiate datamodule datamodule_kwargs = conf_dict["datamodule"] @@ -170,7 +183,7 @@ class TestMultiLabelClassificationTask: def test_no_logger(self) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", "bigearthnet_s1.yaml")) conf_dict = OmegaConf.to_object(conf.experiment) - conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict) + conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict) # Instantiate datamodule datamodule_kwargs = conf_dict["datamodule"] @@ -187,23 +200,23 @@ class TestMultiLabelClassificationTask: trainer.fit(model=model, datamodule=datamodule) @pytest.fixture - def model_kwargs(self) -> Dict[Any, Any]: + def model_kwargs(self) -> Dict[str, Any]: return { "model": "resnet18", "in_channels": 14, "loss": "bce", "num_classes": 19, - "weights": "random", + "weights": None, } - def test_invalid_loss(self, model_kwargs: Dict[Any, Any]) -> None: + def test_invalid_loss(self, model_kwargs: Dict[str, Any]) -> None: model_kwargs["loss"] = "invalid_loss" match = "Loss type 'invalid_loss' is not valid." with pytest.raises(ValueError, match=match): MultiLabelClassificationTask(**model_kwargs) def test_missing_attributes( - self, model_kwargs: Dict[Any, Any], monkeypatch: MonkeyPatch + self, model_kwargs: Dict[str, Any], monkeypatch: MonkeyPatch ) -> None: monkeypatch.delattr(BigEarthNetDataModule, "plot") datamodule = BigEarthNetDataModule( diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index 67d70334d..22524b0a8 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -2,18 +2,30 @@ # Licensed under the MIT License. import os +from pathlib import Path from typing import Any, Dict, Type, cast import pytest +import timm +import torch +import torchvision +from _pytest.monkeypatch import MonkeyPatch from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer +from torchvision.models._api import WeightsEnum from torchgeo.datamodules import COWCCountingDataModule, TropicalCycloneDataModule +from torchgeo.models import ResNet18_Weights from torchgeo.trainers import RegressionTask from .test_utils import RegressionTestModel +def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]: + state_dict: Dict[str, Any] = torch.load(url) + return state_dict + + class TestRegressionTask: @pytest.mark.parametrize( "name,classname", @@ -25,7 +37,7 @@ class TestRegressionTask: def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) conf_dict = OmegaConf.to_object(conf.experiment) - conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict) + conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict) # Instantiate datamodule datamodule_kwargs = conf_dict["datamodule"] @@ -46,7 +58,7 @@ class TestRegressionTask: def test_no_logger(self) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", "cyclone.yaml")) conf_dict = OmegaConf.to_object(conf.experiment) - conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict) + conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict) # Instantiate datamodule datamodule_kwargs = conf_dict["datamodule"] @@ -63,36 +75,39 @@ class TestRegressionTask: trainer.fit(model=model, datamodule=datamodule) @pytest.fixture - def model_kwargs(self) -> Dict[Any, Any]: + def model_kwargs(self) -> Dict[str, Any]: return { "model": "resnet18", - "weights": "random", + "weights": None, "num_outputs": 1, "in_channels": 3, } - def test_invalid_pretrained( - self, model_kwargs: Dict[Any, Any], checkpoint: str - ) -> None: - model_kwargs["weights"] = checkpoint - model_kwargs["model"] = "resnet50" - match = "Trying to load resnet18 weights into a resnet50" - with pytest.raises(ValueError, match=match): - RegressionTask(**model_kwargs) + @pytest.fixture + def mocked_weights(self, tmp_path: Path, monkeypatch: MonkeyPatch) -> WeightsEnum: + weights = ResNet18_Weights.SENTINEL2_RGB_MOCO + path = tmp_path / f"{weights}.pth" + model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"]) + torch.save(model.state_dict(), path) + monkeypatch.setattr(weights, "url", str(path)) + monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) + return weights - def test_pretrained(self, model_kwargs: Dict[Any, Any], checkpoint: str) -> None: + def test_weight_file(self, model_kwargs: Dict[str, Any], checkpoint: str) -> None: model_kwargs["weights"] = checkpoint with pytest.warns(UserWarning): RegressionTask(**model_kwargs) - def test_invalid_model(self, model_kwargs: Dict[Any, Any]) -> None: - model_kwargs["model"] = "invalid_model" - match = "Model type 'invalid_model' is not a valid timm model." - with pytest.raises(ValueError, match=match): + def test_weight_enum( + self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum + ) -> None: + model_kwargs["weights"] = mocked_weights + with pytest.warns(UserWarning): RegressionTask(**model_kwargs) - def test_invalid_weights(self, model_kwargs: Dict[Any, Any]) -> None: - model_kwargs["weights"] = "invalid_weights" - match = "Weight type 'invalid_weights' is not valid." - with pytest.raises(ValueError, match=match): + def test_weight_str( + self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum + ) -> None: + model_kwargs["weights"] = str(mocked_weights) + with pytest.warns(UserWarning): RegressionTask(**model_kwargs) diff --git a/torchgeo/models/__init__.py b/torchgeo/models/__init__.py index 1f4afdf3d..e1b20236e 100644 --- a/torchgeo/models/__init__.py +++ b/torchgeo/models/__init__.py @@ -3,14 +3,17 @@ """TorchGeo models.""" +from .api import get_model, get_model_weights, get_weight, list_models from .changestar import ChangeMixin, ChangeStar, ChangeStarFarSeg from .farseg import FarSeg from .fcn import FCN from .fcsiam import FCSiamConc, FCSiamDiff from .rcf import RCF -from .resnet import resnet50 +from .resnet import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50 +from .vit import ViTSmall16_Weights, vit_small_patch16_224 __all__ = ( + # models "ChangeMixin", "ChangeStar", "ChangeStarFarSeg", @@ -19,5 +22,16 @@ __all__ = ( "FCSiamConc", "FCSiamDiff", "RCF", + "resnet18", "resnet50", + "vit_small_patch16_224", + # weights + "ResNet50_Weights", + "ResNet18_Weights", + "ViTSmall16_Weights", + # utilities + "get_model", + "get_model_weights", + "get_weight", + "list_models", ) diff --git a/torchgeo/models/api.py b/torchgeo/models/api.py new file mode 100644 index 000000000..95cf5b5c8 --- /dev/null +++ b/torchgeo/models/api.py @@ -0,0 +1,90 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""APIs for querying and loading pre-trained model weights. + +See the following references for design details: + +* https://pytorch.org/blog/easily-list-and-initialize-models-with-new-apis-in-torchvision/ +* https://pytorch.org/vision/stable/models.html +* https://github.com/pytorch/vision/blob/main/torchvision/models/_api.py +""" # noqa: E501 + +from typing import Any, Callable, List, Union + +import torch.nn as nn +from torchvision.models._api import WeightsEnum + +from .resnet import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50 +from .vit import ViTSmall16_Weights, vit_small_patch16_224 + +_model = { + "resnet18": resnet18, + "resnet50": resnet50, + "vit_small_patch16_224": vit_small_patch16_224, +} + +_model_weights = { + resnet18: ResNet18_Weights, + resnet50: ResNet50_Weights, + vit_small_patch16_224: ViTSmall16_Weights, + "resnet18": ResNet18_Weights, + "resnet50": ResNet50_Weights, + "vit_small_patch16_224": ViTSmall16_Weights, +} + + +def get_model(name: str, *args: Any, **kwargs: Any) -> nn.Module: + """Get an instantiated model from its name. + + .. versionadded:: 0.4 + + Args: + name: Name of the model. + *args: Additional arguments passed to the model builder method. + **kwargs: Additional keyword arguments passed to the model builder method. + + Returns: + An instantiated model. + """ + model: nn.Module = _model[name](*args, **kwargs) + return model + + +def get_model_weights(name: Union[Callable[..., nn.Module], str]) -> WeightsEnum: + """Get the weights enum class associated with a given model. + + .. versionadded:: 0.4 + + Args: + name: Model builder function or the name under which it is registered. + + Returns: + The weights enum class associated with the model. + """ + return _model_weights[name] + + +def get_weight(name: str) -> WeightsEnum: + """Get the weights enum value by its full name. + + .. versionadded:: 0.4 + + Args: + name: Name of the weight enum entry. + + Returns: + The requested weight enum. + """ + return eval(name) + + +def list_models() -> List[str]: + """List the registered models. + + .. versionadded:: 0.4 + + Returns: + A list of registered models. + """ + return list(_model.keys()) diff --git a/torchgeo/models/farseg.py b/torchgeo/models/farseg.py index 50db33bab..13f6601f9 100644 --- a/torchgeo/models/farseg.py +++ b/torchgeo/models/farseg.py @@ -9,7 +9,6 @@ from typing import List, cast import torch.nn.functional as F import torchvision -from packaging.version import parse from torch import Tensor from torch.nn.modules import ( BatchNorm2d, @@ -62,17 +61,14 @@ class FarSeg(Module): else: raise ValueError(f"unknown backbone: {backbone}.") kwargs = {} - if parse(torchvision.__version__) >= parse("0.13"): - if backbone_pretrained: - kwargs = { - "weights": getattr( - torchvision.models, f"ResNet{backbone[6:]}_Weights" - ).DEFAULT - } - else: - kwargs = {"weights": None} + if backbone_pretrained: + kwargs = { + "weights": getattr( + torchvision.models, f"ResNet{backbone[6:]}_Weights" + ).DEFAULT + } else: - kwargs = {"pretrained": backbone_pretrained} + kwargs = {"weights": None} self.backbone = getattr(resnet, backbone)(**kwargs) diff --git a/torchgeo/models/resnet.py b/torchgeo/models/resnet.py index 13126d922..18812156e 100644 --- a/torchgeo/models/resnet.py +++ b/torchgeo/models/resnet.py @@ -3,83 +3,208 @@ """Pre-trained ResNet models.""" -from typing import Any, List, Type, Union +from typing import Any, Optional +import kornia.augmentation as K +import timm import torch.nn as nn -from torch.hub import load_state_dict_from_url -from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet +from timm.models import ResNet +from torchvision.models._api import Weights, WeightsEnum -MODEL_URLS = { - "sentinel2": { - "all": { - "resnet50": "https://zenodo.org/record/5610000/files/resnet50-sentinel2.pt" - } - } -} +from ..transforms import AugmentationSequential + +__all__ = ["ResNet50_Weights", "ResNet18_Weights"] + +_zhu_xlab_transforms = AugmentationSequential( + K.Resize(256), K.CenterCrop(224), data_keys=["image"] +) + +# https://github.com/pytorch/vision/pull/6883 +# https://github.com/pytorch/vision/pull/7107 +# Can be removed once torchvision>=0.15 is required +Weights.__deepcopy__ = lambda *args, **kwargs: args[0] -IN_CHANNELS = {"sentinel2": {"all": 10}} +class ResNet18_Weights(WeightsEnum): # type: ignore[misc] + """ResNet18 weights. -NUM_CLASSES = {"sentinel2": 17} + For `timm `_ + *resnet18* implementation. + + .. versionadded:: 0.4 + """ + + SENTINEL2_ALL_MOCO = Weights( + url=( + "https://huggingface.co/torchgeo/resnet18_sentinel2_all_moco/" + "resolve/main/resnet18_sentinel2_all_moco.pth" + ), + transforms=_zhu_xlab_transforms, + meta={ + "dataset": "SSL4EO-S12", + "in_chans": 13, + "model": "resnet18", + "publication": "https://arxiv.org/abs/2211.07044", + "repo": "https://github.com/zhu-xlab/SSL4EO-S12", + "ssl_method": "moco", + }, + ) + + SENTINEL2_RGB_MOCO = Weights( + url=( + "https://huggingface.co/torchgeo/resnet18_sentinel2_rgb_moco/" + "resolve/main/resnet18_sentinel2_rgb_moco.pth" + ), + transforms=_zhu_xlab_transforms, + meta={ + "dataset": "SSL4EO-S12", + "in_chans": 3, + "model": "resnet18", + "publication": "https://arxiv.org/abs/2211.07044", + "repo": "https://github.com/zhu-xlab/SSL4EO-S12", + "ssl_method": "moco", + }, + ) + + SENTINEL2_RGB_SECO = Weights( + url=( + "https://huggingface.co/torchgeo/resnet18_sentinel2_rgb_seco/" + "resolve/main/resnet18_sentinel2_rgb_seco.ckpt" + ), + transforms=nn.Identity(), + meta={ + "dataset": "SeCo Dataset", + "in_chans": 3, + "model": "resnet18", + "publication": "https://arxiv.org/abs/2103.16607", + "repo": "https://github.com/ServiceNow/seasonal-contrast", + "ssl_method": "seco", + }, + ) -def _resnet( - sensor: str, - bands: str, - arch: str, - block: Type[Union[BasicBlock, Bottleneck]], - layers: List[int], - pretrained: bool, - progress: bool, - **kwargs: Any, +class ResNet50_Weights(WeightsEnum): # type: ignore[misc] + """ResNet50 weights. + + For `timm `_ + *resnet50* implementation. + + .. versionadded:: 0.4 + """ + + SENTINEL1_ALL_MOCO = Weights( + url=( + "https://huggingface.co/torchgeo/resnet50_sentinel1_all_moco/" + "resolve/main/resnet50_sentinel1_all_moco.pth" + ), + transforms=_zhu_xlab_transforms, + meta={ + "dataset": "SSL4EO-S12", + "in_chans": 2, + "model": "resnet50", + "publication": "https://arxiv.org/abs/2211.07044", + "repo": "https://github.com/zhu-xlab/SSL4EO-S12", + "ssl_method": "moco", + }, + ) + + SENTINEL2_ALL_MOCO = Weights( + url=( + "https://huggingface.co/torchgeo/resnet50_sentinel2_all_moco/" + "resolve/main/resnet50_sentinel2_all_moco.pth" + ), + transforms=_zhu_xlab_transforms, + meta={ + "dataset": "SSL4EO-S12", + "in_chans": 13, + "model": "resnet50", + "publication": "https://arxiv.org/abs/2211.07044", + "repo": "https://github.com/zhu-xlab/SSL4EO-S12", + "ssl_method": "moco", + }, + ) + + SENTINEL2_RGB_MOCO = Weights( + url=( + "https://huggingface.co/torchgeo/resnet50_sentinel2_rgb_moco/" + "resolve/main/resnet50_sentinel2_rgb_moco.pth" + ), + transforms=_zhu_xlab_transforms, + meta={ + "dataset": "SSL4EO-S12", + "in_chans": 3, + "model": "resnet50", + "publication": "https://arxiv.org/abs/2211.07044", + "repo": "https://github.com/zhu-xlab/SSL4EO-S12", + "ssl_method": "moco", + }, + ) + + SENTINEL2_ALL_DINO = Weights( + url=( + "https://huggingface.co/torchgeo/resnet50_sentinel2_all_dino/" + "resolve/main/resnet50_sentinel2_all_dino.pth" + ), + transforms=_zhu_xlab_transforms, + meta={ + "dataset": "SSL4EO-S12", + "in_chans": 13, + "model": "resnet50", + "publication": "https://arxiv.org/abs/2211.07044", + "repo": "https://github.com/zhu-xlab/SSL4EO-S12", + "ssl_method": "dino", + }, + ) + + SENTINEL2_RGB_SECO = Weights( + url=( + "https://huggingface.co/torchgeo/resnet50_sentinel2_rgb_seco/" + "resolve/main/resnet50_sentinel2_rgb_seco.ckpt" + ), + transforms=nn.Identity(), + meta={ + "dataset": "SeCo Dataset", + "in_chans": 3, + "model": "resnet50", + "publication": "https://arxiv.org/abs/2103.16607", + "repo": "https://github.com/ServiceNow/seasonal-contrast", + "ssl_method": "seco", + }, + ) + + +def resnet18( + weights: Optional[ResNet18_Weights] = None, *args: Any, **kwargs: Any ) -> ResNet: - """Resnet model. + """ResNet-18 model. If you use this model in your research, please cite the following paper: * https://arxiv.org/pdf/1512.03385.pdf + .. versionadded:: 0.4 + Args: - sensor: imagery source which determines number of input channels - bands: which spectral bands to consider: "all", "rgb", etc. - arch: ResNet version specifying number of layers - block: type of network block - layers: number of layers per block - pretrained: if True, returns a model pre-trained on ``sensor`` imagery - progress: if True, displays a progress bar of the download to stderr + weights: Pre-trained model weights to use. + *args: Additional arguments to pass to :func:`timm.create_model` + **kwargs: Additional keywork arguments to pass to :func:`timm.create_model` Returns: - A ResNet-50 model + A ResNet-18 model. """ - # Initialize a new model - model = ResNet(block, layers, NUM_CLASSES[sensor], **kwargs) + if weights: + kwargs["in_chans"] = weights.meta["in_chans"] - # Replace the first layer with the correct number of input channels - model.conv1 = nn.Conv2d( - IN_CHANNELS[sensor][bands], - out_channels=64, - kernel_size=7, - stride=1, - padding=2, - bias=False, - ) + model: ResNet = timm.create_model("resnet18", *args, **kwargs) - # Load pretrained weights - if pretrained: - state_dict = load_state_dict_from_url( - MODEL_URLS[sensor][bands][arch], progress=progress - ) - model.load_state_dict(state_dict) + if weights: + model.load_state_dict(weights.get_state_dict(progress=True), strict=False) return model def resnet50( - sensor: str, - bands: str, - pretrained: bool = False, - progress: bool = True, - **kwargs: Any, + weights: Optional[ResNet50_Weights] = None, *args: Any, **kwargs: Any ) -> ResNet: """ResNet-50 model. @@ -87,22 +212,23 @@ def resnet50( * https://arxiv.org/pdf/1512.03385.pdf + .. versionchanged:: 0.4 + Switched to multi-weight support API. + Args: - sensor: imagery source which determines number of input channels - bands: which spectral bands to consider: "all", "rgb", etc. - pretrained: if True, returns a model pre-trained on ``sensor`` imagery - progress: if True, displays a progress bar of the download to stderr + weights: Pre-trained model weights to use. + *args: Additional arguments to pass to :func:`timm.create_model`. + **kwargs: Additional keywork arguments to pass to :func:`timm.create_model`. Returns: - A ResNet-50 model + A ResNet-50 model. """ - return _resnet( - sensor, - bands, - "resnet50", - Bottleneck, - [3, 4, 6, 3], - pretrained, - progress, - **kwargs, - ) + if weights: + kwargs["in_chans"] = weights.meta["in_chans"] + + model: ResNet = timm.create_model("resnet50", *args, **kwargs) + + if weights: + model.load_state_dict(weights.get_state_dict(progress=True), strict=False) + + return model diff --git a/torchgeo/models/vit.py b/torchgeo/models/vit.py new file mode 100644 index 000000000..b8d0874d4 --- /dev/null +++ b/torchgeo/models/vit.py @@ -0,0 +1,98 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Pre-trained Vision Transformer models.""" + +from typing import Any, Optional + +import kornia.augmentation as K +import timm +from timm.models.vision_transformer import VisionTransformer +from torchvision.models._api import Weights, WeightsEnum + +from ..transforms import AugmentationSequential + +__all__ = ["ViTSmall16_Weights"] + +_zhu_xlab_transforms = AugmentationSequential( + K.Resize(256), K.CenterCrop(224), data_keys=["image"] +) + +# https://github.com/pytorch/vision/pull/6883 +# https://github.com/pytorch/vision/pull/7107 +# Can be removed once torchvision>=0.15 is required +Weights.__deepcopy__ = lambda *args, **kwargs: args[0] + + +class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc] + """Vision Transformer Samll Patch Size 16 weights. + + For `timm `_ + *vit_small_patch16_224* implementation. + + .. versionadded:: 0.4 + """ + + SENTINEL2_ALL_MOCO = Weights( + url=( + "https://huggingface.co/torchgeo/vit_small_patch16_224_sentinel2_all_moco/" + "resolve/main/vit_small_patch16_224_sentinel2_all_moco.pth" + ), + transforms=_zhu_xlab_transforms, + meta={ + "dataset": "SSL4EO-S12", + "in_chans": 13, + "model": "vit_small_patch16_224", + "publication": "https://arxiv.org/abs/2211.07044", + "repo": "https://github.com/zhu-xlab/SSL4EO-S12", + "ssl_method": "moco", + }, + ) + + SENTINEL2_ALL_DINO = Weights( + url=( + "https://huggingface.co/torchgeo/vit_small_patch16_224_sentinel2_all_dino/" + "resolve/main/vit_small_patch16_224_sentinel2_all_dino.pth" + ), + transforms=_zhu_xlab_transforms, + meta={ + "dataset": "SSL4EO-S12", + "in_chans": 13, + "model": "vit_small_patch16_224", + "publication": "https://arxiv.org/abs/2211.07044", + "repo": "https://github.com/zhu-xlab/SSL4EO-S12", + "ssl_method": "dino", + }, + ) + + +def vit_small_patch16_224( + weights: Optional[ViTSmall16_Weights] = None, *args: Any, **kwargs: Any +) -> VisionTransformer: + """Vision Transform (ViT) small patch size 16 model. + + If you use this model in your research, please cite the following paper: + + * https://arxiv.org/abs/2010.11929 + + .. versionadded:: 0.4 + + Args: + weights: Pre-trained model weights to use. + *args: Additional arguments to pass to :func:`timm.create_model`. + **kwargs: Additional keywork arguments to pass to :func:`timm.create_model`. + + Returns: + A ViT small 16 model. + """ + if weights: + kwargs["in_chans"] = weights.meta["in_chans"] + + model: VisionTransformer = timm.create_model( + "vit_small_patch16_224", *args, **kwargs + ) + + if weights: + model.load_state_dict(weights.get_state_dict(progress=True), strict=False) + + return model diff --git a/torchgeo/trainers/byol.py b/torchgeo/trainers/byol.py index fbbd894b3..d733df142 100644 --- a/torchgeo/trainers/byol.py +++ b/torchgeo/trainers/byol.py @@ -17,7 +17,9 @@ from kornia.geometry import transform as KorniaTransform from torch import Tensor, optim from torch.nn.modules import BatchNorm1d, Linear, Module, ReLU, Sequential from torch.optim.lr_scheduler import ReduceLROnPlateau +from torchvision.models._api import WeightsEnum +from ..models import get_weight from . import utils @@ -323,41 +325,23 @@ class BYOLTask(pl.LightningModule): def config_task(self) -> None: """Configures the task based on kwargs parameters passed to the constructor.""" + # Create model in_channels = self.hyperparams["in_channels"] - backbone_name = self.hyperparams["backbone"] + weights = self.hyperparams["weights"] + backbone = timm.create_model( + self.hyperparams["backbone"], + in_chans=in_channels, + pretrained=weights is True, + ) - imagenet_pretrained = False - custom_pretrained = False - if self.hyperparams["weights"] and not os.path.exists( - self.hyperparams["weights"] - ): - if self.hyperparams["weights"] not in ["imagenet", "random"]: - raise ValueError( - f"Weight type '{self.hyperparams['weights']}' is not valid." - ) + # Load weights + if weights and weights is not True: + if isinstance(weights, WeightsEnum): + state_dict = weights.get_state_dict(progress=True) + elif os.path.exists(weights): + _, state_dict = utils.extract_backbone(weights) else: - imagenet_pretrained = self.hyperparams["weights"] == "imagenet" - custom_pretrained = False - else: - custom_pretrained = True - - # Create the model - valid_models = timm.list_models(pretrained=imagenet_pretrained) - if backbone_name in valid_models: - backbone = timm.create_model( - backbone_name, in_chans=in_channels, pretrained=imagenet_pretrained - ) - else: - raise ValueError(f"Model type '{backbone_name}' is not a valid timm model.") - - if custom_pretrained: - name, state_dict = utils.extract_backbone(self.hyperparams["weights"]) - - if self.hyperparams["backbone"] != name: - raise ValueError( - f"Trying to load {name} weights into a " - f"{self.hyperparams['backbone']}" - ) + state_dict = get_weight(weights).get_state_dict(progress=True) backbone = utils.load_state_dict(backbone, state_dict) self.model = BYOL(backbone, in_channels=in_channels, image_size=(256, 256)) @@ -368,7 +352,9 @@ class BYOLTask(pl.LightningModule): Keyword Args: in_channels: Number of input channels to model backbone: Name of the timm model to use - weights: Either "random" or "imagenet" + weights: Either a weight enum, the string representation of a weight enum, + True for ImageNet weights, False or None for random weights, + or the path to a saved model state dict. learning_rate: Learning rate for optimizer learning_rate_schedule_patience: Patience for learning rate scheduler diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index ca4c8dcf6..41d7c922e 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -22,8 +22,10 @@ from torchmetrics.classification import ( MultilabelAccuracy, MultilabelFBetaScore, ) +from torchvision.models._api import WeightsEnum -from ..datasets.utils import unbind_samples +from ..datasets import unbind_samples +from ..models import get_weight from . import utils @@ -43,44 +45,23 @@ class ClassificationTask(pl.LightningModule): def config_model(self) -> None: """Configures the model based on kwargs parameters passed to the constructor.""" - in_channels = self.hyperparams["in_channels"] - model = self.hyperparams["model"] + # Create model + weights = self.hyperparams["weights"] + self.model = timm.create_model( + self.hyperparams["model"], + num_classes=self.hyperparams["num_classes"], + in_chans=self.hyperparams["in_channels"], + pretrained=weights is True, + ) - imagenet_pretrained = False - custom_pretrained = False - if self.hyperparams["weights"] and not os.path.exists( - self.hyperparams["weights"] - ): - if self.hyperparams["weights"] not in ["imagenet", "random"]: - raise ValueError( - f"Weight type '{self.hyperparams['weights']}' is not valid." - ) + # Load weights + if weights and weights is not True: + if isinstance(weights, WeightsEnum): + state_dict = weights.get_state_dict(progress=True) + elif os.path.exists(weights): + _, state_dict = utils.extract_backbone(weights) else: - imagenet_pretrained = self.hyperparams["weights"] == "imagenet" - custom_pretrained = False - else: - custom_pretrained = True - - # Create the model - valid_models = timm.list_models(pretrained=imagenet_pretrained) - if model in valid_models: - self.model = timm.create_model( - model, - num_classes=self.hyperparams["num_classes"], - in_chans=in_channels, - pretrained=imagenet_pretrained, - ) - else: - raise ValueError(f"Model type '{model}' is not a valid timm model.") - - if custom_pretrained: - name, state_dict = utils.extract_backbone(self.hyperparams["weights"]) - - if self.hyperparams["model"] != name: - raise ValueError( - f"Trying to load {name} weights into a " - f"{self.hyperparams['model']}" - ) + state_dict = get_weight(weights).get_state_dict(progress=True) self.model = utils.load_state_dict(self.model, state_dict) def config_task(self) -> None: @@ -102,7 +83,9 @@ class ClassificationTask(pl.LightningModule): Keyword Args: model: Name of the classification model use loss: Name of the loss function, accepts 'ce', 'jaccard', or 'focal' - weights: Either "random" or "imagenet" + weights: Either a weight enum, the string representation of a weight enum, + True for ImageNet weights, False or None for random weights, + or the path to a saved model state dict. num_classes: Number of prediction classes in_channels: Number of input channels to model learning_rate: Learning rate for optimizer @@ -321,12 +304,6 @@ class MultiLabelClassificationTask(ClassificationTask): """ super().__init__(**kwargs) - # Creates `self.hparams` from kwargs - self.save_hyperparameters() # type: ignore[operator] - self.hyperparams = cast(Dict[str, Any], self.hparams) - - self.config_task() - self.train_metrics = MetricCollection( { "OverallAccuracy": MultilabelAccuracy( diff --git a/torchgeo/trainers/detection.py b/torchgeo/trainers/detection.py index e4982ff1f..b50d273d4 100644 --- a/torchgeo/trainers/detection.py +++ b/torchgeo/trainers/detection.py @@ -8,33 +8,29 @@ from typing import Any, Dict, List, cast import matplotlib.pyplot as plt import pytorch_lightning as pl import torch -import torchvision -from packaging.version import parse from torch import Tensor from torch.optim.lr_scheduler import ReduceLROnPlateau from torchmetrics.detection.mean_ap import MeanAveragePrecision +from torchvision.models import resnet as R from torchvision.models.detection import FasterRCNN from torchvision.models.detection.backbone_utils import resnet_fpn_backbone from torchvision.models.detection.rpn import AnchorGenerator from torchvision.ops import MultiScaleRoIAlign -if parse(torchvision.__version__) >= parse("0.13"): - from torchvision.models import resnet as R - - BACKBONE_WEIGHT_MAP = { - "resnet18": R.ResNet18_Weights.DEFAULT, - "resnet34": R.ResNet34_Weights.DEFAULT, - "resnet50": R.ResNet50_Weights.DEFAULT, - "resnet101": R.ResNet101_Weights.DEFAULT, - "resnet152": R.ResNet152_Weights.DEFAULT, - "resnext50_32x4d": R.ResNeXt50_32X4D_Weights.DEFAULT, - "resnext101_32x8d": R.ResNeXt101_32X8D_Weights.DEFAULT, - "wide_resnet50_2": R.Wide_ResNet50_2_Weights.DEFAULT, - "wide_resnet101_2": R.Wide_ResNet101_2_Weights.DEFAULT, - } - from ..datasets.utils import unbind_samples +BACKBONE_WEIGHT_MAP = { + "resnet18": R.ResNet18_Weights.DEFAULT, + "resnet34": R.ResNet34_Weights.DEFAULT, + "resnet50": R.ResNet50_Weights.DEFAULT, + "resnet101": R.ResNet101_Weights.DEFAULT, + "resnet152": R.ResNet152_Weights.DEFAULT, + "resnext50_32x4d": R.ResNeXt50_32X4D_Weights.DEFAULT, + "resnext101_32x8d": R.ResNeXt101_32X8D_Weights.DEFAULT, + "wide_resnet50_2": R.Wide_ResNet50_2_Weights.DEFAULT, + "wide_resnet101_2": R.Wide_ResNet101_2_Weights.DEFAULT, +} + class ObjectDetectionTask(pl.LightningModule): """LightningModule for object detection of images. @@ -62,15 +58,12 @@ class ObjectDetectionTask(pl.LightningModule): "backbone_name": self.hyperparams["backbone"], "trainable_layers": self.hyperparams.get("trainable_layers", 3), } - if parse(torchvision.__version__) >= parse("0.13"): - if backbone_pretrained: - kwargs["weights"] = BACKBONE_WEIGHT_MAP[ - self.hyperparams["backbone"] - ] - else: - kwargs["weights"] = None + if backbone_pretrained: + kwargs["weights"] = BACKBONE_WEIGHT_MAP[ + self.hyperparams["backbone"] + ] else: - kwargs["pretrained"] = backbone_pretrained + kwargs["weights"] = None backbone = resnet_fpn_backbone(**kwargs) else: diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index 8a7056f12..73bffd593 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -14,8 +14,10 @@ import torch.nn.functional as F from torch import Tensor from torch.optim.lr_scheduler import ReduceLROnPlateau from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection +from torchvision.models._api import WeightsEnum -from ..datasets.utils import unbind_samples +from ..datasets import unbind_samples +from ..models import get_weight from . import utils @@ -35,44 +37,23 @@ class RegressionTask(pl.LightningModule): def config_task(self) -> None: """Configures the task based on kwargs parameters.""" - in_channels = self.hyperparams["in_channels"] - model = self.hyperparams["model"] + # Create model + weights = self.hyperparams["weights"] + self.model = timm.create_model( + self.hyperparams["model"], + num_classes=self.hyperparams["num_outputs"], + in_chans=self.hyperparams["in_channels"], + pretrained=weights is True, + ) - imagenet_pretrained = False - custom_pretrained = False - if self.hyperparams["weights"] and not os.path.exists( - self.hyperparams["weights"] - ): - if self.hyperparams["weights"] not in ["imagenet", "random"]: - raise ValueError( - f"Weight type '{self.hyperparams['weights']}' is not valid." - ) + # Load weights + if weights and weights is not True: + if isinstance(weights, WeightsEnum): + state_dict = weights.get_state_dict(progress=True) + elif os.path.exists(weights): + _, state_dict = utils.extract_backbone(weights) else: - imagenet_pretrained = self.hyperparams["weights"] == "imagenet" - custom_pretrained = False - else: - custom_pretrained = True - - # Create the model - valid_models = timm.list_models(pretrained=imagenet_pretrained) - if model in valid_models: - self.model = timm.create_model( - model, - num_classes=self.hyperparams["num_outputs"], - in_chans=in_channels, - pretrained=imagenet_pretrained, - ) - else: - raise ValueError(f"Model type '{model}' is not a valid timm model.") - - if custom_pretrained: - name, state_dict = utils.extract_backbone(self.hyperparams["weights"]) - - if self.hyperparams["model"] != name: - raise ValueError( - f"Trying to load {name} weights into a " - f"{self.hyperparams['model']}" - ) + state_dict = get_weight(weights).get_state_dict(progress=True) self.model = utils.load_state_dict(self.model, state_dict) def __init__(self, **kwargs: Any) -> None: @@ -80,7 +61,9 @@ class RegressionTask(pl.LightningModule): Keyword Args: model: Name of the timm model to use - weights: Either "random" or "imagenet" + weights: Either a weight enum, the string representation of a weight enum, + True for ImageNet weights, False or None for random weights, + or the path to a saved model state dict. num_outputs: Number of prediction outputs in_channels: Number of input channels to model learning_rate: Learning rate for optimizer