зеркало из https://github.com/microsoft/hi-ml.git
ENH: Enable flexible finetuning of deepmil submodules (#549)
Add finetuning flags: `tune_encoder tune_pooling tune_classifier` to flexibly finetune each part of the network.
This commit is contained in:
Родитель
62f3f57e4f
Коммит
122bcc1be3
|
@ -44,7 +44,7 @@ addition, you can turn on fine-tuning of the encoder, which will improve the res
|
|||
|
||||
```shell
|
||||
conda activate HimlHisto
|
||||
python ../hi-ml/src/health_ml/runner.py --model health_cpath.SlidesPandaImageNetMILBenchmark --is_finetune --cluster=<your_cluster_name>
|
||||
python ../hi-ml/src/health_ml/runner.py --model health_cpath.SlidesPandaImageNetMILBenchmark --tune_encoder --cluster=<your_cluster_name>
|
||||
```
|
||||
|
||||
Then the script will output "Successfully queued run number ..." and a line prefixed "Run URL: ...". Open that
|
||||
|
|
|
@ -89,6 +89,7 @@ DEFAULT_ENVIRONMENT_VARIABLES = {
|
|||
"RSLEX_DIRECT_VOLUME_MOUNT": "true",
|
||||
"RSLEX_DIRECT_VOLUME_MOUNT_MAX_CACHE_SIZE": "1",
|
||||
"DATASET_MOUNT_CACHE_SIZE": "1",
|
||||
"AZUREML_COMPUTE_USE_COMMON_RUNTIME": "false",
|
||||
}
|
||||
|
||||
PathOrString = Union[Path, str]
|
||||
|
|
|
@ -84,11 +84,11 @@ python hi-ml/src/health_ml/runner.py --mount_in_azureml --conda_env=hi-ml-cpath/
|
|||
endef
|
||||
|
||||
define DEEPSMILEPANDASLIDES_ARGS
|
||||
--model=health_cpath.SlidesPandaImageNetMILBenchmark --is_finetune
|
||||
--model=health_cpath.SlidesPandaImageNetMILBenchmark --tune_encoder
|
||||
endef
|
||||
|
||||
define DEEPSMILEPANDATILES_ARGS
|
||||
--model=health_cpath.TilesPandaImageNetMIL --is_finetune --batch_size=2
|
||||
--model=health_cpath.TilesPandaImageNetMIL --tune_encoder --batch_size=2
|
||||
endef
|
||||
|
||||
define TCGACRCKSSLMIL_ARGS
|
||||
|
|
|
@ -71,12 +71,25 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams):
|
|||
"generating outputs.")
|
||||
maximise_primary_metric: bool = param.Boolean(True, doc="Whether the primary validation metric should be "
|
||||
"maximised (otherwise minimised).")
|
||||
tune_classifier: bool = param.Boolean(
|
||||
default=True,
|
||||
doc="If True (default), fine-tune the classifier during training. If False, keep the classifier frozen.")
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.run_extra_val_epoch = True # Enable running an additional validation step to save tiles/slides thumbnails
|
||||
self.best_checkpoint_filename = "checkpoint_max_val_auroc"
|
||||
self.best_checkpoint_filename_with_suffix = self.best_checkpoint_filename + ".ckpt"
|
||||
self.validate()
|
||||
|
||||
def validate(self) -> None:
|
||||
super().validate()
|
||||
if not any([self.tune_encoder, self.tune_pooling, self.tune_classifier]) and not self.run_inference_only:
|
||||
raise ValueError(
|
||||
"At least one of the encoder, pooling or classifier should be fine tuned. Turn on one of the tune "
|
||||
"arguments `tune_encoder`, `tune_pooling`, `tune_classifier`. Otherwise, activate inference only "
|
||||
"mode via `run_inference_only` flag."
|
||||
)
|
||||
|
||||
@property
|
||||
def cache_dir(self) -> Path:
|
||||
|
@ -190,8 +203,8 @@ class BaseMILTiles(BaseMIL):
|
|||
def setup(self) -> None:
|
||||
super().setup()
|
||||
# Fine-tuning requires tiles to be loaded on-the-fly, hence, caching is disabled by default.
|
||||
# When is_finetune and is_caching are both set, below lines should disable caching automatically.
|
||||
if self.is_finetune:
|
||||
# When tune_encoder and is_caching are both set, below lines should disable caching automatically.
|
||||
if self.tune_encoder:
|
||||
self.is_caching = False
|
||||
if not self.is_caching:
|
||||
self.cache_mode = CacheMode.NONE
|
||||
|
@ -224,6 +237,7 @@ class BaseMILTiles(BaseMIL):
|
|||
n_classes=self.data_module.train_dataset.n_classes,
|
||||
class_names=self.class_names,
|
||||
class_weights=self.data_module.class_weights,
|
||||
tune_classifier=self.tune_classifier,
|
||||
dropout_rate=self.dropout_rate,
|
||||
outputs_folder=self.outputs_folder,
|
||||
ssl_ckpt_run_id=self.ssl_ckpt_run_id,
|
||||
|
@ -264,6 +278,7 @@ class BaseMILSlides(BaseMIL):
|
|||
n_classes=self.data_module.train_dataset.n_classes,
|
||||
class_names=self.class_names,
|
||||
class_weights=self.data_module.class_weights,
|
||||
tune_classifier=self.tune_classifier,
|
||||
dropout_rate=self.dropout_rate,
|
||||
outputs_folder=self.outputs_folder,
|
||||
ssl_ckpt_run_id=self.ssl_ckpt_run_id,
|
||||
|
|
|
@ -37,7 +37,7 @@ class DeepSMILECrck(BaseMILTiles):
|
|||
num_transformer_pool_layers=4,
|
||||
num_transformer_pool_heads=4,
|
||||
encoding_chunk_size=60,
|
||||
is_finetune=False,
|
||||
tune_encoder=False,
|
||||
is_caching=True,
|
||||
num_top_slides=0,
|
||||
azure_datasets=[TCGA_CRCK_DATASET_ID],
|
||||
|
|
|
@ -33,7 +33,7 @@ class BaseDeepSMILEPanda(BaseMIL):
|
|||
pool_type=AttentionLayer.__name__,
|
||||
num_transformer_pool_layers=4,
|
||||
num_transformer_pool_heads=4,
|
||||
is_finetune=False,
|
||||
tune_encoder=False,
|
||||
# average number of tiles is 56 for PANDA
|
||||
encoding_chunk_size=60,
|
||||
max_bag_size=56,
|
||||
|
@ -55,7 +55,7 @@ class DeepSMILETilesPanda(BaseMILTiles, BaseDeepSMILEPanda):
|
|||
""" DeepSMILETilesPanda is derived from BaseMILTiles and BaseDeepSMILEPanda to inherit common behaviors from both
|
||||
tiles basemil and panda specific configuration.
|
||||
|
||||
`is_finetune` sets the fine-tuning mode. `is_finetune` sets the fine-tuning mode. For fine-tuning, batch_size = 2
|
||||
`tune_encoder` sets the fine-tuning mode of the encoder. For fine-tuning the encoder, batch_size = 2
|
||||
runs on multiple GPUs with ~ 6:24 min/epoch (train) and ~ 00:50 min/epoch (validation).
|
||||
"""
|
||||
|
||||
|
|
|
@ -50,9 +50,8 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
|
|||
"""
|
||||
Configuration for PANDA experiments from Myronenko et al. 2021:
|
||||
(https://link.springer.com/chapter/10.1007/978-3-030-87237-3_32)
|
||||
`is_finetune` sets the fine-tuning mode. For fine-tuning,
|
||||
batch_size = 2 runs on 8 GPUs with
|
||||
~ 6:24 min/epoch (train) and ~ 00:50 min/epoch (validation).
|
||||
`tune_encoder` sets the fine-tuning mode of the encoder. For fine-tuning, batch_size = 2 runs on 8 GPUs
|
||||
with ~ 6:24 min/epoch (train) and ~ 00:50 min/epoch (validation).
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
|
@ -77,7 +76,7 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
|
|||
self.l_rate = 3e-5
|
||||
self.weight_decay = 0.1
|
||||
# Params specific to fine-tuning
|
||||
if self.is_finetune:
|
||||
if self.tune_encoder:
|
||||
self.batch_size = 2
|
||||
super().setup()
|
||||
|
||||
|
|
|
@ -8,13 +8,13 @@ from pytorch_lightning.utilities.warnings import rank_zero_warn
|
|||
from pathlib import Path
|
||||
|
||||
from pytorch_lightning import LightningModule
|
||||
from torch import Tensor, argmax, mode, nn, optim, round, set_grad_enabled
|
||||
from torch import Tensor, argmax, mode, nn, optim, round
|
||||
from torchmetrics import AUROC, F1, Accuracy, ConfusionMatrix, Precision, Recall, CohenKappa
|
||||
|
||||
from health_ml.utils import log_on_epoch
|
||||
from health_ml.deep_learning_config import OptimizerParams
|
||||
from health_cpath.models.encoders import IdentityEncoder
|
||||
from health_cpath.utils.deepmil_utils import EncoderParams, PoolingParams
|
||||
from health_cpath.utils.deepmil_utils import EncoderParams, PoolingParams, set_module_gradients_enabled
|
||||
|
||||
from health_cpath.datasets.base_dataset import TilesDataset
|
||||
from health_cpath.utils.naming import MetricsKey, ResultsKey, SlideKey, ModelKey, TileKey
|
||||
|
@ -39,6 +39,7 @@ class BaseDeepMILModule(LightningModule):
|
|||
n_classes: int,
|
||||
class_weights: Optional[Tensor] = None,
|
||||
class_names: Optional[Sequence[str]] = None,
|
||||
tune_classifier: bool = True,
|
||||
dropout_rate: Optional[float] = None,
|
||||
verbose: bool = False,
|
||||
ssl_ckpt_run_id: Optional[str] = None,
|
||||
|
@ -53,6 +54,7 @@ class BaseDeepMILModule(LightningModule):
|
|||
set to 1.
|
||||
:param class_weights: Tensor containing class weights (default=None).
|
||||
:param class_names: The names of the classes if available (default=None).
|
||||
:param tune_classifier: Whether to tune the classifier (default=True).
|
||||
:param dropout_rate: Rate of pre-classifier dropout (0-1). `None` for no dropout (default).
|
||||
:param verbose: if True statements about memory usage are output at each step.
|
||||
:param ssl_ckpt_run_id: Optional parameter to provide the AML run id from where to download the checkpoint
|
||||
|
@ -75,6 +77,7 @@ class BaseDeepMILModule(LightningModule):
|
|||
|
||||
self.dropout_rate = dropout_rate
|
||||
self.encoder_params = encoder_params
|
||||
self.pooling_params = pooling_params
|
||||
self.optimizer_params = optimizer_params
|
||||
|
||||
self.save_hyperparameters()
|
||||
|
@ -84,6 +87,7 @@ class BaseDeepMILModule(LightningModule):
|
|||
# This flag can be switched on before invoking trainer.validate() to enable saving additional time/memory
|
||||
# consuming validation outputs
|
||||
self.run_extra_val_epoch = False
|
||||
self.tune_classifier = tune_classifier
|
||||
|
||||
# Model components
|
||||
self.encoder = encoder_params.get_encoder(ssl_ckpt_run_id, outputs_folder)
|
||||
|
@ -98,9 +102,10 @@ class BaseDeepMILModule(LightningModule):
|
|||
self.val_metrics = self.get_metrics()
|
||||
self.test_metrics = self.get_metrics()
|
||||
|
||||
def get_classifier(self) -> Callable:
|
||||
def get_classifier(self) -> nn.Module:
|
||||
classifier_layer = nn.Linear(in_features=self.num_pooling,
|
||||
out_features=self.n_classes)
|
||||
set_module_gradients_enabled(classifier_layer, self.tune_classifier)
|
||||
if self.dropout_rate is None:
|
||||
return classifier_layer
|
||||
elif 0 <= self.dropout_rate < 1:
|
||||
|
@ -164,25 +169,42 @@ class BaseDeepMILModule(LightningModule):
|
|||
else:
|
||||
log_on_epoch(self, f'{stage}/{metric_name}', metric_object)
|
||||
|
||||
def forward(self, instances: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore
|
||||
should_enable_encoder_grad = torch.is_grad_enabled() and self.encoder_params.is_finetune
|
||||
with set_grad_enabled(should_enable_encoder_grad):
|
||||
if self.encoder_params.encoding_chunk_size > 0:
|
||||
embeddings = []
|
||||
chunks = torch.split(instances, self.encoder_params.encoding_chunk_size)
|
||||
for chunk in chunks:
|
||||
chunk_embeddings = self.encoder(chunk)
|
||||
embeddings.append(chunk_embeddings)
|
||||
instance_features = torch.cat(embeddings)
|
||||
else:
|
||||
instance_features = self.encoder(instances) # N X L x 1 x 1
|
||||
def get_instance_features(self, instances: Tensor) -> Tensor:
|
||||
if not self.encoder_params.tune_encoder:
|
||||
self.encoder.eval()
|
||||
if self.encoder_params.encoding_chunk_size > 0:
|
||||
embeddings = []
|
||||
chunks = torch.split(instances, self.encoder_params.encoding_chunk_size)
|
||||
for chunk in chunks:
|
||||
chunk_embeddings = self.encoder(chunk)
|
||||
embeddings.append(chunk_embeddings)
|
||||
instance_features = torch.cat(embeddings)
|
||||
else:
|
||||
instance_features = self.encoder(instances) # N X L x 1 x 1
|
||||
return instance_features
|
||||
|
||||
def get_attentions_and_bag_features(self, instance_features: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
if not self.pooling_params.tune_pooling:
|
||||
self.aggregation_fn.eval()
|
||||
attentions, bag_features = self.aggregation_fn(instance_features) # K x N | K x L
|
||||
bag_features = bag_features.view(1, -1)
|
||||
return attentions, bag_features
|
||||
|
||||
def get_bag_logit(self, bag_features: Tensor) -> Tensor:
|
||||
if not self.tune_classifier:
|
||||
self.classifier_fn.eval()
|
||||
bag_logit = self.classifier_fn(bag_features)
|
||||
return bag_logit
|
||||
|
||||
def forward(self, instances: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore
|
||||
instance_features = self.get_instance_features(instances)
|
||||
attentions, bag_features = self.get_attentions_and_bag_features(instance_features)
|
||||
bag_logit = self.get_bag_logit(bag_features)
|
||||
return bag_logit, attentions
|
||||
|
||||
def configure_optimizers(self) -> optim.Optimizer:
|
||||
return optim.Adam(self.parameters(), lr=self.optimizer_params.l_rate,
|
||||
return optim.Adam(filter(lambda p: p.requires_grad, self.parameters()),
|
||||
lr=self.optimizer_params.l_rate,
|
||||
weight_decay=self.optimizer_params.weight_decay,
|
||||
betas=self.optimizer_params.adam_betas)
|
||||
|
||||
|
|
|
@ -28,14 +28,24 @@ from health_ml.networks.layers.attention_layers import (
|
|||
)
|
||||
|
||||
|
||||
def set_module_gradients_enabled(model: nn.Module, tuning_flag: bool) -> None:
|
||||
"""Given a model, enable or disable gradients for all parameters.
|
||||
|
||||
:param model: A PyTorch model.
|
||||
:param tuning_flag: A boolean indicating whether to enable or disable gradients for the model parameters.
|
||||
"""
|
||||
for params in model.parameters():
|
||||
params.requires_grad = tuning_flag
|
||||
|
||||
|
||||
class EncoderParams(param.Parameterized):
|
||||
"""Parameters class to group all encoder specific attributes for deepmil module. """
|
||||
|
||||
encoder_type: str = param.String(doc="Name of the encoder class to use.")
|
||||
tile_size: int = param.Integer(default=224, bounds=(1, None), doc="Tile width/height, in pixels.")
|
||||
n_channels: int = param.Integer(default=3, bounds=(1, None), doc="Number of channels in the tile.")
|
||||
is_finetune: bool = param.Boolean(
|
||||
False, doc="If True, fine-tune the encoder during training. If False (default), " "keep the encoder frozen."
|
||||
tune_encoder: bool = param.Boolean(
|
||||
False, doc="If True, fine-tune the encoder during training. If False (default), keep the encoder frozen."
|
||||
)
|
||||
is_caching: bool = param.Boolean(
|
||||
default=False,
|
||||
|
@ -74,10 +84,12 @@ class EncoderParams(param.Parameterized):
|
|||
|
||||
elif self.encoder_type == SSLEncoder.__name__:
|
||||
assert ssl_ckpt_run_id and outputs_folder, "SSLEncoder requires ssl_ckpt_run_id and outputs_folder"
|
||||
downloader = CheckpointDownloader(run_id=ssl_ckpt_run_id,
|
||||
download_dir=outputs_folder,
|
||||
checkpoint_filename=LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX,
|
||||
remote_checkpoint_dir=Path(DEFAULT_AML_CHECKPOINT_DIR))
|
||||
downloader = CheckpointDownloader(
|
||||
run_id=ssl_ckpt_run_id,
|
||||
download_dir=outputs_folder,
|
||||
checkpoint_filename=LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX,
|
||||
remote_checkpoint_dir=Path(DEFAULT_AML_CHECKPOINT_DIR),
|
||||
)
|
||||
encoder = SSLEncoder(
|
||||
pl_checkpoint_path=downloader.local_checkpoint_path,
|
||||
tile_size=self.tile_size,
|
||||
|
@ -85,12 +97,7 @@ class EncoderParams(param.Parameterized):
|
|||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported encoder type: {self.encoder_type}")
|
||||
|
||||
if self.is_finetune:
|
||||
for params in encoder.parameters():
|
||||
params.requires_grad = True
|
||||
else:
|
||||
encoder.eval()
|
||||
set_module_gradients_enabled(encoder, tuning_flag=self.tune_encoder)
|
||||
return encoder
|
||||
|
||||
|
||||
|
@ -106,9 +113,11 @@ class PoolingParams(param.Parameterized):
|
|||
default=4, doc="If transformer pooling is chosen, this defines the number of encoding layers.",
|
||||
)
|
||||
num_transformer_pool_heads: int = param.Integer(
|
||||
4,
|
||||
doc="If transformer pooling is chosen, this defines the number\
|
||||
of attention heads.",
|
||||
default=4, doc="If transformer pooling is chosen, this defines the number of attention heads.",
|
||||
)
|
||||
tune_pooling: bool = param.Boolean(
|
||||
default=True,
|
||||
doc="If True (default), fine-tune the pooling layer during training. If False, keep the pooling layer frozen.",
|
||||
)
|
||||
|
||||
def get_pooling_layer(self, num_encoding: int) -> Tuple[nn.Module, int]:
|
||||
|
@ -139,6 +148,7 @@ class PoolingParams(param.Parameterized):
|
|||
)
|
||||
self.pool_out_dim = 1 # currently this is hardcoded in forward of the TransformerPooling
|
||||
else:
|
||||
raise ValueError(f"Unsupported pooling type: {self.pooling_type}")
|
||||
raise ValueError(f"Unsupported pooling type: {self.pool_type}")
|
||||
num_features = num_encoding * self.pool_out_dim
|
||||
set_module_gradients_enabled(pooling_layer, tuning_flag=self.tune_pooling)
|
||||
return pooling_layer, num_features
|
||||
|
|
|
@ -348,7 +348,7 @@ class DeepMILOutputsHandler:
|
|||
|
||||
# Writing completed successfully; delete temporary back-up
|
||||
if self.previous_validation_outputs_dir.exists():
|
||||
shutil.rmtree(self.previous_validation_outputs_dir)
|
||||
shutil.rmtree(self.previous_validation_outputs_dir, ignore_errors=True)
|
||||
|
||||
def save_test_outputs(self, epoch_results: EpochResultsType, is_global_rank_zero: bool = True) -> None:
|
||||
"""Render and save test epoch outputs.
|
||||
|
|
|
@ -21,7 +21,6 @@ class MockDeepSMILETilesPanda(DeepSMILETilesPanda):
|
|||
pool_hidden_dim=16,
|
||||
num_transformer_pool_layers=1,
|
||||
num_transformer_pool_heads=1,
|
||||
is_finetune=False,
|
||||
class_names=["ISUP 0", "ISUP 1", "ISUP 2", "ISUP 3", "ISUP 4", "ISUP 5"],
|
||||
# Encoder parameters
|
||||
encoder_type=ImageNetEncoder.__name__,
|
||||
|
@ -62,7 +61,6 @@ class MockDeepSMILESlidesPanda(DeepSMILESlidesPanda):
|
|||
pool_hidden_dim=16,
|
||||
num_transformer_pool_layers=1,
|
||||
num_transformer_pool_heads=1,
|
||||
is_finetune=True,
|
||||
class_names=["ISUP 0", "ISUP 1", "ISUP 2", "ISUP 3", "ISUP 4", "ISUP 5"],
|
||||
# Encoder parameters
|
||||
encoder_type=ImageNetEncoder.__name__,
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from pytorch_lightning import Trainer
|
||||
import torch
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
@ -13,6 +14,7 @@ from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Typ
|
|||
|
||||
from torch import Tensor, argmax, nn, rand, randint, randn, round, stack, allclose
|
||||
from torch.utils.data._utils.collate import default_collate
|
||||
from health_cpath.datamodules.panda_module import PandaTilesDataModule
|
||||
|
||||
from health_ml.networks.layers.attention_layers import AttentionLayer
|
||||
from health_cpath.configs.classification.BaseMIL import BaseMILTiles
|
||||
|
@ -35,12 +37,13 @@ from health_ml.utils.common_utils import is_gpu_available
|
|||
no_gpu = not is_gpu_available()
|
||||
|
||||
|
||||
def get_supervised_imagenet_encoder_params() -> EncoderParams:
|
||||
return EncoderParams(encoder_type=ImageNetEncoder.__name__)
|
||||
def get_supervised_imagenet_encoder_params(tune_encoder: bool = True) -> EncoderParams:
|
||||
return EncoderParams(encoder_type=ImageNetEncoder.__name__, tune_encoder=tune_encoder)
|
||||
|
||||
|
||||
def get_attention_pooling_layer_params(pool_out_dim: int = 1) -> PoolingParams:
|
||||
return PoolingParams(pool_type=AttentionLayer.__name__, pool_out_dim=pool_out_dim, pool_hidden_dim=5)
|
||||
def get_attention_pooling_layer_params(pool_out_dim: int = 1, tune_pooling: bool = True) -> PoolingParams:
|
||||
return PoolingParams(pool_type=AttentionLayer.__name__, pool_out_dim=pool_out_dim, pool_hidden_dim=5,
|
||||
tune_pooling=tune_pooling)
|
||||
|
||||
|
||||
def _test_lightningmodule(
|
||||
|
@ -164,8 +167,8 @@ def mock_panda_slides_root_dir(
|
|||
|
||||
|
||||
@pytest.mark.parametrize("n_classes", [1, 3])
|
||||
@pytest.mark.parametrize("batch_size", [1, 15])
|
||||
@pytest.mark.parametrize("max_bag_size", [1, 7])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize("max_bag_size", [1, 5])
|
||||
@pytest.mark.parametrize("pool_out_dim", [1, 6])
|
||||
@pytest.mark.parametrize("dropout_rate", [None, 0.5])
|
||||
def test_lightningmodule_attention(
|
||||
|
@ -425,3 +428,112 @@ def test_class_weights_multiclass() -> None:
|
|||
# TODO: the test should reflect actual weighted loss operation for the class weights after
|
||||
# batch_size > 1 is implemented.
|
||||
assert allclose(loss_weighted, loss_unweighted)
|
||||
|
||||
|
||||
def test_wrong_tuning_options() -> None:
|
||||
with pytest.raises(ValueError,
|
||||
match=r"At least one of the encoder, pooling or classifier should be fine tuned"):
|
||||
_ = MockDeepSMILETilesPanda(
|
||||
tmp_path=Path("foo"),
|
||||
tune_encoder=False,
|
||||
tune_pooling=False,
|
||||
tune_classifier=False
|
||||
)
|
||||
|
||||
|
||||
def _get_datamodule(tmp_path: Path) -> PandaTilesDataModule:
|
||||
tiles_generator = MockPandaTilesGenerator(
|
||||
dest_data_path=tmp_path,
|
||||
mock_type=MockHistoDataType.FAKE,
|
||||
n_tiles=4,
|
||||
n_slides=10,
|
||||
n_channels=3,
|
||||
tile_size=28,
|
||||
img_size=224,
|
||||
)
|
||||
tiles_generator.generate_mock_histo_data()
|
||||
|
||||
datamodule = PandaTilesDataModule(root_path=tmp_path, batch_size=2, max_bag_size=4)
|
||||
return datamodule
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tune_classifier", [False, True])
|
||||
@pytest.mark.parametrize("tune_pooling", [False, True])
|
||||
@pytest.mark.parametrize("tune_encoder", [False, True])
|
||||
def test_finetuning_options(
|
||||
tune_encoder: bool, tune_pooling: bool, tune_classifier: bool, tmp_path: Path
|
||||
) -> None:
|
||||
module = TilesDeepMILModule(
|
||||
n_classes=1,
|
||||
label_column=DEFAULT_LABEL_COLUMN,
|
||||
encoder_params=get_supervised_imagenet_encoder_params(tune_encoder=tune_encoder),
|
||||
pooling_params=get_attention_pooling_layer_params(pool_out_dim=1, tune_pooling=tune_pooling),
|
||||
tune_classifier=tune_classifier,
|
||||
)
|
||||
|
||||
assert module.encoder_params.tune_encoder == tune_encoder
|
||||
assert module.pooling_params.tune_pooling == tune_pooling
|
||||
assert module.tune_classifier == tune_classifier
|
||||
|
||||
for params in module.encoder.parameters():
|
||||
assert params.requires_grad == tune_encoder
|
||||
for params in module.aggregation_fn.parameters():
|
||||
assert params.requires_grad == tune_pooling
|
||||
for params in module.classifier_fn.parameters():
|
||||
assert params.requires_grad == tune_classifier
|
||||
|
||||
instances = torch.randn(4, 3, 224, 224)
|
||||
|
||||
def _assert_existing_gradients_fn(tensor: Tensor, tuning_flag: bool) -> None:
|
||||
assert tensor.requires_grad == tuning_flag
|
||||
if tuning_flag:
|
||||
assert tensor.grad_fn is not None
|
||||
else:
|
||||
assert tensor.grad_fn is None
|
||||
|
||||
with torch.enable_grad():
|
||||
instance_features = module.get_instance_features(instances)
|
||||
_assert_existing_gradients_fn(instance_features, tuning_flag=tune_encoder)
|
||||
assert module.encoder.training == tune_encoder
|
||||
|
||||
attentions, bag_features = module.get_attentions_and_bag_features(instances)
|
||||
_assert_existing_gradients_fn(attentions, tuning_flag=tune_pooling)
|
||||
_assert_existing_gradients_fn(bag_features, tuning_flag=tune_pooling)
|
||||
assert module.aggregation_fn.training == tune_pooling
|
||||
|
||||
bag_logit = module.get_bag_logit(bag_features)
|
||||
# bag_logit gradients are required for pooling layer gradients computation, hence
|
||||
# "tuning_flag=tune_classifier or tune_pooling"
|
||||
_assert_existing_gradients_fn(bag_logit, tuning_flag=tune_classifier or tune_pooling)
|
||||
assert module.classifier_fn.training == tune_classifier
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tune_classifier", [False, True])
|
||||
@pytest.mark.parametrize("tune_pooling", [False, True])
|
||||
@pytest.mark.parametrize("tune_encoder", [False, True])
|
||||
def test_training_with_different_finetuning_options(
|
||||
tune_encoder: bool, tune_pooling: bool, tune_classifier: bool, tmp_path: Path
|
||||
) -> None:
|
||||
if any([tune_encoder, tune_pooling, tune_classifier]):
|
||||
module = TilesDeepMILModule(
|
||||
n_classes=6,
|
||||
label_column=MockPandaTilesGenerator.ISUP_GRADE,
|
||||
encoder_params=get_supervised_imagenet_encoder_params(tune_encoder=tune_encoder),
|
||||
pooling_params=get_attention_pooling_layer_params(pool_out_dim=1, tune_pooling=tune_pooling),
|
||||
tune_classifier=tune_classifier,
|
||||
)
|
||||
|
||||
def _assert_existing_gradients(module: nn.Module, tuning_flag: bool) -> None:
|
||||
for param in module.parameters():
|
||||
if tuning_flag:
|
||||
assert param.grad is not None
|
||||
else:
|
||||
assert param.grad is None
|
||||
|
||||
with patch.object(module, "validation_step"):
|
||||
trainer = Trainer(max_epochs=1)
|
||||
trainer.fit(module, datamodule=_get_datamodule(tmp_path))
|
||||
|
||||
_assert_existing_gradients(module.classifier_fn, tuning_flag=tune_classifier)
|
||||
_assert_existing_gradients(module.aggregation_fn, tuning_flag=tune_pooling)
|
||||
_assert_existing_gradients(module.encoder, tuning_flag=tune_encoder)
|
||||
|
|
Загрузка…
Ссылка в новой задаче