Resolve multiple SimCLR and MoCo lightly warnings (#1931)

* BaseTask: configure models before anything else

* Memory bank: specify feature dimension

* SimCLRTask: test one warning at a time

* blacken

* Supper older and newer lightly
This commit is contained in:
Adam J. Stewart 2024-03-20 16:24:06 +01:00 коммит произвёл GitHub
Родитель 5646d76bfb
Коммит 642d2daf29
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
7 изменённых файлов: 163 добавлений и 146 удалений

Просмотреть файл

@ -73,13 +73,13 @@ class TestSimCLRTask:
def test_version_warnings(self) -> None:
with pytest.warns(UserWarning, match="SimCLR v1 only uses 2 layers"):
SimCLRTask(version=1, layers=3)
SimCLRTask(version=1, layers=3, memory_bank_size=0)
with pytest.warns(UserWarning, match="SimCLR v1 does not use a memory bank"):
SimCLRTask(version=1, memory_bank_size=10)
SimCLRTask(version=1, layers=2, memory_bank_size=10)
with pytest.warns(UserWarning, match=r"SimCLR v2 uses 3\+ layers"):
SimCLRTask(version=2, layers=2)
SimCLRTask(version=2, layers=2, memory_bank_size=10)
with pytest.warns(UserWarning, match="SimCLR v2 uses a memory bank"):
SimCLRTask(version=2, memory_bank_size=0)
SimCLRTask(version=2, layers=3, memory_bank_size=0)
@pytest.fixture
def weights(self) -> WeightsEnum:

Просмотреть файл

@ -36,9 +36,13 @@ class BaseTask(LightningModule, ABC):
"""
super().__init__()
self.save_hyperparameters(ignore=ignore)
self.configure_models()
self.configure_losses()
self.configure_metrics()
self.configure_models()
@abstractmethod
def configure_models(self) -> None:
"""Initialize the model."""
def configure_losses(self) -> None:
"""Initialize the loss criterion."""
@ -46,10 +50,6 @@ class BaseTask(LightningModule, ABC):
def configure_metrics(self) -> None:
"""Initialize the performance metrics."""
@abstractmethod
def configure_models(self) -> None:
"""Initialize the model."""
def configure_optimizers(
self,
) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig":

Просмотреть файл

@ -75,6 +75,35 @@ class ClassificationTask(BaseTask):
self.weights = weights
super().__init__(ignore="weights")
def configure_models(self) -> None:
"""Initialize the model."""
weights = self.weights
# Create model
self.model = timm.create_model(
self.hparams["model"],
num_classes=self.hparams["num_classes"],
in_chans=self.hparams["in_channels"],
pretrained=weights is True,
)
# Load weights
if weights and weights is not True:
if isinstance(weights, WeightsEnum):
state_dict = weights.get_state_dict(progress=True)
elif os.path.exists(weights):
_, state_dict = utils.extract_backbone(weights)
else:
state_dict = get_weight(weights).get_state_dict(progress=True)
utils.load_state_dict(self.model, state_dict)
# Freeze backbone and unfreeze classifier head
if self.hparams["freeze_backbone"]:
for param in self.model.parameters():
param.requires_grad = False
for param in self.model.get_classifier().parameters():
param.requires_grad = True
def configure_losses(self) -> None:
"""Initialize the loss criterion.
@ -134,35 +163,6 @@ class ClassificationTask(BaseTask):
self.val_metrics = metrics.clone(prefix="val_")
self.test_metrics = metrics.clone(prefix="test_")
def configure_models(self) -> None:
"""Initialize the model."""
weights = self.weights
# Create model
self.model = timm.create_model(
self.hparams["model"],
num_classes=self.hparams["num_classes"],
in_chans=self.hparams["in_channels"],
pretrained=weights is True,
)
# Load weights
if weights and weights is not True:
if isinstance(weights, WeightsEnum):
state_dict = weights.get_state_dict(progress=True)
elif os.path.exists(weights):
_, state_dict = utils.extract_backbone(weights)
else:
state_dict = get_weight(weights).get_state_dict(progress=True)
utils.load_state_dict(self.model, state_dict)
# Freeze backbone and unfreeze classifier head
if self.hparams["freeze_backbone"]:
for param in self.model.parameters():
param.requires_grad = False
for param in self.model.get_classifier().parameters():
param.requires_grad = True
def training_step(
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
) -> Tensor:

Просмотреть файл

@ -226,14 +226,6 @@ class MoCoTask(BaseTask):
self.augmentation1 = augmentation1 or aug1
self.augmentation2 = augmentation2 or aug2
def configure_losses(self) -> None:
"""Initialize the loss criterion."""
self.criterion = NTXentLoss(
self.hparams["temperature"],
self.hparams["memory_bank_size"],
self.hparams["gather_distributed"],
)
def configure_models(self) -> None:
"""Initialize the model."""
model: str = self.hparams["model"]
@ -282,6 +274,22 @@ class MoCoTask(BaseTask):
# Initialize moving average of output
self.avg_output_std = 0.0
def configure_losses(self) -> None:
"""Initialize the loss criterion."""
try:
self.criterion = NTXentLoss(
self.hparams["temperature"],
(self.hparams["memory_bank_size"], self.hparams["output_dim"]),
self.hparams["gather_distributed"],
)
except TypeError:
# lightly 1.4.24 and older
self.criterion = NTXentLoss(
self.hparams["temperature"],
self.hparams["memory_bank_size"],
self.hparams["gather_distributed"],
)
def configure_optimizers(
self,
) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig":

Просмотреть файл

@ -79,6 +79,34 @@ class RegressionTask(BaseTask):
self.weights = weights
super().__init__(ignore="weights")
def configure_models(self) -> None:
"""Initialize the model."""
# Create model
weights = self.weights
self.model = timm.create_model(
self.hparams["model"],
num_classes=self.hparams["num_outputs"],
in_chans=self.hparams["in_channels"],
pretrained=weights is True,
)
# Load weights
if weights and weights is not True:
if isinstance(weights, WeightsEnum):
state_dict = weights.get_state_dict(progress=True)
elif os.path.exists(weights):
_, state_dict = utils.extract_backbone(weights)
else:
state_dict = get_weight(weights).get_state_dict(progress=True)
utils.load_state_dict(self.model, state_dict)
# Freeze backbone and unfreeze classifier head
if self.hparams["freeze_backbone"]:
for param in self.model.parameters():
param.requires_grad = False
for param in self.model.get_classifier().parameters():
param.requires_grad = True
def configure_losses(self) -> None:
"""Initialize the loss criterion.
@ -117,34 +145,6 @@ class RegressionTask(BaseTask):
self.val_metrics = metrics.clone(prefix="val_")
self.test_metrics = metrics.clone(prefix="test_")
def configure_models(self) -> None:
"""Initialize the model."""
# Create model
weights = self.weights
self.model = timm.create_model(
self.hparams["model"],
num_classes=self.hparams["num_outputs"],
in_chans=self.hparams["in_channels"],
pretrained=weights is True,
)
# Load weights
if weights and weights is not True:
if isinstance(weights, WeightsEnum):
state_dict = weights.get_state_dict(progress=True)
elif os.path.exists(weights):
_, state_dict = utils.extract_backbone(weights)
else:
state_dict = get_weight(weights).get_state_dict(progress=True)
utils.load_state_dict(self.model, state_dict)
# Freeze backbone and unfreeze classifier head
if self.hparams["freeze_backbone"]:
for param in self.model.parameters():
param.requires_grad = False
for param in self.model.get_classifier().parameters():
param.requires_grad = True
def training_step(
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
) -> Tensor:

Просмотреть файл

@ -90,6 +90,63 @@ class SemanticSegmentationTask(BaseTask):
self.weights = weights
super().__init__(ignore="weights")
def configure_models(self) -> None:
"""Initialize the model.
Raises:
ValueError: If *model* is invalid.
"""
model: str = self.hparams["model"]
backbone: str = self.hparams["backbone"]
weights = self.weights
in_channels: int = self.hparams["in_channels"]
num_classes: int = self.hparams["num_classes"]
num_filters: int = self.hparams["num_filters"]
if model == "unet":
self.model = smp.Unet(
encoder_name=backbone,
encoder_weights="imagenet" if weights is True else None,
in_channels=in_channels,
classes=num_classes,
)
elif model == "deeplabv3+":
self.model = smp.DeepLabV3Plus(
encoder_name=backbone,
encoder_weights="imagenet" if weights is True else None,
in_channels=in_channels,
classes=num_classes,
)
elif model == "fcn":
self.model = FCN(
in_channels=in_channels, classes=num_classes, num_filters=num_filters
)
else:
raise ValueError(
f"Model type '{model}' is not valid. "
"Currently, only supports 'unet', 'deeplabv3+' and 'fcn'."
)
if model != "fcn":
if weights and weights is not True:
if isinstance(weights, WeightsEnum):
state_dict = weights.get_state_dict(progress=True)
elif os.path.exists(weights):
_, state_dict = utils.extract_backbone(weights)
else:
state_dict = get_weight(weights).get_state_dict(progress=True)
self.model.encoder.load_state_dict(state_dict)
# Freeze backbone
if self.hparams["freeze_backbone"] and model in ["unet", "deeplabv3+"]:
for param in self.model.encoder.parameters():
param.requires_grad = False
# Freeze decoder
if self.hparams["freeze_decoder"] and model in ["unet", "deeplabv3+"]:
for param in self.model.decoder.parameters():
param.requires_grad = False
def configure_losses(self) -> None:
"""Initialize the loss criterion.
@ -155,63 +212,6 @@ class SemanticSegmentationTask(BaseTask):
self.val_metrics = metrics.clone(prefix="val_")
self.test_metrics = metrics.clone(prefix="test_")
def configure_models(self) -> None:
"""Initialize the model.
Raises:
ValueError: If *model* is invalid.
"""
model: str = self.hparams["model"]
backbone: str = self.hparams["backbone"]
weights = self.weights
in_channels: int = self.hparams["in_channels"]
num_classes: int = self.hparams["num_classes"]
num_filters: int = self.hparams["num_filters"]
if model == "unet":
self.model = smp.Unet(
encoder_name=backbone,
encoder_weights="imagenet" if weights is True else None,
in_channels=in_channels,
classes=num_classes,
)
elif model == "deeplabv3+":
self.model = smp.DeepLabV3Plus(
encoder_name=backbone,
encoder_weights="imagenet" if weights is True else None,
in_channels=in_channels,
classes=num_classes,
)
elif model == "fcn":
self.model = FCN(
in_channels=in_channels, classes=num_classes, num_filters=num_filters
)
else:
raise ValueError(
f"Model type '{model}' is not valid. "
"Currently, only supports 'unet', 'deeplabv3+' and 'fcn'."
)
if model != "fcn":
if weights and weights is not True:
if isinstance(weights, WeightsEnum):
state_dict = weights.get_state_dict(progress=True)
elif os.path.exists(weights):
_, state_dict = utils.extract_backbone(weights)
else:
state_dict = get_weight(weights).get_state_dict(progress=True)
self.model.encoder.load_state_dict(state_dict)
# Freeze backbone
if self.hparams["freeze_backbone"] and model in ["unet", "deeplabv3+"]:
for param in self.model.encoder.parameters():
param.requires_grad = False
# Freeze decoder
if self.hparams["freeze_decoder"] and model in ["unet", "deeplabv3+"]:
for param in self.model.decoder.parameters():
param.requires_grad = False
def training_step(
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
) -> Tensor:

Просмотреть файл

@ -142,19 +142,9 @@ class SimCLRTask(BaseTask):
size, grayscale_weights
)
def configure_losses(self) -> None:
"""Initialize the loss criterion."""
self.criterion = NTXentLoss(
self.hparams["temperature"],
self.hparams["memory_bank_size"],
self.hparams["gather_distributed"],
)
def configure_models(self) -> None:
"""Initialize the model."""
weights = self.weights
hidden_dim: int = self.hparams["hidden_dim"]
output_dim: int = self.hparams["output_dim"]
# Create backbone
self.backbone = timm.create_model(
@ -176,13 +166,16 @@ class SimCLRTask(BaseTask):
# Create projection head
input_dim = self.backbone.num_features
if hidden_dim is None:
hidden_dim = input_dim
if output_dim is None:
output_dim = input_dim
if self.hparams["hidden_dim"] is None:
self.hparams["hidden_dim"] = input_dim
if self.hparams["output_dim"] is None:
self.hparams["output_dim"] = input_dim
self.projection_head = SimCLRProjectionHead(
input_dim, hidden_dim, output_dim, self.hparams["layers"]
input_dim,
self.hparams["hidden_dim"],
self.hparams["output_dim"],
self.hparams["layers"],
)
# Initialize moving average of output
@ -192,6 +185,22 @@ class SimCLRTask(BaseTask):
# v1+: add global batch norm
# v2: add selective kernels, channel-wise attention mechanism
def configure_losses(self) -> None:
"""Initialize the loss criterion."""
try:
self.criterion = NTXentLoss(
self.hparams["temperature"],
(self.hparams["memory_bank_size"], self.hparams["output_dim"]),
self.hparams["gather_distributed"],
)
except TypeError:
# lightly 1.4.24 and older
self.criterion = NTXentLoss(
self.hparams["temperature"],
self.hparams["memory_bank_size"],
self.hparams["gather_distributed"],
)
def forward(self, x: Tensor) -> tuple[Tensor, Tensor]:
"""Forward pass of the model.