зеркало из https://github.com/microsoft/torchgeo.git
Add Multi-Weight Support API (#917)
* load pretrained weights * change name millionaid * restructure and additional weights * rename sentinel1 weights * add vit small weights * forgot to add vit.py * struggling with test * wrong name failing test * feedback on tests * increase test coverage * fix failing test * fix failing test * fix failing test and add vit tests * fix failing vit test * torchgeo.models.utils * forgot utils file * typo num channels * nitpick docs, version torchvision * another try min dependencies * add documentation table * expand pytests to test pretrained weights on tasks * reverse changes to byol task * add tests to init pretrained weights from config * forgot to add the conf files * change path * increase test coverage * vit tests all pass locally including slow * now remote * fix tests another one * add a draft tutorial * run black on tutorial notebook * Tutorial typo fixes * Lower min torch/vision versions * Fix bad rebase * Remove dead code * Flake8 fixes * Consistent in_chans * Black fixes * bison > yacs * Remove one more reference * Download modified weights from hugging face * Add entrypoints * Add torch.hub support * progress arg is required * Fix model loading for resnet18 * Add transforms, update tests * VIT -> ViT * add seco weights * Fix type hints * Link to timm docs * Fix pydocstyle * Try to fix timm docs link * Fix tests * Nuke ignores * Ignore timm links * Add model API methods * Add to __init__ and document * Test model API functions * fix tests * Use correct documentation link for intersphinx * Typos * Fix Windows tests * meth -> func * Explicit function scope * weight-specific filename * Support enums in classification trainer * Update other trainers too * Fix regression tests * Fix classification tests * Fix byol tests * Fix types * progress_bar is required arg * Test weight enums * Fix pickling * Fix regression tests * Improve coverage of classification tests * Improve coverage of BYOL tests * Update resnet table * Update ViT table * Update get_state_dict usage * Remove unused YAML files * Update table widths * Documentation improvements * Tweak tables * Try to fix Windows tests * Revert "Try to fix Windows tests" This reverts commit1325b13ff7
. * Monkeypatch everything * Revert "Monkeypatch everything" This reverts commite3e8d7d042
. * Revert "Revert "Monkeypatch everything"" This reverts commit9b27bd705b
. * Patch things not at the source * Fix missing import Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Родитель
0d04e06791
Коммит
60eb61b5fa
|
@ -10,7 +10,7 @@ experiment:
|
|||
model: "resnet18"
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 6
|
||||
weights: "random"
|
||||
weights: null
|
||||
in_channels: 14
|
||||
num_classes: 19
|
||||
datamodule:
|
||||
|
|
|
@ -6,9 +6,11 @@ experiment:
|
|||
name: cowc_counting_test
|
||||
module:
|
||||
model: resnet18
|
||||
weights: null
|
||||
num_outputs: 1
|
||||
in_channels: 3
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 2
|
||||
pretrained: True
|
||||
datamodule:
|
||||
root: "data/cowc_counting"
|
||||
seed: 0
|
||||
|
|
|
@ -6,9 +6,11 @@ experiment:
|
|||
name: "cyclone_test"
|
||||
module:
|
||||
model: "resnet18"
|
||||
weights: null
|
||||
num_outputs: 1
|
||||
in_channels: 3
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 2
|
||||
pretrained: True
|
||||
datamodule:
|
||||
root: "data/cyclone"
|
||||
seed: 0
|
||||
|
|
|
@ -5,7 +5,7 @@ experiment:
|
|||
model: "resnet18"
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 6
|
||||
weights: "random"
|
||||
weights: null
|
||||
in_channels: 13
|
||||
num_classes: 10
|
||||
datamodule:
|
||||
|
|
|
@ -15,7 +15,6 @@ experiment:
|
|||
module:
|
||||
model: "faster-rcnn"
|
||||
backbone: "resnet50"
|
||||
pretrained: True
|
||||
num_classes: 2
|
||||
learning_rate: 1.2e-4
|
||||
learning_rate_schedule_patience: 6
|
||||
|
|
|
@ -10,7 +10,7 @@ experiment:
|
|||
model: "resnet18"
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 6
|
||||
weights: "random"
|
||||
weights: null
|
||||
in_channels: 3
|
||||
num_classes: 45
|
||||
datamodule:
|
||||
|
|
|
@ -10,7 +10,7 @@ experiment:
|
|||
model: "resnet18"
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 6
|
||||
weights: "random"
|
||||
weights: null
|
||||
in_channels: 3
|
||||
num_classes: 17
|
||||
datamodule:
|
||||
|
|
|
@ -21,7 +21,7 @@ Fully-convolutional Network
|
|||
.. autoclass:: FCN
|
||||
|
||||
FC Siamese Networks
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: FCSiamConc
|
||||
.. autoclass:: FCSiamDiff
|
||||
|
@ -34,4 +34,33 @@ RCF Extractor
|
|||
ResNet
|
||||
^^^^^^
|
||||
|
||||
.. autofunction:: resnet18
|
||||
.. autofunction:: resnet50
|
||||
.. autoclass:: ResNet18_Weights
|
||||
.. autoclass:: ResNet50_Weights
|
||||
|
||||
.. csv-table::
|
||||
:widths: 45 10 10 10 15 10 10 10
|
||||
:header-rows: 1
|
||||
:align: center
|
||||
:file: resnet_pretrained_weights.csv
|
||||
|
||||
Vision Transformer
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: vit_small_patch16_224
|
||||
.. autoclass:: ViTSmall16_Weights
|
||||
|
||||
.. csv-table::
|
||||
:widths: 45 10 10 10 15 10 10 10
|
||||
:header-rows: 1
|
||||
:align: center
|
||||
:file: vit_pretrained_weights.csv
|
||||
|
||||
Utility Functions
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: get_model
|
||||
.. autofunction:: get_model_weights
|
||||
.. autofunction:: get_weight
|
||||
.. autofunction:: list_models
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
Weight,Channels,Source,Citation,BigEarthNet,EuroSAT,So2Sat,OSCD
|
||||
ResNet18_Weights.SENTINEL2_ALL_MOCO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,,,,
|
||||
ResNet18_Weights.SENTINEL2_RGB_MOCO, 3,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,,,,
|
||||
ResNet18_Weights.SENTINEL2_RGB_SECO, 3,`link <https://github.com/ServiceNow/seasonal-contrast>`__,`link <https://arxiv.org/abs/2103.16607>`__,87.27,93.14,,46.94
|
||||
ResNet50_Weights.SENTINEL1_ALL_MOCO, 2,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,,,,
|
||||
ResNet50_Weights.SENTINEL2_ALL_MOCO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,91.8,99.1,60.9,
|
||||
ResNet50_Weights.SENTINEL2_RGB_MOCO, 3,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,,,,
|
||||
ResNet50_Weights.SENTINEL2_ALL_DINO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,90.7,99.1,63.6,
|
||||
ResNet50_Weights.SENTINEL2_RGB_SECO, 3,`link <https://github.com/ServiceNow/seasonal-contrast>`__,`link <https://arxiv.org/abs/2103.16607>`__,87.81,,,
|
|
|
@ -0,0 +1,3 @@
|
|||
Weight,Channels,Source,Citation,BigEarthNet,EuroSAT,So2Sat,OSCD
|
||||
VITSmall16_Weights.SENTINEL2_ALL_MOCO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,89.9,98.6,61.6,
|
||||
VITSmall16_Weights.SENTINEL2_ALL_DINO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,90.5,99.0,62.2,
|
|
13
docs/conf.py
13
docs/conf.py
|
@ -56,14 +56,12 @@ needs_sphinx = "4.0"
|
|||
|
||||
nitpicky = True
|
||||
nitpick_ignore = [
|
||||
# https://github.com/sphinx-doc/sphinx/issues/8127
|
||||
("py:class", ".."),
|
||||
# TODO: can't figure out why this isn't found
|
||||
("py:class", "LightningDataModule"),
|
||||
("py:class", "pytorch_lightning.core.module.LightningModule"),
|
||||
# Undocumented class
|
||||
("py:class", "torchvision.models.resnet.ResNet"),
|
||||
# Undocumented classes
|
||||
("py:class", "segmentation_models_pytorch.base.model.SegmentationModel"),
|
||||
("py:class", "timm.models.resnet.ResNet"),
|
||||
("py:class", "timm.models.vision_transformer.VisionTransformer"),
|
||||
("py:class", "torchvision.models._api.WeightsEnum"),
|
||||
("py:class", "torchvision.models.resnet.ResNet"),
|
||||
]
|
||||
|
||||
|
||||
|
@ -114,6 +112,7 @@ intersphinx_mapping = {
|
|||
"rasterio": ("https://rasterio.readthedocs.io/en/stable/", None),
|
||||
"rtree": ("https://rtree.readthedocs.io/en/stable/", None),
|
||||
"segmentation_models_pytorch": ("https://smp.readthedocs.io/en/stable/", None),
|
||||
"timm": ("https://huggingface.co/docs/timm/main/en/", None),
|
||||
"torch": ("https://pytorch.org/docs/stable", None),
|
||||
"torchvision": ("https://pytorch.org/vision/stable", None),
|
||||
}
|
||||
|
|
|
@ -33,6 +33,7 @@ torchgeo
|
|||
tutorials/indices
|
||||
tutorials/trainers
|
||||
tutorials/benchmarking
|
||||
tutorials/pretrained_weights
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
|
|
@ -0,0 +1,254 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Copyright (c) Microsoft Corporation. All rights reserved.\n",
|
||||
"\n",
|
||||
"Licensed under the MIT License."
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Pretrained Weights\n",
|
||||
"\n",
|
||||
"In this tutorial, we demonstrate some available pretrained weights in TorchGeo. The implementation follows torchvisions' recently introduced [Multi-Weight API](https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/). We will use the [EuroSAT](https://torchgeo.readthedocs.io/en/stable/api/datasets.html#eurosat) dataset throughout this tutorial.\n",
|
||||
"\n",
|
||||
"It's recommended to run this notebook on Google Colab if you don't have your own GPU. Click the \"Open in Colab\" button above to get started."
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Setup\n",
|
||||
"\n",
|
||||
"First, we install TorchGeo."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install torchgeo"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Imports\n",
|
||||
"\n",
|
||||
"Next, we import TorchGeo and any other libraries we need."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%matplotlib inline\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"import csv\n",
|
||||
"import tempfile\n",
|
||||
"\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import numpy as np\n",
|
||||
"import pytorch_lightning as pl\n",
|
||||
"import timm\n",
|
||||
"from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint\n",
|
||||
"from pytorch_lightning.loggers import CSVLogger\n",
|
||||
"\n",
|
||||
"from torchgeo.datamodules import EuroSATDataModule\n",
|
||||
"from torchgeo.trainers import ClassificationTask\n",
|
||||
"from torchgeo.models import ResNet50_Weights, VITSmall16_Weights"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# we set a flag to check to see whether the notebook is currently being run by PyTest, if this is the case then we'll\n",
|
||||
"# skip the expensive training.\n",
|
||||
"in_tests = \"PYTEST_CURRENT_TEST\" in os.environ"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Datamodule\n",
|
||||
"\n",
|
||||
"We will utilize TorchGeo's datamodules from [PyTorch Lightning](https://pytorch-lightning.readthedocs.io/) to organize the dataloader setup."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"root = os.path.join(tempfile.gettempdir(), \"eurosat\")\n",
|
||||
"\n",
|
||||
"datamodule = EuroSATDataModule(root=root, batch_size=64, num_workers=4)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Weights\n",
|
||||
"\n",
|
||||
"Available pretrained weights are listed on the model documentation [page](https://torchgeo.readthedocs.io/en/stable/api/models.html). While some weights only accept RGB channel input, some weights have been pretrained on Sentinel 2 imagery with 13 input channels and can hence prove useful for transfer learning tasks involving Sentinel 2 data.\n",
|
||||
"\n",
|
||||
"To access these weights you can do the following:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"weights = ResNet50_Weights.SENTINEL2_ALL_MOCO"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This set of weights is a torchvision `WeightEnum` and holds information such as the download url link or additional meta data. TorchGeo takes care of the downloading and initialization of models with a desired set of weights. Given that EuroSAT is a classification dataset, we can use a `ClassificationTask` object that holds the model and optimizer object as well as the training logic."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"task = ClassificationTask(\n",
|
||||
" model=\"resnet50\",\n",
|
||||
" loss=\"ce\",\n",
|
||||
" weights=weights,\n",
|
||||
" in_channels=13,\n",
|
||||
" num_classes=10,\n",
|
||||
" learning_rate=0.001,\n",
|
||||
" learning_rate_schedule_patience=5,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If you do not want to utilize the `ClassificationTask` functionality for your experiments, you can also just create a [timm](https://github.com/rwightman/pytorch-image-models) model with pretrained weights from TorchGeo as follows:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"in_chans = weights.meta[\"in_chans\"]\n",
|
||||
"model = timm.create_model(\"resnet50\", in_chans=in_chans, num_classes=10)\n",
|
||||
"model.load_state_dict(weights.get_state_dict(), strict=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Training\n",
|
||||
"\n",
|
||||
"To train our pretrained model on the EuroSAT dataset we will make use of PyTorch Lightning's [Trainer](https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html). For a more elaborate explanation of how TorchGeo uses PyTorch Lightning, check out [this tutorial](https://torchgeo.readthedocs.io/en/stable/tutorials/trainers.html)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"experiment_dir = os.path.join(tempfile.gettempdir(), \"eurosat_results\")\n",
|
||||
"\n",
|
||||
"checkpoint_callback = ModelCheckpoint(\n",
|
||||
" monitor=\"val_loss\", dirpath=experiment_dir, save_top_k=1, save_last=True\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"early_stopping_callback = EarlyStopping(monitor=\"val_loss\", min_delta=0.00, patience=10)\n",
|
||||
"\n",
|
||||
"csv_logger = CSVLogger(save_dir=experiment_dir, name=\"pretrained_weights_logs\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trainer = pl.Trainer(\n",
|
||||
" callbacks=[checkpoint_callback, early_stopping_callback],\n",
|
||||
" logger=[csv_logger],\n",
|
||||
" default_root_dir=experiment_dir,\n",
|
||||
" min_epochs=1,\n",
|
||||
" max_epochs=10,\n",
|
||||
" fast_dev_run=in_tests,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trainer.fit(model=task, datamodule=datamodule)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "torchEnv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.10"
|
||||
},
|
||||
"orig_nbformat": 4,
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "b058dd71d0e7047e70e62f655d92ec955f772479bbe5e5addd202027292e8f60"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -11,12 +11,12 @@ dependencies:
|
|||
- pycocotools>=2
|
||||
- pyproj>=2.2
|
||||
- python>=3.7
|
||||
- pytorch>=1.9
|
||||
- pytorch>=1.12
|
||||
- pyvista>=0.20
|
||||
- rarfile>=3
|
||||
- rasterio>=1.0.20
|
||||
- shapely>=1.3
|
||||
- torchvision>=0.10
|
||||
- torchvision>=0.13
|
||||
- pip:
|
||||
- black[jupyter]>=21.8
|
||||
- flake8>=3.8
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""TorchGeo pre-trained model repository configuration file.
|
||||
|
||||
* https://pytorch.org/hub/
|
||||
* https://pytorch.org/docs/stable/hub.html
|
||||
"""
|
||||
|
||||
from torchgeo.models import resnet18, resnet50, vit_small_patch16_224
|
||||
|
||||
__all__ = ("resnet18", "resnet50", "vit_small_patch16_224")
|
||||
|
||||
dependencies = ["timm"]
|
|
@ -18,9 +18,9 @@ scikit-learn==0.21.0
|
|||
segmentation-models-pytorch==0.2.0
|
||||
shapely==1.3.0
|
||||
timm==0.4.12
|
||||
torch==1.9.0
|
||||
torch==1.12.0
|
||||
torchmetrics==0.10.0
|
||||
torchvision==0.10.0
|
||||
torchvision==0.13.0
|
||||
|
||||
# datasets
|
||||
h5py==2.6.0
|
||||
|
|
|
@ -57,12 +57,12 @@ install_requires =
|
|||
shapely>=1.3,<3
|
||||
# timm 0.4.12 required by segmentation-models-pytorch
|
||||
timm>=0.4.12,<0.7
|
||||
# torch 1.9+ required by torchvision
|
||||
torch>=1.9,<2
|
||||
# torch 1.12+ required by torchvision
|
||||
torch>=1.12,<2
|
||||
# torchmetrics 0.10+ required for binary/multiclass/multilabel classification metrics
|
||||
torchmetrics>=0.10,<0.12
|
||||
# torchvision 0.10+ required for torchvision.utils.draw_segmentation_masks
|
||||
torchvision>=0.10,<0.15
|
||||
# torchvision 0.13+ required for torchvision.models._api.WeightsEnum
|
||||
torchvision>=0.13,<0.15
|
||||
python_requires = ~= 3.7
|
||||
packages = find:
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ experiment:
|
|||
model: "resnet18"
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 6
|
||||
weights: "random"
|
||||
weights: null
|
||||
in_channels: 14
|
||||
num_classes: 19
|
||||
datamodule:
|
||||
|
|
|
@ -5,7 +5,7 @@ experiment:
|
|||
model: "resnet18"
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 6
|
||||
weights: "random"
|
||||
weights: null
|
||||
in_channels: 2
|
||||
num_classes: 19
|
||||
datamodule:
|
||||
|
|
|
@ -5,7 +5,7 @@ experiment:
|
|||
model: "resnet18"
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 6
|
||||
weights: "random"
|
||||
weights: null
|
||||
in_channels: 12
|
||||
num_classes: 19
|
||||
datamodule:
|
||||
|
|
|
@ -1,21 +0,0 @@
|
|||
experiment:
|
||||
task: "ssl"
|
||||
name: "test_byol"
|
||||
module:
|
||||
model: "byol"
|
||||
backbone: "resnet18"
|
||||
input_channels: 4
|
||||
weights: imagenet
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 6
|
||||
datamodule:
|
||||
root: "tests/data/chesapeake/cvpr"
|
||||
download: true
|
||||
train_splits:
|
||||
- "de-test"
|
||||
val_splits:
|
||||
- "de-test"
|
||||
test_splits:
|
||||
- "de-test"
|
||||
batch_size: 1
|
||||
num_workers: 0
|
|
@ -10,7 +10,7 @@ experiment:
|
|||
num_classes: 7
|
||||
num_filters: 1
|
||||
ignore_index: null
|
||||
weights: imagenet
|
||||
weights: null
|
||||
datamodule:
|
||||
root: "tests/data/chesapeake/cvpr"
|
||||
download: true
|
||||
|
|
|
@ -10,7 +10,7 @@ experiment:
|
|||
num_classes: 5
|
||||
num_filters: 1
|
||||
ignore_index: null
|
||||
weights: imagenet
|
||||
weights: null
|
||||
datamodule:
|
||||
root: "tests/data/chesapeake/cvpr"
|
||||
download: true
|
||||
|
|
|
@ -2,12 +2,11 @@ experiment:
|
|||
task: cowc_counting
|
||||
module:
|
||||
model: resnet18
|
||||
weights: "random"
|
||||
weights: null
|
||||
num_outputs: 1
|
||||
in_channels: 3
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 2
|
||||
pretrained: True
|
||||
datamodule:
|
||||
root: "tests/data/cowc_counting"
|
||||
download: true
|
||||
|
|
|
@ -2,12 +2,11 @@ experiment:
|
|||
task: "cyclone"
|
||||
module:
|
||||
model: "resnet18"
|
||||
weights: "random"
|
||||
weights: null
|
||||
num_outputs: 1
|
||||
in_channels: 3
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 2
|
||||
pretrained: False
|
||||
datamodule:
|
||||
root: "tests/data/cyclone"
|
||||
download: true
|
||||
|
|
|
@ -5,7 +5,7 @@ experiment:
|
|||
model: "resnet18"
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 6
|
||||
weights: "random"
|
||||
weights: null
|
||||
in_channels: 13
|
||||
num_classes: 2
|
||||
datamodule:
|
||||
|
|
|
@ -5,7 +5,7 @@ experiment:
|
|||
model: "resnet18"
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 6
|
||||
weights: "random"
|
||||
weights: null
|
||||
in_channels: 3
|
||||
num_classes: 3
|
||||
datamodule:
|
||||
|
|
|
@ -5,7 +5,7 @@ experiment:
|
|||
model: "resnet18"
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 6
|
||||
weights: "random"
|
||||
weights: null
|
||||
in_channels: 3
|
||||
num_classes: 17
|
||||
datamodule:
|
||||
|
|
|
@ -5,7 +5,7 @@ experiment:
|
|||
model: "resnet18"
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 6
|
||||
weights: "random"
|
||||
weights: null
|
||||
in_channels: 3
|
||||
num_classes: 17
|
||||
datamodule:
|
||||
|
|
|
@ -3,7 +3,7 @@ experiment:
|
|||
module:
|
||||
loss: "ce"
|
||||
model: "resnet18"
|
||||
weights: "random"
|
||||
weights: null
|
||||
learning_rate: 1e-3
|
||||
learning_rate_schedule_patience: 6
|
||||
in_channels: 3
|
||||
|
|
Двоичные данные
tests/data/models/resnet50-sentinel2-2.pt.zip
Двоичные данные
tests/data/models/resnet50-sentinel2-2.pt.zip
Двоичный файл не отображается.
|
@ -0,0 +1,50 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import enum
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
import torch.nn as nn
|
||||
from torchvision.models._api import WeightsEnum
|
||||
|
||||
from torchgeo.models import (
|
||||
ResNet18_Weights,
|
||||
ResNet50_Weights,
|
||||
ViTSmall16_Weights,
|
||||
get_model,
|
||||
get_model_weights,
|
||||
get_weight,
|
||||
list_models,
|
||||
resnet18,
|
||||
resnet50,
|
||||
vit_small_patch16_224,
|
||||
)
|
||||
|
||||
builders = [resnet18, resnet50, vit_small_patch16_224]
|
||||
enums = [ResNet18_Weights, ResNet50_Weights, ViTSmall16_Weights]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("builder", builders)
|
||||
def test_get_model(builder: Callable[..., nn.Module]) -> None:
|
||||
model = get_model(builder.__name__)
|
||||
assert isinstance(model, nn.Module)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("builder", builders)
|
||||
def test_get_model_weights(builder: Callable[..., nn.Module]) -> None:
|
||||
weights = get_model_weights(builder)
|
||||
assert isinstance(weights, enum.EnumMeta)
|
||||
weights = get_model_weights(builder.__name__)
|
||||
assert isinstance(weights, enum.EnumMeta)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enum", enums)
|
||||
def test_get_weight(enum: WeightsEnum) -> None:
|
||||
for weight in enum:
|
||||
assert weight == get_weight(str(weight))
|
||||
|
||||
|
||||
def test_list_models() -> None:
|
||||
models = [builder.__name__ for builder in builders]
|
||||
assert set(models) == set(list_models())
|
|
@ -1,61 +1,74 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
import timm
|
||||
import torch
|
||||
import torchvision
|
||||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from torch.nn.modules import Module
|
||||
from torchvision.models._api import WeightsEnum
|
||||
|
||||
import torchgeo.models.resnet
|
||||
from torchgeo.datasets.utils import extract_archive
|
||||
from torchgeo.models import resnet50
|
||||
from torchgeo.models import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50
|
||||
|
||||
|
||||
def load_state_dict_from_file(
|
||||
file: str,
|
||||
model_dir: Optional[str] = None,
|
||||
map_location: Optional[Any] = None,
|
||||
progress: Optional[bool] = True,
|
||||
check_hash: Optional[bool] = False,
|
||||
file_name: Optional[str] = None,
|
||||
) -> Any:
|
||||
"""Mockup of ``torch.hub.load_state_dict_from_url``."""
|
||||
return torch.load(file)
|
||||
def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
||||
state_dict: Dict[str, Any] = torch.load(url)
|
||||
return state_dict
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_class,sensor,bands,in_channels,num_classes",
|
||||
[(resnet50, "sentinel2", "all", 10, 17)],
|
||||
)
|
||||
def test_resnet(
|
||||
monkeypatch: MonkeyPatch,
|
||||
tmp_path: Path,
|
||||
model_class: Module,
|
||||
sensor: str,
|
||||
bands: str,
|
||||
in_channels: int,
|
||||
num_classes: int,
|
||||
) -> None:
|
||||
extract_archive(
|
||||
os.path.join("tests", "data", "models", "resnet50-sentinel2-2.pt.zip"),
|
||||
str(tmp_path),
|
||||
)
|
||||
class TestResNet18:
|
||||
@pytest.fixture(params=[*ResNet18_Weights])
|
||||
def weights(self, request: SubRequest) -> WeightsEnum:
|
||||
return request.param
|
||||
|
||||
new_model_urls = {
|
||||
"sentinel2": {"all": {"resnet50": str(tmp_path / "resnet50-sentinel2-2.pt")}}
|
||||
}
|
||||
@pytest.fixture
|
||||
def mocked_weights(
|
||||
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
|
||||
) -> WeightsEnum:
|
||||
path = tmp_path / f"{weights}.pth"
|
||||
model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"])
|
||||
torch.save(model.state_dict(), path)
|
||||
monkeypatch.setattr(weights, "url", str(path))
|
||||
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
|
||||
return weights
|
||||
|
||||
monkeypatch.setattr(torchgeo.models.resnet, "MODEL_URLS", new_model_urls)
|
||||
monkeypatch.setattr(
|
||||
torchgeo.models.resnet, "load_state_dict_from_url", load_state_dict_from_file
|
||||
)
|
||||
def test_resnet(self) -> None:
|
||||
resnet18()
|
||||
|
||||
model = model_class(sensor, bands, pretrained=True)
|
||||
x = torch.zeros(1, in_channels, 256, 256)
|
||||
y = model(x)
|
||||
assert isinstance(y, torch.Tensor)
|
||||
assert y.size() == torch.Size([1, 17])
|
||||
def test_resnet_weights(self, mocked_weights: WeightsEnum) -> None:
|
||||
resnet18(weights=mocked_weights)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_resnet_download(self, weights: WeightsEnum) -> None:
|
||||
resnet18(weights=weights)
|
||||
|
||||
|
||||
class TestResNet50:
|
||||
@pytest.fixture(params=[*ResNet50_Weights])
|
||||
def weights(self, request: SubRequest) -> WeightsEnum:
|
||||
return request.param
|
||||
|
||||
@pytest.fixture
|
||||
def mocked_weights(
|
||||
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
|
||||
) -> WeightsEnum:
|
||||
path = tmp_path / f"{weights}.pth"
|
||||
model = timm.create_model("resnet50", in_chans=weights.meta["in_chans"])
|
||||
torch.save(model.state_dict(), path)
|
||||
monkeypatch.setattr(weights, "url", str(path))
|
||||
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
|
||||
return weights
|
||||
|
||||
def test_resnet(self) -> None:
|
||||
resnet50()
|
||||
|
||||
def test_resnet_weights(self, mocked_weights: WeightsEnum) -> None:
|
||||
resnet50(weights=mocked_weights)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_resnet_download(self, weights: WeightsEnum) -> None:
|
||||
resnet50(weights=weights)
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
import timm
|
||||
import torch
|
||||
import torchvision
|
||||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from torchvision.models._api import WeightsEnum
|
||||
|
||||
from torchgeo.models import ViTSmall16_Weights, vit_small_patch16_224
|
||||
|
||||
|
||||
def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
||||
state_dict: Dict[str, Any] = torch.load(url)
|
||||
return state_dict
|
||||
|
||||
|
||||
class TestViTSmall16:
|
||||
@pytest.fixture(params=[*ViTSmall16_Weights])
|
||||
def weights(self, request: SubRequest) -> WeightsEnum:
|
||||
return request.param
|
||||
|
||||
@pytest.fixture
|
||||
def mocked_weights(
|
||||
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
|
||||
) -> WeightsEnum:
|
||||
path = tmp_path / f"{weights}.pth"
|
||||
model = timm.create_model(
|
||||
weights.meta["model"], in_chans=weights.meta["in_chans"]
|
||||
)
|
||||
torch.save(model.state_dict(), path)
|
||||
monkeypatch.setattr(weights, "url", str(path))
|
||||
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
|
||||
return weights
|
||||
|
||||
def test_vit(self) -> None:
|
||||
vit_small_patch16_224()
|
||||
|
||||
def test_vit_weights(self, mocked_weights: WeightsEnum) -> None:
|
||||
vit_small_patch16_224(weights=mocked_weights)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_vit_download(self, weights: WeightsEnum) -> None:
|
||||
vit_small_patch16_224(weights=weights)
|
|
@ -2,21 +2,33 @@
|
|||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Type, cast
|
||||
|
||||
import pytest
|
||||
import timm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from omegaconf import OmegaConf
|
||||
from pytorch_lightning import LightningDataModule, Trainer
|
||||
from torchvision.models import resnet18
|
||||
from torchvision.models._api import WeightsEnum
|
||||
|
||||
from torchgeo.datamodules import ChesapeakeCVPRDataModule
|
||||
from torchgeo.models import ResNet18_Weights
|
||||
from torchgeo.trainers import BYOLTask
|
||||
from torchgeo.trainers.byol import BYOL, SimCLRAugmentation
|
||||
|
||||
from .test_utils import SegmentationTestModel
|
||||
|
||||
|
||||
def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
||||
state_dict: Dict[str, Any] = torch.load(url)
|
||||
return state_dict
|
||||
|
||||
|
||||
class TestBYOL:
|
||||
def test_custom_augment_fn(self) -> None:
|
||||
backbone = resnet18()
|
||||
|
@ -45,7 +57,7 @@ class TestBYOLTask:
|
|||
def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None:
|
||||
conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))
|
||||
conf_dict = OmegaConf.to_object(conf.experiment)
|
||||
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)
|
||||
conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict)
|
||||
|
||||
# Instantiate datamodule
|
||||
datamodule_kwargs = conf_dict["datamodule"]
|
||||
|
@ -64,30 +76,31 @@ class TestBYOLTask:
|
|||
trainer.predict(model=model, dataloaders=datamodule.val_dataloader())
|
||||
|
||||
@pytest.fixture
|
||||
def model_kwargs(self) -> Dict[Any, Any]:
|
||||
return {"backbone": "resnet18", "weights": "random", "in_channels": 3}
|
||||
def model_kwargs(self) -> Dict[str, Any]:
|
||||
return {"backbone": "resnet18", "weights": None, "in_channels": 3}
|
||||
|
||||
def test_invalid_pretrained(
|
||||
self, model_kwargs: Dict[Any, Any], checkpoint: str
|
||||
@pytest.fixture
|
||||
def mocked_weights(self, tmp_path: Path, monkeypatch: MonkeyPatch) -> WeightsEnum:
|
||||
weights = ResNet18_Weights.SENTINEL2_RGB_MOCO
|
||||
path = tmp_path / f"{weights}.pth"
|
||||
model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"])
|
||||
torch.save(model.state_dict(), path)
|
||||
monkeypatch.setattr(weights, "url", str(path))
|
||||
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
|
||||
return weights
|
||||
|
||||
def test_weight_file(self, model_kwargs: Dict[str, Any], checkpoint: str) -> None:
|
||||
model_kwargs["weights"] = checkpoint
|
||||
BYOLTask(**model_kwargs)
|
||||
|
||||
def test_weight_enum(
|
||||
self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum
|
||||
) -> None:
|
||||
model_kwargs["weights"] = checkpoint
|
||||
model_kwargs["backbone"] = "resnet50"
|
||||
match = "Trying to load resnet18 weights into a resnet50"
|
||||
with pytest.raises(ValueError, match=match):
|
||||
model_kwargs["weights"] = mocked_weights
|
||||
BYOLTask(**model_kwargs)
|
||||
|
||||
def test_pretrained(self, model_kwargs: Dict[Any, Any], checkpoint: str) -> None:
|
||||
model_kwargs["weights"] = checkpoint
|
||||
BYOLTask(**model_kwargs)
|
||||
|
||||
def test_invalid_backbone(self, model_kwargs: Dict[Any, Any]) -> None:
|
||||
model_kwargs["backbone"] = "invalid_backbone"
|
||||
match = "Model type 'invalid_backbone' is not a valid timm model."
|
||||
with pytest.raises(ValueError, match=match):
|
||||
BYOLTask(**model_kwargs)
|
||||
|
||||
def test_invalid_weights(self, model_kwargs: Dict[Any, Any]) -> None:
|
||||
model_kwargs["weights"] = "invalid_weights"
|
||||
match = "Weight type 'invalid_weights' is not valid."
|
||||
with pytest.raises(ValueError, match=match):
|
||||
def test_weight_str(
|
||||
self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum
|
||||
) -> None:
|
||||
model_kwargs["weights"] = str(mocked_weights)
|
||||
BYOLTask(**model_kwargs)
|
||||
|
|
|
@ -2,14 +2,18 @@
|
|||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Type, cast
|
||||
|
||||
import pytest
|
||||
import timm
|
||||
import torch
|
||||
import torchvision
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from omegaconf import OmegaConf
|
||||
from pytorch_lightning import LightningDataModule, Trainer
|
||||
from torch.nn.modules import Module
|
||||
from torchvision.models._api import WeightsEnum
|
||||
|
||||
from torchgeo.datamodules import (
|
||||
BigEarthNetDataModule,
|
||||
|
@ -18,6 +22,7 @@ from torchgeo.datamodules import (
|
|||
So2SatDataModule,
|
||||
UCMercedDataModule,
|
||||
)
|
||||
from torchgeo.models import ResNet18_Weights
|
||||
from torchgeo.trainers import ClassificationTask, MultiLabelClassificationTask
|
||||
|
||||
from .test_utils import ClassificationTestModel
|
||||
|
@ -27,6 +32,11 @@ def create_model(*args: Any, **kwargs: Any) -> Module:
|
|||
return ClassificationTestModel(**kwargs)
|
||||
|
||||
|
||||
def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
||||
state_dict: Dict[str, Any] = torch.load(url)
|
||||
return state_dict
|
||||
|
||||
|
||||
class TestClassificationTask:
|
||||
@pytest.mark.parametrize(
|
||||
"name,classname",
|
||||
|
@ -46,7 +56,7 @@ class TestClassificationTask:
|
|||
|
||||
conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))
|
||||
conf_dict = OmegaConf.to_object(conf.experiment)
|
||||
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)
|
||||
conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict)
|
||||
|
||||
# Instantiate datamodule
|
||||
datamodule_kwargs = conf_dict["datamodule"]
|
||||
|
@ -66,7 +76,7 @@ class TestClassificationTask:
|
|||
def test_no_logger(self) -> None:
|
||||
conf = OmegaConf.load(os.path.join("tests", "conf", "ucmerced.yaml"))
|
||||
conf_dict = OmegaConf.to_object(conf.experiment)
|
||||
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)
|
||||
conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict)
|
||||
|
||||
# Instantiate datamodule
|
||||
datamodule_kwargs = conf_dict["datamodule"]
|
||||
|
@ -83,49 +93,52 @@ class TestClassificationTask:
|
|||
trainer.fit(model=model, datamodule=datamodule)
|
||||
|
||||
@pytest.fixture
|
||||
def model_kwargs(self) -> Dict[Any, Any]:
|
||||
def model_kwargs(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"model": "resnet18",
|
||||
"in_channels": 13,
|
||||
"loss": "ce",
|
||||
"num_classes": 10,
|
||||
"weights": "random",
|
||||
"weights": None,
|
||||
}
|
||||
|
||||
def test_pretrained(self, model_kwargs: Dict[Any, Any], checkpoint: str) -> None:
|
||||
@pytest.fixture
|
||||
def mocked_weights(self, tmp_path: Path, monkeypatch: MonkeyPatch) -> WeightsEnum:
|
||||
weights = ResNet18_Weights.SENTINEL2_ALL_MOCO
|
||||
path = tmp_path / f"{weights}.pth"
|
||||
model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"])
|
||||
torch.save(model.state_dict(), path)
|
||||
monkeypatch.setattr(weights, "url", str(path))
|
||||
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
|
||||
return weights
|
||||
|
||||
def test_weight_file(self, model_kwargs: Dict[str, Any], checkpoint: str) -> None:
|
||||
model_kwargs["weights"] = checkpoint
|
||||
with pytest.warns(UserWarning):
|
||||
ClassificationTask(**model_kwargs)
|
||||
|
||||
def test_invalid_pretrained(
|
||||
self, model_kwargs: Dict[Any, Any], checkpoint: str
|
||||
def test_weight_enum(
|
||||
self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum
|
||||
) -> None:
|
||||
model_kwargs["weights"] = checkpoint
|
||||
model_kwargs["model"] = "resnet50"
|
||||
match = "Trying to load resnet18 weights into a resnet50"
|
||||
with pytest.raises(ValueError, match=match):
|
||||
model_kwargs["weights"] = mocked_weights
|
||||
with pytest.warns(UserWarning):
|
||||
ClassificationTask(**model_kwargs)
|
||||
|
||||
def test_invalid_loss(self, model_kwargs: Dict[Any, Any]) -> None:
|
||||
def test_weight_str(
|
||||
self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum
|
||||
) -> None:
|
||||
model_kwargs["weights"] = str(mocked_weights)
|
||||
with pytest.warns(UserWarning):
|
||||
ClassificationTask(**model_kwargs)
|
||||
|
||||
def test_invalid_loss(self, model_kwargs: Dict[str, Any]) -> None:
|
||||
model_kwargs["loss"] = "invalid_loss"
|
||||
match = "Loss type 'invalid_loss' is not valid."
|
||||
with pytest.raises(ValueError, match=match):
|
||||
ClassificationTask(**model_kwargs)
|
||||
|
||||
def test_invalid_model(self, model_kwargs: Dict[Any, Any]) -> None:
|
||||
model_kwargs["model"] = "invalid_model"
|
||||
match = "Model type 'invalid_model' is not a valid timm model."
|
||||
with pytest.raises(ValueError, match=match):
|
||||
ClassificationTask(**model_kwargs)
|
||||
|
||||
def test_invalid_weights(self, model_kwargs: Dict[Any, Any]) -> None:
|
||||
model_kwargs["weights"] = "invalid_weights"
|
||||
match = "Weight type 'invalid_weights' is not valid."
|
||||
with pytest.raises(ValueError, match=match):
|
||||
ClassificationTask(**model_kwargs)
|
||||
|
||||
def test_missing_attributes(
|
||||
self, model_kwargs: Dict[Any, Any], monkeypatch: MonkeyPatch
|
||||
self, model_kwargs: Dict[str, Any], monkeypatch: MonkeyPatch
|
||||
) -> None:
|
||||
monkeypatch.delattr(EuroSATDataModule, "plot")
|
||||
datamodule = EuroSATDataModule(
|
||||
|
@ -150,7 +163,7 @@ class TestMultiLabelClassificationTask:
|
|||
) -> None:
|
||||
conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))
|
||||
conf_dict = OmegaConf.to_object(conf.experiment)
|
||||
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)
|
||||
conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict)
|
||||
|
||||
# Instantiate datamodule
|
||||
datamodule_kwargs = conf_dict["datamodule"]
|
||||
|
@ -170,7 +183,7 @@ class TestMultiLabelClassificationTask:
|
|||
def test_no_logger(self) -> None:
|
||||
conf = OmegaConf.load(os.path.join("tests", "conf", "bigearthnet_s1.yaml"))
|
||||
conf_dict = OmegaConf.to_object(conf.experiment)
|
||||
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)
|
||||
conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict)
|
||||
|
||||
# Instantiate datamodule
|
||||
datamodule_kwargs = conf_dict["datamodule"]
|
||||
|
@ -187,23 +200,23 @@ class TestMultiLabelClassificationTask:
|
|||
trainer.fit(model=model, datamodule=datamodule)
|
||||
|
||||
@pytest.fixture
|
||||
def model_kwargs(self) -> Dict[Any, Any]:
|
||||
def model_kwargs(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"model": "resnet18",
|
||||
"in_channels": 14,
|
||||
"loss": "bce",
|
||||
"num_classes": 19,
|
||||
"weights": "random",
|
||||
"weights": None,
|
||||
}
|
||||
|
||||
def test_invalid_loss(self, model_kwargs: Dict[Any, Any]) -> None:
|
||||
def test_invalid_loss(self, model_kwargs: Dict[str, Any]) -> None:
|
||||
model_kwargs["loss"] = "invalid_loss"
|
||||
match = "Loss type 'invalid_loss' is not valid."
|
||||
with pytest.raises(ValueError, match=match):
|
||||
MultiLabelClassificationTask(**model_kwargs)
|
||||
|
||||
def test_missing_attributes(
|
||||
self, model_kwargs: Dict[Any, Any], monkeypatch: MonkeyPatch
|
||||
self, model_kwargs: Dict[str, Any], monkeypatch: MonkeyPatch
|
||||
) -> None:
|
||||
monkeypatch.delattr(BigEarthNetDataModule, "plot")
|
||||
datamodule = BigEarthNetDataModule(
|
||||
|
|
|
@ -2,18 +2,30 @@
|
|||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Type, cast
|
||||
|
||||
import pytest
|
||||
import timm
|
||||
import torch
|
||||
import torchvision
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from omegaconf import OmegaConf
|
||||
from pytorch_lightning import LightningDataModule, Trainer
|
||||
from torchvision.models._api import WeightsEnum
|
||||
|
||||
from torchgeo.datamodules import COWCCountingDataModule, TropicalCycloneDataModule
|
||||
from torchgeo.models import ResNet18_Weights
|
||||
from torchgeo.trainers import RegressionTask
|
||||
|
||||
from .test_utils import RegressionTestModel
|
||||
|
||||
|
||||
def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
||||
state_dict: Dict[str, Any] = torch.load(url)
|
||||
return state_dict
|
||||
|
||||
|
||||
class TestRegressionTask:
|
||||
@pytest.mark.parametrize(
|
||||
"name,classname",
|
||||
|
@ -25,7 +37,7 @@ class TestRegressionTask:
|
|||
def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None:
|
||||
conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))
|
||||
conf_dict = OmegaConf.to_object(conf.experiment)
|
||||
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)
|
||||
conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict)
|
||||
|
||||
# Instantiate datamodule
|
||||
datamodule_kwargs = conf_dict["datamodule"]
|
||||
|
@ -46,7 +58,7 @@ class TestRegressionTask:
|
|||
def test_no_logger(self) -> None:
|
||||
conf = OmegaConf.load(os.path.join("tests", "conf", "cyclone.yaml"))
|
||||
conf_dict = OmegaConf.to_object(conf.experiment)
|
||||
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)
|
||||
conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict)
|
||||
|
||||
# Instantiate datamodule
|
||||
datamodule_kwargs = conf_dict["datamodule"]
|
||||
|
@ -63,36 +75,39 @@ class TestRegressionTask:
|
|||
trainer.fit(model=model, datamodule=datamodule)
|
||||
|
||||
@pytest.fixture
|
||||
def model_kwargs(self) -> Dict[Any, Any]:
|
||||
def model_kwargs(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"model": "resnet18",
|
||||
"weights": "random",
|
||||
"weights": None,
|
||||
"num_outputs": 1,
|
||||
"in_channels": 3,
|
||||
}
|
||||
|
||||
def test_invalid_pretrained(
|
||||
self, model_kwargs: Dict[Any, Any], checkpoint: str
|
||||
) -> None:
|
||||
model_kwargs["weights"] = checkpoint
|
||||
model_kwargs["model"] = "resnet50"
|
||||
match = "Trying to load resnet18 weights into a resnet50"
|
||||
with pytest.raises(ValueError, match=match):
|
||||
RegressionTask(**model_kwargs)
|
||||
@pytest.fixture
|
||||
def mocked_weights(self, tmp_path: Path, monkeypatch: MonkeyPatch) -> WeightsEnum:
|
||||
weights = ResNet18_Weights.SENTINEL2_RGB_MOCO
|
||||
path = tmp_path / f"{weights}.pth"
|
||||
model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"])
|
||||
torch.save(model.state_dict(), path)
|
||||
monkeypatch.setattr(weights, "url", str(path))
|
||||
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
|
||||
return weights
|
||||
|
||||
def test_pretrained(self, model_kwargs: Dict[Any, Any], checkpoint: str) -> None:
|
||||
def test_weight_file(self, model_kwargs: Dict[str, Any], checkpoint: str) -> None:
|
||||
model_kwargs["weights"] = checkpoint
|
||||
with pytest.warns(UserWarning):
|
||||
RegressionTask(**model_kwargs)
|
||||
|
||||
def test_invalid_model(self, model_kwargs: Dict[Any, Any]) -> None:
|
||||
model_kwargs["model"] = "invalid_model"
|
||||
match = "Model type 'invalid_model' is not a valid timm model."
|
||||
with pytest.raises(ValueError, match=match):
|
||||
def test_weight_enum(
|
||||
self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum
|
||||
) -> None:
|
||||
model_kwargs["weights"] = mocked_weights
|
||||
with pytest.warns(UserWarning):
|
||||
RegressionTask(**model_kwargs)
|
||||
|
||||
def test_invalid_weights(self, model_kwargs: Dict[Any, Any]) -> None:
|
||||
model_kwargs["weights"] = "invalid_weights"
|
||||
match = "Weight type 'invalid_weights' is not valid."
|
||||
with pytest.raises(ValueError, match=match):
|
||||
def test_weight_str(
|
||||
self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum
|
||||
) -> None:
|
||||
model_kwargs["weights"] = str(mocked_weights)
|
||||
with pytest.warns(UserWarning):
|
||||
RegressionTask(**model_kwargs)
|
||||
|
|
|
@ -3,14 +3,17 @@
|
|||
|
||||
"""TorchGeo models."""
|
||||
|
||||
from .api import get_model, get_model_weights, get_weight, list_models
|
||||
from .changestar import ChangeMixin, ChangeStar, ChangeStarFarSeg
|
||||
from .farseg import FarSeg
|
||||
from .fcn import FCN
|
||||
from .fcsiam import FCSiamConc, FCSiamDiff
|
||||
from .rcf import RCF
|
||||
from .resnet import resnet50
|
||||
from .resnet import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50
|
||||
from .vit import ViTSmall16_Weights, vit_small_patch16_224
|
||||
|
||||
__all__ = (
|
||||
# models
|
||||
"ChangeMixin",
|
||||
"ChangeStar",
|
||||
"ChangeStarFarSeg",
|
||||
|
@ -19,5 +22,16 @@ __all__ = (
|
|||
"FCSiamConc",
|
||||
"FCSiamDiff",
|
||||
"RCF",
|
||||
"resnet18",
|
||||
"resnet50",
|
||||
"vit_small_patch16_224",
|
||||
# weights
|
||||
"ResNet50_Weights",
|
||||
"ResNet18_Weights",
|
||||
"ViTSmall16_Weights",
|
||||
# utilities
|
||||
"get_model",
|
||||
"get_model_weights",
|
||||
"get_weight",
|
||||
"list_models",
|
||||
)
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""APIs for querying and loading pre-trained model weights.
|
||||
|
||||
See the following references for design details:
|
||||
|
||||
* https://pytorch.org/blog/easily-list-and-initialize-models-with-new-apis-in-torchvision/
|
||||
* https://pytorch.org/vision/stable/models.html
|
||||
* https://github.com/pytorch/vision/blob/main/torchvision/models/_api.py
|
||||
""" # noqa: E501
|
||||
|
||||
from typing import Any, Callable, List, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from torchvision.models._api import WeightsEnum
|
||||
|
||||
from .resnet import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50
|
||||
from .vit import ViTSmall16_Weights, vit_small_patch16_224
|
||||
|
||||
_model = {
|
||||
"resnet18": resnet18,
|
||||
"resnet50": resnet50,
|
||||
"vit_small_patch16_224": vit_small_patch16_224,
|
||||
}
|
||||
|
||||
_model_weights = {
|
||||
resnet18: ResNet18_Weights,
|
||||
resnet50: ResNet50_Weights,
|
||||
vit_small_patch16_224: ViTSmall16_Weights,
|
||||
"resnet18": ResNet18_Weights,
|
||||
"resnet50": ResNet50_Weights,
|
||||
"vit_small_patch16_224": ViTSmall16_Weights,
|
||||
}
|
||||
|
||||
|
||||
def get_model(name: str, *args: Any, **kwargs: Any) -> nn.Module:
|
||||
"""Get an instantiated model from its name.
|
||||
|
||||
.. versionadded:: 0.4
|
||||
|
||||
Args:
|
||||
name: Name of the model.
|
||||
*args: Additional arguments passed to the model builder method.
|
||||
**kwargs: Additional keyword arguments passed to the model builder method.
|
||||
|
||||
Returns:
|
||||
An instantiated model.
|
||||
"""
|
||||
model: nn.Module = _model[name](*args, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def get_model_weights(name: Union[Callable[..., nn.Module], str]) -> WeightsEnum:
|
||||
"""Get the weights enum class associated with a given model.
|
||||
|
||||
.. versionadded:: 0.4
|
||||
|
||||
Args:
|
||||
name: Model builder function or the name under which it is registered.
|
||||
|
||||
Returns:
|
||||
The weights enum class associated with the model.
|
||||
"""
|
||||
return _model_weights[name]
|
||||
|
||||
|
||||
def get_weight(name: str) -> WeightsEnum:
|
||||
"""Get the weights enum value by its full name.
|
||||
|
||||
.. versionadded:: 0.4
|
||||
|
||||
Args:
|
||||
name: Name of the weight enum entry.
|
||||
|
||||
Returns:
|
||||
The requested weight enum.
|
||||
"""
|
||||
return eval(name)
|
||||
|
||||
|
||||
def list_models() -> List[str]:
|
||||
"""List the registered models.
|
||||
|
||||
.. versionadded:: 0.4
|
||||
|
||||
Returns:
|
||||
A list of registered models.
|
||||
"""
|
||||
return list(_model.keys())
|
|
@ -9,7 +9,6 @@ from typing import List, cast
|
|||
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
from packaging.version import parse
|
||||
from torch import Tensor
|
||||
from torch.nn.modules import (
|
||||
BatchNorm2d,
|
||||
|
@ -62,7 +61,6 @@ class FarSeg(Module):
|
|||
else:
|
||||
raise ValueError(f"unknown backbone: {backbone}.")
|
||||
kwargs = {}
|
||||
if parse(torchvision.__version__) >= parse("0.13"):
|
||||
if backbone_pretrained:
|
||||
kwargs = {
|
||||
"weights": getattr(
|
||||
|
@ -71,8 +69,6 @@ class FarSeg(Module):
|
|||
}
|
||||
else:
|
||||
kwargs = {"weights": None}
|
||||
else:
|
||||
kwargs = {"pretrained": backbone_pretrained}
|
||||
|
||||
self.backbone = getattr(resnet, backbone)(**kwargs)
|
||||
|
||||
|
|
|
@ -3,83 +3,208 @@
|
|||
|
||||
"""Pre-trained ResNet models."""
|
||||
|
||||
from typing import Any, List, Type, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import kornia.augmentation as K
|
||||
import timm
|
||||
import torch.nn as nn
|
||||
from torch.hub import load_state_dict_from_url
|
||||
from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet
|
||||
from timm.models import ResNet
|
||||
from torchvision.models._api import Weights, WeightsEnum
|
||||
|
||||
MODEL_URLS = {
|
||||
"sentinel2": {
|
||||
"all": {
|
||||
"resnet50": "https://zenodo.org/record/5610000/files/resnet50-sentinel2.pt"
|
||||
}
|
||||
}
|
||||
}
|
||||
from ..transforms import AugmentationSequential
|
||||
|
||||
__all__ = ["ResNet50_Weights", "ResNet18_Weights"]
|
||||
|
||||
_zhu_xlab_transforms = AugmentationSequential(
|
||||
K.Resize(256), K.CenterCrop(224), data_keys=["image"]
|
||||
)
|
||||
|
||||
# https://github.com/pytorch/vision/pull/6883
|
||||
# https://github.com/pytorch/vision/pull/7107
|
||||
# Can be removed once torchvision>=0.15 is required
|
||||
Weights.__deepcopy__ = lambda *args, **kwargs: args[0]
|
||||
|
||||
|
||||
IN_CHANNELS = {"sentinel2": {"all": 10}}
|
||||
class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
|
||||
"""ResNet18 weights.
|
||||
|
||||
NUM_CLASSES = {"sentinel2": 17}
|
||||
For `timm <https://github.com/rwightman/pytorch-image-models>`_
|
||||
*resnet18* implementation.
|
||||
|
||||
.. versionadded:: 0.4
|
||||
"""
|
||||
|
||||
SENTINEL2_ALL_MOCO = Weights(
|
||||
url=(
|
||||
"https://huggingface.co/torchgeo/resnet18_sentinel2_all_moco/"
|
||||
"resolve/main/resnet18_sentinel2_all_moco.pth"
|
||||
),
|
||||
transforms=_zhu_xlab_transforms,
|
||||
meta={
|
||||
"dataset": "SSL4EO-S12",
|
||||
"in_chans": 13,
|
||||
"model": "resnet18",
|
||||
"publication": "https://arxiv.org/abs/2211.07044",
|
||||
"repo": "https://github.com/zhu-xlab/SSL4EO-S12",
|
||||
"ssl_method": "moco",
|
||||
},
|
||||
)
|
||||
|
||||
SENTINEL2_RGB_MOCO = Weights(
|
||||
url=(
|
||||
"https://huggingface.co/torchgeo/resnet18_sentinel2_rgb_moco/"
|
||||
"resolve/main/resnet18_sentinel2_rgb_moco.pth"
|
||||
),
|
||||
transforms=_zhu_xlab_transforms,
|
||||
meta={
|
||||
"dataset": "SSL4EO-S12",
|
||||
"in_chans": 3,
|
||||
"model": "resnet18",
|
||||
"publication": "https://arxiv.org/abs/2211.07044",
|
||||
"repo": "https://github.com/zhu-xlab/SSL4EO-S12",
|
||||
"ssl_method": "moco",
|
||||
},
|
||||
)
|
||||
|
||||
SENTINEL2_RGB_SECO = Weights(
|
||||
url=(
|
||||
"https://huggingface.co/torchgeo/resnet18_sentinel2_rgb_seco/"
|
||||
"resolve/main/resnet18_sentinel2_rgb_seco.ckpt"
|
||||
),
|
||||
transforms=nn.Identity(),
|
||||
meta={
|
||||
"dataset": "SeCo Dataset",
|
||||
"in_chans": 3,
|
||||
"model": "resnet18",
|
||||
"publication": "https://arxiv.org/abs/2103.16607",
|
||||
"repo": "https://github.com/ServiceNow/seasonal-contrast",
|
||||
"ssl_method": "seco",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _resnet(
|
||||
sensor: str,
|
||||
bands: str,
|
||||
arch: str,
|
||||
block: Type[Union[BasicBlock, Bottleneck]],
|
||||
layers: List[int],
|
||||
pretrained: bool,
|
||||
progress: bool,
|
||||
**kwargs: Any,
|
||||
class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
|
||||
"""ResNet50 weights.
|
||||
|
||||
For `timm <https://github.com/rwightman/pytorch-image-models>`_
|
||||
*resnet50* implementation.
|
||||
|
||||
.. versionadded:: 0.4
|
||||
"""
|
||||
|
||||
SENTINEL1_ALL_MOCO = Weights(
|
||||
url=(
|
||||
"https://huggingface.co/torchgeo/resnet50_sentinel1_all_moco/"
|
||||
"resolve/main/resnet50_sentinel1_all_moco.pth"
|
||||
),
|
||||
transforms=_zhu_xlab_transforms,
|
||||
meta={
|
||||
"dataset": "SSL4EO-S12",
|
||||
"in_chans": 2,
|
||||
"model": "resnet50",
|
||||
"publication": "https://arxiv.org/abs/2211.07044",
|
||||
"repo": "https://github.com/zhu-xlab/SSL4EO-S12",
|
||||
"ssl_method": "moco",
|
||||
},
|
||||
)
|
||||
|
||||
SENTINEL2_ALL_MOCO = Weights(
|
||||
url=(
|
||||
"https://huggingface.co/torchgeo/resnet50_sentinel2_all_moco/"
|
||||
"resolve/main/resnet50_sentinel2_all_moco.pth"
|
||||
),
|
||||
transforms=_zhu_xlab_transforms,
|
||||
meta={
|
||||
"dataset": "SSL4EO-S12",
|
||||
"in_chans": 13,
|
||||
"model": "resnet50",
|
||||
"publication": "https://arxiv.org/abs/2211.07044",
|
||||
"repo": "https://github.com/zhu-xlab/SSL4EO-S12",
|
||||
"ssl_method": "moco",
|
||||
},
|
||||
)
|
||||
|
||||
SENTINEL2_RGB_MOCO = Weights(
|
||||
url=(
|
||||
"https://huggingface.co/torchgeo/resnet50_sentinel2_rgb_moco/"
|
||||
"resolve/main/resnet50_sentinel2_rgb_moco.pth"
|
||||
),
|
||||
transforms=_zhu_xlab_transforms,
|
||||
meta={
|
||||
"dataset": "SSL4EO-S12",
|
||||
"in_chans": 3,
|
||||
"model": "resnet50",
|
||||
"publication": "https://arxiv.org/abs/2211.07044",
|
||||
"repo": "https://github.com/zhu-xlab/SSL4EO-S12",
|
||||
"ssl_method": "moco",
|
||||
},
|
||||
)
|
||||
|
||||
SENTINEL2_ALL_DINO = Weights(
|
||||
url=(
|
||||
"https://huggingface.co/torchgeo/resnet50_sentinel2_all_dino/"
|
||||
"resolve/main/resnet50_sentinel2_all_dino.pth"
|
||||
),
|
||||
transforms=_zhu_xlab_transforms,
|
||||
meta={
|
||||
"dataset": "SSL4EO-S12",
|
||||
"in_chans": 13,
|
||||
"model": "resnet50",
|
||||
"publication": "https://arxiv.org/abs/2211.07044",
|
||||
"repo": "https://github.com/zhu-xlab/SSL4EO-S12",
|
||||
"ssl_method": "dino",
|
||||
},
|
||||
)
|
||||
|
||||
SENTINEL2_RGB_SECO = Weights(
|
||||
url=(
|
||||
"https://huggingface.co/torchgeo/resnet50_sentinel2_rgb_seco/"
|
||||
"resolve/main/resnet50_sentinel2_rgb_seco.ckpt"
|
||||
),
|
||||
transforms=nn.Identity(),
|
||||
meta={
|
||||
"dataset": "SeCo Dataset",
|
||||
"in_chans": 3,
|
||||
"model": "resnet50",
|
||||
"publication": "https://arxiv.org/abs/2103.16607",
|
||||
"repo": "https://github.com/ServiceNow/seasonal-contrast",
|
||||
"ssl_method": "seco",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def resnet18(
|
||||
weights: Optional[ResNet18_Weights] = None, *args: Any, **kwargs: Any
|
||||
) -> ResNet:
|
||||
"""Resnet model.
|
||||
"""ResNet-18 model.
|
||||
|
||||
If you use this model in your research, please cite the following paper:
|
||||
|
||||
* https://arxiv.org/pdf/1512.03385.pdf
|
||||
|
||||
.. versionadded:: 0.4
|
||||
|
||||
Args:
|
||||
sensor: imagery source which determines number of input channels
|
||||
bands: which spectral bands to consider: "all", "rgb", etc.
|
||||
arch: ResNet version specifying number of layers
|
||||
block: type of network block
|
||||
layers: number of layers per block
|
||||
pretrained: if True, returns a model pre-trained on ``sensor`` imagery
|
||||
progress: if True, displays a progress bar of the download to stderr
|
||||
weights: Pre-trained model weights to use.
|
||||
*args: Additional arguments to pass to :func:`timm.create_model`
|
||||
**kwargs: Additional keywork arguments to pass to :func:`timm.create_model`
|
||||
|
||||
Returns:
|
||||
A ResNet-50 model
|
||||
A ResNet-18 model.
|
||||
"""
|
||||
# Initialize a new model
|
||||
model = ResNet(block, layers, NUM_CLASSES[sensor], **kwargs)
|
||||
if weights:
|
||||
kwargs["in_chans"] = weights.meta["in_chans"]
|
||||
|
||||
# Replace the first layer with the correct number of input channels
|
||||
model.conv1 = nn.Conv2d(
|
||||
IN_CHANNELS[sensor][bands],
|
||||
out_channels=64,
|
||||
kernel_size=7,
|
||||
stride=1,
|
||||
padding=2,
|
||||
bias=False,
|
||||
)
|
||||
model: ResNet = timm.create_model("resnet18", *args, **kwargs)
|
||||
|
||||
# Load pretrained weights
|
||||
if pretrained:
|
||||
state_dict = load_state_dict_from_url(
|
||||
MODEL_URLS[sensor][bands][arch], progress=progress
|
||||
)
|
||||
model.load_state_dict(state_dict)
|
||||
if weights:
|
||||
model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def resnet50(
|
||||
sensor: str,
|
||||
bands: str,
|
||||
pretrained: bool = False,
|
||||
progress: bool = True,
|
||||
**kwargs: Any,
|
||||
weights: Optional[ResNet50_Weights] = None, *args: Any, **kwargs: Any
|
||||
) -> ResNet:
|
||||
"""ResNet-50 model.
|
||||
|
||||
|
@ -87,22 +212,23 @@ def resnet50(
|
|||
|
||||
* https://arxiv.org/pdf/1512.03385.pdf
|
||||
|
||||
.. versionchanged:: 0.4
|
||||
Switched to multi-weight support API.
|
||||
|
||||
Args:
|
||||
sensor: imagery source which determines number of input channels
|
||||
bands: which spectral bands to consider: "all", "rgb", etc.
|
||||
pretrained: if True, returns a model pre-trained on ``sensor`` imagery
|
||||
progress: if True, displays a progress bar of the download to stderr
|
||||
weights: Pre-trained model weights to use.
|
||||
*args: Additional arguments to pass to :func:`timm.create_model`.
|
||||
**kwargs: Additional keywork arguments to pass to :func:`timm.create_model`.
|
||||
|
||||
Returns:
|
||||
A ResNet-50 model
|
||||
A ResNet-50 model.
|
||||
"""
|
||||
return _resnet(
|
||||
sensor,
|
||||
bands,
|
||||
"resnet50",
|
||||
Bottleneck,
|
||||
[3, 4, 6, 3],
|
||||
pretrained,
|
||||
progress,
|
||||
**kwargs,
|
||||
)
|
||||
if weights:
|
||||
kwargs["in_chans"] = weights.meta["in_chans"]
|
||||
|
||||
model: ResNet = timm.create_model("resnet50", *args, **kwargs)
|
||||
|
||||
if weights:
|
||||
model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
|
||||
|
||||
return model
|
||||
|
|
|
@ -0,0 +1,98 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Pre-trained Vision Transformer models."""
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import kornia.augmentation as K
|
||||
import timm
|
||||
from timm.models.vision_transformer import VisionTransformer
|
||||
from torchvision.models._api import Weights, WeightsEnum
|
||||
|
||||
from ..transforms import AugmentationSequential
|
||||
|
||||
__all__ = ["ViTSmall16_Weights"]
|
||||
|
||||
_zhu_xlab_transforms = AugmentationSequential(
|
||||
K.Resize(256), K.CenterCrop(224), data_keys=["image"]
|
||||
)
|
||||
|
||||
# https://github.com/pytorch/vision/pull/6883
|
||||
# https://github.com/pytorch/vision/pull/7107
|
||||
# Can be removed once torchvision>=0.15 is required
|
||||
Weights.__deepcopy__ = lambda *args, **kwargs: args[0]
|
||||
|
||||
|
||||
class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
|
||||
"""Vision Transformer Samll Patch Size 16 weights.
|
||||
|
||||
For `timm <https://github.com/rwightman/pytorch-image-models>`_
|
||||
*vit_small_patch16_224* implementation.
|
||||
|
||||
.. versionadded:: 0.4
|
||||
"""
|
||||
|
||||
SENTINEL2_ALL_MOCO = Weights(
|
||||
url=(
|
||||
"https://huggingface.co/torchgeo/vit_small_patch16_224_sentinel2_all_moco/"
|
||||
"resolve/main/vit_small_patch16_224_sentinel2_all_moco.pth"
|
||||
),
|
||||
transforms=_zhu_xlab_transforms,
|
||||
meta={
|
||||
"dataset": "SSL4EO-S12",
|
||||
"in_chans": 13,
|
||||
"model": "vit_small_patch16_224",
|
||||
"publication": "https://arxiv.org/abs/2211.07044",
|
||||
"repo": "https://github.com/zhu-xlab/SSL4EO-S12",
|
||||
"ssl_method": "moco",
|
||||
},
|
||||
)
|
||||
|
||||
SENTINEL2_ALL_DINO = Weights(
|
||||
url=(
|
||||
"https://huggingface.co/torchgeo/vit_small_patch16_224_sentinel2_all_dino/"
|
||||
"resolve/main/vit_small_patch16_224_sentinel2_all_dino.pth"
|
||||
),
|
||||
transforms=_zhu_xlab_transforms,
|
||||
meta={
|
||||
"dataset": "SSL4EO-S12",
|
||||
"in_chans": 13,
|
||||
"model": "vit_small_patch16_224",
|
||||
"publication": "https://arxiv.org/abs/2211.07044",
|
||||
"repo": "https://github.com/zhu-xlab/SSL4EO-S12",
|
||||
"ssl_method": "dino",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def vit_small_patch16_224(
|
||||
weights: Optional[ViTSmall16_Weights] = None, *args: Any, **kwargs: Any
|
||||
) -> VisionTransformer:
|
||||
"""Vision Transform (ViT) small patch size 16 model.
|
||||
|
||||
If you use this model in your research, please cite the following paper:
|
||||
|
||||
* https://arxiv.org/abs/2010.11929
|
||||
|
||||
.. versionadded:: 0.4
|
||||
|
||||
Args:
|
||||
weights: Pre-trained model weights to use.
|
||||
*args: Additional arguments to pass to :func:`timm.create_model`.
|
||||
**kwargs: Additional keywork arguments to pass to :func:`timm.create_model`.
|
||||
|
||||
Returns:
|
||||
A ViT small 16 model.
|
||||
"""
|
||||
if weights:
|
||||
kwargs["in_chans"] = weights.meta["in_chans"]
|
||||
|
||||
model: VisionTransformer = timm.create_model(
|
||||
"vit_small_patch16_224", *args, **kwargs
|
||||
)
|
||||
|
||||
if weights:
|
||||
model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
|
||||
|
||||
return model
|
|
@ -17,7 +17,9 @@ from kornia.geometry import transform as KorniaTransform
|
|||
from torch import Tensor, optim
|
||||
from torch.nn.modules import BatchNorm1d, Linear, Module, ReLU, Sequential
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
from torchvision.models._api import WeightsEnum
|
||||
|
||||
from ..models import get_weight
|
||||
from . import utils
|
||||
|
||||
|
||||
|
@ -323,41 +325,23 @@ class BYOLTask(pl.LightningModule):
|
|||
|
||||
def config_task(self) -> None:
|
||||
"""Configures the task based on kwargs parameters passed to the constructor."""
|
||||
# Create model
|
||||
in_channels = self.hyperparams["in_channels"]
|
||||
backbone_name = self.hyperparams["backbone"]
|
||||
|
||||
imagenet_pretrained = False
|
||||
custom_pretrained = False
|
||||
if self.hyperparams["weights"] and not os.path.exists(
|
||||
self.hyperparams["weights"]
|
||||
):
|
||||
if self.hyperparams["weights"] not in ["imagenet", "random"]:
|
||||
raise ValueError(
|
||||
f"Weight type '{self.hyperparams['weights']}' is not valid."
|
||||
)
|
||||
else:
|
||||
imagenet_pretrained = self.hyperparams["weights"] == "imagenet"
|
||||
custom_pretrained = False
|
||||
else:
|
||||
custom_pretrained = True
|
||||
|
||||
# Create the model
|
||||
valid_models = timm.list_models(pretrained=imagenet_pretrained)
|
||||
if backbone_name in valid_models:
|
||||
weights = self.hyperparams["weights"]
|
||||
backbone = timm.create_model(
|
||||
backbone_name, in_chans=in_channels, pretrained=imagenet_pretrained
|
||||
self.hyperparams["backbone"],
|
||||
in_chans=in_channels,
|
||||
pretrained=weights is True,
|
||||
)
|
||||
|
||||
# Load weights
|
||||
if weights and weights is not True:
|
||||
if isinstance(weights, WeightsEnum):
|
||||
state_dict = weights.get_state_dict(progress=True)
|
||||
elif os.path.exists(weights):
|
||||
_, state_dict = utils.extract_backbone(weights)
|
||||
else:
|
||||
raise ValueError(f"Model type '{backbone_name}' is not a valid timm model.")
|
||||
|
||||
if custom_pretrained:
|
||||
name, state_dict = utils.extract_backbone(self.hyperparams["weights"])
|
||||
|
||||
if self.hyperparams["backbone"] != name:
|
||||
raise ValueError(
|
||||
f"Trying to load {name} weights into a "
|
||||
f"{self.hyperparams['backbone']}"
|
||||
)
|
||||
state_dict = get_weight(weights).get_state_dict(progress=True)
|
||||
backbone = utils.load_state_dict(backbone, state_dict)
|
||||
|
||||
self.model = BYOL(backbone, in_channels=in_channels, image_size=(256, 256))
|
||||
|
@ -368,7 +352,9 @@ class BYOLTask(pl.LightningModule):
|
|||
Keyword Args:
|
||||
in_channels: Number of input channels to model
|
||||
backbone: Name of the timm model to use
|
||||
weights: Either "random" or "imagenet"
|
||||
weights: Either a weight enum, the string representation of a weight enum,
|
||||
True for ImageNet weights, False or None for random weights,
|
||||
or the path to a saved model state dict.
|
||||
learning_rate: Learning rate for optimizer
|
||||
learning_rate_schedule_patience: Patience for learning rate scheduler
|
||||
|
||||
|
|
|
@ -22,8 +22,10 @@ from torchmetrics.classification import (
|
|||
MultilabelAccuracy,
|
||||
MultilabelFBetaScore,
|
||||
)
|
||||
from torchvision.models._api import WeightsEnum
|
||||
|
||||
from ..datasets.utils import unbind_samples
|
||||
from ..datasets import unbind_samples
|
||||
from ..models import get_weight
|
||||
from . import utils
|
||||
|
||||
|
||||
|
@ -43,44 +45,23 @@ class ClassificationTask(pl.LightningModule):
|
|||
|
||||
def config_model(self) -> None:
|
||||
"""Configures the model based on kwargs parameters passed to the constructor."""
|
||||
in_channels = self.hyperparams["in_channels"]
|
||||
model = self.hyperparams["model"]
|
||||
|
||||
imagenet_pretrained = False
|
||||
custom_pretrained = False
|
||||
if self.hyperparams["weights"] and not os.path.exists(
|
||||
self.hyperparams["weights"]
|
||||
):
|
||||
if self.hyperparams["weights"] not in ["imagenet", "random"]:
|
||||
raise ValueError(
|
||||
f"Weight type '{self.hyperparams['weights']}' is not valid."
|
||||
)
|
||||
else:
|
||||
imagenet_pretrained = self.hyperparams["weights"] == "imagenet"
|
||||
custom_pretrained = False
|
||||
else:
|
||||
custom_pretrained = True
|
||||
|
||||
# Create the model
|
||||
valid_models = timm.list_models(pretrained=imagenet_pretrained)
|
||||
if model in valid_models:
|
||||
# Create model
|
||||
weights = self.hyperparams["weights"]
|
||||
self.model = timm.create_model(
|
||||
model,
|
||||
self.hyperparams["model"],
|
||||
num_classes=self.hyperparams["num_classes"],
|
||||
in_chans=in_channels,
|
||||
pretrained=imagenet_pretrained,
|
||||
in_chans=self.hyperparams["in_channels"],
|
||||
pretrained=weights is True,
|
||||
)
|
||||
|
||||
# Load weights
|
||||
if weights and weights is not True:
|
||||
if isinstance(weights, WeightsEnum):
|
||||
state_dict = weights.get_state_dict(progress=True)
|
||||
elif os.path.exists(weights):
|
||||
_, state_dict = utils.extract_backbone(weights)
|
||||
else:
|
||||
raise ValueError(f"Model type '{model}' is not a valid timm model.")
|
||||
|
||||
if custom_pretrained:
|
||||
name, state_dict = utils.extract_backbone(self.hyperparams["weights"])
|
||||
|
||||
if self.hyperparams["model"] != name:
|
||||
raise ValueError(
|
||||
f"Trying to load {name} weights into a "
|
||||
f"{self.hyperparams['model']}"
|
||||
)
|
||||
state_dict = get_weight(weights).get_state_dict(progress=True)
|
||||
self.model = utils.load_state_dict(self.model, state_dict)
|
||||
|
||||
def config_task(self) -> None:
|
||||
|
@ -102,7 +83,9 @@ class ClassificationTask(pl.LightningModule):
|
|||
Keyword Args:
|
||||
model: Name of the classification model use
|
||||
loss: Name of the loss function, accepts 'ce', 'jaccard', or 'focal'
|
||||
weights: Either "random" or "imagenet"
|
||||
weights: Either a weight enum, the string representation of a weight enum,
|
||||
True for ImageNet weights, False or None for random weights,
|
||||
or the path to a saved model state dict.
|
||||
num_classes: Number of prediction classes
|
||||
in_channels: Number of input channels to model
|
||||
learning_rate: Learning rate for optimizer
|
||||
|
@ -321,12 +304,6 @@ class MultiLabelClassificationTask(ClassificationTask):
|
|||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Creates `self.hparams` from kwargs
|
||||
self.save_hyperparameters() # type: ignore[operator]
|
||||
self.hyperparams = cast(Dict[str, Any], self.hparams)
|
||||
|
||||
self.config_task()
|
||||
|
||||
self.train_metrics = MetricCollection(
|
||||
{
|
||||
"OverallAccuracy": MultilabelAccuracy(
|
||||
|
|
|
@ -8,18 +8,16 @@ from typing import Any, Dict, List, cast
|
|||
import matplotlib.pyplot as plt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchvision
|
||||
from packaging.version import parse
|
||||
from torch import Tensor
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
from torchmetrics.detection.mean_ap import MeanAveragePrecision
|
||||
from torchvision.models import resnet as R
|
||||
from torchvision.models.detection import FasterRCNN
|
||||
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
|
||||
from torchvision.models.detection.rpn import AnchorGenerator
|
||||
from torchvision.ops import MultiScaleRoIAlign
|
||||
|
||||
if parse(torchvision.__version__) >= parse("0.13"):
|
||||
from torchvision.models import resnet as R
|
||||
from ..datasets.utils import unbind_samples
|
||||
|
||||
BACKBONE_WEIGHT_MAP = {
|
||||
"resnet18": R.ResNet18_Weights.DEFAULT,
|
||||
|
@ -33,8 +31,6 @@ if parse(torchvision.__version__) >= parse("0.13"):
|
|||
"wide_resnet101_2": R.Wide_ResNet101_2_Weights.DEFAULT,
|
||||
}
|
||||
|
||||
from ..datasets.utils import unbind_samples
|
||||
|
||||
|
||||
class ObjectDetectionTask(pl.LightningModule):
|
||||
"""LightningModule for object detection of images.
|
||||
|
@ -62,15 +58,12 @@ class ObjectDetectionTask(pl.LightningModule):
|
|||
"backbone_name": self.hyperparams["backbone"],
|
||||
"trainable_layers": self.hyperparams.get("trainable_layers", 3),
|
||||
}
|
||||
if parse(torchvision.__version__) >= parse("0.13"):
|
||||
if backbone_pretrained:
|
||||
kwargs["weights"] = BACKBONE_WEIGHT_MAP[
|
||||
self.hyperparams["backbone"]
|
||||
]
|
||||
else:
|
||||
kwargs["weights"] = None
|
||||
else:
|
||||
kwargs["pretrained"] = backbone_pretrained
|
||||
|
||||
backbone = resnet_fpn_backbone(**kwargs)
|
||||
else:
|
||||
|
|
|
@ -14,8 +14,10 @@ import torch.nn.functional as F
|
|||
from torch import Tensor
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection
|
||||
from torchvision.models._api import WeightsEnum
|
||||
|
||||
from ..datasets.utils import unbind_samples
|
||||
from ..datasets import unbind_samples
|
||||
from ..models import get_weight
|
||||
from . import utils
|
||||
|
||||
|
||||
|
@ -35,44 +37,23 @@ class RegressionTask(pl.LightningModule):
|
|||
|
||||
def config_task(self) -> None:
|
||||
"""Configures the task based on kwargs parameters."""
|
||||
in_channels = self.hyperparams["in_channels"]
|
||||
model = self.hyperparams["model"]
|
||||
|
||||
imagenet_pretrained = False
|
||||
custom_pretrained = False
|
||||
if self.hyperparams["weights"] and not os.path.exists(
|
||||
self.hyperparams["weights"]
|
||||
):
|
||||
if self.hyperparams["weights"] not in ["imagenet", "random"]:
|
||||
raise ValueError(
|
||||
f"Weight type '{self.hyperparams['weights']}' is not valid."
|
||||
)
|
||||
else:
|
||||
imagenet_pretrained = self.hyperparams["weights"] == "imagenet"
|
||||
custom_pretrained = False
|
||||
else:
|
||||
custom_pretrained = True
|
||||
|
||||
# Create the model
|
||||
valid_models = timm.list_models(pretrained=imagenet_pretrained)
|
||||
if model in valid_models:
|
||||
# Create model
|
||||
weights = self.hyperparams["weights"]
|
||||
self.model = timm.create_model(
|
||||
model,
|
||||
self.hyperparams["model"],
|
||||
num_classes=self.hyperparams["num_outputs"],
|
||||
in_chans=in_channels,
|
||||
pretrained=imagenet_pretrained,
|
||||
in_chans=self.hyperparams["in_channels"],
|
||||
pretrained=weights is True,
|
||||
)
|
||||
|
||||
# Load weights
|
||||
if weights and weights is not True:
|
||||
if isinstance(weights, WeightsEnum):
|
||||
state_dict = weights.get_state_dict(progress=True)
|
||||
elif os.path.exists(weights):
|
||||
_, state_dict = utils.extract_backbone(weights)
|
||||
else:
|
||||
raise ValueError(f"Model type '{model}' is not a valid timm model.")
|
||||
|
||||
if custom_pretrained:
|
||||
name, state_dict = utils.extract_backbone(self.hyperparams["weights"])
|
||||
|
||||
if self.hyperparams["model"] != name:
|
||||
raise ValueError(
|
||||
f"Trying to load {name} weights into a "
|
||||
f"{self.hyperparams['model']}"
|
||||
)
|
||||
state_dict = get_weight(weights).get_state_dict(progress=True)
|
||||
self.model = utils.load_state_dict(self.model, state_dict)
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
|
@ -80,7 +61,9 @@ class RegressionTask(pl.LightningModule):
|
|||
|
||||
Keyword Args:
|
||||
model: Name of the timm model to use
|
||||
weights: Either "random" or "imagenet"
|
||||
weights: Either a weight enum, the string representation of a weight enum,
|
||||
True for ImageNet weights, False or None for random weights,
|
||||
or the path to a saved model state dict.
|
||||
num_outputs: Number of prediction outputs
|
||||
in_channels: Number of input channels to model
|
||||
learning_rate: Learning rate for optimizer
|
||||
|
|
Загрузка…
Ссылка в новой задаче