This commit is contained in:
Anton Schwaighofer 2022-02-09 15:05:45 +00:00
Родитель e7863112cd
Коммит e27144ebfc
2 изменённых файлов: 6 добавлений и 7 удалений

Просмотреть файл

@ -151,7 +151,7 @@ class SSLContainer(LightningContainer):
model: LightningModule = SimCLRInnerEye(encoder_name=self.ssl_encoder.value,
dataset_name=self.ssl_training_dataset_name.value,
use_7x7_first_conv_in_resnet=use_7x7_first_conv_in_resnet,
num_samples=self.data_module.num_train_samples,
num_samples=len(self.data_module.dataset_train),
batch_size=self.data_module.batch_size,
gpus=self.num_gpus_per_node(),
num_nodes=self.num_nodes,
@ -172,8 +172,8 @@ class SSLContainer(LightningContainer):
f"{SSLTrainingType.BYOL.value}. "
f"Found {self.ssl_training_type.value}")
model.hparams.update({'ssl_type': self.ssl_training_type.value,
"num_classes": self.data_module.num_classes})
self.encoder_output_dim = get_encoder_output_dim(model, self.data_module)
"num_classes": 0})
self.encoder_output_dim = -1 # get_encoder_output_dim(model, self.data_module)
return model
def get_data_module(self) -> InnerEyeDataModuleTypes:
@ -184,9 +184,8 @@ class SSLContainer(LightningContainer):
if hasattr(self, "data_module"):
return self.data_module
encoder_data_module = self._create_ssl_data_modules(is_ssl_encoder_module=True)
linear_data_module = self._create_ssl_data_modules(is_ssl_encoder_module=False)
return CombinedDataModule(encoder_data_module, linear_data_module,
self.use_balanced_binary_loss_for_linear_head)
# linear_data_module = self._create_ssl_data_modules(is_ssl_encoder_module=False)
return encoder_data_module
def _create_ssl_data_modules(self, is_ssl_encoder_module: bool) -> InnerEyeVisionDataModule:
"""

Просмотреть файл

@ -124,7 +124,7 @@ class SSLOnlineEvaluatorInnerEye(SSLOnlineEvaluator):
Moves batch to device.
:param device: device to move the batch to.
"""
_, x, y = batch
x, y = batch
return x.to(device), y.to(device)
def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: