SatlasPretrain: ResNet50/152 and Swin_V2_T Weights (#2038)

* add SENTINEL2_MS_MI_SATLAS

* add SENTINEL2_MS_SI_SATLAS

* fix style

* ruff

* Add all SatlasPretrain models

* Get local tests passing

* Get local tests passing

* Fix remote tests

* Mock input channels in testing too

* Fix license

* Fix band order

---------

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Yi-Chia Chang 2024-08-27 08:58:49 -05:00 коммит произвёл GitHub
Родитель 2d6e27ebd0
Коммит a757cf14fb
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
14 изменённых файлов: 628 добавлений и 99 удалений

Просмотреть файл

@ -29,4 +29,5 @@ ResNet50_Weights.LANDSAT_OLI_SR_MOCO,8--9,7,`link <https://github.com/microsoft/
ResNet50_Weights.LANDSAT_OLI_SR_SIMCLR,8--9,7,`link <https://github.com/microsoft/torchgeo>`__,`link <https://proceedings.neurips.cc/paper_files/paper/2023/hash/bbf7ee04e2aefec136ecf60e346c2e61-Abstract-Datasets_and_Benchmarks.html>`__,"CC0-1.0",63.65,46.68,60.01,43.17
ViTSmall16_Weights.LANDSAT_OLI_SR_MOCO,8--9,7,`link <https://github.com/microsoft/torchgeo>`__,`link <https://proceedings.neurips.cc/paper_files/paper/2023/hash/bbf7ee04e2aefec136ecf60e346c2e61-Abstract-Datasets_and_Benchmarks.html>`__,"CC0-1.0",66.81,50.16,64.17,47.24
ViTSmall16_Weights.LANDSAT_OLI_SR_SIMCLR,8--9,7,`link <https://github.com/microsoft/torchgeo>`__,`link <https://proceedings.neurips.cc/paper_files/paper/2023/hash/bbf7ee04e2aefec136ecf60e346c2e61-Abstract-Datasets_and_Benchmarks.html>`__,"CC0-1.0",65.04,48.20,62.61,45.46
Swin_V2_B_Weights.LANDSAT_MS_SI_SATLAS,11,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,"ODC-BY",,,,
Swin_V2_B_Weights.LANDSAT_SI_SATLAS,8--9,11,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY,,,,
Swin_V2_B_Weights.LANDSAT_MI_SATLAS,8--9,11,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY,,,,

Не удается отобразить этот файл, потому что он имеет неправильное количество полей в строке 32.

Просмотреть файл

@ -47,8 +47,10 @@ ResNet
.. autofunction:: resnet18
.. autofunction:: resnet50
.. autofunction:: resnet152
.. autoclass:: ResNet18_Weights
.. autoclass:: ResNet50_Weights
.. autoclass:: ResNet152_Weights
Scale-MAE
^^^^^^^^^
@ -59,7 +61,9 @@ Scale-MAE
Swin Transformer
^^^^^^^^^^^^^^^^^^
.. autofunction:: swin_v2_t
.. autofunction:: swin_v2_b
.. autoclass:: Swin_V2_T_Weights
.. autoclass:: Swin_V2_B_Weights
Vision Transformer

Просмотреть файл

@ -1,2 +1,3 @@
Weight,Channels,Source,Citation,License
Swin_V2_B_Weights.NAIP_RGB_SI_SATLAS,3,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,"ODC-BY"
Swin_V2_B_Weights.NAIP_RGB_MI_SATLAS,3,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY
Swin_V2_B_Weights.NAIP_RGB_SI_SATLAS,3,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY

1 Weight Channels Source Citation License
2 Swin_V2_B_Weights.NAIP_RGB_SI_SATLAS Swin_V2_B_Weights.NAIP_RGB_MI_SATLAS 3 `link <https://github.com/allenai/satlas>`__ `link <https://arxiv.org/abs/2211.15660>`__ ODC-BY
3 Swin_V2_B_Weights.NAIP_RGB_SI_SATLAS 3 `link <https://github.com/allenai/satlas>`__ `link <https://arxiv.org/abs/2211.15660>`__ ODC-BY

Просмотреть файл

@ -1,4 +1,5 @@
Weight,Channels,Source,Citation,License
ResNet50_Weights.SENTINEL1_ALL_DECUR, 2,`link <https://github.com/zhu-xlab/DeCUR>`__,`link <https://arxiv.org/abs/2309.05300>`__,"Apache-2.0"
ResNet50_Weights.SENTINEL1_ALL_MOCO, 2,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,"CC-BY-4.0"
Swin_V2_B_Weights.SENTINEL1_SI_SATLAS,2,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,"ODC-BY"
Swin_V2_B_Weights.SENTINEL1_MI_SATLAS,2,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY
Swin_V2_B_Weights.SENTINEL1_SI_SATLAS,2,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY

1 Weight Channels Source Citation License
2 ResNet50_Weights.SENTINEL1_ALL_DECUR 2 `link <https://github.com/zhu-xlab/DeCUR>`__ `link <https://arxiv.org/abs/2309.05300>`__ Apache-2.0
3 ResNet50_Weights.SENTINEL1_ALL_MOCO 2 `link <https://github.com/zhu-xlab/SSL4EO-S12>`__ `link <https://arxiv.org/abs/2211.07044>`__ CC-BY-4.0
4 Swin_V2_B_Weights.SENTINEL1_SI_SATLAS Swin_V2_B_Weights.SENTINEL1_MI_SATLAS 2 `link <https://github.com/allenai/satlas>`__ `link <https://arxiv.org/abs/2211.15660>`__ ODC-BY
5 Swin_V2_B_Weights.SENTINEL1_SI_SATLAS 2 `link <https://github.com/allenai/satlas>`__ `link <https://arxiv.org/abs/2211.15660>`__ ODC-BY

Просмотреть файл

@ -5,9 +5,23 @@ ResNet18_Weights.SENTINEL2_RGB_SECO, 3,`link <https://github.com/ServiceNow/seas
ResNet50_Weights.SENTINEL2_ALL_DECUR,13,`link <https://github.com/zhu-xlab/DeCUR>`__,`link <https://arxiv.org/abs/2309.05300>`__,"Apache-2.0",,,,
ResNet50_Weights.SENTINEL2_ALL_DINO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,"CC-BY-4.0",90.7,99.1,63.6,
ResNet50_Weights.SENTINEL2_ALL_MOCO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,"CC-BY-4.0",91.8,99.1,60.9,
ResNet50_Weights.SENTINEL2_MI_MS_SATLAS,9,`link <https://github.com/allenai/satlaspretrain_models>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY,,,,
ResNet50_Weights.SENTINEL2_MI_RGB_SATLAS,3,`link <https://github.com/allenai/satlaspretrain_models>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY,,,,
ResNet50_Weights.SENTINEL2_SI_MS_SATLAS,9,`link <https://github.com/allenai/satlaspretrain_models>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY,,,,
ResNet50_Weights.SENTINEL2_SI_RGB_SATLAS,3,`link <https://github.com/allenai/satlaspretrain_models>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY,,,,
ResNet50_Weights.SENTINEL2_RGB_MOCO, 3,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,"CC-BY-4.0",,,
ResNet50_Weights.SENTINEL2_RGB_SECO, 3,`link <https://github.com/ServiceNow/seasonal-contrast>`__,`link <https://arxiv.org/abs/2103.16607>`__,"Apache-2.0",87.81,,,
ResNet152_Weights.SENTINEL2_MI_MS_SATLAS,9,`link <https://github.com/allenai/satlaspretrain_models>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY,,,,
ResNet152_Weights.SENTINEL2_MI_RGB_SATLAS,3,`link <https://github.com/allenai/satlaspretrain_models>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY,,,,
ResNet152_Weights.SENTINEL2_SI_MS_SATLAS,9,`link <https://github.com/allenai/satlaspretrain_models>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY,,,,
ResNet152_Weights.SENTINEL2_SI_RGB_SATLAS,3,`link <https://github.com/allenai/satlaspretrain_models>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY,,,,
ViTSmall16_Weights.SENTINEL2_ALL_DINO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,"CC-BY-4.0",90.5,99.0,62.2,
ViTSmall16_Weights.SENTINEL2_ALL_MOCO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,"CC-BY-4.0",89.9,98.6,61.6,
Swin_V2_B_Weights.SENTINEL2_RGB_SI_SATLAS,3,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,"ODC-BY",,,,
Swin_V2_B_Weights.SENTINEL2_MS_SI_SATLAS,9,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,"ODC-BY",,,,
Swin_V2_T_Weights.SENTINEL2_MI_MS_SATLAS,9,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY,,,,
Swin_V2_T_Weights.SENTINEL2_MI_RGB_SATLAS,3,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY,,,,
Swin_V2_T_Weights.SENTINEL2_SI_MS_SATLAS,9,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY,,,,
Swin_V2_T_Weights.SENTINEL2_SI_RGB_SATLAS,3,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY,,,,
Swin_V2_B_Weights.SENTINEL2_MI_MS_SATLAS,9,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY,,,,
Swin_V2_B_Weights.SENTINEL2_MI_RGB_SATLAS,3,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY,,,,
Swin_V2_B_Weights.SENTINEL2_SI_MS_SATLAS,9,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY,,,,
Swin_V2_B_Weights.SENTINEL2_SI_RGB_SATLAS,3,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY,,,,

Не удается отобразить этот файл, потому что он имеет неправильное количество полей в строке 8.

Просмотреть файл

@ -12,8 +12,10 @@ from torchgeo.models import (
dofa_large_patch16_224,
resnet18,
resnet50,
resnet152,
scalemae_large_patch16,
swin_v2_b,
swin_v2_t,
vit_small_patch16_224,
)
@ -22,9 +24,11 @@ __all__ = (
'dofa_large_patch16_224',
'resnet18',
'resnet50',
'resnet152',
'scalemae_large_patch16',
'swin_v2_t',
'swin_v2_b',
'vit_small_patch16_224',
)
dependencies = ['timm']
dependencies = ['timm', 'torchvision']

Просмотреть файл

@ -13,8 +13,10 @@ from torchgeo.models import (
DOFALarge16_Weights,
ResNet18_Weights,
ResNet50_Weights,
ResNet152_Weights,
ScaleMAELarge16_Weights,
Swin_V2_B_Weights,
Swin_V2_T_Weights,
ViTSmall16_Weights,
dofa_base_patch16_224,
dofa_large_patch16_224,
@ -24,8 +26,10 @@ from torchgeo.models import (
list_models,
resnet18,
resnet50,
resnet152,
scalemae_large_patch16,
swin_v2_b,
swin_v2_t,
vit_small_patch16_224,
)
@ -34,7 +38,9 @@ builders = [
dofa_large_patch16_224,
resnet18,
resnet50,
resnet152,
scalemae_large_patch16,
swin_v2_t,
swin_v2_b,
vit_small_patch16_224,
]
@ -43,7 +49,9 @@ enums = [
DOFALarge16_Weights,
ResNet18_Weights,
ResNet50_Weights,
ResNet152_Weights,
ScaleMAELarge16_Weights,
Swin_V2_T_Weights,
Swin_V2_B_Weights,
ViTSmall16_Weights,
]

Просмотреть файл

@ -10,7 +10,14 @@ from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from torchvision.models._api import WeightsEnum
from torchgeo.models import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50
from torchgeo.models import (
ResNet18_Weights,
ResNet50_Weights,
ResNet152_Weights,
resnet18,
resnet50,
resnet152,
)
class TestResNet18:
@ -44,7 +51,7 @@ class TestResNet18:
def test_transforms(self, mocked_weights: WeightsEnum) -> None:
c = mocked_weights.meta['in_chans']
sample = {
'image': torch.arange(c * 224 * 224, dtype=torch.float).view(c, 224, 224)
'image': torch.arange(c * 256 * 256, dtype=torch.float).view(c, 256, 256)
}
mocked_weights.transforms(sample)
@ -84,10 +91,50 @@ class TestResNet50:
def test_transforms(self, mocked_weights: WeightsEnum) -> None:
c = mocked_weights.meta['in_chans']
sample = {
'image': torch.arange(c * 224 * 224, dtype=torch.float).view(c, 224, 224)
'image': torch.arange(c * 256 * 256, dtype=torch.float).view(c, 256, 256)
}
mocked_weights.transforms(sample)
@pytest.mark.slow
def test_resnet_download(self, weights: WeightsEnum) -> None:
resnet50(weights=weights)
class TestResNet152:
@pytest.fixture(params=[*ResNet152_Weights])
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param
@pytest.fixture
def mocked_weights(
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = timm.create_model('resnet152', in_chans=weights.meta['in_chans'])
torch.save(model.state_dict(), path)
try:
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
return weights
def test_resnet(self) -> None:
resnet152()
def test_resnet_weights(self, mocked_weights: WeightsEnum) -> None:
resnet152(weights=mocked_weights)
def test_transforms(self, mocked_weights: WeightsEnum) -> None:
c = mocked_weights.meta['in_chans']
sample = {
'image': torch.arange(c * 256 * 256, dtype=torch.float).view(c, 256, 256)
}
mocked_weights.transforms(sample)
@pytest.mark.slow
def test_resnet_download(self, weights: WeightsEnum) -> None:
resnet152(weights=weights)

Просмотреть файл

@ -10,7 +10,52 @@ from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from torchvision.models._api import WeightsEnum
from torchgeo.models import Swin_V2_B_Weights, swin_v2_b
from torchgeo.models import Swin_V2_B_Weights, Swin_V2_T_Weights, swin_v2_b, swin_v2_t
class TestSwin_V2_T:
@pytest.fixture(params=[*Swin_V2_T_Weights])
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param
@pytest.fixture
def mocked_weights(
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = torchvision.models.swin_v2_t()
num_channels = weights.meta['in_chans']
out_channels = model.features[0][0].out_channels
model.features[0][0] = torch.nn.Conv2d(
num_channels, out_channels, kernel_size=(4, 4), stride=(4, 4)
)
torch.save(model.state_dict(), path)
try:
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
return weights
def test_swin_v2_t(self) -> None:
swin_v2_t()
def test_swin_v2_t_weights(self, mocked_weights: WeightsEnum) -> None:
swin_v2_t(weights=mocked_weights)
def test_transforms(self, mocked_weights: WeightsEnum) -> None:
c = mocked_weights.meta['in_chans']
sample = {
'image': torch.arange(c * 256 * 256, dtype=torch.float).view(c, 256, 256)
}
mocked_weights.transforms(sample)
@pytest.mark.slow
def test_swin_v2_t_download(self, weights: WeightsEnum) -> None:
swin_v2_t(weights=weights)
class TestSwin_V2_B:
@ -28,6 +73,11 @@ class TestSwin_V2_B:
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = torchvision.models.swin_v2_b()
num_channels = weights.meta['in_chans']
out_channels = model.features[0][0].out_channels
model.features[0][0] = torch.nn.Conv2d(
num_channels, out_channels, kernel_size=(4, 4), stride=(4, 4)
)
torch.save(model.state_dict(), path)
try:
monkeypatch.setattr(weights.value, 'url', str(path))

Просмотреть файл

@ -18,9 +18,16 @@ from .farseg import FarSeg
from .fcn import FCN
from .fcsiam import FCSiamConc, FCSiamDiff
from .rcf import RCF
from .resnet import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50
from .resnet import (
ResNet18_Weights,
ResNet50_Weights,
ResNet152_Weights,
resnet18,
resnet50,
resnet152,
)
from .scale_mae import ScaleMAE, ScaleMAELarge16_Weights, scalemae_large_patch16
from .swin import Swin_V2_B_Weights, swin_v2_b
from .swin import Swin_V2_B_Weights, Swin_V2_T_Weights, swin_v2_b, swin_v2_t
from .vit import ViTSmall16_Weights, vit_small_patch16_224
__all__ = (
@ -40,16 +47,20 @@ __all__ = (
'RCF',
'resnet18',
'resnet50',
'resnet152',
'ScaleMAE',
'scalemae_large_patch16',
'swin_v2_t',
'swin_v2_b',
'vit_small_patch16_224',
# weights
'DOFABase16_Weights',
'DOFALarge16_Weights',
'ResNet50_Weights',
'ResNet18_Weights',
'ResNet50_Weights',
'ResNet152_Weights',
'ScaleMAELarge16_Weights',
'Swin_V2_T_Weights',
'Swin_V2_B_Weights',
'ViTSmall16_Weights',
# utilities

Просмотреть файл

@ -22,9 +22,16 @@ from .dofa import (
dofa_base_patch16_224,
dofa_large_patch16_224,
)
from .resnet import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50
from .resnet import (
ResNet18_Weights,
ResNet50_Weights,
ResNet152_Weights,
resnet18,
resnet50,
resnet152,
)
from .scale_mae import ScaleMAELarge16_Weights, scalemae_large_patch16
from .swin import Swin_V2_B_Weights, swin_v2_b
from .swin import Swin_V2_B_Weights, Swin_V2_T_Weights, swin_v2_b, swin_v2_t
from .vit import ViTSmall16_Weights, vit_small_patch16_224
_model = {
@ -32,7 +39,9 @@ _model = {
'dofa_large_patch16_224': dofa_large_patch16_224,
'resnet18': resnet18,
'resnet50': resnet50,
'resnet152': resnet152,
'scalemae_large_patch16': scalemae_large_patch16,
'swin_v2_t': swin_v2_t,
'swin_v2_b': swin_v2_b,
'vit_small_patch16_224': vit_small_patch16_224,
}
@ -42,14 +51,18 @@ _model_weights = {
dofa_large_patch16_224: DOFALarge16_Weights,
resnet18: ResNet18_Weights,
resnet50: ResNet50_Weights,
resnet152: ResNet152_Weights,
scalemae_large_patch16: ScaleMAELarge16_Weights,
swin_v2_t: Swin_V2_T_Weights,
swin_v2_b: Swin_V2_B_Weights,
vit_small_patch16_224: ViTSmall16_Weights,
'dofa_base_patch16_224': DOFABase16_Weights,
'dofa_large_patch16_224': DOFALarge16_Weights,
'resnet18': ResNet18_Weights,
'resnet50': ResNet50_Weights,
'resnet152': ResNet152_Weights,
'scalemae_large_patch16': ScaleMAELarge16_Weights,
'swin_v2_t': Swin_V2_T_Weights,
'swin_v2_b': Swin_V2_B_Weights,
'vit_small_patch16_224': ViTSmall16_Weights,
}

Просмотреть файл

@ -11,6 +11,13 @@ import torch
from timm.models import ResNet
from torchvision.models._api import Weights, WeightsEnum
from .swin import (
_satlas_bands,
_satlas_sentinel2_bands,
_satlas_sentinel2_transforms,
_satlas_transforms,
)
# https://github.com/zhu-xlab/DeCUR/blob/f190e9a3895ef645c005c8c2fce287ffa5a937e3/src/transfer_classification_BE/linear_BE_resnet.py#L286
# Normalization by channel-wise band statistics
_mean_s1 = torch.tensor([-12.59, -20.26])
@ -113,7 +120,7 @@ Weights.__deepcopy__ = lambda *args, **kwargs: args[0]
class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
"""ResNet18 weights.
"""ResNet-18 weights.
For `timm <https://github.com/rwightman/pytorch-image-models>`_
*resnet18* implementation.
@ -292,7 +299,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
"""ResNet50 weights.
"""ResNet-50 weights.
For `timm <https://github.com/rwightman/pytorch-image-models>`_
*resnet50* implementation.
@ -508,6 +515,32 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
},
)
SENTINEL2_MI_MS_SATLAS = Weights(
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_resnet50_mi_ms-da5413d2.pth',
transforms=_satlas_sentinel2_transforms,
meta={
'dataset': 'SatlasPretrain',
'in_chans': 9,
'model': 'resnet50',
'publication': 'https://arxiv.org/abs/2211.15660',
'repo': 'https://github.com/allenai/satlaspretrain_models',
'bands': _satlas_sentinel2_bands,
},
)
SENTINEL2_MI_RGB_SATLAS = Weights(
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_resnet50_mi_rgb-e79bb7fe.pth',
transforms=_satlas_transforms,
meta={
'dataset': 'SatlasPretrain',
'in_chans': 3,
'model': 'resnet50',
'publication': 'https://arxiv.org/abs/2211.15660',
'repo': 'https://github.com/allenai/satlaspretrain_models',
'bands': _satlas_bands,
},
)
SENTINEL2_RGB_MOCO = Weights(
url='https://hf.co/torchgeo/resnet50_sentinel2_rgb_moco/resolve/efd9723b59a88e9dc1420dc1e96afb25b0630a3c/resnet50_sentinel2_rgb_moco-2b57ba8b.pth',
transforms=_ssl4eo_s12_transforms_s2_10k,
@ -534,6 +567,94 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
},
)
SENTINEL2_SI_MS_SATLAS = Weights(
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_resnet50_si_ms-1f454cc6.pth',
transforms=_satlas_sentinel2_transforms,
meta={
'dataset': 'SatlasPretrain',
'in_chans': 9,
'model': 'resnet50',
'publication': 'https://arxiv.org/abs/2211.15660',
'repo': 'https://github.com/allenai/satlaspretrain_models',
'bands': _satlas_sentinel2_bands,
},
)
SENTINEL2_SI_RGB_SATLAS = Weights(
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_resnet50_si_rgb-45fc6972.pth',
transforms=_satlas_transforms,
meta={
'dataset': 'SatlasPretrain',
'in_chans': 3,
'model': 'resnet50',
'publication': 'https://arxiv.org/abs/2211.15660',
'repo': 'https://github.com/allenai/satlaspretrain_models',
'bands': _satlas_bands,
},
)
class ResNet152_Weights(WeightsEnum): # type: ignore[misc]
"""ResNet-152 weights.
For `timm <https://github.com/rwightman/pytorch-image-models>`_
*resnet152* implementation.
.. versionadded:: 0.6
"""
SENTINEL2_MI_MS_SATLAS = Weights(
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_resnet152_mi_ms-fd35b4bb.pth',
transforms=_satlas_sentinel2_transforms,
meta={
'dataset': 'SatlasPretrain',
'in_chans': 9,
'model': 'resnet50',
'publication': 'https://arxiv.org/abs/2211.15660',
'repo': 'https://github.com/allenai/satlaspretrain_models',
'bands': _satlas_sentinel2_bands,
},
)
SENTINEL2_MI_RGB_SATLAS = Weights(
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_resnet152_mi_rgb-67563ac5.pth',
transforms=_satlas_transforms,
meta={
'dataset': 'SatlasPretrain',
'in_chans': 3,
'model': 'resnet50',
'publication': 'https://arxiv.org/abs/2211.15660',
'repo': 'https://github.com/allenai/satlaspretrain_models',
'bands': _satlas_bands,
},
)
SENTINEL2_SI_MS_SATLAS = Weights(
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_resnet152_si_ms-4500c6cb.pth',
transforms=_satlas_sentinel2_transforms,
meta={
'dataset': 'SatlasPretrain',
'in_chans': 9,
'model': 'resnet50',
'publication': 'https://arxiv.org/abs/2211.15660',
'repo': 'https://github.com/allenai/satlaspretrain_models',
'bands': _satlas_sentinel2_bands,
},
)
SENTINEL2_SI_RGB_SATLAS = Weights(
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_resnet152_si_rgb-f4d24c3c.pth',
transforms=_satlas_transforms,
meta={
'dataset': 'SatlasPretrain',
'in_chans': 3,
'model': 'resnet50',
'publication': 'https://arxiv.org/abs/2211.15660',
'repo': 'https://github.com/allenai/satlaspretrain_models',
'bands': _satlas_bands,
},
)
def resnet18(
weights: ResNet18_Weights | None = None, *args: Any, **kwargs: Any
@ -602,3 +723,37 @@ def resnet50(
assert not unexpected_keys
return model
def resnet152(
weights: ResNet152_Weights | None = None, *args: Any, **kwargs: Any
) -> ResNet:
"""ResNet-152 model.
If you use this model in your research, please cite the following paper:
* https://arxiv.org/pdf/1512.03385.pdf
.. versionadded:: 0.6
Args:
weights: Pre-trained model weights to use.
*args: Additional arguments to pass to :func:`timm.create_model`.
**kwargs: Additional keywork arguments to pass to :func:`timm.create_model`.
Returns:
A ResNet-152 model.
"""
if weights:
kwargs['in_chans'] = weights.meta['in_chans']
model: ResNet = timm.create_model('resnet152', *args, **kwargs)
if weights:
missing_keys, unexpected_keys = model.load_state_dict(
weights.get_state_dict(progress=True), strict=False
)
assert set(missing_keys) <= {'fc.weight', 'fc.bias'}
assert not unexpected_keys
return model

Просмотреть файл

@ -8,35 +8,38 @@ from typing import Any
import kornia.augmentation as K
import torch
import torchvision
from kornia.contrib import Lambda
from torchvision.models import SwinTransformer
from torchvision.models._api import Weights, WeightsEnum
# https://github.com/allenai/satlas/blob/bcaa968da5395f675d067613e02613a344e81415/satlas/cmd/model/train.py#L42
# Satlas uses the TCI product for Sentinel-2 RGB, which is in the range (0, 255).
# See details: https://github.com/allenai/satlas/blob/main/Normalization.md#sentinel-2-images.
# Satlas Sentinel-1 and RGB Sentinel-2 and NAIP imagery is uint8 and is normalized to (0, 1) by dividing by 255.
_satlas_transforms = K.AugmentationSequential(
K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)), data_keys=None
)
import torchgeo.transforms.transforms as T
# Satlas uses the TCI product for Sentinel-2 RGB, which is in the range (0, 255).
# See details: https://github.com/allenai/satlas/blob/main/Normalization.md#sentinel-2-images.
# Satlas Sentinel-2 multispectral imagery has first 3 bands divided by 255 and the following 6 bands by 8160, both clipped to (0, 1).
_std = torch.tensor(
[255.0, 255.0, 255.0, 8160.0, 8160.0, 8160.0, 8160.0, 8160.0, 8160.0]
)
_mean = torch.zeros_like(_std)
_sentinel2_ms_satlas_transforms = K.AugmentationSequential(
K.Normalize(mean=_mean, std=_std),
K.ImageSequential(Lambda(lambda x: torch.clamp(x, min=0.0, max=1.0))),
# All Satlas transforms include:
# https://github.com/allenai/satlas/blob/main/satlas/cmd/model/train.py#L49
#
# Information about sensor-specific normalization can be found at:
# https://github.com/allenai/satlas/blob/main/Normalization.md
_satlas_bands = ('B04', 'B03', 'B02')
_satlas_transforms = K.AugmentationSequential(
K.CenterCrop(256),
K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)),
data_keys=None,
)
# Satlas Landsat imagery is 16-bit, normalized by clipping some pixel N with (N-4000)/16320 to (0, 1).
_landsat_satlas_transforms = K.AugmentationSequential(
_satlas_sentinel2_bands = (*_satlas_bands, 'B05', 'B06', 'B07', 'B08', 'B11', 'B12')
_std = torch.tensor([255, 255, 255, 8160, 8160, 8160, 8160, 8160, 8160])
_satlas_sentinel2_transforms = K.AugmentationSequential(
K.CenterCrop(256),
K.Normalize(mean=torch.tensor(0), std=_std),
T._Clamp(p=1, min=0, max=1),
data_keys=None,
)
_satlas_landsat_bands = tuple(f'B{i:02}' for i in range(1, 12))
_satlas_landsat_transforms = K.AugmentationSequential(
K.CenterCrop(256),
K.Normalize(mean=torch.tensor(4000), std=torch.tensor(16320)),
K.ImageSequential(Lambda(lambda x: torch.clamp(x, min=0.0, max=1.0))),
T._Clamp(p=1, min=0, max=1),
data_keys=None,
)
@ -46,6 +49,68 @@ _landsat_satlas_transforms = K.AugmentationSequential(
Weights.__deepcopy__ = lambda *args, **kwargs: args[0]
class Swin_V2_T_Weights(WeightsEnum): # type: ignore[misc]
"""Swin Transformer v2 Tiny weights.
For `torchvision <https://github.com/pytorch/vision>`_
*swin_v2_t* implementation.
.. versionadded:: 0.6
"""
SENTINEL2_MI_MS_SATLAS = Weights(
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_swint_mi_ms-d8c659e3.pth',
transforms=_satlas_sentinel2_transforms,
meta={
'dataset': 'SatlasPretrain',
'in_chans': 9,
'model': 'swin_v2_t',
'publication': 'https://arxiv.org/abs/2211.15660',
'repo': 'https://github.com/allenai/satlas',
'bands': _satlas_sentinel2_bands,
},
)
SENTINEL2_MI_RGB_SATLAS = Weights(
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_swint_mi_rgb-424d91f4.pth',
transforms=_satlas_transforms,
meta={
'dataset': 'SatlasPretrain',
'in_chans': 3,
'model': 'swin_v2_t',
'publication': 'https://arxiv.org/abs/2211.15660',
'repo': 'https://github.com/allenai/satlas',
'bands': _satlas_bands,
},
)
SENTINEL2_SI_MS_SATLAS = Weights(
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_swint_si_ms-bc68e396.pth',
transforms=_satlas_sentinel2_transforms,
meta={
'dataset': 'SatlasPretrain',
'in_chans': 9,
'model': 'swin_v2_t',
'publication': 'https://arxiv.org/abs/2211.15660',
'repo': 'https://github.com/allenai/satlas',
'bands': _satlas_sentinel2_bands,
},
)
SENTINEL2_SI_RGB_SATLAS = Weights(
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_swint_si_rgb-0c1a96e0.pth',
transforms=_satlas_transforms,
meta={
'dataset': 'SatlasPretrain',
'in_chans': 3,
'model': 'swin_v2_t',
'publication': 'https://arxiv.org/abs/2211.15660',
'repo': 'https://github.com/allenai/satlas',
'bands': _satlas_bands,
},
)
class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc]
"""Swin Transformer v2 Base weights.
@ -55,81 +120,175 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc]
.. versionadded:: 0.6
"""
NAIP_RGB_MI_SATLAS = Weights(
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/aerial_swinb_mi-326d69e1.pth',
transforms=_satlas_transforms,
meta={
'dataset': 'SatlasPretrain',
'in_chans': 3,
'model': 'swin_v2_b',
'publication': 'https://arxiv.org/abs/2211.15660',
'repo': 'https://github.com/allenai/satlas',
'bands': ('R', 'G', 'B'),
},
)
NAIP_RGB_SI_SATLAS = Weights(
url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/aerial_swinb_si.pth',
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/aerial_swinb_si-e4169eb1.pth',
transforms=_satlas_transforms,
meta={
'dataset': 'Satlas',
'dataset': 'SatlasPretrain',
'in_chans': 3,
'model': 'swin_v2_b',
'publication': 'https://arxiv.org/abs/2211.15660',
'repo': 'https://github.com/allenai/satlas',
'bands': ('R', 'G', 'B'),
},
)
SENTINEL2_RGB_SI_SATLAS = Weights(
url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel2_swinb_si_rgb.pth',
transforms=_satlas_transforms,
LANDSAT_MI_SATLAS = Weights(
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/landsat_swinb_mi-6b4a1cda.pth',
transforms=_satlas_landsat_transforms,
meta={
'dataset': 'Satlas',
'in_chans': 3,
'model': 'swin_v2_b',
'publication': 'https://arxiv.org/abs/2211.15660',
'repo': 'https://github.com/allenai/satlas',
},
)
SENTINEL2_MS_SI_SATLAS = Weights(
url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel2_swinb_si_ms.pth',
transforms=_sentinel2_ms_satlas_transforms,
meta={
'dataset': 'Satlas',
'in_chans': 9,
'model': 'swin_v2_b',
'publication': 'https://arxiv.org/abs/2211.15660',
'repo': 'https://github.com/allenai/satlas',
'bands': ['B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B11', 'B12'],
},
)
SENTINEL1_SI_SATLAS = Weights(
url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel1_swinb_si.pth',
transforms=_satlas_transforms,
meta={
'dataset': 'Satlas',
'in_chans': 2,
'model': 'swin_v2_b',
'publication': 'https://arxiv.org/abs/2211.15660',
'repo': 'https://github.com/allenai/satlas',
'bands': ['VH', 'VV'],
},
)
LANDSAT_SI_SATLAS = Weights(
url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/landsat_swinb_si.pth',
transforms=_landsat_satlas_transforms,
meta={
'dataset': 'Satlas',
'dataset': 'SatlasPretrain',
'in_chans': 11,
'model': 'swin_v2_b',
'publication': 'https://arxiv.org/abs/2211.15660',
'repo': 'https://github.com/allenai/satlas',
'bands': [
'B01',
'B02',
'B03',
'B04',
'B05',
'B06',
'B07',
'B08',
'B09',
'B10',
'B11',
],
'bands': _satlas_landsat_bands,
},
)
LANDSAT_SI_SATLAS = Weights(
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/landsat_swinb_si-4af978f6.pth',
transforms=_satlas_landsat_transforms,
meta={
'dataset': 'SatlasPretrain',
'in_chans': 11,
'model': 'swin_v2_b',
'publication': 'https://arxiv.org/abs/2211.15660',
'repo': 'https://github.com/allenai/satlas',
'bands': _satlas_landsat_bands,
},
)
SENTINEL1_MI_SATLAS = Weights(
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel1_swinb_mi-f6c43d97.pth',
transforms=_satlas_transforms,
meta={
'dataset': 'SatlasPretrain',
'in_chans': 2,
'model': 'swin_v2_b',
'publication': 'https://arxiv.org/abs/2211.15660',
'repo': 'https://github.com/allenai/satlas',
'bands': ('VH', 'VV'),
},
)
SENTINEL1_SI_SATLAS = Weights(
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel1_swinb_si-3981c153.pth',
transforms=_satlas_transforms,
meta={
'dataset': 'SatlasPretrain',
'in_chans': 2,
'model': 'swin_v2_b',
'publication': 'https://arxiv.org/abs/2211.15660',
'repo': 'https://github.com/allenai/satlas',
'bands': ('VH', 'VV'),
},
)
SENTINEL2_MI_MS_SATLAS = Weights(
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_swinb_mi_ms-39c86721.pth',
transforms=_satlas_sentinel2_transforms,
meta={
'dataset': 'SatlasPretrain',
'in_chans': 9,
'model': 'swin_v2_b',
'publication': 'https://arxiv.org/abs/2211.15660',
'repo': 'https://github.com/allenai/satlas',
'bands': _satlas_sentinel2_bands,
},
)
SENTINEL2_MI_RGB_SATLAS = Weights(
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_swinb_mi_rgb-4efa210c.pth',
transforms=_satlas_transforms,
meta={
'dataset': 'SatlasPretrain',
'in_chans': 3,
'model': 'swin_v2_b',
'publication': 'https://arxiv.org/abs/2211.15660',
'repo': 'https://github.com/allenai/satlas',
'bands': _satlas_bands,
},
)
SENTINEL2_SI_MS_SATLAS = Weights(
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_swinb_si_ms-fe22a12c.pth',
transforms=_satlas_sentinel2_transforms,
meta={
'dataset': 'SatlasPretrain',
'in_chans': 9,
'model': 'swin_v2_b',
'publication': 'https://arxiv.org/abs/2211.15660',
'repo': 'https://github.com/allenai/satlas',
'bands': _satlas_sentinel2_bands,
},
)
SENTINEL2_SI_RGB_SATLAS = Weights(
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_swinb_si_rgb-156a98d5.pth',
transforms=_satlas_transforms,
meta={
'dataset': 'SatlasPretrain',
'in_chans': 3,
'model': 'swin_v2_b',
'publication': 'https://arxiv.org/abs/2211.15660',
'repo': 'https://github.com/allenai/satlas',
'bands': _satlas_bands,
},
)
def swin_v2_t(
weights: Swin_V2_T_Weights | None = None, *args: Any, **kwargs: Any
) -> SwinTransformer:
"""Swin Transformer v2 tiny model.
If you use this model in your research, please cite the following paper:
* https://arxiv.org/abs/2111.09883
.. versionadded:: 0.6
Args:
weights: Pre-trained model weights to use.
*args: Additional arguments to
pass to :class:`torchvision.models.swin_transformer.SwinTransformer`.
**kwargs: Additional keywork arguments to
pass to :class:`torchvision.models.swin_transformer.SwinTransformer`.
Returns:
A Swin Transformer Tiny model.
"""
model: SwinTransformer = torchvision.models.swin_v2_t(weights=None, *args, **kwargs)
if weights:
num_channels = weights.meta['in_chans']
out_channels = model.features[0][0].out_channels
# https://github.com/allenai/satlaspretrain_models/blob/main/satlaspretrain_models/models/backbones.py#L27
model.features[0][0] = torch.nn.Conv2d(
num_channels, out_channels, kernel_size=(4, 4), stride=(4, 4)
)
missing_keys, unexpected_keys = model.load_state_dict(
weights.get_state_dict(progress=True), strict=False
)
assert set(missing_keys) <= set()
assert not unexpected_keys
return model
def swin_v2_b(
weights: Swin_V2_B_Weights | None = None, *args: Any, **kwargs: Any
@ -155,6 +314,16 @@ def swin_v2_b(
model: SwinTransformer = torchvision.models.swin_v2_b(weights=None, *args, **kwargs)
if weights:
model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
num_channels = weights.meta['in_chans']
out_channels = model.features[0][0].out_channels
# https://github.com/allenai/satlaspretrain_models/blob/main/satlaspretrain_models/models/backbones.py#L27
model.features[0][0] = torch.nn.Conv2d(
num_channels, out_channels, kernel_size=(4, 4), stride=(4, 4)
)
missing_keys, unexpected_keys = model.load_state_dict(
weights.get_state_dict(progress=True), strict=False
)
assert set(missing_keys) <= set()
assert not unexpected_keys
return model

Просмотреть файл

@ -8,7 +8,7 @@ from typing import Any
import kornia.augmentation as K
import torch
from einops import rearrange
from kornia.contrib import Lambda, extract_tensor_patches
from kornia.contrib import extract_tensor_patches
from kornia.geometry import crop_by_indices
from kornia.geometry.boxes import Boxes
from torch import Tensor
@ -25,7 +25,7 @@ class AugmentationSequential(Module):
def __init__(
self,
*args: K.base._AugmentationBase | K.ImageSequential | Lambda,
*args: K.base._AugmentationBase | K.ImageSequential,
data_keys: list[str],
**kwargs: Any,
) -> None:
@ -53,7 +53,7 @@ class AugmentationSequential(Module):
else:
keys.append(key)
self.augs = K.AugmentationSequential(*args, data_keys=keys, **kwargs) # type: ignore[arg-type]
self.augs = K.AugmentationSequential(*args, data_keys=keys, **kwargs)
def forward(self, batch: dict[str, Any]) -> dict[str, Any]:
"""Perform augmentations and update data dict.
@ -272,3 +272,54 @@ class _ExtractPatches(K.GeometricAugmentationBase2D):
out = rearrange(out, 'b t c h w -> (b t) c h w')
return out
class _Clamp(K.IntensityAugmentationBase2D):
"""Clamp images to a specific range."""
def __init__(
self,
p: float = 0.5,
p_batch: float = 1,
min: float = 0,
max: float = 1,
same_on_batch: bool = False,
keepdim: bool = False,
) -> None:
"""Initialize a new _Clamp instance.
Args:
p: Probability for applying an augmentation. This param controls the
augmentation probabilities element-wise for a batch.
p_batch: Probability for applying an augmentation to a batch. This param
controls the augmentation probabilities batch-wise.
min: Minimum value to clamp to.
max: Maximum value to clamp to.
same_on_batch: Apply the same transformation across the batch.
keepdim: Whether to keep the output shape the same as input ``True``
or broadcast it to the batch form ``False``.
"""
super().__init__(
p=p, p_batch=p_batch, same_on_batch=same_on_batch, keepdim=keepdim
)
self.flags = {'min': min, 'max': max}
def apply_transform(
self,
input: Tensor,
params: dict[str, Tensor],
flags: dict[str, Any],
transform: Tensor | None = None,
) -> Tensor:
"""Apply the transform.
Args:
input: the input tensor
params: generated parameters
flags: static parameters
transform: the geometric transformation tensor
Returns:
the augmented input
"""
return torch.clamp(input, self.flags['min'], self.flags['max'])