ENH: Enable flexible finetuning of deepmil submodules (#549)

Add finetuning flags: `tune_encoder tune_pooling tune_classifier` to flexibly finetune each part of the network.
This commit is contained in:
Kenza Bouzid 2022-08-04 13:44:16 +01:00 коммит произвёл GitHub
Родитель 62f3f57e4f
Коммит 122bcc1be3
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
12 изменённых файлов: 210 добавлений и 53 удалений

Просмотреть файл

@ -44,7 +44,7 @@ addition, you can turn on fine-tuning of the encoder, which will improve the res
```shell
conda activate HimlHisto
python ../hi-ml/src/health_ml/runner.py --model health_cpath.SlidesPandaImageNetMILBenchmark --is_finetune --cluster=<your_cluster_name>
python ../hi-ml/src/health_ml/runner.py --model health_cpath.SlidesPandaImageNetMILBenchmark --tune_encoder --cluster=<your_cluster_name>
```
Then the script will output "Successfully queued run number ..." and a line prefixed "Run URL: ...". Open that

Просмотреть файл

@ -89,6 +89,7 @@ DEFAULT_ENVIRONMENT_VARIABLES = {
"RSLEX_DIRECT_VOLUME_MOUNT": "true",
"RSLEX_DIRECT_VOLUME_MOUNT_MAX_CACHE_SIZE": "1",
"DATASET_MOUNT_CACHE_SIZE": "1",
"AZUREML_COMPUTE_USE_COMMON_RUNTIME": "false",
}
PathOrString = Union[Path, str]

Просмотреть файл

@ -84,11 +84,11 @@ python hi-ml/src/health_ml/runner.py --mount_in_azureml --conda_env=hi-ml-cpath/
endef
define DEEPSMILEPANDASLIDES_ARGS
--model=health_cpath.SlidesPandaImageNetMILBenchmark --is_finetune
--model=health_cpath.SlidesPandaImageNetMILBenchmark --tune_encoder
endef
define DEEPSMILEPANDATILES_ARGS
--model=health_cpath.TilesPandaImageNetMIL --is_finetune --batch_size=2
--model=health_cpath.TilesPandaImageNetMIL --tune_encoder --batch_size=2
endef
define TCGACRCKSSLMIL_ARGS

Просмотреть файл

@ -71,12 +71,25 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams):
"generating outputs.")
maximise_primary_metric: bool = param.Boolean(True, doc="Whether the primary validation metric should be "
"maximised (otherwise minimised).")
tune_classifier: bool = param.Boolean(
default=True,
doc="If True (default), fine-tune the classifier during training. If False, keep the classifier frozen.")
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.run_extra_val_epoch = True # Enable running an additional validation step to save tiles/slides thumbnails
self.best_checkpoint_filename = "checkpoint_max_val_auroc"
self.best_checkpoint_filename_with_suffix = self.best_checkpoint_filename + ".ckpt"
self.validate()
def validate(self) -> None:
super().validate()
if not any([self.tune_encoder, self.tune_pooling, self.tune_classifier]) and not self.run_inference_only:
raise ValueError(
"At least one of the encoder, pooling or classifier should be fine tuned. Turn on one of the tune "
"arguments `tune_encoder`, `tune_pooling`, `tune_classifier`. Otherwise, activate inference only "
"mode via `run_inference_only` flag."
)
@property
def cache_dir(self) -> Path:
@ -190,8 +203,8 @@ class BaseMILTiles(BaseMIL):
def setup(self) -> None:
super().setup()
# Fine-tuning requires tiles to be loaded on-the-fly, hence, caching is disabled by default.
# When is_finetune and is_caching are both set, below lines should disable caching automatically.
if self.is_finetune:
# When tune_encoder and is_caching are both set, below lines should disable caching automatically.
if self.tune_encoder:
self.is_caching = False
if not self.is_caching:
self.cache_mode = CacheMode.NONE
@ -224,6 +237,7 @@ class BaseMILTiles(BaseMIL):
n_classes=self.data_module.train_dataset.n_classes,
class_names=self.class_names,
class_weights=self.data_module.class_weights,
tune_classifier=self.tune_classifier,
dropout_rate=self.dropout_rate,
outputs_folder=self.outputs_folder,
ssl_ckpt_run_id=self.ssl_ckpt_run_id,
@ -264,6 +278,7 @@ class BaseMILSlides(BaseMIL):
n_classes=self.data_module.train_dataset.n_classes,
class_names=self.class_names,
class_weights=self.data_module.class_weights,
tune_classifier=self.tune_classifier,
dropout_rate=self.dropout_rate,
outputs_folder=self.outputs_folder,
ssl_ckpt_run_id=self.ssl_ckpt_run_id,

Просмотреть файл

@ -37,7 +37,7 @@ class DeepSMILECrck(BaseMILTiles):
num_transformer_pool_layers=4,
num_transformer_pool_heads=4,
encoding_chunk_size=60,
is_finetune=False,
tune_encoder=False,
is_caching=True,
num_top_slides=0,
azure_datasets=[TCGA_CRCK_DATASET_ID],

Просмотреть файл

@ -33,7 +33,7 @@ class BaseDeepSMILEPanda(BaseMIL):
pool_type=AttentionLayer.__name__,
num_transformer_pool_layers=4,
num_transformer_pool_heads=4,
is_finetune=False,
tune_encoder=False,
# average number of tiles is 56 for PANDA
encoding_chunk_size=60,
max_bag_size=56,
@ -55,7 +55,7 @@ class DeepSMILETilesPanda(BaseMILTiles, BaseDeepSMILEPanda):
""" DeepSMILETilesPanda is derived from BaseMILTiles and BaseDeepSMILEPanda to inherit common behaviors from both
tiles basemil and panda specific configuration.
`is_finetune` sets the fine-tuning mode. `is_finetune` sets the fine-tuning mode. For fine-tuning, batch_size = 2
`tune_encoder` sets the fine-tuning mode of the encoder. For fine-tuning the encoder, batch_size = 2
runs on multiple GPUs with ~ 6:24 min/epoch (train) and ~ 00:50 min/epoch (validation).
"""

Просмотреть файл

@ -50,9 +50,8 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
"""
Configuration for PANDA experiments from Myronenko et al. 2021:
(https://link.springer.com/chapter/10.1007/978-3-030-87237-3_32)
`is_finetune` sets the fine-tuning mode. For fine-tuning,
batch_size = 2 runs on 8 GPUs with
~ 6:24 min/epoch (train) and ~ 00:50 min/epoch (validation).
`tune_encoder` sets the fine-tuning mode of the encoder. For fine-tuning, batch_size = 2 runs on 8 GPUs
with ~ 6:24 min/epoch (train) and ~ 00:50 min/epoch (validation).
"""
def __init__(self, **kwargs: Any) -> None:
@ -77,7 +76,7 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
self.l_rate = 3e-5
self.weight_decay = 0.1
# Params specific to fine-tuning
if self.is_finetune:
if self.tune_encoder:
self.batch_size = 2
super().setup()

Просмотреть файл

@ -8,13 +8,13 @@ from pytorch_lightning.utilities.warnings import rank_zero_warn
from pathlib import Path
from pytorch_lightning import LightningModule
from torch import Tensor, argmax, mode, nn, optim, round, set_grad_enabled
from torch import Tensor, argmax, mode, nn, optim, round
from torchmetrics import AUROC, F1, Accuracy, ConfusionMatrix, Precision, Recall, CohenKappa
from health_ml.utils import log_on_epoch
from health_ml.deep_learning_config import OptimizerParams
from health_cpath.models.encoders import IdentityEncoder
from health_cpath.utils.deepmil_utils import EncoderParams, PoolingParams
from health_cpath.utils.deepmil_utils import EncoderParams, PoolingParams, set_module_gradients_enabled
from health_cpath.datasets.base_dataset import TilesDataset
from health_cpath.utils.naming import MetricsKey, ResultsKey, SlideKey, ModelKey, TileKey
@ -39,6 +39,7 @@ class BaseDeepMILModule(LightningModule):
n_classes: int,
class_weights: Optional[Tensor] = None,
class_names: Optional[Sequence[str]] = None,
tune_classifier: bool = True,
dropout_rate: Optional[float] = None,
verbose: bool = False,
ssl_ckpt_run_id: Optional[str] = None,
@ -53,6 +54,7 @@ class BaseDeepMILModule(LightningModule):
set to 1.
:param class_weights: Tensor containing class weights (default=None).
:param class_names: The names of the classes if available (default=None).
:param tune_classifier: Whether to tune the classifier (default=True).
:param dropout_rate: Rate of pre-classifier dropout (0-1). `None` for no dropout (default).
:param verbose: if True statements about memory usage are output at each step.
:param ssl_ckpt_run_id: Optional parameter to provide the AML run id from where to download the checkpoint
@ -75,6 +77,7 @@ class BaseDeepMILModule(LightningModule):
self.dropout_rate = dropout_rate
self.encoder_params = encoder_params
self.pooling_params = pooling_params
self.optimizer_params = optimizer_params
self.save_hyperparameters()
@ -84,6 +87,7 @@ class BaseDeepMILModule(LightningModule):
# This flag can be switched on before invoking trainer.validate() to enable saving additional time/memory
# consuming validation outputs
self.run_extra_val_epoch = False
self.tune_classifier = tune_classifier
# Model components
self.encoder = encoder_params.get_encoder(ssl_ckpt_run_id, outputs_folder)
@ -98,9 +102,10 @@ class BaseDeepMILModule(LightningModule):
self.val_metrics = self.get_metrics()
self.test_metrics = self.get_metrics()
def get_classifier(self) -> Callable:
def get_classifier(self) -> nn.Module:
classifier_layer = nn.Linear(in_features=self.num_pooling,
out_features=self.n_classes)
set_module_gradients_enabled(classifier_layer, self.tune_classifier)
if self.dropout_rate is None:
return classifier_layer
elif 0 <= self.dropout_rate < 1:
@ -164,25 +169,42 @@ class BaseDeepMILModule(LightningModule):
else:
log_on_epoch(self, f'{stage}/{metric_name}', metric_object)
def forward(self, instances: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore
should_enable_encoder_grad = torch.is_grad_enabled() and self.encoder_params.is_finetune
with set_grad_enabled(should_enable_encoder_grad):
if self.encoder_params.encoding_chunk_size > 0:
embeddings = []
chunks = torch.split(instances, self.encoder_params.encoding_chunk_size)
for chunk in chunks:
chunk_embeddings = self.encoder(chunk)
embeddings.append(chunk_embeddings)
instance_features = torch.cat(embeddings)
else:
instance_features = self.encoder(instances) # N X L x 1 x 1
def get_instance_features(self, instances: Tensor) -> Tensor:
if not self.encoder_params.tune_encoder:
self.encoder.eval()
if self.encoder_params.encoding_chunk_size > 0:
embeddings = []
chunks = torch.split(instances, self.encoder_params.encoding_chunk_size)
for chunk in chunks:
chunk_embeddings = self.encoder(chunk)
embeddings.append(chunk_embeddings)
instance_features = torch.cat(embeddings)
else:
instance_features = self.encoder(instances) # N X L x 1 x 1
return instance_features
def get_attentions_and_bag_features(self, instance_features: Tensor) -> Tuple[Tensor, Tensor]:
if not self.pooling_params.tune_pooling:
self.aggregation_fn.eval()
attentions, bag_features = self.aggregation_fn(instance_features) # K x N | K x L
bag_features = bag_features.view(1, -1)
return attentions, bag_features
def get_bag_logit(self, bag_features: Tensor) -> Tensor:
if not self.tune_classifier:
self.classifier_fn.eval()
bag_logit = self.classifier_fn(bag_features)
return bag_logit
def forward(self, instances: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore
instance_features = self.get_instance_features(instances)
attentions, bag_features = self.get_attentions_and_bag_features(instance_features)
bag_logit = self.get_bag_logit(bag_features)
return bag_logit, attentions
def configure_optimizers(self) -> optim.Optimizer:
return optim.Adam(self.parameters(), lr=self.optimizer_params.l_rate,
return optim.Adam(filter(lambda p: p.requires_grad, self.parameters()),
lr=self.optimizer_params.l_rate,
weight_decay=self.optimizer_params.weight_decay,
betas=self.optimizer_params.adam_betas)

Просмотреть файл

@ -28,14 +28,24 @@ from health_ml.networks.layers.attention_layers import (
)
def set_module_gradients_enabled(model: nn.Module, tuning_flag: bool) -> None:
"""Given a model, enable or disable gradients for all parameters.
:param model: A PyTorch model.
:param tuning_flag: A boolean indicating whether to enable or disable gradients for the model parameters.
"""
for params in model.parameters():
params.requires_grad = tuning_flag
class EncoderParams(param.Parameterized):
"""Parameters class to group all encoder specific attributes for deepmil module. """
encoder_type: str = param.String(doc="Name of the encoder class to use.")
tile_size: int = param.Integer(default=224, bounds=(1, None), doc="Tile width/height, in pixels.")
n_channels: int = param.Integer(default=3, bounds=(1, None), doc="Number of channels in the tile.")
is_finetune: bool = param.Boolean(
False, doc="If True, fine-tune the encoder during training. If False (default), " "keep the encoder frozen."
tune_encoder: bool = param.Boolean(
False, doc="If True, fine-tune the encoder during training. If False (default), keep the encoder frozen."
)
is_caching: bool = param.Boolean(
default=False,
@ -74,10 +84,12 @@ class EncoderParams(param.Parameterized):
elif self.encoder_type == SSLEncoder.__name__:
assert ssl_ckpt_run_id and outputs_folder, "SSLEncoder requires ssl_ckpt_run_id and outputs_folder"
downloader = CheckpointDownloader(run_id=ssl_ckpt_run_id,
download_dir=outputs_folder,
checkpoint_filename=LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX,
remote_checkpoint_dir=Path(DEFAULT_AML_CHECKPOINT_DIR))
downloader = CheckpointDownloader(
run_id=ssl_ckpt_run_id,
download_dir=outputs_folder,
checkpoint_filename=LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX,
remote_checkpoint_dir=Path(DEFAULT_AML_CHECKPOINT_DIR),
)
encoder = SSLEncoder(
pl_checkpoint_path=downloader.local_checkpoint_path,
tile_size=self.tile_size,
@ -85,12 +97,7 @@ class EncoderParams(param.Parameterized):
)
else:
raise ValueError(f"Unsupported encoder type: {self.encoder_type}")
if self.is_finetune:
for params in encoder.parameters():
params.requires_grad = True
else:
encoder.eval()
set_module_gradients_enabled(encoder, tuning_flag=self.tune_encoder)
return encoder
@ -106,9 +113,11 @@ class PoolingParams(param.Parameterized):
default=4, doc="If transformer pooling is chosen, this defines the number of encoding layers.",
)
num_transformer_pool_heads: int = param.Integer(
4,
doc="If transformer pooling is chosen, this defines the number\
of attention heads.",
default=4, doc="If transformer pooling is chosen, this defines the number of attention heads.",
)
tune_pooling: bool = param.Boolean(
default=True,
doc="If True (default), fine-tune the pooling layer during training. If False, keep the pooling layer frozen.",
)
def get_pooling_layer(self, num_encoding: int) -> Tuple[nn.Module, int]:
@ -139,6 +148,7 @@ class PoolingParams(param.Parameterized):
)
self.pool_out_dim = 1 # currently this is hardcoded in forward of the TransformerPooling
else:
raise ValueError(f"Unsupported pooling type: {self.pooling_type}")
raise ValueError(f"Unsupported pooling type: {self.pool_type}")
num_features = num_encoding * self.pool_out_dim
set_module_gradients_enabled(pooling_layer, tuning_flag=self.tune_pooling)
return pooling_layer, num_features

Просмотреть файл

@ -348,7 +348,7 @@ class DeepMILOutputsHandler:
# Writing completed successfully; delete temporary back-up
if self.previous_validation_outputs_dir.exists():
shutil.rmtree(self.previous_validation_outputs_dir)
shutil.rmtree(self.previous_validation_outputs_dir, ignore_errors=True)
def save_test_outputs(self, epoch_results: EpochResultsType, is_global_rank_zero: bool = True) -> None:
"""Render and save test epoch outputs.

Просмотреть файл

@ -21,7 +21,6 @@ class MockDeepSMILETilesPanda(DeepSMILETilesPanda):
pool_hidden_dim=16,
num_transformer_pool_layers=1,
num_transformer_pool_heads=1,
is_finetune=False,
class_names=["ISUP 0", "ISUP 1", "ISUP 2", "ISUP 3", "ISUP 4", "ISUP 5"],
# Encoder parameters
encoder_type=ImageNetEncoder.__name__,
@ -62,7 +61,6 @@ class MockDeepSMILESlidesPanda(DeepSMILESlidesPanda):
pool_hidden_dim=16,
num_transformer_pool_layers=1,
num_transformer_pool_heads=1,
is_finetune=True,
class_names=["ISUP 0", "ISUP 1", "ISUP 2", "ISUP 3", "ISUP 4", "ISUP 5"],
# Encoder parameters
encoder_type=ImageNetEncoder.__name__,

Просмотреть файл

@ -5,6 +5,7 @@
import logging
import os
import shutil
from pytorch_lightning import Trainer
import torch
import pytest
from pathlib import Path
@ -13,6 +14,7 @@ from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Typ
from torch import Tensor, argmax, nn, rand, randint, randn, round, stack, allclose
from torch.utils.data._utils.collate import default_collate
from health_cpath.datamodules.panda_module import PandaTilesDataModule
from health_ml.networks.layers.attention_layers import AttentionLayer
from health_cpath.configs.classification.BaseMIL import BaseMILTiles
@ -35,12 +37,13 @@ from health_ml.utils.common_utils import is_gpu_available
no_gpu = not is_gpu_available()
def get_supervised_imagenet_encoder_params() -> EncoderParams:
return EncoderParams(encoder_type=ImageNetEncoder.__name__)
def get_supervised_imagenet_encoder_params(tune_encoder: bool = True) -> EncoderParams:
return EncoderParams(encoder_type=ImageNetEncoder.__name__, tune_encoder=tune_encoder)
def get_attention_pooling_layer_params(pool_out_dim: int = 1) -> PoolingParams:
return PoolingParams(pool_type=AttentionLayer.__name__, pool_out_dim=pool_out_dim, pool_hidden_dim=5)
def get_attention_pooling_layer_params(pool_out_dim: int = 1, tune_pooling: bool = True) -> PoolingParams:
return PoolingParams(pool_type=AttentionLayer.__name__, pool_out_dim=pool_out_dim, pool_hidden_dim=5,
tune_pooling=tune_pooling)
def _test_lightningmodule(
@ -164,8 +167,8 @@ def mock_panda_slides_root_dir(
@pytest.mark.parametrize("n_classes", [1, 3])
@pytest.mark.parametrize("batch_size", [1, 15])
@pytest.mark.parametrize("max_bag_size", [1, 7])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize("max_bag_size", [1, 5])
@pytest.mark.parametrize("pool_out_dim", [1, 6])
@pytest.mark.parametrize("dropout_rate", [None, 0.5])
def test_lightningmodule_attention(
@ -425,3 +428,112 @@ def test_class_weights_multiclass() -> None:
# TODO: the test should reflect actual weighted loss operation for the class weights after
# batch_size > 1 is implemented.
assert allclose(loss_weighted, loss_unweighted)
def test_wrong_tuning_options() -> None:
with pytest.raises(ValueError,
match=r"At least one of the encoder, pooling or classifier should be fine tuned"):
_ = MockDeepSMILETilesPanda(
tmp_path=Path("foo"),
tune_encoder=False,
tune_pooling=False,
tune_classifier=False
)
def _get_datamodule(tmp_path: Path) -> PandaTilesDataModule:
tiles_generator = MockPandaTilesGenerator(
dest_data_path=tmp_path,
mock_type=MockHistoDataType.FAKE,
n_tiles=4,
n_slides=10,
n_channels=3,
tile_size=28,
img_size=224,
)
tiles_generator.generate_mock_histo_data()
datamodule = PandaTilesDataModule(root_path=tmp_path, batch_size=2, max_bag_size=4)
return datamodule
@pytest.mark.parametrize("tune_classifier", [False, True])
@pytest.mark.parametrize("tune_pooling", [False, True])
@pytest.mark.parametrize("tune_encoder", [False, True])
def test_finetuning_options(
tune_encoder: bool, tune_pooling: bool, tune_classifier: bool, tmp_path: Path
) -> None:
module = TilesDeepMILModule(
n_classes=1,
label_column=DEFAULT_LABEL_COLUMN,
encoder_params=get_supervised_imagenet_encoder_params(tune_encoder=tune_encoder),
pooling_params=get_attention_pooling_layer_params(pool_out_dim=1, tune_pooling=tune_pooling),
tune_classifier=tune_classifier,
)
assert module.encoder_params.tune_encoder == tune_encoder
assert module.pooling_params.tune_pooling == tune_pooling
assert module.tune_classifier == tune_classifier
for params in module.encoder.parameters():
assert params.requires_grad == tune_encoder
for params in module.aggregation_fn.parameters():
assert params.requires_grad == tune_pooling
for params in module.classifier_fn.parameters():
assert params.requires_grad == tune_classifier
instances = torch.randn(4, 3, 224, 224)
def _assert_existing_gradients_fn(tensor: Tensor, tuning_flag: bool) -> None:
assert tensor.requires_grad == tuning_flag
if tuning_flag:
assert tensor.grad_fn is not None
else:
assert tensor.grad_fn is None
with torch.enable_grad():
instance_features = module.get_instance_features(instances)
_assert_existing_gradients_fn(instance_features, tuning_flag=tune_encoder)
assert module.encoder.training == tune_encoder
attentions, bag_features = module.get_attentions_and_bag_features(instances)
_assert_existing_gradients_fn(attentions, tuning_flag=tune_pooling)
_assert_existing_gradients_fn(bag_features, tuning_flag=tune_pooling)
assert module.aggregation_fn.training == tune_pooling
bag_logit = module.get_bag_logit(bag_features)
# bag_logit gradients are required for pooling layer gradients computation, hence
# "tuning_flag=tune_classifier or tune_pooling"
_assert_existing_gradients_fn(bag_logit, tuning_flag=tune_classifier or tune_pooling)
assert module.classifier_fn.training == tune_classifier
@pytest.mark.parametrize("tune_classifier", [False, True])
@pytest.mark.parametrize("tune_pooling", [False, True])
@pytest.mark.parametrize("tune_encoder", [False, True])
def test_training_with_different_finetuning_options(
tune_encoder: bool, tune_pooling: bool, tune_classifier: bool, tmp_path: Path
) -> None:
if any([tune_encoder, tune_pooling, tune_classifier]):
module = TilesDeepMILModule(
n_classes=6,
label_column=MockPandaTilesGenerator.ISUP_GRADE,
encoder_params=get_supervised_imagenet_encoder_params(tune_encoder=tune_encoder),
pooling_params=get_attention_pooling_layer_params(pool_out_dim=1, tune_pooling=tune_pooling),
tune_classifier=tune_classifier,
)
def _assert_existing_gradients(module: nn.Module, tuning_flag: bool) -> None:
for param in module.parameters():
if tuning_flag:
assert param.grad is not None
else:
assert param.grad is None
with patch.object(module, "validation_step"):
trainer = Trainer(max_epochs=1)
trainer.fit(module, datamodule=_get_datamodule(tmp_path))
_assert_existing_gradients(module.classifier_fn, tuning_flag=tune_classifier)
_assert_existing_gradients(module.aggregation_fn, tuning_flag=tune_pooling)
_assert_existing_gradients(module.encoder, tuning_flag=tune_encoder)