зеркало из https://github.com/microsoft/hi-ml.git
ENH: Add pl_replace_sampler_ddp flag for small validation sets (#648)
- Add pl_replace_sampler_ddp to be able to avoid validation data distribution -> make sure all processes validate on the entire validation set (reduce will give the same mean) - Log extra validation outputs with a different prefix - Refactor ClassifierParams - Add specific batch size for inference Random [broken pipe error](https://ml.azure.com/experiments/id/c4ad55a7-ce24-46a5-be8e-0fb478edf47e/runs/refs_pull_648_merge_1667920946_f585577d?wsid=/subscriptions/a85ceddd-892e-4637-ae4b-67d15ddf5f2b/resourceGroups/health-ml/providers/Microsoft.MachineLearningServices/workspaces/hi-ml&tid=72f988bf-86f1-41af-91ab-2d7cd011db47#outputsAndLogs)
This commit is contained in:
Родитель
e30a0d1f6d
Коммит
1034324c51
|
@ -142,7 +142,7 @@ jobs:
|
|||
cd ${{ env.folder }}
|
||||
make smoke_test_tcgacrckimagenetmil_aml
|
||||
|
||||
smoke_test_tcgacrcksslmil_aml:
|
||||
smoke_test_tcgacrcksslmil:
|
||||
runs-on: ubuntu-20.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
@ -157,7 +157,7 @@ jobs:
|
|||
cd ${{ env.folder }}
|
||||
make smoke_test_tcgacrcksslmil_aml
|
||||
|
||||
smoke_test_crck_simclr_aml:
|
||||
smoke_test_crck_simclr:
|
||||
runs-on: ubuntu-20.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
@ -172,7 +172,7 @@ jobs:
|
|||
cd ${{ env.folder }}
|
||||
make smoke_test_crck_simclr_aml
|
||||
|
||||
smoke_test_crck_flexible_finetuning_aml:
|
||||
smoke_test_crck_flexible_finetuning:
|
||||
runs-on: ubuntu-20.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
@ -217,6 +217,36 @@ jobs:
|
|||
cd ${{ env.folder }}
|
||||
make smoke_test_slides_panda_loss_analysis_aml
|
||||
|
||||
smoke_test_slides_panda_no_ddp_sampler:
|
||||
runs-on: ubuntu-20.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
lfs: true
|
||||
|
||||
- name: Prepare Conda environment
|
||||
uses: ./.github/actions/prepare_cpath_environment
|
||||
|
||||
- name: smoke test
|
||||
run: |
|
||||
cd ${{ env.folder }}
|
||||
make smoke_test_slides_panda_no_ddp_sampler_aml
|
||||
|
||||
smoke_test_tiles_panda_no_ddp_sampler:
|
||||
runs-on: ubuntu-20.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
lfs: true
|
||||
|
||||
- name: Prepare Conda environment
|
||||
uses: ./.github/actions/prepare_cpath_environment
|
||||
|
||||
- name: smoke test
|
||||
run: |
|
||||
cd ${{ env.folder }}
|
||||
make smoke_test_tiles_panda_no_ddp_sampler_aml
|
||||
|
||||
publish:
|
||||
runs-on: ubuntu-20.04
|
||||
needs: [
|
||||
|
@ -225,10 +255,14 @@ jobs:
|
|||
pytest,
|
||||
smoke_test_slidespandaimagenetmil,
|
||||
smoke_test_tilespandaimagenetmil,
|
||||
smoke_test_crck_simclr_aml,
|
||||
smoke_test_crck_flexible_finetuning_aml,
|
||||
smoke_test_tcgacrckimagenetmil,
|
||||
smoke_test_tcgacrcksslmil,
|
||||
smoke_test_crck_simclr,
|
||||
smoke_test_crck_flexible_finetuning,
|
||||
smoke_test_crck_loss_analysis,
|
||||
smoke_test_slides_panda_loss_analysis,
|
||||
smoke_test_slides_panda_no_ddp_sampler,
|
||||
smoke_test_tiles_panda_no_ddp_sampler,
|
||||
]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
|
|
@ -82,7 +82,7 @@ pytest_coverage:
|
|||
pytest --cov=health_cpath --cov SSL --cov-branch --cov-report=html --cov-report=xml --cov-report=term-missing --cov-config=.coveragerc
|
||||
|
||||
SSL_CKPT_RUN_ID_CRCK := CRCK_SimCLR_1655731022_85790606
|
||||
SRC_CKPT_RUN_ID_CRCK := TcgaCrckSSLMIL_1667236343_af6e293f
|
||||
SRC_CKPT_RUN_ID_CRCK := TcgaCrckSSLMIL_1667832811_4e855804
|
||||
|
||||
# Run regression tests and compare performance
|
||||
define BASE_CPATH_RUNNER_COMMAND
|
||||
|
@ -190,6 +190,10 @@ define LOSS_ANALYSIS_ARGS
|
|||
--num_slides_scatter=2 --num_slides_heatmap=2 --save_tile_ids=True
|
||||
endef
|
||||
|
||||
define DDP_SAMPLER_ARGS
|
||||
--pl_replace_sampler_ddp=False
|
||||
endef
|
||||
|
||||
# The following test takes around 5 minutes
|
||||
smoke_test_slidespandaimagenetmil_local:
|
||||
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDASLIDES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
|
||||
|
@ -259,6 +263,22 @@ smoke_test_slides_panda_loss_analysis_aml:
|
|||
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDASLIDES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
|
||||
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${LOSS_ANALYSIS_ARGS} ${AML_MULTIPLE_DEVICE_ARGS};}
|
||||
|
||||
smoke tests local: smoke_test_slidespandaimagenetmil_local smoke_test_tilespandaimagenetmil_local smoke_test_tcgacrcksslmil_local smoke_test_crck_simclr_local smoke_test_crck_flexible_finetuning_local smoke_test_tcgacrckimagenetmil_local smoke_test_crck_loss_analysis_local smoke_test_slides_panda_loss_analysis_local
|
||||
smoke_test_slides_panda_no_ddp_sampler_local:
|
||||
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDASLIDES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
|
||||
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${DDP_SAMPLER_ARGS};}
|
||||
|
||||
smoke tests AML: smoke_test_slidespandaimagenetmil_aml smoke_test_tilespandaimagenetmil_aml smoke_test_tcgacrcksslmil_aml smoke_test_crck_simclr_aml smoke_test_crck_flexible_finetuning_aml smoke_test_tcgacrckimagenetmil_aml smoke_test_crck_loss_analysis_aml smoke_test_slides_panda_loss_analysis_aml
|
||||
smoke_test_slides_panda_no_ddp_sampler_aml:
|
||||
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDASLIDES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
|
||||
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${DDP_SAMPLER_ARGS} ${AML_MULTIPLE_DEVICE_ARGS};}
|
||||
|
||||
smoke_test_tiles_panda_no_ddp_sampler_local:
|
||||
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDATILES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
|
||||
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${DDP_SAMPLER_ARGS};}
|
||||
|
||||
smoke_test_tiles_panda_no_ddp_sampler_aml:
|
||||
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDATILES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
|
||||
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${DDP_SAMPLER_ARGS} ${AML_MULTIPLE_DEVICE_ARGS};}
|
||||
|
||||
smoke tests local: smoke_test_slidespandaimagenetmil_local smoke_test_tilespandaimagenetmil_local smoke_test_tcgacrcksslmil_local smoke_test_crck_simclr_local smoke_test_crck_flexible_finetuning_local smoke_test_tcgacrckimagenetmil_local smoke_test_crck_loss_analysis_local smoke_test_slides_panda_loss_analysis_local smoke_test_slides_panda_no_ddp_sampler_local smoke_test_tiles_panda_no_ddp_sampler_local
|
||||
|
||||
smoke tests AML: smoke_test_slidespandaimagenetmil_aml smoke_test_tilespandaimagenetmil_aml smoke_test_tcgacrcksslmil_aml smoke_test_crck_simclr_aml smoke_test_crck_flexible_finetuning_aml smoke_test_tcgacrckimagenetmil_aml smoke_test_crck_loss_analysis_aml smoke_test_slides_panda_loss_analysis_aml smoke_test_slides_panda_no_ddp_sampler_aml smoke_test_tiles_panda_no_ddp_sampler_aml
|
||||
|
|
|
@ -26,22 +26,22 @@ from health_cpath.datamodules.base_module import CacheLocation, CacheMode, Histo
|
|||
from health_cpath.datasets.base_dataset import SlidesDataset
|
||||
from health_cpath.models.deepmil import TilesDeepMILModule, SlidesDeepMILModule, BaseDeepMILModule
|
||||
from health_cpath.models.transforms import EncodeTilesBatchd, LoadTilesBatchd
|
||||
from health_cpath.utils.deepmil_utils import EncoderParams, PoolingParams
|
||||
from health_cpath.utils.deepmil_utils import ClassifierParams, EncoderParams, PoolingParams
|
||||
from health_cpath.utils.output_utils import DeepMILOutputsHandler
|
||||
from health_cpath.utils.naming import MetricsKey, PlotOption, SlideKey, ModelKey
|
||||
from health_cpath.utils.tiles_selection_utils import TilesSelector
|
||||
|
||||
|
||||
class BaseMIL(LightningContainer, EncoderParams, PoolingParams, LossCallbackParams):
|
||||
class BaseMIL(LightningContainer, EncoderParams, PoolingParams, ClassifierParams, LossCallbackParams):
|
||||
"""BaseMIL is an abstract container defining basic functionality for running MIL experiments in both slides and
|
||||
tiles settings. It is responsible for instantiating the encoder and pooling layer. Subclasses should define the
|
||||
full DeepMIL model depending on the type of dataset (tiles/slides based).
|
||||
"""
|
||||
dropout_rate: Optional[float] = param.Number(None, bounds=(0, 1), doc="Pre-classifier dropout rate.")
|
||||
class_names: Optional[Sequence[str]] = param.List(None, item_type=str, doc="List of class names. If `None`, "
|
||||
"defaults to `('0', '1', ...)`.")
|
||||
# Data module parameters:
|
||||
batch_size: int = param.Integer(16, bounds=(1, None), doc="Number of slides to load per batch.")
|
||||
batch_size_inf: int = param.Integer(16, bounds=(1, None), doc="Number of slides per batch during inference.")
|
||||
max_bag_size: int = param.Integer(1000, bounds=(0, None),
|
||||
doc="Upper bound on number of tiles in each loaded bag during training stage. "
|
||||
"If 0 (default), will return all samples in each bag. "
|
||||
|
@ -72,13 +72,6 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams, LossCallbackPara
|
|||
"generating outputs.")
|
||||
maximise_primary_metric: bool = param.Boolean(True, doc="Whether the primary validation metric should be "
|
||||
"maximised (otherwise minimised).")
|
||||
tune_classifier: bool = param.Boolean(
|
||||
default=True,
|
||||
doc="If True (default), fine-tune the classifier during training. If False, keep the classifier frozen.")
|
||||
pretrained_classifier: bool = param.Boolean(
|
||||
default=False,
|
||||
doc="If True, will use classifier weights from pretrained model specified in src_checkpoint. If False, will "
|
||||
"initiliaze classifier with random weights.")
|
||||
max_num_workers: int = param.Integer(10, bounds=(0, None),
|
||||
doc="The maximum number of worker processes for dataloaders. Dataloaders use"
|
||||
"a heuristic num_cpus/num_gpus to set the number of workers, which can be"
|
||||
|
@ -94,6 +87,10 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams, LossCallbackPara
|
|||
self.best_checkpoint_filename = f"checkpoint_{metric_optim}_val_{self.primary_val_metric.value}"
|
||||
self.best_checkpoint_filename_with_suffix = self.best_checkpoint_filename + ".ckpt"
|
||||
self.validate()
|
||||
if not self.pl_replace_sampler_ddp and self.max_num_gpus > 1:
|
||||
logging.info(
|
||||
"Replacing sampler with `DistributedSampler` is disabled. Make sure to set your own DDP sampler"
|
||||
)
|
||||
|
||||
def validate(self) -> None:
|
||||
super().validate()
|
||||
|
@ -154,6 +151,7 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams, LossCallbackPara
|
|||
val_plot_options=self.get_val_plot_options(),
|
||||
test_plot_options=self.get_test_plot_options(),
|
||||
wsi_has_mask=self.wsi_has_mask,
|
||||
val_set_is_dist=self.pl_replace_sampler_ddp and self.max_num_gpus > 1,
|
||||
)
|
||||
if self.num_top_slides > 0:
|
||||
outputs_handler.tiles_selector = TilesSelector(
|
||||
|
@ -176,7 +174,11 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams, LossCallbackPara
|
|||
num_slides_scatter=self.num_slides_scatter,
|
||||
num_slides_heatmap=self.num_slides_heatmap,
|
||||
save_tile_ids=self.save_tile_ids,
|
||||
log_exceptions=self.log_exceptions))
|
||||
log_exceptions=self.log_exceptions,
|
||||
val_set_is_dist=(
|
||||
self.pl_replace_sampler_ddp and self.max_num_gpus > 1),
|
||||
)
|
||||
)
|
||||
return callbacks
|
||||
|
||||
def get_checkpoint_to_test(self) -> Path:
|
||||
|
@ -288,15 +290,14 @@ class BaseMILTiles(BaseMIL):
|
|||
n_classes=self.data_module.train_dataset.n_classes,
|
||||
class_names=self.class_names,
|
||||
class_weights=self.data_module.class_weights,
|
||||
tune_classifier=self.tune_classifier,
|
||||
pretrained_classifier=self.pretrained_classifier,
|
||||
dropout_rate=self.dropout_rate,
|
||||
outputs_folder=self.outputs_folder,
|
||||
encoder_params=create_from_matching_params(self, EncoderParams),
|
||||
pooling_params=create_from_matching_params(self, PoolingParams),
|
||||
classifier_params=create_from_matching_params(self, ClassifierParams),
|
||||
optimizer_params=create_from_matching_params(self, OptimizerParams),
|
||||
outputs_folder=self.outputs_folder,
|
||||
outputs_handler=outputs_handler,
|
||||
analyse_loss=self.analyse_loss)
|
||||
analyse_loss=self.analyse_loss,
|
||||
)
|
||||
deepmil_module.transfer_weights(self.trained_weights_path)
|
||||
outputs_handler.set_slides_dataset_for_plots_handlers(self.get_slides_dataset())
|
||||
return deepmil_module
|
||||
|
@ -331,15 +332,14 @@ class BaseMILSlides(BaseMIL):
|
|||
n_classes=self.data_module.train_dataset.n_classes,
|
||||
class_names=self.class_names,
|
||||
class_weights=self.data_module.class_weights,
|
||||
tune_classifier=self.tune_classifier,
|
||||
pretrained_classifier=self.pretrained_classifier,
|
||||
dropout_rate=self.dropout_rate,
|
||||
outputs_folder=self.outputs_folder,
|
||||
encoder_params=create_from_matching_params(self, EncoderParams),
|
||||
pooling_params=create_from_matching_params(self, PoolingParams),
|
||||
classifier_params=create_from_matching_params(self, ClassifierParams),
|
||||
optimizer_params=create_from_matching_params(self, OptimizerParams),
|
||||
outputs_handler=outputs_handler,
|
||||
analyse_loss=self.analyse_loss)
|
||||
analyse_loss=self.analyse_loss,
|
||||
)
|
||||
deepmil_module.transfer_weights(self.trained_weights_path)
|
||||
outputs_handler.set_slides_dataset_for_plots_handlers(self.get_slides_dataset())
|
||||
return deepmil_module
|
||||
|
|
|
@ -63,6 +63,7 @@ class DeepSMILECrck(BaseMILTiles):
|
|||
root_path=self.local_datasets[0],
|
||||
max_bag_size=self.max_bag_size,
|
||||
batch_size=self.batch_size,
|
||||
batch_size_inf=self.batch_size_inf,
|
||||
max_bag_size_inf=self.max_bag_size_inf,
|
||||
transforms_dict=self.get_transforms_dict(TcgaCrck_TilesDataset.IMAGE_COLUMN),
|
||||
cache_mode=self.cache_mode,
|
||||
|
@ -72,6 +73,7 @@ class DeepSMILECrck(BaseMILTiles):
|
|||
crossval_index=self.crossval_index,
|
||||
dataloader_kwargs=self.get_dataloader_kwargs(),
|
||||
seed=self.get_effective_random_seed(),
|
||||
pl_replace_sampler_ddp=self.pl_replace_sampler_ddp,
|
||||
)
|
||||
|
||||
def get_test_plot_options(self) -> Set[PlotOption]:
|
||||
|
|
|
@ -65,6 +65,7 @@ class DeepSMILETilesPanda(BaseMILTiles, BaseDeepSMILEPanda):
|
|||
# declared in BaseMILTiles:
|
||||
is_caching=False,
|
||||
batch_size=8,
|
||||
batch_size_inf=8,
|
||||
azure_datasets=[PANDA_5X_TILES_DATASET_ID, PANDA_DATASET_ID])
|
||||
default_kwargs.update(kwargs)
|
||||
super().__init__(**default_kwargs)
|
||||
|
@ -75,8 +76,9 @@ class DeepSMILETilesPanda(BaseMILTiles, BaseDeepSMILEPanda):
|
|||
def get_data_module(self) -> PandaTilesDataModule:
|
||||
return PandaTilesDataModule(
|
||||
root_path=self.local_datasets[0],
|
||||
max_bag_size=self.max_bag_size,
|
||||
batch_size=self.batch_size,
|
||||
batch_size_inf=self.batch_size_inf,
|
||||
max_bag_size=self.max_bag_size,
|
||||
max_bag_size_inf=self.max_bag_size_inf,
|
||||
transforms_dict=self.get_transforms_dict(PandaTilesDataset.IMAGE_COLUMN),
|
||||
cache_mode=self.cache_mode,
|
||||
|
@ -86,6 +88,7 @@ class DeepSMILETilesPanda(BaseMILTiles, BaseDeepSMILEPanda):
|
|||
crossval_index=self.crossval_index,
|
||||
dataloader_kwargs=self.get_dataloader_kwargs(),
|
||||
seed=self.get_effective_random_seed(),
|
||||
pl_replace_sampler_ddp=self.pl_replace_sampler_ddp,
|
||||
)
|
||||
|
||||
def get_slides_dataset(self) -> Optional[PandaDataset]:
|
||||
|
@ -148,6 +151,7 @@ class DeepSMILESlidesPanda(BaseMILSlides, BaseDeepSMILEPanda):
|
|||
return PandaSlidesDataModule(
|
||||
root_path=self.local_datasets[0],
|
||||
batch_size=self.batch_size,
|
||||
batch_size_inf=self.batch_size_inf,
|
||||
level=self.level,
|
||||
max_bag_size=self.max_bag_size,
|
||||
max_bag_size_inf=self.max_bag_size_inf,
|
||||
|
@ -162,6 +166,7 @@ class DeepSMILESlidesPanda(BaseMILSlides, BaseDeepSMILEPanda):
|
|||
crossval_count=self.crossval_count,
|
||||
crossval_index=self.crossval_index,
|
||||
dataloader_kwargs=self.get_dataloader_kwargs(),
|
||||
pl_replace_sampler_ddp=self.pl_replace_sampler_ddp,
|
||||
)
|
||||
|
||||
def get_slides_dataset(self) -> PandaDataset:
|
||||
|
|
|
@ -25,7 +25,7 @@ from health_cpath.models.encoders import (
|
|||
)
|
||||
from health_cpath.configs.classification.DeepSMILEPanda import DeepSMILESlidesPanda
|
||||
from health_cpath.models.deepmil import SlidesDeepMILModule
|
||||
from health_cpath.utils.deepmil_utils import EncoderParams, PoolingParams
|
||||
from health_cpath.utils.deepmil_utils import ClassifierParams, EncoderParams, PoolingParams
|
||||
from health_cpath.utils.naming import MetricsKey, ModelKey, SlideKey
|
||||
|
||||
|
||||
|
@ -64,6 +64,7 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
|
|||
encoding_chunk_size=60,
|
||||
max_bag_size=56,
|
||||
batch_size=8, # effective batch size = batch_size * num_GPUs
|
||||
batch_size_inf=8,
|
||||
max_epochs=50,
|
||||
l_rate=3e-4,
|
||||
weight_decay=0,
|
||||
|
@ -79,6 +80,7 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
|
|||
# Params specific to fine-tuning
|
||||
if self.tune_encoder:
|
||||
self.batch_size = 2
|
||||
self.batch_size_inf = 2
|
||||
super().setup()
|
||||
|
||||
def get_transforms_dict(self, image_key: str) -> Dict[ModelKey, Union[Callable, None]]:
|
||||
|
@ -100,8 +102,9 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
|
|||
# Hence, inherited `PandaSlidesDataModuleBenchmark` from `SlidesDataModule`
|
||||
return PandaSlidesDataModuleBenchmark(
|
||||
root_path=self.local_datasets[0],
|
||||
max_bag_size=self.max_bag_size,
|
||||
batch_size=self.batch_size,
|
||||
batch_size_inf=self.batch_size_inf,
|
||||
max_bag_size=self.max_bag_size,
|
||||
max_bag_size_inf=self.max_bag_size_inf,
|
||||
level=self.level,
|
||||
tile_size=self.tile_size,
|
||||
|
@ -115,6 +118,7 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
|
|||
crossval_count=self.crossval_count,
|
||||
crossval_index=self.crossval_index,
|
||||
dataloader_kwargs=self.get_dataloader_kwargs(),
|
||||
pl_replace_sampler_ddp=self.pl_replace_sampler_ddp,
|
||||
)
|
||||
|
||||
def create_model(self) -> SlidesDeepMILModule:
|
||||
|
@ -126,10 +130,10 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
|
|||
n_classes=self.data_module.train_dataset.n_classes,
|
||||
class_names=self.class_names,
|
||||
class_weights=self.data_module.class_weights,
|
||||
dropout_rate=self.dropout_rate,
|
||||
outputs_folder=self.outputs_folder,
|
||||
encoder_params=create_from_matching_params(self, EncoderParams),
|
||||
pooling_params=create_from_matching_params(self, PoolingParams),
|
||||
classifier_params=create_from_matching_params(self, ClassifierParams),
|
||||
optimizer_params=create_from_matching_params(self, OptimizerParams),
|
||||
outputs_handler=outputs_handler,
|
||||
analyse_loss=self.analyse_loss,
|
||||
|
|
|
@ -10,7 +10,7 @@ from typing import Any, Callable, Dict, Generic, Optional, Sequence, Tuple, Type
|
|||
|
||||
from monai.data.dataset import CacheDataset, Dataset, PersistentDataset
|
||||
from pytorch_lightning import LightningDataModule
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
|
||||
from health_ml.utils.bag_utils import BagDataset, multibag_collate
|
||||
from health_ml.utils.common_utils import _create_generator
|
||||
|
@ -47,18 +47,21 @@ class HistoDataModule(LightningDataModule, Generic[_SlidesOrTilesDataset]):
|
|||
self,
|
||||
root_path: Path,
|
||||
batch_size: int = 1,
|
||||
batch_size_inf: Optional[int] = None,
|
||||
max_bag_size: int = 0,
|
||||
max_bag_size_inf: int = 0,
|
||||
seed: Optional[int] = None,
|
||||
transforms_dict: Optional[Dict[ModelKey, Union[Callable, None]]] = None,
|
||||
crossval_count: int = 0,
|
||||
crossval_index: int = 0,
|
||||
pl_replace_sampler_ddp: bool = True,
|
||||
dataloader_kwargs: Optional[Dict[str, Any]] = None,
|
||||
dataframe_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
:param root_path: Root directory of the source dataset.
|
||||
:param batch_size: Number of slides to load per batch.
|
||||
:param batch_size_inf: Number of slides to load per batch during inference. If None, use batch_size.
|
||||
:param max_bag_size: Upper bound on number of tiles in each loaded bag during training stage. If 0 (default),
|
||||
will return all samples in each bag. If > 0 , bags larger than `max_bag_size` will yield
|
||||
random subsets of instances. For SlideDataModule, this parameter is used in TileOnGridd Transform to set the
|
||||
|
@ -74,19 +77,21 @@ class HistoDataModule(LightningDataModule, Generic[_SlidesOrTilesDataset]):
|
|||
By default (`None`).
|
||||
:param crossval_count: Number of folds to perform.
|
||||
:param crossval_index: Index of the cross validation split to be performed.
|
||||
:param pl_replace_sampler_ddp: If True, replace the sampler with a DistributedSampler when using DDP.
|
||||
:param dataloader_kwargs: Additional keyword arguments for the training, validation, and test dataloaders.
|
||||
:param dataframe_kwargs: Keyword arguments to pass to `pd.read_csv()` when loading the dataset CSV.
|
||||
"""
|
||||
|
||||
batch_size_inf = batch_size_inf or batch_size
|
||||
super().__init__()
|
||||
|
||||
self.root_path = root_path
|
||||
self.transforms_dict = transforms_dict
|
||||
self.batch_size = batch_size
|
||||
self.max_bag_size = max_bag_size
|
||||
self.max_bag_size_inf = max_bag_size_inf
|
||||
self.batch_sizes = {ModelKey.TRAIN: batch_size, ModelKey.VAL: batch_size_inf, ModelKey.TEST: batch_size_inf}
|
||||
self.bag_sizes = {ModelKey.TRAIN: max_bag_size, ModelKey.VAL: max_bag_size_inf, ModelKey.TEST: max_bag_size_inf}
|
||||
self.crossval_count = crossval_count
|
||||
self.crossval_index = crossval_index
|
||||
self.pl_replace_sampler_ddp = pl_replace_sampler_ddp
|
||||
self.train_dataset: _SlidesOrTilesDataset
|
||||
self.val_dataset: _SlidesOrTilesDataset
|
||||
self.test_dataset: _SlidesOrTilesDataset
|
||||
|
@ -105,6 +110,13 @@ class HistoDataModule(LightningDataModule, Generic[_SlidesOrTilesDataset]):
|
|||
) -> DataLoader:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_ddp_sampler(self, dataset: Dataset, stage: ModelKey) -> Optional[DistributedSampler]:
|
||||
is_distributed = torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1
|
||||
if is_distributed and not self.pl_replace_sampler_ddp and stage == ModelKey.TRAIN:
|
||||
assert self.seed is not None, "seed must be set when using distributed training for reproducibility"
|
||||
return DistributedSampler(dataset, shuffle=True, seed=self.seed)
|
||||
return None
|
||||
|
||||
def train_dataloader(self) -> DataLoader:
|
||||
return self._get_dataloader(self.train_dataset, # type: ignore
|
||||
shuffle=True,
|
||||
|
@ -206,15 +218,10 @@ class TilesDataModule(HistoDataModule[TilesDataset]):
|
|||
|
||||
generator = _create_generator(self.seed)
|
||||
|
||||
if stage in [ModelKey.VAL, ModelKey.TEST]:
|
||||
eff_max_bag_size = self.max_bag_size_inf
|
||||
else:
|
||||
eff_max_bag_size = self.max_bag_size
|
||||
|
||||
bag_dataset = BagDataset(
|
||||
tiles_dataset, # type: ignore
|
||||
bag_ids=tiles_dataset.slide_ids,
|
||||
max_bag_size=eff_max_bag_size,
|
||||
max_bag_size=self.bag_sizes[stage],
|
||||
shuffle_samples=True,
|
||||
generator=generator,
|
||||
)
|
||||
|
@ -241,11 +248,14 @@ class TilesDataModule(HistoDataModule[TilesDataset]):
|
|||
transformed_bag_dataset = self._load_dataset(dataset, stage=stage, shuffle=shuffle)
|
||||
bag_dataset: BagDataset = transformed_bag_dataset.data # type: ignore
|
||||
generator = bag_dataset.bag_sampler.generator
|
||||
sampler = self._get_ddp_sampler(transformed_bag_dataset, stage)
|
||||
return DataLoader(
|
||||
transformed_bag_dataset,
|
||||
batch_size=self.batch_size,
|
||||
batch_size=self.batch_sizes[stage],
|
||||
collate_fn=multibag_collate,
|
||||
shuffle=shuffle,
|
||||
sampler=sampler,
|
||||
# sampler option is mutually exclusive with shuffle
|
||||
shuffle=shuffle if sampler is None else None, # type: ignore
|
||||
generator=generator,
|
||||
**dataloader_kwargs,
|
||||
)
|
||||
|
@ -260,13 +270,13 @@ class SlidesDataModule(HistoDataModule[SlidesDataset]):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
level: Optional[int] = 1,
|
||||
tile_size: Optional[int] = 224,
|
||||
level: int = 1,
|
||||
tile_size: int = 224,
|
||||
step: Optional[int] = None,
|
||||
random_offset: Optional[bool] = True,
|
||||
pad_full: Optional[bool] = False,
|
||||
background_val: Optional[int] = 255,
|
||||
filter_mode: Optional[str] = "min",
|
||||
random_offset: bool = True,
|
||||
pad_full: bool = False,
|
||||
background_val: int = 255,
|
||||
filter_mode: str = "min",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
|
@ -298,8 +308,9 @@ class SlidesDataModule(HistoDataModule[SlidesDataset]):
|
|||
self.filter_mode = filter_mode
|
||||
# TileOnGridd transform expects None to select all foreground tile so we hardcode max_bag_size and
|
||||
# max_bag_size_inf to None if set to 0
|
||||
self.max_bag_size = None if self.max_bag_size == 0 else self.max_bag_size # type: ignore
|
||||
self.max_bag_size_inf = None if self.max_bag_size_inf == 0 else self.max_bag_size_inf # type: ignore
|
||||
for stage_key, max_bag_size in self.bag_sizes.items():
|
||||
if max_bag_size == 0:
|
||||
self.bag_sizes[stage_key] = None # type: ignore
|
||||
|
||||
def _load_dataset(self, slides_dataset: SlidesDataset, stage: ModelKey) -> Dataset:
|
||||
base_transform = Compose(
|
||||
|
@ -314,7 +325,7 @@ class SlidesDataModule(HistoDataModule[SlidesDataset]):
|
|||
),
|
||||
TileOnGridd(
|
||||
keys=slides_dataset.IMAGE_COLUMN,
|
||||
tile_count=self.max_bag_size if stage == ModelKey.TRAIN else self.max_bag_size_inf,
|
||||
tile_count=self.bag_sizes[stage],
|
||||
tile_size=self.tile_size,
|
||||
step=self.step,
|
||||
random_offset=self.random_offset if stage == ModelKey.TRAIN else False,
|
||||
|
@ -339,11 +350,14 @@ class SlidesDataModule(HistoDataModule[SlidesDataset]):
|
|||
**dataloader_kwargs: Any) -> DataLoader:
|
||||
transformed_slides_dataset = self._load_dataset(dataset, stage)
|
||||
generator = _create_generator(self.seed)
|
||||
sampler = self._get_ddp_sampler(transformed_slides_dataset, stage)
|
||||
return DataLoader(
|
||||
transformed_slides_dataset,
|
||||
batch_size=self.batch_size,
|
||||
batch_size=self.batch_sizes[stage],
|
||||
collate_fn=image_collate,
|
||||
shuffle=shuffle,
|
||||
sampler=sampler,
|
||||
# sampler option is mutually exclusive with shuffle
|
||||
shuffle=shuffle if not sampler else None, # type: ignore
|
||||
generator=generator,
|
||||
**dataloader_kwargs,
|
||||
)
|
||||
|
|
|
@ -6,23 +6,20 @@ import torch
|
|||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
from pathlib import Path
|
||||
|
||||
from pytorch_lightning import LightningModule
|
||||
|
||||
from torch import Tensor, argmax, mode, nn, optim, round
|
||||
from torchmetrics import (AUROC, F1, Accuracy, ConfusionMatrix, Precision,
|
||||
Recall, CohenKappa, AveragePrecision, Specificity)
|
||||
|
||||
|
||||
from health_ml.utils import log_on_epoch
|
||||
from health_ml.deep_learning_config import OptimizerParams
|
||||
from health_cpath.models.encoders import IdentityEncoder
|
||||
from health_cpath.utils.deepmil_utils import EncoderParams, PoolingParams, set_module_gradients_enabled
|
||||
|
||||
from health_cpath.utils.deepmil_utils import ClassifierParams, EncoderParams, PoolingParams
|
||||
from health_cpath.datasets.base_dataset import TilesDataset
|
||||
from health_cpath.utils.naming import DeepMILSubmodules, MetricsKey, ResultsKey, SlideKey, ModelKey, TileKey
|
||||
from health_cpath.utils.output_utils import (BatchResultsType, DeepMILOutputsHandler, EpochResultsType,
|
||||
validate_class_names)
|
||||
validate_class_names, EXTRA_PREFIX)
|
||||
|
||||
|
||||
RESULTS_COLS = [ResultsKey.SLIDE_ID, ResultsKey.TILE_ID, ResultsKey.IMAGE_PATH, ResultsKey.PROB,
|
||||
ResultsKey.CLASS_PROBS, ResultsKey.PRED_LABEL, ResultsKey.TRUE_LABEL, ResultsKey.BAG_ATTN]
|
||||
|
@ -42,34 +39,31 @@ class BaseDeepMILModule(LightningModule):
|
|||
n_classes: int,
|
||||
class_weights: Optional[Tensor] = None,
|
||||
class_names: Optional[Sequence[str]] = None,
|
||||
tune_classifier: bool = True,
|
||||
pretrained_classifier: bool = False,
|
||||
dropout_rate: Optional[float] = None,
|
||||
verbose: bool = False,
|
||||
outputs_folder: Optional[Path] = None,
|
||||
encoder_params: EncoderParams = EncoderParams(),
|
||||
pooling_params: PoolingParams = PoolingParams(),
|
||||
classifier_params: ClassifierParams = ClassifierParams(),
|
||||
optimizer_params: OptimizerParams = OptimizerParams(),
|
||||
outputs_folder: Optional[Path] = None,
|
||||
outputs_handler: Optional[DeepMILOutputsHandler] = None,
|
||||
analyse_loss: Optional[bool] = False) -> None:
|
||||
analyse_loss: Optional[bool] = False,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
:param label_column: Label key for input batch dictionary.
|
||||
:param n_classes: Number of output classes for MIL prediction. For binary classification, n_classes should be
|
||||
set to 1.
|
||||
:param class_weights: Tensor containing class weights (default=None).
|
||||
:param class_names: The names of the classes if available (default=None).
|
||||
:param tune_classifier: Whether to tune the classifier (default=True).
|
||||
:param pretrained_classifier: Whether to use pretrained classifier (default=False for random init).
|
||||
:param dropout_rate: Rate of pre-classifier dropout (0-1). `None` for no dropout (default).
|
||||
:param verbose: if True statements about memory usage are output at each step.
|
||||
:param outputs_folder: Path to output folder where encoder checkpoint is downloaded.
|
||||
:param encoder_params: Encoder parameters that specify all encoder specific attributes.
|
||||
:param pooling_params: Pooling layer parameters that specify all encoder specific attributes.
|
||||
:param classifier_params: Classifier parameters that specify all classifier specific attributes.
|
||||
:param optimizer_params: Optimizer parameters that specify all specific attributes to be used for oprimization.
|
||||
:param outputs_folder: Path to output folder where encoder checkpoint is downloaded.
|
||||
:param outputs_handler: A configured :py:class:`DeepMILOutputsHandler` object to save outputs for the best
|
||||
validation epoch and test stage. If omitted (default), no outputs will be saved to disk (aside from usual
|
||||
metrics logging).
|
||||
:param analyse_loss: If True, the loss is analysed per sample and analysed with LossAnalysisCallback.
|
||||
:param verbose: if True statements about memory usage are output at each step.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
|
@ -79,29 +73,29 @@ class BaseDeepMILModule(LightningModule):
|
|||
self.class_weights = class_weights
|
||||
self.class_names = validate_class_names(class_names, self.n_classes)
|
||||
|
||||
self.dropout_rate = dropout_rate
|
||||
self.encoder_params = encoder_params
|
||||
self.pooling_params = pooling_params
|
||||
self.classifier_params = classifier_params
|
||||
self.optimizer_params = optimizer_params
|
||||
|
||||
self.save_hyperparameters()
|
||||
self.verbose = verbose
|
||||
self.outputs_handler = outputs_handler
|
||||
self.pretrained_classifier = pretrained_classifier
|
||||
|
||||
# This flag can be switched on before invoking trainer.validate() to enable saving additional time/memory
|
||||
# consuming validation outputs via calling self.on_run_extra_validation_epoch()
|
||||
self._run_extra_val_epoch = False
|
||||
self.tune_classifier = tune_classifier
|
||||
self._on_extra_val_epoch = False
|
||||
|
||||
# Model components
|
||||
self.encoder = encoder_params.get_encoder(outputs_folder)
|
||||
self.aggregation_fn, self.num_pooling = pooling_params.get_pooling_layer(self.encoder.num_encoding)
|
||||
self.classifier_fn = self.get_classifier()
|
||||
self.classifier_fn = classifier_params.get_classifier(self.num_pooling, self.n_classes)
|
||||
self.activation_fn = self.get_activation()
|
||||
self.analyse_loss = analyse_loss
|
||||
|
||||
# Loss function
|
||||
self.loss_fn = self.get_loss(reduction="mean")
|
||||
self.loss_fn_no_reduction = self.get_loss(reduction="none")
|
||||
self.analyse_loss = analyse_loss
|
||||
|
||||
# Metrics Objects
|
||||
self.train_metrics = self.get_metrics()
|
||||
|
@ -149,23 +143,12 @@ class BaseDeepMILModule(LightningModule):
|
|||
if self.pooling_params.pretrained_pooling:
|
||||
self.copy_weights(self.aggregation_fn, pretrained_model.aggregation_fn, DeepMILSubmodules.POOLING)
|
||||
|
||||
if self.pretrained_classifier:
|
||||
if self.classifier_params.pretrained_classifier:
|
||||
if pretrained_model.n_classes != self.n_classes:
|
||||
raise ValueError(f"Number of classes in pretrained model {pretrained_model.n_classes} "
|
||||
f"does not match number of classes in current model {self.n_classes}.")
|
||||
self.copy_weights(self.classifier_fn, pretrained_model.classifier_fn, DeepMILSubmodules.CLASSIFIER)
|
||||
|
||||
def get_classifier(self) -> nn.Module:
|
||||
classifier_layer = nn.Linear(in_features=self.num_pooling,
|
||||
out_features=self.n_classes)
|
||||
set_module_gradients_enabled(classifier_layer, self.tune_classifier)
|
||||
if self.dropout_rate is None:
|
||||
return classifier_layer
|
||||
elif 0 <= self.dropout_rate < 1:
|
||||
return nn.Sequential(nn.Dropout(self.dropout_rate), classifier_layer)
|
||||
else:
|
||||
raise ValueError(f"Dropout rate should be in [0, 1), got {self.dropout_rate}")
|
||||
|
||||
def get_loss(self, reduction: str = "mean") -> Callable:
|
||||
if self.n_classes > 1:
|
||||
if self.class_weights is None:
|
||||
|
@ -217,8 +200,12 @@ class BaseDeepMILModule(LightningModule):
|
|||
MetricsKey.RECALL: Recall(),
|
||||
MetricsKey.SPECIFICITY: Specificity()})
|
||||
|
||||
def log_metrics(self, stage: str) -> None:
|
||||
valid_stages = [stage for stage in ModelKey]
|
||||
def get_extra_prefix(self) -> str:
|
||||
"""Get extra prefix for the metrics name to avoir overriding best validation metrics."""
|
||||
return EXTRA_PREFIX if self._on_extra_val_epoch else ""
|
||||
|
||||
def log_metrics(self, stage: str, prefix: str = '') -> None:
|
||||
valid_stages = set([stage for stage in ModelKey])
|
||||
if stage not in valid_stages:
|
||||
raise Exception(f"Invalid stage. Chose one of {valid_stages}")
|
||||
for metric_name, metric_object in self.get_metrics_dict(stage).items():
|
||||
|
@ -226,9 +213,9 @@ class BaseDeepMILModule(LightningModule):
|
|||
metric_value = metric_object.compute()
|
||||
metric_value_n = metric_value / metric_value.sum(axis=1, keepdims=True)
|
||||
for i in range(metric_value_n.shape[0]):
|
||||
log_on_epoch(self, f'{stage}/{self.class_names[i]}', metric_value_n[i, i])
|
||||
log_on_epoch(self, f'{prefix}{stage}/{self.class_names[i]}', metric_value_n[i, i])
|
||||
else:
|
||||
log_on_epoch(self, f'{stage}/{metric_name}', metric_object)
|
||||
log_on_epoch(self, f'{prefix}{stage}/{metric_name}', metric_object)
|
||||
|
||||
def get_instance_features(self, instances: Tensor) -> Tensor:
|
||||
if not self.encoder_params.tune_encoder:
|
||||
|
@ -252,7 +239,7 @@ class BaseDeepMILModule(LightningModule):
|
|||
return attentions, bag_features
|
||||
|
||||
def get_bag_logit(self, bag_features: Tensor) -> Tensor:
|
||||
if not self.tune_classifier:
|
||||
if not self.classifier_params.tune_classifier:
|
||||
self.classifier_fn.eval()
|
||||
bag_logit = self.classifier_fn(bag_features)
|
||||
return bag_logit
|
||||
|
@ -295,6 +282,14 @@ class BaseDeepMILModule(LightningModule):
|
|||
"""Update training results with data specific info. This can be either tiles or slides related metadata."""
|
||||
raise NotImplementedError
|
||||
|
||||
def update_slides_selection(self, stage: str, batch: Dict, results: Dict) -> None:
|
||||
if (
|
||||
(stage == ModelKey.TEST or (stage == ModelKey.VAL and self._on_extra_val_epoch))
|
||||
and self.outputs_handler
|
||||
and self.outputs_handler.tiles_selector
|
||||
):
|
||||
self.outputs_handler.tiles_selector.update_slides_selection(batch, results)
|
||||
|
||||
def _compute_loss(self, loss_fn: Callable, bag_logits: Tensor, bag_labels: Tensor) -> Tensor:
|
||||
if self.n_classes > 1:
|
||||
return loss_fn(bag_logits, bag_labels.long())
|
||||
|
@ -324,6 +319,7 @@ class BaseDeepMILModule(LightningModule):
|
|||
predicted_probs = predicted_probs.squeeze(dim=1)
|
||||
|
||||
results = dict()
|
||||
|
||||
if self.analyse_loss and stage in [ModelKey.TRAIN, ModelKey.VAL]:
|
||||
loss_per_sample = self._compute_loss(self.loss_fn_no_reduction, bag_logits, bag_labels)
|
||||
results[ResultsKey.LOSS_PER_SAMPLE] = loss_per_sample.detach().cpu().numpy()
|
||||
|
@ -340,18 +336,12 @@ class BaseDeepMILModule(LightningModule):
|
|||
ResultsKey.BAG_ATTN: bag_attn_list
|
||||
})
|
||||
self.update_results_with_data_specific_info(batch=batch, results=results)
|
||||
if (
|
||||
(stage == ModelKey.TEST or (stage == ModelKey.VAL and self._run_extra_val_epoch))
|
||||
and self.outputs_handler
|
||||
and self.outputs_handler.tiles_selector
|
||||
):
|
||||
self.outputs_handler.tiles_selector.update_slides_selection(batch, results)
|
||||
self.update_slides_selection(stage=stage, batch=batch, results=results)
|
||||
return results
|
||||
|
||||
def training_step(self, batch: Dict, batch_idx: int) -> BatchResultsType: # type: ignore
|
||||
train_result = self._shared_step(batch, batch_idx, ModelKey.TRAIN)
|
||||
self.log('train/loss', train_result[ResultsKey.LOSS], on_epoch=True, on_step=True, logger=True,
|
||||
sync_dist=True)
|
||||
self.log('train/loss', train_result[ResultsKey.LOSS], on_epoch=True, on_step=True, logger=True, sync_dist=True)
|
||||
if self.verbose:
|
||||
print(f"After loading images batch {batch_idx} -", _format_cuda_memory_stats())
|
||||
results = {ResultsKey.LOSS: train_result[ResultsKey.LOSS]}
|
||||
|
@ -363,34 +353,29 @@ class BaseDeepMILModule(LightningModule):
|
|||
|
||||
def validation_step(self, batch: Dict, batch_idx: int) -> BatchResultsType: # type: ignore
|
||||
val_result = self._shared_step(batch, batch_idx, ModelKey.VAL)
|
||||
self.log('val/loss', val_result[ResultsKey.LOSS], on_epoch=True, on_step=True, logger=True,
|
||||
sync_dist=True)
|
||||
name = f'{self.get_extra_prefix()}val/loss'
|
||||
self.log(name, val_result[ResultsKey.LOSS], on_epoch=True, on_step=True, logger=True, sync_dist=True)
|
||||
return val_result
|
||||
|
||||
def test_step(self, batch: Dict, batch_idx: int) -> BatchResultsType: # type: ignore
|
||||
test_result = self._shared_step(batch, batch_idx, ModelKey.TEST)
|
||||
self.log('test/loss', test_result[ResultsKey.LOSS], on_epoch=True, on_step=True, logger=True,
|
||||
sync_dist=True)
|
||||
self.log('test/loss', test_result[ResultsKey.LOSS], on_epoch=True, on_step=True, logger=True, sync_dist=True)
|
||||
return test_result
|
||||
|
||||
def training_epoch_end(self, outputs: EpochResultsType) -> None: # type: ignore
|
||||
self.log_metrics(ModelKey.TRAIN)
|
||||
|
||||
def validation_epoch_end(self, epoch_results: EpochResultsType) -> None: # type: ignore
|
||||
self.log_metrics(ModelKey.VAL)
|
||||
self.log_metrics(stage=ModelKey.VAL, prefix=self.get_extra_prefix())
|
||||
if self.outputs_handler:
|
||||
self.outputs_handler.save_validation_outputs(
|
||||
epoch_results=epoch_results,
|
||||
metrics_dict=self.get_metrics_dict(ModelKey.VAL), # type: ignore
|
||||
epoch=self.current_epoch,
|
||||
is_global_rank_zero=self.global_rank == 0,
|
||||
run_extra_val_epoch=self._run_extra_val_epoch
|
||||
on_extra_val=self._on_extra_val_epoch
|
||||
)
|
||||
|
||||
# Reset the top and bottom slides heaps
|
||||
if self.outputs_handler.tiles_selector is not None:
|
||||
self.outputs_handler.tiles_selector._clear_cached_slides_heaps()
|
||||
|
||||
def test_epoch_end(self, epoch_results: EpochResultsType) -> None: # type: ignore
|
||||
self.log_metrics(ModelKey.TEST)
|
||||
if self.outputs_handler:
|
||||
|
@ -402,7 +387,7 @@ class BaseDeepMILModule(LightningModule):
|
|||
def on_run_extra_validation_epoch(self) -> None:
|
||||
"""Hook to be called at the beginning of an extra validation epoch to set validation plots options to the same
|
||||
as the test plots options."""
|
||||
self._run_extra_val_epoch = True
|
||||
self._on_extra_val_epoch = True
|
||||
if self.outputs_handler:
|
||||
self.outputs_handler.val_plots_handler.plot_options = self.outputs_handler.test_plots_handler.plot_options
|
||||
|
||||
|
|
|
@ -93,6 +93,7 @@ class LossAnalysisCallback(Callback):
|
|||
save_tile_ids: bool = False,
|
||||
log_exceptions: bool = True,
|
||||
create_outputs_folders: bool = True,
|
||||
val_set_is_dist: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
|
||||
|
@ -107,6 +108,7 @@ class LossAnalysisCallback(Callback):
|
|||
:param log_exceptions: If True, will log exceptions raised during loss values analysis, defaults to True. If
|
||||
False will raise the intercepted exceptions.
|
||||
:param create_outputs_folders: If True, will create the output folders if they don't exist, defaults to True.
|
||||
:param val_set_is_dist: If True, will assume that the validation set is distributed, defaults to True.
|
||||
"""
|
||||
|
||||
self.patience = patience
|
||||
|
@ -116,6 +118,7 @@ class LossAnalysisCallback(Callback):
|
|||
self.num_slides_heatmap = num_slides_heatmap
|
||||
self.save_tile_ids = save_tile_ids
|
||||
self.log_exceptions = log_exceptions
|
||||
self.val_set_is_dist = val_set_is_dist
|
||||
|
||||
self.outputs_folder = outputs_folder / "loss_analysis_callback"
|
||||
if create_outputs_folders:
|
||||
|
@ -216,7 +219,7 @@ class LossAnalysisCallback(Callback):
|
|||
|
||||
def gather_loss_cache(self, rank: int, stage: ModelKey) -> None:
|
||||
"""Gathers the loss cache from all the workers"""
|
||||
if torch.distributed.is_initialized():
|
||||
if self.val_set_is_dist and torch.distributed.is_initialized():
|
||||
world_size = torch.distributed.get_world_size()
|
||||
if world_size > 1:
|
||||
loss_caches = [None] * world_size
|
||||
|
@ -533,7 +536,7 @@ class LossAnalysisCallback(Callback):
|
|||
"""Hook called at the end of validation. Plot the loss heatmap and scratter plots after ranking the slides by
|
||||
loss values."""
|
||||
epoch = trainer.current_epoch
|
||||
if pl_module.global_rank == 0 and not pl_module._run_extra_val_epoch and epoch == (self.max_epochs - 1):
|
||||
if pl_module.global_rank == 0 and not pl_module._on_extra_val_epoch and epoch == (self.max_epochs - 1):
|
||||
try:
|
||||
self.save_loss_outliers_analaysis_results(stage=ModelKey.VAL)
|
||||
except Exception as e:
|
||||
|
|
|
@ -177,3 +177,22 @@ class PoolingParams(param.Parameterized):
|
|||
num_features = num_encoding * self.pool_out_dim
|
||||
set_module_gradients_enabled(pooling_layer, tuning_flag=self.tune_pooling)
|
||||
return pooling_layer, num_features
|
||||
|
||||
|
||||
class ClassifierParams(param.Parameterized):
|
||||
dropout_rate: Optional[float] = param.Number(None, bounds=(0, 1), doc="Pre-classifier dropout rate.")
|
||||
tune_classifier: bool = param.Boolean(
|
||||
default=True,
|
||||
doc="If True (default), fine-tune the classifier during training. If False, keep the classifier frozen.")
|
||||
pretrained_classifier: bool = param.Boolean(
|
||||
default=False,
|
||||
doc="If True, will use classifier weights from pretrained model specified in src_checkpoint. If False, will "
|
||||
"initiliaze classifier with random weights.")
|
||||
|
||||
def get_classifier(self, in_features: int, out_features: int) -> nn.Module:
|
||||
classifier_layer = nn.Linear(in_features=in_features,
|
||||
out_features=out_features)
|
||||
set_module_gradients_enabled(classifier_layer, tuning_flag=self.tune_classifier)
|
||||
if self.dropout_rate is None:
|
||||
return classifier_layer
|
||||
return nn.Sequential(nn.Dropout(self.dropout_rate), classifier_layer)
|
||||
|
|
|
@ -26,6 +26,8 @@ OUTPUTS_CSV_FILENAME = "test_output.csv"
|
|||
VAL_OUTPUTS_SUBDIR = "val"
|
||||
PREV_VAL_OUTPUTS_SUBDIR = "val_old"
|
||||
TEST_OUTPUTS_SUBDIR = "test"
|
||||
EXTRA_VAL_OUTPUTS_SUBDIR = "extra_val"
|
||||
EXTRA_PREFIX = "extra_"
|
||||
|
||||
AML_OUTPUTS_DIR = "outputs"
|
||||
AML_LEGACY_TEST_OUTPUTS_CSV = "/".join([AML_OUTPUTS_DIR, OUTPUTS_CSV_FILENAME])
|
||||
|
@ -198,7 +200,7 @@ class OutputsPolicy:
|
|||
YAML().dump(contents, self.best_metric_file_path)
|
||||
|
||||
def should_save_validation_outputs(self, metrics_dict: Mapping[MetricsKey, Metric], epoch: int,
|
||||
is_global_rank_zero: bool = True) -> bool:
|
||||
is_global_rank_zero: bool = True, on_extra_val: bool = False) -> bool:
|
||||
"""Determine whether validation outputs should be saved given the current epoch's metrics.
|
||||
|
||||
:param metrics_dict: Current epoch's metrics dictionary from
|
||||
|
@ -206,8 +208,11 @@ class OutputsPolicy:
|
|||
:param epoch: Current epoch number.
|
||||
:param is_global_rank_zero: Whether this is the global rank-0 process in distributed scenarios.
|
||||
Set to `True` (default) if running a single process.
|
||||
:param on_extra_val: Whether this is an extra validation epoch (e.g. after training).
|
||||
:return: Whether this is the best validation epoch so far.
|
||||
"""
|
||||
if on_extra_val:
|
||||
return False
|
||||
metric = metrics_dict[self.primary_val_metric]
|
||||
# If the metric hasn't been updated we don't want to save it
|
||||
if not metric._update_called:
|
||||
|
@ -252,7 +257,8 @@ class DeepMILOutputsHandler:
|
|||
def __init__(self, outputs_root: Path, n_classes: int, tile_size: int, level: int,
|
||||
class_names: Optional[Sequence[str]], primary_val_metric: MetricsKey,
|
||||
maximise: bool, val_plot_options: Collection[PlotOption],
|
||||
test_plot_options: Collection[PlotOption], wsi_has_mask: bool = True) -> None:
|
||||
test_plot_options: Collection[PlotOption], wsi_has_mask: bool = True,
|
||||
val_set_is_dist: bool = True) -> None:
|
||||
"""
|
||||
:param outputs_root: Root directory where to save all produced outputs.
|
||||
:param n_classes: Number of MIL classes (set `n_classes=1` for binary).
|
||||
|
@ -265,6 +271,10 @@ class DeepMILOutputsHandler:
|
|||
:param val_plot_options: The desired plot options for validation time.
|
||||
:param test_plot_options: The desired plot options for test time.
|
||||
:param wsi_has_mask: Whether the whole slides have a mask to crop specific ROIs.
|
||||
:param val_set_is_dist: If True, the validation set is distributed across processes. Otherwise, the validation
|
||||
set is replicated on each process. This shouldn't affect the results, as we take the mean of the validation
|
||||
set metrics across processes. This is only relevant for the outputs_handler, which needs to know whether to
|
||||
gather the validation set outputs across processes or not before saving them.
|
||||
"""
|
||||
self.outputs_root = outputs_root
|
||||
self.n_classes = n_classes
|
||||
|
@ -294,11 +304,16 @@ class DeepMILOutputsHandler:
|
|||
stage=ModelKey.TEST,
|
||||
wsi_has_mask=wsi_has_mask
|
||||
)
|
||||
self.val_set_is_dist = val_set_is_dist
|
||||
|
||||
@property
|
||||
def validation_outputs_dir(self) -> Path:
|
||||
return self.outputs_root / VAL_OUTPUTS_SUBDIR
|
||||
|
||||
@property
|
||||
def extra_validation_outputs_dir(self) -> Path:
|
||||
return self.outputs_root / EXTRA_VAL_OUTPUTS_SUBDIR
|
||||
|
||||
@property
|
||||
def previous_validation_outputs_dir(self) -> Path:
|
||||
return self.validation_outputs_dir.with_name(PREV_VAL_OUTPUTS_SUBDIR)
|
||||
|
@ -313,11 +328,15 @@ class DeepMILOutputsHandler:
|
|||
self.test_plots_handler.slides_dataset = slides_dataset
|
||||
self.val_plots_handler.slides_dataset = slides_dataset
|
||||
|
||||
def should_gather_tiles(self, plots_handler: DeepMILPlotsHandler) -> bool:
|
||||
return PlotOption.TOP_BOTTOM_TILES in plots_handler.plot_options and self.tiles_selector is not None
|
||||
|
||||
def _save_outputs(self, epoch_results: EpochResultsType, outputs_dir: Path, stage: ModelKey = ModelKey.VAL) -> None:
|
||||
"""Trigger the rendering and saving of DeepMIL outputs and figures.
|
||||
|
||||
:param epoch_results: Aggregated results from all epoch batches.
|
||||
:param outputs_dir: Specific directory into which outputs should be saved (different for validation and test).
|
||||
:param stage: The stage of the model (e.g. `ModelKey.VAL` or `ModelKey.TEST`).
|
||||
"""
|
||||
# outputs object consists of a list of dictionaries (of metadata and results, including encoded features)
|
||||
# It can be indexed as outputs[batch_idx][batch_key][bag_idx][tile_idx]
|
||||
|
@ -334,8 +353,7 @@ class DeepMILOutputsHandler:
|
|||
plots_handler.save_plots(outputs_dir, self.tiles_selector, results)
|
||||
|
||||
def save_validation_outputs(self, epoch_results: EpochResultsType, metrics_dict: Mapping[MetricsKey, Metric],
|
||||
epoch: int, is_global_rank_zero: bool = True, run_extra_val_epoch: bool = False
|
||||
) -> None:
|
||||
epoch: int, is_global_rank_zero: bool = True, on_extra_val: bool = False) -> None:
|
||||
"""Render and save validation epoch outputs, according to the configured :py:class:`OutputsPolicy`.
|
||||
|
||||
:param epoch_results: Aggregated results from all epoch batches, as passed to :py:meth:`validation_epoch_end()`.
|
||||
|
@ -344,29 +362,35 @@ class DeepMILOutputsHandler:
|
|||
:param is_global_rank_zero: Whether this is the global rank-0 process in distributed scenarios.
|
||||
Set to `True` (default) if running a single process.
|
||||
:param epoch: Current epoch number.
|
||||
:param on_extra_val: Whether this is an extra validation epoch (e.g. after training).
|
||||
"""
|
||||
# All DDP processes must reach this point to allow synchronising epoch results
|
||||
gathered_epoch_results = gather_results(epoch_results)
|
||||
if PlotOption.TOP_BOTTOM_TILES in self.val_plots_handler.plot_options and self.tiles_selector:
|
||||
self.tiles_selector.gather_selected_tiles_across_devices()
|
||||
# All DDP processes must reach this point to allow synchronising epoch results if val_set_is_dist is True
|
||||
if self.val_set_is_dist:
|
||||
epoch_results = gather_results(epoch_results)
|
||||
|
||||
if self.should_gather_tiles(self.val_plots_handler):
|
||||
self.tiles_selector.gather_selected_tiles_across_devices() # type: ignore
|
||||
|
||||
# Only global rank-0 process should actually render and save the outputs
|
||||
# We also want to save the plots of the extra validation epoch
|
||||
if (
|
||||
self.outputs_policy.should_save_validation_outputs(metrics_dict, epoch, is_global_rank_zero)
|
||||
or (run_extra_val_epoch and is_global_rank_zero)
|
||||
):
|
||||
if self.outputs_policy.should_save_validation_outputs(metrics_dict, epoch, is_global_rank_zero, on_extra_val):
|
||||
# First move existing outputs to a temporary directory, to avoid mixing
|
||||
# outputs of different epochs in case writing fails halfway through
|
||||
if self.validation_outputs_dir.exists():
|
||||
replace_directory(source=self.validation_outputs_dir,
|
||||
target=self.previous_validation_outputs_dir)
|
||||
|
||||
self._save_outputs(gathered_epoch_results, self.validation_outputs_dir, ModelKey.VAL)
|
||||
self._save_outputs(epoch_results, self.validation_outputs_dir, ModelKey.VAL)
|
||||
|
||||
# Writing completed successfully; delete temporary back-up
|
||||
if self.previous_validation_outputs_dir.exists():
|
||||
shutil.rmtree(self.previous_validation_outputs_dir, ignore_errors=True)
|
||||
elif on_extra_val and is_global_rank_zero:
|
||||
self._save_outputs(epoch_results, self.extra_validation_outputs_dir, ModelKey.VAL)
|
||||
|
||||
# Reset the top and bottom slides heaps
|
||||
if self.should_gather_tiles(self.val_plots_handler):
|
||||
self.tiles_selector._clear_cached_slides_heaps() # type: ignore
|
||||
|
||||
def save_test_outputs(self, epoch_results: EpochResultsType, is_global_rank_zero: bool = True) -> None:
|
||||
"""Render and save test epoch outputs.
|
||||
|
@ -378,8 +402,8 @@ class DeepMILOutputsHandler:
|
|||
"""
|
||||
# All DDP processes must reach this point to allow synchronising epoch results
|
||||
gathered_epoch_results = gather_results(epoch_results)
|
||||
if PlotOption.TOP_BOTTOM_TILES in self.test_plots_handler.plot_options and self.tiles_selector:
|
||||
self.tiles_selector.gather_selected_tiles_across_devices()
|
||||
if self.should_gather_tiles(self.test_plots_handler):
|
||||
self.tiles_selector.gather_selected_tiles_across_devices() # type: ignore
|
||||
|
||||
# Only global rank-0 process should actually render and save the outputs-
|
||||
if self.outputs_policy.should_save_test_outputs(is_global_rank_zero):
|
||||
|
|
|
@ -8,11 +8,10 @@ import logging
|
|||
import shutil
|
||||
import sys
|
||||
import uuid
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
# temporary workaround until these hi-ml package release
|
||||
testhisto_root_dir = Path(__file__).parent
|
||||
print(f"Adding {testhisto_root_dir} to sys path")
|
||||
|
@ -31,6 +30,7 @@ for package, subpackages in packages.items():
|
|||
from health_ml.utils.fixed_paths import OutputFolderForTests # noqa: E402
|
||||
from testhisto.mocks.base_data_generator import MockHistoDataType # noqa: E402
|
||||
from testhisto.mocks.tiles_generator import MockPandaTilesGenerator # noqa: E402
|
||||
from testhisto.mocks.slides_generator import MockPandaSlidesGenerator, TilesPositioningType # noqa: E402
|
||||
|
||||
|
||||
def remove_and_create_folder(folder: Path) -> None:
|
||||
|
@ -96,3 +96,26 @@ def mock_panda_tiles_root_dir(
|
|||
tiles_generator.generate_mock_histo_data()
|
||||
yield tmp_root_dir
|
||||
shutil.rmtree(tmp_root_dir)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def mock_panda_slides_root_dir(
|
||||
tmp_path_factory: pytest.TempPathFactory, tmp_path_to_pathmnist_dataset: Path
|
||||
) -> Generator:
|
||||
tmp_root_dir = tmp_path_factory.mktemp("mock_slides")
|
||||
wsi_generator = MockPandaSlidesGenerator(
|
||||
dest_data_path=tmp_root_dir,
|
||||
src_data_path=tmp_path_to_pathmnist_dataset,
|
||||
mock_type=MockHistoDataType.PATHMNIST,
|
||||
n_tiles=4,
|
||||
n_slides=15,
|
||||
n_channels=3,
|
||||
n_levels=3,
|
||||
tile_size=28,
|
||||
background_val=255,
|
||||
tiles_pos_type=TilesPositioningType.RANDOM
|
||||
)
|
||||
logging.info("Generating temporary mock slides that will be deleted at the end of the session.")
|
||||
wsi_generator.generate_mock_histo_data()
|
||||
yield tmp_root_dir
|
||||
shutil.rmtree(tmp_root_dir)
|
||||
|
|
|
@ -0,0 +1,184 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import pytest
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from unittest.mock import MagicMock, patch
|
||||
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, SequentialSampler
|
||||
|
||||
from health_cpath.datamodules.base_module import HistoDataModule
|
||||
from health_cpath.datamodules.panda_module import PandaSlidesDataModule, PandaTilesDataModule
|
||||
from health_cpath.utils.naming import ModelKey, SlideKey
|
||||
from health_ml.utils.common_utils import is_gpu_available
|
||||
from testhisto.utils.utils_testhisto import run_distributed
|
||||
|
||||
|
||||
no_gpu = not is_gpu_available()
|
||||
|
||||
|
||||
def _assert_correct_bag_sizes(datamodule: HistoDataModule, max_bag_size: int, max_bag_size_inf: Optional[int],
|
||||
true_bag_sizes: List[int]) -> None:
|
||||
# True bag sizes are the bag sizes that are generated by the mock data generator for a fixed seed as the tiles count
|
||||
# (and therefore bag sizes) are random to reflect real data with varying number of tiles per slide.
|
||||
for stage_key, bag_size in zip([m for m in ModelKey], [max_bag_size, max_bag_size_inf, max_bag_size_inf]):
|
||||
assert datamodule.bag_sizes[stage_key] == bag_size
|
||||
|
||||
def _assert_bag_size_matching(dataloader: DataLoader, expected_bag_sizes: List[int]) -> None:
|
||||
sample = next(iter(dataloader))
|
||||
for i, slide in enumerate(sample[SlideKey.IMAGE]):
|
||||
assert slide.shape[0] == expected_bag_sizes[i]
|
||||
|
||||
_assert_bag_size_matching(datamodule.train_dataloader(), [max_bag_size, max_bag_size])
|
||||
expected_bag_sizes = true_bag_sizes if not max_bag_size_inf else [max_bag_size_inf, max_bag_size_inf]
|
||||
_assert_bag_size_matching(datamodule.val_dataloader(), expected_bag_sizes)
|
||||
_assert_bag_size_matching(datamodule.test_dataloader(), expected_bag_sizes)
|
||||
|
||||
|
||||
def _assert_correct_batch_sizes(datamodule: HistoDataModule, batch_size: int, batch_size_inf: Optional[int]) -> None:
|
||||
batch_size_inf = batch_size_inf if batch_size_inf is not None else batch_size
|
||||
for stage_key, _batch_size in zip([m for m in ModelKey], [batch_size, batch_size_inf, batch_size_inf]):
|
||||
assert datamodule.batch_sizes[stage_key] == _batch_size
|
||||
|
||||
def _assert_batch_size_matching(dataloader: DataLoader, expected_batch_size: int) -> None:
|
||||
sample = next(iter(dataloader))
|
||||
assert len(sample[SlideKey.IMAGE]) == expected_batch_size
|
||||
|
||||
_assert_batch_size_matching(datamodule.train_dataloader(), batch_size)
|
||||
_assert_batch_size_matching(datamodule.val_dataloader(), batch_size_inf)
|
||||
_assert_batch_size_matching(datamodule.test_dataloader(), batch_size_inf)
|
||||
|
||||
|
||||
@pytest.mark.skipif(no_gpu, reason="Test requires GPU")
|
||||
@pytest.mark.gpu
|
||||
@pytest.mark.parametrize("max_bag_size, max_bag_size_inf", [(2, 0), (2, 3)])
|
||||
def test_slides_datamodule_different_bag_sizes(
|
||||
mock_panda_slides_root_dir: Path, max_bag_size: int, max_bag_size_inf: int
|
||||
) -> None:
|
||||
datamodule = PandaSlidesDataModule(
|
||||
root_path=mock_panda_slides_root_dir,
|
||||
batch_size=2,
|
||||
max_bag_size=max_bag_size,
|
||||
max_bag_size_inf=max_bag_size_inf,
|
||||
tile_size=28,
|
||||
level=0,
|
||||
)
|
||||
# To account for the fact that slides datamodule fomats 0 to None so that it's compatible with TileOnGrid transform
|
||||
max_bag_size_inf = max_bag_size_inf if max_bag_size_inf != 0 else None # type: ignore
|
||||
# For slides datamodule, the true bag sizes [4, 4] are the same as requested to TileOnGrid transform
|
||||
_assert_correct_bag_sizes(datamodule, max_bag_size, max_bag_size_inf, true_bag_sizes=[4, 4])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("max_bag_size, max_bag_size_inf", [(2, 0), (2, 3)])
|
||||
def test_tiles_datamodule_different_bag_sizes(
|
||||
mock_panda_tiles_root_dir: Path, max_bag_size: int, max_bag_size_inf: int
|
||||
) -> None:
|
||||
datamodule = PandaTilesDataModule(
|
||||
root_path=mock_panda_tiles_root_dir,
|
||||
batch_size=2,
|
||||
max_bag_size=max_bag_size,
|
||||
max_bag_size_inf=max_bag_size_inf,
|
||||
)
|
||||
# For tiles datamodule, the true bag sizes [4, 5] were generated by the mock data generator for a fixed seed 42
|
||||
# If test fails, check if the mock data generator has changed and update the true bag sizes
|
||||
_assert_correct_bag_sizes(datamodule, max_bag_size, max_bag_size_inf, true_bag_sizes=[4, 5])
|
||||
|
||||
|
||||
@pytest.mark.skipif(no_gpu, reason="Test requires GPU")
|
||||
@pytest.mark.gpu
|
||||
@pytest.mark.parametrize("batch_size, batch_size_inf", [(2, 2), (2, 1), (2, None)])
|
||||
def test_slides_datamodule_different_batch_sizes(
|
||||
mock_panda_slides_root_dir: Path, batch_size: int, batch_size_inf: Optional[int],
|
||||
) -> None:
|
||||
datamodule = PandaSlidesDataModule(
|
||||
root_path=mock_panda_slides_root_dir,
|
||||
batch_size=batch_size,
|
||||
batch_size_inf=batch_size_inf,
|
||||
max_bag_size=16,
|
||||
max_bag_size_inf=16,
|
||||
tile_size=28,
|
||||
level=0,
|
||||
)
|
||||
_assert_correct_batch_sizes(datamodule, batch_size, batch_size_inf)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size, batch_size_inf", [(2, 2), (2, 1), (2, None)])
|
||||
def test_tiles_datamodule_different_batch_sizes(
|
||||
mock_panda_tiles_root_dir: Path, batch_size: int, batch_size_inf: Optional[int],
|
||||
) -> None:
|
||||
datamodule = PandaTilesDataModule(
|
||||
root_path=mock_panda_tiles_root_dir,
|
||||
batch_size=batch_size,
|
||||
batch_size_inf=batch_size_inf,
|
||||
max_bag_size=16,
|
||||
max_bag_size_inf=16,
|
||||
)
|
||||
_assert_correct_batch_sizes(datamodule, batch_size, batch_size_inf)
|
||||
|
||||
|
||||
def _validate_sampler_type(datamodule: HistoDataModule, stages: List[ModelKey], expected_none: bool) -> None:
|
||||
expected_sampler_types = {
|
||||
ModelKey.TRAIN: RandomSampler if expected_none else DistributedSampler,
|
||||
ModelKey.VAL: SequentialSampler,
|
||||
ModelKey.TEST: SequentialSampler,
|
||||
}
|
||||
for stage in stages:
|
||||
datamodule_sampler = datamodule._get_ddp_sampler(getattr(datamodule, f'{stage}_dataset'), stage)
|
||||
assert (datamodule_sampler is None) == expected_none
|
||||
dataloader = getattr(datamodule, f'{stage.value}_dataloader')()
|
||||
assert isinstance(dataloader.sampler, expected_sampler_types[stage])
|
||||
|
||||
|
||||
def _test_datamodule_pl_ddp_sampler_true(
|
||||
datamodule: HistoDataModule, rank: int = 0, world_size: int = 1, device: str = "cpu"
|
||||
) -> None:
|
||||
datamodule.setup()
|
||||
_validate_sampler_type(datamodule, [ModelKey.TRAIN, ModelKey.VAL, ModelKey.TEST], expected_none=True)
|
||||
|
||||
|
||||
def _test_datamodule_pl_ddp_sampler_false(
|
||||
datamodule: HistoDataModule, rank: int = 0, world_size: int = 1, device: str = "cpu"
|
||||
) -> None:
|
||||
datamodule.setup()
|
||||
_validate_sampler_type(datamodule, [ModelKey.VAL, ModelKey.TEST], expected_none=True)
|
||||
_validate_sampler_type(datamodule, [ModelKey.TRAIN], expected_none=False)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.distributed.is_available(), reason="PyTorch distributed unavailable")
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Not enough GPUs available")
|
||||
@pytest.mark.gpu
|
||||
def test_slides_datamodule_pl_replace_sampler_ddp(mock_panda_slides_root_dir: Path) -> None:
|
||||
slides_datamodule = PandaSlidesDataModule(root_path=mock_panda_slides_root_dir,
|
||||
pl_replace_sampler_ddp=True,
|
||||
seed=42)
|
||||
run_distributed(_test_datamodule_pl_ddp_sampler_true, [slides_datamodule], world_size=2)
|
||||
slides_datamodule.pl_replace_sampler_ddp = False
|
||||
run_distributed(_test_datamodule_pl_ddp_sampler_false, [slides_datamodule], world_size=2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.distributed.is_available(), reason="PyTorch distributed unavailable")
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Not enough GPUs available")
|
||||
@pytest.mark.gpu
|
||||
def test_tiles_datamodule_pl_replace_sampler_ddp_true(mock_panda_tiles_root_dir: Path) -> None:
|
||||
tiles_datamodule = PandaTilesDataModule(root_path=mock_panda_tiles_root_dir, seed=42, pl_replace_sampler_ddp=True)
|
||||
run_distributed(_test_datamodule_pl_ddp_sampler_true, [tiles_datamodule], world_size=2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.distributed.is_available(), reason="PyTorch distributed unavailable")
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Not enough GPUs available")
|
||||
@pytest.mark.gpu
|
||||
def test_tiles_datamodule_pl_replace_sampler_ddp_false(mock_panda_tiles_root_dir: Path) -> None:
|
||||
tiles_datamodule = PandaTilesDataModule(root_path=mock_panda_tiles_root_dir, seed=42, pl_replace_sampler_ddp=False)
|
||||
run_distributed(_test_datamodule_pl_ddp_sampler_false, [tiles_datamodule], world_size=2)
|
||||
|
||||
|
||||
def test_assertion_error_missing_seed(mock_panda_slides_root_dir: Path) -> None:
|
||||
with pytest.raises(AssertionError, match="seed must be set when using distributed training for reproducibility"):
|
||||
with patch("torch.distributed.is_initialized", return_value=True):
|
||||
with patch("torch.distributed.get_world_size", return_value=2):
|
||||
slides_datamodule = PandaSlidesDataModule(
|
||||
root_path=mock_panda_slides_root_dir, pl_replace_sampler_ddp=False
|
||||
)
|
||||
slides_datamodule._get_ddp_sampler(MagicMock(), ModelKey.TRAIN)
|
|
@ -1,17 +1,17 @@
|
|||
|
||||
|
||||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import shutil
|
||||
from typing import Generator, Dict, Callable, Union, Tuple
|
||||
import pytest
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from monai.transforms import RandFlipd
|
||||
from typing import Generator, Dict, Callable, Union, Tuple
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from health_ml.utils.common_utils import is_gpu_available
|
||||
from health_cpath.datamodules.base_module import SlidesDataModule
|
||||
|
@ -29,7 +29,7 @@ no_gpu = not is_gpu_available()
|
|||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def mock_panda_slides_root_dir(
|
||||
def mock_panda_slides_root_dir_diagonal(
|
||||
tmp_path_factory: pytest.TempPathFactory, tmp_path_to_pathmnist_dataset: Path
|
||||
) -> Generator:
|
||||
tmp_root_dir = tmp_path_factory.mktemp("mock_wsi")
|
||||
|
@ -38,7 +38,7 @@ def mock_panda_slides_root_dir(
|
|||
src_data_path=tmp_path_to_pathmnist_dataset,
|
||||
mock_type=MockHistoDataType.PATHMNIST,
|
||||
n_tiles=1,
|
||||
n_slides=10,
|
||||
n_slides=16,
|
||||
n_repeat_diag=4,
|
||||
n_repeat_tile=2,
|
||||
n_channels=3,
|
||||
|
@ -83,7 +83,7 @@ def get_original_tile(mock_dir: Path, wsi_id: str) -> np.ndarray:
|
|||
|
||||
@pytest.mark.skipif(no_gpu, reason="Test requires GPU")
|
||||
@pytest.mark.gpu
|
||||
def test_tiling_on_the_fly(mock_panda_slides_root_dir: Path) -> None:
|
||||
def test_tiling_on_the_fly(mock_panda_slides_root_dir_diagonal: Path) -> None:
|
||||
batch_size = 1
|
||||
tile_count = 16
|
||||
tile_size = 28
|
||||
|
@ -91,7 +91,7 @@ def test_tiling_on_the_fly(mock_panda_slides_root_dir: Path) -> None:
|
|||
channels = 3
|
||||
assert_batch_index = 0
|
||||
datamodule = PandaSlidesDataModule(
|
||||
root_path=mock_panda_slides_root_dir,
|
||||
root_path=mock_panda_slides_root_dir_diagonal,
|
||||
batch_size=batch_size,
|
||||
max_bag_size=tile_count,
|
||||
tile_size=tile_size,
|
||||
|
@ -105,14 +105,14 @@ def test_tiling_on_the_fly(mock_panda_slides_root_dir: Path) -> None:
|
|||
assert tiles[assert_batch_index].shape == (tile_count, channels, tile_size, tile_size)
|
||||
|
||||
# check tiling on the fly
|
||||
original_tile = get_original_tile(mock_panda_slides_root_dir, wsi_id)
|
||||
original_tile = get_original_tile(mock_panda_slides_root_dir_diagonal, wsi_id)
|
||||
for i in range(tile_count):
|
||||
assert (original_tile == tiles[assert_batch_index][i].numpy()).all()
|
||||
|
||||
|
||||
@pytest.mark.skipif(no_gpu, reason="Test requires GPU")
|
||||
@pytest.mark.gpu
|
||||
def test_tiling_without_fixed_tile_count(mock_panda_slides_root_dir: Path) -> None:
|
||||
def test_tiling_without_fixed_tile_count(mock_panda_slides_root_dir_diagonal: Path) -> None:
|
||||
batch_size = 1
|
||||
tile_count = None
|
||||
tile_size = 28
|
||||
|
@ -120,7 +120,7 @@ def test_tiling_without_fixed_tile_count(mock_panda_slides_root_dir: Path) -> No
|
|||
assert_batch_index = 0
|
||||
min_expected_tile_count = 16
|
||||
datamodule = PandaSlidesDataModule(
|
||||
root_path=mock_panda_slides_root_dir,
|
||||
root_path=mock_panda_slides_root_dir_diagonal,
|
||||
batch_size=batch_size,
|
||||
max_bag_size=tile_count,
|
||||
tile_size=tile_size,
|
||||
|
@ -135,14 +135,14 @@ def test_tiling_without_fixed_tile_count(mock_panda_slides_root_dir: Path) -> No
|
|||
@pytest.mark.skipif(no_gpu, reason="Test requires GPU")
|
||||
@pytest.mark.gpu
|
||||
@pytest.mark.parametrize("level", [0, 1, 2])
|
||||
def test_multi_resolution_tiling(level: int, mock_panda_slides_root_dir: Path) -> None:
|
||||
def test_multi_resolution_tiling(level: int, mock_panda_slides_root_dir_diagonal: Path) -> None:
|
||||
batch_size = 1
|
||||
tile_count = 16
|
||||
channels = 3
|
||||
tile_size = 28 // 2 ** level
|
||||
assert_batch_index = 0
|
||||
datamodule = PandaSlidesDataModule(
|
||||
root_path=mock_panda_slides_root_dir,
|
||||
root_path=mock_panda_slides_root_dir_diagonal,
|
||||
batch_size=batch_size,
|
||||
max_bag_size=tile_count,
|
||||
tile_size=tile_size,
|
||||
|
@ -155,7 +155,7 @@ def test_multi_resolution_tiling(level: int, mock_panda_slides_root_dir: Path) -
|
|||
assert tiles[assert_batch_index].shape == (tile_count, channels, tile_size, tile_size)
|
||||
|
||||
# check tiling on the fly at different resolutions
|
||||
original_tile = get_original_tile(mock_panda_slides_root_dir, wsi_id)
|
||||
original_tile = get_original_tile(mock_panda_slides_root_dir_diagonal, wsi_id)
|
||||
for i in range(tile_count):
|
||||
# multi resolution mock data has been created via 2 factor downsampling
|
||||
assert (original_tile[:, :: 2 ** level, :: 2 ** level] == tiles[assert_batch_index][i].numpy()).all()
|
||||
|
@ -164,7 +164,7 @@ def test_multi_resolution_tiling(level: int, mock_panda_slides_root_dir: Path) -
|
|||
@pytest.mark.skipif(no_gpu, reason="Test requires GPU")
|
||||
@pytest.mark.gpu
|
||||
@pytest.mark.parametrize("batch_size", [1, 2])
|
||||
def test_overlapping_tiles(batch_size: int, mock_panda_slides_root_dir: Path) -> None:
|
||||
def test_overlapping_tiles(batch_size: int, mock_panda_slides_root_dir_diagonal: Path) -> None:
|
||||
tile_size = 28
|
||||
level = 0
|
||||
step = 14
|
||||
|
@ -172,7 +172,7 @@ def test_overlapping_tiles(batch_size: int, mock_panda_slides_root_dir: Path) ->
|
|||
min_expected_tile_count = 32
|
||||
assert_batch_index = 0
|
||||
datamodule = PandaSlidesDataModule(
|
||||
root_path=mock_panda_slides_root_dir,
|
||||
root_path=mock_panda_slides_root_dir_diagonal,
|
||||
max_bag_size=None,
|
||||
batch_size=batch_size,
|
||||
tile_size=tile_size,
|
||||
|
@ -184,7 +184,7 @@ def test_overlapping_tiles(batch_size: int, mock_panda_slides_root_dir: Path) ->
|
|||
tiles, wsi_id = sample[SlideKey.IMAGE], sample[SlideKey.SLIDE_ID][assert_batch_index]
|
||||
assert tiles[assert_batch_index].shape[0] >= min_expected_tile_count
|
||||
|
||||
original_tile = get_original_tile(mock_panda_slides_root_dir, wsi_id)
|
||||
original_tile = get_original_tile(mock_panda_slides_root_dir_diagonal, wsi_id)
|
||||
tile_matches = 0
|
||||
for _, tile in enumerate(tiles[assert_batch_index]):
|
||||
tile_matches += int((tile.numpy() == original_tile).all())
|
||||
|
@ -193,12 +193,12 @@ def test_overlapping_tiles(batch_size: int, mock_panda_slides_root_dir: Path) ->
|
|||
|
||||
@pytest.mark.skipif(no_gpu, reason="Test requires GPU")
|
||||
@pytest.mark.gpu
|
||||
def test_train_test_transforms(mock_panda_slides_root_dir: Path) -> None:
|
||||
def test_train_test_transforms(mock_panda_slides_root_dir_diagonal: Path) -> None:
|
||||
def get_transforms_dict() -> Dict[ModelKey, Union[Callable, None]]:
|
||||
train_transform = RandFlipd(keys=[SlideKey.IMAGE], spatial_axis=0, prob=1.0)
|
||||
return {ModelKey.TRAIN: train_transform, ModelKey.VAL: None, ModelKey.TEST: None} # type: ignore
|
||||
|
||||
def retrieve_tiles(dataloader: torch.utils.data.DataLoader) -> Dict[str, torch.Tensor]:
|
||||
def retrieve_tiles(dataloader: DataLoader) -> Dict[str, torch.Tensor]:
|
||||
tiles_dict = {}
|
||||
assert_batch_index = 0
|
||||
for sample in dataloader:
|
||||
|
@ -211,7 +211,7 @@ def test_train_test_transforms(mock_panda_slides_root_dir: Path) -> None:
|
|||
tile_size = 28
|
||||
level = 0
|
||||
flipdatamodule = PandaSlidesDataModule(
|
||||
root_path=mock_panda_slides_root_dir,
|
||||
root_path=mock_panda_slides_root_dir_diagonal,
|
||||
batch_size=batch_size,
|
||||
max_bag_size=tile_count,
|
||||
max_bag_size_inf=0,
|
||||
|
@ -224,20 +224,20 @@ def test_train_test_transforms(mock_panda_slides_root_dir: Path) -> None:
|
|||
flip_test_tiles = retrieve_tiles(flipdatamodule.test_dataloader())
|
||||
|
||||
for wsi_id in flip_train_tiles.keys():
|
||||
original_tile = get_original_tile(mock_panda_slides_root_dir, wsi_id)
|
||||
original_tile = get_original_tile(mock_panda_slides_root_dir_diagonal, wsi_id)
|
||||
# the first dimension is the channel, flipping happened on the horizontal axis of the image
|
||||
transformed_original_tile = np.flip(original_tile, axis=1)
|
||||
for tile in flip_train_tiles[wsi_id]:
|
||||
assert (tile.numpy() == transformed_original_tile).all()
|
||||
|
||||
for wsi_id in flip_val_tiles.keys():
|
||||
original_tile = get_original_tile(mock_panda_slides_root_dir, wsi_id)
|
||||
original_tile = get_original_tile(mock_panda_slides_root_dir_diagonal, wsi_id)
|
||||
for tile in flip_val_tiles[wsi_id]:
|
||||
# no transformation has been applied to val tiles
|
||||
assert (tile.numpy() == original_tile).all()
|
||||
|
||||
for wsi_id in flip_test_tiles.keys():
|
||||
original_tile = get_original_tile(mock_panda_slides_root_dir, wsi_id)
|
||||
original_tile = get_original_tile(mock_panda_slides_root_dir_diagonal, wsi_id)
|
||||
for tile in flip_test_tiles[wsi_id]:
|
||||
# no transformation has been applied to test tiles
|
||||
assert (tile.numpy() == original_tile).all()
|
||||
|
@ -251,7 +251,6 @@ class MockPandaSlidesDataModule(SlidesDataModule):
|
|||
"""
|
||||
|
||||
def get_splits(self) -> Tuple[PandaDataset, PandaDataset, PandaDataset]:
|
||||
|
||||
return (PandaDataset(self.root_path), PandaDataset(self.root_path), PandaDataset(self.root_path))
|
||||
|
||||
|
||||
|
@ -278,7 +277,7 @@ def test_whole_slide_inference(batch_size: int, mock_panda_slides_root_with_diff
|
|||
tiles = sample[SlideKey.IMAGE]
|
||||
assert tiles[assert_batch_index].shape[0] == tile_count
|
||||
|
||||
def assert_whole_slide_inference_with_all_tiles(dataloader: torch.utils.data.DataLoader) -> None:
|
||||
def assert_whole_slide_inference_with_all_tiles(dataloader: DataLoader) -> None:
|
||||
for i, sample in enumerate(dataloader):
|
||||
tiles = sample[SlideKey.IMAGE]
|
||||
assert tiles[assert_batch_index].shape[0] == n_tiles_list[i * batch_size]
|
||||
|
|
|
@ -3,15 +3,13 @@
|
|||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
from copy import deepcopy
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from pytorch_lightning import Trainer
|
||||
import torch
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Type
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Type
|
||||
|
||||
from torch import Tensor, argmax, nn, rand, randint, randn, round, stack, allclose
|
||||
from torch.utils.data._utils.collate import default_collate
|
||||
|
@ -30,10 +28,9 @@ from health_cpath.datasets.base_dataset import DEFAULT_LABEL_COLUMN, TilesDatase
|
|||
from health_cpath.datasets.default_paths import PANDA_5X_TILES_DATASET_ID, TCGA_CRCK_DATASET_DIR
|
||||
from health_cpath.models.deepmil import BaseDeepMILModule, TilesDeepMILModule
|
||||
from health_cpath.models.encoders import IdentityEncoder, ImageNetEncoder, Resnet18, TileEncoder
|
||||
from health_cpath.utils.deepmil_utils import EncoderParams, PoolingParams
|
||||
from health_cpath.utils.deepmil_utils import ClassifierParams, EncoderParams, PoolingParams
|
||||
from health_cpath.utils.naming import DeepMILSubmodules, MetricsKey, ResultsKey
|
||||
from testhisto.mocks.base_data_generator import MockHistoDataType
|
||||
from testhisto.mocks.slides_generator import MockPandaSlidesGenerator, TilesPositioningType
|
||||
from testhisto.mocks.tiles_generator import MockPandaTilesGenerator
|
||||
from testhisto.mocks.container import MockDeepSMILETilesPanda, MockDeepSMILESlidesPanda
|
||||
from health_ml.utils.common_utils import is_gpu_available
|
||||
|
@ -75,7 +72,7 @@ def _test_lightningmodule(
|
|||
module = TilesDeepMILModule(
|
||||
label_column=DEFAULT_LABEL_COLUMN,
|
||||
n_classes=n_classes,
|
||||
dropout_rate=dropout_rate,
|
||||
classifier_params=ClassifierParams(dropout_rate=dropout_rate),
|
||||
encoder_params=get_supervised_imagenet_encoder_params(),
|
||||
pooling_params=get_attention_pooling_layer_params(pool_out_dim)
|
||||
)
|
||||
|
@ -141,29 +138,6 @@ def _test_lightningmodule(
|
|||
assert torch.all(score <= 1)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def mock_panda_slides_root_dir(
|
||||
tmp_path_factory: pytest.TempPathFactory, tmp_path_to_pathmnist_dataset: Path
|
||||
) -> Generator:
|
||||
tmp_root_dir = tmp_path_factory.mktemp("mock_slides")
|
||||
wsi_generator = MockPandaSlidesGenerator(
|
||||
dest_data_path=tmp_root_dir,
|
||||
src_data_path=tmp_path_to_pathmnist_dataset,
|
||||
mock_type=MockHistoDataType.PATHMNIST,
|
||||
n_tiles=4,
|
||||
n_slides=10,
|
||||
n_channels=3,
|
||||
n_levels=3,
|
||||
tile_size=28,
|
||||
background_val=255,
|
||||
tiles_pos_type=TilesPositioningType.RANDOM
|
||||
)
|
||||
logging.info("Generating temporary mock slides that will be deleted at the end of the session.")
|
||||
wsi_generator.generate_mock_histo_data()
|
||||
yield tmp_root_dir
|
||||
shutil.rmtree(tmp_root_dir)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_classes", [1, 3])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize("max_bag_size", [1, 5])
|
||||
|
@ -326,9 +300,10 @@ def test_container(container_type: Type[BaseMILTiles], use_gpu: bool) -> None:
|
|||
container = container_type()
|
||||
|
||||
container.setup()
|
||||
container.batch_size = 10
|
||||
container.batch_size_inf = 10
|
||||
|
||||
data_module: TilesDataModule = container.get_data_module() # type: ignore
|
||||
data_module.max_bag_size = 10
|
||||
|
||||
module = container.create_model()
|
||||
module.outputs_handler = MagicMock()
|
||||
|
@ -461,12 +436,12 @@ def test_finetuning_options(
|
|||
label_column=DEFAULT_LABEL_COLUMN,
|
||||
encoder_params=get_supervised_imagenet_encoder_params(tune_encoder=tune_encoder),
|
||||
pooling_params=get_attention_pooling_layer_params(pool_out_dim=1, tune_pooling=tune_pooling),
|
||||
tune_classifier=tune_classifier,
|
||||
classifier_params=ClassifierParams(tune_classifier=tune_classifier),
|
||||
)
|
||||
|
||||
assert module.encoder_params.tune_encoder == tune_encoder
|
||||
assert module.pooling_params.tune_pooling == tune_pooling
|
||||
assert module.tune_classifier == tune_classifier
|
||||
assert module.classifier_params.tune_classifier == tune_classifier
|
||||
|
||||
for params in module.encoder.parameters():
|
||||
assert params.requires_grad == tune_encoder
|
||||
|
@ -513,7 +488,7 @@ def test_training_with_different_finetuning_options(
|
|||
label_column=MockPandaTilesGenerator.ISUP_GRADE,
|
||||
encoder_params=get_supervised_imagenet_encoder_params(tune_encoder=tune_encoder),
|
||||
pooling_params=get_attention_pooling_layer_params(pool_out_dim=1, tune_pooling=tune_pooling),
|
||||
tune_classifier=tune_classifier,
|
||||
classifier_params=ClassifierParams(tune_classifier=tune_classifier),
|
||||
)
|
||||
|
||||
def _assert_existing_gradients(module: nn.Module, tuning_flag: bool) -> None:
|
||||
|
@ -550,7 +525,7 @@ def test_init_weights_options(pretrained_encoder: bool, pretrained_pooling: bool
|
|||
)
|
||||
module.encoder_params.pretrained_encoder = pretrained_encoder
|
||||
module.pooling_params.pretrained_pooling = pretrained_pooling
|
||||
module.pretrained_classifier = pretrained_classifier
|
||||
module.classifier_params.pretrained_classifier = pretrained_classifier
|
||||
|
||||
with patch.object(module, "load_from_checkpoint") as mock_load_from_checkpoint:
|
||||
with patch.object(module, "copy_weights") as mock_copy_weights:
|
||||
|
@ -576,10 +551,11 @@ def _get_tiles_deepmil_module(
|
|||
label_column=MockPandaTilesGenerator.ISUP_GRADE,
|
||||
encoder_params=get_supervised_imagenet_encoder_params(),
|
||||
pooling_params=get_transformer_pooling_layer_params(num_layers, num_heads, hidden_dim, transformer_dropout),
|
||||
classifier_params=ClassifierParams(pretrained_classifier=pretrained_classifier),
|
||||
)
|
||||
module.encoder_params.pretrained_encoder = pretrained_encoder
|
||||
module.pooling_params.pretrained_pooling = pretrained_pooling
|
||||
module.pretrained_classifier = pretrained_classifier
|
||||
module.classifier_params.pretrained_classifier = pretrained_classifier
|
||||
return module
|
||||
|
||||
|
||||
|
@ -711,13 +687,13 @@ def test_on_run_extra_val_epoch(mock_panda_tiles_root_dir: Path) -> None:
|
|||
container.setup()
|
||||
container.data_module = MagicMock()
|
||||
container.create_lightning_module_and_store()
|
||||
assert not container.model._run_extra_val_epoch
|
||||
assert not container.model._on_extra_val_epoch
|
||||
assert (
|
||||
container.model.outputs_handler.test_plots_handler.plot_options # type: ignore
|
||||
!= container.model.outputs_handler.val_plots_handler.plot_options # type: ignore
|
||||
)
|
||||
container.on_run_extra_validation_epoch()
|
||||
assert container.model._run_extra_val_epoch
|
||||
assert container.model._on_extra_val_epoch
|
||||
assert (
|
||||
container.model.outputs_handler.test_plots_handler.plot_options # type: ignore
|
||||
== container.model.outputs_handler.val_plots_handler.plot_options # type: ignore
|
||||
|
|
|
@ -191,7 +191,7 @@ def test_on_train_epoch_end_distributed(tmp_path: Path) -> None:
|
|||
|
||||
|
||||
def test_on_train_and_val_end(tmp_path: Path) -> None:
|
||||
pl_module = MagicMock(global_rank=0, _run_extra_val_epoch=False)
|
||||
pl_module = MagicMock(global_rank=0, _on_extra_val_epoch=False)
|
||||
max_epochs = 4
|
||||
trainer = MagicMock(current_epoch=max_epochs - 1)
|
||||
|
||||
|
@ -224,7 +224,7 @@ def test_on_train_and_val_end(tmp_path: Path) -> None:
|
|||
|
||||
|
||||
def test_on_validation_end_not_called_if_extra_val_epoch(tmp_path: Path) -> None:
|
||||
pl_module = MagicMock(global_rank=0, _run_extra_val_epoch=True)
|
||||
pl_module = MagicMock(global_rank=0, _on_extra_val_epoch=True)
|
||||
max_epochs = 4
|
||||
trainer = MagicMock(current_epoch=0)
|
||||
loss_callback = LossAnalysisCallback(
|
||||
|
@ -264,7 +264,7 @@ def test_nans_detection(tmp_path: Path, caplog: pytest.LogCaptureFixture) -> Non
|
|||
def test_log_exceptions_flag(log_exceptions: bool, tmp_path: Path, caplog: pytest.LogCaptureFixture) -> None:
|
||||
max_epochs = 3
|
||||
trainer = MagicMock(current_epoch=max_epochs - 1)
|
||||
pl_module = MagicMock(global_rank=0, _run_extra_val_epoch=False)
|
||||
pl_module = MagicMock(global_rank=0, _on_extra_val_epoch=False)
|
||||
loss_callback = LossAnalysisCallback(
|
||||
outputs_folder=tmp_path, max_epochs=max_epochs,
|
||||
num_slides_heatmap=2, num_slides_scatter=2, log_exceptions=log_exceptions
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.multiprocessing
|
||||
from ruamel.yaml import YAML
|
||||
from health_cpath.utils.tiles_selection_utils import TilesSelector
|
||||
from testhisto.utils.utils_testhisto import run_distributed
|
||||
from torch.testing import assert_close
|
||||
from torchmetrics.metric import Metric
|
||||
|
@ -251,3 +252,27 @@ def test_collate_results_multigpu() -> None:
|
|||
epoch_results = _create_epoch_results(batch_size, num_batches, rank=0, device='cuda:0') \
|
||||
+ _create_epoch_results(batch_size, num_batches, rank=1, device='cuda:1')
|
||||
_test_collate_results(epoch_results, total_num_samples=2 * num_batches * batch_size)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('val_set_is_dist', [True, False])
|
||||
def test_results_gathering_with_val_set_is_dist_flag(val_set_is_dist: bool, tmp_path: Path) -> None:
|
||||
outputs_handler = _create_outputs_handler(tmp_path)
|
||||
outputs_handler.tiles_selector = TilesSelector(2, num_slides=2, num_tiles=2)
|
||||
outputs_handler.val_set_is_dist = val_set_is_dist
|
||||
outputs_handler._save_outputs = MagicMock() # type: ignore
|
||||
metric_value = 0.5
|
||||
with patch("health_cpath.utils.output_utils.gather_results") as mock_gather_results:
|
||||
with patch.object(outputs_handler.tiles_selector, "gather_selected_tiles_across_devices") as mock_gather_tiles:
|
||||
with patch.object(outputs_handler.tiles_selector, "_clear_cached_slides_heaps") as mock_clear_cache:
|
||||
with patch.object(outputs_handler, "should_gather_tiles") as mock_should_gather_tiles:
|
||||
mock_should_gather_tiles.return_value = True
|
||||
for rank in range(2):
|
||||
epoch_results = [{_PRIMARY_METRIC_KEY: [metric_value] * 5, _RANK_KEY: rank}]
|
||||
outputs_handler.save_validation_outputs(
|
||||
epoch_results=epoch_results, # type: ignore
|
||||
metrics_dict=_get_mock_metrics_dict(metric_value),
|
||||
epoch=0,
|
||||
is_global_rank_zero=rank == 0)
|
||||
assert mock_gather_results.called == val_set_is_dist
|
||||
assert mock_gather_tiles.called == val_set_is_dist
|
||||
mock_clear_cache.assert_called()
|
||||
|
|
|
@ -122,7 +122,7 @@ class HelloRegression(LightningModule):
|
|||
self.model = torch.nn.Linear(in_features=1, out_features=1, bias=True) # type: ignore
|
||||
self.test_mse: List[torch.Tensor] = []
|
||||
self.test_mae = MeanAbsoluteError()
|
||||
self._run_extra_val_epoch = False
|
||||
self._on_extra_val_epoch = False
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
|
||||
"""
|
||||
|
@ -233,7 +233,7 @@ class HelloRegression(LightningModule):
|
|||
self.log("test_mse", average_mse, on_epoch=True, on_step=False)
|
||||
|
||||
def on_run_extra_validation_epoch(self) -> None:
|
||||
self._run_extra_val_epoch = True
|
||||
self._on_extra_val_epoch = True
|
||||
|
||||
|
||||
class HelloWorld(LightningContainer):
|
||||
|
|
|
@ -509,6 +509,11 @@ class TrainerParams(param.Parameterized):
|
|||
pl_log_every_n_steps: int = param.Integer(default=50,
|
||||
doc="PyTorch Lightning trainer flag 'log_every_n_steps': How often to"
|
||||
"log within steps. Default to 50.")
|
||||
pl_replace_sampler_ddp: bool = param.Boolean(default=True,
|
||||
doc="PyTorch Lightning trainer flag 'replace_sampler_ddp' that "
|
||||
"sets the sampler for distributed training with shuffle=True during "
|
||||
"training and shuffle=False during validation. Default to True. Set to"
|
||||
"False to set your own sampler.")
|
||||
|
||||
@property
|
||||
def use_gpu(self) -> bool:
|
||||
|
|
|
@ -210,5 +210,6 @@ def create_lightning_trainer(container: LightningContainer,
|
|||
resume_from_checkpoint=str(resume_from_checkpoint) if resume_from_checkpoint else None,
|
||||
multiple_trainloader_mode=multiple_trainloader_mode,
|
||||
accumulate_grad_batches=container.pl_accumulate_grad_batches,
|
||||
replace_sampler_ddp=container.pl_replace_sampler_ddp,
|
||||
**additional_args)
|
||||
return trainer, storing_logger
|
||||
|
|
Загрузка…
Ссылка в новой задаче