diff --git a/InnerEye/ML/SSL/lightning_containers/ssl_container.py b/InnerEye/ML/SSL/lightning_containers/ssl_container.py index 6e714c8c..0bf82654 100644 --- a/InnerEye/ML/SSL/lightning_containers/ssl_container.py +++ b/InnerEye/ML/SSL/lightning_containers/ssl_container.py @@ -151,7 +151,7 @@ class SSLContainer(LightningContainer): model: LightningModule = SimCLRInnerEye(encoder_name=self.ssl_encoder.value, dataset_name=self.ssl_training_dataset_name.value, use_7x7_first_conv_in_resnet=use_7x7_first_conv_in_resnet, - num_samples=self.data_module.num_train_samples, + num_samples=len(self.data_module.dataset_train), batch_size=self.data_module.batch_size, gpus=self.num_gpus_per_node(), num_nodes=self.num_nodes, @@ -172,8 +172,8 @@ class SSLContainer(LightningContainer): f"{SSLTrainingType.BYOL.value}. " f"Found {self.ssl_training_type.value}") model.hparams.update({'ssl_type': self.ssl_training_type.value, - "num_classes": self.data_module.num_classes}) - self.encoder_output_dim = get_encoder_output_dim(model, self.data_module) + "num_classes": 0}) + self.encoder_output_dim = -1 # get_encoder_output_dim(model, self.data_module) return model def get_data_module(self) -> InnerEyeDataModuleTypes: @@ -184,9 +184,8 @@ class SSLContainer(LightningContainer): if hasattr(self, "data_module"): return self.data_module encoder_data_module = self._create_ssl_data_modules(is_ssl_encoder_module=True) - linear_data_module = self._create_ssl_data_modules(is_ssl_encoder_module=False) - return CombinedDataModule(encoder_data_module, linear_data_module, - self.use_balanced_binary_loss_for_linear_head) + # linear_data_module = self._create_ssl_data_modules(is_ssl_encoder_module=False) + return encoder_data_module def _create_ssl_data_modules(self, is_ssl_encoder_module: bool) -> InnerEyeVisionDataModule: """ diff --git a/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py b/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py index 2aa77096..ecaf2fe9 100644 --- a/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py +++ b/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py @@ -124,7 +124,7 @@ class SSLOnlineEvaluatorInnerEye(SSLOnlineEvaluator): Moves batch to device. :param device: device to move the batch to. """ - _, x, y = batch + x, y = batch return x.to(device), y.to(device) def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: