* load pretrained weights

* change name millionaid

* restructure and additional weights

* rename sentinel1 weights

* add vit small weights

* forgot to add vit.py

* struggling with test

* wrong name failing test

* feedback on tests

* increase test coverage

* fix failing test

* fix failing test

* fix failing test and add vit tests

* fix failing vit test

* torchgeo.models.utils

* forgot utils file

* typo num channels

* nitpick docs, version torchvision

* another try min dependencies

* add documentation table

* expand pytests to test pretrained weights on tasks

* reverse changes to byol task

* add tests to init pretrained weights from config

* forgot to add the conf files

* change path

* increase test coverage

* vit tests all pass locally including slow

* now remote

* fix tests another one

* add a draft tutorial

* run black on tutorial notebook

* Tutorial typo fixes

* Lower min torch/vision versions

* Fix bad rebase

* Remove dead code

* Flake8 fixes

* Consistent in_chans

* Black fixes

* bison > yacs

* Remove one more reference

* Download modified weights from hugging face

* Add entrypoints

* Add torch.hub support

* progress arg is required

* Fix model loading for resnet18

* Add transforms, update tests

* VIT -> ViT

* add seco weights

* Fix type hints

* Link to timm docs

* Fix pydocstyle

* Try to fix timm docs link

* Fix tests

* Nuke ignores

* Ignore timm links

* Add model API methods

* Add to __init__ and document

* Test model API functions

* fix tests

* Use correct documentation link for intersphinx

* Typos

* Fix Windows tests

* meth -> func

* Explicit function scope

* weight-specific filename

* Support enums in classification trainer

* Update other trainers too

* Fix regression tests

* Fix classification tests

* Fix byol tests

* Fix types

* progress_bar is required arg

* Test weight enums

* Fix pickling

* Fix regression tests

* Improve coverage of classification tests

* Improve coverage of BYOL tests

* Update resnet table

* Update ViT table

* Update get_state_dict usage

* Remove unused YAML files

* Update table widths

* Documentation improvements

* Tweak tables

* Try to fix Windows tests

* Revert "Try to fix Windows tests"

This reverts commit 1325b13ff7.

* Monkeypatch everything

* Revert "Monkeypatch everything"

This reverts commit e3e8d7d042.

* Revert "Revert "Monkeypatch everything""

This reverts commit 9b27bd705b.

* Patch things not at the source

* Fix missing import

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Nils Lehmann 2023-01-22 23:25:49 +01:00 коммит произвёл GitHub
Родитель 0d04e06791
Коммит 60eb61b5fa
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
46 изменённых файлов: 1101 добавлений и 396 удалений

Просмотреть файл

@ -10,7 +10,7 @@ experiment:
model: "resnet18" model: "resnet18"
learning_rate: 1e-3 learning_rate: 1e-3
learning_rate_schedule_patience: 6 learning_rate_schedule_patience: 6
weights: "random" weights: null
in_channels: 14 in_channels: 14
num_classes: 19 num_classes: 19
datamodule: datamodule:

Просмотреть файл

@ -6,9 +6,11 @@ experiment:
name: cowc_counting_test name: cowc_counting_test
module: module:
model: resnet18 model: resnet18
weights: null
num_outputs: 1
in_channels: 3
learning_rate: 1e-3 learning_rate: 1e-3
learning_rate_schedule_patience: 2 learning_rate_schedule_patience: 2
pretrained: True
datamodule: datamodule:
root: "data/cowc_counting" root: "data/cowc_counting"
seed: 0 seed: 0

Просмотреть файл

@ -6,9 +6,11 @@ experiment:
name: "cyclone_test" name: "cyclone_test"
module: module:
model: "resnet18" model: "resnet18"
weights: null
num_outputs: 1
in_channels: 3
learning_rate: 1e-3 learning_rate: 1e-3
learning_rate_schedule_patience: 2 learning_rate_schedule_patience: 2
pretrained: True
datamodule: datamodule:
root: "data/cyclone" root: "data/cyclone"
seed: 0 seed: 0

Просмотреть файл

@ -5,7 +5,7 @@ experiment:
model: "resnet18" model: "resnet18"
learning_rate: 1e-3 learning_rate: 1e-3
learning_rate_schedule_patience: 6 learning_rate_schedule_patience: 6
weights: "random" weights: null
in_channels: 13 in_channels: 13
num_classes: 10 num_classes: 10
datamodule: datamodule:

Просмотреть файл

@ -15,7 +15,6 @@ experiment:
module: module:
model: "faster-rcnn" model: "faster-rcnn"
backbone: "resnet50" backbone: "resnet50"
pretrained: True
num_classes: 2 num_classes: 2
learning_rate: 1.2e-4 learning_rate: 1.2e-4
learning_rate_schedule_patience: 6 learning_rate_schedule_patience: 6

Просмотреть файл

@ -10,7 +10,7 @@ experiment:
model: "resnet18" model: "resnet18"
learning_rate: 1e-3 learning_rate: 1e-3
learning_rate_schedule_patience: 6 learning_rate_schedule_patience: 6
weights: "random" weights: null
in_channels: 3 in_channels: 3
num_classes: 45 num_classes: 45
datamodule: datamodule:

Просмотреть файл

@ -10,7 +10,7 @@ experiment:
model: "resnet18" model: "resnet18"
learning_rate: 1e-3 learning_rate: 1e-3
learning_rate_schedule_patience: 6 learning_rate_schedule_patience: 6
weights: "random" weights: null
in_channels: 3 in_channels: 3
num_classes: 17 num_classes: 17
datamodule: datamodule:

Просмотреть файл

@ -21,7 +21,7 @@ Fully-convolutional Network
.. autoclass:: FCN .. autoclass:: FCN
FC Siamese Networks FC Siamese Networks
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^
.. autoclass:: FCSiamConc .. autoclass:: FCSiamConc
.. autoclass:: FCSiamDiff .. autoclass:: FCSiamDiff
@ -34,4 +34,33 @@ RCF Extractor
ResNet ResNet
^^^^^^ ^^^^^^
.. autofunction:: resnet18
.. autofunction:: resnet50 .. autofunction:: resnet50
.. autoclass:: ResNet18_Weights
.. autoclass:: ResNet50_Weights
.. csv-table::
:widths: 45 10 10 10 15 10 10 10
:header-rows: 1
:align: center
:file: resnet_pretrained_weights.csv
Vision Transformer
^^^^^^^^^^^^^^^^^^
.. autofunction:: vit_small_patch16_224
.. autoclass:: ViTSmall16_Weights
.. csv-table::
:widths: 45 10 10 10 15 10 10 10
:header-rows: 1
:align: center
:file: vit_pretrained_weights.csv
Utility Functions
^^^^^^^^^^^^^^^^^
.. autofunction:: get_model
.. autofunction:: get_model_weights
.. autofunction:: get_weight
.. autofunction:: list_models

Просмотреть файл

@ -0,0 +1,9 @@
Weight,Channels,Source,Citation,BigEarthNet,EuroSAT,So2Sat,OSCD
ResNet18_Weights.SENTINEL2_ALL_MOCO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,,,,
ResNet18_Weights.SENTINEL2_RGB_MOCO, 3,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,,,,
ResNet18_Weights.SENTINEL2_RGB_SECO, 3,`link <https://github.com/ServiceNow/seasonal-contrast>`__,`link <https://arxiv.org/abs/2103.16607>`__,87.27,93.14,,46.94
ResNet50_Weights.SENTINEL1_ALL_MOCO, 2,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,,,,
ResNet50_Weights.SENTINEL2_ALL_MOCO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,91.8,99.1,60.9,
ResNet50_Weights.SENTINEL2_RGB_MOCO, 3,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,,,,
ResNet50_Weights.SENTINEL2_ALL_DINO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,90.7,99.1,63.6,
ResNet50_Weights.SENTINEL2_RGB_SECO, 3,`link <https://github.com/ServiceNow/seasonal-contrast>`__,`link <https://arxiv.org/abs/2103.16607>`__,87.81,,,
1 Weight Channels Source Citation BigEarthNet EuroSAT So2Sat OSCD
2 ResNet18_Weights.SENTINEL2_ALL_MOCO 13 `link <https://github.com/zhu-xlab/SSL4EO-S12>`__ `link <https://arxiv.org/abs/2211.07044>`__
3 ResNet18_Weights.SENTINEL2_RGB_MOCO 3 `link <https://github.com/zhu-xlab/SSL4EO-S12>`__ `link <https://arxiv.org/abs/2211.07044>`__
4 ResNet18_Weights.SENTINEL2_RGB_SECO 3 `link <https://github.com/ServiceNow/seasonal-contrast>`__ `link <https://arxiv.org/abs/2103.16607>`__ 87.27 93.14 46.94
5 ResNet50_Weights.SENTINEL1_ALL_MOCO 2 `link <https://github.com/zhu-xlab/SSL4EO-S12>`__ `link <https://arxiv.org/abs/2211.07044>`__
6 ResNet50_Weights.SENTINEL2_ALL_MOCO 13 `link <https://github.com/zhu-xlab/SSL4EO-S12>`__ `link <https://arxiv.org/abs/2211.07044>`__ 91.8 99.1 60.9
7 ResNet50_Weights.SENTINEL2_RGB_MOCO 3 `link <https://github.com/zhu-xlab/SSL4EO-S12>`__ `link <https://arxiv.org/abs/2211.07044>`__
8 ResNet50_Weights.SENTINEL2_ALL_DINO 13 `link <https://github.com/zhu-xlab/SSL4EO-S12>`__ `link <https://arxiv.org/abs/2211.07044>`__ 90.7 99.1 63.6
9 ResNet50_Weights.SENTINEL2_RGB_SECO 3 `link <https://github.com/ServiceNow/seasonal-contrast>`__ `link <https://arxiv.org/abs/2103.16607>`__ 87.81

Просмотреть файл

@ -0,0 +1,3 @@
Weight,Channels,Source,Citation,BigEarthNet,EuroSAT,So2Sat,OSCD
VITSmall16_Weights.SENTINEL2_ALL_MOCO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,89.9,98.6,61.6,
VITSmall16_Weights.SENTINEL2_ALL_DINO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,90.5,99.0,62.2,
1 Weight Channels Source Citation BigEarthNet EuroSAT So2Sat OSCD
2 VITSmall16_Weights.SENTINEL2_ALL_MOCO 13 `link <https://github.com/zhu-xlab/SSL4EO-S12>`__ `link <https://arxiv.org/abs/2211.07044>`__ 89.9 98.6 61.6
3 VITSmall16_Weights.SENTINEL2_ALL_DINO 13 `link <https://github.com/zhu-xlab/SSL4EO-S12>`__ `link <https://arxiv.org/abs/2211.07044>`__ 90.5 99.0 62.2

Просмотреть файл

@ -56,14 +56,12 @@ needs_sphinx = "4.0"
nitpicky = True nitpicky = True
nitpick_ignore = [ nitpick_ignore = [
# https://github.com/sphinx-doc/sphinx/issues/8127 # Undocumented classes
("py:class", ".."),
# TODO: can't figure out why this isn't found
("py:class", "LightningDataModule"),
("py:class", "pytorch_lightning.core.module.LightningModule"),
# Undocumented class
("py:class", "torchvision.models.resnet.ResNet"),
("py:class", "segmentation_models_pytorch.base.model.SegmentationModel"), ("py:class", "segmentation_models_pytorch.base.model.SegmentationModel"),
("py:class", "timm.models.resnet.ResNet"),
("py:class", "timm.models.vision_transformer.VisionTransformer"),
("py:class", "torchvision.models._api.WeightsEnum"),
("py:class", "torchvision.models.resnet.ResNet"),
] ]
@ -114,6 +112,7 @@ intersphinx_mapping = {
"rasterio": ("https://rasterio.readthedocs.io/en/stable/", None), "rasterio": ("https://rasterio.readthedocs.io/en/stable/", None),
"rtree": ("https://rtree.readthedocs.io/en/stable/", None), "rtree": ("https://rtree.readthedocs.io/en/stable/", None),
"segmentation_models_pytorch": ("https://smp.readthedocs.io/en/stable/", None), "segmentation_models_pytorch": ("https://smp.readthedocs.io/en/stable/", None),
"timm": ("https://huggingface.co/docs/timm/main/en/", None),
"torch": ("https://pytorch.org/docs/stable", None), "torch": ("https://pytorch.org/docs/stable", None),
"torchvision": ("https://pytorch.org/vision/stable", None), "torchvision": ("https://pytorch.org/vision/stable", None),
} }

Просмотреть файл

@ -33,6 +33,7 @@ torchgeo
tutorials/indices tutorials/indices
tutorials/trainers tutorials/trainers
tutorials/benchmarking tutorials/benchmarking
tutorials/pretrained_weights
.. toctree:: .. toctree::
:maxdepth: 1 :maxdepth: 1

Просмотреть файл

@ -0,0 +1,254 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Copyright (c) Microsoft Corporation. All rights reserved.\n",
"\n",
"Licensed under the MIT License."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Pretrained Weights\n",
"\n",
"In this tutorial, we demonstrate some available pretrained weights in TorchGeo. The implementation follows torchvisions' recently introduced [Multi-Weight API](https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/). We will use the [EuroSAT](https://torchgeo.readthedocs.io/en/stable/api/datasets.html#eurosat) dataset throughout this tutorial.\n",
"\n",
"It's recommended to run this notebook on Google Colab if you don't have your own GPU. Click the \"Open in Colab\" button above to get started."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup\n",
"\n",
"First, we install TorchGeo."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install torchgeo"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Imports\n",
"\n",
"Next, we import TorchGeo and any other libraries we need."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"\n",
"import os\n",
"import csv\n",
"import tempfile\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pytorch_lightning as pl\n",
"import timm\n",
"from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint\n",
"from pytorch_lightning.loggers import CSVLogger\n",
"\n",
"from torchgeo.datamodules import EuroSATDataModule\n",
"from torchgeo.trainers import ClassificationTask\n",
"from torchgeo.models import ResNet50_Weights, VITSmall16_Weights"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# we set a flag to check to see whether the notebook is currently being run by PyTest, if this is the case then we'll\n",
"# skip the expensive training.\n",
"in_tests = \"PYTEST_CURRENT_TEST\" in os.environ"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Datamodule\n",
"\n",
"We will utilize TorchGeo's datamodules from [PyTorch Lightning](https://pytorch-lightning.readthedocs.io/) to organize the dataloader setup."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"root = os.path.join(tempfile.gettempdir(), \"eurosat\")\n",
"\n",
"datamodule = EuroSATDataModule(root=root, batch_size=64, num_workers=4)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Weights\n",
"\n",
"Available pretrained weights are listed on the model documentation [page](https://torchgeo.readthedocs.io/en/stable/api/models.html). While some weights only accept RGB channel input, some weights have been pretrained on Sentinel 2 imagery with 13 input channels and can hence prove useful for transfer learning tasks involving Sentinel 2 data.\n",
"\n",
"To access these weights you can do the following:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"weights = ResNet50_Weights.SENTINEL2_ALL_MOCO"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"This set of weights is a torchvision `WeightEnum` and holds information such as the download url link or additional meta data. TorchGeo takes care of the downloading and initialization of models with a desired set of weights. Given that EuroSAT is a classification dataset, we can use a `ClassificationTask` object that holds the model and optimizer object as well as the training logic."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"task = ClassificationTask(\n",
" model=\"resnet50\",\n",
" loss=\"ce\",\n",
" weights=weights,\n",
" in_channels=13,\n",
" num_classes=10,\n",
" learning_rate=0.001,\n",
" learning_rate_schedule_patience=5,\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"If you do not want to utilize the `ClassificationTask` functionality for your experiments, you can also just create a [timm](https://github.com/rwightman/pytorch-image-models) model with pretrained weights from TorchGeo as follows:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"in_chans = weights.meta[\"in_chans\"]\n",
"model = timm.create_model(\"resnet50\", in_chans=in_chans, num_classes=10)\n",
"model.load_state_dict(weights.get_state_dict(), strict=False)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Training\n",
"\n",
"To train our pretrained model on the EuroSAT dataset we will make use of PyTorch Lightning's [Trainer](https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html). For a more elaborate explanation of how TorchGeo uses PyTorch Lightning, check out [this tutorial](https://torchgeo.readthedocs.io/en/stable/tutorials/trainers.html)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"experiment_dir = os.path.join(tempfile.gettempdir(), \"eurosat_results\")\n",
"\n",
"checkpoint_callback = ModelCheckpoint(\n",
" monitor=\"val_loss\", dirpath=experiment_dir, save_top_k=1, save_last=True\n",
")\n",
"\n",
"early_stopping_callback = EarlyStopping(monitor=\"val_loss\", min_delta=0.00, patience=10)\n",
"\n",
"csv_logger = CSVLogger(save_dir=experiment_dir, name=\"pretrained_weights_logs\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trainer = pl.Trainer(\n",
" callbacks=[checkpoint_callback, early_stopping_callback],\n",
" logger=[csv_logger],\n",
" default_root_dir=experiment_dir,\n",
" min_epochs=1,\n",
" max_epochs=10,\n",
" fast_dev_run=in_tests,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trainer.fit(model=task, datamodule=datamodule)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "torchEnv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "b058dd71d0e7047e70e62f655d92ec955f772479bbe5e5addd202027292e8f60"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

Просмотреть файл

@ -11,12 +11,12 @@ dependencies:
- pycocotools>=2 - pycocotools>=2
- pyproj>=2.2 - pyproj>=2.2
- python>=3.7 - python>=3.7
- pytorch>=1.9 - pytorch>=1.12
- pyvista>=0.20 - pyvista>=0.20
- rarfile>=3 - rarfile>=3
- rasterio>=1.0.20 - rasterio>=1.0.20
- shapely>=1.3 - shapely>=1.3
- torchvision>=0.10 - torchvision>=0.13
- pip: - pip:
- black[jupyter]>=21.8 - black[jupyter]>=21.8
- flake8>=3.8 - flake8>=3.8

14
hubconf.py Normal file
Просмотреть файл

@ -0,0 +1,14 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""TorchGeo pre-trained model repository configuration file.
* https://pytorch.org/hub/
* https://pytorch.org/docs/stable/hub.html
"""
from torchgeo.models import resnet18, resnet50, vit_small_patch16_224
__all__ = ("resnet18", "resnet50", "vit_small_patch16_224")
dependencies = ["timm"]

Просмотреть файл

@ -18,9 +18,9 @@ scikit-learn==0.21.0
segmentation-models-pytorch==0.2.0 segmentation-models-pytorch==0.2.0
shapely==1.3.0 shapely==1.3.0
timm==0.4.12 timm==0.4.12
torch==1.9.0 torch==1.12.0
torchmetrics==0.10.0 torchmetrics==0.10.0
torchvision==0.10.0 torchvision==0.13.0
# datasets # datasets
h5py==2.6.0 h5py==2.6.0

Просмотреть файл

@ -57,12 +57,12 @@ install_requires =
shapely>=1.3,<3 shapely>=1.3,<3
# timm 0.4.12 required by segmentation-models-pytorch # timm 0.4.12 required by segmentation-models-pytorch
timm>=0.4.12,<0.7 timm>=0.4.12,<0.7
# torch 1.9+ required by torchvision # torch 1.12+ required by torchvision
torch>=1.9,<2 torch>=1.12,<2
# torchmetrics 0.10+ required for binary/multiclass/multilabel classification metrics # torchmetrics 0.10+ required for binary/multiclass/multilabel classification metrics
torchmetrics>=0.10,<0.12 torchmetrics>=0.10,<0.12
# torchvision 0.10+ required for torchvision.utils.draw_segmentation_masks # torchvision 0.13+ required for torchvision.models._api.WeightsEnum
torchvision>=0.10,<0.15 torchvision>=0.13,<0.15
python_requires = ~= 3.7 python_requires = ~= 3.7
packages = find: packages = find:

Просмотреть файл

@ -5,7 +5,7 @@ experiment:
model: "resnet18" model: "resnet18"
learning_rate: 1e-3 learning_rate: 1e-3
learning_rate_schedule_patience: 6 learning_rate_schedule_patience: 6
weights: "random" weights: null
in_channels: 14 in_channels: 14
num_classes: 19 num_classes: 19
datamodule: datamodule:

Просмотреть файл

@ -5,7 +5,7 @@ experiment:
model: "resnet18" model: "resnet18"
learning_rate: 1e-3 learning_rate: 1e-3
learning_rate_schedule_patience: 6 learning_rate_schedule_patience: 6
weights: "random" weights: null
in_channels: 2 in_channels: 2
num_classes: 19 num_classes: 19
datamodule: datamodule:

Просмотреть файл

@ -5,7 +5,7 @@ experiment:
model: "resnet18" model: "resnet18"
learning_rate: 1e-3 learning_rate: 1e-3
learning_rate_schedule_patience: 6 learning_rate_schedule_patience: 6
weights: "random" weights: null
in_channels: 12 in_channels: 12
num_classes: 19 num_classes: 19
datamodule: datamodule:

Просмотреть файл

@ -1,21 +0,0 @@
experiment:
task: "ssl"
name: "test_byol"
module:
model: "byol"
backbone: "resnet18"
input_channels: 4
weights: imagenet
learning_rate: 1e-3
learning_rate_schedule_patience: 6
datamodule:
root: "tests/data/chesapeake/cvpr"
download: true
train_splits:
- "de-test"
val_splits:
- "de-test"
test_splits:
- "de-test"
batch_size: 1
num_workers: 0

Просмотреть файл

@ -10,7 +10,7 @@ experiment:
num_classes: 7 num_classes: 7
num_filters: 1 num_filters: 1
ignore_index: null ignore_index: null
weights: imagenet weights: null
datamodule: datamodule:
root: "tests/data/chesapeake/cvpr" root: "tests/data/chesapeake/cvpr"
download: true download: true

Просмотреть файл

@ -10,7 +10,7 @@ experiment:
num_classes: 5 num_classes: 5
num_filters: 1 num_filters: 1
ignore_index: null ignore_index: null
weights: imagenet weights: null
datamodule: datamodule:
root: "tests/data/chesapeake/cvpr" root: "tests/data/chesapeake/cvpr"
download: true download: true

Просмотреть файл

@ -2,12 +2,11 @@ experiment:
task: cowc_counting task: cowc_counting
module: module:
model: resnet18 model: resnet18
weights: "random" weights: null
num_outputs: 1 num_outputs: 1
in_channels: 3 in_channels: 3
learning_rate: 1e-3 learning_rate: 1e-3
learning_rate_schedule_patience: 2 learning_rate_schedule_patience: 2
pretrained: True
datamodule: datamodule:
root: "tests/data/cowc_counting" root: "tests/data/cowc_counting"
download: true download: true

Просмотреть файл

@ -2,12 +2,11 @@ experiment:
task: "cyclone" task: "cyclone"
module: module:
model: "resnet18" model: "resnet18"
weights: "random" weights: null
num_outputs: 1 num_outputs: 1
in_channels: 3 in_channels: 3
learning_rate: 1e-3 learning_rate: 1e-3
learning_rate_schedule_patience: 2 learning_rate_schedule_patience: 2
pretrained: False
datamodule: datamodule:
root: "tests/data/cyclone" root: "tests/data/cyclone"
download: true download: true

Просмотреть файл

@ -5,7 +5,7 @@ experiment:
model: "resnet18" model: "resnet18"
learning_rate: 1e-3 learning_rate: 1e-3
learning_rate_schedule_patience: 6 learning_rate_schedule_patience: 6
weights: "random" weights: null
in_channels: 13 in_channels: 13
num_classes: 2 num_classes: 2
datamodule: datamodule:

Просмотреть файл

@ -5,7 +5,7 @@ experiment:
model: "resnet18" model: "resnet18"
learning_rate: 1e-3 learning_rate: 1e-3
learning_rate_schedule_patience: 6 learning_rate_schedule_patience: 6
weights: "random" weights: null
in_channels: 3 in_channels: 3
num_classes: 3 num_classes: 3
datamodule: datamodule:

Просмотреть файл

@ -5,7 +5,7 @@ experiment:
model: "resnet18" model: "resnet18"
learning_rate: 1e-3 learning_rate: 1e-3
learning_rate_schedule_patience: 6 learning_rate_schedule_patience: 6
weights: "random" weights: null
in_channels: 3 in_channels: 3
num_classes: 17 num_classes: 17
datamodule: datamodule:

Просмотреть файл

@ -5,7 +5,7 @@ experiment:
model: "resnet18" model: "resnet18"
learning_rate: 1e-3 learning_rate: 1e-3
learning_rate_schedule_patience: 6 learning_rate_schedule_patience: 6
weights: "random" weights: null
in_channels: 3 in_channels: 3
num_classes: 17 num_classes: 17
datamodule: datamodule:

Просмотреть файл

@ -3,7 +3,7 @@ experiment:
module: module:
loss: "ce" loss: "ce"
model: "resnet18" model: "resnet18"
weights: "random" weights: null
learning_rate: 1e-3 learning_rate: 1e-3
learning_rate_schedule_patience: 6 learning_rate_schedule_patience: 6
in_channels: 3 in_channels: 3

Двоичные данные
tests/data/models/resnet50-sentinel2-2.pt.zip

Двоичный файл не отображается.

50
tests/models/test_api.py Normal file
Просмотреть файл

@ -0,0 +1,50 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import enum
from typing import Callable
import pytest
import torch.nn as nn
from torchvision.models._api import WeightsEnum
from torchgeo.models import (
ResNet18_Weights,
ResNet50_Weights,
ViTSmall16_Weights,
get_model,
get_model_weights,
get_weight,
list_models,
resnet18,
resnet50,
vit_small_patch16_224,
)
builders = [resnet18, resnet50, vit_small_patch16_224]
enums = [ResNet18_Weights, ResNet50_Weights, ViTSmall16_Weights]
@pytest.mark.parametrize("builder", builders)
def test_get_model(builder: Callable[..., nn.Module]) -> None:
model = get_model(builder.__name__)
assert isinstance(model, nn.Module)
@pytest.mark.parametrize("builder", builders)
def test_get_model_weights(builder: Callable[..., nn.Module]) -> None:
weights = get_model_weights(builder)
assert isinstance(weights, enum.EnumMeta)
weights = get_model_weights(builder.__name__)
assert isinstance(weights, enum.EnumMeta)
@pytest.mark.parametrize("enum", enums)
def test_get_weight(enum: WeightsEnum) -> None:
for weight in enum:
assert weight == get_weight(str(weight))
def test_list_models() -> None:
models = [builder.__name__ for builder in builders]
assert set(models) == set(list_models())

Просмотреть файл

@ -1,61 +1,74 @@
# Copyright (c) Microsoft Corporation. All rights reserved. # Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. # Licensed under the MIT License.
import os
from pathlib import Path from pathlib import Path
from typing import Any, Optional from typing import Any, Dict
import pytest import pytest
import timm
import torch import torch
import torchvision
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch from _pytest.monkeypatch import MonkeyPatch
from torch.nn.modules import Module from torchvision.models._api import WeightsEnum
import torchgeo.models.resnet from torchgeo.models import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50
from torchgeo.datasets.utils import extract_archive
from torchgeo.models import resnet50
def load_state_dict_from_file( def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]:
file: str, state_dict: Dict[str, Any] = torch.load(url)
model_dir: Optional[str] = None, return state_dict
map_location: Optional[Any] = None,
progress: Optional[bool] = True,
check_hash: Optional[bool] = False,
file_name: Optional[str] = None,
) -> Any:
"""Mockup of ``torch.hub.load_state_dict_from_url``."""
return torch.load(file)
@pytest.mark.parametrize( class TestResNet18:
"model_class,sensor,bands,in_channels,num_classes", @pytest.fixture(params=[*ResNet18_Weights])
[(resnet50, "sentinel2", "all", 10, 17)], def weights(self, request: SubRequest) -> WeightsEnum:
) return request.param
def test_resnet(
monkeypatch: MonkeyPatch,
tmp_path: Path,
model_class: Module,
sensor: str,
bands: str,
in_channels: int,
num_classes: int,
) -> None:
extract_archive(
os.path.join("tests", "data", "models", "resnet50-sentinel2-2.pt.zip"),
str(tmp_path),
)
new_model_urls = { @pytest.fixture
"sentinel2": {"all": {"resnet50": str(tmp_path / "resnet50-sentinel2-2.pt")}} def mocked_weights(
} self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
) -> WeightsEnum:
path = tmp_path / f"{weights}.pth"
model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"])
torch.save(model.state_dict(), path)
monkeypatch.setattr(weights, "url", str(path))
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
return weights
monkeypatch.setattr(torchgeo.models.resnet, "MODEL_URLS", new_model_urls) def test_resnet(self) -> None:
monkeypatch.setattr( resnet18()
torchgeo.models.resnet, "load_state_dict_from_url", load_state_dict_from_file
)
model = model_class(sensor, bands, pretrained=True) def test_resnet_weights(self, mocked_weights: WeightsEnum) -> None:
x = torch.zeros(1, in_channels, 256, 256) resnet18(weights=mocked_weights)
y = model(x)
assert isinstance(y, torch.Tensor) @pytest.mark.slow
assert y.size() == torch.Size([1, 17]) def test_resnet_download(self, weights: WeightsEnum) -> None:
resnet18(weights=weights)
class TestResNet50:
@pytest.fixture(params=[*ResNet50_Weights])
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param
@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
) -> WeightsEnum:
path = tmp_path / f"{weights}.pth"
model = timm.create_model("resnet50", in_chans=weights.meta["in_chans"])
torch.save(model.state_dict(), path)
monkeypatch.setattr(weights, "url", str(path))
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
return weights
def test_resnet(self) -> None:
resnet50()
def test_resnet_weights(self, mocked_weights: WeightsEnum) -> None:
resnet50(weights=mocked_weights)
@pytest.mark.slow
def test_resnet_download(self, weights: WeightsEnum) -> None:
resnet50(weights=weights)

49
tests/models/test_vit.py Normal file
Просмотреть файл

@ -0,0 +1,49 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
from pathlib import Path
from typing import Any, Dict
import pytest
import timm
import torch
import torchvision
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from torchvision.models._api import WeightsEnum
from torchgeo.models import ViTSmall16_Weights, vit_small_patch16_224
def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]:
state_dict: Dict[str, Any] = torch.load(url)
return state_dict
class TestViTSmall16:
@pytest.fixture(params=[*ViTSmall16_Weights])
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param
@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
) -> WeightsEnum:
path = tmp_path / f"{weights}.pth"
model = timm.create_model(
weights.meta["model"], in_chans=weights.meta["in_chans"]
)
torch.save(model.state_dict(), path)
monkeypatch.setattr(weights, "url", str(path))
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
return weights
def test_vit(self) -> None:
vit_small_patch16_224()
def test_vit_weights(self, mocked_weights: WeightsEnum) -> None:
vit_small_patch16_224(weights=mocked_weights)
@pytest.mark.slow
def test_vit_download(self, weights: WeightsEnum) -> None:
vit_small_patch16_224(weights=weights)

Просмотреть файл

@ -2,21 +2,33 @@
# Licensed under the MIT License. # Licensed under the MIT License.
import os import os
from pathlib import Path
from typing import Any, Dict, Type, cast from typing import Any, Dict, Type, cast
import pytest import pytest
import timm
import torch
import torch.nn as nn import torch.nn as nn
import torchvision
from _pytest.monkeypatch import MonkeyPatch
from omegaconf import OmegaConf from omegaconf import OmegaConf
from pytorch_lightning import LightningDataModule, Trainer from pytorch_lightning import LightningDataModule, Trainer
from torchvision.models import resnet18 from torchvision.models import resnet18
from torchvision.models._api import WeightsEnum
from torchgeo.datamodules import ChesapeakeCVPRDataModule from torchgeo.datamodules import ChesapeakeCVPRDataModule
from torchgeo.models import ResNet18_Weights
from torchgeo.trainers import BYOLTask from torchgeo.trainers import BYOLTask
from torchgeo.trainers.byol import BYOL, SimCLRAugmentation from torchgeo.trainers.byol import BYOL, SimCLRAugmentation
from .test_utils import SegmentationTestModel from .test_utils import SegmentationTestModel
def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]:
state_dict: Dict[str, Any] = torch.load(url)
return state_dict
class TestBYOL: class TestBYOL:
def test_custom_augment_fn(self) -> None: def test_custom_augment_fn(self) -> None:
backbone = resnet18() backbone = resnet18()
@ -45,7 +57,7 @@ class TestBYOLTask:
def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None:
conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))
conf_dict = OmegaConf.to_object(conf.experiment) conf_dict = OmegaConf.to_object(conf.experiment)
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict) conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict)
# Instantiate datamodule # Instantiate datamodule
datamodule_kwargs = conf_dict["datamodule"] datamodule_kwargs = conf_dict["datamodule"]
@ -64,30 +76,31 @@ class TestBYOLTask:
trainer.predict(model=model, dataloaders=datamodule.val_dataloader()) trainer.predict(model=model, dataloaders=datamodule.val_dataloader())
@pytest.fixture @pytest.fixture
def model_kwargs(self) -> Dict[Any, Any]: def model_kwargs(self) -> Dict[str, Any]:
return {"backbone": "resnet18", "weights": "random", "in_channels": 3} return {"backbone": "resnet18", "weights": None, "in_channels": 3}
def test_invalid_pretrained( @pytest.fixture
self, model_kwargs: Dict[Any, Any], checkpoint: str def mocked_weights(self, tmp_path: Path, monkeypatch: MonkeyPatch) -> WeightsEnum:
weights = ResNet18_Weights.SENTINEL2_RGB_MOCO
path = tmp_path / f"{weights}.pth"
model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"])
torch.save(model.state_dict(), path)
monkeypatch.setattr(weights, "url", str(path))
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
return weights
def test_weight_file(self, model_kwargs: Dict[str, Any], checkpoint: str) -> None:
model_kwargs["weights"] = checkpoint
BYOLTask(**model_kwargs)
def test_weight_enum(
self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum
) -> None: ) -> None:
model_kwargs["weights"] = checkpoint model_kwargs["weights"] = mocked_weights
model_kwargs["backbone"] = "resnet50"
match = "Trying to load resnet18 weights into a resnet50"
with pytest.raises(ValueError, match=match):
BYOLTask(**model_kwargs) BYOLTask(**model_kwargs)
def test_pretrained(self, model_kwargs: Dict[Any, Any], checkpoint: str) -> None: def test_weight_str(
model_kwargs["weights"] = checkpoint self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum
BYOLTask(**model_kwargs) ) -> None:
model_kwargs["weights"] = str(mocked_weights)
def test_invalid_backbone(self, model_kwargs: Dict[Any, Any]) -> None:
model_kwargs["backbone"] = "invalid_backbone"
match = "Model type 'invalid_backbone' is not a valid timm model."
with pytest.raises(ValueError, match=match):
BYOLTask(**model_kwargs)
def test_invalid_weights(self, model_kwargs: Dict[Any, Any]) -> None:
model_kwargs["weights"] = "invalid_weights"
match = "Weight type 'invalid_weights' is not valid."
with pytest.raises(ValueError, match=match):
BYOLTask(**model_kwargs) BYOLTask(**model_kwargs)

Просмотреть файл

@ -2,14 +2,18 @@
# Licensed under the MIT License. # Licensed under the MIT License.
import os import os
from pathlib import Path
from typing import Any, Dict, Type, cast from typing import Any, Dict, Type, cast
import pytest import pytest
import timm import timm
import torch
import torchvision
from _pytest.monkeypatch import MonkeyPatch from _pytest.monkeypatch import MonkeyPatch
from omegaconf import OmegaConf from omegaconf import OmegaConf
from pytorch_lightning import LightningDataModule, Trainer from pytorch_lightning import LightningDataModule, Trainer
from torch.nn.modules import Module from torch.nn.modules import Module
from torchvision.models._api import WeightsEnum
from torchgeo.datamodules import ( from torchgeo.datamodules import (
BigEarthNetDataModule, BigEarthNetDataModule,
@ -18,6 +22,7 @@ from torchgeo.datamodules import (
So2SatDataModule, So2SatDataModule,
UCMercedDataModule, UCMercedDataModule,
) )
from torchgeo.models import ResNet18_Weights
from torchgeo.trainers import ClassificationTask, MultiLabelClassificationTask from torchgeo.trainers import ClassificationTask, MultiLabelClassificationTask
from .test_utils import ClassificationTestModel from .test_utils import ClassificationTestModel
@ -27,6 +32,11 @@ def create_model(*args: Any, **kwargs: Any) -> Module:
return ClassificationTestModel(**kwargs) return ClassificationTestModel(**kwargs)
def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]:
state_dict: Dict[str, Any] = torch.load(url)
return state_dict
class TestClassificationTask: class TestClassificationTask:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"name,classname", "name,classname",
@ -46,7 +56,7 @@ class TestClassificationTask:
conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))
conf_dict = OmegaConf.to_object(conf.experiment) conf_dict = OmegaConf.to_object(conf.experiment)
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict) conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict)
# Instantiate datamodule # Instantiate datamodule
datamodule_kwargs = conf_dict["datamodule"] datamodule_kwargs = conf_dict["datamodule"]
@ -66,7 +76,7 @@ class TestClassificationTask:
def test_no_logger(self) -> None: def test_no_logger(self) -> None:
conf = OmegaConf.load(os.path.join("tests", "conf", "ucmerced.yaml")) conf = OmegaConf.load(os.path.join("tests", "conf", "ucmerced.yaml"))
conf_dict = OmegaConf.to_object(conf.experiment) conf_dict = OmegaConf.to_object(conf.experiment)
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict) conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict)
# Instantiate datamodule # Instantiate datamodule
datamodule_kwargs = conf_dict["datamodule"] datamodule_kwargs = conf_dict["datamodule"]
@ -83,49 +93,52 @@ class TestClassificationTask:
trainer.fit(model=model, datamodule=datamodule) trainer.fit(model=model, datamodule=datamodule)
@pytest.fixture @pytest.fixture
def model_kwargs(self) -> Dict[Any, Any]: def model_kwargs(self) -> Dict[str, Any]:
return { return {
"model": "resnet18", "model": "resnet18",
"in_channels": 13, "in_channels": 13,
"loss": "ce", "loss": "ce",
"num_classes": 10, "num_classes": 10,
"weights": "random", "weights": None,
} }
def test_pretrained(self, model_kwargs: Dict[Any, Any], checkpoint: str) -> None: @pytest.fixture
def mocked_weights(self, tmp_path: Path, monkeypatch: MonkeyPatch) -> WeightsEnum:
weights = ResNet18_Weights.SENTINEL2_ALL_MOCO
path = tmp_path / f"{weights}.pth"
model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"])
torch.save(model.state_dict(), path)
monkeypatch.setattr(weights, "url", str(path))
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
return weights
def test_weight_file(self, model_kwargs: Dict[str, Any], checkpoint: str) -> None:
model_kwargs["weights"] = checkpoint model_kwargs["weights"] = checkpoint
with pytest.warns(UserWarning): with pytest.warns(UserWarning):
ClassificationTask(**model_kwargs) ClassificationTask(**model_kwargs)
def test_invalid_pretrained( def test_weight_enum(
self, model_kwargs: Dict[Any, Any], checkpoint: str self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum
) -> None: ) -> None:
model_kwargs["weights"] = checkpoint model_kwargs["weights"] = mocked_weights
model_kwargs["model"] = "resnet50" with pytest.warns(UserWarning):
match = "Trying to load resnet18 weights into a resnet50"
with pytest.raises(ValueError, match=match):
ClassificationTask(**model_kwargs) ClassificationTask(**model_kwargs)
def test_invalid_loss(self, model_kwargs: Dict[Any, Any]) -> None: def test_weight_str(
self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum
) -> None:
model_kwargs["weights"] = str(mocked_weights)
with pytest.warns(UserWarning):
ClassificationTask(**model_kwargs)
def test_invalid_loss(self, model_kwargs: Dict[str, Any]) -> None:
model_kwargs["loss"] = "invalid_loss" model_kwargs["loss"] = "invalid_loss"
match = "Loss type 'invalid_loss' is not valid." match = "Loss type 'invalid_loss' is not valid."
with pytest.raises(ValueError, match=match): with pytest.raises(ValueError, match=match):
ClassificationTask(**model_kwargs) ClassificationTask(**model_kwargs)
def test_invalid_model(self, model_kwargs: Dict[Any, Any]) -> None:
model_kwargs["model"] = "invalid_model"
match = "Model type 'invalid_model' is not a valid timm model."
with pytest.raises(ValueError, match=match):
ClassificationTask(**model_kwargs)
def test_invalid_weights(self, model_kwargs: Dict[Any, Any]) -> None:
model_kwargs["weights"] = "invalid_weights"
match = "Weight type 'invalid_weights' is not valid."
with pytest.raises(ValueError, match=match):
ClassificationTask(**model_kwargs)
def test_missing_attributes( def test_missing_attributes(
self, model_kwargs: Dict[Any, Any], monkeypatch: MonkeyPatch self, model_kwargs: Dict[str, Any], monkeypatch: MonkeyPatch
) -> None: ) -> None:
monkeypatch.delattr(EuroSATDataModule, "plot") monkeypatch.delattr(EuroSATDataModule, "plot")
datamodule = EuroSATDataModule( datamodule = EuroSATDataModule(
@ -150,7 +163,7 @@ class TestMultiLabelClassificationTask:
) -> None: ) -> None:
conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))
conf_dict = OmegaConf.to_object(conf.experiment) conf_dict = OmegaConf.to_object(conf.experiment)
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict) conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict)
# Instantiate datamodule # Instantiate datamodule
datamodule_kwargs = conf_dict["datamodule"] datamodule_kwargs = conf_dict["datamodule"]
@ -170,7 +183,7 @@ class TestMultiLabelClassificationTask:
def test_no_logger(self) -> None: def test_no_logger(self) -> None:
conf = OmegaConf.load(os.path.join("tests", "conf", "bigearthnet_s1.yaml")) conf = OmegaConf.load(os.path.join("tests", "conf", "bigearthnet_s1.yaml"))
conf_dict = OmegaConf.to_object(conf.experiment) conf_dict = OmegaConf.to_object(conf.experiment)
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict) conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict)
# Instantiate datamodule # Instantiate datamodule
datamodule_kwargs = conf_dict["datamodule"] datamodule_kwargs = conf_dict["datamodule"]
@ -187,23 +200,23 @@ class TestMultiLabelClassificationTask:
trainer.fit(model=model, datamodule=datamodule) trainer.fit(model=model, datamodule=datamodule)
@pytest.fixture @pytest.fixture
def model_kwargs(self) -> Dict[Any, Any]: def model_kwargs(self) -> Dict[str, Any]:
return { return {
"model": "resnet18", "model": "resnet18",
"in_channels": 14, "in_channels": 14,
"loss": "bce", "loss": "bce",
"num_classes": 19, "num_classes": 19,
"weights": "random", "weights": None,
} }
def test_invalid_loss(self, model_kwargs: Dict[Any, Any]) -> None: def test_invalid_loss(self, model_kwargs: Dict[str, Any]) -> None:
model_kwargs["loss"] = "invalid_loss" model_kwargs["loss"] = "invalid_loss"
match = "Loss type 'invalid_loss' is not valid." match = "Loss type 'invalid_loss' is not valid."
with pytest.raises(ValueError, match=match): with pytest.raises(ValueError, match=match):
MultiLabelClassificationTask(**model_kwargs) MultiLabelClassificationTask(**model_kwargs)
def test_missing_attributes( def test_missing_attributes(
self, model_kwargs: Dict[Any, Any], monkeypatch: MonkeyPatch self, model_kwargs: Dict[str, Any], monkeypatch: MonkeyPatch
) -> None: ) -> None:
monkeypatch.delattr(BigEarthNetDataModule, "plot") monkeypatch.delattr(BigEarthNetDataModule, "plot")
datamodule = BigEarthNetDataModule( datamodule = BigEarthNetDataModule(

Просмотреть файл

@ -2,18 +2,30 @@
# Licensed under the MIT License. # Licensed under the MIT License.
import os import os
from pathlib import Path
from typing import Any, Dict, Type, cast from typing import Any, Dict, Type, cast
import pytest import pytest
import timm
import torch
import torchvision
from _pytest.monkeypatch import MonkeyPatch
from omegaconf import OmegaConf from omegaconf import OmegaConf
from pytorch_lightning import LightningDataModule, Trainer from pytorch_lightning import LightningDataModule, Trainer
from torchvision.models._api import WeightsEnum
from torchgeo.datamodules import COWCCountingDataModule, TropicalCycloneDataModule from torchgeo.datamodules import COWCCountingDataModule, TropicalCycloneDataModule
from torchgeo.models import ResNet18_Weights
from torchgeo.trainers import RegressionTask from torchgeo.trainers import RegressionTask
from .test_utils import RegressionTestModel from .test_utils import RegressionTestModel
def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]:
state_dict: Dict[str, Any] = torch.load(url)
return state_dict
class TestRegressionTask: class TestRegressionTask:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"name,classname", "name,classname",
@ -25,7 +37,7 @@ class TestRegressionTask:
def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None:
conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))
conf_dict = OmegaConf.to_object(conf.experiment) conf_dict = OmegaConf.to_object(conf.experiment)
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict) conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict)
# Instantiate datamodule # Instantiate datamodule
datamodule_kwargs = conf_dict["datamodule"] datamodule_kwargs = conf_dict["datamodule"]
@ -46,7 +58,7 @@ class TestRegressionTask:
def test_no_logger(self) -> None: def test_no_logger(self) -> None:
conf = OmegaConf.load(os.path.join("tests", "conf", "cyclone.yaml")) conf = OmegaConf.load(os.path.join("tests", "conf", "cyclone.yaml"))
conf_dict = OmegaConf.to_object(conf.experiment) conf_dict = OmegaConf.to_object(conf.experiment)
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict) conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict)
# Instantiate datamodule # Instantiate datamodule
datamodule_kwargs = conf_dict["datamodule"] datamodule_kwargs = conf_dict["datamodule"]
@ -63,36 +75,39 @@ class TestRegressionTask:
trainer.fit(model=model, datamodule=datamodule) trainer.fit(model=model, datamodule=datamodule)
@pytest.fixture @pytest.fixture
def model_kwargs(self) -> Dict[Any, Any]: def model_kwargs(self) -> Dict[str, Any]:
return { return {
"model": "resnet18", "model": "resnet18",
"weights": "random", "weights": None,
"num_outputs": 1, "num_outputs": 1,
"in_channels": 3, "in_channels": 3,
} }
def test_invalid_pretrained( @pytest.fixture
self, model_kwargs: Dict[Any, Any], checkpoint: str def mocked_weights(self, tmp_path: Path, monkeypatch: MonkeyPatch) -> WeightsEnum:
) -> None: weights = ResNet18_Weights.SENTINEL2_RGB_MOCO
model_kwargs["weights"] = checkpoint path = tmp_path / f"{weights}.pth"
model_kwargs["model"] = "resnet50" model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"])
match = "Trying to load resnet18 weights into a resnet50" torch.save(model.state_dict(), path)
with pytest.raises(ValueError, match=match): monkeypatch.setattr(weights, "url", str(path))
RegressionTask(**model_kwargs) monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
return weights
def test_pretrained(self, model_kwargs: Dict[Any, Any], checkpoint: str) -> None: def test_weight_file(self, model_kwargs: Dict[str, Any], checkpoint: str) -> None:
model_kwargs["weights"] = checkpoint model_kwargs["weights"] = checkpoint
with pytest.warns(UserWarning): with pytest.warns(UserWarning):
RegressionTask(**model_kwargs) RegressionTask(**model_kwargs)
def test_invalid_model(self, model_kwargs: Dict[Any, Any]) -> None: def test_weight_enum(
model_kwargs["model"] = "invalid_model" self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum
match = "Model type 'invalid_model' is not a valid timm model." ) -> None:
with pytest.raises(ValueError, match=match): model_kwargs["weights"] = mocked_weights
with pytest.warns(UserWarning):
RegressionTask(**model_kwargs) RegressionTask(**model_kwargs)
def test_invalid_weights(self, model_kwargs: Dict[Any, Any]) -> None: def test_weight_str(
model_kwargs["weights"] = "invalid_weights" self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum
match = "Weight type 'invalid_weights' is not valid." ) -> None:
with pytest.raises(ValueError, match=match): model_kwargs["weights"] = str(mocked_weights)
with pytest.warns(UserWarning):
RegressionTask(**model_kwargs) RegressionTask(**model_kwargs)

Просмотреть файл

@ -3,14 +3,17 @@
"""TorchGeo models.""" """TorchGeo models."""
from .api import get_model, get_model_weights, get_weight, list_models
from .changestar import ChangeMixin, ChangeStar, ChangeStarFarSeg from .changestar import ChangeMixin, ChangeStar, ChangeStarFarSeg
from .farseg import FarSeg from .farseg import FarSeg
from .fcn import FCN from .fcn import FCN
from .fcsiam import FCSiamConc, FCSiamDiff from .fcsiam import FCSiamConc, FCSiamDiff
from .rcf import RCF from .rcf import RCF
from .resnet import resnet50 from .resnet import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50
from .vit import ViTSmall16_Weights, vit_small_patch16_224
__all__ = ( __all__ = (
# models
"ChangeMixin", "ChangeMixin",
"ChangeStar", "ChangeStar",
"ChangeStarFarSeg", "ChangeStarFarSeg",
@ -19,5 +22,16 @@ __all__ = (
"FCSiamConc", "FCSiamConc",
"FCSiamDiff", "FCSiamDiff",
"RCF", "RCF",
"resnet18",
"resnet50", "resnet50",
"vit_small_patch16_224",
# weights
"ResNet50_Weights",
"ResNet18_Weights",
"ViTSmall16_Weights",
# utilities
"get_model",
"get_model_weights",
"get_weight",
"list_models",
) )

90
torchgeo/models/api.py Normal file
Просмотреть файл

@ -0,0 +1,90 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""APIs for querying and loading pre-trained model weights.
See the following references for design details:
* https://pytorch.org/blog/easily-list-and-initialize-models-with-new-apis-in-torchvision/
* https://pytorch.org/vision/stable/models.html
* https://github.com/pytorch/vision/blob/main/torchvision/models/_api.py
""" # noqa: E501
from typing import Any, Callable, List, Union
import torch.nn as nn
from torchvision.models._api import WeightsEnum
from .resnet import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50
from .vit import ViTSmall16_Weights, vit_small_patch16_224
_model = {
"resnet18": resnet18,
"resnet50": resnet50,
"vit_small_patch16_224": vit_small_patch16_224,
}
_model_weights = {
resnet18: ResNet18_Weights,
resnet50: ResNet50_Weights,
vit_small_patch16_224: ViTSmall16_Weights,
"resnet18": ResNet18_Weights,
"resnet50": ResNet50_Weights,
"vit_small_patch16_224": ViTSmall16_Weights,
}
def get_model(name: str, *args: Any, **kwargs: Any) -> nn.Module:
"""Get an instantiated model from its name.
.. versionadded:: 0.4
Args:
name: Name of the model.
*args: Additional arguments passed to the model builder method.
**kwargs: Additional keyword arguments passed to the model builder method.
Returns:
An instantiated model.
"""
model: nn.Module = _model[name](*args, **kwargs)
return model
def get_model_weights(name: Union[Callable[..., nn.Module], str]) -> WeightsEnum:
"""Get the weights enum class associated with a given model.
.. versionadded:: 0.4
Args:
name: Model builder function or the name under which it is registered.
Returns:
The weights enum class associated with the model.
"""
return _model_weights[name]
def get_weight(name: str) -> WeightsEnum:
"""Get the weights enum value by its full name.
.. versionadded:: 0.4
Args:
name: Name of the weight enum entry.
Returns:
The requested weight enum.
"""
return eval(name)
def list_models() -> List[str]:
"""List the registered models.
.. versionadded:: 0.4
Returns:
A list of registered models.
"""
return list(_model.keys())

Просмотреть файл

@ -9,7 +9,6 @@ from typing import List, cast
import torch.nn.functional as F import torch.nn.functional as F
import torchvision import torchvision
from packaging.version import parse
from torch import Tensor from torch import Tensor
from torch.nn.modules import ( from torch.nn.modules import (
BatchNorm2d, BatchNorm2d,
@ -62,7 +61,6 @@ class FarSeg(Module):
else: else:
raise ValueError(f"unknown backbone: {backbone}.") raise ValueError(f"unknown backbone: {backbone}.")
kwargs = {} kwargs = {}
if parse(torchvision.__version__) >= parse("0.13"):
if backbone_pretrained: if backbone_pretrained:
kwargs = { kwargs = {
"weights": getattr( "weights": getattr(
@ -71,8 +69,6 @@ class FarSeg(Module):
} }
else: else:
kwargs = {"weights": None} kwargs = {"weights": None}
else:
kwargs = {"pretrained": backbone_pretrained}
self.backbone = getattr(resnet, backbone)(**kwargs) self.backbone = getattr(resnet, backbone)(**kwargs)

Просмотреть файл

@ -3,83 +3,208 @@
"""Pre-trained ResNet models.""" """Pre-trained ResNet models."""
from typing import Any, List, Type, Union from typing import Any, Optional
import kornia.augmentation as K
import timm
import torch.nn as nn import torch.nn as nn
from torch.hub import load_state_dict_from_url from timm.models import ResNet
from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet from torchvision.models._api import Weights, WeightsEnum
MODEL_URLS = { from ..transforms import AugmentationSequential
"sentinel2": {
"all": { __all__ = ["ResNet50_Weights", "ResNet18_Weights"]
"resnet50": "https://zenodo.org/record/5610000/files/resnet50-sentinel2.pt"
} _zhu_xlab_transforms = AugmentationSequential(
} K.Resize(256), K.CenterCrop(224), data_keys=["image"]
} )
# https://github.com/pytorch/vision/pull/6883
# https://github.com/pytorch/vision/pull/7107
# Can be removed once torchvision>=0.15 is required
Weights.__deepcopy__ = lambda *args, **kwargs: args[0]
IN_CHANNELS = {"sentinel2": {"all": 10}} class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
"""ResNet18 weights.
NUM_CLASSES = {"sentinel2": 17} For `timm <https://github.com/rwightman/pytorch-image-models>`_
*resnet18* implementation.
.. versionadded:: 0.4
"""
SENTINEL2_ALL_MOCO = Weights(
url=(
"https://huggingface.co/torchgeo/resnet18_sentinel2_all_moco/"
"resolve/main/resnet18_sentinel2_all_moco.pth"
),
transforms=_zhu_xlab_transforms,
meta={
"dataset": "SSL4EO-S12",
"in_chans": 13,
"model": "resnet18",
"publication": "https://arxiv.org/abs/2211.07044",
"repo": "https://github.com/zhu-xlab/SSL4EO-S12",
"ssl_method": "moco",
},
)
SENTINEL2_RGB_MOCO = Weights(
url=(
"https://huggingface.co/torchgeo/resnet18_sentinel2_rgb_moco/"
"resolve/main/resnet18_sentinel2_rgb_moco.pth"
),
transforms=_zhu_xlab_transforms,
meta={
"dataset": "SSL4EO-S12",
"in_chans": 3,
"model": "resnet18",
"publication": "https://arxiv.org/abs/2211.07044",
"repo": "https://github.com/zhu-xlab/SSL4EO-S12",
"ssl_method": "moco",
},
)
SENTINEL2_RGB_SECO = Weights(
url=(
"https://huggingface.co/torchgeo/resnet18_sentinel2_rgb_seco/"
"resolve/main/resnet18_sentinel2_rgb_seco.ckpt"
),
transforms=nn.Identity(),
meta={
"dataset": "SeCo Dataset",
"in_chans": 3,
"model": "resnet18",
"publication": "https://arxiv.org/abs/2103.16607",
"repo": "https://github.com/ServiceNow/seasonal-contrast",
"ssl_method": "seco",
},
)
def _resnet( class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
sensor: str, """ResNet50 weights.
bands: str,
arch: str, For `timm <https://github.com/rwightman/pytorch-image-models>`_
block: Type[Union[BasicBlock, Bottleneck]], *resnet50* implementation.
layers: List[int],
pretrained: bool, .. versionadded:: 0.4
progress: bool, """
**kwargs: Any,
SENTINEL1_ALL_MOCO = Weights(
url=(
"https://huggingface.co/torchgeo/resnet50_sentinel1_all_moco/"
"resolve/main/resnet50_sentinel1_all_moco.pth"
),
transforms=_zhu_xlab_transforms,
meta={
"dataset": "SSL4EO-S12",
"in_chans": 2,
"model": "resnet50",
"publication": "https://arxiv.org/abs/2211.07044",
"repo": "https://github.com/zhu-xlab/SSL4EO-S12",
"ssl_method": "moco",
},
)
SENTINEL2_ALL_MOCO = Weights(
url=(
"https://huggingface.co/torchgeo/resnet50_sentinel2_all_moco/"
"resolve/main/resnet50_sentinel2_all_moco.pth"
),
transforms=_zhu_xlab_transforms,
meta={
"dataset": "SSL4EO-S12",
"in_chans": 13,
"model": "resnet50",
"publication": "https://arxiv.org/abs/2211.07044",
"repo": "https://github.com/zhu-xlab/SSL4EO-S12",
"ssl_method": "moco",
},
)
SENTINEL2_RGB_MOCO = Weights(
url=(
"https://huggingface.co/torchgeo/resnet50_sentinel2_rgb_moco/"
"resolve/main/resnet50_sentinel2_rgb_moco.pth"
),
transforms=_zhu_xlab_transforms,
meta={
"dataset": "SSL4EO-S12",
"in_chans": 3,
"model": "resnet50",
"publication": "https://arxiv.org/abs/2211.07044",
"repo": "https://github.com/zhu-xlab/SSL4EO-S12",
"ssl_method": "moco",
},
)
SENTINEL2_ALL_DINO = Weights(
url=(
"https://huggingface.co/torchgeo/resnet50_sentinel2_all_dino/"
"resolve/main/resnet50_sentinel2_all_dino.pth"
),
transforms=_zhu_xlab_transforms,
meta={
"dataset": "SSL4EO-S12",
"in_chans": 13,
"model": "resnet50",
"publication": "https://arxiv.org/abs/2211.07044",
"repo": "https://github.com/zhu-xlab/SSL4EO-S12",
"ssl_method": "dino",
},
)
SENTINEL2_RGB_SECO = Weights(
url=(
"https://huggingface.co/torchgeo/resnet50_sentinel2_rgb_seco/"
"resolve/main/resnet50_sentinel2_rgb_seco.ckpt"
),
transforms=nn.Identity(),
meta={
"dataset": "SeCo Dataset",
"in_chans": 3,
"model": "resnet50",
"publication": "https://arxiv.org/abs/2103.16607",
"repo": "https://github.com/ServiceNow/seasonal-contrast",
"ssl_method": "seco",
},
)
def resnet18(
weights: Optional[ResNet18_Weights] = None, *args: Any, **kwargs: Any
) -> ResNet: ) -> ResNet:
"""Resnet model. """ResNet-18 model.
If you use this model in your research, please cite the following paper: If you use this model in your research, please cite the following paper:
* https://arxiv.org/pdf/1512.03385.pdf * https://arxiv.org/pdf/1512.03385.pdf
.. versionadded:: 0.4
Args: Args:
sensor: imagery source which determines number of input channels weights: Pre-trained model weights to use.
bands: which spectral bands to consider: "all", "rgb", etc. *args: Additional arguments to pass to :func:`timm.create_model`
arch: ResNet version specifying number of layers **kwargs: Additional keywork arguments to pass to :func:`timm.create_model`
block: type of network block
layers: number of layers per block
pretrained: if True, returns a model pre-trained on ``sensor`` imagery
progress: if True, displays a progress bar of the download to stderr
Returns: Returns:
A ResNet-50 model A ResNet-18 model.
""" """
# Initialize a new model if weights:
model = ResNet(block, layers, NUM_CLASSES[sensor], **kwargs) kwargs["in_chans"] = weights.meta["in_chans"]
# Replace the first layer with the correct number of input channels model: ResNet = timm.create_model("resnet18", *args, **kwargs)
model.conv1 = nn.Conv2d(
IN_CHANNELS[sensor][bands],
out_channels=64,
kernel_size=7,
stride=1,
padding=2,
bias=False,
)
# Load pretrained weights if weights:
if pretrained: model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
state_dict = load_state_dict_from_url(
MODEL_URLS[sensor][bands][arch], progress=progress
)
model.load_state_dict(state_dict)
return model return model
def resnet50( def resnet50(
sensor: str, weights: Optional[ResNet50_Weights] = None, *args: Any, **kwargs: Any
bands: str,
pretrained: bool = False,
progress: bool = True,
**kwargs: Any,
) -> ResNet: ) -> ResNet:
"""ResNet-50 model. """ResNet-50 model.
@ -87,22 +212,23 @@ def resnet50(
* https://arxiv.org/pdf/1512.03385.pdf * https://arxiv.org/pdf/1512.03385.pdf
.. versionchanged:: 0.4
Switched to multi-weight support API.
Args: Args:
sensor: imagery source which determines number of input channels weights: Pre-trained model weights to use.
bands: which spectral bands to consider: "all", "rgb", etc. *args: Additional arguments to pass to :func:`timm.create_model`.
pretrained: if True, returns a model pre-trained on ``sensor`` imagery **kwargs: Additional keywork arguments to pass to :func:`timm.create_model`.
progress: if True, displays a progress bar of the download to stderr
Returns: Returns:
A ResNet-50 model A ResNet-50 model.
""" """
return _resnet( if weights:
sensor, kwargs["in_chans"] = weights.meta["in_chans"]
bands,
"resnet50", model: ResNet = timm.create_model("resnet50", *args, **kwargs)
Bottleneck,
[3, 4, 6, 3], if weights:
pretrained, model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
progress,
**kwargs, return model
)

98
torchgeo/models/vit.py Normal file
Просмотреть файл

@ -0,0 +1,98 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""Pre-trained Vision Transformer models."""
from typing import Any, Optional
import kornia.augmentation as K
import timm
from timm.models.vision_transformer import VisionTransformer
from torchvision.models._api import Weights, WeightsEnum
from ..transforms import AugmentationSequential
__all__ = ["ViTSmall16_Weights"]
_zhu_xlab_transforms = AugmentationSequential(
K.Resize(256), K.CenterCrop(224), data_keys=["image"]
)
# https://github.com/pytorch/vision/pull/6883
# https://github.com/pytorch/vision/pull/7107
# Can be removed once torchvision>=0.15 is required
Weights.__deepcopy__ = lambda *args, **kwargs: args[0]
class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
"""Vision Transformer Samll Patch Size 16 weights.
For `timm <https://github.com/rwightman/pytorch-image-models>`_
*vit_small_patch16_224* implementation.
.. versionadded:: 0.4
"""
SENTINEL2_ALL_MOCO = Weights(
url=(
"https://huggingface.co/torchgeo/vit_small_patch16_224_sentinel2_all_moco/"
"resolve/main/vit_small_patch16_224_sentinel2_all_moco.pth"
),
transforms=_zhu_xlab_transforms,
meta={
"dataset": "SSL4EO-S12",
"in_chans": 13,
"model": "vit_small_patch16_224",
"publication": "https://arxiv.org/abs/2211.07044",
"repo": "https://github.com/zhu-xlab/SSL4EO-S12",
"ssl_method": "moco",
},
)
SENTINEL2_ALL_DINO = Weights(
url=(
"https://huggingface.co/torchgeo/vit_small_patch16_224_sentinel2_all_dino/"
"resolve/main/vit_small_patch16_224_sentinel2_all_dino.pth"
),
transforms=_zhu_xlab_transforms,
meta={
"dataset": "SSL4EO-S12",
"in_chans": 13,
"model": "vit_small_patch16_224",
"publication": "https://arxiv.org/abs/2211.07044",
"repo": "https://github.com/zhu-xlab/SSL4EO-S12",
"ssl_method": "dino",
},
)
def vit_small_patch16_224(
weights: Optional[ViTSmall16_Weights] = None, *args: Any, **kwargs: Any
) -> VisionTransformer:
"""Vision Transform (ViT) small patch size 16 model.
If you use this model in your research, please cite the following paper:
* https://arxiv.org/abs/2010.11929
.. versionadded:: 0.4
Args:
weights: Pre-trained model weights to use.
*args: Additional arguments to pass to :func:`timm.create_model`.
**kwargs: Additional keywork arguments to pass to :func:`timm.create_model`.
Returns:
A ViT small 16 model.
"""
if weights:
kwargs["in_chans"] = weights.meta["in_chans"]
model: VisionTransformer = timm.create_model(
"vit_small_patch16_224", *args, **kwargs
)
if weights:
model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
return model

Просмотреть файл

@ -17,7 +17,9 @@ from kornia.geometry import transform as KorniaTransform
from torch import Tensor, optim from torch import Tensor, optim
from torch.nn.modules import BatchNorm1d, Linear, Module, ReLU, Sequential from torch.nn.modules import BatchNorm1d, Linear, Module, ReLU, Sequential
from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision.models._api import WeightsEnum
from ..models import get_weight
from . import utils from . import utils
@ -323,41 +325,23 @@ class BYOLTask(pl.LightningModule):
def config_task(self) -> None: def config_task(self) -> None:
"""Configures the task based on kwargs parameters passed to the constructor.""" """Configures the task based on kwargs parameters passed to the constructor."""
# Create model
in_channels = self.hyperparams["in_channels"] in_channels = self.hyperparams["in_channels"]
backbone_name = self.hyperparams["backbone"] weights = self.hyperparams["weights"]
imagenet_pretrained = False
custom_pretrained = False
if self.hyperparams["weights"] and not os.path.exists(
self.hyperparams["weights"]
):
if self.hyperparams["weights"] not in ["imagenet", "random"]:
raise ValueError(
f"Weight type '{self.hyperparams['weights']}' is not valid."
)
else:
imagenet_pretrained = self.hyperparams["weights"] == "imagenet"
custom_pretrained = False
else:
custom_pretrained = True
# Create the model
valid_models = timm.list_models(pretrained=imagenet_pretrained)
if backbone_name in valid_models:
backbone = timm.create_model( backbone = timm.create_model(
backbone_name, in_chans=in_channels, pretrained=imagenet_pretrained self.hyperparams["backbone"],
in_chans=in_channels,
pretrained=weights is True,
) )
# Load weights
if weights and weights is not True:
if isinstance(weights, WeightsEnum):
state_dict = weights.get_state_dict(progress=True)
elif os.path.exists(weights):
_, state_dict = utils.extract_backbone(weights)
else: else:
raise ValueError(f"Model type '{backbone_name}' is not a valid timm model.") state_dict = get_weight(weights).get_state_dict(progress=True)
if custom_pretrained:
name, state_dict = utils.extract_backbone(self.hyperparams["weights"])
if self.hyperparams["backbone"] != name:
raise ValueError(
f"Trying to load {name} weights into a "
f"{self.hyperparams['backbone']}"
)
backbone = utils.load_state_dict(backbone, state_dict) backbone = utils.load_state_dict(backbone, state_dict)
self.model = BYOL(backbone, in_channels=in_channels, image_size=(256, 256)) self.model = BYOL(backbone, in_channels=in_channels, image_size=(256, 256))
@ -368,7 +352,9 @@ class BYOLTask(pl.LightningModule):
Keyword Args: Keyword Args:
in_channels: Number of input channels to model in_channels: Number of input channels to model
backbone: Name of the timm model to use backbone: Name of the timm model to use
weights: Either "random" or "imagenet" weights: Either a weight enum, the string representation of a weight enum,
True for ImageNet weights, False or None for random weights,
or the path to a saved model state dict.
learning_rate: Learning rate for optimizer learning_rate: Learning rate for optimizer
learning_rate_schedule_patience: Patience for learning rate scheduler learning_rate_schedule_patience: Patience for learning rate scheduler

Просмотреть файл

@ -22,8 +22,10 @@ from torchmetrics.classification import (
MultilabelAccuracy, MultilabelAccuracy,
MultilabelFBetaScore, MultilabelFBetaScore,
) )
from torchvision.models._api import WeightsEnum
from ..datasets.utils import unbind_samples from ..datasets import unbind_samples
from ..models import get_weight
from . import utils from . import utils
@ -43,44 +45,23 @@ class ClassificationTask(pl.LightningModule):
def config_model(self) -> None: def config_model(self) -> None:
"""Configures the model based on kwargs parameters passed to the constructor.""" """Configures the model based on kwargs parameters passed to the constructor."""
in_channels = self.hyperparams["in_channels"] # Create model
model = self.hyperparams["model"] weights = self.hyperparams["weights"]
imagenet_pretrained = False
custom_pretrained = False
if self.hyperparams["weights"] and not os.path.exists(
self.hyperparams["weights"]
):
if self.hyperparams["weights"] not in ["imagenet", "random"]:
raise ValueError(
f"Weight type '{self.hyperparams['weights']}' is not valid."
)
else:
imagenet_pretrained = self.hyperparams["weights"] == "imagenet"
custom_pretrained = False
else:
custom_pretrained = True
# Create the model
valid_models = timm.list_models(pretrained=imagenet_pretrained)
if model in valid_models:
self.model = timm.create_model( self.model = timm.create_model(
model, self.hyperparams["model"],
num_classes=self.hyperparams["num_classes"], num_classes=self.hyperparams["num_classes"],
in_chans=in_channels, in_chans=self.hyperparams["in_channels"],
pretrained=imagenet_pretrained, pretrained=weights is True,
) )
# Load weights
if weights and weights is not True:
if isinstance(weights, WeightsEnum):
state_dict = weights.get_state_dict(progress=True)
elif os.path.exists(weights):
_, state_dict = utils.extract_backbone(weights)
else: else:
raise ValueError(f"Model type '{model}' is not a valid timm model.") state_dict = get_weight(weights).get_state_dict(progress=True)
if custom_pretrained:
name, state_dict = utils.extract_backbone(self.hyperparams["weights"])
if self.hyperparams["model"] != name:
raise ValueError(
f"Trying to load {name} weights into a "
f"{self.hyperparams['model']}"
)
self.model = utils.load_state_dict(self.model, state_dict) self.model = utils.load_state_dict(self.model, state_dict)
def config_task(self) -> None: def config_task(self) -> None:
@ -102,7 +83,9 @@ class ClassificationTask(pl.LightningModule):
Keyword Args: Keyword Args:
model: Name of the classification model use model: Name of the classification model use
loss: Name of the loss function, accepts 'ce', 'jaccard', or 'focal' loss: Name of the loss function, accepts 'ce', 'jaccard', or 'focal'
weights: Either "random" or "imagenet" weights: Either a weight enum, the string representation of a weight enum,
True for ImageNet weights, False or None for random weights,
or the path to a saved model state dict.
num_classes: Number of prediction classes num_classes: Number of prediction classes
in_channels: Number of input channels to model in_channels: Number of input channels to model
learning_rate: Learning rate for optimizer learning_rate: Learning rate for optimizer
@ -321,12 +304,6 @@ class MultiLabelClassificationTask(ClassificationTask):
""" """
super().__init__(**kwargs) super().__init__(**kwargs)
# Creates `self.hparams` from kwargs
self.save_hyperparameters() # type: ignore[operator]
self.hyperparams = cast(Dict[str, Any], self.hparams)
self.config_task()
self.train_metrics = MetricCollection( self.train_metrics = MetricCollection(
{ {
"OverallAccuracy": MultilabelAccuracy( "OverallAccuracy": MultilabelAccuracy(

Просмотреть файл

@ -8,18 +8,16 @@ from typing import Any, Dict, List, cast
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import torchvision
from packaging.version import parse
from torch import Tensor from torch import Tensor
from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics.detection.mean_ap import MeanAveragePrecision from torchmetrics.detection.mean_ap import MeanAveragePrecision
from torchvision.models import resnet as R
from torchvision.models.detection import FasterRCNN from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torchvision.models.detection.rpn import AnchorGenerator from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.ops import MultiScaleRoIAlign from torchvision.ops import MultiScaleRoIAlign
if parse(torchvision.__version__) >= parse("0.13"): from ..datasets.utils import unbind_samples
from torchvision.models import resnet as R
BACKBONE_WEIGHT_MAP = { BACKBONE_WEIGHT_MAP = {
"resnet18": R.ResNet18_Weights.DEFAULT, "resnet18": R.ResNet18_Weights.DEFAULT,
@ -33,8 +31,6 @@ if parse(torchvision.__version__) >= parse("0.13"):
"wide_resnet101_2": R.Wide_ResNet101_2_Weights.DEFAULT, "wide_resnet101_2": R.Wide_ResNet101_2_Weights.DEFAULT,
} }
from ..datasets.utils import unbind_samples
class ObjectDetectionTask(pl.LightningModule): class ObjectDetectionTask(pl.LightningModule):
"""LightningModule for object detection of images. """LightningModule for object detection of images.
@ -62,15 +58,12 @@ class ObjectDetectionTask(pl.LightningModule):
"backbone_name": self.hyperparams["backbone"], "backbone_name": self.hyperparams["backbone"],
"trainable_layers": self.hyperparams.get("trainable_layers", 3), "trainable_layers": self.hyperparams.get("trainable_layers", 3),
} }
if parse(torchvision.__version__) >= parse("0.13"):
if backbone_pretrained: if backbone_pretrained:
kwargs["weights"] = BACKBONE_WEIGHT_MAP[ kwargs["weights"] = BACKBONE_WEIGHT_MAP[
self.hyperparams["backbone"] self.hyperparams["backbone"]
] ]
else: else:
kwargs["weights"] = None kwargs["weights"] = None
else:
kwargs["pretrained"] = backbone_pretrained
backbone = resnet_fpn_backbone(**kwargs) backbone = resnet_fpn_backbone(**kwargs)
else: else:

Просмотреть файл

@ -14,8 +14,10 @@ import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection
from torchvision.models._api import WeightsEnum
from ..datasets.utils import unbind_samples from ..datasets import unbind_samples
from ..models import get_weight
from . import utils from . import utils
@ -35,44 +37,23 @@ class RegressionTask(pl.LightningModule):
def config_task(self) -> None: def config_task(self) -> None:
"""Configures the task based on kwargs parameters.""" """Configures the task based on kwargs parameters."""
in_channels = self.hyperparams["in_channels"] # Create model
model = self.hyperparams["model"] weights = self.hyperparams["weights"]
imagenet_pretrained = False
custom_pretrained = False
if self.hyperparams["weights"] and not os.path.exists(
self.hyperparams["weights"]
):
if self.hyperparams["weights"] not in ["imagenet", "random"]:
raise ValueError(
f"Weight type '{self.hyperparams['weights']}' is not valid."
)
else:
imagenet_pretrained = self.hyperparams["weights"] == "imagenet"
custom_pretrained = False
else:
custom_pretrained = True
# Create the model
valid_models = timm.list_models(pretrained=imagenet_pretrained)
if model in valid_models:
self.model = timm.create_model( self.model = timm.create_model(
model, self.hyperparams["model"],
num_classes=self.hyperparams["num_outputs"], num_classes=self.hyperparams["num_outputs"],
in_chans=in_channels, in_chans=self.hyperparams["in_channels"],
pretrained=imagenet_pretrained, pretrained=weights is True,
) )
# Load weights
if weights and weights is not True:
if isinstance(weights, WeightsEnum):
state_dict = weights.get_state_dict(progress=True)
elif os.path.exists(weights):
_, state_dict = utils.extract_backbone(weights)
else: else:
raise ValueError(f"Model type '{model}' is not a valid timm model.") state_dict = get_weight(weights).get_state_dict(progress=True)
if custom_pretrained:
name, state_dict = utils.extract_backbone(self.hyperparams["weights"])
if self.hyperparams["model"] != name:
raise ValueError(
f"Trying to load {name} weights into a "
f"{self.hyperparams['model']}"
)
self.model = utils.load_state_dict(self.model, state_dict) self.model = utils.load_state_dict(self.model, state_dict)
def __init__(self, **kwargs: Any) -> None: def __init__(self, **kwargs: Any) -> None:
@ -80,7 +61,9 @@ class RegressionTask(pl.LightningModule):
Keyword Args: Keyword Args:
model: Name of the timm model to use model: Name of the timm model to use
weights: Either "random" or "imagenet" weights: Either a weight enum, the string representation of a weight enum,
True for ImageNet weights, False or None for random weights,
or the path to a saved model state dict.
num_outputs: Number of prediction outputs num_outputs: Number of prediction outputs
in_channels: Number of input channels to model in_channels: Number of input channels to model
learning_rate: Learning rate for optimizer learning_rate: Learning rate for optimizer