Generalize SSL functionality to work on other datasets (#555)
This PR contains some changes needed to make the SSLContainer compatible with new datasets and allow a user to run by simply creating a new augmentation config or defining a child class * _get_transforms has been changed to accept new datasets without the need to touch the class * get_cxr_ssl_transform has been changed to avoid the hidden channel expansion and make that optional. It has been also renamed to get_ssl_transform because it has nothing specific to cxr * drop_last is now set as parameter of the InnerEyeVisionDataModule and the SSLContainer - that means it can be changed when initializing a new SSLContainer * documentation about bringing your own SSL model has been updated
This commit is contained in:
Родитель
5b7d571209
Коммит
521c004357
|
@ -25,6 +25,7 @@ jobs that run in AzureML.
|
|||
|
||||
### Changed
|
||||
- ([#531](https://github.com/microsoft/InnerEye-DeepLearning/pull/531)) Updated PL to 1.3.8, torchmetrics and pl-bolts and changed relevant metrics and SSL code API.
|
||||
- ([#555](https://github.com/microsoft/InnerEye-DeepLearning/pull/555)) Make the SSLContainer compatible with new datasets
|
||||
- ([#533](https://github.com/microsoft/InnerEye-DeepLearning/pull/533)) Better defaults for inference on ensemble children.
|
||||
- ([#536](https://github.com/microsoft/InnerEye-DeepLearning/pull/536)) Inference will not run on the validation set by default, this can be turned on
|
||||
via the `--inference_on_val_set` flag.
|
||||
|
|
|
@ -29,6 +29,7 @@ class InnerEyeVisionDataModule(VisionDataModule):
|
|||
num_workers: int = 6,
|
||||
batch_size: int = 32,
|
||||
seed: int = 42,
|
||||
drop_last: bool = True,
|
||||
*args: Any, **kwargs: Any) -> None:
|
||||
"""
|
||||
Wrapper around VisionDatamodule to load torchvision dataset into a pytorch-lightning module.
|
||||
|
@ -42,16 +43,17 @@ class InnerEyeVisionDataModule(VisionDataModule):
|
|||
:param val_transforms: transforms to use at validation time
|
||||
:param data_dir: data directory where to find the data
|
||||
:param val_split: proportion of training dataset to use for validation
|
||||
:param num_workers: number of processes for dataloaders.
|
||||
:param batch_size: batch size for training & validation.
|
||||
:param num_workers: number of processes for dataloaders
|
||||
:param batch_size: batch size for training & validation
|
||||
:param seed: random seed for dataset splitting
|
||||
:param drop_last: bool, if true it drops the last incomplete batch
|
||||
"""
|
||||
data_dir = data_dir if data_dir is not None else os.getcwd()
|
||||
super().__init__(data_dir=data_dir,
|
||||
val_split=val_split,
|
||||
num_workers=num_workers,
|
||||
batch_size=batch_size,
|
||||
drop_last=True,
|
||||
drop_last=drop_last,
|
||||
train_transforms=train_transforms,
|
||||
val_transforms=val_transforms,
|
||||
seed=seed,
|
||||
|
|
|
@ -10,16 +10,17 @@ import torch
|
|||
from pl_bolts.models.self_supervised.simclr import SimCLRTrainDataTransform
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from InnerEye.ML.augmentations.transform_pipeline import create_cxr_transforms_from_config
|
||||
from InnerEye.ML.augmentations.transform_pipeline import create_transforms_from_config
|
||||
|
||||
|
||||
def get_cxr_ssl_transforms(config: CfgNode,
|
||||
return_two_views_per_sample: bool,
|
||||
use_training_augmentations_for_validation: bool = False) -> Tuple[Any, Any]:
|
||||
def get_ssl_transforms_from_config(config: CfgNode,
|
||||
return_two_views_per_sample: bool,
|
||||
use_training_augmentations_for_validation: bool = False,
|
||||
expand_channels: bool = True) -> Tuple[Any, Any]:
|
||||
"""
|
||||
Returns training and validation transforms for CXR.
|
||||
Transformations are constructed in the following way:
|
||||
1. Construct the pipeline of augmentations in create_chest_xray_transform (e.g. resize, flip, affine) as defined
|
||||
1. Construct the pipeline of augmentations in create_transform_from_config (e.g. resize, flip, affine) as defined
|
||||
by the config.
|
||||
2. If we just want to construct the transformation pipeline for a classification model or for the linear evaluator
|
||||
of the SSL module, return this pipeline.
|
||||
|
@ -29,14 +30,18 @@ def get_cxr_ssl_transforms(config: CfgNode,
|
|||
|
||||
:param config: configuration defining which augmentations to apply as well as their intensities.
|
||||
:param return_two_views_per_sample: if True the resulting transforms will return two versions of each sample they
|
||||
are called on. If False, simply return one transformed version of the sample.
|
||||
are called on. If False, simply return one transformed version of the sample centered and cropped.
|
||||
:param use_training_augmentations_for_validation: If True, use augmentation at validation time too.
|
||||
This is required for SSL validation loss to be meaningful. If False, only apply basic processing step
|
||||
(no augmentations)
|
||||
:param expand_channels: if True the expand channel transformation from InnerEye.ML.augmentations.image_transforms
|
||||
will be added to the transformation passed through the config. This is needed for single channel images as CXR.
|
||||
"""
|
||||
train_transforms = create_cxr_transforms_from_config(config, apply_augmentations=True)
|
||||
val_transforms = create_cxr_transforms_from_config(config,
|
||||
apply_augmentations=use_training_augmentations_for_validation)
|
||||
train_transforms = create_transforms_from_config(config, apply_augmentations=True,
|
||||
expand_channels=expand_channels)
|
||||
val_transforms = create_transforms_from_config(config,
|
||||
apply_augmentations=use_training_augmentations_for_validation,
|
||||
expand_channels=expand_channels)
|
||||
if return_two_views_per_sample:
|
||||
train_transforms = DualViewTransformWrapper(train_transforms) # type: ignore
|
||||
val_transforms = DualViewTransformWrapper(val_transforms) # type: ignore
|
||||
|
|
|
@ -17,7 +17,7 @@ from InnerEye.ML.SSL.datamodules_and_datasets.cxr_datasets import CheXpert, Covi
|
|||
from InnerEye.ML.SSL.datamodules_and_datasets.datamodules import CombinedDataModule, InnerEyeVisionDataModule
|
||||
from InnerEye.ML.SSL.datamodules_and_datasets.transforms_utils import InnerEyeCIFARLinearHeadTransform, \
|
||||
InnerEyeCIFARTrainTransform, \
|
||||
get_cxr_ssl_transforms
|
||||
get_ssl_transforms_from_config
|
||||
from InnerEye.ML.SSL.encoders import get_encoder_output_dim
|
||||
from InnerEye.ML.SSL.lightning_modules.byol.byol_module import BYOLInnerEye
|
||||
from InnerEye.ML.SSL.lightning_modules.simclr_module import SimCLRInnerEye
|
||||
|
@ -96,6 +96,7 @@ class SSLContainer(LightningContainer):
|
|||
learning_rate_linear_head_during_ssl_training = param.Number(default=1e-4,
|
||||
doc="Learning rate for linear head training during "
|
||||
"SSL training.")
|
||||
drop_last = param.Boolean(default=True, doc="If True drops the last incomplete batch")
|
||||
|
||||
def setup(self) -> None:
|
||||
from InnerEye.ML.SSL.lightning_containers.ssl_image_classifier import SSLClassifierContainer
|
||||
|
@ -166,8 +167,8 @@ class SSLContainer(LightningContainer):
|
|||
f"Found {self.ssl_training_type.value}")
|
||||
model.hparams.update({'ssl_type': self.ssl_training_type.value,
|
||||
"num_classes": self.data_module.num_classes})
|
||||
self.encoder_output_dim = get_encoder_output_dim(model, self.data_module)
|
||||
|
||||
self.encoder_output_dim = get_encoder_output_dim(model, self.data_module)
|
||||
return model
|
||||
|
||||
def get_data_module(self) -> InnerEyeDataModuleTypes:
|
||||
|
@ -186,7 +187,7 @@ class SSLContainer(LightningContainer):
|
|||
"""
|
||||
Returns torch lightning data module for encoder or linear head
|
||||
|
||||
:param is_ssl_encoder_module: whether to return the data module for SSL training or for linear heard. If true,
|
||||
:param is_ssl_encoder_module: whether to return the data module for SSL training or for linear head. If true,
|
||||
:return transforms with two views per sample (batch like (img_v1, img_v2, label)). If False, return only one
|
||||
view per sample but also return the index of the sample in the dataset (to make sure we don't use twice the same
|
||||
batch in one training epoch (batch like (index, img_v1, label), as classifier dataloader expected to be shorter
|
||||
|
@ -209,7 +210,8 @@ class SSLContainer(LightningContainer):
|
|||
data_dir=str(datamodule_args.dataset_path),
|
||||
batch_size=batch_size_per_gpu,
|
||||
num_workers=self.num_workers,
|
||||
seed=self.random_seed)
|
||||
seed=self.random_seed,
|
||||
drop_last=self.drop_last)
|
||||
dm.prepare_data()
|
||||
dm.setup()
|
||||
return dm
|
||||
|
@ -223,8 +225,10 @@ class SSLContainer(LightningContainer):
|
|||
examples.
|
||||
:param dataset_name: name of the dataset, value has to be in SSLDatasetName, determines which transformation
|
||||
pipeline to return.
|
||||
:param is_ssl_encoder_module: if True the transformation pipeline will yield two version of the image it is
|
||||
applied on. If False, return only one transformation.
|
||||
:param is_ssl_encoder_module: if True the transformation pipeline will yield two versions of the image it is
|
||||
applied on and it applies the training transformations also at validation time. Note that if your transformation
|
||||
does not contain any randomness, the pipeline will return two identical copies. If False, it will return only one
|
||||
transformation.
|
||||
:return: training transformation pipeline and validation transformation pipeline.
|
||||
"""
|
||||
if dataset_name in [SSLDatasetName.RSNAKaggleCXR.value,
|
||||
|
@ -232,16 +236,28 @@ class SSLContainer(LightningContainer):
|
|||
SSLDatasetName.CheXpert.value,
|
||||
SSLDatasetName.Covid.value]:
|
||||
assert augmentation_config is not None
|
||||
train_transforms, val_transforms = get_cxr_ssl_transforms(augmentation_config,
|
||||
return_two_views_per_sample=is_ssl_encoder_module,
|
||||
use_training_augmentations_for_validation=is_ssl_encoder_module)
|
||||
train_transforms, val_transforms = get_ssl_transforms_from_config(
|
||||
augmentation_config,
|
||||
return_two_views_per_sample=is_ssl_encoder_module,
|
||||
use_training_augmentations_for_validation=is_ssl_encoder_module
|
||||
)
|
||||
elif dataset_name in [SSLDatasetName.CIFAR10.value, SSLDatasetName.CIFAR100.value]:
|
||||
train_transforms = \
|
||||
InnerEyeCIFARTrainTransform(32) if is_ssl_encoder_module else InnerEyeCIFARLinearHeadTransform(32)
|
||||
val_transforms = \
|
||||
InnerEyeCIFARTrainTransform(32) if is_ssl_encoder_module else InnerEyeCIFARLinearHeadTransform(32)
|
||||
elif augmentation_config:
|
||||
train_transforms, val_transforms = get_ssl_transforms_from_config(
|
||||
augmentation_config,
|
||||
return_two_views_per_sample=is_ssl_encoder_module,
|
||||
use_training_augmentations_for_validation=is_ssl_encoder_module,
|
||||
expand_channels=False,
|
||||
)
|
||||
logging.warning(f"Dataset {dataset_name} unknown. The config will be consumed by "
|
||||
f"get_ssl_transforms() to create the augmentation pipeline, make sure "
|
||||
f"the transformations in your configs are compatible. ")
|
||||
else:
|
||||
raise ValueError(f"Dataset {dataset_name} unknown.")
|
||||
raise ValueError(f"Dataset {dataset_name} unknown and no config has been passed.")
|
||||
|
||||
return train_transforms, val_transforms
|
||||
|
||||
|
|
|
@ -86,16 +86,22 @@ class ImageTransformationPipeline:
|
|||
return self.transform_image(data)
|
||||
|
||||
|
||||
def create_cxr_transforms_from_config(config: CfgNode,
|
||||
apply_augmentations: bool) -> ImageTransformationPipeline:
|
||||
def create_transforms_from_config(config: CfgNode,
|
||||
apply_augmentations: bool,
|
||||
expand_channels: bool = True) -> ImageTransformationPipeline:
|
||||
"""
|
||||
Defines the image transformations pipeline used in Chest-Xray datasets. Can be used for other types of
|
||||
images data, type of augmentations to use and strength are expected to be defined in the config.
|
||||
Defines the image transformations pipeline from a config file. It has been designed for Chest X-Ray
|
||||
images but it can be used for other types of images data, type of augmentations to use and strength are
|
||||
expected to be defined in the config. The channel expansion is needed for gray images.
|
||||
:param config: config yaml file fixing strength and type of augmentation to apply
|
||||
:param apply_augmentations: if True return transformation pipeline with augmentations. Else,
|
||||
disable augmentations i.e. only resize and center crop the image.
|
||||
:param expand_channels: if True the expand channel transformation from InnerEye.ML.augmentations.image_transforms
|
||||
will be added to the transformation passed through the config. This is needed for single channel images as CXR.
|
||||
"""
|
||||
transforms: List[Any] = [ExpandChannels()]
|
||||
transforms: List[Any] = []
|
||||
if expand_channels:
|
||||
transforms.append(ExpandChannels())
|
||||
if apply_augmentations:
|
||||
if config.augmentation.use_random_affine:
|
||||
transforms.append(RandomAffine(
|
||||
|
|
|
@ -23,7 +23,8 @@ from InnerEye.ML.SSL.encoders import SSLEncoder
|
|||
from InnerEye.ML.SSL.lightning_containers.ssl_container import EncoderName
|
||||
from InnerEye.ML.SSL.lightning_modules.ssl_classifier_module import SSLClassifier
|
||||
from InnerEye.ML.SSL.utils import create_ssl_image_classifier, load_yaml_augmentation_config
|
||||
from InnerEye.ML.augmentations.transform_pipeline import create_cxr_transforms_from_config
|
||||
from InnerEye.ML.augmentations.transform_pipeline import create_transforms_from_config
|
||||
|
||||
from InnerEye.ML.common import ModelExecutionMode
|
||||
|
||||
from InnerEye.ML.configs.ssl.CXR_SSL_configs import path_linear_head_augmentation_cxr
|
||||
|
@ -137,9 +138,9 @@ class CovidModel(ScalarModelBase):
|
|||
def get_image_transform(self) -> ModelTransformsPerExecutionMode:
|
||||
config = load_yaml_augmentation_config(path_linear_head_augmentation_cxr)
|
||||
train_transforms = Compose(
|
||||
[DicomPreparation(), create_cxr_transforms_from_config(config, apply_augmentations=True)])
|
||||
[DicomPreparation(), create_transforms_from_config(config, apply_augmentations=True)])
|
||||
val_transforms = Compose(
|
||||
[DicomPreparation(), create_cxr_transforms_from_config(config, apply_augmentations=False)])
|
||||
[DicomPreparation(), create_transforms_from_config(config, apply_augmentations=False)])
|
||||
|
||||
return ModelTransformsPerExecutionMode(train=train_transforms,
|
||||
val=val_transforms,
|
||||
|
|
|
@ -7,13 +7,14 @@ import random
|
|||
import PIL
|
||||
import pytest
|
||||
import torch
|
||||
from torchvision.transforms import CenterCrop, ColorJitter, RandomAffine, RandomErasing, RandomHorizontalFlip, \
|
||||
RandomResizedCrop, Resize, ToTensor
|
||||
from torchvision.transforms import (CenterCrop, ColorJitter, RandomAffine, RandomErasing, RandomHorizontalFlip,
|
||||
RandomResizedCrop, Resize, ToTensor)
|
||||
from torchvision.transforms.functional import to_tensor
|
||||
|
||||
from InnerEye.ML.augmentations.image_transforms import AddGaussianNoise, ElasticTransform, ExpandChannels, RandomGamma
|
||||
from InnerEye.ML.augmentations.image_transforms import (AddGaussianNoise, ElasticTransform,
|
||||
ExpandChannels, RandomGamma)
|
||||
from InnerEye.ML.augmentations.transform_pipeline import ImageTransformationPipeline, \
|
||||
create_cxr_transforms_from_config
|
||||
create_transforms_from_config
|
||||
|
||||
from Tests.SSL.test_data_modules import cxr_augmentation_config
|
||||
|
||||
|
@ -31,7 +32,6 @@ test_2d_image_as_ZCHW_tensor = test_2d_image_as_CHW_tensor.unsqueeze(0)
|
|||
test_4d_scan_as_tensor = torch.ones([5, 4, *image_size]) * 255.
|
||||
test_4d_scan_as_tensor[..., 10:15, 10:20] = 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_different_transformation_per_channel", [True, False])
|
||||
def test_torchvision_on_various_input(use_different_transformation_per_channel: bool) -> None:
|
||||
"""
|
||||
|
@ -107,17 +107,16 @@ def test_custom_tf_on_various_input(use_different_transformation_per_channel: bo
|
|||
assert torch.isclose(transformed[0, 0], transformed[1, 1]).all() != use_different_transformation_per_channel
|
||||
|
||||
|
||||
def test_create_transform_pipeline_from_config() -> None:
|
||||
@pytest.mark.parametrize("expand_channels", [True, False])
|
||||
def test_create_transform_pipeline_from_config(expand_channels: bool) -> None:
|
||||
"""
|
||||
Tests that the pipeline returned by create_transform_pipeline_from_config returns the expected transformation.
|
||||
"""
|
||||
transformation_pipeline = create_cxr_transforms_from_config(cxr_augmentation_config, apply_augmentations=True)
|
||||
transformation_pipeline = create_transforms_from_config(cxr_augmentation_config, apply_augmentations=True,
|
||||
expand_channels=expand_channels)
|
||||
fake_cxr_as_array = np.ones([256, 256]) * 255.
|
||||
fake_cxr_as_array[100:150, 100:200] = 1
|
||||
fake_cxr_image = PIL.Image.fromarray(fake_cxr_as_array).convert("L")
|
||||
|
||||
all_transforms = [ExpandChannels(),
|
||||
RandomAffine(degrees=180, translate=(0, 0), shear=40),
|
||||
all_transforms = [RandomAffine(degrees=180, translate=(0, 0), shear=40),
|
||||
RandomResizedCrop(scale=(0.4, 1.0), size=256),
|
||||
RandomHorizontalFlip(p=0.5),
|
||||
RandomGamma(scale=(0.5, 1.5)),
|
||||
|
@ -128,23 +127,28 @@ def test_create_transform_pipeline_from_config() -> None:
|
|||
AddGaussianNoise(std=0.05, p_apply=0.5)
|
||||
]
|
||||
|
||||
if expand_channels:
|
||||
all_transforms.insert(0, ExpandChannels())
|
||||
# expand channels is used for single-channel input images
|
||||
fake_image = PIL.Image.fromarray(fake_cxr_as_array).convert("L")
|
||||
# In the pipeline the image is converted to tensor before applying the transformations. Do the same here.
|
||||
image = ToTensor()(fake_image).reshape([1, 1, 256, 256])
|
||||
else:
|
||||
fake_3d_array = np.dstack([fake_cxr_as_array, fake_cxr_as_array, fake_cxr_as_array])
|
||||
fake_image = PIL.Image.fromarray(fake_3d_array.astype(np.uint8)).convert("RGB")
|
||||
# In the pipeline the image is converted to tensor before applying the transformations. Do the same here.
|
||||
image = ToTensor()(fake_image).reshape([1, 3, 256, 256])
|
||||
|
||||
np.random.seed(3)
|
||||
torch.manual_seed(3)
|
||||
random.seed(3)
|
||||
|
||||
transformed_image = transformation_pipeline(fake_cxr_image)
|
||||
transformed_image = transformation_pipeline(fake_image)
|
||||
assert isinstance(transformed_image, torch.Tensor)
|
||||
# Expected pipeline
|
||||
image = np.ones([256, 256]) * 255.
|
||||
image[100:150, 100:200] = 1
|
||||
image = PIL.Image.fromarray(image).convert("L")
|
||||
# In the pipeline the image is converted to tensor before applying the transformations. Do the same here.
|
||||
image = ToTensor()(image).reshape([1, 1, 256, 256])
|
||||
|
||||
# Expected pipeline
|
||||
np.random.seed(3)
|
||||
torch.manual_seed(3)
|
||||
random.seed(3)
|
||||
|
||||
expected_transformed = image
|
||||
for t in all_transforms:
|
||||
expected_transformed = t(expected_transformed)
|
||||
|
@ -154,10 +158,14 @@ def test_create_transform_pipeline_from_config() -> None:
|
|||
assert torch.isclose(expected_transformed, transformed_image).all()
|
||||
|
||||
# Test the evaluation pipeline
|
||||
transformation_pipeline = create_cxr_transforms_from_config(cxr_augmentation_config, apply_augmentations=False)
|
||||
transformation_pipeline = create_transforms_from_config(cxr_augmentation_config, apply_augmentations=False,
|
||||
expand_channels=expand_channels)
|
||||
transformed_image = transformation_pipeline(image)
|
||||
assert isinstance(transformed_image, torch.Tensor)
|
||||
all_transforms = [ExpandChannels(), Resize(size=256), CenterCrop(size=224)]
|
||||
all_transforms = [Resize(size=256), CenterCrop(size=224)]
|
||||
if expand_channels:
|
||||
all_transforms.insert(0, ExpandChannels())
|
||||
|
||||
expected_transformed = image
|
||||
for t in all_transforms:
|
||||
expected_transformed = t(expected_transformed)
|
||||
|
|
|
@ -16,7 +16,7 @@ from InnerEye.ML.SSL.datamodules_and_datasets.cifar_datasets import InnerEyeCIFA
|
|||
from InnerEye.ML.SSL.datamodules_and_datasets.cxr_datasets import RSNAKaggleCXR
|
||||
from InnerEye.ML.SSL.datamodules_and_datasets.datamodules import CombinedDataModule, InnerEyeVisionDataModule
|
||||
from InnerEye.ML.SSL.datamodules_and_datasets.transforms_utils import InnerEyeCIFARLinearHeadTransform, \
|
||||
InnerEyeCIFARTrainTransform, get_cxr_ssl_transforms
|
||||
InnerEyeCIFARTrainTransform, get_ssl_transforms_from_config
|
||||
from InnerEye.ML.SSL.lightning_containers.ssl_container import SSLContainer, SSLDatasetName
|
||||
from InnerEye.ML.SSL.utils import SSLDataModuleType, load_yaml_augmentation_config
|
||||
from InnerEye.ML.configs.ssl.CXR_SSL_configs import path_encoder_augmentation_cxr
|
||||
|
@ -32,8 +32,8 @@ def test_weights_innereye_module() -> None:
|
|||
"""
|
||||
Tests if weights in CXR data module are correctly initialized
|
||||
"""
|
||||
transforms = get_cxr_ssl_transforms(cxr_augmentation_config,
|
||||
return_two_views_per_sample=True)
|
||||
transforms = get_ssl_transforms_from_config(cxr_augmentation_config,
|
||||
return_two_views_per_sample=True)
|
||||
data_module = InnerEyeVisionDataModule(dataset_cls=RSNAKaggleCXR,
|
||||
return_index=False,
|
||||
train_transforms=transforms[0],
|
||||
|
@ -70,7 +70,8 @@ def test_innereye_vision_module() -> None:
|
|||
data_dir=None,
|
||||
batch_size=5,
|
||||
shuffle=False,
|
||||
num_workers=0)
|
||||
num_workers=0,
|
||||
drop_last=True)
|
||||
data_module.prepare_data()
|
||||
data_module.setup()
|
||||
assert len(data_module.dataset_train) == 45000
|
||||
|
@ -179,8 +180,8 @@ def test_combined_data_module() -> None:
|
|||
"""
|
||||
Tests the behavior of CombinedDataModule
|
||||
"""
|
||||
_, val_transform = get_cxr_ssl_transforms(cxr_augmentation_config,
|
||||
return_two_views_per_sample=False)
|
||||
_, val_transform = get_ssl_transforms_from_config(cxr_augmentation_config,
|
||||
return_two_views_per_sample=False)
|
||||
|
||||
# Datamodule expected to have 12 training batches - 3 val
|
||||
long_data_module = InnerEyeVisionDataModule(dataset_cls=RSNAKaggleCXR,
|
||||
|
|
|
@ -117,21 +117,29 @@ with the following available arguments:
|
|||
* `random_seed`: seed for the run,
|
||||
* `num_epochs`: number of epochs to train for.
|
||||
|
||||
In case you wish to first test your model locally, here some optional arguments that can be useful:
|
||||
* `local_dataset`: path to local dataset, if passed the azure dataset will be ignored
|
||||
* `is_debug_model`: if True it will only run on the first batch of each epoch
|
||||
* `drop_last`: if False (True by default) it will keep the last batch also if incomplete
|
||||
|
||||
### Creating your own datamodules:
|
||||
|
||||
To use this code with your own data, you will need to:
|
||||
|
||||
1. Create a dataset class that reads your new dataset, inheriting from both `VisionDataset`
|
||||
1. Define your own Lightening Container that inherits from `SSLContainer` as described in the paragraph above.
|
||||
2. Create a dataset class that reads your new dataset, inheriting from both `VisionDataset`
|
||||
and `InnerEyeDataClassBaseWithReturnIndex`. See for example how we constructed `RSNAKaggleCXR`
|
||||
class. WARNING: the first positional argument of your dataset class constructor MUST be the data directory ("root"),
|
||||
as VisionDataModule expects this in the prepare_data step.
|
||||
2. Add a member to the `SSLDatasetName` Enum with your new dataset and update the `_SSLDataClassMappings` member of the
|
||||
class so that the code knows which data class to associate to your new dataset name.
|
||||
3. Update the `_get_transforms` methods to add the transform specific to your new dataset. To simplify this step, we
|
||||
have defined a series of standard transforms parametrized by an augmentation yaml file in `SSL/transforms_utils.py` (
|
||||
see next paragraph for more details). You could for example construct a transform pipeline similar to the one created
|
||||
with `get_cxr_ssl_transforms` for our CXR examples.
|
||||
4. Update all necessary parameters in the model config (cf. previous paragraph)
|
||||
3. In your own container update the `_SSLDataClassMappings` member of the class so that the code knows which data class
|
||||
to associate to your new dataset name.
|
||||
4. Create a yaml configuration file that contains the augmentations specific to your dataset. The yaml file will be
|
||||
consumed by the `create_transforms_from_config` function defined in the
|
||||
`InnerEye.ML.augmentations.transform_pipeline` module (see next paragraph for more details). Alternatively, overwrite
|
||||
the `_get_transforms` method. To simplify this step, we have defined a series of standard operations in
|
||||
`SSL/transforms_utils.py` . You could for example construct a transform pipeline similar to the one created
|
||||
inside `create_transform_from_config` inside your own method.
|
||||
5. Update all necessary parameters in the model config (cf. previous paragraph)
|
||||
|
||||
Once all these steps are updated, the code in the base SSLContainer class will take care of creating the corresponding
|
||||
datamodules for SSL training and linear head monitoring.
|
||||
|
|
Загрузка…
Ссылка в новой задаче