diff --git a/hi-ml-cpath/src/health_cpath/configs/classification/BaseMIL.py b/hi-ml-cpath/src/health_cpath/configs/classification/BaseMIL.py index c65148ab..efbda524 100644 --- a/hi-ml-cpath/src/health_cpath/configs/classification/BaseMIL.py +++ b/hi-ml-cpath/src/health_cpath/configs/classification/BaseMIL.py @@ -79,6 +79,7 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams, ClassifierParams wsi_has_mask: bool = param.Boolean(default=True, doc="Whether the WSI has a mask. If True, will use the mask to load a specific" "region of the WSI. If False, will load the whole WSI.") + wsi_backend: str = param.String(default="cuCIM", doc="The backend to use for loading WSI. ") def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) @@ -152,6 +153,7 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams, ClassifierParams val_plot_options=self.get_val_plot_options(), test_plot_options=self.get_test_plot_options(), wsi_has_mask=self.wsi_has_mask, + backend=self.wsi_backend, val_set_is_dist=self.pl_replace_sampler_ddp and self.max_num_gpus > 1, ) if self.num_top_slides > 0: diff --git a/hi-ml-cpath/src/health_cpath/datamodules/base_module.py b/hi-ml-cpath/src/health_cpath/datamodules/base_module.py index 6c0689a5..8a3d953b 100644 --- a/hi-ml-cpath/src/health_cpath/datamodules/base_module.py +++ b/hi-ml-cpath/src/health_cpath/datamodules/base_module.py @@ -25,6 +25,7 @@ from monai.transforms.io.dictionary import LoadImaged from monai.apps.pathology.transforms import TileOnGridd from monai.data.image_reader import WSIReader + _SlidesOrTilesDataset = TypeVar('_SlidesOrTilesDataset', SlidesDataset, TilesDataset) @@ -280,12 +281,14 @@ class SlidesDataModule(HistoDataModule[SlidesDataset]): random_offset: bool = True, pad_full: bool = False, background_val: int = 255, - filter_mode: str = "min", + filter_mode: str = 'min', + backend: str = 'cuCIM', + wsi_reader_args: Dict[str, Any] = {}, **kwargs: Any, ) -> None: """ :param level: the whole slide image level at which the image is extracted, defaults to 1 - this param is passed to the LoadImaged monai transform that loads a WSI with cucim backend + this param is passed to the LoadImaged monai transform that loads a WSI with cucim backend by default :param tile_size: size of the square tile, defaults to 224 this param is passed to TileOnGridd monai transform for tiling on the fly. :param step: step size to create overlapping tiles, defaults to None (same as tile_size) @@ -301,15 +304,22 @@ class SlidesDataModule(HistoDataModule[SlidesDataset]): tile_count, then sort by intensity sum, and take the smallest (for min), largest (for max) or random (for random) subset, defaults to "min" (which assumes background is high value). This param is passed to TileOnGridd monai transform for tiling on the fly. + :param backend: the WSI reader backend, defaults to "cuCIM". This param is passed to LoadImaged monai transform + :param wsi_reader_args: Additional arguments to pass to the WSIReader, defaults to {}. Multi processing is + enabled since monai 1.0.0 by specifying num_workers > 0 with CuCIM backend only. """ super().__init__(**kwargs) - self.level = level + # Tiling on the fly params self.tile_size = tile_size self.step = step self.random_offset = random_offset self.pad_full = pad_full self.background_val = background_val self.filter_mode = filter_mode + # WSIReader params + self.level = level + self.backend = backend + self.wsi_reader_args = wsi_reader_args # TileOnGridd transform expects None to select all foreground tile so we hardcode max_bag_size and # max_bag_size_inf to None if set to 0 for stage_key, max_bag_size in self.bag_sizes.items(): @@ -322,10 +332,11 @@ class SlidesDataModule(HistoDataModule[SlidesDataset]): LoadImaged( keys=slides_dataset.IMAGE_COLUMN, reader=WSIReader, - backend="cuCIM", dtype=np.uint8, - level=self.level, image_only=True, + level=self.level, + backend=self.backend, + **self.wsi_reader_args, ), TileOnGridd( keys=slides_dataset.IMAGE_COLUMN, diff --git a/hi-ml-cpath/src/health_cpath/utils/output_utils.py b/hi-ml-cpath/src/health_cpath/utils/output_utils.py index 8a99370a..d16e08a1 100644 --- a/hi-ml-cpath/src/health_cpath/utils/output_utils.py +++ b/hi-ml-cpath/src/health_cpath/utils/output_utils.py @@ -258,7 +258,7 @@ class DeepMILOutputsHandler: class_names: Optional[Sequence[str]], primary_val_metric: MetricsKey, maximise: bool, val_plot_options: Collection[PlotOption], test_plot_options: Collection[PlotOption], wsi_has_mask: bool = True, - val_set_is_dist: bool = True) -> None: + backend: str = "cuCIM", val_set_is_dist: bool = True) -> None: """ :param outputs_root: Root directory where to save all produced outputs. :param n_classes: Number of MIL classes (set `n_classes=1` for binary). @@ -271,6 +271,7 @@ class DeepMILOutputsHandler: :param val_plot_options: The desired plot options for validation time. :param test_plot_options: The desired plot options for test time. :param wsi_has_mask: Whether the whole slides have a mask to crop specific ROIs. + :param backend: The backend to use for reading the tiles. Default is "cuCIM". :param val_set_is_dist: If True, the validation set is distributed across processes. Otherwise, the validation set is replicated on each process. This shouldn't affect the results, as we take the mean of the validation set metrics across processes. This is only relevant for the outputs_handler, which needs to know whether to @@ -294,7 +295,8 @@ class DeepMILOutputsHandler: tile_size=self.tile_size, class_names=self.class_names, stage=ModelKey.VAL, - wsi_has_mask=wsi_has_mask + wsi_has_mask=wsi_has_mask, + backend=backend, ) self.test_plots_handler = DeepMILPlotsHandler( plot_options=test_plot_options, @@ -302,7 +304,8 @@ class DeepMILOutputsHandler: tile_size=self.tile_size, class_names=self.class_names, stage=ModelKey.TEST, - wsi_has_mask=wsi_has_mask + wsi_has_mask=wsi_has_mask, + backend=backend, ) self.val_set_is_dist = val_set_is_dist diff --git a/hi-ml-cpath/src/health_cpath/utils/plots_utils.py b/hi-ml-cpath/src/health_cpath/utils/plots_utils.py index 8e5230c4..1fbe8667 100644 --- a/hi-ml-cpath/src/health_cpath/utils/plots_utils.py +++ b/hi-ml-cpath/src/health_cpath/utils/plots_utils.py @@ -183,6 +183,7 @@ class DeepMILPlotsHandler: stage: str = '', class_names: Optional[Sequence[str]] = None, wsi_has_mask: bool = True, + backend: str = "cuCIM", ) -> None: """Class that handles the plotting of DeepMIL results. @@ -195,6 +196,8 @@ class DeepMILPlotsHandler: :param stage: Test or Validation, used to name the plots :param class_names: List of class names, defaults to None :param slides_dataset: The slides dataset from where to load the whole slide images, defaults to None + :param wsi_has_mask: Whether the whole slide images have a mask, defaults to True + :param backend: The backend to use for loading the whole slide images, defaults to "cuCIM" """ self.plot_options = plot_options self.class_names = validate_class_names_for_plot_options(class_names, plot_options) @@ -204,6 +207,7 @@ class DeepMILPlotsHandler: self.figsize = figsize self.stage = stage self.wsi_has_mask = wsi_has_mask + self.backend = backend self.slides_dataset: Optional[SlidesDataset] = None def get_slide_dict(self, slide_node: SlideNode) -> SlideDictType: @@ -212,7 +216,8 @@ class DeepMILPlotsHandler: slide_index = self.slides_dataset.dataset_df.index.get_loc(slide_node.slide_id) assert isinstance(slide_index, int), f"Got non-unique slide ID: {slide_node.slide_id}" slide_dict = self.slides_dataset[slide_index] - slide_dict = load_image_dict(slide_dict, level=self.level, margin=0, wsi_has_mask=self.wsi_has_mask) + slide_dict = load_image_dict(slide_dict, level=self.level, margin=0, wsi_has_mask=self.wsi_has_mask, + backend=self.backend) return slide_dict def save_slide_node_figures( diff --git a/hi-ml-cpath/src/health_cpath/utils/viz_utils.py b/hi-ml-cpath/src/health_cpath/utils/viz_utils.py index f2a1b7d1..6b91702f 100644 --- a/hi-ml-cpath/src/health_cpath/utils/viz_utils.py +++ b/hi-ml-cpath/src/health_cpath/utils/viz_utils.py @@ -27,7 +27,9 @@ from health_cpath.utils.tiles_selection_utils import SlideNode from health_cpath.datasets.panda_dataset import PandaDataset, LoadPandaROId -def load_image_dict(sample: dict, level: int, margin: int, wsi_has_mask: bool = True) -> Dict[SlideKey, Any]: +def load_image_dict( + sample: dict, level: int, margin: int, wsi_has_mask: bool = True, backend: str = "cuCIM" +) -> Dict[SlideKey, Any]: """ Load image from metadata dictionary :param sample: dict describing image metadata. Example: @@ -39,10 +41,11 @@ def load_image_dict(sample: dict, level: int, margin: int, wsi_has_mask: bool = 'gleason_score': ['0+0']} :param level: level of resolution to be loaded :param margin: margin to be included - :return: a dict containing the image data and metadata + :param wsi_has_mask: whether the WSI has a mask + :param backend: backend to be used to load the image (cuCIM or OpenSlide) """ transform = LoadPandaROId if wsi_has_mask else LoadROId - loader = transform(WSIReader("cuCIM"), level=level, margin=margin) + loader = transform(WSIReader(backend=backend), level=level, margin=margin) img = loader(sample) return img diff --git a/hi-ml-cpath/testhisto/testhisto/utils/test_viz_utils.py b/hi-ml-cpath/testhisto/testhisto/utils/test_viz_utils.py index d6117f09..4f7d8662 100644 --- a/hi-ml-cpath/testhisto/testhisto/utils/test_viz_utils.py +++ b/hi-ml-cpath/testhisto/testhisto/utils/test_viz_utils.py @@ -297,10 +297,11 @@ def test_location_selected_tiles(level: int) -> None: assert max(tile_ys) <= slide_image.shape[1] // factor +@pytest.mark.parametrize("backend", ["cuCIM", "OpenSlide"]) @pytest.mark.parametrize("wsi_has_mask", [True, False]) -def test_load_image_dict(wsi_has_mask: bool) -> None: +def test_load_image_dict(wsi_has_mask: bool, backend: str) -> None: with patch("health_cpath.utils.viz_utils.LoadPandaROId") as mock_load_panda_roi: with patch("health_cpath.utils.viz_utils.LoadROId") as mock_load_roi: - _ = load_image_dict(sample=MagicMock(), level=0, margin=0, wsi_has_mask=wsi_has_mask) # type: ignore + _ = load_image_dict(sample=MagicMock(), level=0, margin=0, wsi_has_mask=wsi_has_mask, backend=backend) assert mock_load_panda_roi.called == wsi_has_mask assert mock_load_roi.called == (not wsi_has_mask)