зеркало из https://github.com/microsoft/torchgeo.git
SatlasPretrain: ResNet50/152 and Swin_V2_T Weights (#2038)
* add SENTINEL2_MS_MI_SATLAS * add SENTINEL2_MS_SI_SATLAS * fix style * ruff * Add all SatlasPretrain models * Get local tests passing * Get local tests passing * Fix remote tests * Mock input channels in testing too * Fix license * Fix band order --------- Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Родитель
2d6e27ebd0
Коммит
a757cf14fb
|
@ -29,4 +29,5 @@ ResNet50_Weights.LANDSAT_OLI_SR_MOCO,8--9,7,`link <https://github.com/microsoft/
|
|||
ResNet50_Weights.LANDSAT_OLI_SR_SIMCLR,8--9,7,`link <https://github.com/microsoft/torchgeo>`__,`link <https://proceedings.neurips.cc/paper_files/paper/2023/hash/bbf7ee04e2aefec136ecf60e346c2e61-Abstract-Datasets_and_Benchmarks.html>`__,"CC0-1.0",63.65,46.68,60.01,43.17
|
||||
ViTSmall16_Weights.LANDSAT_OLI_SR_MOCO,8--9,7,`link <https://github.com/microsoft/torchgeo>`__,`link <https://proceedings.neurips.cc/paper_files/paper/2023/hash/bbf7ee04e2aefec136ecf60e346c2e61-Abstract-Datasets_and_Benchmarks.html>`__,"CC0-1.0",66.81,50.16,64.17,47.24
|
||||
ViTSmall16_Weights.LANDSAT_OLI_SR_SIMCLR,8--9,7,`link <https://github.com/microsoft/torchgeo>`__,`link <https://proceedings.neurips.cc/paper_files/paper/2023/hash/bbf7ee04e2aefec136ecf60e346c2e61-Abstract-Datasets_and_Benchmarks.html>`__,"CC0-1.0",65.04,48.20,62.61,45.46
|
||||
Swin_V2_B_Weights.LANDSAT_MS_SI_SATLAS,11,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,"ODC-BY",,,,
|
||||
Swin_V2_B_Weights.LANDSAT_SI_SATLAS,8--9,11,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY,,,,
|
||||
Swin_V2_B_Weights.LANDSAT_MI_SATLAS,8--9,11,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY,,,,
|
||||
|
|
Не удается отобразить этот файл, потому что он имеет неправильное количество полей в строке 32.
|
|
@ -47,8 +47,10 @@ ResNet
|
|||
|
||||
.. autofunction:: resnet18
|
||||
.. autofunction:: resnet50
|
||||
.. autofunction:: resnet152
|
||||
.. autoclass:: ResNet18_Weights
|
||||
.. autoclass:: ResNet50_Weights
|
||||
.. autoclass:: ResNet152_Weights
|
||||
|
||||
Scale-MAE
|
||||
^^^^^^^^^
|
||||
|
@ -59,7 +61,9 @@ Scale-MAE
|
|||
Swin Transformer
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: swin_v2_t
|
||||
.. autofunction:: swin_v2_b
|
||||
.. autoclass:: Swin_V2_T_Weights
|
||||
.. autoclass:: Swin_V2_B_Weights
|
||||
|
||||
Vision Transformer
|
||||
|
|
|
@ -1,2 +1,3 @@
|
|||
Weight,Channels,Source,Citation,License
|
||||
Swin_V2_B_Weights.NAIP_RGB_SI_SATLAS,3,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,"ODC-BY"
|
||||
Swin_V2_B_Weights.NAIP_RGB_MI_SATLAS,3,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY
|
||||
Swin_V2_B_Weights.NAIP_RGB_SI_SATLAS,3,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY
|
||||
|
|
|
|
@ -1,4 +1,5 @@
|
|||
Weight,Channels,Source,Citation,License
|
||||
ResNet50_Weights.SENTINEL1_ALL_DECUR, 2,`link <https://github.com/zhu-xlab/DeCUR>`__,`link <https://arxiv.org/abs/2309.05300>`__,"Apache-2.0"
|
||||
ResNet50_Weights.SENTINEL1_ALL_MOCO, 2,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,"CC-BY-4.0"
|
||||
Swin_V2_B_Weights.SENTINEL1_SI_SATLAS,2,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,"ODC-BY"
|
||||
Swin_V2_B_Weights.SENTINEL1_MI_SATLAS,2,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY
|
||||
Swin_V2_B_Weights.SENTINEL1_SI_SATLAS,2,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY
|
||||
|
|
|
|
@ -5,9 +5,23 @@ ResNet18_Weights.SENTINEL2_RGB_SECO, 3,`link <https://github.com/ServiceNow/seas
|
|||
ResNet50_Weights.SENTINEL2_ALL_DECUR,13,`link <https://github.com/zhu-xlab/DeCUR>`__,`link <https://arxiv.org/abs/2309.05300>`__,"Apache-2.0",,,,
|
||||
ResNet50_Weights.SENTINEL2_ALL_DINO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,"CC-BY-4.0",90.7,99.1,63.6,
|
||||
ResNet50_Weights.SENTINEL2_ALL_MOCO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,"CC-BY-4.0",91.8,99.1,60.9,
|
||||
ResNet50_Weights.SENTINEL2_MI_MS_SATLAS,9,`link <https://github.com/allenai/satlaspretrain_models>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY,,,,
|
||||
ResNet50_Weights.SENTINEL2_MI_RGB_SATLAS,3,`link <https://github.com/allenai/satlaspretrain_models>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY,,,,
|
||||
ResNet50_Weights.SENTINEL2_SI_MS_SATLAS,9,`link <https://github.com/allenai/satlaspretrain_models>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY,,,,
|
||||
ResNet50_Weights.SENTINEL2_SI_RGB_SATLAS,3,`link <https://github.com/allenai/satlaspretrain_models>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY,,,,
|
||||
ResNet50_Weights.SENTINEL2_RGB_MOCO, 3,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,"CC-BY-4.0",,,
|
||||
ResNet50_Weights.SENTINEL2_RGB_SECO, 3,`link <https://github.com/ServiceNow/seasonal-contrast>`__,`link <https://arxiv.org/abs/2103.16607>`__,"Apache-2.0",87.81,,,
|
||||
ResNet152_Weights.SENTINEL2_MI_MS_SATLAS,9,`link <https://github.com/allenai/satlaspretrain_models>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY,,,,
|
||||
ResNet152_Weights.SENTINEL2_MI_RGB_SATLAS,3,`link <https://github.com/allenai/satlaspretrain_models>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY,,,,
|
||||
ResNet152_Weights.SENTINEL2_SI_MS_SATLAS,9,`link <https://github.com/allenai/satlaspretrain_models>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY,,,,
|
||||
ResNet152_Weights.SENTINEL2_SI_RGB_SATLAS,3,`link <https://github.com/allenai/satlaspretrain_models>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY,,,,
|
||||
ViTSmall16_Weights.SENTINEL2_ALL_DINO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,"CC-BY-4.0",90.5,99.0,62.2,
|
||||
ViTSmall16_Weights.SENTINEL2_ALL_MOCO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,"CC-BY-4.0",89.9,98.6,61.6,
|
||||
Swin_V2_B_Weights.SENTINEL2_RGB_SI_SATLAS,3,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,"ODC-BY",,,,
|
||||
Swin_V2_B_Weights.SENTINEL2_MS_SI_SATLAS,9,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,"ODC-BY",,,,
|
||||
Swin_V2_T_Weights.SENTINEL2_MI_MS_SATLAS,9,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY,,,,
|
||||
Swin_V2_T_Weights.SENTINEL2_MI_RGB_SATLAS,3,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY,,,,
|
||||
Swin_V2_T_Weights.SENTINEL2_SI_MS_SATLAS,9,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY,,,,
|
||||
Swin_V2_T_Weights.SENTINEL2_SI_RGB_SATLAS,3,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY,,,,
|
||||
Swin_V2_B_Weights.SENTINEL2_MI_MS_SATLAS,9,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY,,,,
|
||||
Swin_V2_B_Weights.SENTINEL2_MI_RGB_SATLAS,3,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY,,,,
|
||||
Swin_V2_B_Weights.SENTINEL2_SI_MS_SATLAS,9,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY,,,,
|
||||
Swin_V2_B_Weights.SENTINEL2_SI_RGB_SATLAS,3,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY,,,,
|
||||
|
|
Не удается отобразить этот файл, потому что он имеет неправильное количество полей в строке 8.
|
|
@ -12,8 +12,10 @@ from torchgeo.models import (
|
|||
dofa_large_patch16_224,
|
||||
resnet18,
|
||||
resnet50,
|
||||
resnet152,
|
||||
scalemae_large_patch16,
|
||||
swin_v2_b,
|
||||
swin_v2_t,
|
||||
vit_small_patch16_224,
|
||||
)
|
||||
|
||||
|
@ -22,9 +24,11 @@ __all__ = (
|
|||
'dofa_large_patch16_224',
|
||||
'resnet18',
|
||||
'resnet50',
|
||||
'resnet152',
|
||||
'scalemae_large_patch16',
|
||||
'swin_v2_t',
|
||||
'swin_v2_b',
|
||||
'vit_small_patch16_224',
|
||||
)
|
||||
|
||||
dependencies = ['timm']
|
||||
dependencies = ['timm', 'torchvision']
|
||||
|
|
|
@ -13,8 +13,10 @@ from torchgeo.models import (
|
|||
DOFALarge16_Weights,
|
||||
ResNet18_Weights,
|
||||
ResNet50_Weights,
|
||||
ResNet152_Weights,
|
||||
ScaleMAELarge16_Weights,
|
||||
Swin_V2_B_Weights,
|
||||
Swin_V2_T_Weights,
|
||||
ViTSmall16_Weights,
|
||||
dofa_base_patch16_224,
|
||||
dofa_large_patch16_224,
|
||||
|
@ -24,8 +26,10 @@ from torchgeo.models import (
|
|||
list_models,
|
||||
resnet18,
|
||||
resnet50,
|
||||
resnet152,
|
||||
scalemae_large_patch16,
|
||||
swin_v2_b,
|
||||
swin_v2_t,
|
||||
vit_small_patch16_224,
|
||||
)
|
||||
|
||||
|
@ -34,7 +38,9 @@ builders = [
|
|||
dofa_large_patch16_224,
|
||||
resnet18,
|
||||
resnet50,
|
||||
resnet152,
|
||||
scalemae_large_patch16,
|
||||
swin_v2_t,
|
||||
swin_v2_b,
|
||||
vit_small_patch16_224,
|
||||
]
|
||||
|
@ -43,7 +49,9 @@ enums = [
|
|||
DOFALarge16_Weights,
|
||||
ResNet18_Weights,
|
||||
ResNet50_Weights,
|
||||
ResNet152_Weights,
|
||||
ScaleMAELarge16_Weights,
|
||||
Swin_V2_T_Weights,
|
||||
Swin_V2_B_Weights,
|
||||
ViTSmall16_Weights,
|
||||
]
|
||||
|
|
|
@ -10,7 +10,14 @@ from _pytest.fixtures import SubRequest
|
|||
from pytest import MonkeyPatch
|
||||
from torchvision.models._api import WeightsEnum
|
||||
|
||||
from torchgeo.models import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50
|
||||
from torchgeo.models import (
|
||||
ResNet18_Weights,
|
||||
ResNet50_Weights,
|
||||
ResNet152_Weights,
|
||||
resnet18,
|
||||
resnet50,
|
||||
resnet152,
|
||||
)
|
||||
|
||||
|
||||
class TestResNet18:
|
||||
|
@ -44,7 +51,7 @@ class TestResNet18:
|
|||
def test_transforms(self, mocked_weights: WeightsEnum) -> None:
|
||||
c = mocked_weights.meta['in_chans']
|
||||
sample = {
|
||||
'image': torch.arange(c * 224 * 224, dtype=torch.float).view(c, 224, 224)
|
||||
'image': torch.arange(c * 256 * 256, dtype=torch.float).view(c, 256, 256)
|
||||
}
|
||||
mocked_weights.transforms(sample)
|
||||
|
||||
|
@ -84,10 +91,50 @@ class TestResNet50:
|
|||
def test_transforms(self, mocked_weights: WeightsEnum) -> None:
|
||||
c = mocked_weights.meta['in_chans']
|
||||
sample = {
|
||||
'image': torch.arange(c * 224 * 224, dtype=torch.float).view(c, 224, 224)
|
||||
'image': torch.arange(c * 256 * 256, dtype=torch.float).view(c, 256, 256)
|
||||
}
|
||||
mocked_weights.transforms(sample)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_resnet_download(self, weights: WeightsEnum) -> None:
|
||||
resnet50(weights=weights)
|
||||
|
||||
|
||||
class TestResNet152:
|
||||
@pytest.fixture(params=[*ResNet152_Weights])
|
||||
def weights(self, request: SubRequest) -> WeightsEnum:
|
||||
return request.param
|
||||
|
||||
@pytest.fixture
|
||||
def mocked_weights(
|
||||
self,
|
||||
tmp_path: Path,
|
||||
monkeypatch: MonkeyPatch,
|
||||
weights: WeightsEnum,
|
||||
load_state_dict_from_url: None,
|
||||
) -> WeightsEnum:
|
||||
path = tmp_path / f'{weights}.pth'
|
||||
model = timm.create_model('resnet152', in_chans=weights.meta['in_chans'])
|
||||
torch.save(model.state_dict(), path)
|
||||
try:
|
||||
monkeypatch.setattr(weights.value, 'url', str(path))
|
||||
except AttributeError:
|
||||
monkeypatch.setattr(weights, 'url', str(path))
|
||||
return weights
|
||||
|
||||
def test_resnet(self) -> None:
|
||||
resnet152()
|
||||
|
||||
def test_resnet_weights(self, mocked_weights: WeightsEnum) -> None:
|
||||
resnet152(weights=mocked_weights)
|
||||
|
||||
def test_transforms(self, mocked_weights: WeightsEnum) -> None:
|
||||
c = mocked_weights.meta['in_chans']
|
||||
sample = {
|
||||
'image': torch.arange(c * 256 * 256, dtype=torch.float).view(c, 256, 256)
|
||||
}
|
||||
mocked_weights.transforms(sample)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_resnet_download(self, weights: WeightsEnum) -> None:
|
||||
resnet152(weights=weights)
|
||||
|
|
|
@ -10,7 +10,52 @@ from _pytest.fixtures import SubRequest
|
|||
from pytest import MonkeyPatch
|
||||
from torchvision.models._api import WeightsEnum
|
||||
|
||||
from torchgeo.models import Swin_V2_B_Weights, swin_v2_b
|
||||
from torchgeo.models import Swin_V2_B_Weights, Swin_V2_T_Weights, swin_v2_b, swin_v2_t
|
||||
|
||||
|
||||
class TestSwin_V2_T:
|
||||
@pytest.fixture(params=[*Swin_V2_T_Weights])
|
||||
def weights(self, request: SubRequest) -> WeightsEnum:
|
||||
return request.param
|
||||
|
||||
@pytest.fixture
|
||||
def mocked_weights(
|
||||
self,
|
||||
tmp_path: Path,
|
||||
monkeypatch: MonkeyPatch,
|
||||
weights: WeightsEnum,
|
||||
load_state_dict_from_url: None,
|
||||
) -> WeightsEnum:
|
||||
path = tmp_path / f'{weights}.pth'
|
||||
model = torchvision.models.swin_v2_t()
|
||||
num_channels = weights.meta['in_chans']
|
||||
out_channels = model.features[0][0].out_channels
|
||||
model.features[0][0] = torch.nn.Conv2d(
|
||||
num_channels, out_channels, kernel_size=(4, 4), stride=(4, 4)
|
||||
)
|
||||
torch.save(model.state_dict(), path)
|
||||
try:
|
||||
monkeypatch.setattr(weights.value, 'url', str(path))
|
||||
except AttributeError:
|
||||
monkeypatch.setattr(weights, 'url', str(path))
|
||||
return weights
|
||||
|
||||
def test_swin_v2_t(self) -> None:
|
||||
swin_v2_t()
|
||||
|
||||
def test_swin_v2_t_weights(self, mocked_weights: WeightsEnum) -> None:
|
||||
swin_v2_t(weights=mocked_weights)
|
||||
|
||||
def test_transforms(self, mocked_weights: WeightsEnum) -> None:
|
||||
c = mocked_weights.meta['in_chans']
|
||||
sample = {
|
||||
'image': torch.arange(c * 256 * 256, dtype=torch.float).view(c, 256, 256)
|
||||
}
|
||||
mocked_weights.transforms(sample)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_swin_v2_t_download(self, weights: WeightsEnum) -> None:
|
||||
swin_v2_t(weights=weights)
|
||||
|
||||
|
||||
class TestSwin_V2_B:
|
||||
|
@ -28,6 +73,11 @@ class TestSwin_V2_B:
|
|||
) -> WeightsEnum:
|
||||
path = tmp_path / f'{weights}.pth'
|
||||
model = torchvision.models.swin_v2_b()
|
||||
num_channels = weights.meta['in_chans']
|
||||
out_channels = model.features[0][0].out_channels
|
||||
model.features[0][0] = torch.nn.Conv2d(
|
||||
num_channels, out_channels, kernel_size=(4, 4), stride=(4, 4)
|
||||
)
|
||||
torch.save(model.state_dict(), path)
|
||||
try:
|
||||
monkeypatch.setattr(weights.value, 'url', str(path))
|
||||
|
|
|
@ -18,9 +18,16 @@ from .farseg import FarSeg
|
|||
from .fcn import FCN
|
||||
from .fcsiam import FCSiamConc, FCSiamDiff
|
||||
from .rcf import RCF
|
||||
from .resnet import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50
|
||||
from .resnet import (
|
||||
ResNet18_Weights,
|
||||
ResNet50_Weights,
|
||||
ResNet152_Weights,
|
||||
resnet18,
|
||||
resnet50,
|
||||
resnet152,
|
||||
)
|
||||
from .scale_mae import ScaleMAE, ScaleMAELarge16_Weights, scalemae_large_patch16
|
||||
from .swin import Swin_V2_B_Weights, swin_v2_b
|
||||
from .swin import Swin_V2_B_Weights, Swin_V2_T_Weights, swin_v2_b, swin_v2_t
|
||||
from .vit import ViTSmall16_Weights, vit_small_patch16_224
|
||||
|
||||
__all__ = (
|
||||
|
@ -40,16 +47,20 @@ __all__ = (
|
|||
'RCF',
|
||||
'resnet18',
|
||||
'resnet50',
|
||||
'resnet152',
|
||||
'ScaleMAE',
|
||||
'scalemae_large_patch16',
|
||||
'swin_v2_t',
|
||||
'swin_v2_b',
|
||||
'vit_small_patch16_224',
|
||||
# weights
|
||||
'DOFABase16_Weights',
|
||||
'DOFALarge16_Weights',
|
||||
'ResNet50_Weights',
|
||||
'ResNet18_Weights',
|
||||
'ResNet50_Weights',
|
||||
'ResNet152_Weights',
|
||||
'ScaleMAELarge16_Weights',
|
||||
'Swin_V2_T_Weights',
|
||||
'Swin_V2_B_Weights',
|
||||
'ViTSmall16_Weights',
|
||||
# utilities
|
||||
|
|
|
@ -22,9 +22,16 @@ from .dofa import (
|
|||
dofa_base_patch16_224,
|
||||
dofa_large_patch16_224,
|
||||
)
|
||||
from .resnet import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50
|
||||
from .resnet import (
|
||||
ResNet18_Weights,
|
||||
ResNet50_Weights,
|
||||
ResNet152_Weights,
|
||||
resnet18,
|
||||
resnet50,
|
||||
resnet152,
|
||||
)
|
||||
from .scale_mae import ScaleMAELarge16_Weights, scalemae_large_patch16
|
||||
from .swin import Swin_V2_B_Weights, swin_v2_b
|
||||
from .swin import Swin_V2_B_Weights, Swin_V2_T_Weights, swin_v2_b, swin_v2_t
|
||||
from .vit import ViTSmall16_Weights, vit_small_patch16_224
|
||||
|
||||
_model = {
|
||||
|
@ -32,7 +39,9 @@ _model = {
|
|||
'dofa_large_patch16_224': dofa_large_patch16_224,
|
||||
'resnet18': resnet18,
|
||||
'resnet50': resnet50,
|
||||
'resnet152': resnet152,
|
||||
'scalemae_large_patch16': scalemae_large_patch16,
|
||||
'swin_v2_t': swin_v2_t,
|
||||
'swin_v2_b': swin_v2_b,
|
||||
'vit_small_patch16_224': vit_small_patch16_224,
|
||||
}
|
||||
|
@ -42,14 +51,18 @@ _model_weights = {
|
|||
dofa_large_patch16_224: DOFALarge16_Weights,
|
||||
resnet18: ResNet18_Weights,
|
||||
resnet50: ResNet50_Weights,
|
||||
resnet152: ResNet152_Weights,
|
||||
scalemae_large_patch16: ScaleMAELarge16_Weights,
|
||||
swin_v2_t: Swin_V2_T_Weights,
|
||||
swin_v2_b: Swin_V2_B_Weights,
|
||||
vit_small_patch16_224: ViTSmall16_Weights,
|
||||
'dofa_base_patch16_224': DOFABase16_Weights,
|
||||
'dofa_large_patch16_224': DOFALarge16_Weights,
|
||||
'resnet18': ResNet18_Weights,
|
||||
'resnet50': ResNet50_Weights,
|
||||
'resnet152': ResNet152_Weights,
|
||||
'scalemae_large_patch16': ScaleMAELarge16_Weights,
|
||||
'swin_v2_t': Swin_V2_T_Weights,
|
||||
'swin_v2_b': Swin_V2_B_Weights,
|
||||
'vit_small_patch16_224': ViTSmall16_Weights,
|
||||
}
|
||||
|
|
|
@ -11,6 +11,13 @@ import torch
|
|||
from timm.models import ResNet
|
||||
from torchvision.models._api import Weights, WeightsEnum
|
||||
|
||||
from .swin import (
|
||||
_satlas_bands,
|
||||
_satlas_sentinel2_bands,
|
||||
_satlas_sentinel2_transforms,
|
||||
_satlas_transforms,
|
||||
)
|
||||
|
||||
# https://github.com/zhu-xlab/DeCUR/blob/f190e9a3895ef645c005c8c2fce287ffa5a937e3/src/transfer_classification_BE/linear_BE_resnet.py#L286
|
||||
# Normalization by channel-wise band statistics
|
||||
_mean_s1 = torch.tensor([-12.59, -20.26])
|
||||
|
@ -113,7 +120,7 @@ Weights.__deepcopy__ = lambda *args, **kwargs: args[0]
|
|||
|
||||
|
||||
class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
|
||||
"""ResNet18 weights.
|
||||
"""ResNet-18 weights.
|
||||
|
||||
For `timm <https://github.com/rwightman/pytorch-image-models>`_
|
||||
*resnet18* implementation.
|
||||
|
@ -292,7 +299,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
|
|||
|
||||
|
||||
class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
|
||||
"""ResNet50 weights.
|
||||
"""ResNet-50 weights.
|
||||
|
||||
For `timm <https://github.com/rwightman/pytorch-image-models>`_
|
||||
*resnet50* implementation.
|
||||
|
@ -508,6 +515,32 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
|
|||
},
|
||||
)
|
||||
|
||||
SENTINEL2_MI_MS_SATLAS = Weights(
|
||||
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_resnet50_mi_ms-da5413d2.pth',
|
||||
transforms=_satlas_sentinel2_transforms,
|
||||
meta={
|
||||
'dataset': 'SatlasPretrain',
|
||||
'in_chans': 9,
|
||||
'model': 'resnet50',
|
||||
'publication': 'https://arxiv.org/abs/2211.15660',
|
||||
'repo': 'https://github.com/allenai/satlaspretrain_models',
|
||||
'bands': _satlas_sentinel2_bands,
|
||||
},
|
||||
)
|
||||
|
||||
SENTINEL2_MI_RGB_SATLAS = Weights(
|
||||
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_resnet50_mi_rgb-e79bb7fe.pth',
|
||||
transforms=_satlas_transforms,
|
||||
meta={
|
||||
'dataset': 'SatlasPretrain',
|
||||
'in_chans': 3,
|
||||
'model': 'resnet50',
|
||||
'publication': 'https://arxiv.org/abs/2211.15660',
|
||||
'repo': 'https://github.com/allenai/satlaspretrain_models',
|
||||
'bands': _satlas_bands,
|
||||
},
|
||||
)
|
||||
|
||||
SENTINEL2_RGB_MOCO = Weights(
|
||||
url='https://hf.co/torchgeo/resnet50_sentinel2_rgb_moco/resolve/efd9723b59a88e9dc1420dc1e96afb25b0630a3c/resnet50_sentinel2_rgb_moco-2b57ba8b.pth',
|
||||
transforms=_ssl4eo_s12_transforms_s2_10k,
|
||||
|
@ -534,6 +567,94 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
|
|||
},
|
||||
)
|
||||
|
||||
SENTINEL2_SI_MS_SATLAS = Weights(
|
||||
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_resnet50_si_ms-1f454cc6.pth',
|
||||
transforms=_satlas_sentinel2_transforms,
|
||||
meta={
|
||||
'dataset': 'SatlasPretrain',
|
||||
'in_chans': 9,
|
||||
'model': 'resnet50',
|
||||
'publication': 'https://arxiv.org/abs/2211.15660',
|
||||
'repo': 'https://github.com/allenai/satlaspretrain_models',
|
||||
'bands': _satlas_sentinel2_bands,
|
||||
},
|
||||
)
|
||||
|
||||
SENTINEL2_SI_RGB_SATLAS = Weights(
|
||||
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_resnet50_si_rgb-45fc6972.pth',
|
||||
transforms=_satlas_transforms,
|
||||
meta={
|
||||
'dataset': 'SatlasPretrain',
|
||||
'in_chans': 3,
|
||||
'model': 'resnet50',
|
||||
'publication': 'https://arxiv.org/abs/2211.15660',
|
||||
'repo': 'https://github.com/allenai/satlaspretrain_models',
|
||||
'bands': _satlas_bands,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class ResNet152_Weights(WeightsEnum): # type: ignore[misc]
|
||||
"""ResNet-152 weights.
|
||||
|
||||
For `timm <https://github.com/rwightman/pytorch-image-models>`_
|
||||
*resnet152* implementation.
|
||||
|
||||
.. versionadded:: 0.6
|
||||
"""
|
||||
|
||||
SENTINEL2_MI_MS_SATLAS = Weights(
|
||||
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_resnet152_mi_ms-fd35b4bb.pth',
|
||||
transforms=_satlas_sentinel2_transforms,
|
||||
meta={
|
||||
'dataset': 'SatlasPretrain',
|
||||
'in_chans': 9,
|
||||
'model': 'resnet50',
|
||||
'publication': 'https://arxiv.org/abs/2211.15660',
|
||||
'repo': 'https://github.com/allenai/satlaspretrain_models',
|
||||
'bands': _satlas_sentinel2_bands,
|
||||
},
|
||||
)
|
||||
|
||||
SENTINEL2_MI_RGB_SATLAS = Weights(
|
||||
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_resnet152_mi_rgb-67563ac5.pth',
|
||||
transforms=_satlas_transforms,
|
||||
meta={
|
||||
'dataset': 'SatlasPretrain',
|
||||
'in_chans': 3,
|
||||
'model': 'resnet50',
|
||||
'publication': 'https://arxiv.org/abs/2211.15660',
|
||||
'repo': 'https://github.com/allenai/satlaspretrain_models',
|
||||
'bands': _satlas_bands,
|
||||
},
|
||||
)
|
||||
|
||||
SENTINEL2_SI_MS_SATLAS = Weights(
|
||||
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_resnet152_si_ms-4500c6cb.pth',
|
||||
transforms=_satlas_sentinel2_transforms,
|
||||
meta={
|
||||
'dataset': 'SatlasPretrain',
|
||||
'in_chans': 9,
|
||||
'model': 'resnet50',
|
||||
'publication': 'https://arxiv.org/abs/2211.15660',
|
||||
'repo': 'https://github.com/allenai/satlaspretrain_models',
|
||||
'bands': _satlas_sentinel2_bands,
|
||||
},
|
||||
)
|
||||
|
||||
SENTINEL2_SI_RGB_SATLAS = Weights(
|
||||
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_resnet152_si_rgb-f4d24c3c.pth',
|
||||
transforms=_satlas_transforms,
|
||||
meta={
|
||||
'dataset': 'SatlasPretrain',
|
||||
'in_chans': 3,
|
||||
'model': 'resnet50',
|
||||
'publication': 'https://arxiv.org/abs/2211.15660',
|
||||
'repo': 'https://github.com/allenai/satlaspretrain_models',
|
||||
'bands': _satlas_bands,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def resnet18(
|
||||
weights: ResNet18_Weights | None = None, *args: Any, **kwargs: Any
|
||||
|
@ -602,3 +723,37 @@ def resnet50(
|
|||
assert not unexpected_keys
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def resnet152(
|
||||
weights: ResNet152_Weights | None = None, *args: Any, **kwargs: Any
|
||||
) -> ResNet:
|
||||
"""ResNet-152 model.
|
||||
|
||||
If you use this model in your research, please cite the following paper:
|
||||
|
||||
* https://arxiv.org/pdf/1512.03385.pdf
|
||||
|
||||
.. versionadded:: 0.6
|
||||
|
||||
Args:
|
||||
weights: Pre-trained model weights to use.
|
||||
*args: Additional arguments to pass to :func:`timm.create_model`.
|
||||
**kwargs: Additional keywork arguments to pass to :func:`timm.create_model`.
|
||||
|
||||
Returns:
|
||||
A ResNet-152 model.
|
||||
"""
|
||||
if weights:
|
||||
kwargs['in_chans'] = weights.meta['in_chans']
|
||||
|
||||
model: ResNet = timm.create_model('resnet152', *args, **kwargs)
|
||||
|
||||
if weights:
|
||||
missing_keys, unexpected_keys = model.load_state_dict(
|
||||
weights.get_state_dict(progress=True), strict=False
|
||||
)
|
||||
assert set(missing_keys) <= {'fc.weight', 'fc.bias'}
|
||||
assert not unexpected_keys
|
||||
|
||||
return model
|
||||
|
|
|
@ -8,35 +8,38 @@ from typing import Any
|
|||
import kornia.augmentation as K
|
||||
import torch
|
||||
import torchvision
|
||||
from kornia.contrib import Lambda
|
||||
from torchvision.models import SwinTransformer
|
||||
from torchvision.models._api import Weights, WeightsEnum
|
||||
|
||||
# https://github.com/allenai/satlas/blob/bcaa968da5395f675d067613e02613a344e81415/satlas/cmd/model/train.py#L42
|
||||
# Satlas uses the TCI product for Sentinel-2 RGB, which is in the range (0, 255).
|
||||
# See details: https://github.com/allenai/satlas/blob/main/Normalization.md#sentinel-2-images.
|
||||
# Satlas Sentinel-1 and RGB Sentinel-2 and NAIP imagery is uint8 and is normalized to (0, 1) by dividing by 255.
|
||||
_satlas_transforms = K.AugmentationSequential(
|
||||
K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)), data_keys=None
|
||||
)
|
||||
import torchgeo.transforms.transforms as T
|
||||
|
||||
# Satlas uses the TCI product for Sentinel-2 RGB, which is in the range (0, 255).
|
||||
# See details: https://github.com/allenai/satlas/blob/main/Normalization.md#sentinel-2-images.
|
||||
# Satlas Sentinel-2 multispectral imagery has first 3 bands divided by 255 and the following 6 bands by 8160, both clipped to (0, 1).
|
||||
_std = torch.tensor(
|
||||
[255.0, 255.0, 255.0, 8160.0, 8160.0, 8160.0, 8160.0, 8160.0, 8160.0]
|
||||
)
|
||||
_mean = torch.zeros_like(_std)
|
||||
_sentinel2_ms_satlas_transforms = K.AugmentationSequential(
|
||||
K.Normalize(mean=_mean, std=_std),
|
||||
K.ImageSequential(Lambda(lambda x: torch.clamp(x, min=0.0, max=1.0))),
|
||||
# All Satlas transforms include:
|
||||
# https://github.com/allenai/satlas/blob/main/satlas/cmd/model/train.py#L49
|
||||
#
|
||||
# Information about sensor-specific normalization can be found at:
|
||||
# https://github.com/allenai/satlas/blob/main/Normalization.md
|
||||
|
||||
_satlas_bands = ('B04', 'B03', 'B02')
|
||||
_satlas_transforms = K.AugmentationSequential(
|
||||
K.CenterCrop(256),
|
||||
K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)),
|
||||
data_keys=None,
|
||||
)
|
||||
|
||||
# Satlas Landsat imagery is 16-bit, normalized by clipping some pixel N with (N-4000)/16320 to (0, 1).
|
||||
_landsat_satlas_transforms = K.AugmentationSequential(
|
||||
_satlas_sentinel2_bands = (*_satlas_bands, 'B05', 'B06', 'B07', 'B08', 'B11', 'B12')
|
||||
_std = torch.tensor([255, 255, 255, 8160, 8160, 8160, 8160, 8160, 8160])
|
||||
_satlas_sentinel2_transforms = K.AugmentationSequential(
|
||||
K.CenterCrop(256),
|
||||
K.Normalize(mean=torch.tensor(0), std=_std),
|
||||
T._Clamp(p=1, min=0, max=1),
|
||||
data_keys=None,
|
||||
)
|
||||
|
||||
_satlas_landsat_bands = tuple(f'B{i:02}' for i in range(1, 12))
|
||||
_satlas_landsat_transforms = K.AugmentationSequential(
|
||||
K.CenterCrop(256),
|
||||
K.Normalize(mean=torch.tensor(4000), std=torch.tensor(16320)),
|
||||
K.ImageSequential(Lambda(lambda x: torch.clamp(x, min=0.0, max=1.0))),
|
||||
T._Clamp(p=1, min=0, max=1),
|
||||
data_keys=None,
|
||||
)
|
||||
|
||||
|
@ -46,6 +49,68 @@ _landsat_satlas_transforms = K.AugmentationSequential(
|
|||
Weights.__deepcopy__ = lambda *args, **kwargs: args[0]
|
||||
|
||||
|
||||
class Swin_V2_T_Weights(WeightsEnum): # type: ignore[misc]
|
||||
"""Swin Transformer v2 Tiny weights.
|
||||
|
||||
For `torchvision <https://github.com/pytorch/vision>`_
|
||||
*swin_v2_t* implementation.
|
||||
|
||||
.. versionadded:: 0.6
|
||||
"""
|
||||
|
||||
SENTINEL2_MI_MS_SATLAS = Weights(
|
||||
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_swint_mi_ms-d8c659e3.pth',
|
||||
transforms=_satlas_sentinel2_transforms,
|
||||
meta={
|
||||
'dataset': 'SatlasPretrain',
|
||||
'in_chans': 9,
|
||||
'model': 'swin_v2_t',
|
||||
'publication': 'https://arxiv.org/abs/2211.15660',
|
||||
'repo': 'https://github.com/allenai/satlas',
|
||||
'bands': _satlas_sentinel2_bands,
|
||||
},
|
||||
)
|
||||
|
||||
SENTINEL2_MI_RGB_SATLAS = Weights(
|
||||
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_swint_mi_rgb-424d91f4.pth',
|
||||
transforms=_satlas_transforms,
|
||||
meta={
|
||||
'dataset': 'SatlasPretrain',
|
||||
'in_chans': 3,
|
||||
'model': 'swin_v2_t',
|
||||
'publication': 'https://arxiv.org/abs/2211.15660',
|
||||
'repo': 'https://github.com/allenai/satlas',
|
||||
'bands': _satlas_bands,
|
||||
},
|
||||
)
|
||||
|
||||
SENTINEL2_SI_MS_SATLAS = Weights(
|
||||
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_swint_si_ms-bc68e396.pth',
|
||||
transforms=_satlas_sentinel2_transforms,
|
||||
meta={
|
||||
'dataset': 'SatlasPretrain',
|
||||
'in_chans': 9,
|
||||
'model': 'swin_v2_t',
|
||||
'publication': 'https://arxiv.org/abs/2211.15660',
|
||||
'repo': 'https://github.com/allenai/satlas',
|
||||
'bands': _satlas_sentinel2_bands,
|
||||
},
|
||||
)
|
||||
|
||||
SENTINEL2_SI_RGB_SATLAS = Weights(
|
||||
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_swint_si_rgb-0c1a96e0.pth',
|
||||
transforms=_satlas_transforms,
|
||||
meta={
|
||||
'dataset': 'SatlasPretrain',
|
||||
'in_chans': 3,
|
||||
'model': 'swin_v2_t',
|
||||
'publication': 'https://arxiv.org/abs/2211.15660',
|
||||
'repo': 'https://github.com/allenai/satlas',
|
||||
'bands': _satlas_bands,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc]
|
||||
"""Swin Transformer v2 Base weights.
|
||||
|
||||
|
@ -55,81 +120,175 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc]
|
|||
.. versionadded:: 0.6
|
||||
"""
|
||||
|
||||
NAIP_RGB_MI_SATLAS = Weights(
|
||||
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/aerial_swinb_mi-326d69e1.pth',
|
||||
transforms=_satlas_transforms,
|
||||
meta={
|
||||
'dataset': 'SatlasPretrain',
|
||||
'in_chans': 3,
|
||||
'model': 'swin_v2_b',
|
||||
'publication': 'https://arxiv.org/abs/2211.15660',
|
||||
'repo': 'https://github.com/allenai/satlas',
|
||||
'bands': ('R', 'G', 'B'),
|
||||
},
|
||||
)
|
||||
|
||||
NAIP_RGB_SI_SATLAS = Weights(
|
||||
url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/aerial_swinb_si.pth',
|
||||
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/aerial_swinb_si-e4169eb1.pth',
|
||||
transforms=_satlas_transforms,
|
||||
meta={
|
||||
'dataset': 'Satlas',
|
||||
'dataset': 'SatlasPretrain',
|
||||
'in_chans': 3,
|
||||
'model': 'swin_v2_b',
|
||||
'publication': 'https://arxiv.org/abs/2211.15660',
|
||||
'repo': 'https://github.com/allenai/satlas',
|
||||
'bands': ('R', 'G', 'B'),
|
||||
},
|
||||
)
|
||||
|
||||
SENTINEL2_RGB_SI_SATLAS = Weights(
|
||||
url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel2_swinb_si_rgb.pth',
|
||||
transforms=_satlas_transforms,
|
||||
LANDSAT_MI_SATLAS = Weights(
|
||||
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/landsat_swinb_mi-6b4a1cda.pth',
|
||||
transforms=_satlas_landsat_transforms,
|
||||
meta={
|
||||
'dataset': 'Satlas',
|
||||
'in_chans': 3,
|
||||
'model': 'swin_v2_b',
|
||||
'publication': 'https://arxiv.org/abs/2211.15660',
|
||||
'repo': 'https://github.com/allenai/satlas',
|
||||
},
|
||||
)
|
||||
|
||||
SENTINEL2_MS_SI_SATLAS = Weights(
|
||||
url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel2_swinb_si_ms.pth',
|
||||
transforms=_sentinel2_ms_satlas_transforms,
|
||||
meta={
|
||||
'dataset': 'Satlas',
|
||||
'in_chans': 9,
|
||||
'model': 'swin_v2_b',
|
||||
'publication': 'https://arxiv.org/abs/2211.15660',
|
||||
'repo': 'https://github.com/allenai/satlas',
|
||||
'bands': ['B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B11', 'B12'],
|
||||
},
|
||||
)
|
||||
|
||||
SENTINEL1_SI_SATLAS = Weights(
|
||||
url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel1_swinb_si.pth',
|
||||
transforms=_satlas_transforms,
|
||||
meta={
|
||||
'dataset': 'Satlas',
|
||||
'in_chans': 2,
|
||||
'model': 'swin_v2_b',
|
||||
'publication': 'https://arxiv.org/abs/2211.15660',
|
||||
'repo': 'https://github.com/allenai/satlas',
|
||||
'bands': ['VH', 'VV'],
|
||||
},
|
||||
)
|
||||
|
||||
LANDSAT_SI_SATLAS = Weights(
|
||||
url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/landsat_swinb_si.pth',
|
||||
transforms=_landsat_satlas_transforms,
|
||||
meta={
|
||||
'dataset': 'Satlas',
|
||||
'dataset': 'SatlasPretrain',
|
||||
'in_chans': 11,
|
||||
'model': 'swin_v2_b',
|
||||
'publication': 'https://arxiv.org/abs/2211.15660',
|
||||
'repo': 'https://github.com/allenai/satlas',
|
||||
'bands': [
|
||||
'B01',
|
||||
'B02',
|
||||
'B03',
|
||||
'B04',
|
||||
'B05',
|
||||
'B06',
|
||||
'B07',
|
||||
'B08',
|
||||
'B09',
|
||||
'B10',
|
||||
'B11',
|
||||
],
|
||||
'bands': _satlas_landsat_bands,
|
||||
},
|
||||
)
|
||||
|
||||
LANDSAT_SI_SATLAS = Weights(
|
||||
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/landsat_swinb_si-4af978f6.pth',
|
||||
transforms=_satlas_landsat_transforms,
|
||||
meta={
|
||||
'dataset': 'SatlasPretrain',
|
||||
'in_chans': 11,
|
||||
'model': 'swin_v2_b',
|
||||
'publication': 'https://arxiv.org/abs/2211.15660',
|
||||
'repo': 'https://github.com/allenai/satlas',
|
||||
'bands': _satlas_landsat_bands,
|
||||
},
|
||||
)
|
||||
|
||||
SENTINEL1_MI_SATLAS = Weights(
|
||||
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel1_swinb_mi-f6c43d97.pth',
|
||||
transforms=_satlas_transforms,
|
||||
meta={
|
||||
'dataset': 'SatlasPretrain',
|
||||
'in_chans': 2,
|
||||
'model': 'swin_v2_b',
|
||||
'publication': 'https://arxiv.org/abs/2211.15660',
|
||||
'repo': 'https://github.com/allenai/satlas',
|
||||
'bands': ('VH', 'VV'),
|
||||
},
|
||||
)
|
||||
|
||||
SENTINEL1_SI_SATLAS = Weights(
|
||||
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel1_swinb_si-3981c153.pth',
|
||||
transforms=_satlas_transforms,
|
||||
meta={
|
||||
'dataset': 'SatlasPretrain',
|
||||
'in_chans': 2,
|
||||
'model': 'swin_v2_b',
|
||||
'publication': 'https://arxiv.org/abs/2211.15660',
|
||||
'repo': 'https://github.com/allenai/satlas',
|
||||
'bands': ('VH', 'VV'),
|
||||
},
|
||||
)
|
||||
|
||||
SENTINEL2_MI_MS_SATLAS = Weights(
|
||||
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_swinb_mi_ms-39c86721.pth',
|
||||
transforms=_satlas_sentinel2_transforms,
|
||||
meta={
|
||||
'dataset': 'SatlasPretrain',
|
||||
'in_chans': 9,
|
||||
'model': 'swin_v2_b',
|
||||
'publication': 'https://arxiv.org/abs/2211.15660',
|
||||
'repo': 'https://github.com/allenai/satlas',
|
||||
'bands': _satlas_sentinel2_bands,
|
||||
},
|
||||
)
|
||||
|
||||
SENTINEL2_MI_RGB_SATLAS = Weights(
|
||||
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_swinb_mi_rgb-4efa210c.pth',
|
||||
transforms=_satlas_transforms,
|
||||
meta={
|
||||
'dataset': 'SatlasPretrain',
|
||||
'in_chans': 3,
|
||||
'model': 'swin_v2_b',
|
||||
'publication': 'https://arxiv.org/abs/2211.15660',
|
||||
'repo': 'https://github.com/allenai/satlas',
|
||||
'bands': _satlas_bands,
|
||||
},
|
||||
)
|
||||
|
||||
SENTINEL2_SI_MS_SATLAS = Weights(
|
||||
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_swinb_si_ms-fe22a12c.pth',
|
||||
transforms=_satlas_sentinel2_transforms,
|
||||
meta={
|
||||
'dataset': 'SatlasPretrain',
|
||||
'in_chans': 9,
|
||||
'model': 'swin_v2_b',
|
||||
'publication': 'https://arxiv.org/abs/2211.15660',
|
||||
'repo': 'https://github.com/allenai/satlas',
|
||||
'bands': _satlas_sentinel2_bands,
|
||||
},
|
||||
)
|
||||
|
||||
SENTINEL2_SI_RGB_SATLAS = Weights(
|
||||
url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_swinb_si_rgb-156a98d5.pth',
|
||||
transforms=_satlas_transforms,
|
||||
meta={
|
||||
'dataset': 'SatlasPretrain',
|
||||
'in_chans': 3,
|
||||
'model': 'swin_v2_b',
|
||||
'publication': 'https://arxiv.org/abs/2211.15660',
|
||||
'repo': 'https://github.com/allenai/satlas',
|
||||
'bands': _satlas_bands,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def swin_v2_t(
|
||||
weights: Swin_V2_T_Weights | None = None, *args: Any, **kwargs: Any
|
||||
) -> SwinTransformer:
|
||||
"""Swin Transformer v2 tiny model.
|
||||
|
||||
If you use this model in your research, please cite the following paper:
|
||||
|
||||
* https://arxiv.org/abs/2111.09883
|
||||
|
||||
.. versionadded:: 0.6
|
||||
|
||||
Args:
|
||||
weights: Pre-trained model weights to use.
|
||||
*args: Additional arguments to
|
||||
pass to :class:`torchvision.models.swin_transformer.SwinTransformer`.
|
||||
**kwargs: Additional keywork arguments to
|
||||
pass to :class:`torchvision.models.swin_transformer.SwinTransformer`.
|
||||
|
||||
Returns:
|
||||
A Swin Transformer Tiny model.
|
||||
"""
|
||||
model: SwinTransformer = torchvision.models.swin_v2_t(weights=None, *args, **kwargs)
|
||||
|
||||
if weights:
|
||||
num_channels = weights.meta['in_chans']
|
||||
out_channels = model.features[0][0].out_channels
|
||||
# https://github.com/allenai/satlaspretrain_models/blob/main/satlaspretrain_models/models/backbones.py#L27
|
||||
model.features[0][0] = torch.nn.Conv2d(
|
||||
num_channels, out_channels, kernel_size=(4, 4), stride=(4, 4)
|
||||
)
|
||||
missing_keys, unexpected_keys = model.load_state_dict(
|
||||
weights.get_state_dict(progress=True), strict=False
|
||||
)
|
||||
assert set(missing_keys) <= set()
|
||||
assert not unexpected_keys
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def swin_v2_b(
|
||||
weights: Swin_V2_B_Weights | None = None, *args: Any, **kwargs: Any
|
||||
|
@ -155,6 +314,16 @@ def swin_v2_b(
|
|||
model: SwinTransformer = torchvision.models.swin_v2_b(weights=None, *args, **kwargs)
|
||||
|
||||
if weights:
|
||||
model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
|
||||
num_channels = weights.meta['in_chans']
|
||||
out_channels = model.features[0][0].out_channels
|
||||
# https://github.com/allenai/satlaspretrain_models/blob/main/satlaspretrain_models/models/backbones.py#L27
|
||||
model.features[0][0] = torch.nn.Conv2d(
|
||||
num_channels, out_channels, kernel_size=(4, 4), stride=(4, 4)
|
||||
)
|
||||
missing_keys, unexpected_keys = model.load_state_dict(
|
||||
weights.get_state_dict(progress=True), strict=False
|
||||
)
|
||||
assert set(missing_keys) <= set()
|
||||
assert not unexpected_keys
|
||||
|
||||
return model
|
||||
|
|
|
@ -8,7 +8,7 @@ from typing import Any
|
|||
import kornia.augmentation as K
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from kornia.contrib import Lambda, extract_tensor_patches
|
||||
from kornia.contrib import extract_tensor_patches
|
||||
from kornia.geometry import crop_by_indices
|
||||
from kornia.geometry.boxes import Boxes
|
||||
from torch import Tensor
|
||||
|
@ -25,7 +25,7 @@ class AugmentationSequential(Module):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
*args: K.base._AugmentationBase | K.ImageSequential | Lambda,
|
||||
*args: K.base._AugmentationBase | K.ImageSequential,
|
||||
data_keys: list[str],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
|
@ -53,7 +53,7 @@ class AugmentationSequential(Module):
|
|||
else:
|
||||
keys.append(key)
|
||||
|
||||
self.augs = K.AugmentationSequential(*args, data_keys=keys, **kwargs) # type: ignore[arg-type]
|
||||
self.augs = K.AugmentationSequential(*args, data_keys=keys, **kwargs)
|
||||
|
||||
def forward(self, batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Perform augmentations and update data dict.
|
||||
|
@ -272,3 +272,54 @@ class _ExtractPatches(K.GeometricAugmentationBase2D):
|
|||
out = rearrange(out, 'b t c h w -> (b t) c h w')
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class _Clamp(K.IntensityAugmentationBase2D):
|
||||
"""Clamp images to a specific range."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
p: float = 0.5,
|
||||
p_batch: float = 1,
|
||||
min: float = 0,
|
||||
max: float = 1,
|
||||
same_on_batch: bool = False,
|
||||
keepdim: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new _Clamp instance.
|
||||
|
||||
Args:
|
||||
p: Probability for applying an augmentation. This param controls the
|
||||
augmentation probabilities element-wise for a batch.
|
||||
p_batch: Probability for applying an augmentation to a batch. This param
|
||||
controls the augmentation probabilities batch-wise.
|
||||
min: Minimum value to clamp to.
|
||||
max: Maximum value to clamp to.
|
||||
same_on_batch: Apply the same transformation across the batch.
|
||||
keepdim: Whether to keep the output shape the same as input ``True``
|
||||
or broadcast it to the batch form ``False``.
|
||||
"""
|
||||
super().__init__(
|
||||
p=p, p_batch=p_batch, same_on_batch=same_on_batch, keepdim=keepdim
|
||||
)
|
||||
self.flags = {'min': min, 'max': max}
|
||||
|
||||
def apply_transform(
|
||||
self,
|
||||
input: Tensor,
|
||||
params: dict[str, Tensor],
|
||||
flags: dict[str, Any],
|
||||
transform: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
"""Apply the transform.
|
||||
|
||||
Args:
|
||||
input: the input tensor
|
||||
params: generated parameters
|
||||
flags: static parameters
|
||||
transform: the geometric transformation tensor
|
||||
|
||||
Returns:
|
||||
the augmented input
|
||||
"""
|
||||
return torch.clamp(input, self.flags['min'], self.flags['max'])
|
||||
|
|
Загрузка…
Ссылка в новой задаче