using plain dataloaders
This commit is contained in:
Родитель
e7863112cd
Коммит
e27144ebfc
|
@ -151,7 +151,7 @@ class SSLContainer(LightningContainer):
|
|||
model: LightningModule = SimCLRInnerEye(encoder_name=self.ssl_encoder.value,
|
||||
dataset_name=self.ssl_training_dataset_name.value,
|
||||
use_7x7_first_conv_in_resnet=use_7x7_first_conv_in_resnet,
|
||||
num_samples=self.data_module.num_train_samples,
|
||||
num_samples=len(self.data_module.dataset_train),
|
||||
batch_size=self.data_module.batch_size,
|
||||
gpus=self.num_gpus_per_node(),
|
||||
num_nodes=self.num_nodes,
|
||||
|
@ -172,8 +172,8 @@ class SSLContainer(LightningContainer):
|
|||
f"{SSLTrainingType.BYOL.value}. "
|
||||
f"Found {self.ssl_training_type.value}")
|
||||
model.hparams.update({'ssl_type': self.ssl_training_type.value,
|
||||
"num_classes": self.data_module.num_classes})
|
||||
self.encoder_output_dim = get_encoder_output_dim(model, self.data_module)
|
||||
"num_classes": 0})
|
||||
self.encoder_output_dim = -1 # get_encoder_output_dim(model, self.data_module)
|
||||
return model
|
||||
|
||||
def get_data_module(self) -> InnerEyeDataModuleTypes:
|
||||
|
@ -184,9 +184,8 @@ class SSLContainer(LightningContainer):
|
|||
if hasattr(self, "data_module"):
|
||||
return self.data_module
|
||||
encoder_data_module = self._create_ssl_data_modules(is_ssl_encoder_module=True)
|
||||
linear_data_module = self._create_ssl_data_modules(is_ssl_encoder_module=False)
|
||||
return CombinedDataModule(encoder_data_module, linear_data_module,
|
||||
self.use_balanced_binary_loss_for_linear_head)
|
||||
# linear_data_module = self._create_ssl_data_modules(is_ssl_encoder_module=False)
|
||||
return encoder_data_module
|
||||
|
||||
def _create_ssl_data_modules(self, is_ssl_encoder_module: bool) -> InnerEyeVisionDataModule:
|
||||
"""
|
||||
|
|
|
@ -124,7 +124,7 @@ class SSLOnlineEvaluatorInnerEye(SSLOnlineEvaluator):
|
|||
Moves batch to device.
|
||||
:param device: device to move the batch to.
|
||||
"""
|
||||
_, x, y = batch
|
||||
x, y = batch
|
||||
return x.to(device), y.to(device)
|
||||
|
||||
def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
|
||||
|
|
Загрузка…
Ссылка в новой задаче