зеркало из https://github.com/microsoft/hi-ml.git
BUG: DeepMIL model uses training data even when running in eval mode (#889)
Instantiating the DeepMIL classifier in `create_model` always relies on the training dataset and module, which is not available when doing evaluation.
This commit is contained in:
Родитель
26ff7d0d94
Коммит
e8f6b44638
|
@ -78,7 +78,7 @@ RUN_RECOVERY_FILE = "most_recent_run.txt"
|
|||
SDK_NAME = "innereye"
|
||||
SDK_VERSION = "2.0"
|
||||
|
||||
DEFAULT_DOCKER_BASE_IMAGE = "mcr.microsoft.com/azureml/openmpi4.1.0-cuda11.3-cudnn8-ubuntu20.04"
|
||||
DEFAULT_DOCKER_BASE_IMAGE = "mcr.microsoft.com/azureml/openmpi4.1.0-cuda11.3-cudnn8-ubuntu20.04:20230509.v1"
|
||||
DEFAULT_DOCKER_SHM_SIZE = "100g"
|
||||
|
||||
# hyperparameter search args
|
||||
|
|
|
@ -13,6 +13,7 @@ from pathlib import Path
|
|||
from monai.transforms import Compose
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
||||
import torch
|
||||
|
||||
from health_azure.utils import create_from_matching_params
|
||||
from health_cpath.preprocessing.loading import LoadingParams
|
||||
|
@ -20,6 +21,7 @@ from health_cpath.utils.callbacks import LossAnalysisCallback, LossCallbackParam
|
|||
from health_cpath.utils.wsi_utils import TilingParams
|
||||
|
||||
from health_ml.deep_learning_config import OptimizerParams
|
||||
from health_ml.experiment_config import RunnerMode
|
||||
from health_ml.lightning_container import LightningContainer
|
||||
from health_ml.utils.checkpoint_utils import CheckpointParser
|
||||
|
||||
|
@ -102,6 +104,7 @@ class BaseMIL(LightningContainer, LoadingParams, EncoderParams, PoolingParams, C
|
|||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.data_module: Optional[HistoDataModule] = None
|
||||
self.run_extra_val_epoch = True # Enable running an additional validation step to save tiles/slides thumbnails
|
||||
metric_optim = "max" if self.maximise_primary_metric else "min"
|
||||
self.best_checkpoint_filename = f"checkpoint_{metric_optim}_val_{self.primary_val_metric.value}"
|
||||
|
@ -163,7 +166,7 @@ class BaseMIL(LightningContainer, LoadingParams, EncoderParams, PoolingParams, C
|
|||
return set()
|
||||
|
||||
def get_outputs_handler(self) -> DeepMILOutputsHandler:
|
||||
n_classes = self.data_module.train_dataset.n_classes
|
||||
n_classes = self.get_num_classes()
|
||||
outputs_handler = DeepMILOutputsHandler(
|
||||
outputs_root=self.outputs_folder,
|
||||
n_classes=n_classes,
|
||||
|
@ -241,14 +244,53 @@ class BaseMIL(LightningContainer, LoadingParams, EncoderParams, PoolingParams, C
|
|||
def get_encoder_params(self) -> EncoderParams:
|
||||
return create_from_matching_params(self, EncoderParams)
|
||||
|
||||
def get_data_module_for_runner_mode(self) -> HistoDataModule:
|
||||
"""Gets the data module that the model is using from the private data_module field. If it is not set yet, it
|
||||
is set from either the training or the evaluation data module, depending on the runner mode.
|
||||
|
||||
:return: A data module for either training or evaluation, depending on the runner mode.
|
||||
"""
|
||||
if self.data_module is None:
|
||||
if self.runner_mode == RunnerMode.TRAIN:
|
||||
self.data_module = self.get_data_module()
|
||||
elif self.runner_mode == RunnerMode.EVAL_FULL:
|
||||
self.data_module = self.get_eval_data_module()
|
||||
else:
|
||||
raise ValueError(f"Unknown runner mode {self.runner_mode}")
|
||||
return self.data_module
|
||||
|
||||
def get_num_classes(self) -> int:
|
||||
"""Gets the number of classes that the model handles. In training mode, this is the number of classes in the
|
||||
training dataset. In evaluation mode, this is the number of classes in the test dataset.
|
||||
|
||||
This method has a side effect: It creates the data module if it is not yet set in the `data_module` attribute.
|
||||
"""
|
||||
if self.runner_mode == RunnerMode.EVAL_FULL:
|
||||
return self.get_data_module_for_runner_mode().test_dataset.n_classes
|
||||
else:
|
||||
return self.get_data_module_for_runner_mode().train_dataset.n_classes
|
||||
|
||||
def get_class_weights(self) -> torch.Tensor:
|
||||
"""Gets the class weights that the model should use. In training mode, this is the class weights from the
|
||||
training data module. In evaluation mode, this is a tensor with all ones (the class weights will be loaded
|
||||
from the checkpoint, so their value does not matter).
|
||||
|
||||
:return: A tensor if the model is used for training, None otherwise.
|
||||
"""
|
||||
if self.runner_mode == RunnerMode.EVAL_FULL:
|
||||
num_classes = self.get_num_classes()
|
||||
return torch.ones(2 if num_classes == 1 else num_classes)
|
||||
else:
|
||||
return self.get_data_module_for_runner_mode().class_weights
|
||||
|
||||
def create_model(self) -> DeepMILModule:
|
||||
self.data_module = self.get_data_module()
|
||||
num_classes = self.get_num_classes()
|
||||
outputs_handler = self.get_outputs_handler()
|
||||
deepmil_module = DeepMILModule(
|
||||
label_column=self.get_label_column(),
|
||||
n_classes=self.data_module.train_dataset.n_classes,
|
||||
n_classes=num_classes,
|
||||
class_names=self.class_names,
|
||||
class_weights=self.data_module.class_weights,
|
||||
class_weights=self.get_class_weights(),
|
||||
outputs_folder=self.outputs_folder,
|
||||
encoder_params=self.get_encoder_params(),
|
||||
pooling_params=create_from_matching_params(self, PoolingParams),
|
||||
|
@ -265,6 +307,9 @@ class BaseMIL(LightningContainer, LoadingParams, EncoderParams, PoolingParams, C
|
|||
def get_data_module(self) -> HistoDataModule:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_eval_data_module(self) -> HistoDataModule:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_slides_dataset(self) -> Optional[SlidesDataset]:
|
||||
return None
|
||||
|
||||
|
@ -305,7 +350,7 @@ class BaseMILTiles(BaseMIL):
|
|||
)
|
||||
|
||||
def get_label_column(self) -> str:
|
||||
return self.data_module.train_dataset.label_column
|
||||
return self.get_data_module_for_runner_mode().train_dataset.label_column
|
||||
|
||||
def setup(self) -> None:
|
||||
super().setup()
|
||||
|
|
|
@ -2,23 +2,20 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
from copy import deepcopy
|
||||
import os
|
||||
from pytorch_lightning import Trainer
|
||||
import torch
|
||||
import pytest
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Type
|
||||
from unittest.mock import MagicMock, patch, DEFAULT
|
||||
|
||||
from torch import Tensor, argmax, nn, rand, randint, randn, round, stack, allclose
|
||||
import pytest
|
||||
import torch
|
||||
from pytorch_lightning import LightningDataModule, Trainer
|
||||
from torch import Tensor, allclose, argmax, nn, rand, randint, randn, round, stack
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data._utils.collate import default_collate
|
||||
from health_cpath.configs.classification.DeepSMILESlidesPandaBenchmark import SlidesPandaSSLMILBenchmark
|
||||
from health_cpath.datamodules.panda_module import PandaTilesDataModule
|
||||
|
||||
from health_ml.networks.layers.attention_layers import AttentionLayer, TransformerPoolingBenchmark
|
||||
from health_cpath.configs.classification.BaseMIL import BaseMIL, BaseMILTiles
|
||||
|
||||
from health_cpath.configs.classification.DeepSMILECrck import DeepSMILECrck, TcgaCrckSSLMIL
|
||||
from health_cpath.configs.classification.DeepSMILEPanda import (
|
||||
BaseDeepSMILEPanda,
|
||||
|
@ -26,7 +23,10 @@ from health_cpath.configs.classification.DeepSMILEPanda import (
|
|||
SlidesPandaSSLMIL,
|
||||
TilesPandaSSLMIL,
|
||||
)
|
||||
from health_cpath.configs.classification.DeepSMILESlidesPandaBenchmark import SlidesPandaSSLMILBenchmark
|
||||
from health_cpath.configs.run_ids import innereye_ssl_checkpoint_binary, innereye_ssl_checkpoint_crck_4ws
|
||||
from health_cpath.datamodules.base_module import HistoDataModule, TilesDataModule
|
||||
from health_cpath.datamodules.panda_module import PandaTilesDataModule
|
||||
from health_cpath.datasets.base_dataset import DEFAULT_LABEL_COLUMN, TilesDataset
|
||||
from health_cpath.datasets.default_paths import PANDA_5X_TILES_DATASET_ID, TCGA_CRCK_DATASET_DIR
|
||||
from health_cpath.models.deepmil import DeepMILModule
|
||||
|
@ -40,15 +40,16 @@ from health_cpath.models.encoders import (
|
|||
)
|
||||
from health_cpath.utils.deepmil_utils import ClassifierParams, EncoderParams, PoolingParams
|
||||
from health_cpath.utils.naming import DeepMILSubmodules, MetricsKey, ResultsKey, SlideKey
|
||||
from testhisto.mocks.base_data_generator import MockHistoDataType
|
||||
from testhisto.mocks.tiles_generator import MockPandaTilesGenerator
|
||||
from testhisto.mocks.container import MockDeepSMILETilesPanda, MockDeepSMILESlidesPanda
|
||||
from health_ml.utils.common_utils import is_gpu_available
|
||||
from health_ml.eval_runner import EvalRunner
|
||||
from health_ml.experiment_config import ExperimentConfig, RunnerMode
|
||||
from health_ml.networks.layers.attention_layers import AttentionLayer, TransformerPoolingBenchmark
|
||||
from health_ml.utils.checkpoint_utils import CheckpointParser
|
||||
from health_cpath.configs.run_ids import innereye_ssl_checkpoint_crck_4ws, innereye_ssl_checkpoint_binary
|
||||
from testhisto.models.test_encoders import TEST_SSL_RUN_ID
|
||||
from torch.utils.data import DataLoader
|
||||
from health_ml.utils.common_utils import is_gpu_available
|
||||
|
||||
from testhisto.mocks.base_data_generator import MockHistoDataType
|
||||
from testhisto.mocks.container import MockDeepSMILESlidesPanda, MockDeepSMILETilesPanda
|
||||
from testhisto.mocks.tiles_generator import MockPandaTilesGenerator
|
||||
from testhisto.models.test_encoders import TEST_SSL_RUN_ID
|
||||
|
||||
no_gpu = not is_gpu_available()
|
||||
|
||||
|
@ -707,7 +708,10 @@ def test_checkpoint_name(
|
|||
def test_on_run_extra_val_epoch(mock_panda_tiles_root_dir: Path) -> None:
|
||||
container = MockDeepSMILETilesPanda(tmp_path=mock_panda_tiles_root_dir)
|
||||
container.setup()
|
||||
container.data_module = MagicMock()
|
||||
num_classes = 6
|
||||
container.data_module = MagicMock(
|
||||
class_weights=torch.ones(num_classes), train_dataset=MagicMock(n_classes=num_classes)
|
||||
)
|
||||
container.create_lightning_module_and_store()
|
||||
assert not container.model._on_extra_val_epoch
|
||||
assert (
|
||||
|
@ -828,3 +832,76 @@ def test_encoder_checkpointning(
|
|||
custom_forward.return_value = torch.zeros(container.max_bag_size, feature_dim)
|
||||
_, _ = model_ckpt_enc(sample[SlideKey.IMAGE][0])
|
||||
custom_forward.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("runner_mode", [RunnerMode.TRAIN, RunnerMode.EVAL_FULL])
|
||||
def test_create_model_in_eval_mode(tmp_path: Path, runner_mode: RunnerMode) -> None:
|
||||
"""When the runner is in eval mode, the call to create_model should not instantiate the training data loader."""
|
||||
container = BaseMIL(encoder_type="Resnet18", pool_type="MeanPoolingLayer")
|
||||
eval_runner = EvalRunner(
|
||||
container=container, experiment_config=ExperimentConfig(mode=runner_mode), project_root=tmp_path
|
||||
)
|
||||
get_data_module = MagicMock(
|
||||
return_value=MagicMock(class_weights=torch.ones(2), train_dataset=MagicMock(n_classes=1))
|
||||
)
|
||||
get_eval_data_module = MagicMock(
|
||||
return_value=MagicMock(class_weights=torch.ones(2), test_dataset=MagicMock(n_classes=1))
|
||||
)
|
||||
with patch.multiple(
|
||||
"health_cpath.configs.classification.BaseMIL.BaseMIL",
|
||||
get_data_module=get_data_module,
|
||||
get_eval_data_module=get_eval_data_module,
|
||||
get_label_column=DEFAULT,
|
||||
):
|
||||
eval_runner.setup()
|
||||
if runner_mode == RunnerMode.TRAIN:
|
||||
get_data_module.assert_called_once()
|
||||
get_eval_data_module.assert_not_called()
|
||||
else:
|
||||
get_data_module.assert_not_called()
|
||||
get_eval_data_module.assert_called_once()
|
||||
|
||||
|
||||
def test_run_model_in_eval_mode(tmp_path: Path) -> None:
|
||||
"""When the runner is in eval mode, ensure a checkpoint can be loaded correctly."""
|
||||
# Mock a training dataset with two classes, but this is represented as n_classes=1 in the run we are trying
|
||||
# to repro here
|
||||
one_class_dataset = MagicMock(n_classes=1)
|
||||
class_weights = torch.ones(2)
|
||||
get_data_module = MagicMock(return_value=MagicMock(class_weights=class_weights, train_dataset=one_class_dataset))
|
||||
with patch.multiple(
|
||||
"health_cpath.datamodules.base_module.HistoDataModule",
|
||||
get_splits=MagicMock(return_value=(MagicMock(), MagicMock(), one_class_dataset)),
|
||||
_get_dataloader=MagicMock(return_value=DataLoader(one_class_dataset)),
|
||||
):
|
||||
eval_data_module = HistoDataModule(tmp_path)
|
||||
with patch.multiple(
|
||||
"health_cpath.configs.classification.BaseMIL.BaseMIL",
|
||||
get_data_module=get_data_module,
|
||||
get_label_column=MagicMock(return_value="label"),
|
||||
get_eval_data_module=MagicMock(return_value=eval_data_module),
|
||||
):
|
||||
container_orig = BaseMIL(encoder_type="Resnet18", pool_type="MeanPoolingLayer")
|
||||
model = container_orig.create_model()
|
||||
# Create a trainer to save a checkpoint only
|
||||
trainer = Trainer()
|
||||
trainer.strategy.model = model
|
||||
checkpoint = tmp_path / "model.ckpt"
|
||||
trainer.save_checkpoint(checkpoint)
|
||||
# Now create the model again and run it in eval mode, loading the checkpoint
|
||||
container = BaseMIL(encoder_type="Resnet18", pool_type="MeanPoolingLayer")
|
||||
container.src_checkpoint = CheckpointParser(str(checkpoint))
|
||||
eval_runner = EvalRunner(
|
||||
container=container,
|
||||
experiment_config=ExperimentConfig(mode=RunnerMode.EVAL_FULL),
|
||||
project_root=tmp_path,
|
||||
)
|
||||
eval_runner.setup()
|
||||
# Before bug fix:
|
||||
# Error(s) in loading state_dict for DeepMILModule:
|
||||
# Unexpected key(s) in state_dict: "loss_fn.pos_weight", "loss_fn_no_reduction.pos_weight".
|
||||
# size mismatch for classifier_fn.weight: copying a param with shape torch.Size([1, 512]) from checkpoint,
|
||||
# the shape in current model is torch.Size([2, 512]).
|
||||
# size mismatch for classifier_fn.bias: copying a param with shape torch.Size([1]) from checkpoint,
|
||||
# the shape in current model is torch.Size([2]).
|
||||
eval_runner.run()
|
||||
|
|
|
@ -15,7 +15,7 @@ from torch.optim.lr_scheduler import _LRScheduler
|
|||
|
||||
from health_azure.utils import create_from_matching_params
|
||||
from health_ml.deep_learning_config import DatasetParams, OptimizerParams, OutputParams, TrainerParams, WorkflowParams
|
||||
from health_ml.experiment_config import ExperimentConfig
|
||||
from health_ml.experiment_config import ExperimentConfig, RunnerMode
|
||||
from health_ml.utils.checkpoint_utils import get_best_checkpoint_path
|
||||
from health_ml.utils.lr_scheduler import SchedulerWithWarmUp
|
||||
from health_ml.utils.model_util import create_optimizer
|
||||
|
@ -32,8 +32,10 @@ class LightningContainer(WorkflowParams, DatasetParams, OutputParams, TrainerPar
|
|||
super().__init__(**kwargs)
|
||||
self._model: Optional[LightningModule] = None
|
||||
self._model_name = type(self).__name__
|
||||
self.num_nodes = 1
|
||||
self.trained_weights_path: Optional[Path] = None
|
||||
# Number of nodes and the runner mode are read from the ExperimentConfig, and will be copied here
|
||||
self.num_nodes = 1
|
||||
self.runner_mode = RunnerMode.TRAIN
|
||||
|
||||
def validate(self) -> None:
|
||||
WorkflowParams.validate(self)
|
||||
|
|
|
@ -68,6 +68,7 @@ class RunnerBase:
|
|||
self.container = container
|
||||
self.experiment_config = experiment_config
|
||||
self.container.num_nodes = self.experiment_config.num_nodes
|
||||
self.container.runner_mode = self.experiment_config.mode
|
||||
self.project_root: Path = project_root or fixed_paths.repository_root_directory()
|
||||
self.storing_logger: Optional[StoringLogger] = None
|
||||
self._has_setup_run = False
|
||||
|
|
|
@ -37,7 +37,7 @@ class DatasetSplits:
|
|||
:return: a Set of elements that appear in more than one collection
|
||||
"""
|
||||
intersection = set()
|
||||
for col1, col2 in combinations(map(set, collections), 2):
|
||||
for col1, col2 in combinations(map(set, collections), 2): # type: ignore
|
||||
intersection |= col1 & col2
|
||||
return intersection
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ def test_split_by_subject_ids() -> None:
|
|||
test_df, test_ids, train_ids, val_ids = _get_test_df()
|
||||
splits = DatasetSplits.from_subject_ids(test_df, train_ids, test_ids, val_ids, subject_column=CSV_SUBJECT_HEADER)
|
||||
|
||||
for x, y in zip([splits.train, splits.test, splits.val], [train_ids, test_ids, val_ids]):
|
||||
for x, y in zip([splits.train, splits.test, splits.val], [train_ids, test_ids, val_ids]): # type: ignore
|
||||
pd.testing.assert_frame_equal(x, test_df[test_df.subject.isin(y)])
|
||||
|
||||
|
||||
|
@ -66,7 +66,7 @@ def _check_is_partition(total: pd.DataFrame, parts: Iterable[pd.DataFrame], colu
|
|||
total = set(total[column].unique())
|
||||
parts = [set(part[column].unique()) for part in parts]
|
||||
assert total == set.union(*parts)
|
||||
for part1, part2 in combinations(parts, 2):
|
||||
for part1, part2 in combinations(parts, 2): # type: ignore
|
||||
assert part1.isdisjoint(part2)
|
||||
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче