From a757cf14fb0dcebcdb91e94a08d7b4759a5bdc79 Mon Sep 17 00:00:00 2001 From: Yi-Chia Chang <61452667+yichiac@users.noreply.github.com> Date: Tue, 27 Aug 2024 08:58:49 -0500 Subject: [PATCH] SatlasPretrain: ResNet50/152 and Swin_V2_T Weights (#2038) * add SENTINEL2_MS_MI_SATLAS * add SENTINEL2_MS_SI_SATLAS * fix style * ruff * Add all SatlasPretrain models * Get local tests passing * Get local tests passing * Fix remote tests * Mock input channels in testing too * Fix license * Fix band order --------- Co-authored-by: Adam J. Stewart --- docs/api/landsat_pretrained_weights.csv | 3 +- docs/api/models.rst | 4 + docs/api/naip_pretrained_weights.csv | 3 +- docs/api/sentinel1_pretrained_weights.csv | 3 +- docs/api/sentinel2_pretrained_weights.csv | 18 +- hubconf.py | 6 +- tests/models/test_api.py | 8 + tests/models/test_resnet.py | 53 +++- tests/models/test_swin.py | 52 +++- torchgeo/models/__init__.py | 17 +- torchgeo/models/api.py | 17 +- torchgeo/models/resnet.py | 159 ++++++++++- torchgeo/models/swin.py | 327 ++++++++++++++++------ torchgeo/transforms/transforms.py | 57 +++- 14 files changed, 628 insertions(+), 99 deletions(-) diff --git a/docs/api/landsat_pretrained_weights.csv b/docs/api/landsat_pretrained_weights.csv index faf3c286d..bfc90b651 100644 --- a/docs/api/landsat_pretrained_weights.csv +++ b/docs/api/landsat_pretrained_weights.csv @@ -29,4 +29,5 @@ ResNet50_Weights.LANDSAT_OLI_SR_MOCO,8--9,7,`link `__,`link `__,"CC0-1.0",63.65,46.68,60.01,43.17 ViTSmall16_Weights.LANDSAT_OLI_SR_MOCO,8--9,7,`link `__,`link `__,"CC0-1.0",66.81,50.16,64.17,47.24 ViTSmall16_Weights.LANDSAT_OLI_SR_SIMCLR,8--9,7,`link `__,`link `__,"CC0-1.0",65.04,48.20,62.61,45.46 -Swin_V2_B_Weights.LANDSAT_MS_SI_SATLAS,11,`link `__,`link `__,"ODC-BY",,,, +Swin_V2_B_Weights.LANDSAT_SI_SATLAS,8--9,11,`link `__,`link `__,ODC-BY,,,, +Swin_V2_B_Weights.LANDSAT_MI_SATLAS,8--9,11,`link `__,`link `__,ODC-BY,,,, diff --git a/docs/api/models.rst b/docs/api/models.rst index 58e3c5ba0..c61ca59e4 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -47,8 +47,10 @@ ResNet .. autofunction:: resnet18 .. autofunction:: resnet50 +.. autofunction:: resnet152 .. autoclass:: ResNet18_Weights .. autoclass:: ResNet50_Weights +.. autoclass:: ResNet152_Weights Scale-MAE ^^^^^^^^^ @@ -59,7 +61,9 @@ Scale-MAE Swin Transformer ^^^^^^^^^^^^^^^^^^ +.. autofunction:: swin_v2_t .. autofunction:: swin_v2_b +.. autoclass:: Swin_V2_T_Weights .. autoclass:: Swin_V2_B_Weights Vision Transformer diff --git a/docs/api/naip_pretrained_weights.csv b/docs/api/naip_pretrained_weights.csv index e8e8ef14b..7dfe84d21 100644 --- a/docs/api/naip_pretrained_weights.csv +++ b/docs/api/naip_pretrained_weights.csv @@ -1,2 +1,3 @@ Weight,Channels,Source,Citation,License -Swin_V2_B_Weights.NAIP_RGB_SI_SATLAS,3,`link `__,`link `__,"ODC-BY" +Swin_V2_B_Weights.NAIP_RGB_MI_SATLAS,3,`link `__,`link `__,ODC-BY +Swin_V2_B_Weights.NAIP_RGB_SI_SATLAS,3,`link `__,`link `__,ODC-BY diff --git a/docs/api/sentinel1_pretrained_weights.csv b/docs/api/sentinel1_pretrained_weights.csv index ceacfc5bd..82ed045f1 100644 --- a/docs/api/sentinel1_pretrained_weights.csv +++ b/docs/api/sentinel1_pretrained_weights.csv @@ -1,4 +1,5 @@ Weight,Channels,Source,Citation,License ResNet50_Weights.SENTINEL1_ALL_DECUR, 2,`link `__,`link `__,"Apache-2.0" ResNet50_Weights.SENTINEL1_ALL_MOCO, 2,`link `__,`link `__,"CC-BY-4.0" -Swin_V2_B_Weights.SENTINEL1_SI_SATLAS,2,`link `__,`link `__,"ODC-BY" +Swin_V2_B_Weights.SENTINEL1_MI_SATLAS,2,`link `__,`link `__,ODC-BY +Swin_V2_B_Weights.SENTINEL1_SI_SATLAS,2,`link `__,`link `__,ODC-BY diff --git a/docs/api/sentinel2_pretrained_weights.csv b/docs/api/sentinel2_pretrained_weights.csv index dc6f01123..e583cfc89 100644 --- a/docs/api/sentinel2_pretrained_weights.csv +++ b/docs/api/sentinel2_pretrained_weights.csv @@ -5,9 +5,23 @@ ResNet18_Weights.SENTINEL2_RGB_SECO, 3,`link `__,`link `__,"Apache-2.0",,,, ResNet50_Weights.SENTINEL2_ALL_DINO,13,`link `__,`link `__,"CC-BY-4.0",90.7,99.1,63.6, ResNet50_Weights.SENTINEL2_ALL_MOCO,13,`link `__,`link `__,"CC-BY-4.0",91.8,99.1,60.9, +ResNet50_Weights.SENTINEL2_MI_MS_SATLAS,9,`link `__,`link `__,ODC-BY,,,, +ResNet50_Weights.SENTINEL2_MI_RGB_SATLAS,3,`link `__,`link `__,ODC-BY,,,, +ResNet50_Weights.SENTINEL2_SI_MS_SATLAS,9,`link `__,`link `__,ODC-BY,,,, +ResNet50_Weights.SENTINEL2_SI_RGB_SATLAS,3,`link `__,`link `__,ODC-BY,,,, ResNet50_Weights.SENTINEL2_RGB_MOCO, 3,`link `__,`link `__,"CC-BY-4.0",,, ResNet50_Weights.SENTINEL2_RGB_SECO, 3,`link `__,`link `__,"Apache-2.0",87.81,,, +ResNet152_Weights.SENTINEL2_MI_MS_SATLAS,9,`link `__,`link `__,ODC-BY,,,, +ResNet152_Weights.SENTINEL2_MI_RGB_SATLAS,3,`link `__,`link `__,ODC-BY,,,, +ResNet152_Weights.SENTINEL2_SI_MS_SATLAS,9,`link `__,`link `__,ODC-BY,,,, +ResNet152_Weights.SENTINEL2_SI_RGB_SATLAS,3,`link `__,`link `__,ODC-BY,,,, ViTSmall16_Weights.SENTINEL2_ALL_DINO,13,`link `__,`link `__,"CC-BY-4.0",90.5,99.0,62.2, ViTSmall16_Weights.SENTINEL2_ALL_MOCO,13,`link `__,`link `__,"CC-BY-4.0",89.9,98.6,61.6, -Swin_V2_B_Weights.SENTINEL2_RGB_SI_SATLAS,3,`link `__,`link `__,"ODC-BY",,,, -Swin_V2_B_Weights.SENTINEL2_MS_SI_SATLAS,9,`link `__,`link `__,"ODC-BY",,,, +Swin_V2_T_Weights.SENTINEL2_MI_MS_SATLAS,9,`link `__,`link `__,ODC-BY,,,, +Swin_V2_T_Weights.SENTINEL2_MI_RGB_SATLAS,3,`link `__,`link `__,ODC-BY,,,, +Swin_V2_T_Weights.SENTINEL2_SI_MS_SATLAS,9,`link `__,`link `__,ODC-BY,,,, +Swin_V2_T_Weights.SENTINEL2_SI_RGB_SATLAS,3,`link `__,`link `__,ODC-BY,,,, +Swin_V2_B_Weights.SENTINEL2_MI_MS_SATLAS,9,`link `__,`link `__,ODC-BY,,,, +Swin_V2_B_Weights.SENTINEL2_MI_RGB_SATLAS,3,`link `__,`link `__,ODC-BY,,,, +Swin_V2_B_Weights.SENTINEL2_SI_MS_SATLAS,9,`link `__,`link `__,ODC-BY,,,, +Swin_V2_B_Weights.SENTINEL2_SI_RGB_SATLAS,3,`link `__,`link `__,ODC-BY,,,, diff --git a/hubconf.py b/hubconf.py index a9944f8a2..4ce63b0c7 100644 --- a/hubconf.py +++ b/hubconf.py @@ -12,8 +12,10 @@ from torchgeo.models import ( dofa_large_patch16_224, resnet18, resnet50, + resnet152, scalemae_large_patch16, swin_v2_b, + swin_v2_t, vit_small_patch16_224, ) @@ -22,9 +24,11 @@ __all__ = ( 'dofa_large_patch16_224', 'resnet18', 'resnet50', + 'resnet152', 'scalemae_large_patch16', + 'swin_v2_t', 'swin_v2_b', 'vit_small_patch16_224', ) -dependencies = ['timm'] +dependencies = ['timm', 'torchvision'] diff --git a/tests/models/test_api.py b/tests/models/test_api.py index 6ccdc04f2..6c0bc6790 100644 --- a/tests/models/test_api.py +++ b/tests/models/test_api.py @@ -13,8 +13,10 @@ from torchgeo.models import ( DOFALarge16_Weights, ResNet18_Weights, ResNet50_Weights, + ResNet152_Weights, ScaleMAELarge16_Weights, Swin_V2_B_Weights, + Swin_V2_T_Weights, ViTSmall16_Weights, dofa_base_patch16_224, dofa_large_patch16_224, @@ -24,8 +26,10 @@ from torchgeo.models import ( list_models, resnet18, resnet50, + resnet152, scalemae_large_patch16, swin_v2_b, + swin_v2_t, vit_small_patch16_224, ) @@ -34,7 +38,9 @@ builders = [ dofa_large_patch16_224, resnet18, resnet50, + resnet152, scalemae_large_patch16, + swin_v2_t, swin_v2_b, vit_small_patch16_224, ] @@ -43,7 +49,9 @@ enums = [ DOFALarge16_Weights, ResNet18_Weights, ResNet50_Weights, + ResNet152_Weights, ScaleMAELarge16_Weights, + Swin_V2_T_Weights, Swin_V2_B_Weights, ViTSmall16_Weights, ] diff --git a/tests/models/test_resnet.py b/tests/models/test_resnet.py index 03df2e837..24edcf406 100644 --- a/tests/models/test_resnet.py +++ b/tests/models/test_resnet.py @@ -10,7 +10,14 @@ from _pytest.fixtures import SubRequest from pytest import MonkeyPatch from torchvision.models._api import WeightsEnum -from torchgeo.models import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50 +from torchgeo.models import ( + ResNet18_Weights, + ResNet50_Weights, + ResNet152_Weights, + resnet18, + resnet50, + resnet152, +) class TestResNet18: @@ -44,7 +51,7 @@ class TestResNet18: def test_transforms(self, mocked_weights: WeightsEnum) -> None: c = mocked_weights.meta['in_chans'] sample = { - 'image': torch.arange(c * 224 * 224, dtype=torch.float).view(c, 224, 224) + 'image': torch.arange(c * 256 * 256, dtype=torch.float).view(c, 256, 256) } mocked_weights.transforms(sample) @@ -84,10 +91,50 @@ class TestResNet50: def test_transforms(self, mocked_weights: WeightsEnum) -> None: c = mocked_weights.meta['in_chans'] sample = { - 'image': torch.arange(c * 224 * 224, dtype=torch.float).view(c, 224, 224) + 'image': torch.arange(c * 256 * 256, dtype=torch.float).view(c, 256, 256) } mocked_weights.transforms(sample) @pytest.mark.slow def test_resnet_download(self, weights: WeightsEnum) -> None: resnet50(weights=weights) + + +class TestResNet152: + @pytest.fixture(params=[*ResNet152_Weights]) + def weights(self, request: SubRequest) -> WeightsEnum: + return request.param + + @pytest.fixture + def mocked_weights( + self, + tmp_path: Path, + monkeypatch: MonkeyPatch, + weights: WeightsEnum, + load_state_dict_from_url: None, + ) -> WeightsEnum: + path = tmp_path / f'{weights}.pth' + model = timm.create_model('resnet152', in_chans=weights.meta['in_chans']) + torch.save(model.state_dict(), path) + try: + monkeypatch.setattr(weights.value, 'url', str(path)) + except AttributeError: + monkeypatch.setattr(weights, 'url', str(path)) + return weights + + def test_resnet(self) -> None: + resnet152() + + def test_resnet_weights(self, mocked_weights: WeightsEnum) -> None: + resnet152(weights=mocked_weights) + + def test_transforms(self, mocked_weights: WeightsEnum) -> None: + c = mocked_weights.meta['in_chans'] + sample = { + 'image': torch.arange(c * 256 * 256, dtype=torch.float).view(c, 256, 256) + } + mocked_weights.transforms(sample) + + @pytest.mark.slow + def test_resnet_download(self, weights: WeightsEnum) -> None: + resnet152(weights=weights) diff --git a/tests/models/test_swin.py b/tests/models/test_swin.py index c6d6ef241..043059e9b 100644 --- a/tests/models/test_swin.py +++ b/tests/models/test_swin.py @@ -10,7 +10,52 @@ from _pytest.fixtures import SubRequest from pytest import MonkeyPatch from torchvision.models._api import WeightsEnum -from torchgeo.models import Swin_V2_B_Weights, swin_v2_b +from torchgeo.models import Swin_V2_B_Weights, Swin_V2_T_Weights, swin_v2_b, swin_v2_t + + +class TestSwin_V2_T: + @pytest.fixture(params=[*Swin_V2_T_Weights]) + def weights(self, request: SubRequest) -> WeightsEnum: + return request.param + + @pytest.fixture + def mocked_weights( + self, + tmp_path: Path, + monkeypatch: MonkeyPatch, + weights: WeightsEnum, + load_state_dict_from_url: None, + ) -> WeightsEnum: + path = tmp_path / f'{weights}.pth' + model = torchvision.models.swin_v2_t() + num_channels = weights.meta['in_chans'] + out_channels = model.features[0][0].out_channels + model.features[0][0] = torch.nn.Conv2d( + num_channels, out_channels, kernel_size=(4, 4), stride=(4, 4) + ) + torch.save(model.state_dict(), path) + try: + monkeypatch.setattr(weights.value, 'url', str(path)) + except AttributeError: + monkeypatch.setattr(weights, 'url', str(path)) + return weights + + def test_swin_v2_t(self) -> None: + swin_v2_t() + + def test_swin_v2_t_weights(self, mocked_weights: WeightsEnum) -> None: + swin_v2_t(weights=mocked_weights) + + def test_transforms(self, mocked_weights: WeightsEnum) -> None: + c = mocked_weights.meta['in_chans'] + sample = { + 'image': torch.arange(c * 256 * 256, dtype=torch.float).view(c, 256, 256) + } + mocked_weights.transforms(sample) + + @pytest.mark.slow + def test_swin_v2_t_download(self, weights: WeightsEnum) -> None: + swin_v2_t(weights=weights) class TestSwin_V2_B: @@ -28,6 +73,11 @@ class TestSwin_V2_B: ) -> WeightsEnum: path = tmp_path / f'{weights}.pth' model = torchvision.models.swin_v2_b() + num_channels = weights.meta['in_chans'] + out_channels = model.features[0][0].out_channels + model.features[0][0] = torch.nn.Conv2d( + num_channels, out_channels, kernel_size=(4, 4), stride=(4, 4) + ) torch.save(model.state_dict(), path) try: monkeypatch.setattr(weights.value, 'url', str(path)) diff --git a/torchgeo/models/__init__.py b/torchgeo/models/__init__.py index 20a4bc2f9..af4b2d5f1 100644 --- a/torchgeo/models/__init__.py +++ b/torchgeo/models/__init__.py @@ -18,9 +18,16 @@ from .farseg import FarSeg from .fcn import FCN from .fcsiam import FCSiamConc, FCSiamDiff from .rcf import RCF -from .resnet import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50 +from .resnet import ( + ResNet18_Weights, + ResNet50_Weights, + ResNet152_Weights, + resnet18, + resnet50, + resnet152, +) from .scale_mae import ScaleMAE, ScaleMAELarge16_Weights, scalemae_large_patch16 -from .swin import Swin_V2_B_Weights, swin_v2_b +from .swin import Swin_V2_B_Weights, Swin_V2_T_Weights, swin_v2_b, swin_v2_t from .vit import ViTSmall16_Weights, vit_small_patch16_224 __all__ = ( @@ -40,16 +47,20 @@ __all__ = ( 'RCF', 'resnet18', 'resnet50', + 'resnet152', 'ScaleMAE', 'scalemae_large_patch16', + 'swin_v2_t', 'swin_v2_b', 'vit_small_patch16_224', # weights 'DOFABase16_Weights', 'DOFALarge16_Weights', - 'ResNet50_Weights', 'ResNet18_Weights', + 'ResNet50_Weights', + 'ResNet152_Weights', 'ScaleMAELarge16_Weights', + 'Swin_V2_T_Weights', 'Swin_V2_B_Weights', 'ViTSmall16_Weights', # utilities diff --git a/torchgeo/models/api.py b/torchgeo/models/api.py index eca54b595..6e06db82a 100644 --- a/torchgeo/models/api.py +++ b/torchgeo/models/api.py @@ -22,9 +22,16 @@ from .dofa import ( dofa_base_patch16_224, dofa_large_patch16_224, ) -from .resnet import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50 +from .resnet import ( + ResNet18_Weights, + ResNet50_Weights, + ResNet152_Weights, + resnet18, + resnet50, + resnet152, +) from .scale_mae import ScaleMAELarge16_Weights, scalemae_large_patch16 -from .swin import Swin_V2_B_Weights, swin_v2_b +from .swin import Swin_V2_B_Weights, Swin_V2_T_Weights, swin_v2_b, swin_v2_t from .vit import ViTSmall16_Weights, vit_small_patch16_224 _model = { @@ -32,7 +39,9 @@ _model = { 'dofa_large_patch16_224': dofa_large_patch16_224, 'resnet18': resnet18, 'resnet50': resnet50, + 'resnet152': resnet152, 'scalemae_large_patch16': scalemae_large_patch16, + 'swin_v2_t': swin_v2_t, 'swin_v2_b': swin_v2_b, 'vit_small_patch16_224': vit_small_patch16_224, } @@ -42,14 +51,18 @@ _model_weights = { dofa_large_patch16_224: DOFALarge16_Weights, resnet18: ResNet18_Weights, resnet50: ResNet50_Weights, + resnet152: ResNet152_Weights, scalemae_large_patch16: ScaleMAELarge16_Weights, + swin_v2_t: Swin_V2_T_Weights, swin_v2_b: Swin_V2_B_Weights, vit_small_patch16_224: ViTSmall16_Weights, 'dofa_base_patch16_224': DOFABase16_Weights, 'dofa_large_patch16_224': DOFALarge16_Weights, 'resnet18': ResNet18_Weights, 'resnet50': ResNet50_Weights, + 'resnet152': ResNet152_Weights, 'scalemae_large_patch16': ScaleMAELarge16_Weights, + 'swin_v2_t': Swin_V2_T_Weights, 'swin_v2_b': Swin_V2_B_Weights, 'vit_small_patch16_224': ViTSmall16_Weights, } diff --git a/torchgeo/models/resnet.py b/torchgeo/models/resnet.py index bbb9b73d5..7429251c7 100644 --- a/torchgeo/models/resnet.py +++ b/torchgeo/models/resnet.py @@ -11,6 +11,13 @@ import torch from timm.models import ResNet from torchvision.models._api import Weights, WeightsEnum +from .swin import ( + _satlas_bands, + _satlas_sentinel2_bands, + _satlas_sentinel2_transforms, + _satlas_transforms, +) + # https://github.com/zhu-xlab/DeCUR/blob/f190e9a3895ef645c005c8c2fce287ffa5a937e3/src/transfer_classification_BE/linear_BE_resnet.py#L286 # Normalization by channel-wise band statistics _mean_s1 = torch.tensor([-12.59, -20.26]) @@ -113,7 +120,7 @@ Weights.__deepcopy__ = lambda *args, **kwargs: args[0] class ResNet18_Weights(WeightsEnum): # type: ignore[misc] - """ResNet18 weights. + """ResNet-18 weights. For `timm `_ *resnet18* implementation. @@ -292,7 +299,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc] class ResNet50_Weights(WeightsEnum): # type: ignore[misc] - """ResNet50 weights. + """ResNet-50 weights. For `timm `_ *resnet50* implementation. @@ -508,6 +515,32 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] }, ) + SENTINEL2_MI_MS_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_resnet50_mi_ms-da5413d2.pth', + transforms=_satlas_sentinel2_transforms, + meta={ + 'dataset': 'SatlasPretrain', + 'in_chans': 9, + 'model': 'resnet50', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlaspretrain_models', + 'bands': _satlas_sentinel2_bands, + }, + ) + + SENTINEL2_MI_RGB_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_resnet50_mi_rgb-e79bb7fe.pth', + transforms=_satlas_transforms, + meta={ + 'dataset': 'SatlasPretrain', + 'in_chans': 3, + 'model': 'resnet50', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlaspretrain_models', + 'bands': _satlas_bands, + }, + ) + SENTINEL2_RGB_MOCO = Weights( url='https://hf.co/torchgeo/resnet50_sentinel2_rgb_moco/resolve/efd9723b59a88e9dc1420dc1e96afb25b0630a3c/resnet50_sentinel2_rgb_moco-2b57ba8b.pth', transforms=_ssl4eo_s12_transforms_s2_10k, @@ -534,6 +567,94 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] }, ) + SENTINEL2_SI_MS_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_resnet50_si_ms-1f454cc6.pth', + transforms=_satlas_sentinel2_transforms, + meta={ + 'dataset': 'SatlasPretrain', + 'in_chans': 9, + 'model': 'resnet50', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlaspretrain_models', + 'bands': _satlas_sentinel2_bands, + }, + ) + + SENTINEL2_SI_RGB_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_resnet50_si_rgb-45fc6972.pth', + transforms=_satlas_transforms, + meta={ + 'dataset': 'SatlasPretrain', + 'in_chans': 3, + 'model': 'resnet50', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlaspretrain_models', + 'bands': _satlas_bands, + }, + ) + + +class ResNet152_Weights(WeightsEnum): # type: ignore[misc] + """ResNet-152 weights. + + For `timm `_ + *resnet152* implementation. + + .. versionadded:: 0.6 + """ + + SENTINEL2_MI_MS_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_resnet152_mi_ms-fd35b4bb.pth', + transforms=_satlas_sentinel2_transforms, + meta={ + 'dataset': 'SatlasPretrain', + 'in_chans': 9, + 'model': 'resnet50', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlaspretrain_models', + 'bands': _satlas_sentinel2_bands, + }, + ) + + SENTINEL2_MI_RGB_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_resnet152_mi_rgb-67563ac5.pth', + transforms=_satlas_transforms, + meta={ + 'dataset': 'SatlasPretrain', + 'in_chans': 3, + 'model': 'resnet50', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlaspretrain_models', + 'bands': _satlas_bands, + }, + ) + + SENTINEL2_SI_MS_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_resnet152_si_ms-4500c6cb.pth', + transforms=_satlas_sentinel2_transforms, + meta={ + 'dataset': 'SatlasPretrain', + 'in_chans': 9, + 'model': 'resnet50', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlaspretrain_models', + 'bands': _satlas_sentinel2_bands, + }, + ) + + SENTINEL2_SI_RGB_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_resnet152_si_rgb-f4d24c3c.pth', + transforms=_satlas_transforms, + meta={ + 'dataset': 'SatlasPretrain', + 'in_chans': 3, + 'model': 'resnet50', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlaspretrain_models', + 'bands': _satlas_bands, + }, + ) + def resnet18( weights: ResNet18_Weights | None = None, *args: Any, **kwargs: Any @@ -602,3 +723,37 @@ def resnet50( assert not unexpected_keys return model + + +def resnet152( + weights: ResNet152_Weights | None = None, *args: Any, **kwargs: Any +) -> ResNet: + """ResNet-152 model. + + If you use this model in your research, please cite the following paper: + + * https://arxiv.org/pdf/1512.03385.pdf + + .. versionadded:: 0.6 + + Args: + weights: Pre-trained model weights to use. + *args: Additional arguments to pass to :func:`timm.create_model`. + **kwargs: Additional keywork arguments to pass to :func:`timm.create_model`. + + Returns: + A ResNet-152 model. + """ + if weights: + kwargs['in_chans'] = weights.meta['in_chans'] + + model: ResNet = timm.create_model('resnet152', *args, **kwargs) + + if weights: + missing_keys, unexpected_keys = model.load_state_dict( + weights.get_state_dict(progress=True), strict=False + ) + assert set(missing_keys) <= {'fc.weight', 'fc.bias'} + assert not unexpected_keys + + return model diff --git a/torchgeo/models/swin.py b/torchgeo/models/swin.py index f29c2ffab..314000126 100644 --- a/torchgeo/models/swin.py +++ b/torchgeo/models/swin.py @@ -8,35 +8,38 @@ from typing import Any import kornia.augmentation as K import torch import torchvision -from kornia.contrib import Lambda from torchvision.models import SwinTransformer from torchvision.models._api import Weights, WeightsEnum -# https://github.com/allenai/satlas/blob/bcaa968da5395f675d067613e02613a344e81415/satlas/cmd/model/train.py#L42 -# Satlas uses the TCI product for Sentinel-2 RGB, which is in the range (0, 255). -# See details: https://github.com/allenai/satlas/blob/main/Normalization.md#sentinel-2-images. -# Satlas Sentinel-1 and RGB Sentinel-2 and NAIP imagery is uint8 and is normalized to (0, 1) by dividing by 255. -_satlas_transforms = K.AugmentationSequential( - K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)), data_keys=None -) +import torchgeo.transforms.transforms as T -# Satlas uses the TCI product for Sentinel-2 RGB, which is in the range (0, 255). -# See details: https://github.com/allenai/satlas/blob/main/Normalization.md#sentinel-2-images. -# Satlas Sentinel-2 multispectral imagery has first 3 bands divided by 255 and the following 6 bands by 8160, both clipped to (0, 1). -_std = torch.tensor( - [255.0, 255.0, 255.0, 8160.0, 8160.0, 8160.0, 8160.0, 8160.0, 8160.0] -) -_mean = torch.zeros_like(_std) -_sentinel2_ms_satlas_transforms = K.AugmentationSequential( - K.Normalize(mean=_mean, std=_std), - K.ImageSequential(Lambda(lambda x: torch.clamp(x, min=0.0, max=1.0))), +# All Satlas transforms include: +# https://github.com/allenai/satlas/blob/main/satlas/cmd/model/train.py#L49 +# +# Information about sensor-specific normalization can be found at: +# https://github.com/allenai/satlas/blob/main/Normalization.md + +_satlas_bands = ('B04', 'B03', 'B02') +_satlas_transforms = K.AugmentationSequential( + K.CenterCrop(256), + K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)), data_keys=None, ) -# Satlas Landsat imagery is 16-bit, normalized by clipping some pixel N with (N-4000)/16320 to (0, 1). -_landsat_satlas_transforms = K.AugmentationSequential( +_satlas_sentinel2_bands = (*_satlas_bands, 'B05', 'B06', 'B07', 'B08', 'B11', 'B12') +_std = torch.tensor([255, 255, 255, 8160, 8160, 8160, 8160, 8160, 8160]) +_satlas_sentinel2_transforms = K.AugmentationSequential( + K.CenterCrop(256), + K.Normalize(mean=torch.tensor(0), std=_std), + T._Clamp(p=1, min=0, max=1), + data_keys=None, +) + +_satlas_landsat_bands = tuple(f'B{i:02}' for i in range(1, 12)) +_satlas_landsat_transforms = K.AugmentationSequential( + K.CenterCrop(256), K.Normalize(mean=torch.tensor(4000), std=torch.tensor(16320)), - K.ImageSequential(Lambda(lambda x: torch.clamp(x, min=0.0, max=1.0))), + T._Clamp(p=1, min=0, max=1), data_keys=None, ) @@ -46,6 +49,68 @@ _landsat_satlas_transforms = K.AugmentationSequential( Weights.__deepcopy__ = lambda *args, **kwargs: args[0] +class Swin_V2_T_Weights(WeightsEnum): # type: ignore[misc] + """Swin Transformer v2 Tiny weights. + + For `torchvision `_ + *swin_v2_t* implementation. + + .. versionadded:: 0.6 + """ + + SENTINEL2_MI_MS_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_swint_mi_ms-d8c659e3.pth', + transforms=_satlas_sentinel2_transforms, + meta={ + 'dataset': 'SatlasPretrain', + 'in_chans': 9, + 'model': 'swin_v2_t', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlas', + 'bands': _satlas_sentinel2_bands, + }, + ) + + SENTINEL2_MI_RGB_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_swint_mi_rgb-424d91f4.pth', + transforms=_satlas_transforms, + meta={ + 'dataset': 'SatlasPretrain', + 'in_chans': 3, + 'model': 'swin_v2_t', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlas', + 'bands': _satlas_bands, + }, + ) + + SENTINEL2_SI_MS_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_swint_si_ms-bc68e396.pth', + transforms=_satlas_sentinel2_transforms, + meta={ + 'dataset': 'SatlasPretrain', + 'in_chans': 9, + 'model': 'swin_v2_t', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlas', + 'bands': _satlas_sentinel2_bands, + }, + ) + + SENTINEL2_SI_RGB_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_swint_si_rgb-0c1a96e0.pth', + transforms=_satlas_transforms, + meta={ + 'dataset': 'SatlasPretrain', + 'in_chans': 3, + 'model': 'swin_v2_t', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlas', + 'bands': _satlas_bands, + }, + ) + + class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc] """Swin Transformer v2 Base weights. @@ -55,81 +120,175 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc] .. versionadded:: 0.6 """ + NAIP_RGB_MI_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/aerial_swinb_mi-326d69e1.pth', + transforms=_satlas_transforms, + meta={ + 'dataset': 'SatlasPretrain', + 'in_chans': 3, + 'model': 'swin_v2_b', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlas', + 'bands': ('R', 'G', 'B'), + }, + ) + NAIP_RGB_SI_SATLAS = Weights( - url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/aerial_swinb_si.pth', + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/aerial_swinb_si-e4169eb1.pth', transforms=_satlas_transforms, meta={ - 'dataset': 'Satlas', + 'dataset': 'SatlasPretrain', 'in_chans': 3, 'model': 'swin_v2_b', 'publication': 'https://arxiv.org/abs/2211.15660', 'repo': 'https://github.com/allenai/satlas', + 'bands': ('R', 'G', 'B'), }, ) - SENTINEL2_RGB_SI_SATLAS = Weights( - url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel2_swinb_si_rgb.pth', - transforms=_satlas_transforms, + LANDSAT_MI_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/landsat_swinb_mi-6b4a1cda.pth', + transforms=_satlas_landsat_transforms, meta={ - 'dataset': 'Satlas', - 'in_chans': 3, - 'model': 'swin_v2_b', - 'publication': 'https://arxiv.org/abs/2211.15660', - 'repo': 'https://github.com/allenai/satlas', - }, - ) - - SENTINEL2_MS_SI_SATLAS = Weights( - url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel2_swinb_si_ms.pth', - transforms=_sentinel2_ms_satlas_transforms, - meta={ - 'dataset': 'Satlas', - 'in_chans': 9, - 'model': 'swin_v2_b', - 'publication': 'https://arxiv.org/abs/2211.15660', - 'repo': 'https://github.com/allenai/satlas', - 'bands': ['B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B11', 'B12'], - }, - ) - - SENTINEL1_SI_SATLAS = Weights( - url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel1_swinb_si.pth', - transforms=_satlas_transforms, - meta={ - 'dataset': 'Satlas', - 'in_chans': 2, - 'model': 'swin_v2_b', - 'publication': 'https://arxiv.org/abs/2211.15660', - 'repo': 'https://github.com/allenai/satlas', - 'bands': ['VH', 'VV'], - }, - ) - - LANDSAT_SI_SATLAS = Weights( - url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/landsat_swinb_si.pth', - transforms=_landsat_satlas_transforms, - meta={ - 'dataset': 'Satlas', + 'dataset': 'SatlasPretrain', 'in_chans': 11, 'model': 'swin_v2_b', 'publication': 'https://arxiv.org/abs/2211.15660', 'repo': 'https://github.com/allenai/satlas', - 'bands': [ - 'B01', - 'B02', - 'B03', - 'B04', - 'B05', - 'B06', - 'B07', - 'B08', - 'B09', - 'B10', - 'B11', - ], + 'bands': _satlas_landsat_bands, }, ) + LANDSAT_SI_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/landsat_swinb_si-4af978f6.pth', + transforms=_satlas_landsat_transforms, + meta={ + 'dataset': 'SatlasPretrain', + 'in_chans': 11, + 'model': 'swin_v2_b', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlas', + 'bands': _satlas_landsat_bands, + }, + ) + + SENTINEL1_MI_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel1_swinb_mi-f6c43d97.pth', + transforms=_satlas_transforms, + meta={ + 'dataset': 'SatlasPretrain', + 'in_chans': 2, + 'model': 'swin_v2_b', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlas', + 'bands': ('VH', 'VV'), + }, + ) + + SENTINEL1_SI_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel1_swinb_si-3981c153.pth', + transforms=_satlas_transforms, + meta={ + 'dataset': 'SatlasPretrain', + 'in_chans': 2, + 'model': 'swin_v2_b', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlas', + 'bands': ('VH', 'VV'), + }, + ) + + SENTINEL2_MI_MS_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_swinb_mi_ms-39c86721.pth', + transforms=_satlas_sentinel2_transforms, + meta={ + 'dataset': 'SatlasPretrain', + 'in_chans': 9, + 'model': 'swin_v2_b', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlas', + 'bands': _satlas_sentinel2_bands, + }, + ) + + SENTINEL2_MI_RGB_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_swinb_mi_rgb-4efa210c.pth', + transforms=_satlas_transforms, + meta={ + 'dataset': 'SatlasPretrain', + 'in_chans': 3, + 'model': 'swin_v2_b', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlas', + 'bands': _satlas_bands, + }, + ) + + SENTINEL2_SI_MS_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_swinb_si_ms-fe22a12c.pth', + transforms=_satlas_sentinel2_transforms, + meta={ + 'dataset': 'SatlasPretrain', + 'in_chans': 9, + 'model': 'swin_v2_b', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlas', + 'bands': _satlas_sentinel2_bands, + }, + ) + + SENTINEL2_SI_RGB_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_swinb_si_rgb-156a98d5.pth', + transforms=_satlas_transforms, + meta={ + 'dataset': 'SatlasPretrain', + 'in_chans': 3, + 'model': 'swin_v2_b', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlas', + 'bands': _satlas_bands, + }, + ) + + +def swin_v2_t( + weights: Swin_V2_T_Weights | None = None, *args: Any, **kwargs: Any +) -> SwinTransformer: + """Swin Transformer v2 tiny model. + + If you use this model in your research, please cite the following paper: + + * https://arxiv.org/abs/2111.09883 + + .. versionadded:: 0.6 + + Args: + weights: Pre-trained model weights to use. + *args: Additional arguments to + pass to :class:`torchvision.models.swin_transformer.SwinTransformer`. + **kwargs: Additional keywork arguments to + pass to :class:`torchvision.models.swin_transformer.SwinTransformer`. + + Returns: + A Swin Transformer Tiny model. + """ + model: SwinTransformer = torchvision.models.swin_v2_t(weights=None, *args, **kwargs) + + if weights: + num_channels = weights.meta['in_chans'] + out_channels = model.features[0][0].out_channels + # https://github.com/allenai/satlaspretrain_models/blob/main/satlaspretrain_models/models/backbones.py#L27 + model.features[0][0] = torch.nn.Conv2d( + num_channels, out_channels, kernel_size=(4, 4), stride=(4, 4) + ) + missing_keys, unexpected_keys = model.load_state_dict( + weights.get_state_dict(progress=True), strict=False + ) + assert set(missing_keys) <= set() + assert not unexpected_keys + + return model + def swin_v2_b( weights: Swin_V2_B_Weights | None = None, *args: Any, **kwargs: Any @@ -155,6 +314,16 @@ def swin_v2_b( model: SwinTransformer = torchvision.models.swin_v2_b(weights=None, *args, **kwargs) if weights: - model.load_state_dict(weights.get_state_dict(progress=True), strict=False) + num_channels = weights.meta['in_chans'] + out_channels = model.features[0][0].out_channels + # https://github.com/allenai/satlaspretrain_models/blob/main/satlaspretrain_models/models/backbones.py#L27 + model.features[0][0] = torch.nn.Conv2d( + num_channels, out_channels, kernel_size=(4, 4), stride=(4, 4) + ) + missing_keys, unexpected_keys = model.load_state_dict( + weights.get_state_dict(progress=True), strict=False + ) + assert set(missing_keys) <= set() + assert not unexpected_keys return model diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py index 15a1d7c59..d8f80bdca 100644 --- a/torchgeo/transforms/transforms.py +++ b/torchgeo/transforms/transforms.py @@ -8,7 +8,7 @@ from typing import Any import kornia.augmentation as K import torch from einops import rearrange -from kornia.contrib import Lambda, extract_tensor_patches +from kornia.contrib import extract_tensor_patches from kornia.geometry import crop_by_indices from kornia.geometry.boxes import Boxes from torch import Tensor @@ -25,7 +25,7 @@ class AugmentationSequential(Module): def __init__( self, - *args: K.base._AugmentationBase | K.ImageSequential | Lambda, + *args: K.base._AugmentationBase | K.ImageSequential, data_keys: list[str], **kwargs: Any, ) -> None: @@ -53,7 +53,7 @@ class AugmentationSequential(Module): else: keys.append(key) - self.augs = K.AugmentationSequential(*args, data_keys=keys, **kwargs) # type: ignore[arg-type] + self.augs = K.AugmentationSequential(*args, data_keys=keys, **kwargs) def forward(self, batch: dict[str, Any]) -> dict[str, Any]: """Perform augmentations and update data dict. @@ -272,3 +272,54 @@ class _ExtractPatches(K.GeometricAugmentationBase2D): out = rearrange(out, 'b t c h w -> (b t) c h w') return out + + +class _Clamp(K.IntensityAugmentationBase2D): + """Clamp images to a specific range.""" + + def __init__( + self, + p: float = 0.5, + p_batch: float = 1, + min: float = 0, + max: float = 1, + same_on_batch: bool = False, + keepdim: bool = False, + ) -> None: + """Initialize a new _Clamp instance. + + Args: + p: Probability for applying an augmentation. This param controls the + augmentation probabilities element-wise for a batch. + p_batch: Probability for applying an augmentation to a batch. This param + controls the augmentation probabilities batch-wise. + min: Minimum value to clamp to. + max: Maximum value to clamp to. + same_on_batch: Apply the same transformation across the batch. + keepdim: Whether to keep the output shape the same as input ``True`` + or broadcast it to the batch form ``False``. + """ + super().__init__( + p=p, p_batch=p_batch, same_on_batch=same_on_batch, keepdim=keepdim + ) + self.flags = {'min': min, 'max': max} + + def apply_transform( + self, + input: Tensor, + params: dict[str, Tensor], + flags: dict[str, Any], + transform: Tensor | None = None, + ) -> Tensor: + """Apply the transform. + + Args: + input: the input tensor + params: generated parameters + flags: static parameters + transform: the geometric transformation tensor + + Returns: + the augmented input + """ + return torch.clamp(input, self.flags['min'], self.flags['max'])