ENH: Enable heatmaps and thumbnails plots for WSI without masks (#637)

Use LoadROId for WSI without masks.
This commit is contained in:
Kenza Bouzid 2022-10-19 16:35:52 +01:00 коммит произвёл GitHub
Родитель fa4e0984d0
Коммит 2f39328a56
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
6 изменённых файлов: 37 добавлений и 10 удалений

Просмотреть файл

@ -86,6 +86,9 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams, LossCallbackPara
doc="The maximum number of worker processes for dataloaders. Dataloaders use"
"a heuristic num_cpus/num_gpus to set the number of workers, which can be"
"very high for small num_gpus. This parameters sets an upper bound.")
wsi_has_mask: bool = param.Boolean(default=True,
doc="Whether the WSI has a mask. If True, will use the mask to load a specific"
"region of the WSI. If False, will load the whole WSI.")
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
@ -152,6 +155,7 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams, LossCallbackPara
maximise=self.maximise_primary_metric,
val_plot_options=self.get_val_plot_options(),
test_plot_options=self.get_test_plot_options(),
wsi_has_mask=self.wsi_has_mask,
)
if self.num_top_slides > 0:
outputs_handler.tiles_selector = TilesSelector(

Просмотреть файл

@ -10,6 +10,7 @@ import pandas as pd
from monai.config import KeysCollection
from monai.data.image_reader import ImageReader, WSIReader
from monai.transforms import MapTransform
from health_cpath.utils.naming import SlideKey
from health_ml.utils import box_utils
@ -121,8 +122,9 @@ class LoadPandaROId(MapTransform):
# but relative region size in pixels at the chosen level
scale = mask_obj.resolutions['level_downsamples'][self.level]
scaled_bbox = level0_bbox / scale
origin = (level0_bbox.y, level0_bbox.x)
get_data_kwargs = dict(
location=(level0_bbox.y, level0_bbox.x),
location=origin,
size=(scaled_bbox.h, scaled_bbox.w),
level=self.level,
)
@ -130,7 +132,8 @@ class LoadPandaROId(MapTransform):
data[self.mask_key] = mask[:1] # PANDA segmentation mask is in 'R' channel
data[self.image_key], _ = self.reader.get_data(image_obj, **get_data_kwargs) # type: ignore
data.update(get_data_kwargs)
data['scale'] = scale
data[SlideKey.SCALE] = scale
data[SlideKey.ORIGIN] = origin
mask_obj.close()
image_obj.close()

Просмотреть файл

@ -249,7 +249,7 @@ class DeepMILOutputsHandler:
def __init__(self, outputs_root: Path, n_classes: int, tile_size: int, level: int,
class_names: Optional[Sequence[str]], primary_val_metric: MetricsKey,
maximise: bool, val_plot_options: Collection[PlotOption],
test_plot_options: Collection[PlotOption]) -> None:
test_plot_options: Collection[PlotOption], wsi_has_mask: bool = True) -> None:
"""
:param outputs_root: Root directory where to save all produced outputs.
:param n_classes: Number of MIL classes (set `n_classes=1` for binary).
@ -261,6 +261,7 @@ class DeepMILOutputsHandler:
:param maximise: Whether higher is better for `primary_val_metric`.
:param val_plot_options: The desired plot options for validation time.
:param test_plot_options: The desired plot options for test time.
:param wsi_has_mask: Whether the whole slides have a mask to crop specific ROIs.
"""
self.outputs_root = outputs_root
self.n_classes = n_classes
@ -279,14 +280,16 @@ class DeepMILOutputsHandler:
level=self.level,
tile_size=self.tile_size,
class_names=self.class_names,
stage=ModelKey.VAL
stage=ModelKey.VAL,
wsi_has_mask=wsi_has_mask
)
self.test_plots_handler = DeepMILPlotsHandler(
plot_options=test_plot_options,
level=self.level,
tile_size=self.tile_size,
class_names=self.class_names,
stage=ModelKey.TEST
stage=ModelKey.TEST,
wsi_has_mask=wsi_has_mask
)
@property

Просмотреть файл

@ -128,6 +128,7 @@ def save_slide_thumbnail_and_heatmap(
slides_dataset: SlidesDataset,
tile_size: int = 224,
level: int = 1,
wsi_has_mask: bool = True,
) -> None:
"""Plots and saves a slide thumbnail and attention heatmap
@ -139,13 +140,14 @@ def save_slide_thumbnail_and_heatmap(
:param tile_size: Size of each tile. Default 224.
:param level: Magnification at which tiles are available (e.g. PANDA levels are 0 for original,
1 for 4x downsampled, 2 for 16x downsampled). Default 1.
:param wsi_has_mask: Whether the slide has a mask or not. Default True.
"""
slide_index = slides_dataset.dataset_df.index.get_loc(slide_node.slide_id)
assert isinstance(slide_index, int), f"Got non-unique slide ID: {slide_node.slide_id}"
slide_dict = slides_dataset[slide_index]
slide_dict = load_image_dict(slide_dict, level=level, margin=0)
slide_dict = load_image_dict(slide_dict, level=level, margin=0, wsi_has_mask=wsi_has_mask)
slide_image = slide_dict[SlideKey.IMAGE]
location_bbox = slide_dict[SlideKey.LOCATION]
location_bbox = slide_dict[SlideKey.ORIGIN]
fig = plot_slide(case=case, slide_node=slide_node, slide_image=slide_image, scale=1.0)
save_figure(fig=fig, figpath=figures_dir / f"{slide_node.slide_id}_thumbnail.png")
@ -179,6 +181,7 @@ class DeepMILPlotsHandler:
figsize: Tuple[int, int] = (10, 10),
stage: str = '',
class_names: Optional[Sequence[str]] = None,
wsi_has_mask: bool = True,
) -> None:
"""_summary_
@ -199,6 +202,7 @@ class DeepMILPlotsHandler:
self.num_columns = num_columns
self.figsize = figsize
self.stage = stage
self.wsi_has_mask = wsi_has_mask
self.slides_dataset: Optional[SlidesDataset] = None
def save_slide_node_figures(
@ -226,6 +230,7 @@ class DeepMILPlotsHandler:
slides_dataset=self.slides_dataset,
tile_size=self.tile_size,
level=self.level,
wsi_has_mask=self.wsi_has_mask,
)
def save_plots(self, outputs_dir: Path, tiles_selector: Optional[TilesSelector], results: ResultsType) -> None:

Просмотреть файл

@ -18,6 +18,7 @@ from typing import Sequence, List, Any, Dict, Optional, Union, Tuple
from monai.data.dataset import Dataset
from monai.data.image_reader import WSIReader
from torch.utils.data import DataLoader
from health_cpath.preprocessing.loading import LoadROId
from health_cpath.utils.naming import SlideKey
from health_cpath.utils.naming import ResultsKey
@ -26,7 +27,7 @@ from health_cpath.utils.tiles_selection_utils import SlideNode
from health_cpath.datasets.panda_dataset import PandaDataset, LoadPandaROId
def load_image_dict(sample: dict, level: int, margin: int) -> Dict[SlideKey, Any]:
def load_image_dict(sample: dict, level: int, margin: int, wsi_has_mask: bool = True) -> Dict[SlideKey, Any]:
"""
Load image from metadata dictionary
:param sample: dict describing image metadata. Example:
@ -40,7 +41,8 @@ def load_image_dict(sample: dict, level: int, margin: int) -> Dict[SlideKey, Any
:param margin: margin to be included
:return: a dict containing the image data and metadata
"""
loader = LoadPandaROId(WSIReader("cuCIM"), level=level, margin=margin)
transform = LoadPandaROId if wsi_has_mask else LoadROId
loader = transform(WSIReader("cuCIM"), level=level, margin=margin)
img = loader(sample)
return img

Просмотреть файл

@ -8,6 +8,7 @@ import math
import random
from pathlib import Path
from typing import List, Optional
from unittest.mock import MagicMock, patch
import matplotlib
import numpy as np
@ -24,7 +25,7 @@ from health_cpath.utils.viz_utils import plot_attention_tiles, plot_scores_hist,
from health_cpath.utils.naming import ResultsKey
from health_cpath.utils.heatmap_utils import location_selected_tiles
from health_cpath.utils.tiles_selection_utils import SlideNode, TileNode
from health_cpath.utils.viz_utils import save_figure
from health_cpath.utils.viz_utils import save_figure, load_image_dict
from testhisto.utils.utils_testhisto import assert_binary_files_match, full_ml_test_data_path
@ -292,3 +293,12 @@ def test_location_selected_tiles(level: int) -> None:
assert max(tile_xs) <= slide_image.shape[2] // factor
assert min(tile_ys) >= 0
assert max(tile_ys) <= slide_image.shape[1] // factor
@pytest.mark.parametrize("wsi_has_mask", [True, False])
def test_load_image_dict(wsi_has_mask: bool) -> None:
with patch("health_cpath.utils.viz_utils.LoadPandaROId") as mock_load_panda_roi:
with patch("health_cpath.utils.viz_utils.LoadROId") as mock_load_roi:
_ = load_image_dict(sample=MagicMock(), level=0, margin=0, wsi_has_mask=wsi_has_mask) # type: ignore
assert mock_load_panda_roi.called == wsi_has_mask
assert mock_load_roi.called == (not wsi_has_mask)