This commit is contained in:
Harshita Sharma 2022-02-09 10:01:57 +00:00 коммит произвёл GitHub
Родитель eda76357f0
Коммит 914a89383d
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 52 добавлений и 17 удалений

Просмотреть файл

@ -48,6 +48,7 @@ jobs that run in AzureML.
- ([#635](https://github.com/microsoft/InnerEye-DeepLearning/pull/635)) Add tile selection and binary label for online evaluation of PANDA SSL
- ([#647](https://github.com/microsoft/InnerEye-DeepLearning/pull/647)) Add class-wise accuracy logging and confusion matrix to DeepMIL
- ([#653](https://github.com/microsoft/InnerEye-DeepLearning/pull/653)) Add dropout to DeepMIL and fix feature extractor setup.
- ([#650](https://github.com/microsoft/InnerEye-DeepLearning/pull/650)) Enable fine-tuning in DeepMIL using PANDA as the classification task.
### Changed
- ([#588](https://github.com/microsoft/InnerEye-DeepLearning/pull/588)) Replace SciPy with PIL.PngImagePlugin.PngImageFile to load png files.

Просмотреть файл

@ -13,7 +13,7 @@ import matplotlib.pyplot as plt
import more_itertools as mi
from pytorch_lightning import LightningModule
from torch import Tensor, argmax, mode, nn, no_grad, optim, round
from torch import Tensor, argmax, mode, nn, set_grad_enabled, optim, round
from torchmetrics import AUROC, F1, Accuracy, Precision, Recall, ConfusionMatrix
from InnerEye.Common import fixed_paths
@ -55,7 +55,8 @@ class DeepMILModule(LightningModule):
slide_dataset: SlidesDataset = None,
tile_size: int = 224,
level: int = 1,
class_names: Optional[List[str]] = None) -> None:
class_names: Optional[List[str]] = None,
is_finetune: bool = False) -> None:
"""
:param label_column: Label key for input batch dictionary.
:param n_classes: Number of output classes for MIL prediction. For binary classification, n_classes should be set to 1.
@ -75,6 +76,7 @@ class DeepMILModule(LightningModule):
:param tile_size: The size of each tile (default=224).
:param level: The downsampling level (e.g. 0, 1, 2) of the tiles if available (default=1).
:param class_names: The names of the classes if available (default=None).
:param is_finetune: Boolean value to enable/disable finetuning (default=False).
"""
super().__init__()
@ -115,6 +117,9 @@ class DeepMILModule(LightningModule):
self.verbose = verbose
# Finetuning attributes
self.is_finetune = is_finetune
self.aggregation_fn, self.num_pooling = self.get_pooling()
self.classifier_fn = self.get_classifier()
self.loss_fn = self.get_loss()
@ -196,7 +201,7 @@ class DeepMILModule(LightningModule):
log_on_epoch(self, f'{stage}/{metric_name}', metric_object)
def forward(self, instances: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore
with no_grad():
with set_grad_enabled(self.is_finetune):
instance_features = self.encoder(instances) # N X L x 1 x 1
attentions, bag_features = self.aggregation_fn(instance_features) # K x N | K x L
bag_features = bag_features.view(-1, self.num_encoding * self.pool_out_dim)

Просмотреть файл

@ -139,5 +139,9 @@ class HistoSSLEncoder(TileEncoder):
def _get_encoder(self) -> Tuple[Callable, int]:
resnet18_model = resnet18(pretrained=False)
num_features = resnet18_model.fc.in_features
histossl_encoder = load_weights_to_model(self.WEIGHTS_URL, resnet18_model)
return setup_feature_extractor(histossl_encoder, self.input_dim) # type: ignore
histossl_encoder.fc = torch.nn.Sequential()
for param in histossl_encoder.parameters():
param.requires_grad = False
return histossl_encoder, num_features # type: ignore

Просмотреть файл

@ -27,6 +27,8 @@ from InnerEye.ML.Histopathology.models.encoders import (HistoSSLEncoder, Identit
class BaseMIL(LightningContainer):
# Model parameters:
pooling_type: str = param.String(doc="Name of the pooling layer class to use.")
is_finetune: bool = param.Boolean(doc="Whether to fine-tune the encoder. Options:"
"`False` (default), or `True`.")
dropout_rate: Optional[float] = param.Number(None, bounds=(0, 1), doc="Pre-classifier dropout rate.")
# l_rate, weight_decay, adam_betas are already declared in OptimizerParams superclass
@ -62,8 +64,8 @@ class BaseMIL(LightningContainer):
raise NotImplementedError("InnerEyeSSLEncoder requires a pre-trained checkpoint.")
self.encoder = self.get_encoder()
self.encoder.cuda()
self.encoder.eval()
if not self.is_finetune:
self.encoder.eval()
def get_encoder(self) -> TileEncoder:
if self.encoder_type == ImageNetEncoder.__name__:
@ -95,7 +97,13 @@ class BaseMIL(LightningContainer):
self.data_module = self.get_data_module()
# Encoding is done in the datamodule, so here we provide instead a dummy
# no-op IdentityEncoder to be used inside the model
return DeepMILModule(encoder=IdentityEncoder(input_dim=(self.encoder.num_encoding,)),
if self.is_finetune:
self.model_encoder = self.encoder
for params in self.model_encoder.parameters():
params.requires_grad = True
else:
self.model_encoder = IdentityEncoder(input_dim=(self.encoder.num_encoding,))
return DeepMILModule(encoder=self.model_encoder,
label_column=self.data_module.train_dataset.LABEL_COLUMN,
n_classes=self.data_module.train_dataset.N_CLASSES,
pooling_layer=self.get_pooling_layer(),

Просмотреть файл

@ -14,6 +14,7 @@ from health_azure.utils import CheckpointDownloader
from health_azure.utils import get_workspace, is_running_in_azure_ml
from health_ml.networks.layers.attention_layers import GatedAttentionLayer
from InnerEye.Common import fixed_paths
from InnerEye.ML.Histopathology.datamodules.base_module import CacheMode, CacheLocation
from InnerEye.ML.Histopathology.datamodules.panda_module import PandaTilesDataModule
from InnerEye.ML.Histopathology.datasets.panda_tiles_dataset import PandaTilesDataset
from InnerEye.ML.common import get_best_checkpoint_path
@ -35,12 +36,19 @@ from InnerEye.ML.Histopathology.models.deepmil import DeepMILModule
class DeepSMILEPanda(BaseMIL):
"""`is_finetune` sets the fine-tuning mode. If this is set, setting cache_mode=CacheMode.NONE takes ~30 min/epoch and
cache_mode=CacheMode.MEMORY, precache_location=CacheLocation.CPU takes ~[5-10] min/epoch.
Fine-tuning with caching completes using batch_size=4, max_bag_size=1000, num_epochs=20, max_num_gpus=1 on PANDA.
"""
def __init__(self, **kwargs: Any) -> None:
default_kwargs = dict(
# declared in BaseMIL:
pooling_type=GatedAttentionLayer.__name__,
# average number of tiles is 56 for PANDA
encoding_chunk_size=60,
cache_mode=CacheMode.MEMORY,
precache_location=CacheLocation.CPU,
is_finetune=False,
# declared in DatasetParams:
local_dataset=Path("/tmp/datasets/PANDA_tiles"),
@ -98,17 +106,19 @@ class DeepSMILEPanda(BaseMIL):
os.chdir(fixed_paths.repository_parent_directory())
self.downloader.download_checkpoint_if_necessary()
self.encoder = self.get_encoder()
self.encoder.cuda()
self.encoder.eval()
if not self.is_finetune:
self.encoder.eval()
def get_data_module(self) -> PandaTilesDataModule:
image_key = PandaTilesDataset.IMAGE_COLUMN
transform = Compose(
[
LoadTilesBatchd(image_key, progress=True),
EncodeTilesBatchd(image_key, self.encoder, chunk_size=self.encoding_chunk_size),
]
)
if self.is_finetune:
transform = Compose([LoadTilesBatchd(image_key, progress=True)])
else:
transform = Compose([
LoadTilesBatchd(image_key, progress=True),
EncodeTilesBatchd(image_key, self.encoder, chunk_size=self.encoding_chunk_size)
])
return PandaTilesDataModule(
root_path=self.local_dataset,
max_bag_size=self.max_bag_size,
@ -128,7 +138,13 @@ class DeepSMILEPanda(BaseMIL):
self.slide_dataset = self.get_slide_dataset()
self.level = 1
self.class_names = ["ISUP 0", "ISUP 1", "ISUP 2", "ISUP 3", "ISUP 4", "ISUP 5"]
return DeepMILModule(encoder=IdentityEncoder(input_dim=(self.encoder.num_encoding,)),
if self.is_finetune:
self.model_encoder = self.encoder
for params in self.model_encoder.parameters():
params.requires_grad = True
else:
self.model_encoder = IdentityEncoder(input_dim=(self.encoder.num_encoding,))
return DeepMILModule(encoder=self.model_encoder,
label_column=self.data_module.train_dataset.LABEL_COLUMN,
n_classes=self.data_module.train_dataset.N_CLASSES,
pooling_layer=self.get_pooling_layer(),
@ -139,7 +155,8 @@ class DeepSMILEPanda(BaseMIL):
slide_dataset=self.get_slide_dataset(),
tile_size=self.tile_size,
level=self.level,
class_names=self.class_names)
class_names=self.class_names,
is_finetune=self.is_finetune)
def get_slide_dataset(self) -> PandaDataset:
return PandaDataset(root=self.extra_local_dataset_paths[0]) # type: ignore