using plain dataloaders
This commit is contained in:
Родитель
e7863112cd
Коммит
e27144ebfc
|
@ -151,7 +151,7 @@ class SSLContainer(LightningContainer):
|
||||||
model: LightningModule = SimCLRInnerEye(encoder_name=self.ssl_encoder.value,
|
model: LightningModule = SimCLRInnerEye(encoder_name=self.ssl_encoder.value,
|
||||||
dataset_name=self.ssl_training_dataset_name.value,
|
dataset_name=self.ssl_training_dataset_name.value,
|
||||||
use_7x7_first_conv_in_resnet=use_7x7_first_conv_in_resnet,
|
use_7x7_first_conv_in_resnet=use_7x7_first_conv_in_resnet,
|
||||||
num_samples=self.data_module.num_train_samples,
|
num_samples=len(self.data_module.dataset_train),
|
||||||
batch_size=self.data_module.batch_size,
|
batch_size=self.data_module.batch_size,
|
||||||
gpus=self.num_gpus_per_node(),
|
gpus=self.num_gpus_per_node(),
|
||||||
num_nodes=self.num_nodes,
|
num_nodes=self.num_nodes,
|
||||||
|
@ -172,8 +172,8 @@ class SSLContainer(LightningContainer):
|
||||||
f"{SSLTrainingType.BYOL.value}. "
|
f"{SSLTrainingType.BYOL.value}. "
|
||||||
f"Found {self.ssl_training_type.value}")
|
f"Found {self.ssl_training_type.value}")
|
||||||
model.hparams.update({'ssl_type': self.ssl_training_type.value,
|
model.hparams.update({'ssl_type': self.ssl_training_type.value,
|
||||||
"num_classes": self.data_module.num_classes})
|
"num_classes": 0})
|
||||||
self.encoder_output_dim = get_encoder_output_dim(model, self.data_module)
|
self.encoder_output_dim = -1 # get_encoder_output_dim(model, self.data_module)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def get_data_module(self) -> InnerEyeDataModuleTypes:
|
def get_data_module(self) -> InnerEyeDataModuleTypes:
|
||||||
|
@ -184,9 +184,8 @@ class SSLContainer(LightningContainer):
|
||||||
if hasattr(self, "data_module"):
|
if hasattr(self, "data_module"):
|
||||||
return self.data_module
|
return self.data_module
|
||||||
encoder_data_module = self._create_ssl_data_modules(is_ssl_encoder_module=True)
|
encoder_data_module = self._create_ssl_data_modules(is_ssl_encoder_module=True)
|
||||||
linear_data_module = self._create_ssl_data_modules(is_ssl_encoder_module=False)
|
# linear_data_module = self._create_ssl_data_modules(is_ssl_encoder_module=False)
|
||||||
return CombinedDataModule(encoder_data_module, linear_data_module,
|
return encoder_data_module
|
||||||
self.use_balanced_binary_loss_for_linear_head)
|
|
||||||
|
|
||||||
def _create_ssl_data_modules(self, is_ssl_encoder_module: bool) -> InnerEyeVisionDataModule:
|
def _create_ssl_data_modules(self, is_ssl_encoder_module: bool) -> InnerEyeVisionDataModule:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -124,7 +124,7 @@ class SSLOnlineEvaluatorInnerEye(SSLOnlineEvaluator):
|
||||||
Moves batch to device.
|
Moves batch to device.
|
||||||
:param device: device to move the batch to.
|
:param device: device to move the batch to.
|
||||||
"""
|
"""
|
||||||
_, x, y = batch
|
x, y = batch
|
||||||
return x.to(device), y.to(device)
|
return x.to(device), y.to(device)
|
||||||
|
|
||||||
def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
|
def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
|
||||||
|
|
Загрузка…
Ссылка в новой задаче