Enable fine-tuning in Deepmil (#650)
This commit is contained in:
Родитель
eda76357f0
Коммит
914a89383d
|
@ -48,6 +48,7 @@ jobs that run in AzureML.
|
|||
- ([#635](https://github.com/microsoft/InnerEye-DeepLearning/pull/635)) Add tile selection and binary label for online evaluation of PANDA SSL
|
||||
- ([#647](https://github.com/microsoft/InnerEye-DeepLearning/pull/647)) Add class-wise accuracy logging and confusion matrix to DeepMIL
|
||||
- ([#653](https://github.com/microsoft/InnerEye-DeepLearning/pull/653)) Add dropout to DeepMIL and fix feature extractor setup.
|
||||
- ([#650](https://github.com/microsoft/InnerEye-DeepLearning/pull/650)) Enable fine-tuning in DeepMIL using PANDA as the classification task.
|
||||
|
||||
### Changed
|
||||
- ([#588](https://github.com/microsoft/InnerEye-DeepLearning/pull/588)) Replace SciPy with PIL.PngImagePlugin.PngImageFile to load png files.
|
||||
|
|
|
@ -13,7 +13,7 @@ import matplotlib.pyplot as plt
|
|||
import more_itertools as mi
|
||||
|
||||
from pytorch_lightning import LightningModule
|
||||
from torch import Tensor, argmax, mode, nn, no_grad, optim, round
|
||||
from torch import Tensor, argmax, mode, nn, set_grad_enabled, optim, round
|
||||
from torchmetrics import AUROC, F1, Accuracy, Precision, Recall, ConfusionMatrix
|
||||
|
||||
from InnerEye.Common import fixed_paths
|
||||
|
@ -55,7 +55,8 @@ class DeepMILModule(LightningModule):
|
|||
slide_dataset: SlidesDataset = None,
|
||||
tile_size: int = 224,
|
||||
level: int = 1,
|
||||
class_names: Optional[List[str]] = None) -> None:
|
||||
class_names: Optional[List[str]] = None,
|
||||
is_finetune: bool = False) -> None:
|
||||
"""
|
||||
:param label_column: Label key for input batch dictionary.
|
||||
:param n_classes: Number of output classes for MIL prediction. For binary classification, n_classes should be set to 1.
|
||||
|
@ -75,6 +76,7 @@ class DeepMILModule(LightningModule):
|
|||
:param tile_size: The size of each tile (default=224).
|
||||
:param level: The downsampling level (e.g. 0, 1, 2) of the tiles if available (default=1).
|
||||
:param class_names: The names of the classes if available (default=None).
|
||||
:param is_finetune: Boolean value to enable/disable finetuning (default=False).
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
|
@ -115,6 +117,9 @@ class DeepMILModule(LightningModule):
|
|||
|
||||
self.verbose = verbose
|
||||
|
||||
# Finetuning attributes
|
||||
self.is_finetune = is_finetune
|
||||
|
||||
self.aggregation_fn, self.num_pooling = self.get_pooling()
|
||||
self.classifier_fn = self.get_classifier()
|
||||
self.loss_fn = self.get_loss()
|
||||
|
@ -196,7 +201,7 @@ class DeepMILModule(LightningModule):
|
|||
log_on_epoch(self, f'{stage}/{metric_name}', metric_object)
|
||||
|
||||
def forward(self, instances: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore
|
||||
with no_grad():
|
||||
with set_grad_enabled(self.is_finetune):
|
||||
instance_features = self.encoder(instances) # N X L x 1 x 1
|
||||
attentions, bag_features = self.aggregation_fn(instance_features) # K x N | K x L
|
||||
bag_features = bag_features.view(-1, self.num_encoding * self.pool_out_dim)
|
||||
|
|
|
@ -139,5 +139,9 @@ class HistoSSLEncoder(TileEncoder):
|
|||
|
||||
def _get_encoder(self) -> Tuple[Callable, int]:
|
||||
resnet18_model = resnet18(pretrained=False)
|
||||
num_features = resnet18_model.fc.in_features
|
||||
histossl_encoder = load_weights_to_model(self.WEIGHTS_URL, resnet18_model)
|
||||
return setup_feature_extractor(histossl_encoder, self.input_dim) # type: ignore
|
||||
histossl_encoder.fc = torch.nn.Sequential()
|
||||
for param in histossl_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
return histossl_encoder, num_features # type: ignore
|
||||
|
|
|
@ -27,6 +27,8 @@ from InnerEye.ML.Histopathology.models.encoders import (HistoSSLEncoder, Identit
|
|||
class BaseMIL(LightningContainer):
|
||||
# Model parameters:
|
||||
pooling_type: str = param.String(doc="Name of the pooling layer class to use.")
|
||||
is_finetune: bool = param.Boolean(doc="Whether to fine-tune the encoder. Options:"
|
||||
"`False` (default), or `True`.")
|
||||
dropout_rate: Optional[float] = param.Number(None, bounds=(0, 1), doc="Pre-classifier dropout rate.")
|
||||
# l_rate, weight_decay, adam_betas are already declared in OptimizerParams superclass
|
||||
|
||||
|
@ -62,8 +64,8 @@ class BaseMIL(LightningContainer):
|
|||
raise NotImplementedError("InnerEyeSSLEncoder requires a pre-trained checkpoint.")
|
||||
|
||||
self.encoder = self.get_encoder()
|
||||
self.encoder.cuda()
|
||||
self.encoder.eval()
|
||||
if not self.is_finetune:
|
||||
self.encoder.eval()
|
||||
|
||||
def get_encoder(self) -> TileEncoder:
|
||||
if self.encoder_type == ImageNetEncoder.__name__:
|
||||
|
@ -95,7 +97,13 @@ class BaseMIL(LightningContainer):
|
|||
self.data_module = self.get_data_module()
|
||||
# Encoding is done in the datamodule, so here we provide instead a dummy
|
||||
# no-op IdentityEncoder to be used inside the model
|
||||
return DeepMILModule(encoder=IdentityEncoder(input_dim=(self.encoder.num_encoding,)),
|
||||
if self.is_finetune:
|
||||
self.model_encoder = self.encoder
|
||||
for params in self.model_encoder.parameters():
|
||||
params.requires_grad = True
|
||||
else:
|
||||
self.model_encoder = IdentityEncoder(input_dim=(self.encoder.num_encoding,))
|
||||
return DeepMILModule(encoder=self.model_encoder,
|
||||
label_column=self.data_module.train_dataset.LABEL_COLUMN,
|
||||
n_classes=self.data_module.train_dataset.N_CLASSES,
|
||||
pooling_layer=self.get_pooling_layer(),
|
||||
|
|
|
@ -14,6 +14,7 @@ from health_azure.utils import CheckpointDownloader
|
|||
from health_azure.utils import get_workspace, is_running_in_azure_ml
|
||||
from health_ml.networks.layers.attention_layers import GatedAttentionLayer
|
||||
from InnerEye.Common import fixed_paths
|
||||
from InnerEye.ML.Histopathology.datamodules.base_module import CacheMode, CacheLocation
|
||||
from InnerEye.ML.Histopathology.datamodules.panda_module import PandaTilesDataModule
|
||||
from InnerEye.ML.Histopathology.datasets.panda_tiles_dataset import PandaTilesDataset
|
||||
from InnerEye.ML.common import get_best_checkpoint_path
|
||||
|
@ -35,12 +36,19 @@ from InnerEye.ML.Histopathology.models.deepmil import DeepMILModule
|
|||
|
||||
|
||||
class DeepSMILEPanda(BaseMIL):
|
||||
"""`is_finetune` sets the fine-tuning mode. If this is set, setting cache_mode=CacheMode.NONE takes ~30 min/epoch and
|
||||
cache_mode=CacheMode.MEMORY, precache_location=CacheLocation.CPU takes ~[5-10] min/epoch.
|
||||
Fine-tuning with caching completes using batch_size=4, max_bag_size=1000, num_epochs=20, max_num_gpus=1 on PANDA.
|
||||
"""
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
default_kwargs = dict(
|
||||
# declared in BaseMIL:
|
||||
pooling_type=GatedAttentionLayer.__name__,
|
||||
# average number of tiles is 56 for PANDA
|
||||
encoding_chunk_size=60,
|
||||
cache_mode=CacheMode.MEMORY,
|
||||
precache_location=CacheLocation.CPU,
|
||||
is_finetune=False,
|
||||
|
||||
# declared in DatasetParams:
|
||||
local_dataset=Path("/tmp/datasets/PANDA_tiles"),
|
||||
|
@ -98,17 +106,19 @@ class DeepSMILEPanda(BaseMIL):
|
|||
os.chdir(fixed_paths.repository_parent_directory())
|
||||
self.downloader.download_checkpoint_if_necessary()
|
||||
self.encoder = self.get_encoder()
|
||||
self.encoder.cuda()
|
||||
self.encoder.eval()
|
||||
if not self.is_finetune:
|
||||
self.encoder.eval()
|
||||
|
||||
def get_data_module(self) -> PandaTilesDataModule:
|
||||
image_key = PandaTilesDataset.IMAGE_COLUMN
|
||||
transform = Compose(
|
||||
[
|
||||
LoadTilesBatchd(image_key, progress=True),
|
||||
EncodeTilesBatchd(image_key, self.encoder, chunk_size=self.encoding_chunk_size),
|
||||
]
|
||||
)
|
||||
if self.is_finetune:
|
||||
transform = Compose([LoadTilesBatchd(image_key, progress=True)])
|
||||
else:
|
||||
transform = Compose([
|
||||
LoadTilesBatchd(image_key, progress=True),
|
||||
EncodeTilesBatchd(image_key, self.encoder, chunk_size=self.encoding_chunk_size)
|
||||
])
|
||||
|
||||
return PandaTilesDataModule(
|
||||
root_path=self.local_dataset,
|
||||
max_bag_size=self.max_bag_size,
|
||||
|
@ -128,7 +138,13 @@ class DeepSMILEPanda(BaseMIL):
|
|||
self.slide_dataset = self.get_slide_dataset()
|
||||
self.level = 1
|
||||
self.class_names = ["ISUP 0", "ISUP 1", "ISUP 2", "ISUP 3", "ISUP 4", "ISUP 5"]
|
||||
return DeepMILModule(encoder=IdentityEncoder(input_dim=(self.encoder.num_encoding,)),
|
||||
if self.is_finetune:
|
||||
self.model_encoder = self.encoder
|
||||
for params in self.model_encoder.parameters():
|
||||
params.requires_grad = True
|
||||
else:
|
||||
self.model_encoder = IdentityEncoder(input_dim=(self.encoder.num_encoding,))
|
||||
return DeepMILModule(encoder=self.model_encoder,
|
||||
label_column=self.data_module.train_dataset.LABEL_COLUMN,
|
||||
n_classes=self.data_module.train_dataset.N_CLASSES,
|
||||
pooling_layer=self.get_pooling_layer(),
|
||||
|
@ -139,7 +155,8 @@ class DeepSMILEPanda(BaseMIL):
|
|||
slide_dataset=self.get_slide_dataset(),
|
||||
tile_size=self.tile_size,
|
||||
level=self.level,
|
||||
class_names=self.class_names)
|
||||
class_names=self.class_names,
|
||||
is_finetune=self.is_finetune)
|
||||
|
||||
def get_slide_dataset(self) -> PandaDataset:
|
||||
return PandaDataset(root=self.extra_local_dataset_paths[0]) # type: ignore
|
||||
|
|
Загрузка…
Ссылка в новой задаче