зеркало из https://github.com/microsoft/hi-ml.git
ENH: Enable heatmaps and thumbnails plots for WSI without masks (#637)
Use LoadROId for WSI without masks.
This commit is contained in:
Родитель
fa4e0984d0
Коммит
2f39328a56
|
@ -86,6 +86,9 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams, LossCallbackPara
|
|||
doc="The maximum number of worker processes for dataloaders. Dataloaders use"
|
||||
"a heuristic num_cpus/num_gpus to set the number of workers, which can be"
|
||||
"very high for small num_gpus. This parameters sets an upper bound.")
|
||||
wsi_has_mask: bool = param.Boolean(default=True,
|
||||
doc="Whether the WSI has a mask. If True, will use the mask to load a specific"
|
||||
"region of the WSI. If False, will load the whole WSI.")
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
@ -152,6 +155,7 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams, LossCallbackPara
|
|||
maximise=self.maximise_primary_metric,
|
||||
val_plot_options=self.get_val_plot_options(),
|
||||
test_plot_options=self.get_test_plot_options(),
|
||||
wsi_has_mask=self.wsi_has_mask,
|
||||
)
|
||||
if self.num_top_slides > 0:
|
||||
outputs_handler.tiles_selector = TilesSelector(
|
||||
|
|
|
@ -10,6 +10,7 @@ import pandas as pd
|
|||
from monai.config import KeysCollection
|
||||
from monai.data.image_reader import ImageReader, WSIReader
|
||||
from monai.transforms import MapTransform
|
||||
from health_cpath.utils.naming import SlideKey
|
||||
|
||||
from health_ml.utils import box_utils
|
||||
|
||||
|
@ -121,8 +122,9 @@ class LoadPandaROId(MapTransform):
|
|||
# but relative region size in pixels at the chosen level
|
||||
scale = mask_obj.resolutions['level_downsamples'][self.level]
|
||||
scaled_bbox = level0_bbox / scale
|
||||
origin = (level0_bbox.y, level0_bbox.x)
|
||||
get_data_kwargs = dict(
|
||||
location=(level0_bbox.y, level0_bbox.x),
|
||||
location=origin,
|
||||
size=(scaled_bbox.h, scaled_bbox.w),
|
||||
level=self.level,
|
||||
)
|
||||
|
@ -130,7 +132,8 @@ class LoadPandaROId(MapTransform):
|
|||
data[self.mask_key] = mask[:1] # PANDA segmentation mask is in 'R' channel
|
||||
data[self.image_key], _ = self.reader.get_data(image_obj, **get_data_kwargs) # type: ignore
|
||||
data.update(get_data_kwargs)
|
||||
data['scale'] = scale
|
||||
data[SlideKey.SCALE] = scale
|
||||
data[SlideKey.ORIGIN] = origin
|
||||
|
||||
mask_obj.close()
|
||||
image_obj.close()
|
||||
|
|
|
@ -249,7 +249,7 @@ class DeepMILOutputsHandler:
|
|||
def __init__(self, outputs_root: Path, n_classes: int, tile_size: int, level: int,
|
||||
class_names: Optional[Sequence[str]], primary_val_metric: MetricsKey,
|
||||
maximise: bool, val_plot_options: Collection[PlotOption],
|
||||
test_plot_options: Collection[PlotOption]) -> None:
|
||||
test_plot_options: Collection[PlotOption], wsi_has_mask: bool = True) -> None:
|
||||
"""
|
||||
:param outputs_root: Root directory where to save all produced outputs.
|
||||
:param n_classes: Number of MIL classes (set `n_classes=1` for binary).
|
||||
|
@ -261,6 +261,7 @@ class DeepMILOutputsHandler:
|
|||
:param maximise: Whether higher is better for `primary_val_metric`.
|
||||
:param val_plot_options: The desired plot options for validation time.
|
||||
:param test_plot_options: The desired plot options for test time.
|
||||
:param wsi_has_mask: Whether the whole slides have a mask to crop specific ROIs.
|
||||
"""
|
||||
self.outputs_root = outputs_root
|
||||
self.n_classes = n_classes
|
||||
|
@ -279,14 +280,16 @@ class DeepMILOutputsHandler:
|
|||
level=self.level,
|
||||
tile_size=self.tile_size,
|
||||
class_names=self.class_names,
|
||||
stage=ModelKey.VAL
|
||||
stage=ModelKey.VAL,
|
||||
wsi_has_mask=wsi_has_mask
|
||||
)
|
||||
self.test_plots_handler = DeepMILPlotsHandler(
|
||||
plot_options=test_plot_options,
|
||||
level=self.level,
|
||||
tile_size=self.tile_size,
|
||||
class_names=self.class_names,
|
||||
stage=ModelKey.TEST
|
||||
stage=ModelKey.TEST,
|
||||
wsi_has_mask=wsi_has_mask
|
||||
)
|
||||
|
||||
@property
|
||||
|
|
|
@ -128,6 +128,7 @@ def save_slide_thumbnail_and_heatmap(
|
|||
slides_dataset: SlidesDataset,
|
||||
tile_size: int = 224,
|
||||
level: int = 1,
|
||||
wsi_has_mask: bool = True,
|
||||
) -> None:
|
||||
"""Plots and saves a slide thumbnail and attention heatmap
|
||||
|
||||
|
@ -139,13 +140,14 @@ def save_slide_thumbnail_and_heatmap(
|
|||
:param tile_size: Size of each tile. Default 224.
|
||||
:param level: Magnification at which tiles are available (e.g. PANDA levels are 0 for original,
|
||||
1 for 4x downsampled, 2 for 16x downsampled). Default 1.
|
||||
:param wsi_has_mask: Whether the slide has a mask or not. Default True.
|
||||
"""
|
||||
slide_index = slides_dataset.dataset_df.index.get_loc(slide_node.slide_id)
|
||||
assert isinstance(slide_index, int), f"Got non-unique slide ID: {slide_node.slide_id}"
|
||||
slide_dict = slides_dataset[slide_index]
|
||||
slide_dict = load_image_dict(slide_dict, level=level, margin=0)
|
||||
slide_dict = load_image_dict(slide_dict, level=level, margin=0, wsi_has_mask=wsi_has_mask)
|
||||
slide_image = slide_dict[SlideKey.IMAGE]
|
||||
location_bbox = slide_dict[SlideKey.LOCATION]
|
||||
location_bbox = slide_dict[SlideKey.ORIGIN]
|
||||
|
||||
fig = plot_slide(case=case, slide_node=slide_node, slide_image=slide_image, scale=1.0)
|
||||
save_figure(fig=fig, figpath=figures_dir / f"{slide_node.slide_id}_thumbnail.png")
|
||||
|
@ -179,6 +181,7 @@ class DeepMILPlotsHandler:
|
|||
figsize: Tuple[int, int] = (10, 10),
|
||||
stage: str = '',
|
||||
class_names: Optional[Sequence[str]] = None,
|
||||
wsi_has_mask: bool = True,
|
||||
) -> None:
|
||||
"""_summary_
|
||||
|
||||
|
@ -199,6 +202,7 @@ class DeepMILPlotsHandler:
|
|||
self.num_columns = num_columns
|
||||
self.figsize = figsize
|
||||
self.stage = stage
|
||||
self.wsi_has_mask = wsi_has_mask
|
||||
self.slides_dataset: Optional[SlidesDataset] = None
|
||||
|
||||
def save_slide_node_figures(
|
||||
|
@ -226,6 +230,7 @@ class DeepMILPlotsHandler:
|
|||
slides_dataset=self.slides_dataset,
|
||||
tile_size=self.tile_size,
|
||||
level=self.level,
|
||||
wsi_has_mask=self.wsi_has_mask,
|
||||
)
|
||||
|
||||
def save_plots(self, outputs_dir: Path, tiles_selector: Optional[TilesSelector], results: ResultsType) -> None:
|
||||
|
|
|
@ -18,6 +18,7 @@ from typing import Sequence, List, Any, Dict, Optional, Union, Tuple
|
|||
from monai.data.dataset import Dataset
|
||||
from monai.data.image_reader import WSIReader
|
||||
from torch.utils.data import DataLoader
|
||||
from health_cpath.preprocessing.loading import LoadROId
|
||||
|
||||
from health_cpath.utils.naming import SlideKey
|
||||
from health_cpath.utils.naming import ResultsKey
|
||||
|
@ -26,7 +27,7 @@ from health_cpath.utils.tiles_selection_utils import SlideNode
|
|||
from health_cpath.datasets.panda_dataset import PandaDataset, LoadPandaROId
|
||||
|
||||
|
||||
def load_image_dict(sample: dict, level: int, margin: int) -> Dict[SlideKey, Any]:
|
||||
def load_image_dict(sample: dict, level: int, margin: int, wsi_has_mask: bool = True) -> Dict[SlideKey, Any]:
|
||||
"""
|
||||
Load image from metadata dictionary
|
||||
:param sample: dict describing image metadata. Example:
|
||||
|
@ -40,7 +41,8 @@ def load_image_dict(sample: dict, level: int, margin: int) -> Dict[SlideKey, Any
|
|||
:param margin: margin to be included
|
||||
:return: a dict containing the image data and metadata
|
||||
"""
|
||||
loader = LoadPandaROId(WSIReader("cuCIM"), level=level, margin=margin)
|
||||
transform = LoadPandaROId if wsi_has_mask else LoadROId
|
||||
loader = transform(WSIReader("cuCIM"), level=level, margin=margin)
|
||||
img = loader(sample)
|
||||
return img
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ import math
|
|||
import random
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import matplotlib
|
||||
import numpy as np
|
||||
|
@ -24,7 +25,7 @@ from health_cpath.utils.viz_utils import plot_attention_tiles, plot_scores_hist,
|
|||
from health_cpath.utils.naming import ResultsKey
|
||||
from health_cpath.utils.heatmap_utils import location_selected_tiles
|
||||
from health_cpath.utils.tiles_selection_utils import SlideNode, TileNode
|
||||
from health_cpath.utils.viz_utils import save_figure
|
||||
from health_cpath.utils.viz_utils import save_figure, load_image_dict
|
||||
from testhisto.utils.utils_testhisto import assert_binary_files_match, full_ml_test_data_path
|
||||
|
||||
|
||||
|
@ -292,3 +293,12 @@ def test_location_selected_tiles(level: int) -> None:
|
|||
assert max(tile_xs) <= slide_image.shape[2] // factor
|
||||
assert min(tile_ys) >= 0
|
||||
assert max(tile_ys) <= slide_image.shape[1] // factor
|
||||
|
||||
|
||||
@pytest.mark.parametrize("wsi_has_mask", [True, False])
|
||||
def test_load_image_dict(wsi_has_mask: bool) -> None:
|
||||
with patch("health_cpath.utils.viz_utils.LoadPandaROId") as mock_load_panda_roi:
|
||||
with patch("health_cpath.utils.viz_utils.LoadROId") as mock_load_roi:
|
||||
_ = load_image_dict(sample=MagicMock(), level=0, margin=0, wsi_has_mask=wsi_has_mask) # type: ignore
|
||||
assert mock_load_panda_roi.called == wsi_has_mask
|
||||
assert mock_load_roi.called == (not wsi_has_mask)
|
||||
|
|
Загрузка…
Ссылка в новой задаче