* add resnet50 weights from decur

* update decur rn50 weights

* Alphabetical order

* Remove unneeded ignores

* update data transform for loading resnet50 sentinel 1 weights

* update formatting

* update formatting

* update ssl4eo-s12 transforms and naming

* update formatting

* update transforms definition comments

---------

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Yi Wang 2024-08-19 17:53:16 +02:00 коммит произвёл GitHub
Родитель 067ae1af75
Коммит 2d0557f645
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
3 изменённых файлов: 88 добавлений и 8 удалений

Просмотреть файл

@ -1,3 +1,4 @@
Weight,Channels,Source,Citation,License
ResNet50_Weights.SENTINEL1_ALL_DECUR, 2,`link <https://github.com/zhu-xlab/DeCUR>`__,`link <https://arxiv.org/abs/2309.05300>`__,"Apache-2.0"
ResNet50_Weights.SENTINEL1_ALL_MOCO, 2,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,"CC-BY-4.0"
Swin_V2_B_Weights.SENTINEL1_SI_SATLAS,2,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,"ODC-BY"

1 Weight Channels Source Citation License
2 ResNet50_Weights.SENTINEL1_ALL_DECUR 2 `link <https://github.com/zhu-xlab/DeCUR>`__ `link <https://arxiv.org/abs/2309.05300>`__ Apache-2.0
3 ResNet50_Weights.SENTINEL1_ALL_MOCO 2 `link <https://github.com/zhu-xlab/SSL4EO-S12>`__ `link <https://arxiv.org/abs/2211.07044>`__ CC-BY-4.0
4 Swin_V2_B_Weights.SENTINEL1_SI_SATLAS 2 `link <https://github.com/allenai/satlas>`__ `link <https://arxiv.org/abs/2211.15660>`__ ODC-BY

Просмотреть файл

@ -2,6 +2,7 @@ Weight,Channels,Source,Citation,License,BigEarthNet,EuroSAT,So2Sat,OSCD
ResNet18_Weights.SENTINEL2_ALL_MOCO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,"CC-BY-4.0",,,,
ResNet18_Weights.SENTINEL2_RGB_MOCO, 3,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,"CC-BY-4.0",,,,
ResNet18_Weights.SENTINEL2_RGB_SECO, 3,`link <https://github.com/ServiceNow/seasonal-contrast>`__,`link <https://arxiv.org/abs/2103.16607>`__,"Apache-2.0",87.27,93.14,,46.94
ResNet50_Weights.SENTINEL2_ALL_DECUR,13,`link <https://github.com/zhu-xlab/DeCUR>`__,`link <https://arxiv.org/abs/2309.05300>`__,"Apache-2.0",,,,
ResNet50_Weights.SENTINEL2_ALL_DINO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,"CC-BY-4.0",90.7,99.1,63.6,
ResNet50_Weights.SENTINEL2_ALL_MOCO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,"CC-BY-4.0",91.8,99.1,60.9,
ResNet50_Weights.SENTINEL2_RGB_MOCO, 3,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,"CC-BY-4.0",,,

Не удается отобразить этот файл, потому что он имеет неправильное количество полей в строке 7.

Просмотреть файл

@ -11,16 +11,68 @@ import torch
from timm.models import ResNet
from torchvision.models._api import Weights, WeightsEnum
# https://github.com/zhu-xlab/DeCUR/blob/f190e9a3895ef645c005c8c2fce287ffa5a937e3/src/transfer_classification_BE/linear_BE_resnet.py#L286
# Normalization by channel-wise band statistics
_mean_s1 = torch.tensor([-12.59, -20.26])
_std_s1 = torch.tensor([5.26, 5.91])
_ssl4eo_s12_transforms_s1 = K.AugmentationSequential(
K.Resize(256),
K.CenterCrop(224),
K.Normalize(mean=_mean_s1, std=_std_s1),
data_keys=None,
)
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97
# Normalization either by 10K or channel-wise with band statistics
_zhu_xlab_transforms = K.AugmentationSequential(
# Normalization either by 10K (for S2 uint16 input) or channel-wise with band statistics
_ssl4eo_s12_transforms_s2_10k = K.AugmentationSequential(
K.Resize(256),
K.CenterCrop(224),
K.Normalize(mean=torch.tensor(0), std=torch.tensor(10000)),
data_keys=None,
)
_mean_s2 = torch.tensor(
[
1612.9,
1397.6,
1322.3,
1373.1,
1561.0,
2108.4,
2390.7,
2318.7,
2581.0,
837.7,
22.0,
2195.2,
1537.4,
]
)
_std_s2 = torch.tensor(
[
791.0,
854.3,
878.7,
1144.9,
1127.5,
1164.2,
1276.0,
1249.5,
1345.9,
577.5,
47.5,
1340.0,
1142.9,
]
)
_ssl4eo_s12_transforms_s2_stats = K.AugmentationSequential(
K.Resize(256),
K.CenterCrop(224),
K.Normalize(mean=_mean_s2, std=_std_s2),
data_keys=None,
)
# Normalization only available for RGB dataset, defined here:
# https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/seco_dataset.py
_min = torch.tensor([3, 2, 0])
@ -201,7 +253,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
SENTINEL2_ALL_MOCO = Weights(
url='https://hf.co/torchgeo/resnet18_sentinel2_all_moco/resolve/5b8cddc9a14f3844350b7f40b85bcd32aed75918/resnet18_sentinel2_all_moco-59bfdff9.pth',
transforms=_zhu_xlab_transforms,
transforms=_ssl4eo_s12_transforms_s2_10k,
meta={
'dataset': 'SSL4EO-S12',
'in_chans': 13,
@ -214,7 +266,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
SENTINEL2_RGB_MOCO = Weights(
url='https://hf.co/torchgeo/resnet18_sentinel2_rgb_moco/resolve/e1c032e7785fd0625224cdb6699aa138bb304eec/resnet18_sentinel2_rgb_moco-e3a335e3.pth',
transforms=_zhu_xlab_transforms,
transforms=_ssl4eo_s12_transforms_s2_10k,
meta={
'dataset': 'SSL4EO-S12',
'in_chans': 3,
@ -391,9 +443,22 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
},
)
SENTINEL1_ALL_DECUR = Weights(
url='https://huggingface.co/torchgeo/decur/resolve/9328eeb90c686a88b30f8526ed757b4bc0f12027/rn50_ssl4eo-s12_sar_decur_ep100-f0e69ba2.pth',
transforms=_ssl4eo_s12_transforms_s1,
meta={
'dataset': 'SSL4EO-S12',
'in_chans': 2,
'model': 'resnet50',
'publication': 'https://arxiv.org/abs/2309.05300',
'repo': 'https://github.com/zhu-xlab/DeCUR',
'ssl_method': 'decur',
},
)
SENTINEL1_ALL_MOCO = Weights(
url='https://hf.co/torchgeo/resnet50_sentinel1_all_moco/resolve/e79862c667853c10a709bdd77ea8ffbad0e0f1cf/resnet50_sentinel1_all_moco-906e4356.pth',
transforms=_zhu_xlab_transforms,
transforms=_ssl4eo_s12_transforms_s1,
meta={
'dataset': 'SSL4EO-S12',
'in_chans': 2,
@ -404,9 +469,22 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
},
)
SENTINEL2_ALL_DECUR = Weights(
url='https://huggingface.co/torchgeo/decur/resolve/eba7ae5945d482a4319be046d34b552db5dd9950/rn50_ssl4eo-s12_ms_decur_ep100-fc6b09ff.pth',
transforms=_ssl4eo_s12_transforms_s2_10k,
meta={
'dataset': 'SSL4EO-S12',
'in_chans': 13,
'model': 'resnet50',
'publication': 'https://arxiv.org/abs/2309.05300',
'repo': 'https://github.com/zhu-xlab/DeCUR',
'ssl_method': 'decur',
},
)
SENTINEL2_ALL_DINO = Weights(
url='https://hf.co/torchgeo/resnet50_sentinel2_all_dino/resolve/d7f14bf5530d70ac69d763e58e77e44dbecfec7c/resnet50_sentinel2_all_dino-d6c330e9.pth',
transforms=_zhu_xlab_transforms,
transforms=_ssl4eo_s12_transforms_s2_10k,
meta={
'dataset': 'SSL4EO-S12',
'in_chans': 13,
@ -419,7 +497,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
SENTINEL2_ALL_MOCO = Weights(
url='https://hf.co/torchgeo/resnet50_sentinel2_all_moco/resolve/da4f3c9dbe09272eb902f3b37f46635fa4726879/resnet50_sentinel2_all_moco-df8b932e.pth',
transforms=_zhu_xlab_transforms,
transforms=_ssl4eo_s12_transforms_s2_10k,
meta={
'dataset': 'SSL4EO-S12',
'in_chans': 13,
@ -432,7 +510,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
SENTINEL2_RGB_MOCO = Weights(
url='https://hf.co/torchgeo/resnet50_sentinel2_rgb_moco/resolve/efd9723b59a88e9dc1420dc1e96afb25b0630a3c/resnet50_sentinel2_rgb_moco-2b57ba8b.pth',
transforms=_zhu_xlab_transforms,
transforms=_ssl4eo_s12_transforms_s2_10k,
meta={
'dataset': 'SSL4EO-S12',
'in_chans': 3,