This commit is contained in:
Anton Schwaighofer 2022-02-09 15:05:45 +00:00
Родитель e7863112cd
Коммит e27144ebfc
2 изменённых файлов: 6 добавлений и 7 удалений

Просмотреть файл

@ -151,7 +151,7 @@ class SSLContainer(LightningContainer):
model: LightningModule = SimCLRInnerEye(encoder_name=self.ssl_encoder.value, model: LightningModule = SimCLRInnerEye(encoder_name=self.ssl_encoder.value,
dataset_name=self.ssl_training_dataset_name.value, dataset_name=self.ssl_training_dataset_name.value,
use_7x7_first_conv_in_resnet=use_7x7_first_conv_in_resnet, use_7x7_first_conv_in_resnet=use_7x7_first_conv_in_resnet,
num_samples=self.data_module.num_train_samples, num_samples=len(self.data_module.dataset_train),
batch_size=self.data_module.batch_size, batch_size=self.data_module.batch_size,
gpus=self.num_gpus_per_node(), gpus=self.num_gpus_per_node(),
num_nodes=self.num_nodes, num_nodes=self.num_nodes,
@ -172,8 +172,8 @@ class SSLContainer(LightningContainer):
f"{SSLTrainingType.BYOL.value}. " f"{SSLTrainingType.BYOL.value}. "
f"Found {self.ssl_training_type.value}") f"Found {self.ssl_training_type.value}")
model.hparams.update({'ssl_type': self.ssl_training_type.value, model.hparams.update({'ssl_type': self.ssl_training_type.value,
"num_classes": self.data_module.num_classes}) "num_classes": 0})
self.encoder_output_dim = get_encoder_output_dim(model, self.data_module) self.encoder_output_dim = -1 # get_encoder_output_dim(model, self.data_module)
return model return model
def get_data_module(self) -> InnerEyeDataModuleTypes: def get_data_module(self) -> InnerEyeDataModuleTypes:
@ -184,9 +184,8 @@ class SSLContainer(LightningContainer):
if hasattr(self, "data_module"): if hasattr(self, "data_module"):
return self.data_module return self.data_module
encoder_data_module = self._create_ssl_data_modules(is_ssl_encoder_module=True) encoder_data_module = self._create_ssl_data_modules(is_ssl_encoder_module=True)
linear_data_module = self._create_ssl_data_modules(is_ssl_encoder_module=False) # linear_data_module = self._create_ssl_data_modules(is_ssl_encoder_module=False)
return CombinedDataModule(encoder_data_module, linear_data_module, return encoder_data_module
self.use_balanced_binary_loss_for_linear_head)
def _create_ssl_data_modules(self, is_ssl_encoder_module: bool) -> InnerEyeVisionDataModule: def _create_ssl_data_modules(self, is_ssl_encoder_module: bool) -> InnerEyeVisionDataModule:
""" """

Просмотреть файл

@ -124,7 +124,7 @@ class SSLOnlineEvaluatorInnerEye(SSLOnlineEvaluator):
Moves batch to device. Moves batch to device.
:param device: device to move the batch to. :param device: device to move the batch to.
""" """
_, x, y = batch x, y = batch
return x.to(device), y.to(device) return x.to(device), y.to(device)
def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: