From 914a89383d3275cb3304bde9c6cdd6050357f3f7 Mon Sep 17 00:00:00 2001 From: Harshita Sharma <61745616+harshita-s@users.noreply.github.com> Date: Wed, 9 Feb 2022 10:01:57 +0000 Subject: [PATCH] Enable fine-tuning in Deepmil (#650) --- CHANGELOG.md | 1 + InnerEye/ML/Histopathology/models/deepmil.py | 11 ++++-- InnerEye/ML/Histopathology/models/encoders.py | 6 ++- .../histo_configs/classification/BaseMIL.py | 14 +++++-- .../classification/DeepSMILEPanda.py | 37 ++++++++++++++----- 5 files changed, 52 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 19801749..aa0c76e0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -48,6 +48,7 @@ jobs that run in AzureML. - ([#635](https://github.com/microsoft/InnerEye-DeepLearning/pull/635)) Add tile selection and binary label for online evaluation of PANDA SSL - ([#647](https://github.com/microsoft/InnerEye-DeepLearning/pull/647)) Add class-wise accuracy logging and confusion matrix to DeepMIL - ([#653](https://github.com/microsoft/InnerEye-DeepLearning/pull/653)) Add dropout to DeepMIL and fix feature extractor setup. +- ([#650](https://github.com/microsoft/InnerEye-DeepLearning/pull/650)) Enable fine-tuning in DeepMIL using PANDA as the classification task. ### Changed - ([#588](https://github.com/microsoft/InnerEye-DeepLearning/pull/588)) Replace SciPy with PIL.PngImagePlugin.PngImageFile to load png files. diff --git a/InnerEye/ML/Histopathology/models/deepmil.py b/InnerEye/ML/Histopathology/models/deepmil.py index 6052fbef..afd88c17 100644 --- a/InnerEye/ML/Histopathology/models/deepmil.py +++ b/InnerEye/ML/Histopathology/models/deepmil.py @@ -13,7 +13,7 @@ import matplotlib.pyplot as plt import more_itertools as mi from pytorch_lightning import LightningModule -from torch import Tensor, argmax, mode, nn, no_grad, optim, round +from torch import Tensor, argmax, mode, nn, set_grad_enabled, optim, round from torchmetrics import AUROC, F1, Accuracy, Precision, Recall, ConfusionMatrix from InnerEye.Common import fixed_paths @@ -55,7 +55,8 @@ class DeepMILModule(LightningModule): slide_dataset: SlidesDataset = None, tile_size: int = 224, level: int = 1, - class_names: Optional[List[str]] = None) -> None: + class_names: Optional[List[str]] = None, + is_finetune: bool = False) -> None: """ :param label_column: Label key for input batch dictionary. :param n_classes: Number of output classes for MIL prediction. For binary classification, n_classes should be set to 1. @@ -75,6 +76,7 @@ class DeepMILModule(LightningModule): :param tile_size: The size of each tile (default=224). :param level: The downsampling level (e.g. 0, 1, 2) of the tiles if available (default=1). :param class_names: The names of the classes if available (default=None). + :param is_finetune: Boolean value to enable/disable finetuning (default=False). """ super().__init__() @@ -115,6 +117,9 @@ class DeepMILModule(LightningModule): self.verbose = verbose + # Finetuning attributes + self.is_finetune = is_finetune + self.aggregation_fn, self.num_pooling = self.get_pooling() self.classifier_fn = self.get_classifier() self.loss_fn = self.get_loss() @@ -196,7 +201,7 @@ class DeepMILModule(LightningModule): log_on_epoch(self, f'{stage}/{metric_name}', metric_object) def forward(self, instances: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore - with no_grad(): + with set_grad_enabled(self.is_finetune): instance_features = self.encoder(instances) # N X L x 1 x 1 attentions, bag_features = self.aggregation_fn(instance_features) # K x N | K x L bag_features = bag_features.view(-1, self.num_encoding * self.pool_out_dim) diff --git a/InnerEye/ML/Histopathology/models/encoders.py b/InnerEye/ML/Histopathology/models/encoders.py index 04f454bb..57025429 100644 --- a/InnerEye/ML/Histopathology/models/encoders.py +++ b/InnerEye/ML/Histopathology/models/encoders.py @@ -139,5 +139,9 @@ class HistoSSLEncoder(TileEncoder): def _get_encoder(self) -> Tuple[Callable, int]: resnet18_model = resnet18(pretrained=False) + num_features = resnet18_model.fc.in_features histossl_encoder = load_weights_to_model(self.WEIGHTS_URL, resnet18_model) - return setup_feature_extractor(histossl_encoder, self.input_dim) # type: ignore + histossl_encoder.fc = torch.nn.Sequential() + for param in histossl_encoder.parameters(): + param.requires_grad = False + return histossl_encoder, num_features # type: ignore diff --git a/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py b/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py index 5e5ee165..4406a3de 100644 --- a/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py +++ b/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py @@ -27,6 +27,8 @@ from InnerEye.ML.Histopathology.models.encoders import (HistoSSLEncoder, Identit class BaseMIL(LightningContainer): # Model parameters: pooling_type: str = param.String(doc="Name of the pooling layer class to use.") + is_finetune: bool = param.Boolean(doc="Whether to fine-tune the encoder. Options:" + "`False` (default), or `True`.") dropout_rate: Optional[float] = param.Number(None, bounds=(0, 1), doc="Pre-classifier dropout rate.") # l_rate, weight_decay, adam_betas are already declared in OptimizerParams superclass @@ -62,8 +64,8 @@ class BaseMIL(LightningContainer): raise NotImplementedError("InnerEyeSSLEncoder requires a pre-trained checkpoint.") self.encoder = self.get_encoder() - self.encoder.cuda() - self.encoder.eval() + if not self.is_finetune: + self.encoder.eval() def get_encoder(self) -> TileEncoder: if self.encoder_type == ImageNetEncoder.__name__: @@ -95,7 +97,13 @@ class BaseMIL(LightningContainer): self.data_module = self.get_data_module() # Encoding is done in the datamodule, so here we provide instead a dummy # no-op IdentityEncoder to be used inside the model - return DeepMILModule(encoder=IdentityEncoder(input_dim=(self.encoder.num_encoding,)), + if self.is_finetune: + self.model_encoder = self.encoder + for params in self.model_encoder.parameters(): + params.requires_grad = True + else: + self.model_encoder = IdentityEncoder(input_dim=(self.encoder.num_encoding,)) + return DeepMILModule(encoder=self.model_encoder, label_column=self.data_module.train_dataset.LABEL_COLUMN, n_classes=self.data_module.train_dataset.N_CLASSES, pooling_layer=self.get_pooling_layer(), diff --git a/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py b/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py index 68d952cf..9296d649 100644 --- a/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py +++ b/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py @@ -14,6 +14,7 @@ from health_azure.utils import CheckpointDownloader from health_azure.utils import get_workspace, is_running_in_azure_ml from health_ml.networks.layers.attention_layers import GatedAttentionLayer from InnerEye.Common import fixed_paths +from InnerEye.ML.Histopathology.datamodules.base_module import CacheMode, CacheLocation from InnerEye.ML.Histopathology.datamodules.panda_module import PandaTilesDataModule from InnerEye.ML.Histopathology.datasets.panda_tiles_dataset import PandaTilesDataset from InnerEye.ML.common import get_best_checkpoint_path @@ -35,12 +36,19 @@ from InnerEye.ML.Histopathology.models.deepmil import DeepMILModule class DeepSMILEPanda(BaseMIL): + """`is_finetune` sets the fine-tuning mode. If this is set, setting cache_mode=CacheMode.NONE takes ~30 min/epoch and + cache_mode=CacheMode.MEMORY, precache_location=CacheLocation.CPU takes ~[5-10] min/epoch. + Fine-tuning with caching completes using batch_size=4, max_bag_size=1000, num_epochs=20, max_num_gpus=1 on PANDA. + """ def __init__(self, **kwargs: Any) -> None: default_kwargs = dict( # declared in BaseMIL: pooling_type=GatedAttentionLayer.__name__, # average number of tiles is 56 for PANDA encoding_chunk_size=60, + cache_mode=CacheMode.MEMORY, + precache_location=CacheLocation.CPU, + is_finetune=False, # declared in DatasetParams: local_dataset=Path("/tmp/datasets/PANDA_tiles"), @@ -98,17 +106,19 @@ class DeepSMILEPanda(BaseMIL): os.chdir(fixed_paths.repository_parent_directory()) self.downloader.download_checkpoint_if_necessary() self.encoder = self.get_encoder() - self.encoder.cuda() - self.encoder.eval() + if not self.is_finetune: + self.encoder.eval() def get_data_module(self) -> PandaTilesDataModule: image_key = PandaTilesDataset.IMAGE_COLUMN - transform = Compose( - [ - LoadTilesBatchd(image_key, progress=True), - EncodeTilesBatchd(image_key, self.encoder, chunk_size=self.encoding_chunk_size), - ] - ) + if self.is_finetune: + transform = Compose([LoadTilesBatchd(image_key, progress=True)]) + else: + transform = Compose([ + LoadTilesBatchd(image_key, progress=True), + EncodeTilesBatchd(image_key, self.encoder, chunk_size=self.encoding_chunk_size) + ]) + return PandaTilesDataModule( root_path=self.local_dataset, max_bag_size=self.max_bag_size, @@ -128,7 +138,13 @@ class DeepSMILEPanda(BaseMIL): self.slide_dataset = self.get_slide_dataset() self.level = 1 self.class_names = ["ISUP 0", "ISUP 1", "ISUP 2", "ISUP 3", "ISUP 4", "ISUP 5"] - return DeepMILModule(encoder=IdentityEncoder(input_dim=(self.encoder.num_encoding,)), + if self.is_finetune: + self.model_encoder = self.encoder + for params in self.model_encoder.parameters(): + params.requires_grad = True + else: + self.model_encoder = IdentityEncoder(input_dim=(self.encoder.num_encoding,)) + return DeepMILModule(encoder=self.model_encoder, label_column=self.data_module.train_dataset.LABEL_COLUMN, n_classes=self.data_module.train_dataset.N_CLASSES, pooling_layer=self.get_pooling_layer(), @@ -139,7 +155,8 @@ class DeepSMILEPanda(BaseMIL): slide_dataset=self.get_slide_dataset(), tile_size=self.tile_size, level=self.level, - class_names=self.class_names) + class_names=self.class_names, + is_finetune=self.is_finetune) def get_slide_dataset(self) -> PandaDataset: return PandaDataset(root=self.extra_local_dataset_paths[0]) # type: ignore