ENH: Parameterize CuCim backend (#661)

Add wsi_reader_args to be able to use different backend
This commit is contained in:
Kenza Bouzid 2022-11-15 13:28:59 +00:00 коммит произвёл GitHub
Родитель f97d368870
Коммит a1f580576c
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
6 изменённых файлов: 39 добавлений и 14 удалений

Просмотреть файл

@ -79,6 +79,7 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams, ClassifierParams
wsi_has_mask: bool = param.Boolean(default=True,
doc="Whether the WSI has a mask. If True, will use the mask to load a specific"
"region of the WSI. If False, will load the whole WSI.")
wsi_backend: str = param.String(default="cuCIM", doc="The backend to use for loading WSI. ")
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
@ -152,6 +153,7 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams, ClassifierParams
val_plot_options=self.get_val_plot_options(),
test_plot_options=self.get_test_plot_options(),
wsi_has_mask=self.wsi_has_mask,
backend=self.wsi_backend,
val_set_is_dist=self.pl_replace_sampler_ddp and self.max_num_gpus > 1,
)
if self.num_top_slides > 0:

Просмотреть файл

@ -25,6 +25,7 @@ from monai.transforms.io.dictionary import LoadImaged
from monai.apps.pathology.transforms import TileOnGridd
from monai.data.image_reader import WSIReader
_SlidesOrTilesDataset = TypeVar('_SlidesOrTilesDataset', SlidesDataset, TilesDataset)
@ -280,12 +281,14 @@ class SlidesDataModule(HistoDataModule[SlidesDataset]):
random_offset: bool = True,
pad_full: bool = False,
background_val: int = 255,
filter_mode: str = "min",
filter_mode: str = 'min',
backend: str = 'cuCIM',
wsi_reader_args: Dict[str, Any] = {},
**kwargs: Any,
) -> None:
"""
:param level: the whole slide image level at which the image is extracted, defaults to 1
this param is passed to the LoadImaged monai transform that loads a WSI with cucim backend
this param is passed to the LoadImaged monai transform that loads a WSI with cucim backend by default
:param tile_size: size of the square tile, defaults to 224
this param is passed to TileOnGridd monai transform for tiling on the fly.
:param step: step size to create overlapping tiles, defaults to None (same as tile_size)
@ -301,15 +304,22 @@ class SlidesDataModule(HistoDataModule[SlidesDataset]):
tile_count, then sort by intensity sum, and take the smallest (for min), largest (for max) or random (for
random) subset, defaults to "min" (which assumes background is high value). This param is passed to TileOnGridd
monai transform for tiling on the fly.
:param backend: the WSI reader backend, defaults to "cuCIM". This param is passed to LoadImaged monai transform
:param wsi_reader_args: Additional arguments to pass to the WSIReader, defaults to {}. Multi processing is
enabled since monai 1.0.0 by specifying num_workers > 0 with CuCIM backend only.
"""
super().__init__(**kwargs)
self.level = level
# Tiling on the fly params
self.tile_size = tile_size
self.step = step
self.random_offset = random_offset
self.pad_full = pad_full
self.background_val = background_val
self.filter_mode = filter_mode
# WSIReader params
self.level = level
self.backend = backend
self.wsi_reader_args = wsi_reader_args
# TileOnGridd transform expects None to select all foreground tile so we hardcode max_bag_size and
# max_bag_size_inf to None if set to 0
for stage_key, max_bag_size in self.bag_sizes.items():
@ -322,10 +332,11 @@ class SlidesDataModule(HistoDataModule[SlidesDataset]):
LoadImaged(
keys=slides_dataset.IMAGE_COLUMN,
reader=WSIReader,
backend="cuCIM",
dtype=np.uint8,
level=self.level,
image_only=True,
level=self.level,
backend=self.backend,
**self.wsi_reader_args,
),
TileOnGridd(
keys=slides_dataset.IMAGE_COLUMN,

Просмотреть файл

@ -258,7 +258,7 @@ class DeepMILOutputsHandler:
class_names: Optional[Sequence[str]], primary_val_metric: MetricsKey,
maximise: bool, val_plot_options: Collection[PlotOption],
test_plot_options: Collection[PlotOption], wsi_has_mask: bool = True,
val_set_is_dist: bool = True) -> None:
backend: str = "cuCIM", val_set_is_dist: bool = True) -> None:
"""
:param outputs_root: Root directory where to save all produced outputs.
:param n_classes: Number of MIL classes (set `n_classes=1` for binary).
@ -271,6 +271,7 @@ class DeepMILOutputsHandler:
:param val_plot_options: The desired plot options for validation time.
:param test_plot_options: The desired plot options for test time.
:param wsi_has_mask: Whether the whole slides have a mask to crop specific ROIs.
:param backend: The backend to use for reading the tiles. Default is "cuCIM".
:param val_set_is_dist: If True, the validation set is distributed across processes. Otherwise, the validation
set is replicated on each process. This shouldn't affect the results, as we take the mean of the validation
set metrics across processes. This is only relevant for the outputs_handler, which needs to know whether to
@ -294,7 +295,8 @@ class DeepMILOutputsHandler:
tile_size=self.tile_size,
class_names=self.class_names,
stage=ModelKey.VAL,
wsi_has_mask=wsi_has_mask
wsi_has_mask=wsi_has_mask,
backend=backend,
)
self.test_plots_handler = DeepMILPlotsHandler(
plot_options=test_plot_options,
@ -302,7 +304,8 @@ class DeepMILOutputsHandler:
tile_size=self.tile_size,
class_names=self.class_names,
stage=ModelKey.TEST,
wsi_has_mask=wsi_has_mask
wsi_has_mask=wsi_has_mask,
backend=backend,
)
self.val_set_is_dist = val_set_is_dist

Просмотреть файл

@ -183,6 +183,7 @@ class DeepMILPlotsHandler:
stage: str = '',
class_names: Optional[Sequence[str]] = None,
wsi_has_mask: bool = True,
backend: str = "cuCIM",
) -> None:
"""Class that handles the plotting of DeepMIL results.
@ -195,6 +196,8 @@ class DeepMILPlotsHandler:
:param stage: Test or Validation, used to name the plots
:param class_names: List of class names, defaults to None
:param slides_dataset: The slides dataset from where to load the whole slide images, defaults to None
:param wsi_has_mask: Whether the whole slide images have a mask, defaults to True
:param backend: The backend to use for loading the whole slide images, defaults to "cuCIM"
"""
self.plot_options = plot_options
self.class_names = validate_class_names_for_plot_options(class_names, plot_options)
@ -204,6 +207,7 @@ class DeepMILPlotsHandler:
self.figsize = figsize
self.stage = stage
self.wsi_has_mask = wsi_has_mask
self.backend = backend
self.slides_dataset: Optional[SlidesDataset] = None
def get_slide_dict(self, slide_node: SlideNode) -> SlideDictType:
@ -212,7 +216,8 @@ class DeepMILPlotsHandler:
slide_index = self.slides_dataset.dataset_df.index.get_loc(slide_node.slide_id)
assert isinstance(slide_index, int), f"Got non-unique slide ID: {slide_node.slide_id}"
slide_dict = self.slides_dataset[slide_index]
slide_dict = load_image_dict(slide_dict, level=self.level, margin=0, wsi_has_mask=self.wsi_has_mask)
slide_dict = load_image_dict(slide_dict, level=self.level, margin=0, wsi_has_mask=self.wsi_has_mask,
backend=self.backend)
return slide_dict
def save_slide_node_figures(

Просмотреть файл

@ -27,7 +27,9 @@ from health_cpath.utils.tiles_selection_utils import SlideNode
from health_cpath.datasets.panda_dataset import PandaDataset, LoadPandaROId
def load_image_dict(sample: dict, level: int, margin: int, wsi_has_mask: bool = True) -> Dict[SlideKey, Any]:
def load_image_dict(
sample: dict, level: int, margin: int, wsi_has_mask: bool = True, backend: str = "cuCIM"
) -> Dict[SlideKey, Any]:
"""
Load image from metadata dictionary
:param sample: dict describing image metadata. Example:
@ -39,10 +41,11 @@ def load_image_dict(sample: dict, level: int, margin: int, wsi_has_mask: bool =
'gleason_score': ['0+0']}
:param level: level of resolution to be loaded
:param margin: margin to be included
:return: a dict containing the image data and metadata
:param wsi_has_mask: whether the WSI has a mask
:param backend: backend to be used to load the image (cuCIM or OpenSlide)
"""
transform = LoadPandaROId if wsi_has_mask else LoadROId
loader = transform(WSIReader("cuCIM"), level=level, margin=margin)
loader = transform(WSIReader(backend=backend), level=level, margin=margin)
img = loader(sample)
return img

Просмотреть файл

@ -297,10 +297,11 @@ def test_location_selected_tiles(level: int) -> None:
assert max(tile_ys) <= slide_image.shape[1] // factor
@pytest.mark.parametrize("backend", ["cuCIM", "OpenSlide"])
@pytest.mark.parametrize("wsi_has_mask", [True, False])
def test_load_image_dict(wsi_has_mask: bool) -> None:
def test_load_image_dict(wsi_has_mask: bool, backend: str) -> None:
with patch("health_cpath.utils.viz_utils.LoadPandaROId") as mock_load_panda_roi:
with patch("health_cpath.utils.viz_utils.LoadROId") as mock_load_roi:
_ = load_image_dict(sample=MagicMock(), level=0, margin=0, wsi_has_mask=wsi_has_mask) # type: ignore
_ = load_image_dict(sample=MagicMock(), level=0, margin=0, wsi_has_mask=wsi_has_mask, backend=backend)
assert mock_load_panda_roi.called == wsi_has_mask
assert mock_load_roi.called == (not wsi_has_mask)