diff --git a/hi-ml-cpath/src/health_cpath/configs/classification/BaseMIL.py b/hi-ml-cpath/src/health_cpath/configs/classification/BaseMIL.py index edb646e3..0d046bb3 100644 --- a/hi-ml-cpath/src/health_cpath/configs/classification/BaseMIL.py +++ b/hi-ml-cpath/src/health_cpath/configs/classification/BaseMIL.py @@ -86,6 +86,9 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams, LossCallbackPara doc="The maximum number of worker processes for dataloaders. Dataloaders use" "a heuristic num_cpus/num_gpus to set the number of workers, which can be" "very high for small num_gpus. This parameters sets an upper bound.") + wsi_has_mask: bool = param.Boolean(default=True, + doc="Whether the WSI has a mask. If True, will use the mask to load a specific" + "region of the WSI. If False, will load the whole WSI.") def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) @@ -152,6 +155,7 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams, LossCallbackPara maximise=self.maximise_primary_metric, val_plot_options=self.get_val_plot_options(), test_plot_options=self.get_test_plot_options(), + wsi_has_mask=self.wsi_has_mask, ) if self.num_top_slides > 0: outputs_handler.tiles_selector = TilesSelector( diff --git a/hi-ml-cpath/src/health_cpath/datasets/panda_dataset.py b/hi-ml-cpath/src/health_cpath/datasets/panda_dataset.py index c1f19c3a..ad8812bc 100644 --- a/hi-ml-cpath/src/health_cpath/datasets/panda_dataset.py +++ b/hi-ml-cpath/src/health_cpath/datasets/panda_dataset.py @@ -10,6 +10,7 @@ import pandas as pd from monai.config import KeysCollection from monai.data.image_reader import ImageReader, WSIReader from monai.transforms import MapTransform +from health_cpath.utils.naming import SlideKey from health_ml.utils import box_utils @@ -121,8 +122,9 @@ class LoadPandaROId(MapTransform): # but relative region size in pixels at the chosen level scale = mask_obj.resolutions['level_downsamples'][self.level] scaled_bbox = level0_bbox / scale + origin = (level0_bbox.y, level0_bbox.x) get_data_kwargs = dict( - location=(level0_bbox.y, level0_bbox.x), + location=origin, size=(scaled_bbox.h, scaled_bbox.w), level=self.level, ) @@ -130,7 +132,8 @@ class LoadPandaROId(MapTransform): data[self.mask_key] = mask[:1] # PANDA segmentation mask is in 'R' channel data[self.image_key], _ = self.reader.get_data(image_obj, **get_data_kwargs) # type: ignore data.update(get_data_kwargs) - data['scale'] = scale + data[SlideKey.SCALE] = scale + data[SlideKey.ORIGIN] = origin mask_obj.close() image_obj.close() diff --git a/hi-ml-cpath/src/health_cpath/utils/output_utils.py b/hi-ml-cpath/src/health_cpath/utils/output_utils.py index 78ce0799..5f295bd2 100644 --- a/hi-ml-cpath/src/health_cpath/utils/output_utils.py +++ b/hi-ml-cpath/src/health_cpath/utils/output_utils.py @@ -249,7 +249,7 @@ class DeepMILOutputsHandler: def __init__(self, outputs_root: Path, n_classes: int, tile_size: int, level: int, class_names: Optional[Sequence[str]], primary_val_metric: MetricsKey, maximise: bool, val_plot_options: Collection[PlotOption], - test_plot_options: Collection[PlotOption]) -> None: + test_plot_options: Collection[PlotOption], wsi_has_mask: bool = True) -> None: """ :param outputs_root: Root directory where to save all produced outputs. :param n_classes: Number of MIL classes (set `n_classes=1` for binary). @@ -261,6 +261,7 @@ class DeepMILOutputsHandler: :param maximise: Whether higher is better for `primary_val_metric`. :param val_plot_options: The desired plot options for validation time. :param test_plot_options: The desired plot options for test time. + :param wsi_has_mask: Whether the whole slides have a mask to crop specific ROIs. """ self.outputs_root = outputs_root self.n_classes = n_classes @@ -279,14 +280,16 @@ class DeepMILOutputsHandler: level=self.level, tile_size=self.tile_size, class_names=self.class_names, - stage=ModelKey.VAL + stage=ModelKey.VAL, + wsi_has_mask=wsi_has_mask ) self.test_plots_handler = DeepMILPlotsHandler( plot_options=test_plot_options, level=self.level, tile_size=self.tile_size, class_names=self.class_names, - stage=ModelKey.TEST + stage=ModelKey.TEST, + wsi_has_mask=wsi_has_mask ) @property diff --git a/hi-ml-cpath/src/health_cpath/utils/plots_utils.py b/hi-ml-cpath/src/health_cpath/utils/plots_utils.py index 738254a2..b8f3048a 100644 --- a/hi-ml-cpath/src/health_cpath/utils/plots_utils.py +++ b/hi-ml-cpath/src/health_cpath/utils/plots_utils.py @@ -128,6 +128,7 @@ def save_slide_thumbnail_and_heatmap( slides_dataset: SlidesDataset, tile_size: int = 224, level: int = 1, + wsi_has_mask: bool = True, ) -> None: """Plots and saves a slide thumbnail and attention heatmap @@ -139,13 +140,14 @@ def save_slide_thumbnail_and_heatmap( :param tile_size: Size of each tile. Default 224. :param level: Magnification at which tiles are available (e.g. PANDA levels are 0 for original, 1 for 4x downsampled, 2 for 16x downsampled). Default 1. + :param wsi_has_mask: Whether the slide has a mask or not. Default True. """ slide_index = slides_dataset.dataset_df.index.get_loc(slide_node.slide_id) assert isinstance(slide_index, int), f"Got non-unique slide ID: {slide_node.slide_id}" slide_dict = slides_dataset[slide_index] - slide_dict = load_image_dict(slide_dict, level=level, margin=0) + slide_dict = load_image_dict(slide_dict, level=level, margin=0, wsi_has_mask=wsi_has_mask) slide_image = slide_dict[SlideKey.IMAGE] - location_bbox = slide_dict[SlideKey.LOCATION] + location_bbox = slide_dict[SlideKey.ORIGIN] fig = plot_slide(case=case, slide_node=slide_node, slide_image=slide_image, scale=1.0) save_figure(fig=fig, figpath=figures_dir / f"{slide_node.slide_id}_thumbnail.png") @@ -179,6 +181,7 @@ class DeepMILPlotsHandler: figsize: Tuple[int, int] = (10, 10), stage: str = '', class_names: Optional[Sequence[str]] = None, + wsi_has_mask: bool = True, ) -> None: """_summary_ @@ -199,6 +202,7 @@ class DeepMILPlotsHandler: self.num_columns = num_columns self.figsize = figsize self.stage = stage + self.wsi_has_mask = wsi_has_mask self.slides_dataset: Optional[SlidesDataset] = None def save_slide_node_figures( @@ -226,6 +230,7 @@ class DeepMILPlotsHandler: slides_dataset=self.slides_dataset, tile_size=self.tile_size, level=self.level, + wsi_has_mask=self.wsi_has_mask, ) def save_plots(self, outputs_dir: Path, tiles_selector: Optional[TilesSelector], results: ResultsType) -> None: diff --git a/hi-ml-cpath/src/health_cpath/utils/viz_utils.py b/hi-ml-cpath/src/health_cpath/utils/viz_utils.py index 1f42f61f..f2a1b7d1 100644 --- a/hi-ml-cpath/src/health_cpath/utils/viz_utils.py +++ b/hi-ml-cpath/src/health_cpath/utils/viz_utils.py @@ -18,6 +18,7 @@ from typing import Sequence, List, Any, Dict, Optional, Union, Tuple from monai.data.dataset import Dataset from monai.data.image_reader import WSIReader from torch.utils.data import DataLoader +from health_cpath.preprocessing.loading import LoadROId from health_cpath.utils.naming import SlideKey from health_cpath.utils.naming import ResultsKey @@ -26,7 +27,7 @@ from health_cpath.utils.tiles_selection_utils import SlideNode from health_cpath.datasets.panda_dataset import PandaDataset, LoadPandaROId -def load_image_dict(sample: dict, level: int, margin: int) -> Dict[SlideKey, Any]: +def load_image_dict(sample: dict, level: int, margin: int, wsi_has_mask: bool = True) -> Dict[SlideKey, Any]: """ Load image from metadata dictionary :param sample: dict describing image metadata. Example: @@ -40,7 +41,8 @@ def load_image_dict(sample: dict, level: int, margin: int) -> Dict[SlideKey, Any :param margin: margin to be included :return: a dict containing the image data and metadata """ - loader = LoadPandaROId(WSIReader("cuCIM"), level=level, margin=margin) + transform = LoadPandaROId if wsi_has_mask else LoadROId + loader = transform(WSIReader("cuCIM"), level=level, margin=margin) img = loader(sample) return img diff --git a/hi-ml-cpath/testhisto/testhisto/utils/test_viz_utils.py b/hi-ml-cpath/testhisto/testhisto/utils/test_viz_utils.py index 70e2a76a..12a0749f 100644 --- a/hi-ml-cpath/testhisto/testhisto/utils/test_viz_utils.py +++ b/hi-ml-cpath/testhisto/testhisto/utils/test_viz_utils.py @@ -8,6 +8,7 @@ import math import random from pathlib import Path from typing import List, Optional +from unittest.mock import MagicMock, patch import matplotlib import numpy as np @@ -24,7 +25,7 @@ from health_cpath.utils.viz_utils import plot_attention_tiles, plot_scores_hist, from health_cpath.utils.naming import ResultsKey from health_cpath.utils.heatmap_utils import location_selected_tiles from health_cpath.utils.tiles_selection_utils import SlideNode, TileNode -from health_cpath.utils.viz_utils import save_figure +from health_cpath.utils.viz_utils import save_figure, load_image_dict from testhisto.utils.utils_testhisto import assert_binary_files_match, full_ml_test_data_path @@ -292,3 +293,12 @@ def test_location_selected_tiles(level: int) -> None: assert max(tile_xs) <= slide_image.shape[2] // factor assert min(tile_ys) >= 0 assert max(tile_ys) <= slide_image.shape[1] // factor + + +@pytest.mark.parametrize("wsi_has_mask", [True, False]) +def test_load_image_dict(wsi_has_mask: bool) -> None: + with patch("health_cpath.utils.viz_utils.LoadPandaROId") as mock_load_panda_roi: + with patch("health_cpath.utils.viz_utils.LoadROId") as mock_load_roi: + _ = load_image_dict(sample=MagicMock(), level=0, margin=0, wsi_has_mask=wsi_has_mask) # type: ignore + assert mock_load_panda_roi.called == wsi_has_mask + assert mock_load_roi.called == (not wsi_has_mask)