BUG: DeepMIL model uses training data even when running in eval mode (#889)

Instantiating the DeepMIL classifier in `create_model` always relies on
the training dataset and module, which is not available when doing
evaluation.
This commit is contained in:
Anton Schwaighofer 2023-05-25 17:55:20 +01:00 коммит произвёл GitHub
Родитель 26ff7d0d94
Коммит e8f6b44638
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
7 изменённых файлов: 154 добавлений и 29 удалений

Просмотреть файл

@ -78,7 +78,7 @@ RUN_RECOVERY_FILE = "most_recent_run.txt"
SDK_NAME = "innereye"
SDK_VERSION = "2.0"
DEFAULT_DOCKER_BASE_IMAGE = "mcr.microsoft.com/azureml/openmpi4.1.0-cuda11.3-cudnn8-ubuntu20.04"
DEFAULT_DOCKER_BASE_IMAGE = "mcr.microsoft.com/azureml/openmpi4.1.0-cuda11.3-cudnn8-ubuntu20.04:20230509.v1"
DEFAULT_DOCKER_SHM_SIZE = "100g"
# hyperparameter search args

Просмотреть файл

@ -13,6 +13,7 @@ from pathlib import Path
from monai.transforms import Compose
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
import torch
from health_azure.utils import create_from_matching_params
from health_cpath.preprocessing.loading import LoadingParams
@ -20,6 +21,7 @@ from health_cpath.utils.callbacks import LossAnalysisCallback, LossCallbackParam
from health_cpath.utils.wsi_utils import TilingParams
from health_ml.deep_learning_config import OptimizerParams
from health_ml.experiment_config import RunnerMode
from health_ml.lightning_container import LightningContainer
from health_ml.utils.checkpoint_utils import CheckpointParser
@ -102,6 +104,7 @@ class BaseMIL(LightningContainer, LoadingParams, EncoderParams, PoolingParams, C
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.data_module: Optional[HistoDataModule] = None
self.run_extra_val_epoch = True # Enable running an additional validation step to save tiles/slides thumbnails
metric_optim = "max" if self.maximise_primary_metric else "min"
self.best_checkpoint_filename = f"checkpoint_{metric_optim}_val_{self.primary_val_metric.value}"
@ -163,7 +166,7 @@ class BaseMIL(LightningContainer, LoadingParams, EncoderParams, PoolingParams, C
return set()
def get_outputs_handler(self) -> DeepMILOutputsHandler:
n_classes = self.data_module.train_dataset.n_classes
n_classes = self.get_num_classes()
outputs_handler = DeepMILOutputsHandler(
outputs_root=self.outputs_folder,
n_classes=n_classes,
@ -241,14 +244,53 @@ class BaseMIL(LightningContainer, LoadingParams, EncoderParams, PoolingParams, C
def get_encoder_params(self) -> EncoderParams:
return create_from_matching_params(self, EncoderParams)
def get_data_module_for_runner_mode(self) -> HistoDataModule:
"""Gets the data module that the model is using from the private data_module field. If it is not set yet, it
is set from either the training or the evaluation data module, depending on the runner mode.
:return: A data module for either training or evaluation, depending on the runner mode.
"""
if self.data_module is None:
if self.runner_mode == RunnerMode.TRAIN:
self.data_module = self.get_data_module()
elif self.runner_mode == RunnerMode.EVAL_FULL:
self.data_module = self.get_eval_data_module()
else:
raise ValueError(f"Unknown runner mode {self.runner_mode}")
return self.data_module
def get_num_classes(self) -> int:
"""Gets the number of classes that the model handles. In training mode, this is the number of classes in the
training dataset. In evaluation mode, this is the number of classes in the test dataset.
This method has a side effect: It creates the data module if it is not yet set in the `data_module` attribute.
"""
if self.runner_mode == RunnerMode.EVAL_FULL:
return self.get_data_module_for_runner_mode().test_dataset.n_classes
else:
return self.get_data_module_for_runner_mode().train_dataset.n_classes
def get_class_weights(self) -> torch.Tensor:
"""Gets the class weights that the model should use. In training mode, this is the class weights from the
training data module. In evaluation mode, this is a tensor with all ones (the class weights will be loaded
from the checkpoint, so their value does not matter).
:return: A tensor if the model is used for training, None otherwise.
"""
if self.runner_mode == RunnerMode.EVAL_FULL:
num_classes = self.get_num_classes()
return torch.ones(2 if num_classes == 1 else num_classes)
else:
return self.get_data_module_for_runner_mode().class_weights
def create_model(self) -> DeepMILModule:
self.data_module = self.get_data_module()
num_classes = self.get_num_classes()
outputs_handler = self.get_outputs_handler()
deepmil_module = DeepMILModule(
label_column=self.get_label_column(),
n_classes=self.data_module.train_dataset.n_classes,
n_classes=num_classes,
class_names=self.class_names,
class_weights=self.data_module.class_weights,
class_weights=self.get_class_weights(),
outputs_folder=self.outputs_folder,
encoder_params=self.get_encoder_params(),
pooling_params=create_from_matching_params(self, PoolingParams),
@ -265,6 +307,9 @@ class BaseMIL(LightningContainer, LoadingParams, EncoderParams, PoolingParams, C
def get_data_module(self) -> HistoDataModule:
raise NotImplementedError
def get_eval_data_module(self) -> HistoDataModule:
raise NotImplementedError
def get_slides_dataset(self) -> Optional[SlidesDataset]:
return None
@ -305,7 +350,7 @@ class BaseMILTiles(BaseMIL):
)
def get_label_column(self) -> str:
return self.data_module.train_dataset.label_column
return self.get_data_module_for_runner_mode().train_dataset.label_column
def setup(self) -> None:
super().setup()

Просмотреть файл

@ -2,23 +2,20 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from copy import deepcopy
import os
from pytorch_lightning import Trainer
import torch
import pytest
from copy import deepcopy
from pathlib import Path
from unittest.mock import MagicMock, patch
from typing import Any, Callable, Dict, Iterable, List, Optional, Type
from unittest.mock import MagicMock, patch, DEFAULT
from torch import Tensor, argmax, nn, rand, randint, randn, round, stack, allclose
import pytest
import torch
from pytorch_lightning import LightningDataModule, Trainer
from torch import Tensor, allclose, argmax, nn, rand, randint, randn, round, stack
from torch.utils.data import DataLoader
from torch.utils.data._utils.collate import default_collate
from health_cpath.configs.classification.DeepSMILESlidesPandaBenchmark import SlidesPandaSSLMILBenchmark
from health_cpath.datamodules.panda_module import PandaTilesDataModule
from health_ml.networks.layers.attention_layers import AttentionLayer, TransformerPoolingBenchmark
from health_cpath.configs.classification.BaseMIL import BaseMIL, BaseMILTiles
from health_cpath.configs.classification.DeepSMILECrck import DeepSMILECrck, TcgaCrckSSLMIL
from health_cpath.configs.classification.DeepSMILEPanda import (
BaseDeepSMILEPanda,
@ -26,7 +23,10 @@ from health_cpath.configs.classification.DeepSMILEPanda import (
SlidesPandaSSLMIL,
TilesPandaSSLMIL,
)
from health_cpath.configs.classification.DeepSMILESlidesPandaBenchmark import SlidesPandaSSLMILBenchmark
from health_cpath.configs.run_ids import innereye_ssl_checkpoint_binary, innereye_ssl_checkpoint_crck_4ws
from health_cpath.datamodules.base_module import HistoDataModule, TilesDataModule
from health_cpath.datamodules.panda_module import PandaTilesDataModule
from health_cpath.datasets.base_dataset import DEFAULT_LABEL_COLUMN, TilesDataset
from health_cpath.datasets.default_paths import PANDA_5X_TILES_DATASET_ID, TCGA_CRCK_DATASET_DIR
from health_cpath.models.deepmil import DeepMILModule
@ -40,15 +40,16 @@ from health_cpath.models.encoders import (
)
from health_cpath.utils.deepmil_utils import ClassifierParams, EncoderParams, PoolingParams
from health_cpath.utils.naming import DeepMILSubmodules, MetricsKey, ResultsKey, SlideKey
from testhisto.mocks.base_data_generator import MockHistoDataType
from testhisto.mocks.tiles_generator import MockPandaTilesGenerator
from testhisto.mocks.container import MockDeepSMILETilesPanda, MockDeepSMILESlidesPanda
from health_ml.utils.common_utils import is_gpu_available
from health_ml.eval_runner import EvalRunner
from health_ml.experiment_config import ExperimentConfig, RunnerMode
from health_ml.networks.layers.attention_layers import AttentionLayer, TransformerPoolingBenchmark
from health_ml.utils.checkpoint_utils import CheckpointParser
from health_cpath.configs.run_ids import innereye_ssl_checkpoint_crck_4ws, innereye_ssl_checkpoint_binary
from testhisto.models.test_encoders import TEST_SSL_RUN_ID
from torch.utils.data import DataLoader
from health_ml.utils.common_utils import is_gpu_available
from testhisto.mocks.base_data_generator import MockHistoDataType
from testhisto.mocks.container import MockDeepSMILESlidesPanda, MockDeepSMILETilesPanda
from testhisto.mocks.tiles_generator import MockPandaTilesGenerator
from testhisto.models.test_encoders import TEST_SSL_RUN_ID
no_gpu = not is_gpu_available()
@ -707,7 +708,10 @@ def test_checkpoint_name(
def test_on_run_extra_val_epoch(mock_panda_tiles_root_dir: Path) -> None:
container = MockDeepSMILETilesPanda(tmp_path=mock_panda_tiles_root_dir)
container.setup()
container.data_module = MagicMock()
num_classes = 6
container.data_module = MagicMock(
class_weights=torch.ones(num_classes), train_dataset=MagicMock(n_classes=num_classes)
)
container.create_lightning_module_and_store()
assert not container.model._on_extra_val_epoch
assert (
@ -828,3 +832,76 @@ def test_encoder_checkpointning(
custom_forward.return_value = torch.zeros(container.max_bag_size, feature_dim)
_, _ = model_ckpt_enc(sample[SlideKey.IMAGE][0])
custom_forward.assert_called_once()
@pytest.mark.parametrize("runner_mode", [RunnerMode.TRAIN, RunnerMode.EVAL_FULL])
def test_create_model_in_eval_mode(tmp_path: Path, runner_mode: RunnerMode) -> None:
"""When the runner is in eval mode, the call to create_model should not instantiate the training data loader."""
container = BaseMIL(encoder_type="Resnet18", pool_type="MeanPoolingLayer")
eval_runner = EvalRunner(
container=container, experiment_config=ExperimentConfig(mode=runner_mode), project_root=tmp_path
)
get_data_module = MagicMock(
return_value=MagicMock(class_weights=torch.ones(2), train_dataset=MagicMock(n_classes=1))
)
get_eval_data_module = MagicMock(
return_value=MagicMock(class_weights=torch.ones(2), test_dataset=MagicMock(n_classes=1))
)
with patch.multiple(
"health_cpath.configs.classification.BaseMIL.BaseMIL",
get_data_module=get_data_module,
get_eval_data_module=get_eval_data_module,
get_label_column=DEFAULT,
):
eval_runner.setup()
if runner_mode == RunnerMode.TRAIN:
get_data_module.assert_called_once()
get_eval_data_module.assert_not_called()
else:
get_data_module.assert_not_called()
get_eval_data_module.assert_called_once()
def test_run_model_in_eval_mode(tmp_path: Path) -> None:
"""When the runner is in eval mode, ensure a checkpoint can be loaded correctly."""
# Mock a training dataset with two classes, but this is represented as n_classes=1 in the run we are trying
# to repro here
one_class_dataset = MagicMock(n_classes=1)
class_weights = torch.ones(2)
get_data_module = MagicMock(return_value=MagicMock(class_weights=class_weights, train_dataset=one_class_dataset))
with patch.multiple(
"health_cpath.datamodules.base_module.HistoDataModule",
get_splits=MagicMock(return_value=(MagicMock(), MagicMock(), one_class_dataset)),
_get_dataloader=MagicMock(return_value=DataLoader(one_class_dataset)),
):
eval_data_module = HistoDataModule(tmp_path)
with patch.multiple(
"health_cpath.configs.classification.BaseMIL.BaseMIL",
get_data_module=get_data_module,
get_label_column=MagicMock(return_value="label"),
get_eval_data_module=MagicMock(return_value=eval_data_module),
):
container_orig = BaseMIL(encoder_type="Resnet18", pool_type="MeanPoolingLayer")
model = container_orig.create_model()
# Create a trainer to save a checkpoint only
trainer = Trainer()
trainer.strategy.model = model
checkpoint = tmp_path / "model.ckpt"
trainer.save_checkpoint(checkpoint)
# Now create the model again and run it in eval mode, loading the checkpoint
container = BaseMIL(encoder_type="Resnet18", pool_type="MeanPoolingLayer")
container.src_checkpoint = CheckpointParser(str(checkpoint))
eval_runner = EvalRunner(
container=container,
experiment_config=ExperimentConfig(mode=RunnerMode.EVAL_FULL),
project_root=tmp_path,
)
eval_runner.setup()
# Before bug fix:
# Error(s) in loading state_dict for DeepMILModule:
# Unexpected key(s) in state_dict: "loss_fn.pos_weight", "loss_fn_no_reduction.pos_weight".
# size mismatch for classifier_fn.weight: copying a param with shape torch.Size([1, 512]) from checkpoint,
# the shape in current model is torch.Size([2, 512]).
# size mismatch for classifier_fn.bias: copying a param with shape torch.Size([1]) from checkpoint,
# the shape in current model is torch.Size([2]).
eval_runner.run()

Просмотреть файл

@ -15,7 +15,7 @@ from torch.optim.lr_scheduler import _LRScheduler
from health_azure.utils import create_from_matching_params
from health_ml.deep_learning_config import DatasetParams, OptimizerParams, OutputParams, TrainerParams, WorkflowParams
from health_ml.experiment_config import ExperimentConfig
from health_ml.experiment_config import ExperimentConfig, RunnerMode
from health_ml.utils.checkpoint_utils import get_best_checkpoint_path
from health_ml.utils.lr_scheduler import SchedulerWithWarmUp
from health_ml.utils.model_util import create_optimizer
@ -32,8 +32,10 @@ class LightningContainer(WorkflowParams, DatasetParams, OutputParams, TrainerPar
super().__init__(**kwargs)
self._model: Optional[LightningModule] = None
self._model_name = type(self).__name__
self.num_nodes = 1
self.trained_weights_path: Optional[Path] = None
# Number of nodes and the runner mode are read from the ExperimentConfig, and will be copied here
self.num_nodes = 1
self.runner_mode = RunnerMode.TRAIN
def validate(self) -> None:
WorkflowParams.validate(self)

Просмотреть файл

@ -68,6 +68,7 @@ class RunnerBase:
self.container = container
self.experiment_config = experiment_config
self.container.num_nodes = self.experiment_config.num_nodes
self.container.runner_mode = self.experiment_config.mode
self.project_root: Path = project_root or fixed_paths.repository_root_directory()
self.storing_logger: Optional[StoringLogger] = None
self._has_setup_run = False

Просмотреть файл

@ -37,7 +37,7 @@ class DatasetSplits:
:return: a Set of elements that appear in more than one collection
"""
intersection = set()
for col1, col2 in combinations(map(set, collections), 2):
for col1, col2 in combinations(map(set, collections), 2): # type: ignore
intersection |= col1 & col2
return intersection

Просмотреть файл

@ -24,7 +24,7 @@ def test_split_by_subject_ids() -> None:
test_df, test_ids, train_ids, val_ids = _get_test_df()
splits = DatasetSplits.from_subject_ids(test_df, train_ids, test_ids, val_ids, subject_column=CSV_SUBJECT_HEADER)
for x, y in zip([splits.train, splits.test, splits.val], [train_ids, test_ids, val_ids]):
for x, y in zip([splits.train, splits.test, splits.val], [train_ids, test_ids, val_ids]): # type: ignore
pd.testing.assert_frame_equal(x, test_df[test_df.subject.isin(y)])
@ -66,7 +66,7 @@ def _check_is_partition(total: pd.DataFrame, parts: Iterable[pd.DataFrame], colu
total = set(total[column].unique())
parts = [set(part[column].unique()) for part in parts]
assert total == set.union(*parts)
for part1, part2 in combinations(parts, 2):
for part1, part2 in combinations(parts, 2): # type: ignore
assert part1.isdisjoint(part2)