ENH: Add pl_replace_sampler_ddp flag for small validation sets (#648)

- Add pl_replace_sampler_ddp to be able to avoid validation data
distribution -> make sure all processes validate on the entire
validation set (reduce will give the same mean)
- Log extra validation outputs with a different prefix
- Refactor ClassifierParams 
- Add specific batch size for inference

Random [broken pipe
error](https://ml.azure.com/experiments/id/c4ad55a7-ce24-46a5-be8e-0fb478edf47e/runs/refs_pull_648_merge_1667920946_f585577d?wsid=/subscriptions/a85ceddd-892e-4637-ae4b-67d15ddf5f2b/resourceGroups/health-ml/providers/Microsoft.MachineLearningServices/workspaces/hi-ml&tid=72f988bf-86f1-41af-91ab-2d7cd011db47#outputsAndLogs)
This commit is contained in:
Kenza Bouzid 2022-11-11 13:12:07 +00:00 коммит произвёл GitHub
Родитель e30a0d1f6d
Коммит 1034324c51
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
20 изменённых файлов: 523 добавлений и 200 удалений

44
.github/workflows/cpath-pr.yml поставляемый
Просмотреть файл

@ -142,7 +142,7 @@ jobs:
cd ${{ env.folder }}
make smoke_test_tcgacrckimagenetmil_aml
smoke_test_tcgacrcksslmil_aml:
smoke_test_tcgacrcksslmil:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v3
@ -157,7 +157,7 @@ jobs:
cd ${{ env.folder }}
make smoke_test_tcgacrcksslmil_aml
smoke_test_crck_simclr_aml:
smoke_test_crck_simclr:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v3
@ -172,7 +172,7 @@ jobs:
cd ${{ env.folder }}
make smoke_test_crck_simclr_aml
smoke_test_crck_flexible_finetuning_aml:
smoke_test_crck_flexible_finetuning:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v3
@ -217,6 +217,36 @@ jobs:
cd ${{ env.folder }}
make smoke_test_slides_panda_loss_analysis_aml
smoke_test_slides_panda_no_ddp_sampler:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v3
with:
lfs: true
- name: Prepare Conda environment
uses: ./.github/actions/prepare_cpath_environment
- name: smoke test
run: |
cd ${{ env.folder }}
make smoke_test_slides_panda_no_ddp_sampler_aml
smoke_test_tiles_panda_no_ddp_sampler:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v3
with:
lfs: true
- name: Prepare Conda environment
uses: ./.github/actions/prepare_cpath_environment
- name: smoke test
run: |
cd ${{ env.folder }}
make smoke_test_tiles_panda_no_ddp_sampler_aml
publish:
runs-on: ubuntu-20.04
needs: [
@ -225,10 +255,14 @@ jobs:
pytest,
smoke_test_slidespandaimagenetmil,
smoke_test_tilespandaimagenetmil,
smoke_test_crck_simclr_aml,
smoke_test_crck_flexible_finetuning_aml,
smoke_test_tcgacrckimagenetmil,
smoke_test_tcgacrcksslmil,
smoke_test_crck_simclr,
smoke_test_crck_flexible_finetuning,
smoke_test_crck_loss_analysis,
smoke_test_slides_panda_loss_analysis,
smoke_test_slides_panda_no_ddp_sampler,
smoke_test_tiles_panda_no_ddp_sampler,
]
steps:
- uses: actions/checkout@v3

Просмотреть файл

@ -82,7 +82,7 @@ pytest_coverage:
pytest --cov=health_cpath --cov SSL --cov-branch --cov-report=html --cov-report=xml --cov-report=term-missing --cov-config=.coveragerc
SSL_CKPT_RUN_ID_CRCK := CRCK_SimCLR_1655731022_85790606
SRC_CKPT_RUN_ID_CRCK := TcgaCrckSSLMIL_1667236343_af6e293f
SRC_CKPT_RUN_ID_CRCK := TcgaCrckSSLMIL_1667832811_4e855804
# Run regression tests and compare performance
define BASE_CPATH_RUNNER_COMMAND
@ -190,6 +190,10 @@ define LOSS_ANALYSIS_ARGS
--num_slides_scatter=2 --num_slides_heatmap=2 --save_tile_ids=True
endef
define DDP_SAMPLER_ARGS
--pl_replace_sampler_ddp=False
endef
# The following test takes around 5 minutes
smoke_test_slidespandaimagenetmil_local:
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDASLIDES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
@ -259,6 +263,22 @@ smoke_test_slides_panda_loss_analysis_aml:
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDASLIDES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${LOSS_ANALYSIS_ARGS} ${AML_MULTIPLE_DEVICE_ARGS};}
smoke tests local: smoke_test_slidespandaimagenetmil_local smoke_test_tilespandaimagenetmil_local smoke_test_tcgacrcksslmil_local smoke_test_crck_simclr_local smoke_test_crck_flexible_finetuning_local smoke_test_tcgacrckimagenetmil_local smoke_test_crck_loss_analysis_local smoke_test_slides_panda_loss_analysis_local
smoke_test_slides_panda_no_ddp_sampler_local:
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDASLIDES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${DDP_SAMPLER_ARGS};}
smoke tests AML: smoke_test_slidespandaimagenetmil_aml smoke_test_tilespandaimagenetmil_aml smoke_test_tcgacrcksslmil_aml smoke_test_crck_simclr_aml smoke_test_crck_flexible_finetuning_aml smoke_test_tcgacrckimagenetmil_aml smoke_test_crck_loss_analysis_aml smoke_test_slides_panda_loss_analysis_aml
smoke_test_slides_panda_no_ddp_sampler_aml:
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDASLIDES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${DDP_SAMPLER_ARGS} ${AML_MULTIPLE_DEVICE_ARGS};}
smoke_test_tiles_panda_no_ddp_sampler_local:
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDATILES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${DDP_SAMPLER_ARGS};}
smoke_test_tiles_panda_no_ddp_sampler_aml:
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDATILES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${DDP_SAMPLER_ARGS} ${AML_MULTIPLE_DEVICE_ARGS};}
smoke tests local: smoke_test_slidespandaimagenetmil_local smoke_test_tilespandaimagenetmil_local smoke_test_tcgacrcksslmil_local smoke_test_crck_simclr_local smoke_test_crck_flexible_finetuning_local smoke_test_tcgacrckimagenetmil_local smoke_test_crck_loss_analysis_local smoke_test_slides_panda_loss_analysis_local smoke_test_slides_panda_no_ddp_sampler_local smoke_test_tiles_panda_no_ddp_sampler_local
smoke tests AML: smoke_test_slidespandaimagenetmil_aml smoke_test_tilespandaimagenetmil_aml smoke_test_tcgacrcksslmil_aml smoke_test_crck_simclr_aml smoke_test_crck_flexible_finetuning_aml smoke_test_tcgacrckimagenetmil_aml smoke_test_crck_loss_analysis_aml smoke_test_slides_panda_loss_analysis_aml smoke_test_slides_panda_no_ddp_sampler_aml smoke_test_tiles_panda_no_ddp_sampler_aml

Просмотреть файл

@ -26,22 +26,22 @@ from health_cpath.datamodules.base_module import CacheLocation, CacheMode, Histo
from health_cpath.datasets.base_dataset import SlidesDataset
from health_cpath.models.deepmil import TilesDeepMILModule, SlidesDeepMILModule, BaseDeepMILModule
from health_cpath.models.transforms import EncodeTilesBatchd, LoadTilesBatchd
from health_cpath.utils.deepmil_utils import EncoderParams, PoolingParams
from health_cpath.utils.deepmil_utils import ClassifierParams, EncoderParams, PoolingParams
from health_cpath.utils.output_utils import DeepMILOutputsHandler
from health_cpath.utils.naming import MetricsKey, PlotOption, SlideKey, ModelKey
from health_cpath.utils.tiles_selection_utils import TilesSelector
class BaseMIL(LightningContainer, EncoderParams, PoolingParams, LossCallbackParams):
class BaseMIL(LightningContainer, EncoderParams, PoolingParams, ClassifierParams, LossCallbackParams):
"""BaseMIL is an abstract container defining basic functionality for running MIL experiments in both slides and
tiles settings. It is responsible for instantiating the encoder and pooling layer. Subclasses should define the
full DeepMIL model depending on the type of dataset (tiles/slides based).
"""
dropout_rate: Optional[float] = param.Number(None, bounds=(0, 1), doc="Pre-classifier dropout rate.")
class_names: Optional[Sequence[str]] = param.List(None, item_type=str, doc="List of class names. If `None`, "
"defaults to `('0', '1', ...)`.")
# Data module parameters:
batch_size: int = param.Integer(16, bounds=(1, None), doc="Number of slides to load per batch.")
batch_size_inf: int = param.Integer(16, bounds=(1, None), doc="Number of slides per batch during inference.")
max_bag_size: int = param.Integer(1000, bounds=(0, None),
doc="Upper bound on number of tiles in each loaded bag during training stage. "
"If 0 (default), will return all samples in each bag. "
@ -72,13 +72,6 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams, LossCallbackPara
"generating outputs.")
maximise_primary_metric: bool = param.Boolean(True, doc="Whether the primary validation metric should be "
"maximised (otherwise minimised).")
tune_classifier: bool = param.Boolean(
default=True,
doc="If True (default), fine-tune the classifier during training. If False, keep the classifier frozen.")
pretrained_classifier: bool = param.Boolean(
default=False,
doc="If True, will use classifier weights from pretrained model specified in src_checkpoint. If False, will "
"initiliaze classifier with random weights.")
max_num_workers: int = param.Integer(10, bounds=(0, None),
doc="The maximum number of worker processes for dataloaders. Dataloaders use"
"a heuristic num_cpus/num_gpus to set the number of workers, which can be"
@ -94,6 +87,10 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams, LossCallbackPara
self.best_checkpoint_filename = f"checkpoint_{metric_optim}_val_{self.primary_val_metric.value}"
self.best_checkpoint_filename_with_suffix = self.best_checkpoint_filename + ".ckpt"
self.validate()
if not self.pl_replace_sampler_ddp and self.max_num_gpus > 1:
logging.info(
"Replacing sampler with `DistributedSampler` is disabled. Make sure to set your own DDP sampler"
)
def validate(self) -> None:
super().validate()
@ -154,6 +151,7 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams, LossCallbackPara
val_plot_options=self.get_val_plot_options(),
test_plot_options=self.get_test_plot_options(),
wsi_has_mask=self.wsi_has_mask,
val_set_is_dist=self.pl_replace_sampler_ddp and self.max_num_gpus > 1,
)
if self.num_top_slides > 0:
outputs_handler.tiles_selector = TilesSelector(
@ -176,7 +174,11 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams, LossCallbackPara
num_slides_scatter=self.num_slides_scatter,
num_slides_heatmap=self.num_slides_heatmap,
save_tile_ids=self.save_tile_ids,
log_exceptions=self.log_exceptions))
log_exceptions=self.log_exceptions,
val_set_is_dist=(
self.pl_replace_sampler_ddp and self.max_num_gpus > 1),
)
)
return callbacks
def get_checkpoint_to_test(self) -> Path:
@ -288,15 +290,14 @@ class BaseMILTiles(BaseMIL):
n_classes=self.data_module.train_dataset.n_classes,
class_names=self.class_names,
class_weights=self.data_module.class_weights,
tune_classifier=self.tune_classifier,
pretrained_classifier=self.pretrained_classifier,
dropout_rate=self.dropout_rate,
outputs_folder=self.outputs_folder,
encoder_params=create_from_matching_params(self, EncoderParams),
pooling_params=create_from_matching_params(self, PoolingParams),
classifier_params=create_from_matching_params(self, ClassifierParams),
optimizer_params=create_from_matching_params(self, OptimizerParams),
outputs_folder=self.outputs_folder,
outputs_handler=outputs_handler,
analyse_loss=self.analyse_loss)
analyse_loss=self.analyse_loss,
)
deepmil_module.transfer_weights(self.trained_weights_path)
outputs_handler.set_slides_dataset_for_plots_handlers(self.get_slides_dataset())
return deepmil_module
@ -331,15 +332,14 @@ class BaseMILSlides(BaseMIL):
n_classes=self.data_module.train_dataset.n_classes,
class_names=self.class_names,
class_weights=self.data_module.class_weights,
tune_classifier=self.tune_classifier,
pretrained_classifier=self.pretrained_classifier,
dropout_rate=self.dropout_rate,
outputs_folder=self.outputs_folder,
encoder_params=create_from_matching_params(self, EncoderParams),
pooling_params=create_from_matching_params(self, PoolingParams),
classifier_params=create_from_matching_params(self, ClassifierParams),
optimizer_params=create_from_matching_params(self, OptimizerParams),
outputs_handler=outputs_handler,
analyse_loss=self.analyse_loss)
analyse_loss=self.analyse_loss,
)
deepmil_module.transfer_weights(self.trained_weights_path)
outputs_handler.set_slides_dataset_for_plots_handlers(self.get_slides_dataset())
return deepmil_module

Просмотреть файл

@ -63,6 +63,7 @@ class DeepSMILECrck(BaseMILTiles):
root_path=self.local_datasets[0],
max_bag_size=self.max_bag_size,
batch_size=self.batch_size,
batch_size_inf=self.batch_size_inf,
max_bag_size_inf=self.max_bag_size_inf,
transforms_dict=self.get_transforms_dict(TcgaCrck_TilesDataset.IMAGE_COLUMN),
cache_mode=self.cache_mode,
@ -72,6 +73,7 @@ class DeepSMILECrck(BaseMILTiles):
crossval_index=self.crossval_index,
dataloader_kwargs=self.get_dataloader_kwargs(),
seed=self.get_effective_random_seed(),
pl_replace_sampler_ddp=self.pl_replace_sampler_ddp,
)
def get_test_plot_options(self) -> Set[PlotOption]:

Просмотреть файл

@ -65,6 +65,7 @@ class DeepSMILETilesPanda(BaseMILTiles, BaseDeepSMILEPanda):
# declared in BaseMILTiles:
is_caching=False,
batch_size=8,
batch_size_inf=8,
azure_datasets=[PANDA_5X_TILES_DATASET_ID, PANDA_DATASET_ID])
default_kwargs.update(kwargs)
super().__init__(**default_kwargs)
@ -75,8 +76,9 @@ class DeepSMILETilesPanda(BaseMILTiles, BaseDeepSMILEPanda):
def get_data_module(self) -> PandaTilesDataModule:
return PandaTilesDataModule(
root_path=self.local_datasets[0],
max_bag_size=self.max_bag_size,
batch_size=self.batch_size,
batch_size_inf=self.batch_size_inf,
max_bag_size=self.max_bag_size,
max_bag_size_inf=self.max_bag_size_inf,
transforms_dict=self.get_transforms_dict(PandaTilesDataset.IMAGE_COLUMN),
cache_mode=self.cache_mode,
@ -86,6 +88,7 @@ class DeepSMILETilesPanda(BaseMILTiles, BaseDeepSMILEPanda):
crossval_index=self.crossval_index,
dataloader_kwargs=self.get_dataloader_kwargs(),
seed=self.get_effective_random_seed(),
pl_replace_sampler_ddp=self.pl_replace_sampler_ddp,
)
def get_slides_dataset(self) -> Optional[PandaDataset]:
@ -148,6 +151,7 @@ class DeepSMILESlidesPanda(BaseMILSlides, BaseDeepSMILEPanda):
return PandaSlidesDataModule(
root_path=self.local_datasets[0],
batch_size=self.batch_size,
batch_size_inf=self.batch_size_inf,
level=self.level,
max_bag_size=self.max_bag_size,
max_bag_size_inf=self.max_bag_size_inf,
@ -162,6 +166,7 @@ class DeepSMILESlidesPanda(BaseMILSlides, BaseDeepSMILEPanda):
crossval_count=self.crossval_count,
crossval_index=self.crossval_index,
dataloader_kwargs=self.get_dataloader_kwargs(),
pl_replace_sampler_ddp=self.pl_replace_sampler_ddp,
)
def get_slides_dataset(self) -> PandaDataset:

Просмотреть файл

@ -25,7 +25,7 @@ from health_cpath.models.encoders import (
)
from health_cpath.configs.classification.DeepSMILEPanda import DeepSMILESlidesPanda
from health_cpath.models.deepmil import SlidesDeepMILModule
from health_cpath.utils.deepmil_utils import EncoderParams, PoolingParams
from health_cpath.utils.deepmil_utils import ClassifierParams, EncoderParams, PoolingParams
from health_cpath.utils.naming import MetricsKey, ModelKey, SlideKey
@ -64,6 +64,7 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
encoding_chunk_size=60,
max_bag_size=56,
batch_size=8, # effective batch size = batch_size * num_GPUs
batch_size_inf=8,
max_epochs=50,
l_rate=3e-4,
weight_decay=0,
@ -79,6 +80,7 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
# Params specific to fine-tuning
if self.tune_encoder:
self.batch_size = 2
self.batch_size_inf = 2
super().setup()
def get_transforms_dict(self, image_key: str) -> Dict[ModelKey, Union[Callable, None]]:
@ -100,8 +102,9 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
# Hence, inherited `PandaSlidesDataModuleBenchmark` from `SlidesDataModule`
return PandaSlidesDataModuleBenchmark(
root_path=self.local_datasets[0],
max_bag_size=self.max_bag_size,
batch_size=self.batch_size,
batch_size_inf=self.batch_size_inf,
max_bag_size=self.max_bag_size,
max_bag_size_inf=self.max_bag_size_inf,
level=self.level,
tile_size=self.tile_size,
@ -115,6 +118,7 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
crossval_count=self.crossval_count,
crossval_index=self.crossval_index,
dataloader_kwargs=self.get_dataloader_kwargs(),
pl_replace_sampler_ddp=self.pl_replace_sampler_ddp,
)
def create_model(self) -> SlidesDeepMILModule:
@ -126,10 +130,10 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
n_classes=self.data_module.train_dataset.n_classes,
class_names=self.class_names,
class_weights=self.data_module.class_weights,
dropout_rate=self.dropout_rate,
outputs_folder=self.outputs_folder,
encoder_params=create_from_matching_params(self, EncoderParams),
pooling_params=create_from_matching_params(self, PoolingParams),
classifier_params=create_from_matching_params(self, ClassifierParams),
optimizer_params=create_from_matching_params(self, OptimizerParams),
outputs_handler=outputs_handler,
analyse_loss=self.analyse_loss,

Просмотреть файл

@ -10,7 +10,7 @@ from typing import Any, Callable, Dict, Generic, Optional, Sequence, Tuple, Type
from monai.data.dataset import CacheDataset, Dataset, PersistentDataset
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, DistributedSampler
from health_ml.utils.bag_utils import BagDataset, multibag_collate
from health_ml.utils.common_utils import _create_generator
@ -47,18 +47,21 @@ class HistoDataModule(LightningDataModule, Generic[_SlidesOrTilesDataset]):
self,
root_path: Path,
batch_size: int = 1,
batch_size_inf: Optional[int] = None,
max_bag_size: int = 0,
max_bag_size_inf: int = 0,
seed: Optional[int] = None,
transforms_dict: Optional[Dict[ModelKey, Union[Callable, None]]] = None,
crossval_count: int = 0,
crossval_index: int = 0,
pl_replace_sampler_ddp: bool = True,
dataloader_kwargs: Optional[Dict[str, Any]] = None,
dataframe_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
"""
:param root_path: Root directory of the source dataset.
:param batch_size: Number of slides to load per batch.
:param batch_size_inf: Number of slides to load per batch during inference. If None, use batch_size.
:param max_bag_size: Upper bound on number of tiles in each loaded bag during training stage. If 0 (default),
will return all samples in each bag. If > 0 , bags larger than `max_bag_size` will yield
random subsets of instances. For SlideDataModule, this parameter is used in TileOnGridd Transform to set the
@ -74,19 +77,21 @@ class HistoDataModule(LightningDataModule, Generic[_SlidesOrTilesDataset]):
By default (`None`).
:param crossval_count: Number of folds to perform.
:param crossval_index: Index of the cross validation split to be performed.
:param pl_replace_sampler_ddp: If True, replace the sampler with a DistributedSampler when using DDP.
:param dataloader_kwargs: Additional keyword arguments for the training, validation, and test dataloaders.
:param dataframe_kwargs: Keyword arguments to pass to `pd.read_csv()` when loading the dataset CSV.
"""
batch_size_inf = batch_size_inf or batch_size
super().__init__()
self.root_path = root_path
self.transforms_dict = transforms_dict
self.batch_size = batch_size
self.max_bag_size = max_bag_size
self.max_bag_size_inf = max_bag_size_inf
self.batch_sizes = {ModelKey.TRAIN: batch_size, ModelKey.VAL: batch_size_inf, ModelKey.TEST: batch_size_inf}
self.bag_sizes = {ModelKey.TRAIN: max_bag_size, ModelKey.VAL: max_bag_size_inf, ModelKey.TEST: max_bag_size_inf}
self.crossval_count = crossval_count
self.crossval_index = crossval_index
self.pl_replace_sampler_ddp = pl_replace_sampler_ddp
self.train_dataset: _SlidesOrTilesDataset
self.val_dataset: _SlidesOrTilesDataset
self.test_dataset: _SlidesOrTilesDataset
@ -105,6 +110,13 @@ class HistoDataModule(LightningDataModule, Generic[_SlidesOrTilesDataset]):
) -> DataLoader:
raise NotImplementedError
def _get_ddp_sampler(self, dataset: Dataset, stage: ModelKey) -> Optional[DistributedSampler]:
is_distributed = torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1
if is_distributed and not self.pl_replace_sampler_ddp and stage == ModelKey.TRAIN:
assert self.seed is not None, "seed must be set when using distributed training for reproducibility"
return DistributedSampler(dataset, shuffle=True, seed=self.seed)
return None
def train_dataloader(self) -> DataLoader:
return self._get_dataloader(self.train_dataset, # type: ignore
shuffle=True,
@ -206,15 +218,10 @@ class TilesDataModule(HistoDataModule[TilesDataset]):
generator = _create_generator(self.seed)
if stage in [ModelKey.VAL, ModelKey.TEST]:
eff_max_bag_size = self.max_bag_size_inf
else:
eff_max_bag_size = self.max_bag_size
bag_dataset = BagDataset(
tiles_dataset, # type: ignore
bag_ids=tiles_dataset.slide_ids,
max_bag_size=eff_max_bag_size,
max_bag_size=self.bag_sizes[stage],
shuffle_samples=True,
generator=generator,
)
@ -241,11 +248,14 @@ class TilesDataModule(HistoDataModule[TilesDataset]):
transformed_bag_dataset = self._load_dataset(dataset, stage=stage, shuffle=shuffle)
bag_dataset: BagDataset = transformed_bag_dataset.data # type: ignore
generator = bag_dataset.bag_sampler.generator
sampler = self._get_ddp_sampler(transformed_bag_dataset, stage)
return DataLoader(
transformed_bag_dataset,
batch_size=self.batch_size,
batch_size=self.batch_sizes[stage],
collate_fn=multibag_collate,
shuffle=shuffle,
sampler=sampler,
# sampler option is mutually exclusive with shuffle
shuffle=shuffle if sampler is None else None, # type: ignore
generator=generator,
**dataloader_kwargs,
)
@ -260,13 +270,13 @@ class SlidesDataModule(HistoDataModule[SlidesDataset]):
def __init__(
self,
level: Optional[int] = 1,
tile_size: Optional[int] = 224,
level: int = 1,
tile_size: int = 224,
step: Optional[int] = None,
random_offset: Optional[bool] = True,
pad_full: Optional[bool] = False,
background_val: Optional[int] = 255,
filter_mode: Optional[str] = "min",
random_offset: bool = True,
pad_full: bool = False,
background_val: int = 255,
filter_mode: str = "min",
**kwargs: Any,
) -> None:
"""
@ -298,8 +308,9 @@ class SlidesDataModule(HistoDataModule[SlidesDataset]):
self.filter_mode = filter_mode
# TileOnGridd transform expects None to select all foreground tile so we hardcode max_bag_size and
# max_bag_size_inf to None if set to 0
self.max_bag_size = None if self.max_bag_size == 0 else self.max_bag_size # type: ignore
self.max_bag_size_inf = None if self.max_bag_size_inf == 0 else self.max_bag_size_inf # type: ignore
for stage_key, max_bag_size in self.bag_sizes.items():
if max_bag_size == 0:
self.bag_sizes[stage_key] = None # type: ignore
def _load_dataset(self, slides_dataset: SlidesDataset, stage: ModelKey) -> Dataset:
base_transform = Compose(
@ -314,7 +325,7 @@ class SlidesDataModule(HistoDataModule[SlidesDataset]):
),
TileOnGridd(
keys=slides_dataset.IMAGE_COLUMN,
tile_count=self.max_bag_size if stage == ModelKey.TRAIN else self.max_bag_size_inf,
tile_count=self.bag_sizes[stage],
tile_size=self.tile_size,
step=self.step,
random_offset=self.random_offset if stage == ModelKey.TRAIN else False,
@ -339,11 +350,14 @@ class SlidesDataModule(HistoDataModule[SlidesDataset]):
**dataloader_kwargs: Any) -> DataLoader:
transformed_slides_dataset = self._load_dataset(dataset, stage)
generator = _create_generator(self.seed)
sampler = self._get_ddp_sampler(transformed_slides_dataset, stage)
return DataLoader(
transformed_slides_dataset,
batch_size=self.batch_size,
batch_size=self.batch_sizes[stage],
collate_fn=image_collate,
shuffle=shuffle,
sampler=sampler,
# sampler option is mutually exclusive with shuffle
shuffle=shuffle if not sampler else None, # type: ignore
generator=generator,
**dataloader_kwargs,
)

Просмотреть файл

@ -6,23 +6,20 @@ import torch
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pathlib import Path
from pytorch_lightning import LightningModule
from torch import Tensor, argmax, mode, nn, optim, round
from torchmetrics import (AUROC, F1, Accuracy, ConfusionMatrix, Precision,
Recall, CohenKappa, AveragePrecision, Specificity)
from health_ml.utils import log_on_epoch
from health_ml.deep_learning_config import OptimizerParams
from health_cpath.models.encoders import IdentityEncoder
from health_cpath.utils.deepmil_utils import EncoderParams, PoolingParams, set_module_gradients_enabled
from health_cpath.utils.deepmil_utils import ClassifierParams, EncoderParams, PoolingParams
from health_cpath.datasets.base_dataset import TilesDataset
from health_cpath.utils.naming import DeepMILSubmodules, MetricsKey, ResultsKey, SlideKey, ModelKey, TileKey
from health_cpath.utils.output_utils import (BatchResultsType, DeepMILOutputsHandler, EpochResultsType,
validate_class_names)
validate_class_names, EXTRA_PREFIX)
RESULTS_COLS = [ResultsKey.SLIDE_ID, ResultsKey.TILE_ID, ResultsKey.IMAGE_PATH, ResultsKey.PROB,
ResultsKey.CLASS_PROBS, ResultsKey.PRED_LABEL, ResultsKey.TRUE_LABEL, ResultsKey.BAG_ATTN]
@ -42,34 +39,31 @@ class BaseDeepMILModule(LightningModule):
n_classes: int,
class_weights: Optional[Tensor] = None,
class_names: Optional[Sequence[str]] = None,
tune_classifier: bool = True,
pretrained_classifier: bool = False,
dropout_rate: Optional[float] = None,
verbose: bool = False,
outputs_folder: Optional[Path] = None,
encoder_params: EncoderParams = EncoderParams(),
pooling_params: PoolingParams = PoolingParams(),
classifier_params: ClassifierParams = ClassifierParams(),
optimizer_params: OptimizerParams = OptimizerParams(),
outputs_folder: Optional[Path] = None,
outputs_handler: Optional[DeepMILOutputsHandler] = None,
analyse_loss: Optional[bool] = False) -> None:
analyse_loss: Optional[bool] = False,
verbose: bool = False,
) -> None:
"""
:param label_column: Label key for input batch dictionary.
:param n_classes: Number of output classes for MIL prediction. For binary classification, n_classes should be
set to 1.
:param class_weights: Tensor containing class weights (default=None).
:param class_names: The names of the classes if available (default=None).
:param tune_classifier: Whether to tune the classifier (default=True).
:param pretrained_classifier: Whether to use pretrained classifier (default=False for random init).
:param dropout_rate: Rate of pre-classifier dropout (0-1). `None` for no dropout (default).
:param verbose: if True statements about memory usage are output at each step.
:param outputs_folder: Path to output folder where encoder checkpoint is downloaded.
:param encoder_params: Encoder parameters that specify all encoder specific attributes.
:param pooling_params: Pooling layer parameters that specify all encoder specific attributes.
:param classifier_params: Classifier parameters that specify all classifier specific attributes.
:param optimizer_params: Optimizer parameters that specify all specific attributes to be used for oprimization.
:param outputs_folder: Path to output folder where encoder checkpoint is downloaded.
:param outputs_handler: A configured :py:class:`DeepMILOutputsHandler` object to save outputs for the best
validation epoch and test stage. If omitted (default), no outputs will be saved to disk (aside from usual
metrics logging).
:param analyse_loss: If True, the loss is analysed per sample and analysed with LossAnalysisCallback.
:param verbose: if True statements about memory usage are output at each step.
"""
super().__init__()
@ -79,29 +73,29 @@ class BaseDeepMILModule(LightningModule):
self.class_weights = class_weights
self.class_names = validate_class_names(class_names, self.n_classes)
self.dropout_rate = dropout_rate
self.encoder_params = encoder_params
self.pooling_params = pooling_params
self.classifier_params = classifier_params
self.optimizer_params = optimizer_params
self.save_hyperparameters()
self.verbose = verbose
self.outputs_handler = outputs_handler
self.pretrained_classifier = pretrained_classifier
# This flag can be switched on before invoking trainer.validate() to enable saving additional time/memory
# consuming validation outputs via calling self.on_run_extra_validation_epoch()
self._run_extra_val_epoch = False
self.tune_classifier = tune_classifier
self._on_extra_val_epoch = False
# Model components
self.encoder = encoder_params.get_encoder(outputs_folder)
self.aggregation_fn, self.num_pooling = pooling_params.get_pooling_layer(self.encoder.num_encoding)
self.classifier_fn = self.get_classifier()
self.classifier_fn = classifier_params.get_classifier(self.num_pooling, self.n_classes)
self.activation_fn = self.get_activation()
self.analyse_loss = analyse_loss
# Loss function
self.loss_fn = self.get_loss(reduction="mean")
self.loss_fn_no_reduction = self.get_loss(reduction="none")
self.analyse_loss = analyse_loss
# Metrics Objects
self.train_metrics = self.get_metrics()
@ -149,23 +143,12 @@ class BaseDeepMILModule(LightningModule):
if self.pooling_params.pretrained_pooling:
self.copy_weights(self.aggregation_fn, pretrained_model.aggregation_fn, DeepMILSubmodules.POOLING)
if self.pretrained_classifier:
if self.classifier_params.pretrained_classifier:
if pretrained_model.n_classes != self.n_classes:
raise ValueError(f"Number of classes in pretrained model {pretrained_model.n_classes} "
f"does not match number of classes in current model {self.n_classes}.")
self.copy_weights(self.classifier_fn, pretrained_model.classifier_fn, DeepMILSubmodules.CLASSIFIER)
def get_classifier(self) -> nn.Module:
classifier_layer = nn.Linear(in_features=self.num_pooling,
out_features=self.n_classes)
set_module_gradients_enabled(classifier_layer, self.tune_classifier)
if self.dropout_rate is None:
return classifier_layer
elif 0 <= self.dropout_rate < 1:
return nn.Sequential(nn.Dropout(self.dropout_rate), classifier_layer)
else:
raise ValueError(f"Dropout rate should be in [0, 1), got {self.dropout_rate}")
def get_loss(self, reduction: str = "mean") -> Callable:
if self.n_classes > 1:
if self.class_weights is None:
@ -217,8 +200,12 @@ class BaseDeepMILModule(LightningModule):
MetricsKey.RECALL: Recall(),
MetricsKey.SPECIFICITY: Specificity()})
def log_metrics(self, stage: str) -> None:
valid_stages = [stage for stage in ModelKey]
def get_extra_prefix(self) -> str:
"""Get extra prefix for the metrics name to avoir overriding best validation metrics."""
return EXTRA_PREFIX if self._on_extra_val_epoch else ""
def log_metrics(self, stage: str, prefix: str = '') -> None:
valid_stages = set([stage for stage in ModelKey])
if stage not in valid_stages:
raise Exception(f"Invalid stage. Chose one of {valid_stages}")
for metric_name, metric_object in self.get_metrics_dict(stage).items():
@ -226,9 +213,9 @@ class BaseDeepMILModule(LightningModule):
metric_value = metric_object.compute()
metric_value_n = metric_value / metric_value.sum(axis=1, keepdims=True)
for i in range(metric_value_n.shape[0]):
log_on_epoch(self, f'{stage}/{self.class_names[i]}', metric_value_n[i, i])
log_on_epoch(self, f'{prefix}{stage}/{self.class_names[i]}', metric_value_n[i, i])
else:
log_on_epoch(self, f'{stage}/{metric_name}', metric_object)
log_on_epoch(self, f'{prefix}{stage}/{metric_name}', metric_object)
def get_instance_features(self, instances: Tensor) -> Tensor:
if not self.encoder_params.tune_encoder:
@ -252,7 +239,7 @@ class BaseDeepMILModule(LightningModule):
return attentions, bag_features
def get_bag_logit(self, bag_features: Tensor) -> Tensor:
if not self.tune_classifier:
if not self.classifier_params.tune_classifier:
self.classifier_fn.eval()
bag_logit = self.classifier_fn(bag_features)
return bag_logit
@ -295,6 +282,14 @@ class BaseDeepMILModule(LightningModule):
"""Update training results with data specific info. This can be either tiles or slides related metadata."""
raise NotImplementedError
def update_slides_selection(self, stage: str, batch: Dict, results: Dict) -> None:
if (
(stage == ModelKey.TEST or (stage == ModelKey.VAL and self._on_extra_val_epoch))
and self.outputs_handler
and self.outputs_handler.tiles_selector
):
self.outputs_handler.tiles_selector.update_slides_selection(batch, results)
def _compute_loss(self, loss_fn: Callable, bag_logits: Tensor, bag_labels: Tensor) -> Tensor:
if self.n_classes > 1:
return loss_fn(bag_logits, bag_labels.long())
@ -324,6 +319,7 @@ class BaseDeepMILModule(LightningModule):
predicted_probs = predicted_probs.squeeze(dim=1)
results = dict()
if self.analyse_loss and stage in [ModelKey.TRAIN, ModelKey.VAL]:
loss_per_sample = self._compute_loss(self.loss_fn_no_reduction, bag_logits, bag_labels)
results[ResultsKey.LOSS_PER_SAMPLE] = loss_per_sample.detach().cpu().numpy()
@ -340,18 +336,12 @@ class BaseDeepMILModule(LightningModule):
ResultsKey.BAG_ATTN: bag_attn_list
})
self.update_results_with_data_specific_info(batch=batch, results=results)
if (
(stage == ModelKey.TEST or (stage == ModelKey.VAL and self._run_extra_val_epoch))
and self.outputs_handler
and self.outputs_handler.tiles_selector
):
self.outputs_handler.tiles_selector.update_slides_selection(batch, results)
self.update_slides_selection(stage=stage, batch=batch, results=results)
return results
def training_step(self, batch: Dict, batch_idx: int) -> BatchResultsType: # type: ignore
train_result = self._shared_step(batch, batch_idx, ModelKey.TRAIN)
self.log('train/loss', train_result[ResultsKey.LOSS], on_epoch=True, on_step=True, logger=True,
sync_dist=True)
self.log('train/loss', train_result[ResultsKey.LOSS], on_epoch=True, on_step=True, logger=True, sync_dist=True)
if self.verbose:
print(f"After loading images batch {batch_idx} -", _format_cuda_memory_stats())
results = {ResultsKey.LOSS: train_result[ResultsKey.LOSS]}
@ -363,34 +353,29 @@ class BaseDeepMILModule(LightningModule):
def validation_step(self, batch: Dict, batch_idx: int) -> BatchResultsType: # type: ignore
val_result = self._shared_step(batch, batch_idx, ModelKey.VAL)
self.log('val/loss', val_result[ResultsKey.LOSS], on_epoch=True, on_step=True, logger=True,
sync_dist=True)
name = f'{self.get_extra_prefix()}val/loss'
self.log(name, val_result[ResultsKey.LOSS], on_epoch=True, on_step=True, logger=True, sync_dist=True)
return val_result
def test_step(self, batch: Dict, batch_idx: int) -> BatchResultsType: # type: ignore
test_result = self._shared_step(batch, batch_idx, ModelKey.TEST)
self.log('test/loss', test_result[ResultsKey.LOSS], on_epoch=True, on_step=True, logger=True,
sync_dist=True)
self.log('test/loss', test_result[ResultsKey.LOSS], on_epoch=True, on_step=True, logger=True, sync_dist=True)
return test_result
def training_epoch_end(self, outputs: EpochResultsType) -> None: # type: ignore
self.log_metrics(ModelKey.TRAIN)
def validation_epoch_end(self, epoch_results: EpochResultsType) -> None: # type: ignore
self.log_metrics(ModelKey.VAL)
self.log_metrics(stage=ModelKey.VAL, prefix=self.get_extra_prefix())
if self.outputs_handler:
self.outputs_handler.save_validation_outputs(
epoch_results=epoch_results,
metrics_dict=self.get_metrics_dict(ModelKey.VAL), # type: ignore
epoch=self.current_epoch,
is_global_rank_zero=self.global_rank == 0,
run_extra_val_epoch=self._run_extra_val_epoch
on_extra_val=self._on_extra_val_epoch
)
# Reset the top and bottom slides heaps
if self.outputs_handler.tiles_selector is not None:
self.outputs_handler.tiles_selector._clear_cached_slides_heaps()
def test_epoch_end(self, epoch_results: EpochResultsType) -> None: # type: ignore
self.log_metrics(ModelKey.TEST)
if self.outputs_handler:
@ -402,7 +387,7 @@ class BaseDeepMILModule(LightningModule):
def on_run_extra_validation_epoch(self) -> None:
"""Hook to be called at the beginning of an extra validation epoch to set validation plots options to the same
as the test plots options."""
self._run_extra_val_epoch = True
self._on_extra_val_epoch = True
if self.outputs_handler:
self.outputs_handler.val_plots_handler.plot_options = self.outputs_handler.test_plots_handler.plot_options

Просмотреть файл

@ -93,6 +93,7 @@ class LossAnalysisCallback(Callback):
save_tile_ids: bool = False,
log_exceptions: bool = True,
create_outputs_folders: bool = True,
val_set_is_dist: bool = True,
) -> None:
"""
@ -107,6 +108,7 @@ class LossAnalysisCallback(Callback):
:param log_exceptions: If True, will log exceptions raised during loss values analysis, defaults to True. If
False will raise the intercepted exceptions.
:param create_outputs_folders: If True, will create the output folders if they don't exist, defaults to True.
:param val_set_is_dist: If True, will assume that the validation set is distributed, defaults to True.
"""
self.patience = patience
@ -116,6 +118,7 @@ class LossAnalysisCallback(Callback):
self.num_slides_heatmap = num_slides_heatmap
self.save_tile_ids = save_tile_ids
self.log_exceptions = log_exceptions
self.val_set_is_dist = val_set_is_dist
self.outputs_folder = outputs_folder / "loss_analysis_callback"
if create_outputs_folders:
@ -216,7 +219,7 @@ class LossAnalysisCallback(Callback):
def gather_loss_cache(self, rank: int, stage: ModelKey) -> None:
"""Gathers the loss cache from all the workers"""
if torch.distributed.is_initialized():
if self.val_set_is_dist and torch.distributed.is_initialized():
world_size = torch.distributed.get_world_size()
if world_size > 1:
loss_caches = [None] * world_size
@ -533,7 +536,7 @@ class LossAnalysisCallback(Callback):
"""Hook called at the end of validation. Plot the loss heatmap and scratter plots after ranking the slides by
loss values."""
epoch = trainer.current_epoch
if pl_module.global_rank == 0 and not pl_module._run_extra_val_epoch and epoch == (self.max_epochs - 1):
if pl_module.global_rank == 0 and not pl_module._on_extra_val_epoch and epoch == (self.max_epochs - 1):
try:
self.save_loss_outliers_analaysis_results(stage=ModelKey.VAL)
except Exception as e:

Просмотреть файл

@ -177,3 +177,22 @@ class PoolingParams(param.Parameterized):
num_features = num_encoding * self.pool_out_dim
set_module_gradients_enabled(pooling_layer, tuning_flag=self.tune_pooling)
return pooling_layer, num_features
class ClassifierParams(param.Parameterized):
dropout_rate: Optional[float] = param.Number(None, bounds=(0, 1), doc="Pre-classifier dropout rate.")
tune_classifier: bool = param.Boolean(
default=True,
doc="If True (default), fine-tune the classifier during training. If False, keep the classifier frozen.")
pretrained_classifier: bool = param.Boolean(
default=False,
doc="If True, will use classifier weights from pretrained model specified in src_checkpoint. If False, will "
"initiliaze classifier with random weights.")
def get_classifier(self, in_features: int, out_features: int) -> nn.Module:
classifier_layer = nn.Linear(in_features=in_features,
out_features=out_features)
set_module_gradients_enabled(classifier_layer, tuning_flag=self.tune_classifier)
if self.dropout_rate is None:
return classifier_layer
return nn.Sequential(nn.Dropout(self.dropout_rate), classifier_layer)

Просмотреть файл

@ -26,6 +26,8 @@ OUTPUTS_CSV_FILENAME = "test_output.csv"
VAL_OUTPUTS_SUBDIR = "val"
PREV_VAL_OUTPUTS_SUBDIR = "val_old"
TEST_OUTPUTS_SUBDIR = "test"
EXTRA_VAL_OUTPUTS_SUBDIR = "extra_val"
EXTRA_PREFIX = "extra_"
AML_OUTPUTS_DIR = "outputs"
AML_LEGACY_TEST_OUTPUTS_CSV = "/".join([AML_OUTPUTS_DIR, OUTPUTS_CSV_FILENAME])
@ -198,7 +200,7 @@ class OutputsPolicy:
YAML().dump(contents, self.best_metric_file_path)
def should_save_validation_outputs(self, metrics_dict: Mapping[MetricsKey, Metric], epoch: int,
is_global_rank_zero: bool = True) -> bool:
is_global_rank_zero: bool = True, on_extra_val: bool = False) -> bool:
"""Determine whether validation outputs should be saved given the current epoch's metrics.
:param metrics_dict: Current epoch's metrics dictionary from
@ -206,8 +208,11 @@ class OutputsPolicy:
:param epoch: Current epoch number.
:param is_global_rank_zero: Whether this is the global rank-0 process in distributed scenarios.
Set to `True` (default) if running a single process.
:param on_extra_val: Whether this is an extra validation epoch (e.g. after training).
:return: Whether this is the best validation epoch so far.
"""
if on_extra_val:
return False
metric = metrics_dict[self.primary_val_metric]
# If the metric hasn't been updated we don't want to save it
if not metric._update_called:
@ -252,7 +257,8 @@ class DeepMILOutputsHandler:
def __init__(self, outputs_root: Path, n_classes: int, tile_size: int, level: int,
class_names: Optional[Sequence[str]], primary_val_metric: MetricsKey,
maximise: bool, val_plot_options: Collection[PlotOption],
test_plot_options: Collection[PlotOption], wsi_has_mask: bool = True) -> None:
test_plot_options: Collection[PlotOption], wsi_has_mask: bool = True,
val_set_is_dist: bool = True) -> None:
"""
:param outputs_root: Root directory where to save all produced outputs.
:param n_classes: Number of MIL classes (set `n_classes=1` for binary).
@ -265,6 +271,10 @@ class DeepMILOutputsHandler:
:param val_plot_options: The desired plot options for validation time.
:param test_plot_options: The desired plot options for test time.
:param wsi_has_mask: Whether the whole slides have a mask to crop specific ROIs.
:param val_set_is_dist: If True, the validation set is distributed across processes. Otherwise, the validation
set is replicated on each process. This shouldn't affect the results, as we take the mean of the validation
set metrics across processes. This is only relevant for the outputs_handler, which needs to know whether to
gather the validation set outputs across processes or not before saving them.
"""
self.outputs_root = outputs_root
self.n_classes = n_classes
@ -294,11 +304,16 @@ class DeepMILOutputsHandler:
stage=ModelKey.TEST,
wsi_has_mask=wsi_has_mask
)
self.val_set_is_dist = val_set_is_dist
@property
def validation_outputs_dir(self) -> Path:
return self.outputs_root / VAL_OUTPUTS_SUBDIR
@property
def extra_validation_outputs_dir(self) -> Path:
return self.outputs_root / EXTRA_VAL_OUTPUTS_SUBDIR
@property
def previous_validation_outputs_dir(self) -> Path:
return self.validation_outputs_dir.with_name(PREV_VAL_OUTPUTS_SUBDIR)
@ -313,11 +328,15 @@ class DeepMILOutputsHandler:
self.test_plots_handler.slides_dataset = slides_dataset
self.val_plots_handler.slides_dataset = slides_dataset
def should_gather_tiles(self, plots_handler: DeepMILPlotsHandler) -> bool:
return PlotOption.TOP_BOTTOM_TILES in plots_handler.plot_options and self.tiles_selector is not None
def _save_outputs(self, epoch_results: EpochResultsType, outputs_dir: Path, stage: ModelKey = ModelKey.VAL) -> None:
"""Trigger the rendering and saving of DeepMIL outputs and figures.
:param epoch_results: Aggregated results from all epoch batches.
:param outputs_dir: Specific directory into which outputs should be saved (different for validation and test).
:param stage: The stage of the model (e.g. `ModelKey.VAL` or `ModelKey.TEST`).
"""
# outputs object consists of a list of dictionaries (of metadata and results, including encoded features)
# It can be indexed as outputs[batch_idx][batch_key][bag_idx][tile_idx]
@ -334,8 +353,7 @@ class DeepMILOutputsHandler:
plots_handler.save_plots(outputs_dir, self.tiles_selector, results)
def save_validation_outputs(self, epoch_results: EpochResultsType, metrics_dict: Mapping[MetricsKey, Metric],
epoch: int, is_global_rank_zero: bool = True, run_extra_val_epoch: bool = False
) -> None:
epoch: int, is_global_rank_zero: bool = True, on_extra_val: bool = False) -> None:
"""Render and save validation epoch outputs, according to the configured :py:class:`OutputsPolicy`.
:param epoch_results: Aggregated results from all epoch batches, as passed to :py:meth:`validation_epoch_end()`.
@ -344,29 +362,35 @@ class DeepMILOutputsHandler:
:param is_global_rank_zero: Whether this is the global rank-0 process in distributed scenarios.
Set to `True` (default) if running a single process.
:param epoch: Current epoch number.
:param on_extra_val: Whether this is an extra validation epoch (e.g. after training).
"""
# All DDP processes must reach this point to allow synchronising epoch results
gathered_epoch_results = gather_results(epoch_results)
if PlotOption.TOP_BOTTOM_TILES in self.val_plots_handler.plot_options and self.tiles_selector:
self.tiles_selector.gather_selected_tiles_across_devices()
# All DDP processes must reach this point to allow synchronising epoch results if val_set_is_dist is True
if self.val_set_is_dist:
epoch_results = gather_results(epoch_results)
if self.should_gather_tiles(self.val_plots_handler):
self.tiles_selector.gather_selected_tiles_across_devices() # type: ignore
# Only global rank-0 process should actually render and save the outputs
# We also want to save the plots of the extra validation epoch
if (
self.outputs_policy.should_save_validation_outputs(metrics_dict, epoch, is_global_rank_zero)
or (run_extra_val_epoch and is_global_rank_zero)
):
if self.outputs_policy.should_save_validation_outputs(metrics_dict, epoch, is_global_rank_zero, on_extra_val):
# First move existing outputs to a temporary directory, to avoid mixing
# outputs of different epochs in case writing fails halfway through
if self.validation_outputs_dir.exists():
replace_directory(source=self.validation_outputs_dir,
target=self.previous_validation_outputs_dir)
self._save_outputs(gathered_epoch_results, self.validation_outputs_dir, ModelKey.VAL)
self._save_outputs(epoch_results, self.validation_outputs_dir, ModelKey.VAL)
# Writing completed successfully; delete temporary back-up
if self.previous_validation_outputs_dir.exists():
shutil.rmtree(self.previous_validation_outputs_dir, ignore_errors=True)
elif on_extra_val and is_global_rank_zero:
self._save_outputs(epoch_results, self.extra_validation_outputs_dir, ModelKey.VAL)
# Reset the top and bottom slides heaps
if self.should_gather_tiles(self.val_plots_handler):
self.tiles_selector._clear_cached_slides_heaps() # type: ignore
def save_test_outputs(self, epoch_results: EpochResultsType, is_global_rank_zero: bool = True) -> None:
"""Render and save test epoch outputs.
@ -378,8 +402,8 @@ class DeepMILOutputsHandler:
"""
# All DDP processes must reach this point to allow synchronising epoch results
gathered_epoch_results = gather_results(epoch_results)
if PlotOption.TOP_BOTTOM_TILES in self.test_plots_handler.plot_options and self.tiles_selector:
self.tiles_selector.gather_selected_tiles_across_devices()
if self.should_gather_tiles(self.test_plots_handler):
self.tiles_selector.gather_selected_tiles_across_devices() # type: ignore
# Only global rank-0 process should actually render and save the outputs-
if self.outputs_policy.should_save_test_outputs(is_global_rank_zero):

Просмотреть файл

@ -8,11 +8,10 @@ import logging
import shutil
import sys
import uuid
import pytest
from pathlib import Path
from typing import Generator
import pytest
# temporary workaround until these hi-ml package release
testhisto_root_dir = Path(__file__).parent
print(f"Adding {testhisto_root_dir} to sys path")
@ -31,6 +30,7 @@ for package, subpackages in packages.items():
from health_ml.utils.fixed_paths import OutputFolderForTests # noqa: E402
from testhisto.mocks.base_data_generator import MockHistoDataType # noqa: E402
from testhisto.mocks.tiles_generator import MockPandaTilesGenerator # noqa: E402
from testhisto.mocks.slides_generator import MockPandaSlidesGenerator, TilesPositioningType # noqa: E402
def remove_and_create_folder(folder: Path) -> None:
@ -96,3 +96,26 @@ def mock_panda_tiles_root_dir(
tiles_generator.generate_mock_histo_data()
yield tmp_root_dir
shutil.rmtree(tmp_root_dir)
@pytest.fixture(scope="session")
def mock_panda_slides_root_dir(
tmp_path_factory: pytest.TempPathFactory, tmp_path_to_pathmnist_dataset: Path
) -> Generator:
tmp_root_dir = tmp_path_factory.mktemp("mock_slides")
wsi_generator = MockPandaSlidesGenerator(
dest_data_path=tmp_root_dir,
src_data_path=tmp_path_to_pathmnist_dataset,
mock_type=MockHistoDataType.PATHMNIST,
n_tiles=4,
n_slides=15,
n_channels=3,
n_levels=3,
tile_size=28,
background_val=255,
tiles_pos_type=TilesPositioningType.RANDOM
)
logging.info("Generating temporary mock slides that will be deleted at the end of the session.")
wsi_generator.generate_mock_histo_data()
yield tmp_root_dir
shutil.rmtree(tmp_root_dir)

Просмотреть файл

@ -0,0 +1,184 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import pytest
import torch
from pathlib import Path
from typing import List, Optional
from unittest.mock import MagicMock, patch
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, SequentialSampler
from health_cpath.datamodules.base_module import HistoDataModule
from health_cpath.datamodules.panda_module import PandaSlidesDataModule, PandaTilesDataModule
from health_cpath.utils.naming import ModelKey, SlideKey
from health_ml.utils.common_utils import is_gpu_available
from testhisto.utils.utils_testhisto import run_distributed
no_gpu = not is_gpu_available()
def _assert_correct_bag_sizes(datamodule: HistoDataModule, max_bag_size: int, max_bag_size_inf: Optional[int],
true_bag_sizes: List[int]) -> None:
# True bag sizes are the bag sizes that are generated by the mock data generator for a fixed seed as the tiles count
# (and therefore bag sizes) are random to reflect real data with varying number of tiles per slide.
for stage_key, bag_size in zip([m for m in ModelKey], [max_bag_size, max_bag_size_inf, max_bag_size_inf]):
assert datamodule.bag_sizes[stage_key] == bag_size
def _assert_bag_size_matching(dataloader: DataLoader, expected_bag_sizes: List[int]) -> None:
sample = next(iter(dataloader))
for i, slide in enumerate(sample[SlideKey.IMAGE]):
assert slide.shape[0] == expected_bag_sizes[i]
_assert_bag_size_matching(datamodule.train_dataloader(), [max_bag_size, max_bag_size])
expected_bag_sizes = true_bag_sizes if not max_bag_size_inf else [max_bag_size_inf, max_bag_size_inf]
_assert_bag_size_matching(datamodule.val_dataloader(), expected_bag_sizes)
_assert_bag_size_matching(datamodule.test_dataloader(), expected_bag_sizes)
def _assert_correct_batch_sizes(datamodule: HistoDataModule, batch_size: int, batch_size_inf: Optional[int]) -> None:
batch_size_inf = batch_size_inf if batch_size_inf is not None else batch_size
for stage_key, _batch_size in zip([m for m in ModelKey], [batch_size, batch_size_inf, batch_size_inf]):
assert datamodule.batch_sizes[stage_key] == _batch_size
def _assert_batch_size_matching(dataloader: DataLoader, expected_batch_size: int) -> None:
sample = next(iter(dataloader))
assert len(sample[SlideKey.IMAGE]) == expected_batch_size
_assert_batch_size_matching(datamodule.train_dataloader(), batch_size)
_assert_batch_size_matching(datamodule.val_dataloader(), batch_size_inf)
_assert_batch_size_matching(datamodule.test_dataloader(), batch_size_inf)
@pytest.mark.skipif(no_gpu, reason="Test requires GPU")
@pytest.mark.gpu
@pytest.mark.parametrize("max_bag_size, max_bag_size_inf", [(2, 0), (2, 3)])
def test_slides_datamodule_different_bag_sizes(
mock_panda_slides_root_dir: Path, max_bag_size: int, max_bag_size_inf: int
) -> None:
datamodule = PandaSlidesDataModule(
root_path=mock_panda_slides_root_dir,
batch_size=2,
max_bag_size=max_bag_size,
max_bag_size_inf=max_bag_size_inf,
tile_size=28,
level=0,
)
# To account for the fact that slides datamodule fomats 0 to None so that it's compatible with TileOnGrid transform
max_bag_size_inf = max_bag_size_inf if max_bag_size_inf != 0 else None # type: ignore
# For slides datamodule, the true bag sizes [4, 4] are the same as requested to TileOnGrid transform
_assert_correct_bag_sizes(datamodule, max_bag_size, max_bag_size_inf, true_bag_sizes=[4, 4])
@pytest.mark.parametrize("max_bag_size, max_bag_size_inf", [(2, 0), (2, 3)])
def test_tiles_datamodule_different_bag_sizes(
mock_panda_tiles_root_dir: Path, max_bag_size: int, max_bag_size_inf: int
) -> None:
datamodule = PandaTilesDataModule(
root_path=mock_panda_tiles_root_dir,
batch_size=2,
max_bag_size=max_bag_size,
max_bag_size_inf=max_bag_size_inf,
)
# For tiles datamodule, the true bag sizes [4, 5] were generated by the mock data generator for a fixed seed 42
# If test fails, check if the mock data generator has changed and update the true bag sizes
_assert_correct_bag_sizes(datamodule, max_bag_size, max_bag_size_inf, true_bag_sizes=[4, 5])
@pytest.mark.skipif(no_gpu, reason="Test requires GPU")
@pytest.mark.gpu
@pytest.mark.parametrize("batch_size, batch_size_inf", [(2, 2), (2, 1), (2, None)])
def test_slides_datamodule_different_batch_sizes(
mock_panda_slides_root_dir: Path, batch_size: int, batch_size_inf: Optional[int],
) -> None:
datamodule = PandaSlidesDataModule(
root_path=mock_panda_slides_root_dir,
batch_size=batch_size,
batch_size_inf=batch_size_inf,
max_bag_size=16,
max_bag_size_inf=16,
tile_size=28,
level=0,
)
_assert_correct_batch_sizes(datamodule, batch_size, batch_size_inf)
@pytest.mark.parametrize("batch_size, batch_size_inf", [(2, 2), (2, 1), (2, None)])
def test_tiles_datamodule_different_batch_sizes(
mock_panda_tiles_root_dir: Path, batch_size: int, batch_size_inf: Optional[int],
) -> None:
datamodule = PandaTilesDataModule(
root_path=mock_panda_tiles_root_dir,
batch_size=batch_size,
batch_size_inf=batch_size_inf,
max_bag_size=16,
max_bag_size_inf=16,
)
_assert_correct_batch_sizes(datamodule, batch_size, batch_size_inf)
def _validate_sampler_type(datamodule: HistoDataModule, stages: List[ModelKey], expected_none: bool) -> None:
expected_sampler_types = {
ModelKey.TRAIN: RandomSampler if expected_none else DistributedSampler,
ModelKey.VAL: SequentialSampler,
ModelKey.TEST: SequentialSampler,
}
for stage in stages:
datamodule_sampler = datamodule._get_ddp_sampler(getattr(datamodule, f'{stage}_dataset'), stage)
assert (datamodule_sampler is None) == expected_none
dataloader = getattr(datamodule, f'{stage.value}_dataloader')()
assert isinstance(dataloader.sampler, expected_sampler_types[stage])
def _test_datamodule_pl_ddp_sampler_true(
datamodule: HistoDataModule, rank: int = 0, world_size: int = 1, device: str = "cpu"
) -> None:
datamodule.setup()
_validate_sampler_type(datamodule, [ModelKey.TRAIN, ModelKey.VAL, ModelKey.TEST], expected_none=True)
def _test_datamodule_pl_ddp_sampler_false(
datamodule: HistoDataModule, rank: int = 0, world_size: int = 1, device: str = "cpu"
) -> None:
datamodule.setup()
_validate_sampler_type(datamodule, [ModelKey.VAL, ModelKey.TEST], expected_none=True)
_validate_sampler_type(datamodule, [ModelKey.TRAIN], expected_none=False)
@pytest.mark.skipif(not torch.distributed.is_available(), reason="PyTorch distributed unavailable")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Not enough GPUs available")
@pytest.mark.gpu
def test_slides_datamodule_pl_replace_sampler_ddp(mock_panda_slides_root_dir: Path) -> None:
slides_datamodule = PandaSlidesDataModule(root_path=mock_panda_slides_root_dir,
pl_replace_sampler_ddp=True,
seed=42)
run_distributed(_test_datamodule_pl_ddp_sampler_true, [slides_datamodule], world_size=2)
slides_datamodule.pl_replace_sampler_ddp = False
run_distributed(_test_datamodule_pl_ddp_sampler_false, [slides_datamodule], world_size=2)
@pytest.mark.skipif(not torch.distributed.is_available(), reason="PyTorch distributed unavailable")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Not enough GPUs available")
@pytest.mark.gpu
def test_tiles_datamodule_pl_replace_sampler_ddp_true(mock_panda_tiles_root_dir: Path) -> None:
tiles_datamodule = PandaTilesDataModule(root_path=mock_panda_tiles_root_dir, seed=42, pl_replace_sampler_ddp=True)
run_distributed(_test_datamodule_pl_ddp_sampler_true, [tiles_datamodule], world_size=2)
@pytest.mark.skipif(not torch.distributed.is_available(), reason="PyTorch distributed unavailable")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Not enough GPUs available")
@pytest.mark.gpu
def test_tiles_datamodule_pl_replace_sampler_ddp_false(mock_panda_tiles_root_dir: Path) -> None:
tiles_datamodule = PandaTilesDataModule(root_path=mock_panda_tiles_root_dir, seed=42, pl_replace_sampler_ddp=False)
run_distributed(_test_datamodule_pl_ddp_sampler_false, [tiles_datamodule], world_size=2)
def test_assertion_error_missing_seed(mock_panda_slides_root_dir: Path) -> None:
with pytest.raises(AssertionError, match="seed must be set when using distributed training for reproducibility"):
with patch("torch.distributed.is_initialized", return_value=True):
with patch("torch.distributed.get_world_size", return_value=2):
slides_datamodule = PandaSlidesDataModule(
root_path=mock_panda_slides_root_dir, pl_replace_sampler_ddp=False
)
slides_datamodule._get_ddp_sampler(MagicMock(), ModelKey.TRAIN)

Просмотреть файл

@ -1,17 +1,17 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import shutil
from typing import Generator, Dict, Callable, Union, Tuple
import pytest
import logging
import numpy as np
import torch
from pathlib import Path
from monai.transforms import RandFlipd
from typing import Generator, Dict, Callable, Union, Tuple
from torch.utils.data import DataLoader
from health_ml.utils.common_utils import is_gpu_available
from health_cpath.datamodules.base_module import SlidesDataModule
@ -29,7 +29,7 @@ no_gpu = not is_gpu_available()
@pytest.fixture(scope="session")
def mock_panda_slides_root_dir(
def mock_panda_slides_root_dir_diagonal(
tmp_path_factory: pytest.TempPathFactory, tmp_path_to_pathmnist_dataset: Path
) -> Generator:
tmp_root_dir = tmp_path_factory.mktemp("mock_wsi")
@ -38,7 +38,7 @@ def mock_panda_slides_root_dir(
src_data_path=tmp_path_to_pathmnist_dataset,
mock_type=MockHistoDataType.PATHMNIST,
n_tiles=1,
n_slides=10,
n_slides=16,
n_repeat_diag=4,
n_repeat_tile=2,
n_channels=3,
@ -83,7 +83,7 @@ def get_original_tile(mock_dir: Path, wsi_id: str) -> np.ndarray:
@pytest.mark.skipif(no_gpu, reason="Test requires GPU")
@pytest.mark.gpu
def test_tiling_on_the_fly(mock_panda_slides_root_dir: Path) -> None:
def test_tiling_on_the_fly(mock_panda_slides_root_dir_diagonal: Path) -> None:
batch_size = 1
tile_count = 16
tile_size = 28
@ -91,7 +91,7 @@ def test_tiling_on_the_fly(mock_panda_slides_root_dir: Path) -> None:
channels = 3
assert_batch_index = 0
datamodule = PandaSlidesDataModule(
root_path=mock_panda_slides_root_dir,
root_path=mock_panda_slides_root_dir_diagonal,
batch_size=batch_size,
max_bag_size=tile_count,
tile_size=tile_size,
@ -105,14 +105,14 @@ def test_tiling_on_the_fly(mock_panda_slides_root_dir: Path) -> None:
assert tiles[assert_batch_index].shape == (tile_count, channels, tile_size, tile_size)
# check tiling on the fly
original_tile = get_original_tile(mock_panda_slides_root_dir, wsi_id)
original_tile = get_original_tile(mock_panda_slides_root_dir_diagonal, wsi_id)
for i in range(tile_count):
assert (original_tile == tiles[assert_batch_index][i].numpy()).all()
@pytest.mark.skipif(no_gpu, reason="Test requires GPU")
@pytest.mark.gpu
def test_tiling_without_fixed_tile_count(mock_panda_slides_root_dir: Path) -> None:
def test_tiling_without_fixed_tile_count(mock_panda_slides_root_dir_diagonal: Path) -> None:
batch_size = 1
tile_count = None
tile_size = 28
@ -120,7 +120,7 @@ def test_tiling_without_fixed_tile_count(mock_panda_slides_root_dir: Path) -> No
assert_batch_index = 0
min_expected_tile_count = 16
datamodule = PandaSlidesDataModule(
root_path=mock_panda_slides_root_dir,
root_path=mock_panda_slides_root_dir_diagonal,
batch_size=batch_size,
max_bag_size=tile_count,
tile_size=tile_size,
@ -135,14 +135,14 @@ def test_tiling_without_fixed_tile_count(mock_panda_slides_root_dir: Path) -> No
@pytest.mark.skipif(no_gpu, reason="Test requires GPU")
@pytest.mark.gpu
@pytest.mark.parametrize("level", [0, 1, 2])
def test_multi_resolution_tiling(level: int, mock_panda_slides_root_dir: Path) -> None:
def test_multi_resolution_tiling(level: int, mock_panda_slides_root_dir_diagonal: Path) -> None:
batch_size = 1
tile_count = 16
channels = 3
tile_size = 28 // 2 ** level
assert_batch_index = 0
datamodule = PandaSlidesDataModule(
root_path=mock_panda_slides_root_dir,
root_path=mock_panda_slides_root_dir_diagonal,
batch_size=batch_size,
max_bag_size=tile_count,
tile_size=tile_size,
@ -155,7 +155,7 @@ def test_multi_resolution_tiling(level: int, mock_panda_slides_root_dir: Path) -
assert tiles[assert_batch_index].shape == (tile_count, channels, tile_size, tile_size)
# check tiling on the fly at different resolutions
original_tile = get_original_tile(mock_panda_slides_root_dir, wsi_id)
original_tile = get_original_tile(mock_panda_slides_root_dir_diagonal, wsi_id)
for i in range(tile_count):
# multi resolution mock data has been created via 2 factor downsampling
assert (original_tile[:, :: 2 ** level, :: 2 ** level] == tiles[assert_batch_index][i].numpy()).all()
@ -164,7 +164,7 @@ def test_multi_resolution_tiling(level: int, mock_panda_slides_root_dir: Path) -
@pytest.mark.skipif(no_gpu, reason="Test requires GPU")
@pytest.mark.gpu
@pytest.mark.parametrize("batch_size", [1, 2])
def test_overlapping_tiles(batch_size: int, mock_panda_slides_root_dir: Path) -> None:
def test_overlapping_tiles(batch_size: int, mock_panda_slides_root_dir_diagonal: Path) -> None:
tile_size = 28
level = 0
step = 14
@ -172,7 +172,7 @@ def test_overlapping_tiles(batch_size: int, mock_panda_slides_root_dir: Path) ->
min_expected_tile_count = 32
assert_batch_index = 0
datamodule = PandaSlidesDataModule(
root_path=mock_panda_slides_root_dir,
root_path=mock_panda_slides_root_dir_diagonal,
max_bag_size=None,
batch_size=batch_size,
tile_size=tile_size,
@ -184,7 +184,7 @@ def test_overlapping_tiles(batch_size: int, mock_panda_slides_root_dir: Path) ->
tiles, wsi_id = sample[SlideKey.IMAGE], sample[SlideKey.SLIDE_ID][assert_batch_index]
assert tiles[assert_batch_index].shape[0] >= min_expected_tile_count
original_tile = get_original_tile(mock_panda_slides_root_dir, wsi_id)
original_tile = get_original_tile(mock_panda_slides_root_dir_diagonal, wsi_id)
tile_matches = 0
for _, tile in enumerate(tiles[assert_batch_index]):
tile_matches += int((tile.numpy() == original_tile).all())
@ -193,12 +193,12 @@ def test_overlapping_tiles(batch_size: int, mock_panda_slides_root_dir: Path) ->
@pytest.mark.skipif(no_gpu, reason="Test requires GPU")
@pytest.mark.gpu
def test_train_test_transforms(mock_panda_slides_root_dir: Path) -> None:
def test_train_test_transforms(mock_panda_slides_root_dir_diagonal: Path) -> None:
def get_transforms_dict() -> Dict[ModelKey, Union[Callable, None]]:
train_transform = RandFlipd(keys=[SlideKey.IMAGE], spatial_axis=0, prob=1.0)
return {ModelKey.TRAIN: train_transform, ModelKey.VAL: None, ModelKey.TEST: None} # type: ignore
def retrieve_tiles(dataloader: torch.utils.data.DataLoader) -> Dict[str, torch.Tensor]:
def retrieve_tiles(dataloader: DataLoader) -> Dict[str, torch.Tensor]:
tiles_dict = {}
assert_batch_index = 0
for sample in dataloader:
@ -211,7 +211,7 @@ def test_train_test_transforms(mock_panda_slides_root_dir: Path) -> None:
tile_size = 28
level = 0
flipdatamodule = PandaSlidesDataModule(
root_path=mock_panda_slides_root_dir,
root_path=mock_panda_slides_root_dir_diagonal,
batch_size=batch_size,
max_bag_size=tile_count,
max_bag_size_inf=0,
@ -224,20 +224,20 @@ def test_train_test_transforms(mock_panda_slides_root_dir: Path) -> None:
flip_test_tiles = retrieve_tiles(flipdatamodule.test_dataloader())
for wsi_id in flip_train_tiles.keys():
original_tile = get_original_tile(mock_panda_slides_root_dir, wsi_id)
original_tile = get_original_tile(mock_panda_slides_root_dir_diagonal, wsi_id)
# the first dimension is the channel, flipping happened on the horizontal axis of the image
transformed_original_tile = np.flip(original_tile, axis=1)
for tile in flip_train_tiles[wsi_id]:
assert (tile.numpy() == transformed_original_tile).all()
for wsi_id in flip_val_tiles.keys():
original_tile = get_original_tile(mock_panda_slides_root_dir, wsi_id)
original_tile = get_original_tile(mock_panda_slides_root_dir_diagonal, wsi_id)
for tile in flip_val_tiles[wsi_id]:
# no transformation has been applied to val tiles
assert (tile.numpy() == original_tile).all()
for wsi_id in flip_test_tiles.keys():
original_tile = get_original_tile(mock_panda_slides_root_dir, wsi_id)
original_tile = get_original_tile(mock_panda_slides_root_dir_diagonal, wsi_id)
for tile in flip_test_tiles[wsi_id]:
# no transformation has been applied to test tiles
assert (tile.numpy() == original_tile).all()
@ -251,7 +251,6 @@ class MockPandaSlidesDataModule(SlidesDataModule):
"""
def get_splits(self) -> Tuple[PandaDataset, PandaDataset, PandaDataset]:
return (PandaDataset(self.root_path), PandaDataset(self.root_path), PandaDataset(self.root_path))
@ -278,7 +277,7 @@ def test_whole_slide_inference(batch_size: int, mock_panda_slides_root_with_diff
tiles = sample[SlideKey.IMAGE]
assert tiles[assert_batch_index].shape[0] == tile_count
def assert_whole_slide_inference_with_all_tiles(dataloader: torch.utils.data.DataLoader) -> None:
def assert_whole_slide_inference_with_all_tiles(dataloader: DataLoader) -> None:
for i, sample in enumerate(dataloader):
tiles = sample[SlideKey.IMAGE]
assert tiles[assert_batch_index].shape[0] == n_tiles_list[i * batch_size]

Просмотреть файл

@ -3,15 +3,13 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from copy import deepcopy
import logging
import os
import shutil
from pytorch_lightning import Trainer
import torch
import pytest
from pathlib import Path
from unittest.mock import MagicMock, patch
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Type
from typing import Any, Callable, Dict, Iterable, List, Optional, Type
from torch import Tensor, argmax, nn, rand, randint, randn, round, stack, allclose
from torch.utils.data._utils.collate import default_collate
@ -30,10 +28,9 @@ from health_cpath.datasets.base_dataset import DEFAULT_LABEL_COLUMN, TilesDatase
from health_cpath.datasets.default_paths import PANDA_5X_TILES_DATASET_ID, TCGA_CRCK_DATASET_DIR
from health_cpath.models.deepmil import BaseDeepMILModule, TilesDeepMILModule
from health_cpath.models.encoders import IdentityEncoder, ImageNetEncoder, Resnet18, TileEncoder
from health_cpath.utils.deepmil_utils import EncoderParams, PoolingParams
from health_cpath.utils.deepmil_utils import ClassifierParams, EncoderParams, PoolingParams
from health_cpath.utils.naming import DeepMILSubmodules, MetricsKey, ResultsKey
from testhisto.mocks.base_data_generator import MockHistoDataType
from testhisto.mocks.slides_generator import MockPandaSlidesGenerator, TilesPositioningType
from testhisto.mocks.tiles_generator import MockPandaTilesGenerator
from testhisto.mocks.container import MockDeepSMILETilesPanda, MockDeepSMILESlidesPanda
from health_ml.utils.common_utils import is_gpu_available
@ -75,7 +72,7 @@ def _test_lightningmodule(
module = TilesDeepMILModule(
label_column=DEFAULT_LABEL_COLUMN,
n_classes=n_classes,
dropout_rate=dropout_rate,
classifier_params=ClassifierParams(dropout_rate=dropout_rate),
encoder_params=get_supervised_imagenet_encoder_params(),
pooling_params=get_attention_pooling_layer_params(pool_out_dim)
)
@ -141,29 +138,6 @@ def _test_lightningmodule(
assert torch.all(score <= 1)
@pytest.fixture(scope="session")
def mock_panda_slides_root_dir(
tmp_path_factory: pytest.TempPathFactory, tmp_path_to_pathmnist_dataset: Path
) -> Generator:
tmp_root_dir = tmp_path_factory.mktemp("mock_slides")
wsi_generator = MockPandaSlidesGenerator(
dest_data_path=tmp_root_dir,
src_data_path=tmp_path_to_pathmnist_dataset,
mock_type=MockHistoDataType.PATHMNIST,
n_tiles=4,
n_slides=10,
n_channels=3,
n_levels=3,
tile_size=28,
background_val=255,
tiles_pos_type=TilesPositioningType.RANDOM
)
logging.info("Generating temporary mock slides that will be deleted at the end of the session.")
wsi_generator.generate_mock_histo_data()
yield tmp_root_dir
shutil.rmtree(tmp_root_dir)
@pytest.mark.parametrize("n_classes", [1, 3])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize("max_bag_size", [1, 5])
@ -326,9 +300,10 @@ def test_container(container_type: Type[BaseMILTiles], use_gpu: bool) -> None:
container = container_type()
container.setup()
container.batch_size = 10
container.batch_size_inf = 10
data_module: TilesDataModule = container.get_data_module() # type: ignore
data_module.max_bag_size = 10
module = container.create_model()
module.outputs_handler = MagicMock()
@ -461,12 +436,12 @@ def test_finetuning_options(
label_column=DEFAULT_LABEL_COLUMN,
encoder_params=get_supervised_imagenet_encoder_params(tune_encoder=tune_encoder),
pooling_params=get_attention_pooling_layer_params(pool_out_dim=1, tune_pooling=tune_pooling),
tune_classifier=tune_classifier,
classifier_params=ClassifierParams(tune_classifier=tune_classifier),
)
assert module.encoder_params.tune_encoder == tune_encoder
assert module.pooling_params.tune_pooling == tune_pooling
assert module.tune_classifier == tune_classifier
assert module.classifier_params.tune_classifier == tune_classifier
for params in module.encoder.parameters():
assert params.requires_grad == tune_encoder
@ -513,7 +488,7 @@ def test_training_with_different_finetuning_options(
label_column=MockPandaTilesGenerator.ISUP_GRADE,
encoder_params=get_supervised_imagenet_encoder_params(tune_encoder=tune_encoder),
pooling_params=get_attention_pooling_layer_params(pool_out_dim=1, tune_pooling=tune_pooling),
tune_classifier=tune_classifier,
classifier_params=ClassifierParams(tune_classifier=tune_classifier),
)
def _assert_existing_gradients(module: nn.Module, tuning_flag: bool) -> None:
@ -550,7 +525,7 @@ def test_init_weights_options(pretrained_encoder: bool, pretrained_pooling: bool
)
module.encoder_params.pretrained_encoder = pretrained_encoder
module.pooling_params.pretrained_pooling = pretrained_pooling
module.pretrained_classifier = pretrained_classifier
module.classifier_params.pretrained_classifier = pretrained_classifier
with patch.object(module, "load_from_checkpoint") as mock_load_from_checkpoint:
with patch.object(module, "copy_weights") as mock_copy_weights:
@ -576,10 +551,11 @@ def _get_tiles_deepmil_module(
label_column=MockPandaTilesGenerator.ISUP_GRADE,
encoder_params=get_supervised_imagenet_encoder_params(),
pooling_params=get_transformer_pooling_layer_params(num_layers, num_heads, hidden_dim, transformer_dropout),
classifier_params=ClassifierParams(pretrained_classifier=pretrained_classifier),
)
module.encoder_params.pretrained_encoder = pretrained_encoder
module.pooling_params.pretrained_pooling = pretrained_pooling
module.pretrained_classifier = pretrained_classifier
module.classifier_params.pretrained_classifier = pretrained_classifier
return module
@ -711,13 +687,13 @@ def test_on_run_extra_val_epoch(mock_panda_tiles_root_dir: Path) -> None:
container.setup()
container.data_module = MagicMock()
container.create_lightning_module_and_store()
assert not container.model._run_extra_val_epoch
assert not container.model._on_extra_val_epoch
assert (
container.model.outputs_handler.test_plots_handler.plot_options # type: ignore
!= container.model.outputs_handler.val_plots_handler.plot_options # type: ignore
)
container.on_run_extra_validation_epoch()
assert container.model._run_extra_val_epoch
assert container.model._on_extra_val_epoch
assert (
container.model.outputs_handler.test_plots_handler.plot_options # type: ignore
== container.model.outputs_handler.val_plots_handler.plot_options # type: ignore

Просмотреть файл

@ -191,7 +191,7 @@ def test_on_train_epoch_end_distributed(tmp_path: Path) -> None:
def test_on_train_and_val_end(tmp_path: Path) -> None:
pl_module = MagicMock(global_rank=0, _run_extra_val_epoch=False)
pl_module = MagicMock(global_rank=0, _on_extra_val_epoch=False)
max_epochs = 4
trainer = MagicMock(current_epoch=max_epochs - 1)
@ -224,7 +224,7 @@ def test_on_train_and_val_end(tmp_path: Path) -> None:
def test_on_validation_end_not_called_if_extra_val_epoch(tmp_path: Path) -> None:
pl_module = MagicMock(global_rank=0, _run_extra_val_epoch=True)
pl_module = MagicMock(global_rank=0, _on_extra_val_epoch=True)
max_epochs = 4
trainer = MagicMock(current_epoch=0)
loss_callback = LossAnalysisCallback(
@ -264,7 +264,7 @@ def test_nans_detection(tmp_path: Path, caplog: pytest.LogCaptureFixture) -> Non
def test_log_exceptions_flag(log_exceptions: bool, tmp_path: Path, caplog: pytest.LogCaptureFixture) -> None:
max_epochs = 3
trainer = MagicMock(current_epoch=max_epochs - 1)
pl_module = MagicMock(global_rank=0, _run_extra_val_epoch=False)
pl_module = MagicMock(global_rank=0, _on_extra_val_epoch=False)
loss_callback = LossAnalysisCallback(
outputs_folder=tmp_path, max_epochs=max_epochs,
num_slides_heatmap=2, num_slides_scatter=2, log_exceptions=log_exceptions

Просмотреть файл

@ -1,12 +1,13 @@
from pathlib import Path
from typing import Dict, List
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch
import pytest
import torch
import torch.distributed
import torch.multiprocessing
from ruamel.yaml import YAML
from health_cpath.utils.tiles_selection_utils import TilesSelector
from testhisto.utils.utils_testhisto import run_distributed
from torch.testing import assert_close
from torchmetrics.metric import Metric
@ -251,3 +252,27 @@ def test_collate_results_multigpu() -> None:
epoch_results = _create_epoch_results(batch_size, num_batches, rank=0, device='cuda:0') \
+ _create_epoch_results(batch_size, num_batches, rank=1, device='cuda:1')
_test_collate_results(epoch_results, total_num_samples=2 * num_batches * batch_size)
@pytest.mark.parametrize('val_set_is_dist', [True, False])
def test_results_gathering_with_val_set_is_dist_flag(val_set_is_dist: bool, tmp_path: Path) -> None:
outputs_handler = _create_outputs_handler(tmp_path)
outputs_handler.tiles_selector = TilesSelector(2, num_slides=2, num_tiles=2)
outputs_handler.val_set_is_dist = val_set_is_dist
outputs_handler._save_outputs = MagicMock() # type: ignore
metric_value = 0.5
with patch("health_cpath.utils.output_utils.gather_results") as mock_gather_results:
with patch.object(outputs_handler.tiles_selector, "gather_selected_tiles_across_devices") as mock_gather_tiles:
with patch.object(outputs_handler.tiles_selector, "_clear_cached_slides_heaps") as mock_clear_cache:
with patch.object(outputs_handler, "should_gather_tiles") as mock_should_gather_tiles:
mock_should_gather_tiles.return_value = True
for rank in range(2):
epoch_results = [{_PRIMARY_METRIC_KEY: [metric_value] * 5, _RANK_KEY: rank}]
outputs_handler.save_validation_outputs(
epoch_results=epoch_results, # type: ignore
metrics_dict=_get_mock_metrics_dict(metric_value),
epoch=0,
is_global_rank_zero=rank == 0)
assert mock_gather_results.called == val_set_is_dist
assert mock_gather_tiles.called == val_set_is_dist
mock_clear_cache.assert_called()

Просмотреть файл

@ -122,7 +122,7 @@ class HelloRegression(LightningModule):
self.model = torch.nn.Linear(in_features=1, out_features=1, bias=True) # type: ignore
self.test_mse: List[torch.Tensor] = []
self.test_mae = MeanAbsoluteError()
self._run_extra_val_epoch = False
self._on_extra_val_epoch = False
def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
"""
@ -233,7 +233,7 @@ class HelloRegression(LightningModule):
self.log("test_mse", average_mse, on_epoch=True, on_step=False)
def on_run_extra_validation_epoch(self) -> None:
self._run_extra_val_epoch = True
self._on_extra_val_epoch = True
class HelloWorld(LightningContainer):

Просмотреть файл

@ -509,6 +509,11 @@ class TrainerParams(param.Parameterized):
pl_log_every_n_steps: int = param.Integer(default=50,
doc="PyTorch Lightning trainer flag 'log_every_n_steps': How often to"
"log within steps. Default to 50.")
pl_replace_sampler_ddp: bool = param.Boolean(default=True,
doc="PyTorch Lightning trainer flag 'replace_sampler_ddp' that "
"sets the sampler for distributed training with shuffle=True during "
"training and shuffle=False during validation. Default to True. Set to"
"False to set your own sampler.")
@property
def use_gpu(self) -> bool:

Просмотреть файл

@ -210,5 +210,6 @@ def create_lightning_trainer(container: LightningContainer,
resume_from_checkpoint=str(resume_from_checkpoint) if resume_from_checkpoint else None,
multiple_trainloader_mode=multiple_trainloader_mode,
accumulate_grad_batches=container.pl_accumulate_grad_batches,
replace_sampler_ddp=container.pl_replace_sampler_ddp,
**additional_args)
return trainer, storing_logger