diff --git a/tests/trainers/test_simclr.py b/tests/trainers/test_simclr.py index b7629cb36..6d15931ef 100644 --- a/tests/trainers/test_simclr.py +++ b/tests/trainers/test_simclr.py @@ -73,13 +73,13 @@ class TestSimCLRTask: def test_version_warnings(self) -> None: with pytest.warns(UserWarning, match="SimCLR v1 only uses 2 layers"): - SimCLRTask(version=1, layers=3) + SimCLRTask(version=1, layers=3, memory_bank_size=0) with pytest.warns(UserWarning, match="SimCLR v1 does not use a memory bank"): - SimCLRTask(version=1, memory_bank_size=10) + SimCLRTask(version=1, layers=2, memory_bank_size=10) with pytest.warns(UserWarning, match=r"SimCLR v2 uses 3\+ layers"): - SimCLRTask(version=2, layers=2) + SimCLRTask(version=2, layers=2, memory_bank_size=10) with pytest.warns(UserWarning, match="SimCLR v2 uses a memory bank"): - SimCLRTask(version=2, memory_bank_size=0) + SimCLRTask(version=2, layers=3, memory_bank_size=0) @pytest.fixture def weights(self) -> WeightsEnum: diff --git a/torchgeo/trainers/base.py b/torchgeo/trainers/base.py index 3a44c047a..878034a77 100644 --- a/torchgeo/trainers/base.py +++ b/torchgeo/trainers/base.py @@ -36,9 +36,13 @@ class BaseTask(LightningModule, ABC): """ super().__init__() self.save_hyperparameters(ignore=ignore) + self.configure_models() self.configure_losses() self.configure_metrics() - self.configure_models() + + @abstractmethod + def configure_models(self) -> None: + """Initialize the model.""" def configure_losses(self) -> None: """Initialize the loss criterion.""" @@ -46,10 +50,6 @@ class BaseTask(LightningModule, ABC): def configure_metrics(self) -> None: """Initialize the performance metrics.""" - @abstractmethod - def configure_models(self) -> None: - """Initialize the model.""" - def configure_optimizers( self, ) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig": diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index d7a133e44..3aa06be84 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -75,6 +75,35 @@ class ClassificationTask(BaseTask): self.weights = weights super().__init__(ignore="weights") + def configure_models(self) -> None: + """Initialize the model.""" + weights = self.weights + + # Create model + self.model = timm.create_model( + self.hparams["model"], + num_classes=self.hparams["num_classes"], + in_chans=self.hparams["in_channels"], + pretrained=weights is True, + ) + + # Load weights + if weights and weights is not True: + if isinstance(weights, WeightsEnum): + state_dict = weights.get_state_dict(progress=True) + elif os.path.exists(weights): + _, state_dict = utils.extract_backbone(weights) + else: + state_dict = get_weight(weights).get_state_dict(progress=True) + utils.load_state_dict(self.model, state_dict) + + # Freeze backbone and unfreeze classifier head + if self.hparams["freeze_backbone"]: + for param in self.model.parameters(): + param.requires_grad = False + for param in self.model.get_classifier().parameters(): + param.requires_grad = True + def configure_losses(self) -> None: """Initialize the loss criterion. @@ -134,35 +163,6 @@ class ClassificationTask(BaseTask): self.val_metrics = metrics.clone(prefix="val_") self.test_metrics = metrics.clone(prefix="test_") - def configure_models(self) -> None: - """Initialize the model.""" - weights = self.weights - - # Create model - self.model = timm.create_model( - self.hparams["model"], - num_classes=self.hparams["num_classes"], - in_chans=self.hparams["in_channels"], - pretrained=weights is True, - ) - - # Load weights - if weights and weights is not True: - if isinstance(weights, WeightsEnum): - state_dict = weights.get_state_dict(progress=True) - elif os.path.exists(weights): - _, state_dict = utils.extract_backbone(weights) - else: - state_dict = get_weight(weights).get_state_dict(progress=True) - utils.load_state_dict(self.model, state_dict) - - # Freeze backbone and unfreeze classifier head - if self.hparams["freeze_backbone"]: - for param in self.model.parameters(): - param.requires_grad = False - for param in self.model.get_classifier().parameters(): - param.requires_grad = True - def training_step( self, batch: Any, batch_idx: int, dataloader_idx: int = 0 ) -> Tensor: diff --git a/torchgeo/trainers/moco.py b/torchgeo/trainers/moco.py index e8e3212cc..8b1e59cf3 100644 --- a/torchgeo/trainers/moco.py +++ b/torchgeo/trainers/moco.py @@ -226,14 +226,6 @@ class MoCoTask(BaseTask): self.augmentation1 = augmentation1 or aug1 self.augmentation2 = augmentation2 or aug2 - def configure_losses(self) -> None: - """Initialize the loss criterion.""" - self.criterion = NTXentLoss( - self.hparams["temperature"], - self.hparams["memory_bank_size"], - self.hparams["gather_distributed"], - ) - def configure_models(self) -> None: """Initialize the model.""" model: str = self.hparams["model"] @@ -282,6 +274,22 @@ class MoCoTask(BaseTask): # Initialize moving average of output self.avg_output_std = 0.0 + def configure_losses(self) -> None: + """Initialize the loss criterion.""" + try: + self.criterion = NTXentLoss( + self.hparams["temperature"], + (self.hparams["memory_bank_size"], self.hparams["output_dim"]), + self.hparams["gather_distributed"], + ) + except TypeError: + # lightly 1.4.24 and older + self.criterion = NTXentLoss( + self.hparams["temperature"], + self.hparams["memory_bank_size"], + self.hparams["gather_distributed"], + ) + def configure_optimizers( self, ) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig": diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index 9f1ffdc57..7ab19d215 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -79,6 +79,34 @@ class RegressionTask(BaseTask): self.weights = weights super().__init__(ignore="weights") + def configure_models(self) -> None: + """Initialize the model.""" + # Create model + weights = self.weights + self.model = timm.create_model( + self.hparams["model"], + num_classes=self.hparams["num_outputs"], + in_chans=self.hparams["in_channels"], + pretrained=weights is True, + ) + + # Load weights + if weights and weights is not True: + if isinstance(weights, WeightsEnum): + state_dict = weights.get_state_dict(progress=True) + elif os.path.exists(weights): + _, state_dict = utils.extract_backbone(weights) + else: + state_dict = get_weight(weights).get_state_dict(progress=True) + utils.load_state_dict(self.model, state_dict) + + # Freeze backbone and unfreeze classifier head + if self.hparams["freeze_backbone"]: + for param in self.model.parameters(): + param.requires_grad = False + for param in self.model.get_classifier().parameters(): + param.requires_grad = True + def configure_losses(self) -> None: """Initialize the loss criterion. @@ -117,34 +145,6 @@ class RegressionTask(BaseTask): self.val_metrics = metrics.clone(prefix="val_") self.test_metrics = metrics.clone(prefix="test_") - def configure_models(self) -> None: - """Initialize the model.""" - # Create model - weights = self.weights - self.model = timm.create_model( - self.hparams["model"], - num_classes=self.hparams["num_outputs"], - in_chans=self.hparams["in_channels"], - pretrained=weights is True, - ) - - # Load weights - if weights and weights is not True: - if isinstance(weights, WeightsEnum): - state_dict = weights.get_state_dict(progress=True) - elif os.path.exists(weights): - _, state_dict = utils.extract_backbone(weights) - else: - state_dict = get_weight(weights).get_state_dict(progress=True) - utils.load_state_dict(self.model, state_dict) - - # Freeze backbone and unfreeze classifier head - if self.hparams["freeze_backbone"]: - for param in self.model.parameters(): - param.requires_grad = False - for param in self.model.get_classifier().parameters(): - param.requires_grad = True - def training_step( self, batch: Any, batch_idx: int, dataloader_idx: int = 0 ) -> Tensor: diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 7b7be93a0..1249f3f28 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -90,6 +90,63 @@ class SemanticSegmentationTask(BaseTask): self.weights = weights super().__init__(ignore="weights") + def configure_models(self) -> None: + """Initialize the model. + + Raises: + ValueError: If *model* is invalid. + """ + model: str = self.hparams["model"] + backbone: str = self.hparams["backbone"] + weights = self.weights + in_channels: int = self.hparams["in_channels"] + num_classes: int = self.hparams["num_classes"] + num_filters: int = self.hparams["num_filters"] + + if model == "unet": + self.model = smp.Unet( + encoder_name=backbone, + encoder_weights="imagenet" if weights is True else None, + in_channels=in_channels, + classes=num_classes, + ) + elif model == "deeplabv3+": + self.model = smp.DeepLabV3Plus( + encoder_name=backbone, + encoder_weights="imagenet" if weights is True else None, + in_channels=in_channels, + classes=num_classes, + ) + elif model == "fcn": + self.model = FCN( + in_channels=in_channels, classes=num_classes, num_filters=num_filters + ) + else: + raise ValueError( + f"Model type '{model}' is not valid. " + "Currently, only supports 'unet', 'deeplabv3+' and 'fcn'." + ) + + if model != "fcn": + if weights and weights is not True: + if isinstance(weights, WeightsEnum): + state_dict = weights.get_state_dict(progress=True) + elif os.path.exists(weights): + _, state_dict = utils.extract_backbone(weights) + else: + state_dict = get_weight(weights).get_state_dict(progress=True) + self.model.encoder.load_state_dict(state_dict) + + # Freeze backbone + if self.hparams["freeze_backbone"] and model in ["unet", "deeplabv3+"]: + for param in self.model.encoder.parameters(): + param.requires_grad = False + + # Freeze decoder + if self.hparams["freeze_decoder"] and model in ["unet", "deeplabv3+"]: + for param in self.model.decoder.parameters(): + param.requires_grad = False + def configure_losses(self) -> None: """Initialize the loss criterion. @@ -155,63 +212,6 @@ class SemanticSegmentationTask(BaseTask): self.val_metrics = metrics.clone(prefix="val_") self.test_metrics = metrics.clone(prefix="test_") - def configure_models(self) -> None: - """Initialize the model. - - Raises: - ValueError: If *model* is invalid. - """ - model: str = self.hparams["model"] - backbone: str = self.hparams["backbone"] - weights = self.weights - in_channels: int = self.hparams["in_channels"] - num_classes: int = self.hparams["num_classes"] - num_filters: int = self.hparams["num_filters"] - - if model == "unet": - self.model = smp.Unet( - encoder_name=backbone, - encoder_weights="imagenet" if weights is True else None, - in_channels=in_channels, - classes=num_classes, - ) - elif model == "deeplabv3+": - self.model = smp.DeepLabV3Plus( - encoder_name=backbone, - encoder_weights="imagenet" if weights is True else None, - in_channels=in_channels, - classes=num_classes, - ) - elif model == "fcn": - self.model = FCN( - in_channels=in_channels, classes=num_classes, num_filters=num_filters - ) - else: - raise ValueError( - f"Model type '{model}' is not valid. " - "Currently, only supports 'unet', 'deeplabv3+' and 'fcn'." - ) - - if model != "fcn": - if weights and weights is not True: - if isinstance(weights, WeightsEnum): - state_dict = weights.get_state_dict(progress=True) - elif os.path.exists(weights): - _, state_dict = utils.extract_backbone(weights) - else: - state_dict = get_weight(weights).get_state_dict(progress=True) - self.model.encoder.load_state_dict(state_dict) - - # Freeze backbone - if self.hparams["freeze_backbone"] and model in ["unet", "deeplabv3+"]: - for param in self.model.encoder.parameters(): - param.requires_grad = False - - # Freeze decoder - if self.hparams["freeze_decoder"] and model in ["unet", "deeplabv3+"]: - for param in self.model.decoder.parameters(): - param.requires_grad = False - def training_step( self, batch: Any, batch_idx: int, dataloader_idx: int = 0 ) -> Tensor: diff --git a/torchgeo/trainers/simclr.py b/torchgeo/trainers/simclr.py index 924b49a54..b753dabcb 100644 --- a/torchgeo/trainers/simclr.py +++ b/torchgeo/trainers/simclr.py @@ -142,19 +142,9 @@ class SimCLRTask(BaseTask): size, grayscale_weights ) - def configure_losses(self) -> None: - """Initialize the loss criterion.""" - self.criterion = NTXentLoss( - self.hparams["temperature"], - self.hparams["memory_bank_size"], - self.hparams["gather_distributed"], - ) - def configure_models(self) -> None: """Initialize the model.""" weights = self.weights - hidden_dim: int = self.hparams["hidden_dim"] - output_dim: int = self.hparams["output_dim"] # Create backbone self.backbone = timm.create_model( @@ -176,13 +166,16 @@ class SimCLRTask(BaseTask): # Create projection head input_dim = self.backbone.num_features - if hidden_dim is None: - hidden_dim = input_dim - if output_dim is None: - output_dim = input_dim + if self.hparams["hidden_dim"] is None: + self.hparams["hidden_dim"] = input_dim + if self.hparams["output_dim"] is None: + self.hparams["output_dim"] = input_dim self.projection_head = SimCLRProjectionHead( - input_dim, hidden_dim, output_dim, self.hparams["layers"] + input_dim, + self.hparams["hidden_dim"], + self.hparams["output_dim"], + self.hparams["layers"], ) # Initialize moving average of output @@ -192,6 +185,22 @@ class SimCLRTask(BaseTask): # v1+: add global batch norm # v2: add selective kernels, channel-wise attention mechanism + def configure_losses(self) -> None: + """Initialize the loss criterion.""" + try: + self.criterion = NTXentLoss( + self.hparams["temperature"], + (self.hparams["memory_bank_size"], self.hparams["output_dim"]), + self.hparams["gather_distributed"], + ) + except TypeError: + # lightly 1.4.24 and older + self.criterion = NTXentLoss( + self.hparams["temperature"], + self.hparams["memory_bank_size"], + self.hparams["gather_distributed"], + ) + def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: """Forward pass of the model.