зеркало из https://github.com/microsoft/torchgeo.git
Add Multi-Weight Support API (#917)
* load pretrained weights * change name millionaid * restructure and additional weights * rename sentinel1 weights * add vit small weights * forgot to add vit.py * struggling with test * wrong name failing test * feedback on tests * increase test coverage * fix failing test * fix failing test * fix failing test and add vit tests * fix failing vit test * torchgeo.models.utils * forgot utils file * typo num channels * nitpick docs, version torchvision * another try min dependencies * add documentation table * expand pytests to test pretrained weights on tasks * reverse changes to byol task * add tests to init pretrained weights from config * forgot to add the conf files * change path * increase test coverage * vit tests all pass locally including slow * now remote * fix tests another one * add a draft tutorial * run black on tutorial notebook * Tutorial typo fixes * Lower min torch/vision versions * Fix bad rebase * Remove dead code * Flake8 fixes * Consistent in_chans * Black fixes * bison > yacs * Remove one more reference * Download modified weights from hugging face * Add entrypoints * Add torch.hub support * progress arg is required * Fix model loading for resnet18 * Add transforms, update tests * VIT -> ViT * add seco weights * Fix type hints * Link to timm docs * Fix pydocstyle * Try to fix timm docs link * Fix tests * Nuke ignores * Ignore timm links * Add model API methods * Add to __init__ and document * Test model API functions * fix tests * Use correct documentation link for intersphinx * Typos * Fix Windows tests * meth -> func * Explicit function scope * weight-specific filename * Support enums in classification trainer * Update other trainers too * Fix regression tests * Fix classification tests * Fix byol tests * Fix types * progress_bar is required arg * Test weight enums * Fix pickling * Fix regression tests * Improve coverage of classification tests * Improve coverage of BYOL tests * Update resnet table * Update ViT table * Update get_state_dict usage * Remove unused YAML files * Update table widths * Documentation improvements * Tweak tables * Try to fix Windows tests * Revert "Try to fix Windows tests" This reverts commit1325b13ff7
. * Monkeypatch everything * Revert "Monkeypatch everything" This reverts commite3e8d7d042
. * Revert "Revert "Monkeypatch everything"" This reverts commit9b27bd705b
. * Patch things not at the source * Fix missing import Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Родитель
0d04e06791
Коммит
60eb61b5fa
|
@ -10,7 +10,7 @@ experiment:
|
||||||
model: "resnet18"
|
model: "resnet18"
|
||||||
learning_rate: 1e-3
|
learning_rate: 1e-3
|
||||||
learning_rate_schedule_patience: 6
|
learning_rate_schedule_patience: 6
|
||||||
weights: "random"
|
weights: null
|
||||||
in_channels: 14
|
in_channels: 14
|
||||||
num_classes: 19
|
num_classes: 19
|
||||||
datamodule:
|
datamodule:
|
||||||
|
|
|
@ -6,9 +6,11 @@ experiment:
|
||||||
name: cowc_counting_test
|
name: cowc_counting_test
|
||||||
module:
|
module:
|
||||||
model: resnet18
|
model: resnet18
|
||||||
|
weights: null
|
||||||
|
num_outputs: 1
|
||||||
|
in_channels: 3
|
||||||
learning_rate: 1e-3
|
learning_rate: 1e-3
|
||||||
learning_rate_schedule_patience: 2
|
learning_rate_schedule_patience: 2
|
||||||
pretrained: True
|
|
||||||
datamodule:
|
datamodule:
|
||||||
root: "data/cowc_counting"
|
root: "data/cowc_counting"
|
||||||
seed: 0
|
seed: 0
|
||||||
|
|
|
@ -6,9 +6,11 @@ experiment:
|
||||||
name: "cyclone_test"
|
name: "cyclone_test"
|
||||||
module:
|
module:
|
||||||
model: "resnet18"
|
model: "resnet18"
|
||||||
|
weights: null
|
||||||
|
num_outputs: 1
|
||||||
|
in_channels: 3
|
||||||
learning_rate: 1e-3
|
learning_rate: 1e-3
|
||||||
learning_rate_schedule_patience: 2
|
learning_rate_schedule_patience: 2
|
||||||
pretrained: True
|
|
||||||
datamodule:
|
datamodule:
|
||||||
root: "data/cyclone"
|
root: "data/cyclone"
|
||||||
seed: 0
|
seed: 0
|
||||||
|
|
|
@ -5,7 +5,7 @@ experiment:
|
||||||
model: "resnet18"
|
model: "resnet18"
|
||||||
learning_rate: 1e-3
|
learning_rate: 1e-3
|
||||||
learning_rate_schedule_patience: 6
|
learning_rate_schedule_patience: 6
|
||||||
weights: "random"
|
weights: null
|
||||||
in_channels: 13
|
in_channels: 13
|
||||||
num_classes: 10
|
num_classes: 10
|
||||||
datamodule:
|
datamodule:
|
||||||
|
|
|
@ -15,7 +15,6 @@ experiment:
|
||||||
module:
|
module:
|
||||||
model: "faster-rcnn"
|
model: "faster-rcnn"
|
||||||
backbone: "resnet50"
|
backbone: "resnet50"
|
||||||
pretrained: True
|
|
||||||
num_classes: 2
|
num_classes: 2
|
||||||
learning_rate: 1.2e-4
|
learning_rate: 1.2e-4
|
||||||
learning_rate_schedule_patience: 6
|
learning_rate_schedule_patience: 6
|
||||||
|
|
|
@ -10,7 +10,7 @@ experiment:
|
||||||
model: "resnet18"
|
model: "resnet18"
|
||||||
learning_rate: 1e-3
|
learning_rate: 1e-3
|
||||||
learning_rate_schedule_patience: 6
|
learning_rate_schedule_patience: 6
|
||||||
weights: "random"
|
weights: null
|
||||||
in_channels: 3
|
in_channels: 3
|
||||||
num_classes: 45
|
num_classes: 45
|
||||||
datamodule:
|
datamodule:
|
||||||
|
|
|
@ -10,7 +10,7 @@ experiment:
|
||||||
model: "resnet18"
|
model: "resnet18"
|
||||||
learning_rate: 1e-3
|
learning_rate: 1e-3
|
||||||
learning_rate_schedule_patience: 6
|
learning_rate_schedule_patience: 6
|
||||||
weights: "random"
|
weights: null
|
||||||
in_channels: 3
|
in_channels: 3
|
||||||
num_classes: 17
|
num_classes: 17
|
||||||
datamodule:
|
datamodule:
|
||||||
|
|
|
@ -21,7 +21,7 @@ Fully-convolutional Network
|
||||||
.. autoclass:: FCN
|
.. autoclass:: FCN
|
||||||
|
|
||||||
FC Siamese Networks
|
FC Siamese Networks
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
.. autoclass:: FCSiamConc
|
.. autoclass:: FCSiamConc
|
||||||
.. autoclass:: FCSiamDiff
|
.. autoclass:: FCSiamDiff
|
||||||
|
@ -34,4 +34,33 @@ RCF Extractor
|
||||||
ResNet
|
ResNet
|
||||||
^^^^^^
|
^^^^^^
|
||||||
|
|
||||||
|
.. autofunction:: resnet18
|
||||||
.. autofunction:: resnet50
|
.. autofunction:: resnet50
|
||||||
|
.. autoclass:: ResNet18_Weights
|
||||||
|
.. autoclass:: ResNet50_Weights
|
||||||
|
|
||||||
|
.. csv-table::
|
||||||
|
:widths: 45 10 10 10 15 10 10 10
|
||||||
|
:header-rows: 1
|
||||||
|
:align: center
|
||||||
|
:file: resnet_pretrained_weights.csv
|
||||||
|
|
||||||
|
Vision Transformer
|
||||||
|
^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
.. autofunction:: vit_small_patch16_224
|
||||||
|
.. autoclass:: ViTSmall16_Weights
|
||||||
|
|
||||||
|
.. csv-table::
|
||||||
|
:widths: 45 10 10 10 15 10 10 10
|
||||||
|
:header-rows: 1
|
||||||
|
:align: center
|
||||||
|
:file: vit_pretrained_weights.csv
|
||||||
|
|
||||||
|
Utility Functions
|
||||||
|
^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
.. autofunction:: get_model
|
||||||
|
.. autofunction:: get_model_weights
|
||||||
|
.. autofunction:: get_weight
|
||||||
|
.. autofunction:: list_models
|
||||||
|
|
|
@ -0,0 +1,9 @@
|
||||||
|
Weight,Channels,Source,Citation,BigEarthNet,EuroSAT,So2Sat,OSCD
|
||||||
|
ResNet18_Weights.SENTINEL2_ALL_MOCO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,,,,
|
||||||
|
ResNet18_Weights.SENTINEL2_RGB_MOCO, 3,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,,,,
|
||||||
|
ResNet18_Weights.SENTINEL2_RGB_SECO, 3,`link <https://github.com/ServiceNow/seasonal-contrast>`__,`link <https://arxiv.org/abs/2103.16607>`__,87.27,93.14,,46.94
|
||||||
|
ResNet50_Weights.SENTINEL1_ALL_MOCO, 2,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,,,,
|
||||||
|
ResNet50_Weights.SENTINEL2_ALL_MOCO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,91.8,99.1,60.9,
|
||||||
|
ResNet50_Weights.SENTINEL2_RGB_MOCO, 3,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,,,,
|
||||||
|
ResNet50_Weights.SENTINEL2_ALL_DINO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,90.7,99.1,63.6,
|
||||||
|
ResNet50_Weights.SENTINEL2_RGB_SECO, 3,`link <https://github.com/ServiceNow/seasonal-contrast>`__,`link <https://arxiv.org/abs/2103.16607>`__,87.81,,,
|
|
|
@ -0,0 +1,3 @@
|
||||||
|
Weight,Channels,Source,Citation,BigEarthNet,EuroSAT,So2Sat,OSCD
|
||||||
|
VITSmall16_Weights.SENTINEL2_ALL_MOCO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,89.9,98.6,61.6,
|
||||||
|
VITSmall16_Weights.SENTINEL2_ALL_DINO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,90.5,99.0,62.2,
|
|
13
docs/conf.py
13
docs/conf.py
|
@ -56,14 +56,12 @@ needs_sphinx = "4.0"
|
||||||
|
|
||||||
nitpicky = True
|
nitpicky = True
|
||||||
nitpick_ignore = [
|
nitpick_ignore = [
|
||||||
# https://github.com/sphinx-doc/sphinx/issues/8127
|
# Undocumented classes
|
||||||
("py:class", ".."),
|
|
||||||
# TODO: can't figure out why this isn't found
|
|
||||||
("py:class", "LightningDataModule"),
|
|
||||||
("py:class", "pytorch_lightning.core.module.LightningModule"),
|
|
||||||
# Undocumented class
|
|
||||||
("py:class", "torchvision.models.resnet.ResNet"),
|
|
||||||
("py:class", "segmentation_models_pytorch.base.model.SegmentationModel"),
|
("py:class", "segmentation_models_pytorch.base.model.SegmentationModel"),
|
||||||
|
("py:class", "timm.models.resnet.ResNet"),
|
||||||
|
("py:class", "timm.models.vision_transformer.VisionTransformer"),
|
||||||
|
("py:class", "torchvision.models._api.WeightsEnum"),
|
||||||
|
("py:class", "torchvision.models.resnet.ResNet"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -114,6 +112,7 @@ intersphinx_mapping = {
|
||||||
"rasterio": ("https://rasterio.readthedocs.io/en/stable/", None),
|
"rasterio": ("https://rasterio.readthedocs.io/en/stable/", None),
|
||||||
"rtree": ("https://rtree.readthedocs.io/en/stable/", None),
|
"rtree": ("https://rtree.readthedocs.io/en/stable/", None),
|
||||||
"segmentation_models_pytorch": ("https://smp.readthedocs.io/en/stable/", None),
|
"segmentation_models_pytorch": ("https://smp.readthedocs.io/en/stable/", None),
|
||||||
|
"timm": ("https://huggingface.co/docs/timm/main/en/", None),
|
||||||
"torch": ("https://pytorch.org/docs/stable", None),
|
"torch": ("https://pytorch.org/docs/stable", None),
|
||||||
"torchvision": ("https://pytorch.org/vision/stable", None),
|
"torchvision": ("https://pytorch.org/vision/stable", None),
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,6 +33,7 @@ torchgeo
|
||||||
tutorials/indices
|
tutorials/indices
|
||||||
tutorials/trainers
|
tutorials/trainers
|
||||||
tutorials/benchmarking
|
tutorials/benchmarking
|
||||||
|
tutorials/pretrained_weights
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
|
|
|
@ -0,0 +1,254 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Copyright (c) Microsoft Corporation. All rights reserved.\n",
|
||||||
|
"\n",
|
||||||
|
"Licensed under the MIT License."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Pretrained Weights\n",
|
||||||
|
"\n",
|
||||||
|
"In this tutorial, we demonstrate some available pretrained weights in TorchGeo. The implementation follows torchvisions' recently introduced [Multi-Weight API](https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/). We will use the [EuroSAT](https://torchgeo.readthedocs.io/en/stable/api/datasets.html#eurosat) dataset throughout this tutorial.\n",
|
||||||
|
"\n",
|
||||||
|
"It's recommended to run this notebook on Google Colab if you don't have your own GPU. Click the \"Open in Colab\" button above to get started."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Setup\n",
|
||||||
|
"\n",
|
||||||
|
"First, we install TorchGeo."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"%pip install torchgeo"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Imports\n",
|
||||||
|
"\n",
|
||||||
|
"Next, we import TorchGeo and any other libraries we need."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"%matplotlib inline\n",
|
||||||
|
"\n",
|
||||||
|
"import os\n",
|
||||||
|
"import csv\n",
|
||||||
|
"import tempfile\n",
|
||||||
|
"\n",
|
||||||
|
"import matplotlib.pyplot as plt\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"import pytorch_lightning as pl\n",
|
||||||
|
"import timm\n",
|
||||||
|
"from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint\n",
|
||||||
|
"from pytorch_lightning.loggers import CSVLogger\n",
|
||||||
|
"\n",
|
||||||
|
"from torchgeo.datamodules import EuroSATDataModule\n",
|
||||||
|
"from torchgeo.trainers import ClassificationTask\n",
|
||||||
|
"from torchgeo.models import ResNet50_Weights, VITSmall16_Weights"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# we set a flag to check to see whether the notebook is currently being run by PyTest, if this is the case then we'll\n",
|
||||||
|
"# skip the expensive training.\n",
|
||||||
|
"in_tests = \"PYTEST_CURRENT_TEST\" in os.environ"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Datamodule\n",
|
||||||
|
"\n",
|
||||||
|
"We will utilize TorchGeo's datamodules from [PyTorch Lightning](https://pytorch-lightning.readthedocs.io/) to organize the dataloader setup."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"root = os.path.join(tempfile.gettempdir(), \"eurosat\")\n",
|
||||||
|
"\n",
|
||||||
|
"datamodule = EuroSATDataModule(root=root, batch_size=64, num_workers=4)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Weights\n",
|
||||||
|
"\n",
|
||||||
|
"Available pretrained weights are listed on the model documentation [page](https://torchgeo.readthedocs.io/en/stable/api/models.html). While some weights only accept RGB channel input, some weights have been pretrained on Sentinel 2 imagery with 13 input channels and can hence prove useful for transfer learning tasks involving Sentinel 2 data.\n",
|
||||||
|
"\n",
|
||||||
|
"To access these weights you can do the following:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"weights = ResNet50_Weights.SENTINEL2_ALL_MOCO"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"This set of weights is a torchvision `WeightEnum` and holds information such as the download url link or additional meta data. TorchGeo takes care of the downloading and initialization of models with a desired set of weights. Given that EuroSAT is a classification dataset, we can use a `ClassificationTask` object that holds the model and optimizer object as well as the training logic."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"task = ClassificationTask(\n",
|
||||||
|
" model=\"resnet50\",\n",
|
||||||
|
" loss=\"ce\",\n",
|
||||||
|
" weights=weights,\n",
|
||||||
|
" in_channels=13,\n",
|
||||||
|
" num_classes=10,\n",
|
||||||
|
" learning_rate=0.001,\n",
|
||||||
|
" learning_rate_schedule_patience=5,\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"If you do not want to utilize the `ClassificationTask` functionality for your experiments, you can also just create a [timm](https://github.com/rwightman/pytorch-image-models) model with pretrained weights from TorchGeo as follows:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"in_chans = weights.meta[\"in_chans\"]\n",
|
||||||
|
"model = timm.create_model(\"resnet50\", in_chans=in_chans, num_classes=10)\n",
|
||||||
|
"model.load_state_dict(weights.get_state_dict(), strict=False)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Training\n",
|
||||||
|
"\n",
|
||||||
|
"To train our pretrained model on the EuroSAT dataset we will make use of PyTorch Lightning's [Trainer](https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html). For a more elaborate explanation of how TorchGeo uses PyTorch Lightning, check out [this tutorial](https://torchgeo.readthedocs.io/en/stable/tutorials/trainers.html)."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"experiment_dir = os.path.join(tempfile.gettempdir(), \"eurosat_results\")\n",
|
||||||
|
"\n",
|
||||||
|
"checkpoint_callback = ModelCheckpoint(\n",
|
||||||
|
" monitor=\"val_loss\", dirpath=experiment_dir, save_top_k=1, save_last=True\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"early_stopping_callback = EarlyStopping(monitor=\"val_loss\", min_delta=0.00, patience=10)\n",
|
||||||
|
"\n",
|
||||||
|
"csv_logger = CSVLogger(save_dir=experiment_dir, name=\"pretrained_weights_logs\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"trainer = pl.Trainer(\n",
|
||||||
|
" callbacks=[checkpoint_callback, early_stopping_callback],\n",
|
||||||
|
" logger=[csv_logger],\n",
|
||||||
|
" default_root_dir=experiment_dir,\n",
|
||||||
|
" min_epochs=1,\n",
|
||||||
|
" max_epochs=10,\n",
|
||||||
|
" fast_dev_run=in_tests,\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"trainer.fit(model=task, datamodule=datamodule)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "torchEnv",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.8.10"
|
||||||
|
},
|
||||||
|
"orig_nbformat": 4,
|
||||||
|
"vscode": {
|
||||||
|
"interpreter": {
|
||||||
|
"hash": "b058dd71d0e7047e70e62f655d92ec955f772479bbe5e5addd202027292e8f60"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
|
@ -11,12 +11,12 @@ dependencies:
|
||||||
- pycocotools>=2
|
- pycocotools>=2
|
||||||
- pyproj>=2.2
|
- pyproj>=2.2
|
||||||
- python>=3.7
|
- python>=3.7
|
||||||
- pytorch>=1.9
|
- pytorch>=1.12
|
||||||
- pyvista>=0.20
|
- pyvista>=0.20
|
||||||
- rarfile>=3
|
- rarfile>=3
|
||||||
- rasterio>=1.0.20
|
- rasterio>=1.0.20
|
||||||
- shapely>=1.3
|
- shapely>=1.3
|
||||||
- torchvision>=0.10
|
- torchvision>=0.13
|
||||||
- pip:
|
- pip:
|
||||||
- black[jupyter]>=21.8
|
- black[jupyter]>=21.8
|
||||||
- flake8>=3.8
|
- flake8>=3.8
|
||||||
|
|
|
@ -0,0 +1,14 @@
|
||||||
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
"""TorchGeo pre-trained model repository configuration file.
|
||||||
|
|
||||||
|
* https://pytorch.org/hub/
|
||||||
|
* https://pytorch.org/docs/stable/hub.html
|
||||||
|
"""
|
||||||
|
|
||||||
|
from torchgeo.models import resnet18, resnet50, vit_small_patch16_224
|
||||||
|
|
||||||
|
__all__ = ("resnet18", "resnet50", "vit_small_patch16_224")
|
||||||
|
|
||||||
|
dependencies = ["timm"]
|
|
@ -18,9 +18,9 @@ scikit-learn==0.21.0
|
||||||
segmentation-models-pytorch==0.2.0
|
segmentation-models-pytorch==0.2.0
|
||||||
shapely==1.3.0
|
shapely==1.3.0
|
||||||
timm==0.4.12
|
timm==0.4.12
|
||||||
torch==1.9.0
|
torch==1.12.0
|
||||||
torchmetrics==0.10.0
|
torchmetrics==0.10.0
|
||||||
torchvision==0.10.0
|
torchvision==0.13.0
|
||||||
|
|
||||||
# datasets
|
# datasets
|
||||||
h5py==2.6.0
|
h5py==2.6.0
|
||||||
|
|
|
@ -57,12 +57,12 @@ install_requires =
|
||||||
shapely>=1.3,<3
|
shapely>=1.3,<3
|
||||||
# timm 0.4.12 required by segmentation-models-pytorch
|
# timm 0.4.12 required by segmentation-models-pytorch
|
||||||
timm>=0.4.12,<0.7
|
timm>=0.4.12,<0.7
|
||||||
# torch 1.9+ required by torchvision
|
# torch 1.12+ required by torchvision
|
||||||
torch>=1.9,<2
|
torch>=1.12,<2
|
||||||
# torchmetrics 0.10+ required for binary/multiclass/multilabel classification metrics
|
# torchmetrics 0.10+ required for binary/multiclass/multilabel classification metrics
|
||||||
torchmetrics>=0.10,<0.12
|
torchmetrics>=0.10,<0.12
|
||||||
# torchvision 0.10+ required for torchvision.utils.draw_segmentation_masks
|
# torchvision 0.13+ required for torchvision.models._api.WeightsEnum
|
||||||
torchvision>=0.10,<0.15
|
torchvision>=0.13,<0.15
|
||||||
python_requires = ~= 3.7
|
python_requires = ~= 3.7
|
||||||
packages = find:
|
packages = find:
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ experiment:
|
||||||
model: "resnet18"
|
model: "resnet18"
|
||||||
learning_rate: 1e-3
|
learning_rate: 1e-3
|
||||||
learning_rate_schedule_patience: 6
|
learning_rate_schedule_patience: 6
|
||||||
weights: "random"
|
weights: null
|
||||||
in_channels: 14
|
in_channels: 14
|
||||||
num_classes: 19
|
num_classes: 19
|
||||||
datamodule:
|
datamodule:
|
||||||
|
|
|
@ -5,7 +5,7 @@ experiment:
|
||||||
model: "resnet18"
|
model: "resnet18"
|
||||||
learning_rate: 1e-3
|
learning_rate: 1e-3
|
||||||
learning_rate_schedule_patience: 6
|
learning_rate_schedule_patience: 6
|
||||||
weights: "random"
|
weights: null
|
||||||
in_channels: 2
|
in_channels: 2
|
||||||
num_classes: 19
|
num_classes: 19
|
||||||
datamodule:
|
datamodule:
|
||||||
|
|
|
@ -5,7 +5,7 @@ experiment:
|
||||||
model: "resnet18"
|
model: "resnet18"
|
||||||
learning_rate: 1e-3
|
learning_rate: 1e-3
|
||||||
learning_rate_schedule_patience: 6
|
learning_rate_schedule_patience: 6
|
||||||
weights: "random"
|
weights: null
|
||||||
in_channels: 12
|
in_channels: 12
|
||||||
num_classes: 19
|
num_classes: 19
|
||||||
datamodule:
|
datamodule:
|
||||||
|
|
|
@ -1,21 +0,0 @@
|
||||||
experiment:
|
|
||||||
task: "ssl"
|
|
||||||
name: "test_byol"
|
|
||||||
module:
|
|
||||||
model: "byol"
|
|
||||||
backbone: "resnet18"
|
|
||||||
input_channels: 4
|
|
||||||
weights: imagenet
|
|
||||||
learning_rate: 1e-3
|
|
||||||
learning_rate_schedule_patience: 6
|
|
||||||
datamodule:
|
|
||||||
root: "tests/data/chesapeake/cvpr"
|
|
||||||
download: true
|
|
||||||
train_splits:
|
|
||||||
- "de-test"
|
|
||||||
val_splits:
|
|
||||||
- "de-test"
|
|
||||||
test_splits:
|
|
||||||
- "de-test"
|
|
||||||
batch_size: 1
|
|
||||||
num_workers: 0
|
|
|
@ -10,7 +10,7 @@ experiment:
|
||||||
num_classes: 7
|
num_classes: 7
|
||||||
num_filters: 1
|
num_filters: 1
|
||||||
ignore_index: null
|
ignore_index: null
|
||||||
weights: imagenet
|
weights: null
|
||||||
datamodule:
|
datamodule:
|
||||||
root: "tests/data/chesapeake/cvpr"
|
root: "tests/data/chesapeake/cvpr"
|
||||||
download: true
|
download: true
|
||||||
|
|
|
@ -10,7 +10,7 @@ experiment:
|
||||||
num_classes: 5
|
num_classes: 5
|
||||||
num_filters: 1
|
num_filters: 1
|
||||||
ignore_index: null
|
ignore_index: null
|
||||||
weights: imagenet
|
weights: null
|
||||||
datamodule:
|
datamodule:
|
||||||
root: "tests/data/chesapeake/cvpr"
|
root: "tests/data/chesapeake/cvpr"
|
||||||
download: true
|
download: true
|
||||||
|
|
|
@ -2,12 +2,11 @@ experiment:
|
||||||
task: cowc_counting
|
task: cowc_counting
|
||||||
module:
|
module:
|
||||||
model: resnet18
|
model: resnet18
|
||||||
weights: "random"
|
weights: null
|
||||||
num_outputs: 1
|
num_outputs: 1
|
||||||
in_channels: 3
|
in_channels: 3
|
||||||
learning_rate: 1e-3
|
learning_rate: 1e-3
|
||||||
learning_rate_schedule_patience: 2
|
learning_rate_schedule_patience: 2
|
||||||
pretrained: True
|
|
||||||
datamodule:
|
datamodule:
|
||||||
root: "tests/data/cowc_counting"
|
root: "tests/data/cowc_counting"
|
||||||
download: true
|
download: true
|
||||||
|
|
|
@ -2,12 +2,11 @@ experiment:
|
||||||
task: "cyclone"
|
task: "cyclone"
|
||||||
module:
|
module:
|
||||||
model: "resnet18"
|
model: "resnet18"
|
||||||
weights: "random"
|
weights: null
|
||||||
num_outputs: 1
|
num_outputs: 1
|
||||||
in_channels: 3
|
in_channels: 3
|
||||||
learning_rate: 1e-3
|
learning_rate: 1e-3
|
||||||
learning_rate_schedule_patience: 2
|
learning_rate_schedule_patience: 2
|
||||||
pretrained: False
|
|
||||||
datamodule:
|
datamodule:
|
||||||
root: "tests/data/cyclone"
|
root: "tests/data/cyclone"
|
||||||
download: true
|
download: true
|
||||||
|
|
|
@ -5,7 +5,7 @@ experiment:
|
||||||
model: "resnet18"
|
model: "resnet18"
|
||||||
learning_rate: 1e-3
|
learning_rate: 1e-3
|
||||||
learning_rate_schedule_patience: 6
|
learning_rate_schedule_patience: 6
|
||||||
weights: "random"
|
weights: null
|
||||||
in_channels: 13
|
in_channels: 13
|
||||||
num_classes: 2
|
num_classes: 2
|
||||||
datamodule:
|
datamodule:
|
||||||
|
|
|
@ -5,7 +5,7 @@ experiment:
|
||||||
model: "resnet18"
|
model: "resnet18"
|
||||||
learning_rate: 1e-3
|
learning_rate: 1e-3
|
||||||
learning_rate_schedule_patience: 6
|
learning_rate_schedule_patience: 6
|
||||||
weights: "random"
|
weights: null
|
||||||
in_channels: 3
|
in_channels: 3
|
||||||
num_classes: 3
|
num_classes: 3
|
||||||
datamodule:
|
datamodule:
|
||||||
|
|
|
@ -5,7 +5,7 @@ experiment:
|
||||||
model: "resnet18"
|
model: "resnet18"
|
||||||
learning_rate: 1e-3
|
learning_rate: 1e-3
|
||||||
learning_rate_schedule_patience: 6
|
learning_rate_schedule_patience: 6
|
||||||
weights: "random"
|
weights: null
|
||||||
in_channels: 3
|
in_channels: 3
|
||||||
num_classes: 17
|
num_classes: 17
|
||||||
datamodule:
|
datamodule:
|
||||||
|
|
|
@ -5,7 +5,7 @@ experiment:
|
||||||
model: "resnet18"
|
model: "resnet18"
|
||||||
learning_rate: 1e-3
|
learning_rate: 1e-3
|
||||||
learning_rate_schedule_patience: 6
|
learning_rate_schedule_patience: 6
|
||||||
weights: "random"
|
weights: null
|
||||||
in_channels: 3
|
in_channels: 3
|
||||||
num_classes: 17
|
num_classes: 17
|
||||||
datamodule:
|
datamodule:
|
||||||
|
|
|
@ -3,7 +3,7 @@ experiment:
|
||||||
module:
|
module:
|
||||||
loss: "ce"
|
loss: "ce"
|
||||||
model: "resnet18"
|
model: "resnet18"
|
||||||
weights: "random"
|
weights: null
|
||||||
learning_rate: 1e-3
|
learning_rate: 1e-3
|
||||||
learning_rate_schedule_patience: 6
|
learning_rate_schedule_patience: 6
|
||||||
in_channels: 3
|
in_channels: 3
|
||||||
|
|
Двоичные данные
tests/data/models/resnet50-sentinel2-2.pt.zip
Двоичные данные
tests/data/models/resnet50-sentinel2-2.pt.zip
Двоичный файл не отображается.
|
@ -0,0 +1,50 @@
|
||||||
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
import enum
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch.nn as nn
|
||||||
|
from torchvision.models._api import WeightsEnum
|
||||||
|
|
||||||
|
from torchgeo.models import (
|
||||||
|
ResNet18_Weights,
|
||||||
|
ResNet50_Weights,
|
||||||
|
ViTSmall16_Weights,
|
||||||
|
get_model,
|
||||||
|
get_model_weights,
|
||||||
|
get_weight,
|
||||||
|
list_models,
|
||||||
|
resnet18,
|
||||||
|
resnet50,
|
||||||
|
vit_small_patch16_224,
|
||||||
|
)
|
||||||
|
|
||||||
|
builders = [resnet18, resnet50, vit_small_patch16_224]
|
||||||
|
enums = [ResNet18_Weights, ResNet50_Weights, ViTSmall16_Weights]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("builder", builders)
|
||||||
|
def test_get_model(builder: Callable[..., nn.Module]) -> None:
|
||||||
|
model = get_model(builder.__name__)
|
||||||
|
assert isinstance(model, nn.Module)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("builder", builders)
|
||||||
|
def test_get_model_weights(builder: Callable[..., nn.Module]) -> None:
|
||||||
|
weights = get_model_weights(builder)
|
||||||
|
assert isinstance(weights, enum.EnumMeta)
|
||||||
|
weights = get_model_weights(builder.__name__)
|
||||||
|
assert isinstance(weights, enum.EnumMeta)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("enum", enums)
|
||||||
|
def test_get_weight(enum: WeightsEnum) -> None:
|
||||||
|
for weight in enum:
|
||||||
|
assert weight == get_weight(str(weight))
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_models() -> None:
|
||||||
|
models = [builder.__name__ for builder in builders]
|
||||||
|
assert set(models) == set(list_models())
|
|
@ -1,61 +1,74 @@
|
||||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
# Licensed under the MIT License.
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional
|
from typing import Any, Dict
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import timm
|
||||||
import torch
|
import torch
|
||||||
|
import torchvision
|
||||||
|
from _pytest.fixtures import SubRequest
|
||||||
from _pytest.monkeypatch import MonkeyPatch
|
from _pytest.monkeypatch import MonkeyPatch
|
||||||
from torch.nn.modules import Module
|
from torchvision.models._api import WeightsEnum
|
||||||
|
|
||||||
import torchgeo.models.resnet
|
from torchgeo.models import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50
|
||||||
from torchgeo.datasets.utils import extract_archive
|
|
||||||
from torchgeo.models import resnet50
|
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict_from_file(
|
def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
||||||
file: str,
|
state_dict: Dict[str, Any] = torch.load(url)
|
||||||
model_dir: Optional[str] = None,
|
return state_dict
|
||||||
map_location: Optional[Any] = None,
|
|
||||||
progress: Optional[bool] = True,
|
|
||||||
check_hash: Optional[bool] = False,
|
|
||||||
file_name: Optional[str] = None,
|
|
||||||
) -> Any:
|
|
||||||
"""Mockup of ``torch.hub.load_state_dict_from_url``."""
|
|
||||||
return torch.load(file)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
class TestResNet18:
|
||||||
"model_class,sensor,bands,in_channels,num_classes",
|
@pytest.fixture(params=[*ResNet18_Weights])
|
||||||
[(resnet50, "sentinel2", "all", 10, 17)],
|
def weights(self, request: SubRequest) -> WeightsEnum:
|
||||||
)
|
return request.param
|
||||||
def test_resnet(
|
|
||||||
monkeypatch: MonkeyPatch,
|
|
||||||
tmp_path: Path,
|
|
||||||
model_class: Module,
|
|
||||||
sensor: str,
|
|
||||||
bands: str,
|
|
||||||
in_channels: int,
|
|
||||||
num_classes: int,
|
|
||||||
) -> None:
|
|
||||||
extract_archive(
|
|
||||||
os.path.join("tests", "data", "models", "resnet50-sentinel2-2.pt.zip"),
|
|
||||||
str(tmp_path),
|
|
||||||
)
|
|
||||||
|
|
||||||
new_model_urls = {
|
@pytest.fixture
|
||||||
"sentinel2": {"all": {"resnet50": str(tmp_path / "resnet50-sentinel2-2.pt")}}
|
def mocked_weights(
|
||||||
}
|
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
|
||||||
|
) -> WeightsEnum:
|
||||||
|
path = tmp_path / f"{weights}.pth"
|
||||||
|
model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"])
|
||||||
|
torch.save(model.state_dict(), path)
|
||||||
|
monkeypatch.setattr(weights, "url", str(path))
|
||||||
|
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
|
||||||
|
return weights
|
||||||
|
|
||||||
monkeypatch.setattr(torchgeo.models.resnet, "MODEL_URLS", new_model_urls)
|
def test_resnet(self) -> None:
|
||||||
monkeypatch.setattr(
|
resnet18()
|
||||||
torchgeo.models.resnet, "load_state_dict_from_url", load_state_dict_from_file
|
|
||||||
)
|
|
||||||
|
|
||||||
model = model_class(sensor, bands, pretrained=True)
|
def test_resnet_weights(self, mocked_weights: WeightsEnum) -> None:
|
||||||
x = torch.zeros(1, in_channels, 256, 256)
|
resnet18(weights=mocked_weights)
|
||||||
y = model(x)
|
|
||||||
assert isinstance(y, torch.Tensor)
|
@pytest.mark.slow
|
||||||
assert y.size() == torch.Size([1, 17])
|
def test_resnet_download(self, weights: WeightsEnum) -> None:
|
||||||
|
resnet18(weights=weights)
|
||||||
|
|
||||||
|
|
||||||
|
class TestResNet50:
|
||||||
|
@pytest.fixture(params=[*ResNet50_Weights])
|
||||||
|
def weights(self, request: SubRequest) -> WeightsEnum:
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mocked_weights(
|
||||||
|
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
|
||||||
|
) -> WeightsEnum:
|
||||||
|
path = tmp_path / f"{weights}.pth"
|
||||||
|
model = timm.create_model("resnet50", in_chans=weights.meta["in_chans"])
|
||||||
|
torch.save(model.state_dict(), path)
|
||||||
|
monkeypatch.setattr(weights, "url", str(path))
|
||||||
|
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
|
||||||
|
return weights
|
||||||
|
|
||||||
|
def test_resnet(self) -> None:
|
||||||
|
resnet50()
|
||||||
|
|
||||||
|
def test_resnet_weights(self, mocked_weights: WeightsEnum) -> None:
|
||||||
|
resnet50(weights=mocked_weights)
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
def test_resnet_download(self, weights: WeightsEnum) -> None:
|
||||||
|
resnet50(weights=weights)
|
||||||
|
|
|
@ -0,0 +1,49 @@
|
||||||
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import timm
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
from _pytest.fixtures import SubRequest
|
||||||
|
from _pytest.monkeypatch import MonkeyPatch
|
||||||
|
from torchvision.models._api import WeightsEnum
|
||||||
|
|
||||||
|
from torchgeo.models import ViTSmall16_Weights, vit_small_patch16_224
|
||||||
|
|
||||||
|
|
||||||
|
def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
||||||
|
state_dict: Dict[str, Any] = torch.load(url)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
class TestViTSmall16:
|
||||||
|
@pytest.fixture(params=[*ViTSmall16_Weights])
|
||||||
|
def weights(self, request: SubRequest) -> WeightsEnum:
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mocked_weights(
|
||||||
|
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
|
||||||
|
) -> WeightsEnum:
|
||||||
|
path = tmp_path / f"{weights}.pth"
|
||||||
|
model = timm.create_model(
|
||||||
|
weights.meta["model"], in_chans=weights.meta["in_chans"]
|
||||||
|
)
|
||||||
|
torch.save(model.state_dict(), path)
|
||||||
|
monkeypatch.setattr(weights, "url", str(path))
|
||||||
|
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
|
||||||
|
return weights
|
||||||
|
|
||||||
|
def test_vit(self) -> None:
|
||||||
|
vit_small_patch16_224()
|
||||||
|
|
||||||
|
def test_vit_weights(self, mocked_weights: WeightsEnum) -> None:
|
||||||
|
vit_small_patch16_224(weights=mocked_weights)
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
def test_vit_download(self, weights: WeightsEnum) -> None:
|
||||||
|
vit_small_patch16_224(weights=weights)
|
|
@ -2,21 +2,33 @@
|
||||||
# Licensed under the MIT License.
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Type, cast
|
from typing import Any, Dict, Type, cast
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import timm
|
||||||
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torchvision
|
||||||
|
from _pytest.monkeypatch import MonkeyPatch
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from pytorch_lightning import LightningDataModule, Trainer
|
from pytorch_lightning import LightningDataModule, Trainer
|
||||||
from torchvision.models import resnet18
|
from torchvision.models import resnet18
|
||||||
|
from torchvision.models._api import WeightsEnum
|
||||||
|
|
||||||
from torchgeo.datamodules import ChesapeakeCVPRDataModule
|
from torchgeo.datamodules import ChesapeakeCVPRDataModule
|
||||||
|
from torchgeo.models import ResNet18_Weights
|
||||||
from torchgeo.trainers import BYOLTask
|
from torchgeo.trainers import BYOLTask
|
||||||
from torchgeo.trainers.byol import BYOL, SimCLRAugmentation
|
from torchgeo.trainers.byol import BYOL, SimCLRAugmentation
|
||||||
|
|
||||||
from .test_utils import SegmentationTestModel
|
from .test_utils import SegmentationTestModel
|
||||||
|
|
||||||
|
|
||||||
|
def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
||||||
|
state_dict: Dict[str, Any] = torch.load(url)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
class TestBYOL:
|
class TestBYOL:
|
||||||
def test_custom_augment_fn(self) -> None:
|
def test_custom_augment_fn(self) -> None:
|
||||||
backbone = resnet18()
|
backbone = resnet18()
|
||||||
|
@ -45,7 +57,7 @@ class TestBYOLTask:
|
||||||
def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None:
|
def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None:
|
||||||
conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))
|
conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))
|
||||||
conf_dict = OmegaConf.to_object(conf.experiment)
|
conf_dict = OmegaConf.to_object(conf.experiment)
|
||||||
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)
|
conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict)
|
||||||
|
|
||||||
# Instantiate datamodule
|
# Instantiate datamodule
|
||||||
datamodule_kwargs = conf_dict["datamodule"]
|
datamodule_kwargs = conf_dict["datamodule"]
|
||||||
|
@ -64,30 +76,31 @@ class TestBYOLTask:
|
||||||
trainer.predict(model=model, dataloaders=datamodule.val_dataloader())
|
trainer.predict(model=model, dataloaders=datamodule.val_dataloader())
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def model_kwargs(self) -> Dict[Any, Any]:
|
def model_kwargs(self) -> Dict[str, Any]:
|
||||||
return {"backbone": "resnet18", "weights": "random", "in_channels": 3}
|
return {"backbone": "resnet18", "weights": None, "in_channels": 3}
|
||||||
|
|
||||||
def test_invalid_pretrained(
|
@pytest.fixture
|
||||||
self, model_kwargs: Dict[Any, Any], checkpoint: str
|
def mocked_weights(self, tmp_path: Path, monkeypatch: MonkeyPatch) -> WeightsEnum:
|
||||||
|
weights = ResNet18_Weights.SENTINEL2_RGB_MOCO
|
||||||
|
path = tmp_path / f"{weights}.pth"
|
||||||
|
model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"])
|
||||||
|
torch.save(model.state_dict(), path)
|
||||||
|
monkeypatch.setattr(weights, "url", str(path))
|
||||||
|
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
|
||||||
|
return weights
|
||||||
|
|
||||||
|
def test_weight_file(self, model_kwargs: Dict[str, Any], checkpoint: str) -> None:
|
||||||
|
model_kwargs["weights"] = checkpoint
|
||||||
|
BYOLTask(**model_kwargs)
|
||||||
|
|
||||||
|
def test_weight_enum(
|
||||||
|
self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum
|
||||||
) -> None:
|
) -> None:
|
||||||
model_kwargs["weights"] = checkpoint
|
model_kwargs["weights"] = mocked_weights
|
||||||
model_kwargs["backbone"] = "resnet50"
|
|
||||||
match = "Trying to load resnet18 weights into a resnet50"
|
|
||||||
with pytest.raises(ValueError, match=match):
|
|
||||||
BYOLTask(**model_kwargs)
|
BYOLTask(**model_kwargs)
|
||||||
|
|
||||||
def test_pretrained(self, model_kwargs: Dict[Any, Any], checkpoint: str) -> None:
|
def test_weight_str(
|
||||||
model_kwargs["weights"] = checkpoint
|
self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum
|
||||||
BYOLTask(**model_kwargs)
|
) -> None:
|
||||||
|
model_kwargs["weights"] = str(mocked_weights)
|
||||||
def test_invalid_backbone(self, model_kwargs: Dict[Any, Any]) -> None:
|
|
||||||
model_kwargs["backbone"] = "invalid_backbone"
|
|
||||||
match = "Model type 'invalid_backbone' is not a valid timm model."
|
|
||||||
with pytest.raises(ValueError, match=match):
|
|
||||||
BYOLTask(**model_kwargs)
|
|
||||||
|
|
||||||
def test_invalid_weights(self, model_kwargs: Dict[Any, Any]) -> None:
|
|
||||||
model_kwargs["weights"] = "invalid_weights"
|
|
||||||
match = "Weight type 'invalid_weights' is not valid."
|
|
||||||
with pytest.raises(ValueError, match=match):
|
|
||||||
BYOLTask(**model_kwargs)
|
BYOLTask(**model_kwargs)
|
||||||
|
|
|
@ -2,14 +2,18 @@
|
||||||
# Licensed under the MIT License.
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Type, cast
|
from typing import Any, Dict, Type, cast
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import timm
|
import timm
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
from _pytest.monkeypatch import MonkeyPatch
|
from _pytest.monkeypatch import MonkeyPatch
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from pytorch_lightning import LightningDataModule, Trainer
|
from pytorch_lightning import LightningDataModule, Trainer
|
||||||
from torch.nn.modules import Module
|
from torch.nn.modules import Module
|
||||||
|
from torchvision.models._api import WeightsEnum
|
||||||
|
|
||||||
from torchgeo.datamodules import (
|
from torchgeo.datamodules import (
|
||||||
BigEarthNetDataModule,
|
BigEarthNetDataModule,
|
||||||
|
@ -18,6 +22,7 @@ from torchgeo.datamodules import (
|
||||||
So2SatDataModule,
|
So2SatDataModule,
|
||||||
UCMercedDataModule,
|
UCMercedDataModule,
|
||||||
)
|
)
|
||||||
|
from torchgeo.models import ResNet18_Weights
|
||||||
from torchgeo.trainers import ClassificationTask, MultiLabelClassificationTask
|
from torchgeo.trainers import ClassificationTask, MultiLabelClassificationTask
|
||||||
|
|
||||||
from .test_utils import ClassificationTestModel
|
from .test_utils import ClassificationTestModel
|
||||||
|
@ -27,6 +32,11 @@ def create_model(*args: Any, **kwargs: Any) -> Module:
|
||||||
return ClassificationTestModel(**kwargs)
|
return ClassificationTestModel(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
||||||
|
state_dict: Dict[str, Any] = torch.load(url)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
class TestClassificationTask:
|
class TestClassificationTask:
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"name,classname",
|
"name,classname",
|
||||||
|
@ -46,7 +56,7 @@ class TestClassificationTask:
|
||||||
|
|
||||||
conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))
|
conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))
|
||||||
conf_dict = OmegaConf.to_object(conf.experiment)
|
conf_dict = OmegaConf.to_object(conf.experiment)
|
||||||
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)
|
conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict)
|
||||||
|
|
||||||
# Instantiate datamodule
|
# Instantiate datamodule
|
||||||
datamodule_kwargs = conf_dict["datamodule"]
|
datamodule_kwargs = conf_dict["datamodule"]
|
||||||
|
@ -66,7 +76,7 @@ class TestClassificationTask:
|
||||||
def test_no_logger(self) -> None:
|
def test_no_logger(self) -> None:
|
||||||
conf = OmegaConf.load(os.path.join("tests", "conf", "ucmerced.yaml"))
|
conf = OmegaConf.load(os.path.join("tests", "conf", "ucmerced.yaml"))
|
||||||
conf_dict = OmegaConf.to_object(conf.experiment)
|
conf_dict = OmegaConf.to_object(conf.experiment)
|
||||||
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)
|
conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict)
|
||||||
|
|
||||||
# Instantiate datamodule
|
# Instantiate datamodule
|
||||||
datamodule_kwargs = conf_dict["datamodule"]
|
datamodule_kwargs = conf_dict["datamodule"]
|
||||||
|
@ -83,49 +93,52 @@ class TestClassificationTask:
|
||||||
trainer.fit(model=model, datamodule=datamodule)
|
trainer.fit(model=model, datamodule=datamodule)
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def model_kwargs(self) -> Dict[Any, Any]:
|
def model_kwargs(self) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"model": "resnet18",
|
"model": "resnet18",
|
||||||
"in_channels": 13,
|
"in_channels": 13,
|
||||||
"loss": "ce",
|
"loss": "ce",
|
||||||
"num_classes": 10,
|
"num_classes": 10,
|
||||||
"weights": "random",
|
"weights": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
def test_pretrained(self, model_kwargs: Dict[Any, Any], checkpoint: str) -> None:
|
@pytest.fixture
|
||||||
|
def mocked_weights(self, tmp_path: Path, monkeypatch: MonkeyPatch) -> WeightsEnum:
|
||||||
|
weights = ResNet18_Weights.SENTINEL2_ALL_MOCO
|
||||||
|
path = tmp_path / f"{weights}.pth"
|
||||||
|
model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"])
|
||||||
|
torch.save(model.state_dict(), path)
|
||||||
|
monkeypatch.setattr(weights, "url", str(path))
|
||||||
|
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
|
||||||
|
return weights
|
||||||
|
|
||||||
|
def test_weight_file(self, model_kwargs: Dict[str, Any], checkpoint: str) -> None:
|
||||||
model_kwargs["weights"] = checkpoint
|
model_kwargs["weights"] = checkpoint
|
||||||
with pytest.warns(UserWarning):
|
with pytest.warns(UserWarning):
|
||||||
ClassificationTask(**model_kwargs)
|
ClassificationTask(**model_kwargs)
|
||||||
|
|
||||||
def test_invalid_pretrained(
|
def test_weight_enum(
|
||||||
self, model_kwargs: Dict[Any, Any], checkpoint: str
|
self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum
|
||||||
) -> None:
|
) -> None:
|
||||||
model_kwargs["weights"] = checkpoint
|
model_kwargs["weights"] = mocked_weights
|
||||||
model_kwargs["model"] = "resnet50"
|
with pytest.warns(UserWarning):
|
||||||
match = "Trying to load resnet18 weights into a resnet50"
|
|
||||||
with pytest.raises(ValueError, match=match):
|
|
||||||
ClassificationTask(**model_kwargs)
|
ClassificationTask(**model_kwargs)
|
||||||
|
|
||||||
def test_invalid_loss(self, model_kwargs: Dict[Any, Any]) -> None:
|
def test_weight_str(
|
||||||
|
self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum
|
||||||
|
) -> None:
|
||||||
|
model_kwargs["weights"] = str(mocked_weights)
|
||||||
|
with pytest.warns(UserWarning):
|
||||||
|
ClassificationTask(**model_kwargs)
|
||||||
|
|
||||||
|
def test_invalid_loss(self, model_kwargs: Dict[str, Any]) -> None:
|
||||||
model_kwargs["loss"] = "invalid_loss"
|
model_kwargs["loss"] = "invalid_loss"
|
||||||
match = "Loss type 'invalid_loss' is not valid."
|
match = "Loss type 'invalid_loss' is not valid."
|
||||||
with pytest.raises(ValueError, match=match):
|
with pytest.raises(ValueError, match=match):
|
||||||
ClassificationTask(**model_kwargs)
|
ClassificationTask(**model_kwargs)
|
||||||
|
|
||||||
def test_invalid_model(self, model_kwargs: Dict[Any, Any]) -> None:
|
|
||||||
model_kwargs["model"] = "invalid_model"
|
|
||||||
match = "Model type 'invalid_model' is not a valid timm model."
|
|
||||||
with pytest.raises(ValueError, match=match):
|
|
||||||
ClassificationTask(**model_kwargs)
|
|
||||||
|
|
||||||
def test_invalid_weights(self, model_kwargs: Dict[Any, Any]) -> None:
|
|
||||||
model_kwargs["weights"] = "invalid_weights"
|
|
||||||
match = "Weight type 'invalid_weights' is not valid."
|
|
||||||
with pytest.raises(ValueError, match=match):
|
|
||||||
ClassificationTask(**model_kwargs)
|
|
||||||
|
|
||||||
def test_missing_attributes(
|
def test_missing_attributes(
|
||||||
self, model_kwargs: Dict[Any, Any], monkeypatch: MonkeyPatch
|
self, model_kwargs: Dict[str, Any], monkeypatch: MonkeyPatch
|
||||||
) -> None:
|
) -> None:
|
||||||
monkeypatch.delattr(EuroSATDataModule, "plot")
|
monkeypatch.delattr(EuroSATDataModule, "plot")
|
||||||
datamodule = EuroSATDataModule(
|
datamodule = EuroSATDataModule(
|
||||||
|
@ -150,7 +163,7 @@ class TestMultiLabelClassificationTask:
|
||||||
) -> None:
|
) -> None:
|
||||||
conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))
|
conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))
|
||||||
conf_dict = OmegaConf.to_object(conf.experiment)
|
conf_dict = OmegaConf.to_object(conf.experiment)
|
||||||
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)
|
conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict)
|
||||||
|
|
||||||
# Instantiate datamodule
|
# Instantiate datamodule
|
||||||
datamodule_kwargs = conf_dict["datamodule"]
|
datamodule_kwargs = conf_dict["datamodule"]
|
||||||
|
@ -170,7 +183,7 @@ class TestMultiLabelClassificationTask:
|
||||||
def test_no_logger(self) -> None:
|
def test_no_logger(self) -> None:
|
||||||
conf = OmegaConf.load(os.path.join("tests", "conf", "bigearthnet_s1.yaml"))
|
conf = OmegaConf.load(os.path.join("tests", "conf", "bigearthnet_s1.yaml"))
|
||||||
conf_dict = OmegaConf.to_object(conf.experiment)
|
conf_dict = OmegaConf.to_object(conf.experiment)
|
||||||
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)
|
conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict)
|
||||||
|
|
||||||
# Instantiate datamodule
|
# Instantiate datamodule
|
||||||
datamodule_kwargs = conf_dict["datamodule"]
|
datamodule_kwargs = conf_dict["datamodule"]
|
||||||
|
@ -187,23 +200,23 @@ class TestMultiLabelClassificationTask:
|
||||||
trainer.fit(model=model, datamodule=datamodule)
|
trainer.fit(model=model, datamodule=datamodule)
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def model_kwargs(self) -> Dict[Any, Any]:
|
def model_kwargs(self) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"model": "resnet18",
|
"model": "resnet18",
|
||||||
"in_channels": 14,
|
"in_channels": 14,
|
||||||
"loss": "bce",
|
"loss": "bce",
|
||||||
"num_classes": 19,
|
"num_classes": 19,
|
||||||
"weights": "random",
|
"weights": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
def test_invalid_loss(self, model_kwargs: Dict[Any, Any]) -> None:
|
def test_invalid_loss(self, model_kwargs: Dict[str, Any]) -> None:
|
||||||
model_kwargs["loss"] = "invalid_loss"
|
model_kwargs["loss"] = "invalid_loss"
|
||||||
match = "Loss type 'invalid_loss' is not valid."
|
match = "Loss type 'invalid_loss' is not valid."
|
||||||
with pytest.raises(ValueError, match=match):
|
with pytest.raises(ValueError, match=match):
|
||||||
MultiLabelClassificationTask(**model_kwargs)
|
MultiLabelClassificationTask(**model_kwargs)
|
||||||
|
|
||||||
def test_missing_attributes(
|
def test_missing_attributes(
|
||||||
self, model_kwargs: Dict[Any, Any], monkeypatch: MonkeyPatch
|
self, model_kwargs: Dict[str, Any], monkeypatch: MonkeyPatch
|
||||||
) -> None:
|
) -> None:
|
||||||
monkeypatch.delattr(BigEarthNetDataModule, "plot")
|
monkeypatch.delattr(BigEarthNetDataModule, "plot")
|
||||||
datamodule = BigEarthNetDataModule(
|
datamodule = BigEarthNetDataModule(
|
||||||
|
|
|
@ -2,18 +2,30 @@
|
||||||
# Licensed under the MIT License.
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Type, cast
|
from typing import Any, Dict, Type, cast
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import timm
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
from _pytest.monkeypatch import MonkeyPatch
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from pytorch_lightning import LightningDataModule, Trainer
|
from pytorch_lightning import LightningDataModule, Trainer
|
||||||
|
from torchvision.models._api import WeightsEnum
|
||||||
|
|
||||||
from torchgeo.datamodules import COWCCountingDataModule, TropicalCycloneDataModule
|
from torchgeo.datamodules import COWCCountingDataModule, TropicalCycloneDataModule
|
||||||
|
from torchgeo.models import ResNet18_Weights
|
||||||
from torchgeo.trainers import RegressionTask
|
from torchgeo.trainers import RegressionTask
|
||||||
|
|
||||||
from .test_utils import RegressionTestModel
|
from .test_utils import RegressionTestModel
|
||||||
|
|
||||||
|
|
||||||
|
def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
||||||
|
state_dict: Dict[str, Any] = torch.load(url)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
class TestRegressionTask:
|
class TestRegressionTask:
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"name,classname",
|
"name,classname",
|
||||||
|
@ -25,7 +37,7 @@ class TestRegressionTask:
|
||||||
def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None:
|
def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None:
|
||||||
conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))
|
conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))
|
||||||
conf_dict = OmegaConf.to_object(conf.experiment)
|
conf_dict = OmegaConf.to_object(conf.experiment)
|
||||||
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)
|
conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict)
|
||||||
|
|
||||||
# Instantiate datamodule
|
# Instantiate datamodule
|
||||||
datamodule_kwargs = conf_dict["datamodule"]
|
datamodule_kwargs = conf_dict["datamodule"]
|
||||||
|
@ -46,7 +58,7 @@ class TestRegressionTask:
|
||||||
def test_no_logger(self) -> None:
|
def test_no_logger(self) -> None:
|
||||||
conf = OmegaConf.load(os.path.join("tests", "conf", "cyclone.yaml"))
|
conf = OmegaConf.load(os.path.join("tests", "conf", "cyclone.yaml"))
|
||||||
conf_dict = OmegaConf.to_object(conf.experiment)
|
conf_dict = OmegaConf.to_object(conf.experiment)
|
||||||
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)
|
conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict)
|
||||||
|
|
||||||
# Instantiate datamodule
|
# Instantiate datamodule
|
||||||
datamodule_kwargs = conf_dict["datamodule"]
|
datamodule_kwargs = conf_dict["datamodule"]
|
||||||
|
@ -63,36 +75,39 @@ class TestRegressionTask:
|
||||||
trainer.fit(model=model, datamodule=datamodule)
|
trainer.fit(model=model, datamodule=datamodule)
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def model_kwargs(self) -> Dict[Any, Any]:
|
def model_kwargs(self) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"model": "resnet18",
|
"model": "resnet18",
|
||||||
"weights": "random",
|
"weights": None,
|
||||||
"num_outputs": 1,
|
"num_outputs": 1,
|
||||||
"in_channels": 3,
|
"in_channels": 3,
|
||||||
}
|
}
|
||||||
|
|
||||||
def test_invalid_pretrained(
|
@pytest.fixture
|
||||||
self, model_kwargs: Dict[Any, Any], checkpoint: str
|
def mocked_weights(self, tmp_path: Path, monkeypatch: MonkeyPatch) -> WeightsEnum:
|
||||||
) -> None:
|
weights = ResNet18_Weights.SENTINEL2_RGB_MOCO
|
||||||
model_kwargs["weights"] = checkpoint
|
path = tmp_path / f"{weights}.pth"
|
||||||
model_kwargs["model"] = "resnet50"
|
model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"])
|
||||||
match = "Trying to load resnet18 weights into a resnet50"
|
torch.save(model.state_dict(), path)
|
||||||
with pytest.raises(ValueError, match=match):
|
monkeypatch.setattr(weights, "url", str(path))
|
||||||
RegressionTask(**model_kwargs)
|
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
|
||||||
|
return weights
|
||||||
|
|
||||||
def test_pretrained(self, model_kwargs: Dict[Any, Any], checkpoint: str) -> None:
|
def test_weight_file(self, model_kwargs: Dict[str, Any], checkpoint: str) -> None:
|
||||||
model_kwargs["weights"] = checkpoint
|
model_kwargs["weights"] = checkpoint
|
||||||
with pytest.warns(UserWarning):
|
with pytest.warns(UserWarning):
|
||||||
RegressionTask(**model_kwargs)
|
RegressionTask(**model_kwargs)
|
||||||
|
|
||||||
def test_invalid_model(self, model_kwargs: Dict[Any, Any]) -> None:
|
def test_weight_enum(
|
||||||
model_kwargs["model"] = "invalid_model"
|
self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum
|
||||||
match = "Model type 'invalid_model' is not a valid timm model."
|
) -> None:
|
||||||
with pytest.raises(ValueError, match=match):
|
model_kwargs["weights"] = mocked_weights
|
||||||
|
with pytest.warns(UserWarning):
|
||||||
RegressionTask(**model_kwargs)
|
RegressionTask(**model_kwargs)
|
||||||
|
|
||||||
def test_invalid_weights(self, model_kwargs: Dict[Any, Any]) -> None:
|
def test_weight_str(
|
||||||
model_kwargs["weights"] = "invalid_weights"
|
self, model_kwargs: Dict[str, Any], mocked_weights: WeightsEnum
|
||||||
match = "Weight type 'invalid_weights' is not valid."
|
) -> None:
|
||||||
with pytest.raises(ValueError, match=match):
|
model_kwargs["weights"] = str(mocked_weights)
|
||||||
|
with pytest.warns(UserWarning):
|
||||||
RegressionTask(**model_kwargs)
|
RegressionTask(**model_kwargs)
|
||||||
|
|
|
@ -3,14 +3,17 @@
|
||||||
|
|
||||||
"""TorchGeo models."""
|
"""TorchGeo models."""
|
||||||
|
|
||||||
|
from .api import get_model, get_model_weights, get_weight, list_models
|
||||||
from .changestar import ChangeMixin, ChangeStar, ChangeStarFarSeg
|
from .changestar import ChangeMixin, ChangeStar, ChangeStarFarSeg
|
||||||
from .farseg import FarSeg
|
from .farseg import FarSeg
|
||||||
from .fcn import FCN
|
from .fcn import FCN
|
||||||
from .fcsiam import FCSiamConc, FCSiamDiff
|
from .fcsiam import FCSiamConc, FCSiamDiff
|
||||||
from .rcf import RCF
|
from .rcf import RCF
|
||||||
from .resnet import resnet50
|
from .resnet import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50
|
||||||
|
from .vit import ViTSmall16_Weights, vit_small_patch16_224
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
|
# models
|
||||||
"ChangeMixin",
|
"ChangeMixin",
|
||||||
"ChangeStar",
|
"ChangeStar",
|
||||||
"ChangeStarFarSeg",
|
"ChangeStarFarSeg",
|
||||||
|
@ -19,5 +22,16 @@ __all__ = (
|
||||||
"FCSiamConc",
|
"FCSiamConc",
|
||||||
"FCSiamDiff",
|
"FCSiamDiff",
|
||||||
"RCF",
|
"RCF",
|
||||||
|
"resnet18",
|
||||||
"resnet50",
|
"resnet50",
|
||||||
|
"vit_small_patch16_224",
|
||||||
|
# weights
|
||||||
|
"ResNet50_Weights",
|
||||||
|
"ResNet18_Weights",
|
||||||
|
"ViTSmall16_Weights",
|
||||||
|
# utilities
|
||||||
|
"get_model",
|
||||||
|
"get_model_weights",
|
||||||
|
"get_weight",
|
||||||
|
"list_models",
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,90 @@
|
||||||
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
"""APIs for querying and loading pre-trained model weights.
|
||||||
|
|
||||||
|
See the following references for design details:
|
||||||
|
|
||||||
|
* https://pytorch.org/blog/easily-list-and-initialize-models-with-new-apis-in-torchvision/
|
||||||
|
* https://pytorch.org/vision/stable/models.html
|
||||||
|
* https://github.com/pytorch/vision/blob/main/torchvision/models/_api.py
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
from typing import Any, Callable, List, Union
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from torchvision.models._api import WeightsEnum
|
||||||
|
|
||||||
|
from .resnet import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50
|
||||||
|
from .vit import ViTSmall16_Weights, vit_small_patch16_224
|
||||||
|
|
||||||
|
_model = {
|
||||||
|
"resnet18": resnet18,
|
||||||
|
"resnet50": resnet50,
|
||||||
|
"vit_small_patch16_224": vit_small_patch16_224,
|
||||||
|
}
|
||||||
|
|
||||||
|
_model_weights = {
|
||||||
|
resnet18: ResNet18_Weights,
|
||||||
|
resnet50: ResNet50_Weights,
|
||||||
|
vit_small_patch16_224: ViTSmall16_Weights,
|
||||||
|
"resnet18": ResNet18_Weights,
|
||||||
|
"resnet50": ResNet50_Weights,
|
||||||
|
"vit_small_patch16_224": ViTSmall16_Weights,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(name: str, *args: Any, **kwargs: Any) -> nn.Module:
|
||||||
|
"""Get an instantiated model from its name.
|
||||||
|
|
||||||
|
.. versionadded:: 0.4
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Name of the model.
|
||||||
|
*args: Additional arguments passed to the model builder method.
|
||||||
|
**kwargs: Additional keyword arguments passed to the model builder method.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An instantiated model.
|
||||||
|
"""
|
||||||
|
model: nn.Module = _model[name](*args, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_weights(name: Union[Callable[..., nn.Module], str]) -> WeightsEnum:
|
||||||
|
"""Get the weights enum class associated with a given model.
|
||||||
|
|
||||||
|
.. versionadded:: 0.4
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Model builder function or the name under which it is registered.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The weights enum class associated with the model.
|
||||||
|
"""
|
||||||
|
return _model_weights[name]
|
||||||
|
|
||||||
|
|
||||||
|
def get_weight(name: str) -> WeightsEnum:
|
||||||
|
"""Get the weights enum value by its full name.
|
||||||
|
|
||||||
|
.. versionadded:: 0.4
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Name of the weight enum entry.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The requested weight enum.
|
||||||
|
"""
|
||||||
|
return eval(name)
|
||||||
|
|
||||||
|
|
||||||
|
def list_models() -> List[str]:
|
||||||
|
"""List the registered models.
|
||||||
|
|
||||||
|
.. versionadded:: 0.4
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of registered models.
|
||||||
|
"""
|
||||||
|
return list(_model.keys())
|
|
@ -9,7 +9,6 @@ from typing import List, cast
|
||||||
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchvision
|
import torchvision
|
||||||
from packaging.version import parse
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn.modules import (
|
from torch.nn.modules import (
|
||||||
BatchNorm2d,
|
BatchNorm2d,
|
||||||
|
@ -62,7 +61,6 @@ class FarSeg(Module):
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"unknown backbone: {backbone}.")
|
raise ValueError(f"unknown backbone: {backbone}.")
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if parse(torchvision.__version__) >= parse("0.13"):
|
|
||||||
if backbone_pretrained:
|
if backbone_pretrained:
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"weights": getattr(
|
"weights": getattr(
|
||||||
|
@ -71,8 +69,6 @@ class FarSeg(Module):
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
kwargs = {"weights": None}
|
kwargs = {"weights": None}
|
||||||
else:
|
|
||||||
kwargs = {"pretrained": backbone_pretrained}
|
|
||||||
|
|
||||||
self.backbone = getattr(resnet, backbone)(**kwargs)
|
self.backbone = getattr(resnet, backbone)(**kwargs)
|
||||||
|
|
||||||
|
|
|
@ -3,83 +3,208 @@
|
||||||
|
|
||||||
"""Pre-trained ResNet models."""
|
"""Pre-trained ResNet models."""
|
||||||
|
|
||||||
from typing import Any, List, Type, Union
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import kornia.augmentation as K
|
||||||
|
import timm
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.hub import load_state_dict_from_url
|
from timm.models import ResNet
|
||||||
from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet
|
from torchvision.models._api import Weights, WeightsEnum
|
||||||
|
|
||||||
MODEL_URLS = {
|
from ..transforms import AugmentationSequential
|
||||||
"sentinel2": {
|
|
||||||
"all": {
|
__all__ = ["ResNet50_Weights", "ResNet18_Weights"]
|
||||||
"resnet50": "https://zenodo.org/record/5610000/files/resnet50-sentinel2.pt"
|
|
||||||
}
|
_zhu_xlab_transforms = AugmentationSequential(
|
||||||
}
|
K.Resize(256), K.CenterCrop(224), data_keys=["image"]
|
||||||
}
|
)
|
||||||
|
|
||||||
|
# https://github.com/pytorch/vision/pull/6883
|
||||||
|
# https://github.com/pytorch/vision/pull/7107
|
||||||
|
# Can be removed once torchvision>=0.15 is required
|
||||||
|
Weights.__deepcopy__ = lambda *args, **kwargs: args[0]
|
||||||
|
|
||||||
|
|
||||||
IN_CHANNELS = {"sentinel2": {"all": 10}}
|
class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
|
"""ResNet18 weights.
|
||||||
|
|
||||||
NUM_CLASSES = {"sentinel2": 17}
|
For `timm <https://github.com/rwightman/pytorch-image-models>`_
|
||||||
|
*resnet18* implementation.
|
||||||
|
|
||||||
|
.. versionadded:: 0.4
|
||||||
|
"""
|
||||||
|
|
||||||
|
SENTINEL2_ALL_MOCO = Weights(
|
||||||
|
url=(
|
||||||
|
"https://huggingface.co/torchgeo/resnet18_sentinel2_all_moco/"
|
||||||
|
"resolve/main/resnet18_sentinel2_all_moco.pth"
|
||||||
|
),
|
||||||
|
transforms=_zhu_xlab_transforms,
|
||||||
|
meta={
|
||||||
|
"dataset": "SSL4EO-S12",
|
||||||
|
"in_chans": 13,
|
||||||
|
"model": "resnet18",
|
||||||
|
"publication": "https://arxiv.org/abs/2211.07044",
|
||||||
|
"repo": "https://github.com/zhu-xlab/SSL4EO-S12",
|
||||||
|
"ssl_method": "moco",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
SENTINEL2_RGB_MOCO = Weights(
|
||||||
|
url=(
|
||||||
|
"https://huggingface.co/torchgeo/resnet18_sentinel2_rgb_moco/"
|
||||||
|
"resolve/main/resnet18_sentinel2_rgb_moco.pth"
|
||||||
|
),
|
||||||
|
transforms=_zhu_xlab_transforms,
|
||||||
|
meta={
|
||||||
|
"dataset": "SSL4EO-S12",
|
||||||
|
"in_chans": 3,
|
||||||
|
"model": "resnet18",
|
||||||
|
"publication": "https://arxiv.org/abs/2211.07044",
|
||||||
|
"repo": "https://github.com/zhu-xlab/SSL4EO-S12",
|
||||||
|
"ssl_method": "moco",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
SENTINEL2_RGB_SECO = Weights(
|
||||||
|
url=(
|
||||||
|
"https://huggingface.co/torchgeo/resnet18_sentinel2_rgb_seco/"
|
||||||
|
"resolve/main/resnet18_sentinel2_rgb_seco.ckpt"
|
||||||
|
),
|
||||||
|
transforms=nn.Identity(),
|
||||||
|
meta={
|
||||||
|
"dataset": "SeCo Dataset",
|
||||||
|
"in_chans": 3,
|
||||||
|
"model": "resnet18",
|
||||||
|
"publication": "https://arxiv.org/abs/2103.16607",
|
||||||
|
"repo": "https://github.com/ServiceNow/seasonal-contrast",
|
||||||
|
"ssl_method": "seco",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _resnet(
|
class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
sensor: str,
|
"""ResNet50 weights.
|
||||||
bands: str,
|
|
||||||
arch: str,
|
For `timm <https://github.com/rwightman/pytorch-image-models>`_
|
||||||
block: Type[Union[BasicBlock, Bottleneck]],
|
*resnet50* implementation.
|
||||||
layers: List[int],
|
|
||||||
pretrained: bool,
|
.. versionadded:: 0.4
|
||||||
progress: bool,
|
"""
|
||||||
**kwargs: Any,
|
|
||||||
|
SENTINEL1_ALL_MOCO = Weights(
|
||||||
|
url=(
|
||||||
|
"https://huggingface.co/torchgeo/resnet50_sentinel1_all_moco/"
|
||||||
|
"resolve/main/resnet50_sentinel1_all_moco.pth"
|
||||||
|
),
|
||||||
|
transforms=_zhu_xlab_transforms,
|
||||||
|
meta={
|
||||||
|
"dataset": "SSL4EO-S12",
|
||||||
|
"in_chans": 2,
|
||||||
|
"model": "resnet50",
|
||||||
|
"publication": "https://arxiv.org/abs/2211.07044",
|
||||||
|
"repo": "https://github.com/zhu-xlab/SSL4EO-S12",
|
||||||
|
"ssl_method": "moco",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
SENTINEL2_ALL_MOCO = Weights(
|
||||||
|
url=(
|
||||||
|
"https://huggingface.co/torchgeo/resnet50_sentinel2_all_moco/"
|
||||||
|
"resolve/main/resnet50_sentinel2_all_moco.pth"
|
||||||
|
),
|
||||||
|
transforms=_zhu_xlab_transforms,
|
||||||
|
meta={
|
||||||
|
"dataset": "SSL4EO-S12",
|
||||||
|
"in_chans": 13,
|
||||||
|
"model": "resnet50",
|
||||||
|
"publication": "https://arxiv.org/abs/2211.07044",
|
||||||
|
"repo": "https://github.com/zhu-xlab/SSL4EO-S12",
|
||||||
|
"ssl_method": "moco",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
SENTINEL2_RGB_MOCO = Weights(
|
||||||
|
url=(
|
||||||
|
"https://huggingface.co/torchgeo/resnet50_sentinel2_rgb_moco/"
|
||||||
|
"resolve/main/resnet50_sentinel2_rgb_moco.pth"
|
||||||
|
),
|
||||||
|
transforms=_zhu_xlab_transforms,
|
||||||
|
meta={
|
||||||
|
"dataset": "SSL4EO-S12",
|
||||||
|
"in_chans": 3,
|
||||||
|
"model": "resnet50",
|
||||||
|
"publication": "https://arxiv.org/abs/2211.07044",
|
||||||
|
"repo": "https://github.com/zhu-xlab/SSL4EO-S12",
|
||||||
|
"ssl_method": "moco",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
SENTINEL2_ALL_DINO = Weights(
|
||||||
|
url=(
|
||||||
|
"https://huggingface.co/torchgeo/resnet50_sentinel2_all_dino/"
|
||||||
|
"resolve/main/resnet50_sentinel2_all_dino.pth"
|
||||||
|
),
|
||||||
|
transforms=_zhu_xlab_transforms,
|
||||||
|
meta={
|
||||||
|
"dataset": "SSL4EO-S12",
|
||||||
|
"in_chans": 13,
|
||||||
|
"model": "resnet50",
|
||||||
|
"publication": "https://arxiv.org/abs/2211.07044",
|
||||||
|
"repo": "https://github.com/zhu-xlab/SSL4EO-S12",
|
||||||
|
"ssl_method": "dino",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
SENTINEL2_RGB_SECO = Weights(
|
||||||
|
url=(
|
||||||
|
"https://huggingface.co/torchgeo/resnet50_sentinel2_rgb_seco/"
|
||||||
|
"resolve/main/resnet50_sentinel2_rgb_seco.ckpt"
|
||||||
|
),
|
||||||
|
transforms=nn.Identity(),
|
||||||
|
meta={
|
||||||
|
"dataset": "SeCo Dataset",
|
||||||
|
"in_chans": 3,
|
||||||
|
"model": "resnet50",
|
||||||
|
"publication": "https://arxiv.org/abs/2103.16607",
|
||||||
|
"repo": "https://github.com/ServiceNow/seasonal-contrast",
|
||||||
|
"ssl_method": "seco",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet18(
|
||||||
|
weights: Optional[ResNet18_Weights] = None, *args: Any, **kwargs: Any
|
||||||
) -> ResNet:
|
) -> ResNet:
|
||||||
"""Resnet model.
|
"""ResNet-18 model.
|
||||||
|
|
||||||
If you use this model in your research, please cite the following paper:
|
If you use this model in your research, please cite the following paper:
|
||||||
|
|
||||||
* https://arxiv.org/pdf/1512.03385.pdf
|
* https://arxiv.org/pdf/1512.03385.pdf
|
||||||
|
|
||||||
|
.. versionadded:: 0.4
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sensor: imagery source which determines number of input channels
|
weights: Pre-trained model weights to use.
|
||||||
bands: which spectral bands to consider: "all", "rgb", etc.
|
*args: Additional arguments to pass to :func:`timm.create_model`
|
||||||
arch: ResNet version specifying number of layers
|
**kwargs: Additional keywork arguments to pass to :func:`timm.create_model`
|
||||||
block: type of network block
|
|
||||||
layers: number of layers per block
|
|
||||||
pretrained: if True, returns a model pre-trained on ``sensor`` imagery
|
|
||||||
progress: if True, displays a progress bar of the download to stderr
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A ResNet-50 model
|
A ResNet-18 model.
|
||||||
"""
|
"""
|
||||||
# Initialize a new model
|
if weights:
|
||||||
model = ResNet(block, layers, NUM_CLASSES[sensor], **kwargs)
|
kwargs["in_chans"] = weights.meta["in_chans"]
|
||||||
|
|
||||||
# Replace the first layer with the correct number of input channels
|
model: ResNet = timm.create_model("resnet18", *args, **kwargs)
|
||||||
model.conv1 = nn.Conv2d(
|
|
||||||
IN_CHANNELS[sensor][bands],
|
|
||||||
out_channels=64,
|
|
||||||
kernel_size=7,
|
|
||||||
stride=1,
|
|
||||||
padding=2,
|
|
||||||
bias=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load pretrained weights
|
if weights:
|
||||||
if pretrained:
|
model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
|
||||||
state_dict = load_state_dict_from_url(
|
|
||||||
MODEL_URLS[sensor][bands][arch], progress=progress
|
|
||||||
)
|
|
||||||
model.load_state_dict(state_dict)
|
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def resnet50(
|
def resnet50(
|
||||||
sensor: str,
|
weights: Optional[ResNet50_Weights] = None, *args: Any, **kwargs: Any
|
||||||
bands: str,
|
|
||||||
pretrained: bool = False,
|
|
||||||
progress: bool = True,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> ResNet:
|
) -> ResNet:
|
||||||
"""ResNet-50 model.
|
"""ResNet-50 model.
|
||||||
|
|
||||||
|
@ -87,22 +212,23 @@ def resnet50(
|
||||||
|
|
||||||
* https://arxiv.org/pdf/1512.03385.pdf
|
* https://arxiv.org/pdf/1512.03385.pdf
|
||||||
|
|
||||||
|
.. versionchanged:: 0.4
|
||||||
|
Switched to multi-weight support API.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sensor: imagery source which determines number of input channels
|
weights: Pre-trained model weights to use.
|
||||||
bands: which spectral bands to consider: "all", "rgb", etc.
|
*args: Additional arguments to pass to :func:`timm.create_model`.
|
||||||
pretrained: if True, returns a model pre-trained on ``sensor`` imagery
|
**kwargs: Additional keywork arguments to pass to :func:`timm.create_model`.
|
||||||
progress: if True, displays a progress bar of the download to stderr
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A ResNet-50 model
|
A ResNet-50 model.
|
||||||
"""
|
"""
|
||||||
return _resnet(
|
if weights:
|
||||||
sensor,
|
kwargs["in_chans"] = weights.meta["in_chans"]
|
||||||
bands,
|
|
||||||
"resnet50",
|
model: ResNet = timm.create_model("resnet50", *args, **kwargs)
|
||||||
Bottleneck,
|
|
||||||
[3, 4, 6, 3],
|
if weights:
|
||||||
pretrained,
|
model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
|
||||||
progress,
|
|
||||||
**kwargs,
|
return model
|
||||||
)
|
|
||||||
|
|
|
@ -0,0 +1,98 @@
|
||||||
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
"""Pre-trained Vision Transformer models."""
|
||||||
|
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import kornia.augmentation as K
|
||||||
|
import timm
|
||||||
|
from timm.models.vision_transformer import VisionTransformer
|
||||||
|
from torchvision.models._api import Weights, WeightsEnum
|
||||||
|
|
||||||
|
from ..transforms import AugmentationSequential
|
||||||
|
|
||||||
|
__all__ = ["ViTSmall16_Weights"]
|
||||||
|
|
||||||
|
_zhu_xlab_transforms = AugmentationSequential(
|
||||||
|
K.Resize(256), K.CenterCrop(224), data_keys=["image"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# https://github.com/pytorch/vision/pull/6883
|
||||||
|
# https://github.com/pytorch/vision/pull/7107
|
||||||
|
# Can be removed once torchvision>=0.15 is required
|
||||||
|
Weights.__deepcopy__ = lambda *args, **kwargs: args[0]
|
||||||
|
|
||||||
|
|
||||||
|
class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
|
||||||
|
"""Vision Transformer Samll Patch Size 16 weights.
|
||||||
|
|
||||||
|
For `timm <https://github.com/rwightman/pytorch-image-models>`_
|
||||||
|
*vit_small_patch16_224* implementation.
|
||||||
|
|
||||||
|
.. versionadded:: 0.4
|
||||||
|
"""
|
||||||
|
|
||||||
|
SENTINEL2_ALL_MOCO = Weights(
|
||||||
|
url=(
|
||||||
|
"https://huggingface.co/torchgeo/vit_small_patch16_224_sentinel2_all_moco/"
|
||||||
|
"resolve/main/vit_small_patch16_224_sentinel2_all_moco.pth"
|
||||||
|
),
|
||||||
|
transforms=_zhu_xlab_transforms,
|
||||||
|
meta={
|
||||||
|
"dataset": "SSL4EO-S12",
|
||||||
|
"in_chans": 13,
|
||||||
|
"model": "vit_small_patch16_224",
|
||||||
|
"publication": "https://arxiv.org/abs/2211.07044",
|
||||||
|
"repo": "https://github.com/zhu-xlab/SSL4EO-S12",
|
||||||
|
"ssl_method": "moco",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
SENTINEL2_ALL_DINO = Weights(
|
||||||
|
url=(
|
||||||
|
"https://huggingface.co/torchgeo/vit_small_patch16_224_sentinel2_all_dino/"
|
||||||
|
"resolve/main/vit_small_patch16_224_sentinel2_all_dino.pth"
|
||||||
|
),
|
||||||
|
transforms=_zhu_xlab_transforms,
|
||||||
|
meta={
|
||||||
|
"dataset": "SSL4EO-S12",
|
||||||
|
"in_chans": 13,
|
||||||
|
"model": "vit_small_patch16_224",
|
||||||
|
"publication": "https://arxiv.org/abs/2211.07044",
|
||||||
|
"repo": "https://github.com/zhu-xlab/SSL4EO-S12",
|
||||||
|
"ssl_method": "dino",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def vit_small_patch16_224(
|
||||||
|
weights: Optional[ViTSmall16_Weights] = None, *args: Any, **kwargs: Any
|
||||||
|
) -> VisionTransformer:
|
||||||
|
"""Vision Transform (ViT) small patch size 16 model.
|
||||||
|
|
||||||
|
If you use this model in your research, please cite the following paper:
|
||||||
|
|
||||||
|
* https://arxiv.org/abs/2010.11929
|
||||||
|
|
||||||
|
.. versionadded:: 0.4
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weights: Pre-trained model weights to use.
|
||||||
|
*args: Additional arguments to pass to :func:`timm.create_model`.
|
||||||
|
**kwargs: Additional keywork arguments to pass to :func:`timm.create_model`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A ViT small 16 model.
|
||||||
|
"""
|
||||||
|
if weights:
|
||||||
|
kwargs["in_chans"] = weights.meta["in_chans"]
|
||||||
|
|
||||||
|
model: VisionTransformer = timm.create_model(
|
||||||
|
"vit_small_patch16_224", *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if weights:
|
||||||
|
model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
|
||||||
|
|
||||||
|
return model
|
|
@ -17,7 +17,9 @@ from kornia.geometry import transform as KorniaTransform
|
||||||
from torch import Tensor, optim
|
from torch import Tensor, optim
|
||||||
from torch.nn.modules import BatchNorm1d, Linear, Module, ReLU, Sequential
|
from torch.nn.modules import BatchNorm1d, Linear, Module, ReLU, Sequential
|
||||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||||
|
from torchvision.models._api import WeightsEnum
|
||||||
|
|
||||||
|
from ..models import get_weight
|
||||||
from . import utils
|
from . import utils
|
||||||
|
|
||||||
|
|
||||||
|
@ -323,41 +325,23 @@ class BYOLTask(pl.LightningModule):
|
||||||
|
|
||||||
def config_task(self) -> None:
|
def config_task(self) -> None:
|
||||||
"""Configures the task based on kwargs parameters passed to the constructor."""
|
"""Configures the task based on kwargs parameters passed to the constructor."""
|
||||||
|
# Create model
|
||||||
in_channels = self.hyperparams["in_channels"]
|
in_channels = self.hyperparams["in_channels"]
|
||||||
backbone_name = self.hyperparams["backbone"]
|
weights = self.hyperparams["weights"]
|
||||||
|
|
||||||
imagenet_pretrained = False
|
|
||||||
custom_pretrained = False
|
|
||||||
if self.hyperparams["weights"] and not os.path.exists(
|
|
||||||
self.hyperparams["weights"]
|
|
||||||
):
|
|
||||||
if self.hyperparams["weights"] not in ["imagenet", "random"]:
|
|
||||||
raise ValueError(
|
|
||||||
f"Weight type '{self.hyperparams['weights']}' is not valid."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
imagenet_pretrained = self.hyperparams["weights"] == "imagenet"
|
|
||||||
custom_pretrained = False
|
|
||||||
else:
|
|
||||||
custom_pretrained = True
|
|
||||||
|
|
||||||
# Create the model
|
|
||||||
valid_models = timm.list_models(pretrained=imagenet_pretrained)
|
|
||||||
if backbone_name in valid_models:
|
|
||||||
backbone = timm.create_model(
|
backbone = timm.create_model(
|
||||||
backbone_name, in_chans=in_channels, pretrained=imagenet_pretrained
|
self.hyperparams["backbone"],
|
||||||
|
in_chans=in_channels,
|
||||||
|
pretrained=weights is True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Load weights
|
||||||
|
if weights and weights is not True:
|
||||||
|
if isinstance(weights, WeightsEnum):
|
||||||
|
state_dict = weights.get_state_dict(progress=True)
|
||||||
|
elif os.path.exists(weights):
|
||||||
|
_, state_dict = utils.extract_backbone(weights)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Model type '{backbone_name}' is not a valid timm model.")
|
state_dict = get_weight(weights).get_state_dict(progress=True)
|
||||||
|
|
||||||
if custom_pretrained:
|
|
||||||
name, state_dict = utils.extract_backbone(self.hyperparams["weights"])
|
|
||||||
|
|
||||||
if self.hyperparams["backbone"] != name:
|
|
||||||
raise ValueError(
|
|
||||||
f"Trying to load {name} weights into a "
|
|
||||||
f"{self.hyperparams['backbone']}"
|
|
||||||
)
|
|
||||||
backbone = utils.load_state_dict(backbone, state_dict)
|
backbone = utils.load_state_dict(backbone, state_dict)
|
||||||
|
|
||||||
self.model = BYOL(backbone, in_channels=in_channels, image_size=(256, 256))
|
self.model = BYOL(backbone, in_channels=in_channels, image_size=(256, 256))
|
||||||
|
@ -368,7 +352,9 @@ class BYOLTask(pl.LightningModule):
|
||||||
Keyword Args:
|
Keyword Args:
|
||||||
in_channels: Number of input channels to model
|
in_channels: Number of input channels to model
|
||||||
backbone: Name of the timm model to use
|
backbone: Name of the timm model to use
|
||||||
weights: Either "random" or "imagenet"
|
weights: Either a weight enum, the string representation of a weight enum,
|
||||||
|
True for ImageNet weights, False or None for random weights,
|
||||||
|
or the path to a saved model state dict.
|
||||||
learning_rate: Learning rate for optimizer
|
learning_rate: Learning rate for optimizer
|
||||||
learning_rate_schedule_patience: Patience for learning rate scheduler
|
learning_rate_schedule_patience: Patience for learning rate scheduler
|
||||||
|
|
||||||
|
|
|
@ -22,8 +22,10 @@ from torchmetrics.classification import (
|
||||||
MultilabelAccuracy,
|
MultilabelAccuracy,
|
||||||
MultilabelFBetaScore,
|
MultilabelFBetaScore,
|
||||||
)
|
)
|
||||||
|
from torchvision.models._api import WeightsEnum
|
||||||
|
|
||||||
from ..datasets.utils import unbind_samples
|
from ..datasets import unbind_samples
|
||||||
|
from ..models import get_weight
|
||||||
from . import utils
|
from . import utils
|
||||||
|
|
||||||
|
|
||||||
|
@ -43,44 +45,23 @@ class ClassificationTask(pl.LightningModule):
|
||||||
|
|
||||||
def config_model(self) -> None:
|
def config_model(self) -> None:
|
||||||
"""Configures the model based on kwargs parameters passed to the constructor."""
|
"""Configures the model based on kwargs parameters passed to the constructor."""
|
||||||
in_channels = self.hyperparams["in_channels"]
|
# Create model
|
||||||
model = self.hyperparams["model"]
|
weights = self.hyperparams["weights"]
|
||||||
|
|
||||||
imagenet_pretrained = False
|
|
||||||
custom_pretrained = False
|
|
||||||
if self.hyperparams["weights"] and not os.path.exists(
|
|
||||||
self.hyperparams["weights"]
|
|
||||||
):
|
|
||||||
if self.hyperparams["weights"] not in ["imagenet", "random"]:
|
|
||||||
raise ValueError(
|
|
||||||
f"Weight type '{self.hyperparams['weights']}' is not valid."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
imagenet_pretrained = self.hyperparams["weights"] == "imagenet"
|
|
||||||
custom_pretrained = False
|
|
||||||
else:
|
|
||||||
custom_pretrained = True
|
|
||||||
|
|
||||||
# Create the model
|
|
||||||
valid_models = timm.list_models(pretrained=imagenet_pretrained)
|
|
||||||
if model in valid_models:
|
|
||||||
self.model = timm.create_model(
|
self.model = timm.create_model(
|
||||||
model,
|
self.hyperparams["model"],
|
||||||
num_classes=self.hyperparams["num_classes"],
|
num_classes=self.hyperparams["num_classes"],
|
||||||
in_chans=in_channels,
|
in_chans=self.hyperparams["in_channels"],
|
||||||
pretrained=imagenet_pretrained,
|
pretrained=weights is True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Load weights
|
||||||
|
if weights and weights is not True:
|
||||||
|
if isinstance(weights, WeightsEnum):
|
||||||
|
state_dict = weights.get_state_dict(progress=True)
|
||||||
|
elif os.path.exists(weights):
|
||||||
|
_, state_dict = utils.extract_backbone(weights)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Model type '{model}' is not a valid timm model.")
|
state_dict = get_weight(weights).get_state_dict(progress=True)
|
||||||
|
|
||||||
if custom_pretrained:
|
|
||||||
name, state_dict = utils.extract_backbone(self.hyperparams["weights"])
|
|
||||||
|
|
||||||
if self.hyperparams["model"] != name:
|
|
||||||
raise ValueError(
|
|
||||||
f"Trying to load {name} weights into a "
|
|
||||||
f"{self.hyperparams['model']}"
|
|
||||||
)
|
|
||||||
self.model = utils.load_state_dict(self.model, state_dict)
|
self.model = utils.load_state_dict(self.model, state_dict)
|
||||||
|
|
||||||
def config_task(self) -> None:
|
def config_task(self) -> None:
|
||||||
|
@ -102,7 +83,9 @@ class ClassificationTask(pl.LightningModule):
|
||||||
Keyword Args:
|
Keyword Args:
|
||||||
model: Name of the classification model use
|
model: Name of the classification model use
|
||||||
loss: Name of the loss function, accepts 'ce', 'jaccard', or 'focal'
|
loss: Name of the loss function, accepts 'ce', 'jaccard', or 'focal'
|
||||||
weights: Either "random" or "imagenet"
|
weights: Either a weight enum, the string representation of a weight enum,
|
||||||
|
True for ImageNet weights, False or None for random weights,
|
||||||
|
or the path to a saved model state dict.
|
||||||
num_classes: Number of prediction classes
|
num_classes: Number of prediction classes
|
||||||
in_channels: Number of input channels to model
|
in_channels: Number of input channels to model
|
||||||
learning_rate: Learning rate for optimizer
|
learning_rate: Learning rate for optimizer
|
||||||
|
@ -321,12 +304,6 @@ class MultiLabelClassificationTask(ClassificationTask):
|
||||||
"""
|
"""
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
# Creates `self.hparams` from kwargs
|
|
||||||
self.save_hyperparameters() # type: ignore[operator]
|
|
||||||
self.hyperparams = cast(Dict[str, Any], self.hparams)
|
|
||||||
|
|
||||||
self.config_task()
|
|
||||||
|
|
||||||
self.train_metrics = MetricCollection(
|
self.train_metrics = MetricCollection(
|
||||||
{
|
{
|
||||||
"OverallAccuracy": MultilabelAccuracy(
|
"OverallAccuracy": MultilabelAccuracy(
|
||||||
|
|
|
@ -8,18 +8,16 @@ from typing import Any, Dict, List, cast
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
|
||||||
from packaging.version import parse
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||||
from torchmetrics.detection.mean_ap import MeanAveragePrecision
|
from torchmetrics.detection.mean_ap import MeanAveragePrecision
|
||||||
|
from torchvision.models import resnet as R
|
||||||
from torchvision.models.detection import FasterRCNN
|
from torchvision.models.detection import FasterRCNN
|
||||||
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
|
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
|
||||||
from torchvision.models.detection.rpn import AnchorGenerator
|
from torchvision.models.detection.rpn import AnchorGenerator
|
||||||
from torchvision.ops import MultiScaleRoIAlign
|
from torchvision.ops import MultiScaleRoIAlign
|
||||||
|
|
||||||
if parse(torchvision.__version__) >= parse("0.13"):
|
from ..datasets.utils import unbind_samples
|
||||||
from torchvision.models import resnet as R
|
|
||||||
|
|
||||||
BACKBONE_WEIGHT_MAP = {
|
BACKBONE_WEIGHT_MAP = {
|
||||||
"resnet18": R.ResNet18_Weights.DEFAULT,
|
"resnet18": R.ResNet18_Weights.DEFAULT,
|
||||||
|
@ -33,8 +31,6 @@ if parse(torchvision.__version__) >= parse("0.13"):
|
||||||
"wide_resnet101_2": R.Wide_ResNet101_2_Weights.DEFAULT,
|
"wide_resnet101_2": R.Wide_ResNet101_2_Weights.DEFAULT,
|
||||||
}
|
}
|
||||||
|
|
||||||
from ..datasets.utils import unbind_samples
|
|
||||||
|
|
||||||
|
|
||||||
class ObjectDetectionTask(pl.LightningModule):
|
class ObjectDetectionTask(pl.LightningModule):
|
||||||
"""LightningModule for object detection of images.
|
"""LightningModule for object detection of images.
|
||||||
|
@ -62,15 +58,12 @@ class ObjectDetectionTask(pl.LightningModule):
|
||||||
"backbone_name": self.hyperparams["backbone"],
|
"backbone_name": self.hyperparams["backbone"],
|
||||||
"trainable_layers": self.hyperparams.get("trainable_layers", 3),
|
"trainable_layers": self.hyperparams.get("trainable_layers", 3),
|
||||||
}
|
}
|
||||||
if parse(torchvision.__version__) >= parse("0.13"):
|
|
||||||
if backbone_pretrained:
|
if backbone_pretrained:
|
||||||
kwargs["weights"] = BACKBONE_WEIGHT_MAP[
|
kwargs["weights"] = BACKBONE_WEIGHT_MAP[
|
||||||
self.hyperparams["backbone"]
|
self.hyperparams["backbone"]
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
kwargs["weights"] = None
|
kwargs["weights"] = None
|
||||||
else:
|
|
||||||
kwargs["pretrained"] = backbone_pretrained
|
|
||||||
|
|
||||||
backbone = resnet_fpn_backbone(**kwargs)
|
backbone = resnet_fpn_backbone(**kwargs)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -14,8 +14,10 @@ import torch.nn.functional as F
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||||
from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection
|
from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection
|
||||||
|
from torchvision.models._api import WeightsEnum
|
||||||
|
|
||||||
from ..datasets.utils import unbind_samples
|
from ..datasets import unbind_samples
|
||||||
|
from ..models import get_weight
|
||||||
from . import utils
|
from . import utils
|
||||||
|
|
||||||
|
|
||||||
|
@ -35,44 +37,23 @@ class RegressionTask(pl.LightningModule):
|
||||||
|
|
||||||
def config_task(self) -> None:
|
def config_task(self) -> None:
|
||||||
"""Configures the task based on kwargs parameters."""
|
"""Configures the task based on kwargs parameters."""
|
||||||
in_channels = self.hyperparams["in_channels"]
|
# Create model
|
||||||
model = self.hyperparams["model"]
|
weights = self.hyperparams["weights"]
|
||||||
|
|
||||||
imagenet_pretrained = False
|
|
||||||
custom_pretrained = False
|
|
||||||
if self.hyperparams["weights"] and not os.path.exists(
|
|
||||||
self.hyperparams["weights"]
|
|
||||||
):
|
|
||||||
if self.hyperparams["weights"] not in ["imagenet", "random"]:
|
|
||||||
raise ValueError(
|
|
||||||
f"Weight type '{self.hyperparams['weights']}' is not valid."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
imagenet_pretrained = self.hyperparams["weights"] == "imagenet"
|
|
||||||
custom_pretrained = False
|
|
||||||
else:
|
|
||||||
custom_pretrained = True
|
|
||||||
|
|
||||||
# Create the model
|
|
||||||
valid_models = timm.list_models(pretrained=imagenet_pretrained)
|
|
||||||
if model in valid_models:
|
|
||||||
self.model = timm.create_model(
|
self.model = timm.create_model(
|
||||||
model,
|
self.hyperparams["model"],
|
||||||
num_classes=self.hyperparams["num_outputs"],
|
num_classes=self.hyperparams["num_outputs"],
|
||||||
in_chans=in_channels,
|
in_chans=self.hyperparams["in_channels"],
|
||||||
pretrained=imagenet_pretrained,
|
pretrained=weights is True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Load weights
|
||||||
|
if weights and weights is not True:
|
||||||
|
if isinstance(weights, WeightsEnum):
|
||||||
|
state_dict = weights.get_state_dict(progress=True)
|
||||||
|
elif os.path.exists(weights):
|
||||||
|
_, state_dict = utils.extract_backbone(weights)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Model type '{model}' is not a valid timm model.")
|
state_dict = get_weight(weights).get_state_dict(progress=True)
|
||||||
|
|
||||||
if custom_pretrained:
|
|
||||||
name, state_dict = utils.extract_backbone(self.hyperparams["weights"])
|
|
||||||
|
|
||||||
if self.hyperparams["model"] != name:
|
|
||||||
raise ValueError(
|
|
||||||
f"Trying to load {name} weights into a "
|
|
||||||
f"{self.hyperparams['model']}"
|
|
||||||
)
|
|
||||||
self.model = utils.load_state_dict(self.model, state_dict)
|
self.model = utils.load_state_dict(self.model, state_dict)
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any) -> None:
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
|
@ -80,7 +61,9 @@ class RegressionTask(pl.LightningModule):
|
||||||
|
|
||||||
Keyword Args:
|
Keyword Args:
|
||||||
model: Name of the timm model to use
|
model: Name of the timm model to use
|
||||||
weights: Either "random" or "imagenet"
|
weights: Either a weight enum, the string representation of a weight enum,
|
||||||
|
True for ImageNet weights, False or None for random weights,
|
||||||
|
or the path to a saved model state dict.
|
||||||
num_outputs: Number of prediction outputs
|
num_outputs: Number of prediction outputs
|
||||||
in_channels: Number of input channels to model
|
in_channels: Number of input channels to model
|
||||||
learning_rate: Learning rate for optimizer
|
learning_rate: Learning rate for optimizer
|
||||||
|
|
Загрузка…
Ссылка в новой задаче