зеркало из https://github.com/microsoft/torchgeo.git
Resolve multiple SimCLR and MoCo lightly warnings (#1931)
* BaseTask: configure models before anything else * Memory bank: specify feature dimension * SimCLRTask: test one warning at a time * blacken * Supper older and newer lightly
This commit is contained in:
Родитель
5646d76bfb
Коммит
642d2daf29
|
@ -73,13 +73,13 @@ class TestSimCLRTask:
|
|||
|
||||
def test_version_warnings(self) -> None:
|
||||
with pytest.warns(UserWarning, match="SimCLR v1 only uses 2 layers"):
|
||||
SimCLRTask(version=1, layers=3)
|
||||
SimCLRTask(version=1, layers=3, memory_bank_size=0)
|
||||
with pytest.warns(UserWarning, match="SimCLR v1 does not use a memory bank"):
|
||||
SimCLRTask(version=1, memory_bank_size=10)
|
||||
SimCLRTask(version=1, layers=2, memory_bank_size=10)
|
||||
with pytest.warns(UserWarning, match=r"SimCLR v2 uses 3\+ layers"):
|
||||
SimCLRTask(version=2, layers=2)
|
||||
SimCLRTask(version=2, layers=2, memory_bank_size=10)
|
||||
with pytest.warns(UserWarning, match="SimCLR v2 uses a memory bank"):
|
||||
SimCLRTask(version=2, memory_bank_size=0)
|
||||
SimCLRTask(version=2, layers=3, memory_bank_size=0)
|
||||
|
||||
@pytest.fixture
|
||||
def weights(self) -> WeightsEnum:
|
||||
|
|
|
@ -36,9 +36,13 @@ class BaseTask(LightningModule, ABC):
|
|||
"""
|
||||
super().__init__()
|
||||
self.save_hyperparameters(ignore=ignore)
|
||||
self.configure_models()
|
||||
self.configure_losses()
|
||||
self.configure_metrics()
|
||||
self.configure_models()
|
||||
|
||||
@abstractmethod
|
||||
def configure_models(self) -> None:
|
||||
"""Initialize the model."""
|
||||
|
||||
def configure_losses(self) -> None:
|
||||
"""Initialize the loss criterion."""
|
||||
|
@ -46,10 +50,6 @@ class BaseTask(LightningModule, ABC):
|
|||
def configure_metrics(self) -> None:
|
||||
"""Initialize the performance metrics."""
|
||||
|
||||
@abstractmethod
|
||||
def configure_models(self) -> None:
|
||||
"""Initialize the model."""
|
||||
|
||||
def configure_optimizers(
|
||||
self,
|
||||
) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig":
|
||||
|
|
|
@ -75,6 +75,35 @@ class ClassificationTask(BaseTask):
|
|||
self.weights = weights
|
||||
super().__init__(ignore="weights")
|
||||
|
||||
def configure_models(self) -> None:
|
||||
"""Initialize the model."""
|
||||
weights = self.weights
|
||||
|
||||
# Create model
|
||||
self.model = timm.create_model(
|
||||
self.hparams["model"],
|
||||
num_classes=self.hparams["num_classes"],
|
||||
in_chans=self.hparams["in_channels"],
|
||||
pretrained=weights is True,
|
||||
)
|
||||
|
||||
# Load weights
|
||||
if weights and weights is not True:
|
||||
if isinstance(weights, WeightsEnum):
|
||||
state_dict = weights.get_state_dict(progress=True)
|
||||
elif os.path.exists(weights):
|
||||
_, state_dict = utils.extract_backbone(weights)
|
||||
else:
|
||||
state_dict = get_weight(weights).get_state_dict(progress=True)
|
||||
utils.load_state_dict(self.model, state_dict)
|
||||
|
||||
# Freeze backbone and unfreeze classifier head
|
||||
if self.hparams["freeze_backbone"]:
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad = False
|
||||
for param in self.model.get_classifier().parameters():
|
||||
param.requires_grad = True
|
||||
|
||||
def configure_losses(self) -> None:
|
||||
"""Initialize the loss criterion.
|
||||
|
||||
|
@ -134,35 +163,6 @@ class ClassificationTask(BaseTask):
|
|||
self.val_metrics = metrics.clone(prefix="val_")
|
||||
self.test_metrics = metrics.clone(prefix="test_")
|
||||
|
||||
def configure_models(self) -> None:
|
||||
"""Initialize the model."""
|
||||
weights = self.weights
|
||||
|
||||
# Create model
|
||||
self.model = timm.create_model(
|
||||
self.hparams["model"],
|
||||
num_classes=self.hparams["num_classes"],
|
||||
in_chans=self.hparams["in_channels"],
|
||||
pretrained=weights is True,
|
||||
)
|
||||
|
||||
# Load weights
|
||||
if weights and weights is not True:
|
||||
if isinstance(weights, WeightsEnum):
|
||||
state_dict = weights.get_state_dict(progress=True)
|
||||
elif os.path.exists(weights):
|
||||
_, state_dict = utils.extract_backbone(weights)
|
||||
else:
|
||||
state_dict = get_weight(weights).get_state_dict(progress=True)
|
||||
utils.load_state_dict(self.model, state_dict)
|
||||
|
||||
# Freeze backbone and unfreeze classifier head
|
||||
if self.hparams["freeze_backbone"]:
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad = False
|
||||
for param in self.model.get_classifier().parameters():
|
||||
param.requires_grad = True
|
||||
|
||||
def training_step(
|
||||
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
|
||||
) -> Tensor:
|
||||
|
|
|
@ -226,14 +226,6 @@ class MoCoTask(BaseTask):
|
|||
self.augmentation1 = augmentation1 or aug1
|
||||
self.augmentation2 = augmentation2 or aug2
|
||||
|
||||
def configure_losses(self) -> None:
|
||||
"""Initialize the loss criterion."""
|
||||
self.criterion = NTXentLoss(
|
||||
self.hparams["temperature"],
|
||||
self.hparams["memory_bank_size"],
|
||||
self.hparams["gather_distributed"],
|
||||
)
|
||||
|
||||
def configure_models(self) -> None:
|
||||
"""Initialize the model."""
|
||||
model: str = self.hparams["model"]
|
||||
|
@ -282,6 +274,22 @@ class MoCoTask(BaseTask):
|
|||
# Initialize moving average of output
|
||||
self.avg_output_std = 0.0
|
||||
|
||||
def configure_losses(self) -> None:
|
||||
"""Initialize the loss criterion."""
|
||||
try:
|
||||
self.criterion = NTXentLoss(
|
||||
self.hparams["temperature"],
|
||||
(self.hparams["memory_bank_size"], self.hparams["output_dim"]),
|
||||
self.hparams["gather_distributed"],
|
||||
)
|
||||
except TypeError:
|
||||
# lightly 1.4.24 and older
|
||||
self.criterion = NTXentLoss(
|
||||
self.hparams["temperature"],
|
||||
self.hparams["memory_bank_size"],
|
||||
self.hparams["gather_distributed"],
|
||||
)
|
||||
|
||||
def configure_optimizers(
|
||||
self,
|
||||
) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig":
|
||||
|
|
|
@ -79,6 +79,34 @@ class RegressionTask(BaseTask):
|
|||
self.weights = weights
|
||||
super().__init__(ignore="weights")
|
||||
|
||||
def configure_models(self) -> None:
|
||||
"""Initialize the model."""
|
||||
# Create model
|
||||
weights = self.weights
|
||||
self.model = timm.create_model(
|
||||
self.hparams["model"],
|
||||
num_classes=self.hparams["num_outputs"],
|
||||
in_chans=self.hparams["in_channels"],
|
||||
pretrained=weights is True,
|
||||
)
|
||||
|
||||
# Load weights
|
||||
if weights and weights is not True:
|
||||
if isinstance(weights, WeightsEnum):
|
||||
state_dict = weights.get_state_dict(progress=True)
|
||||
elif os.path.exists(weights):
|
||||
_, state_dict = utils.extract_backbone(weights)
|
||||
else:
|
||||
state_dict = get_weight(weights).get_state_dict(progress=True)
|
||||
utils.load_state_dict(self.model, state_dict)
|
||||
|
||||
# Freeze backbone and unfreeze classifier head
|
||||
if self.hparams["freeze_backbone"]:
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad = False
|
||||
for param in self.model.get_classifier().parameters():
|
||||
param.requires_grad = True
|
||||
|
||||
def configure_losses(self) -> None:
|
||||
"""Initialize the loss criterion.
|
||||
|
||||
|
@ -117,34 +145,6 @@ class RegressionTask(BaseTask):
|
|||
self.val_metrics = metrics.clone(prefix="val_")
|
||||
self.test_metrics = metrics.clone(prefix="test_")
|
||||
|
||||
def configure_models(self) -> None:
|
||||
"""Initialize the model."""
|
||||
# Create model
|
||||
weights = self.weights
|
||||
self.model = timm.create_model(
|
||||
self.hparams["model"],
|
||||
num_classes=self.hparams["num_outputs"],
|
||||
in_chans=self.hparams["in_channels"],
|
||||
pretrained=weights is True,
|
||||
)
|
||||
|
||||
# Load weights
|
||||
if weights and weights is not True:
|
||||
if isinstance(weights, WeightsEnum):
|
||||
state_dict = weights.get_state_dict(progress=True)
|
||||
elif os.path.exists(weights):
|
||||
_, state_dict = utils.extract_backbone(weights)
|
||||
else:
|
||||
state_dict = get_weight(weights).get_state_dict(progress=True)
|
||||
utils.load_state_dict(self.model, state_dict)
|
||||
|
||||
# Freeze backbone and unfreeze classifier head
|
||||
if self.hparams["freeze_backbone"]:
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad = False
|
||||
for param in self.model.get_classifier().parameters():
|
||||
param.requires_grad = True
|
||||
|
||||
def training_step(
|
||||
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
|
||||
) -> Tensor:
|
||||
|
|
|
@ -90,6 +90,63 @@ class SemanticSegmentationTask(BaseTask):
|
|||
self.weights = weights
|
||||
super().__init__(ignore="weights")
|
||||
|
||||
def configure_models(self) -> None:
|
||||
"""Initialize the model.
|
||||
|
||||
Raises:
|
||||
ValueError: If *model* is invalid.
|
||||
"""
|
||||
model: str = self.hparams["model"]
|
||||
backbone: str = self.hparams["backbone"]
|
||||
weights = self.weights
|
||||
in_channels: int = self.hparams["in_channels"]
|
||||
num_classes: int = self.hparams["num_classes"]
|
||||
num_filters: int = self.hparams["num_filters"]
|
||||
|
||||
if model == "unet":
|
||||
self.model = smp.Unet(
|
||||
encoder_name=backbone,
|
||||
encoder_weights="imagenet" if weights is True else None,
|
||||
in_channels=in_channels,
|
||||
classes=num_classes,
|
||||
)
|
||||
elif model == "deeplabv3+":
|
||||
self.model = smp.DeepLabV3Plus(
|
||||
encoder_name=backbone,
|
||||
encoder_weights="imagenet" if weights is True else None,
|
||||
in_channels=in_channels,
|
||||
classes=num_classes,
|
||||
)
|
||||
elif model == "fcn":
|
||||
self.model = FCN(
|
||||
in_channels=in_channels, classes=num_classes, num_filters=num_filters
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Model type '{model}' is not valid. "
|
||||
"Currently, only supports 'unet', 'deeplabv3+' and 'fcn'."
|
||||
)
|
||||
|
||||
if model != "fcn":
|
||||
if weights and weights is not True:
|
||||
if isinstance(weights, WeightsEnum):
|
||||
state_dict = weights.get_state_dict(progress=True)
|
||||
elif os.path.exists(weights):
|
||||
_, state_dict = utils.extract_backbone(weights)
|
||||
else:
|
||||
state_dict = get_weight(weights).get_state_dict(progress=True)
|
||||
self.model.encoder.load_state_dict(state_dict)
|
||||
|
||||
# Freeze backbone
|
||||
if self.hparams["freeze_backbone"] and model in ["unet", "deeplabv3+"]:
|
||||
for param in self.model.encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# Freeze decoder
|
||||
if self.hparams["freeze_decoder"] and model in ["unet", "deeplabv3+"]:
|
||||
for param in self.model.decoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def configure_losses(self) -> None:
|
||||
"""Initialize the loss criterion.
|
||||
|
||||
|
@ -155,63 +212,6 @@ class SemanticSegmentationTask(BaseTask):
|
|||
self.val_metrics = metrics.clone(prefix="val_")
|
||||
self.test_metrics = metrics.clone(prefix="test_")
|
||||
|
||||
def configure_models(self) -> None:
|
||||
"""Initialize the model.
|
||||
|
||||
Raises:
|
||||
ValueError: If *model* is invalid.
|
||||
"""
|
||||
model: str = self.hparams["model"]
|
||||
backbone: str = self.hparams["backbone"]
|
||||
weights = self.weights
|
||||
in_channels: int = self.hparams["in_channels"]
|
||||
num_classes: int = self.hparams["num_classes"]
|
||||
num_filters: int = self.hparams["num_filters"]
|
||||
|
||||
if model == "unet":
|
||||
self.model = smp.Unet(
|
||||
encoder_name=backbone,
|
||||
encoder_weights="imagenet" if weights is True else None,
|
||||
in_channels=in_channels,
|
||||
classes=num_classes,
|
||||
)
|
||||
elif model == "deeplabv3+":
|
||||
self.model = smp.DeepLabV3Plus(
|
||||
encoder_name=backbone,
|
||||
encoder_weights="imagenet" if weights is True else None,
|
||||
in_channels=in_channels,
|
||||
classes=num_classes,
|
||||
)
|
||||
elif model == "fcn":
|
||||
self.model = FCN(
|
||||
in_channels=in_channels, classes=num_classes, num_filters=num_filters
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Model type '{model}' is not valid. "
|
||||
"Currently, only supports 'unet', 'deeplabv3+' and 'fcn'."
|
||||
)
|
||||
|
||||
if model != "fcn":
|
||||
if weights and weights is not True:
|
||||
if isinstance(weights, WeightsEnum):
|
||||
state_dict = weights.get_state_dict(progress=True)
|
||||
elif os.path.exists(weights):
|
||||
_, state_dict = utils.extract_backbone(weights)
|
||||
else:
|
||||
state_dict = get_weight(weights).get_state_dict(progress=True)
|
||||
self.model.encoder.load_state_dict(state_dict)
|
||||
|
||||
# Freeze backbone
|
||||
if self.hparams["freeze_backbone"] and model in ["unet", "deeplabv3+"]:
|
||||
for param in self.model.encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# Freeze decoder
|
||||
if self.hparams["freeze_decoder"] and model in ["unet", "deeplabv3+"]:
|
||||
for param in self.model.decoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def training_step(
|
||||
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
|
||||
) -> Tensor:
|
||||
|
|
|
@ -142,19 +142,9 @@ class SimCLRTask(BaseTask):
|
|||
size, grayscale_weights
|
||||
)
|
||||
|
||||
def configure_losses(self) -> None:
|
||||
"""Initialize the loss criterion."""
|
||||
self.criterion = NTXentLoss(
|
||||
self.hparams["temperature"],
|
||||
self.hparams["memory_bank_size"],
|
||||
self.hparams["gather_distributed"],
|
||||
)
|
||||
|
||||
def configure_models(self) -> None:
|
||||
"""Initialize the model."""
|
||||
weights = self.weights
|
||||
hidden_dim: int = self.hparams["hidden_dim"]
|
||||
output_dim: int = self.hparams["output_dim"]
|
||||
|
||||
# Create backbone
|
||||
self.backbone = timm.create_model(
|
||||
|
@ -176,13 +166,16 @@ class SimCLRTask(BaseTask):
|
|||
|
||||
# Create projection head
|
||||
input_dim = self.backbone.num_features
|
||||
if hidden_dim is None:
|
||||
hidden_dim = input_dim
|
||||
if output_dim is None:
|
||||
output_dim = input_dim
|
||||
if self.hparams["hidden_dim"] is None:
|
||||
self.hparams["hidden_dim"] = input_dim
|
||||
if self.hparams["output_dim"] is None:
|
||||
self.hparams["output_dim"] = input_dim
|
||||
|
||||
self.projection_head = SimCLRProjectionHead(
|
||||
input_dim, hidden_dim, output_dim, self.hparams["layers"]
|
||||
input_dim,
|
||||
self.hparams["hidden_dim"],
|
||||
self.hparams["output_dim"],
|
||||
self.hparams["layers"],
|
||||
)
|
||||
|
||||
# Initialize moving average of output
|
||||
|
@ -192,6 +185,22 @@ class SimCLRTask(BaseTask):
|
|||
# v1+: add global batch norm
|
||||
# v2: add selective kernels, channel-wise attention mechanism
|
||||
|
||||
def configure_losses(self) -> None:
|
||||
"""Initialize the loss criterion."""
|
||||
try:
|
||||
self.criterion = NTXentLoss(
|
||||
self.hparams["temperature"],
|
||||
(self.hparams["memory_bank_size"], self.hparams["output_dim"]),
|
||||
self.hparams["gather_distributed"],
|
||||
)
|
||||
except TypeError:
|
||||
# lightly 1.4.24 and older
|
||||
self.criterion = NTXentLoss(
|
||||
self.hparams["temperature"],
|
||||
self.hparams["memory_bank_size"],
|
||||
self.hparams["gather_distributed"],
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> tuple[Tensor, Tensor]:
|
||||
"""Forward pass of the model.
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче